1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include <libxsmm_intrinsics_x86.h>
13 
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <math.h>
21 #if defined(_OPENMP)
22 # include <omp.h>
23 #endif
24 #if defined(LIBXSMM_OFFLOAD_TARGET)
25 # pragma offload_attribute(pop)
26 #endif
27 
28 /* include c-based dnn library */
29 #include "../common/dnn_common.h"
30 
31 #define CHKERR_LIBXSMM_DNN(A) { const int chkerr_libxsmm_dnn_ = A; if (LIBXSMM_DNN_SUCCESS != chkerr_libxsmm_dnn_) { \
32   fprintf(stderr, "%s\n", libxsmm_dnn_get_error(chkerr_libxsmm_dnn_)); global_status = chkerr_libxsmm_dnn_; } \
33 }
34 
main(int argc,char * argv[])35 int main(int argc, char* argv[])
36 {
37   float *wigold, *wcgold, *wfgold, *rigold, *rcgold, *rfgold, *bigold, *bcgold, *bfgold;
38   float *xgoldt, *hpgold, *hgoldt;
39   float *dwgold, *drgold, *dbgold;
40   float *dxgoldt, *dhpgold, *dhgoldt;
41   float *igoldt, *cgoldt, *fgoldt, *ogoldt;
42   float *xt, *hp, *w, *r, *b, *ht;
43   float *it, *ct, *ft, *ot;
44   float *dxt, *dhp, *dw, *dr, *db, *dht;
45   float *scratch_bu, *dwtest, *drtest;
46 
47   void *scratch, *internalstate;
48   size_t scratch_size = 0, internalstate_size = 0;
49 
50   int iters = 10;   /* repetitions of benchmark */
51   int pass = 0;     /* pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD */
52   int N = 168;      /* size of mini-batch */
53   int C = 512;      /* number of inputs */
54   int K = 256;      /* number of outputs */
55   int t = 50;       /* number of time steps (>= 1) */
56   int bn = 24;
57   int bc = 64;
58   int bk = 64;
59 
60   const char *const env_check = getenv("CHECK");
61   const double check = LIBXSMM_ABS(0 == env_check ? 0/*disabled by default*/ : atof(env_check));
62 
63 #if defined(_OPENMP)
64   int nThreads = omp_get_max_threads(); /* number of threads */
65 #else
66   int nThreads = 1; /* number of threads */
67 #endif
68 
69   unsigned long long l_start, l_end;
70   double l_total = 0.0;
71   double flops = 0.0;
72   const double tflops = 12; /* transcendental flops */
73   int j;
74 
75   libxsmm_dnn_rnncell_desc grucell_desc;
76   libxsmm_dnn_rnncell* libxsmm_handle;
77   libxsmm_dnn_tensor* libxsmm_input;
78   libxsmm_dnn_tensor* libxsmm_hidden_state_prev;
79   libxsmm_dnn_tensor* libxsmm_weight;
80   libxsmm_dnn_tensor* libxsmm_recur_weight;
81   libxsmm_dnn_tensor* libxsmm_bias;
82   libxsmm_dnn_tensor* libxsmm_hidden_state;
83   libxsmm_dnn_tensor* libxsmm_i;
84   libxsmm_dnn_tensor* libxsmm_c;
85   libxsmm_dnn_tensor* libxsmm_f;
86   libxsmm_dnn_tensor* libxsmm_o;
87   libxsmm_dnn_tensor* libxsmm_dinput;
88   libxsmm_dnn_tensor* libxsmm_dhidden_state_prev;
89   libxsmm_dnn_tensor* libxsmm_dweight;
90   libxsmm_dnn_tensor* libxsmm_drecur_weight;
91   libxsmm_dnn_tensor* libxsmm_dbias;
92   libxsmm_dnn_tensor* libxsmm_dhidden_state;
93 
94   libxsmm_dnn_tensor_datalayout* libxsmm_layout;
95   libxsmm_dnn_err_t status;
96   libxsmm_dnn_err_t global_status = LIBXSMM_DNN_SUCCESS;
97 
98   libxsmm_matdiff_info norms_fwd, norms_bwd, norms_upd_w, norms_upd_r, norms_upd_b, diff;
99   libxsmm_matdiff_clear(&norms_fwd);
100   libxsmm_matdiff_clear(&norms_bwd);
101   libxsmm_matdiff_clear(&norms_upd_w);
102   libxsmm_matdiff_clear(&norms_upd_r);
103   libxsmm_matdiff_clear(&norms_upd_b);
104   libxsmm_matdiff_clear(&diff);
105 
106   if (argc > 1 && !strncmp(argv[1], "-h", 3)) {
107     printf("\nUsage: ./grudriver [reps] [pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD] [N] [C] [K] [time_steps > 0]\n\n");
108     return 0;
109   }
110   libxsmm_rng_set_seed(1);
111 
112   /* reading new values from cli */
113   j = 1;
114   if (argc > j) iters = atoi(argv[j++]);
115   if (argc > j) pass  = atoi(argv[j++]);
116   if (argc > j) N     = atoi(argv[j++]);
117   if (argc > j) C     = atoi(argv[j++]);
118   if (argc > j) K     = atoi(argv[j++]);
119   if (argc > j) t     = atoi(argv[j++]);
120   if (argc > j) bn    = atoi(argv[j++]);
121   if (argc > j) bc    = atoi(argv[j++]);
122   if (argc > j) bk    = atoi(argv[j++]);
123 
124   if (t <= 0) {
125     printf("time_steps %d should be greater than or equal to 1\n\n", t);
126     return 0;
127   }
128   if (!(pass == 0 || pass == 1 || pass == 2 || pass == 3)) {
129     printf("Unknown pass: %d, valid arguments for pass = {0(FWD), 1(BWD), 2(UPD), 3(BWD+UPD)\n\n", pass);
130     return 0;
131   }
132 
133 #if defined(__SSE3__)
134   _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
135   _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
136   _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
137 #endif
138 
139   /* print some summary */
140   printf("##########################################\n");
141   printf("#          Setting Up (Common)           #\n");
142   printf("##########################################\n");
143   printf("PARAMS: N:%d  C:%d  K:%d  T:%d\n", N, C, K, t);
144   printf("PARAMS: ITERS:%d", iters); if (LIBXSMM_FEQ(0, check)) printf("  Threads:%d\n", nThreads); else printf("\n");
145   printf("SIZE Weight (MB): %10.2f MiB\n", (double)(C*K*sizeof(float))/(1024.0*1024.0) );
146   printf("SIZE Input (MB): %10.2f MiB\n", (double)(N*C*sizeof(float))/(1024.0*1024.0) );
147   printf("SIZE Hidden State: %10.2f MiB\n", (double)(K*N*sizeof(float))/(1024.0*1024.0) );
148 
149   /* allocate data */
150   xgoldt     = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
151   hpgold     = (float*)libxsmm_aligned_malloc(K*N*sizeof(float),   2097152);
152   wigold     = (float*)libxsmm_aligned_malloc(C*K*sizeof(float),   2097152);
153   wcgold     = (float*)libxsmm_aligned_malloc(C*K*sizeof(float),   2097152);
154   wfgold     = (float*)libxsmm_aligned_malloc(C*K*sizeof(float),   2097152);
155   rigold     = (float*)libxsmm_aligned_malloc(K*K*sizeof(float),   2097152);
156   rcgold     = (float*)libxsmm_aligned_malloc(K*K*sizeof(float),   2097152);
157   rfgold     = (float*)libxsmm_aligned_malloc(K*K*sizeof(float),   2097152);
158   bigold     = (float*)libxsmm_aligned_malloc(K*sizeof(float),     2097152);
159   bcgold     = (float*)libxsmm_aligned_malloc(K*sizeof(float),     2097152);
160   bfgold     = (float*)libxsmm_aligned_malloc(K*sizeof(float),     2097152);
161   hgoldt     = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
162   igoldt     = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
163   cgoldt     = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
164   fgoldt     = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
165   ogoldt     = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
166   dxgoldt    = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
167   dhpgold    = (float*)libxsmm_aligned_malloc(K*N*sizeof(float),   2097152);
168   dwgold     = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
169   drgold     = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
170   dbgold     = (float*)libxsmm_aligned_malloc(K*3*sizeof(float),   2097152);
171   dhgoldt    = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
172   scratch_bu = (float*)libxsmm_aligned_malloc(K*N*6*sizeof(float), 2097152);
173   xt         = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
174   hp         = (float*)libxsmm_aligned_malloc(K*N*sizeof(float),   2097152);
175   w          = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
176   r          = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
177   b          = (float*)libxsmm_aligned_malloc(K*3*sizeof(float),   2097152);
178   ht         = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
179   it         = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
180   ct         = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
181   ft         = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
182   ot         = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
183   dxt        = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
184   dhp        = (float*)libxsmm_aligned_malloc(K*N*sizeof(float),   2097152);
185   dw         = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
186   dr         = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
187   db         = (float*)libxsmm_aligned_malloc(K*3*sizeof(float),   2097152);
188   dht        = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
189   dwtest     = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
190   drtest     = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
191   LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N * C);
192   LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, N * K);
193   /*LIBXSMM_VLA_DECL(2, float, igold, igoldt, N * K);*/
194   /*LIBXSMM_VLA_DECL(2, float, cgold, cgoldt, N * K);*/
195   /*LIBXSMM_VLA_DECL(2, float, fgold, fgoldt, N * K);*/
196   /*LIBXSMM_VLA_DECL(2, float, ogold, ogoldt, N * K);*/
197   /*LIBXSMM_VLA_DECL(2, float, dxgold, dxgoldt, N * C);*/
198   LIBXSMM_VLA_DECL(2, float, dhgold, dhgoldt, N * K);
199   LIBXSMM_VLA_DECL(2, float, h, ht, N * K);
200 
201   /* initialize data */
202   /* FWD */
203   for (j = 0; j < t; ++j) {
204     LIBXSMM_MATINIT_OMP(float, 24, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), N, C, N, 1.0);
205   }
206   LIBXSMM_MATINIT_OMP(float, 24, hpgold, N, K, N, 1.0);
207   LIBXSMM_MATINIT_OMP(float, 42, wigold, C, K, C, 1.0);
208   LIBXSMM_MATINIT_OMP(float, 42, wcgold, C, K, C, 1.0);
209   LIBXSMM_MATINIT_OMP(float, 42, wfgold, C, K, C, 1.0);
210   LIBXSMM_MATINIT_OMP(float, 42, rigold, K, K, K, 1.0);
211   LIBXSMM_MATINIT_OMP(float, 42, rcgold, K, K, K, 1.0);
212   LIBXSMM_MATINIT_OMP(float, 42, rfgold, K, K, K, 1.0);
213   LIBXSMM_MATINIT_OMP(float, 24, bigold, 1, K, 1, 1.0);
214   LIBXSMM_MATINIT_OMP(float, 24, bcgold, 1, K, 1, 1.0);
215   LIBXSMM_MATINIT_OMP(float, 24, bfgold, 1, K, 1, 1.0);
216   zero_buf(hgoldt, N*K*t);
217   /* BWD/UPD */
218   for (j = 0; j < t; ++j) {
219     LIBXSMM_MATINIT_OMP(float, 24, &LIBXSMM_VLA_ACCESS(2, dhgold, j, 0, K * N), N, K, N, 1.0);
220   }
221   zero_buf(dxgoldt, N*C*t);
222   zero_buf(dhpgold, K*N);
223   zero_buf(dwgold,  C*K*3);
224   zero_buf(drgold,  K*K*3);
225   zero_buf(dbgold,  K*3);
226 
227   /* first touch LIBXSMM */
228   zero_buf(xt,  N*C*t);
229   zero_buf(hp,  K*N);
230   zero_buf(w,   C*K*3);
231   zero_buf(r,   K*K*3);
232   zero_buf(b,   K*3);
233   zero_buf(ht,  N*K*t);
234   zero_buf(it,  K*N*t);
235   zero_buf(ct,  K*N*t);
236   zero_buf(ft,  K*N*t);
237   zero_buf(ot,  K*N*t);
238   zero_buf(dxt, N*C*t);
239   zero_buf(dhp, K*N);
240   zero_buf(dw,  C*K*3);
241   zero_buf(dr,  K*K*3);
242   zero_buf(db,  K*3);
243   zero_buf(dht, K*N*t);
244 
245   if (LIBXSMM_NEQ(0, check)) {
246     printf("##########################################\n");
247     printf("#         Computing Reference ...        #\n");
248     printf("##########################################\n");
249 
250     gru_ref_fwd( N, C, K, t,
251                  wigold, wcgold, wfgold,
252                  rigold, rcgold, rfgold,
253                  bigold, bcgold, bfgold,
254                  xgoldt, hpgold, hgoldt,
255                  igoldt, cgoldt, fgoldt, ogoldt );
256 
257     gru_ref_bwd_upd( N, C, K, t,
258                      xgoldt, hpgold, hgoldt,
259                      igoldt, cgoldt, fgoldt, ogoldt,
260                      wigold, wcgold, wfgold,
261                      rigold, rcgold, rfgold,
262                      dhgoldt, dwgold, drgold, dbgold,
263                      dxgoldt, dhpgold, scratch_bu );
264 
265     printf("##########################################\n");
266     printf("#      Computing Reference ... done      #\n");
267     printf("##########################################\n");
268   }
269 
270   if (1 /* format == 'A' || format == 'L' */) {
271     printf("\n");
272     printf("##########################################\n");
273     printf("#      Setting Up  (custom-Storage)      #\n");
274     printf("##########################################\n");
275 
276     /* setup LIBXSMM handle */
277     grucell_desc.threads = nThreads;
278     grucell_desc.N = N;
279     grucell_desc.C = C;
280     grucell_desc.K = K;
281     grucell_desc.max_T = t;
282     grucell_desc.bn = bn;
283     grucell_desc.bc = bc;
284     grucell_desc.bk = bk;
285     grucell_desc.cell_type = LIBXSMM_DNN_RNNCELL_GRU;
286     grucell_desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
287     grucell_desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
288     grucell_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NC;
289     grucell_desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_CK;
290 
291     libxsmm_handle = libxsmm_dnn_create_rnncell( grucell_desc, &status );
292     CHKERR_LIBXSMM_DNN( status );
293 
294     /* setup LIBXSMM buffers and filter */
295     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
296     libxsmm_input = libxsmm_dnn_link_tensor( libxsmm_layout, xt, &status ); CHKERR_LIBXSMM_DNN( status );
297     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
298 
299     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV, &status ); CHKERR_LIBXSMM_DNN( status );
300     libxsmm_hidden_state_prev = libxsmm_dnn_link_tensor( libxsmm_layout, hp, &status ); CHKERR_LIBXSMM_DNN( status );
301     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
302 
303     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
304     libxsmm_weight = libxsmm_dnn_link_tensor( libxsmm_layout, w, &status ); CHKERR_LIBXSMM_DNN( status );
305     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
306 
307     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
308     libxsmm_recur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, r, &status ); CHKERR_LIBXSMM_DNN( status );
309     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
310 
311     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
312     libxsmm_bias = libxsmm_dnn_link_tensor( libxsmm_layout, b, &status ); CHKERR_LIBXSMM_DNN( status );
313     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
314 
315     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
316     libxsmm_hidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, ht, &status ); CHKERR_LIBXSMM_DNN( status );
317     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
318 
319     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_I, &status ); CHKERR_LIBXSMM_DNN( status );
320     libxsmm_i = libxsmm_dnn_link_tensor( libxsmm_layout, it, &status ); CHKERR_LIBXSMM_DNN( status );
321     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
322 
323     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_CI, &status ); CHKERR_LIBXSMM_DNN( status );
324     libxsmm_c = libxsmm_dnn_link_tensor( libxsmm_layout, ct, &status ); CHKERR_LIBXSMM_DNN( status );
325     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
326 
327     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_F, &status ); CHKERR_LIBXSMM_DNN( status );
328     libxsmm_f = libxsmm_dnn_link_tensor( libxsmm_layout, ft, &status ); CHKERR_LIBXSMM_DNN( status );
329     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
330 
331     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_O, &status ); CHKERR_LIBXSMM_DNN( status );
332     libxsmm_o = libxsmm_dnn_link_tensor( libxsmm_layout, ot, &status ); CHKERR_LIBXSMM_DNN( status );
333     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
334 
335     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
336     libxsmm_dinput = libxsmm_dnn_link_tensor( libxsmm_layout, dxt, &status ); CHKERR_LIBXSMM_DNN( status );
337     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
338 
339     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV, &status ); CHKERR_LIBXSMM_DNN( status );
340     libxsmm_dhidden_state_prev = libxsmm_dnn_link_tensor( libxsmm_layout, dhp, &status ); CHKERR_LIBXSMM_DNN( status );
341     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
342 
343     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
344     libxsmm_dweight = libxsmm_dnn_link_tensor( libxsmm_layout, dw, &status ); CHKERR_LIBXSMM_DNN( status );
345     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
346 
347     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
348     libxsmm_drecur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, dr, &status ); CHKERR_LIBXSMM_DNN( status );
349     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
350 
351     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
352     libxsmm_dbias = libxsmm_dnn_link_tensor( libxsmm_layout, db, &status ); CHKERR_LIBXSMM_DNN( status );
353     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
354 
355     libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
356     libxsmm_dhidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, dht, &status ); CHKERR_LIBXSMM_DNN( status );
357     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
358 
359     /* copy in data to LIBXSMM format */
360     matrix_copy(N*C*t, xgoldt, xt);
361     matrix_copy(K*N, hpgold, hp);
362     convert_ck_c3k(C, K, wigold, w);
363     convert_ck_c3k(C, K, wcgold, &(w[K]));
364     convert_ck_c3k(C, K, wfgold, &(w[2*K]));
365     convert_ck_c3k(K, K, rigold, r);
366     convert_ck_c3k(K, K, rcgold, &(r[K]));
367     convert_ck_c3k(K, K, rfgold, &(r[2*K]));
368     matrix_copy(K, bigold, b);
369     matrix_copy(K, bcgold, &(b[K]));
370     matrix_copy(K, bfgold, &(b[2*K]));
371     matrix_copy(K*N*t, dhgoldt, dht);
372 
373     /* bind buffers and filter to handle */
374     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_input, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
375     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state_prev, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
376     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_weight, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
377     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_recur_weight, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
378     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_bias, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
379     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
380     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_i, LIBXSMM_DNN_RNN_INTERNAL_I ) );
381     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_c, LIBXSMM_DNN_RNN_INTERNAL_CI ) );
382     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_f, LIBXSMM_DNN_RNN_INTERNAL_F ) );
383     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_o, LIBXSMM_DNN_RNN_INTERNAL_O ) );
384     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dinput, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
385     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dhidden_state_prev, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) );
386     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dweight, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
387     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_drecur_weight, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
388     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dbias, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
389     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dhidden_state, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
390 
391     /* let's allocate and bind scratch */
392     if (pass == 0) {
393       scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
394       CHKERR_LIBXSMM_DNN( status );
395       scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
396       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, scratch ) );
397     } else {
398       scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
399       CHKERR_LIBXSMM_DNN( status );
400       scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
401       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ) );
402     }
403     zero_buf( (float*)scratch, scratch_size/4 );
404 
405     /* let's allocate and bind internalstate */
406     if (pass == 0) {
407       internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
408       CHKERR_LIBXSMM_DNN( status );
409       internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
410       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, internalstate ) );
411     } else {
412       internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
413       CHKERR_LIBXSMM_DNN( status );
414       internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
415       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, internalstate ) );
416     }
417     zero_buf( (float*)internalstate, internalstate_size/4 );
418 
419     if ((pass == 0) && LIBXSMM_NEQ(0, check)) {
420       printf("##########################################\n");
421       printf("#   Correctness - FWD (custom-Storage)   #\n");
422       printf("##########################################\n");
423       /* run LIBXSMM GRU */
424 #if defined(_OPENMP)
425 #     pragma omp parallel
426 #endif
427       {
428 #if defined(_OPENMP)
429         const int tid = omp_get_thread_num();
430 #else
431         const int tid = 0;
432 #endif
433         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
434       }
435 
436       /* compare */
437       libxsmm_matdiff(&norms_fwd, LIBXSMM_DATATYPE_F32, K*N, 1, &LIBXSMM_VLA_ACCESS(2, hgold, t-1, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, t-1, 0, K * N), 0, 0);
438       printf("L1 reference  : %.25g\n", norms_fwd.l1_ref);
439       printf("L1 test       : %.25g\n", norms_fwd.l1_tst);
440       printf("L2 abs.error  : %.24f\n", norms_fwd.l2_abs);
441       printf("L2 rel.error  : %.24f\n", norms_fwd.l2_rel);
442       printf("Linf abs.error: %.24f\n", norms_fwd.linf_abs);
443       printf("Linf rel.error: %.24f\n", norms_fwd.linf_rel);
444       printf("Check-norm    : %.24f\n", norms_fwd.normf_rel);
445       libxsmm_matdiff_reduce(&diff, &norms_fwd);
446     } else {
447       /* We need to always run FWD pass once to populate i, c, f, o, h */
448 #if defined(_OPENMP)
449 #     pragma omp parallel
450 #endif
451       {
452 #if defined(_OPENMP)
453         const int tid = omp_get_thread_num();
454 #else
455         const int tid = 0;
456 #endif
457         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
458       }
459     }
460 
461     if ( (pass == 1) && LIBXSMM_NEQ(0, check) ) {
462       printf("##########################################\n");
463       printf("#   Correctness - BWD (custom-Storage)   #\n");
464       printf("##########################################\n");
465       /* run LIBXSMM GRU */
466 #if defined(_OPENMP)
467 #     pragma omp parallel
468 #endif
469       {
470 #if defined(_OPENMP)
471         const int tid = omp_get_thread_num();
472 #else
473         const int tid = 0;
474 #endif
475         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid ) );
476       }
477 
478       /* compare */
479       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, dxgoldt, dxt, 0, 0);
480       printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
481       printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
482       printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
483       printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
484       printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
485       printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
486       printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
487       libxsmm_matdiff_reduce(&diff, &norms_bwd);
488     }
489 
490     if ( (pass == 2) && LIBXSMM_NEQ(0, check) ) {
491       printf("##########################################\n");
492       printf("#   Correctness - UPD (custom-Storage)   #\n");
493       printf("##########################################\n");
494       /* run LIBXSMM GRU */
495 #if defined(_OPENMP)
496 #     pragma omp parallel
497 #endif
498       {
499 #if defined(_OPENMP)
500         const int tid = omp_get_thread_num();
501 #else
502         const int tid = 0;
503 #endif
504         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid ) );
505       }
506 
507       convert_ck_c3k(C, K, &(dwgold[0]),     &(dwtest[0]));
508       convert_ck_c3k(C, K, &(dwgold[C*K]),   &(dwtest[K]));
509       convert_ck_c3k(C, K, &(dwgold[2*C*K]), &(dwtest[2*K]));
510       convert_ck_c3k(K, K, &(drgold[0]),     &(drtest[0]));
511       convert_ck_c3k(K, K, &(drgold[K*K]),   &(drtest[K]));
512       convert_ck_c3k(K, K, &(drgold[2*K*K]), &(drtest[2*K]));
513       /* compare */
514       libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K*3, 1, dwtest, dw, 0, 0);
515       printf("Delta weight\n");
516       printf("L1 reference  : %.25g\n", norms_upd_w.l1_ref);
517       printf("L1 test       : %.25g\n", norms_upd_w.l1_tst);
518       printf("L2 abs.error  : %.24f\n", norms_upd_w.l2_abs);
519       printf("L2 rel.error  : %.24f\n", norms_upd_w.l2_rel);
520       printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
521       printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
522       printf("Check-norm    : %.24f\n", norms_upd_w.normf_rel);
523       libxsmm_matdiff_reduce(&diff, &norms_upd_w);
524 
525       libxsmm_matdiff(&norms_upd_r, LIBXSMM_DATATYPE_F32, K*K*3, 1, drtest, dr, 0, 0);
526       printf("Delta recurrent weight\n");
527       printf("L1 reference  : %.25g\n", norms_upd_r.l1_ref);
528       printf("L1 test       : %.25g\n", norms_upd_r.l1_tst);
529       printf("L2 abs.error  : %.24f\n", norms_upd_r.l2_abs);
530       printf("L2 rel.error  : %.24f\n", norms_upd_r.l2_rel);
531       printf("Linf abs.error: %.24f\n", norms_upd_r.linf_abs);
532       printf("Linf rel.error: %.24f\n", norms_upd_r.linf_rel);
533       printf("Check-norm    : %.24f\n", norms_upd_r.normf_rel);
534       libxsmm_matdiff_reduce(&diff, &norms_upd_r);
535 
536       libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K*3, 1, dbgold, db, 0, 0);
537       printf("Delta bias\n");
538       printf("L1 reference  : %.25g\n", norms_upd_b.l1_ref);
539       printf("L1 test       : %.25g\n", norms_upd_b.l1_tst);
540       printf("L2 abs.error  : %.24f\n", norms_upd_b.l2_abs);
541       printf("L2 rel.error  : %.24f\n", norms_upd_b.l2_rel);
542       printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
543       printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
544       printf("Check-norm    : %.24f\n", norms_upd_b.normf_rel);
545       libxsmm_matdiff_reduce(&diff, &norms_upd_b);
546     }
547 
548     if ( (pass == 3) && LIBXSMM_NEQ(0, check) ) {
549       printf("##########################################\n");
550       printf("# Correctness - BWD+UPD (custom-Storage) #\n");
551       printf("##########################################\n");
552       /* run LIBXSMM GRU */
553 #if defined(_OPENMP)
554 #     pragma omp parallel
555 #endif
556       {
557 #if defined(_OPENMP)
558         const int tid = omp_get_thread_num();
559 #else
560         const int tid = 0;
561 #endif
562         CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid ) );
563       }
564 
565       convert_ck_c3k(C, K, &(dwgold[0]),     &(dwtest[0]));
566       convert_ck_c3k(C, K, &(dwgold[C*K]),   &(dwtest[K]));
567       convert_ck_c3k(C, K, &(dwgold[2*C*K]), &(dwtest[2*K]));
568       convert_ck_c3k(K, K, &(drgold[0]),     &(drtest[0]));
569       convert_ck_c3k(K, K, &(drgold[K*K]),   &(drtest[K]));
570       convert_ck_c3k(K, K, &(drgold[2*K*K]), &(drtest[2*K]));
571       /* compare */
572       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, dxgoldt, dxt, 0, 0);
573       printf("Delta input\n");
574       printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
575       printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
576       printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
577       printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
578       printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
579       printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
580       printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
581       libxsmm_matdiff_reduce(&diff, &norms_bwd);
582 
583       libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K*3, 1, dwtest, dw, 0, 0);
584       printf("Delta weight\n");
585       printf("L1 reference  : %.25g\n", norms_upd_w.l1_ref);
586       printf("L1 test       : %.25g\n", norms_upd_w.l1_tst);
587       printf("L2 abs.error  : %.24f\n", norms_upd_w.l2_abs);
588       printf("L2 rel.error  : %.24f\n", norms_upd_w.l2_rel);
589       printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
590       printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
591       printf("Check-norm    : %.24f\n", norms_upd_w.normf_rel);
592       libxsmm_matdiff_reduce(&diff, &norms_upd_w);
593 
594       libxsmm_matdiff(&norms_upd_r, LIBXSMM_DATATYPE_F32, K*K*3, 1, drtest, dr, 0, 0);
595       printf("Delta recurrent weight\n");
596       printf("L1 reference  : %.25g\n", norms_upd_r.l1_ref);
597       printf("L1 test       : %.25g\n", norms_upd_r.l1_tst);
598       printf("L2 abs.error  : %.24f\n", norms_upd_r.l2_abs);
599       printf("L2 rel.error  : %.24f\n", norms_upd_r.l2_rel);
600       printf("Linf abs.error: %.24f\n", norms_upd_r.linf_abs);
601       printf("Linf rel.error: %.24f\n", norms_upd_r.linf_rel);
602       printf("Check-norm    : %.24f\n", norms_upd_r.normf_rel);
603       libxsmm_matdiff_reduce(&diff, &norms_upd_r);
604 
605       libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K*3, 1, dbgold, db, 0, 0);
606       printf("Delta bias\n");
607       printf("L1 reference  : %.25g\n", norms_upd_b.l1_ref);
608       printf("L1 test       : %.25g\n", norms_upd_b.l1_tst);
609       printf("L2 abs.error  : %.24f\n", norms_upd_b.l2_abs);
610       printf("L2 rel.error  : %.24f\n", norms_upd_b.l2_rel);
611       printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
612       printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
613       printf("Check-norm    : %.24f\n", norms_upd_b.normf_rel);
614       libxsmm_matdiff_reduce(&diff, &norms_upd_b);
615     }
616 
617     if ( pass == 0 ) {
618       printf("##########################################\n");
619       printf("#   Performance - FWD (custom-Storage)   #\n");
620       printf("##########################################\n");
621       /* run LIBXSMM GRU for performance */
622       l_start = libxsmm_timer_tick();
623 
624 #if defined(_OPENMP)
625 #     pragma omp parallel private(j)
626 #endif
627       {
628 #if defined(_OPENMP)
629         const int tid = omp_get_thread_num();
630 #else
631         const int tid = 0;
632 #endif
633         for (j = 0; j < iters; ++j) {
634           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
635         }
636       }
637       l_end = libxsmm_timer_tick();
638       l_total = libxsmm_timer_duration(l_start, l_end);
639       flops = (((2.0 * K * N * C) + (2.0 * K * N * K) + (2.0 * K * N) + (tflops * K * N)) * 2.0 + (K * N) + (2.0 * K * N * C) + (2.0 * K * N * K) + (tflops * K * N) + 4.0 * (K * N)) * (double)t * (double)iters;
640 
641       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
642       printf("fp time = %.5g\n", ((double)(l_total/iters)));
643       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
644 
645       printf("PERFDUMP,FP,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
646     }
647 
648     if ( pass == 1 ) {
649       printf("##########################################\n");
650       printf("#   Performance - BWD (custom-Storage)   #\n");
651       printf("##########################################\n");
652       /* run LIBXSMM GRU for performance */
653       l_start = libxsmm_timer_tick();
654 
655 #if defined(_OPENMP)
656 #     pragma omp parallel private(j)
657 #endif
658       {
659 #if defined(_OPENMP)
660         const int tid = omp_get_thread_num();
661 #else
662         const int tid = 0;
663 #endif
664         for (j = 0; j < iters; ++j) {
665           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid );
666         }
667       }
668       l_end = libxsmm_timer_tick();
669       l_total = libxsmm_timer_duration(l_start, l_end);
670       flops = K * N; /* d3 = djdh + d23 (delta) */
671       flops += 2.0 * K * N; /* d4 = (1 - z).d3 */
672       flops += K * N; /* d5 = d3.h */
673       flops += K * N; /* d6 = -d5 */
674       flops += K * N; /* d7 = d3.g */
675       flops += K * N; /* d8 = d3.z */
676       flops += K * N; /* d9 = d7 + d8 */
677       flops += 3.0 * K * N; /* d10 = d8.tanh'(g) */
678       flops += 3.0 * K * N; /* d11 = d9.sig'(z) */
679       flops += (2.0 * K * K * N + K * K); /* d13 = Wg^T * d10 (including transpose) */
680       flops += (2.0 * K * K * N + K * K); /* d15 = Wz^T * d11 (including transpose) */
681       flops += K * N; /* d16 = d13.z */
682       flops += K * N; /* d17 = d13.r */
683       flops += 3.0 * K * N; /* d18 = d16.sig'(r) */
684       flops += K * N; /* d19 = d17 + d4 */
685       flops += (2.0 * K * K * N + K * K); /* d21 = Wr^T * d18 (including transpose) */
686       flops += K * N; /* d22 = d21 + d15 */
687       flops += K * N; /* d23 = d19 + d22 */
688       flops += (2.0 * K * C * N + K * C); /* d12 = Ug^T * d10 (including transpose) */
689       flops += (2.0 * K * C * N + K * C); /* d14 = Uz^T * d11 (including transpose) */
690       flops += (2.0 * K * C * N + K * C); /* d20 = Ur^T * d18 (including transpose) */
691       flops += 2.0 * K * N; /* djdx = d12 + d14 + d20 */
692       flops *= t; /* for t time steps */
693       flops *= iters;
694 
695       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
696       printf("bp time = %.5g\n", ((double)(l_total/iters)));
697       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
698 
699       printf("PERFDUMP,BP,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
700     }
701 
702     if ( pass == 2 ) {
703       printf("##########################################\n");
704       printf("#   Performance - UPD (custom-Storage)   #\n");
705       printf("##########################################\n");
706       /* run LIBXSMM GRU for performance */
707       l_start = libxsmm_timer_tick();
708 
709 #if defined(_OPENMP)
710 #     pragma omp parallel private(j)
711 #endif
712       {
713 #if defined(_OPENMP)
714         const int tid = omp_get_thread_num();
715 #else
716         const int tid = 0;
717 #endif
718         for (j = 0; j < iters; ++j) {
719           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid );
720         }
721       }
722       l_end = libxsmm_timer_tick();
723       l_total = libxsmm_timer_duration(l_start, l_end);
724       flops = K * N; /* d3 = djdh + d23 (delta) */
725       flops += 2.0 * K * N; /* d4 = (1 - z).d3 */
726       flops += K * N; /* d5 = d3.h */
727       flops += K * N; /* d6 = -d5 */
728       flops += K * N; /* d7 = d3.g */
729       flops += K * N; /* d8 = d3.z */
730       flops += K * N; /* d9 = d7 + d8 */
731       flops += 3.0 * K * N; /* d10 = d8.tanh'(g) */
732       flops += 3.0 * K * N; /* d11 = d9.sig'(z) */
733       flops += (2.0 * K * K * N + K * K); /* d13 = Wg^T * d10 (including transpose) */
734       flops += (2.0 * K * K * N + K * K); /* d15 = Wz^T * d11 (including transpose) */
735       flops += K * N; /* d16 = d13.z */
736       flops += K * N; /* d17 = d13.r */
737       flops += 3.0 * K * N; /* d18 = d16.sig'(r) */
738       flops += K * N; /* d19 = d17 + d4 */
739       flops += (2.0 * K * K * N + K * K); /* d21 = Wr^T * d18 (including transpose) */
740       flops += K * N; /* d22 = d21 + d15 */
741       flops += K * N; /* d23 = d19 + d22 */
742       flops += (2.0 * K * N * K + K * N + K * K); /* djdwr = djdwr + d18 * h^T */
743       flops += (2.0 * K * N * K + K * N + K * K); /* djdwz = djdwz + d11 * h^T */
744       flops += (2.0 * K * N * K + 2.0 * K * N + K * K); /* djdwg = djdwg + d10 * (h.r)^T */
745       flops += (2.0 * K * N * C + C * N + K * C); /* djdur = djdur + d18 * x^T */
746       flops += (2.0 * K * N * C + C * N + K * C); /* djduz = djduz + d11 * x^T */
747       flops += (2.0 * K * N * C + C * N + K * C); /* djdug = djdug + d10 * x^T */
748       flops += K * N; /* djdbr = djdbr + d18 */
749       flops += K * N; /* djdbz = djdbz + d11 */
750       flops += K * N; /* djdbg = djdbg + d10 */
751       flops *= t; /* for t time steps */
752       flops *= iters;
753 
754       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
755       printf("wu time = %.5g\n", ((double)(l_total/iters)));
756       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
757 
758       printf("PERFDUMP,WU,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
759     }
760 
761     if ( pass == 3 ) {
762       printf("##########################################\n");
763       printf("# Performance - BWD+UPD (custom-Storage) #\n");
764       printf("##########################################\n");
765       /* run LIBXSMM GRU for performance */
766       l_start = libxsmm_timer_tick();
767 
768 #if defined(_OPENMP)
769 #     pragma omp parallel private(j)
770 #endif
771       {
772 #if defined(_OPENMP)
773         const int tid = omp_get_thread_num();
774 #else
775         const int tid = 0;
776 #endif
777         for (j = 0; j < iters; ++j) {
778           libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
779         }
780       }
781       l_end = libxsmm_timer_tick();
782       l_total = libxsmm_timer_duration(l_start, l_end);
783       flops = K * N; /* d3 = djdh + d23 (delta) */
784       flops += 2.0 * K * N; /* d4 = (1 - z).d3 */
785       flops += K * N; /* d5 = d3.h */
786       flops += K * N; /* d6 = -d5 */
787       flops += K * N; /* d7 = d3.g */
788       flops += K * N; /* d8 = d3.z */
789       flops += K * N; /* d9 = d7 + d8 */
790       flops += 3.0 * K * N; /* d10 = d8.tanh'(g) */
791       flops += 3.0 * K * N; /* d11 = d9.sig'(z) */
792       flops += (2.0 * K * K * N + K * K); /* d13 = Wg^T * d10 (including transpose) */
793       flops += (2.0 * K * K * N + K * K); /* d15 = Wz^T * d11 (including transpose) */
794       flops += K * N; /* d16 = d13.z */
795       flops += K * N; /* d17 = d13.r */
796       flops += 3.0 * K * N; /* d18 = d16.sig'(r) */
797       flops += K * N; /* d19 = d17 + d4 */
798       flops += (2.0 * K * K * N + K * K); /* d21 = Wr^T * d18 (including transpose) */
799       flops += K * N; /* d22 = d21 + d15 */
800       flops += K * N; /* d23 = d19 + d22 */
801       flops += (2.0 * K * C * N + K * C); /* d12 = Ug^T * d10 (including transpose) */
802       flops += (2.0 * K * C * N + K * C); /* d14 = Uz^T * d11 (including transpose) */
803       flops += (2.0 * K * C * N + K * C); /* d20 = Ur^T * d18 (including transpose) */
804       flops += 2.0 * K * N; /* djdx = d12 + d14 + d20 */
805       flops += (2.0 * K * N * K + K * N + K * K); /* djdwr = djdwr + d18 * h^T */
806       flops += (2.0 * K * N * K + K * N + K * K); /* djdwz = djdwz + d11 * h^T */
807       flops += (2.0 * K * N * K + 2.0 * K * N + K * K); /* djdwg = djdwg + d10 * (h.r)^T */
808       flops += (2.0 * K * N * C + C * N + K * C); /* djdur = djdur + d18 * x^T */
809       flops += (2.0 * K * N * C + C * N + K * C); /* djduz = djduz + d11 * x^T */
810       flops += (2.0 * K * N * C + C * N + K * C); /* djdug = djdug + d10 * x^T */
811       flops += K * N; /* djdbr = djdbr + d18 */
812       flops += K * N; /* djdbz = djdbz + d11 */
813       flops += K * N; /* djdbg = djdbg + d10 */
814       flops *= t; /* for t time steps */
815       flops *= iters;
816 
817       printf("GFLOP  = %.5g\n", flops*1e-9/(double)iters);
818       printf("bp+wu time = %.5g\n", ((double)(l_total/iters)));
819       printf("GFLOPS  = %.5g\n", (flops*1e-9)/l_total);
820 
821       printf("PERFDUMP,BP+WU,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
822     }
823 
824     /* clean-up */
825     if (pass == 0) {
826       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
827       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
828     } else {
829       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
830       CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
831     }
832     libxsmm_free(scratch);
833     libxsmm_free(internalstate);
834     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
835     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
836     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
837     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
838     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
839     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
840     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_I ) );
841     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_CI ) );
842     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_F ) );
843     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_O ) );
844     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
845     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) );
846     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
847     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
848     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
849     CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
850     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_input ) );
851     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state_prev ) );
852     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_weight ) );
853     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_recur_weight ) );
854     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_bias ) );
855     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state ) );
856     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_i ) );
857     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_c ) );
858     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_f ) );
859     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_o ) );
860     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dinput ) );
861     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dhidden_state_prev ) );
862     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dweight ) );
863     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_drecur_weight ) );
864     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dbias ) );
865     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dhidden_state ) );
866     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_rnncell( libxsmm_handle ) );
867   }
868 
869   /* deallocate data */
870   libxsmm_free(xgoldt);
871   libxsmm_free(hpgold);
872   libxsmm_free(wigold);
873   libxsmm_free(wcgold);
874   libxsmm_free(wfgold);
875   libxsmm_free(rigold);
876   libxsmm_free(rcgold);
877   libxsmm_free(rfgold);
878   libxsmm_free(bigold);
879   libxsmm_free(bcgold);
880   libxsmm_free(bfgold);
881   libxsmm_free(hgoldt);
882   libxsmm_free(igoldt);
883   libxsmm_free(cgoldt);
884   libxsmm_free(fgoldt);
885   libxsmm_free(ogoldt);
886   libxsmm_free(dxgoldt);
887   libxsmm_free(dhpgold);
888   libxsmm_free(dwgold);
889   libxsmm_free(drgold);
890   libxsmm_free(dbgold);
891   libxsmm_free(dhgoldt);
892   libxsmm_free(xt);
893   libxsmm_free(hp);
894   libxsmm_free(w);
895   libxsmm_free(r);
896   libxsmm_free(b);
897   libxsmm_free(ht);
898   libxsmm_free(it);
899   libxsmm_free(ct);
900   libxsmm_free(ft);
901   libxsmm_free(ot);
902   libxsmm_free(dxt);
903   libxsmm_free(dhp);
904   libxsmm_free(dw);
905   libxsmm_free(dr);
906   libxsmm_free(db);
907   libxsmm_free(dht);
908   libxsmm_free(dwtest);
909   libxsmm_free(drtest);
910 
911   { const char *const env_check_scale = getenv("CHECK_SCALE");
912     const double check_scale = LIBXSMM_ABS(0 == env_check_scale ? 1.0 : atof(env_check_scale));
913     if (LIBXSMM_NEQ(0, check) && (check < 100.0 * check_scale * diff.normf_rel) && (global_status == LIBXSMM_DNN_SUCCESS)) {
914       fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
915       exit(EXIT_FAILURE);
916     }
917   }
918 
919   /* some empty lines at the end */
920   printf("\n\n\n");
921 
922   return global_status;
923 }
924 
925