1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include <libxsmm_intrinsics_x86.h>
13 
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <math.h>
21 #if defined(_OPENMP)
22 # include <omp.h>
23 #endif
24 #if defined(LIBXSMM_OFFLOAD_TARGET)
25 # pragma offload_attribute(pop)
26 #endif
27 
28 /* include c-based dnn library */
29 #include "../common/dnn_common.h"
30 
31 #define CHKERR_LIBXSMM_DNN(A) { const int chkerr_libxsmm_dnn_ = A; if (LIBXSMM_DNN_SUCCESS != chkerr_libxsmm_dnn_) { \
32   fprintf(stderr, "%s\n", libxsmm_dnn_get_error(chkerr_libxsmm_dnn_)); global_status = chkerr_libxsmm_dnn_; } \
33 }
34 
main(int argc,char * argv[])35 int main(int argc, char* argv[])
36 {
37   /* Arrays related to FWD pass */
38   float *wgold, *xgoldt, *ugold, *hpgold, *hgoldt, *z1gold, *z2gold, *zgoldt, *bgold, *bmgold;
39   float *w, *xt, *u, *hp, *ht, *htest, *b;
40   /* Arrays related to BWD and UPD pass */
41   float *djdhgoldt, *deltagoldt, *djdugold, *djdwgold, *djdxgoldt, *djdbgold;
42   float *zigold, *di1gold, *di2gold, *ugoldTp, *wgoldTp, *hgoldTp, *xgoldTp;
43   float *djdht, *djdu, *djdw, *djdxt, *djdb, *djdxtestt, *djdwtest, *djdutest;
44 
45   const char transa = 'N', transb = 'N'; /* no transposes */
46   const float alpha = 1, beta = 1, beta0 = 0;
47   void *scratch, *internalstate;
48   size_t scratch_size = 0, internalstate_size = 0;
49 
50   int iters = 10; /* repetitions of benchmark */
51   int pass = 0;   /* pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD */
52   int nonlin = 2; /* nonlin=1 denotes ReLU, 2 denotes sigmoid, 3 denotes tanh */
53   int N = 168;    /* size of mini-batch */
54   int C = 512;    /* number of inputs */
55   int K = 256;    /* number of outputs */
56   int t = 4;      /* number of time steps (>= 1) */
57   int bn = 24;
58   int bc = 64;
59   int bk = 64;
60 
61   const char *const env_check = getenv("CHECK");
62   const double check = LIBXSMM_ABS(0 == env_check ? 1/*enable by default*/ : atof(env_check));
63 
64 #if defined(_OPENMP)
65   int nThreads = omp_get_max_threads(); /* number of threads */
66 #else
67   int nThreads = 1; /* number of threads */
68 #endif
69 
70   unsigned long long l_start, l_end;
71   double l_total = 0.0;
72   double flops = 0.0, tempflops = 0.0;
73   const double tflops = 12; /* transcendental flops */
74   int i, j, it;
75 
76   libxsmm_dnn_rnncell_desc rnncell_desc;
77   libxsmm_dnn_rnncell* libxsmm_handle;
78   libxsmm_dnn_tensor* libxsmm_input;
79   libxsmm_dnn_tensor* libxsmm_hidden_state_prev;
80   libxsmm_dnn_tensor* libxsmm_weight;
81   libxsmm_dnn_tensor* libxsmm_recur_weight;
82   libxsmm_dnn_tensor* libxsmm_bias;
83   libxsmm_dnn_tensor* libxsmm_hidden_state;
84   libxsmm_dnn_tensor* libxsmm_dinput;
85   libxsmm_dnn_tensor* libxsmm_dweight;
86   libxsmm_dnn_tensor* libxsmm_drecur_weight;
87   libxsmm_dnn_tensor* libxsmm_dbias;
88   libxsmm_dnn_tensor* libxsmm_dhidden_state;
89 
90   libxsmm_dnn_tensor_datalayout* libxsmm_layout;
91   libxsmm_dnn_err_t status;
92   libxsmm_dnn_err_t global_status = LIBXSMM_DNN_SUCCESS;
93 
94   libxsmm_matdiff_info norms_fwd, norms_bwd, norms_upd_w, norms_upd_u, norms_upd_b, diff;
95   libxsmm_matdiff_clear(&norms_fwd);
96   libxsmm_matdiff_clear(&norms_bwd);
97   libxsmm_matdiff_clear(&norms_upd_w);
98   libxsmm_matdiff_clear(&norms_upd_u);
99   libxsmm_matdiff_clear(&norms_upd_b);
100   libxsmm_matdiff_clear(&diff);
101 
102   if (argc > 1 && !strncmp(argv[1], "-h", 3)) {
103     printf("\nUsage: ./rnndriver [reps] [pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD] [nonlin: 1--ReLU, 2--sigmoid, 3--tanh] [N] [C] [K] [time_steps > 0]\n\n");
104     return 0;
105   }
106   libxsmm_rng_set_seed(1);
107 
108   /* reading new values from cli */
109   i = 1;
110   if (argc > i) iters = atoi(argv[i++]);
111   if (argc > i) pass  = atoi(argv[i++]);
112   if (argc > i) nonlin= atoi(argv[i++]);
113   if (argc > i) N     = atoi(argv[i++]);
114   if (argc > i) C     = atoi(argv[i++]);
115   if (argc > i) K     = atoi(argv[i++]);
116   if (argc > i) t     = atoi(argv[i++]);
117   if (argc > i) bn    = atoi(argv[i++]);
118   if (argc > i) bc    = atoi(argv[i++]);
119   if (argc > i) bk    = atoi(argv[i++]);
120 
121   if (t <= 0) {
122     printf("time_steps %d should be greater than 0\n\n", t);
123     return 0;
124   }
125   if (!(pass == 0 || pass == 1 || pass == 2 || pass == 3 || pass == 4)) {
126     printf("Unknown pass: %d, valid arguments for pass = {0(FWD), 1(BWD), 2(UPD), 3(BWD+UPD)\n\n", pass);
127     return 0;
128   }
129   if (nonlin != 1 && nonlin != 2 && nonlin != 3) {
130     printf("Unsupported non-linear function used [1--ReLU, 2--sigmoid, 3--tanh]\n\n");
131     return 0;
132   }
133 
134 #if defined(__SSE3__)
135   _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
136   _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
137   _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
138 #endif
139 
140   /* print some summary */
141   printf("##########################################\n");
142   printf("#          Setting Up (Common)           #\n");
143   printf("##########################################\n");
144   printf("PARAMS: N:%d  C:%d  K:%d  T:%d\n", N, C, K, t);
145   printf("PARAMS: ITERS:%d", iters); if (LIBXSMM_FEQ(0, check)) printf("  Threads:%d\n", nThreads); else printf("\n");
146   printf("SIZE Weight (MB): %10.2f MiB\n", (double)(C*K*sizeof(float))/(1024.0*1024.0) );
147   printf("SIZE Input (MB): %10.2f MiB\n", (double)(N*C*sizeof(float))/(1024.0*1024.0) );
148   printf("SIZE Hidden State: %10.2f MiB\n", (double)(K*N*sizeof(float))/(1024.0*1024.0) );
149 
150   /* allocate data */
151   xgoldt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
152   hpgold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
153   wgold  = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
154   ugold  = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
155   bgold  = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
156   hgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
157   zgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
158   bmgold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
159   z1gold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
160   z2gold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
161   djdxgoldt  = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
162   djdwgold   = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
163   djdugold   = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
164   djdbgold   = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
165   djdhgoldt  = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
166   deltagoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
167   zigold     = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
168   di1gold    = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
169   di2gold    = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
170   xgoldTp    = (float*)libxsmm_aligned_malloc(N*C*sizeof(float), 2097152);
171   wgoldTp    = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
172   ugoldTp    = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
173   hgoldTp    = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
174   xt     = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
175   hp     = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
176   w      = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
177   u      = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
178   ht     = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
179   b      = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
180   djdxt  = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
181   djdw   = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
182   djdu   = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
183   djdb   = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
184   djdht  = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
185   htest  = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
186   djdxtestt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
187   djdwtest  = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
188   djdutest  = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
189   LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N*C);
190   LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, K*N);
191   LIBXSMM_VLA_DECL(2, float, zgold, zgoldt, K*N);
192   LIBXSMM_VLA_DECL(2, float, djdxgold, djdxgoldt, N*C);
193   LIBXSMM_VLA_DECL(2, float, djdhgold, djdhgoldt, K*N);
194   LIBXSMM_VLA_DECL(2, float, deltagold, deltagoldt, K*N);
195 
196   /* initialize data */
197   /* All data in gold is considered to be in column-major format */
198   for (it = 0; it < t; ++it) {
199     init_buf(&LIBXSMM_VLA_ACCESS(2, xgold, it, 0, N*C), N*C, 0, 0);
200   }
201   init_buf(hpgold, N*K, 0, 0);
202   init_buf(wgold,  C*K, 0, 0);
203   init_buf(ugold,  K*K, 0, 0);
204   init_buf(bgold,  K,   0, 0);
205   for (j = 0; j < N; j++) {
206     matrix_copy(K, bgold, &(bmgold[j*K]));
207   }
208   zero_buf(hgoldt, K*N*t);
209   zero_buf(zgoldt, K*N*t);
210   zero_buf(z1gold, K*N);
211   zero_buf(z2gold, K*N);
212   for (it = 0; it < t; ++it) {
213     init_buf(&LIBXSMM_VLA_ACCESS(2, djdhgold, it, 0, K*N), N*K, 0, 0);
214   }
215   zero_buf(djdxgoldt, N*C*t);
216   zero_buf(djdwgold, C*K);
217   zero_buf(djdugold, K*K);
218   zero_buf(djdbgold, K);
219   zero_buf(deltagoldt, K*N*t);
220   zero_buf(zigold, K*N);
221   zero_buf(di1gold, K*N);
222   zero_buf(di2gold, K*N);
223   zero_buf(xgoldTp, N*C);
224   zero_buf(ugoldTp, K*K);
225   zero_buf(wgoldTp, C*K);
226   zero_buf(hgoldTp, K*N);
227 
228   /* first touch LIBXSMM */
229   zero_buf(xt, N*C*t);
230   zero_buf(hp, K*N);
231   zero_buf(w,  C*K);
232   zero_buf(u,  K*K);
233   zero_buf(b,  K);
234   zero_buf(ht, K*N*t);
235   zero_buf(djdxt,N*C*t);
236   zero_buf(djdw, C*K);
237   zero_buf(djdu, K*K);
238   zero_buf(djdb, K);
239   zero_buf(djdht, K*N*t);
240   LIBXSMM_VLA_DECL(2, float, h, ht, K*N);
241 
242   if (LIBXSMM_NEQ(0, check)) {
243     printf("##########################################\n");
244     printf("#         Computing Reference ...        #\n");
245     printf("##########################################\n");
246     for (i = 0; i < t; ++i) {
247       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wgold, &K, &LIBXSMM_VLA_ACCESS(2, xgold, i, 0, N*C), &C, &beta0, z1gold, &K);
248       if (0 == i) {
249         LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ugold, &K, hpgold, &K, &beta0, z2gold, &K);
250       } else {
251         LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ugold, &K, &LIBXSMM_VLA_ACCESS(2, hgold, i-1, 0, K*N), &K, &beta0, z2gold, &K);
252       }
253       matrix_add(K*N, z1gold, z2gold, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N));
254       matrix_add(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), bmgold, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N));
255       if (1 == nonlin) {
256         matrix_relu(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), &LIBXSMM_VLA_ACCESS(2, hgold, i, 0, K*N));
257       } else if (2 == nonlin) {
258         matrix_sigmoid(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), &LIBXSMM_VLA_ACCESS(2, hgold, i, 0, K*N));
259       } else {
260         matrix_tanh(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), &LIBXSMM_VLA_ACCESS(2, hgold, i, 0, K*N));
261       }
262     }
263     /* Conceptually, delta iterates over 0 ... t-1, whereas, djdh and z iterates over 1 ... t */
264     /* Hence these have identical array indices */
265     if (1 == nonlin) {
266       matrix_relu_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, t-1, 0, K*N), zigold);
267     } else if (2 == nonlin) {
268       matrix_sigmoid_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, t-1, 0, K*N), zigold);
269     } else {
270       matrix_tanh_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, t-1, 0, K*N), zigold);
271     }
272     matrix_eltwise_mult(K*N, zigold, &LIBXSMM_VLA_ACCESS(2, djdhgold, t-1, 0, K*N), &LIBXSMM_VLA_ACCESS(2, deltagold, t-1, 0, K*N));
273     matrix_transpose(K, K, ugold, ugoldTp);
274     for (i = t-2; i >= 0; --i) {
275       if (1 == nonlin) {
276         matrix_relu_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), zigold);
277       } else if (2 == nonlin) {
278         matrix_sigmoid_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), zigold);
279       } else {
280         matrix_tanh_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), zigold);
281       }
282       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ugoldTp, &K, &LIBXSMM_VLA_ACCESS(2, deltagold, i+1, 0, K*N), &K, &beta0, di1gold, &K);
283       matrix_add(K*N, &LIBXSMM_VLA_ACCESS(2, djdhgold, i, 0, K*N), di1gold, di2gold);
284       matrix_eltwise_mult(K*N, zigold, di2gold, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N));
285     }
286     if (pass == 1 || pass == 3) {
287       matrix_transpose(C, K, wgold, wgoldTp);
288       for (i = 0; i < t; ++i) {
289         LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &C, &N, &K, &alpha, wgoldTp, &C, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N), &K, &beta0, &LIBXSMM_VLA_ACCESS(2, djdxgold, i, 0, N*C), &C);
290       }
291     }
292     if (pass == 2 || pass == 3) {
293       for (i = 0; i < t; ++i) {
294         if (0 == i) {
295           matrix_transpose(N, K, hpgold, hgoldTp);
296         } else {
297           matrix_transpose(N, K, &LIBXSMM_VLA_ACCESS(2, hgold, i-1, 0, K*N), hgoldTp);
298         }
299         LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &K, &N, &alpha, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N), &K, hgoldTp, &N, &beta, djdugold, &K);
300         matrix_transpose(N, C, &LIBXSMM_VLA_ACCESS(2, xgold, i, 0, N*C), xgoldTp);
301         LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &C, &N, &alpha, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N), &K, xgoldTp, &N, &beta, djdwgold, &K);
302         for (j = 0; j < K*N; j++) {
303           djdbgold[j%K] += LIBXSMM_VLA_ACCESS(2, deltagold, i, j, K*N);
304         }
305       }
306     }
307     printf("##########################################\n");
308     printf("#      Computing Reference ... done      #\n");
309     printf("##########################################\n");
310   }
311 
312   if (1 /* format == 'A' || format == 'L' */) {
313     printf("\n");
314     printf("##########################################\n");
315     printf("#      Setting Up  (custom-Storage)      #\n");
316     printf("##########################################\n");
317 
318     if ( N % bn != 0 ) {
319       bn = N;
320     }
321     if ( C % bc != 0 ) {
322       bc = C;
323     }
324     if ( K % bk != 0 ) {
325       bk = K;
326     }
327 
328     /* setup LIBXSMM handle */
329     rnncell_desc.threads = nThreads;
330     rnncell_desc.N = N;
331     rnncell_desc.C = C;
332     rnncell_desc.K = K;
333     rnncell_desc.bn = bn;
334     rnncell_desc.bk = bk;
335     rnncell_desc.bc = bc;
336     rnncell_desc.max_T = t;
337 
338     if ( nonlin == 1 ) {
339       rnncell_desc.cell_type = LIBXSMM_DNN_RNNCELL_RNN_RELU;
340     } else if ( nonlin == 2 ) {
341       rnncell_desc.cell_type = LIBXSMM_DNN_RNNCELL_RNN_SIGMOID;
342     } else if ( nonlin == 3 ) {
343       rnncell_desc.cell_type = LIBXSMM_DNN_RNNCELL_RNN_TANH;
344     } else {
345       /* should not happen */
346     }
347     rnncell_desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
348     rnncell_desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
349     rnncell_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NC;
350     rnncell_desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_CK;
351 
352     libxsmm_handle = libxsmm_dnn_create_rnncell( rnncell_desc, &status );
353     CHKERR_LIBXSMM_DNN( status );
354 
355     /* setup LIBXSMM buffers and filter */
356     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
357     libxsmm_input = libxsmm_dnn_link_tensor( libxsmm_layout, xt, &status ); CHKERR_LIBXSMM_DNN( status );
358     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
359 
360     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV, &status ); CHKERR_LIBXSMM_DNN( status );
361     libxsmm_hidden_state_prev = libxsmm_dnn_link_tensor( libxsmm_layout, hp, &status ); CHKERR_LIBXSMM_DNN( status );
362     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
363 
364     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
365     libxsmm_weight = libxsmm_dnn_link_tensor( libxsmm_layout, w, &status ); CHKERR_LIBXSMM_DNN( status );
366     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
367 
368     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
369     libxsmm_recur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, u, &status ); CHKERR_LIBXSMM_DNN( status );
370     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
371 
372     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
373     libxsmm_bias = libxsmm_dnn_link_tensor( libxsmm_layout, b, &status ); CHKERR_LIBXSMM_DNN( status );
374     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
375 
376     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
377     libxsmm_hidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, ht, &status ); CHKERR_LIBXSMM_DNN( status );
378     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
379 
380     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
381     libxsmm_dinput = libxsmm_dnn_link_tensor( libxsmm_layout, djdxt, &status ); CHKERR_LIBXSMM_DNN( status );
382     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
383 
384     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
385     libxsmm_dweight = libxsmm_dnn_link_tensor( libxsmm_layout, djdw, &status ); CHKERR_LIBXSMM_DNN( status );
386     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
387 
388     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
389     libxsmm_drecur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, djdu, &status ); CHKERR_LIBXSMM_DNN( status );
390     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
391 
392     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
393     libxsmm_dbias = libxsmm_dnn_link_tensor( libxsmm_layout, djdb, &status ); CHKERR_LIBXSMM_DNN( status );
394     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
395 
396     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
397     libxsmm_dhidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, djdht, &status ); CHKERR_LIBXSMM_DNN( status );
398     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
399 
400     /* copy in data to LIBXSMM format */
401     matrix_copy( t*N*C, xgoldt, xt );
402     matrix_copy( K*N, hpgold, hp );
403     matrix_copy( C*K, wgold, w );
404     matrix_copy( K*K, ugold, u );
405     matrix_copy( K, bgold, b );
406     matrix_copy( t*K*N, djdhgoldt, djdht );
407 
408     /* bind buffers and filter to handle */
409     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_input, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
410     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state_prev, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
411     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_weight, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
412     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_recur_weight, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
413     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_bias, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
414     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
415     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dinput, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
416     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dweight, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
417     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_drecur_weight, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
418     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dbias, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
419     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dhidden_state, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
420 
421     /* let's allocate and bind scratch */
422     if (pass == 0) {
423       scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
424       CHKERR_LIBXSMM_DNN( status );
425       scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
426       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, scratch ) );
427     } else {
428       scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
429       CHKERR_LIBXSMM_DNN( status );
430       scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
431       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ) );
432     }
433     zero_buf( (float*)scratch, scratch_size/4 );
434 
435     /* let's allocate and bind internalstate */
436     if (pass == 0) {
437       internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
438       CHKERR_LIBXSMM_DNN( status );
439       internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
440       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, internalstate ) );
441     } else {
442       internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
443       CHKERR_LIBXSMM_DNN( status );
444       internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
445       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, internalstate ) );
446     }
447     zero_buf( (float*)internalstate, internalstate_size/4 );
448 
449     if ((pass == 0) && LIBXSMM_NEQ(0, check)) {
450       printf("##########################################\n");
451       printf("#   Correctness - FWD (custom-Storage)   #\n");
452       printf("##########################################\n");
453       /* run LIBXSMM RNN */
454 #if defined(_OPENMP)
455 #     pragma omp parallel
456 #endif
457       {
458 #if defined(_OPENMP)
459         const int tid = omp_get_thread_num();
460 #else
461         const int tid = 0;
462 #endif
463         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
464       }
465       matrix_copy( N*K, &LIBXSMM_VLA_ACCESS(2, h, t-1, 0, K*N), htest );
466 
467       /* compare */
468       libxsmm_matdiff(&norms_fwd, LIBXSMM_DATATYPE_F32, K*N, 1, &LIBXSMM_VLA_ACCESS(2, hgold, t-1, 0, K*N), htest, 0, 0);
469       printf("L1 reference  : %.25g\n", norms_fwd.l1_ref);
470       printf("L1 test       : %.25g\n", norms_fwd.l1_tst);
471       printf("L2 abs.error  : %.24f\n", norms_fwd.l2_abs);
472       printf("L2 rel.error  : %.24f\n", norms_fwd.l2_rel);
473       printf("Linf abs.error: %.24f\n", norms_fwd.linf_abs);
474       printf("Linf rel.error: %.24f\n", norms_fwd.linf_rel);
475       printf("Check-norm    : %.24f\n", norms_fwd.normf_rel);
476       libxsmm_matdiff_reduce(&diff, &norms_fwd);
477     } else {
478       /* We need to always run FWD pass once to populate zt, ht */
479 #if defined(_OPENMP)
480 #     pragma omp parallel
481 #endif
482       {
483 #if defined(_OPENMP)
484         const int tid = omp_get_thread_num();
485 #else
486         const int tid = 0;
487 #endif
488         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
489       }
490     }
491 
492     if ( (pass == 1) && LIBXSMM_NEQ(0, check) ) {
493       printf("##########################################\n");
494       printf("#   Correctness - BWD (custom-Storage)   #\n");
495       printf("##########################################\n");
496       /* run LIBXSMM RNN */
497 #if defined(_OPENMP)
498 #     pragma omp parallel
499 #endif
500       {
501 #if defined(_OPENMP)
502         const int tid = omp_get_thread_num();
503 #else
504         const int tid = 0;
505 #endif
506         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid ) );
507       }
508 
509       /* copy out data */
510       matrix_copy(N*C*t, djdxt, djdxtestt);
511 
512       /* compare */
513       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, djdxgoldt, djdxtestt, 0, 0);
514       printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
515       printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
516       printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
517       printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
518       printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
519       printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
520       printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
521       libxsmm_matdiff_reduce(&diff, &norms_bwd);
522     }
523 
524     if ( (pass == 2) && LIBXSMM_NEQ(0, check) ) {
525       printf("##########################################\n");
526       printf("#   Correctness - UPD (custom-Storage)   #\n");
527       printf("##########################################\n");
528       /* run LIBXSMM RNN */
529 #if defined(_OPENMP)
530 #     pragma omp parallel
531 #endif
532       {
533 #if defined(_OPENMP)
534         const int tid = omp_get_thread_num();
535 #else
536         const int tid = 0;
537 #endif
538         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid ) );
539       }
540 
541       /* copy out data */
542       matrix_copy(C*K, djdw, djdwtest);
543       matrix_copy(K*K, djdu, djdutest);
544 
545       /* compare */
546       libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K, 1, djdwgold, djdwtest, 0, 0);
547       printf("Delta weight\n");
548       printf("L1 reference  : %.25g\n", norms_upd_w.l1_ref);
549       printf("L1 test       : %.25g\n", norms_upd_w.l1_tst);
550       printf("L2 abs.error  : %.24f\n", norms_upd_w.l2_abs);
551       printf("L2 rel.error  : %.24f\n", norms_upd_w.l2_rel);
552       printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
553       printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
554       printf("Check-norm    : %.24f\n", norms_upd_w.normf_rel);
555       libxsmm_matdiff_reduce(&diff, &norms_upd_w);
556 
557       libxsmm_matdiff(&norms_upd_u, LIBXSMM_DATATYPE_F32, K*K, 1, djdugold, djdutest, 0, 0);
558       printf("Delta recurrent weight\n");
559       printf("L1 reference  : %.25g\n", norms_upd_u.l1_ref);
560       printf("L1 test       : %.25g\n", norms_upd_u.l1_tst);
561       printf("L2 abs.error  : %.24f\n", norms_upd_u.l2_abs);
562       printf("L2 rel.error  : %.24f\n", norms_upd_u.l2_rel);
563       printf("Linf abs.error: %.24f\n", norms_upd_u.linf_abs);
564       printf("Linf rel.error: %.24f\n", norms_upd_u.linf_rel);
565       printf("Check-norm    : %.24f\n", norms_upd_u.normf_rel);
566       libxsmm_matdiff_reduce(&diff, &norms_upd_u);
567 
568       libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K, 1, djdbgold, djdb, 0, 0);
569       printf("Delta bias\n");
570       printf("L1 reference  : %.25g\n", norms_upd_b.l1_ref);
571       printf("L1 test       : %.25g\n", norms_upd_b.l1_tst);
572       printf("L2 abs.error  : %.24f\n", norms_upd_b.l2_abs);
573       printf("L2 rel.error  : %.24f\n", norms_upd_b.l2_rel);
574       printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
575       printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
576       printf("Check-norm    : %.24f\n", norms_upd_b.normf_rel);
577       libxsmm_matdiff_reduce(&diff, &norms_upd_b);
578     }
579 
580     if ( (pass == 3) && LIBXSMM_NEQ(0, check) ) {
581       printf("##########################################\n");
582       printf("# Correctness - BWD+UPD (custom-Storage) #\n");
583       printf("##########################################\n");
584       /* run LIBXSMM RNN */
585 #if defined(_OPENMP)
586 #     pragma omp parallel
587 #endif
588       {
589 #if defined(_OPENMP)
590         const int tid = omp_get_thread_num();
591 #else
592         const int tid = 0;
593 #endif
594         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid ) );
595       }
596 
597       /* copy out data */
598       matrix_copy(N*C*t, djdxt, djdxtestt);
599       matrix_copy(C*K, djdw, djdwtest);
600       matrix_copy(K*K, djdu, djdutest);
601 
602       /* compare */
603       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, djdxgoldt, djdxtestt, 0, 0);
604       printf("Delta input\n");
605       printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
606       printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
607       printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
608       printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
609       printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
610       printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
611       printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
612       libxsmm_matdiff_reduce(&diff, &norms_bwd);
613 
614       libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K, 1, djdwgold, djdwtest, 0, 0);
615       printf("Delta weight\n");
616       printf("L1 reference  : %.25g\n", norms_upd_w.l1_ref);
617       printf("L1 test       : %.25g\n", norms_upd_w.l1_tst);
618       printf("L2 abs.error  : %.24f\n", norms_upd_w.l2_abs);
619       printf("L2 rel.error  : %.24f\n", norms_upd_w.l2_rel);
620       printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
621       printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
622       printf("Check-norm    : %.24f\n", norms_upd_w.normf_rel);
623       libxsmm_matdiff_reduce(&diff, &norms_upd_w);
624 
625       libxsmm_matdiff(&norms_upd_u, LIBXSMM_DATATYPE_F32, K*K, 1, djdugold, djdutest, 0, 0);
626       printf("Delta recurrent weight\n");
627       printf("L1 reference  : %.25g\n", norms_upd_u.l1_ref);
628       printf("L1 test       : %.25g\n", norms_upd_u.l1_tst);
629       printf("L2 abs.error  : %.24f\n", norms_upd_u.l2_abs);
630       printf("L2 rel.error  : %.24f\n", norms_upd_u.l2_rel);
631       printf("Linf abs.error: %.24f\n", norms_upd_u.linf_abs);
632       printf("Linf rel.error: %.24f\n", norms_upd_u.linf_rel);
633       printf("Check-norm    : %.24f\n", norms_upd_u.normf_rel);
634       libxsmm_matdiff_reduce(&diff, &norms_upd_u);
635 
636       libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K, 1, djdbgold, djdb, 0, 0);
637       printf("Delta bias\n");
638       printf("L1 reference  : %.25g\n", norms_upd_b.l1_ref);
639       printf("L1 test       : %.25g\n", norms_upd_b.l1_tst);
640       printf("L2 abs.error  : %.24f\n", norms_upd_b.l2_abs);
641       printf("L2 rel.error  : %.24f\n", norms_upd_b.l2_rel);
642       printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
643       printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
644       printf("Check-norm    : %.24f\n", norms_upd_b.normf_rel);
645       libxsmm_matdiff_reduce(&diff, &norms_upd_b);
646     }
647 
648     if ( pass == 0 ) {
649       printf("##########################################\n");
650       printf("#   Performance - FWD (custom-Storage)   #\n");
651       printf("##########################################\n");
652       /* run LIBXSMM RNN for performance */
653       l_start = libxsmm_timer_tick();
654 
655 #if defined(_OPENMP)
656 #     pragma omp parallel private(i)
657 #endif
658       {
659 #if defined(_OPENMP)
660         const int tid = omp_get_thread_num();
661 #else
662         const int tid = 0;
663 #endif
664         for (i = 0; i < iters; ++i) {
665           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
666         }
667       }
668       l_end = libxsmm_timer_tick();
669       l_total = libxsmm_timer_duration(l_start, l_end);
670       flops = ((2.0 * K*N*C) + (2.0 * K*N*K) + (K*N) + (tflops * K*N)) * (double)t * (double)iters;
671 
672       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
673       printf("fp time = %.5g\n", ((double)(l_total/iters)));
674       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
675 
676       printf("PERFDUMP,FP,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
677     }
678 
679     if ( pass == 1 ) {
680       printf("##########################################\n");
681       printf("#   Performance - BWD (custom-Storage)   #\n");
682       printf("##########################################\n");
683       /* run LIBXSMM RNN for performance */
684       l_start = libxsmm_timer_tick();
685 
686 #if defined(_OPENMP)
687 #     pragma omp parallel private(i)
688 #endif
689       {
690 #if defined(_OPENMP)
691         const int tid = omp_get_thread_num();
692 #else
693         const int tid = 0;
694 #endif
695         for (i = 0; i < iters; ++i) {
696           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid );
697         }
698       }
699       l_end = libxsmm_timer_tick();
700       l_total = libxsmm_timer_duration(l_start, l_end);
701       flops = K*K; /* U^T */
702       flops += (2.0 * K*N*K); /* U^T * delta */
703       flops += (K*N); /* dJdh + (U^T * delta) */
704       flops += (tflops * K*N); /* sigma'(Z) */
705       flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
706       flops *= t; /* for t time steps */
707       tempflops = C*K; /* W^T */
708       tempflops += (2.0 * K*N*C); /* W^T * delta */
709       tempflops *= t; /* for t time steps of input */
710       flops += tempflops;
711       flops *= iters;
712 
713       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
714       printf("bp time = %.5g\n", ((double)(l_total/iters)));
715       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
716 
717       printf("PERFDUMP,BP,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
718     }
719 
720     if ( pass == 2 ) {
721       printf("##########################################\n");
722       printf("#   Performance - UPD (custom-Storage)   #\n");
723       printf("##########################################\n");
724       /* run LIBXSMM RNN for performance */
725       l_start = libxsmm_timer_tick();
726 
727 #if defined(_OPENMP)
728 #     pragma omp parallel private(i)
729 #endif
730       {
731 #if defined(_OPENMP)
732         const int tid = omp_get_thread_num();
733 #else
734         const int tid = 0;
735 #endif
736         for (i = 0; i < iters; ++i) {
737           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid );
738         }
739       }
740       l_end = libxsmm_timer_tick();
741       l_total = libxsmm_timer_duration(l_start, l_end);
742       flops = K*K; /* U^T */
743       flops += (2.0 * K*N*K); /* U^T * delta */
744       flops += (K*N); /* dJdh + (U^T * delta) */
745       flops += (tflops * K*N); /* sigma'(Z) */
746       flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
747       flops *= t; /* for t time steps */
748       tempflops = K*N; /* h^T */
749       tempflops += (2.0 * K*N*K); /* delta * h^T */
750       tempflops *= t; /* for t time steps */
751       tempflops += (K*K * (t-1)); /* for summation of dJdU */
752       flops += tempflops;
753       tempflops = N*C; /* x^T */
754       tempflops += (2.0 * K*N*C); /* delta * x^T */
755       tempflops *= t; /* for t time steps */
756       tempflops += (C*K * (t-1)); /* for summation of dJdW */
757       flops += tempflops;
758       flops *= iters;
759 
760       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
761       printf("wu time = %.5g\n", ((double)(l_total/iters)));
762       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
763 
764       printf("PERFDUMP,WU,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
765     }
766 
767     if ( pass == 3 ) {
768       printf("##########################################\n");
769       printf("# Performance - BWD+UPD (custom-Storage) #\n");
770       printf("##########################################\n");
771       /* run LIBXSMM RNN for performance */
772       l_start = libxsmm_timer_tick();
773 
774 #if defined(_OPENMP)
775 #     pragma omp parallel private(i)
776 #endif
777       {
778 #if defined(_OPENMP)
779         const int tid = omp_get_thread_num();
780 #else
781         const int tid = 0;
782 #endif
783         for (i = 0; i < iters; ++i) {
784           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
785         }
786       }
787       l_end = libxsmm_timer_tick();
788       l_total = libxsmm_timer_duration(l_start, l_end);
789       flops = K*K; /* U^T */
790       flops += (2.0 * K*N*K); /* U^T * delta */
791       flops += (K*N); /* dJdh + (U^T * delta) */
792       flops += (tflops * K*N); /* sigma'(Z) */
793       flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
794       flops *= t; /* for t time steps */
795       tempflops = K*N; /* h^T */
796       tempflops += (2.0 * K*N*K); /* delta * h^T */
797       tempflops *= t; /* for t time steps */
798       tempflops += (K*K * (t-1)); /* for summation of dJdU */
799       flops += tempflops;
800       tempflops = N*C; /* x^T */
801       tempflops += (2.0 * K*N*C); /* delta * x^T */
802       tempflops *= t; /* for t time steps */
803       tempflops += (C*K * (t-1)); /* for summation of dJdW */
804       flops += tempflops;
805       tempflops = C*K; /* W^T */
806       tempflops += (2.0 * K*N*C); /* W^T * delta */
807       tempflops *= t; /* for t time steps of input */
808       flops += tempflops;
809       flops *= iters;
810 
811       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
812       printf("bp+wu time = %.5g\n", ((double)(l_total/iters)));
813       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
814 
815       printf("PERFDUMP,BP+WU,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
816     }
817 
818     if ( pass == 4 ) {
819       printf("#############################################\n");
820       printf("# Performance - FWD+BWD+UPD (nc-ck Storage) #\n");
821       printf("#############################################\n");
822       /* run LIBXSMM RNN for performance */
823       l_start = libxsmm_timer_tick();
824 
825 #if defined(_OPENMP)
826 #     pragma omp parallel private(i)
827 #endif
828       {
829 #if defined(_OPENMP)
830         const int tid = omp_get_thread_num();
831 #else
832         const int tid = 0;
833 #endif
834         for (i = 0; i < iters; ++i) {
835           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
836           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
837         }
838       }
839       l_end = libxsmm_timer_tick();
840       l_total = libxsmm_timer_duration(l_start, l_end);
841       flops = (2.0 * K*N*K); /* U^T * delta */
842       flops += (K*N); /* dJdh + (U^T * delta) */
843       flops += (tflops * K*N); /* sigma'(Z) */
844       flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
845       flops *= t; /* for t time steps */
846       tempflops = (2.0 * K*N*K); /* delta * h^T */
847       tempflops *= t; /* for t time steps */
848       tempflops += (K*K * (t-1)); /* for summation of dJdU */
849       flops += tempflops;
850       tempflops = (2.0 * K*N*C); /* delta * x^T */
851       tempflops *= t; /* for t time steps */
852       tempflops += (C*K * (t-1)); /* for summation of dJdW */
853       flops += tempflops;
854       tempflops = (2.0 * K*N*C); /* W^T * delta */
855       tempflops *= t; /* for t time steps of input */
856       flops += tempflops;
857       flops *= iters;
858       flops += ((2.0 * K*N*C) + (2.0 * K*N*K) + (K*N) + (tflops * K*N)) * (double)t * (double)iters;
859 
860       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
861       printf("fp+bp+wu time = %.5g\n", ((double)(l_total/iters)));
862       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
863 
864       printf("PERFDUMP,FP+BP+WU,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
865     }
866 
867     /* clean-up */
868     if (pass == 0) {
869       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
870       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
871     } else {
872       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
873       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
874     }
875     libxsmm_free(scratch);
876     libxsmm_free(internalstate);
877     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
878     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
879     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
880     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
881     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
882     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
883     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
884     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
885     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
886     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
887     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
888     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_input ) );
889     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state_prev ) );
890     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_weight ) );
891     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_recur_weight ) );
892     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_bias ) );
893     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state ) );
894     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dinput ) );
895     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dweight ) );
896     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_drecur_weight ) );
897     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dbias ) );
898     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dhidden_state ) );
899     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_rnncell( libxsmm_handle ) );
900   }
901 
902   /* deallocate data */
903   libxsmm_free(xgoldt);
904   libxsmm_free(hpgold);
905   libxsmm_free(wgold);
906   libxsmm_free(ugold);
907   libxsmm_free(bgold);
908   libxsmm_free(hgoldt);
909   libxsmm_free(zgoldt);
910   libxsmm_free(bmgold);
911   libxsmm_free(z1gold);
912   libxsmm_free(z2gold);
913   libxsmm_free(djdxgoldt);
914   libxsmm_free(djdwgold);
915   libxsmm_free(djdugold);
916   libxsmm_free(djdbgold);
917   libxsmm_free(djdhgoldt);
918   libxsmm_free(deltagoldt);
919   libxsmm_free(zigold);
920   libxsmm_free(di1gold);
921   libxsmm_free(di2gold);
922   libxsmm_free(xgoldTp);
923   libxsmm_free(wgoldTp);
924   libxsmm_free(ugoldTp);
925   libxsmm_free(hgoldTp);
926   libxsmm_free(xt);
927   libxsmm_free(hp);
928   libxsmm_free(w);
929   libxsmm_free(u);
930   libxsmm_free(b);
931   libxsmm_free(ht);
932   libxsmm_free(djdxt);
933   libxsmm_free(djdw);
934   libxsmm_free(djdu);
935   libxsmm_free(djdb);
936   libxsmm_free(djdht);
937   libxsmm_free(htest);
938   libxsmm_free(djdxtestt);
939   libxsmm_free(djdwtest);
940   libxsmm_free(djdutest);
941 
942   { const char *const env_check_scale = getenv("CHECK_SCALE");
943     const double check_scale = LIBXSMM_ABS(0 == env_check_scale ? 1.0 : atof(env_check_scale));
944     if (LIBXSMM_NEQ(0, check) && (check < 100.0 * check_scale * diff.normf_rel) && (global_status == LIBXSMM_DNN_SUCCESS)) {
945       fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
946       exit(EXIT_FAILURE);
947     }
948   }
949 
950   /* some empty lines at the end */
951   printf("\n\n\n");
952 
953   return global_status;
954 }
955 
956