1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_rnncell_forward.h"
12 #include "libxsmm_dnn_elementwise.h"
13 #include "libxsmm_main.h"
14 
15 
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
17 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
18 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
19 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
20 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
21 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
22 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
23 
24 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)25 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
26 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
27 {
28   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
29 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
30   typedef float element_input_type;
31   typedef float element_output_type;
32   typedef float element_filter_type;
33   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
34 # define LIBXSMM_DNN_RNN_RELU_FWD
35 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
36 # undef LIBXSMM_DNN_RNN_RELU_FWD
37   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
38 # define LIBXSMM_DNN_RNN_SIGMOID_FWD
39 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
40 # undef LIBXSMM_DNN_RNN_SIGMOID_FWD
41   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
42 # define LIBXSMM_DNN_RNN_TANH_FWD
43 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
44 # undef LIBXSMM_DNN_RNN_TANH_FWD
45   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
46 #define LIBXSMM_RNN_CELL_AVX512
47 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
48 #undef LIBXSMM_RNN_CELL_AVX512
49   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
50 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
51   } else {
52     /* should not happen */
53   }
54 #else /* should not happen */
55   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
56   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
57 #endif
58   return status;
59 }
60 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)61 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
62 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
63 {
64   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
65 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__, __AVX512BW__, __AVX512DQ__*/
66   typedef libxsmm_bfloat16 element_input_type;
67   typedef libxsmm_bfloat16 element_output_type;
68   typedef libxsmm_bfloat16 element_filter_type;
69 
70   /* some portable macrros fof BF16 <-> FP32 */
71 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
72 
73   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
74     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
75   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
76     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
77   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
78     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
79   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
80 #define LIBXSMM_RNN_CELL_AVX512
81 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
82 #undef LIBXSMM_RNN_CELL_AVX512
83   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
84     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
85   } else {
86     /* should not happen */
87   }
88 
89 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
90 #else /* should not happen */
91   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
92   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
93 #endif
94   return status;
95 }
96 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)97 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
98 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
99 {
100   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
101 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/
102   typedef libxsmm_bfloat16 element_input_type;
103   typedef libxsmm_bfloat16 element_output_type;
104   typedef libxsmm_bfloat16 element_filter_type;
105 
106 #define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
107   /* some portable macrros fof BF16 <-> FP32 */
108 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
109 
110   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
111     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
112   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
113     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
114   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
115     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
116   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
117 #define LIBXSMM_RNN_CELL_AVX512
118 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
119 #undef LIBXSMM_RNN_CELL_AVX512
120   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
121     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
122   } else {
123     /* should not happen */
124   }
125 
126 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
127 #undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
128 #else /* should not happen */
129   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
130   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
131 #endif
132   return status;
133 }
134 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)135 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
136 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
137 {
138   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
139 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
140   typedef float element_input_type;
141   typedef float element_output_type;
142   typedef float element_filter_type;
143   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
144 # define LIBXSMM_DNN_RNN_RELU_FWD
145 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
146 # undef LIBXSMM_DNN_RNN_RELU_FWD
147   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
148 # define LIBXSMM_DNN_RNN_SIGMOID_FWD
149 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
150 # undef LIBXSMM_DNN_RNN_SIGMOID_FWD
151   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
152 # define LIBXSMM_DNN_RNN_TANH_FWD
153 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
154 # undef LIBXSMM_DNN_RNN_TANH_FWD
155   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
156     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
157   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
158     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
159   } else {
160     /* should not happen */
161   }
162 #else /* should not happen */
163   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
164   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
165 #endif
166   return status;
167 }
168 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)169 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
170 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
171 {
172   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
173 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
174   typedef float element_input_type;
175   typedef float element_output_type;
176   typedef float element_filter_type;
177   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
178 # define LIBXSMM_DNN_RNN_RELU_FWD
179 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
180 # undef LIBXSMM_DNN_RNN_RELU_FWD
181   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
182 # define LIBXSMM_DNN_RNN_SIGMOID_FWD
183 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
184 # undef LIBXSMM_DNN_RNN_SIGMOID_FWD
185   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
186 # define LIBXSMM_DNN_RNN_TANH_FWD
187 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
188 # undef LIBXSMM_DNN_RNN_TANH_FWD
189   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
190 #define LIBXSMM_RNN_CELL_AVX512
191 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
192 #undef LIBXSMM_RNN_CELL_AVX512
193   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
194 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
195   } else {
196     /* should not happen */
197   }
198 #else /* should not happen */
199   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
200   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
201 #endif
202   return status;
203 }
204 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)205 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
206 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
207 {
208   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
209 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
210   typedef libxsmm_bfloat16 element_input_type;
211   typedef libxsmm_bfloat16 element_output_type;
212   typedef libxsmm_bfloat16 element_filter_type;
213 
214   /* some portable macrros fof BF16 <-> FP32 */
215 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
216 
217   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
218     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
219   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
220     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
221   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
222     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
223   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
224 #define LIBXSMM_RNN_CELL_AVX512
225 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
226 #undef LIBXSMM_RNN_CELL_AVX512
227   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
228     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
229   } else {
230     /* should not happen */
231   }
232 
233 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
234 #else /* should not happen */
235   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
236   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
237 #endif
238   return status;
239 }
240 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)241 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
242 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
243 {
244   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
245 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
246   typedef libxsmm_bfloat16 element_input_type;
247   typedef libxsmm_bfloat16 element_output_type;
248   typedef libxsmm_bfloat16 element_filter_type;
249 
250 #define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
251   /* some portable macrros fof BF16 <-> FP32 */
252 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
253 
254   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
255     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
256   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
257     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
258   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
259     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
260   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
261 #define LIBXSMM_RNN_CELL_AVX512
262 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
263 #undef LIBXSMM_RNN_CELL_AVX512
264   } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
265     status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
266   } else {
267     /* should not happen */
268   }
269 
270 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
271 #undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
272 #else /* should not happen */
273   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
274   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
275 #endif
276   return status;
277 }
278 
libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell * handle,int start_thread,int tid)279 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
280 {
281   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
282 
283   /* check if we have input, output and filter */
284 #if 0
285   if (handle->? == 0 ) {
286     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
287     return status;
288   }
289 #endif
290 
291   /* check if we are on AVX512 */
292 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
293   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
294     if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
295       status = libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32( handle, start_thread, tid);
296     }
297 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
298     else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
299       status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid);
300     } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX ) {
301       status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16( handle, start_thread, tid);
302     }
303 #elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
304     else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) {
305       status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid);
306     }
307 #endif
308     else {
309       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
310       return status;
311     }
312   } else
313 #endif
314   {
315     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
316       typedef float element_input_type;
317       typedef float element_output_type;
318       typedef float element_filter_type;
319       if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
320 #define LIBXSMM_DNN_RNN_RELU_FWD
321 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
322 #undef LIBXSMM_DNN_RNN_RELU_FWD
323       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
324 #define LIBXSMM_DNN_RNN_SIGMOID_FWD
325 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
326 #undef LIBXSMM_DNN_RNN_SIGMOID_FWD
327       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
328 #define LIBXSMM_DNN_RNN_TANH_FWD
329 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
330 #undef LIBXSMM_DNN_RNN_TANH_FWD
331       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
332 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
333       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
334 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
335       } else {
336         /* should not happen */
337       }
338     } else {
339       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
340       return status;
341     }
342   }
343 
344   return status;
345 }
346 
347 
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell * handle,int start_thread,int tid)348 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
349 {
350   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
351 
352   /* check if we have input, output and filter */
353 #if 0
354   if (handle->? == 0 ) {
355     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
356     return status;
357   }
358 #endif
359 
360   /* check if we are on AVX512 */
361 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
362   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
363     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
364       status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid);
365     } else {
366       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
367       return status;
368     }
369   } else
370 #endif
371   {
372     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
373       typedef float element_input_type;
374       typedef float element_output_type;
375       typedef float element_filter_type;
376       if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
377 #define LIBXSMM_DNN_RNN_RELU_FWD
378 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
379 #undef LIBXSMM_DNN_RNN_RELU_FWD
380       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
381 #define LIBXSMM_DNN_RNN_SIGMOID_FWD
382 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
383 #undef LIBXSMM_DNN_RNN_SIGMOID_FWD
384       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
385 #define LIBXSMM_DNN_RNN_TANH_FWD
386 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
387 #undef LIBXSMM_DNN_RNN_TANH_FWD
388       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
389         status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
390       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
391         status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
392       } else {
393         /* should not happen */
394       }
395     } else {
396       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
397       return status;
398     }
399   }
400 
401   return status;
402 }
403 
libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell * handle,int start_thread,int tid)404 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
405 {
406   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
407 
408   /* check if we have input, output and filter */
409 #if 0
410   if (handle->? == 0 ) {
411     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
412     return status;
413   }
414 #endif
415 
416   /* check if we are on AVX512 */
417 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
418   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
419     if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
420       status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32( handle, start_thread, tid);
421     }
422 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
423     else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
424       status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid);
425     } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX ) {
426       status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16( handle, start_thread, tid);
427     }
428 #elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
429     else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE  ) {
430       status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid);
431     }
432 #endif
433     else {
434       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
435       return status;
436     }
437   } else
438 #endif
439   {
440     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
441       typedef float element_input_type;
442       typedef float element_output_type;
443       typedef float element_filter_type;
444       if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
445 #define LIBXSMM_DNN_RNN_RELU_FWD
446 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
447 #undef LIBXSMM_DNN_RNN_RELU_FWD
448       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
449 #define LIBXSMM_DNN_RNN_SIGMOID_FWD
450 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
451 #undef LIBXSMM_DNN_RNN_SIGMOID_FWD
452       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
453 #define LIBXSMM_DNN_RNN_TANH_FWD
454 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
455 #undef LIBXSMM_DNN_RNN_TANH_FWD
456       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
457 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
458       } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
459 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
460       } else {
461         /* should not happen */
462       }
463     } else {
464       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
465       return status;
466     }
467   }
468 
469   return status;
470 }
471