1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas (Intel Corp.)
10 ******************************************************************************/
11 
12 #define TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1) do {\
13   __m512i zero_reg = _mm512_setzero_si512();\
14   src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\
15   tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\
16   for (pixel_pair = 0; pixel_pair < n_full_pixel_pairs; pixel_pair++) {\
17     for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\
18       pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\
19       pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\
20       ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\
21       ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\
22       _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\
23       _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\
24     }\
25     src_out += 2* handle->ofmblock;\
26     tr_out += 2*handle->ofmblock;\
27   }\
28   if (half_pixel_pair == 1) {\
29     for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\
30       pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\
31       pixel_1 = _mm512_setzero_si512();\
32       ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\
33       ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\
34       _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\
35       _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\
36     }\
37   }\
38   for (oi = ((handle->compute_pixels+1)/2)*2; oi < handle->output_pixels; oi+=2) {\
39     for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\
40       tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\
41       _mm512_storeu_si512((element_output_type*)tr_out, zero_reg);\
42       _mm512_storeu_si512((element_output_type*)tr_out+32, zero_reg);\
43     }\
44   }\
45 } while(0)
46 
47 #define TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, oj, H) do {\
48   int h, w_pixel_pair, w_full_pixel_pairs = handle->ofwp/2;\
49   for (h=0; h<H; h++) {\
50     src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj + h, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\
51     tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, 0, h, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2);\
52     for (w_pixel_pair = 0; w_pixel_pair < w_full_pixel_pairs; w_pixel_pair++) {\
53       for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\
54         pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\
55         pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\
56         ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\
57         ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\
58         _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\
59         _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\
60       }\
61       src_out += 2* handle->ofmblock;\
62       tr_out += 2*handle->ofmblock;\
63     }\
64   }\
65 } while(0)
66 
67 int img, my_img_start, my_img_end, ofmb, ifmb, ofm1, ifm1, ifm2, ofm2, oj, oi, ii, ij, kj, ki, j_br, img_br, i, j, img_block_size = 1, my_ofm_start, my_ofm_end, my_ifm_start, my_ifm_end, block_ofm, block_ifm, pix;
68 /* computing first logical thread */
69 const int ltid = tid - start_thread;
70 
71 element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock;
72 LIBXSMM_VLA_DECL(5, const element_output_type, output, (const element_output_type*)out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);
73 LIBXSMM_VLA_DECL(5, const element_input_type, input, (const element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
74 
75 element_filter_type *weight_ptr = (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S;
76 
77 element_filter_type *filter_dst_ptr = (handle->weight_copies > 1) ? (element_filter_type*)weight_ptr : (element_filter_type*)handle->grad_filter->data;
78 LIBXSMM_VLA_DECL(7, element_filter_type, weight_dst, (element_filter_type*)filter_dst_ptr, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2);
79 
80 /* This intermediate tensor is used when pixels are NOT fully accumulated  */
81 float *weight_ptr_f32 = (float*) ((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S;
82 
83 LIBXSMM_VLA_DECL(6, float, weight_private_f32, (float*)weight_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
84 /* Accumulation scratch is used when pixels are ully accumulated  */
85 element_filter_type *filter_scratch = (element_filter_type*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->ofmblock * handle->ifmblock * 2;
86 
87 LIBXSMM_VLA_DECL(2, float, filter_tmp, (float*)filter_scratch, handle->ofmblock);
88 
89 element_input_type *scratch_tr_input = (element_input_type*)((char*)handle->scratch + handle->upd_lp_input_full_scratch_offset);
90 element_input_type *zero_ptr_in;
91 LIBXSMM_VLA_DECL(4, element_input_type, tr_input, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, handle->input_pixels);
92 LIBXSMM_VLA_DECL(5, element_input_type, tr_input_2, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended);
93 
94 element_output_type *scratch_tr_output = (element_input_type*)((char*)handle->scratch + handle->upd_lp_output_full_scratch_offset);
95 LIBXSMM_VLA_DECL(5, element_output_type, tr_output, (element_output_type*) scratch_tr_output, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);
96 LIBXSMM_VLA_DECL(6, element_output_type, tr_output_2, (element_output_type*) scratch_tr_output, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2);
97 #if 0
98 element_output_type *out_ptr = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock;
99 element_output_type *zero_ptr_out;
100 #endif
101 
102 /* transpose, copy and reduce work-related variables  */
103 const int reduce_work = (handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S)/16;
104 const int reduce_chunksize = (reduce_work % handle->desc.threads == 0) ? (reduce_work / handle->desc.threads) : (reduce_work / handle->desc.threads) + 1;
105 const int reduce_thr_begin = (ltid * reduce_chunksize < reduce_work) ? (ltid * reduce_chunksize) : reduce_work;
106 const int reduce_thr_end = ((ltid + 1) * reduce_chunksize < reduce_work) ? ((ltid + 1) * reduce_chunksize) : reduce_work;
107 
108 const float beta = (handle->use_intermediate_f32_wt_tensor ? 1.f : 0.f);
109 float *dst_ptr;
110 gemm_br_function br_gemm_kernel = 0;
111 
112 /* These are used for the vnni reformatting of the f32 output  */
113 __m512i c01 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
114 const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
115 
116 /* Related to the output transpose */
117 int n_full_pixel_pairs = handle->compute_pixels/2, half_pixel_pair = handle->compute_pixels%2, pixel_pair;
118 element_output_type *tr_out, *src_out;
119 const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0);
120 const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0);
121 const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16);
122 const __m512i idx_lo =  _mm512_or_epi32(selector, offsets_lo);
123 const __m512i idx_hi =  _mm512_or_epi32(selector, offsets_hi);
124 __m512i pixel_0, pixel_1, ofms_lo, ofms_hi;
125 
126 /* Batch reduce related variables */
127 const element_output_type *A_ptrs[1024];
128 const element_input_type  *B_ptrs[1024];
129 unsigned long long n_blocks;
130 
131 libxsmm_blasint LDA = handle->ofmblock;
132 libxsmm_blasint LDB = handle->input_pixels;
133 libxsmm_blasint LDC = handle->ofmblock;
134 int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
135 int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
136 
137 const int img_work = handle->desc.N;
138 const int img_chunksize = (img_work % handle->desc.threads == 0) ? (img_work / handle->desc.threads) : (img_work / handle->desc.threads) + 1;
139 my_img_start = (ltid * img_chunksize < img_work) ? (ltid * img_chunksize) : img_work;
140 my_img_end = ((ltid + 1) * img_chunksize < img_work) ? ((ltid + 1) * img_chunksize) : img_work;
141 
142 libxsmm_barrier_init(handle->barrier, ltid);
143 
144 if (handle->upd_linearized_pixels == 1) {
145   /* First transpose input and output */
146   if (handle->use_hybrid_imgofm_parallelization == 1) {
147     if (handle->upd_pack_input_upfront == 0) {
148       for (img = my_img_start; img < my_img_end; img++) {
149 #if 0
150         zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels);
151         memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type));
152 #endif
153         for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) {
154           transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock),
155               (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels),
156               handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels );
157 #if 0
158           for (ij = 0; ij < handle->ifhp; ij++) {
159             for (ii = 0; ii < handle->ifwp; ii++) {
160               for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) {
161                 LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * handle->ifwp + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) =
162                   LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
163               }
164             }
165           }
166 #endif
167         }
168       }
169     } else {
170       for (img = my_img_start; img < my_img_end; img++) {
171 #if 0
172         zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels);
173         memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type));
174 #endif
175         for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) {
176           for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) {
177             transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock),
178                 (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, ij * (handle->ifwp/handle->desc.v), handle->blocksifm, handle->ifmblock, handle->input_pixels),
179                 handle->ifmblock, handle->ifwp/handle->desc.v, 2*handle->ifmblock, handle->input_pixels );
180 #if 0
181             for (ii = 0; ii < handle->ifwp/handle->desc.v; ii++) {
182               for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) {
183                 LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * (handle->ifwp/handle->desc.v) + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) =
184                   LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, ii*handle->desc.v, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
185               }
186             }
187 #endif
188           }
189         }
190       }
191     }
192 
193     for (img = my_img_start; img < my_img_end; img++) {
194       for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) {
195         TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1);
196       }
197     }
198   }
199 #if 0
200   for (img = my_img_start; img < my_img_end; img++) {
201     zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, 0, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);
202     memset(zero_ptr_out, 0, handle->desc.K * handle->output_pixels * sizeof(element_output_type));
203     for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) {
204       for (oi = 0; oi < handle->n_used_pixels; oi++) {
205         for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) {
206           LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, oi%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) =
207             *((element_output_type*)out_ptr + img * handle->blocksofm * handle->ofwp * handle->ofhp * handle->ofmblock + ofm1 * handle->ofwp * handle->ofhp * handle->ofmblock + oi * handle->ofmblock + ofm2);
208         }
209       }
210     }
211   }
212 #endif
213 } else {
214   if (handle->upd_trans_w_only == 0) {
215     if (handle->on_the_fly_input_packing == 0) {
216       for (img = my_img_start; img < my_img_end; img++) {
217         zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended);
218         memset(zero_ptr_in, 0, handle->desc.C * handle->ifhp * handle->ifwp_extended * sizeof(element_input_type));
219         for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) {
220           for (ij = 0; ij < handle->ifhp; ij++) {
221             for (ii = 0; ii < handle->ifwp; ii++) {
222               for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) {
223                 LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended) =
224                   LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
225               }
226             }
227           }
228         }
229       }
230     }
231     for (img = my_img_start; img < my_img_end; img++) {
232       for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) {
233 #if 0
234         TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, 0, handle->ofh);
235 #else
236         for (oj = 0; oj < handle->ofh; oj++) {
237           for (oi = 0; oi < handle->ofw; oi++) {
238             for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) {
239               LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, oi/2, ofm2, oi%2, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2) =
240                 LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);
241             }
242           }
243         }
244         if (handle->ofw % 2 == 1) {
245           for (oj = 0; oj < handle->ofh; oj++) {
246             for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) {
247               LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, handle->ofw/2, ofm2, handle->ofw%2, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2) = (element_output_type)0;
248             }
249           }
250         }
251 #endif
252       }
253     }
254   }
255 }
256 
257 /* Make sure we initialize intermediate weights to zero */
258 if (handle->use_intermediate_f32_wt_tensor == 1 && handle->use_hybrid_imgofm_parallelization == 0) {
259   memset(weight_ptr_f32, 0, handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S * sizeof(float));
260 }
261 
262 if (handle->upd_linearized_pixels == 0) {
263   if (handle->upd_trans_w_only == 1) {
264     LDA = handle->ofmblock;
265     LDB = handle->ifhp*handle->ifwp_extended;
266     LDC = handle->ofmblock;
267     prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
268     l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
269     n_blocks = handle->batchreduce_h_pixels;
270     br_gemm_kernel =  libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
271 
272     for (img = my_img_start; img < my_img_end; img++) {
273       for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) {
274         for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){
275           for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) {
276             for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) {
277               /* Transpose output block */
278               TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, oj, handle->batchreduce_h_pixels);
279               for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) {
280                 /* Transpose input block */
281                 for (j=0; j < handle->batchreduce_h_pixels; j++) {
282                   transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj+j, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock),
283                       (element_input_type*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended),
284                       handle->ifmblock, handle->ifwp_extended, handle->ifmblock, handle->ifhp*handle->ifwp_extended );
285                 }
286                 for (kj = 0; kj < handle->desc.R; ++kj) {
287                   for (ki = 0; ki < handle->desc.S; ++ki) {
288 
289                     /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */
290                     if (handle->use_intermediate_f32_wt_tensor == 1) {
291                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
292                     } else {
293                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock);
294                     }
295 
296                     for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) {
297                       A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, 0, j_br, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2);
298                       B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j_br, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended);
299                     }
300 
301                     br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks);
302 
303                     /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */
304                     if ((oj + handle->batchreduce_h_pixels >= handle->ofh) && (img == my_img_end - 1)) {
305                       LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock);
306                       for (ij = 0; ij < handle->ifmblock; ij+=2) {
307                         for (ii = 0; ii < handle->ofmblock; ii+=16) {
308                           c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)),
309                                                                         LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) );
310                           _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01));
311                         }
312                       }
313                     }
314                   }
315                 }
316               }
317             }
318           }
319         }
320       }
321     }
322   } else {
323     int fast_trans = (handle->ofw == 112 && handle->desc.v == 2 && handle->ifmblock == 4 && handle->batchreduce_h_pixels == 1) ? 1 : 0;
324     const __m512i skipper = LIBXSMM_INTRINSICS_MM512_SET_EPI16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0);
325     __m512i p0, p1, p2, p3;
326     __m256i _p0, _p1, _p2, _p3;
327     __m256i r0 = _mm256_undefined_si256();
328     __m256i r1 = _mm256_undefined_si256();
329     __m256i r2 = _mm256_undefined_si256();
330     __m256i r3 = _mm256_undefined_si256();
331     LDA = handle->ofmblock;
332     LDB = handle->ifhp*handle->ifwp_extended;
333     LDC = handle->ofmblock;
334     prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
335     l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
336     n_blocks = handle->batchreduce_h_pixels;
337     /* Handle case when ofw is odd number...  */
338     if (handle->ofw % 2 == 1) {
339       br_gemm_kernel =  libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw+1, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
340     } else {
341       br_gemm_kernel =  libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
342     }
343 
344     for (img = my_img_start; img < my_img_end; img++) {
345       for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) {
346         for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){
347           for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) {
348             for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) {
349               for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) {
350                 for (kj = 0; kj < handle->desc.R; ++kj) {
351                   for (ki = 0; ki < handle->desc.S; ++ki) {
352 
353                     /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */
354                     if (handle->use_intermediate_f32_wt_tensor == 1) {
355                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
356                     } else {
357                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock);
358                     }
359 
360                     /* Copy the input in such a way that we ignore "w-pixels" based on ki value  */
361                     if (handle->on_the_fly_input_packing == 1) {
362                       if (fast_trans == 1) {
363                         for (ii = 0; ii < handle->ofw*2; ii+=32) {
364                           p0 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock));
365                           p0 = _mm512_permutexvar_epi16(skipper, p0);
366                           _p0 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p0, 0);
367                           p1 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+8+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock));
368                           p1 = _mm512_permutexvar_epi16(skipper, p1);
369                           _p1 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p1, 0);
370                           p2 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+16+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock));
371                           p2 = _mm512_permutexvar_epi16(skipper, p2);
372                           _p2 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p2, 0);
373                           p3 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+24+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock));
374                           p3 = _mm512_permutexvar_epi16(skipper, p3);
375                           _p3 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p3, 0);
376 
377                           r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p0, 0), 0);
378                           r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p1, 0), 1);
379                           r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p2, 0), 2);
380                           r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p3, 0), 3);
381                           _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r0);
382 
383                           r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p0, 1), 0);
384                           r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p1, 1), 1);
385                           r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p2, 1), 2);
386                           r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p3, 1), 3);
387                           _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 1, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r1);
388 
389                           r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p0, 2), 0);
390                           r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p1, 2), 1);
391                           r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p2, 2), 2);
392                           r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p3, 2), 3);
393                           _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 2, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r2);
394 
395                           r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p0, 3), 0);
396                           r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p1, 3), 1);
397                           r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p2, 3), 2);
398                           r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p3, 3), 3);
399                           _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 3, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r3);
400 
401                         }
402                       } else {
403                         for (ij = 0; ij < handle->batchreduce_h_pixels; ij++) {
404                           for (ii = 0; ii < handle->ofw; ii++) {
405                             for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) {
406                               LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended) =
407                                 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, (oj+ij)*handle->desc.u+kj, ii*handle->desc.v+ki, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
408                             }
409                           }
410                         }
411                       }
412                     }
413 
414                     for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) {
415                       A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj+j_br, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2);
416                       B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j_br, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended);
417                     }
418 
419                     br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks);
420 
421                     /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */
422                     if ((oj + handle->batchreduce_h_pixels >= handle->ofh) && (img == my_img_end - 1)) {
423                       LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock);
424                       for (ij = 0; ij < handle->ifmblock; ij+=2) {
425                         for (ii = 0; ii < handle->ofmblock; ii+=16) {
426                           c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)),
427                                                                         LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)));
428                           _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01));
429                         }
430                       }
431                     }
432                   }
433                 }
434               }
435             }
436           }
437         }
438       }
439     }
440   }
441 } else {
442   LDA = handle->ofmblock;
443   LDB = handle->input_pixels;
444   LDC = handle->ofmblock;
445   prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE);
446   l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N');
447 
448   if (handle->use_hybrid_imgofm_parallelization == 1) {
449     /* Here we are using batch-reduce kernel and hybrid minibatch/FM parallelization */
450     /* FIXME: Hardcoed logic for N=27  */
451     int group_size = (handle->desc.threads == 27 && handle->desc.N == 27 && handle->ofw == 14 && handle->desc.R == 1 && handle->desc.u == 1 && ltid >= 24) ? 3 : LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies);
452     int tile_id = ltid / LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies);
453     int tiles = handle->weight_copies;
454     int img_per_tile = LIBXSMM_UPDIV(handle->desc.N, tiles);
455     int my_in_tile_id = ltid % group_size;
456     int ifms_per_thread = LIBXSMM_UPDIV(handle->blocksifm, group_size);
457     int ofms_per_thread = LIBXSMM_UPDIV(handle->blocksofm, group_size);
458     int my_R_start = 0;
459     int my_R_end = handle->desc.R;
460     element_filter_type *weight_ptr_group = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data;
461     LIBXSMM_VLA_DECL(7, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2);
462     /* This intermediate tensor is used when pixels are NOT fully accumulated  */
463     float *weight_tile_ptr_f32 = (float*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S;
464     LIBXSMM_VLA_DECL(6, float, weight_private_tile_f32, (float*)weight_tile_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
465 
466     my_img_start = LIBXSMM_MIN(tile_id * img_per_tile, handle->desc.N);
467     my_img_end = LIBXSMM_MIN((tile_id+1) * img_per_tile, handle->desc.N);
468     my_ifm_start = LIBXSMM_MIN(my_in_tile_id * ifms_per_thread, handle->blocksifm  );
469     my_ifm_end = LIBXSMM_MIN((my_in_tile_id+1) * ifms_per_thread, handle->blocksifm  );
470     my_ofm_start = 0;
471     my_ofm_end = handle->blocksofm;
472     /* FIXME: Hardcoed logic for N=27  */
473     if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.C == 256 && handle->desc.K == 1024 && handle->ofh == 14 && handle->desc.u == 1) {
474       my_ofm_start = LIBXSMM_MIN(my_in_tile_id * ofms_per_thread, handle->blocksofm);
475       my_ofm_end = LIBXSMM_MIN((my_in_tile_id+1) * ofms_per_thread, handle->blocksofm);
476       my_ifm_start = 0;
477       my_ifm_end = handle->blocksifm;
478     }
479     if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 14) {
480       int r_per_tile = LIBXSMM_UPDIV(handle->desc.R, group_size);
481       my_ifm_start = 0;
482       my_ifm_end = handle->blocksifm;
483       my_ofm_start = 0;
484       my_ofm_end = handle->blocksofm;
485       my_R_start = LIBXSMM_MIN(my_in_tile_id * r_per_tile, handle->desc.R);
486       my_R_end = LIBXSMM_MIN((my_in_tile_id+1) * r_per_tile, handle->desc.R);
487     }
488     block_ofm = my_ofm_end-my_ofm_start+1;
489     block_ifm = my_ifm_end-my_ifm_start+1;
490     img_block_size = my_img_end - my_img_start;
491 
492     br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
493     n_blocks = img_block_size;
494 
495     /* Make sure we initialize intermediate weights to zero */
496     if (handle->use_intermediate_f32_wt_tensor == 1) {
497       for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++ ) {
498         for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) {
499           for (kj = my_R_start; kj < my_R_end; ++kj) {
500             memset((float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), 0, handle->ofmblock * handle->ifmblock * handle->desc.S * sizeof(float));
501           }
502         }
503       }
504     }
505 
506     libxsmm_barrier_wait(handle->barrier, ltid);
507 
508     for (img = my_img_start; img < my_img_end; img += img_block_size) {
509       for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) {
510         for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){
511           for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) {
512             for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) {
513               for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) {
514                 for (kj = my_R_start; kj < my_R_end; ++kj) {
515                   for (ki = 0; ki < handle->desc.S; ++ki) {
516 
517                     /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */
518                     if (handle->use_intermediate_f32_wt_tensor == 1) {
519                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
520                     } else {
521                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock);
522                     }
523 
524                     for (img_br = 0; img_br < img_block_size; img_br++) {
525                       A_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(5, tr_output, img + img_br, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);
526                       B_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(4, tr_input, img + img_br, ifm1, 0, pix + kj * handle->ifwp + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels);
527                     }
528 
529                     br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks);
530 
531                     /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */
532                     if ((pix + handle->pixel_blocking >= handle->n_used_pixels) && (img == my_img_end - img_block_size)) {
533                       LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock);
534                       for (ij = 0; ij < handle->ifmblock; ij+=2) {
535                         for (ii = 0; ii < handle->ofmblock; ii+=16) {
536                           c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)),
537                                                                         LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) );
538                           _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_private_group, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01));
539                         }
540                       }
541                     }
542                   }
543                 }
544               }
545             }
546           }
547         }
548       }
549     }
550 
551   } else {
552     gemm_function gemm_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode);
553 
554     for (img = my_img_start; img < my_img_end; img++) {
555       for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) {
556         for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){
557           for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) {
558             for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) {
559               /* Transpose output block  */
560               if (pix == 0 && ifmb == 0) {
561                 TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1);
562               }
563               for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) {
564                 /* Transpose input block */
565                 if (pix == 0 && ofmb == 0 && ofm1 == 0) {
566                   if (handle->upd_pack_input_upfront == 0) {
567                     transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock),
568                         (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels),
569                         handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels );
570                   } else {
571                     for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) {
572                       transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock),
573                           (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, ij * (handle->ifwp/handle->desc.v), handle->blocksifm, handle->ifmblock, handle->input_pixels),
574                           handle->ifmblock, handle->ifwp/handle->desc.v, 2*handle->ifmblock, handle->input_pixels );
575                     }
576                   }
577                 }
578                 for (kj = 0; kj < handle->desc.R; ++kj) {
579                   for (ki = 0; ki < handle->desc.S; ++ki) {
580 
581                     /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */
582                     if (handle->use_intermediate_f32_wt_tensor == 1) {
583                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock);
584                     } else {
585                       dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock);
586                     }
587                     gemm_kernel( &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2),
588                         &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, pix + kj * handle->ifwp + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels),
589                         dst_ptr);
590 
591                     /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */
592                     if ((pix + handle->pixel_blocking >= handle->n_used_pixels) && (img == my_img_end - 1)) {
593                       LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock);
594                       for (ij = 0; ij < handle->ifmblock; ij+=2) {
595                         for (ii = 0; ii < handle->ofmblock; ii+=16) {
596                           c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)),
597                                                                         LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) );
598                           _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01));
599                         }
600                       }
601                     }
602                   }
603                 }
604               }
605             }
606           }
607         }
608       }
609     }
610   }
611 }
612 
613 libxsmm_barrier_wait(handle->barrier, ltid);
614 
615 if (handle->weight_copies > 1) {
616   int active_copies = handle->weight_copies;
617   const int filter_size = handle->desc.R  * handle->desc.S * handle->desc.C * handle->desc.K;
618   LIBXSMM_VLA_DECL(2, element_filter_type, weight_copies_buffer, (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset), filter_size);
619   element_filter_type *weight_global_ptr = (element_filter_type*) handle->grad_filter->data;
620 
621   /* In this case calculate how many weight copies have been indeed computed  */
622   if (handle->desc.N != handle->desc.threads) {
623     active_copies = 1;
624     while (active_copies * img_chunksize < handle->desc.N) {
625       active_copies++;
626     }
627   }
628 
629   for ( j = reduce_thr_begin; j < reduce_thr_end; j++) {
630     __m512 weight_sum = _mm512_setzero_ps();
631     for ( i = 0; i < active_copies; i++ ) {
632       weight_sum = _mm512_add_ps(weight_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, weight_copies_buffer, i, j*16, filter_size))));
633     }
634     _mm256_storeu_si256((__m256i*)(((libxsmm_bfloat16*) weight_global_ptr) + j*16), LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(weight_sum));
635   }
636   libxsmm_barrier_wait(handle->barrier, ltid);
637 }
638 
639 #undef TRANS_OUTPUT_W_TO_VNNI_FORMAT
640 #undef TRANS_OUTPUT_TO_VNNI_FORMAT
641