1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include <libxsmm_intrinsics_x86.h>
13
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #include <math.h>
21 #if defined(_OPENMP)
22 # include <omp.h>
23 #endif
24 #if defined(LIBXSMM_OFFLOAD_TARGET)
25 # pragma offload_attribute(pop)
26 #endif
27
28 /* include c-based dnn library */
29 #include "../common/dnn_common.h"
30
31 #define CHKERR_LIBXSMM_DNN(A) { const int chkerr_libxsmm_dnn_ = A; if (LIBXSMM_DNN_SUCCESS != chkerr_libxsmm_dnn_) { \
32 fprintf(stderr, "%s\n", libxsmm_dnn_get_error(chkerr_libxsmm_dnn_)); global_status = chkerr_libxsmm_dnn_; } \
33 }
34
main(int argc,char * argv[])35 int main(int argc, char* argv[])
36 {
37 /* Arrays related to FWD pass */
38 float *wgold, *xgoldt, *ugold, *hpgold, *hgoldt, *z1gold, *z2gold, *zgoldt, *bgold, *bmgold;
39 float *w, *xt, *u, *hp, *ht, *htest, *b;
40 /* Arrays related to BWD and UPD pass */
41 float *djdhgoldt, *deltagoldt, *djdugold, *djdwgold, *djdxgoldt, *djdbgold;
42 float *zigold, *di1gold, *di2gold, *ugoldTp, *wgoldTp, *hgoldTp, *xgoldTp;
43 float *djdht, *djdu, *djdw, *djdxt, *djdb, *djdxtestt, *djdwtest, *djdutest;
44
45 const char transa = 'N', transb = 'N'; /* no transposes */
46 const float alpha = 1, beta = 1, beta0 = 0;
47 void *scratch, *internalstate;
48 size_t scratch_size = 0, internalstate_size = 0;
49
50 int iters = 10; /* repetitions of benchmark */
51 int pass = 0; /* pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD */
52 int nonlin = 2; /* nonlin=1 denotes ReLU, 2 denotes sigmoid, 3 denotes tanh */
53 int N = 168; /* size of mini-batch */
54 int C = 512; /* number of inputs */
55 int K = 256; /* number of outputs */
56 int t = 4; /* number of time steps (>= 1) */
57 int bn = 24;
58 int bc = 64;
59 int bk = 64;
60
61 const char *const env_check = getenv("CHECK");
62 const double check = LIBXSMM_ABS(0 == env_check ? 1/*enable by default*/ : atof(env_check));
63
64 #if defined(_OPENMP)
65 int nThreads = omp_get_max_threads(); /* number of threads */
66 #else
67 int nThreads = 1; /* number of threads */
68 #endif
69
70 unsigned long long l_start, l_end;
71 double l_total = 0.0;
72 double flops = 0.0, tempflops = 0.0;
73 const double tflops = 12; /* transcendental flops */
74 int i, j, it;
75
76 libxsmm_dnn_rnncell_desc rnncell_desc;
77 libxsmm_dnn_rnncell* libxsmm_handle;
78 libxsmm_dnn_tensor* libxsmm_input;
79 libxsmm_dnn_tensor* libxsmm_hidden_state_prev;
80 libxsmm_dnn_tensor* libxsmm_weight;
81 libxsmm_dnn_tensor* libxsmm_recur_weight;
82 libxsmm_dnn_tensor* libxsmm_bias;
83 libxsmm_dnn_tensor* libxsmm_hidden_state;
84 libxsmm_dnn_tensor* libxsmm_dinput;
85 libxsmm_dnn_tensor* libxsmm_dweight;
86 libxsmm_dnn_tensor* libxsmm_drecur_weight;
87 libxsmm_dnn_tensor* libxsmm_dbias;
88 libxsmm_dnn_tensor* libxsmm_dhidden_state;
89
90 libxsmm_dnn_tensor_datalayout* libxsmm_layout;
91 libxsmm_dnn_err_t status;
92 libxsmm_dnn_err_t global_status = LIBXSMM_DNN_SUCCESS;
93
94 libxsmm_matdiff_info norms_fwd, norms_bwd, norms_upd_w, norms_upd_u, norms_upd_b, diff;
95 libxsmm_matdiff_clear(&norms_fwd);
96 libxsmm_matdiff_clear(&norms_bwd);
97 libxsmm_matdiff_clear(&norms_upd_w);
98 libxsmm_matdiff_clear(&norms_upd_u);
99 libxsmm_matdiff_clear(&norms_upd_b);
100 libxsmm_matdiff_clear(&diff);
101
102 if (argc > 1 && !strncmp(argv[1], "-h", 3)) {
103 printf("\nUsage: ./rnndriver [reps] [pass: 0--FWD, 1--BWD, 2--UPD, 3--BWD+UPD] [nonlin: 1--ReLU, 2--sigmoid, 3--tanh] [N] [C] [K] [time_steps > 0]\n\n");
104 return 0;
105 }
106 libxsmm_rng_set_seed(1);
107
108 /* reading new values from cli */
109 i = 1;
110 if (argc > i) iters = atoi(argv[i++]);
111 if (argc > i) pass = atoi(argv[i++]);
112 if (argc > i) nonlin= atoi(argv[i++]);
113 if (argc > i) N = atoi(argv[i++]);
114 if (argc > i) C = atoi(argv[i++]);
115 if (argc > i) K = atoi(argv[i++]);
116 if (argc > i) t = atoi(argv[i++]);
117 if (argc > i) bn = atoi(argv[i++]);
118 if (argc > i) bc = atoi(argv[i++]);
119 if (argc > i) bk = atoi(argv[i++]);
120
121 if (t <= 0) {
122 printf("time_steps %d should be greater than 0\n\n", t);
123 return 0;
124 }
125 if (!(pass == 0 || pass == 1 || pass == 2 || pass == 3 || pass == 4)) {
126 printf("Unknown pass: %d, valid arguments for pass = {0(FWD), 1(BWD), 2(UPD), 3(BWD+UPD)\n\n", pass);
127 return 0;
128 }
129 if (nonlin != 1 && nonlin != 2 && nonlin != 3) {
130 printf("Unsupported non-linear function used [1--ReLU, 2--sigmoid, 3--tanh]\n\n");
131 return 0;
132 }
133
134 #if defined(__SSE3__)
135 _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
136 _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
137 _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
138 #endif
139
140 /* print some summary */
141 printf("##########################################\n");
142 printf("# Setting Up (Common) #\n");
143 printf("##########################################\n");
144 printf("PARAMS: N:%d C:%d K:%d T:%d\n", N, C, K, t);
145 printf("PARAMS: ITERS:%d", iters); if (LIBXSMM_FEQ(0, check)) printf(" Threads:%d\n", nThreads); else printf("\n");
146 printf("SIZE Weight (MB): %10.2f MiB\n", (double)(C*K*sizeof(float))/(1024.0*1024.0) );
147 printf("SIZE Input (MB): %10.2f MiB\n", (double)(N*C*sizeof(float))/(1024.0*1024.0) );
148 printf("SIZE Hidden State: %10.2f MiB\n", (double)(K*N*sizeof(float))/(1024.0*1024.0) );
149
150 /* allocate data */
151 xgoldt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
152 hpgold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
153 wgold = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
154 ugold = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
155 bgold = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
156 hgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
157 zgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
158 bmgold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
159 z1gold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
160 z2gold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
161 djdxgoldt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
162 djdwgold = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
163 djdugold = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
164 djdbgold = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
165 djdhgoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
166 deltagoldt = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
167 zigold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
168 di1gold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
169 di2gold = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
170 xgoldTp = (float*)libxsmm_aligned_malloc(N*C*sizeof(float), 2097152);
171 wgoldTp = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
172 ugoldTp = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
173 hgoldTp = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
174 xt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
175 hp = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
176 w = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
177 u = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
178 ht = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
179 b = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
180 djdxt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
181 djdw = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
182 djdu = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
183 djdb = (float*)libxsmm_aligned_malloc(K*sizeof(float), 2097152);
184 djdht = (float*)libxsmm_aligned_malloc(K*N*t*sizeof(float), 2097152);
185 htest = (float*)libxsmm_aligned_malloc(K*N*sizeof(float), 2097152);
186 djdxtestt = (float*)libxsmm_aligned_malloc(N*C*t*sizeof(float), 2097152);
187 djdwtest = (float*)libxsmm_aligned_malloc(C*K*sizeof(float), 2097152);
188 djdutest = (float*)libxsmm_aligned_malloc(K*K*sizeof(float), 2097152);
189 LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N*C);
190 LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, K*N);
191 LIBXSMM_VLA_DECL(2, float, zgold, zgoldt, K*N);
192 LIBXSMM_VLA_DECL(2, float, djdxgold, djdxgoldt, N*C);
193 LIBXSMM_VLA_DECL(2, float, djdhgold, djdhgoldt, K*N);
194 LIBXSMM_VLA_DECL(2, float, deltagold, deltagoldt, K*N);
195
196 /* initialize data */
197 /* All data in gold is considered to be in column-major format */
198 for (it = 0; it < t; ++it) {
199 init_buf(&LIBXSMM_VLA_ACCESS(2, xgold, it, 0, N*C), N*C, 0, 0);
200 }
201 init_buf(hpgold, N*K, 0, 0);
202 init_buf(wgold, C*K, 0, 0);
203 init_buf(ugold, K*K, 0, 0);
204 init_buf(bgold, K, 0, 0);
205 for (j = 0; j < N; j++) {
206 matrix_copy(K, bgold, &(bmgold[j*K]));
207 }
208 zero_buf(hgoldt, K*N*t);
209 zero_buf(zgoldt, K*N*t);
210 zero_buf(z1gold, K*N);
211 zero_buf(z2gold, K*N);
212 for (it = 0; it < t; ++it) {
213 init_buf(&LIBXSMM_VLA_ACCESS(2, djdhgold, it, 0, K*N), N*K, 0, 0);
214 }
215 zero_buf(djdxgoldt, N*C*t);
216 zero_buf(djdwgold, C*K);
217 zero_buf(djdugold, K*K);
218 zero_buf(djdbgold, K);
219 zero_buf(deltagoldt, K*N*t);
220 zero_buf(zigold, K*N);
221 zero_buf(di1gold, K*N);
222 zero_buf(di2gold, K*N);
223 zero_buf(xgoldTp, N*C);
224 zero_buf(ugoldTp, K*K);
225 zero_buf(wgoldTp, C*K);
226 zero_buf(hgoldTp, K*N);
227
228 /* first touch LIBXSMM */
229 zero_buf(xt, N*C*t);
230 zero_buf(hp, K*N);
231 zero_buf(w, C*K);
232 zero_buf(u, K*K);
233 zero_buf(b, K);
234 zero_buf(ht, K*N*t);
235 zero_buf(djdxt,N*C*t);
236 zero_buf(djdw, C*K);
237 zero_buf(djdu, K*K);
238 zero_buf(djdb, K);
239 zero_buf(djdht, K*N*t);
240 LIBXSMM_VLA_DECL(2, float, h, ht, K*N);
241
242 if (LIBXSMM_NEQ(0, check)) {
243 printf("##########################################\n");
244 printf("# Computing Reference ... #\n");
245 printf("##########################################\n");
246 for (i = 0; i < t; ++i) {
247 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wgold, &K, &LIBXSMM_VLA_ACCESS(2, xgold, i, 0, N*C), &C, &beta0, z1gold, &K);
248 if (0 == i) {
249 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ugold, &K, hpgold, &K, &beta0, z2gold, &K);
250 } else {
251 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ugold, &K, &LIBXSMM_VLA_ACCESS(2, hgold, i-1, 0, K*N), &K, &beta0, z2gold, &K);
252 }
253 matrix_add(K*N, z1gold, z2gold, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N));
254 matrix_add(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), bmgold, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N));
255 if (1 == nonlin) {
256 matrix_relu(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), &LIBXSMM_VLA_ACCESS(2, hgold, i, 0, K*N));
257 } else if (2 == nonlin) {
258 matrix_sigmoid(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), &LIBXSMM_VLA_ACCESS(2, hgold, i, 0, K*N));
259 } else {
260 matrix_tanh(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), &LIBXSMM_VLA_ACCESS(2, hgold, i, 0, K*N));
261 }
262 }
263 /* Conceptually, delta iterates over 0 ... t-1, whereas, djdh and z iterates over 1 ... t */
264 /* Hence these have identical array indices */
265 if (1 == nonlin) {
266 matrix_relu_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, t-1, 0, K*N), zigold);
267 } else if (2 == nonlin) {
268 matrix_sigmoid_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, t-1, 0, K*N), zigold);
269 } else {
270 matrix_tanh_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, t-1, 0, K*N), zigold);
271 }
272 matrix_eltwise_mult(K*N, zigold, &LIBXSMM_VLA_ACCESS(2, djdhgold, t-1, 0, K*N), &LIBXSMM_VLA_ACCESS(2, deltagold, t-1, 0, K*N));
273 matrix_transpose(K, K, ugold, ugoldTp);
274 for (i = t-2; i >= 0; --i) {
275 if (1 == nonlin) {
276 matrix_relu_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), zigold);
277 } else if (2 == nonlin) {
278 matrix_sigmoid_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), zigold);
279 } else {
280 matrix_tanh_inverse(K*N, &LIBXSMM_VLA_ACCESS(2, zgold, i, 0, K*N), zigold);
281 }
282 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ugoldTp, &K, &LIBXSMM_VLA_ACCESS(2, deltagold, i+1, 0, K*N), &K, &beta0, di1gold, &K);
283 matrix_add(K*N, &LIBXSMM_VLA_ACCESS(2, djdhgold, i, 0, K*N), di1gold, di2gold);
284 matrix_eltwise_mult(K*N, zigold, di2gold, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N));
285 }
286 if (pass == 1 || pass == 3) {
287 matrix_transpose(C, K, wgold, wgoldTp);
288 for (i = 0; i < t; ++i) {
289 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &C, &N, &K, &alpha, wgoldTp, &C, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N), &K, &beta0, &LIBXSMM_VLA_ACCESS(2, djdxgold, i, 0, N*C), &C);
290 }
291 }
292 if (pass == 2 || pass == 3) {
293 for (i = 0; i < t; ++i) {
294 if (0 == i) {
295 matrix_transpose(N, K, hpgold, hgoldTp);
296 } else {
297 matrix_transpose(N, K, &LIBXSMM_VLA_ACCESS(2, hgold, i-1, 0, K*N), hgoldTp);
298 }
299 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &K, &N, &alpha, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N), &K, hgoldTp, &N, &beta, djdugold, &K);
300 matrix_transpose(N, C, &LIBXSMM_VLA_ACCESS(2, xgold, i, 0, N*C), xgoldTp);
301 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &C, &N, &alpha, &LIBXSMM_VLA_ACCESS(2, deltagold, i, 0, K*N), &K, xgoldTp, &N, &beta, djdwgold, &K);
302 for (j = 0; j < K*N; j++) {
303 djdbgold[j%K] += LIBXSMM_VLA_ACCESS(2, deltagold, i, j, K*N);
304 }
305 }
306 }
307 printf("##########################################\n");
308 printf("# Computing Reference ... done #\n");
309 printf("##########################################\n");
310 }
311
312 if (1 /* format == 'A' || format == 'L' */) {
313 printf("\n");
314 printf("##########################################\n");
315 printf("# Setting Up (custom-Storage) #\n");
316 printf("##########################################\n");
317
318 if ( N % bn != 0 ) {
319 bn = N;
320 }
321 if ( C % bc != 0 ) {
322 bc = C;
323 }
324 if ( K % bk != 0 ) {
325 bk = K;
326 }
327
328 /* setup LIBXSMM handle */
329 rnncell_desc.threads = nThreads;
330 rnncell_desc.N = N;
331 rnncell_desc.C = C;
332 rnncell_desc.K = K;
333 rnncell_desc.bn = bn;
334 rnncell_desc.bk = bk;
335 rnncell_desc.bc = bc;
336 rnncell_desc.max_T = t;
337
338 if ( nonlin == 1 ) {
339 rnncell_desc.cell_type = LIBXSMM_DNN_RNNCELL_RNN_RELU;
340 } else if ( nonlin == 2 ) {
341 rnncell_desc.cell_type = LIBXSMM_DNN_RNNCELL_RNN_SIGMOID;
342 } else if ( nonlin == 3 ) {
343 rnncell_desc.cell_type = LIBXSMM_DNN_RNNCELL_RNN_TANH;
344 } else {
345 /* should not happen */
346 }
347 rnncell_desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
348 rnncell_desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
349 rnncell_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NC;
350 rnncell_desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_CK;
351
352 libxsmm_handle = libxsmm_dnn_create_rnncell( rnncell_desc, &status );
353 CHKERR_LIBXSMM_DNN( status );
354
355 /* setup LIBXSMM buffers and filter */
356 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
357 libxsmm_input = libxsmm_dnn_link_tensor( libxsmm_layout, xt, &status ); CHKERR_LIBXSMM_DNN( status );
358 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
359
360 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV, &status ); CHKERR_LIBXSMM_DNN( status );
361 libxsmm_hidden_state_prev = libxsmm_dnn_link_tensor( libxsmm_layout, hp, &status ); CHKERR_LIBXSMM_DNN( status );
362 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
363
364 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
365 libxsmm_weight = libxsmm_dnn_link_tensor( libxsmm_layout, w, &status ); CHKERR_LIBXSMM_DNN( status );
366 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
367
368 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
369 libxsmm_recur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, u, &status ); CHKERR_LIBXSMM_DNN( status );
370 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
371
372 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
373 libxsmm_bias = libxsmm_dnn_link_tensor( libxsmm_layout, b, &status ); CHKERR_LIBXSMM_DNN( status );
374 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
375
376 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
377 libxsmm_hidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, ht, &status ); CHKERR_LIBXSMM_DNN( status );
378 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
379
380 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
381 libxsmm_dinput = libxsmm_dnn_link_tensor( libxsmm_layout, djdxt, &status ); CHKERR_LIBXSMM_DNN( status );
382 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
383
384 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
385 libxsmm_dweight = libxsmm_dnn_link_tensor( libxsmm_layout, djdw, &status ); CHKERR_LIBXSMM_DNN( status );
386 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
387
388 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT, &status ); CHKERR_LIBXSMM_DNN( status );
389 libxsmm_drecur_weight = libxsmm_dnn_link_tensor( libxsmm_layout, djdu, &status ); CHKERR_LIBXSMM_DNN( status );
390 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
391
392 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
393 libxsmm_dbias = libxsmm_dnn_link_tensor( libxsmm_layout, djdb, &status ); CHKERR_LIBXSMM_DNN( status );
394 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
395
396 libxsmm_layout = libxsmm_dnn_rnncell_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE, &status ); CHKERR_LIBXSMM_DNN( status );
397 libxsmm_dhidden_state = libxsmm_dnn_link_tensor( libxsmm_layout, djdht, &status ); CHKERR_LIBXSMM_DNN( status );
398 libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
399
400 /* copy in data to LIBXSMM format */
401 matrix_copy( t*N*C, xgoldt, xt );
402 matrix_copy( K*N, hpgold, hp );
403 matrix_copy( C*K, wgold, w );
404 matrix_copy( K*K, ugold, u );
405 matrix_copy( K, bgold, b );
406 matrix_copy( t*K*N, djdhgoldt, djdht );
407
408 /* bind buffers and filter to handle */
409 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_input, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
410 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state_prev, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
411 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_weight, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
412 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_recur_weight, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
413 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_bias, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
414 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_hidden_state, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
415 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dinput, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
416 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dweight, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
417 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_drecur_weight, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
418 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dbias, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
419 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_tensor( libxsmm_handle, libxsmm_dhidden_state, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
420
421 /* let's allocate and bind scratch */
422 if (pass == 0) {
423 scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
424 CHKERR_LIBXSMM_DNN( status );
425 scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
426 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, scratch ) );
427 } else {
428 scratch_size = libxsmm_dnn_rnncell_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
429 CHKERR_LIBXSMM_DNN( status );
430 scratch = libxsmm_aligned_malloc( scratch_size, 2097152 );
431 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ) );
432 }
433 zero_buf( (float*)scratch, scratch_size/4 );
434
435 /* let's allocate and bind internalstate */
436 if (pass == 0) {
437 internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, &status );
438 CHKERR_LIBXSMM_DNN( status );
439 internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
440 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, internalstate ) );
441 } else {
442 internalstate_size = libxsmm_dnn_rnncell_get_internalstate_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status );
443 CHKERR_LIBXSMM_DNN( status );
444 internalstate = libxsmm_aligned_malloc( internalstate_size, 2097152 );
445 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_bind_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, internalstate ) );
446 }
447 zero_buf( (float*)internalstate, internalstate_size/4 );
448
449 if ((pass == 0) && LIBXSMM_NEQ(0, check)) {
450 printf("##########################################\n");
451 printf("# Correctness - FWD (custom-Storage) #\n");
452 printf("##########################################\n");
453 /* run LIBXSMM RNN */
454 #if defined(_OPENMP)
455 # pragma omp parallel
456 #endif
457 {
458 #if defined(_OPENMP)
459 const int tid = omp_get_thread_num();
460 #else
461 const int tid = 0;
462 #endif
463 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
464 }
465 matrix_copy( N*K, &LIBXSMM_VLA_ACCESS(2, h, t-1, 0, K*N), htest );
466
467 /* compare */
468 libxsmm_matdiff(&norms_fwd, LIBXSMM_DATATYPE_F32, K*N, 1, &LIBXSMM_VLA_ACCESS(2, hgold, t-1, 0, K*N), htest, 0, 0);
469 printf("L1 reference : %.25g\n", norms_fwd.l1_ref);
470 printf("L1 test : %.25g\n", norms_fwd.l1_tst);
471 printf("L2 abs.error : %.24f\n", norms_fwd.l2_abs);
472 printf("L2 rel.error : %.24f\n", norms_fwd.l2_rel);
473 printf("Linf abs.error: %.24f\n", norms_fwd.linf_abs);
474 printf("Linf rel.error: %.24f\n", norms_fwd.linf_rel);
475 printf("Check-norm : %.24f\n", norms_fwd.normf_rel);
476 libxsmm_matdiff_reduce(&diff, &norms_fwd);
477 } else {
478 /* We need to always run FWD pass once to populate zt, ht */
479 #if defined(_OPENMP)
480 # pragma omp parallel
481 #endif
482 {
483 #if defined(_OPENMP)
484 const int tid = omp_get_thread_num();
485 #else
486 const int tid = 0;
487 #endif
488 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
489 }
490 }
491
492 if ( (pass == 1) && LIBXSMM_NEQ(0, check) ) {
493 printf("##########################################\n");
494 printf("# Correctness - BWD (custom-Storage) #\n");
495 printf("##########################################\n");
496 /* run LIBXSMM RNN */
497 #if defined(_OPENMP)
498 # pragma omp parallel
499 #endif
500 {
501 #if defined(_OPENMP)
502 const int tid = omp_get_thread_num();
503 #else
504 const int tid = 0;
505 #endif
506 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid ) );
507 }
508
509 /* copy out data */
510 matrix_copy(N*C*t, djdxt, djdxtestt);
511
512 /* compare */
513 libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, djdxgoldt, djdxtestt, 0, 0);
514 printf("L1 reference : %.25g\n", norms_bwd.l1_ref);
515 printf("L1 test : %.25g\n", norms_bwd.l1_tst);
516 printf("L2 abs.error : %.24f\n", norms_bwd.l2_abs);
517 printf("L2 rel.error : %.24f\n", norms_bwd.l2_rel);
518 printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
519 printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
520 printf("Check-norm : %.24f\n", norms_bwd.normf_rel);
521 libxsmm_matdiff_reduce(&diff, &norms_bwd);
522 }
523
524 if ( (pass == 2) && LIBXSMM_NEQ(0, check) ) {
525 printf("##########################################\n");
526 printf("# Correctness - UPD (custom-Storage) #\n");
527 printf("##########################################\n");
528 /* run LIBXSMM RNN */
529 #if defined(_OPENMP)
530 # pragma omp parallel
531 #endif
532 {
533 #if defined(_OPENMP)
534 const int tid = omp_get_thread_num();
535 #else
536 const int tid = 0;
537 #endif
538 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid ) );
539 }
540
541 /* copy out data */
542 matrix_copy(C*K, djdw, djdwtest);
543 matrix_copy(K*K, djdu, djdutest);
544
545 /* compare */
546 libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K, 1, djdwgold, djdwtest, 0, 0);
547 printf("Delta weight\n");
548 printf("L1 reference : %.25g\n", norms_upd_w.l1_ref);
549 printf("L1 test : %.25g\n", norms_upd_w.l1_tst);
550 printf("L2 abs.error : %.24f\n", norms_upd_w.l2_abs);
551 printf("L2 rel.error : %.24f\n", norms_upd_w.l2_rel);
552 printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
553 printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
554 printf("Check-norm : %.24f\n", norms_upd_w.normf_rel);
555 libxsmm_matdiff_reduce(&diff, &norms_upd_w);
556
557 libxsmm_matdiff(&norms_upd_u, LIBXSMM_DATATYPE_F32, K*K, 1, djdugold, djdutest, 0, 0);
558 printf("Delta recurrent weight\n");
559 printf("L1 reference : %.25g\n", norms_upd_u.l1_ref);
560 printf("L1 test : %.25g\n", norms_upd_u.l1_tst);
561 printf("L2 abs.error : %.24f\n", norms_upd_u.l2_abs);
562 printf("L2 rel.error : %.24f\n", norms_upd_u.l2_rel);
563 printf("Linf abs.error: %.24f\n", norms_upd_u.linf_abs);
564 printf("Linf rel.error: %.24f\n", norms_upd_u.linf_rel);
565 printf("Check-norm : %.24f\n", norms_upd_u.normf_rel);
566 libxsmm_matdiff_reduce(&diff, &norms_upd_u);
567
568 libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K, 1, djdbgold, djdb, 0, 0);
569 printf("Delta bias\n");
570 printf("L1 reference : %.25g\n", norms_upd_b.l1_ref);
571 printf("L1 test : %.25g\n", norms_upd_b.l1_tst);
572 printf("L2 abs.error : %.24f\n", norms_upd_b.l2_abs);
573 printf("L2 rel.error : %.24f\n", norms_upd_b.l2_rel);
574 printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
575 printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
576 printf("Check-norm : %.24f\n", norms_upd_b.normf_rel);
577 libxsmm_matdiff_reduce(&diff, &norms_upd_b);
578 }
579
580 if ( (pass == 3) && LIBXSMM_NEQ(0, check) ) {
581 printf("##########################################\n");
582 printf("# Correctness - BWD+UPD (custom-Storage) #\n");
583 printf("##########################################\n");
584 /* run LIBXSMM RNN */
585 #if defined(_OPENMP)
586 # pragma omp parallel
587 #endif
588 {
589 #if defined(_OPENMP)
590 const int tid = omp_get_thread_num();
591 #else
592 const int tid = 0;
593 #endif
594 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid ) );
595 }
596
597 /* copy out data */
598 matrix_copy(N*C*t, djdxt, djdxtestt);
599 matrix_copy(C*K, djdw, djdwtest);
600 matrix_copy(K*K, djdu, djdutest);
601
602 /* compare */
603 libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, N*C*t, 1, djdxgoldt, djdxtestt, 0, 0);
604 printf("Delta input\n");
605 printf("L1 reference : %.25g\n", norms_bwd.l1_ref);
606 printf("L1 test : %.25g\n", norms_bwd.l1_tst);
607 printf("L2 abs.error : %.24f\n", norms_bwd.l2_abs);
608 printf("L2 rel.error : %.24f\n", norms_bwd.l2_rel);
609 printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
610 printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
611 printf("Check-norm : %.24f\n", norms_bwd.normf_rel);
612 libxsmm_matdiff_reduce(&diff, &norms_bwd);
613
614 libxsmm_matdiff(&norms_upd_w, LIBXSMM_DATATYPE_F32, C*K, 1, djdwgold, djdwtest, 0, 0);
615 printf("Delta weight\n");
616 printf("L1 reference : %.25g\n", norms_upd_w.l1_ref);
617 printf("L1 test : %.25g\n", norms_upd_w.l1_tst);
618 printf("L2 abs.error : %.24f\n", norms_upd_w.l2_abs);
619 printf("L2 rel.error : %.24f\n", norms_upd_w.l2_rel);
620 printf("Linf abs.error: %.24f\n", norms_upd_w.linf_abs);
621 printf("Linf rel.error: %.24f\n", norms_upd_w.linf_rel);
622 printf("Check-norm : %.24f\n", norms_upd_w.normf_rel);
623 libxsmm_matdiff_reduce(&diff, &norms_upd_w);
624
625 libxsmm_matdiff(&norms_upd_u, LIBXSMM_DATATYPE_F32, K*K, 1, djdugold, djdutest, 0, 0);
626 printf("Delta recurrent weight\n");
627 printf("L1 reference : %.25g\n", norms_upd_u.l1_ref);
628 printf("L1 test : %.25g\n", norms_upd_u.l1_tst);
629 printf("L2 abs.error : %.24f\n", norms_upd_u.l2_abs);
630 printf("L2 rel.error : %.24f\n", norms_upd_u.l2_rel);
631 printf("Linf abs.error: %.24f\n", norms_upd_u.linf_abs);
632 printf("Linf rel.error: %.24f\n", norms_upd_u.linf_rel);
633 printf("Check-norm : %.24f\n", norms_upd_u.normf_rel);
634 libxsmm_matdiff_reduce(&diff, &norms_upd_u);
635
636 libxsmm_matdiff(&norms_upd_b, LIBXSMM_DATATYPE_F32, K, 1, djdbgold, djdb, 0, 0);
637 printf("Delta bias\n");
638 printf("L1 reference : %.25g\n", norms_upd_b.l1_ref);
639 printf("L1 test : %.25g\n", norms_upd_b.l1_tst);
640 printf("L2 abs.error : %.24f\n", norms_upd_b.l2_abs);
641 printf("L2 rel.error : %.24f\n", norms_upd_b.l2_rel);
642 printf("Linf abs.error: %.24f\n", norms_upd_b.linf_abs);
643 printf("Linf rel.error: %.24f\n", norms_upd_b.linf_rel);
644 printf("Check-norm : %.24f\n", norms_upd_b.normf_rel);
645 libxsmm_matdiff_reduce(&diff, &norms_upd_b);
646 }
647
648 if ( pass == 0 ) {
649 printf("##########################################\n");
650 printf("# Performance - FWD (custom-Storage) #\n");
651 printf("##########################################\n");
652 /* run LIBXSMM RNN for performance */
653 l_start = libxsmm_timer_tick();
654
655 #if defined(_OPENMP)
656 # pragma omp parallel private(i)
657 #endif
658 {
659 #if defined(_OPENMP)
660 const int tid = omp_get_thread_num();
661 #else
662 const int tid = 0;
663 #endif
664 for (i = 0; i < iters; ++i) {
665 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
666 }
667 }
668 l_end = libxsmm_timer_tick();
669 l_total = libxsmm_timer_duration(l_start, l_end);
670 flops = ((2.0 * K*N*C) + (2.0 * K*N*K) + (K*N) + (tflops * K*N)) * (double)t * (double)iters;
671
672 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
673 printf("fp time = %.5g\n", ((double)(l_total/iters)));
674 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
675
676 printf("PERFDUMP,FP,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
677 }
678
679 if ( pass == 1 ) {
680 printf("##########################################\n");
681 printf("# Performance - BWD (custom-Storage) #\n");
682 printf("##########################################\n");
683 /* run LIBXSMM RNN for performance */
684 l_start = libxsmm_timer_tick();
685
686 #if defined(_OPENMP)
687 # pragma omp parallel private(i)
688 #endif
689 {
690 #if defined(_OPENMP)
691 const int tid = omp_get_thread_num();
692 #else
693 const int tid = 0;
694 #endif
695 for (i = 0; i < iters; ++i) {
696 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid );
697 }
698 }
699 l_end = libxsmm_timer_tick();
700 l_total = libxsmm_timer_duration(l_start, l_end);
701 flops = K*K; /* U^T */
702 flops += (2.0 * K*N*K); /* U^T * delta */
703 flops += (K*N); /* dJdh + (U^T * delta) */
704 flops += (tflops * K*N); /* sigma'(Z) */
705 flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
706 flops *= t; /* for t time steps */
707 tempflops = C*K; /* W^T */
708 tempflops += (2.0 * K*N*C); /* W^T * delta */
709 tempflops *= t; /* for t time steps of input */
710 flops += tempflops;
711 flops *= iters;
712
713 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
714 printf("bp time = %.5g\n", ((double)(l_total/iters)));
715 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
716
717 printf("PERFDUMP,BP,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
718 }
719
720 if ( pass == 2 ) {
721 printf("##########################################\n");
722 printf("# Performance - UPD (custom-Storage) #\n");
723 printf("##########################################\n");
724 /* run LIBXSMM RNN for performance */
725 l_start = libxsmm_timer_tick();
726
727 #if defined(_OPENMP)
728 # pragma omp parallel private(i)
729 #endif
730 {
731 #if defined(_OPENMP)
732 const int tid = omp_get_thread_num();
733 #else
734 const int tid = 0;
735 #endif
736 for (i = 0; i < iters; ++i) {
737 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid );
738 }
739 }
740 l_end = libxsmm_timer_tick();
741 l_total = libxsmm_timer_duration(l_start, l_end);
742 flops = K*K; /* U^T */
743 flops += (2.0 * K*N*K); /* U^T * delta */
744 flops += (K*N); /* dJdh + (U^T * delta) */
745 flops += (tflops * K*N); /* sigma'(Z) */
746 flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
747 flops *= t; /* for t time steps */
748 tempflops = K*N; /* h^T */
749 tempflops += (2.0 * K*N*K); /* delta * h^T */
750 tempflops *= t; /* for t time steps */
751 tempflops += (K*K * (t-1)); /* for summation of dJdU */
752 flops += tempflops;
753 tempflops = N*C; /* x^T */
754 tempflops += (2.0 * K*N*C); /* delta * x^T */
755 tempflops *= t; /* for t time steps */
756 tempflops += (C*K * (t-1)); /* for summation of dJdW */
757 flops += tempflops;
758 flops *= iters;
759
760 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
761 printf("wu time = %.5g\n", ((double)(l_total/iters)));
762 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
763
764 printf("PERFDUMP,WU,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
765 }
766
767 if ( pass == 3 ) {
768 printf("##########################################\n");
769 printf("# Performance - BWD+UPD (custom-Storage) #\n");
770 printf("##########################################\n");
771 /* run LIBXSMM RNN for performance */
772 l_start = libxsmm_timer_tick();
773
774 #if defined(_OPENMP)
775 # pragma omp parallel private(i)
776 #endif
777 {
778 #if defined(_OPENMP)
779 const int tid = omp_get_thread_num();
780 #else
781 const int tid = 0;
782 #endif
783 for (i = 0; i < iters; ++i) {
784 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
785 }
786 }
787 l_end = libxsmm_timer_tick();
788 l_total = libxsmm_timer_duration(l_start, l_end);
789 flops = K*K; /* U^T */
790 flops += (2.0 * K*N*K); /* U^T * delta */
791 flops += (K*N); /* dJdh + (U^T * delta) */
792 flops += (tflops * K*N); /* sigma'(Z) */
793 flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
794 flops *= t; /* for t time steps */
795 tempflops = K*N; /* h^T */
796 tempflops += (2.0 * K*N*K); /* delta * h^T */
797 tempflops *= t; /* for t time steps */
798 tempflops += (K*K * (t-1)); /* for summation of dJdU */
799 flops += tempflops;
800 tempflops = N*C; /* x^T */
801 tempflops += (2.0 * K*N*C); /* delta * x^T */
802 tempflops *= t; /* for t time steps */
803 tempflops += (C*K * (t-1)); /* for summation of dJdW */
804 flops += tempflops;
805 tempflops = C*K; /* W^T */
806 tempflops += (2.0 * K*N*C); /* W^T * delta */
807 tempflops *= t; /* for t time steps of input */
808 flops += tempflops;
809 flops *= iters;
810
811 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
812 printf("bp+wu time = %.5g\n", ((double)(l_total/iters)));
813 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
814
815 printf("PERFDUMP,BP+WU,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
816 }
817
818 if ( pass == 4 ) {
819 printf("#############################################\n");
820 printf("# Performance - FWD+BWD+UPD (nc-ck Storage) #\n");
821 printf("#############################################\n");
822 /* run LIBXSMM RNN for performance */
823 l_start = libxsmm_timer_tick();
824
825 #if defined(_OPENMP)
826 # pragma omp parallel private(i)
827 #endif
828 {
829 #if defined(_OPENMP)
830 const int tid = omp_get_thread_num();
831 #else
832 const int tid = 0;
833 #endif
834 for (i = 0; i < iters; ++i) {
835 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
836 libxsmm_dnn_rnncell_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
837 }
838 }
839 l_end = libxsmm_timer_tick();
840 l_total = libxsmm_timer_duration(l_start, l_end);
841 flops = (2.0 * K*N*K); /* U^T * delta */
842 flops += (K*N); /* dJdh + (U^T * delta) */
843 flops += (tflops * K*N); /* sigma'(Z) */
844 flops += (K*N); /* sigma'(Z) * (dJdh + (U^T * delta)) */
845 flops *= t; /* for t time steps */
846 tempflops = (2.0 * K*N*K); /* delta * h^T */
847 tempflops *= t; /* for t time steps */
848 tempflops += (K*K * (t-1)); /* for summation of dJdU */
849 flops += tempflops;
850 tempflops = (2.0 * K*N*C); /* delta * x^T */
851 tempflops *= t; /* for t time steps */
852 tempflops += (C*K * (t-1)); /* for summation of dJdW */
853 flops += tempflops;
854 tempflops = (2.0 * K*N*C); /* W^T * delta */
855 tempflops *= t; /* for t time steps of input */
856 flops += tempflops;
857 flops *= iters;
858 flops += ((2.0 * K*N*C) + (2.0 * K*N*K) + (K*N) + (tflops * K*N)) * (double)t * (double)iters;
859
860 printf("GFLOP = %.5g\n", flops*1e-9/(double)iters);
861 printf("fp+bp+wu time = %.5g\n", ((double)(l_total/iters)));
862 printf("GFLOPS = %.5g\n", (flops*1e-9)/l_total);
863
864 printf("PERFDUMP,FP+BP+WU,%s,%i,%i,%i,%i,%i,%.5g,%.5g\n", LIBXSMM_VERSION, nThreads, N, C, K, t, ((double)(l_total/iters)), (flops*1e-9)/l_total);
865 }
866
867 /* clean-up */
868 if (pass == 0) {
869 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
870 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD ) );
871 } else {
872 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
873 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_internalstate( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ) );
874 }
875 libxsmm_free(scratch);
876 libxsmm_free(internalstate);
877 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_INPUT ) );
878 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV ) );
879 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_WEIGHT ) );
880 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT ) );
881 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_BIAS ) );
882 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE ) );
883 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_INPUT ) );
884 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_WEIGHT ) );
885 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT ) );
886 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_BIAS ) );
887 CHKERR_LIBXSMM_DNN( libxsmm_dnn_rnncell_release_tensor( libxsmm_handle, LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE ) );
888 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_input ) );
889 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state_prev ) );
890 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_weight ) );
891 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_recur_weight ) );
892 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_bias ) );
893 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_hidden_state ) );
894 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dinput ) );
895 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dweight ) );
896 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_drecur_weight ) );
897 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dbias ) );
898 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_dhidden_state ) );
899 CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_rnncell( libxsmm_handle ) );
900 }
901
902 /* deallocate data */
903 libxsmm_free(xgoldt);
904 libxsmm_free(hpgold);
905 libxsmm_free(wgold);
906 libxsmm_free(ugold);
907 libxsmm_free(bgold);
908 libxsmm_free(hgoldt);
909 libxsmm_free(zgoldt);
910 libxsmm_free(bmgold);
911 libxsmm_free(z1gold);
912 libxsmm_free(z2gold);
913 libxsmm_free(djdxgoldt);
914 libxsmm_free(djdwgold);
915 libxsmm_free(djdugold);
916 libxsmm_free(djdbgold);
917 libxsmm_free(djdhgoldt);
918 libxsmm_free(deltagoldt);
919 libxsmm_free(zigold);
920 libxsmm_free(di1gold);
921 libxsmm_free(di2gold);
922 libxsmm_free(xgoldTp);
923 libxsmm_free(wgoldTp);
924 libxsmm_free(ugoldTp);
925 libxsmm_free(hgoldTp);
926 libxsmm_free(xt);
927 libxsmm_free(hp);
928 libxsmm_free(w);
929 libxsmm_free(u);
930 libxsmm_free(b);
931 libxsmm_free(ht);
932 libxsmm_free(djdxt);
933 libxsmm_free(djdw);
934 libxsmm_free(djdu);
935 libxsmm_free(djdb);
936 libxsmm_free(djdht);
937 libxsmm_free(htest);
938 libxsmm_free(djdxtestt);
939 libxsmm_free(djdwtest);
940 libxsmm_free(djdutest);
941
942 { const char *const env_check_scale = getenv("CHECK_SCALE");
943 const double check_scale = LIBXSMM_ABS(0 == env_check_scale ? 1.0 : atof(env_check_scale));
944 if (LIBXSMM_NEQ(0, check) && (check < 100.0 * check_scale * diff.normf_rel) && (global_status == LIBXSMM_DNN_SUCCESS)) {
945 fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
946 exit(EXIT_FAILURE);
947 }
948 }
949
950 /* some empty lines at the end */
951 printf("\n\n\n");
952
953 return global_status;
954 }
955
956