1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 
12 /* here we assume that input and output blocking is similar */
13 const int bn = handle->bn;
14 const int bk = handle->bk;
15 const int bc = handle->bc;
16 const int nBlocksIFm = handle->desc.C / bc;
17 const int nBlocksOFm = handle->desc.K / bk;
18 const int nBlocksMB  = handle->desc.N / bn;
19 
20 /* computing first logical thread */
21 const int ltid = tid - start_thread;
22 
23 /* Transpose kernel to transpose filters  */
24 libxsmm_xtransfunction tr_kernel = handle->tr_kernel;
25 
26 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID)
27 /* number of tasks for transpose that could be run in parallel */
28 const int eltwise_work = nBlocksOFm * nBlocksMB;
29 /* compute chunk size */
30 const int eltwise_chunksize = (eltwise_work % handle->desc.threads == 0) ? (eltwise_work / handle->desc.threads) : ((eltwise_work / handle->desc.threads) + 1);
31 /* compute thr_begin and thr_end */
32 const int eltwise_thr_begin = (ltid * eltwise_chunksize < eltwise_work) ? (ltid * eltwise_chunksize) : eltwise_work;
33 const int eltwise_thr_end = ((ltid + 1) * eltwise_chunksize < eltwise_work) ? ((ltid + 1) * eltwise_chunksize) : eltwise_work;
34 int mb1ofm1;
35 #endif
36 
37 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
38 /* number of tasks for transpose that could be run in parallel */
39 const int dbias_work = nBlocksOFm;
40 /* compute chunk size */
41 const int dbias_chunksize = (dbias_work % handle->desc.threads == 0) ? (dbias_work / handle->desc.threads) : ((dbias_work / handle->desc.threads) + 1);
42 /* compute thr_begin and thr_end */
43 const int dbias_thr_begin = (ltid * dbias_chunksize < dbias_work) ? (ltid * dbias_chunksize) : dbias_work;
44 const int dbias_thr_end = ((ltid + 1) * dbias_chunksize < dbias_work) ? ((ltid + 1) * dbias_chunksize) : dbias_work;
45 #endif
46 
47 /* loop variables */
48 int ofm1 = 0, mb1 = 0, ofm2 = 0, mb2 = 0;
49 
50 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID)
51 element_output_type *grad_output_ptr = ((element_output_type*)handle->scratch)+(handle->desc.C*handle->desc.K);
52 LIBXSMM_VLA_DECL(4, const element_output_type, doutput_orig, (element_output_type*)handle->grad_output->data, nBlocksOFm, bn, bk);
53 #else
54 element_output_type *grad_output_ptr = (element_output_type*)handle->grad_output->data;
55 #endif
56 LIBXSMM_VLA_DECL(4, element_output_type, doutput, grad_output_ptr, nBlocksOFm, bn, bk);
57 
58 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
59 LIBXSMM_VLA_DECL(2,       float,                   dbias, (float*)              handle->grad_bias->data,                          handle->bk);
60 #endif
61 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU
62 LIBXSMM_VLA_DECL(4, unsigned char,              relumask, (unsigned char*)      handle->relumask->data,   nBlocksOFm, handle->bn, handle->bk);
63 #endif
64 
65 /* lazy barrier init */
66 libxsmm_barrier_init(handle->barrier, ltid);
67 
68 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID)
69 for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) {
70   mb1  = mb1ofm1%nBlocksMB;
71   ofm1 = mb1ofm1/nBlocksMB;
72 
73   for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
74     for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
75       float l_cur_out = LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk);
76 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU
77       l_cur_out = (LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) != 0) ? l_cur_out : (element_output_type)0;
78 #endif
79 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
80       l_cur_out = l_cur_out*(1.0f - l_cur_out);
81 #endif
82       LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out;
83     }
84   }
85 }
86 
87 /* wait for eltwise to finish */
88 libxsmm_barrier_wait(handle->barrier, ltid);
89 #endif
90 
91 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS)
92 for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) {
93   for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
94     LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, ofm2, handle->bk ) = 0.0f;
95   }
96 
97   for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) {
98     for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) {
99       for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) {
100         LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, ofm2, handle->bk ) += LIBXSMM_VLA_ACCESS(4,  doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk);
101       }
102     }
103   }
104 }
105 
106 /* wait for eltwise to finish */
107 libxsmm_barrier_wait(handle->barrier, ltid);
108 #endif
109 
110 if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) {
111   const int use_2d_blocking = handle->bwd_2d_blocking;
112 
113   /* number of tasks that could be run in parallel */
114   const int work = nBlocksIFm * nBlocksMB;
115   /* compute chunk size */
116   const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1);
117   /* compute thr_begin and thr_end */
118   const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
119   const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
120 
121   /* number of tasks for transpose that could be run in parallel */
122   const int transpose_work = nBlocksIFm * nBlocksOFm;
123   /* compute chunk size */
124   const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1);
125   /* compute thr_begin and thr_end */
126   const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work;
127   const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work;
128 
129   /* loop variables */
130   int ifm1 = 0, ifm2 = 0, ifm1ofm1 = 0, mb1ifm1 = 0;
131   int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0;
132 
133   LIBXSMM_VLA_DECL(4, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc, bk);
134   LIBXSMM_VLA_DECL(4,        element_input_type,    dinput, (element_input_type* )handle->grad_input->data, nBlocksIFm, bn, bc);
135   LIBXSMM_VLA_DECL(4,       element_filter_type, filter_tr, (element_filter_type*)handle->scratch, nBlocksOFm, bk, bc);
136 
137   unsigned long long  blocks = nBlocksOFm;
138   int KB_BLOCKS = nBlocksOFm, BF = 1;
139   BF = handle->bwd_bf;
140   KB_BLOCKS = nBlocksOFm/BF;
141   blocks = KB_BLOCKS;
142 
143   if (use_2d_blocking == 1) {
144     row_teams = handle->bwd_row_teams;
145     column_teams = handle->bwd_column_teams;
146     my_col_id = ltid % column_teams;
147     my_row_id = ltid / column_teams;
148     im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams);
149     in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, column_teams);
150     my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB);
151     my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB);
152     my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksIFm);
153     my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksIFm);
154   }
155 
156   /* transpose weight */
157   for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) {
158     const unsigned int ubk = (unsigned int)bk;
159     const unsigned int ubc = (unsigned int)bc;
160     ofm1 = ifm1ofm1 / nBlocksIFm;
161     ifm1 = ifm1ofm1 % nBlocksIFm;
162     tr_kernel(&LIBXSMM_VLA_ACCESS(4, filter,  ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &ubk,
163               &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, 0, 0, nBlocksOFm, bk, bc), &ubc);
164 
165 #if 0
166     for (ofm2 = 0; ofm2 < bk; ++ofm2) {
167       for (ifm2 = 0; ifm2 < bc; ++ifm2) {
168         LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, ofm2, ifm2, nBlocksOFm, bk, bc) =
169           LIBXSMM_VLA_ACCESS(4, filter,  ofm1, ifm1, ifm2, ofm2, nBlocksIFm, bc, bk);
170       }
171     }
172 #endif
173   }
174 
175   /* wait for transpose to finish */
176   libxsmm_barrier_wait(handle->barrier, ltid);
177 
178   if (use_2d_blocking == 1) {
179     if (BF > 1) {
180       for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) {
181         for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) {
182           for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) {
183             /* Initialize intermediate f32 tensor */
184             if ( ofm1 == 0 ) {
185               for ( mb2 = 0; mb2 < bn; ++mb2 ) {
186                 for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) {
187                   LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc) = (element_input_type)0;
188                 }
189               }
190             }
191             batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bk, bc ),
192                 &LIBXSMM_VLA_ACCESS(4, doutput,   mb1,  ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk),
193                 &LIBXSMM_VLA_ACCESS(4, dinput,    mb1,  ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks);
194           }
195         }
196       }
197     } else {
198       for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) {
199         for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) {
200           batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, bk, bc),
201               &LIBXSMM_VLA_ACCESS(4, doutput,   mb1,  0, 0, 0, nBlocksOFm, bn, bk),
202               &LIBXSMM_VLA_ACCESS(4, dinput,    mb1,  ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks);
203         }
204       }
205     }
206   } else {
207     if (BF > 1) {
208       for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) {
209         for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) {
210           mb1  = mb1ifm1%nBlocksMB;
211           ifm1 = mb1ifm1/nBlocksMB;
212           /* Initialize intermediate f32 tensor */
213           if ( ofm1 == 0 ) {
214             for ( mb2 = 0; mb2 < bn; ++mb2 ) {
215               for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) {
216                 LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc) = (element_input_type)0;
217               }
218             }
219           }
220           batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bk, bc ),
221               &LIBXSMM_VLA_ACCESS(4, doutput,   mb1,  ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk),
222               &LIBXSMM_VLA_ACCESS(4, dinput,    mb1,  ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks);
223         }
224       }
225     } else {
226       for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) {
227         mb1  = mb1ifm1%nBlocksMB;
228         ifm1 = mb1ifm1/nBlocksMB;
229         batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, bk, bc ),
230             &LIBXSMM_VLA_ACCESS(4, doutput,   mb1,  0, 0, 0, nBlocksOFm, bn, bk),
231             &LIBXSMM_VLA_ACCESS(4, dinput,    mb1,  ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks);
232       }
233     }
234   }
235 
236   libxsmm_barrier_wait(handle->barrier, ltid);
237 }
238 
239 if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) {
240   /* number of tasks that could be run in parallel */
241   const int ofm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ofm_subtasks;
242   const int ifm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ifm_subtasks;
243   const int bbk = (handle->upd_2d_blocking == 1) ? bk : bk/ofm_subtasks;
244   const int bbc = (handle->upd_2d_blocking == 1) ? bc : bc/ifm_subtasks;
245   const int work = nBlocksIFm * ifm_subtasks * nBlocksOFm * ofm_subtasks;
246   const int Cck_work = nBlocksIFm * ifm_subtasks * ofm_subtasks;
247   const int Cc_work = nBlocksIFm * ifm_subtasks;
248 
249   /* 2D blocking parameters  */
250   int use_2d_blocking = handle->upd_2d_blocking;
251   int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0;
252 
253   /* compute chunk size */
254   const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1);
255   /* compute thr_begin and thr_end */
256   const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
257   const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
258   int BF = handle->upd_bf;
259 
260   /* loop variables */
261   int ifm1ofm1 = 0, ifm1 = 0, ifm2 = 0, bfn = 0, ii = 0, jj = 0;
262 
263   /* Batch reduce related variables */
264   unsigned long long  blocks = nBlocksMB/BF;
265 
266   LIBXSMM_VLA_DECL(4, const element_input_type,  input,    (element_input_type* )handle->reg_input->data, nBlocksIFm, bn, bc);
267   LIBXSMM_VLA_DECL(4,       element_filter_type, dfilter,  (element_filter_type*)handle->grad_filter->data, nBlocksIFm, bc, bk);
268 
269   if (use_2d_blocking == 1) {
270     row_teams = handle->upd_row_teams;
271     column_teams = handle->upd_column_teams;
272     my_col_id = ltid % column_teams;
273     my_row_id = ltid / column_teams;
274     im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, row_teams);
275     in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams);
276     my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksIFm);
277     my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksIFm);
278     my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm);
279     my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm);
280   }
281 
282   if (use_2d_blocking == 1) {
283     if (BF == 1) {
284       for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) {
285         for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) {
286           batchreduce_kernel_upd_zerobeta(&LIBXSMM_VLA_ACCESS(4, doutput, 0, ofm1, 0, 0, nBlocksOFm, bn, bk),
287                                           &LIBXSMM_VLA_ACCESS(4, input,   0, ifm1, 0, 0, nBlocksIFm, bn, bc),
288                                           &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &blocks);
289         }
290       }
291     } else {
292       for (bfn = 0; bfn < BF; bfn++) {
293         for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) {
294           for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) {
295             /* initialize current work task to zero */
296             if (bfn == 0) {
297               for (ii = 0; ii<bc; ii++) {
298                 for (jj = 0; jj<bk; jj++) {
299                   LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ii, jj, nBlocksIFm, bc, bk) = (element_filter_type)0;
300                 }
301               }
302             }
303             batchreduce_kernel_upd( &LIBXSMM_VLA_ACCESS(4, doutput, bfn*blocks, ofm1, 0, 0, nBlocksOFm, bn, bk),
304                                     &LIBXSMM_VLA_ACCESS(4, input,   bfn*blocks, ifm1, 0, 0, nBlocksIFm, bn, bc),
305                                     &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &blocks);
306           }
307         }
308       }
309     }
310   } else {
311     if (BF == 1) {
312       for ( ifm1ofm1 = thr_begin; ifm1ofm1 < thr_end; ++ifm1ofm1 ) {
313         ofm1 = ifm1ofm1 / Cck_work;
314         ofm2 = (ifm1ofm1 % Cck_work) / Cc_work;
315         ifm1 = ((ifm1ofm1 % Cck_work) % Cc_work) / ifm_subtasks;
316         ifm2 = ((ifm1ofm1 % Cck_work) % Cc_work) % ifm_subtasks;
317 
318         batchreduce_kernel_upd_zerobeta( &LIBXSMM_VLA_ACCESS(4, doutput, 0, ofm1, 0, ofm2*bbk, nBlocksOFm, bn, bk),
319                                          &LIBXSMM_VLA_ACCESS(4, input,   0, ifm1, 0, ifm2*bbc, nBlocksIFm, bn, bc),
320                                          &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2*bbc, ofm2*bbk, nBlocksIFm, bc, bk), &blocks);
321       }
322     } else {
323       for (bfn = 0; bfn < BF; bfn++) {
324         for ( ifm1ofm1 = thr_begin; ifm1ofm1 < thr_end; ++ifm1ofm1 ) {
325           ofm1 = ifm1ofm1 / Cck_work;
326           ofm2 = (ifm1ofm1 % Cck_work) / Cc_work;
327           ifm1 = ((ifm1ofm1 % Cck_work) % Cc_work) / ifm_subtasks;
328           ifm2 = ((ifm1ofm1 % Cck_work) % Cc_work) % ifm_subtasks;
329 
330           /* initialize current work task to zero */
331           if (bfn == 0) {
332             for (ii = 0; ii<bbc; ii++) {
333               for (jj = 0; jj<bbk; jj++) {
334                 LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2*bbc+ii, ofm2*bbk+jj, nBlocksIFm, bc, bk) = (element_filter_type)0;
335               }
336             }
337           }
338 
339           batchreduce_kernel_upd( &LIBXSMM_VLA_ACCESS(4, doutput, bfn*blocks, ofm1, 0, ofm2*bbk, nBlocksOFm, bn, bk),
340                                   &LIBXSMM_VLA_ACCESS(4, input,   bfn*blocks, ifm1, 0, ifm2*bbc, nBlocksIFm, bn, bc),
341                                   &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2*bbc, ofm2*bbk, nBlocksIFm, bc, bk), &blocks);
342         }
343       }
344     }
345   }
346 
347   libxsmm_barrier_wait(handle->barrier, ltid);
348 }
349 
350