1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include <libxsmm_intrinsics_x86.h>
13
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <math.h>
21 #if defined(_OPENMP)
22 # include <omp.h>
23 #endif
24 #if defined(LIBXSMM_OFFLOAD_TARGET)
25 # pragma offload_attribute(pop)
26 #endif
27
28 /* include c-based dnn library */
29 #include "../common/dnn_common.h"
30
31 #define CHKERR_LIBXSMM_DNN(A) { const int chkerr_libxsmm_dnn_ = A; if (LIBXSMM_DNN_SUCCESS != chkerr_libxsmm_dnn_) { \
32 fprintf(stderr, "%s\n", libxsmm_dnn_get_error(chkerr_libxsmm_dnn_)); global_status = chkerr_libxsmm_dnn_; } \
33 }
34
main(int argc,char * argv[])35 int main(int argc, char* argv[])
36 {
37 float *wigold, *wcgold, *wfgold, *rigold, *rcgold, *rfgold, *bigold, *bcgold, *bfgold;
38 float *xgoldt, *hpgold, *hgoldt;
39 float *dwgold, *drgold, *dbgold;
40 float *dxgoldt, *dhpgold, *dhgoldt;
41 float *igoldt, *cgoldt, *fgoldt, *ogoldt;
42 float *xt, *hp, *w, *r, *b, *ht;
43 float *it, *ct, *ft, *ot;
44 float *dxt, *dhp, *dw, *dr, *db, *dht;
45 float *scratch_bu, *dwtest, *drtest;
46
47 void *scratch, *internalstate;
48 size_t scratch_size = 0, internalstate_size = 0;
49
50 int iters = 10; /* repetitions of benchmark */
51 int pass = 0; /* pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD */
52 int N = 168; /* size of mini-batch */
53 int C = 512; /* number of inputs */
54 int K = 256; /* number of outputs */
55 int t = 50; /* number of time steps (>= 1) */
56 int bn = 24;
57 int bc = 64;
58 int bk = 64;
59
60 const char *const env_check = getenv("CHECK");
61 const double check = LIBXSMM_ABS(0 == env_check ? 0/*disabled by default*/ : atof(env_check));
62
63 #if defined(_OPENMP)
64 int nThreads = omp_get_max_threads(); /* number of threads */
65 #else
66 int nThreads = 1; /* number of threads */
67 #endif
68
69 unsigned long long l_start, l_end;
70 double l_total = 0.0;
71 double flops = 0.0;
72 const double tflops = 12; /* transcendental flops */
73 int j;
74
75 libxsmm_dnn_rnncell_desc grucell_desc;
76 libxsmm_dnn_rnncell* libxsmm_handle;
77 libxsmm_dnn_tensor* libxsmm_input;
78 libxsmm_dnn_tensor* libxsmm_hidden_state_prev;
79 libxsmm_dnn_tensor* libxsmm_weight;
80 libxsmm_dnn_tensor* libxsmm_recur_weight;
81 libxsmm_dnn_tensor* libxsmm_bias;
82 libxsmm_dnn_tensor* libxsmm_hidden_state;
83 libxsmm_dnn_tensor* libxsmm_i;
84 libxsmm_dnn_tensor* libxsmm_c;
85 libxsmm_dnn_tensor* libxsmm_f;
86 libxsmm_dnn_tensor* libxsmm_o;
87 libxsmm_dnn_tensor* libxsmm_dinput;
88 libxsmm_dnn_tensor* libxsmm_dhidden_state_prev;
89 libxsmm_dnn_tensor* libxsmm_dweight;
90 libxsmm_dnn_tensor* libxsmm_drecur_weight;
91 libxsmm_dnn_tensor* libxsmm_dbias;
92 libxsmm_dnn_tensor* libxsmm_dhidden_state;
93
94 libxsmm_dnn_tensor_datalayout* libxsmm_layout;
95 libxsmm_dnn_err_t status;
96 libxsmm_dnn_err_t global_status = LIBXSMM_DNN_SUCCESS;
97
98 libxsmm_matdiff_info norms_fwd, norms_bwd, norms_upd_w, norms_upd_r, norms_upd_b, diff;
99 libxsmm_matdiff_clear(&norms_fwd);
100 libxsmm_matdiff_clear(&norms_bwd);
101 libxsmm_matdiff_clear(&norms_upd_w);
102 libxsmm_matdiff_clear(&norms_upd_r);
103 libxsmm_matdiff_clear(&norms_upd_b);
104 libxsmm_matdiff_clear(&diff);
105
106 if (argc > 1 && !strncmp(argv[1], "-h", 3)) {
107 printf("\nUsage: ./grudriver [reps] [pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD] [N] [C] [K] [time_steps > 0]\n\n");
108 return 0;
109 }
110 libxsmm_rng_set_seed(1);
111
112 /* reading new values from cli */
113 j = 1;
114 if (argc > j) iters = atoi(argv[j++]);
115 if (argc > j) pass = atoi(argv[j++]);
116 if (argc > j) N = atoi(argv[j++]);
117 if (argc > j) C = atoi(argv[j++]);
118 if (argc > j) K = atoi(argv[j++]);
119 if (argc > j) t = atoi(argv[j++]);
120 if (argc > j) bn = atoi(argv[j++]);
121 if (argc > j) bc = atoi(argv[j++]);
122 if (argc > j) bk = atoi(argv[j++]);
123
124 if (t <= 0) {
125 printf("time_steps %d should be greater than or equal to 1\n\n", t);
126 return 0;
127 }
128 if (!(pass == 0 || pass == 1 || pass == 2 || pass == 3)) {
129 printf("Unknown pass: %d, valid arguments for pass = {0(FWD), 1(BWD), 2(UPD), 3(BWD+UPD)\n\n", pass);
130 return 0;
131 }
132
133 #if defined(__SSE3__)
134 _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
135 _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
136 _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
137 #endif
138
139 /* print some summary */
140 printf("##########################################\n");
141 printf("# Setting Up (Common) #\n");
142 printf("##########################################\n");
143 printf("PARAMS: N:%d C:%d K:%d T:%d\n", N, C, K, t);
144 printf("PARAMS: ITERS:%d", iters); if (LIBXSMM_FEQ(0, check)) printf(" Threads:%d\n", nThreads); else printf("\n");
145 printf("SIZE Weight (MB): %10.2f MiB\n", (double)(C*K*sizeof(float))/(1024.0*1024.0) );
146 printf("SIZE Input (MB): %10.2f MiB\n", (double)(N*C*sizeof(float))/(1024.0*1024.0) );
147 printf("SIZE Hidden State: %10.2f MiB\n", (double)(K*N*sizeof(float))/(1024.0*1024.0) );
148
149 /* allocate data */
150 xgoldt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
151 hpgold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
152 wigold = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
153 wcgold = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
154 wfgold = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
155 rigold = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
156 rcgold = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
157 rfgold = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
158 bigold = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
159 bcgold = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
160 bfgold = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
161 hgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
162 igoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
163 cgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
164 fgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
165 ogoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
166 dxgoldt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
167 dhpgold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
168 dwgold = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
169 drgold = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
170 dbgold = (float*)libxsmm_aligned_malloc(K*3*sizeof(float), 2097152);
171 dhgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
172 scratch_bu = (float*)libxsmm_aligned_malloc(K*N*6*sizeof(float), 2097152);
173 xt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
174 hp = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
175 w = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
176 r = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
177 b = (float*)libxsmm_aligned_malloc(K*3*sizeof(float), 2097152);
178 ht = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
179 it = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
180 ct = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
181 ft = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
182 ot = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
183 dxt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
184 dhp = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
185 dw = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
186 dr = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
187 db = (float*)libxsmm_aligned_malloc(K*3*sizeof(float), 2097152);
188 dht = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
189 dwtest = (float*)libxsmm_aligned_malloc(C*K*3*sizeof(float), 2097152);
190 drtest = (float*)libxsmm_aligned_malloc(K*K*3*sizeof(float), 2097152);
191 LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N * C);
192 LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, N * K);
193 /*LIBXSMM_VLA_DECL(2, float, igold, igoldt, N * K);*/
194 /*LIBXSMM_VLA_DECL(2, float, cgold, cgoldt, N * K);*/
195 /*LIBXSMM_VLA_DECL(2, float, fgold, fgoldt, N * K);*/
196 /*LIBXSMM_VLA_DECL(2, float, ogold, ogoldt, N * K);*/
197 /*LIBXSMM_VLA_DECL(2, float, dxgold, dxgoldt, N * C);*/
198 LIBXSMM_VLA_DECL(2, float, dhgold, dhgoldt, N * K);
199 LIBXSMM_VLA_DECL(2, float, h, ht, N * K);
200
201 /* initialize data */
202 /* FWD */
203 for (j = 0; j < t; ++j) {
204 LIBXSMM_MATINIT_OMP(float, 24, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), N, C, N, 1.0);
205 }
206 LIBXSMM_MATINIT_OMP(float, 24, hpgold, N, K, N, 1.0);
207 LIBXSMM_MATINIT_OMP(float, 42, wigold, C, K, C, 1.0);
208 LIBXSMM_MATINIT_OMP(float, 42, wcgold, C, K, C, 1.0);
209 LIBXSMM_MATINIT_OMP(float, 42, wfgold, C, K, C, 1.0);
210 LIBXSMM_MATINIT_OMP(float, 42, rigold, K, K, K, 1.0);
211 LIBXSMM_MATINIT_OMP(float, 42, rcgold, K, K, K, 1.0);
212 LIBXSMM_MATINIT_OMP(float, 42, rfgold, K, K, K, 1.0);
213 LIBXSMM_MATINIT_OMP(float, 24, bigold, 1, K, 1, 1.0);
214 LIBXSMM_MATINIT_OMP(float, 24, bcgold, 1, K, 1, 1.0);
215 LIBXSMM_MATINIT_OMP(float, 24, bfgold, 1, K, 1, 1.0);
216 zero_buf(hgoldt, N*K*t);
217 /* BWD/UPD */
218 for (j = 0; j < t; ++j) {
219 LIBXSMM_MATINIT_OMP(float, 24, &LIBXSMM_VLA_ACCESS(2, dhgold, j, 0, K * N), N, K, N, 1.0);
220 }
221 zero_buf(dxgoldt, N*C*t);
222 zero_buf(dhpgold, K*N);
223 zero_buf(dwgold, C*K*3);
224 zero_buf(drgold, K*K*3);
225 zero_buf(dbgold, K*3);
226
227 /* first touch LIBXSMM */
228 zero_buf(xt, N*C*t);
229 zero_buf(hp, K*N);
230 zero_buf(w, C*K*3);
231 zero_buf(r, K*K*3);
232 zero_buf(b, K*3);
233 zero_buf(ht, N*K*t);
234 zero_buf(it, K*N*t);
235 zero_buf(ct, K*N*t);
236 zero_buf(ft, K*N*t);
237 zero_buf(ot, K*N*t);
238 zero_buf(dxt, N*C*t);
239 zero_buf(dhp, K*N);
240 zero_buf(dw, C*K*3);
241 zero_buf(dr, K*K*3);
242 zero_buf(db, K*3);
243 zero_buf(dht, K*N*t);
244
245 if (LIBXSMM_NEQ(0, check)) {
246 printf("##########################################\n");
247 printf("# Computing Reference ... #\n");
248 printf("##########################################\n");
249
250 gru_ref_fwd( N, C, K, t,
251 wigold, wcgold, wfgold,
252 rigold, rcgold, rfgold,
253 bigold, bcgold, bfgold,
254 xgoldt, hpgold, hgoldt,
255 igoldt, cgoldt, fgoldt, ogoldt );
256
257 gru_ref_bwd_upd( N, C, K, t,
258 xgoldt, hpgold, hgoldt,
259 igoldt, cgoldt, fgoldt, ogoldt,
260 wigold, wcgold, wfgold,
261 rigold, rcgold, rfgold,
262 dhgoldt, dwgold, drgold, dbgold,
263 dxgoldt, dhpgold, scratch_bu );
264
265 printf("##########################################\n");
266 printf("# Computing Reference ... done #\n");
267 printf("##########################################\n");
268 }
269
270 if (1 /* format == 'A' || format == 'L' */) {
271 printf("\n");
272 printf("##########################################\n");
273 printf("# Setting Up (custom-Storage) #\n");
274 printf("##########################################\n");
275
276 /* setup LIBXSMM handle */
277 grucell_desc.threads = nThreads;
278 grucell_desc.N = N;
279 grucell_desc.C = C;
280 grucell_desc.K = K;
281 grucell_desc.max_T = t;
282 grucell_desc.bn = bn;
283 grucell_desc.bc = bc;
284 grucell_desc.bk = bk;
285 grucell_desc.cell_type = LIBXSMM_DNN_RNNCELL_GRU;
286 grucell_desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
287 grucell_desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
288 grucell_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NC;
289 grucell_desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_CK;
290
291 libxsmm_handle = libxsmm_dnn_create_rnncell( grucell_desc, &status );
292 CHKERR_LIBXSMM_DNN( status );
293
294 /* setup LIBXSMM buffers and filter */
295 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
296 libxsmm_input = libxsmm_dnn_link_tensor( libxsmm_layout, xt, &status ); CHKERR_LIBXSMM_DNN( status );
297 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
298
299 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV, &status ); CHKERR_LIBXSMM_DNN( status );
300 libxsmm_hidden_state_prev = libxsmm_dnn_link_tensor( libxsmm_layout, hp, &status ); CHKERR_LIBXSMM_DNN( status );
301 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
302
303 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
304 libxsmm_weight = libxsmm_dnn_link_tensor( libxsmm_layout, w, &status ); CHKERR_LIBXSMM_DNN( status );
305 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
306
307 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
308 libxsmm_recur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, r, &status ); CHKERR_LIBXSMM_DNN( status );
309 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
310
311 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
312 libxsmm_bias = libxsmm_dnn_link_tensor( libxsmm_layout, b, &status ); CHKERR_LIBXSMM_DNN( status );
313 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
314
315 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
316 libxsmm_hidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, ht, &status ); CHKERR_LIBXSMM_DNN( status );
317 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
318
319 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_I, &status ); CHKERR_LIBXSMM_DNN( status );
320 libxsmm_i = libxsmm_dnn_link_tensor( libxsmm_layout, it, &status ); CHKERR_LIBXSMM_DNN( status );
321 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
322
323 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_CI, &status ); CHKERR_LIBXSMM_DNN( status );
324 libxsmm_c = libxsmm_dnn_link_tensor( libxsmm_layout, ct, &status ); CHKERR_LIBXSMM_DNN( status );
325 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
326
327 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_F, &status ); CHKERR_LIBXSMM_DNN( status );
328 libxsmm_f = libxsmm_dnn_link_tensor( libxsmm_layout, ft, &status ); CHKERR_LIBXSMM_DNN( status );
329 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
330
331 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_O, &status ); CHKERR_LIBXSMM_DNN( status );
332 libxsmm_o = libxsmm_dnn_link_tensor( libxsmm_layout, ot, &status ); CHKERR_LIBXSMM_DNN( status );
333 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
334
335 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
336 libxsmm_dinput = libxsmm_dnn_link_tensor( libxsmm_layout, dxt, &status ); CHKERR_LIBXSMM_DNN( status );
337 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
338
339 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV, &status ); CHKERR_LIBXSMM_DNN( status );
340 libxsmm_dhidden_state_prev = libxsmm_dnn_link_tensor( libxsmm_layout, dhp, &status ); CHKERR_LIBXSMM_DNN( status );
341 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
342
343 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
344 libxsmm_dweight = libxsmm_dnn_link_tensor( libxsmm_layout, dw, &status ); CHKERR_LIBXSMM_DNN( status );
345 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
346
347 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
348 libxsmm_drecur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, dr, &status ); CHKERR_LIBXSMM_DNN( status );
349 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
350
351 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
352 libxsmm_dbias = libxsmm_dnn_link_tensor( libxsmm_layout, db, &status ); CHKERR_LIBXSMM_DNN( status );
353 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
354
355 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
356 libxsmm_dhidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, dht, &status ); CHKERR_LIBXSMM_DNN( status );
357 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
358
359 /* copy in data to LIBXSMM format */
360 matrix_copy(N*C*t, xgoldt, xt);
361 matrix_copy(K*N, hpgold, hp);
362 convert_ck_c3k(C, K, wigold, w);
363 convert_ck_c3k(C, K, wcgold, &(w[K]));
364 convert_ck_c3k(C, K, wfgold, &(w[2*K]));
365 convert_ck_c3k(K, K, rigold, r);
366 convert_ck_c3k(K, K, rcgold, &(r[K]));
367 convert_ck_c3k(K, K, rfgold, &(r[2*K]));
368 matrix_copy(K, bigold, b);
369 matrix_copy(K, bcgold, &(b[K]));
370 matrix_copy(K, bfgold, &(b[2*K]));
371 matrix_copy(K*N*t, dhgoldt, dht);
372
373 /* bind buffers and filter to handle */
374 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_input, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
375 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state_prev, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
376 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_weight, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
377 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_recur_weight, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
378 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_bias, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
379 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
380 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_i, LIBXSMM_DNN_RNN_INTERNAL_I ) );
381 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_c, LIBXSMM_DNN_RNN_INTERNAL_CI ) );
382 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_f, LIBXSMM_DNN_RNN_INTERNAL_F ) );
383 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_o, LIBXSMM_DNN_RNN_INTERNAL_O ) );
384 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dinput, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
385 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dhidden_state_prev, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) );
386 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dweight, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
387 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_drecur_weight, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
388 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dbias, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
389 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dhidden_state, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
390
391 /* let's allocate and bind scratch */
392 if (pass == 0) {
393 scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
394 CHKERR_LIBXSMM_DNN( status );
395 scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
396 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, scratch ) );
397 } else {
398 scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
399 CHKERR_LIBXSMM_DNN( status );
400 scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
401 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ) );
402 }
403 zero_buf( (float*)scratch, scratch_size/4 );
404
405 /* let's allocate and bind internalstate */
406 if (pass == 0) {
407 internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
408 CHKERR_LIBXSMM_DNN( status );
409 internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
410 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, internalstate ) );
411 } else {
412 internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
413 CHKERR_LIBXSMM_DNN( status );
414 internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
415 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, internalstate ) );
416 }
417 zero_buf( (float*)internalstate, internalstate_size/4 );
418
419 if ((pass == 0) && LIBXSMM_NEQ(0, check)) {
420 printf("##########################################\n");
421 printf("# Correctness - FWD (custom-Storage) #\n");
422 printf("##########################################\n");
423 /* run LIBXSMM GRU */
424 #if defined(_OPENMP)
425 # pragma omp parallel
426 #endif
427 {
428 #if defined(_OPENMP)
429 const int tid = omp_get_thread_num();
430 #else
431 const int tid = 0;
432 #endif
433 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
434 }
435
436 /* compare */
437 libxsmm_matdiff(&norms_fwd, LIBXSMM_DATATYPE_F32, K*N, 1, &LIBXSMM_VLA_ACCESS(2, hgold, t-1, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, t-1, 0, K * N), 0, 0);
438 printf("L1 reference : %.25g\n", norms_fwd.l1_ref);
439 printf("L1 test : %.25g\n", norms_fwd.l1_tst);
440 printf("L2 abs.error : %.24f\n", norms_fwd.l2_abs);
441 printf("L2 rel.error : %.24f\n", norms_fwd.l2_rel);
442 printf("Linf abs.error: %.24f\n", norms_fwd.linf_abs);
443 printf("Linf rel.error: %.24f\n", norms_fwd.linf_rel);
444 printf("Check-norm : %.24f\n", norms_fwd.normf_rel);
445 libxsmm_matdiff_reduce(&diff, &norms_fwd);
446 } else {
447 /* We need to always run FWD pass once to populate i, c, f, o, h */
448 #if defined(_OPENMP)
449 # pragma omp parallel
450 #endif
451 {
452 #if defined(_OPENMP)
453 const int tid = omp_get_thread_num();
454 #else
455 const int tid = 0;
456 #endif
457 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
458 }
459 }
460
461 if ( (pass == 1) && LIBXSMM_NEQ(0, check) ) {
462 printf("##########################################\n");
463 printf("# Correctness - BWD (custom-Storage) #\n");
464 printf("##########################################\n");
465 /* run LIBXSMM GRU */
466 #if defined(_OPENMP)
467 # pragma omp parallel
468 #endif
469 {
470 #if defined(_OPENMP)
471 const int tid = omp_get_thread_num();
472 #else
473 const int tid = 0;
474 #endif
475 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid ) );
476 }
477
478 /* compare */
479 libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, dxgoldt, dxt, 0, 0);
480 printf("L1 reference : %.25g\n", norms_bwd.l1_ref);
481 printf("L1 test : %.25g\n", norms_bwd.l1_tst);
482 printf("L2 abs.error : %.24f\n", norms_bwd.l2_abs);
483 printf("L2 rel.error : %.24f\n", norms_bwd.l2_rel);
484 printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
485 printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
486 printf("Check-norm : %.24f\n", norms_bwd.normf_rel);
487 libxsmm_matdiff_reduce(&diff, &norms_bwd);
488 }
489
490 if ( (pass == 2) && LIBXSMM_NEQ(0, check) ) {
491 printf("##########################################\n");
492 printf("# Correctness - UPD (custom-Storage) #\n");
493 printf("##########################################\n");
494 /* run LIBXSMM GRU */
495 #if defined(_OPENMP)
496 # pragma omp parallel
497 #endif
498 {
499 #if defined(_OPENMP)
500 const int tid = omp_get_thread_num();
501 #else
502 const int tid = 0;
503 #endif
504 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid ) );
505 }
506
507 convert_ck_c3k(C, K, &(dwgold[0]), &(dwtest[0]));
508 convert_ck_c3k(C, K, &(dwgold[C*K]), &(dwtest[K]));
509 convert_ck_c3k(C, K, &(dwgold[2*C*K]), &(dwtest[2*K]));
510 convert_ck_c3k(K, K, &(drgold[0]), &(drtest[0]));
511 convert_ck_c3k(K, K, &(drgold[K*K]), &(drtest[K]));
512 convert_ck_c3k(K, K, &(drgold[2*K*K]), &(drtest[2*K]));
513 /* compare */
514 libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K*3, 1, dwtest, dw, 0, 0);
515 printf("Delta weight\n");
516 printf("L1 reference : %.25g\n", norms_upd_w.l1_ref);
517 printf("L1 test : %.25g\n", norms_upd_w.l1_tst);
518 printf("L2 abs.error : %.24f\n", norms_upd_w.l2_abs);
519 printf("L2 rel.error : %.24f\n", norms_upd_w.l2_rel);
520 printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
521 printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
522 printf("Check-norm : %.24f\n", norms_upd_w.normf_rel);
523 libxsmm_matdiff_reduce(&diff, &norms_upd_w);
524
525 libxsmm_matdiff(&norms_upd_r, LIBXSMM_DATATYPE_F32, K*K*3, 1, drtest, dr, 0, 0);
526 printf("Delta recurrent weight\n");
527 printf("L1 reference : %.25g\n", norms_upd_r.l1_ref);
528 printf("L1 test : %.25g\n", norms_upd_r.l1_tst);
529 printf("L2 abs.error : %.24f\n", norms_upd_r.l2_abs);
530 printf("L2 rel.error : %.24f\n", norms_upd_r.l2_rel);
531 printf("Linf abs.error: %.24f\n", norms_upd_r.linf_abs);
532 printf("Linf rel.error: %.24f\n", norms_upd_r.linf_rel);
533 printf("Check-norm : %.24f\n", norms_upd_r.normf_rel);
534 libxsmm_matdiff_reduce(&diff, &norms_upd_r);
535
536 libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K*3, 1, dbgold, db, 0, 0);
537 printf("Delta bias\n");
538 printf("L1 reference : %.25g\n", norms_upd_b.l1_ref);
539 printf("L1 test : %.25g\n", norms_upd_b.l1_tst);
540 printf("L2 abs.error : %.24f\n", norms_upd_b.l2_abs);
541 printf("L2 rel.error : %.24f\n", norms_upd_b.l2_rel);
542 printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
543 printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
544 printf("Check-norm : %.24f\n", norms_upd_b.normf_rel);
545 libxsmm_matdiff_reduce(&diff, &norms_upd_b);
546 }
547
548 if ( (pass == 3) && LIBXSMM_NEQ(0, check) ) {
549 printf("##########################################\n");
550 printf("# Correctness - BWD+UPD (custom-Storage) #\n");
551 printf("##########################################\n");
552 /* run LIBXSMM GRU */
553 #if defined(_OPENMP)
554 # pragma omp parallel
555 #endif
556 {
557 #if defined(_OPENMP)
558 const int tid = omp_get_thread_num();
559 #else
560 const int tid = 0;
561 #endif
562 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid ) );
563 }
564
565 convert_ck_c3k(C, K, &(dwgold[0]), &(dwtest[0]));
566 convert_ck_c3k(C, K, &(dwgold[C*K]), &(dwtest[K]));
567 convert_ck_c3k(C, K, &(dwgold[2*C*K]), &(dwtest[2*K]));
568 convert_ck_c3k(K, K, &(drgold[0]), &(drtest[0]));
569 convert_ck_c3k(K, K, &(drgold[K*K]), &(drtest[K]));
570 convert_ck_c3k(K, K, &(drgold[2*K*K]), &(drtest[2*K]));
571 /* compare */
572 libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, dxgoldt, dxt, 0, 0);
573 printf("Delta input\n");
574 printf("L1 reference : %.25g\n", norms_bwd.l1_ref);
575 printf("L1 test : %.25g\n", norms_bwd.l1_tst);
576 printf("L2 abs.error : %.24f\n", norms_bwd.l2_abs);
577 printf("L2 rel.error : %.24f\n", norms_bwd.l2_rel);
578 printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
579 printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
580 printf("Check-norm : %.24f\n", norms_bwd.normf_rel);
581 libxsmm_matdiff_reduce(&diff, &norms_bwd);
582
583 libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K*3, 1, dwtest, dw, 0, 0);
584 printf("Delta weight\n");
585 printf("L1 reference : %.25g\n", norms_upd_w.l1_ref);
586 printf("L1 test : %.25g\n", norms_upd_w.l1_tst);
587 printf("L2 abs.error : %.24f\n", norms_upd_w.l2_abs);
588 printf("L2 rel.error : %.24f\n", norms_upd_w.l2_rel);
589 printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
590 printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
591 printf("Check-norm : %.24f\n", norms_upd_w.normf_rel);
592 libxsmm_matdiff_reduce(&diff, &norms_upd_w);
593
594 libxsmm_matdiff(&norms_upd_r, LIBXSMM_DATATYPE_F32, K*K*3, 1, drtest, dr, 0, 0);
595 printf("Delta recurrent weight\n");
596 printf("L1 reference : %.25g\n", norms_upd_r.l1_ref);
597 printf("L1 test : %.25g\n", norms_upd_r.l1_tst);
598 printf("L2 abs.error : %.24f\n", norms_upd_r.l2_abs);
599 printf("L2 rel.error : %.24f\n", norms_upd_r.l2_rel);
600 printf("Linf abs.error: %.24f\n", norms_upd_r.linf_abs);
601 printf("Linf rel.error: %.24f\n", norms_upd_r.linf_rel);
602 printf("Check-norm : %.24f\n", norms_upd_r.normf_rel);
603 libxsmm_matdiff_reduce(&diff, &norms_upd_r);
604
605 libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K*3, 1, dbgold, db, 0, 0);
606 printf("Delta bias\n");
607 printf("L1 reference : %.25g\n", norms_upd_b.l1_ref);
608 printf("L1 test : %.25g\n", norms_upd_b.l1_tst);
609 printf("L2 abs.error : %.24f\n", norms_upd_b.l2_abs);
610 printf("L2 rel.error : %.24f\n", norms_upd_b.l2_rel);
611 printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
612 printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
613 printf("Check-norm : %.24f\n", norms_upd_b.normf_rel);
614 libxsmm_matdiff_reduce(&diff, &norms_upd_b);
615 }
616
617 if ( pass == 0 ) {
618 printf("##########################################\n");
619 printf("# Performance - FWD (custom-Storage) #\n");
620 printf("##########################################\n");
621 /* run LIBXSMM GRU for performance */
622 l_start = libxsmm_timer_tick();
623
624 #if defined(_OPENMP)
625 # pragma omp parallel private(j)
626 #endif
627 {
628 #if defined(_OPENMP)
629 const int tid = omp_get_thread_num();
630 #else
631 const int tid = 0;
632 #endif
633 for (j = 0; j < iters; ++j) {
634 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
635 }
636 }
637 l_end = libxsmm_timer_tick();
638 l_total = libxsmm_timer_duration(l_start, l_end);
639 flops = (((2.0 * K * N * C) + (2.0 * K * N * K) + (2.0 * K * N) + (tflops * K * N)) * 2.0 + (K * N) + (2.0 * K * N * C) + (2.0 * K * N * K) + (tflops * K * N) + 4.0 * (K * N)) * (double)t * (double)iters;
640
641 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
642 printf("fp time = %.5g\n", ((double)(l_total/iters)));
643 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
644
645 printf("PERFDUMP,FP,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
646 }
647
648 if ( pass == 1 ) {
649 printf("##########################################\n");
650 printf("# Performance - BWD (custom-Storage) #\n");
651 printf("##########################################\n");
652 /* run LIBXSMM GRU for performance */
653 l_start = libxsmm_timer_tick();
654
655 #if defined(_OPENMP)
656 # pragma omp parallel private(j)
657 #endif
658 {
659 #if defined(_OPENMP)
660 const int tid = omp_get_thread_num();
661 #else
662 const int tid = 0;
663 #endif
664 for (j = 0; j < iters; ++j) {
665 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid );
666 }
667 }
668 l_end = libxsmm_timer_tick();
669 l_total = libxsmm_timer_duration(l_start, l_end);
670 flops = K * N; /* d3 = djdh + d23 (delta) */
671 flops += 2.0 * K * N; /* d4 = (1 - z).d3 */
672 flops += K * N; /* d5 = d3.h */
673 flops += K * N; /* d6 = -d5 */
674 flops += K * N; /* d7 = d3.g */
675 flops += K * N; /* d8 = d3.z */
676 flops += K * N; /* d9 = d7 + d8 */
677 flops += 3.0 * K * N; /* d10 = d8.tanh'(g) */
678 flops += 3.0 * K * N; /* d11 = d9.sig'(z) */
679 flops += (2.0 * K * K * N + K * K); /* d13 = Wg^T * d10 (including transpose) */
680 flops += (2.0 * K * K * N + K * K); /* d15 = Wz^T * d11 (including transpose) */
681 flops += K * N; /* d16 = d13.z */
682 flops += K * N; /* d17 = d13.r */
683 flops += 3.0 * K * N; /* d18 = d16.sig'(r) */
684 flops += K * N; /* d19 = d17 + d4 */
685 flops += (2.0 * K * K * N + K * K); /* d21 = Wr^T * d18 (including transpose) */
686 flops += K * N; /* d22 = d21 + d15 */
687 flops += K * N; /* d23 = d19 + d22 */
688 flops += (2.0 * K * C * N + K * C); /* d12 = Ug^T * d10 (including transpose) */
689 flops += (2.0 * K * C * N + K * C); /* d14 = Uz^T * d11 (including transpose) */
690 flops += (2.0 * K * C * N + K * C); /* d20 = Ur^T * d18 (including transpose) */
691 flops += 2.0 * K * N; /* djdx = d12 + d14 + d20 */
692 flops *= t; /* for t time steps */
693 flops *= iters;
694
695 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
696 printf("bp time = %.5g\n", ((double)(l_total/iters)));
697 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
698
699 printf("PERFDUMP,BP,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
700 }
701
702 if ( pass == 2 ) {
703 printf("##########################################\n");
704 printf("# Performance - UPD (custom-Storage) #\n");
705 printf("##########################################\n");
706 /* run LIBXSMM GRU for performance */
707 l_start = libxsmm_timer_tick();
708
709 #if defined(_OPENMP)
710 # pragma omp parallel private(j)
711 #endif
712 {
713 #if defined(_OPENMP)
714 const int tid = omp_get_thread_num();
715 #else
716 const int tid = 0;
717 #endif
718 for (j = 0; j < iters; ++j) {
719 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid );
720 }
721 }
722 l_end = libxsmm_timer_tick();
723 l_total = libxsmm_timer_duration(l_start, l_end);
724 flops = K * N; /* d3 = djdh + d23 (delta) */
725 flops += 2.0 * K * N; /* d4 = (1 - z).d3 */
726 flops += K * N; /* d5 = d3.h */
727 flops += K * N; /* d6 = -d5 */
728 flops += K * N; /* d7 = d3.g */
729 flops += K * N; /* d8 = d3.z */
730 flops += K * N; /* d9 = d7 + d8 */
731 flops += 3.0 * K * N; /* d10 = d8.tanh'(g) */
732 flops += 3.0 * K * N; /* d11 = d9.sig'(z) */
733 flops += (2.0 * K * K * N + K * K); /* d13 = Wg^T * d10 (including transpose) */
734 flops += (2.0 * K * K * N + K * K); /* d15 = Wz^T * d11 (including transpose) */
735 flops += K * N; /* d16 = d13.z */
736 flops += K * N; /* d17 = d13.r */
737 flops += 3.0 * K * N; /* d18 = d16.sig'(r) */
738 flops += K * N; /* d19 = d17 + d4 */
739 flops += (2.0 * K * K * N + K * K); /* d21 = Wr^T * d18 (including transpose) */
740 flops += K * N; /* d22 = d21 + d15 */
741 flops += K * N; /* d23 = d19 + d22 */
742 flops += (2.0 * K * N * K + K * N + K * K); /* djdwr = djdwr + d18 * h^T */
743 flops += (2.0 * K * N * K + K * N + K * K); /* djdwz = djdwz + d11 * h^T */
744 flops += (2.0 * K * N * K + 2.0 * K * N + K * K); /* djdwg = djdwg + d10 * (h.r)^T */
745 flops += (2.0 * K * N * C + C * N + K * C); /* djdur = djdur + d18 * x^T */
746 flops += (2.0 * K * N * C + C * N + K * C); /* djduz = djduz + d11 * x^T */
747 flops += (2.0 * K * N * C + C * N + K * C); /* djdug = djdug + d10 * x^T */
748 flops += K * N; /* djdbr = djdbr + d18 */
749 flops += K * N; /* djdbz = djdbz + d11 */
750 flops += K * N; /* djdbg = djdbg + d10 */
751 flops *= t; /* for t time steps */
752 flops *= iters;
753
754 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
755 printf("wu time = %.5g\n", ((double)(l_total/iters)));
756 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
757
758 printf("PERFDUMP,WU,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
759 }
760
761 if ( pass == 3 ) {
762 printf("##########################################\n");
763 printf("# Performance - BWD+UPD (custom-Storage) #\n");
764 printf("##########################################\n");
765 /* run LIBXSMM GRU for performance */
766 l_start = libxsmm_timer_tick();
767
768 #if defined(_OPENMP)
769 # pragma omp parallel private(j)
770 #endif
771 {
772 #if defined(_OPENMP)
773 const int tid = omp_get_thread_num();
774 #else
775 const int tid = 0;
776 #endif
777 for (j = 0; j < iters; ++j) {
778 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
779 }
780 }
781 l_end = libxsmm_timer_tick();
782 l_total = libxsmm_timer_duration(l_start, l_end);
783 flops = K * N; /* d3 = djdh + d23 (delta) */
784 flops += 2.0 * K * N; /* d4 = (1 - z).d3 */
785 flops += K * N; /* d5 = d3.h */
786 flops += K * N; /* d6 = -d5 */
787 flops += K * N; /* d7 = d3.g */
788 flops += K * N; /* d8 = d3.z */
789 flops += K * N; /* d9 = d7 + d8 */
790 flops += 3.0 * K * N; /* d10 = d8.tanh'(g) */
791 flops += 3.0 * K * N; /* d11 = d9.sig'(z) */
792 flops += (2.0 * K * K * N + K * K); /* d13 = Wg^T * d10 (including transpose) */
793 flops += (2.0 * K * K * N + K * K); /* d15 = Wz^T * d11 (including transpose) */
794 flops += K * N; /* d16 = d13.z */
795 flops += K * N; /* d17 = d13.r */
796 flops += 3.0 * K * N; /* d18 = d16.sig'(r) */
797 flops += K * N; /* d19 = d17 + d4 */
798 flops += (2.0 * K * K * N + K * K); /* d21 = Wr^T * d18 (including transpose) */
799 flops += K * N; /* d22 = d21 + d15 */
800 flops += K * N; /* d23 = d19 + d22 */
801 flops += (2.0 * K * C * N + K * C); /* d12 = Ug^T * d10 (including transpose) */
802 flops += (2.0 * K * C * N + K * C); /* d14 = Uz^T * d11 (including transpose) */
803 flops += (2.0 * K * C * N + K * C); /* d20 = Ur^T * d18 (including transpose) */
804 flops += 2.0 * K * N; /* djdx = d12 + d14 + d20 */
805 flops += (2.0 * K * N * K + K * N + K * K); /* djdwr = djdwr + d18 * h^T */
806 flops += (2.0 * K * N * K + K * N + K * K); /* djdwz = djdwz + d11 * h^T */
807 flops += (2.0 * K * N * K + 2.0 * K * N + K * K); /* djdwg = djdwg + d10 * (h.r)^T */
808 flops += (2.0 * K * N * C + C * N + K * C); /* djdur = djdur + d18 * x^T */
809 flops += (2.0 * K * N * C + C * N + K * C); /* djduz = djduz + d11 * x^T */
810 flops += (2.0 * K * N * C + C * N + K * C); /* djdug = djdug + d10 * x^T */
811 flops += K * N; /* djdbr = djdbr + d18 */
812 flops += K * N; /* djdbz = djdbz + d11 */
813 flops += K * N; /* djdbg = djdbg + d10 */
814 flops *= t; /* for t time steps */
815 flops *= iters;
816
817 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
818 printf("bp+wu time = %.5g\n", ((double)(l_total/iters)));
819 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
820
821 printf("PERFDUMP,BP+WU,%s,%i,%i,%i,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, bn, bc, bk, ((double)(l_total/iters)), (flops*1e-9)/l_total);
822 }
823
824 /* clean-up */
825 if (pass == 0) {
826 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
827 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
828 } else {
829 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
830 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
831 }
832 libxsmm_free(scratch);
833 libxsmm_free(internalstate);
834 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
835 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
836 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
837 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
838 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
839 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
840 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_I ) );
841 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_CI ) );
842 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_F ) );
843 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_INTERNAL_O ) );
844 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
845 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV ) );
846 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
847 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
848 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
849 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
850 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_input ) );
851 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state_prev ) );
852 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_weight ) );
853 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_recur_weight ) );
854 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_bias ) );
855 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state ) );
856 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_i ) );
857 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_c ) );
858 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_f ) );
859 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_o ) );
860 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dinput ) );
861 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dhidden_state_prev ) );
862 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dweight ) );
863 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_drecur_weight ) );
864 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dbias ) );
865 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dhidden_state ) );
866 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_rnncell( libxsmm_handle ) );
867 }
868
869 /* deallocate data */
870 libxsmm_free(xgoldt);
871 libxsmm_free(hpgold);
872 libxsmm_free(wigold);
873 libxsmm_free(wcgold);
874 libxsmm_free(wfgold);
875 libxsmm_free(rigold);
876 libxsmm_free(rcgold);
877 libxsmm_free(rfgold);
878 libxsmm_free(bigold);
879 libxsmm_free(bcgold);
880 libxsmm_free(bfgold);
881 libxsmm_free(hgoldt);
882 libxsmm_free(igoldt);
883 libxsmm_free(cgoldt);
884 libxsmm_free(fgoldt);
885 libxsmm_free(ogoldt);
886 libxsmm_free(dxgoldt);
887 libxsmm_free(dhpgold);
888 libxsmm_free(dwgold);
889 libxsmm_free(drgold);
890 libxsmm_free(dbgold);
891 libxsmm_free(dhgoldt);
892 libxsmm_free(xt);
893 libxsmm_free(hp);
894 libxsmm_free(w);
895 libxsmm_free(r);
896 libxsmm_free(b);
897 libxsmm_free(ht);
898 libxsmm_free(it);
899 libxsmm_free(ct);
900 libxsmm_free(ft);
901 libxsmm_free(ot);
902 libxsmm_free(dxt);
903 libxsmm_free(dhp);
904 libxsmm_free(dw);
905 libxsmm_free(dr);
906 libxsmm_free(db);
907 libxsmm_free(dht);
908 libxsmm_free(dwtest);
909 libxsmm_free(drtest);
910
911 { const char *const env_check_scale = getenv("CHECK_SCALE");
912 const double check_scale = LIBXSMM_ABS(0 == env_check_scale ? 1.0 : atof(env_check_scale));
913 if (LIBXSMM_NEQ(0, check) && (check < 100.0 * check_scale * diff.normf_rel) && (global_status == LIBXSMM_DNN_SUCCESS)) {
914 fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
915 exit(EXIT_FAILURE);
916 }
917 }
918
919 /* some empty lines at the end */
920 printf("\n\n\n");
921
922 return global_status;
923 }
924
925