1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 12 int img, ofm1, ofm2 = 0, ifm1, ifm2 = 0, oj, oi, kj, ki, oi_use, oj_use, ii_use, ij_use, ofmb, ifmb, ojb, myOfmId, nOfmBlocks, ind, ofm11, ki1, kj1, ojj, oii, ii, ij, spread_out = 1; 13 /* computing first logical thread */ 14 const int ltid = tid - start_thread; 15 int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); 16 int threads_per_image = handle->desc.threads / handle->desc.N; 17 int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); 18 int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); 19 int my_ofm_start = 0; 20 int my_ofm_end = handle->blocksofm; 21 22 /* Batch reduce related variables */ 23 const element_filter_type *A_ptrs[1024]; 24 const element_input_type *B_ptrs[1024]; 25 unsigned long long n_blocks; 26 27 /* offset output pointer in case of physical output padding */ 28 element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; 29 LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); 30 element_input_type *input_ptr = ( (handle->pack_input == 1) || (handle->fwd_padding_copy == 1) ) ? (element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; 31 const int IFW = (handle->fwd_padding_copy == 1) ? handle->ifwp + 2*handle->desc.pad_w : ( (handle->pack_input == 1) ? handle->ofwp : handle->ifwp ); 32 const int IFH = (handle->fwd_padding_copy == 1) ? handle->ifhp + 2*handle->desc.pad_h : ( (handle->pack_input == 1) ? handle->ofhp : handle->ifhp ); 33 LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); 34 LIBXSMM_VLA_DECL(6, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 35 36 /* lazy barrier init */ 37 libxsmm_barrier_init(handle->barrier, ltid); 38 39 if ( imgpt <= 1 ) { 40 my_img_start = LIBXSMM_MIN(ltid / threads_per_image, handle->desc.N); 41 my_img_end = LIBXSMM_MIN(my_img_start + 1, handle->desc.N); 42 myOfmId = ltid % threads_per_image; 43 nOfmBlocks = LIBXSMM_UPDIV(handle->blocksofm, threads_per_image); 44 my_ofm_start = LIBXSMM_MIN(myOfmId * nOfmBlocks, handle->blocksofm); 45 my_ofm_end = LIBXSMM_MIN((myOfmId+1) * nOfmBlocks, handle->blocksofm); 46 } 47 48 if ( handle->use_ofm_parallelization == 1 ) { 49 if ( handle->desc.N % 8 == 0) { 50 spread_out = 8; 51 } else if ( handle->desc.N % 4 == 0) { 52 spread_out = 4; 53 } else if (handle->desc.N % 2 == 0) { 54 spread_out = 2; 55 } else if (handle->desc.N % 3 == 0) { 56 spread_out = 3; 57 } else { 58 spread_out = 1; 59 } 60 if ((spread_out > 1) && (handle->desc.threads % spread_out == 0)) { 61 int tile_id = ltid / spread_out; 62 int ofmpt = LIBXSMM_UPDIV(handle->blocksofm, spread_out); 63 int ofm_id = ltid % spread_out; 64 imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads) * spread_out; 65 my_img_start = LIBXSMM_MIN(tile_id * imgpt, handle->desc.N); 66 my_img_end = LIBXSMM_MIN((tile_id+1) * imgpt, handle->desc.N); 67 my_ofm_start = LIBXSMM_MIN(ofm_id * ofmpt, handle->blocksofm); 68 my_ofm_end = LIBXSMM_MIN((ofm_id+1) * ofmpt, handle->blocksofm); 69 } 70 } 71 72 /* remove stride from input */ 73 if (handle->pack_input == 1) { 74 int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); 75 int ifm_id = ltid % spread_out; 76 int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); 77 int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); 78 LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 79 for (img = my_img_start; img < my_img_end; img++) { 80 for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { 81 for (oj = 0; oj < handle->ofh; oj++) { 82 for (oi = 0; oi < handle->ofw; oi++) { 83 ij_use = oj * handle->desc.u; 84 ii_use = oi * handle->desc.v; 85 LIBXSMM_PRAGMA_SIMD 86 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 87 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 88 } 89 } 90 } 91 } 92 } 93 if ( handle->use_ofm_parallelization == 1 || handle->desc.N % handle->desc.threads != 0) { 94 libxsmm_barrier_wait(handle->barrier, ltid); 95 } 96 } 97 98 /* physical pad input */ 99 if (handle->fwd_padding_copy == 1) { 100 int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); 101 int ifm_id = ltid % spread_out; 102 int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); 103 int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); 104 LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 105 for (img = my_img_start; img < my_img_end; img++) { 106 for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { 107 /* copy the inner part */ 108 for (ij = 0; ij < handle->ifhp+(2*handle->desc.pad_h); ij++) { 109 for (ii = 0; ii < handle->ifwp+(2*handle->desc.pad_w); ii++) { 110 if ( (ij >= handle->desc.pad_h) && (ii >= handle->desc.pad_w) && (ij < handle->ifhp+handle->desc.pad_h) && (ii < handle->ifwp+handle->desc.pad_w) ) { 111 LIBXSMM_PRAGMA_SIMD 112 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 113 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = 114 LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij-handle->desc.pad_h, ii-handle->desc.pad_w, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 115 } 116 } else { 117 LIBXSMM_PRAGMA_SIMD 118 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 119 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij, ii, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = (element_input_type)0; 120 } 121 } 122 } 123 } 124 } 125 } 126 if ( handle->use_ofm_parallelization == 1 || handle->desc.N % handle->desc.threads != 0 ) { 127 libxsmm_barrier_wait(handle->barrier, ltid); 128 } 129 } 130 131 if (handle->use_fallback_fwd_loops == 1) { 132 /* number of tasks that could be run in parallel */ 133 const int work = handle->desc.N * handle->blocksofm * handle->ofh; 134 /* compute chunk size */ 135 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 136 /* compute thr_begin and thr_end */ 137 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 138 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 139 int imgofm1ofh; 140 141 if ( handle->avoid_fmas_in_rim == 1) { 142 for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { 143 img = imgofm1ofh / (handle->blocksofm*handle->ofh); 144 #if 1 145 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; 146 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; 147 #else 148 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; 149 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; 150 #endif 151 152 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 153 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 154 /* set output feature map to zero */ 155 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 156 for (oi = 0; oi < handle->ofw; ++oi) { 157 LIBXSMM_PRAGMA_SIMD 158 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 159 temp_ptr[ofm2] = (element_output_type)0; 160 } 161 temp_ptr += handle->ofmblock; 162 } 163 } 164 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 165 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 166 for (kj = 0; kj < handle->desc.R; kj++) { 167 for (ki = 0; ki < handle->desc.S; ki++) { 168 /* Prepare batch-reduce kernel arguments */ 169 if (handle->pack_input == 1) { 170 ij_use = oj; 171 ii_use = oi; 172 } else { 173 ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); 174 ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); 175 } 176 oi_use = oi; 177 oj_use = oj; 178 179 if (kj == 0 && oj == 0) { 180 /* Do no FLOPS */ 181 } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { 182 /* Do no FLOPS */ 183 } else if ( oi == 0 && ki == 0 ) { 184 ind = 0; 185 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 186 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 187 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 188 ind++; 189 } 190 n_blocks = ind; 191 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 192 } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { 193 ind = 0; 194 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 195 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 196 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 197 ind++; 198 } 199 n_blocks = ind; 200 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 201 } else { 202 ind = 0; 203 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 204 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 205 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 206 ind++; 207 } 208 n_blocks = ind; 209 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 210 } 211 } 212 } 213 } 214 } 215 } 216 } 217 } else { 218 for (imgofm1ofh = thr_begin; imgofm1ofh < thr_end; ++imgofm1ofh) { 219 img = imgofm1ofh / (handle->blocksofm*handle->ofh); 220 #if 1 221 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->ofh; 222 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->ofh; 223 #else 224 oj = (imgofm1ofh % (handle->blocksofm*handle->ofh))/handle->blocksofm; 225 ofm1 = (imgofm1ofh % (handle->blocksofm*handle->ofh))%handle->blocksofm; 226 #endif 227 228 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 229 230 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 231 /* set output feature map to zero */ 232 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 233 for (oi = 0; oi < handle->ofw; ++oi) { 234 LIBXSMM_PRAGMA_SIMD 235 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 236 temp_ptr[ofm2] = (element_output_type)0; 237 } 238 temp_ptr += handle->ofmblock; 239 } 240 } 241 242 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 243 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 244 /* Prepare batch-reduce kernel arguments */ 245 if (handle->pack_input == 1) { 246 ij_use = oj; 247 ii_use = oi; 248 } else { 249 ij_use = oj * handle->desc.u; 250 ii_use = oi * handle->desc.v; 251 } 252 oi_use = oi; 253 oj_use = oj; 254 ind = 0; 255 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 256 for (kj = 0; kj < handle->desc.R; kj++) { 257 for (ki = 0; ki < handle->desc.S; ki++) { 258 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 259 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 260 ind++; 261 } 262 } 263 } 264 n_blocks = ind; 265 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 266 } 267 } 268 } 269 } 270 } 271 272 } else { 273 if (handle->loop_order == 0) { 274 if ( handle->avoid_fmas_in_rim == 1) { 275 for (img = my_img_start; img < my_img_end; img++) { 276 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { 277 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 278 for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { 279 for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { 280 ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; 281 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { 282 /* set output feature map to zero */ 283 for (oj = 0; oj < handle->ofh; ++oj) { 284 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 285 for (oi = 0; oi < handle->ofw; ++oi) { 286 LIBXSMM_PRAGMA_SIMD 287 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 288 temp_ptr[ofm2] = (element_output_type)0; 289 } 290 temp_ptr += handle->ofmblock; 291 } 292 } 293 } 294 295 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 296 for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { 297 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 298 for (kj1 = 0; kj1 < handle->desc.R; kj1++) { 299 for (ki1 = 0; ki1 < handle->desc.S; ki1++) { 300 /* Prepare batch-reduce kernel arguments */ 301 if (handle->pack_input == 1) { 302 ij_use = oj; 303 ii_use = oi; 304 } else { 305 ij_use = oj * handle->desc.u - (1-handle->desc.pad_h_in); 306 ii_use = oi * handle->desc.v - (1-handle->desc.pad_w_in); 307 } 308 oi_use = oi; 309 oj_use = oj; 310 311 ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; 312 kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; 313 314 if (kj == 0 && oj == 0) { 315 /* Do no FLOPS */ 316 } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { 317 /* Do no FLOPS */ 318 } else if ( oi == 0 && ki == 0 ) { 319 ind = 0; 320 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 321 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 322 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki + 1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 323 ind++; 324 } 325 n_blocks = ind; 326 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use + 1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 327 } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { 328 ind = 0; 329 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 330 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 331 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 332 ind++; 333 } 334 n_blocks = ind; 335 br_gemm_kernel2(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 336 } else { 337 ind = 0; 338 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 339 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 340 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 341 ind++; 342 } 343 n_blocks = ind; 344 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 345 } 346 } 347 } 348 } 349 } 350 } 351 } 352 } 353 } 354 } 355 } 356 } else { 357 for (img = my_img_start; img < my_img_end; img++) { 358 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { 359 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 360 for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { 361 for (ofm11 = ofmb; ofm11 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm11++ ) { 362 ofm1 = (handle->shuffle_filter_accesses == 1) ? (ofm11+ltid)%handle->blocksofm : ofm11; 363 if ( (ifmb == 0) && ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && ojb == 0) { 364 /* set output feature map to zero */ 365 for (oj = 0; oj < handle->ofh; ++oj) { 366 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 367 for (oi = 0; oi < handle->ofw; ++oi) { 368 LIBXSMM_PRAGMA_SIMD 369 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 370 temp_ptr[ofm2] = (element_output_type)0; 371 } 372 temp_ptr += handle->ofmblock; 373 } 374 } 375 } 376 377 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 378 for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { 379 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 380 /* Prepare batch-reduce kernel arguments */ 381 if (handle->pack_input == 1) { 382 ij_use = oj; 383 ii_use = oi; 384 } else { 385 ij_use = oj * handle->desc.u; 386 ii_use = oi * handle->desc.v; 387 } 388 oi_use = oi; 389 oj_use = oj; 390 ind = 0; 391 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 392 for (kj1 = 0; kj1 < handle->desc.R; kj1++) { 393 for (ki1 = 0; ki1 < handle->desc.S; ki1++) { 394 ki = (handle->shuffle_filter_accesses == 1) ? (ki1+ltid)%handle->desc.S : ki1; 395 kj = (handle->shuffle_filter_accesses == 1) ? (kj1+ltid)%handle->desc.R : kj1; 396 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 397 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 398 ind++; 399 } 400 } 401 } 402 n_blocks = ind; 403 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 404 } 405 } 406 } 407 } 408 } 409 } 410 } 411 } 412 } 413 } 414 415 if (handle->loop_order == 1) { 416 for (img = my_img_start; img < my_img_end; img++) { 417 for (ofmb = my_ofm_start; ofmb < my_ofm_end; ofmb += handle->block_fwd_ofm) { 418 for (ojb = 0; ojb < handle->ofh; ojb += handle->block_fwd_oj) { 419 for (oj = ojb; oj < LIBXSMM_MIN(ojb+handle->block_fwd_oj,handle->ofh); oj += handle->fwd_ofh_rb) { 420 for (oi = 0; oi < handle->ofw; oi += handle->fwd_ofw_rb) { 421 for (ofm1 = ofmb; ofm1 < LIBXSMM_MIN(ofmb+handle->block_fwd_ofm, my_ofm_end); ofm1++ ) { 422 if (((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0 && oj == 0 && oi == 0) { 423 /* set output feature map to zero */ 424 for (ojj = 0; ojj < handle->ofh; ++ojj) { 425 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, ojj, 0, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 426 for (oii = 0; oii < handle->ofw; ++oii) { 427 LIBXSMM_PRAGMA_SIMD 428 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 429 temp_ptr[ofm2] = (element_output_type)0; 430 } 431 temp_ptr += handle->ofmblock; 432 } 433 } 434 } 435 for (ifmb = 0; ifmb < handle->blocksifm; ifmb += handle->block_fwd_ifm) { 436 for (ifm1 = ifmb; ifm1 < LIBXSMM_MIN(ifmb+handle->block_fwd_ifm, handle->blocksifm); ifm1 += handle->blocksifm_blocking) { 437 /* Prepare batch-reduce kernel arguments */ 438 if (handle->pack_input == 1) { 439 ij_use = oj; 440 ii_use = oi; 441 } else { 442 ij_use = oj * handle->desc.u; 443 ii_use = oi * handle->desc.v; 444 } 445 oi_use = oi; 446 oj_use = oj; 447 ind = 0; 448 for (ifm2 = ifm1; ifm2 < ifm1 + handle->blocksifm_blocking; ifm2++) { 449 for (kj = 0; kj < handle->desc.R; kj++) { 450 for (ki = 0; ki < handle->desc.S; ki++) { 451 A_ptrs[ind] = &LIBXSMM_VLA_ACCESS(6, weight, ofm1, ifm2, kj, ki, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, handle->ifmblock, handle->ofmblock); 452 B_ptrs[ind] = &LIBXSMM_VLA_ACCESS(5, input, img, ifm2, ij_use + kj, ii_use + ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock); 453 ind++; 454 } 455 } 456 } 457 n_blocks = ind; 458 br_gemm_kernel(A_ptrs, B_ptrs, &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj_use, oi_use, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 459 } 460 } 461 } 462 } 463 } 464 } 465 } 466 } 467 } 468 } 469 470 libxsmm_barrier_wait(handle->barrier, ltid); 471 472