1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 12 int img, ofm1, ofm2, ifm1, ifm2, oj, oi, kj, ki, ii_use, ij_use, oii, spread_out = 1; 13 /* computing first logical thread */ 14 const int ltid = tid - start_thread; 15 16 /* number of tasks that could be run in parallel */ 17 const int w_tasks = handle->ofw/handle->fwd_ofw_rb; 18 const int work = handle->desc.N * handle->blocksofm * handle->ofh * w_tasks; 19 const int work_KHW = handle->blocksofm * handle->ofh * w_tasks; 20 const int work_HW = handle->ofh * w_tasks; 21 /* compute chunk size */ 22 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 23 /* compute thr_begin and thr_end */ 24 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 25 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 26 int imgofm1ofhofw; 27 int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads); 28 int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N); 29 int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N); 30 int ifmblock_lp = handle->ifmblock/handle->fm_lp_block; 31 /* Batch reduce related variables */ 32 unsigned long long n_blocks; 33 34 /* offset output pointer in case of physical output padding */ 35 element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock; 36 LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock); 37 element_input_type *input_ptr = (handle->pack_input == 1) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data; 38 const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp; 39 const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp; 40 LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock); 41 LIBXSMM_VLA_DECL(7, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block); 42 43 libxsmm_barrier_init(handle->barrier, ltid); 44 45 if (handle->pack_input == 1) { 46 int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out); 47 int ifm_id = ltid % spread_out; 48 int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm); 49 int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm); 50 LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 51 for (img = my_img_start; img < my_img_end; img++) { 52 for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) { 53 for (oj = 0; oj < handle->ofh; oj++) { 54 for (oi = 0; oi < handle->ofw; oi++) { 55 ij_use = oj * handle->desc.u; 56 ii_use = oi * handle->desc.v; 57 LIBXSMM_PRAGMA_SIMD 58 for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) { 59 LIBXSMM_VLA_ACCESS(5, input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5, input_src, img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock); 60 } 61 } 62 } 63 } 64 } 65 if ( handle->use_ofm_parallelization == 1 ) { 66 libxsmm_barrier_wait(handle->barrier, ltid); 67 } 68 } 69 70 if (handle->avoid_fmas_in_rim == 1) { 71 n_blocks = handle->blocksifm_blocking; 72 for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { 73 img = imgofm1ofhofw / work_KHW; 74 ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; 75 oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; 76 oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; 77 ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u - (1-handle->desc.pad_h_in); 78 ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v - (1-handle->desc.pad_w_in); 79 if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 80 /* set output feature map to zero */ 81 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 82 for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) { 83 LIBXSMM_PRAGMA_SIMD 84 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 85 temp_ptr[ofm2] = (element_output_type)0; 86 } 87 temp_ptr += handle->ofmblock; 88 } 89 } 90 for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) { 91 for (kj = 0; kj < handle->desc.R; kj++) { 92 for (ki = 0; ki < handle->desc.S; ki++) { 93 if (kj == 0 && oj == 0) { 94 /* Do no FLOPS */ 95 } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) { 96 /* Do no FLOPS */ 97 } else if ( oi == 0 && ki == 0 ) { 98 br_gemm_kernel_strided2( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), 99 &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use+kj, ii_use+ki+1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), 100 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi+1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 101 } else if (oi == handle->ofw-handle->fwd_ofw_rb && ki == handle->desc.S-1) { 102 br_gemm_kernel_strided2( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), 103 &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use+kj, ii_use+ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), 104 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 105 } else { 106 br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), 107 &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use+kj, ii_use+ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), 108 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 109 } 110 } 111 } 112 } 113 } 114 } else { 115 /* Strided based BRGEMM */ 116 n_blocks = (unsigned long long)handle->blocksifm_blocking * handle->desc.R * handle->desc.S; 117 if (handle->desc.R == 1 && handle->desc.S == 1) { 118 for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { 119 img = imgofm1ofhofw / work_KHW; 120 ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; 121 oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; 122 oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; 123 ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u; 124 ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v; 125 if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 126 /* set output feature map to zero */ 127 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 128 for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) { 129 LIBXSMM_PRAGMA_SIMD 130 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 131 temp_ptr[ofm2] = (element_output_type)0; 132 } 133 temp_ptr += handle->ofmblock; 134 } 135 } 136 for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) { 137 br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), 138 &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), 139 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks); 140 } 141 } 142 } else { /* Offset based BRGEMM */ 143 for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) { 144 img = imgofm1ofhofw / work_KHW; 145 ofm1 = (imgofm1ofhofw % work_KHW)/work_HW; 146 oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks; 147 oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb; 148 ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u; 149 ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v; 150 if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) { 151 /* set output feature map to zero */ 152 element_output_type* temp_ptr = &(LIBXSMM_VLA_ACCESS( 5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock)); 153 for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) { 154 LIBXSMM_PRAGMA_SIMD 155 for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) { 156 temp_ptr[ofm2] = (element_output_type)0; 157 } 158 temp_ptr += handle->ofmblock; 159 } 160 } 161 for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) { 162 br_gemm_kernel_offset( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block), 163 &LIBXSMM_VLA_ACCESS(5, input, img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock), 164 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets); 165 } 166 } 167 } 168 } 169 libxsmm_barrier_wait(handle->barrier, ltid); 170 171