1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_rnncell_forward.h"
12 #include "libxsmm_dnn_elementwise.h"
13 #include "libxsmm_main.h"
14
15
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
17 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
18 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
19 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
20 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
21 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
22 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid);
23
24
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)25 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
26 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
27 {
28 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
29 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
30 typedef float element_input_type;
31 typedef float element_output_type;
32 typedef float element_filter_type;
33 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
34 # define LIBXSMM_DNN_RNN_RELU_FWD
35 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
36 # undef LIBXSMM_DNN_RNN_RELU_FWD
37 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
38 # define LIBXSMM_DNN_RNN_SIGMOID_FWD
39 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
40 # undef LIBXSMM_DNN_RNN_SIGMOID_FWD
41 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
42 # define LIBXSMM_DNN_RNN_TANH_FWD
43 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
44 # undef LIBXSMM_DNN_RNN_TANH_FWD
45 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
46 #define LIBXSMM_RNN_CELL_AVX512
47 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
48 #undef LIBXSMM_RNN_CELL_AVX512
49 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
50 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
51 } else {
52 /* should not happen */
53 }
54 #else /* should not happen */
55 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
56 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
57 #endif
58 return status;
59 }
60
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)61 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
62 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
63 {
64 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
65 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__, __AVX512BW__, __AVX512DQ__*/
66 typedef libxsmm_bfloat16 element_input_type;
67 typedef libxsmm_bfloat16 element_output_type;
68 typedef libxsmm_bfloat16 element_filter_type;
69
70 /* some portable macrros fof BF16 <-> FP32 */
71 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
72
73 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
74 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
75 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
76 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
77 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
78 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
79 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
80 #define LIBXSMM_RNN_CELL_AVX512
81 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
82 #undef LIBXSMM_RNN_CELL_AVX512
83 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
84 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
85 } else {
86 /* should not happen */
87 }
88
89 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
90 #else /* should not happen */
91 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
92 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
93 #endif
94 return status;
95 }
96
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)97 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
98 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
99 {
100 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
101 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__, __AVX512BW__, __AVX512DQ__, __AVX512BF16__*/
102 typedef libxsmm_bfloat16 element_input_type;
103 typedef libxsmm_bfloat16 element_output_type;
104 typedef libxsmm_bfloat16 element_filter_type;
105
106 #define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
107 /* some portable macrros fof BF16 <-> FP32 */
108 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
109
110 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
111 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
112 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
113 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
114 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
115 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
116 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
117 #define LIBXSMM_RNN_CELL_AVX512
118 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic_bf16.tpl.c"
119 #undef LIBXSMM_RNN_CELL_AVX512
120 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
121 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
122 } else {
123 /* should not happen */
124 }
125
126 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
127 #undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
128 #else /* should not happen */
129 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
130 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
131 #endif
132 return status;
133 }
134
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)135 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
136 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
137 {
138 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
139 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
140 typedef float element_input_type;
141 typedef float element_output_type;
142 typedef float element_filter_type;
143 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
144 # define LIBXSMM_DNN_RNN_RELU_FWD
145 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
146 # undef LIBXSMM_DNN_RNN_RELU_FWD
147 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
148 # define LIBXSMM_DNN_RNN_SIGMOID_FWD
149 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
150 # undef LIBXSMM_DNN_RNN_SIGMOID_FWD
151 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
152 # define LIBXSMM_DNN_RNN_TANH_FWD
153 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
154 # undef LIBXSMM_DNN_RNN_TANH_FWD
155 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
156 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
157 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
158 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
159 } else {
160 /* should not happen */
161 }
162 #else /* should not happen */
163 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
164 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
165 #endif
166 return status;
167 }
168
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)169 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
170 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
171 {
172 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
173 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
174 typedef float element_input_type;
175 typedef float element_output_type;
176 typedef float element_filter_type;
177 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
178 # define LIBXSMM_DNN_RNN_RELU_FWD
179 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
180 # undef LIBXSMM_DNN_RNN_RELU_FWD
181 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
182 # define LIBXSMM_DNN_RNN_SIGMOID_FWD
183 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
184 # undef LIBXSMM_DNN_RNN_SIGMOID_FWD
185 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
186 # define LIBXSMM_DNN_RNN_TANH_FWD
187 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
188 # undef LIBXSMM_DNN_RNN_TANH_FWD
189 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
190 #define LIBXSMM_RNN_CELL_AVX512
191 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
192 #undef LIBXSMM_RNN_CELL_AVX512
193 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
194 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
195 } else {
196 /* should not happen */
197 }
198 #else /* should not happen */
199 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
200 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
201 #endif
202 return status;
203 }
204
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)205 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
206 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
207 {
208 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
209 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
210 typedef libxsmm_bfloat16 element_input_type;
211 typedef libxsmm_bfloat16 element_output_type;
212 typedef libxsmm_bfloat16 element_filter_type;
213
214 /* some portable macrros fof BF16 <-> FP32 */
215 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
216
217 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
218 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
219 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
220 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
221 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
222 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
223 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
224 #define LIBXSMM_RNN_CELL_AVX512
225 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
226 #undef LIBXSMM_RNN_CELL_AVX512
227 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
228 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
229 } else {
230 /* should not happen */
231 }
232
233 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
234 #else /* should not happen */
235 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
236 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
237 #endif
238 return status;
239 }
240
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)241 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
242 libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
243 {
244 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
245 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
246 typedef libxsmm_bfloat16 element_input_type;
247 typedef libxsmm_bfloat16 element_output_type;
248 typedef libxsmm_bfloat16 element_filter_type;
249
250 #define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
251 /* some portable macrros fof BF16 <-> FP32 */
252 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
253
254 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
255 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
256 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
257 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
258 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
259 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
260 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
261 #define LIBXSMM_RNN_CELL_AVX512
262 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_bf16.tpl.c"
263 #undef LIBXSMM_RNN_CELL_AVX512
264 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
265 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
266 } else {
267 /* should not happen */
268 }
269
270 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
271 #undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
272 #else /* should not happen */
273 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
274 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
275 #endif
276 return status;
277 }
278
libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell * handle,int start_thread,int tid)279 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_ck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
280 {
281 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
282
283 /* check if we have input, output and filter */
284 #if 0
285 if (handle->? == 0 ) {
286 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
287 return status;
288 }
289 #endif
290
291 /* check if we are on AVX512 */
292 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
293 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
294 if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
295 status = libxsmm_dnn_rnncell_st_fwd_nc_ck_f32_f32( handle, start_thread, tid);
296 }
297 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
298 else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
299 status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid);
300 } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX ) {
301 status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16( handle, start_thread, tid);
302 }
303 #elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
304 else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) {
305 status = libxsmm_dnn_rnncell_st_fwd_nc_ck_bf16_bf16_emu( handle, start_thread, tid);
306 }
307 #endif
308 else {
309 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
310 return status;
311 }
312 } else
313 #endif
314 {
315 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
316 typedef float element_input_type;
317 typedef float element_output_type;
318 typedef float element_filter_type;
319 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
320 #define LIBXSMM_DNN_RNN_RELU_FWD
321 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
322 #undef LIBXSMM_DNN_RNN_RELU_FWD
323 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
324 #define LIBXSMM_DNN_RNN_SIGMOID_FWD
325 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
326 #undef LIBXSMM_DNN_RNN_SIGMOID_FWD
327 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
328 #define LIBXSMM_DNN_RNN_TANH_FWD
329 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_ck_generic.tpl.c"
330 #undef LIBXSMM_DNN_RNN_TANH_FWD
331 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
332 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_ck_generic.tpl.c"
333 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
334 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_ck_generic.tpl.c"
335 } else {
336 /* should not happen */
337 }
338 } else {
339 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
340 return status;
341 }
342 }
343
344 return status;
345 }
346
347
libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell * handle,int start_thread,int tid)348 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_ncnc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
349 {
350 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
351
352 /* check if we have input, output and filter */
353 #if 0
354 if (handle->? == 0 ) {
355 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
356 return status;
357 }
358 #endif
359
360 /* check if we are on AVX512 */
361 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
362 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
363 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
364 status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck_f32_f32( handle, start_thread, tid);
365 } else {
366 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
367 return status;
368 }
369 } else
370 #endif
371 {
372 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
373 typedef float element_input_type;
374 typedef float element_output_type;
375 typedef float element_filter_type;
376 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
377 #define LIBXSMM_DNN_RNN_RELU_FWD
378 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
379 #undef LIBXSMM_DNN_RNN_RELU_FWD
380 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
381 #define LIBXSMM_DNN_RNN_SIGMOID_FWD
382 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
383 #undef LIBXSMM_DNN_RNN_SIGMOID_FWD
384 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
385 #define LIBXSMM_DNN_RNN_TANH_FWD
386 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_ncnc_kcck.tpl.c"
387 #undef LIBXSMM_DNN_RNN_TANH_FWD
388 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
389 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
390 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
391 status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
392 } else {
393 /* should not happen */
394 }
395 } else {
396 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
397 return status;
398 }
399 }
400
401 return status;
402 }
403
libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell * handle,int start_thread,int tid)404 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_rnncell_st_fwd_nc_kcck(libxsmm_dnn_rnncell* handle, int start_thread, int tid)
405 {
406 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
407
408 /* check if we have input, output and filter */
409 #if 0
410 if (handle->? == 0 ) {
411 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
412 return status;
413 }
414 #endif
415
416 /* check if we are on AVX512 */
417 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
418 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
419 if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
420 status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_f32_f32( handle, start_thread, tid);
421 }
422 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
423 else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX ) {
424 status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid);
425 } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX ) {
426 status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16( handle, start_thread, tid);
427 }
428 #elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
429 else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) {
430 status = libxsmm_dnn_rnncell_st_fwd_nc_kcck_bf16_bf16_emu( handle, start_thread, tid);
431 }
432 #endif
433 else {
434 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
435 return status;
436 }
437 } else
438 #endif
439 {
440 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
441 typedef float element_input_type;
442 typedef float element_output_type;
443 typedef float element_filter_type;
444 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_RELU ) {
445 #define LIBXSMM_DNN_RNN_RELU_FWD
446 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
447 #undef LIBXSMM_DNN_RNN_RELU_FWD
448 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_SIGMOID ) {
449 #define LIBXSMM_DNN_RNN_SIGMOID_FWD
450 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
451 #undef LIBXSMM_DNN_RNN_SIGMOID_FWD
452 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_RNN_TANH ) {
453 #define LIBXSMM_DNN_RNN_TANH_FWD
454 # include "template/libxsmm_dnn_rnncell_st_rnn_fwd_nc_kcck.tpl.c"
455 #undef LIBXSMM_DNN_RNN_TANH_FWD
456 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
457 # include "template/libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck.tpl.c"
458 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
459 # include "template/libxsmm_dnn_rnncell_st_gru_fwd_nc_kcck.tpl.c"
460 } else {
461 /* should not happen */
462 }
463 } else {
464 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
465 return status;
466 }
467 }
468
469 return status;
470 }
471