1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_FRONTEND_H 12 #define LIBXSMM_FRONTEND_H 13 14 #include "libxsmm_typedefs.h" 15 16 /** Helper macros for eliding prefetch address calculations depending on prefetch scheme. */ 17 #if !defined(_WIN32) && !defined(__CYGWIN__) /* TODO: fully support calling convention */ 18 #if 0 != ((LIBXSMM_PREFETCH) & 2/*AL2*/) \ 19 || 0 != ((LIBXSMM_PREFETCH) & 8/*AL2_AHEAD*/) 20 # define LIBXSMM_GEMM_PREFETCH_A(EXPR) (EXPR) 21 #endif 22 #if 0 != ((LIBXSMM_PREFETCH) & 4/*BL2_VIA_C*/) \ 23 || 0 != ((LIBXSMM_PREFETCH) & 16/*BL1*/) 24 # define LIBXSMM_GEMM_PREFETCH_B(EXPR) (EXPR) 25 #endif 26 #endif 27 /** Secondary helper macros derived from the above group. */ 28 #if defined(LIBXSMM_GEMM_PREFETCH_A) 29 # define LIBXSMM_NOPREFETCH_A(EXPR) 30 #else 31 # define LIBXSMM_NOPREFETCH_A(EXPR) EXPR 32 # define LIBXSMM_GEMM_PREFETCH_A(EXPR) 0 33 #endif 34 #if defined(LIBXSMM_GEMM_PREFETCH_B) 35 # define LIBXSMM_NOPREFETCH_B(EXPR) 36 #else 37 # define LIBXSMM_NOPREFETCH_B(EXPR) EXPR 38 # define LIBXSMM_GEMM_PREFETCH_B(EXPR) 0 39 #endif 40 #if defined(LIBXSMM_GEMM_PREFETCH_C) 41 # define LIBXSMM_NOPREFETCH_C(EXPR) 42 #else 43 # define LIBXSMM_NOPREFETCH_C(EXPR) EXPR 44 # define LIBXSMM_GEMM_PREFETCH_C(EXPR) 0 45 #endif 46 47 /** MKL_DIRECT_CALL requires to include the MKL interface. */ 48 #if (defined(MKL_DIRECT_CALL_SEQ) || defined(MKL_DIRECT_CALL)) 49 # if (0 != LIBXSMM_ILP64 && !defined(MKL_ILP64)) 50 # error "Inconsistent ILP64 configuration detected!" 51 # endif 52 # if defined(LIBXSMM_OFFLOAD_BUILD) 53 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET)) 54 # include <mkl.h> 55 # pragma offload_attribute(pop) 56 # else 57 # include <mkl.h> 58 # endif 59 #endif 60 61 /** Automatically select a prefetch-strategy (libxsmm_get_gemm_xprefetch, etc.). */ 62 #define LIBXSMM_PREFETCH_AUTO -1 63 64 /** Append "_omp" postfix to the given symbol. */ 65 #define LIBXSMM_USEOMP(FUNCTION) LIBXSMM_CONCATENATE(FUNCTION, _omp) 66 67 /** Helper macro for BLAS-style prefixes. */ 68 #define LIBXSMM_TPREFIX_NAME(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_TPREFIX_, TYPE) 69 #define LIBXSMM_TPREFIX(TYPE, FUNCTION) LIBXSMM_CONCATENATE(LIBXSMM_TPREFIX_NAME(TYPE), FUNCTION) 70 #define LIBXSMM_TPREFIX_doubledouble d 71 #define LIBXSMM_TPREFIX_floatfloat s 72 #define LIBXSMM_TPREFIX_shortfloat ws 73 #define LIBXSMM_TPREFIX_shortint wi 74 #define LIBXSMM_TPREFIX_libxsmm_bfloat16float bs 75 /** Defaults if only the input type is specified. */ 76 #define LIBXSMM_TPREFIX_double LIBXSMM_TPREFIX_doubledouble 77 #define LIBXSMM_TPREFIX_float LIBXSMM_TPREFIX_floatfloat 78 #define LIBXSMM_TPREFIX_short LIBXSMM_TPREFIX_shortint 79 80 /** Construct symbol name from a given real type name (float, double and short). */ 81 #define LIBXSMM_BLAS_FNTYPE(TYPE, KIND) LIBXSMM_CONCATENATE3(libxsmm_, LIBXSMM_TPREFIX(TYPE, KIND), _function) 82 #define LIBXSMM_MMFUNCTION_TYPE(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, mmfunction)) 83 #define LIBXSMM_MMDISPATCH_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, mmdispatch)) 84 #define LIBXSMM_XBLAS_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_blas_, LIBXSMM_TPREFIX(TYPE, gemm)) 85 #define LIBXSMM_XGEMM_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, gemm)) 86 #define LIBXSMM_YGEMM_SYMBOL(TYPE) LIBXSMM_USEOMP(LIBXSMM_XGEMM_SYMBOL(TYPE)) 87 #define LIBXSMM_BLAS_SYMBOL(TYPE, KIND) LIBXSMM_FSYMBOL(LIBXSMM_TPREFIX(TYPE, KIND)) 88 #define LIBXSMM_CBLAS_SYMBOL LIBXSMM_TPREFIX 89 90 #define LIBXSMM_BLAS_DECL(TYPE, KIND, DECL) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_, LIBXSMM_TPREFIX(TYPE, KIND))(DECL) 91 #if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL) 92 # define LIBXSMM_BLAS_dgemm(DECL) DECL; 93 # define LIBXSMM_BLAS_sgemm(DECL) DECL; 94 # define LIBXSMM_BLAS_dgemv(DECL) DECL; 95 # define LIBXSMM_BLAS_sgemv(DECL) DECL; 96 #else 97 # define LIBXSMM_BLAS_dgemm 98 # define LIBXSMM_BLAS_sgemm 99 # define LIBXSMM_BLAS_dgemv 100 # define LIBXSMM_BLAS_sgemv 101 #endif 102 103 /* Construct prefix names, function type or dispatch function from given input and output types. */ 104 #define LIBXSMM_MMFUNCTION_TYPE2(ITYPE, OTYPE) LIBXSMM_MMFUNCTION_TYPE(LIBXSMM_CONCATENATE(ITYPE, OTYPE)) 105 #define LIBXSMM_MMDISPATCH_SYMBOL2(ITYPE, OTYPE) LIBXSMM_MMDISPATCH_SYMBOL(LIBXSMM_CONCATENATE(ITYPE, OTYPE)) 106 #define LIBXSMM_TPREFIX_NAME2(ITYPE, OTYPE) LIBXSMM_TPREFIX_NAME(LIBXSMM_CONCATENATE(ITYPE, OTYPE)) 107 #define LIBXSMM_TPREFIX2(ITYPE, OTYPE, FUNCTION) LIBXSMM_TPREFIX(LIBXSMM_CONCATENATE(ITYPE, OTYPE), FUNCTION) 108 109 /** Helper macro for comparing selected types. */ 110 #define LIBXSMM_EQUAL(T1, T2) LIBXSMM_CONCATENATE3(LIBXSMM_EQUAL_, T1, T2) 111 #define LIBXSMM_EQUAL_floatfloat 1 112 #define LIBXSMM_EQUAL_doubledouble 1 113 #define LIBXSMM_EQUAL_floatdouble 0 114 #define LIBXSMM_EQUAL_doublefloat 0 115 #define LIBXSMM_EQUAL_shortdouble 0 116 #define LIBXSMM_EQUAL_shortfloat 0 117 118 #if defined(LIBXSMM_BLAS_CONST) 119 # undef LIBXSMM_BLAS_CONST 120 # define LIBXSMM_BLAS_CONST const 121 #elif defined(OPENBLAS_CONST) 122 # define LIBXSMM_BLAS_CONST OPENBLAS_CONST 123 #elif defined(LIBXSMM_BLAS_NONCONST) || defined(__OPENBLAS) || defined(__OPENBLAS77) 124 # define LIBXSMM_BLAS_CONST 125 #else 126 # define LIBXSMM_BLAS_CONST const 127 #endif 128 129 #if !defined(LIBXSMM_NO_BLAS) 130 # if (!defined(__BLAS) || (0 != __BLAS)) 131 # define LIBXSMM_NO_BLAS 0 132 # define LIBXSMM_BLAS 1 133 # else 134 # define LIBXSMM_NO_BLAS 1 135 # define LIBXSMM_BLAS 0 136 # endif 137 #endif 138 139 #if defined(__BLAS) && (1 == __BLAS) 140 # if defined(__OPENBLAS) 141 LIBXSMM_EXTERN void openblas_set_num_threads(int num_threads); 142 # define LIBXSMM_BLAS_INIT openblas_set_num_threads(1); 143 # endif 144 #endif 145 #if !defined(LIBXSMM_BLAS_INIT) 146 # define LIBXSMM_BLAS_INIT 147 #endif 148 149 #if defined(LIBXSMM_BUILD) 150 # if defined(LIBXSMM_BUILD_EXT) && !defined(__STATIC) 151 # define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_APIEXT 152 # elif defined(LIBXSMM_NO_BLAS) && (1 == LIBXSMM_NO_BLAS) 153 # define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_API 154 # endif 155 #endif 156 #if !defined(LIBXSMM_BLAS_SYMBOL_VISIBILITY) 157 # define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_EXTERN LIBXSMM_VISIBILITY_IMPORT LIBXSMM_RETARGETABLE 158 #endif 159 160 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemm_batch(CONST_STAR, STAR, TYPE) char CONST_STAR, char CONST_STAR, \ 161 libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, \ 162 TYPE CONST_STAR, TYPE CONST_STAR STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR STAR, libxsmm_blasint CONST_STAR, \ 163 TYPE CONST_STAR, TYPE STAR STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR 164 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemm(CONST_STAR, STAR, TYPE) char CONST_STAR, char CONST_STAR, \ 165 libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, \ 166 TYPE CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, TYPE STAR, libxsmm_blasint CONST_STAR 167 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemv(CONST_STAR, STAR, TYPE) char CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, \ 168 TYPE CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, \ 169 TYPE CONST_STAR, TYPE STAR, libxsmm_blasint CONST_STAR 170 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_SYMBOL_SIGNATURE_, KIND)(CONST_STAR, STAR, TYPE) 171 #define LIBXSMM_BLAS_SYMBOL_FDECL(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_BLAS_SYMBOL_VISIBILITY \ 172 void LIBXSMM_BLAS_SYMBOL(TYPE, KIND)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND)) 173 #define LIBXSMM_BLAS_SYMBOL_CDECL(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_BLAS_SYMBOL_VISIBILITY \ 174 void LIBXSMM_CBLAS_SYMBOL(TYPE, KIND)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND)) 175 176 #if (0 != LIBXSMM_BLAS) /* BLAS available */ 177 # define LIBXSMM_BLAS_SYMBOL_DECL(TYPE, KIND) LIBXSMM_BLAS_DECL(TYPE, KIND, LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, TYPE, KIND)) 178 #else 179 # define LIBXSMM_BLAS_SYMBOL_DECL(TYPE, KIND) 180 #endif 181 182 /** Helper macro consolidating the transpose requests into a set of flags. */ 183 #define LIBXSMM_GEMM_FLAGS(TRANSA, TRANSB) /* check for N/n rather than T/t since C/c is also valid! */ \ 184 ((('n' == (TRANSA) || *"N" == (TRANSA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \ 185 | (('n' == (TRANSB) || *"N" == (TRANSB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B)) 186 187 /** Helper macro consolidating CBLAS transpose requests into a set of flags. */ 188 #define LIBXSMM_GEMM_CFLAGS(TRANSA, TRANSB) /* check for N/n rather than T/t since C/c is also valid! */ \ 189 ((CblasNoTrans == (TRANSA) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \ 190 | (CblasNoTrans == (TRANSB) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B)) 191 192 /** Helper macro consolidating the transpose requests into a set of flags. */ 193 #define LIBXSMM_GEMM_VNNI_FLAGS(TRANSA, TRANSB, VNNIA, VNNIB) /* check for N/n rather than T/t since C/c is also valid! */ \ 194 ((('n' == (TRANSA) || *"N" == (TRANSA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \ 195 | (('n' == (TRANSB) || *"N" == (TRANSB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B) \ 196 | (('n' == (VNNIA) || *"N" == (VNNIA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_VNNI_A) \ 197 | (('n' == (VNNIB) || *"N" == (VNNIB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_VNNI_B)) 198 199 /** Helper macro allowing NULL-requests (transposes) supplied by some default. */ 200 #define LIBXSMM_GEMM_PFLAGS(TRANSA, TRANSB, DEFAULT) LIBXSMM_GEMM_FLAGS( \ 201 NULL != ((const void*)(TRANSA)) ? (*(const char*)(TRANSA)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & (DEFAULT)) ? 'n' : 't'), \ 202 NULL != ((const void*)(TRANSB)) ? (*(const char*)(TRANSB)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & (DEFAULT)) ? 'n' : 't')) \ 203 | (~(LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B) & (DEFAULT)) 204 205 /** Inlinable GEMM exercising the compiler's code generation (macro template). TODO: only NN is supported and SP/DP matrices. */ 206 #define LIBXSMM_INLINE_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \ 207 /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \ 208 const char libxsmm_inline_xgemm_transa_ = (char)(NULL != ((void*)(TRANSA)) ? (*(const char*)(TRANSA)) : \ 209 (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & LIBXSMM_FLAGS) ? 'n' : 't')); \ 210 const char libxsmm_inline_xgemm_transb_ = (char)(NULL != ((void*)(TRANSB)) ? (*(const char*)(TRANSB)) : \ 211 (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & LIBXSMM_FLAGS) ? 'n' : 't')); \ 212 const libxsmm_blasint libxsmm_inline_xgemm_m_ = *(const libxsmm_blasint*)(M); /* must be specified */ \ 213 const libxsmm_blasint libxsmm_inline_xgemm_k_ = (NULL != ((void*)(K)) ? (*(const libxsmm_blasint*)(K)) : libxsmm_inline_xgemm_m_); \ 214 const libxsmm_blasint libxsmm_inline_xgemm_n_ = (NULL != ((void*)(N)) ? (*(const libxsmm_blasint*)(N)) : libxsmm_inline_xgemm_k_); \ 215 const libxsmm_blasint libxsmm_inline_xgemm_lda_ = (NULL != ((void*)(LDA)) ? (*(const libxsmm_blasint*)(LDA)) : \ 216 (('n' == libxsmm_inline_xgemm_transa_ || *"N" == libxsmm_inline_xgemm_transa_) ? libxsmm_inline_xgemm_m_ : libxsmm_inline_xgemm_k_)); \ 217 const libxsmm_blasint libxsmm_inline_xgemm_ldb_ = (NULL != ((void*)(LDB)) ? (*(const libxsmm_blasint*)(LDB)) : \ 218 (('n' == libxsmm_inline_xgemm_transb_ || *"N" == libxsmm_inline_xgemm_transb_) ? libxsmm_inline_xgemm_k_ : libxsmm_inline_xgemm_n_)); \ 219 const libxsmm_blasint libxsmm_inline_xgemm_ldc_ = (NULL != ((void*)(LDC)) ? (*(const libxsmm_blasint*)(LDC)) : libxsmm_inline_xgemm_m_); \ 220 const OTYPE libxsmm_inline_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ 221 const OTYPE libxsmm_inline_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ 222 libxsmm_blasint libxsmm_inline_xgemm_ni_, libxsmm_inline_xgemm_mi_ = 0, libxsmm_inline_xgemm_ki_; /* loop induction variables */ \ 223 LIBXSMM_ASSERT('n' == libxsmm_inline_xgemm_transa_ || *"N" == libxsmm_inline_xgemm_transa_); \ 224 LIBXSMM_ASSERT('n' == libxsmm_inline_xgemm_transb_ || *"N" == libxsmm_inline_xgemm_transb_); \ 225 LIBXSMM_PRAGMA_SIMD \ 226 for (libxsmm_inline_xgemm_mi_ = 0; libxsmm_inline_xgemm_mi_ < libxsmm_inline_xgemm_m_; ++libxsmm_inline_xgemm_mi_) { \ 227 LIBXSMM_PRAGMA_LOOP_COUNT(1, LIBXSMM_CONFIG_MAX_DIM, LIBXSMM_CONFIG_AVG_DIM) \ 228 for (libxsmm_inline_xgemm_ki_ = 0; libxsmm_inline_xgemm_ki_ < libxsmm_inline_xgemm_k_; ++libxsmm_inline_xgemm_ki_) { \ 229 LIBXSMM_PRAGMA_UNROLL \ 230 for (libxsmm_inline_xgemm_ni_ = 0; libxsmm_inline_xgemm_ni_ < libxsmm_inline_xgemm_n_; ++libxsmm_inline_xgemm_ni_) { \ 231 ((OTYPE*)(C))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldc_+libxsmm_inline_xgemm_mi_] \ 232 = ((const ITYPE*)(B))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldb_+libxsmm_inline_xgemm_ki_] * \ 233 (((const ITYPE*)(A))[libxsmm_inline_xgemm_ki_*libxsmm_inline_xgemm_lda_+libxsmm_inline_xgemm_mi_] * libxsmm_inline_xgemm_alpha_) \ 234 + ((const OTYPE*)(C))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldc_+libxsmm_inline_xgemm_mi_] * libxsmm_inline_xgemm_beta_; \ 235 } \ 236 } \ 237 } \ 238 } 239 240 #if (defined(LIBXSMM_INIT) || defined(LIBXSMM_CTOR)) 241 # undef LIBXSMM_INIT 242 # define LIBXSMM_INIT LIBXSMM_ASSERT_MSG(1 < libxsmm_ninit, "LIBXSMM is not initialized"); 243 # define LIBXSMM_INIT_COMPLETED 244 #else 245 # define LIBXSMM_INIT if (2 > libxsmm_ninit) libxsmm_init(); 246 #endif 247 248 /** Map to appropriate BLAS function (or fall-back). The mapping is used, e.g., inside of LIBXSMM_BLAS_XGEMM. */ 249 #define LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, FUNCTION) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_FUNCTION_, LIBXSMM_TPREFIX2(ITYPE, OTYPE, FUNCTION)) 250 #if (0 != LIBXSMM_BLAS) /* Helper macro to eventually (if defined) call libxsmm_init */ 251 # if defined(LIBXSMM_INIT_COMPLETED) 252 # define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_original_dgemm_batch_function 253 # define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_original_sgemm_batch_function 254 # define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_original_dgemm_function 255 # define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_original_sgemm_function 256 # define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_original_dgemv_function 257 # define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_original_sgemv_function 258 # else 259 # define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_original_dgemm_batch() 260 # define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_original_sgemm_batch() 261 # define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_original_dgemm() 262 # define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_original_sgemm() 263 # define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_original_dgemv() 264 # define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_original_sgemv() 265 # endif 266 #else /* no BLAS */ 267 # define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_blas_error("dgemm_batch") 268 # define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_blas_error("sgemm_batch") 269 # define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_blas_error("dgemm") 270 # define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_blas_error("sgemm") 271 # define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_blas_error("dgemv") 272 # define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_blas_error("sgemv") 273 #endif 274 /** Low-precision (BLAS-like) function symbols. */ 275 #define LIBXSMM_BLAS_FUNCTION_wigemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 276 LIBXSMM_INLINE_XGEMM(short, int, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 277 #define LIBXSMM_BLAS_FUNCTION_bsgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 278 LIBXSMM_INLINE_XGEMM(libxsmm_bfloat16, float, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 279 280 /** Short-cut macros to construct desired BLAS function symbol. */ 281 #define LIBXSMM_BLAS_FUNCTION1(TYPE, FUNCTION) LIBXSMM_BLAS_FUNCTION(TYPE, TYPE, FUNCTION) 282 #define LIBXSMM_GEMM_BATCH_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemm_batch) 283 #define LIBXSMM_GEMM_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemm) 284 #define LIBXSMM_GEMV_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemv) 285 286 /** BLAS-based GEMM supplied by the linked LAPACK/BLAS library (macro template). */ 287 #define LIBXSMM_BLAS_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \ 288 /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \ 289 const char libxsmm_blas_xgemm_transa_ = (char)(NULL != ((void*)(TRANSA)) ? (*(const char*)(TRANSA)) : \ 290 (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & LIBXSMM_FLAGS) ? 'n' : 't')); \ 291 const char libxsmm_blas_xgemm_transb_ = (char)(NULL != ((void*)(TRANSB)) ? (*(const char*)(TRANSB)) : \ 292 (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & LIBXSMM_FLAGS) ? 'n' : 't')); \ 293 const libxsmm_blasint *const libxsmm_blas_xgemm_k_ = (NULL != ((void*)(K)) ? (K) : (M)); \ 294 const libxsmm_blasint *const libxsmm_blas_xgemm_n_ = (NULL != ((void*)(N)) ? (N) : libxsmm_blas_xgemm_k_); \ 295 const libxsmm_blasint libxsmm_blas_xgemm_lda_ = LIBXSMM_MAX(NULL != ((void*)(LDA)) ? *(LDA) : \ 296 *(('n' == libxsmm_blas_xgemm_transa_ || *"N" == libxsmm_blas_xgemm_transa_) ? (M) : libxsmm_blas_xgemm_k_), 1); \ 297 const libxsmm_blasint libxsmm_blas_xgemm_ldb_ = LIBXSMM_MAX(NULL != ((void*)(LDB)) ? *(LDB) : \ 298 *(('n' == libxsmm_blas_xgemm_transb_ || *"N" == libxsmm_blas_xgemm_transb_) ? libxsmm_blas_xgemm_k_ : libxsmm_blas_xgemm_n_), 1); \ 299 const libxsmm_blasint libxsmm_blas_xgemm_ldc_ = LIBXSMM_MAX(NULL != ((void*)(LDC)) ? *(LDC) : *(M), 1); \ 300 const OTYPE libxsmm_blas_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ 301 const OTYPE libxsmm_blas_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ 302 LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(&libxsmm_blas_xgemm_transa_, &libxsmm_blas_xgemm_transb_, \ 303 M, libxsmm_blas_xgemm_n_, libxsmm_blas_xgemm_k_, \ 304 &libxsmm_blas_xgemm_alpha_, (const ITYPE*)(A), &libxsmm_blas_xgemm_lda_, \ 305 (const ITYPE*)(B), &libxsmm_blas_xgemm_ldb_, \ 306 &libxsmm_blas_xgemm_beta_, (ITYPE*)(C), &libxsmm_blas_xgemm_ldc_); \ 307 } 308 309 /** Helper macros for calling a dispatched function in a row/column-major aware fashion. */ 310 #define LIBXSMM_MMCALL_ABC(FN, A, B, C) \ 311 LIBXSMM_ASSERT(FN); FN(A, B, C) 312 #define LIBXSMM_MMCALL_PRF(FN, A, B, C, PA, PB, PC) { \ 313 LIBXSMM_NOPREFETCH_A(LIBXSMM_UNUSED(PA)); \ 314 LIBXSMM_NOPREFETCH_B(LIBXSMM_UNUSED(PB)); \ 315 LIBXSMM_NOPREFETCH_C(LIBXSMM_UNUSED(PC)); \ 316 LIBXSMM_ASSERT(FN); FN(A, B, C, \ 317 LIBXSMM_GEMM_PREFETCH_A(PA), \ 318 LIBXSMM_GEMM_PREFETCH_B(PB), \ 319 LIBXSMM_GEMM_PREFETCH_C(PC)); \ 320 } 321 322 #if (0/*LIBXSMM_GEMM_PREFETCH_NONE*/ == LIBXSMM_PREFETCH) 323 # define LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, LDA, LDB, LDC) \ 324 LIBXSMM_MMCALL_ABC(FN, A, B, C) 325 #else 326 # define LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, LDA, LDB, LDC) \ 327 LIBXSMM_MMCALL_PRF(FN, A, B, C, (A) + ((size_t)LDA) * (K), (B) + ((size_t)LDB) * (N), (C) + ((size_t)LDC) * (N)) 328 #endif 329 #define LIBXSMM_MMCALL(FN, A, B, C, M, N, K) LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, M, K, M) 330 331 /** Calculate problem size from M, N, and K using the correct integer type in order to cover the general case. */ 332 #define LIBXSMM_MNK_SIZE(M, N, K) (((size_t)(M)) * ((size_t)(N)) * ((size_t)(K))) 333 /** Calculate total number of matrix-elements; matrices A, B, C are given per M, N, K, and emphasize (S) the C-size. */ 334 #define LIBXSMM_SIZE(M, N, K, S) \ 335 (((size_t)(M) * (size_t)(K)) + ((size_t)(K) * (size_t)(N)) + \ 336 (((size_t)(S) * (size_t)(M) * (size_t)(N)))) 337 /** Condition based on arithmetic intensity (AI) */ 338 #define LIBXSMM_SMM_AI(M, N, K, S, TYPESIZE) \ 339 ((LIBXSMM_MNK_SIZE(M, N, K) * 2) <= ((size_t)(TYPESIZE) * 4/*AI*/ * LIBXSMM_SIZE(M, N, K, S))) 340 /** Determine whether an SMM is suitable, i.e., small enough. */ 341 #if !defined(LIBXSMM_THRESHOLD_AI) /* traditional MNK-threshold */ 342 # define LIBXSMM_SMM(M, N, K, S, TYPESIZE) (LIBXSMM_MNK_SIZE(M, N, K) <= (LIBXSMM_MAX_MNK)) 343 #else /* threshold based on arithmetic intensity */ 344 # define LIBXSMM_SMM LIBXSMM_SMM_AI 345 #endif 346 347 /** Fall-back code paths: LIBXSMM_XGEMM_FALLBACK0, and LIBXSMM_XGEMM_FALLBACK1 (macro template). */ 348 #if !defined(LIBXSMM_XGEMM_FALLBACK0) 349 # define LIBXSMM_XGEMM_FALLBACK0(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 350 LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 351 #endif 352 #if !defined(LIBXSMM_XGEMM_FALLBACK1) 353 # define LIBXSMM_XGEMM_FALLBACK1(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 354 LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 355 #endif 356 357 /** 358 * Execute a specialized function, or use a fall-back code path depending on threshold (macro template). 359 * LIBXSMM_XGEMM_FALLBACK0 or specialized function: below LIBXSMM_MAX_MNK 360 * LIBXSMM_XGEMM_FALLBACK1: above LIBXSMM_MAX_MNK 361 */ 362 #define LIBXSMM_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \ 363 const int libxsmm_xgemm_flags_ = LIBXSMM_GEMM_PFLAGS(TRANSA, TRANSB, LIBXSMM_FLAGS); \ 364 const libxsmm_blasint *const libxsmm_xgemm_k_ = (NULL != (K) ? (K) : (M)); \ 365 const libxsmm_blasint *const libxsmm_xgemm_n_ = (NULL != (N) ? (N) : libxsmm_xgemm_k_); \ 366 const libxsmm_blasint libxsmm_xgemm_lda_ = LIBXSMM_MAX(NULL != ((void*)(LDA)) ? *(LDA) : \ 367 *(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? (M) : libxsmm_xgemm_k_), 1); \ 368 const libxsmm_blasint libxsmm_xgemm_ldb_ = LIBXSMM_MAX(NULL != ((void*)(LDB)) ? *(LDB) : \ 369 *(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? libxsmm_xgemm_k_ : libxsmm_xgemm_n_), 1); \ 370 const libxsmm_blasint libxsmm_xgemm_ldc_ = LIBXSMM_MAX(NULL != (LDC) ? *(LDC) : *(M), 1); \ 371 if (LIBXSMM_SMM(*(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, 2/*RFO*/, sizeof(OTYPE))) { \ 372 const LIBXSMM_MMFUNCTION_TYPE2(ITYPE, OTYPE) libxsmm_mmfunction_ = LIBXSMM_MMDISPATCH_SYMBOL2(ITYPE, OTYPE)( \ 373 *(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, &libxsmm_xgemm_lda_, &libxsmm_xgemm_ldb_, &libxsmm_xgemm_ldc_, \ 374 (const OTYPE*)(ALPHA), (const OTYPE*)(BETA), &libxsmm_xgemm_flags_, NULL); \ 375 if (NULL != libxsmm_mmfunction_) { \ 376 LIBXSMM_MMCALL_LDX(libxsmm_mmfunction_, (const ITYPE*)(A), (const ITYPE*)(B), (OTYPE*)(C), \ 377 *(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, libxsmm_xgemm_lda_, libxsmm_xgemm_ldb_, libxsmm_xgemm_ldc_); \ 378 } \ 379 else { \ 380 const char libxsmm_xgemm_transa_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? 'n' : 't'); \ 381 const char libxsmm_xgemm_transb_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? 'n' : 't'); \ 382 const OTYPE libxsmm_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ 383 const OTYPE libxsmm_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ 384 LIBXSMM_XGEMM_FALLBACK0(ITYPE, OTYPE, &libxsmm_xgemm_transa_, &libxsmm_xgemm_transb_, \ 385 M, libxsmm_xgemm_n_, libxsmm_xgemm_k_, \ 386 &libxsmm_xgemm_alpha_, A, &libxsmm_xgemm_lda_, \ 387 B, &libxsmm_xgemm_ldb_, \ 388 &libxsmm_xgemm_beta_, C, &libxsmm_xgemm_ldc_); \ 389 } \ 390 } \ 391 else { \ 392 const char libxsmm_xgemm_transa_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? 'n' : 't'); \ 393 const char libxsmm_xgemm_transb_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? 'n' : 't'); \ 394 const OTYPE libxsmm_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \ 395 const OTYPE libxsmm_xgemm_beta_ = (NULL != ((void*)(BETA)) ? (*(const OTYPE*)(BETA)) : ((OTYPE)LIBXSMM_BETA)); \ 396 LIBXSMM_XGEMM_FALLBACK1(ITYPE, OTYPE, &libxsmm_xgemm_transa_, &libxsmm_xgemm_transb_, \ 397 M, libxsmm_xgemm_n_, libxsmm_xgemm_k_, \ 398 &libxsmm_xgemm_alpha_, A, &libxsmm_xgemm_lda_, \ 399 B, &libxsmm_xgemm_ldb_, \ 400 &libxsmm_xgemm_beta_, C, &libxsmm_xgemm_ldc_); \ 401 } \ 402 } 403 404 /** Helper macro to setup a matrix with some initial values. */ 405 #define LIBXSMM_MATINIT_AUX(OMP, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) { \ 406 /*const*/ double libxsmm_matinit_seed_ = (double)SEED; /* avoid constant conditional */ \ 407 const double libxsmm_matinit_scale_ = (SCALE) * libxsmm_matinit_seed_ + (SCALE); \ 408 const libxsmm_blasint libxsmm_matinit_ld_ = (libxsmm_blasint)LD; \ 409 libxsmm_blasint libxsmm_matinit_i_ = 0, libxsmm_matinit_j_; \ 410 LIBXSMM_OMP_VAR(libxsmm_matinit_i_); LIBXSMM_OMP_VAR(libxsmm_matinit_j_); \ 411 if (0 != libxsmm_matinit_seed_) { \ 412 OMP(parallel for private(libxsmm_matinit_i_, libxsmm_matinit_j_)) \ 413 for (libxsmm_matinit_i_ = 0; libxsmm_matinit_i_ < ((libxsmm_blasint)NCOLS); ++libxsmm_matinit_i_) { \ 414 for (libxsmm_matinit_j_ = 0; libxsmm_matinit_j_ < ((libxsmm_blasint)NROWS); ++libxsmm_matinit_j_) { \ 415 const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \ 416 (DST)[libxsmm_matinit_k_] = (TYPE)(libxsmm_matinit_scale_ / (1.0 + libxsmm_matinit_k_)); \ 417 } \ 418 for (; libxsmm_matinit_j_ < libxsmm_matinit_ld_; ++libxsmm_matinit_j_) { \ 419 const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \ 420 (DST)[libxsmm_matinit_k_] = (TYPE)SEED; \ 421 } \ 422 } \ 423 } \ 424 else { /* shuffle based initialization */ \ 425 const unsigned int libxsmm_matinit_maxval_ = ((unsigned int)NCOLS) * ((unsigned int)libxsmm_matinit_ld_); \ 426 const TYPE libxsmm_matinit_maxval2_ = (TYPE)(libxsmm_matinit_maxval_ / 2), libxsmm_matinit_inv_ = (TYPE)((SCALE) / libxsmm_matinit_maxval2_); \ 427 const size_t libxsmm_matinit_shuffle_ = libxsmm_shuffle(libxsmm_matinit_maxval_); \ 428 OMP(parallel for private(libxsmm_matinit_i_, libxsmm_matinit_j_)) \ 429 for (libxsmm_matinit_i_ = 0; libxsmm_matinit_i_ < ((libxsmm_blasint)NCOLS); ++libxsmm_matinit_i_) { \ 430 for (libxsmm_matinit_j_ = 0; libxsmm_matinit_j_ < libxsmm_matinit_ld_; ++libxsmm_matinit_j_) { \ 431 const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \ 432 (DST)[libxsmm_matinit_k_] = libxsmm_matinit_inv_ * /* normalize values to an interval of [-1, +1] */ \ 433 ((TYPE)(libxsmm_matinit_shuffle_ * libxsmm_matinit_k_ % libxsmm_matinit_maxval_) - libxsmm_matinit_maxval2_); \ 434 } \ 435 } \ 436 } \ 437 } 438 439 #define LIBXSMM_MATINIT(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \ 440 LIBXSMM_MATINIT_AUX(LIBXSMM_ELIDE, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) 441 #define LIBXSMM_MATINIT_SEQ(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \ 442 LIBXSMM_MATINIT(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) 443 #define LIBXSMM_MATINIT_OMP(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \ 444 LIBXSMM_MATINIT_AUX(LIBXSMM_PRAGMA_OMP, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) 445 446 /** Call libxsmm_gemm_print using LIBXSMM's GEMM-flags. */ 447 #define LIBXSMM_GEMM_PRINT(OSTREAM, PRECISION, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) \ 448 LIBXSMM_GEMM_PRINT2(OSTREAM, PRECISION, PRECISION, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) 449 #define LIBXSMM_GEMM_PRINT2(OSTREAM, IPREC, OPREC, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) \ 450 libxsmm_gemm_dprint2(OSTREAM, (libxsmm_gemm_precision)(IPREC), (libxsmm_gemm_precision)(OPREC), \ 451 /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \ 452 (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & (FLAGS)) ? 'n' : 't'), \ 453 (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & (FLAGS)) ? 'n' : 't'), \ 454 M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) 455 456 /** 457 * Utility function, which either prints information about the GEMM call 458 * or dumps (FILE/ostream=0) all input and output data into MHD files. 459 * The Meta Image Format (MHD) is suitable for visual inspection using, 460 * e.g., ITK-SNAP or ParaView. 461 */ 462 LIBXSMM_API void libxsmm_gemm_print(void* ostream, 463 libxsmm_gemm_precision precision, const char* transa, const char* transb, 464 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, 465 const void* alpha, const void* a, const libxsmm_blasint* lda, 466 const void* b, const libxsmm_blasint* ldb, 467 const void* beta, void* c, const libxsmm_blasint* ldc); 468 LIBXSMM_API void libxsmm_gemm_print2(void* ostream, 469 libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, 470 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, 471 const void* alpha, const void* a, const libxsmm_blasint* lda, 472 const void* b, const libxsmm_blasint* ldb, 473 const void* beta, void* c, const libxsmm_blasint* ldc); 474 LIBXSMM_API void libxsmm_gemm_dprint(void* ostream, 475 libxsmm_gemm_precision precision, char transa, char transb, 476 libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, 477 double dalpha, const void* a, libxsmm_blasint lda, 478 const void* b, libxsmm_blasint ldb, 479 double dbeta, void* c, libxsmm_blasint ldc); 480 LIBXSMM_API void libxsmm_gemm_dprint2(void* ostream, 481 libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, char transa, char transb, 482 libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, 483 double dalpha, const void* a, libxsmm_blasint lda, 484 const void* b, libxsmm_blasint ldb, 485 double dbeta, void* c, libxsmm_blasint ldc); 486 LIBXSMM_API void libxsmm_gemm_xprint(void* ostream, 487 libxsmm_xmmfunction kernel, const void* a, const void* b, void* c); 488 489 /** GEMM_BATCH: fall-back prototype functions served by any compliant LAPACK/BLAS. */ 490 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemm_batch_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch)); 491 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemm_batch_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch)); 492 /** GEMM: fall-back prototype functions served by any compliant LAPACK/BLAS. */ 493 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemm_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm)); 494 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemm_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm)); 495 /** GEMV: fall-back prototype functions served by any compliant LAPACK/BLAS. */ 496 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemv_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemv)); 497 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemv_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemv)); 498 /** Helper function to consume arguments when called. */ 499 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sink_function)(LIBXSMM_VARIADIC); 500 501 /** The original BLAS functions. */ 502 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch_function); 503 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch_function); 504 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemm_function libxsmm_original_dgemm_function); 505 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemm_function libxsmm_original_sgemm_function); 506 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemv_function libxsmm_original_dgemv_function); 507 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemv_function libxsmm_original_sgemv_function); 508 LIBXSMM_API libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void); 509 LIBXSMM_API libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void); 510 LIBXSMM_API libxsmm_dgemm_function libxsmm_original_dgemm(void); 511 LIBXSMM_API libxsmm_sgemm_function libxsmm_original_sgemm(void); 512 LIBXSMM_API libxsmm_dgemv_function libxsmm_original_dgemv(void); 513 LIBXSMM_API libxsmm_sgemv_function libxsmm_original_sgemv(void); 514 LIBXSMM_API libxsmm_sink_function libxsmm_blas_error(const char* symbol); 515 LIBXSMM_API void libxsmm_sink(LIBXSMM_VARIADIC); 516 517 /** 518 * General dense matrix multiplication, which re-exposes LAPACK/BLAS 519 * but allows to rely on LIBXSMM's defaults (libxsmm_config.h) 520 * when supplying NULL-arguments in certain places. 521 */ 522 LIBXSMM_API void libxsmm_blas_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, 523 const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k, 524 const void* alpha, const void* a, const libxsmm_blasint* lda, 525 const void* b, const libxsmm_blasint* ldb, 526 const void* beta, void* c, const libxsmm_blasint* ldc); 527 528 #define libxsmm_blas_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 529 libxsmm_blas_xgemm(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, \ 530 TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 531 #define libxsmm_blas_sgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 532 libxsmm_blas_xgemm(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, \ 533 TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 534 535 #define libxsmm_dgemm_omp(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 536 libxsmm_xgemm_omp(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, \ 537 TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 538 #define libxsmm_sgemm_omp(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \ 539 libxsmm_xgemm_omp(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, \ 540 TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) 541 542 /** Translates GEMM prefetch request into prefetch-enumeration (incl. FE's auto-prefetch). */ 543 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_xprefetch(const int* prefetch); 544 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_prefetch(int prefetch); 545 546 #endif /*LIBXSMM_FRONTEND_H*/ 547 548