1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 #if 0 12 #define PROFILE 13 #endif 14 15 #define MATRIX_CVT_BF16_FP32_LD(m, n, ld, _src, _dst) \ 16 do { \ 17 libxsmm_bfloat16 *src = _src; \ 18 float *dst = _dst; \ 19 libxsmm_blasint __i,__j; \ 20 for ( __j = 0; __j < n; ++__j ) { \ 21 for ( __i = 0; __i < m; __i+=16 ) { \ 22 _mm512_storeu_ps((float*)&dst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&src[(__j*ld)+__i]))); \ 23 } \ 24 } \ 25 } while (0) 26 27 #define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD(m, n, ld, _srcdst, _colv) \ 28 do { \ 29 libxsmm_bfloat16 *colv = _colv; \ 30 float *srcdst = _srcdst; \ 31 libxsmm_blasint __i,__j; \ 32 for ( __j = 0; __j < n; ++__j ) { \ 33 for ( __i = 0; __i < m; __i+=16 ) { \ 34 _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i]))); \ 35 } \ 36 } \ 37 } while (0) 38 39 #define MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD(m, n, ld, _srcdst, _colv, const_bias) \ 40 do { \ 41 libxsmm_bfloat16 *colv = _colv; \ 42 float *srcdst = _srcdst; \ 43 libxsmm_blasint __i,__j; \ 44 __m512 vbias = _mm512_set1_ps(const_bias); \ 45 for ( __j = 0; __j < n; ++__j ) { \ 46 for ( __i = 0; __i < m; __i+=16 ) { \ 47 _mm512_storeu_ps((float*)&srcdst[(__j*ld)+__i], _mm512_add_ps(vbias, LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(_mm256_loadu_si256((__m256i*)&colv[__i])))); \ 48 } \ 49 } \ 50 } while (0) 51 52 /* helper variables */ 53 libxsmm_blasint j, ik, ikb, in, /*ic, icb,*/ inik, BF, CB, CB_BLOCKS, KB_BLOCKS; 54 /* input sizes */ 55 const libxsmm_blasint K = handle->desc.K; 56 const libxsmm_blasint N = handle->desc.N; 57 const libxsmm_blasint C = handle->desc.C; 58 const libxsmm_blasint t = handle->T; 59 const libxsmm_blasint bk = handle->bk; 60 const libxsmm_blasint bn = handle->bn; 61 const libxsmm_blasint bc = handle->bc; 62 const libxsmm_blasint cBlocks = C/bc; 63 const libxsmm_blasint kBlocks = K/bk; 64 int lpb = 2; 65 const int bc_lp = bc/lpb; 66 const int bk_lp = bk/lpb; 67 unsigned long long blocks, blocksa, blocksb; 68 69 /* define tensors */ 70 element_input_type *xt = (element_input_type* )handle->xt->data; 71 element_input_type *csp = (element_input_type* )handle->csp->data; 72 element_input_type *hpD = (element_input_type* )handle->hp->data; 73 element_filter_type *w = (element_filter_type*)handle->w->data; 74 element_filter_type *r = (element_filter_type*)handle->r->data; 75 element_output_type *b = (element_output_type*)handle->b->data; 76 77 /* These buffers are scratch for fp32 output of gemms (intermmediate results) */ 78 float *cst = (float*)handle->cst_scratch; 79 float *ht = (float*)handle->ht_scratch; 80 float *it = (float*)handle->it_scratch; 81 float *ft = (float*)handle->ft_scratch; 82 float *ot = (float*)handle->ot_scratch; 83 float *cit = (float*)handle->cit_scratch; 84 float *cot = (float*)handle->cot_scratch; 85 /* This has to be also upconverted since it is used in the elementwise functions */ 86 float *csp_f32 = (float*)handle->csp_scratch; 87 /* These are the output bf16 data */ 88 element_output_type *cst_bf16 = (element_output_type*)handle->cst->data; 89 element_output_type *ht_bf16 = (element_output_type*)handle->ht->data; 90 element_output_type *it_bf16 = (element_output_type*)handle->it->data; 91 element_output_type *ft_bf16 = (element_output_type*)handle->ft->data; 92 element_output_type *ot_bf16 = (element_output_type*)handle->ot->data; 93 element_output_type *cit_bf16 = (element_output_type*)handle->cit->data; 94 element_output_type *cot_bf16 = (element_output_type*)handle->cot->data; 95 96 element_filter_type *wiD = &(w[0]); 97 element_filter_type *wcD = &(w[C*K]); 98 element_filter_type *wfD = &(w[2*C*K]); 99 element_filter_type *woD = &(w[3*C*K]); 100 element_filter_type *riD = &(r[0]); 101 element_filter_type *rcD = &(r[K*K]); 102 element_filter_type *rfD = &(r[2*K*K]); 103 element_filter_type *roD = &(r[3*K*K]); 104 element_output_type *bi = &(b[0]); 105 element_output_type *bd = &(b[K]); 106 element_output_type *bf = &(b[2*K]); 107 element_output_type *bo = &(b[3*K]); 108 LIBXSMM_VLA_DECL(2, float, cp, csp_f32, K); 109 LIBXSMM_VLA_DECL(2, element_input_type, cp_bf16, csp, K); 110 LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); 111 LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); 112 LIBXSMM_VLA_DECL(5, element_filter_type, wi, wiD, cBlocks, bc_lp, bk, lpb); 113 LIBXSMM_VLA_DECL(5, element_filter_type, wf, wfD, cBlocks, bc_lp, bk, lpb); 114 LIBXSMM_VLA_DECL(5, element_filter_type, wo, woD, cBlocks, bc_lp, bk, lpb); 115 LIBXSMM_VLA_DECL(5, element_filter_type, wc, wcD, cBlocks, bc_lp, bk, lpb); 116 LIBXSMM_VLA_DECL(5, element_filter_type, ri, riD, kBlocks, bk_lp, bk, lpb); 117 LIBXSMM_VLA_DECL(5, element_filter_type, rf, rfD, kBlocks, bk_lp, bk, lpb); 118 LIBXSMM_VLA_DECL(5, element_filter_type, ro, roD, kBlocks, bk_lp, bk, lpb); 119 LIBXSMM_VLA_DECL(5, element_filter_type, rc, rcD, kBlocks, bk_lp, bk, lpb); 120 LIBXSMM_VLA_DECL(3, float, cs, cst, N, K); 121 LIBXSMM_VLA_DECL(3, float, h, ht, N, K); 122 LIBXSMM_VLA_DECL(3, float, i, it, N, K); 123 LIBXSMM_VLA_DECL(3, float, f, ft, N, K); 124 LIBXSMM_VLA_DECL(3, float, o, ot, N, K); 125 LIBXSMM_VLA_DECL(3, float, ci, cit, N, K); 126 LIBXSMM_VLA_DECL(3, float, co, cot, N, K); 127 LIBXSMM_VLA_DECL(3, element_output_type, cs_out, cst_bf16, N, K); 128 LIBXSMM_VLA_DECL(3, element_output_type, h_out, ht_bf16, N, K); 129 LIBXSMM_VLA_DECL(3, element_output_type, i_out, it_bf16, N, K); 130 LIBXSMM_VLA_DECL(3, element_output_type, f_out, ft_bf16, N, K); 131 LIBXSMM_VLA_DECL(3, element_output_type, o_out, ot_bf16, N, K); 132 LIBXSMM_VLA_DECL(3, element_output_type, ci_out, cit_bf16, N, K); 133 LIBXSMM_VLA_DECL(3, element_output_type, co_out, cot_bf16, N, K); 134 /* define batch-reduce gemm kernels */ 135 const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernela = handle->fwd_kernela; 136 const libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernelb = handle->fwd_kernelb; 137 138 float *cps_ptr = NULL; 139 140 /* parallelize over C-blocks */ 141 /* computing first logical thread */ 142 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; 143 /* number of tasks that could be run in parallel */ 144 const libxsmm_blasint work = (N/bn) * (K/bk); 145 /* compute chunk size */ 146 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); 147 /* compute thr_begin and thr_end */ 148 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 149 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 150 151 #if 0 152 /* number of tasks that could be run in parallel for C and K blocks*/ 153 const libxsmm_blasint work_ck = (C/bc) * (K/bk); 154 /* compute chunk size */ 155 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); 156 /* compute thr_begin and thr_end */ 157 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; 158 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; 159 /* number of tasks that could be run in parallel for K and K blocks*/ 160 const libxsmm_blasint work_kk = (K/bk) * (K/bk); 161 /* compute chunk size */ 162 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); 163 /* compute thr_begin and thr_end */ 164 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; 165 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; 166 #endif 167 const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; 168 169 #ifdef PROFILE 170 __int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0, reformat_start, reformat_end, reformat_cycles = 0; 171 float total_time = 0.0; 172 #endif 173 174 /* lazy barrier init */ 175 libxsmm_barrier_init(handle->barrier, (int)ltid); 176 177 /* Blocking reduction domain if it is too large */ 178 BF = 1; 179 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { 180 BF = 8; 181 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { 182 BF--; 183 } 184 } 185 if (C > 2048 || K > 2048) { 186 BF = 16; 187 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { 188 BF--; 189 } 190 } 191 192 if (C == 2048 && K == 1024) { 193 BF = 2; 194 } 195 196 CB_BLOCKS = cBlocks/BF; 197 KB_BLOCKS = kBlocks/BF; 198 199 #ifdef PROFILE 200 if (ltid == 0) reformat_start = _rdtsc(); 201 #endif 202 203 /* Upconvert the cp input to fp32 that is used for elementwise stuff */ 204 for (inik = thr_begin; inik < thr_end; ++inik ) { 205 in = (inik % (N/bn))*bn; 206 ikb = inik / (N/bn); 207 ik = ikb*bk; 208 MATRIX_CVT_BF16_FP32_LD( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, cp_bf16, in, ik, K), &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K)); 209 } 210 211 libxsmm_barrier_wait(handle->barrier, (int)ltid); 212 #ifdef PROFILE 213 if (ltid == 0) { 214 reformat_end = _rdtsc(); 215 reformat_cycles = reformat_end - reformat_start; 216 } 217 #endif 218 219 if (use_fused_implementation) { 220 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused_bf16.tpl.c" 221 } else { 222 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused_bf16.tpl.c" 223 } 224 225 #ifdef PROFILE 226 if (ltid == 0) { 227 printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); 228 total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles+reformat_cycles)/(2.5 * 1e9)*1000.0f; 229 printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 230 printf("Reformat weights time is %f ms (%.2f%%)\n", reformat_cycles/(2.5 * 1e9)*1000.0f, reformat_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 231 printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); 232 printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); 233 } 234 #undef PROFILE 235 #endif 236 237 #undef MATRIX_CVT_BF16_FP32_LD 238 #undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_LD 239 #undef MATRIX_BCST_CVT_BF16_FP32_COLVECTOR_CONST_LD 240