1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Alexander Heinecke, Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_DNN_RNNCELL_H 12 #define LIBXSMM_DNN_RNNCELL_H 13 14 #include "libxsmm_dnn.h" 15 #include "libxsmm_dnn_tensor.h" 16 17 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell libxsmm_dnn_rnncell; 18 19 /** Type of algorithm used for convolutions. */ 20 typedef enum libxsmm_dnn_rnncell_type { 21 /** simple RNN cell with ReLU as activation function */ 22 LIBXSMM_DNN_RNNCELL_RNN_RELU, 23 /** simple RNN cell with sigmoid as activation function */ 24 LIBXSMM_DNN_RNNCELL_RNN_SIGMOID, 25 /** simple RNN cell with tanh as activation function */ 26 LIBXSMM_DNN_RNNCELL_RNN_TANH, 27 /** LSTM cell */ 28 LIBXSMM_DNN_RNNCELL_LSTM, 29 /** GRU cell */ 30 LIBXSMM_DNN_RNNCELL_GRU 31 } libxsmm_dnn_rnncell_type; 32 33 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell_desc { 34 int threads; 35 libxsmm_blasint K; /* number of outputs */ 36 libxsmm_blasint N; /* size of the minibatch */ 37 libxsmm_blasint C; /* number of inputs */ 38 libxsmm_blasint max_T; /* number of time steps */ 39 libxsmm_blasint bk; 40 libxsmm_blasint bn; 41 libxsmm_blasint bc; 42 libxsmm_dnn_rnncell_type cell_type; /* cell type RNN ReLU, RNN Sigmoid, RNN Tanh, LSTM, GRU */ 43 libxsmm_dnn_datatype datatype_in; /* datatypes used for all input related buffer */ 44 libxsmm_dnn_datatype datatype_out; /* datatypes used for all output related buffer */ 45 libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ 46 libxsmm_dnn_tensor_format filter_format; /* format which is for filter buffers */ 47 } libxsmm_dnn_rnncell_desc; 48 49 LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status); 50 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle); 51 52 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); 53 54 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status); 55 LIBXSMM_API void* libxsmm_dnn_rnncell_get_scratch_ptr (const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status); 56 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch); 57 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind); 58 59 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status); 60 LIBXSMM_API void* libxsmm_dnn_rnncell_get_internalstate_ptr (const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status); 61 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate); 62 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind); 63 64 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias); 65 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type); 66 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status); 67 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type); 68 69 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T ); 70 LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status ); 71 72 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind, 73 /*unsigned*/int start_thread, /*unsigned*/int tid); 74 75 #endif /*LIBXSMM_DNN_RNNCELL_H*/ 76 77