1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas (Intel Corp.) 10 ******************************************************************************/ 11 12 /* First perform the W*x part of the output */ 13 for (j = 0; j < t; ++j) { 14 /* let's run the cell in blocks for good locality */ 15 /* Block reduction loop if requested */ 16 for (CB = 0; CB < BF; CB++) { 17 for (inik = thr_begin; inik < thr_end; ++inik ) { 18 in = (inik % (N/bn))*bn; 19 ikb = inik / (N/bn); 20 ik = ikb*bk; 21 /* initialize i with bi */ 22 #ifdef PROFILE 23 if (ltid == 0) gemm_start = _rdtsc(); 24 #endif 25 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); 26 /* i += W.x */ 27 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 28 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 29 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 30 } 31 /* Reduce batch gemm call */ 32 blocks = CB_BLOCKS; 33 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); 34 #ifdef PROFILE 35 if (ltid == 0) { 36 gemm_end = _rdtsc(); 37 gemm_cycles += gemm_end-gemm_start; 38 } 39 #endif 40 41 #ifdef PROFILE 42 if (ltid == 0) gemm_start = _rdtsc(); 43 #endif 44 /* initialize ci with bd */ 45 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); 46 /* ci += W.x */ 47 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 48 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 49 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 50 } 51 /* Reduce batch gemm call */ 52 blocks = CB_BLOCKS; 53 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); 54 #ifdef PROFILE 55 if (ltid == 0) { 56 gemm_end = _rdtsc(); 57 gemm_cycles += gemm_end-gemm_start; 58 } 59 #endif 60 61 #ifdef PROFILE 62 if (ltid == 0) gemm_start = _rdtsc(); 63 #endif 64 /* initialize f with (bf + forget_bias) */ 65 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_const_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); 66 /* f += W.x */ 67 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 68 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 69 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 70 } 71 /* Reduce batch gemm call */ 72 blocks = CB_BLOCKS; 73 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); 74 #ifdef PROFILE 75 if (ltid == 0) { 76 gemm_end = _rdtsc(); 77 gemm_cycles += gemm_end-gemm_start; 78 } 79 #endif 80 81 #ifdef PROFILE 82 if (ltid == 0) gemm_start = _rdtsc(); 83 #endif 84 /* initialize o with bo */ 85 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); 86 /* o += W.x */ 87 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 88 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wo, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 89 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 90 } 91 /* Reduce batch gemm call */ 92 blocks = CB_BLOCKS; 93 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); 94 #ifdef PROFILE 95 if (ltid == 0) { 96 gemm_end = _rdtsc(); 97 gemm_cycles += gemm_end-gemm_start; 98 } 99 #endif 100 } 101 } 102 } 103 104 /* Compute the R*h part of the output */ 105 for (j = 0; j < t; ++j) { 106 /* let's run the cell in blocks for good locality */ 107 /* Block reduction loop if requested */ 108 for (CB = 0; CB < BF; CB++) { 109 for (inik = thr_begin; inik < thr_end; ++inik ) { 110 in = (inik % (N/bn))*bn; 111 ikb = inik / (N/bn); 112 ik = ikb*bk; 113 #ifdef PROFILE 114 if (ltid == 0) gemm_start = _rdtsc(); 115 #endif 116 /* i += R.h */ 117 if (0 == j) { 118 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 119 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 120 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 121 } 122 } else { 123 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 124 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 125 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 126 } 127 } 128 /* Reduce batch gemm call */ 129 blocks = KB_BLOCKS; 130 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); 131 #ifdef PROFILE 132 if (ltid == 0) { 133 gemm_end = _rdtsc(); 134 gemm_cycles2 += gemm_end-gemm_start; 135 } 136 #endif 137 138 #ifdef PROFILE 139 if (ltid == 0) gemm_start = _rdtsc(); 140 #endif 141 /* ci += R.h */ 142 if (0 == j) { 143 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 144 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 145 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 146 } 147 } else { 148 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 149 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 150 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 151 } 152 } 153 /* Reduce batch gemm call */ 154 blocks = KB_BLOCKS; 155 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); 156 #ifdef PROFILE 157 if (ltid == 0) { 158 gemm_end = _rdtsc(); 159 gemm_cycles2 += gemm_end-gemm_start; 160 } 161 #endif 162 163 #ifdef PROFILE 164 if (ltid == 0) gemm_start = _rdtsc(); 165 #endif 166 /* f += R.h */ 167 if (0 == j) { 168 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 169 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 170 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 171 } 172 } else { 173 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 174 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 175 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 176 } 177 } 178 /* Reduce batch gemm call */ 179 blocks = KB_BLOCKS; 180 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); 181 #ifdef PROFILE 182 if (ltid == 0) { 183 gemm_end = _rdtsc(); 184 gemm_cycles2 += gemm_end-gemm_start; 185 } 186 #endif 187 #ifdef PROFILE 188 if (ltid == 0) gemm_start = _rdtsc(); 189 #endif 190 /* o += R.h */ 191 if (0 == j) { 192 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 193 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 194 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 195 } 196 } else { 197 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 198 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 199 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 200 } 201 } 202 /* Reduce batch gemm call */ 203 blocks = KB_BLOCKS; 204 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); 205 #ifdef PROFILE 206 if (ltid == 0) { 207 gemm_end = _rdtsc(); 208 gemm_cycles2 += gemm_end-gemm_start; 209 } 210 #endif 211 212 if (CB == BF-1) { 213 #ifdef PROFILE 214 if (ltid == 0) { 215 eltwise_start = _rdtsc(); 216 } 217 #endif 218 cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); 219 /* Compute i, ci, f, o, cs, co and h */ 220 #if defined(LIBXSMM_RNN_CELL_AVX512) 221 if (bk % 16 == 0 && bc % 16 == 0) { 222 #include "libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c" 223 } else { 224 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); 225 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); 226 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); 227 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); 228 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 229 libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 230 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); 231 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 232 } 233 #else 234 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); 235 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); 236 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); 237 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); 238 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 239 libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 240 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); 241 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 242 #endif 243 244 #ifdef PROFILE 245 if (ltid == 0) { 246 eltwise_end = _rdtsc(); 247 eltwise_cycles += eltwise_end-eltwise_start; 248 } 249 #endif 250 } 251 } 252 } 253 libxsmm_barrier_wait(handle->barrier, (int)ltid); 254 } 255