1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_BLOCKED_GEMM_H
12 #define LIBXSMM_BLOCKED_GEMM_H
13 
14 #include "libxsmm_typedefs.h"
15 
16 
17 /** Denotes the BGEMM data order. */
18 typedef enum libxsmm_blocked_gemm_order {
19   LIBXSMM_BLOCKED_GEMM_ORDER_JIK = 0,
20   LIBXSMM_BLOCKED_GEMM_ORDER_IJK = 1,
21   LIBXSMM_BLOCKED_GEMM_ORDER_JKI = 2,
22   LIBXSMM_BLOCKED_GEMM_ORDER_IKJ = 3,
23   LIBXSMM_BLOCKED_GEMM_ORDER_KJI = 4,
24   LIBXSMM_BLOCKED_GEMM_ORDER_KIJ = 5
25 } libxsmm_blocked_gemm_order;
26 
27 /** Describes the Block-GEMM (BGEMM) operation. */
28 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_blocked_gemm_handle libxsmm_blocked_gemm_handle;
29 
30 
31 LIBXSMM_API libxsmm_blocked_gemm_handle* libxsmm_blocked_gemm_handle_create(
32   /** Number of threads used to run BGEMM. */
33   /*unsigned*/ int nthreads, libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
34   libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
35   /** If the block-size (BM, BN, or BK) is not given, a suitable value is chosen internally. */
36   const libxsmm_blasint* bm, const libxsmm_blasint* bn, const libxsmm_blasint* bk,
37   /** If b_m1, b_n1, b_k1, or b_k2 is not supplied, the respective value defaults to one. */
38   const libxsmm_blasint* b_m1, const libxsmm_blasint* b_n1, const libxsmm_blasint* b_k1, const libxsmm_blasint* b_k2,
39   /** If alpha is not supplied (NULL), then LIBXSMM_ALPHA is used instead. */ const void* alpha,
40   /** If beta is not supplied (NULL), then LIBXSMM_BETA is used instead. */   const void*  beta,
41   /** See libxsmm_gemm_flags (LIBXSMM_FLAGS is used if NULL is given). */ const int* gemm_flags,
42   /** See libxsmm_gemm_prefetch_type; a strategy chosen automatically if NULL is given. */
43   const libxsmm_gemm_prefetch_type* prefetch,
44   /** See libxsmm_blocked_gemm_order; an order is chosen automatically if NULL is given. */
45   const libxsmm_blocked_gemm_order* order);
46 
47 LIBXSMM_API void libxsmm_blocked_gemm_handle_destroy(const libxsmm_blocked_gemm_handle* handle);
48 
49 /** Copy-in functions for A, B, and C matrices. A leading dimension for the source buffer is optional and can be NULL. */
50 LIBXSMM_API int libxsmm_blocked_gemm_copyin_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst);
51 LIBXSMM_API int libxsmm_blocked_gemm_copyin_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst);
52 LIBXSMM_API int libxsmm_blocked_gemm_copyin_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst);
53 /** Copy-out function for the C-matrix. A leading dimension for the destination buffer is optional and can be NULL. */
54 LIBXSMM_API int libxsmm_blocked_gemm_copyout_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst);
55 
56 /** Convert function required to reorganize elements in delta for BWD and UPD passes of RNN, LSTM and GRU */
57 LIBXSMM_API int libxsmm_blocked_gemm_convert_b_to_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst);
58 /** Transpose matrix b for UPD pass of GRU */
59 LIBXSMM_API int libxsmm_blocked_gemm_transpose_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst);
60 
61 /**
62 * Fine grain parallelized block-GEMM (BGEMM), which uses a block structure
63 * layout for the A and B matrices. The implementation is parallelized
64 * among M, N, and K using fine-grained on-demand locks when writing C.
65 */
66 LIBXSMM_API void libxsmm_blocked_gemm_st(const libxsmm_blocked_gemm_handle* handle, const void* a, const void* b, void* c,
67   /*unsigned*/int start_thread, /*unsigned*/int tid);
68 
69 /**
70  * Implementation of libxsmm_blocked_gemm, which is parallelized with OpenMP
71  * and uses an OpenMP or custom barrier implementation. The function
72  * allows to run multiple GEMMs, which is specified by 'count' (RNNs).
73  * This function requires to link against libxsmmext.
74  */
75 LIBXSMM_APIEXT void libxsmm_blocked_gemm_omp(const libxsmm_blocked_gemm_handle* handle,
76   const void* a, const void* b, void* c, /*unsigned*/int count);
77 
78 #endif /*LIBXSMM_BLOCKED_GEMM_H*/
79 
80