1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Alexander Heinecke, Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 
12 int img, ofm1, ofm2, ifm1, ifm2, oj, oi, kj, ki, ii_use, ij_use, oii, spread_out = 1;
13 /* computing first logical thread */
14 const int ltid = tid - start_thread;
15 
16 /* number of tasks that could be run in parallel */
17 const int w_tasks = handle->ofw/handle->fwd_ofw_rb;
18 const int work = handle->desc.N * handle->blocksofm * handle->ofh * w_tasks;
19 const int work_KHW = handle->blocksofm * handle->ofh * w_tasks;
20 const int work_HW = handle->ofh * w_tasks;
21 /* compute chunk size */
22 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1);
23 /* compute thr_begin and thr_end */
24 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
25 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
26 int imgofm1ofhofw;
27 int imgpt = LIBXSMM_UPDIV(handle->desc.N, handle->desc.threads);
28 int my_img_start = LIBXSMM_MIN(ltid * imgpt, handle->desc.N);
29 int my_img_end = LIBXSMM_MIN((ltid+1) * imgpt, handle->desc.N);
30 int ifmblock_lp =  handle->ifmblock/handle->fm_lp_block;
31 /* Batch reduce related variables */
32 unsigned long long n_blocks;
33 
34 /* offset output pointer in case of physical output padding */
35 element_output_type* out = (element_output_type*)handle->reg_output->data + ((size_t)handle->desc.pad_h_out * handle->ofwp + handle->desc.pad_w_out) * handle->ofmblock;
36 LIBXSMM_VLA_DECL(5, element_output_type, output, out, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock);
37 element_input_type *input_ptr = (handle->pack_input == 1) ?(element_input_type*)((char*)handle->scratch + handle->fwd_packing_padding_scratch_offset) : (element_input_type*)handle->reg_input->data;
38 const int IFW = (handle->pack_input == 1) ? handle->ofwp : handle->ifwp;
39 const int IFH = (handle->pack_input == 1) ? handle->ofhp : handle->ifhp;
40 LIBXSMM_VLA_DECL(5, element_input_type, input, input_ptr, handle->blocksifm, IFH, IFW, handle->ifmblock);
41 LIBXSMM_VLA_DECL(7, const element_filter_type, weight, (element_filter_type*)handle->reg_filter->data, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block);
42 
43 libxsmm_barrier_init(handle->barrier, ltid);
44 
45 if (handle->pack_input == 1) {
46   int ifmpt = LIBXSMM_UPDIV(handle->blocksifm, spread_out);
47   int ifm_id = ltid % spread_out;
48   int my_ifm_start = LIBXSMM_MIN(ifm_id * ifmpt, handle->blocksifm);
49   int my_ifm_end = LIBXSMM_MIN((ifm_id+1) * ifmpt, handle->blocksifm);
50   LIBXSMM_VLA_DECL(5, element_input_type, input_src, (element_input_type*)handle->reg_input->data, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
51   for (img = my_img_start; img < my_img_end; img++) {
52     for (ifm1 = my_ifm_start; ifm1 < my_ifm_end; ifm1++) {
53       for (oj = 0; oj < handle->ofh; oj++) {
54         for (oi = 0; oi < handle->ofw; oi++) {
55           ij_use = oj * handle->desc.u;
56           ii_use = oi * handle->desc.v;
57           LIBXSMM_PRAGMA_SIMD
58             for (ifm2 = 0; ifm2 < handle->ifmblock; ifm2++) {
59               LIBXSMM_VLA_ACCESS(5,  input, img, ifm1, oj, oi, ifm2, handle->blocksifm, IFH, IFW, handle->ifmblock) = LIBXSMM_VLA_ACCESS(5,  input_src,  img, ifm1, ij_use, ii_use, ifm2, handle->blocksifm, handle->ifhp, handle->ifwp, handle->ifmblock);
60             }
61         }
62       }
63     }
64   }
65   if ( handle->use_ofm_parallelization == 1 ) {
66     libxsmm_barrier_wait(handle->barrier, ltid);
67   }
68 }
69 
70 if (handle->avoid_fmas_in_rim == 1) {
71   n_blocks = handle->blocksifm_blocking;
72   for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) {
73     img = imgofm1ofhofw / work_KHW;
74     ofm1 = (imgofm1ofhofw % work_KHW)/work_HW;
75     oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks;
76     oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb;
77     ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u - (1-handle->desc.pad_h_in);
78     ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v - (1-handle->desc.pad_w_in);
79     if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) {
80       /* set output feature map to zero */
81       element_output_type* temp_ptr   = &(LIBXSMM_VLA_ACCESS(  5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock));
82       for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) {
83         LIBXSMM_PRAGMA_SIMD
84           for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) {
85             temp_ptr[ofm2] = (element_output_type)0;
86           }
87         temp_ptr += handle->ofmblock;
88       }
89     }
90     for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) {
91       for (kj = 0; kj < handle->desc.R; kj++) {
92         for (ki = 0; ki < handle->desc.S; ki++) {
93           if (kj == 0 && oj == 0) {
94             /* Do no FLOPS  */
95           } else if (kj == handle->desc.R-1 && oj == handle->ofh-1 ) {
96             /* Do no FLOPS  */
97           } else if ( oi == 0 && ki == 0 ) {
98             br_gemm_kernel_strided2( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block),
99                 &LIBXSMM_VLA_ACCESS(5,  input,  img, ifm1, ij_use+kj, ii_use+ki+1, 0, handle->blocksifm, IFH, IFW, handle->ifmblock),
100                 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi+1, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks);
101           } else if (oi == handle->ofw-handle->fwd_ofw_rb  && ki == handle->desc.S-1) {
102             br_gemm_kernel_strided2( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block),
103                 &LIBXSMM_VLA_ACCESS(5,  input,  img, ifm1, ij_use+kj, ii_use+ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock),
104                 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks);
105           } else {
106             br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, kj, ki, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block),
107                 &LIBXSMM_VLA_ACCESS(5,  input,  img, ifm1, ij_use+kj, ii_use+ki, 0, handle->blocksifm, IFH, IFW, handle->ifmblock),
108                 &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks);
109           }
110         }
111       }
112     }
113   }
114 } else {
115   /* Strided based BRGEMM  */
116   n_blocks = (unsigned long long)handle->blocksifm_blocking * handle->desc.R * handle->desc.S;
117   if (handle->desc.R == 1 && handle->desc.S == 1) {
118     for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) {
119       img = imgofm1ofhofw / work_KHW;
120       ofm1 = (imgofm1ofhofw % work_KHW)/work_HW;
121       oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks;
122       oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb;
123       ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u;
124       ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v;
125       if (  ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) {
126         /* set output feature map to zero */
127         element_output_type* temp_ptr   = &(LIBXSMM_VLA_ACCESS(  5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock));
128         for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) {
129           LIBXSMM_PRAGMA_SIMD
130             for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) {
131               temp_ptr[ofm2] = (element_output_type)0;
132             }
133           temp_ptr += handle->ofmblock;
134         }
135       }
136       for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) {
137         br_gemm_kernel_strided( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block),
138             &LIBXSMM_VLA_ACCESS(5,  input,  img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock),
139             &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks);
140       }
141     }
142   } else { /* Offset based BRGEMM */
143     for (imgofm1ofhofw = thr_begin; imgofm1ofhofw < thr_end; ++imgofm1ofhofw) {
144       img = imgofm1ofhofw / work_KHW;
145       ofm1 = (imgofm1ofhofw % work_KHW)/work_HW;
146       oj = ((imgofm1ofhofw % work_KHW)%work_HW)/w_tasks;
147       oi = (((imgofm1ofhofw % work_KHW)%work_HW)%w_tasks)*handle->fwd_ofw_rb;
148       ij_use = (handle->pack_input == 1) ? oj : oj * handle->desc.u;
149       ii_use = (handle->pack_input == 1) ? oi : oi * handle->desc.v;
150       if ( ((handle->options & LIBXSMM_DNN_CONV_OPTION_OVERWRITE) > 0) && handle->avoid_acc_load == 0) {
151         /* set output feature map to zero */
152         element_output_type* temp_ptr   = &(LIBXSMM_VLA_ACCESS(  5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock));
153         for (oii = 0; oii < handle->fwd_ofw_rb; ++oii) {
154           LIBXSMM_PRAGMA_SIMD
155             for (ofm2 = 0; ofm2 < handle->ofmblock; ++ofm2) {
156               temp_ptr[ofm2] = (element_output_type)0;
157             }
158           temp_ptr += handle->ofmblock;
159         }
160       }
161       for (ifm1 = 0; ifm1 < handle->blocksifm; ifm1 += handle->blocksifm_blocking) {
162         br_gemm_kernel_offset( &LIBXSMM_VLA_ACCESS(7, weight, ofm1, ifm1, 0, 0, 0, 0, 0, handle->blocksifm, handle->desc.R, handle->desc.S, ifmblock_lp, handle->ofmblock, handle->fm_lp_block),
163             &LIBXSMM_VLA_ACCESS(5,  input,  img, ifm1, ij_use, ii_use, 0, handle->blocksifm, IFH, IFW, handle->ifmblock),
164             &LIBXSMM_VLA_ACCESS(5, output, img, ofm1, oj, oi, 0, handle->blocksofm, handle->ofhp, handle->ofwp, handle->ofmblock), &n_blocks, handle->A_offsets, handle->B_offsets);
165       }
166     }
167   }
168 }
169 libxsmm_barrier_wait(handle->barrier, ltid);
170 
171