1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 #if 0 12 #define PROFILE 13 #endif 14 15 /* helper variables */ 16 libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS; 17 /* input sizes */ 18 const libxsmm_blasint K = handle->desc.K; 19 const libxsmm_blasint N = handle->desc.N; 20 const libxsmm_blasint C = handle->desc.C; 21 const libxsmm_blasint t = handle->T; 22 const libxsmm_blasint bk = handle->bk; 23 const libxsmm_blasint bn = handle->bn; 24 const libxsmm_blasint bc = handle->bc; 25 const libxsmm_blasint cBlocks = C/bc; 26 const libxsmm_blasint kBlocks = K/bk; 27 unsigned long long blocks; 28 29 /* define tensors */ 30 element_input_type *xt = (element_input_type* )handle->xt->data; 31 element_input_type *csp = (element_input_type* )handle->csp->data; 32 element_input_type *hpD = (element_input_type* )handle->hp->data; 33 element_filter_type *w = (element_filter_type*)handle->w->data; 34 element_filter_type *r = (element_filter_type*)handle->r->data; 35 element_output_type *b = (element_output_type*)handle->b->data; 36 element_output_type *cst = (element_output_type*)handle->cst->data; 37 element_output_type *ht = (element_output_type*)handle->ht->data; 38 element_output_type *it = (element_output_type*)handle->it->data; 39 element_output_type *ft = (element_output_type*)handle->ft->data; 40 element_output_type *ot = (element_output_type*)handle->ot->data; 41 element_output_type *cit = (element_output_type*)handle->cit->data; 42 element_output_type *cot = (element_output_type*)handle->cot->data; 43 element_filter_type *wiD = &(w[0]); 44 element_filter_type *wcD = &(w[C*K]); 45 element_filter_type *wfD = &(w[2*C*K]); 46 element_filter_type *woD = &(w[3*C*K]); 47 element_filter_type *riD = &(r[0]); 48 element_filter_type *rcD = &(r[K*K]); 49 element_filter_type *rfD = &(r[2*K*K]); 50 element_filter_type *roD = &(r[3*K*K]); 51 element_output_type *bi = &(b[0]); 52 element_output_type *bd = &(b[K]); 53 element_output_type *bf = &(b[2*K]); 54 element_output_type *bo = &(b[3*K]); 55 LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); 56 LIBXSMM_VLA_DECL(2, element_input_type, cp, csp, K); 57 LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); 58 LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk); 59 LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk); 60 LIBXSMM_VLA_DECL(4, element_filter_type, wo, woD, cBlocks, bc, bk); 61 LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk); 62 LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk); 63 LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk); 64 LIBXSMM_VLA_DECL(4, element_filter_type, ro, roD, kBlocks, bk, bk); 65 LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk); 66 LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K); 67 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); 68 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); 69 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); 70 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); 71 LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K); 72 LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K); 73 /* define batch-reduce gemm kernels */ 74 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL ); 75 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); 76 /* Auxiliary arrays for batch-reduce gemms */ 77 const element_filter_type *A_array[1024]; 78 const element_input_type *B_array[1024]; 79 element_output_type *cps_ptr = NULL; 80 81 /* parallelize over C-blocks */ 82 /* computing first logical thread */ 83 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; 84 /* number of tasks that could be run in parallel */ 85 const libxsmm_blasint work = (N/bn) * (K/bk); 86 /* compute chunk size */ 87 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); 88 /* compute thr_begin and thr_end */ 89 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 90 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 91 92 const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; 93 #ifdef PROFILE 94 __int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0; 95 float total_time = 0.0; 96 #endif 97 98 /* lazy barrier init */ 99 libxsmm_barrier_init(handle->barrier, (int)ltid); 100 101 /* Blocking reduction domain if it is too large */ 102 BF = 1; 103 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { 104 BF = 8; 105 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { 106 BF--; 107 } 108 } 109 if (C > 2048 || K > 2048) { 110 BF = 16; 111 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { 112 BF--; 113 } 114 } 115 116 if (C == 2048 && K == 1024) { 117 BF = 2; 118 } 119 120 CB_BLOCKS = cBlocks/BF; 121 KB_BLOCKS = kBlocks/BF; 122 123 if (use_fused_implementation) { 124 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c" 125 } else { 126 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c" 127 } 128 129 #ifdef PROFILE 130 if (ltid == 0) { 131 printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk ); 132 total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles)/(2.5 * 1e9)*1000.0f; 133 printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time ); 134 printf("GEMM W*x time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9))); 135 printf("GEMM R*h time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9))); 136 } 137 #undef PROFILE 138 #endif 139