1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke, Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_rnncell_forward.h"
12 #include "libxsmm_dnn_rnncell_backward_weight_update.h"
13 #include "libxsmm_dnn_elementwise.h"
14 #include "libxsmm_main.h"
15
16 #if defined(LIBXSMM_OFFLOAD_TARGET)
17 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
18 #endif
19 #include <math.h>
20 #if defined(LIBXSMM_OFFLOAD_TARGET)
21 # pragma offload_attribute(pop)
22 #endif
23
libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc,libxsmm_dnn_err_t * status)24 LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status)
25 {
26 libxsmm_dnn_rnncell* handle = 0;
27
28 /* init libxsmm */
29 LIBXSMM_INIT
30
31 /* some check we can do before allocating the handle */
32 if ( (rnncell_desc.datatype_in != rnncell_desc.datatype_out) ||
33 ( (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_BF16) && (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_F32) ) ) {
34 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
35 return NULL;
36 }
37 /* let's do some simple checks for BF16 as this limits the cell and architecture */
38 if ( (rnncell_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (rnncell_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
39 if ( (LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid) || (rnncell_desc.C % 16 != 0) || (rnncell_desc.K % 16 != 0) ) {
40 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
41 return NULL;
42 }
43 }
44 /* we need at least one timestep */
45 if (rnncell_desc.max_T < 1) {
46 *status = LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL;
47 return NULL;
48 }
49
50 handle = (libxsmm_dnn_rnncell*)malloc(sizeof(libxsmm_dnn_rnncell));
51 if (0 != handle) {
52 *status = LIBXSMM_DNN_SUCCESS;
53 /* zero entire content; not only safer but also sets data and code pointers to NULL */
54 memset(handle, 0, sizeof(*handle));
55 /* initialize known handle components */
56 handle->desc = rnncell_desc;
57 /* set current seq length to max length */
58 handle->T = rnncell_desc.max_T;
59 /* set blocking factors */
60 handle->bk = (handle->desc.bk == 0) ? 64 : handle->desc.bk;
61 handle->bn = (handle->desc.bn == 0) ? 64 : handle->desc.bn;
62 handle->bc = (handle->desc.bc == 0) ? 64 : handle->desc.bc;
63 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
64 handle->lpb = 2;
65 } else {
66 handle->lpb = 1;
67 }
68 /* validate blocking factors */
69 if ( handle->desc.N % handle->bn != 0 ) {
70 handle->bn = handle->desc.N;
71 *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING;
72 }
73 if ( handle->desc.C % handle->bc != 0 ) {
74 handle->bc = handle->desc.C;
75 *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING;
76 }
77 if ( handle->desc.K % handle->bk != 0 ) {
78 handle->bk = handle->desc.K;
79 *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING;
80 }
81
82 /* In case of BF16 for now hoist the BRGEMM and make them to use STRIDED variant by default */
83 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
84 const int typesize_in = (int)libxsmm_dnn_typesize(handle->desc.datatype_in);
85 const libxsmm_blasint K = handle->desc.K;
86 const libxsmm_blasint N = handle->desc.N;
87 const libxsmm_blasint C = handle->desc.C;
88 const libxsmm_blasint bk = handle->bk;
89 const libxsmm_blasint bn = handle->bn;
90 const libxsmm_blasint bc = handle->bc;
91 const libxsmm_blasint cBlocks = C/bc;
92 const libxsmm_blasint kBlocks = K/bk;
93 const libxsmm_blasint nBlocks = N/bn;
94 libxsmm_blasint BF, CB_BLOCKS, KB_BLOCKS;
95 libxsmm_blasint stride_a, stride_b;
96 int kernel_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
97
98 /* Blocking reduction domain if it is too large */
99 BF = 1;
100 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) {
101 BF = 8;
102 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
103 BF--;
104 }
105 }
106 if (C > 2048 || K > 2048) {
107 BF = 16;
108 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
109 BF--;
110 }
111 }
112 if (C == 2048 && K == 1024) {
113 BF = 2;
114 }
115 CB_BLOCKS = cBlocks/BF;
116 KB_BLOCKS = kBlocks/BF;
117
118 /* define batch-reduce gemm kernels */
119 stride_a = bc * bk * typesize_in;
120 stride_b = bc * typesize_in;
121 handle->fwd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bc, stride_a, stride_b, CB_BLOCKS, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL );
122 stride_a = bk * bk * typesize_in;
123 stride_b = bk * typesize_in;
124 handle->fwd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL );
125
126 KB_BLOCKS = kBlocks/BF;
127
128 stride_a = bc * bk * typesize_in;
129 stride_b = bk * typesize_in;
130 handle->bwdupd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bc, bn, bk, stride_a, stride_b, KB_BLOCKS, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL);
131 stride_a = bn * bk * typesize_in;
132 stride_b = bn * typesize_in;
133 handle->bwdupd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bk, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);
134 stride_a = bn * bk * typesize_in;
135 stride_b = bn * typesize_in;
136 handle->bwdupd_kernelc = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bc, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);
137 stride_a = bk * bk * typesize_in;
138 stride_b = bk * typesize_in;
139 handle->bwdupd_kerneld = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL);
140 }
141
142 /* Need to allocate space for scratch libxsmm_dnn_tensor's, let's set all pointers to zero */
143 handle->internal_z = 0;
144 handle->scratch_wT = 0;
145 handle->scratch_rT = 0;
146 handle->scratch_xT = 0;
147 handle->scratch_hT = 0;
148 handle->scratch_deltat = 0;
149 handle->scratch_di = 0;
150 handle->scratch_df = 0;
151 handle->scratch_do = 0;
152 handle->scratch_dci = 0;
153 handle->scratch_diB = 0;
154 handle->scratch_dfB = 0;
155 handle->scratch_dpB = 0;
156 handle->scratch_dciB = 0;
157 /* initialize a high-performant barrier */
158 handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
159 if (NULL == handle->barrier)
160 {
161 *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
162 free(handle);
163 return NULL;
164 }
165 } else {
166 *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
167 }
168 return handle;
169 }
170
171
libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell * handle)172 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle)
173 {
174 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
175 if (0 != handle) {
176 /* Deallocate barrier */
177 if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
178 /* deallocate handle structure */
179 free(/*remove constness*/(libxsmm_dnn_rnncell*)handle);
180 } else {
181 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
182 }
183 return status;
184 }
185
186
libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)187 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
188 {
189 libxsmm_dnn_tensor_datalayout* layout;
190 *status = LIBXSMM_DNN_SUCCESS;
191 layout = 0;
192 if (handle != 0) {
193 layout = (libxsmm_dnn_tensor_datalayout*) malloc(sizeof(libxsmm_dnn_tensor_datalayout));
194 if (layout != 0) {
195 memset(layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
196 if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ||
197 (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) ||
198 (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
199 (type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) ||
200 (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) ||
201 (type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) ||
202 (type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) ||
203 (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
204 layout->format = handle->desc.buffer_format;
205 layout->tensor_type = LIBXSMM_DNN_ACTIVATION;
206 if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
207 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
208 layout->datatype = handle->desc.datatype_in;
209 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
210 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
211
212 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
213 layout->num_dims = 5;
214
215 if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) {
216 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
217 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
218 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
219 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
220 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
221 layout->dim_size[0] = (unsigned int)handle->bc;
222 layout->dim_size[1] = (unsigned int)handle->bn;
223 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
224 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
225 layout->dim_size[4] = (unsigned int)handle->desc.max_T;
226 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) ||
227 (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
228 (type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) ||
229 (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) ||
230 (type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) ||
231 (type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) ||
232 (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
233 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
234 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
235 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
236 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
237 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
238 layout->dim_size[0] = (unsigned int)handle->bk;
239 layout->dim_size[1] = (unsigned int)handle->bn;
240 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
241 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
242 layout->dim_size[4] = (unsigned int)handle->desc.max_T;
243 } else {
244 free(layout->dim_type);
245 free(layout->dim_size);
246 free(layout);
247 layout = 0; /* make sure a NULL is returned */
248 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
249 }
250 } else {
251 free(layout);
252 layout = 0; /* make sure a NULL is returned */
253 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
254 }
255 } else {
256 free(layout);
257 layout = 0; /* make sure a NULL is returned */
258 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
259 }
260 } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) {
261 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
262 layout->datatype = handle->desc.datatype_in;
263 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
264 layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
265
266 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
267 layout->num_dims = 3;
268
269 if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) {
270 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
271 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
272 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
273 layout->dim_size[0] = (unsigned int)handle->desc.C;
274 layout->dim_size[1] = (unsigned int)handle->desc.N;
275 layout->dim_size[2] = (unsigned int)handle->desc.max_T;
276 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) ||
277 (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
278 (type == LIBXSMM_DNN_RNN_REGULAR_CS) || (type == LIBXSMM_DNN_RNN_GRADIENT_CS) ||
279 (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) ||
280 (type == LIBXSMM_DNN_RNN_INTERNAL_I) || (type == LIBXSMM_DNN_RNN_INTERNAL_F) ||
281 (type == LIBXSMM_DNN_RNN_INTERNAL_O) || (type == LIBXSMM_DNN_RNN_INTERNAL_CI) ||
282 (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
283 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
284 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
285 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
286 layout->dim_size[0] = (unsigned int)handle->desc.K;
287 layout->dim_size[1] = (unsigned int)handle->desc.N;
288 layout->dim_size[2] = (unsigned int)handle->desc.max_T;
289 } else {
290 free(layout->dim_type);
291 free(layout->dim_size);
292 free(layout);
293 layout = 0; /* make sure a NULL is returned */
294 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
295 }
296 } else {
297 free(layout);
298 layout = 0; /* make sure a NULL is returned */
299 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
300 }
301 } else {
302 free(layout);
303 layout = 0; /* make sure a NULL is returned */
304 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
305 }
306 } else {
307 free(layout);
308 layout = 0; /* make sure a NULL is returned */
309 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
310 }
311 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ||
312 (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
313 layout->format = handle->desc.filter_format;
314 layout->tensor_type = LIBXSMM_DNN_FILTER;
315 if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
316 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
317 layout->datatype = handle->desc.datatype_in;
318 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
319 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
320 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
321
322 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
323 layout->num_dims = 5;
324
325 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
326 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
327 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
328 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
329 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
330 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
331 layout->dim_size[0] = (unsigned int)handle->bk;
332 layout->dim_size[1] = (unsigned int)handle->bc;
333 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
334 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
335 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
336 layout->dim_size[4] = 4;
337 } else {
338 layout->dim_size[4] = 3;
339 }
340 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
341 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
342 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
343 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
344 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
345 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
346 layout->dim_size[0] = (unsigned int)handle->bk;
347 layout->dim_size[1] = (unsigned int)handle->bk;
348 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
349 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
350 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
351 layout->dim_size[4] = 4;
352 } else {
353 layout->dim_size[4] = 3;
354 }
355 } else {
356 free(layout->dim_type);
357 free(layout->dim_size);
358 free(layout);
359 layout = 0; /* make sure a NULL is returned */
360 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
361 }
362 } else {
363 free(layout);
364 layout = 0; /* make sure a NULL is returned */
365 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
366 }
367 } else {
368 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
369 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
370
371 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
372 layout->num_dims = 4;
373
374 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
375 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
376 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
377 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
378 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
379 layout->dim_size[0] = (unsigned int)handle->bk;
380 layout->dim_size[1] = (unsigned int)handle->bc;
381 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
382 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
383 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
384 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
385 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
386 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
387 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
388 layout->dim_size[0] = (unsigned int)handle->bk;
389 layout->dim_size[1] = (unsigned int)handle->bk;
390 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
391 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
392 } else {
393 free(layout->dim_type);
394 free(layout->dim_size);
395 free(layout);
396 layout = 0; /* make sure a NULL is returned */
397 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
398 }
399 } else {
400 free(layout);
401 layout = 0; /* make sure a NULL is returned */
402 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
403 }
404 }
405 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
406 layout->datatype = handle->desc.datatype_in;
407 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
408 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
409 layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
410
411 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
412 layout->num_dims = 6;
413
414 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
415 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
416 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
417 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
418 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
419 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
420 layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
421 layout->dim_size[0] = (unsigned int)handle->lpb;
422 layout->dim_size[1] = (unsigned int)handle->bk;
423 layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb);
424 layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
425 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
426 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
427 layout->dim_size[5] = 4;
428 } else {
429 layout->dim_size[5] = 3;
430 }
431 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
432 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
433 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
434 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
435 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
436 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
437 layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
438 layout->dim_size[0] = (unsigned int)handle->lpb;
439 layout->dim_size[1] = (unsigned int)handle->bk;
440 layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
441 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
442 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
443 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
444 layout->dim_size[5] = 4;
445 } else {
446 layout->dim_size[5] = 3;
447 }
448 } else {
449 free(layout->dim_type);
450 free(layout->dim_size);
451 free(layout);
452 layout = 0; /* make sure a NULL is returned */
453 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
454 }
455 } else {
456 free(layout);
457 layout = 0; /* make sure a NULL is returned */
458 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
459 }
460 } else {
461 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
462 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
463
464 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
465 layout->num_dims = 5;
466
467 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
468 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
469 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
470 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
471 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
472 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
473 layout->dim_size[0] = (unsigned int)handle->lpb;
474 layout->dim_size[1] = (unsigned int)handle->bk;
475 layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb);
476 layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
477 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
478 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
479 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
480 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
481 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
482 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
483 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
484 layout->dim_size[0] = (unsigned int)handle->lpb;
485 layout->dim_size[1] = (unsigned int)handle->bk;
486 layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
487 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
488 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
489 } else {
490 free(layout->dim_type);
491 free(layout->dim_size);
492 free(layout);
493 layout = 0; /* make sure a NULL is returned */
494 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
495 }
496 } else {
497 free(layout);
498 layout = 0; /* make sure a NULL is returned */
499 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
500 }
501 }
502
503 } else {
504 free(layout);
505 layout = 0; /* make sure a NULL is returned */
506 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
507 }
508 } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) {
509 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
510 layout->datatype = handle->desc.datatype_in;
511 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
512 layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
513 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
514 layout->num_dims = 2;
515
516 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
517 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
518 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
519 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
520 layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
521 layout->dim_size[1] = (unsigned int)handle->desc.C;
522 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
523 layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
524 layout->dim_size[1] = (unsigned int)handle->desc.C;
525 } else {
526 layout->dim_size[0] = (unsigned int)handle->desc.K;
527 layout->dim_size[1] = (unsigned int)handle->desc.C;
528 }
529 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
530 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
531 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
532 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
533 layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
534 layout->dim_size[1] = (unsigned int)handle->desc.K;
535 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
536 layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
537 layout->dim_size[1] = (unsigned int)handle->desc.K;
538 } else {
539 layout->dim_size[0] = (unsigned int)handle->desc.K;
540 layout->dim_size[1] = (unsigned int)handle->desc.K;
541 }
542 } else {
543 free(layout->dim_type);
544 free(layout->dim_size);
545 free(layout);
546 layout = 0; /* make sure a NULL is returned */
547 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
548 }
549 } else {
550 free(layout);
551 layout = 0; /* make sure a NULL is returned */
552 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
553 }
554 } else {
555 free(layout);
556 layout = 0; /* make sure a NULL is returned */
557 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
558 }
559 } else {
560 free(layout);
561 layout = 0; /* make sure a NULL is returned */
562 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
563 }
564 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) || (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
565 layout->format = handle->desc.filter_format;
566 layout->tensor_type = LIBXSMM_DNN_FILTER;
567 if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
568 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
569 layout->datatype = handle->desc.datatype_in;
570 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
571 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
572 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
573
574 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
575 layout->num_dims = 5;
576
577 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
578 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
579 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
580 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
581 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
582 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
583 layout->dim_size[0] = (unsigned int)handle->bc;
584 layout->dim_size[1] = (unsigned int)handle->bk;
585 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
586 layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
587 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
588 layout->dim_size[4] = 4;
589 } else {
590 layout->dim_size[4] = 3;
591 }
592 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
593 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
594 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
595 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
596 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
597 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
598 layout->dim_size[0] = (unsigned int)handle->bk;
599 layout->dim_size[1] = (unsigned int)handle->bk;
600 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
601 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
602 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
603 layout->dim_size[4] = 4;
604 } else {
605 layout->dim_size[4] = 3;
606 }
607 } else {
608 free(layout->dim_type);
609 free(layout->dim_size);
610 free(layout);
611 layout = 0; /* make sure a NULL is returned */
612 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
613 }
614 } else {
615 free(layout);
616 layout = 0; /* make sure a NULL is returned */
617 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
618 }
619 } else {
620 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
621 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
622
623 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
624 layout->num_dims = 4;
625
626 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
627 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
628 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
629 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
630 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
631 layout->dim_size[0] = (unsigned int)handle->bc;
632 layout->dim_size[1] = (unsigned int)handle->bk;
633 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
634 layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
635 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
636 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
637 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
638 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
639 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
640 layout->dim_size[0] = (unsigned int)handle->bk;
641 layout->dim_size[1] = (unsigned int)handle->bk;
642 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
643 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
644 } else {
645 free(layout->dim_type);
646 free(layout->dim_size);
647 free(layout);
648 layout = 0; /* make sure a NULL is returned */
649 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
650 }
651 } else {
652 free(layout);
653 layout = 0; /* make sure a NULL is returned */
654 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
655 }
656 }
657 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
658 layout->datatype = handle->desc.datatype_in;
659 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
660 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
661 layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
662
663 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
664 layout->num_dims = 6;
665
666 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
667 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
668 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
669 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
670 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
671 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
672 layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
673 layout->dim_size[0] = (unsigned int)handle->lpb;
674 layout->dim_size[1] = (unsigned int)handle->bc;
675 layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
676 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
677 layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc);
678 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
679 layout->dim_size[5] = 4;
680 } else {
681 layout->dim_size[5] = 3;
682 }
683 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
684 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
685 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
686 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
687 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
688 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
689 layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
690 layout->dim_size[0] = (unsigned int)handle->lpb;
691 layout->dim_size[1] = (unsigned int)handle->bk;
692 layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
693 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
694 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
695 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
696 layout->dim_size[5] = 4;
697 } else {
698 layout->dim_size[5] = 3;
699 }
700 } else {
701 free(layout->dim_type);
702 free(layout->dim_size);
703 free(layout);
704 layout = 0; /* make sure a NULL is returned */
705 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
706 }
707 } else {
708 free(layout);
709 layout = 0; /* make sure a NULL is returned */
710 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
711 }
712 } else {
713 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
714 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
715
716 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
717 layout->num_dims = 5;
718
719 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
720 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
721 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
722 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
723 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
724 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
725 layout->dim_size[0] = (unsigned int)handle->lpb;
726 layout->dim_size[1] = (unsigned int)handle->bc;
727 layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
728 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
729 layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc);
730 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
731 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
732 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
733 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
734 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
735 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
736 layout->dim_size[0] = (unsigned int)handle->lpb;
737 layout->dim_size[1] = (unsigned int)handle->bk;
738 layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
739 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
740 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
741 } else {
742 free(layout->dim_type);
743 free(layout->dim_size);
744 free(layout);
745 layout = 0; /* make sure a NULL is returned */
746 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
747 }
748 } else {
749 free(layout);
750 layout = 0; /* make sure a NULL is returned */
751 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
752 }
753 }
754 } else {
755 free(layout);
756 layout = 0; /* make sure a NULL is returned */
757 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
758 }
759 } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) {
760 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
761 layout->datatype = handle->desc.datatype_in;
762 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
763 layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
764
765 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
766 layout->num_dims = 2;
767
768 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
769 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
770 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
771 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
772 layout->dim_size[0] = (unsigned int)handle->desc.C;
773 layout->dim_size[1] = (unsigned int)(handle->desc.K * 4);
774 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
775 layout->dim_size[0] = (unsigned int)handle->desc.C;
776 layout->dim_size[1] = (unsigned int)(handle->desc.K * 3);
777 } else {
778 layout->dim_size[0] = (unsigned int)handle->desc.C;
779 layout->dim_size[1] = (unsigned int)handle->desc.K;
780 }
781 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
782 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
783 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
784 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
785 layout->dim_size[0] = (unsigned int)handle->desc.K;
786 layout->dim_size[1] = (unsigned int)(handle->desc.K * 4);
787 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
788 layout->dim_size[0] = (unsigned int)handle->desc.K;
789 layout->dim_size[1] = (unsigned int)(handle->desc.K * 3);
790 } else {
791 layout->dim_size[0] = (unsigned int)handle->desc.K;
792 layout->dim_size[1] = (unsigned int)handle->desc.K;
793 }
794 } else {
795 free(layout->dim_type);
796 free(layout->dim_size);
797 free(layout);
798 layout = 0; /* make sure a NULL is returned */
799 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
800 }
801 } else {
802 free(layout);
803 layout = 0; /* make sure a NULL is returned */
804 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
805 }
806 } else {
807 free(layout);
808 layout = 0; /* make sure a NULL is returned */
809 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
810 }
811 } else {
812 free(layout);
813 layout = 0; /* make sure a NULL is returned */
814 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
815 }
816 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) {
817 layout->format = handle->desc.buffer_format;
818 layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
819
820
821 if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) || ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
822 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
823 layout->datatype = handle->desc.datatype_in;
824 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
825 layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
826
827 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
828 layout->num_dims = 1;
829
830 if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) {
831 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
832 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
833 layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
834 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
835 layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
836 } else {
837 layout->dim_size[0] = (unsigned int)handle->desc.K;
838 }
839 } else { /* coverity[dead_error_begin] */
840 free(layout->dim_type);
841 free(layout->dim_size);
842 free(layout);
843 layout = 0; /* make sure a NULL is returned */
844 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
845 }
846 } else {
847 free(layout);
848 layout = 0; /* make sure a NULL is returned */
849 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
850 }
851 } else {
852 free(layout);
853 layout = 0; /* make sure a NULL is returned */
854 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
855 }
856 } else {
857 free(layout);
858 layout = 0; /* make sure a NULL is returned */
859 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
860 }
861 } else {
862 free(layout);
863 layout = 0; /* make sure a NULL is returned */
864 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
865 }
866 } else {
867 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
868 }
869 } else {
870 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
871 }
872 return layout;
873 }
874
875
libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,libxsmm_dnn_err_t * status)876 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
877 {
878 size_t size = 0;
879 *status = LIBXSMM_DNN_SUCCESS;
880
881 if (0 != handle) {
882 const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in);
883 const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in;
884
885 switch (handle->desc.cell_type) {
886 case LIBXSMM_DNN_RNNCELL_RNN_RELU:
887 case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
888 case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
889 switch (kind) {
890 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
891 size += 0;
892 } break;
893 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
894 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
895 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
896 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
897 size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in + 64; /* wT */
898 size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in + 64; /* rT */
899 size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */
900 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
901 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T + 64; /* deltat */
902
903 } break;
904 default: {
905 *status = LIBXSMM_DNN_ERR_INVALID_KIND;
906 }
907 }
908 } break;
909 case LIBXSMM_DNN_RNNCELL_LSTM: {
910 switch (kind) {
911 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
912 size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* w */
913 size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* r */
914 /* The scratches below are needed only for BF16 code for the intermediate results */
915 if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
916 size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */
917 size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64; /* intermediate scratches */
918 }
919 } break;
920 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
921 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
922 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
923 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
924 size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* w */
925 size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* r */
926 size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* wT */
927 size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* rT */
928 size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */
929 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
930 size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */
931 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */
932 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */
933 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */
934 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dci */
935 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */
936 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */
937 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dpB */
938 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dciB */
939 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */
940 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */
941 /* The scratches below are needed only for BF16 code for the intermediate results */
942 if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
943 size += (size_t)4 *((size_t)handle->desc.K * sizeof(float) + 64); /* intermediate db scratch */
944 size += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; /* intermediate dx scratches */
945 size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */
946 size += (size_t)2 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64); /* intermediate scratches */
947 }
948 } break;
949 default: {
950 *status = LIBXSMM_DNN_ERR_INVALID_KIND;
951 }
952 }
953 } break;
954 case LIBXSMM_DNN_RNNCELL_GRU: {
955 switch (kind) {
956 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
957 size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* w */
958 size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* r */
959 } break;
960 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
961 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
962 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
963 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
964 size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* w */
965 size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* r */
966 size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* wT */
967 size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* rT */
968 size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64; /* xT */
969 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
970 size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */
971 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */
972 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dc */
973 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */
974 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */
975 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */
976 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dcB */
977 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */
978 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* oT */
979 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */
980 size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */
981 } break;
982 default: {
983 *status = LIBXSMM_DNN_ERR_INVALID_KIND;
984 }
985 }
986 } break;
987 default: {
988 *status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
989 }
990 }
991 } else {
992 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
993 }
994
995 return size;
996 }
997
998
libxsmm_dnn_rnncell_get_scratch_ptr(const libxsmm_dnn_rnncell * handle,libxsmm_dnn_err_t * status)999 LIBXSMM_API void* libxsmm_dnn_rnncell_get_scratch_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status)
1000 {
1001 *status = LIBXSMM_DNN_SUCCESS;
1002
1003 if (0 != handle) {
1004 return handle->scratch_base;
1005 } else {
1006 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1007 }
1008
1009 return NULL;
1010 }
1011
1012
libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,const void * scratch)1013 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch)
1014 {
1015 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1016
1017 if (NULL != handle) {
1018 const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in);
1019 const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in;
1020 uintptr_t address = (uintptr_t)scratch;
1021 size_t offset = 0;
1022
1023 switch (handle->desc.cell_type) {
1024 case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1025 case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1026 case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1027 switch (kind) {
1028 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1029 /* forward only has no scratch need */
1030 } break;
1031 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1032 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1033 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1034 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1035 if (scratch == 0) {
1036 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1037 return status;
1038 }
1039 handle->scratch_base = (void*)address;
1040 /* wT */
1041 if (address % 64 == 0) {
1042 handle->scratch_wT = (void*)address;
1043 } else {
1044 offset = (64 - address % 64);
1045 handle->scratch_wT = (void*)(address+offset);
1046 }
1047 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) + 64;
1048 /* rT */
1049 if (address % 64 == 0) {
1050 handle->scratch_rT = (void*)address;
1051 } else {
1052 offset = (64 - address % 64);
1053 handle->scratch_rT = (void*)(address+offset);
1054 }
1055 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) + 64;
1056 /* xT */
1057 if (address % 64 == 0) {
1058 handle->scratch_xT = (void*)address;
1059 } else {
1060 offset = (64 - address % 64);
1061 handle->scratch_xT = (void*)(address+offset);
1062 }
1063 address += ((size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in) + 64;
1064 /* hT */
1065 if (address % 64 == 0) {
1066 handle->scratch_hT = (void*)address;
1067 } else {
1068 offset = (64 - address % 64);
1069 handle->scratch_hT = (void*)(address+offset);
1070 }
1071 address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out)) + 64;
1072 /* deltat */
1073 if (address % 64 == 0) {
1074 handle->scratch_deltat = (void*)address;
1075 } else {
1076 offset = (64 - address % 64);
1077 handle->scratch_deltat = (void*)(address+offset);
1078 }
1079 address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T) + 64;
1080 } break;
1081 default: {
1082 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1083 }
1084 }
1085 } break;
1086 case LIBXSMM_DNN_RNNCELL_LSTM: {
1087 switch (kind) {
1088 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1089 if (scratch == 0) {
1090 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1091 return status;
1092 }
1093 handle->scratch_base = (void*)address;
1094 /* w scratch */
1095 if (address % 64 == 0) {
1096 handle->scratch_w = (void*)address;
1097 } else {
1098 offset = (64 - address % 64);
1099 handle->scratch_w = (void*)(address+offset);
1100 }
1101 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1102 /* r scratch */
1103 if (address % 64 == 0) {
1104 handle->scratch_r = (void*)address;
1105 } else {
1106 offset = (64 - address % 64);
1107 handle->scratch_r = (void*)(address+offset);
1108 }
1109 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1110 /* The scratches below are needed only for BF16 code for the intermediate results */
1111 if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
1112 /* cst scratch */
1113 if (address % 64 == 0) {
1114 handle->cst_scratch = (void*)address;
1115 } else {
1116 offset = (64 - address % 64);
1117 handle->cst_scratch = (void*)(address+offset);
1118 }
1119 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1120 /* ht scratch */
1121 if (address % 64 == 0) {
1122 handle->ht_scratch = (void*)address;
1123 } else {
1124 offset = (64 - address % 64);
1125 handle->ht_scratch = (void*)(address+offset);
1126 }
1127 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1128 /* it scratch */
1129 if (address % 64 == 0) {
1130 handle->it_scratch = (void*)address;
1131 } else {
1132 offset = (64 - address % 64);
1133 handle->it_scratch = (void*)(address+offset);
1134 }
1135 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1136 /* ft scratch */
1137 if (address % 64 == 0) {
1138 handle->ft_scratch = (void*)address;
1139 } else {
1140 offset = (64 - address % 64);
1141 handle->ft_scratch = (void*)(address+offset);
1142 }
1143 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1144 /* ot scratch */
1145 if (address % 64 == 0) {
1146 handle->ot_scratch = (void*)address;
1147 } else {
1148 offset = (64 - address % 64);
1149 handle->ot_scratch = (void*)(address+offset);
1150 }
1151 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1152 /* cit scratch */
1153 if (address % 64 == 0) {
1154 handle->cit_scratch = (void*)address;
1155 } else {
1156 offset = (64 - address % 64);
1157 handle->cit_scratch = (void*)(address+offset);
1158 }
1159 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1160 /* cot scratch */
1161 if (address % 64 == 0) {
1162 handle->cot_scratch = (void*)address;
1163 } else {
1164 offset = (64 - address % 64);
1165 handle->cot_scratch = (void*)(address+offset);
1166 }
1167 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1168 /* csp scratch */
1169 if (address % 64 == 0) {
1170 handle->csp_scratch = (void*)address;
1171 } else {
1172 offset = (64 - address % 64);
1173 handle->csp_scratch = (void*)(address+offset);
1174 }
1175 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
1176 }
1177 } break;
1178 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1179 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1180 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1181 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1182 if (scratch == 0) {
1183 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1184 return status;
1185 }
1186 handle->scratch_base = (void*)address;
1187 /* w scratch */
1188 if (address % 64 == 0) {
1189 handle->scratch_w = (void*)address;
1190 } else {
1191 offset = (64 - address % 64);
1192 handle->scratch_w = (void*)(address+offset);
1193 }
1194 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64;
1195 /* r scratch */
1196 if (address % 64 == 0) {
1197 handle->scratch_r = (void*)address;
1198 } else {
1199 offset = (64 - address % 64);
1200 handle->scratch_r = (void*)(address+offset);
1201 }
1202 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64;
1203 /* wT */
1204 if (address % 64 == 0) {
1205 handle->scratch_wT = (void*)address;
1206 } else {
1207 offset = (64 - address % 64);
1208 handle->scratch_wT = (void*)(address+offset);
1209 }
1210 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1211 /* rT */
1212 if (address % 64 == 0) {
1213 handle->scratch_rT = (void*)address;
1214 } else {
1215 offset = (64 - address % 64);
1216 handle->scratch_rT = (void*)(address+offset);
1217 }
1218 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1219 /* xT */
1220 if (address % 64 == 0) {
1221 handle->scratch_xT = (void*)address;
1222 } else {
1223 offset = (64 - address % 64);
1224 handle->scratch_xT = (void*)(address+offset);
1225 }
1226 address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64;
1227 /* hT */
1228 if (address % 64 == 0) {
1229 handle->scratch_hT = (void*)address;
1230 } else {
1231 offset = (64 - address % 64);
1232 handle->scratch_hT = (void*)(address+offset);
1233 }
1234 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1235 /* deltat */
1236 if (address % 64 == 0) {
1237 handle->scratch_deltat = (void*)address;
1238 } else {
1239 offset = (64 - address % 64);
1240 handle->scratch_deltat = (void*)(address+offset);
1241 }
1242 address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64;
1243 /* di */
1244 if (address % 64 == 0) {
1245 handle->scratch_di = (void*)address;
1246 } else {
1247 offset = (64 - address % 64);
1248 handle->scratch_di = (void*)(address+offset);
1249 }
1250 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1251 /* df */
1252 if (address % 64 == 0) {
1253 handle->scratch_df = (void*)address;
1254 } else {
1255 offset = (64 - address % 64);
1256 handle->scratch_df = (void*)(address+offset);
1257 }
1258 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1259 /* do */
1260 if (address % 64 == 0) {
1261 handle->scratch_do = (void*)address;
1262 } else {
1263 offset = (64 - address % 64);
1264 handle->scratch_do = (void*)(address+offset);
1265 }
1266 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1267 /* dci */
1268 if (address % 64 == 0) {
1269 handle->scratch_dci = (void*)address;
1270 } else {
1271 offset = (64 - address % 64);
1272 handle->scratch_dci = (void*)(address+offset);
1273 }
1274 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1275 /* diB */
1276 if (address % 64 == 0) {
1277 handle->scratch_diB = (void*)address;
1278 } else {
1279 offset = (64 - address % 64);
1280 handle->scratch_diB = (void*)(address+offset);
1281 }
1282 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1283 /* dfB */
1284 if (address % 64 == 0) {
1285 handle->scratch_dfB = (void*)address;
1286 } else {
1287 offset = (64 - address % 64);
1288 handle->scratch_dfB = (void*)(address+offset);
1289 }
1290 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1291 /* dpB */
1292 if (address % 64 == 0) {
1293 handle->scratch_dpB = (void*)address;
1294 } else {
1295 offset = (64 - address % 64);
1296 handle->scratch_dpB = (void*)(address+offset);
1297 }
1298 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1299 /* dciB */
1300 if (address % 64 == 0) {
1301 handle->scratch_dciB = (void*)address;
1302 } else {
1303 offset = (64 - address % 64);
1304 handle->scratch_dciB = (void*)(address+offset);
1305 }
1306 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1307 /* t1 */
1308 if (address % 64 == 0) {
1309 handle->scratch_t1 = (void*)address;
1310 } else {
1311 offset = (64 - address % 64);
1312 handle->scratch_t1 = (void*)(address+offset);
1313 }
1314 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1315 /* t2 */
1316 if (address % 64 == 0) {
1317 handle->scratch_t2 = (void*)address;
1318 } else {
1319 offset = (64 - address % 64);
1320 handle->scratch_t2 = (void*)(address+offset);
1321 }
1322 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1323 /* The scratches below are needed only for BF16 code for the intermediate results */
1324 if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
1325 /* dx scratch */
1326 if (address % 64 == 0) {
1327 handle->scratch_dx = (void*)address;
1328 } else {
1329 offset = (64 - address % 64);
1330 handle->scratch_dx = (void*)(address+offset);
1331 }
1332 address += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1333 /* dhp scratch */
1334 if (address % 64 == 0) {
1335 handle->scratch_dhp = (void*)address;
1336 } else {
1337 offset = (64 - address % 64);
1338 handle->scratch_dhp = (void*)(address+offset);
1339 }
1340 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
1341 /* db scratch */
1342 if (address % 64 == 0) {
1343 handle->scratch_db = (void*)address;
1344 } else {
1345 offset = (64 - address % 64);
1346 handle->scratch_db = (void*)(address+offset);
1347 }
1348 address += (size_t)handle->desc.K * 4 * sizeof(float) + 64;
1349 /* cst scratch */
1350 if (address % 64 == 0) {
1351 handle->cst_scratch = (void*)address;
1352 } else {
1353 offset = (64 - address % 64);
1354 handle->cst_scratch = (void*)(address+offset);
1355 }
1356 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1357 /* ht scratch */
1358 if (address % 64 == 0) {
1359 handle->ht_scratch = (void*)address;
1360 } else {
1361 offset = (64 - address % 64);
1362 handle->ht_scratch = (void*)(address+offset);
1363 }
1364 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1365 /* it scratch */
1366 if (address % 64 == 0) {
1367 handle->it_scratch = (void*)address;
1368 } else {
1369 offset = (64 - address % 64);
1370 handle->it_scratch = (void*)(address+offset);
1371 }
1372 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1373 /* ft scratch */
1374 if (address % 64 == 0) {
1375 handle->ft_scratch = (void*)address;
1376 } else {
1377 offset = (64 - address % 64);
1378 handle->ft_scratch = (void*)(address+offset);
1379 }
1380 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1381 /* ot scratch */
1382 if (address % 64 == 0) {
1383 handle->ot_scratch = (void*)address;
1384 } else {
1385 offset = (64 - address % 64);
1386 handle->ot_scratch = (void*)(address+offset);
1387 }
1388 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1389 /* cit scratch */
1390 if (address % 64 == 0) {
1391 handle->cit_scratch = (void*)address;
1392 } else {
1393 offset = (64 - address % 64);
1394 handle->cit_scratch = (void*)(address+offset);
1395 }
1396 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1397 /* cot scratch */
1398 if (address % 64 == 0) {
1399 handle->cot_scratch = (void*)address;
1400 } else {
1401 offset = (64 - address % 64);
1402 handle->cot_scratch = (void*)(address+offset);
1403 }
1404 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1405 /* csp scratch */
1406 if (address % 64 == 0) {
1407 handle->csp_scratch = (void*)address;
1408 } else {
1409 offset = (64 - address % 64);
1410 handle->csp_scratch = (void*)(address+offset);
1411 }
1412 address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
1413 }
1414 } break;
1415 default: {
1416 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1417 }
1418 }
1419 } break;
1420 case LIBXSMM_DNN_RNNCELL_GRU: {
1421 switch (kind) {
1422 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1423 if (scratch == 0) {
1424 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1425 return status;
1426 }
1427 handle->scratch_base = (void*)address;
1428 /* w scratch */
1429 if (address % 64 == 0) {
1430 handle->scratch_w = (void*)address;
1431 } else {
1432 offset = (64 - address % 64);
1433 handle->scratch_w = (void*)(address+offset);
1434 }
1435 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1436 /* r scratch */
1437 if (address % 64 == 0) {
1438 handle->scratch_r = (void*)address;
1439 } else {
1440 offset = (64 - address % 64);
1441 handle->scratch_r = (void*)(address+offset);
1442 }
1443 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1444 } break;
1445 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1446 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1447 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1448 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1449 if (scratch == 0) {
1450 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1451 return status;
1452 }
1453 handle->scratch_base = (void*)address;
1454 /* w scratch */
1455 if (address % 64 == 0) {
1456 handle->scratch_w = (void*)address;
1457 } else {
1458 offset = (64 - address % 64);
1459 handle->scratch_w = (void*)(address+offset);
1460 }
1461 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64;
1462 /* r scratch */
1463 if (address % 64 == 0) {
1464 handle->scratch_r = (void*)address;
1465 } else {
1466 offset = (64 - address % 64);
1467 handle->scratch_r = (void*)(address+offset);
1468 }
1469 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64;
1470 /* wT */
1471 if (address % 64 == 0) {
1472 handle->scratch_wT = (void*)address;
1473 } else {
1474 offset = (64 - address % 64);
1475 handle->scratch_wT = (void*)(address+offset);
1476 }
1477 address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1478 /* rT */
1479 if (address % 64 == 0) {
1480 handle->scratch_rT = (void*)address;
1481 } else {
1482 offset = (64 - address % 64);
1483 handle->scratch_rT = (void*)(address+offset);
1484 }
1485 address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1486 /* xT */
1487 if (address % 64 == 0) {
1488 handle->scratch_xT = (void*)address;
1489 } else {
1490 offset = (64 - address % 64);
1491 handle->scratch_xT = (void*)(address+offset);
1492 }
1493 address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64;
1494 /* hT */
1495 if (address % 64 == 0) {
1496 handle->scratch_hT = (void*)address;
1497 } else {
1498 offset = (64 - address % 64);
1499 handle->scratch_hT = (void*)(address+offset);
1500 }
1501 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1502 /* deltat */
1503 if (address % 64 == 0) {
1504 handle->scratch_deltat = (void*)address;
1505 } else {
1506 offset = (64 - address % 64);
1507 handle->scratch_deltat = (void*)(address+offset);
1508 }
1509 address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64;
1510 /* di */
1511 if (address % 64 == 0) {
1512 handle->scratch_di = (void*)address;
1513 } else {
1514 offset = (64 - address % 64);
1515 handle->scratch_di = (void*)(address+offset);
1516 }
1517 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1518 /* dc */
1519 if (address % 64 == 0) {
1520 handle->scratch_dci = (void*)address;
1521 } else {
1522 offset = (64 - address % 64);
1523 handle->scratch_dci = (void*)(address+offset);
1524 }
1525 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1526 /* df */
1527 if (address % 64 == 0) {
1528 handle->scratch_df = (void*)address;
1529 } else {
1530 offset = (64 - address % 64);
1531 handle->scratch_df = (void*)(address+offset);
1532 }
1533 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1534 /* do */
1535 if (address % 64 == 0) {
1536 handle->scratch_do = (void*)address;
1537 } else {
1538 offset = (64 - address % 64);
1539 handle->scratch_do = (void*)(address+offset);
1540 }
1541 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1542 /* diB */
1543 if (address % 64 == 0) {
1544 handle->scratch_diB = (void*)address;
1545 } else {
1546 offset = (64 - address % 64);
1547 handle->scratch_diB = (void*)(address+offset);
1548 }
1549 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1550 /* dcB */
1551 if (address % 64 == 0) {
1552 handle->scratch_dciB = (void*)address;
1553 } else {
1554 offset = (64 - address % 64);
1555 handle->scratch_dciB = (void*)(address+offset);
1556 }
1557 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1558 /* dfB */
1559 if (address % 64 == 0) {
1560 handle->scratch_dfB = (void*)address;
1561 } else {
1562 offset = (64 - address % 64);
1563 handle->scratch_dfB = (void*)(address+offset);
1564 }
1565 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1566 /* doB (repurposed for oT) */
1567 if (address % 64 == 0) {
1568 handle->scratch_dpB = (void*)address;
1569 } else {
1570 offset = (64 - address % 64);
1571 handle->scratch_dpB = (void*)(address+offset);
1572 }
1573 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1574 /* t1 */
1575 if (address % 64 == 0) {
1576 handle->scratch_t1 = (void*)address;
1577 } else {
1578 offset = (64 - address % 64);
1579 handle->scratch_t1 = (void*)(address+offset);
1580 }
1581 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1582 /* t2 */
1583 if (address % 64 == 0) {
1584 handle->scratch_t2 = (void*)address;
1585 } else {
1586 offset = (64 - address % 64);
1587 handle->scratch_t2 = (void*)(address+offset);
1588 }
1589 address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1590 } break;
1591 default: {
1592 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1593 }
1594 }
1595 } break;
1596 default: {
1597 status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1598 }
1599 }
1600 } else {
1601 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1602 }
1603
1604 return status;
1605 }
1606
1607
libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind)1608 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind)
1609 {
1610 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1611
1612 if (0 != handle) {
1613 switch (handle->desc.cell_type) {
1614 case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1615 case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1616 case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1617 switch (kind) {
1618 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1619 /* forward only has no scratch need */
1620 } break;
1621 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1622 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1623 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1624 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1625 handle->scratch_wT = 0;
1626 handle->scratch_rT = 0;
1627 handle->scratch_xT = 0;
1628 handle->scratch_hT = 0;
1629 handle->scratch_deltat = 0;
1630 } break;
1631 default: {
1632 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1633 }
1634 }
1635 } break;
1636 case LIBXSMM_DNN_RNNCELL_LSTM: {
1637 switch (kind) {
1638 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1639 handle->scratch_w = 0;
1640 handle->scratch_r = 0;
1641 handle->csp_scratch = 0;
1642 handle->cst_scratch = 0;
1643 handle->ht_scratch = 0;
1644 handle->it_scratch = 0;
1645 handle->ft_scratch = 0;
1646 handle->ot_scratch = 0;
1647 handle->cit_scratch = 0;
1648 handle->cot_scratch = 0;
1649 } break;
1650 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1651 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1652 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1653 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1654 handle->scratch_w = 0;
1655 handle->scratch_r = 0;
1656 handle->scratch_wT = 0;
1657 handle->scratch_rT = 0;
1658 handle->scratch_xT = 0;
1659 handle->scratch_hT = 0;
1660 handle->scratch_deltat = 0;
1661 handle->scratch_di = 0;
1662 handle->scratch_df = 0;
1663 handle->scratch_do = 0;
1664 handle->scratch_dci = 0;
1665 handle->scratch_diB = 0;
1666 handle->scratch_dfB = 0;
1667 handle->scratch_dpB = 0;
1668 handle->scratch_dciB = 0;
1669 handle->scratch_t1 = 0;
1670 handle->scratch_t2 = 0;
1671 handle->csp_scratch = 0;
1672 handle->cst_scratch = 0;
1673 handle->ht_scratch = 0;
1674 handle->it_scratch = 0;
1675 handle->ft_scratch = 0;
1676 handle->ot_scratch = 0;
1677 handle->cit_scratch = 0;
1678 handle->cot_scratch = 0;
1679 } break;
1680 default: {
1681 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1682 }
1683 }
1684 } break;
1685 case LIBXSMM_DNN_RNNCELL_GRU: {
1686 switch (kind) {
1687 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1688 handle->scratch_w = 0;
1689 handle->scratch_r = 0;
1690 handle->ht_scratch = 0;
1691 handle->it_scratch = 0;
1692 handle->cit_scratch = 0;
1693 handle->ft_scratch = 0;
1694 handle->ot_scratch = 0;
1695 } break;
1696 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1697 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1698 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1699 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1700 handle->scratch_w = 0;
1701 handle->scratch_r = 0;
1702 handle->scratch_wT = 0;
1703 handle->scratch_rT = 0;
1704 handle->scratch_xT = 0;
1705 handle->scratch_hT = 0;
1706 handle->scratch_deltat = 0;
1707 handle->scratch_di = 0;
1708 handle->scratch_dci = 0;
1709 handle->scratch_df = 0;
1710 handle->scratch_do = 0;
1711 handle->scratch_diB = 0;
1712 handle->scratch_dciB = 0;
1713 handle->scratch_dfB = 0;
1714 handle->scratch_dpB = 0;
1715 handle->scratch_t1 = 0;
1716 handle->scratch_t2 = 0;
1717 handle->ht_scratch = 0;
1718 handle->it_scratch = 0;
1719 handle->ft_scratch = 0;
1720 handle->ot_scratch = 0;
1721 handle->cit_scratch = 0;
1722 } break;
1723 default: {
1724 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1725 }
1726 }
1727 } break;
1728 default: {
1729 status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1730 }
1731 }
1732 } else {
1733 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1734 }
1735
1736 return status;
1737 }
1738
1739
libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,libxsmm_dnn_err_t * status)1740 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
1741 {
1742 size_t size = 0;
1743 *status = LIBXSMM_DNN_SUCCESS;
1744
1745 if (0 != handle) {
1746 const size_t sizeof_datatype = sizeof(float);
1747
1748 switch (handle->desc.cell_type) {
1749 case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1750 case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1751 case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1752 switch (kind) {
1753 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1754 size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */
1755 } break;
1756 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1757 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1758 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1759 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1760 size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */
1761 } break;
1762 default: {
1763 *status = LIBXSMM_DNN_ERR_INVALID_KIND;
1764 }
1765 }
1766 } break;
1767 case LIBXSMM_DNN_RNNCELL_LSTM: {
1768 switch (kind) {
1769 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1770 /* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
1771 } break;
1772 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1773 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1774 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1775 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1776 /* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
1777 } break;
1778 default: {
1779 *status = LIBXSMM_DNN_ERR_INVALID_KIND;
1780 }
1781 }
1782 } break;
1783 case LIBXSMM_DNN_RNNCELL_GRU: {
1784 switch (kind) {
1785 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1786 /* with i, f, c, o exposed as i/o, there is currently no need for internal state */
1787 } break;
1788 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1789 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1790 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1791 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1792 /* with i, f, c, o exposed as i/o, there is currently no need for internal state */
1793 } break;
1794 default: {
1795 *status = LIBXSMM_DNN_ERR_INVALID_KIND;
1796 }
1797 }
1798 } break;
1799 default: {
1800 *status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1801 }
1802 }
1803 } else {
1804 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1805 }
1806
1807 return size;
1808 }
1809
1810
libxsmm_dnn_rnncell_get_internalstate_ptr(const libxsmm_dnn_rnncell * handle,libxsmm_dnn_err_t * status)1811 LIBXSMM_API void* libxsmm_dnn_rnncell_get_internalstate_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status)
1812 {
1813 *status = LIBXSMM_DNN_SUCCESS;
1814
1815 if (0 != handle) {
1816 return handle->internal_z;
1817 } else {
1818 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1819 }
1820
1821 return NULL;
1822 }
1823
1824
libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,const void * internalstate)1825 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate)
1826 {
1827 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1828 uintptr_t address = (uintptr_t)internalstate;
1829 size_t offset = 0;
1830
1831 if (0 != handle) {
1832 switch (handle->desc.cell_type) {
1833 case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1834 case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1835 case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1836 if (internalstate == 0) {
1837 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1838 return status;
1839 }
1840 switch (kind) {
1841 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1842 if (address % 64 == 0) {
1843 handle->internal_z = (void*)address;
1844 } else {
1845 offset = (64 - address % 64);
1846 handle->internal_z = (void*)(address+offset);
1847 }
1848 } break;
1849 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1850 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1851 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1852 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1853 if (address % 64 == 0) {
1854 handle->internal_z = (void*)address;
1855 } else {
1856 offset = (64 - address % 64);
1857 handle->internal_z = (void*)(address+offset);
1858 }
1859 } break;
1860 default: {
1861 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1862 }
1863 }
1864 } break;
1865 case LIBXSMM_DNN_RNNCELL_LSTM: {
1866 switch (kind) {
1867 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1868 } break;
1869 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1870 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1871 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1872 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1873 } break;
1874 default: {
1875 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1876 }
1877 }
1878 } break;
1879 case LIBXSMM_DNN_RNNCELL_GRU: {
1880 switch (kind) {
1881 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1882 } break;
1883 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1884 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1885 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1886 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1887 } break;
1888 default: {
1889 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1890 }
1891 }
1892 } break;
1893 default: {
1894 status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1895 }
1896 }
1897 } else {
1898 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1899 }
1900
1901 return status;
1902 }
1903
1904
libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind)1905 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind)
1906 {
1907 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1908
1909 if (0 != handle) {
1910 switch (handle->desc.cell_type) {
1911 case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1912 case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1913 case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1914 switch (kind) {
1915 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1916 handle->internal_z = 0;
1917 } break;
1918 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1919 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1920 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1921 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1922 handle->internal_z = 0;
1923 } break;
1924 default: {
1925 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1926 }
1927 }
1928 } break;
1929 case LIBXSMM_DNN_RNNCELL_LSTM: {
1930 switch (kind) {
1931 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1932 } break;
1933 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1934 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1935 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1936 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1937 } break;
1938 default: {
1939 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1940 }
1941 }
1942 } break;
1943 case LIBXSMM_DNN_RNNCELL_GRU: {
1944 switch (kind) {
1945 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1946 } break;
1947 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1948 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1949 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1950 case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1951 } break;
1952 default: {
1953 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1954 }
1955 }
1956 } break;
1957 default: {
1958 status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1959 }
1960 }
1961 } else {
1962 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1963 }
1964
1965 return status;
1966 }
1967
1968
libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell * handle,const float forget_bias)1969 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias)
1970 {
1971 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1972
1973 if (handle != 0) {
1974 handle->forget_bias = forget_bias;
1975 } else {
1976 status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
1977 }
1978
1979 return status;
1980 }
1981
1982
libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor * tensor,const libxsmm_dnn_tensor_type type)1983 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type)
1984 {
1985 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1986
1987 /* check for tensor type */
1988 if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) &&
1989 (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) &&
1990 (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
1991 (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) &&
1992 (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) &&
1993 (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
1994 (type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) &&
1995 (type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) &&
1996 (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) &&
1997 (type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) &&
1998 (type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) &&
1999 (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
2000 status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
2001 return status;
2002 }
2003
2004 if (handle != 0 && tensor != 0) {
2005 libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_rnncell_create_tensor_datalayout(handle, type, &status);
2006
2007 if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
2008 if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
2009 handle->xt = (libxsmm_dnn_tensor*)tensor;
2010 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
2011 handle->dxt = (libxsmm_dnn_tensor*)tensor;
2012 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
2013 handle->csp = (libxsmm_dnn_tensor*)tensor;
2014 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
2015 handle->dcsp = (libxsmm_dnn_tensor*)tensor;
2016 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
2017 handle->hp = (libxsmm_dnn_tensor*)tensor;
2018 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
2019 handle->dhp = (libxsmm_dnn_tensor*)tensor;
2020 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
2021 handle->w = (libxsmm_dnn_tensor*)tensor;
2022 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
2023 handle->wt = (libxsmm_dnn_tensor*)tensor;
2024 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
2025 handle->dw = (libxsmm_dnn_tensor*)tensor;
2026 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
2027 handle->r = (libxsmm_dnn_tensor*)tensor;
2028 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
2029 handle->rt = (libxsmm_dnn_tensor*)tensor;
2030 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
2031 handle->dr = (libxsmm_dnn_tensor*)tensor;
2032 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
2033 handle->b = (libxsmm_dnn_tensor*)tensor;
2034 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
2035 handle->db = (libxsmm_dnn_tensor*)tensor;
2036 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
2037 handle->cst = (libxsmm_dnn_tensor*)tensor;
2038 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
2039 handle->dcs = (libxsmm_dnn_tensor*)tensor;
2040 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
2041 handle->ht = (libxsmm_dnn_tensor*)tensor;
2042 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
2043 handle->dht = (libxsmm_dnn_tensor*)tensor;
2044 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
2045 handle->it = (libxsmm_dnn_tensor*)tensor;
2046 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
2047 handle->ft = (libxsmm_dnn_tensor*)tensor;
2048 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
2049 handle->ot = (libxsmm_dnn_tensor*)tensor;
2050 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
2051 handle->cit = (libxsmm_dnn_tensor*)tensor;
2052 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
2053 handle->cot = (libxsmm_dnn_tensor*)tensor;
2054 } else {
2055 /* cannot happen */
2056 }
2057 } else {
2058 status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
2059 }
2060
2061 libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
2062 }
2063 else {
2064 status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
2065 }
2066
2067 return status;
2068 }
2069
2070
libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)2071 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
2072 {
2073 libxsmm_dnn_tensor* tensor = 0;
2074 LIBXSMM_UNUSED(status/*TODO*/);
2075
2076 /* check for tensor type */
2077 if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) &&
2078 (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) &&
2079 (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
2080 (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) &&
2081 (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) &&
2082 (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
2083 (type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) &&
2084 (type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) &&
2085 (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) &&
2086 (type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) &&
2087 (type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) &&
2088 (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
2089 return tensor;
2090 }
2091
2092 if (handle != 0) {
2093 if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
2094 tensor = handle->xt;
2095 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
2096 tensor = handle->dxt;
2097 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
2098 tensor = handle->csp;
2099 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
2100 tensor = handle->dcsp;
2101 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
2102 tensor = handle->hp;
2103 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
2104 tensor = handle->dhp;
2105 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
2106 tensor = handle->w;
2107 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
2108 tensor = handle->wt;
2109 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
2110 tensor = handle->dw;
2111 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
2112 tensor = handle->r;
2113 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
2114 tensor = handle->rt;
2115 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
2116 tensor = handle->dr;
2117 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
2118 tensor = handle->b;
2119 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
2120 tensor = handle->db;
2121 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
2122 tensor = handle->cst;
2123 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
2124 tensor = handle->dcs;
2125 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
2126 tensor = handle->ht;
2127 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
2128 tensor = handle->dht;
2129 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
2130 tensor = handle->it;
2131 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
2132 tensor = handle->ft;
2133 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
2134 tensor = handle->ot;
2135 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
2136 tensor = handle->cit;
2137 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
2138 tensor = handle->cot;
2139 } else {
2140 /* cannot happen */
2141 }
2142 }
2143
2144 return tensor;
2145 }
2146
2147
libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor_type type)2148 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type)
2149 {
2150 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
2151
2152 /* check for tensor type */
2153 if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT) &&
2154 (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV) &&
2155 (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
2156 (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) &&
2157 (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) &&
2158 (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
2159 (type != LIBXSMM_DNN_RNN_REGULAR_BIAS) && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS) &&
2160 (type != LIBXSMM_DNN_RNN_REGULAR_CS) && (type != LIBXSMM_DNN_RNN_GRADIENT_CS) &&
2161 (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE) &&
2162 (type != LIBXSMM_DNN_RNN_INTERNAL_I) && (type != LIBXSMM_DNN_RNN_INTERNAL_F) &&
2163 (type != LIBXSMM_DNN_RNN_INTERNAL_O) && (type != LIBXSMM_DNN_RNN_INTERNAL_CI) &&
2164 (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
2165 status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
2166 return status;
2167 }
2168
2169 if (handle != 0) {
2170 if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
2171 handle->xt = 0;
2172 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
2173 handle->dxt = 0;
2174 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
2175 handle->csp = 0;
2176 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
2177 handle->dcsp = 0;
2178 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
2179 handle->hp = 0;
2180 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
2181 handle->dhp = 0;
2182 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
2183 handle->w = 0;
2184 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
2185 handle->wt = 0;
2186 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
2187 handle->dw = 0;
2188 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
2189 handle->r = 0;
2190 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
2191 handle->rt = 0;
2192 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
2193 handle->dr = 0;
2194 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
2195 handle->b = 0;
2196 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
2197 handle->db = 0;
2198 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
2199 handle->cst = 0;
2200 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
2201 handle->dcs = 0;
2202 } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
2203 handle->ht = 0;
2204 } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
2205 handle->dht = 0;
2206 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
2207 handle->it = 0;
2208 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
2209 handle->ft = 0;
2210 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
2211 handle->ot = 0;
2212 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
2213 handle->cit = 0;
2214 } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
2215 handle->cot = 0;
2216 } else {
2217 /* cannot happen */
2218 }
2219 }
2220 else {
2221 status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
2222 }
2223
2224 return status;
2225 }
2226
2227
libxsmm_dnn_rnncell_set_sequence_length(libxsmm_dnn_rnncell * handle,const libxsmm_blasint T)2228 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T ) {
2229 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
2230
2231 if (0 != handle) {
2232 if ( handle->desc.max_T < T ) {
2233 status = LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN;
2234 } else {
2235 handle->T = T;
2236 }
2237 } else {
2238 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
2239 }
2240
2241 return status;
2242 }
2243
2244
libxsmm_dnn_rnncell_get_sequence_length(libxsmm_dnn_rnncell * handle,libxsmm_dnn_err_t * status)2245 LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status ) {
2246 *status = LIBXSMM_DNN_SUCCESS;
2247
2248 if (0 != handle) {
2249 return handle->T;
2250 } else {
2251 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
2252 }
2253
2254 return 0;
2255 }
2256
2257
libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)2258 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind,
2259 /*unsigned*/int start_thread, /*unsigned*/int tid)
2260 {
2261 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
2262
2263 if (0 != handle) {
2264 switch (kind) {
2265 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
2266 if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) {
2267 status = libxsmm_dnn_rnncell_st_fwd_nc_ck( handle, start_thread, tid );
2268 } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
2269 status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck( handle, start_thread, tid );
2270 } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
2271 status = libxsmm_dnn_rnncell_st_fwd_nc_kcck( handle, start_thread, tid );
2272 } else {
2273 status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
2274 }
2275 } break;
2276 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
2277 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
2278 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
2279 if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) {
2280 status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck( handle, kind, start_thread, tid );
2281 } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
2282 status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck( handle, kind, start_thread, tid );
2283 } else {
2284 status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
2285 }
2286 } break;
2287 default: {
2288 status = LIBXSMM_DNN_ERR_INVALID_KIND;
2289 }
2290 }
2291 } else {
2292 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
2293 }
2294
2295 return status;
2296 }
2297
2298