1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11
12 #include <libxsmm.h>
13 #include <libxsmm_intrinsics_x86.h>
14
15 #if defined(_OPENMP)
16 # include <omp.h>
17 #endif
18
19 typedef struct {
20 int nImg;
21 int nIfm;
22 int nOfm;
23 int ifhp;
24 int ifwp;
25 int ifh;
26 int ifw;
27 int ofhp;
28 int ofwp;
29 int ofh;
30 int ofw;
31 int pad_h;
32 int pad_w;
33 int pad_h_in;
34 int pad_w_in;
35 int pad_h_out;
36 int pad_w_out;
37 int kh;
38 int kw;
39 int stride_h;
40 int stride_w;
41 } naive_conv_t;
42
43 typedef struct {
44 int N;
45 int C;
46 int H;
47 int W;
48 int stride_h;
49 int stride_w;
50 int norm_type; /* 0: full batchnorm, 1: batch scaling only */
51 int fuse_type; /* 0: nothing fused, 1: relu fused, 2: elementwise fused, 3: relu and elementwise fused */
52 } naive_fusedbatchnorm_t;
53
54 typedef struct {
55 int N;
56 int C;
57 int G;
58 int H;
59 int W;
60 int stride_h;
61 int stride_w;
62 int fuse_type; /* 0: nothing fused, 1: relu fused, 2: elementwise fused, 3: relu and elementwise fused */
63 } naive_fusedgroupnorm_t;
64
65 typedef struct {
66 int N;
67 int C;
68 int K;
69 int fuse_type; /* 0: nothing fused */
70 } naive_fullyconnected_t;
71
72 typedef struct {
73 int N;
74 int C;
75 int H;
76 int W;
77 int R;
78 int S;
79 int pad_h;
80 int pad_w;
81 int stride_h;
82 int stride_w;
83 int type;
84 } naive_pooling_t;
85
86 /* it's fine to alias in and out */
truncate_mask_fp32_bf16(float * in,float * out,unsigned int len)87 LIBXSMM_INLINE void truncate_mask_fp32_bf16(float* in, float* out, unsigned int len) {
88 unsigned int i = 0;
89
90 /* truncate buffer to bf16 */
91 for ( i = 0; i < len; ++i ) {
92 union libxsmm_bfloat16_hp t;
93
94 t.f = in[i];
95 t.i[0] = 0;
96 out[i] = t.f;
97 }
98 }
99
100 /* it's fine to alias in and out */
rnaz_mask_fp32_bf16(float * in,float * out,unsigned int len)101 LIBXSMM_INLINE void rnaz_mask_fp32_bf16(float* in, float* out, unsigned int len) {
102 unsigned int i = 0;
103
104 /* rnaz buffer to bf16 */
105 for ( i = 0; i < len; ++i ) {
106 unsigned int int_round = 0;
107 unsigned int do_round = 1;
108 const void *const ptr = &int_round;
109
110 int_round = *((unsigned int*)&(in[i]));
111
112 /* we don't round NaN and inf */
113 if ( (int_round & 0x7f800000) == 0x7f800000 ) {
114 do_round = 0;
115 }
116
117 /* perform round nearest tie away from zero */
118 if ( do_round != 0 ) {
119 int_round = int_round + 0x00008000;
120 }
121
122 /* chop bits to create BFP16 in FP32 */
123 int_round = int_round & 0xffff0000;
124
125 out[i] = *((float*)ptr);
126 }
127 }
128
129 /* it's fine to alias in and out */
rne_mask_fp32_bf16(float * in,float * out,unsigned int len)130 LIBXSMM_INLINE void rne_mask_fp32_bf16(float* in, float* out, unsigned int len) {
131 unsigned int i = 0;
132
133 /* rnaz buffer to bf16 */
134 for ( i = 0; i < len; ++i ) {
135 unsigned int int_round = 0;
136 unsigned int do_round = 1;
137 const void *const ptr = &int_round;
138
139 int_round = *((unsigned int*)&(in[i]));
140
141 /* we don't round NaN and inf */
142 if ( (int_round & 0x7f800000) == 0x7f800000 ) {
143 do_round = 0;
144 }
145
146 /* perform round nearest tie even */
147 if ( do_round != 0 ) {
148 unsigned int fixup = (int_round >> 16) & 1;
149 int_round = int_round + 0x00007fff + fixup;
150 }
151
152 /* chop bits to create BFP16 in FP32 */
153 int_round = int_round & 0xffff0000;
154
155 out[i] = *((float*)ptr);
156 }
157 }
158
zero_buf(float * buf,size_t size)159 LIBXSMM_INLINE void zero_buf(float* buf, size_t size) {
160 int i;
161 #if defined(_OPENMP)
162 LIBXSMM_OMP_VAR(i);
163 # pragma omp parallel for private(i)
164 #endif
165 for (i = 0; i < (int)size; ++i) {
166 buf[i] = 0.0f;
167 }
168 }
169
zero_buf_bf16(libxsmm_bfloat16 * buf,size_t size)170 LIBXSMM_INLINE void zero_buf_bf16(libxsmm_bfloat16* buf, size_t size) {
171 int i;
172 #if defined(_OPENMP)
173 # pragma omp parallel for private(i)
174 #endif
175 for (i = 0; i < (int)size; ++i) {
176 buf[i] = 0;
177 }
178 }
179
zero_buf_int16(short * buf,size_t size)180 LIBXSMM_INLINE void zero_buf_int16(short* buf, size_t size) {
181 int i;
182 #if defined(_OPENMP)
183 LIBXSMM_OMP_VAR(i);
184 # pragma omp parallel for private(i)
185 #endif
186 for (i = 0; i < (int)size; ++i) {
187 buf[i] = 0;
188 }
189 }
190
zero_buf_int32(int * buf,size_t size)191 LIBXSMM_INLINE void zero_buf_int32(int* buf, size_t size) {
192 int i;
193 #if defined(_OPENMP)
194 LIBXSMM_OMP_VAR(i);
195 # pragma omp parallel for private(i)
196 #endif
197 for (i = 0; i < (int)size; ++i) {
198 buf[i] = 0;
199 }
200 }
201
zero_buf_int8(char * buf,size_t size)202 LIBXSMM_INLINE void zero_buf_int8(char* buf, size_t size) {
203 int i;
204 #if defined(_OPENMP)
205 LIBXSMM_OMP_VAR(i);
206 # pragma omp parallel for private(i)
207 #endif
208 for (i = 0; i < (int)size; ++i) {
209 buf[i] = 0;
210 }
211 }
212
zero_buf_uint8(unsigned char * buf,size_t size)213 LIBXSMM_INLINE void zero_buf_uint8(unsigned char* buf, size_t size) {
214 int i;
215 #if defined(_OPENMP)
216 LIBXSMM_OMP_VAR(i);
217 # pragma omp parallel for private(i)
218 #endif
219 for (i = 0; i < (int)size; ++i) {
220 buf[i] = 0;
221 }
222 }
223
copy_buf(float * src,float * dst,size_t size)224 LIBXSMM_INLINE void copy_buf(float* src, float* dst, size_t size) {
225 int i;
226 #if defined(_OPENMP)
227 LIBXSMM_OMP_VAR(i);
228 # pragma omp parallel for private(i)
229 #endif
230 for (i = 0; i < (int)size; ++i) {
231 dst[i] = src[i];
232 }
233 }
234
copy_buf_int16(short * src,short * dst,size_t size)235 LIBXSMM_INLINE void copy_buf_int16(short* src, short* dst, size_t size) {
236 int i;
237 #if defined(_OPENMP)
238 LIBXSMM_OMP_VAR(i);
239 # pragma omp parallel for private(i)
240 #endif
241 for (i = 0; i < (int)size; ++i) {
242 dst[i] = src[i];
243 }
244 }
245
copy_buf_int8(char * src,char * dst,size_t size)246 LIBXSMM_INLINE void copy_buf_int8(char* src, char* dst, size_t size) {
247 int i;
248 #if defined(_OPENMP)
249 LIBXSMM_OMP_VAR(i);
250 # pragma omp parallel for private(i)
251 #endif
252 for (i = 0; i < (int)size; ++i) {
253 dst[i] = src[i];
254 }
255 }
256
copy_buf_uint8(unsigned char * src,unsigned char * dst,size_t size)257 LIBXSMM_INLINE void copy_buf_uint8(unsigned char* src, unsigned char* dst, size_t size) {
258 int i;
259 #if defined(_OPENMP)
260 LIBXSMM_OMP_VAR(i);
261 # pragma omp parallel for private(i)
262 #endif
263 for (i = 0; i < (int)size; ++i) {
264 dst[i] = src[i];
265 }
266 }
267
init_buf(float * buf,size_t size,int initPos,int initOne)268 LIBXSMM_INLINE void init_buf(float* buf, size_t size, int initPos, int initOne)
269 {
270 int i;
271 zero_buf(buf, size);
272 #if defined(_OPENMP)
273 # pragma omp parallel for private(i)
274 #endif
275 for (i = 0; i < (int)size; ++i) {
276 buf[i] = (float)((initOne != 0) ? 1.0 : ((initPos != 0) ? libxsmm_rng_f64() : (0.05 - libxsmm_rng_f64()/10.0)));
277 }
278 }
279
init_buf_bf16(libxsmm_bfloat16 * buf,size_t size,int initPos,int initOne)280 LIBXSMM_INLINE void init_buf_bf16(libxsmm_bfloat16* buf, size_t size, int initPos, int initOne)
281 {
282 int i;
283 zero_buf_bf16(buf, size);
284 #if defined(_OPENMP)
285 # pragma omp parallel for private(i)
286 #endif
287 for (i = 0; i < (int)size; ++i) {
288 libxsmm_bfloat16_hp tmp;
289 tmp.f = (float)((initOne != 0) ? 1.0 : ((initPos != 0) ? libxsmm_rng_f64() : (0.05 - libxsmm_rng_f64()/10.0)));
290 buf[i] = tmp.i[1];
291 }
292 }
293
libxsmm_dnn_dequantize_int8(char * in_buffer,float * out_buffer,int length,unsigned char scf)294 LIBXSMM_INLINE void libxsmm_dnn_dequantize_int8( char* in_buffer, float* out_buffer, int length, unsigned char scf ) {
295 const float val_exp = libxsmm_sexp2_i8i(-scf);
296 int i = 0;
297 #ifdef _OPENMP
298 # pragma omp parallel for private(i)
299 #endif
300 for ( i = 0; i < length; ++i ) {
301 out_buffer[i] = ((float)in_buffer[i])*val_exp;
302 }
303 }
304
libxsmm_internal_get_max_common(float * in_buffer,int length)305 LIBXSMM_INLINE float libxsmm_internal_get_max_common( float* in_buffer, int length ) {
306 float absmax_value = LIBXSMM_ABS(in_buffer[0]);
307 int i = 0;
308 for (i = 1; i < length; ++i ) {
309 if (LIBXSMM_ABS(in_buffer[i]) > absmax_value) {
310 absmax_value = LIBXSMM_ABS(in_buffer[i]);
311 }
312 }
313 return absmax_value;
314 }
315
quantize_buffer_char(float * in_buffer,char * out_buffer,int size,unsigned char add_shift,unsigned char * scf)316 LIBXSMM_INLINE void quantize_buffer_char(float *in_buffer, char *out_buffer, int size, unsigned char add_shift, unsigned char* scf) {
317 int i;
318 const float max_value = libxsmm_internal_get_max_common(in_buffer, size);
319 int maxexp = 0;
320 /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
321 float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
322 maxexp -= (7 - add_shift);
323 scfq = libxsmm_sexp2_i8i(-maxexp);
324 for (i=0; i<size; i++) {
325 out_buffer[i] = (char)LIBXSMM_ROUNDF(in_buffer[i]*scfq);
326 }
327 *scf = (unsigned char)(-maxexp);
328 }
329
quantize_buffer_uchar(float * in_buffer,unsigned char * out_buffer,int size,unsigned char add_shift,unsigned char * scf)330 LIBXSMM_INLINE void quantize_buffer_uchar(float *in_buffer, unsigned char *out_buffer, int size, unsigned char add_shift, unsigned char* scf) {
331 int i;
332 const float max_value = libxsmm_internal_get_max_common(in_buffer, size);
333 int maxexp = 0;
334 /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
335 float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
336 maxexp -= (7 - add_shift);
337 scfq = libxsmm_sexp2_i8i(-maxexp);
338 for (i=0; i<size; i++) {
339 out_buffer[i] = (unsigned char)LIBXSMM_ROUNDF(in_buffer[i]*scfq);
340 }
341 *scf = (unsigned char)(-maxexp);
342 }
343
init_buf_range(float * buf,size_t size,float low,float high)344 LIBXSMM_INLINE void init_buf_range(float* buf, size_t size, float low, float high)
345 {
346 int i;
347 float range = high - low;
348 zero_buf(buf, size);
349 for (i = 0; i < (int)size; ++i) {
350 buf[i] = (((float)rand())/RAND_MAX)*range+low;
351 }
352 }
353
init_buf_int16(short * buf,size_t size,int initPos,int initOne)354 LIBXSMM_INLINE void init_buf_int16(short* buf, size_t size, int initPos, int initOne)
355 {
356 int i;
357 zero_buf_int16(buf, size);
358 #if defined(_OPENMP)
359 # pragma omp parallel for private(i)
360 #endif
361 for (i = 0; i < (int)size; ++i) {
362 buf[i] = (short)((initOne != 0) ? 1 : ((initPos != 0) ? (rand()%7) : (rand()%7-3)));
363 }
364 }
365
init_buf_int32(int * buf,size_t size,int initPos,int initOne)366 LIBXSMM_INLINE void init_buf_int32(int* buf, size_t size, int initPos, int initOne)
367 {
368 int i;
369 zero_buf_int32(buf, size);
370 #if defined(_OPENMP)
371 # pragma omp parallel for private(i)
372 #endif
373 for (i = 0; i < (int)size; ++i) {
374 buf[i] = (int)((initOne != 0) ? 1 : ((initPos != 0) ? (rand()%7) : (rand()%7-3)));
375 }
376 }
377
init_buf_int8(char * buf,size_t size,int initPos,int initOne)378 LIBXSMM_INLINE void init_buf_int8(char* buf, size_t size, int initPos, int initOne)
379 {
380 int i;
381 zero_buf_int8(buf, size);
382 #if defined(_OPENMP)
383 # pragma omp parallel for private(i)
384 #endif
385 for (i = 0; i < (int)size; ++i) {
386 buf[i] = (char)((initOne != 0) ? 1 : ((initPos != 0) ? (rand()%3) : (rand()%3)-1));
387 }
388 }
389
init_buf_uint8(unsigned char * buf,size_t size,int initPos,int initOne)390 LIBXSMM_INLINE void init_buf_uint8(unsigned char* buf, size_t size, int initPos, int initOne)
391 {
392 int i;
393 LIBXSMM_UNUSED(initPos);
394 zero_buf_uint8(buf, size);
395 #if defined(_OPENMP)
396 # pragma omp parallel for private(i)
397 #endif
398 for (i = 0; i < (int)size; ++i) {
399 buf[i] = (unsigned char)((initOne != 0) ? 1 : (rand()%3));
400 }
401 }
402
set_zeropad_nchw(float * nchw,int N,int C,int H,int W,int pad_h,int pad_w)403 LIBXSMM_INLINE void set_zeropad_nchw(float* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
404 {
405 LIBXSMM_VLA_DECL(4, float, input, nchw, C, H, W);
406 int n, h, w, c;
407
408 #if defined(_OPENMP)
409 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
410 # pragma omp parallel for private(n,c,h,w)
411 #endif
412 for ( n = 0; n < N; n++ ) {
413 for ( c = 0; c < C; c++ ) {
414 for ( h = 0; h < H; h++ ) {
415 for ( w = 0; w < W; w++ ) {
416 if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
417 LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0.0;
418 }
419 }
420 }
421 }
422 }
423 }
424
set_zeropad_nchw_int16(short * nchw,int N,int C,int H,int W,int pad_h,int pad_w)425 LIBXSMM_INLINE void set_zeropad_nchw_int16(short* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
426 {
427 LIBXSMM_VLA_DECL(4, short, input, nchw, C, H, W);
428 int n, h, w, c;
429
430 #if defined(_OPENMP)
431 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
432 # pragma omp parallel for private(n,c,h,w)
433 #endif
434 for ( n = 0; n < N; n++ ) {
435 for ( c = 0; c < C; c++ ) {
436 for ( h = 0; h < H; h++ ) {
437 for ( w = 0; w < W; w++ ) {
438 if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
439 LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0;
440 }
441 }
442 }
443 }
444 }
445 }
446
set_zeropad_nchw_int32(int * nchw,int N,int C,int H,int W,int pad_h,int pad_w)447 LIBXSMM_INLINE void set_zeropad_nchw_int32(int* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
448 {
449 LIBXSMM_VLA_DECL(4, int, input, nchw, C, H, W);
450 int n, h, w, c;
451
452 #if defined(_OPENMP)
453 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
454 # pragma omp parallel for private(n,c,h,w)
455 #endif
456 for ( n = 0; n < N; n++ ) {
457 for ( c = 0; c < C; c++ ) {
458 for ( h = 0; h < H; h++ ) {
459 for ( w = 0; w < W; w++ ) {
460 if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
461 LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0;
462 }
463 }
464 }
465 }
466 }
467 }
468
set_zeropad_nchw_uint8(unsigned char * nchw,int N,int C,int H,int W,int pad_h,int pad_w)469 LIBXSMM_INLINE void set_zeropad_nchw_uint8(unsigned char* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
470 {
471 LIBXSMM_VLA_DECL(4, unsigned char, input, nchw, C, H, W);
472 int n, h, w, c;
473
474 #if defined(_OPENMP)
475 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
476 # pragma omp parallel for private(n,c,h,w)
477 #endif
478 for ( n = 0; n < N; n++ ) {
479 for ( c = 0; c < C; c++ ) {
480 for ( h = 0; h < H; h++ ) {
481 for ( w = 0; w < W; w++ ) {
482 if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
483 LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0;
484 }
485 }
486 }
487 }
488 }
489 }
490
copy_internal_nchw(float * dst,float * src,int N,int C,int H,int W,int pad_h,int pad_w)491 LIBXSMM_INLINE void copy_internal_nchw(float* dst , float* src, int N, int C, int H, int W, int pad_h, int pad_w)
492 {
493 LIBXSMM_VLA_DECL(4, float, input, src, C, H, W);
494 LIBXSMM_VLA_DECL(4, float, new_input, dst, C, H+2*pad_h, W+2*pad_w);
495 int n, h, w, c;
496
497 #if defined(_OPENMP)
498 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
499 # pragma omp parallel for private(n,c,h,w)
500 #endif
501 for ( n = 0; n < N; n++ ) {
502 for ( c = 0; c < C; c++ ) {
503 for ( h = 0; h < H; h++ ) {
504 for ( w = 0; w < W; w++ ) {
505 LIBXSMM_VLA_ACCESS(4, new_input, n, c, h+pad_h, w+pad_w, C, H+2*pad_h, W+2*pad_w) = LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
506 }
507 }
508 }
509 }
510 }
511
copy_internal_nchw_int16(short * dst,short * src,int N,int C,int H,int W,int pad_h,int pad_w)512 LIBXSMM_INLINE void copy_internal_nchw_int16(short* dst , short* src, int N, int C, int H, int W, int pad_h, int pad_w)
513 {
514 LIBXSMM_VLA_DECL(4, short, input, src, C, H, W);
515 LIBXSMM_VLA_DECL(4, short, new_input, dst, C, H+2*pad_h, W+2*pad_w);
516 int n, h, w, c;
517
518 #if defined(_OPENMP)
519 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
520 # pragma omp parallel for private(n,c,h,w)
521 #endif
522 for ( n = 0; n < N; n++ ) {
523 for ( c = 0; c < C; c++ ) {
524 for ( h = 0; h < H; h++ ) {
525 for ( w = 0; w < W; w++ ) {
526 LIBXSMM_VLA_ACCESS(4, new_input, n, c, h+pad_h, w+pad_w, C, H+2*pad_h, W+2*pad_w) = LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
527 }
528 }
529 }
530 }
531 }
532
copy_internal_nchw_uint8(unsigned char * dst,unsigned char * src,int N,int C,int H,int W,int pad_h,int pad_w)533 LIBXSMM_INLINE void copy_internal_nchw_uint8(unsigned char* dst , unsigned char* src, int N, int C, int H, int W, int pad_h, int pad_w)
534 {
535 LIBXSMM_VLA_DECL(4, unsigned char, input, src, C, H, W);
536 LIBXSMM_VLA_DECL(4, unsigned char, new_input, dst, C, H+2*pad_h, W+2*pad_w);
537 int n, h, w, c;
538
539 #if defined(_OPENMP)
540 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
541 # pragma omp parallel for private(n,c,h,w)
542 #endif
543 for ( n = 0; n < N; n++ ) {
544 for ( c = 0; c < C; c++ ) {
545 for ( h = 0; h < H; h++ ) {
546 for ( w = 0; w < W; w++ ) {
547 LIBXSMM_VLA_ACCESS(4, new_input, n, c, h+pad_h, w+pad_w, C, H+2*pad_h, W+2*pad_w) = LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
548 }
549 }
550 }
551 }
552 }
553
naive_copy_NCHW_to_NHWC(const float * nchw,float * nhwc,int N,int H,int W,int C)554 LIBXSMM_INLINE void naive_copy_NCHW_to_NHWC(const float* nchw, float* nhwc, int N, int H, int W, int C)
555 {
556 LIBXSMM_VLA_DECL(4, float, output, nhwc, H, W, C);
557 LIBXSMM_VLA_DECL(4, const float, input, nchw, C, H, W);
558 int n, h, w, c;
559
560 #if defined(_OPENMP)
561 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
562 # pragma omp parallel for private(n,c,h,w)
563 #endif
564 for ( n = 0; n < N; n++ ) {
565 for ( h = 0; h < H; h++ ) {
566 for ( w = 0; w < W; w++ ) {
567 for ( c = 0; c < C; c++ ) {
568 LIBXSMM_VLA_ACCESS(4, output, n, h, w, c, H, W, C) =
569 LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
570 }
571 }
572 }
573 }
574 }
575
naive_copy_NHWC_to_NCHW(const float * nhwc,float * nchw,int N,int H,int W,int C)576 LIBXSMM_INLINE void naive_copy_NHWC_to_NCHW(const float* nhwc, float* nchw, int N, int H, int W, int C)
577 {
578 LIBXSMM_VLA_DECL(4, float, output, nchw, C, H, W);
579 LIBXSMM_VLA_DECL(4, const float, input, nhwc, H, W, C);
580 int n, h, w, c;
581
582 #if defined(_OPENMP)
583 LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
584 # pragma omp parallel for private(n,c,h,w)
585 #endif
586 for ( n = 0; n < N; n++ ) {
587 for ( h = 0; h < H; h++ ) {
588 for ( w = 0; w < W; w++ ) {
589 for ( c = 0; c < C; c++ ) {
590 LIBXSMM_VLA_ACCESS(4, output, n, c, h, w, C, H, W) =
591 LIBXSMM_VLA_ACCESS(4, input, n, h, w, c, H, W, C);
592 }
593 }
594 }
595 }
596 }
597
naive_copy_KCRS_to_RSCK(const float * kcrs,float * rsck,int R,int S,int C,int K)598 LIBXSMM_INLINE void naive_copy_KCRS_to_RSCK(const float* kcrs, float* rsck, int R, int S, int C, int K)
599 {
600 LIBXSMM_VLA_DECL(4, float, output, rsck, S, C, K);
601 LIBXSMM_VLA_DECL(4, const float, input, kcrs, C, R, S);
602 int r, s, c, k;
603
604 #if defined(_OPENMP)
605 LIBXSMM_OMP_VAR(s); LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(k);
606 # pragma omp parallel for private(r,s,c,k)
607 #endif
608 for ( r = 0; r < R; r++ ) {
609 for ( s = 0; s < S; s++ ) {
610 for ( c = 0; c < C; c++ ) {
611 for ( k = 0; k < K; k++ ) {
612 LIBXSMM_VLA_ACCESS(4, output, r, s, c, k, S, C, K) =
613 LIBXSMM_VLA_ACCESS(4, input, k, c, r, s, C, R, S);
614 }
615 }
616 }
617 }
618 }
619
620
naive_copy_RSCK_to_KCRS(const float * rsck,float * kcrs,int R,int S,int C,int K)621 LIBXSMM_INLINE void naive_copy_RSCK_to_KCRS(const float* rsck, float* kcrs, int R, int S, int C, int K)
622 {
623 LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
624 LIBXSMM_VLA_DECL(4, float, output, kcrs, C, R, S);
625 int r, s, c, k;
626
627 #if defined(_OPENMP)
628 LIBXSMM_OMP_VAR(s); LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(k);
629 # pragma omp parallel for private(r,s,c,k)
630 #endif
631 for ( r = 0; r < R; r++ ) {
632 for ( s = 0; s < S; s++ ) {
633 for ( c = 0; c < C; c++ ) {
634 for ( k = 0; k < K; k++ ) {
635 LIBXSMM_VLA_ACCESS(4, output, k, c, r, s, C, R, S) =
636 LIBXSMM_VLA_ACCESS(4, input, r, s, c, k, S, C, K);
637 }
638 }
639 }
640 }
641 }
642
matrix_copy_NC_to_NCNC(float * src,float * dst,int T,int N,int C,int bn,int bc)643 LIBXSMM_INLINE void matrix_copy_NC_to_NCNC(float *src, float *dst, int T, int N, int C, int bn, int bc)
644 {
645 int t, n1, n2, c1, c2;
646 int nBlocks = N/bn;
647 int cBlocks = C/bc;
648 LIBXSMM_VLA_DECL(3, float, real_src, src, N, C);
649 LIBXSMM_VLA_DECL(5, float, real_dst, dst, nBlocks, cBlocks, bn, bc);
650
651 #if defined(_OPENMP)
652 LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
653 # pragma omp parallel for private(t,n1,c1,n2,c2)
654 #endif
655 for (t = 0; t < T; t++) {
656 for (n1 = 0; n1 < nBlocks; n1++) {
657 for (c1 = 0; c1 < cBlocks; c1++) {
658 for (n2 = 0; n2 < bn; n2++) {
659 for (c2 = 0; c2 < bc; c2++) {
660 LIBXSMM_VLA_ACCESS(5, real_dst, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc) =
661 LIBXSMM_VLA_ACCESS(3, real_src, t, n1*bn+n2, c1*bc+c2, N, C);
662 }
663 }
664 }
665 }
666 }
667 }
668
matrix_copy_NCNC_to_NC(float * src,float * dst,int T,int N,int C,int bn,int bc)669 LIBXSMM_INLINE void matrix_copy_NCNC_to_NC(float *src, float *dst, int T, int N, int C, int bn, int bc)
670 {
671 int t, n1, n2, c1, c2;
672 int nBlocks = N/bn;
673 int cBlocks = C/bc;
674 LIBXSMM_VLA_DECL(3, float, real_dst, dst, N, C);
675 LIBXSMM_VLA_DECL(5, float, real_src, src, nBlocks, cBlocks, bn, bc);
676
677 #if defined(_OPENMP)
678 LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
679 # pragma omp parallel for private(t,n1,c1,n2,c2)
680 #endif
681 for (t = 0; t < T; t++) {
682 for (n1 = 0; n1 < nBlocks; n1++) {
683 for (c1 = 0; c1 < cBlocks; c1++) {
684 for (n2 = 0; n2 < bn; n2++) {
685 for (c2 = 0; c2 < bc; c2++) {
686 LIBXSMM_VLA_ACCESS(3, real_dst, t, n1*bn+n2, c1*bc+c2, N, C) =
687 LIBXSMM_VLA_ACCESS(5, real_src, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc);
688 }
689 }
690 }
691 }
692 }
693 }
694
matrix_copy_NC_to_NCNC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int T,int N,int C,int bn,int bc)695 LIBXSMM_INLINE void matrix_copy_NC_to_NCNC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int T, int N, int C, int bn, int bc)
696 {
697 int t, n1, n2, c1, c2;
698 int nBlocks = N/bn;
699 int cBlocks = C/bc;
700 LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, real_src, src, N, C);
701 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, nBlocks, cBlocks, bn, bc);
702
703 #if defined(_OPENMP)
704 LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
705 # pragma omp parallel for private(t,n1,c1,n2,c2)
706 #endif
707 for (t = 0; t < T; t++) {
708 for (n1 = 0; n1 < nBlocks; n1++) {
709 for (c1 = 0; c1 < cBlocks; c1++) {
710 for (n2 = 0; n2 < bn; n2++) {
711 for (c2 = 0; c2 < bc; c2++) {
712 LIBXSMM_VLA_ACCESS(5, real_dst, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc) =
713 LIBXSMM_VLA_ACCESS(3, real_src, t, n1*bn+n2, c1*bc+c2, N, C);
714 }
715 }
716 }
717 }
718 }
719 }
720
matrix_copy_NCNC_to_NC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int T,int N,int C,int bn,int bc)721 LIBXSMM_INLINE void matrix_copy_NCNC_to_NC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int T, int N, int C, int bn, int bc)
722 {
723 int t, n1, n2, c1, c2;
724 int nBlocks = N/bn;
725 int cBlocks = C/bc;
726 LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, real_dst, dst, N, C);
727 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, nBlocks, cBlocks, bn, bc);
728
729 #if defined(_OPENMP)
730 LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
731 # pragma omp parallel for private(t,n1,c1,n2,c2)
732 #endif
733 for (t = 0; t < T; t++) {
734 for (n1 = 0; n1 < nBlocks; n1++) {
735 for (c1 = 0; c1 < cBlocks; c1++) {
736 for (n2 = 0; n2 < bn; n2++) {
737 for (c2 = 0; c2 < bc; c2++) {
738 LIBXSMM_VLA_ACCESS(3, real_dst, t, n1*bn+n2, c1*bc+c2, N, C) =
739 LIBXSMM_VLA_ACCESS(5, real_src, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc);
740 }
741 }
742 }
743 }
744 }
745 }
746
matrix_copy_CK_to_KCCK(float * src,float * dst,int C,int K,int bc,int bk)747 LIBXSMM_INLINE void matrix_copy_CK_to_KCCK(float *src, float *dst, int C, int K, int bc, int bk)
748 {
749 int k1, k2, c1, c2;
750 int kBlocks = K/bk;
751 int cBlocks = C/bc;
752 LIBXSMM_VLA_DECL(2, float, real_src, src, K);
753 LIBXSMM_VLA_DECL(4, float, real_dst, dst, cBlocks, bc, bk);
754
755 #if defined(_OPENMP)
756 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
757 # pragma omp parallel for private(k1,c1,c2,k2)
758 #endif
759 for (k1 = 0; k1 < kBlocks; k1++) {
760 for (c1 = 0; c1 < cBlocks; c1++) {
761 for (c2 = 0; c2 < bc; c2++) {
762 for (k2 = 0; k2 < bk; k2++) {
763 LIBXSMM_VLA_ACCESS(4, real_dst, k1, c1, c2, k2, cBlocks, bc, bk) =
764 LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
765 }
766 }
767 }
768 }
769 }
770
matrix_copy_CK_to_CKKC(float * src,float * dst,int C,int K,int bc,int bk)771 LIBXSMM_INLINE void matrix_copy_CK_to_CKKC(float *src, float *dst, int C, int K, int bc, int bk)
772 {
773 int k1, k2, c1, c2;
774 int kBlocks = K/bk;
775 int cBlocks = C/bc;
776 LIBXSMM_VLA_DECL(2, float, real_src, src, K);
777 LIBXSMM_VLA_DECL(4, float, real_dst, dst, kBlocks, bk, bc);
778
779 #if defined(_OPENMP)
780 LIBXSMM_OMP_VAR(k1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
781 # pragma omp parallel for private(k1,c1,c2,k2)
782 #endif
783 for (c1 = 0; c1 < cBlocks; c1++) {
784 for (k1 = 0; k1 < kBlocks; k1++) {
785 for (k2 = 0; k2 < bk; k2++) {
786 for (c2 = 0; c2 < bc; c2++) {
787 LIBXSMM_VLA_ACCESS(4, real_dst, c1, k1, k2, c2, kBlocks, bk, bc) =
788 LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
789 }
790 }
791 }
792 }
793 }
794
matrix_copy_KC_to_KCCK(float * src,float * dst,int C,int K,int bc,int bk)795 LIBXSMM_INLINE void matrix_copy_KC_to_KCCK(float *src, float *dst, int C, int K, int bc, int bk)
796 {
797 int k1, k2, c1, c2;
798 int kBlocks = K/bk;
799 int cBlocks = C/bc;
800 LIBXSMM_VLA_DECL(2, float, real_src, src, C);
801 LIBXSMM_VLA_DECL(4, float, real_dst, dst, cBlocks, bc, bk);
802
803 #if defined(_OPENMP)
804 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
805 # pragma omp parallel for private(k1,c1,c2,k2)
806 #endif
807 for (k1 = 0; k1 < kBlocks; k1++) {
808 for (c1 = 0; c1 < cBlocks; c1++) {
809 for (c2 = 0; c2 < bc; c2++) {
810 for (k2 = 0; k2 < bk; k2++) {
811 LIBXSMM_VLA_ACCESS(4, real_dst, k1, c1, c2, k2, cBlocks, bc, bk) =
812 LIBXSMM_VLA_ACCESS(2, real_src, k1*bk+k2, c1*bc+c2, C);
813 }
814 }
815 }
816 }
817 }
818
matrix_copy_KCCK_to_KC(float * src,float * dst,int C,int K,int bc,int bk)819 LIBXSMM_INLINE void matrix_copy_KCCK_to_KC(float *src, float *dst, int C, int K, int bc, int bk)
820 {
821 int k1, k2, c1, c2;
822 int kBlocks = K/bk;
823 int cBlocks = C/bc;
824 LIBXSMM_VLA_DECL(2, float, real_dst, dst, C);
825 LIBXSMM_VLA_DECL(4, float, real_src, src, cBlocks, bc, bk);
826
827 #if defined(_OPENMP)
828 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
829 # pragma omp parallel for private(k1,c1,c2,k2)
830 #endif
831 for (k1 = 0; k1 < kBlocks; k1++) {
832 for (c1 = 0; c1 < cBlocks; c1++) {
833 for (c2 = 0; c2 < bc; c2++) {
834 for (k2 = 0; k2 < bk; k2++) {
835 LIBXSMM_VLA_ACCESS(2, real_dst, k1*bk+k2, c1*bc+c2, C) =
836 LIBXSMM_VLA_ACCESS(4, real_src, k1, c1, c2, k2, cBlocks, bc, bk);
837 }
838 }
839 }
840 }
841 }
842
matrix_copy_KCCK_to_CK(float * src,float * dst,int C,int K,int bc,int bk)843 LIBXSMM_INLINE void matrix_copy_KCCK_to_CK(float *src, float *dst, int C, int K, int bc, int bk)
844 {
845 int k1, k2, c1, c2;
846 int kBlocks = K/bk;
847 int cBlocks = C/bc;
848 LIBXSMM_VLA_DECL(2, float, real_dst, dst, K);
849 LIBXSMM_VLA_DECL(4, float, real_src, src, cBlocks, bc, bk);
850
851 #if defined(_OPENMP)
852 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
853 # pragma omp parallel for private(k1,c1,c2,k2)
854 #endif
855 for (k1 = 0; k1 < kBlocks; k1++) {
856 for (c1 = 0; c1 < cBlocks; c1++) {
857 for (c2 = 0; c2 < bc; c2++) {
858 for (k2 = 0; k2 < bk; k2++) {
859 LIBXSMM_VLA_ACCESS(2, real_dst, c1*bc+c2, k1*bk+k2, K) =
860 LIBXSMM_VLA_ACCESS(4, real_src, k1, c1, c2, k2, cBlocks, bc, bk);
861 }
862 }
863 }
864 }
865 }
866
matrix_copy_CK_to_KCCK_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)867 LIBXSMM_INLINE void matrix_copy_CK_to_KCCK_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
868 {
869 int k1, k2, c1, c2;
870 int kBlocks = K/bk;
871 int cBlocks = C/bc;
872 LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_src, src, K);
873 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, cBlocks, bc/2, bk, 2);
874
875 #if defined(_OPENMP)
876 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
877 # pragma omp parallel for private(k1,c1,c2,k2)
878 #endif
879 for (k1 = 0; k1 < kBlocks; k1++) {
880 for (c1 = 0; c1 < cBlocks; c1++) {
881 for (c2 = 0; c2 < bc; c2++) {
882 for (k2 = 0; k2 < bk; k2++) {
883 LIBXSMM_VLA_ACCESS(5, real_dst, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2) =
884 LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
885 }
886 }
887 }
888 }
889 }
890
matrix_copy_CK_to_CKKC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)891 LIBXSMM_INLINE void matrix_copy_CK_to_CKKC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
892 {
893 int k1, k2, c1, c2;
894 int kBlocks = K/bk;
895 int cBlocks = C/bc;
896 LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_src, src, K);
897 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, kBlocks, bk/2, bc, 2);
898
899 #if defined(_OPENMP)
900 LIBXSMM_OMP_VAR(k1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
901 # pragma omp parallel for private(k1,c1,c2,k2)
902 #endif
903 for (c1 = 0; c1 < cBlocks; c1++) {
904 for (k1 = 0; k1 < kBlocks; k1++) {
905 for (k2 = 0; k2 < bk; k2++) {
906 for (c2 = 0; c2 < bc; c2++) {
907 LIBXSMM_VLA_ACCESS(5, real_dst, c1, k1, k2/2, c2, k2%2, kBlocks, bk/2, bc, 2) =
908 LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
909 }
910 }
911 }
912 }
913 }
914
matrix_copy_KC_to_KCCK_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)915 LIBXSMM_INLINE void matrix_copy_KC_to_KCCK_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
916 {
917 int k1, k2, c1, c2;
918 int kBlocks = K/bk;
919 int cBlocks = C/bc;
920 LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_src, src, C);
921 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, cBlocks, bc/2, bk, 2);
922
923 #if defined(_OPENMP)
924 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
925 # pragma omp parallel for private(k1,c1,c2,k2)
926 #endif
927 for (k1 = 0; k1 < kBlocks; k1++) {
928 for (c1 = 0; c1 < cBlocks; c1++) {
929 for (c2 = 0; c2 < bc; c2++) {
930 for (k2 = 0; k2 < bk; k2++) {
931 LIBXSMM_VLA_ACCESS(5, real_dst, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2) =
932 LIBXSMM_VLA_ACCESS(2, real_src, k1*bk+k2, c1*bc+c2, C);
933 }
934 }
935 }
936 }
937 }
938
matrix_copy_KCCK_to_KC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)939 LIBXSMM_INLINE void matrix_copy_KCCK_to_KC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
940 {
941 int k1, k2, c1, c2;
942 int kBlocks = K/bk;
943 int cBlocks = C/bc;
944 LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_dst, dst, C);
945 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, cBlocks, bc/2, bk, 2);
946
947 #if defined(_OPENMP)
948 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
949 # pragma omp parallel for private(k1,c1,c2,k2)
950 #endif
951 for (k1 = 0; k1 < kBlocks; k1++) {
952 for (c1 = 0; c1 < cBlocks; c1++) {
953 for (c2 = 0; c2 < bc; c2++) {
954 for (k2 = 0; k2 < bk; k2++) {
955 LIBXSMM_VLA_ACCESS(2, real_dst, k1*bk+k2, c1*bc+c2, C) =
956 LIBXSMM_VLA_ACCESS(5, real_src, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2);
957 }
958 }
959 }
960 }
961 }
962
matrix_copy_KCCK_to_CK_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)963 LIBXSMM_INLINE void matrix_copy_KCCK_to_CK_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
964 {
965 int k1, k2, c1, c2;
966 int kBlocks = K/bk;
967 int cBlocks = C/bc;
968 LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_dst, dst, K);
969 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, cBlocks, bc/2, bk, 2);
970
971 #if defined(_OPENMP)
972 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
973 # pragma omp parallel for private(k1,c1,c2,k2)
974 #endif
975 for (k1 = 0; k1 < kBlocks; k1++) {
976 for (c1 = 0; c1 < cBlocks; c1++) {
977 for (c2 = 0; c2 < bc; c2++) {
978 for (k2 = 0; k2 < bk; k2++) {
979 LIBXSMM_VLA_ACCESS(2, real_dst, c1*bc+c2, k1*bk+k2, K) =
980 LIBXSMM_VLA_ACCESS(5, real_src, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2);
981 }
982 }
983 }
984 }
985 }
986
matrix_copy_KCCK_to_CKKC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)987 LIBXSMM_INLINE void matrix_copy_KCCK_to_CKKC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
988 {
989 int k1, k2, c1, c2;
990 int kBlocks = K/bk;
991 int cBlocks = C/bc;
992 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, kBlocks, bk/2, bc, 2);
993 LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, cBlocks, bc/2, bk, 2);
994
995 #if defined(_OPENMP)
996 LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
997 # pragma omp parallel for private(k1,c1,c2,k2)
998 #endif
999 for (k1 = 0; k1 < kBlocks; k1++) {
1000 for (c1 = 0; c1 < cBlocks; c1++) {
1001 for (c2 = 0; c2 < bc; c2++) {
1002 for (k2 = 0; k2 < bk; k2++) {
1003 LIBXSMM_VLA_ACCESS(5, real_dst, c1, k1, k2/2, c2, k2%2, kBlocks, bk/2, bc, 2) =
1004 LIBXSMM_VLA_ACCESS(5, real_src, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2);
1005 }
1006 }
1007 }
1008 }
1009 }
1010
matrix_add(int size,float * a,float * b,float * c)1011 LIBXSMM_INLINE void matrix_add(int size, float *a, float *b, float *c)
1012 {
1013 int i;
1014 #if defined(_OPENMP)
1015 # pragma omp parallel for private(i)
1016 #endif
1017 for (i = 0; i < size; i++) {
1018 c[i] = a[i] + b[i];
1019 }
1020 }
1021
matrix_eltwise_mult(int size,float * a,float * b,float * c)1022 LIBXSMM_INLINE void matrix_eltwise_mult(int size, float *a, float *b, float *c)
1023 {
1024 int i;
1025 #if defined(_OPENMP)
1026 # pragma omp parallel for private(i)
1027 #endif
1028 for (i = 0; i < size; i++) {
1029 c[i] = a[i] * b[i];
1030 }
1031 }
1032
matrix_eltwise_fma(int size,float * a,float * b,float * c)1033 LIBXSMM_INLINE void matrix_eltwise_fma(int size, float *a, float *b, float *c)
1034 {
1035 int i;
1036 #if defined(_OPENMP)
1037 # pragma omp parallel for private(i)
1038 #endif
1039 for (i = 0; i < size; i++) {
1040 c[i] += a[i] * b[i];
1041 }
1042 }
1043
matrix_eltwise_mult_ld_a(int m,int n,int ld,float * a,float * b,float * c)1044 LIBXSMM_INLINE void matrix_eltwise_mult_ld_a(int m, int n, int ld, float *a, float *b, float *c)
1045 {
1046 int i;
1047 #if defined(_OPENMP)
1048 # pragma omp parallel for private(i)
1049 #endif
1050 for (i = 0; i < m*n; i++) {
1051 int row = i / m;
1052 int col = i % m;
1053 c[i] = a[row*ld + col] * b[i];
1054 }
1055 }
1056
matrix_eltwise_mult_ld_ab(int m,int n,int ld,float * a,float * b,float * c)1057 LIBXSMM_INLINE void matrix_eltwise_mult_ld_ab(int m, int n, int ld, float *a, float *b, float *c)
1058 {
1059 int i;
1060 #if defined(_OPENMP)
1061 # pragma omp parallel for private(i)
1062 #endif
1063 for (i = 0; i < m*n; i++) {
1064 int row = i / m;
1065 int col = i % m;
1066 c[i] = a[row*ld + col] * b[row*ld + col];
1067 }
1068 }
1069
matrix_eltwise_mult_ld_c(int m,int n,int ld,float * a,float * b,float * c)1070 LIBXSMM_INLINE void matrix_eltwise_mult_ld_c(int m, int n, int ld, float *a, float *b, float *c)
1071 {
1072 int i;
1073 #if defined(_OPENMP)
1074 # pragma omp parallel for private(i)
1075 #endif
1076 for (i = 0; i < m*n; i++) {
1077 int row = i / m;
1078 int col = i % m;
1079 c[row*ld + col] = a[i] * b[i];
1080 }
1081 }
1082
matrix_sigmoid(int size,float * src,float * dst)1083 LIBXSMM_INLINE void matrix_sigmoid(int size, float *src, float *dst)
1084 {
1085 int i;
1086 #if defined(_OPENMP)
1087 # pragma omp parallel for private(i)
1088 #endif
1089 for (i = 0; i < size; i++) {
1090 const float exp_value = (float)exp((double) -src[i]);
1091 dst[i] = 1.0f / (1.0f + exp_value);
1092 }
1093 }
1094
matrix_sigmoid_ld(int m,int n,int ld,float * src,float * dst)1095 LIBXSMM_INLINE void matrix_sigmoid_ld(int m, int n, int ld, float *src, float *dst)
1096 {
1097 int i;
1098 #if defined(_OPENMP)
1099 # pragma omp parallel for private(i)
1100 #endif
1101 for (i = 0; i < m*n; i++) {
1102 int row = i / m;
1103 int col = i % m;
1104 const float exp_value = (float)exp((double) -src[row*ld + col]);
1105 dst[row*ld + col] = 1.0f / (1.0f + exp_value);
1106 }
1107 }
1108
matrix_tanh(int size,float * src,float * dst)1109 LIBXSMM_INLINE void matrix_tanh(int size, float *src, float *dst)
1110 {
1111 int i;
1112 #if defined(_OPENMP)
1113 # pragma omp parallel for private(i)
1114 #endif
1115 for (i = 0; i < size; i++) {
1116 dst[i] = (float)tanh((double)src[i]);
1117 }
1118 }
1119
matrix_tanh_ld(int m,int n,int ld,float * src,float * dst)1120 LIBXSMM_INLINE void matrix_tanh_ld(int m, int n, int ld, float *src, float *dst)
1121 {
1122 int i;
1123 #if defined(_OPENMP)
1124 # pragma omp parallel for private(i)
1125 #endif
1126 for (i = 0; i < m*n; i++) {
1127 int row = i / m;
1128 int col = i % m;
1129 dst[row*ld + col] = (float)tanh((double)src[row*ld + col]);
1130 }
1131 }
1132
matrix_relu(int size,float * src,float * dst)1133 LIBXSMM_INLINE void matrix_relu(int size, float *src, float *dst)
1134 {
1135 int i;
1136 #if defined(_OPENMP)
1137 # pragma omp parallel for private(i)
1138 #endif
1139 for (i = 0; i < size; i++) {
1140 dst[i] = (src[i] > 0.0f) ? src[i] : 0.0f;
1141 }
1142 }
1143
matrix_sigmoid_inverse(int size,float * src,float * dst)1144 LIBXSMM_INLINE void matrix_sigmoid_inverse(int size, float *src, float *dst)
1145 {
1146 int i;
1147 #if defined(_OPENMP)
1148 # pragma omp parallel for private(i)
1149 #endif
1150 for (i = 0; i < size; i++) {
1151 const float exp_value = (float)exp((double) -src[i]);
1152 const float sig_exp = 1.0f / (1.0f + exp_value);
1153 dst[i] = (1.0f - sig_exp)*sig_exp;
1154 }
1155 }
1156
matrix_tanh_inverse(int size,float * src,float * dst)1157 LIBXSMM_INLINE void matrix_tanh_inverse(int size, float *src, float *dst)
1158 {
1159 int i;
1160 #if defined(_OPENMP)
1161 # pragma omp parallel for private(i)
1162 #endif
1163 for (i = 0; i < size; i++) {
1164 const float tanh_value = (float)tanh((double)src[i]);
1165 dst[i] = 1.0f - (tanh_value * tanh_value);
1166 }
1167 }
1168
matrix_relu_inverse(int size,float * src,float * dst)1169 LIBXSMM_INLINE void matrix_relu_inverse(int size, float *src, float *dst)
1170 {
1171 int i;
1172 #if defined(_OPENMP)
1173 # pragma omp parallel for private(i)
1174 #endif
1175 for (i = 0; i < size; i++) {
1176 dst[i] = (src[i] > 0.0f) ? 1.0f : 0.0f;
1177 }
1178 }
1179
matrix_transpose(int rows,int cols,float * src,float * dst)1180 LIBXSMM_INLINE void matrix_transpose(int rows, int cols, float *src, float *dst)
1181 {
1182 libxsmm_otrans_omp(dst, src, sizeof(float), cols, rows, cols/*ldi*/, rows/*ldo*/);
1183 }
1184
matrix_copy(int size,float * src,float * dst)1185 LIBXSMM_INLINE void matrix_copy(int size, float *src, float *dst)
1186 {
1187 int i;
1188 #if defined(_OPENMP)
1189 # pragma omp parallel for private(i)
1190 #endif
1191 for (i = 0; i < size; i++) {
1192 dst[i] = src[i];
1193 }
1194 }
1195
matrix_copy_f32_bf16(int size,float * src,libxsmm_bfloat16 * dst)1196 LIBXSMM_INLINE void matrix_copy_f32_bf16(int size, float *src, libxsmm_bfloat16 *dst)
1197 {
1198 int i;
1199 #if defined(_OPENMP)
1200 # pragma omp parallel for private(i)
1201 #endif
1202 for (i = 0; i < size; i++) {
1203 libxsmm_bfloat16_hp t;
1204 t.f = src[i];
1205 dst[i] = t.i[1];
1206 }
1207 }
1208
matrix_copy_bf16_f32(int size,libxsmm_bfloat16 * src,float * dst)1209 LIBXSMM_INLINE void matrix_copy_bf16_f32(int size, libxsmm_bfloat16 *src, float *dst)
1210 {
1211 int i;
1212 #if defined(_OPENMP)
1213 # pragma omp parallel for private(i)
1214 #endif
1215 for (i = 0; i < size; i++) {
1216 libxsmm_bfloat16_hp t;
1217 t.i[1] = src[i];
1218 t.i[0] = 0;
1219 dst[i] = t.f;
1220 }
1221 }
1222
matrix_copy_ld(int m,int n,int ld,float * src,float * dst)1223 LIBXSMM_INLINE void matrix_copy_ld(int m, int n, int ld, float *src, float *dst)
1224 {
1225 int i;
1226 #if defined(_OPENMP)
1227 # pragma omp parallel for private(i)
1228 #endif
1229 for (i = 0; i < m*n; i++) {
1230 int row = i / m;
1231 int col = i % m;
1232 dst[i] = src[row*ld + col];
1233 }
1234 }
1235
matrix_copy_bias(int m,int n,int ld,float * src,float * dst)1236 LIBXSMM_INLINE void matrix_copy_bias(int m, int n, int ld, float *src, float *dst)
1237 {
1238 int i;
1239 #if defined(_OPENMP)
1240 # pragma omp parallel for private(i)
1241 #endif
1242 for (i = 0; i < m*n; i++) {
1243 int row = i / m;
1244 int col = i % m;
1245 dst[row*ld + col] = src[col];
1246 }
1247 }
1248
matrix_copy_forget_bias(int m,int n,int ld,float * src,float * dst,float forget_bias)1249 LIBXSMM_INLINE void matrix_copy_forget_bias(int m, int n, int ld, float *src, float *dst, float forget_bias)
1250 {
1251 int i;
1252 #if defined(_OPENMP)
1253 # pragma omp parallel for private(i)
1254 #endif
1255 for (i = 0; i < m*n; i++) {
1256 int row = i / m;
1257 int col = i % m;
1258 dst[row*ld + col] = src[col] + forget_bias;
1259 }
1260 }
1261
matrix_complement(int size,float * src,float * dst)1262 LIBXSMM_INLINE void matrix_complement(int size, float *src, float *dst)
1263 {
1264 int i;
1265 #if defined(_OPENMP)
1266 # pragma omp parallel for private(i)
1267 #endif
1268 for (i = 0; i < size; i++) {
1269 dst[i] = 1.0f - src[i];
1270 }
1271 }
1272
matrix_complement_ld(int m,int n,int ld,float * src,float * dst)1273 LIBXSMM_INLINE void matrix_complement_ld(int m, int n, int ld, float *src, float *dst)
1274 {
1275 int i;
1276 #if defined(_OPENMP)
1277 # pragma omp parallel for private(i)
1278 #endif
1279 for (i = 0; i < m*n; i++) {
1280 int row = i / m;
1281 int col = i % m;
1282 dst[i] = 1.0f - src[row*ld + col];
1283 }
1284 }
1285
1286
matrix_complement_square(int size,float * src,float * dst)1287 LIBXSMM_INLINE void matrix_complement_square(int size, float *src, float *dst)
1288 {
1289 int i;
1290 #if defined(_OPENMP)
1291 # pragma omp parallel for private(i)
1292 #endif
1293 for (i = 0; i < size; i++) {
1294 dst[i] = 1.0f - (src[i] * src[i]);
1295 }
1296 }
1297
matrix_complement_square_ld(int m,int n,int ld,float * src,float * dst)1298 LIBXSMM_INLINE void matrix_complement_square_ld(int m, int n, int ld, float *src, float *dst)
1299 {
1300 int i;
1301 #if defined(_OPENMP)
1302 # pragma omp parallel for private(i)
1303 #endif
1304 for (i = 0; i < m*n; i++) {
1305 int row = i / m;
1306 int col = i % m;
1307 dst[i] = 1.0f - (src[row*ld + col] * src[row*ld + col]);
1308 }
1309 }
1310
convert_ck_c4k_offset(int C,int K,int offset,float * src,float * dst)1311 LIBXSMM_INLINE void convert_ck_c4k_offset(int C, int K, int offset, float *src, float *dst)
1312 {
1313 /* offsets: i--0, c--1, f--2, o--3 */
1314 int x, y;
1315 #if defined(_OPENMP)
1316 LIBXSMM_OMP_VAR(x);
1317 # pragma omp parallel for private(x, y)
1318 #endif
1319 for (y = 0; y < C; y++) {
1320 for (x = 0; x < K; x++) {
1321 dst[y*4*K + offset*K + x] = src[y*K + x];
1322 }
1323 }
1324 }
1325
convert_ck_c4k(int C,int K,float * src,float * dst)1326 LIBXSMM_INLINE void convert_ck_c4k(int C, int K, float *src, float *dst)
1327 {
1328 convert_ck_c4k_offset(C, K, 0, src, dst);
1329 }
1330
convert_ck_f32_to_c4k_bf16(int C,int K,float * src,libxsmm_bfloat16 * dst)1331 LIBXSMM_INLINE void convert_ck_f32_to_c4k_bf16(int C, int K, float *src, libxsmm_bfloat16 *dst)
1332 {
1333 int x, y;
1334 #if defined(_OPENMP)
1335 LIBXSMM_OMP_VAR(x);
1336 # pragma omp parallel for private(x, y)
1337 #endif
1338 for (y = 0; y < C; y++) {
1339 for (x = 0; x < K; x++) {
1340 libxsmm_bfloat16_hp t;
1341 t.f = src[y*K + x];
1342 dst[y*4*K + x] = t.i[1];
1343 }
1344 }
1345 }
1346
convert_c4k_4ck(int C,int K,float * src,float * dst)1347 LIBXSMM_INLINE void convert_c4k_4ck(int C, int K, float *src, float *dst)
1348 {
1349 /* offsets: i--0, c--1, f--2, o--3 */
1350 int x, y, offset;
1351 #if defined(_OPENMP)
1352 LIBXSMM_OMP_VAR(x); LIBXSMM_OMP_VAR(y);
1353 # pragma omp parallel for private(x, y, offset)
1354 #endif
1355 for (offset = 0; offset < 4; offset++) {
1356 for (y = 0; y < C; y++) {
1357 for (x = 0; x < K; x++) {
1358 dst[offset*C*K + y*K + x] = src[y*4*K + offset*K + x];
1359 }
1360 }
1361 }
1362 }
1363
convert_ck_c3k(int C,int K,float * src,float * dst)1364 LIBXSMM_INLINE void convert_ck_c3k(int C, int K, float *src, float *dst)
1365 {
1366 int x, y;
1367 #if defined(_OPENMP)
1368 LIBXSMM_OMP_VAR(x);
1369 # pragma omp parallel for private(x, y)
1370 #endif
1371 for (y = 0; y < C; y++) {
1372 for (x = 0; x < K; x++) {
1373 dst[y*3*K + x] = src[y*K + x];
1374 }
1375 }
1376 }
1377
convert_nk_nck(int N,int K,int CK,float * src,float * dst)1378 LIBXSMM_INLINE void convert_nk_nck(int N, int K, int CK, float *src, float *dst)
1379 {
1380 int x, y;
1381 #if defined(_OPENMP)
1382 LIBXSMM_OMP_VAR(x);
1383 # pragma omp parallel for private(x, y)
1384 #endif
1385 for (y = 0; y < N; y++) {
1386 for (x = 0; x < K; x++) {
1387 dst[y*CK + x] = src[y*K + x];
1388 }
1389 }
1390 }
1391
naive_conv_fp(naive_conv_t * param,const float * input,float * output,const float * filter,const float * bias)1392 LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, float* output, const float* filter, const float* bias)
1393 {
1394 int nImg = param->nImg;
1395 int nIfm = param->nIfm;
1396 int nOfm = param->nOfm;
1397 int ifhp = param->ifhp;
1398 int ifwp = param->ifwp;
1399 int ofhp = param->ofhp;
1400 int ofwp = param->ofwp;
1401 int ifh = param->ifh;
1402 int ifw = param->ifw;
1403 int ofh = param->ofh;
1404 int ofw = param->ofw;
1405 int pad_h = param->pad_h;
1406 int pad_w = param->pad_w;
1407 int pad_h_in = param->pad_h_in;
1408 int pad_w_in = param->pad_w_in;
1409 int pad_h_out = param->pad_h_out;
1410 int pad_w_out = param->pad_w_out;
1411 int kh = param->kh;
1412 int kw = param->kw;
1413 int stride_h = param->stride_h;
1414 int stride_w = param->stride_w;
1415 /* loop counters */
1416 int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1417
1418 LIBXSMM_VLA_DECL(4, float, output_t, output + (pad_h_out * ofwp + pad_w_out), nOfm, ofhp, ofwp);
1419 LIBXSMM_VLA_DECL(4, const float, input_t, input + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1420 LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw);
1421
1422 #if defined(USE_FUSED_BIAS) || defined(USE_FUSED_BIAS_RELU)
1423 #if defined(_OPENMP)
1424 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1425 #endif
1426 for (img = 0; img < nImg; ++img) {
1427 for (ofm = 0; ofm < nOfm; ++ofm) {
1428 for (oj = 0; oj < ofh; ++oj) {
1429 for (oi = 0; oi < ofw; ++oi) {
1430 LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) = bias[ofm];
1431 }
1432 }
1433 }
1434 }
1435 #else
1436 LIBXSMM_UNUSED(bias);
1437 #endif
1438
1439 #if defined(_OPENMP)
1440 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj); LIBXSMM_OMP_VAR(oi);
1441 LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij); LIBXSMM_OMP_VAR(ii); LIBXSMM_OMP_VAR(kj); LIBXSMM_OMP_VAR(ki);
1442 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1443 #endif
1444 for (img = 0; img < nImg; ++img) {
1445 for (ofm = 0; ofm < nOfm; ++ofm) {
1446 for (ifm = 0; ifm < nIfm; ++ifm) {
1447 for (oj = 0; oj < ofh; ++oj) {
1448 ij = oj * stride_h - pad_h;
1449 for (oi = 0; oi < ofw; ++oi) {
1450 ii = oi * stride_w - pad_w;
1451 for (kj = 0; kj < kh; ++kj) {
1452 if (ij+kj < 0 || ij+kj >= ifh) continue;
1453 for (ki = 0; ki < kw; ++ki) {
1454 if (ii+ki < 0 || ii+ki >= ifw) continue;
1455 LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) +=
1456 LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
1457 * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
1458 }
1459 }
1460 }
1461 }
1462 }
1463 #if defined(USE_FUSED_RELU) || defined(USE_FUSED_BIAS_RELU)
1464 for (oj = 0; oj < ofh; ++oj) {
1465 for (oi = 0; oi < ofw; ++oi) {
1466 LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) =
1467 (LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) < 0.0f) ? 0.0f : LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp);
1468 }
1469 }
1470 #endif
1471 }
1472 }
1473 }
1474
naive_conv_bp(naive_conv_t * param,float * input,const float * output,const float * filter,const float * naive_input_save)1475 LIBXSMM_INLINE void naive_conv_bp(naive_conv_t* param, float* input, const float* output, const float* filter, const float* naive_input_save)
1476 {
1477 int nImg = param->nImg;
1478 int nIfm = param->nIfm;
1479 int nOfm = param->nOfm;
1480 int ifhp = param->ifhp;
1481 int ifwp = param->ifwp;
1482 int ofhp = param->ofhp;
1483 int ofwp = param->ofwp;
1484 int ifh = param->ifh;
1485 int ifw = param->ifw;
1486 int ofh = param->ofh;
1487 int ofw = param->ofw;
1488 int pad_h = param->pad_h;
1489 int pad_w = param->pad_w;
1490 int pad_h_in = param->pad_h_in;
1491 int pad_w_in = param->pad_w_in;
1492 int pad_h_out = param->pad_h_out;
1493 int pad_w_out = param->pad_w_out;
1494 int kh = param->kh;
1495 int kw = param->kw;
1496 int stride_h = param->stride_h;
1497 int stride_w = param->stride_w;
1498 /* loop counters */
1499 int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1500
1501 LIBXSMM_VLA_DECL(4, const float, output_t, output + (pad_h_out * ofwp + pad_w_out), nOfm, ofhp, ofwp);
1502 LIBXSMM_VLA_DECL(4, float, input_t, input + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1503 LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw);
1504 #if (defined(USE_FUSED_RELU_BWD) || defined(USE_FUSED_BATCH_STATS_BWD))
1505 LIBXSMM_VLA_DECL(4, const float, naive_input_t, naive_input_save + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1506 #else
1507 LIBXSMM_UNUSED(naive_input_save);
1508 #endif
1509
1510 #if defined(_OPENMP)
1511 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj); LIBXSMM_OMP_VAR(oi);
1512 LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij); LIBXSMM_OMP_VAR(ii); LIBXSMM_OMP_VAR(kj); LIBXSMM_OMP_VAR(ki);
1513 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1514 #endif
1515 for (img = 0; img < nImg; ++img) {
1516 for (ifm = 0; ifm < nIfm; ++ifm) {
1517 for (ofm = 0; ofm < nOfm; ++ofm) {
1518 for (oj = 0; oj < ofh; ++oj) {
1519 ij = oj * stride_h - pad_h;
1520 for (oi = 0; oi < ofw; ++oi) {
1521 ii = oi * stride_w - pad_w;
1522 for (kj = 0; kj < kh; ++kj) {
1523 if (ij+kj < 0 || ij+kj >= ifh) continue;
1524 for (ki = 0; ki < kw; ++ki) {
1525 if (ii+ki < 0 || ii+ki >= ifw) continue;
1526 LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp) +=
1527 LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp)
1528 * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
1529 }
1530 }
1531 }
1532 }
1533 }
1534 #if (defined(USE_FUSED_RELU_BWD) || defined(USE_FUSED_BATCH_STATS_BWD))
1535 for (ij = 0; ij < ifh; ij++) {
1536 for (ii = 0; ii < ifw; ii++) {
1537 if ( LIBXSMM_VLA_ACCESS(4, naive_input_t, img, ifm, ij, ii , nIfm, ifhp, ifwp) == 0.0 ) {
1538 LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij, ii , nIfm, ifhp, ifwp) = 0.0;
1539 }
1540 }
1541 }
1542 #endif
1543 }
1544 }
1545 }
1546
naive_conv_wu(naive_conv_t * param,const float * input,const float * output,float * filter)1547 LIBXSMM_INLINE void naive_conv_wu(naive_conv_t* param, const float* input, const float* output, float* filter)
1548 {
1549 int nImg = param->nImg;
1550 int nIfm = param->nIfm;
1551 int nOfm = param->nOfm;
1552 int ifhp = param->ifhp;
1553 int ifwp = param->ifwp;
1554 int ofhp = param->ofhp;
1555 int ofwp = param->ofwp;
1556 int ifh = param->ifh;
1557 int ifw = param->ifw;
1558 int ofh = param->ofh;
1559 int ofw = param->ofw;
1560 int pad_h = param->pad_h;
1561 int pad_w = param->pad_w;
1562 int pad_h_in = param->pad_h_in;
1563 int pad_w_in = param->pad_w_in;
1564 int pad_h_out = param->pad_h_out;
1565 int pad_w_out = param->pad_w_out;
1566 int kh = param->kh;
1567 int kw = param->kw;
1568 int stride_h = param->stride_h;
1569 int stride_w = param->stride_w;
1570 /* loop counters */
1571 int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1572
1573 LIBXSMM_VLA_DECL(4, const float, output_t, output + (pad_h_out * ofwp + pad_w_out), nOfm, ofhp, ofwp);
1574 LIBXSMM_VLA_DECL(4, const float, input_t, input + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1575 LIBXSMM_VLA_DECL(4, float, filter_t, filter, nIfm, kh, kw);
1576
1577 #if defined(_OPENMP)
1578 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj); LIBXSMM_OMP_VAR(oi);
1579 LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij); LIBXSMM_OMP_VAR(ii); LIBXSMM_OMP_VAR(kj); LIBXSMM_OMP_VAR(ki);
1580 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1581 #endif
1582 for (ofm = 0; ofm < nOfm; ++ofm) {
1583 for (ifm = 0; ifm < nIfm; ++ifm) {
1584 for (img = 0; img < nImg; ++img) {
1585 for (oj = 0; oj < ofh; ++oj) {
1586 ij = oj * stride_h - pad_h;
1587 for (oi = 0; oi < ofw; ++oi) {
1588 ii = oi * stride_w - pad_w;
1589 for (kj = 0; kj < kh; ++kj) {
1590 if (ij+kj < 0 || ij+kj >= ifh) continue;
1591 for (ki = 0; ki < kw; ++ki) {
1592 if (ii+ki < 0 || ii+ki >= ifw) continue;
1593 LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw) +=
1594 LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
1595 * LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp);
1596 }
1597 }
1598 }
1599 }
1600 }
1601 }
1602 }
1603 }
1604
naive_conv_fp_int16fp32(naive_conv_t * param,const short * input,float * output,const short * filter)1605 LIBXSMM_INLINE void naive_conv_fp_int16fp32(naive_conv_t* param, const short* input, float* output, const short* filter)
1606 {
1607 int nImg = param->nImg;
1608 int nIfm = param->nIfm;
1609 int nOfm = param->nOfm;
1610 int ifhp = param->ifhp;
1611 int ifwp = param->ifwp;
1612 int ofhp = param->ofhp;
1613 int ofwp = param->ofwp;
1614 int ifh = param->ifh;
1615 int ifw = param->ifw;
1616 int ofh = param->ofh;
1617 int ofw = param->ofw;
1618 int pad_h = param->pad_h;
1619 int pad_w = param->pad_w;
1620 int pad_h_in = param->pad_h_in;
1621 int pad_w_in = param->pad_w_in;
1622 int pad_h_out = param->pad_h_out;
1623 int pad_w_out = param->pad_w_out;
1624 int kh = param->kh;
1625 int kw = param->kw;
1626 int stride_h = param->stride_h;
1627 int stride_w = param->stride_w;
1628 /* loop counters */
1629 int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1630
1631 LIBXSMM_VLA_DECL(4, float, output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
1632 LIBXSMM_VLA_DECL(4, const short, input_t, input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
1633 LIBXSMM_VLA_DECL(4, const short, filter_t, filter, nIfm, kh, kw);
1634
1635
1636 #if defined(_OPENMP)
1637 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj); LIBXSMM_OMP_VAR(oi);
1638 LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij); LIBXSMM_OMP_VAR(ii); LIBXSMM_OMP_VAR(kj); LIBXSMM_OMP_VAR(ki);
1639 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1640 #endif
1641 for (img = 0; img < nImg; ++img) {
1642 for (ofm = 0; ofm < nOfm; ++ofm) {
1643 for (ifm = 0; ifm < nIfm; ++ifm) {
1644 for (oj = 0; oj < ofh; ++oj) {
1645 ij = oj * stride_h - pad_h;
1646 for (oi = 0; oi < ofw; ++oi) {
1647 ii = oi * stride_w - pad_w;
1648 for (kj = 0; kj < kh; ++kj) {
1649 if (ij+kj < 0 || ij+kj >= ifh) continue;
1650 for (ki = 0; ki < kw; ++ki) {
1651 if (ii+ki < 0 || ii+ki >= ifw) continue;
1652 LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) +=
1653 (1.f * LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp))
1654 * (1.f * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw));
1655 }
1656 }
1657 }
1658 }
1659 }
1660 }
1661 }
1662 }
1663
naive_conv_fp_int16int32(naive_conv_t * param,const short * input,int * output,const short * filter)1664 LIBXSMM_INLINE void naive_conv_fp_int16int32(naive_conv_t* param, const short* input, int* output, const short* filter)
1665 {
1666 int nImg = param->nImg;
1667 int nIfm = param->nIfm;
1668 int nOfm = param->nOfm;
1669 int ifhp = param->ifhp;
1670 int ifwp = param->ifwp;
1671 int ofhp = param->ofhp;
1672 int ofwp = param->ofwp;
1673 int ifh = param->ifh;
1674 int ifw = param->ifw;
1675 int ofh = param->ofh;
1676 int ofw = param->ofw;
1677 int pad_h = param->pad_h;
1678 int pad_w = param->pad_w;
1679 int pad_h_in = param->pad_h_in;
1680 int pad_w_in = param->pad_w_in;
1681 int pad_h_out = param->pad_h_out;
1682 int pad_w_out = param->pad_w_out;
1683 int kh = param->kh;
1684 int kw = param->kw;
1685 int stride_h = param->stride_h;
1686 int stride_w = param->stride_w;
1687 /* loop counters */
1688 int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1689
1690 LIBXSMM_VLA_DECL(4, int, output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
1691 LIBXSMM_VLA_DECL(4, const short, input_t, input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
1692 LIBXSMM_VLA_DECL(4, const short, filter_t, filter, nIfm, kh, kw);
1693
1694
1695 #if defined(_OPENMP)
1696 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj); LIBXSMM_OMP_VAR(oi);
1697 LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij); LIBXSMM_OMP_VAR(ii); LIBXSMM_OMP_VAR(kj); LIBXSMM_OMP_VAR(ki);
1698 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1699 #endif
1700 for (img = 0; img < nImg; ++img) {
1701 for (ofm = 0; ofm < nOfm; ++ofm) {
1702 for (ifm = 0; ifm < nIfm; ++ifm) {
1703 for (oj = 0; oj < ofh; ++oj) {
1704 ij = oj * stride_h - pad_h;
1705 for (oi = 0; oi < ofw; ++oi) {
1706 ii = oi * stride_w - pad_w;
1707 for (kj = 0; kj < kh; ++kj) {
1708 if (ij+kj < 0 || ij+kj >= ifh) continue;
1709 for (ki = 0; ki < kw; ++ki) {
1710 if (ii+ki < 0 || ii+ki >= ifw) continue;
1711 LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) += (int)
1712 ( (int)LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp))
1713 * ( (int) LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw));
1714 }
1715 }
1716 }
1717 }
1718 }
1719 }
1720 }
1721 }
1722
naive_conv_fp_int8int32(naive_conv_t * param,const unsigned char * input,int * output,const char * filter)1723 LIBXSMM_INLINE void naive_conv_fp_int8int32(naive_conv_t* param, const unsigned char* input, int* output, const char* filter)
1724 {
1725 int nImg = param->nImg;
1726 int nIfm = param->nIfm;
1727 int nOfm = param->nOfm;
1728 int ifhp = param->ifhp;
1729 int ifwp = param->ifwp;
1730 int ofhp = param->ofhp;
1731 int ofwp = param->ofwp;
1732 int ifh = param->ifh;
1733 int ifw = param->ifw;
1734 int ofh = param->ofh;
1735 int ofw = param->ofw;
1736 int pad_h = param->pad_h;
1737 int pad_w = param->pad_w;
1738 int pad_h_in = param->pad_h_in;
1739 int pad_w_in = param->pad_w_in;
1740 int pad_h_out = param->pad_h_out;
1741 int pad_w_out = param->pad_w_out;
1742 int kh = param->kh;
1743 int kw = param->kw;
1744 int stride_h = param->stride_h;
1745 int stride_w = param->stride_w;
1746 /* loop counters */
1747 int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1748
1749 LIBXSMM_VLA_DECL(4, int, output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
1750 LIBXSMM_VLA_DECL(4, const unsigned char, input_t, input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
1751 LIBXSMM_VLA_DECL(4, const char, filter_t, filter, nIfm, kh, kw);
1752
1753
1754 #if defined(_OPENMP)
1755 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj); LIBXSMM_OMP_VAR(oi);
1756 LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij); LIBXSMM_OMP_VAR(ii); LIBXSMM_OMP_VAR(kj); LIBXSMM_OMP_VAR(ki);
1757 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1758 #endif
1759 for (img = 0; img < nImg; ++img) {
1760 for (ofm = 0; ofm < nOfm; ++ofm) {
1761 for (ifm = 0; ifm < nIfm; ++ifm) {
1762 for (oj = 0; oj < ofh; ++oj) {
1763 ij = oj * stride_h - pad_h;
1764 for (oi = 0; oi < ofw; ++oi) {
1765 ii = oi * stride_w - pad_w;
1766 for (kj = 0; kj < kh; ++kj) {
1767 if (ij+kj < 0 || ij+kj >= ifh) continue;
1768 for (ki = 0; ki < kw; ++ki) {
1769 if (ii+ki < 0 || ii+ki >= ifw) continue;
1770 LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) += (int)
1771 LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
1772 * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
1773 }
1774 }
1775 }
1776 }
1777 }
1778 }
1779 }
1780 }
1781
naive_fullyconnected_fp(naive_fullyconnected_t * param,const float * input_ptr,float * output_ptr,const float * filter_ptr)1782 LIBXSMM_INLINE void naive_fullyconnected_fp(naive_fullyconnected_t* param, const float* input_ptr, float* output_ptr, const float* filter_ptr)
1783 {
1784 const int nImg = param->N;
1785 const int nIFm = param->C;
1786 const int nOFm = param->K;
1787
1788 int img, ifm, ofm;
1789
1790 LIBXSMM_VLA_DECL(2, const float, input, input_ptr, nIFm);
1791 LIBXSMM_VLA_DECL(2, const float, filter, filter_ptr, nIFm);
1792 LIBXSMM_VLA_DECL(2, float, output, output_ptr, nOFm);
1793
1794 #if defined(_OPENMP)
1795 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ofm);
1796 # pragma omp parallel for private(img, ofm, ifm)
1797 #endif
1798 for (ofm = 0; ofm < nOFm; ++ofm) {
1799 for(img = 0; img < nImg; ++img) {
1800 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = (float)0;
1801 for (ifm = 0; ifm < nIFm; ++ifm) {
1802 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) +=
1803 LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, input, img, ifm, nIFm);
1804 }
1805 }
1806 }
1807 }
1808
naive_fullyconnected_bp(naive_fullyconnected_t * param,float * delinput_ptr,const float * deloutput_ptr,const float * filter_ptr)1809 LIBXSMM_INLINE void naive_fullyconnected_bp(naive_fullyconnected_t* param, float* delinput_ptr, const float* deloutput_ptr, const float* filter_ptr)
1810 {
1811 const int nImg = param->N;
1812 const int nIFm = param->C;
1813 const int nOFm = param->K;
1814
1815 int img, ifm, ofm;
1816
1817 LIBXSMM_VLA_DECL(2, float, dinput, delinput_ptr, nIFm);
1818 LIBXSMM_VLA_DECL(2, const float, filter, filter_ptr, nIFm);
1819 LIBXSMM_VLA_DECL(2, const float, doutput, deloutput_ptr, nOFm);
1820
1821 #if defined(_OPENMP)
1822 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(ifm);
1823 # pragma omp parallel for private(img, ofm, ifm)
1824 #endif
1825 for (ifm = 0; ifm < nIFm; ++ifm) {
1826 for(img = 0; img < nImg; ++img) {
1827 LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) = (float)0;
1828 for (ofm = 0; ofm < nOFm; ++ofm) {
1829 LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) +=
1830 LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1831 }
1832 }
1833 }
1834 }
1835
naive_fullyconnected_fused_fp(naive_fullyconnected_t * param,const float * input_ptr,float * output_ptr,const float * filter_ptr,const float * bias_ptr)1836 LIBXSMM_INLINE void naive_fullyconnected_fused_fp(naive_fullyconnected_t* param, const float* input_ptr, float* output_ptr, const float* filter_ptr, const float* bias_ptr)
1837 {
1838 const int nImg = param->N;
1839 const int nIFm = param->C;
1840 const int nOFm = param->K;
1841
1842 int img, ifm, ofm;
1843
1844 LIBXSMM_VLA_DECL(2, const float, input, input_ptr, nIFm);
1845 LIBXSMM_VLA_DECL(2, const float, filter, filter_ptr, nIFm);
1846 LIBXSMM_VLA_DECL(2, float, output, output_ptr, nOFm);
1847
1848 #if defined(_OPENMP)
1849 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ofm);
1850 # pragma omp parallel for private(img, ofm, ifm)
1851 #endif
1852 for (ofm = 0; ofm < nOFm; ++ofm) {
1853 for(img = 0; img < nImg; ++img) {
1854 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = (float)0;
1855 for (ifm = 0; ifm < nIFm; ++ifm) {
1856 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) +=
1857 LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, input, img, ifm, nIFm);
1858 }
1859 if ( param->fuse_type == 1 ) {
1860 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) += bias_ptr[ofm];
1861 } else if ( param->fuse_type == 2 ) {
1862 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) : 0;
1863 } else if ( param->fuse_type == 3 ) {
1864 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ((float)tanh((double)LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm)/2.0)+1.0f)/2.0f;
1865 } else if ( param->fuse_type == 4 ) {
1866 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) += bias_ptr[ofm];
1867 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) : 0;
1868 } else if ( param->fuse_type == 5 ) {
1869 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) += bias_ptr[ofm];
1870 LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ((float)tanh((double)LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm)/2.0)+1.0f)/2.0f;
1871 }
1872 }
1873 }
1874 }
1875
naive_fullyconnected_fused_bp(naive_fullyconnected_t * param,float * delinput_ptr,float * deloutput_ptr,const float * filter_ptr,float * delbias_ptr,const float * output_ptr)1876 LIBXSMM_INLINE void naive_fullyconnected_fused_bp(naive_fullyconnected_t* param, float* delinput_ptr, float* deloutput_ptr, const float* filter_ptr, float* delbias_ptr, const float* output_ptr)
1877 {
1878 const int nImg = param->N;
1879 const int nIFm = param->C;
1880 const int nOFm = param->K;
1881
1882 int img, ifm, ofm;
1883
1884 LIBXSMM_VLA_DECL(2, float, dinput, delinput_ptr, nIFm);
1885 LIBXSMM_VLA_DECL(2, const float, filter, filter_ptr, nIFm);
1886 LIBXSMM_VLA_DECL(2, float, doutput, deloutput_ptr, nOFm);
1887 LIBXSMM_VLA_DECL(2, const float, output, output_ptr, nOFm);
1888
1889 if ( param->fuse_type != 0 ) {
1890 #if defined(_OPENMP)
1891 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm);
1892 # pragma omp parallel for private(img, ofm)
1893 #endif
1894 for (ofm = 0; ofm < nOFm; ++ofm) {
1895 float dbias = 0.0f;
1896 for(img = 0; img < nImg; ++img) {
1897 if ( param->fuse_type == 1 ) {
1898 dbias += LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1899 } else if ( param->fuse_type == 2 ) {
1900 LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) : 0;
1901 } else if ( param->fuse_type == 3 ) {
1902 LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm)*(1.0f-LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm));
1903 } else if ( param->fuse_type == 4 ) {
1904 LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) : 0;
1905 dbias += LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1906 } else if ( param->fuse_type == 5 ) {
1907 LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm)*(1.0f-LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm));
1908 dbias += LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1909 }
1910 }
1911 delbias_ptr[ofm] = dbias;
1912 }
1913 }
1914
1915 #if defined(_OPENMP)
1916 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(ifm);
1917 # pragma omp parallel for private(img, ofm, ifm)
1918 #endif
1919 for (ifm = 0; ifm < nIFm; ++ifm) {
1920 for(img = 0; img < nImg; ++img) {
1921 LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) = (float)0;
1922 for (ofm = 0; ofm < nOFm; ++ofm) {
1923 LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) +=
1924 LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1925 }
1926 }
1927 }
1928 }
1929
naive_fullyconnected_wu(naive_fullyconnected_t * param,const float * input_ptr,const float * deloutput_ptr,float * delfilter_ptr)1930 LIBXSMM_INLINE void naive_fullyconnected_wu(naive_fullyconnected_t* param, const float* input_ptr, const float* deloutput_ptr, float* delfilter_ptr)
1931 {
1932 const int nImg = param->N;
1933 const int nIFm = param->C;
1934 const int nOFm = param->K;
1935
1936 int img, ifm, ofm;
1937
1938 LIBXSMM_VLA_DECL(2, const float, input, input_ptr, nIFm);
1939 LIBXSMM_VLA_DECL(2, float, dfilter, delfilter_ptr, nIFm);
1940 LIBXSMM_VLA_DECL(2, const float, doutput, deloutput_ptr, nOFm);
1941
1942 #if defined(_OPENMP)
1943 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(ifm);
1944 # pragma omp parallel for private(img, ofm, ifm)
1945 #endif
1946 for (ofm = 0; ofm < nOFm; ++ofm) {
1947 for (ifm = 0; ifm < nIFm; ++ifm) {
1948 LIBXSMM_VLA_ACCESS(2, dfilter, ofm, ifm, nIFm) = (float)0;
1949 for(img = 0; img < nImg; ++img) {
1950 LIBXSMM_VLA_ACCESS(2, dfilter, ofm, ifm, nIFm) +=
1951 LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) * LIBXSMM_VLA_ACCESS(2, input, img, ifm, nIFm);
1952 }
1953 }
1954 }
1955 }
1956
naive_pooling_fp(naive_pooling_t * param,const float * input_ptr,float * output_ptr,int * mask_ptr)1957 LIBXSMM_INLINE void naive_pooling_fp(naive_pooling_t* param, const float* input_ptr, float* output_ptr, int* mask_ptr)
1958 {
1959 const int nImg = param->N;
1960 const int nFm = param->C;
1961 const int ifh = param->H;
1962 const int ifw = param->W;
1963 const int sh = param->stride_h;
1964 const int sw = param->stride_w;
1965 const int r = param->R;
1966 const int s = param->S;
1967 const int pad_h = param->pad_h;
1968 const int pad_w = param->pad_w;
1969 const int ofh = (ifh + 2*pad_h - r)/sh + 1;
1970 const int ofw = (ifw + 2*pad_w - s)/sw + 1;
1971
1972
1973 int img, fm;
1974
1975 LIBXSMM_VLA_DECL(4, const float, input, input_ptr, nFm, ifh, ifw);
1976 LIBXSMM_VLA_DECL(4, int, mask, mask_ptr, nFm, ofh, ofw);
1977 LIBXSMM_VLA_DECL(4, float, output, output_ptr, nFm, ofh, ofw);
1978
1979 #if defined(_OPENMP)
1980 float* tmp_buffer = (float*)malloc(sizeof(float)*ofh*ofw*omp_get_max_threads());
1981 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(fm);
1982 # pragma omp parallel for private(img, fm)
1983 #else
1984 float* tmp_buffer = (float*)malloc(sizeof(float)*ofh*ofw);
1985 #endif
1986 for (img = 0; img < nImg; img++) {
1987 for (fm = 0; fm < nFm; fm++) {
1988 #if defined(_OPENMP)
1989 float* lcl_buffer_ptr = tmp_buffer + (ofh*ofw*omp_get_thread_num());
1990 #else
1991 float* lcl_buffer_ptr = tmp_buffer;
1992 #endif
1993 LIBXSMM_VLA_DECL(2, float, lcl_buffer, lcl_buffer_ptr, ofw);
1994 int i, ho, wo, hi, wi, kh, kw;
1995
1996 if (param->type == 0 ) {
1997 for ( i = 0; i < ofh*ofw; i++ ) {
1998 lcl_buffer_ptr[i] = -FLT_MAX;
1999 }
2000 } else if (param->type == 1) {
2001 for ( i = 0; i < ofh*ofw; i++ ) {
2002 lcl_buffer_ptr[i] = 0.0;
2003 }
2004 } else {
2005 /* shouldn't happen */
2006 }
2007
2008 for( ho = 0; ho < ofh; ho++ ) {
2009 hi = (ho * sh) - pad_h;
2010 for( wo = 0; wo < ofw; wo++ ) {
2011 wi = (wo * sw) - pad_w;
2012 for( kh = 0; kh < r; kh++ ) {
2013 if (hi+kh < 0 || hi+kh >= ifh) continue;
2014 for( kw = 0; kw < s; kw++ ) {
2015 if (wi+kw < 0 || wi+kw >= ifw) continue;
2016 if ( param->type == 0 ) {
2017 const int index = (hi+kh)*ifw + wi+kw;
2018 if ( LIBXSMM_VLA_ACCESS(4, input, img, fm, hi+kh, wi+kw, nFm, ifh, ifw) > LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) ) {
2019 LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi+kh, wi+kw, nFm, ifh, ifw);
2020 LIBXSMM_VLA_ACCESS(4, mask, img, fm, ho, wo, nFm, ofh, ofw) = index;
2021 }
2022 } else if ( param->type == 1 ) {
2023 LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) += LIBXSMM_VLA_ACCESS(4, input, img, fm, hi+kh, wi+kw, nFm, ifh, ifw);
2024 } else {
2025 /* shouldn't happen */
2026 }
2027 }
2028 }
2029 }
2030 }
2031
2032 if (param->type == 0 ) {
2033 for( ho = 0; ho < ofh; ho++ ) {
2034 for( wo = 0; wo < ofw; wo++ ) {
2035 LIBXSMM_VLA_ACCESS(4, output, img, fm, ho, wo, nFm, ofh, ofw) = LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw);
2036 }
2037 }
2038 } else if (param->type == 1) {
2039 for( ho = 0; ho < ofh; ho++ ) {
2040 for( wo = 0; wo < ofw; wo++ ) {
2041 LIBXSMM_VLA_ACCESS(4, output, img, fm, ho, wo, nFm, ofh, ofw) = LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) * (1.0f/(((float)r) * ((float)s)));
2042 }
2043 }
2044 } else {
2045 /* shouldn't happen */
2046 }
2047 }
2048 }
2049
2050 free( tmp_buffer );
2051 }
2052
naive_pooling_bp(naive_pooling_t * param,float * dinput_ptr,const float * doutput_ptr,const int * mask_ptr)2053 LIBXSMM_INLINE void naive_pooling_bp(naive_pooling_t* param, float* dinput_ptr, const float* doutput_ptr, const int* mask_ptr)
2054 {
2055 const int nImg = param->N;
2056 const int nFm = param->C;
2057 const int ifh = param->H;
2058 const int ifw = param->W;
2059 const int sh = param->stride_h;
2060 const int sw = param->stride_w;
2061 const int r = param->R;
2062 const int s = param->S;
2063 const int pad_h = param->pad_h;
2064 const int pad_w = param->pad_w;
2065 const int ofh = (ifh + 2*pad_h - r)/sh + 1;
2066 const int ofw = (ifw + 2*pad_w - s)/sw + 1;
2067
2068 int img, fm;
2069
2070 LIBXSMM_VLA_DECL(4, float, dinput, dinput_ptr, nFm, ifh, ifw);
2071 LIBXSMM_VLA_DECL(4, const int , mask, mask_ptr, nFm, ofh, ofw);
2072 LIBXSMM_VLA_DECL(4, const float, doutput, doutput_ptr, nFm, ofh, ofw);
2073
2074 #if defined(_OPENMP)
2075 float* tmp_buffer = (float*)malloc(sizeof(float)*ifh*ifw*omp_get_max_threads());
2076 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(fm);
2077 # pragma omp parallel for private(img, fm)
2078 #else
2079 float* tmp_buffer = (float*)malloc(sizeof(float)*ofh*ofw);
2080 #endif
2081 for (img = 0; img < nImg; img++) {
2082 for (fm = 0; fm < nFm; fm++) {
2083 #if defined(_OPENMP)
2084 float* lcl_buffer_ptr = tmp_buffer + (ifh*ifw*omp_get_thread_num());
2085 #else
2086 float* lcl_buffer_ptr = tmp_buffer;
2087 #endif
2088 LIBXSMM_VLA_DECL(2, float, lcl_buffer, lcl_buffer_ptr, ifw);
2089 int i, ho, wo, hi, wi, kh, kw;
2090
2091 for ( i = 0; i < ifh*ifw; i++ ) {
2092 lcl_buffer_ptr[i] = 0.0;
2093 }
2094
2095 if (param->type == 0 ) {
2096 for( ho = 0; ho < ofh; ho++ ) {
2097 for( wo = 0; wo < ofw; wo++ ) {
2098 lcl_buffer_ptr[LIBXSMM_VLA_ACCESS(4, mask, img, fm, ho, wo, nFm, ofh, ofw)] += LIBXSMM_VLA_ACCESS(4, doutput, img, fm, ho, wo, nFm, ofh, ofw);
2099 }
2100 }
2101 } else if ( param->type == 1 ) {
2102 for( ho = 0; ho < ofh; ho++ ) {
2103 hi = (ho * sh) - pad_h;
2104 for( wo = 0; wo < ofw; wo++ ) {
2105 wi = (wo * sw) - pad_w;
2106 for( kh = 0; kh < r; kh++ ) {
2107 if (hi+kh < 0 || hi+kh >= ifh) continue;
2108 for( kw = 0; kw < s; kw++ ) {
2109 if (wi+kw < 0 || wi+kw >= ifw) continue;
2110 LIBXSMM_VLA_ACCESS(2, lcl_buffer, hi+kh, wi+kw, ifw) += ( LIBXSMM_VLA_ACCESS(4, doutput, img, fm, ho, wo, nFm, ofh, ofw) * (1.0f/(((float)r) * ((float)s))) );
2111 }
2112 }
2113 }
2114 }
2115 } else {
2116 /* shouldn't happen */
2117 }
2118
2119 for( hi = 0; hi < ifh; hi++ ) {
2120 for( wi = 0; wi < ifw; wi++ ) {
2121 LIBXSMM_VLA_ACCESS(4, dinput, img, fm, hi, wi, nFm, ifh, ifw) = LIBXSMM_VLA_ACCESS(2, lcl_buffer, hi, wi, ifw);
2122 }
2123 }
2124 }
2125 }
2126
2127 free( tmp_buffer );
2128 }
2129
naive_fusedbatchnorm_fp(naive_fusedbatchnorm_t * param,const float * input_ptr,float * output_ptr,const float * input_add_ptr,const float * beta_ptr,const float * gamma_ptr,float * expectval_ptr,float * rcpstddev_ptr,float * variance_ptr)2130 LIBXSMM_INLINE void naive_fusedbatchnorm_fp(naive_fusedbatchnorm_t* param, const float* input_ptr, float* output_ptr, const float* input_add_ptr,
2131 const float* beta_ptr, const float* gamma_ptr, float* expectval_ptr, float* rcpstddev_ptr, float* variance_ptr)
2132 {
2133 const int nImg = param->N;
2134 const int nFm = param->C;
2135 const int ifh = param->H;
2136 const int ifw = param->W;
2137 const int sh = param->stride_h;
2138 const int sw = param->stride_w;
2139 const int ofh = ifh/sh;
2140 const int ofw = ifw/sw;
2141 const float nhw = (float)(nImg * ifh * ifw);
2142 const float recp_nhw = 1.0f/nhw;
2143 const float sqrt_eps = 1e-7f;
2144
2145 int img, fm, hi, wi, ho, wo;
2146
2147 LIBXSMM_VLA_DECL(4, const float, input, input_ptr, nFm, ifh, ifw);
2148 LIBXSMM_VLA_DECL(4, const float, input_add, input_add_ptr, nFm, ifh, ifw);
2149 LIBXSMM_VLA_DECL(4, float, output, output_ptr, nFm, ofh, ofw);
2150
2151 if ( param->norm_type == 0 ) {
2152 #if defined(_OPENMP)
2153 LIBXSMM_OMP_VAR(wi); LIBXSMM_OMP_VAR(hi);
2154 # pragma omp parallel for private(img, fm, hi, wi)
2155 #endif
2156 for (fm = 0; fm < nFm; fm++) {
2157 float ch_sum = 0.0f;
2158 float ch_sumsq = 0.0f;
2159 float tbmean = 0.0f;
2160 float tbmeansq = 0.0f;
2161 float tsqbmean = 0.0f;
2162 float tbrstd = 0.0f;
2163 float tvariance = 0.0f;
2164
2165 for ( img = 0; img < nImg; img++ ) {
2166 for ( hi = 0; hi < ifh; hi++ ) {
2167 for ( wi = 0; wi < ifw; wi++ ) {
2168 const float input_val = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi, wi, nFm, ifh, ifw);
2169 ch_sum += input_val;
2170 ch_sumsq += (input_val * input_val);
2171 }
2172 }
2173 }
2174
2175 tbmean = recp_nhw * ch_sum;
2176 tbmeansq = tbmean * tbmean;
2177 tsqbmean = recp_nhw * ch_sumsq;
2178 tvariance = tsqbmean - tbmeansq;
2179 tbrstd = (float)(1.0/sqrt(tvariance + sqrt_eps));
2180 expectval_ptr[fm] = tbmean;
2181 rcpstddev_ptr[fm] = tbrstd;
2182 variance_ptr[fm] = tvariance;
2183 }
2184 }
2185
2186 #if defined(_OPENMP)
2187 LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo);
2188 # pragma omp parallel for private(img, fm, hi, wi, ho, wo)
2189 #endif
2190 for ( img = 0; img < nImg; img++ ) {
2191 for ( fm = 0; fm < nFm; fm++ ) {
2192 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2193 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2194 const float input_val = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi, wi, nFm, ifh, ifw);
2195 const float input_add_val = LIBXSMM_VLA_ACCESS(4, input_add, img, fm, hi, wi, nFm, ifh, ifw);
2196 float* output_ptr2 = &LIBXSMM_VLA_ACCESS(4, output, img, fm, ho, wo, nFm, ofh, ofw);
2197
2198 /* BN + scale (gamma, beta) */
2199 float o = gamma_ptr[fm]*(input_val - expectval_ptr[fm])*rcpstddev_ptr[fm] + beta_ptr[fm];
2200 /* Eltwise */
2201 if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2202 o += input_add_val;
2203 }
2204 /* ReLU */
2205 if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2206 o = ( o < 0.0f ) ? 0.0f : o;
2207 }
2208 *output_ptr2 = o;
2209 }
2210 }
2211 }
2212 }
2213 }
2214
naive_fusedbatchnorm_bp(naive_fusedbatchnorm_t * param,const float * input_ptr,float * dinput_ptr,const float * output_ptr,float * doutput_ptr,float * dinput_add_ptr,const float * beta_ptr,float * del_beta_ptr,const float * gamma_ptr,float * del_gamma_ptr,const float * expectval_ptr,const float * rcpstddev_ptr)2215 LIBXSMM_INLINE void naive_fusedbatchnorm_bp(naive_fusedbatchnorm_t* param, const float* input_ptr, float* dinput_ptr, const float* output_ptr, float* doutput_ptr, float* dinput_add_ptr,
2216 const float* beta_ptr, float* del_beta_ptr, const float* gamma_ptr, float* del_gamma_ptr,
2217 const float* expectval_ptr, const float* rcpstddev_ptr)
2218 {
2219 const int nImg = param->N;
2220 const int nFm = param->C;
2221 const int ifh = param->H;
2222 const int ifw = param->W;
2223 const int sh = param->stride_h;
2224 const int sw = param->stride_w;
2225 const int ofh = ifh/sh;
2226 const int ofw = ifw/sw;
2227 const float nhw = (float)(nImg * ifh * ifw);
2228 const float recp_nhw = 1.0f/nhw;
2229
2230 int img, fm, hi, wi, ho, wo;
2231
2232 LIBXSMM_VLA_DECL(4, const float, input, input_ptr, nFm, ifh, ifw);
2233 LIBXSMM_VLA_DECL(4, float, dinput, dinput_ptr, nFm, ifh, ifw);
2234 LIBXSMM_VLA_DECL(4, float, dinput_add, dinput_add_ptr, nFm, ifh, ifw);
2235 LIBXSMM_VLA_DECL(4, const float, output, output_ptr, nFm, ofh, ofw);
2236 LIBXSMM_VLA_DECL(4, float, doutput, doutput_ptr, nFm, ofh, ofw);
2237 LIBXSMM_UNUSED(beta_ptr);
2238
2239 if ( param->norm_type == 0 ) {
2240 #if defined(_OPENMP)
2241 LIBXSMM_OMP_VAR(hi); LIBXSMM_OMP_VAR(wi); LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo);
2242 # pragma omp parallel for private(img, fm, hi, wi, ho, wo)
2243 #endif
2244 for ( fm = 0; fm < nFm; fm++ ) {
2245 del_gamma_ptr[fm] = 0.0f;
2246 del_beta_ptr[fm] = 0.0f;
2247
2248 for ( img = 0; img < nImg; img++ ) {
2249 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2250 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2251 float* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(4, dinput_add, img, fm, hi, wi, fm, ifh, ifw);
2252 const float output_val = LIBXSMM_VLA_ACCESS(4, output, img, fm, ho, wo, fm, ofh, ofw);
2253 const float input_val = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi, wi, fm, ifh, ifw);
2254 float* del_output_ptr = &LIBXSMM_VLA_ACCESS(4, doutput, img, fm, ho, wo, fm, ofh, ofw);
2255
2256 /* ReLU */
2257 if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2258 *del_output_ptr = (output_val == 0) ? 0 : *del_output_ptr;
2259 }
2260 /* elementwise */
2261 if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2262 *del_input_add_ptr = *del_output_ptr;
2263 }
2264 del_gamma_ptr[fm] += (input_val - expectval_ptr[fm]) * (*del_output_ptr) * rcpstddev_ptr[fm];
2265 del_beta_ptr[fm] += *del_output_ptr;
2266 }
2267 }
2268 }
2269 }
2270 }
2271
2272 #if defined(_OPENMP)
2273 # pragma omp parallel for private(img, fm, hi, wi, ho, wo)
2274 #endif
2275 for ( img = 0; img < nImg; img++ ) {
2276 for ( fm = 0; fm < nFm; fm++ ) {
2277 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2278 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++) {
2279 float* del_input_ptr = &LIBXSMM_VLA_ACCESS(4, dinput, img, fm, hi, wi, fm, ifh, ifw);
2280 const float input_val = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi, wi, fm, ifh, ifw);
2281 const float del_output_val = LIBXSMM_VLA_ACCESS(4, doutput, img, fm, ho, wo, fm, ofh, ofw);
2282
2283 *del_input_ptr = gamma_ptr[fm] * rcpstddev_ptr[fm] * recp_nhw * (nhw * del_output_val -
2284 (del_beta_ptr[fm] + (input_val - expectval_ptr[fm]) * del_gamma_ptr[fm] * rcpstddev_ptr[fm]));
2285 }
2286 }
2287 }
2288 }
2289 }
2290
naive_fusedgroupnorm_fp(naive_fusedgroupnorm_t * param,const float * input_ptr,float * output_ptr,const float * input_add_ptr,const float * beta_ptr,const float * gamma_ptr,float * expectval_ptr,float * rcpstddev_ptr,float * variance_ptr)2291 LIBXSMM_INLINE void naive_fusedgroupnorm_fp(naive_fusedgroupnorm_t* param, const float* input_ptr, float* output_ptr, const float* input_add_ptr,
2292 const float* beta_ptr, const float* gamma_ptr, float* expectval_ptr, float* rcpstddev_ptr, float* variance_ptr)
2293 {
2294 const int nImg = param->N;
2295 const int nFm = param->C;
2296 const int ifh = param->H;
2297 const int ifw = param->W;
2298 const int sh = param->stride_h;
2299 const int sw = param->stride_w;
2300 const int ofh = ifh/sh;
2301 const int ofw = ifw/sw;
2302 const int nG = param->G;
2303 const int nFMG = nFm/nG;
2304 const float ghw = (float)(nFMG * ifh * ifw);
2305 const float recp_ghw = 1.0f/ghw;
2306 const float sqrt_eps = 1e-7f;
2307
2308 int img, g, fmg, hi, wi, ho, wo;
2309
2310 LIBXSMM_VLA_DECL(5, const float, input, input_ptr, nG, nFMG, ifh, ifw);
2311 LIBXSMM_VLA_DECL(5, const float, input_add, input_add_ptr, nG, nFMG, ifh, ifw);
2312 LIBXSMM_VLA_DECL(5, float, output, output_ptr, nG, nFMG, ofh, ofw);
2313
2314 #if defined(_OPENMP)
2315 LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(g); LIBXSMM_OMP_VAR(fmg); LIBXSMM_OMP_VAR(hi); LIBXSMM_OMP_VAR(wi);
2316 # pragma omp parallel for private(img, g, fmg, hi, wi)
2317 #endif
2318 for ( img = 0; img < nImg; img++ ) {
2319 for (g = 0; g < nG; g++) {
2320 float ch_sum = 0.0f;
2321 float ch_sumsq = 0.0f;
2322 float tbmean = 0.0f;
2323 float tbmeansq = 0.0f;
2324 float tsqbmean = 0.0f;
2325 float tbrstd = 0.0f;
2326 float tvariance = 0.0f;
2327
2328 for ( fmg = 0; fmg < nFMG; fmg++) {
2329 for ( hi = 0; hi < ifh; hi++ ) {
2330 for ( wi = 0; wi < ifw; wi++ ) {
2331 const float input_val = LIBXSMM_VLA_ACCESS(5, input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2332 ch_sum += input_val;
2333 ch_sumsq += (input_val * input_val);
2334 }
2335 }
2336 }
2337
2338 tbmean = recp_ghw * ch_sum;
2339 tbmeansq = tbmean * tbmean;
2340 tsqbmean = recp_ghw * ch_sumsq;
2341 tvariance = tsqbmean - tbmeansq;
2342 tbrstd = (float)(1.0/sqrt(tvariance + sqrt_eps));
2343 expectval_ptr[img*nG+g] = tbmean;
2344 rcpstddev_ptr[img*nG+g] = tbrstd;
2345 variance_ptr[img*nG+g] = tvariance;
2346 }
2347 }
2348
2349 #if defined(_OPENMP)
2350 LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo);
2351 # pragma omp parallel for private(img, g, fmg, hi, wi, ho, wo)
2352 #endif
2353 for ( img = 0; img < nImg; img++ ) {
2354 for ( g = 0; g < nG; g++ ) {
2355 for ( fmg = 0; fmg < nFMG; fmg++ ) {
2356 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2357 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2358 const float input_val = LIBXSMM_VLA_ACCESS(5, input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2359 const float input_add_val = LIBXSMM_VLA_ACCESS(5, input_add, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2360 float* output_ptr2 = &LIBXSMM_VLA_ACCESS(5, output, img, g, fmg, ho, wo, nG, nFMG, ofh, ofw);
2361
2362 /* BN + scale (gamma, beta) */
2363 float o = gamma_ptr[g*nFMG+fmg]*(input_val - expectval_ptr[img*nG+g])*rcpstddev_ptr[img*nG+g] + beta_ptr[g*nFMG+fmg];
2364 /* Eltwise */
2365 if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2366 o += input_add_val;
2367 }
2368 /* ReLU */
2369 if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2370 o = ( o < 0.0f ) ? 0.0f : o;
2371 }
2372 *output_ptr2 = o;
2373 }
2374 }
2375 }
2376 }
2377 }
2378 }
2379
naive_fusedgroupnorm_bp(naive_fusedgroupnorm_t * param,const float * input_ptr,float * dinput_ptr,const float * output_ptr,float * doutput_ptr,float * dinput_add_ptr,const float * beta_ptr,float * del_beta_ptr,const float * gamma_ptr,float * del_gamma_ptr,const float * expectval_ptr,const float * rcpstddev_ptr,const float * variance_ptr)2380 LIBXSMM_INLINE void naive_fusedgroupnorm_bp(naive_fusedgroupnorm_t* param, const float* input_ptr, float* dinput_ptr, const float* output_ptr, float* doutput_ptr, float* dinput_add_ptr,
2381 const float* beta_ptr, float* del_beta_ptr, const float* gamma_ptr, float* del_gamma_ptr,
2382 const float* expectval_ptr, const float* rcpstddev_ptr, const float* variance_ptr)
2383 {
2384 const int nImg = param->N;
2385 const int nFm = param->C;
2386 const int ifh = param->H;
2387 const int ifw = param->W;
2388 const int sh = param->stride_h;
2389 const int sw = param->stride_w;
2390 const int ofh = ifh/sh;
2391 const int ofw = ifw/sw;
2392 const int nG = param->G;
2393 const int nFMG = nFm/nG;
2394 const float ghw = (float)(nFMG * ifh * ifw);
2395 const float recp_ghw = 1.0f/ghw;
2396 const float eps = 1e-7f;
2397
2398 int img, g, fmg, fm, hi, wi, ho, wo;
2399
2400 LIBXSMM_VLA_DECL(5, const float, input, input_ptr, nG, nFMG, ifh, ifw);
2401 LIBXSMM_VLA_DECL(5, float, dinput, dinput_ptr, nG, nFMG, ifh, ifw);
2402 /*LIBXSMM_VLA_DECL(5, const float, output, output_ptr, nG, nFMG, ofh, ofw);*/
2403 LIBXSMM_VLA_DECL(5, float, doutput, doutput_ptr, nG, nFMG, ofh, ofw);
2404
2405 LIBXSMM_VLA_DECL(4, const float, input_gb, input_ptr, nFm, ifh, ifw);
2406 LIBXSMM_VLA_DECL(4, const float, output_gb, output_ptr, nFm, ofh, ofw);
2407 LIBXSMM_VLA_DECL(4, float, doutput_gb, doutput_ptr, nFm, ofh, ofw);
2408 LIBXSMM_VLA_DECL(4, float, dinput_add, dinput_add_ptr, nFm, ifh, ifw);
2409
2410 LIBXSMM_UNUSED(beta_ptr);
2411
2412 #if defined(_OPENMP)
2413 LIBXSMM_OMP_VAR(hi); LIBXSMM_OMP_VAR(wi); LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo); LIBXSMM_OMP_VAR(g);
2414 # pragma omp parallel for private(img, fm, hi, wi, ho, wo, g)
2415 #endif
2416 for ( fm = 0; fm < nFm; fm++ ) {
2417 del_gamma_ptr[fm] = 0.0f;
2418 del_beta_ptr[fm] = 0.0f;
2419
2420 for ( img = 0; img < nImg; img++ ) {
2421 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2422 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2423 float* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(4, dinput_add, img, fm, hi, wi, nFm, ifh, ifw);
2424 const float output_val = LIBXSMM_VLA_ACCESS(4, output_gb, img, fm, ho, wo, nFm, ofh, ofw);
2425 const float input_val = LIBXSMM_VLA_ACCESS(4, input_gb, img, fm, hi, wi, nFm, ifh, ifw);
2426 float* del_output_ptr = &LIBXSMM_VLA_ACCESS(4, doutput_gb, img, fm, ho, wo, nFm, ofh, ofw);
2427
2428 /* ReLU */
2429 if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2430 *del_output_ptr = (output_val == 0) ? 0 : *del_output_ptr;
2431 }
2432 /* elementwise */
2433 if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2434 *del_input_add_ptr = *del_output_ptr;
2435 }
2436 g = fm/nFMG;
2437 del_gamma_ptr[fm] += (input_val - expectval_ptr[img*nG+g]) * (*del_output_ptr) * rcpstddev_ptr[img*nG+g];
2438 del_beta_ptr[fm] += *del_output_ptr;
2439 }
2440 }
2441 }
2442 }
2443
2444 #if defined(_OPENMP)
2445 LIBXSMM_OMP_VAR(fmg);
2446 # pragma omp parallel for private(img, g, fmg, hi, wi, ho, wo)
2447 #endif
2448 for ( img = 0; img < nImg; img++ ) {
2449 for ( g = 0; g < nG; g++ ) {
2450 float d1_val = 0.0;
2451 float d2_val = 0.0;
2452
2453 for ( fmg = 0; fmg < nFMG; fmg++ ) {
2454 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2455 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++) {
2456 const float input_val = LIBXSMM_VLA_ACCESS(5, input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2457 const float del_output_val = LIBXSMM_VLA_ACCESS(5, doutput, img, g, fmg, ho, wo, nG, nFMG, ofh, ofw);
2458
2459 d1_val += del_output_val * (input_val - expectval_ptr[img*nG+g]) * gamma_ptr[g*nFMG+fmg];
2460 d2_val += del_output_val * gamma_ptr[g*nFMG+fmg];
2461 }
2462 }
2463 }
2464
2465 for ( fmg = 0; fmg < nFMG; fmg++ ) {
2466 for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2467 for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++) {
2468 const float input_val = LIBXSMM_VLA_ACCESS(5, input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2469 const float del_output_val = LIBXSMM_VLA_ACCESS(5, doutput, img, g, fmg, ho, wo, nG, nFMG, ofh, ofw);
2470 float* del_input_ptr = &LIBXSMM_VLA_ACCESS(5, dinput, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2471
2472 float t0_val = rcpstddev_ptr[img*nG+g] * recp_ghw;
2473 *del_input_ptr = t0_val * ((gamma_ptr[g*nFMG+fmg] * ghw * del_output_val) - d2_val - ((input_val - expectval_ptr[img*nG+g]) * d1_val * (1.0f/(variance_ptr[img*nG+g]+eps))));
2474 }
2475 }
2476 }
2477 }
2478 }
2479 }
2480
lstm_fwd_copy_bias(int N,int K,float * bigold,float * bcgold,float * bfgold,float * bogold,float forget_bias,float * icfogoldt,int j)2481 LIBXSMM_INLINE void lstm_fwd_copy_bias(int N, int K, float *bigold, float *bcgold, float *bfgold, float *bogold, float forget_bias, float *icfogoldt, int j)
2482 {
2483 LIBXSMM_VLA_DECL(3, float, icfogold, icfogoldt, N, 4 * K);
2484 int i, l;
2485 #if defined(_OPENMP)
2486 LIBXSMM_OMP_VAR(i); LIBXSMM_OMP_VAR(l);
2487 # pragma omp parallel for private(i, l) LIBXSMM_OPENMP_COLLAPSE(2)
2488 #endif
2489 for (i = 0; i < N; i++) {
2490 for (l = 0; l < K; l++) {
2491 LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l, N, 4 * K) = bigold[l];
2492 LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l+K, N, 4 * K) = bcgold[l];
2493 LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l+2*K, N, 4 * K) = bfgold[l] + forget_bias;
2494 LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l+3*K, N, 4 * K) = bogold[l];
2495 }
2496 }
2497 }
2498
lstm_fwd_eltwise_merged(int N,int K,float * i,float * c,float * f,float * o,float * csp,float * cs,float * co,float * h)2499 LIBXSMM_INLINE void lstm_fwd_eltwise_merged(int N, int K, float *i, float *c, float *f, float *o, float *csp, float *cs, float *co, float *h)
2500 {
2501 int j;
2502 #if defined(__AVX512F__)
2503 int l;
2504 int rem = (K/16)*16;
2505 __m512 minus1 = _mm512_set1_ps (-1.0f);
2506 __m512 plus1 = _mm512_set1_ps (1.0f);
2507 #if defined(_OPENMP)
2508 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2509 #endif
2510 for (j = 0; j < N; j++) {
2511 for (l = 0; l < rem; l+=16) {
2512 __m512 iv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(i[j*4*K + l]));
2513 __m512 cv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(c[j*4*K + l]));
2514 __m512 fv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(f[j*4*K + l]));
2515 __m512 ov = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(o[j*4*K + l]));
2516 __m512 cspv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(csp[j*K + l]));
2517 __m512 csv, cov, hv;
2518 /* i = sigmoid(i) */
2519 iv = _mm512_mul_ps (iv, minus1);
2520 iv = LIBXSMM_INTRINSICS_MM512_EXP_PS (iv);
2521 iv = _mm512_add_ps (iv, plus1);
2522 iv = _mm512_div_ps (plus1, iv);
2523 /* c = tanh(c) */
2524 cv = LIBXSMM_INTRINSICS_MM512_TANH_PS (cv);
2525 /* f = sigmoid(f) */
2526 fv = _mm512_mul_ps (fv, minus1);
2527 fv = LIBXSMM_INTRINSICS_MM512_EXP_PS (fv);
2528 fv = _mm512_add_ps (fv, plus1);
2529 fv = _mm512_div_ps (plus1, fv);
2530 /* o = sigmoid(o) */
2531 ov = _mm512_mul_ps (ov, minus1);
2532 ov = LIBXSMM_INTRINSICS_MM512_EXP_PS (ov);
2533 ov = _mm512_add_ps (ov, plus1);
2534 ov = _mm512_div_ps (plus1, ov);
2535 /* cs = f.csp + i.c */
2536 csv = _mm512_mul_ps (fv, cspv);
2537 csv = _mm512_fmadd_ps (iv, cv, csv);
2538 /* co = tanh(cs) */
2539 cov = LIBXSMM_INTRINSICS_MM512_TANH_PS (csv);
2540 /* h = o.co */
2541 hv = _mm512_mul_ps (ov, cov);
2542 _mm512_storeu_ps (&(i[j*4*K + l]), iv);
2543 _mm512_storeu_ps (&(c[j*4*K + l]), cv);
2544 _mm512_storeu_ps (&(f[j*4*K + l]), fv);
2545 _mm512_storeu_ps (&(o[j*4*K + l]), ov);
2546 _mm512_storeu_ps (&(cs[j*K + l]), csv);
2547 _mm512_storeu_ps (&(co[j*K + l]), cov);
2548 _mm512_storeu_ps (&(h[j*K + l]), hv);
2549 }
2550 }
2551 #if defined(_OPENMP)
2552 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2553 #endif
2554 for (j = 0; j < N; j++) {
2555 for (l = rem; l < K; l++) {
2556 float exp_value;
2557 /* i = sigmoid(i) */
2558 exp_value = (float)exp((double) -i[j*4*K + l]);
2559 i[j*4*K + l] = 1.0f / (1.0f + exp_value);
2560 /* c = tanh(c) */
2561 c[j*4*K + l] = (float)tanh((double)c[j*4*K + l]);
2562 /* f = sigmoid(f) */
2563 exp_value = (float)exp((double) -f[j*4*K + l]);
2564 f[j*4*K + l] = 1.0f / (1.0f + exp_value);
2565 /* o = sigmoid(o) */
2566 exp_value = (float)exp((double) -o[j*4*K + l]);
2567 o[j*4*K + l] = 1.0f / (1.0f + exp_value);
2568 /* cs = f.csp + i.c */
2569 cs[j*K + l] = f[j*4*K + l]*csp[j*K + l] + i[j*4*K + l]*c[j*4*K + l];
2570 /* co = tanh(cs) */
2571 co[j*K + l] = (float)tanh((double)cs[j*K + l]);
2572 /* h = o.co */
2573 h[j*K + l] = o[j*4*K + l] * co[j*K + l];
2574 }
2575 }
2576 #else
2577 #if defined(_OPENMP)
2578 # pragma omp parallel for private(j)
2579 #endif
2580 for (j = 0; j < N*K; j++) {
2581 const int row = j / K;
2582 const int col = j % K;
2583 float exp_value;
2584 /* i = sigmoid(i) */
2585 exp_value = (float)exp((double) -i[row*4*K + col]);
2586 i[row*4*K + col] = 1.0f / (1.0f + exp_value);
2587 /* c = tanh(c) */
2588 c[row*4*K + col] = (float)tanh((double)c[row*4*K + col]);
2589 /* f = sigmoid(f) */
2590 exp_value = (float)exp((double) -f[row*4*K + col]);
2591 f[row*4*K + col] = 1.0f / (1.0f + exp_value);
2592 /* o = sigmoid(o) */
2593 exp_value = (float)exp((double) -o[row*4*K + col]);
2594 o[row*4*K + col] = 1.0f / (1.0f + exp_value);
2595 /* cs = f.csp + i.c */
2596 cs[j] = f[row*4*K + col]*csp[j] + i[row*4*K + col]*c[row*4*K + col];
2597 /* co = tanh(cs) */
2598 co[j] = (float)tanh((double)cs[j]);
2599 /* h = o.co */
2600 h[j] = o[row*4*K + col] * co[j];
2601 }
2602 #endif
2603 }
2604
lstm_bwd_upd_eltwise_merged(int N,int K,float * i,float * c,float * f,float * o,float * csp,float * co,float * dh,float * dout,float * di,float * dc,float * df,float * dp,float * dcsp,float * dcs)2605 LIBXSMM_INLINE void lstm_bwd_upd_eltwise_merged(int N, int K, float *i, float *c, float *f, float *o, float *csp, float *co,
2606 float *dh, float *dout, float *di, float *dc, float *df, float *dp, float *dcsp, float *dcs)
2607 {
2608 int j;
2609 #if defined(__AVX512F__)
2610 int l;
2611 int rem = (K/16)*16;
2612 __m512 plus1 = _mm512_set1_ps (1.0f);
2613 #if defined(_OPENMP)
2614 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2615 #endif
2616 for (j = 0; j < N; j++) {
2617 for (l = 0; l < rem; l+=16) {
2618 __m512 iv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(i[j*4*K + l]));
2619 __m512 cv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(c[j*4*K + l]));
2620 __m512 fv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(f[j*4*K + l]));
2621 __m512 ov = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(o[j*4*K + l]));
2622 __m512 cspv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(csp[j*K + l]));
2623 __m512 cov = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(co[j*K + l]));
2624 __m512 dcsv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dcs[j*K + l]));
2625 __m512 dhv, doutv, div, dcv, dfv, dov, dcspv, deltav, tv;
2626 /* compute delta */
2627 if (NULL == dout) {
2628 deltav = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dh[j*K + l]));
2629 } else {
2630 dhv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dh[j*K + l]));
2631 doutv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dout[j*K + l]));
2632 deltav = _mm512_add_ps (dhv, doutv);
2633 }
2634 /* compute dcsp */
2635 /* dcsp = delta.o.(1 - (co.co)) + dcs */
2636 tv = _mm512_mul_ps (cov, cov);
2637 tv = _mm512_sub_ps (plus1, tv);
2638 dcspv = _mm512_mul_ps (deltav, ov);
2639 dcspv = _mm512_fmadd_ps (dcspv, tv, dcsv);
2640 /* compute di */
2641 /* di = dcsp.c.i.(1 - i) */
2642 tv = _mm512_sub_ps (plus1, iv);
2643 tv = _mm512_mul_ps (iv, tv);
2644 div = _mm512_mul_ps (dcspv, cv);
2645 div = _mm512_mul_ps (div, tv);
2646 /* compute dc */
2647 /* dc = dcsp.i.(1 - (c.c)) */
2648 tv = _mm512_mul_ps (cv, cv);
2649 tv = _mm512_sub_ps (plus1, tv);
2650 dcv = _mm512_mul_ps (dcspv, iv);
2651 dcv = _mm512_mul_ps (dcv, tv);
2652 /* compute df */
2653 /* df = dcsp.csp.f.(1 - f) */
2654 tv = _mm512_sub_ps (plus1, fv);
2655 tv = _mm512_mul_ps (fv, tv);
2656 dfv = _mm512_mul_ps (dcspv, cspv);
2657 dfv = _mm512_mul_ps (dfv, tv);
2658 /* compute do */
2659 /* do = delta.co.o.(1 - o) */
2660 tv = _mm512_sub_ps (plus1, ov);
2661 tv = _mm512_mul_ps (ov, tv);
2662 dov = _mm512_mul_ps (deltav, cov);
2663 dov = _mm512_mul_ps (dov, tv);
2664 /* update dcsp */
2665 /* dcsp = dcsp.f */
2666 dcspv = _mm512_mul_ps (dcspv, fv);
2667 _mm512_storeu_ps (&(di[j*4*K + l]), div);
2668 _mm512_storeu_ps (&(dc[j*4*K + l]), dcv);
2669 _mm512_storeu_ps (&(df[j*4*K + l]), dfv);
2670 _mm512_storeu_ps (&(dp[j*4*K + l]), dov);
2671 _mm512_storeu_ps (&(dcsp[j*K + l]), dcspv);
2672 }
2673 }
2674 #if defined(_OPENMP)
2675 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2676 #endif
2677 for (j = 0; j < N; j++) {
2678 for (l = rem; l < K; l++) {
2679 float delta;
2680 /* compute delta */
2681 if (NULL == dout) {
2682 delta = dh[j*K + l];
2683 } else {
2684 delta = dh[j*K + l] + dout[j*K + l];
2685 }
2686 /* compute dcsp */
2687 dcsp[j*K + l] = delta * o[j*4*K + l] * (1.0f - (co[j*K + l]*co[j*K + l])) + dcs[j*K + l];
2688 /* compute di */
2689 di[j*4*K + l] = dcsp[j*K + l] * c[j*4*K + l] * i[j*4*K + l] * (1.0f - i[j*4*K + l]);
2690 /* compute dc */
2691 dc[j*4*K + l] = dcsp[j*K + l] * i[j*4*K + l] * (1.0f - (c[j*4*K + l]*c[j*4*K + l]));
2692 /* compute df */
2693 df[j*4*K + l] = dcsp[j*K + l] * csp[j*K + l] * f[j*4*K + l] * (1.0f - f[j*4*K + l]);
2694 /* compute do */
2695 dp[j*4*K + l] = delta * co[j*K + l] * o[j*4*K + l] * (1.0f - o[j*4*K + l]);
2696 /* update dcsp */
2697 dcsp[j*K + l] = dcsp[j*K + l] * f[j*4*K + l];
2698 }
2699 }
2700 #else
2701 #if defined(_OPENMP)
2702 # pragma omp parallel for private(j)
2703 #endif
2704 for (j = 0; j < N*K; j++) {
2705 const int row = j / K;
2706 const int col = j % K;
2707 float delta;
2708 /* compute delta */
2709 if (NULL == dout) {
2710 delta = dh[j];
2711 } else {
2712 delta = dh[j] + dout[j];
2713 }
2714 /* compute dcsp */
2715 dcsp[j] = delta * o[row*4*K + col] * (1.0f - (co[j]*co[j])) + dcs[j];
2716 /* compute di */
2717 di[row*4*K + col] = dcsp[j] * c[row*4*K + col] * i[row*4*K + col] * (1.0f - i[row*4*K + col]);
2718 /* compute dc */
2719 dc[row*4*K + col] = dcsp[j] * i[row*4*K + col] * (1.0f - (c[row*4*K + col]*c[row*4*K + col]));
2720 /* compute df */
2721 df[row*4*K + col] = dcsp[j] * csp[j] * f[row*4*K + col] * (1.0f - f[row*4*K + col]);
2722 /* compute do */
2723 dp[row*4*K + col] = delta * co[j] * o[row*4*K + col] * (1.0f - o[row*4*K + col]);
2724 /* update dcsp */
2725 dcsp[j] = dcsp[j] * f[row*4*K + col];
2726 }
2727 #endif
2728 }
2729
lstm_ref_fwd(int N,int C,int K,int t,float forget_bias,float * wigold,float * wcgold,float * wfgold,float * wogold,float * rigold,float * rcgold,float * rfgold,float * rogold,float * bigold,float * bcgold,float * bfgold,float * bogold,float * xgoldt,float * cspgold,float * hpgold,float * csgoldt,float * cogoldt,float * hgoldt,float * icfogoldt,float * wgold,float * rgold,float * scratch)2730 LIBXSMM_INLINE void lstm_ref_fwd( int N, int C, int K, int t, float forget_bias,
2731 float *wigold, float *wcgold, float *wfgold, float *wogold,
2732 float *rigold, float *rcgold, float *rfgold, float *rogold,
2733 float *bigold, float *bcgold, float *bfgold, float *bogold,
2734 float *xgoldt, float *cspgold, float *hpgold,
2735 float *csgoldt, float *cogoldt, float *hgoldt,
2736 float *icfogoldt, float *wgold, float *rgold, float *scratch )
2737 {
2738 #if !defined(TWO_GEMMS)
2739 float *xhgold = scratch;
2740 #endif
2741 const char transa = 'N', transb = 'N'; /* no transposes */
2742 const float alpha = 1, beta = 1;
2743 int j;
2744 int K4 = K * 4;
2745 int CK = C + K;
2746 LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N * C);
2747 LIBXSMM_VLA_DECL(2, float, csgold, csgoldt, K * N);
2748 LIBXSMM_VLA_DECL(2, float, cogold, cogoldt, K * N);
2749 LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, K * N);
2750 LIBXSMM_VLA_DECL(3, float, icfogold, icfogoldt, N, 4 * K);
2751 #if defined(PROFILE)
2752 Gbl_conv_start = libxsmm_timer_tick();
2753 #endif
2754 #if defined(TWO_GEMMS)
2755 convert_ck_c4k(C, K, wigold, wgold);
2756 convert_ck_c4k(C, K, wcgold, &(wgold[K]));
2757 convert_ck_c4k(C, K, wfgold, &(wgold[2*K]));
2758 convert_ck_c4k(C, K, wogold, &(wgold[3*K]));
2759 convert_ck_c4k(K, K, rigold, rgold);
2760 convert_ck_c4k(K, K, rcgold, &(rgold[K]));
2761 convert_ck_c4k(K, K, rfgold, &(rgold[2*K]));
2762 convert_ck_c4k(K, K, rogold, &(rgold[3*K]));
2763 #else
2764 LIBXSMM_UNUSED(rgold);
2765 convert_ck_c4k(C, K, wigold, wgold);
2766 convert_ck_c4k(C, K, wcgold, &(wgold[K]));
2767 convert_ck_c4k(C, K, wfgold, &(wgold[2*K]));
2768 convert_ck_c4k(C, K, wogold, &(wgold[3*K]));
2769 convert_ck_c4k(K, K, rigold, &(wgold[C*K*4]));
2770 convert_ck_c4k(K, K, rcgold, &(wgold[C*K*4 + K]));
2771 convert_ck_c4k(K, K, rfgold, &(wgold[C*K*4 + 2*K]));
2772 convert_ck_c4k(K, K, rogold, &(wgold[C*K*4 + 3*K]));
2773 #endif
2774 #if defined(PROFILE)
2775 Gbl_conv_end = libxsmm_timer_tick();
2776 Gbl_conv_total += libxsmm_timer_duration(Gbl_conv_start, Gbl_conv_end);
2777 #endif
2778 for (j = 0; j < t; ++j) {
2779 /* Initialization with bias */
2780 #if defined(PROFILE)
2781 Gbl_copy_bias_start = libxsmm_timer_tick();
2782 #endif
2783 lstm_fwd_copy_bias(N, K, bigold, bcgold, bfgold, bogold, forget_bias, icfogoldt, j);
2784 #if defined(PROFILE)
2785 Gbl_copy_bias_end = libxsmm_timer_tick();
2786 Gbl_copy_bias_total += libxsmm_timer_duration(Gbl_copy_bias_start, Gbl_copy_bias_end);
2787 Gbl_blas_start = libxsmm_timer_tick();
2788 #endif
2789 #if defined(TWO_GEMMS)
2790 /* icfo += W * x */
2791 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &C, &alpha, wgold, &K4, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K), &K4);
2792 /* icfo += R * h */
2793 if (j == 0) {
2794 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &K, &alpha, rgold, &K4, hpgold, &K, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 0, N, 4 * K), &K4);
2795 } else {
2796 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &K, &alpha, rgold, &K4, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K), &K4);
2797 }
2798 #else
2799 /* Concatenate x and h */
2800 convert_nk_nck(N, C, C+K, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), xhgold);
2801 if (j == 0) {
2802 convert_nk_nck(N, K, C+K, hpgold, &(xhgold[C]));
2803 } else {
2804 convert_nk_nck(N, K, C+K, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &(xhgold[C]));
2805 }
2806 /* icfo += (W * x) + (R * h) */
2807 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &CK, &alpha, wgold, &K4, xhgold, &CK, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K), &K4);
2808 #endif
2809 #if defined(PROFILE)
2810 Gbl_blas_end = libxsmm_timer_tick();
2811 Gbl_blas_total += libxsmm_timer_duration(Gbl_blas_start, Gbl_blas_end);
2812 Gbl_eltwise_start = libxsmm_timer_tick();
2813 #endif
2814 if (j == 0) {
2815 lstm_fwd_eltwise_merged( N, K,
2816 &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 0, N, 4 * K),
2817 &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, K, N, 4 * K),
2818 &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 2*K, N, 4 * K),
2819 &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 3*K, N, 4 * K),
2820 cspgold,
2821 &LIBXSMM_VLA_ACCESS(2, csgold, 0, 0, K * N),
2822 &LIBXSMM_VLA_ACCESS(2, cogold, 0, 0, K * N),
2823 &LIBXSMM_VLA_ACCESS(2, hgold, 0, 0, K * N) );
2824 } else {
2825 lstm_fwd_eltwise_merged( N, K,
2826 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K),
2827 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, K, N, 4 * K),
2828 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 2*K, N, 4 * K),
2829 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 3*K, N, 4 * K),
2830 &LIBXSMM_VLA_ACCESS(2, csgold, j-1, 0, K * N),
2831 &LIBXSMM_VLA_ACCESS(2, csgold, j, 0, K * N),
2832 &LIBXSMM_VLA_ACCESS(2, cogold, j, 0, K * N),
2833 &LIBXSMM_VLA_ACCESS(2, hgold, j, 0, K * N) );
2834 }
2835 #if defined(PROFILE)
2836 Gbl_eltwise_end = libxsmm_timer_tick();
2837 Gbl_eltwise_total += libxsmm_timer_duration(Gbl_eltwise_start, Gbl_eltwise_end);
2838 #endif
2839 }
2840 }
2841
lstm_ref_bwd_upd(int N,int C,int K,int t,float * xgoldt,float * cspgold,float * hpgold,float * csgoldt,float * cogoldt,float * hgoldt,float * icfogoldt,float * wgold,float * rgold,float * dcsgold,float * dhgoldt,float * dwgold,float * drgold,float * dbgold,float * dxgoldt,float * dcspgold,float * dhpgold,float * scratch)2842 LIBXSMM_INLINE void lstm_ref_bwd_upd( int N, int C, int K, int t,
2843 float *xgoldt, float *cspgold, float *hpgold,
2844 float *csgoldt, float *cogoldt, float *hgoldt,
2845 float *icfogoldt, float *wgold, float *rgold,
2846 float *dcsgold, float *dhgoldt,
2847 float *dwgold, float *drgold, float *dbgold,
2848 float *dxgoldt, float *dcspgold, float *dhpgold, float *scratch )
2849 {
2850 #if !defined(TWO_GEMMS)
2851 float *xhgold = &(scratch[K*N*t*5]);
2852 float *dxhgold = &(scratch[K*N*t*5 + (C+K)*N]);
2853 #endif
2854 float *dicfogoldt = scratch;
2855 float *doutgoldt = &(scratch[K*N*t*4]);
2856 float *dout, *dcs, *csp;
2857 const char transa = 'N', transb = 'N'; /* no transposes */
2858 const char transaT = 'T', transbT = 'T'; /* transposes */
2859 const float alpha = 1, beta = 1, beta0 = 0;
2860 int j, l, p;
2861 int K4 = K * 4;
2862 int CK = C + K;
2863 LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N * C);
2864 LIBXSMM_VLA_DECL(2, float, csgold, csgoldt, K * N);
2865 LIBXSMM_VLA_DECL(2, float, cogold, cogoldt, K * N);
2866 LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, K * N);
2867 LIBXSMM_VLA_DECL(3, float, icfogold, icfogoldt, N, 4 * K);
2868 LIBXSMM_VLA_DECL(2, float, dxgold, dxgoldt, N * C);
2869 LIBXSMM_VLA_DECL(2, float, dhgold, dhgoldt, K * N);
2870 LIBXSMM_VLA_DECL(3, float, dicfogold, dicfogoldt, N, 4 * K);
2871 LIBXSMM_VLA_DECL(2, float, doutgold, doutgoldt, K * N);
2872 for (j = t-1; j >= 0; --j) {
2873 #if defined(PROFILE)
2874 Gbl_eltwise_start = libxsmm_timer_tick();
2875 #endif
2876 if (t-1 == j) {
2877 dout = NULL;
2878 dcs = dcsgold;
2879 } else {
2880 dout = &LIBXSMM_VLA_ACCESS(2, doutgold, j, 0, K * N);
2881 dcs = dcspgold;
2882 }
2883 if (0 == j) {
2884 csp = cspgold;
2885 } else {
2886 csp = &LIBXSMM_VLA_ACCESS(2, csgold, j-1, 0, K * N);
2887 }
2888 lstm_bwd_upd_eltwise_merged( N, K,
2889 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K),
2890 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, K, N, 4 * K),
2891 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 2*K, N, 4 * K),
2892 &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 3*K, N, 4 * K),
2893 csp,
2894 &LIBXSMM_VLA_ACCESS(2, cogold, j, 0, K * N),
2895 &LIBXSMM_VLA_ACCESS(2, dhgold, j, 0, K * N),
2896 dout,
2897 &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K),
2898 &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, K, N, 4 * K),
2899 &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 2*K, N, 4 * K),
2900 &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 3*K, N, 4 * K),
2901 dcspgold, dcs);
2902 #if defined(PROFILE)
2903 Gbl_eltwise_end = libxsmm_timer_tick();
2904 Gbl_eltwise_total += libxsmm_timer_duration(Gbl_eltwise_start, Gbl_eltwise_end);
2905 Gbl_blas_start = libxsmm_timer_tick();
2906 #endif
2907 #if defined(TWO_GEMMS)
2908 if (j > 0) {
2909 /* compute dout */
2910 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K4, &alpha, rgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &beta0, &LIBXSMM_VLA_ACCESS(2, doutgold, j-1, 0, K * N), &K);
2911 } else {
2912 /* compute dhp */
2913 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K4, &alpha, rgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, 0, 0, 0, N, 4 * K), &K4, &beta0, dhpgold, &K);
2914 }
2915
2916 /* compute dx */
2917 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K4, &alpha, wgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &beta, &LIBXSMM_VLA_ACCESS(2, dxgold, j, 0, N * C), &C);
2918
2919 /* compute dw */
2920 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &C, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), &C, &beta, dwgold, &K4);
2921
2922 /* compute dr */
2923 if (j == 0) {
2924 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &K, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, hpgold, &K, &beta, drgold, &K4);
2925 } else {
2926 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &K, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &K, &beta, drgold, &K4);
2927 }
2928 #else
2929 LIBXSMM_UNUSED(rgold); LIBXSMM_UNUSED(drgold);
2930 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &CK, &N, &K4, &alpha, wgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &beta0, dxhgold, &CK);
2931 matrix_copy_ld(C, N, C+K, dxhgold, &LIBXSMM_VLA_ACCESS(2, dxgold, j, 0, N * C));
2932 if (j > 0) {
2933 matrix_copy_ld(K, N, C+K, &(dxhgold[C]), &LIBXSMM_VLA_ACCESS(2, doutgold, j-1, 0, K * N));
2934 } else {
2935 matrix_copy_ld(K, N, C+K, &(dxhgold[C]), dhpgold);
2936 }
2937
2938 /* Concatenate x and h */
2939 convert_nk_nck(N, C, C+K, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), xhgold);
2940 if (j == 0) {
2941 convert_nk_nck(N, K, C+K, hpgold, &(xhgold[C]));
2942 } else {
2943 convert_nk_nck(N, K, C+K, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &(xhgold[C]));
2944 }
2945 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &CK, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, xhgold, &CK, &beta, dwgold, &K4);
2946 #endif
2947 #if defined(PROFILE)
2948 Gbl_blas_end = libxsmm_timer_tick();
2949 Gbl_blas_total += libxsmm_timer_duration(Gbl_blas_start, Gbl_blas_end);
2950 #endif
2951 /* compute db */
2952 #if defined(_OPENMP)
2953 LIBXSMM_OMP_VAR(p);
2954 # pragma omp parallel for private(l, p)
2955 #endif
2956 for (l = 0; l < K; l++) {
2957 for (p = 0; p < N; p++) {
2958 dbgold[l] += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l, N, 4 * K);
2959 dbgold[l + K] += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l + K, N, 4 * K);
2960 dbgold[l + 2*K] += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l + 2*K, N, 4 * K);
2961 dbgold[l + 3*K] += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l + 3*K, N, 4 * K);
2962 }
2963 }
2964 }
2965 }
2966
gru_ref_fwd(int N,int C,int K,int t,float * wi,float * wc,float * wf,float * ri,float * rc,float * rf,float * bi,float * bc,float * bf,float * xt,float * hp,float * ht,float * it,float * ct,float * ft,float * ot)2967 LIBXSMM_INLINE void gru_ref_fwd( int N, int C, int K, int t,
2968 float *wi, float *wc, float *wf,
2969 float *ri, float *rc, float *rf,
2970 float *bi, float *bc, float *bf,
2971 float *xt, float *hp, float *ht,
2972 float *it, float *ct, float *ft, float *ot )
2973 {
2974 const char transa = 'N', transb = 'N'; /* no transposes */
2975 const float alpha = 1, beta = 1;
2976 int j;
2977 LIBXSMM_VLA_DECL(2, float, x, xt, N * C);
2978 LIBXSMM_VLA_DECL(2, float, h, ht, K * N);
2979 LIBXSMM_VLA_DECL(2, float, i, it, K * N);
2980 LIBXSMM_VLA_DECL(2, float, c, ct, K * N);
2981 LIBXSMM_VLA_DECL(2, float, f, ft, K * N);
2982 LIBXSMM_VLA_DECL(2, float, o, ot, K * N);
2983 for (j = 0; j < t; ++j) {
2984 /* i_t = b_i */
2985 matrix_copy_bias(K, N, K, bi, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N));
2986 /* i_t += W_i * x_t */
2987 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wi, &K, &LIBXSMM_VLA_ACCESS(2, x, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &K);
2988 /* i_t += R_i * h_{t-1} */
2989 if (0 == j) {
2990 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ri, &K, hp, &K, &beta, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &K);
2991 } else {
2992 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ri, &K, &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &K);
2993 }
2994 /* i_t = sigmoid(i_t) */
2995 matrix_sigmoid(N*K, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N));
2996 /* c_t = b_c */
2997 matrix_copy_bias(K, N, K, bc, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N));
2998 /* c_t += W_c * x_t */
2999 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wc, &K, &LIBXSMM_VLA_ACCESS(2, x, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &K);
3000 /* c_t += R_c * h_{t-1} */
3001 if (0 == j) {
3002 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, rc, &K, hp, &K, &beta, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &K);
3003 } else {
3004 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, rc, &K, &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &K);
3005 }
3006 /* c_t = sigmoid(c_t) */
3007 matrix_sigmoid(N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N));
3008 /* o_t = h_{t-1} . i_t */
3009 if (0 == j) {
3010 matrix_eltwise_mult(N*K, hp, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, o, j, 0, K * N));
3011 } else {
3012 matrix_eltwise_mult(N*K, &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, o, j, 0, K * N));
3013 }
3014 /* f_t = b_f */
3015 matrix_copy_bias(K, N, K, bf, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N));
3016 /* f_t += W_f * x_t */
3017 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wf, &K, &LIBXSMM_VLA_ACCESS(2, x, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &K);
3018 /* f_t += R_f * o_t */
3019 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, rf, &K, &LIBXSMM_VLA_ACCESS(2, o, j, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &K);
3020 /* f_t = tanh(f_t) */
3021 matrix_tanh(N*K, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N));
3022 /* h_t = (1 - c_t) . f_t */
3023 matrix_complement (N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3024 matrix_eltwise_mult(N*K, &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3025 /* h_t += c_t . h_{t-1} */
3026 if (0 == j) {
3027 matrix_eltwise_fma(N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), hp, &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3028 } else {
3029 matrix_eltwise_fma(N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3030 }
3031 }
3032 }
3033
gru_ref_bwd_upd(int N,int C,int K,int t,float * xt,float * hpD,float * ht,float * it,float * ct,float * ft,float * ot,float * wi,float * wc,float * wf,float * ri,float * rc,float * rf,float * dht,float * dw,float * dr,float * db,float * dxt,float * dhpD,float * scratch)3034 LIBXSMM_INLINE void gru_ref_bwd_upd( int N, int C, int K, int t,
3035 float *xt, float *hpD, float *ht,
3036 float *it, float *ct, float *ft, float *ot,
3037 float *wi, float *wc, float *wf,
3038 float *ri, float *rc, float *rf,
3039 float *dht, float *dw, float *dr, float *db,
3040 float *dxt, float *dhpD, float *scratch )
3041 {
3042 const char transa = 'N', transb = 'N'; /* no transposes */
3043 const char transaT = 'T', transbT = 'T'; /* transposes */
3044 const float alpha = 1, beta = 1, beta0 = 0;
3045 int j, l, p;
3046 float *dwi = dw;
3047 float *dwc = &(dw[C*K]);
3048 float *dwf = &(dw[2*C*K]);
3049 float *dri = dr;
3050 float *drc = &(dr[K*K]);
3051 float *drf = &(dr[2*K*K]);
3052 float *dbi = db;
3053 float *dbc = &(db[K]);
3054 float *dbf = &(db[2*K]);
3055 float *deltaD = scratch;
3056 float *doutD = &(scratch[N*K]);
3057 float *diD = &(scratch[2*N*K]);
3058 float *dcD = &(scratch[3*N*K]);
3059 float *dfD = &(scratch[4*N*K]);
3060 float *doD = &(scratch[5*N*K]);
3061 LIBXSMM_VLA_DECL(3, float, x, xt, N, C);
3062 LIBXSMM_VLA_DECL(2, float, hp, hpD, K);
3063 LIBXSMM_VLA_DECL(3, float, h, ht, N, K);
3064 LIBXSMM_VLA_DECL(3, float, i, it, N, K);
3065 LIBXSMM_VLA_DECL(3, float, c, ct, N, K);
3066 LIBXSMM_VLA_DECL(3, float, f, ft, N, K);
3067 LIBXSMM_VLA_DECL(3, float, o, ot, N, K);
3068 LIBXSMM_VLA_DECL(3, float, dx, dxt, N, C);
3069 LIBXSMM_VLA_DECL(2, float, dhp, dhpD, K);
3070 LIBXSMM_VLA_DECL(3, float, dh, dht, N, K);
3071 LIBXSMM_VLA_DECL(2, float, di, diD, K);
3072 LIBXSMM_VLA_DECL(2, float, dc, dcD, K);
3073 LIBXSMM_VLA_DECL(2, float, df, dfD, K);
3074 LIBXSMM_VLA_DECL(2, float, dp, doD, K);
3075 LIBXSMM_VLA_DECL(2, float, dout, doutD, K);
3076 LIBXSMM_VLA_DECL(2, float, delta, deltaD, K);
3077 for (j = t-1; j >= 0; j--) {
3078 #if defined(_OPENMP)
3079 LIBXSMM_OMP_VAR(p);
3080 # pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3081 #endif
3082 for (l = 0; l < N; l++) {
3083 for (p = 0; p < K; p++) {
3084 if (t-1 == j) {
3085 LIBXSMM_VLA_ACCESS(2, delta, l, p, K) = LIBXSMM_VLA_ACCESS(3, dh, t-1, l, p, N, K);
3086 } else {
3087 LIBXSMM_VLA_ACCESS(2, delta, l, p, K) = LIBXSMM_VLA_ACCESS(3, dh, j, l, p, N, K) + LIBXSMM_VLA_ACCESS(2, dout, l, p, K);
3088 }
3089 /* df = delta . (1 - c_t) . (1 - (f_t . f_t)) */
3090 LIBXSMM_VLA_ACCESS(2, df, l, p, K) = LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K)) * (1.0f - (LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K) * LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K)));
3091 /* dc = delta . (h_{t-1} - f_t) . c_t . (1 - c_t) */
3092 if (0 == j) {
3093 LIBXSMM_VLA_ACCESS(2, dc, l, p, K) = LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * (LIBXSMM_VLA_ACCESS(2, hp, l, p, K) - LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K)) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K));
3094 } else {
3095 LIBXSMM_VLA_ACCESS(2, dc, l, p, K) = LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * (LIBXSMM_VLA_ACCESS(3, h, j-1, l, p, N, K) - LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K)) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K));
3096 }
3097 }
3098 }
3099 /* do = {R_f}^T * df */
3100 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, rf, &K, dfD, &K, &beta0, doD, &K);
3101 /* di = do . h_{t-1} . i_t . (1 - i_t) */
3102 if (0 == j) {
3103 #if defined(_OPENMP)
3104 # pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3105 #endif
3106 for (l = 0; l < N; l++) {
3107 for (p = 0; p < K; p++) {
3108 LIBXSMM_VLA_ACCESS(2, di, l, p, K) = LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(2, hp, l, p, K) * LIBXSMM_VLA_ACCESS(3, i, 0, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, i, 0, l, p, N, K));
3109 }
3110 }
3111 } else {
3112 #if defined(_OPENMP)
3113 # pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3114 #endif
3115 for (l = 0; l < N; l++) {
3116 for (p = 0; p < K; p++) {
3117 LIBXSMM_VLA_ACCESS(2, di, l, p, K) = LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(3, h, j-1, l, p, N, K) * LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K));
3118 }
3119 }
3120 }
3121 /* dx_t = {W_i}^T * di */
3122 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K, &alpha, wi, &K, diD, &K, &beta0, &LIBXSMM_VLA_ACCESS(3, dx, j, 0, 0, N, C), &C);
3123 /* dx_t += {W_c}^T * dc */
3124 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K, &alpha, wc, &K, dcD, &K, &beta, &LIBXSMM_VLA_ACCESS(3, dx, j, 0, 0, N, C), &C);
3125 /* dx_t += {W_f}^T * df */
3126 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K, &alpha, wf, &K, dfD, &K, &beta, &LIBXSMM_VLA_ACCESS(3, dx, j, 0, 0, N, C), &C);
3127 /* dh_{t-1} = {R_i}^T * di */
3128 /* dh_{t-1} += {R_c}^T * dc */
3129 if (0 == j) {
3130 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, ri, &K, diD, &K, &beta0, dhpD, &K);
3131 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, rc, &K, dcD, &K, &beta, dhpD, &K);
3132 } else {
3133 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, ri, &K, diD, &K, &beta0, doutD, &K);
3134 LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, rc, &K, dcD, &K, &beta, doutD, &K);
3135 }
3136 /* dh_{t-1} += do * i_t + delta * c_t */
3137 if (0 == j) {
3138 #if defined(_OPENMP)
3139 # pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3140 #endif
3141 for (l = 0; l < N; l++) {
3142 for (p = 0; p < K; p++) {
3143 LIBXSMM_VLA_ACCESS(2, dhp, l, p, K) += LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K) + LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K);
3144 }
3145 }
3146 } else {
3147 #if defined(_OPENMP)
3148 # pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3149 #endif
3150 for (l = 0; l < N; l++) {
3151 for (p = 0; p < K; p++) {
3152 LIBXSMM_VLA_ACCESS(2, dout, l, p, K) += LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K) + LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K);
3153 }
3154 }
3155 }
3156 /* dw_i += di * {x_t}^T */
3157 /* dw_c += dc * {x_t}^T */
3158 /* dw_f += df * {x_t}^T */
3159 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &C, &N, &alpha, diD, &K, &LIBXSMM_VLA_ACCESS(3, x, j, 0, 0, N, C), &C, &beta, dwi, &K);
3160 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &C, &N, &alpha, dcD, &K, &LIBXSMM_VLA_ACCESS(3, x, j, 0, 0, N, C), &C, &beta, dwc, &K);
3161 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &C, &N, &alpha, dfD, &K, &LIBXSMM_VLA_ACCESS(3, x, j, 0, 0, N, C), &C, &beta, dwf, &K);
3162 /* dr_i += di * {o_t}^T */
3163 /* dr_c += dc * {o_t}^T */
3164 /* dr_f += df * {h_{t-1}}^T */
3165 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, diD, &K, &LIBXSMM_VLA_ACCESS(3, o, j, 0, 0, N, K), &K, &beta, dri, &K);
3166 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, dcD, &K, &LIBXSMM_VLA_ACCESS(3, o, j, 0, 0, N, K), &K, &beta, drc, &K);
3167 if (0 == j) {
3168 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, dfD, &K, &LIBXSMM_VLA_ACCESS(2, hp, 0, 0, K), &K, &beta, drf, &K);
3169 } else {
3170 LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, dfD, &K, &LIBXSMM_VLA_ACCESS(3, h, j-1, 0, 0, N, K), &K, &beta, drf, &K);
3171 }
3172 /* compute db */
3173 #if defined(_OPENMP)
3174 # pragma omp parallel for private(l, p)
3175 #endif
3176 for (l = 0; l < K; l++) {
3177 for (p = 0; p < N; p++) {
3178 dbi[l] += LIBXSMM_VLA_ACCESS(2, di, p, l, K);
3179 dbc[l] += LIBXSMM_VLA_ACCESS(2, dc, p, l, K);
3180 dbf[l] += LIBXSMM_VLA_ACCESS(2, df, p, l, K);
3181 }
3182 }
3183 }
3184 }
3185
3186