1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Alexander Heinecke, Kunal Banerjee (Intel Corp.) 10 ******************************************************************************/ 11 12 /* helper variables */ 13 libxsmm_blasint i, ik, in, ic, inik; 14 /* input sizes */ 15 const libxsmm_blasint K = handle->desc.K; 16 const libxsmm_blasint N = handle->desc.N; 17 const libxsmm_blasint C = handle->desc.C; 18 const libxsmm_blasint t = handle->T; 19 const libxsmm_blasint bk = handle->bk; 20 const libxsmm_blasint bn = handle->bn; 21 const libxsmm_blasint bc = handle->bc; 22 /* define tensors */ 23 element_input_type *xt = (element_input_type* )handle->xt->data; 24 element_input_type *hpD= (element_input_type* )handle->hp->data; 25 element_filter_type *wD = (element_filter_type*)handle->w->data; 26 element_filter_type *rD = (element_filter_type*)handle->r->data; 27 element_output_type *b = (element_output_type*)handle->b->data; 28 element_output_type *ht = (element_output_type*)handle->ht->data; 29 element_output_type *zt = (element_output_type*)handle->internal_z; 30 LIBXSMM_VLA_DECL(3, element_input_type, x, xt, N, C); 31 LIBXSMM_VLA_DECL(2, element_input_type, hp, hpD, K); 32 LIBXSMM_VLA_DECL(2, element_filter_type, w, wD, K); 33 LIBXSMM_VLA_DECL(2, element_filter_type, r, rD, K); 34 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K); 35 LIBXSMM_VLA_DECL(3, element_output_type, z, zt, N, K); 36 /* define gemm kernels */ 37 libxsmm_smmfunction gemmkernela = libxsmm_smmdispatch( bk, bn, bc, &K, &C, &K, NULL, NULL, NULL, NULL ); 38 libxsmm_smmfunction gemmkernelb = libxsmm_smmdispatch( bk, bn, bk, &K, &K, &K, NULL, NULL, NULL, NULL ); 39 /* parallelize over C-blocks */ 40 /* computing first logical thread */ 41 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread; 42 /* number of tasks that could be run in parallel */ 43 const libxsmm_blasint work = (N/bn) * (K/bk); 44 /* compute chunk size */ 45 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1); 46 /* compute thr_begin and thr_end */ 47 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work; 48 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work; 49 50 /* lazy barrier init */ 51 libxsmm_barrier_init(handle->barrier, (int)ltid); 52 53 /* All data is in column-major format */ 54 for (i = 0; i < t; ++i) { 55 /* let's run the cell in blocks for good locality */ 56 for (inik = thr_begin; inik < thr_end; ++inik ) { 57 in = (inik / (K/bk))*bn; 58 ik = (inik % (K/bk))*bk; 59 60 /* z = per_col(b) */ 61 libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &b[ik] ); 62 63 /* z += W.x */ 64 for (ic = 0; ic < C; ic += bc) { 65 /* this is a small matmul */ 66 gemmkernela( &LIBXSMM_VLA_ACCESS(2, w, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, x, i, in, ic, N, C), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) ); 67 } 68 /* z += U.h */ 69 if (0 == i) { 70 for (ic = 0; ic < K; ic += bk) { 71 /* this is a small matmul */ 72 gemmkernelb( &LIBXSMM_VLA_ACCESS(2, r, ic, ik, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ic, K), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) ); 73 } 74 } else { 75 for (ic = 0; ic < K; ic += bk) { 76 /* this is a small matmul */ 77 gemmkernelb( &LIBXSMM_VLA_ACCESS(2, r, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, h, i-1, in, ic, N, K), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) ); 78 } 79 } 80 #if defined(LIBXSMM_DNN_RNN_RELU_FWD) 81 libxsmm_internal_matrix_relu_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) ); 82 #endif 83 #if defined(LIBXSMM_DNN_RNN_SIGMOID_FWD) 84 libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) ); 85 #endif 86 #if defined(LIBXSMM_DNN_RNN_TANH_FWD) 87 libxsmm_internal_matrix_tanh_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) ); 88 #endif 89 } 90 91 libxsmm_barrier_wait(handle->barrier, (int)ltid); 92 } 93