1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Evangelos Georganas, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 #if 0
12 #define PROFILE
13 #endif
14 
15 /* helper variables */
16 libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS;
17 /* input sizes */
18 const libxsmm_blasint K =  handle->desc.K;
19 const libxsmm_blasint N =  handle->desc.N;
20 const libxsmm_blasint C =  handle->desc.C;
21 const libxsmm_blasint t =  handle->T;
22 const libxsmm_blasint bk = handle->bk;
23 const libxsmm_blasint bn = handle->bn;
24 const libxsmm_blasint bc = handle->bc;
25 const libxsmm_blasint cBlocks = C/bc;
26 const libxsmm_blasint kBlocks = K/bk;
27 unsigned long long blocks;
28 
29 /* define tensors */
30 element_input_type  *xt  = (element_input_type* )handle->xt->data;
31 element_input_type  *csp = (element_input_type* )handle->csp->data;
32 element_input_type  *hpD = (element_input_type* )handle->hp->data;
33 element_filter_type *w   = (element_filter_type*)handle->w->data;
34 element_filter_type *r   = (element_filter_type*)handle->r->data;
35 element_output_type *b   = (element_output_type*)handle->b->data;
36 element_output_type *cst = (element_output_type*)handle->cst->data;
37 element_output_type *ht  = (element_output_type*)handle->ht->data;
38 element_output_type *it  = (element_output_type*)handle->it->data;
39 element_output_type *ft  = (element_output_type*)handle->ft->data;
40 element_output_type *ot  = (element_output_type*)handle->ot->data;
41 element_output_type *cit = (element_output_type*)handle->cit->data;
42 element_output_type *cot = (element_output_type*)handle->cot->data;
43 element_filter_type *wiD = &(w[0]);
44 element_filter_type *wcD = &(w[C*K]);
45 element_filter_type *wfD = &(w[2*C*K]);
46 element_filter_type *woD = &(w[3*C*K]);
47 element_filter_type *riD = &(r[0]);
48 element_filter_type *rcD = &(r[K*K]);
49 element_filter_type *rfD = &(r[2*K*K]);
50 element_filter_type *roD = &(r[3*K*K]);
51 element_output_type *bi  = &(b[0]);
52 element_output_type *bd  = &(b[K]);
53 element_output_type *bf  = &(b[2*K]);
54 element_output_type *bo  = &(b[3*K]);
55 LIBXSMM_VLA_DECL(3, element_input_type,  x, xt, N, C);
56 LIBXSMM_VLA_DECL(2, element_input_type,  cp, csp, K);
57 LIBXSMM_VLA_DECL(2, element_input_type,  hp, hpD, K);
58 LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD, cBlocks, bc, bk);
59 LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD, cBlocks, bc, bk);
60 LIBXSMM_VLA_DECL(4, element_filter_type, wo, woD, cBlocks, bc, bk);
61 LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD, cBlocks, bc, bk);
62 LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD, kBlocks, bk, bk);
63 LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD, kBlocks, bk, bk);
64 LIBXSMM_VLA_DECL(4, element_filter_type, ro, roD, kBlocks, bk, bk);
65 LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD, kBlocks, bk, bk);
66 LIBXSMM_VLA_DECL(3, element_output_type, cs, cst, N, K);
67 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K);
68 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K);
69 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K);
70 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K);
71 LIBXSMM_VLA_DECL(3, element_output_type, ci, cit, N, K);
72 LIBXSMM_VLA_DECL(3, element_output_type, co, cot, N, K);
73 /* define batch-reduce gemm kernels */
74 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL );
75 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL );
76 /* Auxiliary arrays for batch-reduce gemms */
77 const element_filter_type *A_array[1024];
78 const element_input_type  *B_array[1024];
79 element_output_type *cps_ptr = NULL;
80 
81 /* parallelize over C-blocks */
82 /* computing first logical thread */
83 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread;
84 /* number of tasks that could be run in parallel */
85 const libxsmm_blasint work = (N/bn) * (K/bk);
86 /* compute chunk size */
87 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1);
88 /* compute thr_begin and thr_end */
89 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
90 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
91 
92 const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0;
93 #ifdef PROFILE
94 __int64_t eltwise_start, eltwise_end, eltwise_cycles = 0, gemm_start, gemm_end, gemm_cycles = 0, gemm_cycles2 = 0;
95 float total_time = 0.0;
96 #endif
97 
98 /* lazy barrier init */
99 libxsmm_barrier_init(handle->barrier, (int)ltid);
100 
101 /* Blocking reduction domain if it is too large */
102 BF = 1;
103 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) {
104   BF = 8;
105   while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
106     BF--;
107   }
108 }
109 if (C > 2048 || K > 2048) {
110   BF = 16;
111   while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) {
112     BF--;
113   }
114 }
115 
116 if (C == 2048 && K == 1024) {
117   BF = 2;
118 }
119 
120 CB_BLOCKS = cBlocks/BF;
121 KB_BLOCKS = kBlocks/BF;
122 
123 if (use_fused_implementation) {
124 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_fused.tpl.c"
125 } else {
126 #include "libxsmm_dnn_rnncell_st_lstm_fwd_nc_kcck_diffused.tpl.c"
127 }
128 
129 #ifdef PROFILE
130 if (ltid == 0) {
131   printf("----- PROFILING LSTM FWD (N = %d, C = %d, K = %d, bn = %d. bc = %d, bk = %d)----\n", N, C, K, bn, bc, bk );
132   total_time = (gemm_cycles+gemm_cycles2+eltwise_cycles)/(2.5 * 1e9)*1000.0f;
133   printf("Elementwise time is %f ms (%.2f%%)\n", eltwise_cycles/(2.5 * 1e9)*1000.0f, eltwise_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time );
134   printf("GEMM W*x  time is %f ms (%.2f%%) at %f GFLOPS\n", gemm_cycles/(2.5 * 1e9)*1000.0f, gemm_cycles/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*C*K*2.0)*4.0/1e9/(gemm_cycles/(2.5 * 1e9)));
135   printf("GEMM R*h  time is %f ms (%.2f%%) at %f GFLOPS\n\n", gemm_cycles2/(2.5 * 1e9)*1000.0f, gemm_cycles2/(2.5 * 1e9)*1000.0f*100.0/total_time, t*(N*K*K*2.0)*4.0/1e9/(gemm_cycles2/(2.5 * 1e9)));
136 }
137 #undef PROFILE
138 #endif
139