1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas (Intel Corp.)
10 ******************************************************************************/
11 
12 /* First perform the W*x part of the output  */
13 for (j = 0; j < t; ++j) {
14   /* let's run the cell in blocks for good locality */
15   /* Block reduction loop if requested */
16   for (CB = 0; CB < BF; CB++) {
17     for (inik = thr_begin; inik < thr_end; ++inik ) {
18       in = (inik % (N/bn))*bn;
19       ikb = inik / (N/bn);
20       ik = ikb*bk;
21       /* initialize i with bi */
22 #ifdef PROFILE
23       if (ltid == 0) gemm_start = _rdtsc();
24 #endif
25       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] );
26       /* i += W.x */
27       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
28         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
29         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
30       }
31       /* Reduce batch gemm call  */
32       blocks = CB_BLOCKS;
33       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks);
34 #ifdef PROFILE
35       if (ltid == 0) {
36         gemm_end = _rdtsc();
37         gemm_cycles += gemm_end-gemm_start;
38       }
39 #endif
40 
41 #ifdef PROFILE
42       if (ltid == 0) gemm_start = _rdtsc();
43 #endif
44       /* initialize ci with bd */
45       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] );
46       /* ci += W.x */
47       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
48         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
49         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
50       }
51       /* Reduce batch gemm call  */
52       blocks = CB_BLOCKS;
53       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks);
54 #ifdef PROFILE
55       if (ltid == 0) {
56         gemm_end = _rdtsc();
57         gemm_cycles += gemm_end-gemm_start;
58       }
59 #endif
60 
61 #ifdef PROFILE
62       if (ltid == 0) gemm_start = _rdtsc();
63 #endif
64       /* initialize f with (bf + forget_bias) */
65       if (CB == 0)  libxsmm_internal_matrix_bcst_colvector_const_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias );
66       /* f += W.x */
67       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
68         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
69         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
70       }
71       /* Reduce batch gemm call  */
72       blocks = CB_BLOCKS;
73       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks);
74 #ifdef PROFILE
75       if (ltid == 0) {
76         gemm_end = _rdtsc();
77         gemm_cycles += gemm_end-gemm_start;
78       }
79 #endif
80 
81 #ifdef PROFILE
82       if (ltid == 0) gemm_start = _rdtsc();
83 #endif
84       /* initialize o with bo */
85       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] );
86       /* o += W.x */
87       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
88         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wo, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
89         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
90       }
91       /* Reduce batch gemm call  */
92       blocks = CB_BLOCKS;
93       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks);
94 #ifdef PROFILE
95       if (ltid == 0) {
96         gemm_end = _rdtsc();
97         gemm_cycles += gemm_end-gemm_start;
98       }
99 #endif
100     }
101   }
102 }
103 
104 /* Compute the R*h part of the output  */
105 for (j = 0; j < t; ++j) {
106   /* let's run the cell in blocks for good locality */
107   /* Block reduction loop if requested */
108   for (CB = 0; CB < BF; CB++) {
109     for (inik = thr_begin; inik < thr_end; ++inik ) {
110       in = (inik % (N/bn))*bn;
111       ikb = inik / (N/bn);
112       ik = ikb*bk;
113 #ifdef PROFILE
114       if (ltid == 0) gemm_start = _rdtsc();
115 #endif
116       /* i += R.h */
117       if (0 == j) {
118         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
119           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
120           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
121         }
122       } else {
123         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
124           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
125           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
126         }
127       }
128       /* Reduce batch gemm call  */
129       blocks = KB_BLOCKS;
130       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks);
131 #ifdef PROFILE
132       if (ltid == 0) {
133         gemm_end = _rdtsc();
134         gemm_cycles2 += gemm_end-gemm_start;
135       }
136 #endif
137 
138 #ifdef PROFILE
139       if (ltid == 0) gemm_start = _rdtsc();
140 #endif
141       /* ci += R.h */
142       if (0 == j) {
143         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
144           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
145           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
146         }
147       } else {
148         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
149           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
150           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
151         }
152       }
153       /* Reduce batch gemm call  */
154       blocks = KB_BLOCKS;
155       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks);
156 #ifdef PROFILE
157       if (ltid == 0) {
158         gemm_end = _rdtsc();
159         gemm_cycles2 += gemm_end-gemm_start;
160       }
161 #endif
162 
163 #ifdef PROFILE
164       if (ltid == 0) gemm_start = _rdtsc();
165 #endif
166       /* f += R.h */
167       if (0 == j) {
168         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
169           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
170           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
171         }
172       } else {
173         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
174           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
175           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
176         }
177       }
178       /* Reduce batch gemm call  */
179       blocks = KB_BLOCKS;
180       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks);
181 #ifdef PROFILE
182       if (ltid == 0) {
183         gemm_end = _rdtsc();
184         gemm_cycles2 += gemm_end-gemm_start;
185       }
186 #endif
187 #ifdef PROFILE
188       if (ltid == 0) gemm_start = _rdtsc();
189 #endif
190       /* o += R.h */
191       if (0 == j) {
192         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
193           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
194           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
195         }
196       } else {
197         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
198           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
199           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
200         }
201       }
202       /* Reduce batch gemm call  */
203       blocks = KB_BLOCKS;
204       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks);
205 #ifdef PROFILE
206       if (ltid == 0) {
207         gemm_end = _rdtsc();
208         gemm_cycles2 += gemm_end-gemm_start;
209       }
210 #endif
211 
212       if (CB == BF-1) {
213 #ifdef PROFILE
214         if (ltid == 0) {
215           eltwise_start = _rdtsc();
216         }
217 #endif
218         cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K);
219         /* Compute i, ci, f, o, cs, co and h */
220 #if defined(LIBXSMM_RNN_CELL_AVX512)
221         if (bk % 16 == 0 && bc % 16 == 0) {
222 #include "libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c"
223         } else {
224           libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) );
225           libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) );
226           libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) );
227           libxsmm_internal_matrix_tanh_ld(    bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) );
228           libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
229           libxsmm_internal_matrix_eltwise_fma_ld(  bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
230           libxsmm_internal_matrix_tanh_ld(         bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) );
231           libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K),  &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
232         }
233 #else
234         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) );
235         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) );
236         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) );
237         libxsmm_internal_matrix_tanh_ld(    bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) );
238         libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
239         libxsmm_internal_matrix_eltwise_fma_ld(  bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
240         libxsmm_internal_matrix_tanh_ld(         bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) );
241         libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K),  &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
242 #endif
243 
244 #ifdef PROFILE
245         if (ltid == 0) {
246           eltwise_end = _rdtsc();
247           eltwise_cycles += eltwise_end-eltwise_start;
248         }
249 #endif
250       }
251     }
252   }
253   libxsmm_barrier_wait(handle->barrier, (int)ltid);
254 }
255