1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 12 /* helper variables */ 13 libxsmm_blasint j, ik, ikb, in, ic, icb, inik, BF, CB, CB_BLOCKS, KB_BLOCKS, ikic, jk, jc; 14 /* input sizes */ 15 const libxsmm_blasint K = handle->desc.K; 16 const libxsmm_blasint N = handle->desc.N; 17 const libxsmm_blasint C = handle->desc.C; 18 const libxsmm_blasint t = handle->T; 19 const libxsmm_blasint bk = handle->bk; 20 const libxsmm_blasint bn = handle->bn; 21 const libxsmm_blasint bc = handle->bc; 22 const libxsmm_blasint K3 = K * 3; 23 const libxsmm_blasint cBlocks = C/bc; 24 const libxsmm_blasint kBlocks = K/bk; 25 unsigned long long blocks; 26 27 /* define tensors */ 28 element_input_type *xt = (element_input_type* )handle->xt->data; 29 element_input_type *hpD = (element_input_type* )handle->hp->data; 30 element_filter_type *w = (element_filter_type*)handle->w->data; 31 element_filter_type *r = (element_filter_type*)handle->r->data; 32 element_filter_type *w_scratch = (element_filter_type*)handle->scratch_w; 33 element_filter_type *r_scratch = (element_filter_type*)handle->scratch_r; 34 element_output_type *b = (element_output_type*)handle->b->data; 35 element_output_type *ht = (element_output_type*)handle->ht->data; 36 element_output_type *it = (element_output_type*)handle->it->data; 37 element_output_type *ct = (element_output_type*)handle->cit->data; 38 element_output_type *ft = (element_output_type*)handle->ft->data; 39 element_output_type *ot = (element_output_type*)handle->ot->data; 40 element_filter_type *wiD = &(w[0]); 41 element_filter_type *wcD = &(w[K]); 42 element_filter_type *wfD = &(w[2*K]); 43 element_filter_type *riD = &(r[0]); 44 element_filter_type *rcD = &(r[K]); 45 element_filter_type *rfD = &(r[2*K]); 46 element_filter_type *wiD_scratch = &(w_scratch[0]); 47 element_filter_type *wcD_scratch = &(w_scratch[C*K]); 48 element_filter_type *wfD_scratch = &(w_scratch[2*C*K]); 49 element_filter_type *riD_scratch = &(r_scratch[0]); 50 element_filter_type *rcD_scratch = &(r_scratch[K*K]); 51 element_filter_type *rfD_scratch = &(r_scratch[2*K*K]); 52 element_output_type *bi = &(b[0]); 53 element_output_type *bd = &(b[K]); 54 element_output_type *bf = &(b[2*K]); 55 LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); 56 LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); 57 LIBXSMM_VLA_DECL(4, element_filter_type, wi, wiD_scratch, cBlocks, bc, bk); 58 LIBXSMM_VLA_DECL(4, element_filter_type, wc, wcD_scratch, cBlocks, bc, bk); 59 LIBXSMM_VLA_DECL(4, element_filter_type, wf, wfD_scratch, cBlocks, bc, bk); 60 LIBXSMM_VLA_DECL(4, element_filter_type, ri, riD_scratch, kBlocks, bk, bk); 61 LIBXSMM_VLA_DECL(4, element_filter_type, rc, rcD_scratch, kBlocks, bk, bk); 62 LIBXSMM_VLA_DECL(4, element_filter_type, rf, rfD_scratch, kBlocks, bk, bk); 63 LIBXSMM_VLA_DECL(2, element_filter_type, wi_ck, wiD, K3); 64 LIBXSMM_VLA_DECL(2, element_filter_type, wc_ck, wcD, K3); 65 LIBXSMM_VLA_DECL(2, element_filter_type, wf_ck, wfD, K3); 66 LIBXSMM_VLA_DECL(2, element_filter_type, ri_ck, riD, K3); 67 LIBXSMM_VLA_DECL(2, element_filter_type, rc_ck, rcD, K3); 68 LIBXSMM_VLA_DECL(2, element_filter_type, rf_ck, rfD, K3); 69 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); 70 LIBXSMM_VLA_DECL(3, element_output_type, i, it, N, K); 71 LIBXSMM_VLA_DECL(3, element_output_type, c, ct, N, K); 72 LIBXSMM_VLA_DECL(3, element_output_type, f, ft, N, K); 73 LIBXSMM_VLA_DECL(3, element_output_type, o, ot, N, K); 74 /* define batch-reduce gemm kernels */ 75 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernela = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bc, &bk, &C, &K, NULL, NULL, NULL, NULL ); 76 const libxsmm_smmfunction_reducebatch_addr batchreduce_kernelb = libxsmm_smmdispatch_reducebatch_addr( bk, bn, bk, &bk, &K, &K, NULL, NULL, NULL, NULL ); 77 /* define gemm kernels */ 78 /* Auxiliary arrays for batch-reduce gemms */ 79 const element_filter_type *A_array[1024]; 80 const element_input_type *B_array[1024]; 81 82 /* parallelize over C-blocks */ 83 /* computing first logical thread */ 84 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; 85 /* number of tasks that could be run in parallel */ 86 const libxsmm_blasint work = (N/bn) * (K/bk); 87 /* compute chunk size */ 88 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); 89 /* compute thr_begin and thr_end */ 90 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 91 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 92 93 /* number of tasks that could be run in parallel for C and K blocks*/ 94 const libxsmm_blasint work_ck = (C/bc) * (K/bk); 95 /* compute chunk size */ 96 const libxsmm_blasint chunksize_ck = (work_ck % (libxsmm_blasint)handle->desc.threads == 0) ? (work_ck / (libxsmm_blasint)handle->desc.threads) : ((work_ck / (libxsmm_blasint)handle->desc.threads) + 1); 97 /* compute thr_begin and thr_end */ 98 const libxsmm_blasint thr_begin_ck = (ltid * chunksize_ck < work_ck) ? (ltid * chunksize_ck) : work_ck; 99 const libxsmm_blasint thr_end_ck = ((ltid + 1) * chunksize_ck < work_ck) ? ((ltid + 1) * chunksize_ck) : work_ck; 100 101 /* number of tasks that could be run in parallel for K and K blocks*/ 102 const libxsmm_blasint work_kk = (K/bk) * (K/bk); 103 /* compute chunk size */ 104 const libxsmm_blasint chunksize_kk = (work_kk % (libxsmm_blasint)handle->desc.threads == 0) ? (work_kk / (libxsmm_blasint)handle->desc.threads) : ((work_kk / (libxsmm_blasint)handle->desc.threads) + 1); 105 /* compute thr_begin and thr_end */ 106 const libxsmm_blasint thr_begin_kk = (ltid * chunksize_kk < work_kk) ? (ltid * chunksize_kk) : work_kk; 107 const libxsmm_blasint thr_end_kk = ((ltid + 1) * chunksize_kk < work_kk) ? ((ltid + 1) * chunksize_kk) : work_kk; 108 #if 0 109 const int use_fused_implementation = (C == 2048 && K == 2048) ? 1 : 0; 110 #endif 111 /* lazy barrier init */ 112 libxsmm_barrier_init(handle->barrier, (int)ltid); 113 114 /* Blocking reduction domain if it is too large */ 115 BF = 1; 116 if ((C > 1024 && C <= 2048) || (K > 1024 && K <= 2048)) { 117 BF = 8; 118 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { 119 BF--; 120 } 121 } 122 if (C > 2048 || K > 2048) { 123 BF = 16; 124 while ( (cBlocks % BF != 0) || (kBlocks % BF != 0) ) { 125 BF--; 126 } 127 } 128 129 if (C == 2048 && K == 1024) { 130 BF = 2; 131 } 132 133 CB_BLOCKS = cBlocks/BF; 134 KB_BLOCKS = kBlocks/BF; 135 136 /* Upfront reformatting of W and R */ 137 /* reformat W */ 138 for (ikic = thr_begin_ck; ikic < thr_end_ck; ++ikic ) { 139 ic = (ikic / (K/bk)); 140 ik = (ikic % (K/bk)); 141 for (jk = 0; jk < bk; ++jk) { 142 for (jc = 0; jc < bc; ++jc) { 143 LIBXSMM_VLA_ACCESS(4, wi, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wi_ck, ic*bc+jc, ik*bk+jk, 3*K); 144 LIBXSMM_VLA_ACCESS(4, wc, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wc_ck, ic*bc+jc, ik*bk+jk, 3*K); 145 LIBXSMM_VLA_ACCESS(4, wf, ik, ic, jc, jk, cBlocks, bc, bk) = LIBXSMM_VLA_ACCESS(2, wf_ck, ic*bc+jc, ik*bk+jk, 3*K); 146 } 147 } 148 } 149 150 /* reformat R */ 151 for (ikic = thr_begin_kk; ikic < thr_end_kk; ++ikic ) { 152 ik = (ikic / (K/bk)); 153 ic = (ikic % (K/bk)); 154 for (jk = 0; jk < bk; ++jk) { 155 for (jc = 0; jc < bk; ++jc) { 156 LIBXSMM_VLA_ACCESS(4, ri, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, ri_ck, ic*bk+jc, ik*bk+jk, 3*K); 157 LIBXSMM_VLA_ACCESS(4, rc, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rc_ck, ic*bk+jc, ik*bk+jk, 3*K); 158 LIBXSMM_VLA_ACCESS(4, rf, ik, ic, jc, jk, kBlocks, bk, bk) = LIBXSMM_VLA_ACCESS(2, rf_ck, ic*bk+jc, ik*bk+jk, 3*K); 159 } 160 } 161 } 162 163 libxsmm_barrier_wait(handle->barrier, (int)ltid); 164 165 /* lazy barrier init */ 166 libxsmm_barrier_init(handle->barrier, (int)ltid); 167 168 /* All data is in column-major format */ 169 for (j = 0; j < t; ++j) { 170 /* let's run the cell in blocks for good locality */ 171 /* Block reduction loop if requested */ 172 for (CB = 0; CB < BF; CB++) { 173 for (inik = thr_begin; inik < thr_end; ++inik ) { 174 in = (inik % (N/bn))*bn; 175 ikb = inik / (N/bn); 176 ik = ikb*bk; 177 /* initialize i with bi */ 178 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); 179 /* i += W.x */ 180 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 181 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 182 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 183 } 184 /* Reduce batch gemm call */ 185 blocks = CB_BLOCKS; 186 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); 187 /* i += R.hp */ 188 if (0 == j) { 189 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 190 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 191 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 192 } 193 } else { 194 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 195 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 196 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 197 } 198 } 199 /* Reduce batch gemm call */ 200 blocks = KB_BLOCKS; 201 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); 202 /* initialize c with bd */ 203 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &bd[ik] ); 204 /* c += W.x */ 205 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 206 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 207 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 208 } 209 /* Reduce batch gemm call */ 210 blocks = CB_BLOCKS; 211 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks); 212 /* c += R.hp */ 213 if (0 == j) { 214 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 215 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 216 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 217 } 218 } else { 219 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 220 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 221 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 222 } 223 } 224 /* Reduce batch gemm call */ 225 blocks = KB_BLOCKS; 226 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &blocks); 227 228 if (CB == BF-1) { 229 /* i = sigmoid(i) */ 230 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); 231 /* o = hp . i */ 232 if (0 == j) { 233 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); 234 } else { 235 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); 236 } 237 } 238 } 239 } 240 libxsmm_barrier_wait(handle->barrier, (int)ltid); 241 /* We need a barrier here to ensure all elements of o are computed before f can be computed */ 242 for (CB = 0; CB < BF; CB++) { 243 for (inik = thr_begin; inik < thr_end; ++inik ) { 244 in = (inik % (N/bn))*bn; 245 ikb = inik / (N/bn); 246 ik = ikb*bk; 247 /* initialize f with bf */ 248 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik] ); 249 /* f += W.x */ 250 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 251 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 252 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 253 } 254 /* Reduce batch gemm call */ 255 blocks = CB_BLOCKS; 256 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); 257 /* f += R.o */ 258 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 259 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 260 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, o, j, in, ic + CB*KB_BLOCKS*bk, N, K); 261 } 262 /* Reduce batch gemm call */ 263 blocks = KB_BLOCKS; 264 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); 265 266 if (CB == BF-1) { 267 /* f = tanh(f) */ 268 libxsmm_internal_matrix_tanh_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); 269 /* c = sigmoid(c) */ 270 libxsmm_internal_matrix_sigmoid_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K) ); 271 /* h = (1 - c) . f */ 272 libxsmm_internal_matrix_complement_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 273 libxsmm_internal_matrix_eltwise_mult_ld ( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 274 /* h += c . hp */ 275 if (0 == j) { 276 libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ik, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 277 } else { 278 libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, c, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 279 } 280 } 281 } 282 } 283 libxsmm_barrier_wait(handle->barrier, (int)ltid); 284 } 285 286