1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas (Intel Corp.) 10 ******************************************************************************/ 11 12 #define TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1) do {\ 13 __m512i zero_reg = _mm512_setzero_si512();\ 14 src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\ 15 tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\ 16 for (pixel_pair = 0; pixel_pair < n_full_pixel_pairs; pixel_pair++) {\ 17 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ 18 pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ 19 pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\ 20 ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ 21 ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ 22 _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ 23 _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ 24 }\ 25 src_out += 2* handle->ofmblock;\ 26 tr_out += 2*handle->ofmblock;\ 27 }\ 28 if (half_pixel_pair == 1) {\ 29 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ 30 pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ 31 pixel_1 = _mm512_setzero_si512();\ 32 ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ 33 ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ 34 _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ 35 _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ 36 }\ 37 }\ 38 for (oi = ((handle->compute_pixels+1)/2)*2; oi < handle->output_pixels; oi+=2) {\ 39 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ 40 tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2);\ 41 _mm512_storeu_si512((element_output_type*)tr_out, zero_reg);\ 42 _mm512_storeu_si512((element_output_type*)tr_out+32, zero_reg);\ 43 }\ 44 }\ 45 } while(0) 46 47 #define TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, oj, H) do {\ 48 int h, w_pixel_pair, w_full_pixel_pairs = handle->ofwp/2;\ 49 for (h=0; h<H; h++) {\ 50 src_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj + h, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);\ 51 tr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, 0, h, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2);\ 52 for (w_pixel_pair = 0; w_pixel_pair < w_full_pixel_pairs; w_pixel_pair++) {\ 53 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2+=32) {\ 54 pixel_0 = _mm512_loadu_si512((element_output_type*)src_out+ofm2);\ 55 pixel_1 = _mm512_loadu_si512(((element_output_type*)src_out+handle->ofmblock+ofm2));\ 56 ofms_lo = _mm512_permutex2var_epi16(pixel_0, idx_lo, pixel_1);\ 57 ofms_hi = _mm512_permutex2var_epi16(pixel_0, idx_hi, pixel_1);\ 58 _mm512_storeu_si512(tr_out+ofm2*2, ofms_lo);\ 59 _mm512_storeu_si512((element_output_type*)tr_out+32+ofm2*2, ofms_hi);\ 60 }\ 61 src_out += 2* handle->ofmblock;\ 62 tr_out += 2*handle->ofmblock;\ 63 }\ 64 }\ 65 } while(0) 66 67 int img, my_img_start, my_img_end, ofmb, ifmb, ofm1, ifm1, ifm2, ofm2, oj, oi, ii, ij, kj, ki, j_br, img_br, i, j, img_block_size = 1, my_ofm_start, my_ofm_end, my_ifm_start, my_ifm_end, block_ofm, block_ifm, pix; 68 /* computing first logical thread */ 69 const int ltid = tid - start_thread; 70 71 element_output_type *const out = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; 72 LIBXSMM_VLA_DECL(5, const element_output_type, output, (const element_output_type*)out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); 73 LIBXSMM_VLA_DECL(5, const element_input_type, input, (const element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 74 75 element_filter_type *weight_ptr = (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; 76 77 element_filter_type *filter_dst_ptr = (handle->weight_copies > 1) ? (element_filter_type*)weight_ptr : (element_filter_type*)handle->grad_filter->data; 78 LIBXSMM_VLA_DECL(7, element_filter_type, weight_dst, (element_filter_type*)filter_dst_ptr, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2); 79 80 /* This intermediate tensor is used when pixels are NOT fully accumulated */ 81 float *weight_ptr_f32 = (float*) ((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; 82 83 LIBXSMM_VLA_DECL(6, float, weight_private_f32, (float*)weight_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 84 /* Accumulation scratch is used when pixels are ully accumulated */ 85 element_filter_type *filter_scratch = (element_filter_type*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + ltid * handle->ofmblock * handle->ifmblock * 2; 86 87 LIBXSMM_VLA_DECL(2, float, filter_tmp, (float*)filter_scratch, handle->ofmblock); 88 89 element_input_type *scratch_tr_input = (element_input_type*)((char*)handle->scratch + handle->upd_lp_input_full_scratch_offset); 90 element_input_type *zero_ptr_in; 91 LIBXSMM_VLA_DECL(4, element_input_type, tr_input, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, handle->input_pixels); 92 LIBXSMM_VLA_DECL(5, element_input_type, tr_input_2, (element_input_type*) scratch_tr_input, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); 93 94 element_output_type *scratch_tr_output = (element_input_type*)((char*)handle->scratch + handle->upd_lp_output_full_scratch_offset); 95 LIBXSMM_VLA_DECL(5, element_output_type, tr_output, (element_output_type*) scratch_tr_output, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); 96 LIBXSMM_VLA_DECL(6, element_output_type, tr_output_2, (element_output_type*) scratch_tr_output, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2); 97 #if 0 98 element_output_type *out_ptr = (element_output_type*)handle->grad_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; 99 element_output_type *zero_ptr_out; 100 #endif 101 102 /* transpose, copy and reduce work-related variables */ 103 const int reduce_work = (handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S)/16; 104 const int reduce_chunksize = (reduce_work % handle->desc.threads == 0) ? (reduce_work / handle->desc.threads) : (reduce_work / handle->desc.threads) + 1; 105 const int reduce_thr_begin = (ltid * reduce_chunksize < reduce_work) ? (ltid * reduce_chunksize) : reduce_work; 106 const int reduce_thr_end = ((ltid + 1) * reduce_chunksize < reduce_work) ? ((ltid + 1) * reduce_chunksize) : reduce_work; 107 108 const float beta = (handle->use_intermediate_f32_wt_tensor ? 1.f : 0.f); 109 float *dst_ptr; 110 gemm_br_function br_gemm_kernel = 0; 111 112 /* These are used for the vnni reformatting of the f32 output */ 113 __m512i c01 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(); 114 const __m512i perm_index = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8, 23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); 115 116 /* Related to the output transpose */ 117 int n_full_pixel_pairs = handle->compute_pixels/2, half_pixel_pair = handle->compute_pixels%2, pixel_pair; 118 element_output_type *tr_out, *src_out; 119 const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0); 120 const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); 121 const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16); 122 const __m512i idx_lo = _mm512_or_epi32(selector, offsets_lo); 123 const __m512i idx_hi = _mm512_or_epi32(selector, offsets_hi); 124 __m512i pixel_0, pixel_1, ofms_lo, ofms_hi; 125 126 /* Batch reduce related variables */ 127 const element_output_type *A_ptrs[1024]; 128 const element_input_type *B_ptrs[1024]; 129 unsigned long long n_blocks; 130 131 libxsmm_blasint LDA = handle->ofmblock; 132 libxsmm_blasint LDB = handle->input_pixels; 133 libxsmm_blasint LDC = handle->ofmblock; 134 int prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); 135 int l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); 136 137 const int img_work = handle->desc.N; 138 const int img_chunksize = (img_work % handle->desc.threads == 0) ? (img_work / handle->desc.threads) : (img_work / handle->desc.threads) + 1; 139 my_img_start = (ltid * img_chunksize < img_work) ? (ltid * img_chunksize) : img_work; 140 my_img_end = ((ltid + 1) * img_chunksize < img_work) ? ((ltid + 1) * img_chunksize) : img_work; 141 142 libxsmm_barrier_init(handle->barrier, ltid); 143 144 if (handle->upd_linearized_pixels == 1) { 145 /* First transpose input and output */ 146 if (handle->use_hybrid_imgofm_parallelization == 1) { 147 if (handle->upd_pack_input_upfront == 0) { 148 for (img = my_img_start; img < my_img_end; img++) { 149 #if 0 150 zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); 151 memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type)); 152 #endif 153 for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { 154 transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), 155 (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels), 156 handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels ); 157 #if 0 158 for (ij = 0; ij < handle->ifhp; ij++) { 159 for (ii = 0; ii < handle->ifwp; ii++) { 160 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 161 LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * handle->ifwp + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) = 162 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 163 } 164 } 165 } 166 #endif 167 } 168 } 169 } else { 170 for (img = my_img_start; img < my_img_end; img++) { 171 #if 0 172 zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(4, tr_input, img, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels); 173 memset(zero_ptr_in, 0, handle->desc.C * handle->input_pixels * sizeof(element_input_type)); 174 #endif 175 for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { 176 for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) { 177 transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), 178 (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, ij * (handle->ifwp/handle->desc.v), handle->blocksifm, handle->ifmblock, handle->input_pixels), 179 handle->ifmblock, handle->ifwp/handle->desc.v, 2*handle->ifmblock, handle->input_pixels ); 180 #if 0 181 for (ii = 0; ii < handle->ifwp/handle->desc.v; ii++) { 182 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 183 LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, ifm2, ij * (handle->ifwp/handle->desc.v) + ii, handle->blocksifm, handle->ifmblock, handle->input_pixels) = 184 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, ii*handle->desc.v, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 185 } 186 } 187 #endif 188 } 189 } 190 } 191 } 192 193 for (img = my_img_start; img < my_img_end; img++) { 194 for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { 195 TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1); 196 } 197 } 198 } 199 #if 0 200 for (img = my_img_start; img < my_img_end; img++) { 201 zero_ptr_out = (element_output_type*) &LIBXSMM_VLA_ACCESS(5, tr_output, img, 0, 0, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); 202 memset(zero_ptr_out, 0, handle->desc.K * handle->output_pixels * sizeof(element_output_type)); 203 for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { 204 for (oi = 0; oi < handle->n_used_pixels; oi++) { 205 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { 206 LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, oi/2, ofm2, oi%2, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2) = 207 *((element_output_type*)out_ptr + img * handle->blocksofm * handle->ofwp * handle->ofhp * handle->ofmblock + ofm1 * handle->ofwp * handle->ofhp * handle->ofmblock + oi * handle->ofmblock + ofm2); 208 } 209 } 210 } 211 } 212 #endif 213 } else { 214 if (handle->upd_trans_w_only == 0) { 215 if (handle->on_the_fly_input_packing == 0) { 216 for (img = my_img_start; img < my_img_end; img++) { 217 zero_ptr_in = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); 218 memset(zero_ptr_in, 0, handle->desc.C * handle->ifhp * handle->ifwp_extended * sizeof(element_input_type)); 219 for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1++) { 220 for (ij = 0; ij < handle->ifhp; ij++) { 221 for (ii = 0; ii < handle->ifwp; ii++) { 222 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 223 LIBXSMM_VLA_ACCESS(5, tr_input_2, img, ifm1, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended) = 224 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 225 } 226 } 227 } 228 } 229 } 230 } 231 for (img = my_img_start; img < my_img_end; img++) { 232 for (ofm1 = 0; ofm1 < handle->blocksofm; ofm1++) { 233 #if 0 234 TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, 0, handle->ofh); 235 #else 236 for (oj = 0; oj < handle->ofh; oj++) { 237 for (oi = 0; oi < handle->ofw; oi++) { 238 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { 239 LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, oi/2, ofm2, oi%2, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2) = 240 LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, ofm2, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); 241 } 242 } 243 } 244 if (handle->ofw % 2 == 1) { 245 for (oj = 0; oj < handle->ofh; oj++) { 246 for (ofm2 = 0; ofm2 < handle->ofmblock; ofm2++) { 247 LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj, handle->ofw/2, ofm2, handle->ofw%2, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2) = (element_output_type)0; 248 } 249 } 250 } 251 #endif 252 } 253 } 254 } 255 } 256 257 /* Make sure we initialize intermediate weights to zero */ 258 if (handle->use_intermediate_f32_wt_tensor == 1 && handle->use_hybrid_imgofm_parallelization == 0) { 259 memset(weight_ptr_f32, 0, handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S * sizeof(float)); 260 } 261 262 if (handle->upd_linearized_pixels == 0) { 263 if (handle->upd_trans_w_only == 1) { 264 LDA = handle->ofmblock; 265 LDB = handle->ifhp*handle->ifwp_extended; 266 LDC = handle->ofmblock; 267 prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); 268 l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); 269 n_blocks = handle->batchreduce_h_pixels; 270 br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); 271 272 for (img = my_img_start; img < my_img_end; img++) { 273 for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { 274 for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){ 275 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { 276 for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { 277 /* Transpose output block */ 278 TRANS_OUTPUT_W_TO_VNNI_FORMAT(img, ofm1, oj, handle->batchreduce_h_pixels); 279 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { 280 /* Transpose input block */ 281 for (j=0; j < handle->batchreduce_h_pixels; j++) { 282 transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj+j, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), 283 (element_input_type*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), 284 handle->ifmblock, handle->ifwp_extended, handle->ifmblock, handle->ifhp*handle->ifwp_extended ); 285 } 286 for (kj = 0; kj < handle->desc.R; ++kj) { 287 for (ki = 0; ki < handle->desc.S; ++ki) { 288 289 /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ 290 if (handle->use_intermediate_f32_wt_tensor == 1) { 291 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 292 } else { 293 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); 294 } 295 296 for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) { 297 A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, 0, j_br, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2); 298 B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j_br, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); 299 } 300 301 br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); 302 303 /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */ 304 if ((oj + handle->batchreduce_h_pixels >= handle->ofh) && (img == my_img_end - 1)) { 305 LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); 306 for (ij = 0; ij < handle->ifmblock; ij+=2) { 307 for (ii = 0; ii < handle->ofmblock; ii+=16) { 308 c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), 309 LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) ); 310 _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); 311 } 312 } 313 } 314 } 315 } 316 } 317 } 318 } 319 } 320 } 321 } 322 } else { 323 int fast_trans = (handle->ofw == 112 && handle->desc.v == 2 && handle->ifmblock == 4 && handle->batchreduce_h_pixels == 1) ? 1 : 0; 324 const __m512i skipper = LIBXSMM_INTRINSICS_MM512_SET_EPI16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 19, 11, 3, 26, 18, 10, 2, 25, 17, 9, 1, 24, 16, 8, 0); 325 __m512i p0, p1, p2, p3; 326 __m256i _p0, _p1, _p2, _p3; 327 __m256i r0 = _mm256_undefined_si256(); 328 __m256i r1 = _mm256_undefined_si256(); 329 __m256i r2 = _mm256_undefined_si256(); 330 __m256i r3 = _mm256_undefined_si256(); 331 LDA = handle->ofmblock; 332 LDB = handle->ifhp*handle->ifwp_extended; 333 LDC = handle->ofmblock; 334 prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); 335 l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); 336 n_blocks = handle->batchreduce_h_pixels; 337 /* Handle case when ofw is odd number... */ 338 if (handle->ofw % 2 == 1) { 339 br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw+1, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); 340 } else { 341 br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->ofw, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); 342 } 343 344 for (img = my_img_start; img < my_img_end; img++) { 345 for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { 346 for (oj = 0; oj < handle->ofh; oj += handle->batchreduce_h_pixels){ 347 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { 348 for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { 349 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { 350 for (kj = 0; kj < handle->desc.R; ++kj) { 351 for (ki = 0; ki < handle->desc.S; ++ki) { 352 353 /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ 354 if (handle->use_intermediate_f32_wt_tensor == 1) { 355 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 356 } else { 357 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); 358 } 359 360 /* Copy the input in such a way that we ignore "w-pixels" based on ki value */ 361 if (handle->on_the_fly_input_packing == 1) { 362 if (fast_trans == 1) { 363 for (ii = 0; ii < handle->ofw*2; ii+=32) { 364 p0 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); 365 p0 = _mm512_permutexvar_epi16(skipper, p0); 366 _p0 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p0, 0); 367 p1 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+8+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); 368 p1 = _mm512_permutexvar_epi16(skipper, p1); 369 _p1 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p1, 0); 370 p2 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+16+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); 371 p2 = _mm512_permutexvar_epi16(skipper, p2); 372 _p2 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p2, 0); 373 p3 = _mm512_loadu_si512((element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj*handle->desc.u+kj, ii+24+ki, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock)); 374 p3 = _mm512_permutexvar_epi16(skipper, p3); 375 _p3 = LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(p3, 0); 376 377 r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p0, 0), 0); 378 r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p1, 0), 1); 379 r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p2, 0), 2); 380 r0 = _mm256_insert_epi64 (r0, _mm256_extract_epi64(_p3, 0), 3); 381 _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r0); 382 383 r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p0, 1), 0); 384 r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p1, 1), 1); 385 r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p2, 1), 2); 386 r1 = _mm256_insert_epi64 (r1, _mm256_extract_epi64(_p3, 1), 3); 387 _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 1, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r1); 388 389 r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p0, 2), 0); 390 r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p1, 2), 1); 391 r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p2, 2), 2); 392 r2 = _mm256_insert_epi64 (r2, _mm256_extract_epi64(_p3, 2), 3); 393 _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 2, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r2); 394 395 r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p0, 3), 0); 396 r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p1, 3), 1); 397 r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p2, 3), 2); 398 r3 = _mm256_insert_epi64 (r3, _mm256_extract_epi64(_p3, 3), 3); 399 _mm256_storeu_si256((__m256i*)&LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 3, 0, ii/2, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended), r3); 400 401 } 402 } else { 403 for (ij = 0; ij < handle->batchreduce_h_pixels; ij++) { 404 for (ii = 0; ii < handle->ofw; ii++) { 405 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 406 LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, ifm2, ij, ii, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended) = 407 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, (oj+ij)*handle->desc.u+kj, ii*handle->desc.v+ki, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 408 } 409 } 410 } 411 } 412 } 413 414 for (j_br = 0; j_br < handle->batchreduce_h_pixels; j_br++) { 415 A_ptrs[j_br] = (element_output_type*) &LIBXSMM_VLA_ACCESS(6, tr_output_2, img, ofm1, oj+j_br, 0, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp_extended/2, handle->ofmblock, 2); 416 B_ptrs[j_br] = (element_input_type*) &LIBXSMM_VLA_ACCESS(5, tr_input_2, img, 0, 0, j_br, 0, handle->blocksifm, handle->ifmblock, handle->ifhp, handle->ifwp_extended); 417 } 418 419 br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); 420 421 /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */ 422 if ((oj + handle->batchreduce_h_pixels >= handle->ofh) && (img == my_img_end - 1)) { 423 LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); 424 for (ij = 0; ij < handle->ifmblock; ij+=2) { 425 for (ii = 0; ii < handle->ofmblock; ii+=16) { 426 c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), 427 LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock))); 428 _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); 429 } 430 } 431 } 432 } 433 } 434 } 435 } 436 } 437 } 438 } 439 } 440 } 441 } else { 442 LDA = handle->ofmblock; 443 LDB = handle->input_pixels; 444 LDC = handle->ofmblock; 445 prefetch_mode = libxsmm_get_gemm_prefetch(LIBXSMM_GEMM_PREFETCH_NONE); 446 l_flags = LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N'); 447 448 if (handle->use_hybrid_imgofm_parallelization == 1) { 449 /* Here we are using batch-reduce kernel and hybrid minibatch/FM parallelization */ 450 /* FIXME: Hardcoed logic for N=27 */ 451 int group_size = (handle->desc.threads == 27 && handle->desc.N == 27 && handle->ofw == 14 && handle->desc.R == 1 && handle->desc.u == 1 && ltid >= 24) ? 3 : LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); 452 int tile_id = ltid / LIBXSMM_UPDIV(handle->desc.threads, handle->weight_copies); 453 int tiles = handle->weight_copies; 454 int img_per_tile = LIBXSMM_UPDIV(handle->desc.N, tiles); 455 int my_in_tile_id = ltid % group_size; 456 int ifms_per_thread = LIBXSMM_UPDIV(handle->blocksifm, group_size); 457 int ofms_per_thread = LIBXSMM_UPDIV(handle->blocksofm, group_size); 458 int my_R_start = 0; 459 int my_R_end = handle->desc.R; 460 element_filter_type *weight_ptr_group = (handle->weight_copies > 1) ? (element_filter_type*)((char*)handle->scratch + handle->upd_filter_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S : (element_filter_type*)handle->grad_filter->data; 461 LIBXSMM_VLA_DECL(7, element_filter_type, weight_private_group, (element_filter_type*)weight_ptr_group, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2); 462 /* This intermediate tensor is used when pixels are NOT fully accumulated */ 463 float *weight_tile_ptr_f32 = (float*)((char*)handle->scratch + handle->upd_lp_filter_full_scratch_offset) + tile_id * handle->desc.C * handle->desc.K * handle->desc.R * handle->desc.S; 464 LIBXSMM_VLA_DECL(6, float, weight_private_tile_f32, (float*)weight_tile_ptr_f32, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 465 466 my_img_start = LIBXSMM_MIN(tile_id * img_per_tile, handle->desc.N); 467 my_img_end = LIBXSMM_MIN((tile_id+1) * img_per_tile, handle->desc.N); 468 my_ifm_start = LIBXSMM_MIN(my_in_tile_id * ifms_per_thread, handle->blocksifm ); 469 my_ifm_end = LIBXSMM_MIN((my_in_tile_id+1) * ifms_per_thread, handle->blocksifm ); 470 my_ofm_start = 0; 471 my_ofm_end = handle->blocksofm; 472 /* FIXME: Hardcoed logic for N=27 */ 473 if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.C == 256 && handle->desc.K == 1024 && handle->ofh == 14 && handle->desc.u == 1) { 474 my_ofm_start = LIBXSMM_MIN(my_in_tile_id * ofms_per_thread, handle->blocksofm); 475 my_ofm_end = LIBXSMM_MIN((my_in_tile_id+1) * ofms_per_thread, handle->blocksofm); 476 my_ifm_start = 0; 477 my_ifm_end = handle->blocksifm; 478 } 479 if (handle->desc.threads == 27 && handle->desc.N == 27 && handle->desc.R == 3 && handle->desc.S == 3 && handle->ofh == 14) { 480 int r_per_tile = LIBXSMM_UPDIV(handle->desc.R, group_size); 481 my_ifm_start = 0; 482 my_ifm_end = handle->blocksifm; 483 my_ofm_start = 0; 484 my_ofm_end = handle->blocksofm; 485 my_R_start = LIBXSMM_MIN(my_in_tile_id * r_per_tile, handle->desc.R); 486 my_R_end = LIBXSMM_MIN((my_in_tile_id+1) * r_per_tile, handle->desc.R); 487 } 488 block_ofm = my_ofm_end-my_ofm_start+1; 489 block_ifm = my_ifm_end-my_ifm_start+1; 490 img_block_size = my_img_end - my_img_start; 491 492 br_gemm_kernel = libxsmm_bsmmdispatch_reducebatch_addr(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); 493 n_blocks = img_block_size; 494 495 /* Make sure we initialize intermediate weights to zero */ 496 if (handle->use_intermediate_f32_wt_tensor == 1) { 497 for (ofm1 = my_ofm_start; ofm1 < my_ofm_end; ofm1++ ) { 498 for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { 499 for (kj = my_R_start; kj < my_R_end; ++kj) { 500 memset((float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock), 0, handle->ofmblock * handle->ifmblock * handle->desc.S * sizeof(float)); 501 } 502 } 503 } 504 } 505 506 libxsmm_barrier_wait(handle->barrier, ltid); 507 508 for (img = my_img_start; img < my_img_end; img += img_block_size) { 509 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += block_ofm) { 510 for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){ 511 for (ifmb = my_ifm_start; ifmb < my_ifm_end; ifmb += block_ifm) { 512 for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+block_ofm, my_ofm_end); ofm1++ ) { 513 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+block_ifm, my_ifm_end); ifm1++) { 514 for (kj = my_R_start; kj < my_R_end; ++kj) { 515 for (ki = 0; ki < handle->desc.S; ++ki) { 516 517 /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ 518 if (handle->use_intermediate_f32_wt_tensor == 1) { 519 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_tile_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 520 } else { 521 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); 522 } 523 524 for (img_br = 0; img_br < img_block_size; img_br++) { 525 A_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(5, tr_output, img + img_br, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2); 526 B_ptrs[img_br] = &LIBXSMM_VLA_ACCESS(4, tr_input, img + img_br, ifm1, 0, pix + kj * handle->ifwp + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels); 527 } 528 529 br_gemm_kernel(A_ptrs, B_ptrs, dst_ptr, &n_blocks); 530 531 /* Convert fully caccumulated buffer to bf16 weight buffer in case of full accumulation has happened */ 532 if ((pix + handle->pixel_blocking >= handle->n_used_pixels) && (img == my_img_end - img_block_size)) { 533 LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); 534 for (ij = 0; ij < handle->ifmblock; ij+=2) { 535 for (ii = 0; ii < handle->ofmblock; ii+=16) { 536 c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), 537 LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) ); 538 _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_private_group, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); 539 } 540 } 541 } 542 } 543 } 544 } 545 } 546 } 547 } 548 } 549 } 550 551 } else { 552 gemm_function gemm_kernel = libxsmm_bsmmdispatch(handle->ofmblock, handle->ifmblock, handle->pixel_blocking, &LDA, &LDB, &LDC, NULL, &beta, &l_flags, &prefetch_mode); 553 554 for (img = my_img_start; img < my_img_end; img++) { 555 for (ofmb = 0; ofmb < handle->blocksofm; ofmb += handle->block_upd_ofm) { 556 for (pix = 0; pix < handle->n_used_pixels; pix += handle->pixel_blocking){ 557 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_upd_ifm) { 558 for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_upd_ofm, handle->blocksofm); ofm1++ ) { 559 /* Transpose output block */ 560 if (pix == 0 && ifmb == 0) { 561 TRANS_OUTPUT_TO_VNNI_FORMAT(img, ofm1); 562 } 563 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_upd_ifm, handle->blocksifm); ifm1++) { 564 /* Transpose input block */ 565 if (pix == 0 && ofmb == 0 && ofm1 == 0) { 566 if (handle->upd_pack_input_upfront == 0) { 567 transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, 0, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), 568 (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, 0, handle->blocksifm, handle->ifmblock, handle->input_pixels), 569 handle->ifmblock, handle->ifhp*handle->ifwp, handle->ifmblock, handle->input_pixels ); 570 } else { 571 for (ij = 0; ij < handle->ifhp/handle->desc.u; ij++) { 572 transpose_input_pixels_bf16( (element_input_type*)&LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij*handle->desc.u, 0, 0, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock), 573 (element_input_type*)&LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, ij * (handle->ifwp/handle->desc.v), handle->blocksifm, handle->ifmblock, handle->input_pixels), 574 handle->ifmblock, handle->ifwp/handle->desc.v, 2*handle->ifmblock, handle->input_pixels ); 575 } 576 } 577 } 578 for (kj = 0; kj < handle->desc.R; ++kj) { 579 for (ki = 0; ki < handle->desc.S; ++ki) { 580 581 /* Determine if destination is the accumulation scratch or the intermediate fp32 weight tensor */ 582 if (handle->use_intermediate_f32_wt_tensor == 1) { 583 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(6, weight_private_f32, ofm1, ifm1, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 584 } else { 585 dst_ptr = (float*)&LIBXSMM_VLA_ACCESS(2, filter_tmp, 0, 0, handle->ofmblock); 586 } 587 gemm_kernel( &LIBXSMM_VLA_ACCESS(5, tr_output, img, ofm1, pix/2, 0, 0, handle->blocksofm, handle->output_pixels/2, handle->ofmblock, 2), 588 &LIBXSMM_VLA_ACCESS(4, tr_input, img, ifm1, 0, pix + kj * handle->ifwp + ki, handle->blocksifm, handle->ifmblock, handle->input_pixels), 589 dst_ptr); 590 591 /* Convert fully accumulated buffer to bf16 weight buffer in case of full accumulation has happened */ 592 if ((pix + handle->pixel_blocking >= handle->n_used_pixels) && (img == my_img_end - 1)) { 593 LIBXSMM_VLA_DECL(2, float, filter_acc_buffer, (float*)dst_ptr, handle->ofmblock); 594 for (ij = 0; ij < handle->ifmblock; ij+=2) { 595 for (ii = 0; ii < handle->ofmblock; ii+=16) { 596 c01 = LIBXSMM_INTRINSISCS_MM512_CVTNE2PS_PBH( LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij+1, ii, handle->ofmblock)), 597 LIBXSMM_INTRINSICS_MM512_LOAD_PS(&LIBXSMM_VLA_ACCESS(2, filter_acc_buffer, ij, ii, handle->ofmblock)) ); 598 _mm512_storeu_si512(&LIBXSMM_VLA_ACCESS(7, weight_dst, ofm1, ifm1, kj, ki, ij/2, ii, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock/2, handle->ofmblock, 2), _mm512_permutexvar_epi16(perm_index, c01)); 599 } 600 } 601 } 602 } 603 } 604 } 605 } 606 } 607 } 608 } 609 } 610 } 611 } 612 613 libxsmm_barrier_wait(handle->barrier, ltid); 614 615 if (handle->weight_copies > 1) { 616 int active_copies = handle->weight_copies; 617 const int filter_size = handle->desc.R * handle->desc.S * handle->desc.C * handle->desc.K; 618 LIBXSMM_VLA_DECL(2, element_filter_type, weight_copies_buffer, (element_filter_type*) ((char*)handle->scratch + handle->upd_filter_scratch_offset), filter_size); 619 element_filter_type *weight_global_ptr = (element_filter_type*) handle->grad_filter->data; 620 621 /* In this case calculate how many weight copies have been indeed computed */ 622 if (handle->desc.N != handle->desc.threads) { 623 active_copies = 1; 624 while (active_copies * img_chunksize < handle->desc.N) { 625 active_copies++; 626 } 627 } 628 629 for ( j = reduce_thr_begin; j < reduce_thr_end; j++) { 630 __m512 weight_sum = _mm512_setzero_ps(); 631 for ( i = 0; i < active_copies; i++ ) { 632 weight_sum = _mm512_add_ps(weight_sum, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((const __m256i*)&LIBXSMM_VLA_ACCESS(2, weight_copies_buffer, i, j*16, filter_size)))); 633 } 634 _mm256_storeu_si256((__m256i*)(((libxsmm_bfloat16*) weight_global_ptr) + j*16), LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(weight_sum)); 635 } 636 libxsmm_barrier_wait(handle->barrier, ltid); 637 } 638 639 #undef TRANS_OUTPUT_W_TO_VNNI_FORMAT 640 #undef TRANS_OUTPUT_TO_VNNI_FORMAT 641