1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas (Intel Corp.)
10 ******************************************************************************/
11 
12 /* All data is in column-major format */
13 for (j = 0; j < t; ++j) {
14   /* let's run the cell in blocks for good locality */
15   /* Block reduction loop if requested */
16   for (CB = 0; CB < BF; CB++) {
17     for (inik = thr_begin; inik < thr_end; ++inik ) {
18       in = (inik % (N/bn))*bn;
19       ikb = inik / (N/bn);
20       ik = ikb*bk;
21       /* initialize i with bi */
22 #ifdef PROFILE
23       if (ltid == 0) gemm_start = _rdtsc();
24 #endif
25       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] );
26       /* i += W.x */
27       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
28         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
29         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
30       }
31       /* Reduce batch gemm call  */
32       blocks = CB_BLOCKS;
33       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks);
34 #ifdef PROFILE
35       if (ltid == 0) {
36         gemm_end = _rdtsc();
37         gemm_cycles += gemm_end-gemm_start;
38       }
39 #endif
40 #ifdef PROFILE
41       if (ltid == 0) gemm_start = _rdtsc();
42 #endif
43       /* i += R.h */
44       if (0 == j) {
45         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
46           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
47           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
48         }
49       } else {
50         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
51           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
52           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
53         }
54       }
55       /* Reduce batch gemm call  */
56       blocks = KB_BLOCKS;
57       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks);
58 #ifdef PROFILE
59       if (ltid == 0) {
60         gemm_end = _rdtsc();
61         gemm_cycles2 += gemm_end-gemm_start;
62       }
63 #endif
64 #ifdef PROFILE
65       if (ltid == 0) gemm_start = _rdtsc();
66 #endif
67       /* initialize ci with bd */
68       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] );
69       /* ci += W.x */
70       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
71         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
72         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
73       }
74       /* Reduce batch gemm call  */
75       blocks = CB_BLOCKS;
76       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks);
77 #ifdef PROFILE
78       if (ltid == 0) {
79         gemm_end = _rdtsc();
80         gemm_cycles += gemm_end-gemm_start;
81       }
82 #endif
83 #ifdef PROFILE
84       if (ltid == 0) gemm_start = _rdtsc();
85 #endif
86       /* ci += R.h */
87       if (0 == j) {
88         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
89           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
90           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
91         }
92       } else {
93         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
94           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
95           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
96         }
97       }
98       /* Reduce batch gemm call  */
99       blocks = KB_BLOCKS;
100       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks);
101 #ifdef PROFILE
102       if (ltid == 0) {
103         gemm_end = _rdtsc();
104         gemm_cycles2 += gemm_end-gemm_start;
105       }
106 #endif
107 #ifdef PROFILE
108       if (ltid == 0) gemm_start = _rdtsc();
109 #endif
110       /* initialize f with (bf + forget_bias) */
111       if (CB == 0)  libxsmm_internal_matrix_bcst_colvector_const_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias );
112       /* f += W.x */
113       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
114         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
115         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
116       }
117       /* Reduce batch gemm call  */
118       blocks = CB_BLOCKS;
119       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks);
120 #ifdef PROFILE
121       if (ltid == 0) {
122         gemm_end = _rdtsc();
123         gemm_cycles += gemm_end-gemm_start;
124       }
125 #endif
126 #ifdef PROFILE
127       if (ltid == 0) gemm_start = _rdtsc();
128 #endif
129       /* f += R.h */
130       if (0 == j) {
131         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
132           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
133           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
134         }
135       } else {
136         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
137           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
138           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
139         }
140       }
141       /* Reduce batch gemm call  */
142       blocks = KB_BLOCKS;
143       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks);
144 #ifdef PROFILE
145       if (ltid == 0) {
146         gemm_end = _rdtsc();
147         gemm_cycles2 += gemm_end-gemm_start;
148       }
149 #endif
150 #ifdef PROFILE
151       if (ltid == 0) gemm_start = _rdtsc();
152 #endif
153       /* initialize o with bo */
154       if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] );
155       /* o += W.x */
156       for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) {
157         A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wo, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk);
158         B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C);
159       }
160       /* Reduce batch gemm call  */
161       blocks = CB_BLOCKS;
162       batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks);
163 #ifdef PROFILE
164       if (ltid == 0) {
165         gemm_end = _rdtsc();
166         gemm_cycles += gemm_end-gemm_start;
167       }
168 #endif
169 #ifdef PROFILE
170       if (ltid == 0) gemm_start = _rdtsc();
171 #endif
172       /* o += R.h */
173       if (0 == j) {
174         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
175           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
176           B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K);
177         }
178       } else {
179         for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) {
180           A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk);
181           B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K);
182         }
183       }
184       /* Reduce batch gemm call  */
185       blocks = KB_BLOCKS;
186       batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks);
187 #ifdef PROFILE
188       if (ltid == 0) {
189         gemm_end = _rdtsc();
190         gemm_cycles2 += gemm_end-gemm_start;
191       }
192 #endif
193 
194       if (CB == BF-1) {
195 #ifdef PROFILE
196         if (ltid == 0) {
197           eltwise_start = _rdtsc();
198         }
199 #endif
200         cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K);
201         /* Compute i, ci, f, o, cs, co and h */
202 #if defined(LIBXSMM_RNN_CELL_AVX512)
203         if (bk % 16 == 0 && bc % 16 == 0) {
204 #include "libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c"
205         } else {
206           libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) );
207           libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) );
208           libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) );
209           libxsmm_internal_matrix_tanh_ld(    bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) );
210           libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
211           libxsmm_internal_matrix_eltwise_fma_ld(  bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
212           libxsmm_internal_matrix_tanh_ld(         bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) );
213           libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K),  &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
214         }
215 #else
216         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) );
217         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) );
218         libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) );
219         libxsmm_internal_matrix_tanh_ld(    bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) );
220         libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
221         libxsmm_internal_matrix_eltwise_fma_ld(  bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) );
222         libxsmm_internal_matrix_tanh_ld(         bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) );
223         libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K),  &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) );
224 #endif
225 
226 #ifdef PROFILE
227         if (ltid == 0) {
228           eltwise_end = _rdtsc();
229           eltwise_cycles += eltwise_end-eltwise_start;
230         }
231 #endif
232       }
233     }
234   }
235   libxsmm_barrier_wait(handle->barrier, (int)ltid);
236 }
237 
238