1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 
12 /* helper variables */
13 libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS, ikic, jk, jc;
14 /* input sizes */
15 const libxsmm_blasint K =  handle->desc.K;
16 const libxsmm_blasint N =  handle->desc.N;
17 const libxsmm_blasint C =  handle->desc.C;
18 const libxsmm_blasint t =  handle->T;
19 const libxsmm_blasint bk = handle->bk;
20 const libxsmm_blasint bn = handle->bn;
21 const libxsmm_blasint bc = handle->bc;
22 const libxsmm_blasint K3 = K * 3;
23 const libxsmm_blasint cBlocks = C/bc;
24 const libxsmm_blasint kBlocks = K/bk;
25 unsigned long long blocks;
26 
27 /* define tensors */
28 element_input_type  *xt  = (element_input_type* )handle->xt->data;
29 element_input_type  *hpD = (element_input_type* )handle->hp->data;
30 element_filter_type *w   = (element_filter_type*)handle->w->data;
31 element_filter_type *r   = (element_filter_type*)handle->r->data;
32 element_filter_type *w_scratch   = (element_filter_type*)handle->scratch_w;
33 element_filter_type *r_scratch   = (element_filter_type*)handle->scratch_r;
34 element_output_type *b   = (element_output_type*)handle->b->data;
35 element_output_type *ht  = (element_output_type*)handle->ht->data;
36 element_output_type *it  = (element_output_type*)handle->it->data;
37 element_output_type *ct  = (element_output_type*)handle->cit->data;
38 element_output_type *ft  = (element_output_type*)handle->ft->data;
39 element_output_type *ot  = (element_output_type*)handle->ot->data;
40 element_filter_type *wiD = &(w[0]);
41 element_filter_type *wcD = &(w[K]);
42 element_filter_type *wfD = &(w[2*K]);
43 element_filter_type *riD = &(r[0]);
44 element_filter_type *rcD = &(r[K]);
45 element_filter_type *rfD = &(r[2*K]);
46 element_filter_type *wiD_scratch = &(w_scratch[0]);
47 element_filter_type *wcD_scratch = &(w_scratch[C*K]);
48 element_filter_type *wfD_scratch = &(w_scratch[2*C*K]);
49 element_filter_type *riD_scratch = &(r_scratch[0]);
50 element_filter_type *rcD_scratch = &(r_scratch[K*K]);
51 element_filter_type *rfD_scratch = &(r_scratch[2*K*K]);
52 element_output_type *bi  = &(b[0]);
53 element_output_type *bd  = &(b[K]);
54 element_output_type *bf  = &(b[2*K]);
55 LIBXSMM_VLA_DECL(3, element_input_type,  x, xt, N, C);
56 LIBXSMM_VLA_DECL(2, element_input_type,  hp, hpD, K);
57 LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD_scratch, cBlocks, bc, bk);
58 LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD_scratch, cBlocks, bc, bk);
59 LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD_scratch, cBlocks, bc, bk);
60 LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD_scratch, kBlocks, bk, bk);
61 LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD_scratch, kBlocks, bk, bk);
62 LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD_scratch, kBlocks, bk, bk);
63 LIBXSMM_VLA_DECL(2, element_filter_type, wi_ck, wiD, K3);
64 LIBXSMM_VLA_DECL(2, element_filter_type, wc_ck, wcD, K3);
65 LIBXSMM_VLA_DECL(2, element_filter_type, wf_ck, wfD, K3);
66 LIBXSMM_VLA_DECL(2, element_filter_type, ri_ck, riD, K3);
67 LIBXSMM_VLA_DECL(2, element_filter_type, rc_ck, rcD, K3);
68 LIBXSMM_VLA_DECL(2, element_filter_type, rf_ck, rfD, K3);
69 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K);
70 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K);
71 LIBXSMM_VLA_DECL(3, element_output_type, c, ct, N, K);
72 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K);
73 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K);
74 /* define batch-reduce gemm kernels */
75 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL );
76 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL );
77 /* define gemm kernels */
78 /* Auxiliary arrays for batch-reduce gemms */
79 const element_filter_type *A_array[1024];
80 const element_input_type  *B_array[1024];
81 
82 /* parallelize over C-blocks */
83 /* computing first logical thread */
84 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread;
85 /* number of tasks that could be run in parallel */
86 const libxsmm_blasint work = (N/bn) * (K/bk);
87 /* compute chunk size */
88 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1);
89 /* compute thr_begin and thr_end */
90 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
91 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
92 
93 /* number of tasks that could be run in parallel for C and K blocks*/
94 const libxsmm_blasint work_ck = (C/bc) * (K/bk);
95 /* compute chunk size */
96 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1);
97 /* compute thr_begin and thr_end */
98 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck;
99 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck;
100 
101 /* number of tasks that could be run in parallel for K and K blocks*/
102 const libxsmm_blasint work_kk = (K/bk) * (K/bk);
103 /* compute chunk size */
104 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1);
105 /* compute thr_begin and thr_end */
106 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk;
107 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk;
108 #if 0
109 const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0;
110 #endif
111 /* lazy barrier init */
112 libxsmm_barrier_init(handle->barrier, (int)ltid);
113 
114 /* Blocking reduction domain if it is too large */
115 BF = 1;
116 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) {
117   BF = 8;
118   while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
119     BF--;
120   }
121 }
122 if (C > 2048 || K > 2048) {
123   BF = 16;
124   while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
125     BF--;
126   }
127 }
128 
129 if (C == 2048 && K == 1024) {
130   BF = 2;
131 }
132 
133 CB_BLOCKS = cBlocks/BF;
134 KB_BLOCKS = kBlocks/BF;
135 
136 /* Upfront reformatting of W and R */
137 /* reformat W */
138 for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) {
139   ic = (ikic / (K/bk));
140   ik = (ikic % (K/bk));
141   for (jk = 0; jk < bk; ++jk) {
142     for (jc = 0; jc < bc; ++jc) {
143       LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk) =  LIBXSMM_VLA_ACCESS(2, wi_ck, ic*bc+jc, ik*bk+jk, 3*K);
144       LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk) =  LIBXSMM_VLA_ACCESS(2, wc_ck, ic*bc+jc, ik*bk+jk, 3*K);
145       LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk) =  LIBXSMM_VLA_ACCESS(2, wf_ck, ic*bc+jc, ik*bk+jk, 3*K);
146     }
147   }
148 }
149 
150 /* reformat R */
151 for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) {
152   ik = (ikic / (K/bk));
153   ic = (ikic % (K/bk));
154   for (jk = 0; jk < bk; ++jk) {
155     for (jc = 0; jc < bk; ++jc) {
156       LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, ri_ck, ic*bk+jc, ik*bk+jk, 3*K);
157       LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, rc_ck, ic*bk+jc, ik*bk+jk, 3*K);
158       LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk) =  LIBXSMM_VLA_ACCESS(2, rf_ck, ic*bk+jc, ik*bk+jk, 3*K);
159     }
160   }
161 }
162 
163 libxsmm_barrier_wait(handle->barrier, (int)ltid);
164 
165 /* lazy barrier init */
166 libxsmm_barrier_init(handle->barrier, (int)ltid);
167 
168 /* All data is in column-major format */
169 for (j = 0; j < t; ++j) {
170   /* let's run the cell in blocks for good locality */
171   /* Block reduction loop if requested */
172   for (CB = 0; CB < BF; CB++) {
173     for (inik = thr_begin; inik < thr_end; ++inik ) {
174       in = (inik % (N/bn))*bn;
175       ikb = inik / (N/bn);
176       ik = ikb*bk;
177       /* initialize i with bi */
178       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] );
179       /* i += W.x */
180       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
181         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
182         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
183       }
184       /* Reduce batch gemm call  */
185       blocks = CB_BLOCKS;
186       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks);
187       /* i += R.hp */
188       if (0 == j) {
189         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
190           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
191           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
192         }
193       } else {
194         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
195           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
196           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
197         }
198       }
199       /* Reduce batch gemm call  */
200       blocks = KB_BLOCKS;
201       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks);
202       /* initialize c with bd */
203       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &bd[ik] );
204       /* c += W.x */
205       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
206         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
207         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
208       }
209       /* Reduce batch gemm call  */
210       blocks = CB_BLOCKS;
211       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks);
212       /* c += R.hp */
213       if (0 == j) {
214         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
215           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
216           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
217         }
218       } else {
219         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
220           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
221           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
222         }
223       }
224       /* Reduce batch gemm call  */
225       blocks = KB_BLOCKS;
226       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks);
227 
228       if (CB == BF-1) {
229         /* i = sigmoid(i) */
230         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) );
231         /* o = hp . i */
232         if (0 == j) {
233           libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K),        &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) );
234         } else {
235           libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) );
236         }
237       }
238     }
239   }
240   libxsmm_barrier_wait(handle->barrier, (int)ltid);
241   /* We need a barrier here to ensure all elements of o are computed before f can be computed */
242   for (CB = 0; CB < BF; CB++) {
243     for (inik = thr_begin; inik < thr_end; ++inik ) {
244       in = (inik % (N/bn))*bn;
245       ikb = inik / (N/bn);
246       ik = ikb*bk;
247       /* initialize f with bf */
248       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik] );
249       /* f += W.x */
250       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
251         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
252         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
253       }
254       /* Reduce batch gemm call  */
255       blocks = CB_BLOCKS;
256       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks);
257       /* f += R.o */
258       for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
259         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
260         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, o, j, in, ic + CB*KB_BLOCKS*bk, N, K);
261       }
262       /* Reduce batch gemm call  */
263       blocks = KB_BLOCKS;
264       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks);
265 
266       if (CB == BF-1) {
267         /* f = tanh(f) */
268         libxsmm_internal_matrix_tanh_ld         ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) );
269         /* c = sigmoid(c) */
270         libxsmm_internal_matrix_sigmoid_ld      ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K) );
271         /* h = (1 - c) . f */
272         libxsmm_internal_matrix_complement_ld   ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
273         libxsmm_internal_matrix_eltwise_mult_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
274         /* h += c . hp */
275         if (0 == j) {
276           libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K),        &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
277         } else {
278           libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
279         }
280       }
281     }
282   }
283   libxsmm_barrier_wait(handle->barrier, (int)ltid);
284 }
285 
286