1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 
12 #include <libxsmm.h>
13 #include <libxsmm_intrinsics_x86.h>
14 
15 #if defined(_OPENMP)
16 # include <omp.h>
17 #endif
18 
19 typedef struct {
20   int nImg;
21   int nIfm;
22   int nOfm;
23   int ifhp;
24   int ifwp;
25   int ifh;
26   int ifw;
27   int ofhp;
28   int ofwp;
29   int ofh;
30   int ofw;
31   int pad_h;
32   int pad_w;
33   int pad_h_in;
34   int pad_w_in;
35   int pad_h_out;
36   int pad_w_out;
37   int kh;
38   int kw;
39   int stride_h;
40   int stride_w;
41 } naive_conv_t;
42 
43 typedef struct {
44   int N;
45   int C;
46   int H;
47   int W;
48   int stride_h;
49   int stride_w;
50   int norm_type;  /* 0: full batchnorm, 1: batch scaling only */
51   int fuse_type;  /* 0: nothing fused, 1: relu fused, 2: elementwise fused, 3: relu and elementwise fused */
52 } naive_fusedbatchnorm_t;
53 
54 typedef struct {
55   int N;
56   int C;
57   int G;
58   int H;
59   int W;
60   int stride_h;
61   int stride_w;
62   int fuse_type;  /* 0: nothing fused, 1: relu fused, 2: elementwise fused, 3: relu and elementwise fused */
63 } naive_fusedgroupnorm_t;
64 
65 typedef struct {
66   int N;
67   int C;
68   int K;
69   int fuse_type;  /* 0: nothing fused */
70 } naive_fullyconnected_t;
71 
72 typedef struct {
73   int N;
74   int C;
75   int H;
76   int W;
77   int R;
78   int S;
79   int pad_h;
80   int pad_w;
81   int stride_h;
82   int stride_w;
83   int type;
84 } naive_pooling_t;
85 
86 /* it's fine to alias in and out */
truncate_mask_fp32_bf16(float * in,float * out,unsigned int len)87 LIBXSMM_INLINE void truncate_mask_fp32_bf16(float* in, float* out, unsigned int len) {
88   unsigned int i = 0;
89 
90   /* truncate buffer to bf16 */
91   for ( i = 0; i < len; ++i ) {
92     union libxsmm_bfloat16_hp t;
93 
94     t.f = in[i];
95     t.i[0] = 0;
96     out[i] = t.f;
97   }
98 }
99 
100 /* it's fine to alias in and out */
rnaz_mask_fp32_bf16(float * in,float * out,unsigned int len)101 LIBXSMM_INLINE void rnaz_mask_fp32_bf16(float* in, float* out, unsigned int len) {
102   unsigned int i = 0;
103 
104   /* rnaz buffer to bf16 */
105   for ( i = 0; i < len; ++i ) {
106     unsigned int int_round = 0;
107     unsigned int do_round = 1;
108     const void *const ptr = &int_round;
109 
110     int_round = *((unsigned int*)&(in[i]));
111 
112     /* we don't round NaN and inf */
113     if ( (int_round & 0x7f800000) == 0x7f800000 ) {
114       do_round = 0;
115     }
116 
117     /* perform round nearest tie away from zero */
118     if ( do_round != 0 ) {
119       int_round = int_round + 0x00008000;
120     }
121 
122     /* chop bits to create BFP16 in FP32 */
123     int_round = int_round & 0xffff0000;
124 
125     out[i] = *((float*)ptr);
126   }
127 }
128 
129 /* it's fine to alias in and out */
rne_mask_fp32_bf16(float * in,float * out,unsigned int len)130 LIBXSMM_INLINE void rne_mask_fp32_bf16(float* in, float* out, unsigned int len) {
131   unsigned int i = 0;
132 
133   /* rnaz buffer to bf16 */
134   for ( i = 0; i < len; ++i ) {
135     unsigned int int_round = 0;
136     unsigned int do_round = 1;
137     const void *const ptr = &int_round;
138 
139     int_round = *((unsigned int*)&(in[i]));
140 
141     /* we don't round NaN and inf */
142     if ( (int_round & 0x7f800000) == 0x7f800000 ) {
143       do_round = 0;
144     }
145 
146     /* perform round nearest tie even */
147     if ( do_round != 0 ) {
148       unsigned int fixup = (int_round >> 16) & 1;
149       int_round = int_round + 0x00007fff + fixup;
150     }
151 
152     /* chop bits to create BFP16 in FP32 */
153     int_round = int_round & 0xffff0000;
154 
155     out[i] = *((float*)ptr);
156   }
157 }
158 
zero_buf(float * buf,size_t size)159 LIBXSMM_INLINE void zero_buf(float* buf, size_t size) {
160   int i;
161 #if defined(_OPENMP)
162   LIBXSMM_OMP_VAR(i);
163 # pragma omp parallel for private(i)
164 #endif
165   for (i = 0; i < (int)size; ++i) {
166     buf[i] = 0.0f;
167   }
168 }
169 
zero_buf_bf16(libxsmm_bfloat16 * buf,size_t size)170 LIBXSMM_INLINE void zero_buf_bf16(libxsmm_bfloat16* buf, size_t size) {
171   int i;
172 #if defined(_OPENMP)
173 # pragma omp parallel for private(i)
174 #endif
175   for (i = 0; i < (int)size; ++i) {
176     buf[i] = 0;
177   }
178 }
179 
zero_buf_int16(short * buf,size_t size)180 LIBXSMM_INLINE void zero_buf_int16(short* buf, size_t size) {
181   int i;
182 #if defined(_OPENMP)
183   LIBXSMM_OMP_VAR(i);
184 # pragma omp parallel for private(i)
185 #endif
186   for (i = 0; i < (int)size; ++i) {
187     buf[i] = 0;
188   }
189 }
190 
zero_buf_int32(int * buf,size_t size)191 LIBXSMM_INLINE void zero_buf_int32(int* buf, size_t size) {
192   int i;
193 #if defined(_OPENMP)
194   LIBXSMM_OMP_VAR(i);
195 # pragma omp parallel for private(i)
196 #endif
197   for (i = 0; i < (int)size; ++i) {
198     buf[i] = 0;
199   }
200 }
201 
zero_buf_int8(char * buf,size_t size)202 LIBXSMM_INLINE void zero_buf_int8(char* buf, size_t size) {
203   int i;
204 #if defined(_OPENMP)
205   LIBXSMM_OMP_VAR(i);
206 # pragma omp parallel for private(i)
207 #endif
208   for (i = 0; i < (int)size; ++i) {
209     buf[i] = 0;
210   }
211 }
212 
zero_buf_uint8(unsigned char * buf,size_t size)213 LIBXSMM_INLINE void zero_buf_uint8(unsigned char* buf, size_t size) {
214   int i;
215 #if defined(_OPENMP)
216   LIBXSMM_OMP_VAR(i);
217 # pragma omp parallel for private(i)
218 #endif
219   for (i = 0; i < (int)size; ++i) {
220     buf[i] = 0;
221   }
222 }
223 
copy_buf(float * src,float * dst,size_t size)224 LIBXSMM_INLINE void copy_buf(float* src, float* dst, size_t size) {
225   int i;
226 #if defined(_OPENMP)
227   LIBXSMM_OMP_VAR(i);
228 # pragma omp parallel for private(i)
229 #endif
230   for (i = 0; i < (int)size; ++i) {
231     dst[i] = src[i];
232   }
233 }
234 
copy_buf_int16(short * src,short * dst,size_t size)235 LIBXSMM_INLINE void copy_buf_int16(short* src, short* dst, size_t size) {
236   int i;
237 #if defined(_OPENMP)
238   LIBXSMM_OMP_VAR(i);
239 # pragma omp parallel for private(i)
240 #endif
241   for (i = 0; i < (int)size; ++i) {
242     dst[i] = src[i];
243   }
244 }
245 
copy_buf_int8(char * src,char * dst,size_t size)246 LIBXSMM_INLINE void copy_buf_int8(char* src, char* dst, size_t size) {
247   int i;
248 #if defined(_OPENMP)
249   LIBXSMM_OMP_VAR(i);
250 # pragma omp parallel for private(i)
251 #endif
252   for (i = 0; i < (int)size; ++i) {
253     dst[i] = src[i];
254   }
255 }
256 
copy_buf_uint8(unsigned char * src,unsigned char * dst,size_t size)257 LIBXSMM_INLINE void copy_buf_uint8(unsigned char* src, unsigned char* dst, size_t size) {
258   int i;
259 #if defined(_OPENMP)
260   LIBXSMM_OMP_VAR(i);
261 # pragma omp parallel for private(i)
262 #endif
263   for (i = 0; i < (int)size; ++i) {
264     dst[i] = src[i];
265   }
266 }
267 
init_buf(float * buf,size_t size,int initPos,int initOne)268 LIBXSMM_INLINE void init_buf(float* buf, size_t size, int initPos, int initOne)
269 {
270   int i;
271   zero_buf(buf, size);
272 #if defined(_OPENMP)
273 # pragma omp parallel for private(i)
274 #endif
275   for (i = 0; i < (int)size; ++i) {
276     buf[i] = (float)((initOne != 0) ? 1.0 : ((initPos != 0) ? libxsmm_rng_f64() : (0.05 - libxsmm_rng_f64()/10.0)));
277   }
278 }
279 
init_buf_bf16(libxsmm_bfloat16 * buf,size_t size,int initPos,int initOne)280 LIBXSMM_INLINE void init_buf_bf16(libxsmm_bfloat16* buf, size_t size, int initPos, int initOne)
281 {
282   int i;
283   zero_buf_bf16(buf, size);
284 #if defined(_OPENMP)
285 # pragma omp parallel for private(i)
286 #endif
287   for (i = 0; i < (int)size; ++i) {
288     libxsmm_bfloat16_hp tmp;
289     tmp.f = (float)((initOne != 0) ? 1.0 : ((initPos != 0) ? libxsmm_rng_f64() : (0.05 - libxsmm_rng_f64()/10.0)));
290     buf[i] = tmp.i[1];
291   }
292 }
293 
libxsmm_dnn_dequantize_int8(char * in_buffer,float * out_buffer,int length,unsigned char scf)294 LIBXSMM_INLINE void libxsmm_dnn_dequantize_int8( char* in_buffer, float* out_buffer, int length, unsigned char scf ) {
295   const float val_exp = libxsmm_sexp2_i8i(-scf);
296   int i = 0;
297 #ifdef _OPENMP
298 # pragma omp parallel for private(i)
299 #endif
300   for ( i = 0; i < length; ++i ) {
301     out_buffer[i] = ((float)in_buffer[i])*val_exp;
302   }
303 }
304 
libxsmm_internal_get_max_common(float * in_buffer,int length)305 LIBXSMM_INLINE float libxsmm_internal_get_max_common( float* in_buffer, int length ) {
306   float absmax_value = LIBXSMM_ABS(in_buffer[0]);
307   int i = 0;
308   for (i = 1; i < length; ++i ) {
309     if (LIBXSMM_ABS(in_buffer[i]) > absmax_value) {
310       absmax_value = LIBXSMM_ABS(in_buffer[i]);
311     }
312   }
313   return absmax_value;
314 }
315 
quantize_buffer_char(float * in_buffer,char * out_buffer,int size,unsigned char add_shift,unsigned char * scf)316 LIBXSMM_INLINE void quantize_buffer_char(float *in_buffer, char *out_buffer, int size, unsigned char add_shift, unsigned char* scf) {
317   int i;
318   const float max_value = libxsmm_internal_get_max_common(in_buffer, size);
319   int maxexp = 0;
320   /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
321   float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
322   maxexp -= (7 - add_shift);
323   scfq = libxsmm_sexp2_i8i(-maxexp);
324   for (i=0; i<size; i++) {
325     out_buffer[i] = (char)LIBXSMM_ROUNDF(in_buffer[i]*scfq);
326   }
327   *scf = (unsigned char)(-maxexp);
328 }
329 
quantize_buffer_uchar(float * in_buffer,unsigned char * out_buffer,int size,unsigned char add_shift,unsigned char * scf)330 LIBXSMM_INLINE void quantize_buffer_uchar(float *in_buffer, unsigned char *out_buffer, int size, unsigned char add_shift, unsigned char* scf) {
331   int i;
332   const float max_value = libxsmm_internal_get_max_common(in_buffer, size);
333   int maxexp = 0;
334   /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
335   float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
336   maxexp -= (7 - add_shift);
337   scfq = libxsmm_sexp2_i8i(-maxexp);
338   for (i=0; i<size; i++) {
339     out_buffer[i] = (unsigned char)LIBXSMM_ROUNDF(in_buffer[i]*scfq);
340   }
341   *scf = (unsigned char)(-maxexp);
342 }
343 
init_buf_range(float * buf,size_t size,float low,float high)344 LIBXSMM_INLINE void init_buf_range(float* buf, size_t size, float low, float high)
345 {
346   int i;
347   float range = high - low;
348   zero_buf(buf, size);
349   for (i = 0; i < (int)size; ++i) {
350     buf[i] = (((float)rand())/RAND_MAX)*range+low;
351   }
352 }
353 
init_buf_int16(short * buf,size_t size,int initPos,int initOne)354 LIBXSMM_INLINE void init_buf_int16(short* buf, size_t size, int initPos, int initOne)
355 {
356   int i;
357   zero_buf_int16(buf, size);
358 #if defined(_OPENMP)
359 # pragma omp parallel for private(i)
360 #endif
361   for (i = 0; i < (int)size; ++i) {
362     buf[i] = (short)((initOne != 0) ? 1 : ((initPos != 0) ? (rand()%7) : (rand()%7-3)));
363   }
364 }
365 
init_buf_int32(int * buf,size_t size,int initPos,int initOne)366 LIBXSMM_INLINE void init_buf_int32(int* buf, size_t size, int initPos, int initOne)
367 {
368   int i;
369   zero_buf_int32(buf, size);
370 #if defined(_OPENMP)
371 # pragma omp parallel for private(i)
372 #endif
373   for (i = 0; i < (int)size; ++i) {
374     buf[i] = (int)((initOne != 0) ? 1 : ((initPos != 0) ? (rand()%7) : (rand()%7-3)));
375   }
376 }
377 
init_buf_int8(char * buf,size_t size,int initPos,int initOne)378 LIBXSMM_INLINE void init_buf_int8(char* buf, size_t size, int initPos, int initOne)
379 {
380   int i;
381   zero_buf_int8(buf, size);
382 #if defined(_OPENMP)
383 # pragma omp parallel for private(i)
384 #endif
385   for (i = 0; i < (int)size; ++i) {
386     buf[i] = (char)((initOne != 0) ? 1 : ((initPos != 0) ? (rand()%3) : (rand()%3)-1));
387   }
388 }
389 
init_buf_uint8(unsigned char * buf,size_t size,int initPos,int initOne)390 LIBXSMM_INLINE void init_buf_uint8(unsigned char* buf, size_t size, int initPos, int initOne)
391 {
392   int i;
393   LIBXSMM_UNUSED(initPos);
394   zero_buf_uint8(buf, size);
395 #if defined(_OPENMP)
396 # pragma omp parallel for private(i)
397 #endif
398   for (i = 0; i < (int)size; ++i) {
399     buf[i] = (unsigned char)((initOne != 0) ? 1 : (rand()%3));
400   }
401 }
402 
set_zeropad_nchw(float * nchw,int N,int C,int H,int W,int pad_h,int pad_w)403 LIBXSMM_INLINE void set_zeropad_nchw(float* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
404 {
405   LIBXSMM_VLA_DECL(4, float, input, nchw, C, H, W);
406   int n, h, w, c;
407 
408 #if defined(_OPENMP)
409   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
410 # pragma omp parallel for private(n,c,h,w)
411 #endif
412   for ( n = 0; n < N; n++ ) {
413     for ( c = 0; c < C; c++ ) {
414       for ( h = 0; h < H; h++ ) {
415         for ( w = 0; w < W; w++ ) {
416           if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
417             LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0.0;
418           }
419         }
420       }
421     }
422   }
423 }
424 
set_zeropad_nchw_int16(short * nchw,int N,int C,int H,int W,int pad_h,int pad_w)425 LIBXSMM_INLINE void set_zeropad_nchw_int16(short* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
426 {
427   LIBXSMM_VLA_DECL(4, short, input, nchw, C, H, W);
428   int n, h, w, c;
429 
430 #if defined(_OPENMP)
431   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
432 # pragma omp parallel for private(n,c,h,w)
433 #endif
434   for ( n = 0; n < N; n++ ) {
435     for ( c = 0; c < C; c++ ) {
436       for ( h = 0; h < H; h++ ) {
437         for ( w = 0; w < W; w++ ) {
438           if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
439             LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0;
440           }
441         }
442       }
443     }
444   }
445 }
446 
set_zeropad_nchw_int32(int * nchw,int N,int C,int H,int W,int pad_h,int pad_w)447 LIBXSMM_INLINE void set_zeropad_nchw_int32(int* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
448 {
449   LIBXSMM_VLA_DECL(4, int, input, nchw, C, H, W);
450   int n, h, w, c;
451 
452 #if defined(_OPENMP)
453   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
454 # pragma omp parallel for private(n,c,h,w)
455 #endif
456   for ( n = 0; n < N; n++ ) {
457     for ( c = 0; c < C; c++ ) {
458       for ( h = 0; h < H; h++ ) {
459         for ( w = 0; w < W; w++ ) {
460           if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
461             LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0;
462           }
463         }
464       }
465     }
466   }
467 }
468 
set_zeropad_nchw_uint8(unsigned char * nchw,int N,int C,int H,int W,int pad_h,int pad_w)469 LIBXSMM_INLINE void set_zeropad_nchw_uint8(unsigned char* nchw, int N, int C, int H, int W, int pad_h, int pad_w)
470 {
471   LIBXSMM_VLA_DECL(4, unsigned char, input, nchw, C, H, W);
472   int n, h, w, c;
473 
474 #if defined(_OPENMP)
475   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
476 # pragma omp parallel for private(n,c,h,w)
477 #endif
478   for ( n = 0; n < N; n++ ) {
479     for ( c = 0; c < C; c++ ) {
480       for ( h = 0; h < H; h++ ) {
481         for ( w = 0; w < W; w++ ) {
482           if (h < pad_h || h >= H-pad_h || w < pad_w || w >= W-pad_w) {
483             LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W) = 0;
484           }
485         }
486       }
487     }
488   }
489 }
490 
copy_internal_nchw(float * dst,float * src,int N,int C,int H,int W,int pad_h,int pad_w)491 LIBXSMM_INLINE void copy_internal_nchw(float* dst , float* src, int N, int C, int H, int W, int pad_h, int pad_w)
492 {
493   LIBXSMM_VLA_DECL(4, float, input, src, C, H, W);
494   LIBXSMM_VLA_DECL(4, float, new_input, dst, C, H+2*pad_h, W+2*pad_w);
495   int n, h, w, c;
496 
497 #if defined(_OPENMP)
498   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
499 # pragma omp parallel for private(n,c,h,w)
500 #endif
501   for ( n = 0; n < N; n++ ) {
502     for ( c = 0; c < C; c++ ) {
503       for ( h = 0; h < H; h++ ) {
504         for ( w = 0; w < W; w++ ) {
505           LIBXSMM_VLA_ACCESS(4, new_input, n, c, h+pad_h, w+pad_w, C, H+2*pad_h, W+2*pad_w) =  LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
506         }
507       }
508     }
509   }
510 }
511 
copy_internal_nchw_int16(short * dst,short * src,int N,int C,int H,int W,int pad_h,int pad_w)512 LIBXSMM_INLINE void copy_internal_nchw_int16(short* dst , short* src, int N, int C, int H, int W, int pad_h, int pad_w)
513 {
514   LIBXSMM_VLA_DECL(4, short, input, src, C, H, W);
515   LIBXSMM_VLA_DECL(4, short, new_input, dst, C, H+2*pad_h, W+2*pad_w);
516   int n, h, w, c;
517 
518 #if defined(_OPENMP)
519   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
520 # pragma omp parallel for private(n,c,h,w)
521 #endif
522   for ( n = 0; n < N; n++ ) {
523     for ( c = 0; c < C; c++ ) {
524       for ( h = 0; h < H; h++ ) {
525         for ( w = 0; w < W; w++ ) {
526           LIBXSMM_VLA_ACCESS(4, new_input, n, c, h+pad_h, w+pad_w, C, H+2*pad_h, W+2*pad_w) =  LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
527         }
528       }
529     }
530   }
531 }
532 
copy_internal_nchw_uint8(unsigned char * dst,unsigned char * src,int N,int C,int H,int W,int pad_h,int pad_w)533 LIBXSMM_INLINE void copy_internal_nchw_uint8(unsigned char* dst , unsigned char* src, int N, int C, int H, int W, int pad_h, int pad_w)
534 {
535   LIBXSMM_VLA_DECL(4, unsigned char, input, src, C, H, W);
536   LIBXSMM_VLA_DECL(4, unsigned char, new_input, dst, C, H+2*pad_h, W+2*pad_w);
537   int n, h, w, c;
538 
539 #if defined(_OPENMP)
540   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
541 # pragma omp parallel for private(n,c,h,w)
542 #endif
543   for ( n = 0; n < N; n++ ) {
544     for ( c = 0; c < C; c++ ) {
545       for ( h = 0; h < H; h++ ) {
546         for ( w = 0; w < W; w++ ) {
547           LIBXSMM_VLA_ACCESS(4, new_input, n, c, h+pad_h, w+pad_w, C, H+2*pad_h, W+2*pad_w) =  LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
548         }
549       }
550     }
551   }
552 }
553 
naive_copy_NCHW_to_NHWC(const float * nchw,float * nhwc,int N,int H,int W,int C)554 LIBXSMM_INLINE void naive_copy_NCHW_to_NHWC(const float* nchw, float* nhwc, int N, int H, int W, int C)
555 {
556   LIBXSMM_VLA_DECL(4,       float, output, nhwc, H, W, C);
557   LIBXSMM_VLA_DECL(4, const float,  input, nchw, C, H, W);
558   int n, h, w, c;
559 
560 #if defined(_OPENMP)
561   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
562 # pragma omp parallel for private(n,c,h,w)
563 #endif
564   for ( n = 0; n < N; n++ ) {
565     for ( h = 0; h < H; h++ ) {
566       for ( w = 0; w < W; w++ ) {
567         for ( c = 0; c < C; c++ ) {
568           LIBXSMM_VLA_ACCESS(4, output, n, h, w, c, H, W, C) =
569           LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
570         }
571       }
572     }
573   }
574 }
575 
naive_copy_NHWC_to_NCHW(const float * nhwc,float * nchw,int N,int H,int W,int C)576 LIBXSMM_INLINE void naive_copy_NHWC_to_NCHW(const float* nhwc, float* nchw, int N, int H, int W, int C)
577 {
578   LIBXSMM_VLA_DECL(4,       float, output, nchw, C, H, W);
579   LIBXSMM_VLA_DECL(4, const float,  input, nhwc, H, W, C);
580   int n, h, w, c;
581 
582 #if defined(_OPENMP)
583   LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(h); LIBXSMM_OMP_VAR(w);
584 # pragma omp parallel for private(n,c,h,w)
585 #endif
586   for ( n = 0; n < N; n++ ) {
587     for ( h = 0; h < H; h++ ) {
588       for ( w = 0; w < W; w++ ) {
589         for ( c = 0; c < C; c++ ) {
590           LIBXSMM_VLA_ACCESS(4, output, n, c, h, w, C, H, W) =
591           LIBXSMM_VLA_ACCESS(4, input, n, h, w, c, H, W, C);
592         }
593       }
594     }
595   }
596 }
597 
naive_copy_KCRS_to_RSCK(const float * kcrs,float * rsck,int R,int S,int C,int K)598 LIBXSMM_INLINE void naive_copy_KCRS_to_RSCK(const float* kcrs, float* rsck, int R, int S, int C, int K)
599 {
600   LIBXSMM_VLA_DECL(4,       float, output, rsck, S, C, K);
601   LIBXSMM_VLA_DECL(4, const float,  input, kcrs, C, R, S);
602   int r, s, c, k;
603 
604 #if defined(_OPENMP)
605   LIBXSMM_OMP_VAR(s); LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(k);
606 # pragma omp parallel for private(r,s,c,k)
607 #endif
608   for ( r = 0; r < R; r++ ) {
609     for ( s = 0; s < S; s++ ) {
610       for ( c = 0; c < C; c++ ) {
611         for ( k = 0; k < K; k++ ) {
612           LIBXSMM_VLA_ACCESS(4, output, r, s, c, k, S, C, K) =
613           LIBXSMM_VLA_ACCESS(4, input, k, c, r, s, C, R, S);
614         }
615       }
616     }
617   }
618 }
619 
620 
naive_copy_RSCK_to_KCRS(const float * rsck,float * kcrs,int R,int S,int C,int K)621 LIBXSMM_INLINE void naive_copy_RSCK_to_KCRS(const float* rsck, float* kcrs, int R, int S, int C, int K)
622 {
623   LIBXSMM_VLA_DECL(4, const float,  input, rsck, S, C, K);
624   LIBXSMM_VLA_DECL(4,       float, output, kcrs, C, R, S);
625   int r, s, c, k;
626 
627 #if defined(_OPENMP)
628   LIBXSMM_OMP_VAR(s); LIBXSMM_OMP_VAR(c); LIBXSMM_OMP_VAR(k);
629 # pragma omp parallel for private(r,s,c,k)
630 #endif
631   for ( r = 0; r < R; r++ ) {
632     for ( s = 0; s < S; s++ ) {
633       for ( c = 0; c < C; c++ ) {
634         for ( k = 0; k < K; k++ ) {
635           LIBXSMM_VLA_ACCESS(4, output, k, c, r, s, C, R, S) =
636             LIBXSMM_VLA_ACCESS(4, input, r, s, c, k, S, C, K);
637         }
638       }
639     }
640   }
641 }
642 
matrix_copy_NC_to_NCNC(float * src,float * dst,int T,int N,int C,int bn,int bc)643 LIBXSMM_INLINE void matrix_copy_NC_to_NCNC(float *src, float *dst, int T, int N, int C, int bn, int bc)
644 {
645   int t, n1, n2, c1, c2;
646   int nBlocks = N/bn;
647   int cBlocks = C/bc;
648   LIBXSMM_VLA_DECL(3, float, real_src, src, N, C);
649   LIBXSMM_VLA_DECL(5, float, real_dst, dst, nBlocks, cBlocks, bn, bc);
650 
651 #if defined(_OPENMP)
652   LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
653 # pragma omp parallel for private(t,n1,c1,n2,c2)
654 #endif
655   for (t = 0; t < T; t++) {
656     for (n1 = 0; n1 < nBlocks; n1++) {
657       for (c1 = 0; c1 < cBlocks; c1++) {
658         for (n2 = 0; n2 < bn; n2++) {
659           for (c2 = 0; c2 < bc; c2++) {
660             LIBXSMM_VLA_ACCESS(5, real_dst, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc) =
661               LIBXSMM_VLA_ACCESS(3, real_src, t, n1*bn+n2, c1*bc+c2, N, C);
662           }
663         }
664       }
665     }
666   }
667 }
668 
matrix_copy_NCNC_to_NC(float * src,float * dst,int T,int N,int C,int bn,int bc)669 LIBXSMM_INLINE void matrix_copy_NCNC_to_NC(float *src, float *dst, int T, int N, int C, int bn, int bc)
670 {
671   int t, n1, n2, c1, c2;
672   int nBlocks = N/bn;
673   int cBlocks = C/bc;
674   LIBXSMM_VLA_DECL(3, float, real_dst, dst, N, C);
675   LIBXSMM_VLA_DECL(5, float, real_src, src, nBlocks, cBlocks, bn, bc);
676 
677 #if defined(_OPENMP)
678   LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
679 # pragma omp parallel for private(t,n1,c1,n2,c2)
680 #endif
681   for (t = 0; t < T; t++) {
682     for (n1 = 0; n1 < nBlocks; n1++) {
683       for (c1 = 0; c1 < cBlocks; c1++) {
684         for (n2 = 0; n2 < bn; n2++) {
685           for (c2 = 0; c2 < bc; c2++) {
686             LIBXSMM_VLA_ACCESS(3, real_dst, t, n1*bn+n2, c1*bc+c2, N, C) =
687               LIBXSMM_VLA_ACCESS(5, real_src, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc);
688           }
689         }
690       }
691     }
692   }
693 }
694 
matrix_copy_NC_to_NCNC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int T,int N,int C,int bn,int bc)695 LIBXSMM_INLINE void matrix_copy_NC_to_NCNC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int T, int N, int C, int bn, int bc)
696 {
697   int t, n1, n2, c1, c2;
698   int nBlocks = N/bn;
699   int cBlocks = C/bc;
700   LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, real_src, src, N, C);
701   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, nBlocks, cBlocks, bn, bc);
702 
703 #if defined(_OPENMP)
704   LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
705 # pragma omp parallel for private(t,n1,c1,n2,c2)
706 #endif
707   for (t = 0; t < T; t++) {
708     for (n1 = 0; n1 < nBlocks; n1++) {
709       for (c1 = 0; c1 < cBlocks; c1++) {
710         for (n2 = 0; n2 < bn; n2++) {
711           for (c2 = 0; c2 < bc; c2++) {
712             LIBXSMM_VLA_ACCESS(5, real_dst, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc) =
713               LIBXSMM_VLA_ACCESS(3, real_src, t, n1*bn+n2, c1*bc+c2, N, C);
714           }
715         }
716       }
717     }
718   }
719 }
720 
matrix_copy_NCNC_to_NC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int T,int N,int C,int bn,int bc)721 LIBXSMM_INLINE void matrix_copy_NCNC_to_NC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int T, int N, int C, int bn, int bc)
722 {
723   int t, n1, n2, c1, c2;
724   int nBlocks = N/bn;
725   int cBlocks = C/bc;
726   LIBXSMM_VLA_DECL(3, libxsmm_bfloat16, real_dst, dst, N, C);
727   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, nBlocks, cBlocks, bn, bc);
728 
729 #if defined(_OPENMP)
730   LIBXSMM_OMP_VAR(n1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(n2); LIBXSMM_OMP_VAR(c2);
731 # pragma omp parallel for private(t,n1,c1,n2,c2)
732 #endif
733   for (t = 0; t < T; t++) {
734     for (n1 = 0; n1 < nBlocks; n1++) {
735       for (c1 = 0; c1 < cBlocks; c1++) {
736         for (n2 = 0; n2 < bn; n2++) {
737           for (c2 = 0; c2 < bc; c2++) {
738             LIBXSMM_VLA_ACCESS(3, real_dst, t, n1*bn+n2, c1*bc+c2, N, C) =
739               LIBXSMM_VLA_ACCESS(5, real_src, t, n1, c1, n2, c2, nBlocks, cBlocks, bn, bc);
740           }
741         }
742       }
743     }
744   }
745 }
746 
matrix_copy_CK_to_KCCK(float * src,float * dst,int C,int K,int bc,int bk)747 LIBXSMM_INLINE void matrix_copy_CK_to_KCCK(float *src, float *dst, int C, int K, int bc, int bk)
748 {
749   int k1, k2, c1, c2;
750   int kBlocks = K/bk;
751   int cBlocks = C/bc;
752   LIBXSMM_VLA_DECL(2, float, real_src, src, K);
753   LIBXSMM_VLA_DECL(4, float, real_dst, dst, cBlocks, bc, bk);
754 
755 #if defined(_OPENMP)
756   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
757 # pragma omp parallel for private(k1,c1,c2,k2)
758 #endif
759   for (k1 = 0; k1 < kBlocks; k1++) {
760     for (c1 = 0; c1 < cBlocks; c1++) {
761       for (c2 = 0; c2 < bc; c2++) {
762         for (k2 = 0; k2 < bk; k2++) {
763           LIBXSMM_VLA_ACCESS(4, real_dst, k1, c1, c2, k2, cBlocks, bc, bk) =
764             LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
765         }
766       }
767     }
768   }
769 }
770 
matrix_copy_CK_to_CKKC(float * src,float * dst,int C,int K,int bc,int bk)771 LIBXSMM_INLINE void matrix_copy_CK_to_CKKC(float *src, float *dst, int C, int K, int bc, int bk)
772 {
773   int k1, k2, c1, c2;
774   int kBlocks = K/bk;
775   int cBlocks = C/bc;
776   LIBXSMM_VLA_DECL(2, float, real_src, src, K);
777   LIBXSMM_VLA_DECL(4, float, real_dst, dst, kBlocks, bk, bc);
778 
779 #if defined(_OPENMP)
780   LIBXSMM_OMP_VAR(k1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
781 # pragma omp parallel for private(k1,c1,c2,k2)
782 #endif
783   for (c1 = 0; c1 < cBlocks; c1++) {
784     for (k1 = 0; k1 < kBlocks; k1++) {
785       for (k2 = 0; k2 < bk; k2++) {
786         for (c2 = 0; c2 < bc; c2++) {
787           LIBXSMM_VLA_ACCESS(4, real_dst, c1, k1, k2, c2, kBlocks, bk, bc) =
788             LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
789         }
790       }
791     }
792   }
793 }
794 
matrix_copy_KC_to_KCCK(float * src,float * dst,int C,int K,int bc,int bk)795 LIBXSMM_INLINE void matrix_copy_KC_to_KCCK(float *src, float *dst, int C, int K, int bc, int bk)
796 {
797   int k1, k2, c1, c2;
798   int kBlocks = K/bk;
799   int cBlocks = C/bc;
800   LIBXSMM_VLA_DECL(2, float, real_src, src, C);
801   LIBXSMM_VLA_DECL(4, float, real_dst, dst, cBlocks, bc, bk);
802 
803 #if defined(_OPENMP)
804   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
805 # pragma omp parallel for private(k1,c1,c2,k2)
806 #endif
807   for (k1 = 0; k1 < kBlocks; k1++) {
808     for (c1 = 0; c1 < cBlocks; c1++) {
809       for (c2 = 0; c2 < bc; c2++) {
810         for (k2 = 0; k2 < bk; k2++) {
811           LIBXSMM_VLA_ACCESS(4, real_dst, k1, c1, c2, k2, cBlocks, bc, bk) =
812             LIBXSMM_VLA_ACCESS(2, real_src, k1*bk+k2, c1*bc+c2, C);
813         }
814       }
815     }
816   }
817 }
818 
matrix_copy_KCCK_to_KC(float * src,float * dst,int C,int K,int bc,int bk)819 LIBXSMM_INLINE void matrix_copy_KCCK_to_KC(float *src, float *dst, int C, int K, int bc, int bk)
820 {
821   int k1, k2, c1, c2;
822   int kBlocks = K/bk;
823   int cBlocks = C/bc;
824   LIBXSMM_VLA_DECL(2, float, real_dst, dst, C);
825   LIBXSMM_VLA_DECL(4, float, real_src, src, cBlocks, bc, bk);
826 
827 #if defined(_OPENMP)
828   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
829 # pragma omp parallel for private(k1,c1,c2,k2)
830 #endif
831   for (k1 = 0; k1 < kBlocks; k1++) {
832     for (c1 = 0; c1 < cBlocks; c1++) {
833       for (c2 = 0; c2 < bc; c2++) {
834         for (k2 = 0; k2 < bk; k2++) {
835           LIBXSMM_VLA_ACCESS(2, real_dst, k1*bk+k2, c1*bc+c2, C) =
836             LIBXSMM_VLA_ACCESS(4, real_src, k1, c1, c2, k2, cBlocks, bc, bk);
837         }
838       }
839     }
840   }
841 }
842 
matrix_copy_KCCK_to_CK(float * src,float * dst,int C,int K,int bc,int bk)843 LIBXSMM_INLINE void matrix_copy_KCCK_to_CK(float *src, float *dst, int C, int K, int bc, int bk)
844 {
845   int k1, k2, c1, c2;
846   int kBlocks = K/bk;
847   int cBlocks = C/bc;
848   LIBXSMM_VLA_DECL(2, float, real_dst, dst, K);
849   LIBXSMM_VLA_DECL(4, float, real_src, src, cBlocks, bc, bk);
850 
851 #if defined(_OPENMP)
852   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
853 # pragma omp parallel for private(k1,c1,c2,k2)
854 #endif
855   for (k1 = 0; k1 < kBlocks; k1++) {
856     for (c1 = 0; c1 < cBlocks; c1++) {
857       for (c2 = 0; c2 < bc; c2++) {
858         for (k2 = 0; k2 < bk; k2++) {
859           LIBXSMM_VLA_ACCESS(2, real_dst, c1*bc+c2, k1*bk+k2, K) =
860             LIBXSMM_VLA_ACCESS(4, real_src, k1, c1, c2, k2, cBlocks, bc, bk);
861         }
862       }
863     }
864   }
865 }
866 
matrix_copy_CK_to_KCCK_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)867 LIBXSMM_INLINE void matrix_copy_CK_to_KCCK_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
868 {
869   int k1, k2, c1, c2;
870   int kBlocks = K/bk;
871   int cBlocks = C/bc;
872   LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_src, src, K);
873   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, cBlocks, bc/2, bk, 2);
874 
875 #if defined(_OPENMP)
876   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
877 # pragma omp parallel for private(k1,c1,c2,k2)
878 #endif
879   for (k1 = 0; k1 < kBlocks; k1++) {
880     for (c1 = 0; c1 < cBlocks; c1++) {
881       for (c2 = 0; c2 < bc; c2++) {
882         for (k2 = 0; k2 < bk; k2++) {
883           LIBXSMM_VLA_ACCESS(5, real_dst, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2) =
884             LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
885         }
886       }
887     }
888   }
889 }
890 
matrix_copy_CK_to_CKKC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)891 LIBXSMM_INLINE void matrix_copy_CK_to_CKKC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
892 {
893   int k1, k2, c1, c2;
894   int kBlocks = K/bk;
895   int cBlocks = C/bc;
896   LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_src, src, K);
897   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, kBlocks, bk/2, bc, 2);
898 
899 #if defined(_OPENMP)
900   LIBXSMM_OMP_VAR(k1); LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
901 # pragma omp parallel for private(k1,c1,c2,k2)
902 #endif
903   for (c1 = 0; c1 < cBlocks; c1++) {
904     for (k1 = 0; k1 < kBlocks; k1++) {
905       for (k2 = 0; k2 < bk; k2++) {
906         for (c2 = 0; c2 < bc; c2++) {
907           LIBXSMM_VLA_ACCESS(5, real_dst, c1, k1, k2/2, c2, k2%2, kBlocks, bk/2, bc, 2) =
908             LIBXSMM_VLA_ACCESS(2, real_src, c1*bc+c2, k1*bk+k2, K);
909         }
910       }
911     }
912   }
913 }
914 
matrix_copy_KC_to_KCCK_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)915 LIBXSMM_INLINE void matrix_copy_KC_to_KCCK_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
916 {
917   int k1, k2, c1, c2;
918   int kBlocks = K/bk;
919   int cBlocks = C/bc;
920   LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_src, src, C);
921   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, cBlocks, bc/2, bk, 2);
922 
923 #if defined(_OPENMP)
924   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
925 # pragma omp parallel for private(k1,c1,c2,k2)
926 #endif
927   for (k1 = 0; k1 < kBlocks; k1++) {
928     for (c1 = 0; c1 < cBlocks; c1++) {
929       for (c2 = 0; c2 < bc; c2++) {
930         for (k2 = 0; k2 < bk; k2++) {
931           LIBXSMM_VLA_ACCESS(5, real_dst, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2) =
932             LIBXSMM_VLA_ACCESS(2, real_src, k1*bk+k2, c1*bc+c2, C);
933         }
934       }
935     }
936   }
937 }
938 
matrix_copy_KCCK_to_KC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)939 LIBXSMM_INLINE void matrix_copy_KCCK_to_KC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
940 {
941   int k1, k2, c1, c2;
942   int kBlocks = K/bk;
943   int cBlocks = C/bc;
944   LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_dst, dst, C);
945   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, cBlocks, bc/2, bk, 2);
946 
947 #if defined(_OPENMP)
948   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
949 # pragma omp parallel for private(k1,c1,c2,k2)
950 #endif
951   for (k1 = 0; k1 < kBlocks; k1++) {
952     for (c1 = 0; c1 < cBlocks; c1++) {
953       for (c2 = 0; c2 < bc; c2++) {
954         for (k2 = 0; k2 < bk; k2++) {
955           LIBXSMM_VLA_ACCESS(2, real_dst, k1*bk+k2, c1*bc+c2, C) =
956             LIBXSMM_VLA_ACCESS(5, real_src, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2);
957         }
958       }
959     }
960   }
961 }
962 
matrix_copy_KCCK_to_CK_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)963 LIBXSMM_INLINE void matrix_copy_KCCK_to_CK_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
964 {
965   int k1, k2, c1, c2;
966   int kBlocks = K/bk;
967   int cBlocks = C/bc;
968   LIBXSMM_VLA_DECL(2, libxsmm_bfloat16, real_dst, dst, K);
969   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, cBlocks, bc/2, bk, 2);
970 
971 #if defined(_OPENMP)
972   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
973 # pragma omp parallel for private(k1,c1,c2,k2)
974 #endif
975   for (k1 = 0; k1 < kBlocks; k1++) {
976     for (c1 = 0; c1 < cBlocks; c1++) {
977       for (c2 = 0; c2 < bc; c2++) {
978         for (k2 = 0; k2 < bk; k2++) {
979           LIBXSMM_VLA_ACCESS(2, real_dst, c1*bc+c2, k1*bk+k2, K) =
980             LIBXSMM_VLA_ACCESS(5, real_src, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2);
981         }
982       }
983     }
984   }
985 }
986 
matrix_copy_KCCK_to_CKKC_bf16(libxsmm_bfloat16 * src,libxsmm_bfloat16 * dst,int C,int K,int bc,int bk)987 LIBXSMM_INLINE void matrix_copy_KCCK_to_CKKC_bf16(libxsmm_bfloat16 *src, libxsmm_bfloat16 *dst, int C, int K, int bc, int bk)
988 {
989   int k1, k2, c1, c2;
990   int kBlocks = K/bk;
991   int cBlocks = C/bc;
992   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_dst, dst, kBlocks, bk/2, bc, 2);
993   LIBXSMM_VLA_DECL(5, libxsmm_bfloat16, real_src, src, cBlocks, bc/2, bk, 2);
994 
995 #if defined(_OPENMP)
996   LIBXSMM_OMP_VAR(c1); LIBXSMM_OMP_VAR(c2); LIBXSMM_OMP_VAR(k2);
997 # pragma omp parallel for private(k1,c1,c2,k2)
998 #endif
999   for (k1 = 0; k1 < kBlocks; k1++) {
1000     for (c1 = 0; c1 < cBlocks; c1++) {
1001       for (c2 = 0; c2 < bc; c2++) {
1002         for (k2 = 0; k2 < bk; k2++) {
1003           LIBXSMM_VLA_ACCESS(5, real_dst, c1, k1, k2/2, c2, k2%2, kBlocks, bk/2, bc, 2) =
1004           LIBXSMM_VLA_ACCESS(5, real_src, k1, c1, c2/2, k2, c2%2, cBlocks, bc/2, bk, 2);
1005         }
1006       }
1007     }
1008   }
1009 }
1010 
matrix_add(int size,float * a,float * b,float * c)1011 LIBXSMM_INLINE void matrix_add(int size, float *a, float *b, float *c)
1012 {
1013   int i;
1014 #if defined(_OPENMP)
1015 # pragma omp parallel for private(i)
1016 #endif
1017   for (i = 0; i < size; i++) {
1018     c[i] = a[i] + b[i];
1019   }
1020 }
1021 
matrix_eltwise_mult(int size,float * a,float * b,float * c)1022 LIBXSMM_INLINE void matrix_eltwise_mult(int size, float *a, float *b, float *c)
1023 {
1024   int i;
1025 #if defined(_OPENMP)
1026 # pragma omp parallel for private(i)
1027 #endif
1028   for (i = 0; i < size; i++) {
1029     c[i] = a[i] * b[i];
1030   }
1031 }
1032 
matrix_eltwise_fma(int size,float * a,float * b,float * c)1033 LIBXSMM_INLINE void matrix_eltwise_fma(int size, float *a, float *b, float *c)
1034 {
1035   int i;
1036 #if defined(_OPENMP)
1037 # pragma omp parallel for private(i)
1038 #endif
1039   for (i = 0; i < size; i++) {
1040     c[i] += a[i] * b[i];
1041   }
1042 }
1043 
matrix_eltwise_mult_ld_a(int m,int n,int ld,float * a,float * b,float * c)1044 LIBXSMM_INLINE void matrix_eltwise_mult_ld_a(int m, int n, int ld, float *a, float *b, float *c)
1045 {
1046   int i;
1047 #if defined(_OPENMP)
1048 # pragma omp parallel for private(i)
1049 #endif
1050   for (i = 0; i < m*n; i++) {
1051     int row = i / m;
1052     int col = i % m;
1053     c[i] = a[row*ld + col] * b[i];
1054   }
1055 }
1056 
matrix_eltwise_mult_ld_ab(int m,int n,int ld,float * a,float * b,float * c)1057 LIBXSMM_INLINE void matrix_eltwise_mult_ld_ab(int m, int n, int ld, float *a, float *b, float *c)
1058 {
1059   int i;
1060 #if defined(_OPENMP)
1061 # pragma omp parallel for private(i)
1062 #endif
1063   for (i = 0; i < m*n; i++) {
1064     int row = i / m;
1065     int col = i % m;
1066     c[i] = a[row*ld + col] * b[row*ld + col];
1067   }
1068 }
1069 
matrix_eltwise_mult_ld_c(int m,int n,int ld,float * a,float * b,float * c)1070 LIBXSMM_INLINE void matrix_eltwise_mult_ld_c(int m, int n, int ld, float *a, float *b, float *c)
1071 {
1072   int i;
1073 #if defined(_OPENMP)
1074 # pragma omp parallel for private(i)
1075 #endif
1076   for (i = 0; i < m*n; i++) {
1077     int row = i / m;
1078     int col = i % m;
1079     c[row*ld + col] = a[i] * b[i];
1080   }
1081 }
1082 
matrix_sigmoid(int size,float * src,float * dst)1083 LIBXSMM_INLINE void matrix_sigmoid(int size, float *src, float *dst)
1084 {
1085   int i;
1086 #if defined(_OPENMP)
1087 # pragma omp parallel for private(i)
1088 #endif
1089   for (i = 0; i < size; i++) {
1090     const float exp_value = (float)exp((double) -src[i]);
1091     dst[i] = 1.0f / (1.0f + exp_value);
1092   }
1093 }
1094 
matrix_sigmoid_ld(int m,int n,int ld,float * src,float * dst)1095 LIBXSMM_INLINE void matrix_sigmoid_ld(int m, int n, int ld, float *src, float *dst)
1096 {
1097   int i;
1098 #if defined(_OPENMP)
1099 # pragma omp parallel for private(i)
1100 #endif
1101   for (i = 0; i < m*n; i++) {
1102     int row = i / m;
1103     int col = i % m;
1104     const float exp_value = (float)exp((double) -src[row*ld + col]);
1105     dst[row*ld + col] = 1.0f / (1.0f + exp_value);
1106   }
1107 }
1108 
matrix_tanh(int size,float * src,float * dst)1109 LIBXSMM_INLINE void matrix_tanh(int size, float *src, float *dst)
1110 {
1111   int i;
1112 #if defined(_OPENMP)
1113 # pragma omp parallel for private(i)
1114 #endif
1115   for (i = 0; i < size; i++) {
1116     dst[i] = (float)tanh((double)src[i]);
1117   }
1118 }
1119 
matrix_tanh_ld(int m,int n,int ld,float * src,float * dst)1120 LIBXSMM_INLINE void matrix_tanh_ld(int m, int n, int ld, float *src, float *dst)
1121 {
1122   int i;
1123 #if defined(_OPENMP)
1124 # pragma omp parallel for private(i)
1125 #endif
1126   for (i = 0; i < m*n; i++) {
1127     int row = i / m;
1128     int col = i % m;
1129     dst[row*ld + col] = (float)tanh((double)src[row*ld + col]);
1130   }
1131 }
1132 
matrix_relu(int size,float * src,float * dst)1133 LIBXSMM_INLINE void matrix_relu(int size, float *src, float *dst)
1134 {
1135   int i;
1136 #if defined(_OPENMP)
1137 # pragma omp parallel for private(i)
1138 #endif
1139   for (i = 0; i < size; i++) {
1140     dst[i] = (src[i] > 0.0f) ? src[i] : 0.0f;
1141   }
1142 }
1143 
matrix_sigmoid_inverse(int size,float * src,float * dst)1144 LIBXSMM_INLINE void matrix_sigmoid_inverse(int size, float *src, float *dst)
1145 {
1146   int i;
1147 #if defined(_OPENMP)
1148 # pragma omp parallel for private(i)
1149 #endif
1150   for (i = 0; i < size; i++) {
1151     const float exp_value = (float)exp((double) -src[i]);
1152     const float sig_exp = 1.0f / (1.0f + exp_value);
1153     dst[i] = (1.0f - sig_exp)*sig_exp;
1154   }
1155 }
1156 
matrix_tanh_inverse(int size,float * src,float * dst)1157 LIBXSMM_INLINE void matrix_tanh_inverse(int size, float *src, float *dst)
1158 {
1159   int i;
1160 #if defined(_OPENMP)
1161 # pragma omp parallel for private(i)
1162 #endif
1163   for (i = 0; i < size; i++) {
1164     const float tanh_value = (float)tanh((double)src[i]);
1165     dst[i] = 1.0f - (tanh_value * tanh_value);
1166   }
1167 }
1168 
matrix_relu_inverse(int size,float * src,float * dst)1169 LIBXSMM_INLINE void matrix_relu_inverse(int size, float *src, float *dst)
1170 {
1171   int i;
1172 #if defined(_OPENMP)
1173 # pragma omp parallel for private(i)
1174 #endif
1175   for (i = 0; i < size; i++) {
1176     dst[i] = (src[i] > 0.0f) ? 1.0f : 0.0f;
1177   }
1178 }
1179 
matrix_transpose(int rows,int cols,float * src,float * dst)1180 LIBXSMM_INLINE void matrix_transpose(int rows, int cols, float *src, float *dst)
1181 {
1182   libxsmm_otrans_omp(dst, src, sizeof(float), cols, rows, cols/*ldi*/, rows/*ldo*/);
1183 }
1184 
matrix_copy(int size,float * src,float * dst)1185 LIBXSMM_INLINE void matrix_copy(int size, float *src, float *dst)
1186 {
1187   int i;
1188 #if defined(_OPENMP)
1189 # pragma omp parallel for private(i)
1190 #endif
1191   for (i = 0; i < size; i++) {
1192     dst[i] = src[i];
1193   }
1194 }
1195 
matrix_copy_f32_bf16(int size,float * src,libxsmm_bfloat16 * dst)1196 LIBXSMM_INLINE void matrix_copy_f32_bf16(int size, float *src, libxsmm_bfloat16 *dst)
1197 {
1198   int i;
1199 #if defined(_OPENMP)
1200 # pragma omp parallel for private(i)
1201 #endif
1202   for (i = 0; i < size; i++) {
1203     libxsmm_bfloat16_hp t;
1204     t.f = src[i];
1205     dst[i] = t.i[1];
1206   }
1207 }
1208 
matrix_copy_bf16_f32(int size,libxsmm_bfloat16 * src,float * dst)1209 LIBXSMM_INLINE void matrix_copy_bf16_f32(int size, libxsmm_bfloat16 *src, float *dst)
1210 {
1211   int i;
1212 #if defined(_OPENMP)
1213 # pragma omp parallel for private(i)
1214 #endif
1215   for (i = 0; i < size; i++) {
1216     libxsmm_bfloat16_hp t;
1217     t.i[1] = src[i];
1218     t.i[0] = 0;
1219     dst[i] = t.f;
1220   }
1221 }
1222 
matrix_copy_ld(int m,int n,int ld,float * src,float * dst)1223 LIBXSMM_INLINE void matrix_copy_ld(int m, int n, int ld, float *src, float *dst)
1224 {
1225   int i;
1226 #if defined(_OPENMP)
1227 # pragma omp parallel for private(i)
1228 #endif
1229   for (i = 0; i < m*n; i++) {
1230     int row = i / m;
1231     int col = i % m;
1232     dst[i] = src[row*ld + col];
1233   }
1234 }
1235 
matrix_copy_bias(int m,int n,int ld,float * src,float * dst)1236 LIBXSMM_INLINE void matrix_copy_bias(int m, int n, int ld, float *src, float *dst)
1237 {
1238   int i;
1239 #if defined(_OPENMP)
1240 # pragma omp parallel for private(i)
1241 #endif
1242   for (i = 0; i < m*n; i++) {
1243     int row = i / m;
1244     int col = i % m;
1245     dst[row*ld + col] = src[col];
1246   }
1247 }
1248 
matrix_copy_forget_bias(int m,int n,int ld,float * src,float * dst,float forget_bias)1249 LIBXSMM_INLINE void matrix_copy_forget_bias(int m, int n, int ld, float *src, float *dst, float forget_bias)
1250 {
1251   int i;
1252 #if defined(_OPENMP)
1253 # pragma omp parallel for private(i)
1254 #endif
1255   for (i = 0; i < m*n; i++) {
1256     int row = i / m;
1257     int col = i % m;
1258     dst[row*ld + col] = src[col] + forget_bias;
1259   }
1260 }
1261 
matrix_complement(int size,float * src,float * dst)1262 LIBXSMM_INLINE void matrix_complement(int size, float *src, float *dst)
1263 {
1264   int i;
1265 #if defined(_OPENMP)
1266 # pragma omp parallel for private(i)
1267 #endif
1268   for (i = 0; i < size; i++) {
1269     dst[i] = 1.0f - src[i];
1270   }
1271 }
1272 
matrix_complement_ld(int m,int n,int ld,float * src,float * dst)1273 LIBXSMM_INLINE void matrix_complement_ld(int m, int n, int ld, float *src, float *dst)
1274 {
1275   int i;
1276 #if defined(_OPENMP)
1277 # pragma omp parallel for private(i)
1278 #endif
1279   for (i = 0; i < m*n; i++) {
1280     int row = i / m;
1281     int col = i % m;
1282     dst[i] = 1.0f - src[row*ld + col];
1283   }
1284 }
1285 
1286 
matrix_complement_square(int size,float * src,float * dst)1287 LIBXSMM_INLINE void matrix_complement_square(int size, float *src, float *dst)
1288 {
1289   int i;
1290 #if defined(_OPENMP)
1291 # pragma omp parallel for private(i)
1292 #endif
1293   for (i = 0; i < size; i++) {
1294     dst[i] = 1.0f - (src[i] * src[i]);
1295   }
1296 }
1297 
matrix_complement_square_ld(int m,int n,int ld,float * src,float * dst)1298 LIBXSMM_INLINE void matrix_complement_square_ld(int m, int n, int ld, float *src, float *dst)
1299 {
1300   int i;
1301 #if defined(_OPENMP)
1302 # pragma omp parallel for private(i)
1303 #endif
1304   for (i = 0; i < m*n; i++) {
1305     int row = i / m;
1306     int col = i % m;
1307     dst[i] = 1.0f - (src[row*ld + col] * src[row*ld + col]);
1308   }
1309 }
1310 
convert_ck_c4k_offset(int C,int K,int offset,float * src,float * dst)1311 LIBXSMM_INLINE void convert_ck_c4k_offset(int C, int K, int offset, float *src, float *dst)
1312 {
1313   /* offsets: i--0, c--1, f--2, o--3 */
1314   int x, y;
1315 #if defined(_OPENMP)
1316   LIBXSMM_OMP_VAR(x);
1317 # pragma omp parallel for private(x, y)
1318 #endif
1319   for (y = 0; y < C; y++) {
1320     for (x = 0; x < K; x++) {
1321       dst[y*4*K + offset*K + x] = src[y*K + x];
1322     }
1323   }
1324 }
1325 
convert_ck_c4k(int C,int K,float * src,float * dst)1326 LIBXSMM_INLINE void convert_ck_c4k(int C, int K, float *src, float *dst)
1327 {
1328   convert_ck_c4k_offset(C, K, 0, src, dst);
1329 }
1330 
convert_ck_f32_to_c4k_bf16(int C,int K,float * src,libxsmm_bfloat16 * dst)1331 LIBXSMM_INLINE void convert_ck_f32_to_c4k_bf16(int C, int K, float *src, libxsmm_bfloat16 *dst)
1332 {
1333   int x, y;
1334 #if defined(_OPENMP)
1335   LIBXSMM_OMP_VAR(x);
1336 # pragma omp parallel for private(x, y)
1337 #endif
1338   for (y = 0; y < C; y++) {
1339     for (x = 0; x < K; x++) {
1340       libxsmm_bfloat16_hp t;
1341       t.f = src[y*K + x];
1342       dst[y*4*K + x] = t.i[1];
1343     }
1344   }
1345 }
1346 
convert_c4k_4ck(int C,int K,float * src,float * dst)1347 LIBXSMM_INLINE void convert_c4k_4ck(int C, int K, float *src, float *dst)
1348 {
1349   /* offsets: i--0, c--1, f--2, o--3 */
1350   int x, y, offset;
1351 #if defined(_OPENMP)
1352   LIBXSMM_OMP_VAR(x); LIBXSMM_OMP_VAR(y);
1353 # pragma omp parallel for private(x, y, offset)
1354 #endif
1355   for (offset = 0; offset < 4; offset++) {
1356     for (y = 0; y < C; y++) {
1357       for (x = 0; x < K; x++) {
1358         dst[offset*C*K + y*K + x] = src[y*4*K + offset*K + x];
1359       }
1360     }
1361   }
1362 }
1363 
convert_ck_c3k(int C,int K,float * src,float * dst)1364 LIBXSMM_INLINE void convert_ck_c3k(int C, int K, float *src, float *dst)
1365 {
1366   int x, y;
1367 #if defined(_OPENMP)
1368   LIBXSMM_OMP_VAR(x);
1369 # pragma omp parallel for private(x, y)
1370 #endif
1371   for (y = 0; y < C; y++) {
1372     for (x = 0; x < K; x++) {
1373       dst[y*3*K + x] = src[y*K + x];
1374     }
1375   }
1376 }
1377 
convert_nk_nck(int N,int K,int CK,float * src,float * dst)1378 LIBXSMM_INLINE void convert_nk_nck(int N, int K, int CK, float *src, float *dst)
1379 {
1380   int x, y;
1381 #if defined(_OPENMP)
1382   LIBXSMM_OMP_VAR(x);
1383 # pragma omp parallel for private(x, y)
1384 #endif
1385   for (y = 0; y < N; y++) {
1386     for (x = 0; x < K; x++) {
1387       dst[y*CK + x] = src[y*K + x];
1388     }
1389   }
1390 }
1391 
naive_conv_fp(naive_conv_t * param,const float * input,float * output,const float * filter,const float * bias)1392 LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, float* output, const float* filter, const float* bias)
1393 {
1394   int nImg      = param->nImg;
1395   int nIfm      = param->nIfm;
1396   int nOfm      = param->nOfm;
1397   int ifhp      = param->ifhp;
1398   int ifwp      = param->ifwp;
1399   int ofhp      = param->ofhp;
1400   int ofwp      = param->ofwp;
1401   int ifh       = param->ifh;
1402   int ifw       = param->ifw;
1403   int ofh       = param->ofh;
1404   int ofw       = param->ofw;
1405   int pad_h     = param->pad_h;
1406   int pad_w     = param->pad_w;
1407   int pad_h_in  = param->pad_h_in;
1408   int pad_w_in  = param->pad_w_in;
1409   int pad_h_out = param->pad_h_out;
1410   int pad_w_out = param->pad_w_out;
1411   int kh        = param->kh;
1412   int kw        = param->kw;
1413   int stride_h  = param->stride_h;
1414   int stride_w  = param->stride_w;
1415   /* loop counters */
1416   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1417 
1418   LIBXSMM_VLA_DECL(4,       float, output_t, output + (pad_h_out * ofwp + pad_w_out), nOfm, ofhp, ofwp);
1419   LIBXSMM_VLA_DECL(4, const float,  input_t,  input + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1420   LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw);
1421 
1422 #if defined(USE_FUSED_BIAS) || defined(USE_FUSED_BIAS_RELU)
1423 #if defined(_OPENMP)
1424 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1425 #endif
1426   for (img = 0; img < nImg; ++img) {
1427     for (ofm = 0; ofm < nOfm; ++ofm) {
1428       for (oj = 0; oj < ofh; ++oj) {
1429         for (oi = 0; oi < ofw; ++oi) {
1430           LIBXSMM_VLA_ACCESS(  4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) = bias[ofm];
1431         }
1432       }
1433     }
1434   }
1435 #else
1436   LIBXSMM_UNUSED(bias);
1437 #endif
1438 
1439 #if defined(_OPENMP)
1440   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj);  LIBXSMM_OMP_VAR(oi);
1441   LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij);  LIBXSMM_OMP_VAR(ii);  LIBXSMM_OMP_VAR(kj);  LIBXSMM_OMP_VAR(ki);
1442 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1443 #endif
1444   for (img = 0; img < nImg; ++img) {
1445     for (ofm = 0; ofm < nOfm; ++ofm) {
1446       for (ifm = 0; ifm < nIfm; ++ifm) {
1447         for (oj = 0; oj < ofh; ++oj) {
1448           ij = oj * stride_h - pad_h;
1449           for (oi = 0; oi < ofw; ++oi) {
1450             ii = oi * stride_w - pad_w;
1451             for (kj = 0; kj < kh; ++kj) {
1452               if (ij+kj < 0 || ij+kj >= ifh) continue;
1453               for (ki = 0; ki < kw; ++ki) {
1454                 if (ii+ki < 0 || ii+ki >= ifw) continue;
1455                 LIBXSMM_VLA_ACCESS(  4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) +=
1456                   LIBXSMM_VLA_ACCESS(4,  input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
1457                   * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
1458               }
1459             }
1460           }
1461         }
1462       }
1463 #if defined(USE_FUSED_RELU) || defined(USE_FUSED_BIAS_RELU)
1464       for (oj = 0; oj < ofh; ++oj) {
1465         for (oi = 0; oi < ofw; ++oi) {
1466           LIBXSMM_VLA_ACCESS(  4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) =
1467             (LIBXSMM_VLA_ACCESS(  4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) < 0.0f) ? 0.0f : LIBXSMM_VLA_ACCESS(  4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp);
1468         }
1469       }
1470 #endif
1471     }
1472   }
1473 }
1474 
naive_conv_bp(naive_conv_t * param,float * input,const float * output,const float * filter,const float * naive_input_save)1475 LIBXSMM_INLINE void naive_conv_bp(naive_conv_t* param, float* input, const float* output, const float* filter, const float* naive_input_save)
1476 {
1477   int nImg      = param->nImg;
1478   int nIfm      = param->nIfm;
1479   int nOfm      = param->nOfm;
1480   int ifhp      = param->ifhp;
1481   int ifwp      = param->ifwp;
1482   int ofhp      = param->ofhp;
1483   int ofwp      = param->ofwp;
1484   int ifh       = param->ifh;
1485   int ifw       = param->ifw;
1486   int ofh       = param->ofh;
1487   int ofw       = param->ofw;
1488   int pad_h     = param->pad_h;
1489   int pad_w     = param->pad_w;
1490   int pad_h_in  = param->pad_h_in;
1491   int pad_w_in  = param->pad_w_in;
1492   int pad_h_out = param->pad_h_out;
1493   int pad_w_out = param->pad_w_out;
1494   int kh        = param->kh;
1495   int kw        = param->kw;
1496   int stride_h  = param->stride_h;
1497   int stride_w  = param->stride_w;
1498   /* loop counters */
1499   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1500 
1501   LIBXSMM_VLA_DECL(4, const float, output_t, output + (pad_h_out * ofwp + pad_w_out), nOfm, ofhp, ofwp);
1502   LIBXSMM_VLA_DECL(4,       float,  input_t,  input + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1503   LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw);
1504 #if (defined(USE_FUSED_RELU_BWD) || defined(USE_FUSED_BATCH_STATS_BWD))
1505   LIBXSMM_VLA_DECL(4, const float, naive_input_t, naive_input_save + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1506 #else
1507   LIBXSMM_UNUSED(naive_input_save);
1508 #endif
1509 
1510 #if defined(_OPENMP)
1511   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj);  LIBXSMM_OMP_VAR(oi);
1512   LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij);  LIBXSMM_OMP_VAR(ii);  LIBXSMM_OMP_VAR(kj);  LIBXSMM_OMP_VAR(ki);
1513 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1514 #endif
1515   for (img = 0; img < nImg; ++img) {
1516     for (ifm = 0; ifm < nIfm; ++ifm) {
1517       for (ofm = 0; ofm < nOfm; ++ofm) {
1518         for (oj = 0; oj < ofh; ++oj) {
1519           ij = oj * stride_h - pad_h;
1520           for (oi = 0; oi < ofw; ++oi) {
1521             ii = oi * stride_w - pad_w;
1522             for (kj = 0; kj < kh; ++kj) {
1523               if (ij+kj < 0 || ij+kj >= ifh) continue;
1524               for (ki = 0; ki < kw; ++ki) {
1525                 if (ii+ki < 0 || ii+ki >= ifw) continue;
1526                 LIBXSMM_VLA_ACCESS(4,  input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp) +=
1527                   LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp)
1528                   * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
1529               }
1530             }
1531           }
1532         }
1533       }
1534 #if (defined(USE_FUSED_RELU_BWD) || defined(USE_FUSED_BATCH_STATS_BWD))
1535       for (ij = 0; ij < ifh; ij++) {
1536         for (ii = 0; ii < ifw; ii++) {
1537           if ( LIBXSMM_VLA_ACCESS(4,  naive_input_t, img, ifm, ij, ii , nIfm, ifhp, ifwp) == 0.0 ) {
1538             LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij, ii , nIfm, ifhp, ifwp) = 0.0;
1539           }
1540         }
1541       }
1542 #endif
1543     }
1544   }
1545 }
1546 
naive_conv_wu(naive_conv_t * param,const float * input,const float * output,float * filter)1547 LIBXSMM_INLINE void naive_conv_wu(naive_conv_t* param, const float* input, const float* output, float* filter)
1548 {
1549   int nImg      = param->nImg;
1550   int nIfm      = param->nIfm;
1551   int nOfm      = param->nOfm;
1552   int ifhp      = param->ifhp;
1553   int ifwp      = param->ifwp;
1554   int ofhp      = param->ofhp;
1555   int ofwp      = param->ofwp;
1556   int ifh       = param->ifh;
1557   int ifw       = param->ifw;
1558   int ofh       = param->ofh;
1559   int ofw       = param->ofw;
1560   int pad_h     = param->pad_h;
1561   int pad_w     = param->pad_w;
1562   int pad_h_in  = param->pad_h_in;
1563   int pad_w_in  = param->pad_w_in;
1564   int pad_h_out = param->pad_h_out;
1565   int pad_w_out = param->pad_w_out;
1566   int kh        = param->kh;
1567   int kw        = param->kw;
1568   int stride_h  = param->stride_h;
1569   int stride_w  = param->stride_w;
1570   /* loop counters */
1571   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1572 
1573   LIBXSMM_VLA_DECL(4, const float, output_t, output + (pad_h_out * ofwp + pad_w_out), nOfm, ofhp, ofwp);
1574   LIBXSMM_VLA_DECL(4, const float,  input_t,  input + (pad_h_in * ifwp + pad_w_in), nIfm, ifhp, ifwp);
1575   LIBXSMM_VLA_DECL(4,       float, filter_t, filter, nIfm, kh, kw);
1576 
1577 #if defined(_OPENMP)
1578   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj);  LIBXSMM_OMP_VAR(oi);
1579   LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij);  LIBXSMM_OMP_VAR(ii);  LIBXSMM_OMP_VAR(kj);  LIBXSMM_OMP_VAR(ki);
1580 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1581 #endif
1582   for (ofm = 0; ofm < nOfm; ++ofm) {
1583     for (ifm = 0; ifm < nIfm; ++ifm) {
1584       for (img = 0; img < nImg; ++img) {
1585         for (oj = 0; oj < ofh; ++oj) {
1586           ij = oj * stride_h - pad_h;
1587           for (oi = 0; oi < ofw; ++oi) {
1588             ii = oi * stride_w - pad_w;
1589             for (kj = 0; kj < kh; ++kj) {
1590               if (ij+kj < 0 || ij+kj >= ifh) continue;
1591               for (ki = 0; ki < kw; ++ki) {
1592                 if (ii+ki < 0 || ii+ki >= ifw) continue;
1593                 LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw) +=
1594                   LIBXSMM_VLA_ACCESS(4,  input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
1595                   * LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp);
1596               }
1597             }
1598           }
1599         }
1600       }
1601     }
1602   }
1603 }
1604 
naive_conv_fp_int16fp32(naive_conv_t * param,const short * input,float * output,const short * filter)1605 LIBXSMM_INLINE void naive_conv_fp_int16fp32(naive_conv_t* param, const short* input, float* output, const short* filter)
1606 {
1607   int nImg      = param->nImg;
1608   int nIfm      = param->nIfm;
1609   int nOfm      = param->nOfm;
1610   int ifhp      = param->ifhp;
1611   int ifwp      = param->ifwp;
1612   int ofhp      = param->ofhp;
1613   int ofwp      = param->ofwp;
1614   int ifh       = param->ifh;
1615   int ifw       = param->ifw;
1616   int ofh       = param->ofh;
1617   int ofw       = param->ofw;
1618   int pad_h     = param->pad_h;
1619   int pad_w     = param->pad_w;
1620   int pad_h_in  = param->pad_h_in;
1621   int pad_w_in  = param->pad_w_in;
1622   int pad_h_out = param->pad_h_out;
1623   int pad_w_out = param->pad_w_out;
1624   int kh        = param->kh;
1625   int kw        = param->kw;
1626   int stride_h  = param->stride_h;
1627   int stride_w  = param->stride_w;
1628   /* loop counters */
1629   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1630 
1631   LIBXSMM_VLA_DECL(4,       float,     output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
1632   LIBXSMM_VLA_DECL(4, const short,      input_t,  input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
1633   LIBXSMM_VLA_DECL(4, const short,     filter_t, filter, nIfm, kh, kw);
1634 
1635 
1636 #if defined(_OPENMP)
1637   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj);  LIBXSMM_OMP_VAR(oi);
1638   LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij);  LIBXSMM_OMP_VAR(ii);  LIBXSMM_OMP_VAR(kj);  LIBXSMM_OMP_VAR(ki);
1639 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1640 #endif
1641   for (img = 0; img < nImg; ++img) {
1642     for (ofm = 0; ofm < nOfm; ++ofm) {
1643       for (ifm = 0; ifm < nIfm; ++ifm) {
1644         for (oj = 0; oj < ofh; ++oj) {
1645           ij = oj * stride_h - pad_h;
1646           for (oi = 0; oi < ofw; ++oi) {
1647             ii = oi * stride_w - pad_w;
1648             for (kj = 0; kj < kh; ++kj) {
1649               if (ij+kj < 0 || ij+kj >= ifh) continue;
1650               for (ki = 0; ki < kw; ++ki) {
1651                 if (ii+ki < 0 || ii+ki >= ifw) continue;
1652                 LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) +=
1653                   (1.f * LIBXSMM_VLA_ACCESS(4,  input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp))
1654                 * (1.f * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw));
1655               }
1656             }
1657           }
1658         }
1659       }
1660     }
1661   }
1662 }
1663 
naive_conv_fp_int16int32(naive_conv_t * param,const short * input,int * output,const short * filter)1664 LIBXSMM_INLINE void naive_conv_fp_int16int32(naive_conv_t* param, const short* input, int* output, const short* filter)
1665 {
1666   int nImg      = param->nImg;
1667   int nIfm      = param->nIfm;
1668   int nOfm      = param->nOfm;
1669   int ifhp      = param->ifhp;
1670   int ifwp      = param->ifwp;
1671   int ofhp      = param->ofhp;
1672   int ofwp      = param->ofwp;
1673   int ifh       = param->ifh;
1674   int ifw       = param->ifw;
1675   int ofh       = param->ofh;
1676   int ofw       = param->ofw;
1677   int pad_h     = param->pad_h;
1678   int pad_w     = param->pad_w;
1679   int pad_h_in  = param->pad_h_in;
1680   int pad_w_in  = param->pad_w_in;
1681   int pad_h_out = param->pad_h_out;
1682   int pad_w_out = param->pad_w_out;
1683   int kh        = param->kh;
1684   int kw        = param->kw;
1685   int stride_h  = param->stride_h;
1686   int stride_w  = param->stride_w;
1687   /* loop counters */
1688   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1689 
1690   LIBXSMM_VLA_DECL(4,         int,     output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
1691   LIBXSMM_VLA_DECL(4, const short,      input_t,  input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
1692   LIBXSMM_VLA_DECL(4, const short,     filter_t, filter, nIfm, kh, kw);
1693 
1694 
1695 #if defined(_OPENMP)
1696   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj);  LIBXSMM_OMP_VAR(oi);
1697   LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij);  LIBXSMM_OMP_VAR(ii);  LIBXSMM_OMP_VAR(kj);  LIBXSMM_OMP_VAR(ki);
1698 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1699 #endif
1700   for (img = 0; img < nImg; ++img) {
1701     for (ofm = 0; ofm < nOfm; ++ofm) {
1702       for (ifm = 0; ifm < nIfm; ++ifm) {
1703         for (oj = 0; oj < ofh; ++oj) {
1704           ij = oj * stride_h - pad_h;
1705           for (oi = 0; oi < ofw; ++oi) {
1706             ii = oi * stride_w - pad_w;
1707             for (kj = 0; kj < kh; ++kj) {
1708               if (ij+kj < 0 || ij+kj >= ifh) continue;
1709               for (ki = 0; ki < kw; ++ki) {
1710                 if (ii+ki < 0 || ii+ki >= ifw) continue;
1711                 LIBXSMM_VLA_ACCESS(  4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) += (int)
1712                  ( (int)LIBXSMM_VLA_ACCESS(4,  input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp))
1713                 * ( (int)  LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw));
1714               }
1715             }
1716           }
1717         }
1718       }
1719     }
1720   }
1721 }
1722 
naive_conv_fp_int8int32(naive_conv_t * param,const unsigned char * input,int * output,const char * filter)1723 LIBXSMM_INLINE void naive_conv_fp_int8int32(naive_conv_t* param, const unsigned char* input, int* output, const char* filter)
1724 {
1725   int nImg      = param->nImg;
1726   int nIfm      = param->nIfm;
1727   int nOfm      = param->nOfm;
1728   int ifhp      = param->ifhp;
1729   int ifwp      = param->ifwp;
1730   int ofhp      = param->ofhp;
1731   int ofwp      = param->ofwp;
1732   int ifh       = param->ifh;
1733   int ifw       = param->ifw;
1734   int ofh       = param->ofh;
1735   int ofw       = param->ofw;
1736   int pad_h     = param->pad_h;
1737   int pad_w     = param->pad_w;
1738   int pad_h_in  = param->pad_h_in;
1739   int pad_w_in  = param->pad_w_in;
1740   int pad_h_out = param->pad_h_out;
1741   int pad_w_out = param->pad_w_out;
1742   int kh        = param->kh;
1743   int kw        = param->kw;
1744   int stride_h  = param->stride_h;
1745   int stride_w  = param->stride_w;
1746   /* loop counters */
1747   int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
1748 
1749   LIBXSMM_VLA_DECL(4,         int,     output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
1750   LIBXSMM_VLA_DECL(4, const unsigned char,      input_t,  input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
1751   LIBXSMM_VLA_DECL(4, const char,     filter_t, filter, nIfm, kh, kw);
1752 
1753 
1754 #if defined(_OPENMP)
1755   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(oj);  LIBXSMM_OMP_VAR(oi);
1756   LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ij);  LIBXSMM_OMP_VAR(ii);  LIBXSMM_OMP_VAR(kj);  LIBXSMM_OMP_VAR(ki);
1757 # pragma omp parallel for LIBXSMM_OPENMP_COLLAPSE(2) private(img, ofm, ifm, oj, oi, ij, ii, kj, ki)
1758 #endif
1759   for (img = 0; img < nImg; ++img) {
1760     for (ofm = 0; ofm < nOfm; ++ofm) {
1761       for (ifm = 0; ifm < nIfm; ++ifm) {
1762         for (oj = 0; oj < ofh; ++oj) {
1763           ij = oj * stride_h - pad_h;
1764           for (oi = 0; oi < ofw; ++oi) {
1765             ii = oi * stride_w - pad_w;
1766             for (kj = 0; kj < kh; ++kj) {
1767               if (ij+kj < 0 || ij+kj >= ifh) continue;
1768               for (ki = 0; ki < kw; ++ki) {
1769                 if (ii+ki < 0 || ii+ki >= ifw) continue;
1770                 LIBXSMM_VLA_ACCESS(4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) += (int)
1771                 LIBXSMM_VLA_ACCESS(4,  input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
1772                 * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
1773               }
1774             }
1775           }
1776         }
1777       }
1778     }
1779   }
1780 }
1781 
naive_fullyconnected_fp(naive_fullyconnected_t * param,const float * input_ptr,float * output_ptr,const float * filter_ptr)1782 LIBXSMM_INLINE void naive_fullyconnected_fp(naive_fullyconnected_t* param, const float* input_ptr, float* output_ptr, const float* filter_ptr)
1783 {
1784   const int nImg = param->N;
1785   const int nIFm = param->C;
1786   const int nOFm = param->K;
1787 
1788   int img, ifm, ofm;
1789 
1790   LIBXSMM_VLA_DECL(2, const float, input,  input_ptr,  nIFm);
1791   LIBXSMM_VLA_DECL(2, const float, filter, filter_ptr, nIFm);
1792   LIBXSMM_VLA_DECL(2,       float, output, output_ptr, nOFm);
1793 
1794 #if defined(_OPENMP)
1795   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ofm);
1796 # pragma omp parallel for private(img, ofm, ifm)
1797 #endif
1798   for (ofm = 0; ofm < nOFm; ++ofm) {
1799     for(img = 0; img < nImg; ++img) {
1800       LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = (float)0;
1801       for (ifm = 0; ifm < nIFm; ++ifm) {
1802         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) +=
1803           LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, input, img, ifm, nIFm);
1804       }
1805     }
1806   }
1807 }
1808 
naive_fullyconnected_bp(naive_fullyconnected_t * param,float * delinput_ptr,const float * deloutput_ptr,const float * filter_ptr)1809 LIBXSMM_INLINE void naive_fullyconnected_bp(naive_fullyconnected_t* param, float* delinput_ptr, const float* deloutput_ptr, const float* filter_ptr)
1810 {
1811   const int nImg = param->N;
1812   const int nIFm = param->C;
1813   const int nOFm = param->K;
1814 
1815   int img, ifm, ofm;
1816 
1817   LIBXSMM_VLA_DECL(2,       float,  dinput,  delinput_ptr, nIFm);
1818   LIBXSMM_VLA_DECL(2, const float,  filter,    filter_ptr, nIFm);
1819   LIBXSMM_VLA_DECL(2, const float, doutput, deloutput_ptr, nOFm);
1820 
1821 #if defined(_OPENMP)
1822   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(ifm);
1823 # pragma omp parallel for private(img, ofm, ifm)
1824 #endif
1825   for (ifm = 0; ifm < nIFm; ++ifm) {
1826     for(img = 0; img < nImg; ++img) {
1827       LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) = (float)0;
1828       for (ofm = 0; ofm < nOFm; ++ofm) {
1829         LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) +=
1830           LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1831       }
1832     }
1833   }
1834 }
1835 
naive_fullyconnected_fused_fp(naive_fullyconnected_t * param,const float * input_ptr,float * output_ptr,const float * filter_ptr,const float * bias_ptr)1836 LIBXSMM_INLINE void naive_fullyconnected_fused_fp(naive_fullyconnected_t* param, const float* input_ptr, float* output_ptr, const float* filter_ptr, const float* bias_ptr)
1837 {
1838   const int nImg = param->N;
1839   const int nIFm = param->C;
1840   const int nOFm = param->K;
1841 
1842   int img, ifm, ofm;
1843 
1844   LIBXSMM_VLA_DECL(2, const float, input,  input_ptr,  nIFm);
1845   LIBXSMM_VLA_DECL(2, const float, filter, filter_ptr, nIFm);
1846   LIBXSMM_VLA_DECL(2,       float, output, output_ptr, nOFm);
1847 
1848 #if defined(_OPENMP)
1849   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ifm); LIBXSMM_OMP_VAR(ofm);
1850 # pragma omp parallel for private(img, ofm, ifm)
1851 #endif
1852   for (ofm = 0; ofm < nOFm; ++ofm) {
1853     for(img = 0; img < nImg; ++img) {
1854       LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = (float)0;
1855       for (ifm = 0; ifm < nIFm; ++ifm) {
1856         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) +=
1857           LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, input, img, ifm, nIFm);
1858       }
1859       if ( param->fuse_type == 1 ) {
1860         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) += bias_ptr[ofm];
1861       } else if ( param->fuse_type == 2 ) {
1862         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) : 0;
1863       } else if ( param->fuse_type == 3 ) {
1864         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ((float)tanh((double)LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm)/2.0)+1.0f)/2.0f;
1865       } else if ( param->fuse_type == 4 ) {
1866         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) += bias_ptr[ofm];
1867         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) : 0;
1868       } else if ( param->fuse_type == 5 ) {
1869         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) += bias_ptr[ofm];
1870         LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) = ((float)tanh((double)LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm)/2.0)+1.0f)/2.0f;
1871       }
1872     }
1873   }
1874 }
1875 
naive_fullyconnected_fused_bp(naive_fullyconnected_t * param,float * delinput_ptr,float * deloutput_ptr,const float * filter_ptr,float * delbias_ptr,const float * output_ptr)1876 LIBXSMM_INLINE void naive_fullyconnected_fused_bp(naive_fullyconnected_t* param, float* delinput_ptr, float* deloutput_ptr, const float* filter_ptr, float* delbias_ptr, const float* output_ptr)
1877 {
1878   const int nImg = param->N;
1879   const int nIFm = param->C;
1880   const int nOFm = param->K;
1881 
1882   int img, ifm, ofm;
1883 
1884   LIBXSMM_VLA_DECL(2,       float,  dinput,  delinput_ptr, nIFm);
1885   LIBXSMM_VLA_DECL(2, const float,  filter,    filter_ptr, nIFm);
1886   LIBXSMM_VLA_DECL(2,       float, doutput, deloutput_ptr, nOFm);
1887   LIBXSMM_VLA_DECL(2, const float,  output,    output_ptr, nOFm);
1888 
1889   if ( param->fuse_type != 0 ) {
1890 #if defined(_OPENMP)
1891     LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm);
1892 # pragma omp parallel for private(img, ofm)
1893 #endif
1894     for (ofm = 0; ofm < nOFm; ++ofm) {
1895       float dbias = 0.0f;
1896       for(img = 0; img < nImg; ++img) {
1897         if ( param->fuse_type == 1 ) {
1898           dbias += LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1899         } else if ( param->fuse_type == 2 ) {
1900           LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) : 0;
1901         } else if ( param->fuse_type == 3 ) {
1902           LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm)*(1.0f-LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm));
1903         } else if ( param->fuse_type == 4 ) {
1904           LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = ( LIBXSMM_VLA_ACCESS(2, output, img, ofm, nOFm) > 0 ) ? LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) : 0;
1905           dbias += LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1906         } else if ( param->fuse_type == 5 ) {
1907           LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) = LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm)*(1.0f-LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm));
1908           dbias += LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1909         }
1910       }
1911       delbias_ptr[ofm] = dbias;
1912     }
1913   }
1914 
1915 #if defined(_OPENMP)
1916   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(ifm);
1917 # pragma omp parallel for private(img, ofm, ifm)
1918 #endif
1919   for (ifm = 0; ifm < nIFm; ++ifm) {
1920     for(img = 0; img < nImg; ++img) {
1921       LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) = (float)0;
1922       for (ofm = 0; ofm < nOFm; ++ofm) {
1923         LIBXSMM_VLA_ACCESS(2, dinput, img, ifm, nIFm) +=
1924           LIBXSMM_VLA_ACCESS(2, filter, ofm, ifm, nIFm) * LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm);
1925       }
1926     }
1927   }
1928 }
1929 
naive_fullyconnected_wu(naive_fullyconnected_t * param,const float * input_ptr,const float * deloutput_ptr,float * delfilter_ptr)1930 LIBXSMM_INLINE void naive_fullyconnected_wu(naive_fullyconnected_t* param, const float* input_ptr, const float* deloutput_ptr, float* delfilter_ptr)
1931 {
1932   const int nImg = param->N;
1933   const int nIFm = param->C;
1934   const int nOFm = param->K;
1935 
1936   int img, ifm, ofm;
1937 
1938   LIBXSMM_VLA_DECL(2, const float,   input,     input_ptr, nIFm);
1939   LIBXSMM_VLA_DECL(2,       float, dfilter, delfilter_ptr, nIFm);
1940   LIBXSMM_VLA_DECL(2, const float, doutput, deloutput_ptr, nOFm);
1941 
1942 #if defined(_OPENMP)
1943   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(ofm); LIBXSMM_OMP_VAR(ifm);
1944 # pragma omp parallel for private(img, ofm, ifm)
1945 #endif
1946   for (ofm = 0; ofm < nOFm; ++ofm) {
1947     for (ifm = 0; ifm < nIFm; ++ifm) {
1948       LIBXSMM_VLA_ACCESS(2, dfilter, ofm, ifm, nIFm) = (float)0;
1949       for(img = 0; img < nImg; ++img) {
1950         LIBXSMM_VLA_ACCESS(2, dfilter, ofm, ifm, nIFm) +=
1951           LIBXSMM_VLA_ACCESS(2, doutput, img, ofm, nOFm) * LIBXSMM_VLA_ACCESS(2, input, img, ifm, nIFm);
1952       }
1953     }
1954   }
1955 }
1956 
naive_pooling_fp(naive_pooling_t * param,const float * input_ptr,float * output_ptr,int * mask_ptr)1957 LIBXSMM_INLINE void naive_pooling_fp(naive_pooling_t* param, const float* input_ptr, float* output_ptr, int* mask_ptr)
1958 {
1959   const int nImg = param->N;
1960   const int nFm = param->C;
1961   const int ifh = param->H;
1962   const int ifw = param->W;
1963   const int sh = param->stride_h;
1964   const int sw = param->stride_w;
1965   const int r = param->R;
1966   const int s = param->S;
1967   const int pad_h = param->pad_h;
1968   const int pad_w = param->pad_w;
1969   const int ofh = (ifh + 2*pad_h - r)/sh + 1;
1970   const int ofw = (ifw + 2*pad_w - s)/sw + 1;
1971 
1972 
1973   int img, fm;
1974 
1975   LIBXSMM_VLA_DECL(4, const float, input,   input_ptr, nFm, ifh, ifw);
1976   LIBXSMM_VLA_DECL(4,       int,   mask,     mask_ptr, nFm, ofh, ofw);
1977   LIBXSMM_VLA_DECL(4,       float, output, output_ptr, nFm, ofh, ofw);
1978 
1979 #if defined(_OPENMP)
1980   float* tmp_buffer = (float*)malloc(sizeof(float)*ofh*ofw*omp_get_max_threads());
1981   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(fm);
1982 # pragma omp parallel for private(img, fm)
1983 #else
1984   float* tmp_buffer = (float*)malloc(sizeof(float)*ofh*ofw);
1985 #endif
1986   for (img = 0; img < nImg; img++) {
1987     for (fm = 0; fm < nFm; fm++) {
1988 #if defined(_OPENMP)
1989       float* lcl_buffer_ptr = tmp_buffer + (ofh*ofw*omp_get_thread_num());
1990 #else
1991       float* lcl_buffer_ptr = tmp_buffer;
1992 #endif
1993       LIBXSMM_VLA_DECL(2, float, lcl_buffer, lcl_buffer_ptr, ofw);
1994       int i, ho, wo, hi, wi, kh, kw;
1995 
1996       if (param->type == 0 ) {
1997         for ( i = 0; i < ofh*ofw; i++ ) {
1998           lcl_buffer_ptr[i] = -FLT_MAX;
1999         }
2000       } else if (param->type == 1) {
2001         for ( i = 0; i < ofh*ofw; i++ ) {
2002           lcl_buffer_ptr[i] = 0.0;
2003         }
2004       } else {
2005         /* shouldn't happen */
2006       }
2007 
2008       for( ho = 0; ho < ofh; ho++ ) {
2009         hi = (ho * sh) - pad_h;
2010         for( wo = 0; wo < ofw; wo++ ) {
2011           wi = (wo * sw) - pad_w;
2012           for( kh = 0; kh < r; kh++ ) {
2013             if (hi+kh < 0 || hi+kh >= ifh) continue;
2014             for( kw = 0; kw < s; kw++ ) {
2015               if (wi+kw < 0 || wi+kw >= ifw) continue;
2016               if ( param->type == 0 ) {
2017                 const int index = (hi+kh)*ifw + wi+kw;
2018                 if ( LIBXSMM_VLA_ACCESS(4, input, img, fm, hi+kh, wi+kw, nFm, ifh, ifw) > LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) ) {
2019                   LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi+kh, wi+kw, nFm, ifh, ifw);
2020                   LIBXSMM_VLA_ACCESS(4, mask, img, fm, ho, wo, nFm, ofh, ofw) = index;
2021                 }
2022               } else if ( param->type == 1 ) {
2023                 LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) += LIBXSMM_VLA_ACCESS(4, input, img, fm, hi+kh, wi+kw, nFm, ifh, ifw);
2024               } else {
2025                 /* shouldn't happen */
2026               }
2027             }
2028           }
2029         }
2030       }
2031 
2032       if (param->type == 0 ) {
2033         for( ho = 0; ho < ofh; ho++ ) {
2034           for( wo = 0; wo < ofw; wo++ ) {
2035             LIBXSMM_VLA_ACCESS(4, output, img, fm, ho, wo, nFm, ofh, ofw) = LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw);
2036           }
2037         }
2038       } else if (param->type == 1) {
2039         for( ho = 0; ho < ofh; ho++ ) {
2040           for( wo = 0; wo < ofw; wo++ ) {
2041             LIBXSMM_VLA_ACCESS(4, output, img, fm, ho, wo, nFm, ofh, ofw) = LIBXSMM_VLA_ACCESS(2, lcl_buffer, ho, wo, ofw) * (1.0f/(((float)r) * ((float)s)));
2042           }
2043         }
2044       } else {
2045         /* shouldn't happen */
2046       }
2047     }
2048   }
2049 
2050   free( tmp_buffer );
2051 }
2052 
naive_pooling_bp(naive_pooling_t * param,float * dinput_ptr,const float * doutput_ptr,const int * mask_ptr)2053 LIBXSMM_INLINE void naive_pooling_bp(naive_pooling_t* param, float* dinput_ptr, const float* doutput_ptr, const int* mask_ptr)
2054 {
2055   const int nImg = param->N;
2056   const int nFm = param->C;
2057   const int ifh = param->H;
2058   const int ifw = param->W;
2059   const int sh = param->stride_h;
2060   const int sw = param->stride_w;
2061   const int r = param->R;
2062   const int s = param->S;
2063   const int pad_h = param->pad_h;
2064   const int pad_w = param->pad_w;
2065   const int ofh = (ifh + 2*pad_h - r)/sh + 1;
2066   const int ofw = (ifw + 2*pad_w - s)/sw + 1;
2067 
2068   int img, fm;
2069 
2070   LIBXSMM_VLA_DECL(4,       float, dinput,   dinput_ptr, nFm, ifh, ifw);
2071   LIBXSMM_VLA_DECL(4, const int  ,  mask,      mask_ptr, nFm, ofh, ofw);
2072   LIBXSMM_VLA_DECL(4, const float, doutput, doutput_ptr, nFm, ofh, ofw);
2073 
2074 #if defined(_OPENMP)
2075   float* tmp_buffer = (float*)malloc(sizeof(float)*ifh*ifw*omp_get_max_threads());
2076   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(fm);
2077 # pragma omp parallel for private(img, fm)
2078 #else
2079   float* tmp_buffer = (float*)malloc(sizeof(float)*ofh*ofw);
2080 #endif
2081   for (img = 0; img < nImg; img++) {
2082     for (fm = 0; fm < nFm; fm++) {
2083 #if defined(_OPENMP)
2084       float* lcl_buffer_ptr = tmp_buffer + (ifh*ifw*omp_get_thread_num());
2085 #else
2086       float* lcl_buffer_ptr = tmp_buffer;
2087 #endif
2088       LIBXSMM_VLA_DECL(2, float, lcl_buffer, lcl_buffer_ptr, ifw);
2089       int i, ho, wo, hi, wi, kh, kw;
2090 
2091       for ( i = 0; i < ifh*ifw; i++ ) {
2092         lcl_buffer_ptr[i] = 0.0;
2093       }
2094 
2095       if (param->type == 0 ) {
2096         for( ho = 0; ho < ofh; ho++ ) {
2097           for( wo = 0; wo < ofw; wo++ ) {
2098             lcl_buffer_ptr[LIBXSMM_VLA_ACCESS(4, mask, img, fm, ho, wo, nFm, ofh, ofw)] += LIBXSMM_VLA_ACCESS(4, doutput, img, fm, ho, wo, nFm, ofh, ofw);
2099           }
2100         }
2101       } else if ( param->type == 1 ) {
2102         for( ho = 0; ho < ofh; ho++ ) {
2103           hi = (ho * sh) - pad_h;
2104           for( wo = 0; wo < ofw; wo++ ) {
2105             wi = (wo * sw) - pad_w;
2106             for( kh = 0; kh < r; kh++ ) {
2107               if (hi+kh < 0 || hi+kh >= ifh) continue;
2108               for( kw = 0; kw < s; kw++ ) {
2109                 if (wi+kw < 0 || wi+kw >= ifw) continue;
2110                 LIBXSMM_VLA_ACCESS(2, lcl_buffer, hi+kh, wi+kw, ifw) += ( LIBXSMM_VLA_ACCESS(4, doutput, img, fm, ho, wo, nFm, ofh, ofw) * (1.0f/(((float)r) * ((float)s))) );
2111               }
2112             }
2113           }
2114         }
2115       } else {
2116         /* shouldn't happen */
2117       }
2118 
2119       for( hi = 0; hi < ifh; hi++ ) {
2120         for( wi = 0; wi < ifw; wi++ ) {
2121           LIBXSMM_VLA_ACCESS(4, dinput, img, fm, hi, wi, nFm, ifh, ifw) = LIBXSMM_VLA_ACCESS(2, lcl_buffer, hi, wi, ifw);
2122         }
2123       }
2124     }
2125   }
2126 
2127   free( tmp_buffer );
2128 }
2129 
naive_fusedbatchnorm_fp(naive_fusedbatchnorm_t * param,const float * input_ptr,float * output_ptr,const float * input_add_ptr,const float * beta_ptr,const float * gamma_ptr,float * expectval_ptr,float * rcpstddev_ptr,float * variance_ptr)2130 LIBXSMM_INLINE void naive_fusedbatchnorm_fp(naive_fusedbatchnorm_t* param, const float* input_ptr, float* output_ptr, const float* input_add_ptr,
2131                                      const float* beta_ptr, const float* gamma_ptr, float* expectval_ptr, float* rcpstddev_ptr, float* variance_ptr)
2132 {
2133   const int nImg = param->N;
2134   const int nFm = param->C;
2135   const int ifh = param->H;
2136   const int ifw = param->W;
2137   const int sh = param->stride_h;
2138   const int sw = param->stride_w;
2139   const int ofh = ifh/sh;
2140   const int ofw = ifw/sw;
2141   const float nhw = (float)(nImg * ifh * ifw);
2142   const float recp_nhw = 1.0f/nhw;
2143   const float sqrt_eps = 1e-7f;
2144 
2145   int img, fm, hi, wi, ho, wo;
2146 
2147   LIBXSMM_VLA_DECL(4, const float, input,     input_ptr,     nFm, ifh, ifw);
2148   LIBXSMM_VLA_DECL(4, const float, input_add, input_add_ptr, nFm, ifh, ifw);
2149   LIBXSMM_VLA_DECL(4,       float, output,    output_ptr,    nFm, ofh, ofw);
2150 
2151   if ( param->norm_type == 0 ) {
2152 #if defined(_OPENMP)
2153     LIBXSMM_OMP_VAR(wi); LIBXSMM_OMP_VAR(hi);
2154 #   pragma omp parallel for private(img, fm, hi, wi)
2155 #endif
2156     for (fm = 0; fm < nFm; fm++) {
2157       float ch_sum = 0.0f;
2158       float ch_sumsq = 0.0f;
2159       float tbmean = 0.0f;
2160       float tbmeansq = 0.0f;
2161       float tsqbmean = 0.0f;
2162       float tbrstd = 0.0f;
2163       float tvariance = 0.0f;
2164 
2165       for ( img = 0; img < nImg; img++ ) {
2166         for ( hi = 0; hi < ifh; hi++ ) {
2167           for ( wi = 0; wi < ifw; wi++ ) {
2168             const float input_val = LIBXSMM_VLA_ACCESS(4, input, img, fm, hi, wi, nFm, ifh, ifw);
2169             ch_sum   += input_val;
2170             ch_sumsq += (input_val * input_val);
2171           }
2172         }
2173       }
2174 
2175       tbmean = recp_nhw * ch_sum;
2176       tbmeansq  = tbmean * tbmean;
2177       tsqbmean = recp_nhw * ch_sumsq;
2178       tvariance = tsqbmean - tbmeansq;
2179       tbrstd = (float)(1.0/sqrt(tvariance + sqrt_eps));
2180       expectval_ptr[fm] = tbmean;
2181       rcpstddev_ptr[fm] = tbrstd;
2182       variance_ptr[fm] = tvariance;
2183     }
2184   }
2185 
2186 #if defined(_OPENMP)
2187   LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo);
2188 # pragma omp parallel for private(img, fm, hi, wi, ho, wo)
2189 #endif
2190   for ( img = 0; img < nImg; img++ ) {
2191     for ( fm = 0; fm < nFm; fm++ ) {
2192       for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2193         for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2194           const float  input_val     =  LIBXSMM_VLA_ACCESS(4, input,     img, fm, hi, wi, nFm, ifh, ifw);
2195           const float  input_add_val =  LIBXSMM_VLA_ACCESS(4, input_add, img, fm, hi, wi, nFm, ifh, ifw);
2196                 float* output_ptr2   = &LIBXSMM_VLA_ACCESS(4, output,    img, fm, ho, wo, nFm, ofh, ofw);
2197 
2198           /* BN + scale (gamma, beta) */
2199           float o = gamma_ptr[fm]*(input_val - expectval_ptr[fm])*rcpstddev_ptr[fm] + beta_ptr[fm];
2200           /* Eltwise */
2201           if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2202             o += input_add_val;
2203           }
2204           /* ReLU */
2205           if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2206             o = ( o < 0.0f ) ? 0.0f : o;
2207           }
2208           *output_ptr2 = o;
2209         }
2210       }
2211     }
2212   }
2213 }
2214 
naive_fusedbatchnorm_bp(naive_fusedbatchnorm_t * param,const float * input_ptr,float * dinput_ptr,const float * output_ptr,float * doutput_ptr,float * dinput_add_ptr,const float * beta_ptr,float * del_beta_ptr,const float * gamma_ptr,float * del_gamma_ptr,const float * expectval_ptr,const float * rcpstddev_ptr)2215 LIBXSMM_INLINE void naive_fusedbatchnorm_bp(naive_fusedbatchnorm_t* param, const float* input_ptr, float* dinput_ptr, const float* output_ptr, float* doutput_ptr, float* dinput_add_ptr,
2216                                      const float* beta_ptr, float* del_beta_ptr, const float* gamma_ptr, float* del_gamma_ptr,
2217                                      const float* expectval_ptr, const float* rcpstddev_ptr)
2218 {
2219   const int nImg = param->N;
2220   const int nFm = param->C;
2221   const int ifh = param->H;
2222   const int ifw = param->W;
2223   const int sh = param->stride_h;
2224   const int sw = param->stride_w;
2225   const int ofh = ifh/sh;
2226   const int ofw = ifw/sw;
2227   const float nhw = (float)(nImg * ifh * ifw);
2228   const float recp_nhw = 1.0f/nhw;
2229 
2230   int img, fm, hi, wi, ho, wo;
2231 
2232   LIBXSMM_VLA_DECL(4, const float, input,      input_ptr,      nFm, ifh, ifw);
2233   LIBXSMM_VLA_DECL(4,       float, dinput,     dinput_ptr,     nFm, ifh, ifw);
2234   LIBXSMM_VLA_DECL(4,       float, dinput_add, dinput_add_ptr, nFm, ifh, ifw);
2235   LIBXSMM_VLA_DECL(4, const float, output,     output_ptr,     nFm, ofh, ofw);
2236   LIBXSMM_VLA_DECL(4,       float, doutput,    doutput_ptr,    nFm, ofh, ofw);
2237   LIBXSMM_UNUSED(beta_ptr);
2238 
2239   if ( param->norm_type == 0 ) {
2240 #if defined(_OPENMP)
2241     LIBXSMM_OMP_VAR(hi); LIBXSMM_OMP_VAR(wi); LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo);
2242 #   pragma omp parallel for private(img, fm, hi, wi, ho, wo)
2243 #endif
2244     for ( fm = 0; fm < nFm; fm++ ) {
2245       del_gamma_ptr[fm] = 0.0f;
2246       del_beta_ptr[fm] = 0.0f;
2247 
2248       for ( img = 0; img < nImg; img++ ) {
2249         for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2250           for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2251                   float* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(4, dinput_add, img, fm, hi, wi, fm, ifh, ifw);
2252             const float  output_val        =  LIBXSMM_VLA_ACCESS(4,     output, img, fm, ho, wo, fm, ofh, ofw);
2253             const float  input_val         =  LIBXSMM_VLA_ACCESS(4,      input, img, fm, hi, wi, fm, ifh, ifw);
2254                   float* del_output_ptr    = &LIBXSMM_VLA_ACCESS(4,    doutput, img, fm, ho, wo, fm, ofh, ofw);
2255 
2256             /* ReLU */
2257             if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2258               *del_output_ptr    = (output_val == 0) ? 0 : *del_output_ptr;
2259             }
2260             /* elementwise */
2261             if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2262               *del_input_add_ptr = *del_output_ptr;
2263             }
2264             del_gamma_ptr[fm] += (input_val - expectval_ptr[fm]) * (*del_output_ptr) * rcpstddev_ptr[fm];
2265             del_beta_ptr[fm]  += *del_output_ptr;
2266           }
2267         }
2268       }
2269     }
2270   }
2271 
2272 #if defined(_OPENMP)
2273 # pragma omp parallel for private(img, fm, hi, wi, ho, wo)
2274 #endif
2275   for ( img = 0; img < nImg; img++ ) {
2276     for ( fm = 0; fm < nFm; fm++ ) {
2277       for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2278         for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++) {
2279                 float* del_input_ptr  = &LIBXSMM_VLA_ACCESS(4,     dinput, img, fm, hi, wi, fm, ifh, ifw);
2280           const float  input_val      =  LIBXSMM_VLA_ACCESS(4,      input, img, fm, hi, wi, fm, ifh, ifw);
2281           const float  del_output_val =  LIBXSMM_VLA_ACCESS(4,    doutput, img, fm, ho, wo, fm, ofh, ofw);
2282 
2283           *del_input_ptr = gamma_ptr[fm] * rcpstddev_ptr[fm] * recp_nhw * (nhw * del_output_val -
2284                     (del_beta_ptr[fm] + (input_val - expectval_ptr[fm]) * del_gamma_ptr[fm] * rcpstddev_ptr[fm]));
2285         }
2286       }
2287     }
2288   }
2289 }
2290 
naive_fusedgroupnorm_fp(naive_fusedgroupnorm_t * param,const float * input_ptr,float * output_ptr,const float * input_add_ptr,const float * beta_ptr,const float * gamma_ptr,float * expectval_ptr,float * rcpstddev_ptr,float * variance_ptr)2291 LIBXSMM_INLINE void naive_fusedgroupnorm_fp(naive_fusedgroupnorm_t* param, const float* input_ptr, float* output_ptr, const float* input_add_ptr,
2292                                      const float* beta_ptr, const float* gamma_ptr, float* expectval_ptr, float* rcpstddev_ptr, float* variance_ptr)
2293 {
2294   const int nImg = param->N;
2295   const int nFm = param->C;
2296   const int ifh = param->H;
2297   const int ifw = param->W;
2298   const int sh = param->stride_h;
2299   const int sw = param->stride_w;
2300   const int ofh = ifh/sh;
2301   const int ofw = ifw/sw;
2302   const int nG = param->G;
2303   const int nFMG = nFm/nG;
2304   const float ghw = (float)(nFMG * ifh * ifw);
2305   const float recp_ghw = 1.0f/ghw;
2306   const float sqrt_eps = 1e-7f;
2307 
2308   int img, g, fmg, hi, wi, ho, wo;
2309 
2310   LIBXSMM_VLA_DECL(5, const float, input,     input_ptr,     nG,  nFMG, ifh, ifw);
2311   LIBXSMM_VLA_DECL(5, const float, input_add, input_add_ptr, nG,  nFMG, ifh, ifw);
2312   LIBXSMM_VLA_DECL(5,       float, output,    output_ptr,    nG,  nFMG, ofh, ofw);
2313 
2314 #if defined(_OPENMP)
2315   LIBXSMM_OMP_VAR(img); LIBXSMM_OMP_VAR(g); LIBXSMM_OMP_VAR(fmg); LIBXSMM_OMP_VAR(hi); LIBXSMM_OMP_VAR(wi);
2316 # pragma omp parallel for private(img, g, fmg, hi, wi)
2317 #endif
2318   for ( img = 0; img < nImg; img++ ) {
2319     for (g = 0; g < nG; g++) {
2320       float ch_sum = 0.0f;
2321       float ch_sumsq = 0.0f;
2322       float tbmean = 0.0f;
2323       float tbmeansq = 0.0f;
2324       float tsqbmean = 0.0f;
2325       float tbrstd = 0.0f;
2326       float tvariance = 0.0f;
2327 
2328       for ( fmg = 0; fmg < nFMG; fmg++) {
2329         for ( hi = 0; hi < ifh; hi++ ) {
2330           for ( wi = 0; wi < ifw; wi++ ) {
2331             const float input_val = LIBXSMM_VLA_ACCESS(5, input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2332             ch_sum   += input_val;
2333             ch_sumsq += (input_val * input_val);
2334           }
2335         }
2336       }
2337 
2338       tbmean = recp_ghw * ch_sum;
2339       tbmeansq  = tbmean * tbmean;
2340       tsqbmean = recp_ghw * ch_sumsq;
2341       tvariance = tsqbmean - tbmeansq;
2342       tbrstd = (float)(1.0/sqrt(tvariance + sqrt_eps));
2343       expectval_ptr[img*nG+g] = tbmean;
2344       rcpstddev_ptr[img*nG+g] = tbrstd;
2345       variance_ptr[img*nG+g] = tvariance;
2346     }
2347   }
2348 
2349 #if defined(_OPENMP)
2350   LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo);
2351 # pragma omp parallel for private(img, g, fmg, hi, wi, ho, wo)
2352 #endif
2353   for ( img = 0; img < nImg; img++ ) {
2354     for ( g = 0; g < nG; g++ ) {
2355       for ( fmg = 0; fmg < nFMG; fmg++ ) {
2356         for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2357           for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2358             const float  input_val      =  LIBXSMM_VLA_ACCESS(5, input,     img, g,  fmg, hi, wi, nG,  nFMG, ifh, ifw);
2359             const float  input_add_val  =  LIBXSMM_VLA_ACCESS(5, input_add, img, g,  fmg, hi, wi, nG,  nFMG, ifh, ifw);
2360             float* output_ptr2          = &LIBXSMM_VLA_ACCESS(5, output,    img, g,  fmg, ho, wo, nG,  nFMG, ofh, ofw);
2361 
2362             /* BN + scale (gamma, beta) */
2363             float o = gamma_ptr[g*nFMG+fmg]*(input_val - expectval_ptr[img*nG+g])*rcpstddev_ptr[img*nG+g] + beta_ptr[g*nFMG+fmg];
2364             /* Eltwise */
2365             if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2366               o += input_add_val;
2367             }
2368             /* ReLU */
2369             if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2370               o = ( o < 0.0f ) ? 0.0f : o;
2371             }
2372             *output_ptr2 = o;
2373           }
2374         }
2375       }
2376     }
2377   }
2378 }
2379 
naive_fusedgroupnorm_bp(naive_fusedgroupnorm_t * param,const float * input_ptr,float * dinput_ptr,const float * output_ptr,float * doutput_ptr,float * dinput_add_ptr,const float * beta_ptr,float * del_beta_ptr,const float * gamma_ptr,float * del_gamma_ptr,const float * expectval_ptr,const float * rcpstddev_ptr,const float * variance_ptr)2380 LIBXSMM_INLINE void naive_fusedgroupnorm_bp(naive_fusedgroupnorm_t* param, const float* input_ptr, float* dinput_ptr, const float* output_ptr, float* doutput_ptr, float* dinput_add_ptr,
2381                                      const float* beta_ptr, float* del_beta_ptr, const float* gamma_ptr, float* del_gamma_ptr,
2382                                      const float* expectval_ptr, const float* rcpstddev_ptr, const float* variance_ptr)
2383 {
2384   const int nImg = param->N;
2385   const int nFm = param->C;
2386   const int ifh = param->H;
2387   const int ifw = param->W;
2388   const int sh = param->stride_h;
2389   const int sw = param->stride_w;
2390   const int ofh = ifh/sh;
2391   const int ofw = ifw/sw;
2392   const int nG = param->G;
2393   const int nFMG = nFm/nG;
2394   const float ghw = (float)(nFMG * ifh * ifw);
2395   const float recp_ghw = 1.0f/ghw;
2396   const float eps = 1e-7f;
2397 
2398   int img, g, fmg, fm, hi, wi, ho, wo;
2399 
2400   LIBXSMM_VLA_DECL(5, const float, input,      input_ptr,      nG,  nFMG, ifh, ifw);
2401   LIBXSMM_VLA_DECL(5,       float, dinput,     dinput_ptr,     nG,  nFMG, ifh, ifw);
2402   /*LIBXSMM_VLA_DECL(5, const float, output,     output_ptr,     nG,  nFMG, ofh, ofw);*/
2403   LIBXSMM_VLA_DECL(5,       float, doutput,    doutput_ptr,    nG,  nFMG, ofh, ofw);
2404 
2405   LIBXSMM_VLA_DECL(4, const float, input_gb,      input_ptr,      nFm,  ifh, ifw);
2406   LIBXSMM_VLA_DECL(4, const float, output_gb,     output_ptr,     nFm,  ofh, ofw);
2407   LIBXSMM_VLA_DECL(4,       float, doutput_gb,    doutput_ptr,    nFm,  ofh, ofw);
2408   LIBXSMM_VLA_DECL(4,       float, dinput_add,    dinput_add_ptr, nFm, ifh, ifw);
2409 
2410   LIBXSMM_UNUSED(beta_ptr);
2411 
2412 #if defined(_OPENMP)
2413   LIBXSMM_OMP_VAR(hi); LIBXSMM_OMP_VAR(wi); LIBXSMM_OMP_VAR(ho); LIBXSMM_OMP_VAR(wo); LIBXSMM_OMP_VAR(g);
2414 # pragma omp parallel for private(img, fm, hi, wi, ho, wo, g)
2415 #endif
2416   for ( fm = 0; fm < nFm; fm++ ) {
2417     del_gamma_ptr[fm] = 0.0f;
2418     del_beta_ptr[fm] = 0.0f;
2419 
2420     for ( img = 0; img < nImg; img++ ) {
2421       for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2422         for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++ ) {
2423                 float* del_input_add_ptr = &LIBXSMM_VLA_ACCESS(4,    dinput_add, img, fm, hi, wi, nFm, ifh, ifw);
2424           const float  output_val        =  LIBXSMM_VLA_ACCESS(4,     output_gb, img, fm, ho, wo, nFm, ofh, ofw);
2425           const float  input_val         =  LIBXSMM_VLA_ACCESS(4,      input_gb, img, fm, hi, wi, nFm, ifh, ifw);
2426                 float* del_output_ptr    = &LIBXSMM_VLA_ACCESS(4,    doutput_gb, img, fm, ho, wo, nFm, ofh, ofw);
2427 
2428           /* ReLU */
2429           if ( (param->fuse_type == 1) || (param->fuse_type == 3) || (param->fuse_type == 4) || (param->fuse_type == 5) ) {
2430             *del_output_ptr    = (output_val == 0) ? 0 : *del_output_ptr;
2431           }
2432           /* elementwise */
2433           if ( (param->fuse_type == 2) || (param->fuse_type == 3) || (param->fuse_type == 5) ) {
2434             *del_input_add_ptr = *del_output_ptr;
2435           }
2436           g = fm/nFMG;
2437           del_gamma_ptr[fm] += (input_val - expectval_ptr[img*nG+g]) * (*del_output_ptr) * rcpstddev_ptr[img*nG+g];
2438           del_beta_ptr[fm]  += *del_output_ptr;
2439         }
2440       }
2441     }
2442   }
2443 
2444 #if defined(_OPENMP)
2445   LIBXSMM_OMP_VAR(fmg);
2446 # pragma omp parallel for private(img, g, fmg, hi, wi, ho, wo)
2447 #endif
2448   for ( img = 0; img < nImg; img++ ) {
2449     for ( g = 0; g < nG; g++ ) {
2450       float d1_val = 0.0;
2451       float d2_val = 0.0;
2452 
2453       for ( fmg = 0; fmg < nFMG; fmg++ ) {
2454         for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2455           for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++) {
2456             const float  input_val      =  LIBXSMM_VLA_ACCESS(5,      input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2457             const float  del_output_val =  LIBXSMM_VLA_ACCESS(5,    doutput, img, g, fmg, ho, wo, nG, nFMG, ofh, ofw);
2458 
2459             d1_val += del_output_val * (input_val - expectval_ptr[img*nG+g]) * gamma_ptr[g*nFMG+fmg];
2460             d2_val += del_output_val * gamma_ptr[g*nFMG+fmg];
2461           }
2462         }
2463       }
2464 
2465       for ( fmg = 0; fmg < nFMG; fmg++ ) {
2466         for ( hi = 0, ho = 0; hi < ifh; hi += sh, ho++ ) {
2467           for ( wi = 0, wo = 0; wi < ifw; wi += sw, wo++) {
2468             const float  input_val      =  LIBXSMM_VLA_ACCESS(5,      input, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2469             const float  del_output_val =  LIBXSMM_VLA_ACCESS(5,    doutput, img, g, fmg, ho, wo, nG, nFMG, ofh, ofw);
2470                   float* del_input_ptr  = &LIBXSMM_VLA_ACCESS(5,     dinput, img, g, fmg, hi, wi, nG, nFMG, ifh, ifw);
2471 
2472             float t0_val = rcpstddev_ptr[img*nG+g] * recp_ghw;
2473             *del_input_ptr = t0_val * ((gamma_ptr[g*nFMG+fmg] * ghw * del_output_val) - d2_val - ((input_val - expectval_ptr[img*nG+g]) * d1_val * (1.0f/(variance_ptr[img*nG+g]+eps))));
2474           }
2475         }
2476       }
2477     }
2478   }
2479 }
2480 
lstm_fwd_copy_bias(int N,int K,float * bigold,float * bcgold,float * bfgold,float * bogold,float forget_bias,float * icfogoldt,int j)2481 LIBXSMM_INLINE void lstm_fwd_copy_bias(int N, int K, float *bigold, float *bcgold, float *bfgold, float *bogold, float forget_bias, float *icfogoldt, int j)
2482 {
2483   LIBXSMM_VLA_DECL(3, float, icfogold, icfogoldt, N, 4 * K);
2484   int i, l;
2485 #if defined(_OPENMP)
2486   LIBXSMM_OMP_VAR(i); LIBXSMM_OMP_VAR(l);
2487 # pragma omp parallel for private(i, l) LIBXSMM_OPENMP_COLLAPSE(2)
2488 #endif
2489   for (i = 0; i < N; i++) {
2490     for (l = 0; l < K; l++) {
2491       LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l,     N, 4 * K) = bigold[l];
2492       LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l+K,   N, 4 * K) = bcgold[l];
2493       LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l+2*K, N, 4 * K) = bfgold[l] + forget_bias;
2494       LIBXSMM_VLA_ACCESS(3, icfogold, j, i, l+3*K, N, 4 * K) = bogold[l];
2495     }
2496   }
2497 }
2498 
lstm_fwd_eltwise_merged(int N,int K,float * i,float * c,float * f,float * o,float * csp,float * cs,float * co,float * h)2499 LIBXSMM_INLINE void lstm_fwd_eltwise_merged(int N, int K, float *i, float *c, float *f, float *o, float *csp, float *cs, float *co, float *h)
2500 {
2501   int j;
2502 #if defined(__AVX512F__)
2503   int l;
2504   int rem = (K/16)*16;
2505   __m512 minus1 = _mm512_set1_ps (-1.0f);
2506   __m512 plus1  = _mm512_set1_ps (1.0f);
2507 #if defined(_OPENMP)
2508 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2509 #endif
2510   for (j = 0; j < N; j++) {
2511     for (l = 0; l < rem; l+=16) {
2512       __m512 iv   = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(i[j*4*K + l]));
2513       __m512 cv   = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(c[j*4*K + l]));
2514       __m512 fv   = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(f[j*4*K + l]));
2515       __m512 ov   = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(o[j*4*K + l]));
2516       __m512 cspv = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(csp[j*K + l]));
2517       __m512 csv, cov, hv;
2518       /* i = sigmoid(i) */
2519       iv = _mm512_mul_ps (iv, minus1);
2520       iv = LIBXSMM_INTRINSICS_MM512_EXP_PS (iv);
2521       iv = _mm512_add_ps (iv, plus1);
2522       iv = _mm512_div_ps (plus1, iv);
2523       /* c = tanh(c) */
2524       cv = LIBXSMM_INTRINSICS_MM512_TANH_PS (cv);
2525       /* f = sigmoid(f) */
2526       fv = _mm512_mul_ps (fv, minus1);
2527       fv = LIBXSMM_INTRINSICS_MM512_EXP_PS (fv);
2528       fv = _mm512_add_ps (fv, plus1);
2529       fv = _mm512_div_ps (plus1, fv);
2530       /* o = sigmoid(o) */
2531       ov = _mm512_mul_ps (ov, minus1);
2532       ov = LIBXSMM_INTRINSICS_MM512_EXP_PS (ov);
2533       ov = _mm512_add_ps (ov, plus1);
2534       ov = _mm512_div_ps (plus1, ov);
2535       /* cs = f.csp + i.c */
2536       csv = _mm512_mul_ps (fv, cspv);
2537       csv = _mm512_fmadd_ps (iv, cv, csv);
2538       /* co = tanh(cs) */
2539       cov = LIBXSMM_INTRINSICS_MM512_TANH_PS (csv);
2540       /* h = o.co */
2541       hv = _mm512_mul_ps (ov, cov);
2542       _mm512_storeu_ps (&(i[j*4*K + l]), iv);
2543       _mm512_storeu_ps (&(c[j*4*K + l]), cv);
2544       _mm512_storeu_ps (&(f[j*4*K + l]), fv);
2545       _mm512_storeu_ps (&(o[j*4*K + l]), ov);
2546       _mm512_storeu_ps (&(cs[j*K + l]),  csv);
2547       _mm512_storeu_ps (&(co[j*K + l]),  cov);
2548       _mm512_storeu_ps (&(h[j*K + l]),   hv);
2549     }
2550   }
2551 #if defined(_OPENMP)
2552 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2553 #endif
2554   for (j = 0; j < N; j++) {
2555     for (l = rem; l < K; l++) {
2556       float exp_value;
2557       /* i = sigmoid(i) */
2558       exp_value = (float)exp((double) -i[j*4*K + l]);
2559       i[j*4*K + l] = 1.0f / (1.0f + exp_value);
2560       /* c = tanh(c) */
2561       c[j*4*K + l] = (float)tanh((double)c[j*4*K + l]);
2562       /* f = sigmoid(f) */
2563       exp_value = (float)exp((double) -f[j*4*K + l]);
2564       f[j*4*K + l] = 1.0f / (1.0f + exp_value);
2565       /* o = sigmoid(o) */
2566       exp_value = (float)exp((double) -o[j*4*K + l]);
2567       o[j*4*K + l] = 1.0f / (1.0f + exp_value);
2568       /* cs = f.csp + i.c */
2569       cs[j*K + l] = f[j*4*K + l]*csp[j*K + l] + i[j*4*K + l]*c[j*4*K + l];
2570       /* co = tanh(cs) */
2571       co[j*K + l] = (float)tanh((double)cs[j*K + l]);
2572       /* h = o.co */
2573       h[j*K + l] = o[j*4*K + l] * co[j*K + l];
2574     }
2575   }
2576 #else
2577 #if defined(_OPENMP)
2578 # pragma omp parallel for private(j)
2579 #endif
2580   for (j = 0; j < N*K; j++) {
2581     const int row = j / K;
2582     const int col = j % K;
2583     float exp_value;
2584     /* i = sigmoid(i) */
2585     exp_value = (float)exp((double) -i[row*4*K + col]);
2586     i[row*4*K + col] = 1.0f / (1.0f + exp_value);
2587     /* c = tanh(c) */
2588     c[row*4*K + col] = (float)tanh((double)c[row*4*K + col]);
2589     /* f = sigmoid(f) */
2590     exp_value = (float)exp((double) -f[row*4*K + col]);
2591     f[row*4*K + col] = 1.0f / (1.0f + exp_value);
2592     /* o = sigmoid(o) */
2593     exp_value = (float)exp((double) -o[row*4*K + col]);
2594     o[row*4*K + col] = 1.0f / (1.0f + exp_value);
2595     /* cs = f.csp + i.c */
2596     cs[j] = f[row*4*K + col]*csp[j] + i[row*4*K + col]*c[row*4*K + col];
2597     /* co = tanh(cs) */
2598     co[j] = (float)tanh((double)cs[j]);
2599     /* h = o.co */
2600     h[j] = o[row*4*K + col] * co[j];
2601   }
2602 #endif
2603 }
2604 
lstm_bwd_upd_eltwise_merged(int N,int K,float * i,float * c,float * f,float * o,float * csp,float * co,float * dh,float * dout,float * di,float * dc,float * df,float * dp,float * dcsp,float * dcs)2605 LIBXSMM_INLINE void lstm_bwd_upd_eltwise_merged(int N, int K, float *i, float *c, float *f, float *o, float *csp, float *co,
2606                                                 float *dh, float *dout, float *di, float *dc, float *df, float *dp, float *dcsp, float *dcs)
2607 {
2608   int j;
2609 #if defined(__AVX512F__)
2610   int l;
2611   int rem = (K/16)*16;
2612   __m512 plus1  = _mm512_set1_ps (1.0f);
2613 #if defined(_OPENMP)
2614 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2615 #endif
2616   for (j = 0; j < N; j++) {
2617     for (l = 0; l < rem; l+=16) {
2618       __m512 iv       = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(i[j*4*K + l]));
2619       __m512 cv       = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(c[j*4*K + l]));
2620       __m512 fv       = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(f[j*4*K + l]));
2621       __m512 ov       = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(o[j*4*K + l]));
2622       __m512 cspv     = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(csp[j*K + l]));
2623       __m512 cov      = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(co[j*K + l]));
2624       __m512 dcsv     = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dcs[j*K + l]));
2625       __m512 dhv, doutv, div, dcv, dfv, dov, dcspv, deltav, tv;
2626       /* compute delta */
2627       if (NULL == dout) {
2628         deltav = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dh[j*K + l]));
2629       } else {
2630         dhv    = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dh[j*K + l]));
2631         doutv  = LIBXSMM_INTRINSICS_MM512_LOAD_PS (&(dout[j*K + l]));
2632         deltav = _mm512_add_ps (dhv, doutv);
2633       }
2634       /* compute dcsp */
2635       /* dcsp = delta.o.(1 - (co.co)) + dcs */
2636       tv    = _mm512_mul_ps (cov, cov);
2637       tv    = _mm512_sub_ps (plus1, tv);
2638       dcspv = _mm512_mul_ps (deltav, ov);
2639       dcspv = _mm512_fmadd_ps (dcspv, tv, dcsv);
2640       /* compute di */
2641       /* di = dcsp.c.i.(1 - i) */
2642       tv  = _mm512_sub_ps (plus1, iv);
2643       tv  = _mm512_mul_ps (iv, tv);
2644       div = _mm512_mul_ps (dcspv, cv);
2645       div = _mm512_mul_ps (div, tv);
2646       /* compute dc */
2647       /* dc = dcsp.i.(1 - (c.c)) */
2648       tv  = _mm512_mul_ps (cv, cv);
2649       tv  = _mm512_sub_ps (plus1, tv);
2650       dcv = _mm512_mul_ps (dcspv, iv);
2651       dcv = _mm512_mul_ps (dcv, tv);
2652       /* compute df */
2653       /* df = dcsp.csp.f.(1 - f) */
2654       tv  = _mm512_sub_ps (plus1, fv);
2655       tv  = _mm512_mul_ps (fv, tv);
2656       dfv = _mm512_mul_ps (dcspv, cspv);
2657       dfv = _mm512_mul_ps (dfv, tv);
2658       /* compute do */
2659       /* do = delta.co.o.(1 - o) */
2660       tv  = _mm512_sub_ps (plus1, ov);
2661       tv  = _mm512_mul_ps (ov, tv);
2662       dov = _mm512_mul_ps (deltav, cov);
2663       dov = _mm512_mul_ps (dov, tv);
2664       /* update dcsp */
2665       /* dcsp = dcsp.f */
2666       dcspv = _mm512_mul_ps (dcspv, fv);
2667       _mm512_storeu_ps (&(di[j*4*K + l]), div);
2668       _mm512_storeu_ps (&(dc[j*4*K + l]), dcv);
2669       _mm512_storeu_ps (&(df[j*4*K + l]), dfv);
2670       _mm512_storeu_ps (&(dp[j*4*K + l]), dov);
2671       _mm512_storeu_ps (&(dcsp[j*K + l]), dcspv);
2672     }
2673   }
2674 #if defined(_OPENMP)
2675 # pragma omp parallel for private(j, l) LIBXSMM_OPENMP_COLLAPSE(2)
2676 #endif
2677   for (j = 0; j < N; j++) {
2678     for (l = rem; l < K; l++) {
2679       float delta;
2680       /* compute delta */
2681       if (NULL == dout) {
2682         delta = dh[j*K + l];
2683       } else {
2684         delta = dh[j*K + l] + dout[j*K + l];
2685       }
2686       /* compute dcsp */
2687       dcsp[j*K + l] = delta * o[j*4*K + l] * (1.0f - (co[j*K + l]*co[j*K + l])) + dcs[j*K + l];
2688       /* compute di */
2689       di[j*4*K + l] = dcsp[j*K + l] * c[j*4*K + l] * i[j*4*K + l] * (1.0f - i[j*4*K + l]);
2690       /* compute dc */
2691       dc[j*4*K + l] = dcsp[j*K + l] * i[j*4*K + l] * (1.0f - (c[j*4*K + l]*c[j*4*K + l]));
2692       /* compute df */
2693       df[j*4*K + l] = dcsp[j*K + l] * csp[j*K + l] * f[j*4*K + l] * (1.0f - f[j*4*K + l]);
2694       /* compute do */
2695       dp[j*4*K + l] = delta * co[j*K + l] * o[j*4*K + l] * (1.0f - o[j*4*K + l]);
2696       /* update dcsp */
2697       dcsp[j*K + l] = dcsp[j*K + l] * f[j*4*K + l];
2698     }
2699   }
2700 #else
2701 #if defined(_OPENMP)
2702 # pragma omp parallel for private(j)
2703 #endif
2704   for (j = 0; j < N*K; j++) {
2705     const int row = j / K;
2706     const int col = j % K;
2707     float delta;
2708     /* compute delta */
2709     if (NULL == dout) {
2710       delta = dh[j];
2711     } else {
2712       delta = dh[j] + dout[j];
2713     }
2714     /* compute dcsp */
2715     dcsp[j] = delta * o[row*4*K + col] * (1.0f - (co[j]*co[j])) + dcs[j];
2716     /* compute di */
2717     di[row*4*K + col] = dcsp[j] * c[row*4*K + col] * i[row*4*K + col] * (1.0f - i[row*4*K + col]);
2718     /* compute dc */
2719     dc[row*4*K + col] = dcsp[j] * i[row*4*K + col] * (1.0f - (c[row*4*K + col]*c[row*4*K + col]));
2720     /* compute df */
2721     df[row*4*K + col] = dcsp[j] * csp[j] * f[row*4*K + col] * (1.0f - f[row*4*K + col]);
2722     /* compute do */
2723     dp[row*4*K + col] = delta * co[j] * o[row*4*K + col] * (1.0f - o[row*4*K + col]);
2724     /* update dcsp */
2725     dcsp[j] = dcsp[j] * f[row*4*K + col];
2726   }
2727 #endif
2728 }
2729 
lstm_ref_fwd(int N,int C,int K,int t,float forget_bias,float * wigold,float * wcgold,float * wfgold,float * wogold,float * rigold,float * rcgold,float * rfgold,float * rogold,float * bigold,float * bcgold,float * bfgold,float * bogold,float * xgoldt,float * cspgold,float * hpgold,float * csgoldt,float * cogoldt,float * hgoldt,float * icfogoldt,float * wgold,float * rgold,float * scratch)2730 LIBXSMM_INLINE void lstm_ref_fwd( int N, int C, int K, int t, float forget_bias,
2731                    float *wigold, float *wcgold, float *wfgold, float *wogold,
2732                    float *rigold, float *rcgold, float *rfgold, float *rogold,
2733                    float *bigold, float *bcgold, float *bfgold, float *bogold,
2734                    float *xgoldt, float *cspgold, float *hpgold,
2735                    float *csgoldt, float *cogoldt, float *hgoldt,
2736                    float *icfogoldt, float *wgold, float *rgold, float *scratch )
2737 {
2738 #if !defined(TWO_GEMMS)
2739   float *xhgold = scratch;
2740 #endif
2741   const char transa = 'N', transb = 'N';   /* no transposes */
2742   const float alpha = 1, beta = 1;
2743   int j;
2744   int K4 = K * 4;
2745   int CK = C + K;
2746   LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N * C);
2747   LIBXSMM_VLA_DECL(2, float, csgold, csgoldt, K * N);
2748   LIBXSMM_VLA_DECL(2, float, cogold, cogoldt, K * N);
2749   LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, K * N);
2750   LIBXSMM_VLA_DECL(3, float, icfogold, icfogoldt, N, 4 * K);
2751 #if defined(PROFILE)
2752   Gbl_conv_start = libxsmm_timer_tick();
2753 #endif
2754 #if defined(TWO_GEMMS)
2755   convert_ck_c4k(C, K, wigold, wgold);
2756   convert_ck_c4k(C, K, wcgold, &(wgold[K]));
2757   convert_ck_c4k(C, K, wfgold, &(wgold[2*K]));
2758   convert_ck_c4k(C, K, wogold, &(wgold[3*K]));
2759   convert_ck_c4k(K, K, rigold, rgold);
2760   convert_ck_c4k(K, K, rcgold, &(rgold[K]));
2761   convert_ck_c4k(K, K, rfgold, &(rgold[2*K]));
2762   convert_ck_c4k(K, K, rogold, &(rgold[3*K]));
2763 #else
2764   LIBXSMM_UNUSED(rgold);
2765   convert_ck_c4k(C, K, wigold, wgold);
2766   convert_ck_c4k(C, K, wcgold, &(wgold[K]));
2767   convert_ck_c4k(C, K, wfgold, &(wgold[2*K]));
2768   convert_ck_c4k(C, K, wogold, &(wgold[3*K]));
2769   convert_ck_c4k(K, K, rigold, &(wgold[C*K*4]));
2770   convert_ck_c4k(K, K, rcgold, &(wgold[C*K*4 + K]));
2771   convert_ck_c4k(K, K, rfgold, &(wgold[C*K*4 + 2*K]));
2772   convert_ck_c4k(K, K, rogold, &(wgold[C*K*4 + 3*K]));
2773 #endif
2774 #if defined(PROFILE)
2775   Gbl_conv_end = libxsmm_timer_tick();
2776   Gbl_conv_total += libxsmm_timer_duration(Gbl_conv_start, Gbl_conv_end);
2777 #endif
2778   for (j = 0; j < t; ++j) {
2779     /* Initialization with bias */
2780 #if defined(PROFILE)
2781     Gbl_copy_bias_start = libxsmm_timer_tick();
2782 #endif
2783     lstm_fwd_copy_bias(N, K, bigold, bcgold, bfgold, bogold, forget_bias, icfogoldt, j);
2784 #if defined(PROFILE)
2785     Gbl_copy_bias_end = libxsmm_timer_tick();
2786     Gbl_copy_bias_total += libxsmm_timer_duration(Gbl_copy_bias_start, Gbl_copy_bias_end);
2787     Gbl_blas_start = libxsmm_timer_tick();
2788 #endif
2789 #if defined(TWO_GEMMS)
2790     /* icfo += W * x */
2791     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &C, &alpha, wgold, &K4, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K), &K4);
2792     /* icfo += R * h */
2793     if (j == 0) {
2794       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &K, &alpha, rgold, &K4, hpgold, &K, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 0, N, 4 * K), &K4);
2795     } else {
2796       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &K, &alpha, rgold, &K4, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K), &K4);
2797     }
2798 #else
2799     /* Concatenate x and h */
2800     convert_nk_nck(N, C, C+K, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), xhgold);
2801     if (j == 0) {
2802       convert_nk_nck(N, K, C+K, hpgold, &(xhgold[C]));
2803     } else {
2804       convert_nk_nck(N, K, C+K, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &(xhgold[C]));
2805     }
2806     /* icfo += (W * x) + (R * h) */
2807     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K4, &N, &CK, &alpha, wgold, &K4, xhgold, &CK, &beta, &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0, N, 4 * K), &K4);
2808 #endif
2809 #if defined(PROFILE)
2810     Gbl_blas_end = libxsmm_timer_tick();
2811     Gbl_blas_total += libxsmm_timer_duration(Gbl_blas_start, Gbl_blas_end);
2812     Gbl_eltwise_start = libxsmm_timer_tick();
2813 #endif
2814     if (j == 0) {
2815       lstm_fwd_eltwise_merged( N, K,
2816                                &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 0,   N, 4 * K),
2817                                &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, K,   N, 4 * K),
2818                                &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 2*K, N, 4 * K),
2819                                &LIBXSMM_VLA_ACCESS(3, icfogold, 0, 0, 3*K, N, 4 * K),
2820                                cspgold,
2821                                &LIBXSMM_VLA_ACCESS(2, csgold, 0, 0, K * N),
2822                                &LIBXSMM_VLA_ACCESS(2, cogold, 0, 0, K * N),
2823                                &LIBXSMM_VLA_ACCESS(2, hgold, 0, 0, K * N) );
2824     } else {
2825       lstm_fwd_eltwise_merged( N, K,
2826                                &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0,   N, 4 * K),
2827                                &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, K,   N, 4 * K),
2828                                &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 2*K, N, 4 * K),
2829                                &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 3*K, N, 4 * K),
2830                                &LIBXSMM_VLA_ACCESS(2, csgold, j-1, 0, K * N),
2831                                &LIBXSMM_VLA_ACCESS(2, csgold, j, 0, K * N),
2832                                &LIBXSMM_VLA_ACCESS(2, cogold, j, 0, K * N),
2833                                &LIBXSMM_VLA_ACCESS(2, hgold, j, 0, K * N) );
2834     }
2835 #if defined(PROFILE)
2836     Gbl_eltwise_end = libxsmm_timer_tick();
2837     Gbl_eltwise_total += libxsmm_timer_duration(Gbl_eltwise_start, Gbl_eltwise_end);
2838 #endif
2839   }
2840 }
2841 
lstm_ref_bwd_upd(int N,int C,int K,int t,float * xgoldt,float * cspgold,float * hpgold,float * csgoldt,float * cogoldt,float * hgoldt,float * icfogoldt,float * wgold,float * rgold,float * dcsgold,float * dhgoldt,float * dwgold,float * drgold,float * dbgold,float * dxgoldt,float * dcspgold,float * dhpgold,float * scratch)2842 LIBXSMM_INLINE void lstm_ref_bwd_upd( int N, int C, int K, int t,
2843                        float *xgoldt, float *cspgold, float *hpgold,
2844                        float *csgoldt, float *cogoldt, float *hgoldt,
2845                        float *icfogoldt, float *wgold, float *rgold,
2846                        float *dcsgold, float *dhgoldt,
2847                        float *dwgold, float *drgold, float *dbgold,
2848                        float *dxgoldt, float *dcspgold, float *dhpgold, float *scratch )
2849 {
2850 #if !defined(TWO_GEMMS)
2851   float *xhgold   = &(scratch[K*N*t*5]);
2852   float *dxhgold  = &(scratch[K*N*t*5 + (C+K)*N]);
2853 #endif
2854   float *dicfogoldt = scratch;
2855   float *doutgoldt  = &(scratch[K*N*t*4]);
2856   float *dout, *dcs, *csp;
2857   const char transa = 'N', transb = 'N';   /* no transposes */
2858   const char transaT = 'T', transbT = 'T'; /* transposes */
2859   const float alpha = 1, beta = 1, beta0 = 0;
2860   int j, l, p;
2861   int K4 = K * 4;
2862   int CK = C + K;
2863   LIBXSMM_VLA_DECL(2, float, xgold, xgoldt, N * C);
2864   LIBXSMM_VLA_DECL(2, float, csgold, csgoldt, K * N);
2865   LIBXSMM_VLA_DECL(2, float, cogold, cogoldt, K * N);
2866   LIBXSMM_VLA_DECL(2, float, hgold, hgoldt, K * N);
2867   LIBXSMM_VLA_DECL(3, float, icfogold, icfogoldt, N, 4 * K);
2868   LIBXSMM_VLA_DECL(2, float, dxgold, dxgoldt, N * C);
2869   LIBXSMM_VLA_DECL(2, float, dhgold, dhgoldt, K * N);
2870   LIBXSMM_VLA_DECL(3, float, dicfogold, dicfogoldt, N, 4 * K);
2871   LIBXSMM_VLA_DECL(2, float, doutgold, doutgoldt, K * N);
2872   for (j = t-1; j >= 0; --j) {
2873 #if defined(PROFILE)
2874     Gbl_eltwise_start = libxsmm_timer_tick();
2875 #endif
2876     if (t-1 == j) {
2877       dout = NULL;
2878       dcs = dcsgold;
2879     } else {
2880       dout = &LIBXSMM_VLA_ACCESS(2, doutgold, j, 0, K * N);
2881       dcs = dcspgold;
2882     }
2883     if (0 == j) {
2884       csp = cspgold;
2885     } else {
2886       csp = &LIBXSMM_VLA_ACCESS(2, csgold, j-1, 0, K * N);
2887     }
2888     lstm_bwd_upd_eltwise_merged( N, K,
2889                                  &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 0,   N, 4 * K),
2890                                  &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, K,   N, 4 * K),
2891                                  &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 2*K, N, 4 * K),
2892                                  &LIBXSMM_VLA_ACCESS(3, icfogold, j, 0, 3*K, N, 4 * K),
2893                                  csp,
2894                                  &LIBXSMM_VLA_ACCESS(2, cogold, j, 0, K * N),
2895                                  &LIBXSMM_VLA_ACCESS(2, dhgold, j, 0, K * N),
2896                                  dout,
2897                                  &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0,   N, 4 * K),
2898                                  &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, K,   N, 4 * K),
2899                                  &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 2*K, N, 4 * K),
2900                                  &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 3*K, N, 4 * K),
2901                                  dcspgold, dcs);
2902 #if defined(PROFILE)
2903     Gbl_eltwise_end = libxsmm_timer_tick();
2904     Gbl_eltwise_total += libxsmm_timer_duration(Gbl_eltwise_start, Gbl_eltwise_end);
2905     Gbl_blas_start = libxsmm_timer_tick();
2906 #endif
2907 #if defined(TWO_GEMMS)
2908     if (j > 0) {
2909       /* compute dout */
2910       LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K4, &alpha, rgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &beta0, &LIBXSMM_VLA_ACCESS(2, doutgold, j-1, 0, K * N), &K);
2911     } else {
2912       /* compute dhp */
2913       LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K4, &alpha, rgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, 0, 0, 0, N, 4 * K), &K4, &beta0, dhpgold, &K);
2914     }
2915 
2916     /* compute dx */
2917     LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K4, &alpha, wgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &beta, &LIBXSMM_VLA_ACCESS(2, dxgold, j, 0, N * C), &C);
2918 
2919     /* compute dw */
2920     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &C, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), &C, &beta, dwgold, &K4);
2921 
2922     /* compute dr */
2923     if (j == 0) {
2924       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &K, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, hpgold, &K, &beta, drgold, &K4);
2925     } else {
2926       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &K, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &K, &beta, drgold, &K4);
2927     }
2928 #else
2929     LIBXSMM_UNUSED(rgold); LIBXSMM_UNUSED(drgold);
2930     LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &CK, &N, &K4, &alpha, wgold, &K4, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, &beta0, dxhgold, &CK);
2931     matrix_copy_ld(C, N, C+K, dxhgold, &LIBXSMM_VLA_ACCESS(2, dxgold, j, 0, N * C));
2932     if (j > 0) {
2933       matrix_copy_ld(K, N, C+K, &(dxhgold[C]), &LIBXSMM_VLA_ACCESS(2, doutgold, j-1, 0, K * N));
2934     } else {
2935       matrix_copy_ld(K, N, C+K, &(dxhgold[C]), dhpgold);
2936     }
2937 
2938     /* Concatenate x and h */
2939     convert_nk_nck(N, C, C+K, &LIBXSMM_VLA_ACCESS(2, xgold, j, 0, N * C), xhgold);
2940     if (j == 0) {
2941       convert_nk_nck(N, K, C+K, hpgold, &(xhgold[C]));
2942     } else {
2943       convert_nk_nck(N, K, C+K, &LIBXSMM_VLA_ACCESS(2, hgold, j-1, 0, K * N), &(xhgold[C]));
2944     }
2945     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K4, &CK, &N, &alpha, &LIBXSMM_VLA_ACCESS(3, dicfogold, j, 0, 0, N, 4 * K), &K4, xhgold, &CK, &beta, dwgold, &K4);
2946 #endif
2947 #if defined(PROFILE)
2948     Gbl_blas_end = libxsmm_timer_tick();
2949     Gbl_blas_total += libxsmm_timer_duration(Gbl_blas_start, Gbl_blas_end);
2950 #endif
2951     /* compute db */
2952 #if defined(_OPENMP)
2953     LIBXSMM_OMP_VAR(p);
2954 # pragma omp parallel for private(l, p)
2955 #endif
2956     for (l = 0; l < K; l++) {
2957       for (p = 0; p < N; p++) {
2958         dbgold[l]       += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l,       N, 4 * K);
2959         dbgold[l + K]   += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l + K,   N, 4 * K);
2960         dbgold[l + 2*K] += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l + 2*K, N, 4 * K);
2961         dbgold[l + 3*K] += LIBXSMM_VLA_ACCESS(3, dicfogold, j, p, l + 3*K, N, 4 * K);
2962       }
2963     }
2964   }
2965 }
2966 
gru_ref_fwd(int N,int C,int K,int t,float * wi,float * wc,float * wf,float * ri,float * rc,float * rf,float * bi,float * bc,float * bf,float * xt,float * hp,float * ht,float * it,float * ct,float * ft,float * ot)2967 LIBXSMM_INLINE void gru_ref_fwd( int N, int C, int K, int t,
2968                   float *wi, float *wc, float *wf,
2969                   float *ri, float *rc, float *rf,
2970                   float *bi, float *bc, float *bf,
2971                   float *xt, float *hp, float *ht,
2972                   float *it, float *ct, float *ft, float *ot )
2973 {
2974   const char transa = 'N', transb = 'N';   /* no transposes */
2975   const float alpha = 1, beta = 1;
2976   int j;
2977   LIBXSMM_VLA_DECL(2, float, x, xt, N * C);
2978   LIBXSMM_VLA_DECL(2, float, h, ht, K * N);
2979   LIBXSMM_VLA_DECL(2, float, i, it, K * N);
2980   LIBXSMM_VLA_DECL(2, float, c, ct, K * N);
2981   LIBXSMM_VLA_DECL(2, float, f, ft, K * N);
2982   LIBXSMM_VLA_DECL(2, float, o, ot, K * N);
2983   for (j = 0; j < t; ++j) {
2984     /* i_t = b_i */
2985     matrix_copy_bias(K, N, K, bi, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N));
2986     /* i_t += W_i * x_t */
2987     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wi, &K, &LIBXSMM_VLA_ACCESS(2, x, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &K);
2988     /* i_t += R_i * h_{t-1} */
2989     if (0 == j) {
2990       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ri, &K, hp,                                       &K, &beta, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &K);
2991     } else {
2992       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, ri, &K, &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &K);
2993     }
2994     /* i_t = sigmoid(i_t) */
2995     matrix_sigmoid(N*K, &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N));
2996     /* c_t = b_c */
2997     matrix_copy_bias(K, N, K, bc, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N));
2998     /* c_t += W_c * x_t */
2999     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wc, &K, &LIBXSMM_VLA_ACCESS(2, x, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &K);
3000     /* c_t += R_c * h_{t-1} */
3001     if (0 == j) {
3002       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, rc, &K, hp,                                       &K, &beta, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &K);
3003     } else {
3004       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, rc, &K, &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &K);
3005     }
3006     /* c_t = sigmoid(c_t) */
3007     matrix_sigmoid(N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N));
3008     /* o_t = h_{t-1} . i_t */
3009     if (0 == j) {
3010       matrix_eltwise_mult(N*K, hp,                                       &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, o, j, 0, K * N));
3011     } else {
3012       matrix_eltwise_mult(N*K, &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &LIBXSMM_VLA_ACCESS(2, i, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, o, j, 0, K * N));
3013     }
3014     /* f_t = b_f */
3015     matrix_copy_bias(K, N, K, bf, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N));
3016     /* f_t += W_f * x_t */
3017     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &C, &alpha, wf, &K, &LIBXSMM_VLA_ACCESS(2, x, j, 0, N * C), &C, &beta, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &K);
3018     /* f_t += R_f * o_t */
3019     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transb, &K, &N, &K, &alpha, rf, &K, &LIBXSMM_VLA_ACCESS(2, o, j, 0, K * N), &K, &beta, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &K);
3020     /* f_t = tanh(f_t) */
3021     matrix_tanh(N*K, &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N));
3022     /* h_t = (1 - c_t) . f_t */
3023     matrix_complement  (N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3024     matrix_eltwise_mult(N*K, &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, f, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3025     /* h_t += c_t . h_{t-1} */
3026     if (0 == j) {
3027       matrix_eltwise_fma(N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), hp,                                       &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3028     } else {
3029       matrix_eltwise_fma(N*K, &LIBXSMM_VLA_ACCESS(2, c, j, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j-1, 0, K * N), &LIBXSMM_VLA_ACCESS(2, h, j, 0, K * N));
3030     }
3031   }
3032 }
3033 
gru_ref_bwd_upd(int N,int C,int K,int t,float * xt,float * hpD,float * ht,float * it,float * ct,float * ft,float * ot,float * wi,float * wc,float * wf,float * ri,float * rc,float * rf,float * dht,float * dw,float * dr,float * db,float * dxt,float * dhpD,float * scratch)3034 LIBXSMM_INLINE void gru_ref_bwd_upd( int N, int C, int K, int t,
3035                       float *xt,  float *hpD,  float *ht,
3036                       float *it,  float *ct,   float *ft, float *ot,
3037                       float *wi,  float *wc,   float *wf,
3038                       float *ri,  float *rc,   float *rf,
3039                       float *dht, float *dw,   float *dr, float *db,
3040                       float *dxt, float *dhpD, float *scratch )
3041 {
3042   const char transa = 'N', transb = 'N';   /* no transposes */
3043   const char transaT = 'T', transbT = 'T'; /* transposes */
3044   const float alpha = 1, beta = 1, beta0 = 0;
3045   int j, l, p;
3046   float *dwi = dw;
3047   float *dwc = &(dw[C*K]);
3048   float *dwf = &(dw[2*C*K]);
3049   float *dri = dr;
3050   float *drc = &(dr[K*K]);
3051   float *drf = &(dr[2*K*K]);
3052   float *dbi = db;
3053   float *dbc = &(db[K]);
3054   float *dbf = &(db[2*K]);
3055   float *deltaD = scratch;
3056   float *doutD  = &(scratch[N*K]);
3057   float *diD    = &(scratch[2*N*K]);
3058   float *dcD    = &(scratch[3*N*K]);
3059   float *dfD    = &(scratch[4*N*K]);
3060   float *doD    = &(scratch[5*N*K]);
3061   LIBXSMM_VLA_DECL(3, float, x,     xt,     N, C);
3062   LIBXSMM_VLA_DECL(2, float, hp,    hpD,    K);
3063   LIBXSMM_VLA_DECL(3, float, h,     ht,     N, K);
3064   LIBXSMM_VLA_DECL(3, float, i,     it,     N, K);
3065   LIBXSMM_VLA_DECL(3, float, c,     ct,     N, K);
3066   LIBXSMM_VLA_DECL(3, float, f,     ft,     N, K);
3067   LIBXSMM_VLA_DECL(3, float, o,     ot,     N, K);
3068   LIBXSMM_VLA_DECL(3, float, dx,    dxt,    N, C);
3069   LIBXSMM_VLA_DECL(2, float, dhp,   dhpD,   K);
3070   LIBXSMM_VLA_DECL(3, float, dh,    dht,    N, K);
3071   LIBXSMM_VLA_DECL(2, float, di,    diD,    K);
3072   LIBXSMM_VLA_DECL(2, float, dc,    dcD,    K);
3073   LIBXSMM_VLA_DECL(2, float, df,    dfD,    K);
3074   LIBXSMM_VLA_DECL(2, float, dp,    doD,    K);
3075   LIBXSMM_VLA_DECL(2, float, dout,  doutD,  K);
3076   LIBXSMM_VLA_DECL(2, float, delta, deltaD, K);
3077   for (j = t-1; j >= 0; j--) {
3078 #if defined(_OPENMP)
3079     LIBXSMM_OMP_VAR(p);
3080 #   pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3081 #endif
3082     for (l = 0; l < N; l++) {
3083       for (p = 0; p < K; p++) {
3084         if (t-1 == j) {
3085           LIBXSMM_VLA_ACCESS(2, delta, l, p, K) = LIBXSMM_VLA_ACCESS(3, dh, t-1, l, p, N, K);
3086         } else {
3087           LIBXSMM_VLA_ACCESS(2, delta, l, p, K) = LIBXSMM_VLA_ACCESS(3, dh, j,   l, p, N, K) + LIBXSMM_VLA_ACCESS(2, dout, l, p, K);
3088         }
3089         /* df = delta . (1 - c_t) . (1 - (f_t . f_t)) */
3090         LIBXSMM_VLA_ACCESS(2, df, l, p, K) = LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K)) * (1.0f - (LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K) * LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K)));
3091         /* dc = delta . (h_{t-1} - f_t) . c_t . (1 - c_t) */
3092         if (0 == j) {
3093           LIBXSMM_VLA_ACCESS(2, dc, l, p, K) = LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * (LIBXSMM_VLA_ACCESS(2, hp, l, p, K) -        LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K)) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K));
3094         } else {
3095           LIBXSMM_VLA_ACCESS(2, dc, l, p, K) = LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * (LIBXSMM_VLA_ACCESS(3, h, j-1, l, p, N, K) - LIBXSMM_VLA_ACCESS(3, f, j, l, p, N, K)) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K));
3096         }
3097       }
3098     }
3099     /* do = {R_f}^T * df */
3100     LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, rf, &K, dfD, &K, &beta0, doD, &K);
3101     /* di = do . h_{t-1} . i_t . (1 - i_t) */
3102     if (0 == j) {
3103 #if defined(_OPENMP)
3104 #     pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3105 #endif
3106       for (l = 0; l < N; l++) {
3107         for (p = 0; p < K; p++) {
3108           LIBXSMM_VLA_ACCESS(2, di, l, p, K) = LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(2, hp, l, p, K)        * LIBXSMM_VLA_ACCESS(3, i, 0, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, i, 0, l, p, N, K));
3109         }
3110       }
3111     } else {
3112 #if defined(_OPENMP)
3113 #     pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3114 #endif
3115       for (l = 0; l < N; l++) {
3116         for (p = 0; p < K; p++) {
3117           LIBXSMM_VLA_ACCESS(2, di, l, p, K) = LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(3, h, j-1, l, p, N, K) * LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K) * (1.0f - LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K));
3118         }
3119       }
3120     }
3121     /* dx_t  = {W_i}^T * di */
3122     LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K, &alpha, wi, &K, diD, &K, &beta0, &LIBXSMM_VLA_ACCESS(3, dx, j, 0, 0, N, C), &C);
3123     /* dx_t += {W_c}^T * dc */
3124     LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K, &alpha, wc, &K, dcD, &K, &beta,  &LIBXSMM_VLA_ACCESS(3, dx, j, 0, 0, N, C), &C);
3125     /* dx_t += {W_f}^T * df */
3126     LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &C, &N, &K, &alpha, wf, &K, dfD, &K, &beta,  &LIBXSMM_VLA_ACCESS(3, dx, j, 0, 0, N, C), &C);
3127     /* dh_{t-1}  = {R_i}^T * di */
3128     /* dh_{t-1} += {R_c}^T * dc */
3129     if (0 == j) {
3130       LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, ri, &K, diD, &K, &beta0, dhpD, &K);
3131       LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, rc, &K, dcD, &K, &beta,  dhpD, &K);
3132     } else {
3133       LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, ri, &K, diD, &K, &beta0, doutD, &K);
3134       LIBXSMM_XBLAS_SYMBOL(float)(&transaT, &transb, &K, &N, &K, &alpha, rc, &K, dcD, &K, &beta,  doutD, &K);
3135     }
3136     /* dh_{t-1} += do * i_t + delta * c_t */
3137     if (0 == j) {
3138 #if defined(_OPENMP)
3139 #     pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3140 #endif
3141       for (l = 0; l < N; l++) {
3142         for (p = 0; p < K; p++) {
3143           LIBXSMM_VLA_ACCESS(2, dhp,  l, p, K) += LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K) + LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K);
3144         }
3145       }
3146     } else {
3147 #if defined(_OPENMP)
3148 #     pragma omp parallel for private(l, p) LIBXSMM_OPENMP_COLLAPSE(2)
3149 #endif
3150       for (l = 0; l < N; l++) {
3151         for (p = 0; p < K; p++) {
3152           LIBXSMM_VLA_ACCESS(2, dout, l, p, K) += LIBXSMM_VLA_ACCESS(2, dp, l, p, K) * LIBXSMM_VLA_ACCESS(3, i, j, l, p, N, K) + LIBXSMM_VLA_ACCESS(2, delta, l, p, K) * LIBXSMM_VLA_ACCESS(3, c, j, l, p, N, K);
3153         }
3154       }
3155     }
3156     /* dw_i += di * {x_t}^T */
3157     /* dw_c += dc * {x_t}^T */
3158     /* dw_f += df * {x_t}^T */
3159     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &C, &N, &alpha, diD, &K, &LIBXSMM_VLA_ACCESS(3, x, j, 0, 0, N, C), &C, &beta, dwi, &K);
3160     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &C, &N, &alpha, dcD, &K, &LIBXSMM_VLA_ACCESS(3, x, j, 0, 0, N, C), &C, &beta, dwc, &K);
3161     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &C, &N, &alpha, dfD, &K, &LIBXSMM_VLA_ACCESS(3, x, j, 0, 0, N, C), &C, &beta, dwf, &K);
3162     /* dr_i += di * {o_t}^T */
3163     /* dr_c += dc * {o_t}^T */
3164     /* dr_f += df * {h_{t-1}}^T */
3165     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, diD, &K, &LIBXSMM_VLA_ACCESS(3, o, j, 0, 0, N, K), &K, &beta, dri, &K);
3166     LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, dcD, &K, &LIBXSMM_VLA_ACCESS(3, o, j, 0, 0, N, K), &K, &beta, drc, &K);
3167     if (0 == j) {
3168       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, dfD, &K, &LIBXSMM_VLA_ACCESS(2, hp, 0, 0, K),        &K, &beta, drf, &K);
3169     } else {
3170       LIBXSMM_XBLAS_SYMBOL(float)(&transa, &transbT, &K, &K, &N, &alpha, dfD, &K, &LIBXSMM_VLA_ACCESS(3, h, j-1, 0, 0, N, K), &K, &beta, drf, &K);
3171     }
3172     /* compute db */
3173 #if defined(_OPENMP)
3174 #   pragma omp parallel for private(l, p)
3175 #endif
3176     for (l = 0; l < K; l++) {
3177       for (p = 0; p < N; p++) {
3178         dbi[l] += LIBXSMM_VLA_ACCESS(2, di, p, l, K);
3179         dbc[l] += LIBXSMM_VLA_ACCESS(2, dc, p, l, K);
3180         dbf[l] += LIBXSMM_VLA_ACCESS(2, df, p, l, K);
3181       }
3182     }
3183   }
3184 }
3185 
3186