1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 12 /* here we assume that input and output blocking is similar */ 13 const int bn = handle->bn; 14 const int bk = handle->bk; 15 const int bc = handle->bc; 16 const int nBlocksIFm = handle->desc.C / bc; 17 const int nBlocksOFm = handle->desc.K / bk; 18 const int nBlocksMB = handle->desc.N / bn; 19 20 /* computing first logical thread */ 21 const int ltid = tid - start_thread; 22 23 /* Transpose kernel to transpose filters */ 24 libxsmm_xtransfunction tr_kernel = handle->tr_kernel; 25 26 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) 27 /* number of tasks for transpose that could be run in parallel */ 28 const int eltwise_work = nBlocksOFm * nBlocksMB; 29 /* compute chunk size */ 30 const int eltwise_chunksize = (eltwise_work % handle->desc.threads == 0) ? (eltwise_work / handle->desc.threads) : ((eltwise_work / handle->desc.threads) + 1); 31 /* compute thr_begin and thr_end */ 32 const int eltwise_thr_begin = (ltid * eltwise_chunksize < eltwise_work) ? (ltid * eltwise_chunksize) : eltwise_work; 33 const int eltwise_thr_end = ((ltid + 1) * eltwise_chunksize < eltwise_work) ? ((ltid + 1) * eltwise_chunksize) : eltwise_work; 34 int mb1ofm1; 35 #endif 36 37 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS 38 /* number of tasks for transpose that could be run in parallel */ 39 const int dbias_work = nBlocksOFm; 40 /* compute chunk size */ 41 const int dbias_chunksize = (dbias_work % handle->desc.threads == 0) ? (dbias_work / handle->desc.threads) : ((dbias_work / handle->desc.threads) + 1); 42 /* compute thr_begin and thr_end */ 43 const int dbias_thr_begin = (ltid * dbias_chunksize < dbias_work) ? (ltid * dbias_chunksize) : dbias_work; 44 const int dbias_thr_end = ((ltid + 1) * dbias_chunksize < dbias_work) ? ((ltid + 1) * dbias_chunksize) : dbias_work; 45 #endif 46 47 /* loop variables */ 48 int ofm1 = 0, mb1 = 0, ofm2 = 0, mb2 = 0; 49 50 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) 51 element_output_type *grad_output_ptr = ((element_output_type*)handle->scratch)+(handle->desc.C*handle->desc.K); 52 LIBXSMM_VLA_DECL(4, const element_output_type, doutput_orig, (element_output_type*)handle->grad_output->data, nBlocksOFm, bn, bk); 53 #else 54 element_output_type *grad_output_ptr = (element_output_type*)handle->grad_output->data; 55 #endif 56 LIBXSMM_VLA_DECL(4, element_output_type, doutput, grad_output_ptr, nBlocksOFm, bn, bk); 57 58 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_BIAS 59 LIBXSMM_VLA_DECL(2, float, dbias, (float*) handle->grad_bias->data, handle->bk); 60 #endif 61 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU 62 LIBXSMM_VLA_DECL(4, unsigned char, relumask, (unsigned char*) handle->relumask->data, nBlocksOFm, handle->bn, handle->bk); 63 #endif 64 65 /* lazy barrier init */ 66 libxsmm_barrier_init(handle->barrier, ltid); 67 68 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_RELU) || defined(LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID) 69 for ( mb1ofm1 = eltwise_thr_begin; mb1ofm1 < eltwise_thr_end; ++mb1ofm1 ) { 70 mb1 = mb1ofm1%nBlocksMB; 71 ofm1 = mb1ofm1/nBlocksMB; 72 73 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 74 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 75 float l_cur_out = LIBXSMM_VLA_ACCESS(4, doutput_orig, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); 76 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_RELU 77 l_cur_out = (LIBXSMM_VLA_ACCESS(4, relumask, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) != 0) ? l_cur_out : (element_output_type)0; 78 #endif 79 #ifdef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID 80 l_cur_out = l_cur_out*(1.0f - l_cur_out); 81 #endif 82 LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk) = l_cur_out; 83 } 84 } 85 } 86 87 /* wait for eltwise to finish */ 88 libxsmm_barrier_wait(handle->barrier, ltid); 89 #endif 90 91 #if defined(LIBXSMM_DNN_FC_BWD_FUSE_BIAS) 92 for ( ofm1 = dbias_thr_begin; ofm1 < dbias_thr_end; ++ofm1 ) { 93 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 94 LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, ofm2, handle->bk ) = 0.0f; 95 } 96 97 for ( mb1 = 0; mb1 < nBlocksMB; ++mb1 ) { 98 for ( mb2 = 0; mb2 < handle->bn; ++mb2 ) { 99 for ( ofm2 = 0; ofm2 < handle->bk; ++ofm2 ) { 100 LIBXSMM_VLA_ACCESS( 2, dbias, ofm1, ofm2, handle->bk ) += LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1, mb2, ofm2, nBlocksOFm, handle->bn, handle->bk); 101 } 102 } 103 } 104 } 105 106 /* wait for eltwise to finish */ 107 libxsmm_barrier_wait(handle->barrier, ltid); 108 #endif 109 110 if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { 111 const int use_2d_blocking = handle->bwd_2d_blocking; 112 113 /* number of tasks that could be run in parallel */ 114 const int work = nBlocksIFm * nBlocksMB; 115 /* compute chunk size */ 116 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 117 /* compute thr_begin and thr_end */ 118 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 119 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 120 121 /* number of tasks for transpose that could be run in parallel */ 122 const int transpose_work = nBlocksIFm * nBlocksOFm; 123 /* compute chunk size */ 124 const int transpose_chunksize = (transpose_work % handle->desc.threads == 0) ? (transpose_work / handle->desc.threads) : ((transpose_work / handle->desc.threads) + 1); 125 /* compute thr_begin and thr_end */ 126 const int transpose_thr_begin = (ltid * transpose_chunksize < transpose_work) ? (ltid * transpose_chunksize) : transpose_work; 127 const int transpose_thr_end = ((ltid + 1) * transpose_chunksize < transpose_work) ? ((ltid + 1) * transpose_chunksize) : transpose_work; 128 129 /* loop variables */ 130 int ifm1 = 0, ifm2 = 0, ifm1ofm1 = 0, mb1ifm1 = 0; 131 int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; 132 133 LIBXSMM_VLA_DECL(4, const element_filter_type, filter, (element_filter_type*)handle->reg_filter->data, nBlocksIFm, bc, bk); 134 LIBXSMM_VLA_DECL(4, element_input_type, dinput, (element_input_type* )handle->grad_input->data, nBlocksIFm, bn, bc); 135 LIBXSMM_VLA_DECL(4, element_filter_type, filter_tr, (element_filter_type*)handle->scratch, nBlocksOFm, bk, bc); 136 137 unsigned long long blocks = nBlocksOFm; 138 int KB_BLOCKS = nBlocksOFm, BF = 1; 139 BF = handle->bwd_bf; 140 KB_BLOCKS = nBlocksOFm/BF; 141 blocks = KB_BLOCKS; 142 143 if (use_2d_blocking == 1) { 144 row_teams = handle->bwd_row_teams; 145 column_teams = handle->bwd_column_teams; 146 my_col_id = ltid % column_teams; 147 my_row_id = ltid / column_teams; 148 im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksMB, row_teams); 149 in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, column_teams); 150 my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksMB); 151 my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksMB); 152 my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksIFm); 153 my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksIFm); 154 } 155 156 /* transpose weight */ 157 for (ifm1ofm1 = transpose_thr_begin; ifm1ofm1 < transpose_thr_end; ++ifm1ofm1) { 158 const unsigned int ubk = (unsigned int)bk; 159 const unsigned int ubc = (unsigned int)bc; 160 ofm1 = ifm1ofm1 / nBlocksIFm; 161 ifm1 = ifm1ofm1 % nBlocksIFm; 162 tr_kernel(&LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &ubk, 163 &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, 0, 0, nBlocksOFm, bk, bc), &ubc); 164 165 #if 0 166 for (ofm2 = 0; ofm2 < bk; ++ofm2) { 167 for (ifm2 = 0; ifm2 < bc; ++ifm2) { 168 LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1, ofm2, ifm2, nBlocksOFm, bk, bc) = 169 LIBXSMM_VLA_ACCESS(4, filter, ofm1, ifm1, ifm2, ofm2, nBlocksIFm, bc, bk); 170 } 171 } 172 #endif 173 } 174 175 /* wait for transpose to finish */ 176 libxsmm_barrier_wait(handle->barrier, ltid); 177 178 if (use_2d_blocking == 1) { 179 if (BF > 1) { 180 for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { 181 for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { 182 for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { 183 /* Initialize intermediate f32 tensor */ 184 if ( ofm1 == 0 ) { 185 for ( mb2 = 0; mb2 < bn; ++mb2 ) { 186 for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { 187 LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc) = (element_input_type)0; 188 } 189 } 190 } 191 batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bk, bc ), 192 &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), 193 &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); 194 } 195 } 196 } 197 } else { 198 for (ifm1 = my_in_start; ifm1 < my_in_end; ++ifm1) { 199 for (mb1 = my_im_start; mb1 < my_im_end; ++mb1) { 200 batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, bk, bc), 201 &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), 202 &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); 203 } 204 } 205 } 206 } else { 207 if (BF > 1) { 208 for ( ofm1 = 0; ofm1 < BF; ++ofm1 ) { 209 for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { 210 mb1 = mb1ifm1%nBlocksMB; 211 ifm1 = mb1ifm1/nBlocksMB; 212 /* Initialize intermediate f32 tensor */ 213 if ( ofm1 == 0 ) { 214 for ( mb2 = 0; mb2 < bn; ++mb2 ) { 215 for ( ifm2 = 0; ifm2 < bc; ++ifm2 ) { 216 LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, mb2, ifm2, nBlocksIFm, bn, bc) = (element_input_type)0; 217 } 218 } 219 } 220 batchreduce_kernel_bwd( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bk, bc ), 221 &LIBXSMM_VLA_ACCESS(4, doutput, mb1, ofm1*KB_BLOCKS, 0, 0, nBlocksOFm, bn, bk), 222 &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); 223 } 224 } 225 } else { 226 for ( mb1ifm1 = thr_begin; mb1ifm1 < thr_end; ++mb1ifm1 ) { 227 mb1 = mb1ifm1%nBlocksMB; 228 ifm1 = mb1ifm1/nBlocksMB; 229 batchreduce_kernel_bwd_zerobeta( &LIBXSMM_VLA_ACCESS(4, filter_tr, ifm1, 0, 0, 0, nBlocksOFm, bk, bc ), 230 &LIBXSMM_VLA_ACCESS(4, doutput, mb1, 0, 0, 0, nBlocksOFm, bn, bk), 231 &LIBXSMM_VLA_ACCESS(4, dinput, mb1, ifm1, 0, 0, nBlocksIFm, bn, bc), &blocks); 232 } 233 } 234 } 235 236 libxsmm_barrier_wait(handle->barrier, ltid); 237 } 238 239 if ( (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) || (kind == LIBXSMM_DNN_COMPUTE_KIND_BWDUPD) ) { 240 /* number of tasks that could be run in parallel */ 241 const int ofm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ofm_subtasks; 242 const int ifm_subtasks = (handle->upd_2d_blocking == 1) ? 1 : handle->ifm_subtasks; 243 const int bbk = (handle->upd_2d_blocking == 1) ? bk : bk/ofm_subtasks; 244 const int bbc = (handle->upd_2d_blocking == 1) ? bc : bc/ifm_subtasks; 245 const int work = nBlocksIFm * ifm_subtasks * nBlocksOFm * ofm_subtasks; 246 const int Cck_work = nBlocksIFm * ifm_subtasks * ofm_subtasks; 247 const int Cc_work = nBlocksIFm * ifm_subtasks; 248 249 /* 2D blocking parameters */ 250 int use_2d_blocking = handle->upd_2d_blocking; 251 int im_tasks_per_thread = 0, in_tasks_per_thread = 0, my_in_start = 0, my_in_end = 0, my_im_start = 0, my_im_end = 0, my_row_id = 0, my_col_id = 0, row_teams = 0, column_teams = 0; 252 253 /* compute chunk size */ 254 const int chunksize = (work % handle->desc.threads == 0) ? (work / handle->desc.threads) : ((work / handle->desc.threads) + 1); 255 /* compute thr_begin and thr_end */ 256 const int thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 257 const int thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 258 int BF = handle->upd_bf; 259 260 /* loop variables */ 261 int ifm1ofm1 = 0, ifm1 = 0, ifm2 = 0, bfn = 0, ii = 0, jj = 0; 262 263 /* Batch reduce related variables */ 264 unsigned long long blocks = nBlocksMB/BF; 265 266 LIBXSMM_VLA_DECL(4, const element_input_type, input, (element_input_type* )handle->reg_input->data, nBlocksIFm, bn, bc); 267 LIBXSMM_VLA_DECL(4, element_filter_type, dfilter, (element_filter_type*)handle->grad_filter->data, nBlocksIFm, bc, bk); 268 269 if (use_2d_blocking == 1) { 270 row_teams = handle->upd_row_teams; 271 column_teams = handle->upd_column_teams; 272 my_col_id = ltid % column_teams; 273 my_row_id = ltid / column_teams; 274 im_tasks_per_thread = LIBXSMM_UPDIV(nBlocksIFm, row_teams); 275 in_tasks_per_thread = LIBXSMM_UPDIV(nBlocksOFm, column_teams); 276 my_im_start = LIBXSMM_MIN(my_row_id * im_tasks_per_thread, nBlocksIFm); 277 my_im_end = LIBXSMM_MIN((my_row_id+1) * im_tasks_per_thread, nBlocksIFm); 278 my_in_start = LIBXSMM_MIN(my_col_id * in_tasks_per_thread, nBlocksOFm); 279 my_in_end = LIBXSMM_MIN((my_col_id+1) * in_tasks_per_thread, nBlocksOFm); 280 } 281 282 if (use_2d_blocking == 1) { 283 if (BF == 1) { 284 for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { 285 for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { 286 batchreduce_kernel_upd_zerobeta(&LIBXSMM_VLA_ACCESS(4, doutput, 0, ofm1, 0, 0, nBlocksOFm, bn, bk), 287 &LIBXSMM_VLA_ACCESS(4, input, 0, ifm1, 0, 0, nBlocksIFm, bn, bc), 288 &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &blocks); 289 } 290 } 291 } else { 292 for (bfn = 0; bfn < BF; bfn++) { 293 for (ofm1 = my_in_start; ofm1 < my_in_end; ++ofm1) { 294 for (ifm1 = my_im_start; ifm1 < my_im_end; ++ifm1) { 295 /* initialize current work task to zero */ 296 if (bfn == 0) { 297 for (ii = 0; ii<bc; ii++) { 298 for (jj = 0; jj<bk; jj++) { 299 LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ii, jj, nBlocksIFm, bc, bk) = (element_filter_type)0; 300 } 301 } 302 } 303 batchreduce_kernel_upd( &LIBXSMM_VLA_ACCESS(4, doutput, bfn*blocks, ofm1, 0, 0, nBlocksOFm, bn, bk), 304 &LIBXSMM_VLA_ACCESS(4, input, bfn*blocks, ifm1, 0, 0, nBlocksIFm, bn, bc), 305 &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, 0, 0, nBlocksIFm, bc, bk), &blocks); 306 } 307 } 308 } 309 } 310 } else { 311 if (BF == 1) { 312 for ( ifm1ofm1 = thr_begin; ifm1ofm1 < thr_end; ++ifm1ofm1 ) { 313 ofm1 = ifm1ofm1 / Cck_work; 314 ofm2 = (ifm1ofm1 % Cck_work) / Cc_work; 315 ifm1 = ((ifm1ofm1 % Cck_work) % Cc_work) / ifm_subtasks; 316 ifm2 = ((ifm1ofm1 % Cck_work) % Cc_work) % ifm_subtasks; 317 318 batchreduce_kernel_upd_zerobeta( &LIBXSMM_VLA_ACCESS(4, doutput, 0, ofm1, 0, ofm2*bbk, nBlocksOFm, bn, bk), 319 &LIBXSMM_VLA_ACCESS(4, input, 0, ifm1, 0, ifm2*bbc, nBlocksIFm, bn, bc), 320 &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2*bbc, ofm2*bbk, nBlocksIFm, bc, bk), &blocks); 321 } 322 } else { 323 for (bfn = 0; bfn < BF; bfn++) { 324 for ( ifm1ofm1 = thr_begin; ifm1ofm1 < thr_end; ++ifm1ofm1 ) { 325 ofm1 = ifm1ofm1 / Cck_work; 326 ofm2 = (ifm1ofm1 % Cck_work) / Cc_work; 327 ifm1 = ((ifm1ofm1 % Cck_work) % Cc_work) / ifm_subtasks; 328 ifm2 = ((ifm1ofm1 % Cck_work) % Cc_work) % ifm_subtasks; 329 330 /* initialize current work task to zero */ 331 if (bfn == 0) { 332 for (ii = 0; ii<bbc; ii++) { 333 for (jj = 0; jj<bbk; jj++) { 334 LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2*bbc+ii, ofm2*bbk+jj, nBlocksIFm, bc, bk) = (element_filter_type)0; 335 } 336 } 337 } 338 339 batchreduce_kernel_upd( &LIBXSMM_VLA_ACCESS(4, doutput, bfn*blocks, ofm1, 0, ofm2*bbk, nBlocksOFm, bn, bk), 340 &LIBXSMM_VLA_ACCESS(4, input, bfn*blocks, ifm1, 0, ifm2*bbc, nBlocksIFm, bn, bc), 341 &LIBXSMM_VLA_ACCESS(4, dfilter, ofm1, ifm1, ifm2*bbc, ofm2*bbk, nBlocksIFm, bc, bk), &blocks); 342 } 343 } 344 } 345 } 346 347 libxsmm_barrier_wait(handle->barrier, ltid); 348 } 349 350