1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 #if 0 12 #define PROFILE 13 #endif 14 15 /* helper variables */ 16 libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB; 17 /* tensor dimensions */ 18 libxsmm_blasint K = handle->desc.K; 19 libxsmm_blasint N = handle->desc.N; 20 libxsmm_blasint C = handle->desc.C; 21 libxsmm_blasint t = handle->T; 22 libxsmm_blasint bk = handle->bk; 23 libxsmm_blasint bn = handle->bn; 24 libxsmm_blasint bc = handle->bc; 25 libxsmm_blasint K4 = K * 4; 26 const libxsmm_blasint cBlocks = C/bc; 27 const libxsmm_blasint kBlocks = K/bk; 28 const libxsmm_blasint nBlocks = N/bn; 29 unsigned long long blocks; 30 /* tensor raw pointers */ 31 element_input_type *xt = (element_input_type* )handle->xt->data; 32 element_input_type *csp = (element_input_type* )handle->csp->data; 33 element_input_type *hpD = (element_input_type* )handle->hp->data; 34 element_filter_type *w = (element_filter_type*)handle->w->data; 35 element_filter_type *r = (element_filter_type*)handle->r->data; 36 element_output_type *cst = (element_output_type*)handle->cst->data; 37 element_output_type *ht = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL; 38 element_output_type *it = (element_output_type*)handle->it->data; 39 element_output_type *ft = (element_output_type*)handle->ft->data; 40 element_output_type *ot = (element_output_type*)handle->ot->data; 41 element_output_type *cit = (element_output_type*)handle->cit->data; 42 element_output_type *cot = (element_output_type*)handle->cot->data; 43 element_input_type *dxt = (element_input_type*)handle->dxt->data; 44 element_input_type *dcsp = (element_input_type* )handle->dcsp->data; 45 element_input_type *dhpD = (element_input_type* )handle->dhp->data; 46 element_filter_type *dw = (element_filter_type*)handle->dw->data; 47 element_filter_type *dr = (element_filter_type*)handle->dr->data; 48 element_output_type *db = (element_output_type*)handle->db->data; 49 element_output_type *dcsD = (element_output_type*)handle->dcs->data; 50 element_output_type *dht = (element_output_type*)handle->dht->data; 51 element_output_type *diD = (element_output_type*)handle->scratch_di; 52 element_output_type *dfD = (element_output_type*)handle->scratch_df; 53 element_output_type *doD = (element_output_type*)handle->scratch_do; 54 element_output_type *dciD = (element_output_type*)handle->scratch_dci; 55 element_output_type *doutD = (element_output_type*)handle->scratch_deltat; 56 element_input_type *scratch_xT = (element_input_type* )handle->scratch_xT; 57 element_filter_type *scratch_wT = (element_filter_type*)handle->scratch_wT; 58 element_filter_type *scratch_rT = (element_filter_type*)handle->scratch_rT; 59 element_output_type *scratch_hT = (element_output_type*)handle->scratch_hT; 60 element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; 61 element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; 62 element_filter_type *wiD = &(w[0]); 63 element_filter_type *wcD = &(w[K]); 64 element_filter_type *wfD = &(w[2*K]); 65 element_filter_type *woD = &(w[3*K]); 66 element_filter_type *riD = &(r[0]); 67 element_filter_type *rcD = &(r[K]); 68 element_filter_type *rfD = &(r[2*K]); 69 element_filter_type *roD = &(r[3*K]); 70 element_filter_type *dwiD = &(dw[0]); 71 element_filter_type *dwcD = &(dw[K]); 72 element_filter_type *dwfD = &(dw[2*K]); 73 element_filter_type *dwoD = &(dw[3*K]); 74 element_filter_type *driD = &(dr[0]); 75 element_filter_type *drcD = &(dr[K]); 76 element_filter_type *drfD = &(dr[2*K]); 77 element_filter_type *droD = &(dr[3*K]); 78 element_filter_type *dwiD_scratch = &(w_scratch[0]); 79 element_filter_type *dwcD_scratch = &(w_scratch[C*K]); 80 element_filter_type *dwfD_scratch = &(w_scratch[2*C*K]); 81 element_filter_type *dwoD_scratch = &(w_scratch[3*C*K]); 82 element_filter_type *driD_scratch = &(r_scratch[0]); 83 element_filter_type *drcD_scratch = &(r_scratch[K*K]); 84 element_filter_type *drfD_scratch = &(r_scratch[2*K*K]); 85 element_filter_type *droD_scratch = &(r_scratch[3*K*K]); 86 element_output_type *dbi = &(db[0]); 87 element_output_type *dbc = &(db[K]); 88 element_output_type *dbf = &(db[2*K]); 89 element_output_type *dbo = &(db[3*K]); 90 element_filter_type *scratch_wiT = &(scratch_wT[0]); 91 element_filter_type *scratch_wcT = &(scratch_wT[C*K]); 92 element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]); 93 element_filter_type *scratch_woT = &(scratch_wT[3*C*K]); 94 element_filter_type *scratch_riT = &(scratch_rT[0]); 95 element_filter_type *scratch_rcT = &(scratch_rT[K*K]); 96 element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]); 97 element_filter_type *scratch_roT = &(scratch_rT[3*K*K]); 98 element_output_type *t1D = (element_output_type*)handle->scratch_t1; 99 element_output_type *t2D = (element_output_type*)handle->scratch_t2; 100 /* multidimensional arrays */ 101 LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K); 102 LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K); 103 LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); 104 LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); 105 LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); 106 LIBXSMM_VLA_DECL(2, element_filter_type, wi, wiD, K4); 107 LIBXSMM_VLA_DECL(2, element_filter_type, wf, wfD, K4); 108 LIBXSMM_VLA_DECL(2, element_filter_type, wo, woD, K4); 109 LIBXSMM_VLA_DECL(2, element_filter_type, wc, wcD, K4); 110 LIBXSMM_VLA_DECL(2, element_filter_type, ri, riD, K4); 111 LIBXSMM_VLA_DECL(2, element_filter_type, rf, rfD, K4); 112 LIBXSMM_VLA_DECL(2, element_filter_type, ro, roD, K4); 113 LIBXSMM_VLA_DECL(2, element_filter_type, rc, rcD, K4); 114 LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); 115 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); 116 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); 117 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); 118 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); 119 LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); 120 LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); 121 LIBXSMM_VLA_DECL(3, element_input_type, dx, dxt, N, C); 122 LIBXSMM_VLA_DECL(2, element_input_type, dcp, dcsp, K); 123 LIBXSMM_VLA_DECL(2, element_input_type, dhp, dhpD, K); 124 LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD_scratch, cBlocks, bc, bk); 125 LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD_scratch, cBlocks, bc, bk); 126 LIBXSMM_VLA_DECL(4, element_filter_type, dwo, dwoD_scratch, cBlocks, bc, bk); 127 LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD_scratch, cBlocks, bc, bk); 128 LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD_scratch, kBlocks, bk, bk); 129 LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD_scratch, kBlocks, bk, bk); 130 LIBXSMM_VLA_DECL(4, element_filter_type, dro, droD_scratch, kBlocks, bk, bk); 131 LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD_scratch, kBlocks, bk, bk); 132 LIBXSMM_VLA_DECL(2, element_filter_type, dwi_ck, dwiD, 4*K); 133 LIBXSMM_VLA_DECL(2, element_filter_type, dwf_ck, dwfD, 4*K); 134 LIBXSMM_VLA_DECL(2, element_filter_type, dwo_ck, dwoD, 4*K); 135 LIBXSMM_VLA_DECL(2, element_filter_type, dwc_ck, dwcD, 4*K); 136 LIBXSMM_VLA_DECL(2, element_filter_type, dri_ck, driD, 4*K); 137 LIBXSMM_VLA_DECL(2, element_filter_type, drf_ck, drfD, 4*K); 138 LIBXSMM_VLA_DECL(2, element_filter_type, dro_ck, droD, 4*K); 139 LIBXSMM_VLA_DECL(2, element_filter_type, drc_ck, drcD, 4*K); 140 LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K); 141 LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K); 142 LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K); 143 LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K); 144 LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K); 145 LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K); 146 LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K); 147 LIBXSMM_VLA_DECL(2, element_input_type, xT, scratch_xT, N); 148 LIBXSMM_VLA_DECL(4, element_filter_type, wiT, scratch_wiT, kBlocks, bk, bc); 149 LIBXSMM_VLA_DECL(4, element_filter_type, wcT, scratch_wcT, kBlocks, bk, bc); 150 LIBXSMM_VLA_DECL(4, element_filter_type, wfT, scratch_wfT, kBlocks, bk, bc); 151 LIBXSMM_VLA_DECL(4, element_filter_type, woT, scratch_woT, kBlocks, bk, bc); 152 LIBXSMM_VLA_DECL(4, element_filter_type, riT, scratch_riT, kBlocks, bk, bk); 153 LIBXSMM_VLA_DECL(4, element_filter_type, rcT, scratch_rcT, kBlocks, bk, bk); 154 LIBXSMM_VLA_DECL(4, element_filter_type, rfT, scratch_rfT, kBlocks, bk, bk); 155 LIBXSMM_VLA_DECL(4, element_filter_type, roT, scratch_roT, kBlocks, bk, bk); 156 LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N); 157 element_output_type *dout_ptr = NULL; 158 /* define batch-reduce gemm kernels */ 159 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL); 160 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL); 161 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL); 162 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); 163 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL); 164 const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL); 165 166 /* Auxiliary arrays for batch-reduce gemm calls */ 167 const element_filter_type *A_array[1024]; 168 const element_output_type *B_array[1024]; 169 170 LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk); 171 LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk); 172 LIBXSMM_VLA_DECL(4, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, kBlocks, bn, bk); 173 LIBXSMM_VLA_DECL(4, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk); 174 175 /* computing first logical thread */ 176 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; 177 178 /* number of tasks that could be run in parallel for N and K blocks*/ 179 const libxsmm_blasint work_nk = (N/bn) * (K/bk); 180 /* compute chunk size */ 181 const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1); 182 /* compute thr_begin and thr_end */ 183 const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk; 184 const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk; 185 186 /* number of tasks that could be run in parallel for N and C blocks*/ 187 const libxsmm_blasint work_nc = (N/bn) * (C/bc); 188 /* compute chunk size */ 189 const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1); 190 /* compute thr_begin and thr_end */ 191 const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc; 192 const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc; 193 194 /* number of tasks that could be run in parallel for C and K blocks*/ 195 const libxsmm_blasint work_ck = (C/bc) * (K/bk); 196 /* compute chunk size */ 197 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); 198 /* compute thr_begin and thr_end */ 199 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; 200 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; 201 202 /* number of tasks that could be run in parallel for K and K blocks*/ 203 const libxsmm_blasint work_kk = (K/bk) * (K/bk); 204 /* compute chunk size */ 205 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); 206 /* compute thr_begin and thr_end */ 207 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; 208 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; 209 210 #if defined(LIBXSMM_RNN_CELL_AVX512) 211 element_output_type *cps_ptr = NULL; 212 int k_tasks = K/16; 213 int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1); 214 /* compute thr_begin and thr_end */ 215 const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K; 216 const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K;__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum; 217 #endif 218 /* number of tasks that could be run in parallel for K blocks*/ 219 /* compute chunk size */ 220 const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1); 221 /* compute thr_begin and thr_end */ 222 const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K; 223 const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K; 224 #ifdef PROFILE 225 __int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0; 226 float total_time = 0.0; 227 #endif 228 int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0; 229 230 libxsmm_blasint ikic, inic, inik, icin, ikin; 231 232 /* lazy barrier init */ 233 libxsmm_barrier_init(handle->barrier, (int)ltid); 234 235 /* Blocking reduction domain if it is too large */ 236 BF = 1; 237 if (K > 1024 && K <= 2048) { 238 BF = 8; 239 while (kBlocks % BF != 0) { 240 BF--; 241 } 242 } 243 244 if (K > 2048) { 245 BF = 16; 246 while (kBlocks % BF != 0) { 247 BF--; 248 } 249 } 250 KB_BLOCKS = kBlocks/BF; 251 252 /* initialization is done at the beginning */ 253 if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { 254 libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads); 255 } 256 257 /* initialization is done at the beginning */ 258 if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { 259 libxsmm_internal_matrix_zero(C*K*4, w_scratch, start_thread, tid, handle->desc.threads); 260 libxsmm_internal_matrix_zero(K*K*4, r_scratch, start_thread, tid, handle->desc.threads); 261 libxsmm_internal_matrix_zero(K*4, db, start_thread, tid, handle->desc.threads); 262 } 263 264 #ifdef PROFILE 265 if (ltid == 0) _start = _rdtsc(); 266 #endif 267 /* transpose W */ 268 for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { 269 ic = (ikic / (K/bk)); 270 ik = (ikic % (K/bk)); 271 for (jk = 0; jk < bk; ++jk) { 272 for (jc = 0; jc < bc; ++jc) { 273 LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wi, ic*bc+jc, ik*bk+jk, 4*K); 274 LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wc, ic*bc+jc, ik*bk+jk, 4*K); 275 LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wf, ic*bc+jc, ik*bk+jk, 4*K); 276 LIBXSMM_VLA_ACCESS(4, woT, ic, ik, jk, jc, kBlocks, bk, bc) = LIBXSMM_VLA_ACCESS(2, wo, ic*bc+jc, ik*bk+jk, 4*K); 277 } 278 } 279 } 280 281 /* transpose R */ 282 for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { 283 ik = (ikic / (K/bk)); 284 ic = (ikic % (K/bk)); 285 for (jk = 0; jk < bk; ++jk) { 286 for (jc = 0; jc < bk; ++jc) { 287 LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ri, ic*bk+jc, ik*bk+jk, 4*K); 288 LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rc, ic*bk+jc, ik*bk+jk, 4*K); 289 LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rf, ic*bk+jc, ik*bk+jk, 4*K); 290 LIBXSMM_VLA_ACCESS(4, roT, ic, ik, jk, jc, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ro, ic*bk+jc, ik*bk+jk, 4*K); 291 } 292 } 293 } 294 #ifdef PROFILE 295 if (ltid == 0) { 296 _end = _rdtsc(); 297 weight_trans_cycles += _end - _start; 298 } 299 #endif 300 301 #include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c" 302 303 if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) { 304 #ifdef PROFILE 305 if (ltid == 0) _start = _rdtsc(); 306 #endif 307 /* Store result weight matrices in CK format */ 308 for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { 309 icb = ikic / (K/bk); 310 ic = icb*bc; 311 ikb = ikic % (K/bk); 312 ik = ikb*bk; 313 for (jc = 0; jc < bc; ++jc) { 314 for (jk = 0; jk < bk; ++jk) { 315 LIBXSMM_VLA_ACCESS(2, dwi_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk); 316 LIBXSMM_VLA_ACCESS(2, dwc_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk); 317 LIBXSMM_VLA_ACCESS(2, dwf_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk); 318 LIBXSMM_VLA_ACCESS(2, dwo_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk); 319 } 320 } 321 } 322 323 for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { 324 icb = ikic / (K/bk); 325 ic = icb*bk; 326 ikb = ikic % (K/bk); 327 ik = ikb*bk; 328 for (jc = 0; jc < bk; ++jc) { 329 for (jk = 0; jk < bk; ++jk) { 330 LIBXSMM_VLA_ACCESS(2, dri_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk); 331 LIBXSMM_VLA_ACCESS(2, drc_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk); 332 LIBXSMM_VLA_ACCESS(2, drf_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk); 333 LIBXSMM_VLA_ACCESS(2, dro_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk); 334 } 335 } 336 } 337 libxsmm_barrier_wait(handle->barrier, (int)ltid); 338 #ifdef PROFILE 339 if (ltid == 0) { 340 _end = _rdtsc(); 341 reformat_cycles += _end - _start; 342 } 343 #endif 344 } 345 346 #ifdef PROFILE 347 if (ltid == 0) { 348 printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); 349 total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; 350 printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 351 printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 352 printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9))); 353 printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9))); 354 printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 355 printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9))); 356 printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 357 printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 358 } 359 #undef PROFILE 360 #endif 361