1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 12 int img, ofm1, ofm2 = 0, ifm1, ifm2 = 0, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myOfmId, nOfmBlocks, ind, ofm11, ki1, kj1, ojj, oii, ii, ij, spread_out = 1; 13 /* computing first logical thread */ 14 const int ltid = tid - start_thread; 15 int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); 16 int threads_per_image = handle->desc.threads / handle->desc.N; 17 int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); 18 int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); 19 int my_ofm_start = 0; 20 int my_ofm_end = handle->blocksofm; 21 22 /* Batch reduce related variables */ 23 const element_filter_type *A_ptrs[1024]; 24 const element_input_type *B_ptrs[1024]; 25 unsigned long long n_blocks; 26 27 /* offset output pointer in case of physical output padding */ 28 element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->blocksofm * handle->ofmblock; 29 LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock); 30 element_input_type *input_ptr = ( (handle->pack_input == 1) || (handle->fwd_padding_copy == 1) ) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; 31 const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); 32 const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); 33 LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, IFH, IFW, handle->blocksifm, handle->ifmblock); 34 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 35 LIBXSMM_VLA_DECL(6, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 36 #endif 37 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 38 LIBXSMM_VLA_DECL(6, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 39 #endif 40 41 /* lazy barrier init */ 42 libxsmm_barrier_init(handle->barrier, ltid); 43 44 if ( imgpt <= 1 ) { 45 my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); 46 my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); 47 myOfmId = ltid % threads_per_image; 48 nOfmBlocks = LIBXSMM_UPDIV(handle->blocksofm, threads_per_image); 49 my_ofm_start = LIBXSMM_MIN(myOfmId * nOfmBlocks, handle->blocksofm); 50 my_ofm_end = LIBXSMM_MIN((myOfmId+1) * nOfmBlocks, handle->blocksofm); 51 } 52 53 if ( handle->use_ofm_parallelization == 1 ) { 54 if ( handle->desc.N % 8 == 0) { 55 spread_out = 8; 56 } else if ( handle->desc.N % 4 == 0) { 57 spread_out = 4; 58 } else if (handle->desc.N % 2 == 0) { 59 spread_out = 2; 60 } else if (handle->desc.N % 3 == 0) { 61 spread_out = 3; 62 } else { 63 spread_out = 1; 64 } 65 if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { 66 int tile_id = ltid / spread_out; 67 int ofmpt = LIBXSMM_UPDIV(handle->blocksofm, spread_out); 68 int ofm_id = ltid % spread_out; 69 imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; 70 my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); 71 my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); 72 my_ofm_start = LIBXSMM_MIN(ofm_id * ofmpt, handle->blocksofm); 73 my_ofm_end = LIBXSMM_MIN((ofm_id+1) * ofmpt, handle->blocksofm); 74 } 75 } 76 77 /* remove stride from input */ 78 if (handle->pack_input == 1) { 79 int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); 80 int ifm_id = ltid % spread_out; 81 int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); 82 int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); 83 /* @TODO think about packed format */ 84 LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); 85 for (img = my_img_start; img < my_img_end; img++) { 86 for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { 87 for (oj = 0; oj < handle->ofh; oj++) { 88 for (oi = 0; oi < handle->ofw; oi++) { 89 ij_use = oj * handle->desc.u; 90 ii_use = oi * handle->desc.v; 91 LIBXSMM_PRAGMA_SIMD 92 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 93 LIBXSMM_VLA_ACCESS(5, input, img, oj, oi, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ij_use, ii_use, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); 94 } 95 } 96 } 97 } 98 } 99 if ( handle->use_ofm_parallelization == 1 ) { 100 libxsmm_barrier_wait(handle->barrier, ltid); 101 } 102 } 103 104 /* physical pad input */ 105 if (handle->fwd_padding_copy == 1) { 106 int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); 107 int ifm_id = ltid % spread_out; 108 int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); 109 int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); 110 LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); 111 for (img = my_img_start; img < my_img_end; img++) { 112 for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { 113 /* copy the inner part */ 114 for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { 115 for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { 116 if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { 117 LIBXSMM_PRAGMA_SIMD 118 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 119 LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = 120 LIBXSMM_VLA_ACCESS(5, input_src, img, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm1, ifm2, handle->ifhp, handle->ifwp, handle->blocksifm, handle->ifmblock); 121 } 122 } else { 123 LIBXSMM_PRAGMA_SIMD 124 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 125 LIBXSMM_VLA_ACCESS(5, input, img, ij, ii, ifm1, ifm2, IFH, IFW, handle->blocksifm, handle->ifmblock) = (element_input_type)0; 126 } 127 } 128 } 129 } 130 } 131 } 132 if ( handle->use_ofm_parallelization == 1 ) { 133 libxsmm_barrier_wait(handle->barrier, ltid); 134 } 135 } 136 137 if (handle->use_fallback_fwd_loops == 1) { 138 /* number of tasks that could be run in parallel */ 139 const int work = handle->desc.N * handle->blocksofm * handle->ofh; 140 /* compute chunk size */ 141 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 142 /* compute thr_begin and thr_end */ 143 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 144 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 145 int imgofm1ofh; 146 147 if ( handle->avoid_fmas_in_rim == 1) { 148 for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { 149 img = imgofm1ofh / (handle->blocksofm*handle->ofh); 150 #if 1 151 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; 152 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; 153 #else 154 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; 155 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; 156 #endif 157 158 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 159 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 160 /* set output feature map to zero */ 161 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); 162 for (oi = 0; oi < handle->ofw; ++oi) { 163 LIBXSMM_PRAGMA_SIMD 164 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 165 temp_ptr[ofm2] = (element_output_type)0; 166 } 167 temp_ptr += handle->blocksofm*handle->ofmblock; 168 } 169 } 170 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 171 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 172 for (kj = 0; kj < handle->desc.R; kj++) { 173 for (ki = 0; ki < handle->desc.S; ki++) { 174 /* Prepare batch-reduce kernel arguments */ 175 if (handle->pack_input == 1) { 176 ij_use = oj; 177 ii_use = oi; 178 } else { 179 ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); 180 ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); 181 } 182 oi_use = oi; 183 oj_use = oj; 184 185 if (kj == 0 && oj == 0) { 186 /* Do no FLOPS */ 187 } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { 188 /* Do no FLOPS */ 189 } else if ( oi == 0 && ki == 0 ) { 190 ind = 0; 191 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 192 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 193 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 194 #endif 195 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 196 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 197 #endif 198 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki + 1, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 199 ind++; 200 } 201 n_blocks = ind; 202 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use + 1, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 203 } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { 204 ind = 0; 205 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 206 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 207 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 208 #endif 209 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 210 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 211 #endif 212 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 213 ind++; 214 } 215 n_blocks = ind; 216 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 217 } else { 218 ind = 0; 219 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 220 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 221 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 222 #endif 223 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 224 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 225 #endif 226 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 227 ind++; 228 } 229 n_blocks = ind; 230 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 231 } 232 } 233 } 234 } 235 } 236 } 237 } 238 } else { 239 for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { 240 img = imgofm1ofh / (handle->blocksofm*handle->ofh); 241 #if 1 242 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; 243 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; 244 #else 245 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; 246 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; 247 #endif 248 249 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 250 251 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 252 /* set output feature map to zero */ 253 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); 254 for (oi = 0; oi < handle->ofw; ++oi) { 255 LIBXSMM_PRAGMA_SIMD 256 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 257 temp_ptr[ofm2] = (element_output_type)0; 258 } 259 temp_ptr += handle->blocksofm*handle->ofmblock; 260 } 261 } 262 263 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 264 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 265 /* Prepare batch-reduce kernel arguments */ 266 if (handle->pack_input == 1) { 267 ij_use = oj; 268 ii_use = oi; 269 } else { 270 ij_use = oj * handle->desc.u; 271 ii_use = oi * handle->desc.v; 272 } 273 oi_use = oi; 274 oj_use = oj; 275 ind = 0; 276 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 277 for (kj = 0; kj < handle->desc.R; kj++) { 278 for (ki = 0; ki < handle->desc.S; ki++) { 279 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 280 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 281 #endif 282 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 283 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 284 #endif 285 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 286 ind++; 287 } 288 } 289 } 290 n_blocks = ind; 291 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 292 } 293 } 294 } 295 } 296 } 297 298 } else { 299 if (handle->loop_order == 0) { 300 if ( handle->avoid_fmas_in_rim == 1) { 301 for (img = my_img_start; img < my_img_end; img++) { 302 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { 303 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 304 for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { 305 for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { 306 ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; 307 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { 308 /* set output feature map to zero */ 309 for (oj = 0; oj < handle->ofh; ++oj) { 310 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); 311 for (oi = 0; oi < handle->ofw; ++oi) { 312 LIBXSMM_PRAGMA_SIMD 313 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 314 temp_ptr[ofm2] = (element_output_type)0; 315 } 316 temp_ptr += handle->blocksofm*handle->ofmblock; 317 } 318 } 319 } 320 321 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 322 for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { 323 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 324 for (kj1 = 0; kj1 < handle->desc.R; kj1++) { 325 for (ki1 = 0; ki1 < handle->desc.S; ki1++) { 326 /* Prepare batch-reduce kernel arguments */ 327 if (handle->pack_input == 1) { 328 ij_use = oj; 329 ii_use = oi; 330 } else { 331 ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); 332 ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); 333 } 334 oi_use = oi; 335 oj_use = oj; 336 337 ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; 338 kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; 339 340 if (kj == 0 && oj == 0) { 341 /* Do no FLOPS */ 342 } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { 343 /* Do no FLOPS */ 344 } else if ( oi == 0 && ki == 0 ) { 345 ind = 0; 346 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 347 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 348 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 349 #endif 350 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 351 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 352 #endif 353 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki + 1, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 354 ind++; 355 } 356 n_blocks = ind; 357 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use + 1, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 358 } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { 359 ind = 0; 360 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 361 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 362 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 363 #endif 364 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 365 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 366 #endif 367 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 368 ind++; 369 } 370 n_blocks = ind; 371 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 372 } else { 373 ind = 0; 374 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 375 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 376 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 377 #endif 378 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 379 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 380 #endif 381 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 382 ind++; 383 } 384 n_blocks = ind; 385 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 386 } 387 } 388 } 389 } 390 } 391 } 392 } 393 } 394 } 395 } 396 } 397 } else { 398 for (img = my_img_start; img < my_img_end; img++) { 399 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { 400 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 401 for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { 402 for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { 403 ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; 404 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { 405 /* set output feature map to zero */ 406 for (oj = 0; oj < handle->ofh; ++oj) { 407 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, oj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); 408 for (oi = 0; oi < handle->ofw; ++oi) { 409 LIBXSMM_PRAGMA_SIMD 410 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 411 temp_ptr[ofm2] = (element_output_type)0; 412 } 413 temp_ptr += handle->blocksofm * handle->ofmblock; 414 } 415 } 416 } 417 418 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 419 for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { 420 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 421 /* Prepare batch-reduce kernel arguments */ 422 if (handle->pack_input == 1) { 423 ij_use = oj; 424 ii_use = oi; 425 } else { 426 ij_use = oj * handle->desc.u; 427 ii_use = oi * handle->desc.v; 428 } 429 oi_use = oi; 430 oj_use = oj; 431 ind = 0; 432 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 433 for (kj1 = 0; kj1 < handle->desc.R; kj1++) { 434 for (ki1 = 0; ki1 < handle->desc.S; ki1++) { 435 ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; 436 kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; 437 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 438 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 439 #endif 440 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 441 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 442 #endif 443 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 444 ind++; 445 } 446 } 447 } 448 n_blocks = ind; 449 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 450 } 451 } 452 } 453 } 454 } 455 } 456 } 457 } 458 } 459 } 460 461 if (handle->loop_order == 1) { 462 for (img = my_img_start; img < my_img_end; img++) { 463 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { 464 for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { 465 for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { 466 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 467 for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { 468 if (((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && oj == 0 && oi == 0) { 469 /* set output feature map to zero */ 470 for (ojj = 0; ojj < handle->ofh; ++ojj) { 471 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ojj, 0, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock)); 472 for (oii = 0; oii < handle->ofw; ++oii) { 473 LIBXSMM_PRAGMA_SIMD 474 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 475 temp_ptr[ofm2] = (element_output_type)0; 476 } 477 temp_ptr += handle->blocksofm * handle->ofmblock; 478 } 479 } 480 } 481 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 482 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 483 /* Prepare batch-reduce kernel arguments */ 484 if (handle->pack_input == 1) { 485 ij_use = oj; 486 ii_use = oi; 487 } else { 488 ij_use = oj * handle->desc.u; 489 ii_use = oi * handle->desc.v; 490 } 491 oi_use = oi; 492 oj_use = oj; 493 ind = 0; 494 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 495 for (kj = 0; kj < handle->desc.R; kj++) { 496 for (ki = 0; ki < handle->desc.S; ki++) { 497 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_CUSTOM 498 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 499 #endif 500 #ifdef LIBXSMM_DNN_TPL_FWD_DIRECT_GENERIC_NHWC_RSCK 501 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, kj, ki, ifm2, 0, ofm1, 0, handle->desc.S, handle->blocksifm, handle->ifmblock, handle->blocksofm, handle->ofmblock); 502 #endif 503 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ij_use + kj, ii_use + ki, ifm2, 0, IFH, IFW, handle->blocksifm, handle->ifmblock); 504 ind++; 505 } 506 } 507 } 508 n_blocks = ind; 509 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, oj_use, oi_use, ofm1, 0, handle->ofhp, handle->ofwp, handle->blocksofm, handle->ofmblock), &n_blocks); 510 } 511 } 512 } 513 } 514 } 515 } 516 } 517 } 518 } 519 } 520 521 libxsmm_barrier_wait(handle->barrier, ltid); 522 523