1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_BLOCKED_GEMM_H 12 #define LIBXSMM_BLOCKED_GEMM_H 13 14 #include "libxsmm_typedefs.h" 15 16 17 /** Denotes the BGEMM data order. */ 18 typedef enum libxsmm_blocked_gemm_order { 19 LIBXSMM_BLOCKED_GEMM_ORDER_JIK = 0, 20 LIBXSMM_BLOCKED_GEMM_ORDER_IJK = 1, 21 LIBXSMM_BLOCKED_GEMM_ORDER_JKI = 2, 22 LIBXSMM_BLOCKED_GEMM_ORDER_IKJ = 3, 23 LIBXSMM_BLOCKED_GEMM_ORDER_KJI = 4, 24 LIBXSMM_BLOCKED_GEMM_ORDER_KIJ = 5 25 } libxsmm_blocked_gemm_order; 26 27 /** Describes the Block-GEMM (BGEMM) operation. */ 28 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_blocked_gemm_handle libxsmm_blocked_gemm_handle; 29 30 31 LIBXSMM_API libxsmm_blocked_gemm_handle* libxsmm_blocked_gemm_handle_create( 32 /** Number of threads used to run BGEMM. */ 33 /*unsigned*/ int nthreads, libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, 34 libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, 35 /** If the block-size (BM, BN, or BK) is not given, a suitable value is chosen internally. */ 36 const libxsmm_blasint* bm, const libxsmm_blasint* bn, const libxsmm_blasint* bk, 37 /** If b_m1, b_n1, b_k1, or b_k2 is not supplied, the respective value defaults to one. */ 38 const libxsmm_blasint* b_m1, const libxsmm_blasint* b_n1, const libxsmm_blasint* b_k1, const libxsmm_blasint* b_k2, 39 /** If alpha is not supplied (NULL), then LIBXSMM_ALPHA is used instead. */ const void* alpha, 40 /** If beta is not supplied (NULL), then LIBXSMM_BETA is used instead. */ const void* beta, 41 /** See libxsmm_gemm_flags (LIBXSMM_FLAGS is used if NULL is given). */ const int* gemm_flags, 42 /** See libxsmm_gemm_prefetch_type; a strategy chosen automatically if NULL is given. */ 43 const libxsmm_gemm_prefetch_type* prefetch, 44 /** See libxsmm_blocked_gemm_order; an order is chosen automatically if NULL is given. */ 45 const libxsmm_blocked_gemm_order* order); 46 47 LIBXSMM_API void libxsmm_blocked_gemm_handle_destroy(const libxsmm_blocked_gemm_handle* handle); 48 49 /** Copy-in functions for A, B, and C matrices. A leading dimension for the source buffer is optional and can be NULL. */ 50 LIBXSMM_API int libxsmm_blocked_gemm_copyin_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst); 51 LIBXSMM_API int libxsmm_blocked_gemm_copyin_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst); 52 LIBXSMM_API int libxsmm_blocked_gemm_copyin_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst); 53 /** Copy-out function for the C-matrix. A leading dimension for the destination buffer is optional and can be NULL. */ 54 LIBXSMM_API int libxsmm_blocked_gemm_copyout_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst); 55 56 /** Convert function required to reorganize elements in delta for BWD and UPD passes of RNN, LSTM and GRU */ 57 LIBXSMM_API int libxsmm_blocked_gemm_convert_b_to_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst); 58 /** Transpose matrix b for UPD pass of GRU */ 59 LIBXSMM_API int libxsmm_blocked_gemm_transpose_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst); 60 61 /** 62 * Fine grain parallelized block-GEMM (BGEMM), which uses a block structure 63 * layout for the A and B matrices. The implementation is parallelized 64 * among M, N, and K using fine-grained on-demand locks when writing C. 65 */ 66 LIBXSMM_API void libxsmm_blocked_gemm_st(const libxsmm_blocked_gemm_handle* handle, const void* a, const void* b, void* c, 67 /*unsigned*/int start_thread, /*unsigned*/int tid); 68 69 /** 70 * Implementation of libxsmm_blocked_gemm, which is parallelized with OpenMP 71 * and uses an OpenMP or custom barrier implementation. The function 72 * allows to run multiple GEMMs, which is specified by 'count' (RNNs). 73 * This function requires to link against libxsmmext. 74 */ 75 LIBXSMM_APIEXT void libxsmm_blocked_gemm_omp(const libxsmm_blocked_gemm_handle* handle, 76 const void* a, const void* b, void* c, /*unsigned*/int count); 77 78 #endif /*LIBXSMM_BLOCKED_GEMM_H*/ 79 80