1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #if 0
12 #define PROFILE
13 #endif
14 
15 /* helper variables */
16 libxsmm_blasint j, ik, ikb, in, inb, ic, icb, jk, jb/*jn shadows global variable*/, jc, ek, en, ec, BF, KB_BLOCKS, KB;
17 /* tensor dimensions */
18 libxsmm_blasint K = handle->desc.K;
19 libxsmm_blasint N = handle->desc.N;
20 libxsmm_blasint C = handle->desc.C;
21 libxsmm_blasint t = handle->T;
22 libxsmm_blasint bk = handle->bk;
23 libxsmm_blasint bn = handle->bn;
24 libxsmm_blasint bc = handle->bc;
25 const libxsmm_blasint cBlocks = C/bc;
26 const libxsmm_blasint kBlocks = K/bk;
27 const libxsmm_blasint nBlocks = N/bn;
28 unsigned long long blocks;
29 /* tensor raw pointers */
30 element_input_type  *xt    = (element_input_type* )handle->xt->data;
31 element_input_type *csp    = (element_input_type* )handle->csp->data;
32 element_input_type *hpD    = (element_input_type* )handle->hp->data;
33 element_filter_type *wt    = (element_filter_type*)handle->wt->data;
34 element_filter_type *rt    = (element_filter_type*)handle->rt->data;
35 element_output_type *cst   = (element_output_type*)handle->cst->data;
36 element_output_type *ht    = handle->ht ? (element_output_type*)handle->ht->data : (element_output_type*)NULL;
37 element_output_type *it    = (element_output_type*)handle->it->data;
38 element_output_type *ft    = (element_output_type*)handle->ft->data;
39 element_output_type *ot    = (element_output_type*)handle->ot->data;
40 element_output_type *cit   = (element_output_type*)handle->cit->data;
41 element_output_type *cot   = (element_output_type*)handle->cot->data;
42 element_input_type  *dxt   = (element_input_type*)handle->dxt->data;
43 element_input_type  *dcsp  = (element_input_type* )handle->dcsp->data;
44 element_input_type  *dhpD  = (element_input_type* )handle->dhp->data;
45 element_filter_type *dw    = (element_filter_type*)handle->dw->data;
46 element_filter_type *dr    = (element_filter_type*)handle->dr->data;
47 element_output_type *db    = (element_output_type*)handle->db->data;
48 element_output_type *dcsD  = (element_output_type*)handle->dcs->data;
49 element_output_type *dht   = (element_output_type*)handle->dht->data;
50 element_output_type *diD   = (element_output_type*)handle->scratch_di;
51 element_output_type *dfD   = (element_output_type*)handle->scratch_df;
52 element_output_type *doD   = (element_output_type*)handle->scratch_do;
53 element_output_type *dciD  = (element_output_type*)handle->scratch_dci;
54 element_output_type *doutD = (element_output_type*)handle->scratch_deltat;
55 element_input_type  *scratch_xT  = (element_input_type* )handle->scratch_xT;
56 #if 0
57 element_filter_type *scratch_wT  = (element_filter_type*)handle->scratch_wT;
58 element_filter_type *scratch_rT  = (element_filter_type*)handle->scratch_rT;
59 #endif
60 element_output_type *scratch_hT  = (element_output_type*)handle->scratch_hT;
61 element_filter_type *witD  = &(wt[0]);
62 element_filter_type *wctD  = &(wt[C*K]);
63 element_filter_type *wftD  = &(wt[2*C*K]);
64 element_filter_type *wotD  = &(wt[3*C*K]);
65 element_filter_type *ritD  = &(rt[0]);
66 element_filter_type *rctD  = &(rt[K*K]);
67 element_filter_type *rftD  = &(rt[2*K*K]);
68 element_filter_type *rotD  = &(rt[3*K*K]);
69 element_filter_type *dwiD  = &(dw[0]);
70 element_filter_type *dwcD  = &(dw[C*K]);
71 element_filter_type *dwfD  = &(dw[2*C*K]);
72 element_filter_type *dwoD  = &(dw[3*C*K]);
73 element_filter_type *driD  = &(dr[0]);
74 element_filter_type *drcD  = &(dr[K*K]);
75 element_filter_type *drfD  = &(dr[2*K*K]);
76 element_filter_type *droD  = &(dr[3*K*K]);
77 element_output_type *dbi   = &(db[0]);
78 element_output_type *dbc   = &(db[K]);
79 element_output_type *dbf   = &(db[2*K]);
80 element_output_type *dbo   = &(db[3*K]);
81 #if 0
82 element_filter_type *scratch_wiT = &(scratch_wT[0]);
83 element_filter_type *scratch_wcT = &(scratch_wT[C*K]);
84 element_filter_type *scratch_wfT = &(scratch_wT[2*C*K]);
85 element_filter_type *scratch_woT = &(scratch_wT[3*C*K]);
86 element_filter_type *scratch_riT = &(scratch_rT[0]);
87 element_filter_type *scratch_rcT = &(scratch_rT[K*K]);
88 element_filter_type *scratch_rfT = &(scratch_rT[2*K*K]);
89 element_filter_type *scratch_roT = &(scratch_rT[3*K*K]);
90 #endif
91 element_output_type *t1D   = (element_output_type*)handle->scratch_t1;
92 element_output_type *t2D   = (element_output_type*)handle->scratch_t2;
93 /* multidimensional arrays */
94 LIBXSMM_VLA_DECL(2, element_output_type, t1, t1D, K);
95 LIBXSMM_VLA_DECL(2, element_output_type, t2, t2D, K);
96 LIBXSMM_VLA_DECL(3, element_input_type,  x, xt, N, C);
97 LIBXSMM_VLA_DECL(2, element_input_type,  cp, csp, K);
98 LIBXSMM_VLA_DECL(2, element_input_type,  hp, hpD, K);
99 #if 0
100 LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk);
101 LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk);
102 LIBXSMM_VLA_DECL(4, element_filter_type, wo, woD, cBlocks, bc, bk);
103 LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk);
104 LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk);
105 LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk);
106 LIBXSMM_VLA_DECL(4, element_filter_type, ro, roD, kBlocks, bk, bk);
107 LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk);
108 #endif
109 LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K);
110 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K);
111 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K);
112 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K);
113 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K);
114 LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K);
115 LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K);
116 LIBXSMM_VLA_DECL(3, element_input_type,  dx, dxt, N, C);
117 LIBXSMM_VLA_DECL(2, element_input_type,  dcp, dcsp, K);
118 LIBXSMM_VLA_DECL(2, element_input_type,  dhp, dhpD, K);
119 LIBXSMM_VLA_DECL(4, element_filter_type, dwi, dwiD, cBlocks, bc, bk);
120 LIBXSMM_VLA_DECL(4, element_filter_type, dwf, dwfD, cBlocks, bc, bk);
121 LIBXSMM_VLA_DECL(4, element_filter_type, dwo, dwoD, cBlocks, bc, bk);
122 LIBXSMM_VLA_DECL(4, element_filter_type, dwc, dwcD, cBlocks, bc, bk);
123 LIBXSMM_VLA_DECL(4, element_filter_type, dri, driD, kBlocks, bk, bk);
124 LIBXSMM_VLA_DECL(4, element_filter_type, drf, drfD, kBlocks, bk, bk);
125 LIBXSMM_VLA_DECL(4, element_filter_type, dro, droD, kBlocks, bk, bk);
126 LIBXSMM_VLA_DECL(4, element_filter_type, drc, drcD, kBlocks, bk, bk);
127 LIBXSMM_VLA_DECL(2, element_output_type, dcs, dcsD, K);
128 LIBXSMM_VLA_DECL(3, element_output_type, dh, dht, N, K);
129 LIBXSMM_VLA_DECL(2, element_output_type, di, diD, K);
130 LIBXSMM_VLA_DECL(2, element_output_type, df, dfD, K);
131 LIBXSMM_VLA_DECL(2, element_output_type, dp, doD, K);
132 LIBXSMM_VLA_DECL(2, element_output_type, dci, dciD, K);
133 LIBXSMM_VLA_DECL(2, element_output_type, dout, doutD, K);
134 LIBXSMM_VLA_DECL(2, element_input_type,  xT, scratch_xT, N);
135 LIBXSMM_VLA_DECL(4, element_filter_type, wiT, witD, kBlocks, bk, bc);
136 LIBXSMM_VLA_DECL(4, element_filter_type, wcT, wctD, kBlocks, bk, bc);
137 LIBXSMM_VLA_DECL(4, element_filter_type, wfT, wftD, kBlocks, bk, bc);
138 LIBXSMM_VLA_DECL(4, element_filter_type, woT, wotD, kBlocks, bk, bc);
139 LIBXSMM_VLA_DECL(4, element_filter_type, riT, ritD, kBlocks, bk, bk);
140 LIBXSMM_VLA_DECL(4, element_filter_type, rcT, rctD, kBlocks, bk, bk);
141 LIBXSMM_VLA_DECL(4, element_filter_type, rfT, rftD, kBlocks, bk, bk);
142 LIBXSMM_VLA_DECL(4, element_filter_type, roT, rotD, kBlocks, bk, bk);
143 LIBXSMM_VLA_DECL(2, element_output_type, hT, scratch_hT, N);
144 element_output_type *dout_ptr = NULL;
145 /* define batch-reduce gemm kernels */
146 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bc, bn, bk, &bc, &K, &C, NULL, NULL, NULL, NULL);
147 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL);
148 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &bk, &N, &bk, NULL, NULL, NULL, NULL);
149 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb1 = libxsmm_smmdispatch_reducebatch_addr( bk, bk, bn, &K, &N, &bk, NULL, NULL, NULL, NULL);
150 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelc1 = libxsmm_smmdispatch_reducebatch_addr( bk, bc, bn, &K, &N, &bk, NULL, NULL, NULL, NULL);
151 const libxsmm_smmfunction_reducebatch_addr batchreduce_kerneld = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL);
152 
153 /* Auxiliary arrays for batch-reduce gemm calls */
154 const element_filter_type *A_array[1024];
155 const element_output_type *B_array[1024];
156 
157 LIBXSMM_VLA_DECL(4, element_output_type, diB, (element_output_type*)handle->scratch_diB, kBlocks, bn, bk);
158 LIBXSMM_VLA_DECL(4, element_output_type, dfB, (element_output_type*)handle->scratch_dfB, kBlocks, bn, bk);
159 LIBXSMM_VLA_DECL(4, element_output_type, dpB, (element_output_type*)handle->scratch_dpB, kBlocks, bn, bk);
160 LIBXSMM_VLA_DECL(4, element_output_type, dciB, (element_output_type*)handle->scratch_dciB, kBlocks, bn, bk);
161 
162 /* computing first logical thread */
163 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread;
164 
165 /* number of tasks that could be run in parallel for N and K blocks*/
166 const libxsmm_blasint work_nk = (N/bn) * (K/bk);
167 /* compute chunk size */
168 const libxsmm_blasint chunksize_nk = (work_nk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nk / (libxsmm_blasint)handle->desc.threads) : ((work_nk / (libxsmm_blasint)handle->desc.threads) + 1);
169 /* compute thr_begin and thr_end */
170 const libxsmm_blasint thr_begin_nk = (ltid * chunksize_nk < work_nk) ? (ltid * chunksize_nk) : work_nk;
171 const libxsmm_blasint thr_end_nk = ((ltid + 1) * chunksize_nk < work_nk) ? ((ltid + 1) * chunksize_nk) : work_nk;
172 
173 /* number of tasks that could be run in parallel for N and C blocks*/
174 const libxsmm_blasint work_nc = (N/bn) * (C/bc);
175 /* compute chunk size */
176 const libxsmm_blasint chunksize_nc = (work_nc % (libxsmm_blasint)handle->desc.threads == 0) ? (work_nc / (libxsmm_blasint)handle->desc.threads) : ((work_nc / (libxsmm_blasint)handle->desc.threads) + 1);
177 /* compute thr_begin and thr_end */
178 const libxsmm_blasint thr_begin_nc = (ltid * chunksize_nc < work_nc) ? (ltid * chunksize_nc) : work_nc;
179 const libxsmm_blasint thr_end_nc = ((ltid + 1) * chunksize_nc < work_nc) ? ((ltid + 1) * chunksize_nc) : work_nc;
180 
181 /* number of tasks that could be run in parallel for C and K blocks*/
182 const libxsmm_blasint work_ck = (C/bc) * (K/bk);
183 /* compute chunk size */
184 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1);
185 /* compute thr_begin and thr_end */
186 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck;
187 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck;
188 
189 /* number of tasks that could be run in parallel for K and K blocks*/
190 const libxsmm_blasint work_kk = (K/bk) * (K/bk);
191 /* compute chunk size */
192 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1);
193 /* compute thr_begin and thr_end */
194 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk;
195 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk;
196 
197 #if defined(LIBXSMM_RNN_CELL_AVX512)
198 element_output_type *cps_ptr = NULL;
199 int k_tasks = K/16;
200 int k_chunksize = (k_tasks % (libxsmm_blasint)handle->desc.threads == 0) ? (k_tasks / (libxsmm_blasint)handle->desc.threads) : ((k_tasks / (libxsmm_blasint)handle->desc.threads) + 1);
201 /* compute thr_begin and thr_end */
202 const libxsmm_blasint k_thr_begin = (ltid * k_chunksize * 16 < K) ? (ltid * k_chunksize * 16) : K;
203 const libxsmm_blasint k_thr_end = ((ltid + 1) * k_chunksize * 16 < K) ? ((ltid + 1) * k_chunksize * 16) : K;__m512 dbi_sum, dbf_sum, dbo_sum, dbc_sum;
204 #endif
205 /* number of tasks that could be run in parallel for K blocks*/
206 /* compute chunk size */
207 const libxsmm_blasint chunksize_k = (K % (libxsmm_blasint)handle->desc.threads == 0) ? (K / (libxsmm_blasint)handle->desc.threads) : ((K / (libxsmm_blasint)handle->desc.threads) + 1);
208 /* compute thr_begin and thr_end */
209 const libxsmm_blasint thr_begin_k = (ltid * chunksize_k < K) ? (ltid * chunksize_k) : K;
210 const libxsmm_blasint thr_end_k = ((ltid + 1) * chunksize_k < K) ? ((ltid + 1) * chunksize_k) : K;
211 #ifdef PROFILE
212 __int64_t _start, _end, eltwise_cycles = 0, dout_cycles = 0, weight_trans_cycles = 0, act_trans_cycles = 0, dx_cycles = 0, dwdr_cycles = 0, gradient_cycles = 0;
213 float total_time = 0.0;
214 #endif
215 int bcbk_multiples_of_16 = ((bc % 16 == 0) && (bk % 16 == 0)) ? 1 : 0;
216 
217 libxsmm_blasint ikic, inic, inik, icin, ikin;
218 
219 /* lazy barrier init */
220 libxsmm_barrier_init(handle->barrier, (int)ltid);
221 
222 /* Blocking reduction domain if it is too large */
223 BF = 1;
224 if (K > 1024 && K <= 2048) {
225   BF = 8;
226   while (kBlocks % BF != 0) {
227     BF--;
228   }
229 }
230 
231 if (K > 2048) {
232   BF = 16;
233   while (kBlocks % BF != 0) {
234     BF--;
235   }
236 }
237 KB_BLOCKS = kBlocks/BF;
238 
239 /* initialization is done at the beginning */
240 if ( (LIBXSMM_DNN_COMPUTE_KIND_BWD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) {
241   libxsmm_internal_matrix_zero(N*C*t, dxt, start_thread, tid, handle->desc.threads);
242 }
243 
244 /* initialization is done at the beginning */
245 if ( (LIBXSMM_DNN_COMPUTE_KIND_UPD == kind) || (LIBXSMM_DNN_COMPUTE_KIND_BWDUPD == kind) ) {
246   libxsmm_internal_matrix_zero(C*K*4, dw,  start_thread, tid, handle->desc.threads);
247   libxsmm_internal_matrix_zero(K*K*4, dr,  start_thread, tid, handle->desc.threads);
248   libxsmm_internal_matrix_zero(K*4,   db,  start_thread, tid, handle->desc.threads);
249 }
250 
251 #if 0
252 #ifdef PROFILE
253 if (ltid == 0) _start = _rdtsc();
254 #endif
255 /* transpose W */
256 for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) {
257   ic = (ikic / (K/bk));
258   ik = (ikic % (K/bk));
259   for (jk = 0; jk < bk; ++jk) {
260     for (jc = 0; jc < bc; ++jc) {
261       LIBXSMM_VLA_ACCESS(4, wiT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk);
262       LIBXSMM_VLA_ACCESS(4, wcT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk);
263       LIBXSMM_VLA_ACCESS(4, wfT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk);
264       LIBXSMM_VLA_ACCESS(4, woT, ic, ik, jk, jc, kBlocks, bk, bc) =  LIBXSMM_VLA_ACCESS(4, wo, ik, ic, jc, jk, cBlocks, bc, bk);
265     }
266   }
267 }
268 
269 /* transpose R */
270 for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) {
271   ik = (ikic / (K/bk));
272   ic = (ikic % (K/bk));
273   for (jk = 0; jk < bk; ++jk) {
274     for (jc = 0; jc < bk; ++jc) {
275       LIBXSMM_VLA_ACCESS(4, riT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk);
276       LIBXSMM_VLA_ACCESS(4, rcT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk);
277       LIBXSMM_VLA_ACCESS(4, rfT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk);
278       LIBXSMM_VLA_ACCESS(4, roT, ic, ik, jk, jc, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(4, ro, ik, ic, jc, jk, kBlocks, bk, bk);
279     }
280   }
281 }
282 #ifdef PROFILE
283 if (ltid == 0) {
284   _end = _rdtsc();
285   weight_trans_cycles += _end - _start;
286 }
287 #endif
288 #endif
289 
290 #include "libxsmm_dnn_rnncell_st_lstm_bwdupd_nc_kcck_core.tpl.c"
291 
292 #ifdef PROFILE
293 if (ltid == 0) {
294   printf("----- PROFILING LSTM BWD/UPD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk );
295   total_time = (gradient_cycles+dwdr_cycles+dx_cycles+act_trans_cycles+weight_trans_cycles+dout_cycles+eltwise_cycles)/(2.5 * 1e9)*1000.0f;
296   printf("Transpose weights time is %f ms (%.2f%%)\n", weight_trans_cycles/(2.5 * 1e9)*1000.0f, weight_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
297   printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
298   printf("Dx GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dx_cycles/(2.5 * 1e9)*1000.0f, dx_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*C*K*4/1e9/(dx_cycles/(2.5 * 1e9)));
299   printf("Dh GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dout_cycles/(2.5 * 1e9)*1000.0f, dout_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*N*K*K*4/1e9/(dout_cycles/(2.5 * 1e9)));
300   printf("Transpose input activations time is %f ms (%.2f%%)\n", act_trans_cycles/(2.5 * 1e9)*1000.0f, act_trans_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
301   printf("Dwdr GEMM time is %f ms (%.2f%%) at %f GFLOPS\n", dwdr_cycles/(2.5 * 1e9)*1000.0f, dwdr_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*2.0*(N*K*K*2.0+N*C*K*2.0)*2.0/1e9/(dwdr_cycles/(2.5 * 1e9)));
302   printf("Gradient bias calculation time is %f ms (%.2f%%)\n", gradient_cycles/(2.5 * 1e9)*1000.0f, gradient_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
303 }
304 #undef PROFILE
305 #endif
306