1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_rnncell_forward.h"
12 #include "libxsmm_dnn_rnncell_backward_weight_update.h"
13 #include "libxsmm_dnn_elementwise.h"
14 #include "libxsmm_main.h"
15 
16 #if defined(LIBXSMM_OFFLOAD_TARGET)
17 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
18 #endif
19 #include <math.h>
20 #if defined(LIBXSMM_OFFLOAD_TARGET)
21 # pragma offload_attribute(pop)
22 #endif
23 
libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc,libxsmm_dnn_err_t * status)24 LIBXSMM_API libxsmm_dnn_rnncell* libxsmm_dnn_create_rnncell(libxsmm_dnn_rnncell_desc rnncell_desc, libxsmm_dnn_err_t* status)
25 {
26   libxsmm_dnn_rnncell* handle = 0;
27 
28   /* init libxsmm */
29   LIBXSMM_INIT
30 
31   /* some check we can do before allocating the handle */
32   if ( (rnncell_desc.datatype_in != rnncell_desc.datatype_out) ||
33        ( (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_BF16) && (rnncell_desc.datatype_in != LIBXSMM_DNN_DATATYPE_F32) ) ) {
34     *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
35     return NULL;
36   }
37   /* let's do some simple checks for BF16 as this limits the cell and architecture */
38   if ( (rnncell_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) || (rnncell_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
39     if ( (LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid) || (rnncell_desc.C % 16 != 0) || (rnncell_desc.K % 16 != 0) ) {
40       *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
41       return NULL;
42     }
43   }
44   /* we need at least one timestep */
45   if (rnncell_desc.max_T < 1) {
46     *status = LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL;
47     return NULL;
48   }
49 
50   handle = (libxsmm_dnn_rnncell*)malloc(sizeof(libxsmm_dnn_rnncell));
51   if (0 != handle) {
52     *status = LIBXSMM_DNN_SUCCESS;
53     /* zero entire content; not only safer but also sets data and code pointers to NULL */
54     memset(handle, 0, sizeof(*handle));
55     /* initialize known handle components */
56     handle->desc = rnncell_desc;
57   /* set current seq length to max length */
58     handle->T = rnncell_desc.max_T;
59     /* set blocking factors */
60     handle->bk = (handle->desc.bk == 0) ? 64 : handle->desc.bk;
61     handle->bn = (handle->desc.bn == 0) ? 64 : handle->desc.bn;
62     handle->bc = (handle->desc.bc == 0) ? 64 : handle->desc.bc;
63     if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
64       handle->lpb = 2;
65     } else {
66       handle->lpb = 1;
67     }
68    /* validate blocking factors */
69     if ( handle->desc.N % handle->bn != 0 ) {
70       handle->bn = handle->desc.N;
71       *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING;
72     }
73     if ( handle->desc.C % handle->bc != 0 ) {
74       handle->bc = handle->desc.C;
75       *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING;
76     }
77     if ( handle->desc.K % handle->bk != 0 ) {
78       handle->bk = handle->desc.K;
79       *status = LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING;
80     }
81 
82      /* In case of BF16 for now hoist the BRGEMM and make them to use STRIDED variant by default */
83     if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
84       const int typesize_in = (int)libxsmm_dnn_typesize(handle->desc.datatype_in);
85       const libxsmm_blasint K =  handle->desc.K;
86       const libxsmm_blasint N =  handle->desc.N;
87       const libxsmm_blasint C =  handle->desc.C;
88       const libxsmm_blasint bk = handle->bk;
89       const libxsmm_blasint bn = handle->bn;
90       const libxsmm_blasint bc = handle->bc;
91       const libxsmm_blasint cBlocks = C/bc;
92       const libxsmm_blasint kBlocks = K/bk;
93       const libxsmm_blasint nBlocks = N/bn;
94       libxsmm_blasint BF, CB_BLOCKS, KB_BLOCKS;
95       libxsmm_blasint stride_a, stride_b;
96       int kernel_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
97 
98       /* Blocking reduction domain if it is too large */
99       BF = 1;
100       if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) {
101         BF = 8;
102         while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
103           BF--;
104         }
105       }
106       if (C > 2048 || K > 2048) {
107         BF = 16;
108         while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
109           BF--;
110         }
111       }
112       if (C == 2048 && K == 1024) {
113         BF = 2;
114       }
115       CB_BLOCKS = cBlocks/BF;
116       KB_BLOCKS = kBlocks/BF;
117 
118       /* define batch-reduce gemm kernels */
119       stride_a = bc * bk * typesize_in;
120       stride_b = bc * typesize_in;
121       handle->fwd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bc, stride_a, stride_b, CB_BLOCKS, &bk, &C, &K, NULL, NULL, &kernel_flags, NULL );
122       stride_a = bk * bk * typesize_in;
123       stride_b = bk * typesize_in;
124       handle->fwd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL );
125 
126       KB_BLOCKS = kBlocks/BF;
127 
128       stride_a = bc * bk * typesize_in;
129       stride_b = bk * typesize_in;
130       handle->bwdupd_kernela = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bc, bn, bk, stride_a, stride_b, KB_BLOCKS, &bc, &K, &C, NULL, NULL, &kernel_flags, NULL);
131       stride_a = bn * bk * typesize_in;
132       stride_b = bn * typesize_in;
133       handle->bwdupd_kernelb = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bk, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);
134       stride_a = bn * bk * typesize_in;
135       stride_b = bn * typesize_in;
136       handle->bwdupd_kernelc = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bc, bn, stride_a, stride_b, nBlocks, &bk, &N, &bk, NULL, NULL, &kernel_flags, NULL);
137       stride_a = bk * bk * typesize_in;
138       stride_b = bk * typesize_in;
139       handle->bwdupd_kerneld = libxsmm_bsmmdispatch_reducebatch_strd_unroll( bk, bn, bk, stride_a, stride_b, KB_BLOCKS, &bk, &K, &K, NULL, NULL, &kernel_flags, NULL);
140     }
141 
142     /* Need to allocate space for scratch libxsmm_dnn_tensor's, let's set all pointers to zero */
143     handle->internal_z = 0;
144     handle->scratch_wT = 0;
145     handle->scratch_rT = 0;
146     handle->scratch_xT = 0;
147     handle->scratch_hT = 0;
148     handle->scratch_deltat = 0;
149     handle->scratch_di = 0;
150     handle->scratch_df = 0;
151     handle->scratch_do = 0;
152     handle->scratch_dci = 0;
153     handle->scratch_diB = 0;
154     handle->scratch_dfB = 0;
155     handle->scratch_dpB = 0;
156     handle->scratch_dciB = 0;
157     /* initialize a high-performant barrier */
158     handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
159     if (NULL == handle->barrier)
160     {
161       *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
162       free(handle);
163       return NULL;
164     }
165   } else {
166     *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
167   }
168   return handle;
169 }
170 
171 
libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell * handle)172 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_rnncell(const libxsmm_dnn_rnncell* handle)
173 {
174   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
175   if (0 != handle) {
176     /* Deallocate barrier */
177     if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
178     /* deallocate handle structure */
179     free(/*remove constness*/(libxsmm_dnn_rnncell*)handle);
180   } else {
181     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
182   }
183   return status;
184 }
185 
186 
libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)187 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_rnncell_create_tensor_datalayout(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
188 {
189   libxsmm_dnn_tensor_datalayout* layout;
190   *status = LIBXSMM_DNN_SUCCESS;
191   layout = 0;
192   if (handle != 0) {
193     layout = (libxsmm_dnn_tensor_datalayout*) malloc(sizeof(libxsmm_dnn_tensor_datalayout));
194     if (layout != 0) {
195       memset(layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
196       if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT)             || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT)             ||
197            (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV)           || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV)           ||
198            (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
199            (type == LIBXSMM_DNN_RNN_REGULAR_CS)                || (type == LIBXSMM_DNN_RNN_GRADIENT_CS)                ||
200            (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE)      || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE)      ||
201            (type == LIBXSMM_DNN_RNN_INTERNAL_I)                || (type == LIBXSMM_DNN_RNN_INTERNAL_F)                 ||
202            (type == LIBXSMM_DNN_RNN_INTERNAL_O)                || (type == LIBXSMM_DNN_RNN_INTERNAL_CI)                ||
203            (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
204         layout->format = handle->desc.buffer_format;
205         layout->tensor_type = LIBXSMM_DNN_ACTIVATION;
206         if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
207           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
208             layout->datatype = handle->desc.datatype_in;
209             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
210             layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
211 
212             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
213               layout->num_dims = 5;
214 
215               if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) {
216                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
217                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
218                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
219                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
220                 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
221                 layout->dim_size[0] = (unsigned int)handle->bc;
222                 layout->dim_size[1] = (unsigned int)handle->bn;
223                 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
224                 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
225                 layout->dim_size[4] = (unsigned int)handle->desc.max_T;
226               } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV)           || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV)           ||
227                           (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
228                           (type == LIBXSMM_DNN_RNN_REGULAR_CS)                || (type == LIBXSMM_DNN_RNN_GRADIENT_CS)                ||
229                           (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE)      || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE)      ||
230                           (type == LIBXSMM_DNN_RNN_INTERNAL_I)                || (type == LIBXSMM_DNN_RNN_INTERNAL_F)                 ||
231                           (type == LIBXSMM_DNN_RNN_INTERNAL_O)                || (type == LIBXSMM_DNN_RNN_INTERNAL_CI)                ||
232                           (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
233                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
234                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
235                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
236                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
237                 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
238                 layout->dim_size[0] = (unsigned int)handle->bk;
239                 layout->dim_size[1] = (unsigned int)handle->bn;
240                 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
241                 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
242                 layout->dim_size[4] = (unsigned int)handle->desc.max_T;
243               } else {
244                 free(layout->dim_type);
245                 free(layout->dim_size);
246                 free(layout);
247                 layout = 0; /* make sure a NULL is returned */
248                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
249               }
250             } else {
251               free(layout);
252               layout = 0; /* make sure a NULL is returned */
253               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
254             }
255           } else {
256             free(layout);
257             layout = 0; /* make sure a NULL is returned */
258             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
259           }
260         } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) {
261           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
262             layout->datatype = handle->desc.datatype_in;
263             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(3*sizeof(libxsmm_dnn_tensor_dimtype));
264             layout->dim_size = (unsigned int*) malloc(3*sizeof(unsigned int));
265 
266             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
267               layout->num_dims = 3;
268 
269               if ( (type == LIBXSMM_DNN_RNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_RNN_GRADIENT_INPUT) ) {
270                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
271                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
272                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
273                 layout->dim_size[0] = (unsigned int)handle->desc.C;
274                 layout->dim_size[1] = (unsigned int)handle->desc.N;
275                 layout->dim_size[2] = (unsigned int)handle->desc.max_T;
276               } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV)           || (type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV)           ||
277                           (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) ||
278                           (type == LIBXSMM_DNN_RNN_REGULAR_CS)                || (type == LIBXSMM_DNN_RNN_GRADIENT_CS)                ||
279                           (type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE)      || (type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE)      ||
280                           (type == LIBXSMM_DNN_RNN_INTERNAL_I)                || (type == LIBXSMM_DNN_RNN_INTERNAL_F)                 ||
281                           (type == LIBXSMM_DNN_RNN_INTERNAL_O)                || (type == LIBXSMM_DNN_RNN_INTERNAL_CI)                ||
282                           (type == LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
283                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
284                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
285                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_T;
286                 layout->dim_size[0] = (unsigned int)handle->desc.K;
287                 layout->dim_size[1] = (unsigned int)handle->desc.N;
288                 layout->dim_size[2] = (unsigned int)handle->desc.max_T;
289               } else {
290                 free(layout->dim_type);
291                 free(layout->dim_size);
292                 free(layout);
293                 layout = 0; /* make sure a NULL is returned */
294                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
295               }
296             } else {
297               free(layout);
298               layout = 0; /* make sure a NULL is returned */
299               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
300             }
301           } else {
302             free(layout);
303             layout = 0; /* make sure a NULL is returned */
304             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
305           }
306         } else {
307           free(layout);
308           layout = 0; /* make sure a NULL is returned */
309           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
310         }
311       } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT)       || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ||
312                   (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
313         layout->format = handle->desc.filter_format;
314         layout->tensor_type = LIBXSMM_DNN_FILTER;
315         if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
316           if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
317             layout->datatype = handle->desc.datatype_in;
318             if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
319               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
320               layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
321 
322               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
323                 layout->num_dims = 5;
324 
325                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
326                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
327                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
328                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
329                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
330                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
331                   layout->dim_size[0] = (unsigned int)handle->bk;
332                   layout->dim_size[1] = (unsigned int)handle->bc;
333                   layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
334                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
335                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
336                     layout->dim_size[4] = 4;
337                   } else {
338                     layout->dim_size[4] = 3;
339                   }
340                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
341                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
342                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
343                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
344                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
345                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
346                   layout->dim_size[0] = (unsigned int)handle->bk;
347                   layout->dim_size[1] = (unsigned int)handle->bk;
348                   layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
349                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
350                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
351                     layout->dim_size[4] = 4;
352                   } else {
353                     layout->dim_size[4] = 3;
354                   }
355                 } else {
356                   free(layout->dim_type);
357                   free(layout->dim_size);
358                   free(layout);
359                   layout = 0; /* make sure a NULL is returned */
360                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
361                 }
362               } else {
363                 free(layout);
364                 layout = 0; /* make sure a NULL is returned */
365                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
366               }
367             } else {
368               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
369               layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
370 
371               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
372                 layout->num_dims = 4;
373 
374                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
375                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
376                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
377                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
378                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
379                   layout->dim_size[0] = (unsigned int)handle->bk;
380                   layout->dim_size[1] = (unsigned int)handle->bc;
381                   layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
382                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
383                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
384                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
385                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
386                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
387                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
388                   layout->dim_size[0] = (unsigned int)handle->bk;
389                   layout->dim_size[1] = (unsigned int)handle->bk;
390                   layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
391                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
392                 } else {
393                   free(layout->dim_type);
394                   free(layout->dim_size);
395                   free(layout);
396                   layout = 0; /* make sure a NULL is returned */
397                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
398                 }
399               } else {
400                 free(layout);
401                 layout = 0; /* make sure a NULL is returned */
402                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
403               }
404             }
405           } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
406             layout->datatype = handle->desc.datatype_in;
407             if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
408               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
409               layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
410 
411               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
412                 layout->num_dims = 6;
413 
414                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
415                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
416                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
417                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
418                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
419                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
420                   layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
421                   layout->dim_size[0] = (unsigned int)handle->lpb;
422                   layout->dim_size[1] = (unsigned int)handle->bk;
423                   layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb);
424                   layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
425                   layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
426                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
427                     layout->dim_size[5] = 4;
428                   } else {
429                     layout->dim_size[5] = 3;
430                   }
431                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
432                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
433                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
434                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
435                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
436                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
437                   layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
438                   layout->dim_size[0] = (unsigned int)handle->lpb;
439                   layout->dim_size[1] = (unsigned int)handle->bk;
440                   layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
441                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
442                   layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
443                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
444                     layout->dim_size[5] = 4;
445                   } else {
446                     layout->dim_size[5] = 3;
447                   }
448                 } else {
449                   free(layout->dim_type);
450                   free(layout->dim_size);
451                   free(layout);
452                   layout = 0; /* make sure a NULL is returned */
453                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
454                 }
455               } else {
456                 free(layout);
457                 layout = 0; /* make sure a NULL is returned */
458                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
459               }
460             } else {
461               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
462               layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
463 
464               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
465                 layout->num_dims = 5;
466 
467                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
468                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
469                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
470                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
471                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
472                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
473                   layout->dim_size[0] = (unsigned int)handle->lpb;
474                   layout->dim_size[1] = (unsigned int)handle->bk;
475                   layout->dim_size[2] = (unsigned int)(handle->bc / handle->lpb);
476                   layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
477                   layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
478                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
479                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
480                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
481                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
482                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
483                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
484                   layout->dim_size[0] = (unsigned int)handle->lpb;
485                   layout->dim_size[1] = (unsigned int)handle->bk;
486                   layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
487                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
488                   layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
489                 } else {
490                   free(layout->dim_type);
491                   free(layout->dim_size);
492                   free(layout);
493                   layout = 0; /* make sure a NULL is returned */
494                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
495                 }
496               } else {
497                 free(layout);
498                 layout = 0; /* make sure a NULL is returned */
499                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
500               }
501             }
502 
503           } else {
504             free(layout);
505             layout = 0; /* make sure a NULL is returned */
506             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
507           }
508         } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) {
509           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
510             layout->datatype = handle->desc.datatype_in;
511             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
512             layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
513             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
514               layout->num_dims = 2;
515 
516               if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT) ) {
517                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
518                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
519                 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
520                   layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
521                   layout->dim_size[1] = (unsigned int)handle->desc.C;
522                 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
523                   layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
524                   layout->dim_size[1] = (unsigned int)handle->desc.C;
525                 } else {
526                   layout->dim_size[0] = (unsigned int)handle->desc.K;
527                   layout->dim_size[1] = (unsigned int)handle->desc.C;
528                 }
529               } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT) || (type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT) ) {
530                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
531                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
532                 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
533                   layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
534                   layout->dim_size[1] = (unsigned int)handle->desc.K;
535                 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
536                   layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
537                   layout->dim_size[1] = (unsigned int)handle->desc.K;
538                 } else {
539                   layout->dim_size[0] = (unsigned int)handle->desc.K;
540                   layout->dim_size[1] = (unsigned int)handle->desc.K;
541                 }
542               } else {
543                 free(layout->dim_type);
544                 free(layout->dim_size);
545                 free(layout);
546                 layout = 0; /* make sure a NULL is returned */
547                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
548               }
549             } else {
550               free(layout);
551               layout = 0; /* make sure a NULL is returned */
552               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
553             }
554           } else {
555             free(layout);
556             layout = 0; /* make sure a NULL is returned */
557             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
558           }
559         } else {
560           free(layout);
561           layout = 0; /* make sure a NULL is returned */
562           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
563         }
564       } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) || (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
565         layout->format = handle->desc.filter_format;
566         layout->tensor_type = LIBXSMM_DNN_FILTER;
567         if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
568           if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
569             layout->datatype = handle->desc.datatype_in;
570             if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
571               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
572               layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
573 
574               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
575                 layout->num_dims = 5;
576 
577                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
578                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
579                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
580                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
581                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
582                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
583                   layout->dim_size[0] = (unsigned int)handle->bc;
584                   layout->dim_size[1] = (unsigned int)handle->bk;
585                   layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
586                   layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
587                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
588                     layout->dim_size[4] = 4;
589                   } else {
590                     layout->dim_size[4] = 3;
591                   }
592                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
593                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
594                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
595                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
596                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
597                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
598                   layout->dim_size[0] = (unsigned int)handle->bk;
599                   layout->dim_size[1] = (unsigned int)handle->bk;
600                   layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
601                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
602                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
603                     layout->dim_size[4] = 4;
604                   } else {
605                     layout->dim_size[4] = 3;
606                   }
607                 } else {
608                   free(layout->dim_type);
609                   free(layout->dim_size);
610                   free(layout);
611                   layout = 0; /* make sure a NULL is returned */
612                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
613                 }
614               } else {
615                 free(layout);
616                 layout = 0; /* make sure a NULL is returned */
617                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
618               }
619             } else {
620               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
621               layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
622 
623               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
624                 layout->num_dims = 4;
625 
626                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
627                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
628                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
629                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
630                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
631                   layout->dim_size[0] = (unsigned int)handle->bc;
632                   layout->dim_size[1] = (unsigned int)handle->bk;
633                   layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
634                   layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
635                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
636                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
637                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
638                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
639                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
640                   layout->dim_size[0] = (unsigned int)handle->bk;
641                   layout->dim_size[1] = (unsigned int)handle->bk;
642                   layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
643                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
644                 } else {
645                   free(layout->dim_type);
646                   free(layout->dim_size);
647                   free(layout);
648                   layout = 0; /* make sure a NULL is returned */
649                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
650                 }
651               } else {
652                 free(layout);
653                 layout = 0; /* make sure a NULL is returned */
654                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
655               }
656             }
657           } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
658             layout->datatype = handle->desc.datatype_in;
659             if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM || handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
660               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
661               layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
662 
663               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
664                 layout->num_dims = 6;
665 
666                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
667                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
668                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
669                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
670                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
671                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
672                   layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
673                   layout->dim_size[0] = (unsigned int)handle->lpb;
674                   layout->dim_size[1] = (unsigned int)handle->bc;
675                   layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
676                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
677                   layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc);
678                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
679                     layout->dim_size[5] = 4;
680                   } else {
681                     layout->dim_size[5] = 3;
682                   }
683                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
684                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
685                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
686                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
687                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
688                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
689                   layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
690                   layout->dim_size[0] = (unsigned int)handle->lpb;
691                   layout->dim_size[1] = (unsigned int)handle->bk;
692                   layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
693                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
694                   layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
695                   if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
696                     layout->dim_size[5] = 4;
697                   } else {
698                     layout->dim_size[5] = 3;
699                   }
700                 } else {
701                   free(layout->dim_type);
702                   free(layout->dim_size);
703                   free(layout);
704                   layout = 0; /* make sure a NULL is returned */
705                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
706                 }
707               } else {
708                 free(layout);
709                 layout = 0; /* make sure a NULL is returned */
710                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
711               }
712             } else {
713               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
714               layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
715 
716               if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
717                 layout->num_dims = 5;
718 
719                 if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
720                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
721                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
722                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
723                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
724                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
725                   layout->dim_size[0] = (unsigned int)handle->lpb;
726                   layout->dim_size[1] = (unsigned int)handle->bc;
727                   layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
728                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
729                   layout->dim_size[4] = (unsigned int)(handle->desc.C / handle->bc);
730                 } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
731                   layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
732                   layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
733                   layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
734                   layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
735                   layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
736                   layout->dim_size[0] = (unsigned int)handle->lpb;
737                   layout->dim_size[1] = (unsigned int)handle->bk;
738                   layout->dim_size[2] = (unsigned int)(handle->bk / handle->lpb);
739                   layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
740                   layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
741                 } else {
742                   free(layout->dim_type);
743                   free(layout->dim_size);
744                   free(layout);
745                   layout = 0; /* make sure a NULL is returned */
746                   *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
747                 }
748               } else {
749                 free(layout);
750                 layout = 0; /* make sure a NULL is returned */
751                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
752               }
753             }
754           } else {
755             free(layout);
756             layout = 0; /* make sure a NULL is returned */
757             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
758           }
759         } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CK) > 0) {
760           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
761             layout->datatype = handle->desc.datatype_in;
762             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
763             layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
764 
765             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
766               layout->num_dims = 2;
767 
768               if ( (type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS) ) {
769                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
770                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
771                 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
772                   layout->dim_size[0] = (unsigned int)handle->desc.C;
773                   layout->dim_size[1] = (unsigned int)(handle->desc.K * 4);
774                 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
775                   layout->dim_size[0] = (unsigned int)handle->desc.C;
776                   layout->dim_size[1] = (unsigned int)(handle->desc.K * 3);
777                 } else {
778                   layout->dim_size[0] = (unsigned int)handle->desc.C;
779                   layout->dim_size[1] = (unsigned int)handle->desc.K;
780                 }
781               } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) ) {
782                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
783                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
784                 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
785                   layout->dim_size[0] = (unsigned int)handle->desc.K;
786                   layout->dim_size[1] = (unsigned int)(handle->desc.K * 4);
787                 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
788                   layout->dim_size[0] = (unsigned int)handle->desc.K;
789                   layout->dim_size[1] = (unsigned int)(handle->desc.K * 3);
790                 } else {
791                   layout->dim_size[0] = (unsigned int)handle->desc.K;
792                   layout->dim_size[1] = (unsigned int)handle->desc.K;
793                 }
794               } else {
795                 free(layout->dim_type);
796                 free(layout->dim_size);
797                 free(layout);
798                 layout = 0; /* make sure a NULL is returned */
799                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
800               }
801             } else {
802               free(layout);
803               layout = 0; /* make sure a NULL is returned */
804               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
805             }
806           } else {
807             free(layout);
808             layout = 0; /* make sure a NULL is returned */
809             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
810           }
811         } else {
812           free(layout);
813           layout = 0; /* make sure a NULL is returned */
814           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
815         }
816       } else if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) {
817         layout->format = handle->desc.buffer_format;
818         layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
819 
820 
821         if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NC) > 0) || ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
822           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) || ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
823             layout->datatype = handle->desc.datatype_in;
824             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
825             layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
826 
827             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
828               layout->num_dims = 1;
829 
830               if ( (type == LIBXSMM_DNN_RNN_REGULAR_BIAS) || (type == LIBXSMM_DNN_RNN_GRADIENT_BIAS) ) {
831                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
832                 if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_LSTM ) {
833                   layout->dim_size[0] = (unsigned int)(handle->desc.K * 4);
834                 } else if ( handle->desc.cell_type == LIBXSMM_DNN_RNNCELL_GRU ) {
835                   layout->dim_size[0] = (unsigned int)(handle->desc.K * 3);
836                 } else {
837                   layout->dim_size[0] = (unsigned int)handle->desc.K;
838                 }
839               } else { /* coverity[dead_error_begin] */
840                 free(layout->dim_type);
841                 free(layout->dim_size);
842                 free(layout);
843                 layout = 0; /* make sure a NULL is returned */
844                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
845               }
846             } else {
847               free(layout);
848               layout = 0; /* make sure a NULL is returned */
849               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
850             }
851           } else {
852             free(layout);
853             layout = 0; /* make sure a NULL is returned */
854             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
855           }
856         } else {
857           free(layout);
858           layout = 0; /* make sure a NULL is returned */
859           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
860         }
861       } else {
862         free(layout);
863         layout = 0; /* make sure a NULL is returned */
864         *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
865       }
866     } else {
867       *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
868     }
869   } else {
870     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
871   }
872   return layout;
873 }
874 
875 
libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,libxsmm_dnn_err_t * status)876 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_scratch_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
877 {
878   size_t size = 0;
879   *status = LIBXSMM_DNN_SUCCESS;
880 
881   if (0 != handle) {
882     const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in);
883     const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in;
884 
885     switch (handle->desc.cell_type) {
886       case LIBXSMM_DNN_RNNCELL_RNN_RELU:
887       case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
888       case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
889         switch (kind) {
890           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
891             size += 0;
892           } break;
893           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
894           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
895           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
896           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
897             size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in  + 64; /* wT */
898             size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in  + 64; /* rT */
899             size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in  + 64; /* xT */
900             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
901             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T + 64; /* deltat */
902 
903           } break;
904           default: {
905             *status = LIBXSMM_DNN_ERR_INVALID_KIND;
906           }
907         }
908       } break;
909       case  LIBXSMM_DNN_RNNCELL_LSTM: {
910         switch (kind) {
911           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
912             size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* w */
913             size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* r */
914             /*  The scratches below are needed only for BF16 code for the intermediate results  */
915             if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
916               size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */
917               size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;                                           /* intermediate scratches */
918             }
919           } break;
920           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
921           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
922           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
923           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
924             size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* w */
925             size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 4 + 4 * 64; /* r */
926             size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* wT */
927             size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 4 + 4 * 64; /* rT */
928             size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in  + 64; /* xT */
929             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
930             size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */
931             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */
932             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */
933             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */
934             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dci */
935             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */
936             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */
937             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dpB */
938             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dciB */
939             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */
940             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */
941             /*  The scratches below are needed only for BF16 code for the intermediate results  */
942             if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
943               size += (size_t)4 *((size_t)handle->desc.K * sizeof(float) + 64); /* intermediate db scratch */
944               size += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64; /* intermediate dx scratches */
945               size += (size_t)7 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64); /* intermediate scratches */
946               size += (size_t)2 *((size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64); /* intermediate scratches */
947             }
948           } break;
949           default: {
950             *status = LIBXSMM_DNN_ERR_INVALID_KIND;
951           }
952         }
953       } break;
954       case  LIBXSMM_DNN_RNNCELL_GRU: {
955         switch (kind) {
956           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
957             size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* w */
958             size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* r */
959           } break;
960           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
961           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
962           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
963           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
964             size += (size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* w */
965             size += (size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize * 3 + 3 * 64; /* r */
966             size += (size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* wT */
967             size += (size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in * 3 + 3 * 64; /* rT */
968             size += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in  + 64; /* xT */
969             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* hT */
970             size += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64; /* deltat */
971             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* di */
972             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dc */
973             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* df */
974             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* do */
975             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* diB */
976             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dcB */
977             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* dfB */
978             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* oT */
979             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t1 */
980             size += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64; /* t2 */
981           } break;
982           default: {
983             *status = LIBXSMM_DNN_ERR_INVALID_KIND;
984           }
985         }
986       } break;
987       default: {
988         *status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
989       }
990     }
991   } else {
992     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
993   }
994 
995   return size;
996 }
997 
998 
libxsmm_dnn_rnncell_get_scratch_ptr(const libxsmm_dnn_rnncell * handle,libxsmm_dnn_err_t * status)999 LIBXSMM_API void* libxsmm_dnn_rnncell_get_scratch_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status)
1000 {
1001   *status = LIBXSMM_DNN_SUCCESS;
1002 
1003   if (0 != handle) {
1004     return handle->scratch_base;
1005   } else {
1006     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1007   }
1008 
1009   return NULL;
1010 }
1011 
1012 
libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,const void * scratch)1013 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* scratch)
1014 {
1015   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1016 
1017   if (NULL != handle) {
1018     const size_t typesize_in = libxsmm_dnn_typesize(handle->desc.datatype_in);
1019     const size_t dwdr_typesize = (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ? sizeof(float) : typesize_in;
1020     uintptr_t address = (uintptr_t)scratch;
1021     size_t offset = 0;
1022 
1023     switch (handle->desc.cell_type) {
1024       case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1025       case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1026       case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1027         switch (kind) {
1028           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1029             /* forward only has no scratch need */
1030           } break;
1031           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1032           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1033           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1034           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1035             if (scratch == 0) {
1036               status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1037               return status;
1038             }
1039             handle->scratch_base = (void*)address;
1040             /* wT */
1041             if (address % 64 == 0) {
1042               handle->scratch_wT = (void*)address;
1043             } else {
1044               offset = (64 - address % 64);
1045               handle->scratch_wT = (void*)(address+offset);
1046             }
1047             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) + 64;
1048             /* rT */
1049             if (address % 64 == 0) {
1050               handle->scratch_rT = (void*)address;
1051             } else {
1052               offset = (64 - address % 64);
1053               handle->scratch_rT = (void*)(address+offset);
1054             }
1055             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) + 64;
1056             /* xT */
1057             if (address % 64 == 0) {
1058               handle->scratch_xT = (void*)address;
1059             } else {
1060               offset = (64 - address % 64);
1061               handle->scratch_xT = (void*)(address+offset);
1062             }
1063             address += ((size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in) + 64;
1064             /* hT */
1065             if (address % 64 == 0) {
1066               handle->scratch_hT = (void*)address;
1067             } else {
1068               offset = (64 - address % 64);
1069               handle->scratch_hT = (void*)(address+offset);
1070             }
1071             address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out)) + 64;
1072             /* deltat */
1073             if (address % 64 == 0) {
1074               handle->scratch_deltat = (void*)address;
1075             } else {
1076               offset = (64 - address % 64);
1077               handle->scratch_deltat = (void*)(address+offset);
1078             }
1079             address += ((size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) * (size_t)handle->desc.max_T) + 64;
1080           } break;
1081           default: {
1082             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1083           }
1084         }
1085       } break;
1086       case LIBXSMM_DNN_RNNCELL_LSTM: {
1087         switch (kind) {
1088           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1089             if (scratch == 0) {
1090               status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1091               return status;
1092             }
1093             handle->scratch_base = (void*)address;
1094             /* w scratch */
1095             if (address % 64 == 0) {
1096               handle->scratch_w = (void*)address;
1097             } else {
1098               offset = (64 - address % 64);
1099               handle->scratch_w = (void*)(address+offset);
1100             }
1101             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1102             /* r scratch */
1103             if (address % 64 == 0) {
1104               handle->scratch_r = (void*)address;
1105             } else {
1106               offset = (64 - address % 64);
1107               handle->scratch_r = (void*)(address+offset);
1108             }
1109             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1110             /*  The scratches below are needed only for BF16 code for the intermediate results  */
1111             if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
1112               /* cst scratch */
1113               if (address % 64 == 0) {
1114                 handle->cst_scratch = (void*)address;
1115               } else {
1116                 offset = (64 - address % 64);
1117                 handle->cst_scratch = (void*)(address+offset);
1118               }
1119               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1120               /* ht scratch */
1121               if (address % 64 == 0) {
1122                 handle->ht_scratch = (void*)address;
1123               } else {
1124                 offset = (64 - address % 64);
1125                 handle->ht_scratch = (void*)(address+offset);
1126               }
1127               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1128               /* it scratch */
1129               if (address % 64 == 0) {
1130                 handle->it_scratch = (void*)address;
1131               } else {
1132                 offset = (64 - address % 64);
1133                 handle->it_scratch = (void*)(address+offset);
1134               }
1135               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1136               /* ft scratch */
1137               if (address % 64 == 0) {
1138                 handle->ft_scratch = (void*)address;
1139               } else {
1140                 offset = (64 - address % 64);
1141                 handle->ft_scratch = (void*)(address+offset);
1142               }
1143               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1144               /* ot scratch */
1145               if (address % 64 == 0) {
1146                 handle->ot_scratch = (void*)address;
1147               } else {
1148                 offset = (64 - address % 64);
1149                 handle->ot_scratch = (void*)(address+offset);
1150               }
1151               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1152               /* cit scratch */
1153               if (address % 64 == 0) {
1154                 handle->cit_scratch = (void*)address;
1155               } else {
1156                 offset = (64 - address % 64);
1157                 handle->cit_scratch = (void*)(address+offset);
1158               }
1159               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1160               /* cot scratch */
1161               if (address % 64 == 0) {
1162                 handle->cot_scratch = (void*)address;
1163               } else {
1164                 offset = (64 - address % 64);
1165                 handle->cot_scratch = (void*)(address+offset);
1166               }
1167               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1168               /* csp scratch */
1169               if (address % 64 == 0) {
1170                 handle->csp_scratch = (void*)address;
1171               } else {
1172                 offset = (64 - address % 64);
1173                 handle->csp_scratch = (void*)(address+offset);
1174               }
1175               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
1176             }
1177           } break;
1178           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1179           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1180           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1181           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1182             if (scratch == 0) {
1183               status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1184               return status;
1185             }
1186             handle->scratch_base = (void*)address;
1187             /* w scratch */
1188             if (address % 64 == 0) {
1189               handle->scratch_w = (void*)address;
1190             } else {
1191               offset = (64 - address % 64);
1192               handle->scratch_w = (void*)(address+offset);
1193             }
1194             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64;
1195             /* r scratch */
1196             if (address % 64 == 0) {
1197               handle->scratch_r = (void*)address;
1198             } else {
1199               offset = (64 - address % 64);
1200               handle->scratch_r = (void*)(address+offset);
1201             }
1202             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 4 + 64;
1203             /* wT */
1204             if (address % 64 == 0) {
1205               handle->scratch_wT = (void*)address;
1206             } else {
1207               offset = (64 - address % 64);
1208               handle->scratch_wT = (void*)(address+offset);
1209             }
1210             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1211             /* rT */
1212             if (address % 64 == 0) {
1213               handle->scratch_rT = (void*)address;
1214             } else {
1215               offset = (64 - address % 64);
1216               handle->scratch_rT = (void*)(address+offset);
1217             }
1218             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 4 + 64;
1219             /* xT */
1220             if (address % 64 == 0) {
1221               handle->scratch_xT = (void*)address;
1222             } else {
1223               offset = (64 - address % 64);
1224               handle->scratch_xT = (void*)(address+offset);
1225             }
1226             address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64;
1227             /* hT */
1228             if (address % 64 == 0) {
1229               handle->scratch_hT = (void*)address;
1230             } else {
1231               offset = (64 - address % 64);
1232               handle->scratch_hT = (void*)(address+offset);
1233             }
1234             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1235             /* deltat */
1236             if (address % 64 == 0) {
1237               handle->scratch_deltat = (void*)address;
1238             } else {
1239               offset = (64 - address % 64);
1240               handle->scratch_deltat = (void*)(address+offset);
1241             }
1242             address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64;
1243             /* di */
1244             if (address % 64 == 0) {
1245               handle->scratch_di = (void*)address;
1246             } else {
1247               offset = (64 - address % 64);
1248               handle->scratch_di = (void*)(address+offset);
1249             }
1250             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1251             /* df */
1252             if (address % 64 == 0) {
1253               handle->scratch_df = (void*)address;
1254             } else {
1255               offset = (64 - address % 64);
1256               handle->scratch_df = (void*)(address+offset);
1257             }
1258             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1259             /* do */
1260             if (address % 64 == 0) {
1261               handle->scratch_do = (void*)address;
1262             } else {
1263               offset = (64 - address % 64);
1264               handle->scratch_do = (void*)(address+offset);
1265             }
1266             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1267             /* dci */
1268             if (address % 64 == 0) {
1269               handle->scratch_dci = (void*)address;
1270             } else {
1271               offset = (64 - address % 64);
1272               handle->scratch_dci = (void*)(address+offset);
1273             }
1274             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1275             /* diB */
1276             if (address % 64 == 0) {
1277               handle->scratch_diB = (void*)address;
1278             } else {
1279               offset = (64 - address % 64);
1280               handle->scratch_diB = (void*)(address+offset);
1281             }
1282             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1283             /* dfB */
1284             if (address % 64 == 0) {
1285               handle->scratch_dfB = (void*)address;
1286             } else {
1287               offset = (64 - address % 64);
1288               handle->scratch_dfB = (void*)(address+offset);
1289             }
1290             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1291             /* dpB */
1292             if (address % 64 == 0) {
1293               handle->scratch_dpB = (void*)address;
1294             } else {
1295               offset = (64 - address % 64);
1296               handle->scratch_dpB = (void*)(address+offset);
1297             }
1298             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1299             /* dciB */
1300             if (address % 64 == 0) {
1301               handle->scratch_dciB = (void*)address;
1302             } else {
1303               offset = (64 - address % 64);
1304               handle->scratch_dciB = (void*)(address+offset);
1305             }
1306             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1307             /* t1 */
1308             if (address % 64 == 0) {
1309               handle->scratch_t1 = (void*)address;
1310             } else {
1311               offset = (64 - address % 64);
1312               handle->scratch_t1 = (void*)(address+offset);
1313             }
1314             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1315             /* t2 */
1316             if (address % 64 == 0) {
1317               handle->scratch_t2 = (void*)address;
1318             } else {
1319               offset = (64 - address % 64);
1320               handle->scratch_t2 = (void*)(address+offset);
1321             }
1322             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1323             /*  The scratches below are needed only for BF16 code for the intermediate results  */
1324             if (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) {
1325               /* dx scratch */
1326               if (address % 64 == 0) {
1327                 handle->scratch_dx = (void*)address;
1328               } else {
1329                 offset = (64 - address % 64);
1330                 handle->scratch_dx = (void*)(address+offset);
1331               }
1332               address += (size_t)handle->desc.C * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1333               /* dhp scratch */
1334               if (address % 64 == 0) {
1335                 handle->scratch_dhp = (void*)address;
1336               } else {
1337                 offset = (64 - address % 64);
1338                 handle->scratch_dhp = (void*)(address+offset);
1339               }
1340               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
1341               /* db scratch */
1342               if (address % 64 == 0) {
1343                 handle->scratch_db = (void*)address;
1344               } else {
1345                 offset = (64 - address % 64);
1346                 handle->scratch_db = (void*)(address+offset);
1347               }
1348               address += (size_t)handle->desc.K * 4 * sizeof(float) + 64;
1349               /* cst scratch */
1350               if (address % 64 == 0) {
1351                 handle->cst_scratch = (void*)address;
1352               } else {
1353                 offset = (64 - address % 64);
1354                 handle->cst_scratch = (void*)(address+offset);
1355               }
1356               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1357               /* ht scratch */
1358               if (address % 64 == 0) {
1359                 handle->ht_scratch = (void*)address;
1360               } else {
1361                 offset = (64 - address % 64);
1362                 handle->ht_scratch = (void*)(address+offset);
1363               }
1364               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1365               /* it scratch */
1366               if (address % 64 == 0) {
1367                 handle->it_scratch = (void*)address;
1368               } else {
1369                 offset = (64 - address % 64);
1370                 handle->it_scratch = (void*)(address+offset);
1371               }
1372               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1373               /* ft scratch */
1374               if (address % 64 == 0) {
1375                 handle->ft_scratch = (void*)address;
1376               } else {
1377                 offset = (64 - address % 64);
1378                 handle->ft_scratch = (void*)(address+offset);
1379               }
1380               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1381               /* ot scratch */
1382               if (address % 64 == 0) {
1383                 handle->ot_scratch = (void*)address;
1384               } else {
1385                 offset = (64 - address % 64);
1386                 handle->ot_scratch = (void*)(address+offset);
1387               }
1388               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1389               /* cit scratch */
1390               if (address % 64 == 0) {
1391                 handle->cit_scratch = (void*)address;
1392               } else {
1393                 offset = (64 - address % 64);
1394                 handle->cit_scratch = (void*)(address+offset);
1395               }
1396               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1397               /* cot scratch */
1398               if (address % 64 == 0) {
1399                 handle->cot_scratch = (void*)address;
1400               } else {
1401                 offset = (64 - address % 64);
1402                 handle->cot_scratch = (void*)(address+offset);
1403               }
1404               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) * (size_t)handle->desc.max_T + 64;
1405               /* csp scratch */
1406               if (address % 64 == 0) {
1407                 handle->csp_scratch = (void*)address;
1408               } else {
1409                 offset = (64 - address % 64);
1410                 handle->csp_scratch = (void*)(address+offset);
1411               }
1412               address += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof(float) + 64;
1413             }
1414           } break;
1415           default: {
1416             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1417           }
1418         }
1419       } break;
1420       case LIBXSMM_DNN_RNNCELL_GRU: {
1421         switch (kind) {
1422           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1423             if (scratch == 0) {
1424               status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1425               return status;
1426             }
1427             handle->scratch_base = (void*)address;
1428             /* w scratch */
1429             if (address % 64 == 0) {
1430               handle->scratch_w = (void*)address;
1431             } else {
1432               offset = (64 - address % 64);
1433               handle->scratch_w = (void*)(address+offset);
1434             }
1435             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1436             /* r scratch */
1437             if (address % 64 == 0) {
1438               handle->scratch_r = (void*)address;
1439             } else {
1440               offset = (64 - address % 64);
1441               handle->scratch_r = (void*)(address+offset);
1442             }
1443             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1444           } break;
1445           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1446           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1447           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1448           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1449             if (scratch == 0) {
1450               status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1451               return status;
1452             }
1453             handle->scratch_base = (void*)address;
1454             /* w scratch */
1455             if (address % 64 == 0) {
1456               handle->scratch_w = (void*)address;
1457             } else {
1458               offset = (64 - address % 64);
1459               handle->scratch_w = (void*)(address+offset);
1460             }
1461             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64;
1462             /* r scratch */
1463             if (address % 64 == 0) {
1464               handle->scratch_r = (void*)address;
1465             } else {
1466               offset = (64 - address % 64);
1467               handle->scratch_r = (void*)(address+offset);
1468             }
1469             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * dwdr_typesize) * 3 + 64;
1470             /* wT */
1471             if (address % 64 == 0) {
1472               handle->scratch_wT = (void*)address;
1473             } else {
1474               offset = (64 - address % 64);
1475               handle->scratch_wT = (void*)(address+offset);
1476             }
1477             address += ((size_t)handle->desc.C * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1478             /* rT */
1479             if (address % 64 == 0) {
1480               handle->scratch_rT = (void*)address;
1481             } else {
1482               offset = (64 - address % 64);
1483               handle->scratch_rT = (void*)(address+offset);
1484             }
1485             address += ((size_t)handle->desc.K * (size_t)handle->desc.K * typesize_in) * 3 + 64;
1486             /* xT */
1487             if (address % 64 == 0) {
1488               handle->scratch_xT = (void*)address;
1489             } else {
1490               offset = (64 - address % 64);
1491               handle->scratch_xT = (void*)(address+offset);
1492             }
1493             address += (size_t)handle->desc.C * (size_t)handle->desc.N * typesize_in + 64;
1494             /* hT */
1495             if (address % 64 == 0) {
1496               handle->scratch_hT = (void*)address;
1497             } else {
1498               offset = (64 - address % 64);
1499               handle->scratch_hT = (void*)(address+offset);
1500             }
1501             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1502             /* deltat */
1503             if (address % 64 == 0) {
1504               handle->scratch_deltat = (void*)address;
1505             } else {
1506               offset = (64 - address % 64);
1507               handle->scratch_deltat = (void*)(address+offset);
1508             }
1509             address += (size_t)handle->desc.K * (size_t)handle->desc.N * dwdr_typesize + 64;
1510             /* di */
1511             if (address % 64 == 0) {
1512               handle->scratch_di = (void*)address;
1513             } else {
1514               offset = (64 - address % 64);
1515               handle->scratch_di = (void*)(address+offset);
1516             }
1517             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1518             /* dc */
1519             if (address % 64 == 0) {
1520               handle->scratch_dci = (void*)address;
1521             } else {
1522               offset = (64 - address % 64);
1523               handle->scratch_dci = (void*)(address+offset);
1524             }
1525             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1526             /* df */
1527             if (address % 64 == 0) {
1528               handle->scratch_df = (void*)address;
1529             } else {
1530               offset = (64 - address % 64);
1531               handle->scratch_df = (void*)(address+offset);
1532             }
1533             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1534             /* do */
1535             if (address % 64 == 0) {
1536               handle->scratch_do = (void*)address;
1537             } else {
1538               offset = (64 - address % 64);
1539               handle->scratch_do = (void*)(address+offset);
1540             }
1541             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1542             /* diB */
1543             if (address % 64 == 0) {
1544               handle->scratch_diB = (void*)address;
1545             } else {
1546               offset = (64 - address % 64);
1547               handle->scratch_diB = (void*)(address+offset);
1548             }
1549             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1550             /* dcB */
1551             if (address % 64 == 0) {
1552               handle->scratch_dciB = (void*)address;
1553             } else {
1554               offset = (64 - address % 64);
1555               handle->scratch_dciB = (void*)(address+offset);
1556             }
1557             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1558             /* dfB */
1559             if (address % 64 == 0) {
1560               handle->scratch_dfB = (void*)address;
1561             } else {
1562               offset = (64 - address % 64);
1563               handle->scratch_dfB = (void*)(address+offset);
1564             }
1565             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1566             /* doB (repurposed for oT) */
1567             if (address % 64 == 0) {
1568               handle->scratch_dpB = (void*)address;
1569             } else {
1570               offset = (64 - address % 64);
1571               handle->scratch_dpB = (void*)(address+offset);
1572             }
1573             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1574             /* t1 */
1575             if (address % 64 == 0) {
1576               handle->scratch_t1 = (void*)address;
1577             } else {
1578               offset = (64 - address % 64);
1579               handle->scratch_t1 = (void*)(address+offset);
1580             }
1581             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1582             /* t2 */
1583             if (address % 64 == 0) {
1584               handle->scratch_t2 = (void*)address;
1585             } else {
1586               offset = (64 - address % 64);
1587               handle->scratch_t2 = (void*)(address+offset);
1588             }
1589             address += (size_t)handle->desc.K * (size_t)handle->desc.N * libxsmm_dnn_typesize(handle->desc.datatype_out) + 64;
1590           } break;
1591           default: {
1592             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1593           }
1594         }
1595       } break;
1596       default: {
1597         status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1598       }
1599     }
1600   } else {
1601     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1602   }
1603 
1604   return status;
1605 }
1606 
1607 
libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind)1608 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_scratch(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind)
1609 {
1610   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1611 
1612   if (0 != handle) {
1613     switch (handle->desc.cell_type) {
1614       case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1615       case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1616       case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1617         switch (kind) {
1618           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1619             /* forward only has no scratch need */
1620           } break;
1621           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1622           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1623           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1624           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1625             handle->scratch_wT = 0;
1626             handle->scratch_rT = 0;
1627             handle->scratch_xT = 0;
1628             handle->scratch_hT = 0;
1629             handle->scratch_deltat = 0;
1630           } break;
1631           default: {
1632             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1633           }
1634         }
1635       } break;
1636       case LIBXSMM_DNN_RNNCELL_LSTM: {
1637         switch (kind) {
1638           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1639             handle->scratch_w  = 0;
1640             handle->scratch_r  = 0;
1641             handle->csp_scratch  = 0;
1642             handle->cst_scratch  = 0;
1643             handle->ht_scratch  = 0;
1644             handle->it_scratch  = 0;
1645             handle->ft_scratch  = 0;
1646             handle->ot_scratch  = 0;
1647             handle->cit_scratch  = 0;
1648             handle->cot_scratch  = 0;
1649           } break;
1650           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1651           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1652           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1653           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1654             handle->scratch_w = 0;
1655             handle->scratch_r = 0;
1656             handle->scratch_wT = 0;
1657             handle->scratch_rT = 0;
1658             handle->scratch_xT = 0;
1659             handle->scratch_hT = 0;
1660             handle->scratch_deltat = 0;
1661             handle->scratch_di = 0;
1662             handle->scratch_df = 0;
1663             handle->scratch_do = 0;
1664             handle->scratch_dci = 0;
1665             handle->scratch_diB = 0;
1666             handle->scratch_dfB = 0;
1667             handle->scratch_dpB = 0;
1668             handle->scratch_dciB = 0;
1669             handle->scratch_t1 = 0;
1670             handle->scratch_t2 = 0;
1671             handle->csp_scratch = 0;
1672             handle->cst_scratch = 0;
1673             handle->ht_scratch = 0;
1674             handle->it_scratch = 0;
1675             handle->ft_scratch = 0;
1676             handle->ot_scratch = 0;
1677             handle->cit_scratch = 0;
1678             handle->cot_scratch = 0;
1679           } break;
1680           default: {
1681             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1682           }
1683         }
1684       } break;
1685       case LIBXSMM_DNN_RNNCELL_GRU: {
1686         switch (kind) {
1687           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1688             handle->scratch_w   = 0;
1689             handle->scratch_r   = 0;
1690             handle->ht_scratch  = 0;
1691             handle->it_scratch  = 0;
1692             handle->cit_scratch = 0;
1693             handle->ft_scratch  = 0;
1694             handle->ot_scratch  = 0;
1695           } break;
1696           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1697           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1698           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1699           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1700             handle->scratch_w  = 0;
1701             handle->scratch_r  = 0;
1702             handle->scratch_wT = 0;
1703             handle->scratch_rT = 0;
1704             handle->scratch_xT = 0;
1705             handle->scratch_hT = 0;
1706             handle->scratch_deltat = 0;
1707             handle->scratch_di = 0;
1708             handle->scratch_dci = 0;
1709             handle->scratch_df  = 0;
1710             handle->scratch_do  = 0;
1711             handle->scratch_diB = 0;
1712             handle->scratch_dciB = 0;
1713             handle->scratch_dfB = 0;
1714             handle->scratch_dpB = 0;
1715             handle->scratch_t1  = 0;
1716             handle->scratch_t2  = 0;
1717             handle->ht_scratch  = 0;
1718             handle->it_scratch  = 0;
1719             handle->ft_scratch  = 0;
1720             handle->ot_scratch  = 0;
1721             handle->cit_scratch = 0;
1722           } break;
1723           default: {
1724             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1725           }
1726         }
1727       } break;
1728       default: {
1729         status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1730       }
1731     }
1732   } else {
1733     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1734   }
1735 
1736   return status;
1737 }
1738 
1739 
libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,libxsmm_dnn_err_t * status)1740 LIBXSMM_API size_t libxsmm_dnn_rnncell_get_internalstate_size(const libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, libxsmm_dnn_err_t* status)
1741 {
1742   size_t size = 0;
1743   *status = LIBXSMM_DNN_SUCCESS;
1744 
1745   if (0 != handle) {
1746     const size_t sizeof_datatype = sizeof(float);
1747 
1748     switch (handle->desc.cell_type) {
1749       case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1750       case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1751       case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1752         switch (kind) {
1753           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1754             size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */
1755           } break;
1756           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1757           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1758           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1759           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1760             size += (size_t)handle->desc.K * (size_t)handle->desc.N * sizeof_datatype * (size_t)handle->desc.max_T + 64; /* zt */
1761           } break;
1762           default: {
1763             *status = LIBXSMM_DNN_ERR_INVALID_KIND;
1764           }
1765         }
1766       } break;
1767       case LIBXSMM_DNN_RNNCELL_LSTM: {
1768         switch (kind) {
1769           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1770             /* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
1771           } break;
1772           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1773           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1774           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1775           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1776             /* with i, f, o, ci, co, cs exposed as i/o, there is currently no need for internal state */
1777           } break;
1778           default: {
1779             *status = LIBXSMM_DNN_ERR_INVALID_KIND;
1780           }
1781         }
1782       } break;
1783       case LIBXSMM_DNN_RNNCELL_GRU: {
1784         switch (kind) {
1785           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1786             /* with i, f, c, o exposed as i/o, there is currently no need for internal state */
1787           } break;
1788           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1789           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1790           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1791           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1792             /* with i, f, c, o exposed as i/o, there is currently no need for internal state */
1793           } break;
1794           default: {
1795             *status = LIBXSMM_DNN_ERR_INVALID_KIND;
1796           }
1797         }
1798       } break;
1799       default: {
1800         *status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1801       }
1802     }
1803   } else {
1804     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1805   }
1806 
1807   return size;
1808 }
1809 
1810 
libxsmm_dnn_rnncell_get_internalstate_ptr(const libxsmm_dnn_rnncell * handle,libxsmm_dnn_err_t * status)1811 LIBXSMM_API void* libxsmm_dnn_rnncell_get_internalstate_ptr(const libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status)
1812 {
1813   *status = LIBXSMM_DNN_SUCCESS;
1814 
1815   if (0 != handle) {
1816     return handle->internal_z;
1817   } else {
1818     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1819   }
1820 
1821   return NULL;
1822 }
1823 
1824 
libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind,const void * internalstate)1825 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind, const void* internalstate)
1826 {
1827   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1828   uintptr_t address = (uintptr_t)internalstate;
1829   size_t offset = 0;
1830 
1831   if (0 != handle) {
1832     switch (handle->desc.cell_type) {
1833       case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1834       case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1835       case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1836         if (internalstate == 0) {
1837           status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1838           return status;
1839         }
1840         switch (kind) {
1841           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1842             if (address % 64 == 0) {
1843               handle->internal_z = (void*)address;
1844             } else {
1845               offset = (64 - address % 64);
1846               handle->internal_z = (void*)(address+offset);
1847             }
1848           } break;
1849           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1850           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1851           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1852           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1853             if (address % 64 == 0) {
1854               handle->internal_z = (void*)address;
1855             } else {
1856               offset = (64 - address % 64);
1857               handle->internal_z = (void*)(address+offset);
1858             }
1859           } break;
1860           default: {
1861             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1862           }
1863         }
1864       } break;
1865       case LIBXSMM_DNN_RNNCELL_LSTM: {
1866         switch (kind) {
1867           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1868           } break;
1869           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1870           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1871           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1872           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1873           } break;
1874           default: {
1875             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1876           }
1877         }
1878       } break;
1879       case LIBXSMM_DNN_RNNCELL_GRU: {
1880         switch (kind) {
1881           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1882           } break;
1883           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1884           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1885           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1886           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1887           } break;
1888           default: {
1889             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1890           }
1891         }
1892       } break;
1893       default: {
1894         status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1895       }
1896     }
1897   } else {
1898     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1899   }
1900 
1901   return status;
1902 }
1903 
1904 
libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_compute_kind kind)1905 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_internalstate(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_compute_kind kind)
1906 {
1907   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1908 
1909   if (0 != handle) {
1910     switch (handle->desc.cell_type) {
1911       case LIBXSMM_DNN_RNNCELL_RNN_RELU:
1912       case LIBXSMM_DNN_RNNCELL_RNN_SIGMOID:
1913       case LIBXSMM_DNN_RNNCELL_RNN_TANH: {
1914         switch (kind) {
1915           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1916             handle->internal_z = 0;
1917           } break;
1918           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1919           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1920           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1921           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1922             handle->internal_z = 0;
1923           } break;
1924           default: {
1925             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1926           }
1927         }
1928       } break;
1929       case LIBXSMM_DNN_RNNCELL_LSTM: {
1930         switch (kind) {
1931           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1932           } break;
1933           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1934           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1935           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1936           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1937           } break;
1938           default: {
1939             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1940           }
1941         }
1942       } break;
1943       case LIBXSMM_DNN_RNNCELL_GRU: {
1944         switch (kind) {
1945           case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1946           } break;
1947           case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1948           case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1949           case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD:
1950           case LIBXSMM_DNN_COMPUTE_KIND_ALL: {
1951           } break;
1952           default: {
1953             status = LIBXSMM_DNN_ERR_INVALID_KIND;
1954           }
1955         }
1956       } break;
1957       default: {
1958         status = LIBXSMM_DNN_ERR_INVALID_RNN_TYPE;
1959       }
1960     }
1961   } else {
1962     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1963   }
1964 
1965   return status;
1966 }
1967 
1968 
libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell * handle,const float forget_bias)1969 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_allocate_forget_bias(libxsmm_dnn_rnncell* handle, const float forget_bias)
1970 {
1971   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1972 
1973   if (handle != 0) {
1974     handle->forget_bias = forget_bias;
1975   } else {
1976     status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
1977   }
1978 
1979   return status;
1980 }
1981 
1982 
libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor * tensor,const libxsmm_dnn_tensor_type type)1983 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_bind_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type)
1984 {
1985   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1986 
1987   /* check for tensor type */
1988   if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT)             && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT)             &&
1989        (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV)           && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV)           &&
1990        (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
1991        (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT)            && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT)            &&
1992        (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT)      && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT)      &&
1993        (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS)      && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
1994        (type != LIBXSMM_DNN_RNN_REGULAR_BIAS)              && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS)              &&
1995        (type != LIBXSMM_DNN_RNN_REGULAR_CS)                && (type != LIBXSMM_DNN_RNN_GRADIENT_CS)                &&
1996        (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE)      && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE)      &&
1997        (type != LIBXSMM_DNN_RNN_INTERNAL_I)                && (type != LIBXSMM_DNN_RNN_INTERNAL_F)                 &&
1998        (type != LIBXSMM_DNN_RNN_INTERNAL_O)                && (type != LIBXSMM_DNN_RNN_INTERNAL_CI)                &&
1999        (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
2000     status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
2001     return status;
2002   }
2003 
2004   if (handle != 0 && tensor != 0) {
2005     libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_rnncell_create_tensor_datalayout(handle, type, &status);
2006 
2007     if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
2008       if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
2009         handle->xt = (libxsmm_dnn_tensor*)tensor;
2010       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
2011         handle->dxt = (libxsmm_dnn_tensor*)tensor;
2012       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
2013         handle->csp = (libxsmm_dnn_tensor*)tensor;
2014       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
2015         handle->dcsp = (libxsmm_dnn_tensor*)tensor;
2016       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
2017         handle->hp = (libxsmm_dnn_tensor*)tensor;
2018       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
2019         handle->dhp = (libxsmm_dnn_tensor*)tensor;
2020       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
2021         handle->w = (libxsmm_dnn_tensor*)tensor;
2022       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
2023         handle->wt = (libxsmm_dnn_tensor*)tensor;
2024       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
2025         handle->dw = (libxsmm_dnn_tensor*)tensor;
2026       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
2027         handle->r = (libxsmm_dnn_tensor*)tensor;
2028       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
2029         handle->rt = (libxsmm_dnn_tensor*)tensor;
2030       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
2031         handle->dr = (libxsmm_dnn_tensor*)tensor;
2032       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
2033         handle->b = (libxsmm_dnn_tensor*)tensor;
2034       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
2035         handle->db = (libxsmm_dnn_tensor*)tensor;
2036       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
2037         handle->cst = (libxsmm_dnn_tensor*)tensor;
2038       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
2039         handle->dcs = (libxsmm_dnn_tensor*)tensor;
2040       } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
2041         handle->ht = (libxsmm_dnn_tensor*)tensor;
2042       } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
2043         handle->dht = (libxsmm_dnn_tensor*)tensor;
2044       } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
2045         handle->it = (libxsmm_dnn_tensor*)tensor;
2046       } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
2047         handle->ft = (libxsmm_dnn_tensor*)tensor;
2048       } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
2049         handle->ot = (libxsmm_dnn_tensor*)tensor;
2050       } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
2051         handle->cit = (libxsmm_dnn_tensor*)tensor;
2052       } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
2053         handle->cot = (libxsmm_dnn_tensor*)tensor;
2054       } else {
2055         /* cannot happen */
2056       }
2057     } else {
2058       status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
2059     }
2060 
2061     libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
2062   }
2063   else {
2064     status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
2065   }
2066 
2067   return status;
2068 }
2069 
2070 
libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)2071 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_rnncell_get_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status)
2072 {
2073   libxsmm_dnn_tensor* tensor = 0;
2074   LIBXSMM_UNUSED(status/*TODO*/);
2075 
2076   /* check for tensor type */
2077   if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT)             && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT)             &&
2078        (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV)           && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV)           &&
2079        (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
2080        (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT)            && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT)            &&
2081        (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT)      && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT)      &&
2082        (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS)      && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
2083        (type != LIBXSMM_DNN_RNN_REGULAR_BIAS)              && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS)              &&
2084        (type != LIBXSMM_DNN_RNN_REGULAR_CS)                && (type != LIBXSMM_DNN_RNN_GRADIENT_CS)                &&
2085        (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE)      && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE)      &&
2086        (type != LIBXSMM_DNN_RNN_INTERNAL_I)                && (type != LIBXSMM_DNN_RNN_INTERNAL_F)                 &&
2087        (type != LIBXSMM_DNN_RNN_INTERNAL_O)                && (type != LIBXSMM_DNN_RNN_INTERNAL_CI)                &&
2088        (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
2089     return tensor;
2090   }
2091 
2092   if (handle != 0) {
2093     if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
2094       tensor = handle->xt;
2095     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
2096       tensor = handle->dxt;
2097     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
2098       tensor = handle->csp;
2099     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
2100       tensor = handle->dcsp;
2101     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
2102       tensor = handle->hp;
2103     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
2104       tensor = handle->dhp;
2105     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
2106       tensor = handle->w;
2107     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
2108       tensor = handle->wt;
2109     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
2110       tensor = handle->dw;
2111     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
2112       tensor = handle->r;
2113     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
2114       tensor = handle->rt;
2115     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
2116       tensor = handle->dr;
2117     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
2118       tensor = handle->b;
2119     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
2120       tensor = handle->db;
2121     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
2122       tensor = handle->cst;
2123     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
2124       tensor = handle->dcs;
2125     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
2126       tensor = handle->ht;
2127     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
2128       tensor = handle->dht;
2129     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
2130       tensor = handle->it;
2131     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
2132       tensor = handle->ft;
2133     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
2134       tensor = handle->ot;
2135     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
2136       tensor = handle->cit;
2137     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
2138       tensor = handle->cot;
2139     } else {
2140       /* cannot happen */
2141     }
2142   }
2143 
2144   return tensor;
2145 }
2146 
2147 
libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell * handle,const libxsmm_dnn_tensor_type type)2148 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_release_tensor(libxsmm_dnn_rnncell* handle, const libxsmm_dnn_tensor_type type)
2149 {
2150   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
2151 
2152   /* check for tensor type */
2153   if ( (type != LIBXSMM_DNN_RNN_REGULAR_INPUT)             && (type != LIBXSMM_DNN_RNN_GRADIENT_INPUT)             &&
2154        (type != LIBXSMM_DNN_RNN_REGULAR_CS_PREV)           && (type != LIBXSMM_DNN_RNN_GRADIENT_CS_PREV)           &&
2155        (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV) && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV) &&
2156        (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT)            && (type != LIBXSMM_DNN_RNN_GRADIENT_WEIGHT)            &&
2157        (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT)      && (type != LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT)      &&
2158        (type != LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS)      && (type != LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS) &&
2159        (type != LIBXSMM_DNN_RNN_REGULAR_BIAS)              && (type != LIBXSMM_DNN_RNN_GRADIENT_BIAS)              &&
2160        (type != LIBXSMM_DNN_RNN_REGULAR_CS)                && (type != LIBXSMM_DNN_RNN_GRADIENT_CS)                &&
2161        (type != LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE)      && (type != LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE)      &&
2162        (type != LIBXSMM_DNN_RNN_INTERNAL_I)                && (type != LIBXSMM_DNN_RNN_INTERNAL_F)                 &&
2163        (type != LIBXSMM_DNN_RNN_INTERNAL_O)                && (type != LIBXSMM_DNN_RNN_INTERNAL_CI)                &&
2164        (type != LIBXSMM_DNN_RNN_INTERNAL_CO) ) {
2165     status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
2166     return status;
2167   }
2168 
2169   if (handle != 0) {
2170     if ( type == LIBXSMM_DNN_RNN_REGULAR_INPUT ) {
2171       handle->xt = 0;
2172     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_INPUT ) {
2173       handle->dxt = 0;
2174     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS_PREV ) {
2175       handle->csp = 0;
2176     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS_PREV ) {
2177       handle->dcsp = 0;
2178     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) {
2179       handle->hp = 0;
2180     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) {
2181       handle->dhp = 0;
2182     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) {
2183       handle->w = 0;
2184     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS ) {
2185       handle->wt = 0;
2186     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) {
2187       handle->dw = 0;
2188     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) {
2189       handle->r = 0;
2190     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS ) {
2191       handle->rt = 0;
2192     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) {
2193       handle->dr = 0;
2194     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_BIAS ) {
2195       handle->b = 0;
2196     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_BIAS ) {
2197       handle->db = 0;
2198     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_CS ) {
2199       handle->cst = 0;
2200     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_CS ) {
2201       handle->dcs = 0;
2202     } else if ( type == LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) {
2203       handle->ht = 0;
2204     } else if ( type == LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) {
2205       handle->dht = 0;
2206     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_I ) {
2207       handle->it = 0;
2208     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_F ) {
2209       handle->ft = 0;
2210     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_O ) {
2211       handle->ot = 0;
2212     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CI ) {
2213       handle->cit = 0;
2214     } else if ( type == LIBXSMM_DNN_RNN_INTERNAL_CO ) {
2215       handle->cot = 0;
2216     } else {
2217       /* cannot happen */
2218     }
2219   }
2220   else {
2221     status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
2222   }
2223 
2224   return status;
2225 }
2226 
2227 
libxsmm_dnn_rnncell_set_sequence_length(libxsmm_dnn_rnncell * handle,const libxsmm_blasint T)2228 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_set_sequence_length( libxsmm_dnn_rnncell* handle, const libxsmm_blasint T ) {
2229   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
2230 
2231   if (0 != handle) {
2232     if ( handle->desc.max_T < T ) {
2233       status = LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN;
2234     } else {
2235       handle->T = T;
2236     }
2237   } else {
2238     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
2239   }
2240 
2241   return status;
2242 }
2243 
2244 
libxsmm_dnn_rnncell_get_sequence_length(libxsmm_dnn_rnncell * handle,libxsmm_dnn_err_t * status)2245 LIBXSMM_API libxsmm_blasint libxsmm_dnn_rnncell_get_sequence_length( libxsmm_dnn_rnncell* handle, libxsmm_dnn_err_t* status ) {
2246   *status = LIBXSMM_DNN_SUCCESS;
2247 
2248   if (0 != handle) {
2249     return handle->T;
2250   } else {
2251     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
2252   }
2253 
2254   return 0;
2255 }
2256 
2257 
libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)2258 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_rnncell_execute_st(libxsmm_dnn_rnncell* handle, libxsmm_dnn_compute_kind kind,
2259   /*unsigned*/int start_thread, /*unsigned*/int tid)
2260 {
2261   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
2262 
2263   if (0 != handle) {
2264     switch (kind) {
2265       case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
2266         if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) {
2267           status = libxsmm_dnn_rnncell_st_fwd_nc_ck( handle, start_thread, tid );
2268         } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED)  ) {
2269           status = libxsmm_dnn_rnncell_st_fwd_ncnc_kcck( handle, start_thread, tid );
2270         } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED)  ) {
2271           status = libxsmm_dnn_rnncell_st_fwd_nc_kcck( handle, start_thread, tid );
2272         } else {
2273           status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
2274         }
2275       } break;
2276       case LIBXSMM_DNN_COMPUTE_KIND_BWD:
2277       case LIBXSMM_DNN_COMPUTE_KIND_UPD:
2278       case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
2279         if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CK) ) {
2280           status = libxsmm_dnn_rnncell_st_bwdupd_nc_ck( handle, kind, start_thread, tid );
2281         } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NC) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED)  ) {
2282           status = libxsmm_dnn_rnncell_st_bwdupd_nc_kcck( handle, kind, start_thread, tid );
2283         } else {
2284           status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
2285         }
2286       } break;
2287       default: {
2288         status = LIBXSMM_DNN_ERR_INVALID_KIND;
2289       }
2290     }
2291   } else {
2292     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
2293   }
2294 
2295   return status;
2296 }
2297 
2298