1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Evangelos Georganas (Intel Corp.) 10 ******************************************************************************/ 11 12 /* All data is in column-major format */ 13 for (j = 0; j < t; ++j) { 14 /* let's run the cell in blocks for good locality */ 15 /* Block reduction loop if requested */ 16 for (CB = 0; CB < BF; CB++) { 17 for (inik = thr_begin; inik < thr_end; ++inik ) { 18 in = (inik % (N/bn))*bn; 19 ikb = inik / (N/bn); 20 ik = ikb*bk; 21 /* initialize i with bi */ 22 #ifdef PROFILE 23 if (ltid == 0) gemm_start = _rdtsc(); 24 #endif 25 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &bi[ik] ); 26 /* i += W.x */ 27 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 28 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wi, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 29 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 30 } 31 /* Reduce batch gemm call */ 32 blocks = CB_BLOCKS; 33 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); 34 #ifdef PROFILE 35 if (ltid == 0) { 36 gemm_end = _rdtsc(); 37 gemm_cycles += gemm_end-gemm_start; 38 } 39 #endif 40 #ifdef PROFILE 41 if (ltid == 0) gemm_start = _rdtsc(); 42 #endif 43 /* i += R.h */ 44 if (0 == j) { 45 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 46 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 47 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 48 } 49 } else { 50 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 51 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ri, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 52 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 53 } 54 } 55 /* Reduce batch gemm call */ 56 blocks = KB_BLOCKS; 57 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &blocks); 58 #ifdef PROFILE 59 if (ltid == 0) { 60 gemm_end = _rdtsc(); 61 gemm_cycles2 += gemm_end-gemm_start; 62 } 63 #endif 64 #ifdef PROFILE 65 if (ltid == 0) gemm_start = _rdtsc(); 66 #endif 67 /* initialize ci with bd */ 68 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &bd[ik] ); 69 /* ci += W.x */ 70 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 71 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wc, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 72 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 73 } 74 /* Reduce batch gemm call */ 75 blocks = CB_BLOCKS; 76 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); 77 #ifdef PROFILE 78 if (ltid == 0) { 79 gemm_end = _rdtsc(); 80 gemm_cycles += gemm_end-gemm_start; 81 } 82 #endif 83 #ifdef PROFILE 84 if (ltid == 0) gemm_start = _rdtsc(); 85 #endif 86 /* ci += R.h */ 87 if (0 == j) { 88 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 89 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 90 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 91 } 92 } else { 93 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 94 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rc, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 95 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 96 } 97 } 98 /* Reduce batch gemm call */ 99 blocks = KB_BLOCKS; 100 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &blocks); 101 #ifdef PROFILE 102 if (ltid == 0) { 103 gemm_end = _rdtsc(); 104 gemm_cycles2 += gemm_end-gemm_start; 105 } 106 #endif 107 #ifdef PROFILE 108 if (ltid == 0) gemm_start = _rdtsc(); 109 #endif 110 /* initialize f with (bf + forget_bias) */ 111 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_const_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &bf[ik], handle->forget_bias ); 112 /* f += W.x */ 113 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 114 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wf, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 115 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 116 } 117 /* Reduce batch gemm call */ 118 blocks = CB_BLOCKS; 119 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); 120 #ifdef PROFILE 121 if (ltid == 0) { 122 gemm_end = _rdtsc(); 123 gemm_cycles += gemm_end-gemm_start; 124 } 125 #endif 126 #ifdef PROFILE 127 if (ltid == 0) gemm_start = _rdtsc(); 128 #endif 129 /* f += R.h */ 130 if (0 == j) { 131 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 132 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 133 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 134 } 135 } else { 136 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 137 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, rf, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 138 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 139 } 140 } 141 /* Reduce batch gemm call */ 142 blocks = KB_BLOCKS; 143 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &blocks); 144 #ifdef PROFILE 145 if (ltid == 0) { 146 gemm_end = _rdtsc(); 147 gemm_cycles2 += gemm_end-gemm_start; 148 } 149 #endif 150 #ifdef PROFILE 151 if (ltid == 0) gemm_start = _rdtsc(); 152 #endif 153 /* initialize o with bo */ 154 if (CB == 0) libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &bo[ik] ); 155 /* o += W.x */ 156 for (icb = 0, ic = 0; icb < CB_BLOCKS; ic += bc, icb++) { 157 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, wo, ikb, icb + CB*CB_BLOCKS, 0, 0, cBlocks, bc, bk); 158 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, x, j, in, ic + CB*CB_BLOCKS*bc, N, C); 159 } 160 /* Reduce batch gemm call */ 161 blocks = CB_BLOCKS; 162 batchreduce_kernela(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); 163 #ifdef PROFILE 164 if (ltid == 0) { 165 gemm_end = _rdtsc(); 166 gemm_cycles += gemm_end-gemm_start; 167 } 168 #endif 169 #ifdef PROFILE 170 if (ltid == 0) gemm_start = _rdtsc(); 171 #endif 172 /* o += R.h */ 173 if (0 == j) { 174 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 175 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 176 B_array[icb] = &LIBXSMM_VLA_ACCESS(2, hp, in, ic + CB*KB_BLOCKS*bk, K); 177 } 178 } else { 179 for (ic = 0, icb = 0; icb < KB_BLOCKS; ic += bk, icb++) { 180 A_array[icb] = &LIBXSMM_VLA_ACCESS(4, ro, ikb, icb + CB*KB_BLOCKS, 0, 0, kBlocks, bk, bk); 181 B_array[icb] = &LIBXSMM_VLA_ACCESS(3, h, j-1, in, ic + CB*KB_BLOCKS*bk, N, K); 182 } 183 } 184 /* Reduce batch gemm call */ 185 blocks = KB_BLOCKS; 186 batchreduce_kernelb(A_array, B_array, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &blocks); 187 #ifdef PROFILE 188 if (ltid == 0) { 189 gemm_end = _rdtsc(); 190 gemm_cycles2 += gemm_end-gemm_start; 191 } 192 #endif 193 194 if (CB == BF-1) { 195 #ifdef PROFILE 196 if (ltid == 0) { 197 eltwise_start = _rdtsc(); 198 } 199 #endif 200 cps_ptr = (j == 0) ? &LIBXSMM_VLA_ACCESS(2, cp, in, ik, K) : &LIBXSMM_VLA_ACCESS(3, cs, j-1, in, ik, N, K); 201 /* Compute i, ci, f, o, cs, co and h */ 202 #if defined(LIBXSMM_RNN_CELL_AVX512) 203 if (bk % 16 == 0 && bc % 16 == 0) { 204 #include "libxsmm_internal_lstm_fwd_fused_eltwise.tpl.c" 205 } else { 206 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); 207 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); 208 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); 209 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); 210 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 211 libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 212 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); 213 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 214 } 215 #else 216 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K) ); 217 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K) ); 218 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K) ); 219 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K) ); 220 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, f, j, in, ik, N, K), cps_ptr, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 221 libxsmm_internal_matrix_eltwise_fma_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, i, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, ci, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K) ); 222 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, cs, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K) ); 223 libxsmm_internal_matrix_eltwise_mult_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, o, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, co, j, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, j, in, ik, N, K) ); 224 #endif 225 226 #ifdef PROFILE 227 if (ltid == 0) { 228 eltwise_end = _rdtsc(); 229 eltwise_cycles += eltwise_end-eltwise_start; 230 } 231 #endif 232 } 233 } 234 } 235 libxsmm_barrier_wait(handle->barrier, (int)ltid); 236 } 237 238