1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #if 0
12 #define PROFILE
13 #endif
14 
15 #define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \
16 do { \
17   libxsmm_bfloat16 *src = _src; \
18   float *dst = _dst; \
19   libxsmm_blasint __i,__j; \
20   for ( __j = 0; __j < n; ++__j ) { \
21     for ( __i = 0; __i < m; __i+=16 ) { \
22       _mm512_storeu_ps((float*)&dst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&src[(__j*ld)+__i]))); \
23     } \
24   } \
25 } while (0)
26 
27 #define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \
28 do { \
29   libxsmm_bfloat16 *colv = _colv; \
30   float *srcdst = _srcdst; \
31   libxsmm_blasint __i,__j; \
32   for ( __j = 0; __j < n; ++__j ) { \
33     for ( __i = 0; __i < m; __i+=16 ) { \
34       _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i]))); \
35     } \
36   } \
37 } while (0)
38 
39 #define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \
40 do { \
41   libxsmm_bfloat16 *colv = _colv; \
42   float *srcdst = _srcdst; \
43   libxsmm_blasint __i,__j; \
44   __m512 vbias = _mm512_set1_ps(const_bias); \
45   for ( __j = 0; __j < n; ++__j ) { \
46     for ( __i = 0; __i < m; __i+=16 ) { \
47       _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], _mm512_add_ps(vbias, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i])))); \
48     } \
49   } \
50 } while (0)
51 
52 /* helper variables */
53 libxsmm_blasint j, ik, ikb, in, /*ic, icb,*/ inik, BF, CB, CB_BLOCKS, KB_BLOCKS;
54 /* input sizes */
55 const libxsmm_blasint K =  handle->desc.K;
56 const libxsmm_blasint N =  handle->desc.N;
57 const libxsmm_blasint C =  handle->desc.C;
58 const libxsmm_blasint t =  handle->T;
59 const libxsmm_blasint bk = handle->bk;
60 const libxsmm_blasint bn = handle->bn;
61 const libxsmm_blasint bc = handle->bc;
62 const libxsmm_blasint cBlocks = C/bc;
63 const libxsmm_blasint kBlocks = K/bk;
64 int lpb = 2;
65 const int bc_lp = bc/lpb;
66 const int bk_lp = bk/lpb;
67 unsigned long long blocks, blocksa, blocksb;
68 
69 /* define tensors */
70 element_input_type  *xt  = (element_input_type* )handle->xt->data;
71 element_input_type  *csp = (element_input_type* )handle->csp->data;
72 element_input_type  *hpD = (element_input_type* )handle->hp->data;
73 element_filter_type *w   = (element_filter_type*)handle->w->data;
74 element_filter_type *r   = (element_filter_type*)handle->r->data;
75 element_output_type *b   = (element_output_type*)handle->b->data;
76 
77 /* These buffers are scratch for fp32 output of gemms (intermmediate results) */
78 float *cst = (float*)handle->cst_scratch;
79 float *ht  = (float*)handle->ht_scratch;
80 float *it  = (float*)handle->it_scratch;
81 float *ft  = (float*)handle->ft_scratch;
82 float *ot  = (float*)handle->ot_scratch;
83 float *cit = (float*)handle->cit_scratch;
84 float *cot = (float*)handle->cot_scratch;
85 /* This has to be also upconverted since it is used in the elementwise functions  */
86 float *csp_f32 = (float*)handle->csp_scratch;
87 /* These are the output bf16 data  */
88 element_output_type *cst_bf16 = (element_output_type*)handle->cst->data;
89 element_output_type *ht_bf16  = (element_output_type*)handle->ht->data;
90 element_output_type *it_bf16  = (element_output_type*)handle->it->data;
91 element_output_type *ft_bf16  = (element_output_type*)handle->ft->data;
92 element_output_type *ot_bf16  = (element_output_type*)handle->ot->data;
93 element_output_type *cit_bf16 = (element_output_type*)handle->cit->data;
94 element_output_type *cot_bf16 = (element_output_type*)handle->cot->data;
95 
96 element_filter_type *wiD = &(w[0]);
97 element_filter_type *wcD = &(w[C*K]);
98 element_filter_type *wfD = &(w[2*C*K]);
99 element_filter_type *woD = &(w[3*C*K]);
100 element_filter_type *riD = &(r[0]);
101 element_filter_type *rcD = &(r[K*K]);
102 element_filter_type *rfD = &(r[2*K*K]);
103 element_filter_type *roD = &(r[3*K*K]);
104 element_output_type *bi  = &(b[0]);
105 element_output_type *bd  = &(b[K]);
106 element_output_type *bf  = &(b[2*K]);
107 element_output_type *bo  = &(b[3*K]);
108 LIBXSMM_VLA_DECL(2, float,  cp, csp_f32, K);
109 LIBXSMM_VLA_DECL(2, element_input_type,  cp_bf16, csp, K);
110 LIBXSMM_VLA_DECL(3, element_input_type,  x, xt, N, C);
111 LIBXSMM_VLA_DECL(2, element_input_type,  hp, hpD, K);
112 LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb);
113 LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb);
114 LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb);
115 LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb);
116 LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb);
117 LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb);
118 LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb);
119 LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb);
120 LIBXSMM_VLA_DECL(3, float, cs, cst, N, K);
121 LIBXSMM_VLA_DECL(3, float, h, ht, N, K);
122 LIBXSMM_VLA_DECL(3, float, i, it, N, K);
123 LIBXSMM_VLA_DECL(3, float, f, ft, N, K);
124 LIBXSMM_VLA_DECL(3, float, o, ot, N, K);
125 LIBXSMM_VLA_DECL(3, float, ci, cit, N, K);
126 LIBXSMM_VLA_DECL(3, float, co, cot, N, K);
127 LIBXSMM_VLA_DECL(3, element_output_type, cs_out, cst_bf16, N, K);
128 LIBXSMM_VLA_DECL(3, element_output_type, h_out, ht_bf16, N, K);
129 LIBXSMM_VLA_DECL(3, element_output_type, i_out, it_bf16, N, K);
130 LIBXSMM_VLA_DECL(3, element_output_type, f_out, ft_bf16, N, K);
131 LIBXSMM_VLA_DECL(3, element_output_type, o_out, ot_bf16, N, K);
132 LIBXSMM_VLA_DECL(3, element_output_type, ci_out, cit_bf16, N, K);
133 LIBXSMM_VLA_DECL(3, element_output_type, co_out, cot_bf16, N, K);
134 /* define batch-reduce gemm kernels */
135 const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela;
136 const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb;
137 
138 float *cps_ptr = NULL;
139 
140 /* parallelize over C-blocks */
141 /* computing first logical thread */
142 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread;
143 /* number of tasks that could be run in parallel */
144 const libxsmm_blasint work = (N/bn) * (K/bk);
145 /* compute chunk size */
146 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1);
147 /* compute thr_begin and thr_end */
148 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
149 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
150 
151 #if 0
152 /* number of tasks that could be run in parallel for C and K blocks*/
153 const libxsmm_blasint work_ck = (C/bc) * (K/bk);
154 /* compute chunk size */
155 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1);
156 /* compute thr_begin and thr_end */
157 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck;
158 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck;
159 /* number of tasks that could be run in parallel for K and K blocks*/
160 const libxsmm_blasint work_kk = (K/bk) * (K/bk);
161 /* compute chunk size */
162 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1);
163 /* compute thr_begin and thr_end */
164 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk;
165 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk;
166 #endif
167 const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0;
168 
169 #ifdef PROFILE
170 __int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0;
171 float total_time = 0.0;
172 #endif
173 
174 /* lazy barrier init */
175 libxsmm_barrier_init(handle->barrier, (int)ltid);
176 
177 /* Blocking reduction domain if it is too large */
178 BF = 1;
179 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) {
180   BF = 8;
181   while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
182     BF--;
183   }
184 }
185 if (C > 2048 || K > 2048) {
186   BF = 16;
187   while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
188     BF--;
189   }
190 }
191 
192 if (C == 2048 && K == 1024) {
193   BF = 2;
194 }
195 
196 CB_BLOCKS = cBlocks/BF;
197 KB_BLOCKS = kBlocks/BF;
198 
199 #ifdef PROFILE
200 if (ltid == 0) reformat_start = _rdtsc();
201 #endif
202 
203 /* Upconvert the cp input to fp32 that is used for elementwise stuff */
204 for (inik = thr_begin; inik < thr_end; ++inik ) {
205   in = (inik % (N/bn))*bn;
206   ikb = inik / (N/bn);
207   ik = ikb*bk;
208   MATRIX_CVT_BF16_FP32_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, cp_bf16, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K));
209 }
210 
211 libxsmm_barrier_wait(handle->barrier, (int)ltid);
212 #ifdef PROFILE
213 if (ltid == 0) {
214   reformat_end = _rdtsc();
215   reformat_cycles = reformat_end - reformat_start;
216 }
217 #endif
218 
219 if (use_fused_implementation) {
220 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c"
221 } else {
222 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c"
223 }
224 
225 #ifdef PROFILE
226 if (ltid == 0) {
227   printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk );
228   total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f;
229   printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
230   printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
231   printf("GEMM W*x  time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9)));
232   printf("GEMM R*h  time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9)));
233 }
234 #undef PROFILE
235 #endif
236 
237 #undef MATRIX_CVT_BF16_FP32_LD
238 #undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD
239 #undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD
240