1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Kunal Banerjee (Intel Corp.)
10 ******************************************************************************/
11 
12 /* helper variables */
13 libxsmm_blasint i, ik, in, ic, inik;
14 /* input sizes */
15 const libxsmm_blasint K =  handle->desc.K;
16 const libxsmm_blasint N =  handle->desc.N;
17 const libxsmm_blasint C =  handle->desc.C;
18 const libxsmm_blasint t =  handle->T;
19 const libxsmm_blasint bk = handle->bk;
20 const libxsmm_blasint bn = handle->bn;
21 const libxsmm_blasint bc = handle->bc;
22 /* define tensors */
23 element_input_type  *xt = (element_input_type* )handle->xt->data;
24 element_input_type  *hpD= (element_input_type* )handle->hp->data;
25 element_filter_type *wD = (element_filter_type*)handle->w->data;
26 element_filter_type *rD = (element_filter_type*)handle->r->data;
27 element_output_type *b  = (element_output_type*)handle->b->data;
28 element_output_type *ht = (element_output_type*)handle->ht->data;
29 element_output_type *zt = (element_output_type*)handle->internal_z;
30 LIBXSMM_VLA_DECL(3, element_input_type,  x, xt, N, C);
31 LIBXSMM_VLA_DECL(2, element_input_type,  hp, hpD, K);
32 LIBXSMM_VLA_DECL(2, element_filter_type, w, wD, K);
33 LIBXSMM_VLA_DECL(2, element_filter_type, r, rD, K);
34 LIBXSMM_VLA_DECL(3, element_output_type, h, ht, N, K);
35 LIBXSMM_VLA_DECL(3, element_output_type, z, zt, N, K);
36 /* define gemm kernels */
37 libxsmm_smmfunction gemmkernela = libxsmm_smmdispatch( bk, bn, bc, &K, &C, &K, NULL, NULL, NULL, NULL );
38 libxsmm_smmfunction gemmkernelb = libxsmm_smmdispatch( bk, bn, bk, &K, &K, &K, NULL, NULL, NULL, NULL );
39 /* parallelize over C-blocks */
40 /* computing first logical thread */
41 const libxsmm_blasint ltid = (libxsmm_blasint)tid - (libxsmm_blasint)start_thread;
42 /* number of tasks that could be run in parallel */
43 const libxsmm_blasint work = (N/bn) * (K/bk);
44 /* compute chunk size */
45 const libxsmm_blasint chunksize = (work % (libxsmm_blasint)handle->desc.threads == 0) ? (work / (libxsmm_blasint)handle->desc.threads) : ((work / (libxsmm_blasint)handle->desc.threads) + 1);
46 /* compute thr_begin and thr_end */
47 const libxsmm_blasint thr_begin = (ltid * chunksize < work) ? (ltid * chunksize) : work;
48 const libxsmm_blasint thr_end = ((ltid + 1) * chunksize < work) ? ((ltid + 1) * chunksize) : work;
49 
50 /* lazy barrier init */
51 libxsmm_barrier_init(handle->barrier, (int)ltid);
52 
53 /* All data is in column-major format */
54 for (i = 0; i < t; ++i) {
55   /* let's run the cell in blocks for good locality */
56   for (inik = thr_begin; inik < thr_end; ++inik ) {
57     in = (inik / (K/bk))*bn;
58     ik = (inik % (K/bk))*bk;
59 
60     /* z = per_col(b) */
61     libxsmm_internal_matrix_bcst_colvector_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &b[ik] );
62 
63     /* z += W.x */
64     for (ic = 0; ic < C; ic += bc) {
65       /* this is a small matmul */
66       gemmkernela( &LIBXSMM_VLA_ACCESS(2, w, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, x, i, in, ic, N, C), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) );
67     }
68     /* z += U.h */
69     if (0 == i) {
70       for (ic = 0; ic < K; ic += bk) {
71         /* this is a small matmul */
72         gemmkernelb( &LIBXSMM_VLA_ACCESS(2, r, ic, ik, K), &LIBXSMM_VLA_ACCESS(2, hp, in, ic, K), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) );
73       }
74     } else {
75       for (ic = 0; ic < K; ic += bk) {
76         /* this is a small matmul */
77         gemmkernelb( &LIBXSMM_VLA_ACCESS(2, r, ic, ik, K), &LIBXSMM_VLA_ACCESS(3, h, i-1, in, ic, N, K), &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K) );
78       }
79     }
80 #if defined(LIBXSMM_DNN_RNN_RELU_FWD)
81     libxsmm_internal_matrix_relu_ld(    bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) );
82 #endif
83 #if defined(LIBXSMM_DNN_RNN_SIGMOID_FWD)
84     libxsmm_internal_matrix_sigmoid_ld( bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) );
85 #endif
86 #if defined(LIBXSMM_DNN_RNN_TANH_FWD)
87     libxsmm_internal_matrix_tanh_ld(    bk, bn, K, &LIBXSMM_VLA_ACCESS(3, z, i, in, ik, N, K), &LIBXSMM_VLA_ACCESS(3, h, i, in, ik, N, K) );
88 #endif
89   }
90 
91   libxsmm_barrier_wait(handle->barrier, (int)ltid);
92 }
93