1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_DNN_RNNCELL_H
12 #define LIBXSMM_DNN_RNNCELL_H
13 
14 #include "libxsmm_dnn.h"
15 #include "libxsmm_dnn_tensor.h"
16 
17 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell libxsmm_dnn_rnncell;
18 
19 /** Type of algorithm used for convolutions. */
20 typedef enum libxsmm_dnn_rnncell_type {
21   /** simple RNN cell with ReLU as activation function */
22   LIBXSMM_DNN_RNNCELL_RNN_RELU,
23   /** simple RNN cell with sigmoid as activation function */
24   LIBXSMM_DNN_RNNCELL_RNN_SIGMOID,
25   /** simple RNN cell with tanh as activation function */
26   LIBXSMM_DNN_RNNCELL_RNN_TANH,
27   /** LSTM cell */
28   LIBXSMM_DNN_RNNCELL_LSTM,
29   /** GRU cell */
30   LIBXSMM_DNN_RNNCELL_GRU
31 } libxsmm_dnn_rnncell_type;
32 
33 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell_desc {
34   int threads;
35   libxsmm_blasint K;         /* number of outputs */
36   libxsmm_blasint N;         /* size of the minibatch */
37   libxsmm_blasint C;         /* number of inputs */
38   libxsmm_blasint max_T;     /* number of time steps */
39   libxsmm_blasint bk;
40   libxsmm_blasint bn;
41   libxsmm_blasint bc;
42   libxsmm_dnn_rnncell_type cell_type;       /* cell type RNN ReLU, RNN Sigmoid, RNN Tanh, LSTM, GRU */
43   libxsmm_dnn_datatype datatype_in;         /* datatypes used for all input related buffer */
44   libxsmm_dnn_datatype datatype_out;        /* datatypes used for all output related buffer */
45   libxsmm_dnn_tensor_format buffer_format;  /* format which is for activation buffers */
46   libxsmm_dnn_tensor_format filter_format;  /* format which is for filter buffers */
47 } libxsmm_dnn_rnncell_desc;
48 
49 LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status);
50 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle);
51 
52 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status);
53 
54 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status);
55 LIBXSMM_API void*  libxsmm_dnn_rnncell_get_scratch_ptr (const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status);
56 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch);
57 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind);
58 
59 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status);
60 LIBXSMM_API void*  libxsmm_dnn_rnncell_get_internalstate_ptr (const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status);
61 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate);
62 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind);
63 
64 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias);
65 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type);
66 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status);
67 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type);
68 
69 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T );
70 LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status );
71 
72 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind,
73   /*unsigned*/int start_thread, /*unsigned*/int tid);
74 
75 #endif /*LIBXSMM_DNN_RNNCELL_H*/
76 
77