1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #if 0
12 #define PROFILE
13 #endif
14 
15 /* helper variables */
16 libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB;
17 /* tensor dimensions */
18 libxsmm_blasint K = handle->desc.K;
19 libxsmm_blasint N = handle->desc.N;
20 libxsmm_blasint C = handle->desc.C;
21 libxsmm_blasint t = handle->T;
22 libxsmm_blasint bk = handle->bk;
23 libxsmm_blasint bn = handle->bn;
24 libxsmm_blasint bc = handle->bc;
25 libxsmm_blasint K4 = K * 4;
26 const libxsmm_blasint cBlocks = C/bc;
27 const libxsmm_blasint kBlocks = K/bk;
28 const libxsmm_blasint nBlocks = N/bn;
29 unsigned long long blocks;
30 /* tensor raw pointers */
31 element_input_type  *xt    = (element_input_type* )handle->xt->data;
32 element_input_type *csp    = (element_input_type* )handle->csp->data;
33 element_input_type *hpD    = (element_input_type* )handle->hp->data;
34 element_filter_type *w     = (element_filter_type*)handle->w->data;
35 element_filter_type *r     = (element_filter_type*)handle->r->data;
36 element_output_type *cst   = (element_output_type*)handle->cst->data;
37 element_output_type *ht    = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL;
38 element_output_type *it    = (element_output_type*)handle->it->data;
39 element_output_type *ft    = (element_output_type*)handle->ft->data;
40 element_output_type *ot    = (element_output_type*)handle->ot->data;
41 element_output_type *cit   = (element_output_type*)handle->cit->data;
42 element_output_type *cot   = (element_output_type*)handle->cot->data;
43 element_input_type  *dxt   = (element_input_type*)handle->dxt->data;
44 element_input_type  *dcsp  = (element_input_type* )handle->dcsp->data;
45 element_input_type  *dhpD  = (element_input_type* )handle->dhp->data;
46 element_filter_type *dw    = (element_filter_type*)handle->dw->data;
47 element_filter_type *dr    = (element_filter_type*)handle->dr->data;
48 element_output_type *db    = (element_output_type*)handle->db->data;
49 element_output_type *dcsD  = (element_output_type*)handle->dcs->data;
50 element_output_type *dht   = (element_output_type*)handle->dht->data;
51 element_output_type *diD   = (element_output_type*)handle->scratch_di;
52 element_output_type *dfD   = (element_output_type*)handle->scratch_df;
53 element_output_type *doD   = (element_output_type*)handle->scratch_do;
54 element_output_type *dciD  = (element_output_type*)handle->scratch_dci;
55 element_output_type *doutD = (element_output_type*)handle->scratch_deltat;
56 element_input_type  *scratch_xT  = (element_input_type* )handle->scratch_xT;
57 element_filter_type *scratch_wT  = (element_filter_type*)handle->scratch_wT;
58 element_filter_type *scratch_rT  = (element_filter_type*)handle->scratch_rT;
59 element_output_type *scratch_hT  = (element_output_type*)handle->scratch_hT;
60 element_filter_type *w_scratch   = (element_filter_type*)handle->scratch_w;
61 element_filter_type *r_scratch   = (element_filter_type*)handle->scratch_r;
62 element_filter_type *wiD   = &(w[0]);
63 element_filter_type *wcD   = &(w[K]);
64 element_filter_type *wfD   = &(w[2*K]);
65 element_filter_type *woD   = &(w[3*K]);
66 element_filter_type *riD   = &(r[0]);
67 element_filter_type *rcD   = &(r[K]);
68 element_filter_type *rfD   = &(r[2*K]);
69 element_filter_type *roD   = &(r[3*K]);
70 element_filter_type *dwiD  = &(dw[0]);
71 element_filter_type *dwcD  = &(dw[K]);
72 element_filter_type *dwfD  = &(dw[2*K]);
73 element_filter_type *dwoD  = &(dw[3*K]);
74 element_filter_type *driD  = &(dr[0]);
75 element_filter_type *drcD  = &(dr[K]);
76 element_filter_type *drfD  = &(dr[2*K]);
77 element_filter_type *droD  = &(dr[3*K]);
78 element_filter_type *dwiD_scratch  = &(w_scratch[0]);
79 element_filter_type *dwcD_scratch  = &(w_scratch[C*K]);
80 element_filter_type *dwfD_scratch  = &(w_scratch[2*C*K]);
81 element_filter_type *dwoD_scratch  = &(w_scratch[3*C*K]);
82 element_filter_type *driD_scratch  = &(r_scratch[0]);
83 element_filter_type *drcD_scratch  = &(r_scratch[K*K]);
84 element_filter_type *drfD_scratch  = &(r_scratch[2*K*K]);
85 element_filter_type *droD_scratch  = &(r_scratch[3*K*K]);
86 element_output_type *dbi   = &(db[0]);
87 element_output_type *dbc   = &(db[K]);
88 element_output_type *dbf   = &(db[2*K]);
89 element_output_type *dbo   = &(db[3*K]);
90 element_filter_type *scratch_wiT = &(scratch_wT[0]);
91 element_filter_type *scratch_wcT = &(scratch_wT[C*K]);
92 element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]);
93 element_filter_type *scratch_woT = &(scratch_wT[3*C*K]);
94 element_filter_type *scratch_riT = &(scratch_rT[0]);
95 element_filter_type *scratch_rcT = &(scratch_rT[K*K]);
96 element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]);
97 element_filter_type *scratch_roT = &(scratch_rT[3*K*K]);
98 element_output_type *t1D   = (element_output_type*)handle->scratch_t1;
99 element_output_type *t2D   = (element_output_type*)handle->scratch_t2;
100 /* multidimensional arrays */
101 LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);
102 LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);
103 LIBXSMM_VLA_DECL(3, element_input_type,  x, xt, N, C);
104 LIBXSMM_VLA_DECL(2, element_input_type,  cp, csp, K);
105 LIBXSMM_VLA_DECL(2, element_input_type,  hp, hpD, K);
106 LIBXSMM_VLA_DECL(2, element_filter_type, wi, wiD, K4);
107 LIBXSMM_VLA_DECL(2, element_filter_type, wf, wfD, K4);
108 LIBXSMM_VLA_DECL(2, element_filter_type, wo, woD, K4);
109 LIBXSMM_VLA_DECL(2, element_filter_type, wc, wcD, K4);
110 LIBXSMM_VLA_DECL(2, element_filter_type, ri, riD, K4);
111 LIBXSMM_VLA_DECL(2, element_filter_type, rf, rfD, K4);
112 LIBXSMM_VLA_DECL(2, element_filter_type, ro, roD, K4);
113 LIBXSMM_VLA_DECL(2, element_filter_type, rc, rcD, K4);
114 LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K);
115 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K);
116 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K);
117 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K);
118 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K);
119 LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K);
120 LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K);
121 LIBXSMM_VLA_DECL(3, element_input_type,  dx, dxt, N, C);
122 LIBXSMM_VLA_DECL(2, element_input_type,  dcp, dcsp, K);
123 LIBXSMM_VLA_DECL(2, element_input_type,  dhp, dhpD, K);
124 LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD_scratch, cBlocks, bc, bk);
125 LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD_scratch, cBlocks, bc, bk);
126 LIBXSMM_VLA_DECL(4, element_filter_type, dwo, dwoD_scratch, cBlocks, bc, bk);
127 LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD_scratch, cBlocks, bc, bk);
128 LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD_scratch, kBlocks, bk, bk);
129 LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD_scratch, kBlocks, bk, bk);
130 LIBXSMM_VLA_DECL(4, element_filter_type, dro, droD_scratch, kBlocks, bk, bk);
131 LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD_scratch, kBlocks, bk, bk);
132 LIBXSMM_VLA_DECL(2, element_filter_type, dwi_ck, dwiD, 4*K);
133 LIBXSMM_VLA_DECL(2, element_filter_type, dwf_ck, dwfD, 4*K);
134 LIBXSMM_VLA_DECL(2, element_filter_type, dwo_ck, dwoD, 4*K);
135 LIBXSMM_VLA_DECL(2, element_filter_type, dwc_ck, dwcD, 4*K);
136 LIBXSMM_VLA_DECL(2, element_filter_type, dri_ck, driD, 4*K);
137 LIBXSMM_VLA_DECL(2, element_filter_type, drf_ck, drfD, 4*K);
138 LIBXSMM_VLA_DECL(2, element_filter_type, dro_ck, droD, 4*K);
139 LIBXSMM_VLA_DECL(2, element_filter_type, drc_ck, drcD, 4*K);
140 LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K);
141 LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K);
142 LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K);
143 LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K);
144 LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K);
145 LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K);
146 LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K);
147 LIBXSMM_VLA_DECL(2, element_input_type,  xT, scratch_xT, N);
148 LIBXSMM_VLA_DECL(4, element_filter_type, wiT, scratch_wiT, kBlocks, bk, bc);
149 LIBXSMM_VLA_DECL(4, element_filter_type, wcT, scratch_wcT, kBlocks, bk, bc);
150 LIBXSMM_VLA_DECL(4, element_filter_type, wfT, scratch_wfT, kBlocks, bk, bc);
151 LIBXSMM_VLA_DECL(4, element_filter_type, woT, scratch_woT, kBlocks, bk, bc);
152 LIBXSMM_VLA_DECL(4, element_filter_type, riT, scratch_riT, kBlocks, bk, bk);
153 LIBXSMM_VLA_DECL(4, element_filter_type, rcT, scratch_rcT, kBlocks, bk, bk);
154 LIBXSMM_VLA_DECL(4, element_filter_type, rfT, scratch_rfT, kBlocks, bk, bk);
155 LIBXSMM_VLA_DECL(4, element_filter_type, roT, scratch_roT, kBlocks, bk, bk);
156 LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N);
157 element_output_type *dout_ptr = NULL;
158 /* define batch-reduce gemm kernels */
159 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL);
160 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL);
161 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL);
162 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL);
163 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL);
164 const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL);
165 
166 /* Auxiliary arrays for batch-reduce gemm calls */
167 const element_filter_type *A_array[1024];
168 const element_output_type *B_array[1024];
169 
170 LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk);
171 LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk);
172 LIBXSMM_VLA_DECL(4, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, kBlocks, bn, bk);
173 LIBXSMM_VLA_DECL(4, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk);
174 
175 /* computing first logical thread */
176 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread;
177 
178 /* number of tasks that could be run in parallel for N and K blocks*/
179 const libxsmm_blasint work_nk = (N/bn) * (K/bk);
180 /* compute chunk size */
181 const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1);
182 /* compute thr_begin and thr_end */
183 const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk;
184 const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk;
185 
186 /* number of tasks that could be run in parallel for N and C blocks*/
187 const libxsmm_blasint work_nc = (N/bn) * (C/bc);
188 /* compute chunk size */
189 const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1);
190 /* compute thr_begin and thr_end */
191 const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc;
192 const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc;
193 
194 /* number of tasks that could be run in parallel for C and K blocks*/
195 const libxsmm_blasint work_ck = (C/bc) * (K/bk);
196 /* compute chunk size */
197 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1);
198 /* compute thr_begin and thr_end */
199 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck;
200 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck;
201 
202 /* number of tasks that could be run in parallel for K and K blocks*/
203 const libxsmm_blasint work_kk = (K/bk) * (K/bk);
204 /* compute chunk size */
205 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1);
206 /* compute thr_begin and thr_end */
207 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk;
208 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk;
209 
210 #if defined(LIBXSMM_RNN_CELL_AVX512)
211 element_output_type *cps_ptr = NULL;
212 int k_tasks = K/16;
213 int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1);
214 /* compute thr_begin and thr_end */
215 const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K;
216 const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K;__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum;
217 #endif
218 /* number of tasks that could be run in parallel for K blocks*/
219 /* compute chunk size */
220 const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1);
221 /* compute thr_begin and thr_end */
222 const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K;
223 const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K;
224 #ifdef PROFILE
225 __int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0, reformat_cycles = 0;
226 float total_time = 0.0;
227 #endif
228 int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0;
229 
230 libxsmm_blasint ikic, inic, inik, icin, ikin;
231 
232 /* lazy barrier init */
233 libxsmm_barrier_init(handle->barrier, (int)ltid);
234 
235 /* Blocking reduction domain if it is too large */
236 BF = 1;
237 if (K > 1024 && K <= 2048) {
238   BF = 8;
239   while (kBlocks % BF != 0) {
240     BF--;
241   }
242 }
243 
244 if (K > 2048) {
245   BF = 16;
246   while (kBlocks % BF != 0) {
247     BF--;
248   }
249 }
250 KB_BLOCKS = kBlocks/BF;
251 
252 /* initialization is done at the beginning */
253 if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) {
254   libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads);
255 }
256 
257 /* initialization is done at the beginning */
258 if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) {
259   libxsmm_internal_matrix_zero(C*K*4, w_scratch,  start_thread, tid, handle->desc.threads);
260   libxsmm_internal_matrix_zero(K*K*4, r_scratch,  start_thread, tid, handle->desc.threads);
261   libxsmm_internal_matrix_zero(K*4,   db,  start_thread, tid, handle->desc.threads);
262 }
263 
264 #ifdef PROFILE
265 if (ltid == 0) _start = _rdtsc();
266 #endif
267 /* transpose W */
268 for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) {
269   ic = (ikic / (K/bk));
270   ik = (ikic % (K/bk));
271   for (jk = 0; jk < bk; ++jk) {
272     for (jc = 0; jc < bc; ++jc) {
273       LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(2, wi, ic*bc+jc, ik*bk+jk, 4*K);
274       LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(2, wc, ic*bc+jc, ik*bk+jk, 4*K);
275       LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(2, wf, ic*bc+jc, ik*bk+jk, 4*K);
276       LIBXSMM_VLA_ACCESS(4, woT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(2, wo, ic*bc+jc, ik*bk+jk, 4*K);
277     }
278   }
279 }
280 
281 /* transpose R */
282 for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) {
283   ik = (ikic / (K/bk));
284   ic = (ikic % (K/bk));
285   for (jk = 0; jk < bk; ++jk) {
286     for (jc = 0; jc < bk; ++jc) {
287       LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, ri, ic*bk+jc, ik*bk+jk, 4*K);
288       LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, rc, ic*bk+jc, ik*bk+jk, 4*K);
289       LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, rf, ic*bk+jc, ik*bk+jk, 4*K);
290       LIBXSMM_VLA_ACCESS(4, roT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, ro, ic*bk+jc, ik*bk+jk, 4*K);
291     }
292   }
293 }
294 #ifdef PROFILE
295 if (ltid == 0) {
296   _end = _rdtsc();
297   weight_trans_cycles += _end - _start;
298 }
299 #endif
300 
301 #include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c"
302 
303 if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) {
304 #ifdef PROFILE
305   if (ltid == 0) _start = _rdtsc();
306 #endif
307   /* Store result weight matrices in CK format */
308   for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) {
309     icb = ikic / (K/bk);
310     ic = icb*bc;
311     ikb = ikic % (K/bk);
312     ik = ikb*bk;
313     for (jc = 0; jc < bc; ++jc) {
314       for (jk = 0; jk < bk; ++jk) {
315         LIBXSMM_VLA_ACCESS(2, dwi_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwi, ikb, icb, jc, jk, cBlocks, bc, bk);
316         LIBXSMM_VLA_ACCESS(2, dwc_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwc, ikb, icb, jc, jk, cBlocks, bc, bk);
317         LIBXSMM_VLA_ACCESS(2, dwf_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwf, ikb, icb, jc, jk, cBlocks, bc, bk);
318         LIBXSMM_VLA_ACCESS(2, dwo_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dwo, ikb, icb, jc, jk, cBlocks, bc, bk);
319       }
320     }
321   }
322 
323   for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) {
324     icb = ikic / (K/bk);
325     ic = icb*bk;
326     ikb = ikic % (K/bk);
327     ik = ikb*bk;
328     for (jc = 0; jc < bk; ++jc) {
329       for (jk = 0; jk < bk; ++jk) {
330         LIBXSMM_VLA_ACCESS(2, dri_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dri, ikb, icb, jc, jk, kBlocks, bk, bk);
331         LIBXSMM_VLA_ACCESS(2, drc_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, drc, ikb, icb, jc, jk, kBlocks, bk, bk);
332         LIBXSMM_VLA_ACCESS(2, drf_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, drf, ikb, icb, jc, jk, kBlocks, bk, bk);
333         LIBXSMM_VLA_ACCESS(2, dro_ck, ic+jc, ik+jk , K4) = LIBXSMM_VLA_ACCESS(4, dro, ikb, icb, jc, jk, kBlocks, bk, bk);
334       }
335     }
336   }
337   libxsmm_barrier_wait(handle->barrier, (int)ltid);
338 #ifdef PROFILE
339   if (ltid == 0) {
340     _end = _rdtsc();
341     reformat_cycles += _end - _start;
342   }
343 #endif
344 }
345 
346 #ifdef PROFILE
347 if (ltid == 0) {
348   printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk );
349   total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f;
350   printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
351   printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
352   printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9)));
353   printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9)));
354   printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
355   printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9)));
356   printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
357   printf("Reformat dwdr time is %f ms (%.2f%%)\n\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
358 }
359 #undef PROFILE
360 #endif
361