1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_FRONTEND_H
12 #define LIBXSMM_FRONTEND_H
13 
14 #include "libxsmm_typedefs.h"
15 
16 /** Helper macros for eliding prefetch address calculations depending on prefetch scheme. */
17 #if !defined(_WIN32) && !defined(__CYGWIN__) /* TODO: fully support calling convention */
18 #if 0 != ((LIBXSMM_PREFETCH) & 2/*AL2*/) \
19  || 0 != ((LIBXSMM_PREFETCH) & 8/*AL2_AHEAD*/)
20 # define LIBXSMM_GEMM_PREFETCH_A(EXPR) (EXPR)
21 #endif
22 #if 0 != ((LIBXSMM_PREFETCH) & 4/*BL2_VIA_C*/) \
23  || 0 != ((LIBXSMM_PREFETCH) & 16/*BL1*/)
24 # define LIBXSMM_GEMM_PREFETCH_B(EXPR) (EXPR)
25 #endif
26 #endif
27 /** Secondary helper macros derived from the above group. */
28 #if defined(LIBXSMM_GEMM_PREFETCH_A)
29 # define LIBXSMM_NOPREFETCH_A(EXPR)
30 #else
31 # define LIBXSMM_NOPREFETCH_A(EXPR) EXPR
32 # define LIBXSMM_GEMM_PREFETCH_A(EXPR) 0
33 #endif
34 #if defined(LIBXSMM_GEMM_PREFETCH_B)
35 # define LIBXSMM_NOPREFETCH_B(EXPR)
36 #else
37 # define LIBXSMM_NOPREFETCH_B(EXPR) EXPR
38 # define LIBXSMM_GEMM_PREFETCH_B(EXPR) 0
39 #endif
40 #if defined(LIBXSMM_GEMM_PREFETCH_C)
41 # define LIBXSMM_NOPREFETCH_C(EXPR)
42 #else
43 # define LIBXSMM_NOPREFETCH_C(EXPR) EXPR
44 # define LIBXSMM_GEMM_PREFETCH_C(EXPR) 0
45 #endif
46 
47 /** MKL_DIRECT_CALL requires to include the MKL interface. */
48 #if (defined(MKL_DIRECT_CALL_SEQ) || defined(MKL_DIRECT_CALL))
49 # if (0 != LIBXSMM_ILP64 && !defined(MKL_ILP64))
50 #   error "Inconsistent ILP64 configuration detected!"
51 # endif
52 # if defined(LIBXSMM_OFFLOAD_BUILD)
53 #   pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
54 #   include <mkl.h>
55 #   pragma offload_attribute(pop)
56 # else
57 #   include <mkl.h>
58 # endif
59 #endif
60 
61 /** Automatically select a prefetch-strategy (libxsmm_get_gemm_xprefetch, etc.). */
62 #define LIBXSMM_PREFETCH_AUTO -1
63 
64 /** Append "_omp" postfix to the given symbol. */
65 #define LIBXSMM_USEOMP(FUNCTION) LIBXSMM_CONCATENATE(FUNCTION, _omp)
66 
67 /** Helper macro for BLAS-style prefixes. */
68 #define LIBXSMM_TPREFIX_NAME(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_TPREFIX_, TYPE)
69 #define LIBXSMM_TPREFIX(TYPE, FUNCTION) LIBXSMM_CONCATENATE(LIBXSMM_TPREFIX_NAME(TYPE), FUNCTION)
70 #define LIBXSMM_TPREFIX_doubledouble d
71 #define LIBXSMM_TPREFIX_floatfloat s
72 #define LIBXSMM_TPREFIX_shortfloat ws
73 #define LIBXSMM_TPREFIX_shortint wi
74 #define LIBXSMM_TPREFIX_libxsmm_bfloat16float bs
75 /** Defaults if only the input type is specified. */
76 #define LIBXSMM_TPREFIX_double LIBXSMM_TPREFIX_doubledouble
77 #define LIBXSMM_TPREFIX_float LIBXSMM_TPREFIX_floatfloat
78 #define LIBXSMM_TPREFIX_short LIBXSMM_TPREFIX_shortint
79 
80 /** Construct symbol name from a given real type name (float, double and short). */
81 #define LIBXSMM_BLAS_FNTYPE(TYPE, KIND) LIBXSMM_CONCATENATE3(libxsmm_, LIBXSMM_TPREFIX(TYPE, KIND), _function)
82 #define LIBXSMM_MMFUNCTION_TYPE(TYPE)   LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, mmfunction))
83 #define LIBXSMM_MMDISPATCH_SYMBOL(TYPE) LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, mmdispatch))
84 #define LIBXSMM_XBLAS_SYMBOL(TYPE)      LIBXSMM_CONCATENATE(libxsmm_blas_, LIBXSMM_TPREFIX(TYPE, gemm))
85 #define LIBXSMM_XGEMM_SYMBOL(TYPE)      LIBXSMM_CONCATENATE(libxsmm_, LIBXSMM_TPREFIX(TYPE, gemm))
86 #define LIBXSMM_YGEMM_SYMBOL(TYPE)      LIBXSMM_USEOMP(LIBXSMM_XGEMM_SYMBOL(TYPE))
87 #define LIBXSMM_BLAS_SYMBOL(TYPE, KIND) LIBXSMM_FSYMBOL(LIBXSMM_TPREFIX(TYPE, KIND))
88 #define LIBXSMM_CBLAS_SYMBOL            LIBXSMM_TPREFIX
89 
90 #define LIBXSMM_BLAS_DECL(TYPE, KIND, DECL) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_, LIBXSMM_TPREFIX(TYPE, KIND))(DECL)
91 #if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
92 # define LIBXSMM_BLAS_dgemm(DECL) DECL;
93 # define LIBXSMM_BLAS_sgemm(DECL) DECL;
94 # define LIBXSMM_BLAS_dgemv(DECL) DECL;
95 # define LIBXSMM_BLAS_sgemv(DECL) DECL;
96 #else
97 # define LIBXSMM_BLAS_dgemm
98 # define LIBXSMM_BLAS_sgemm
99 # define LIBXSMM_BLAS_dgemv
100 # define LIBXSMM_BLAS_sgemv
101 #endif
102 
103 /* Construct prefix names, function type or dispatch function from given input and output types. */
104 #define LIBXSMM_MMFUNCTION_TYPE2(ITYPE, OTYPE)    LIBXSMM_MMFUNCTION_TYPE(LIBXSMM_CONCATENATE(ITYPE, OTYPE))
105 #define LIBXSMM_MMDISPATCH_SYMBOL2(ITYPE, OTYPE)  LIBXSMM_MMDISPATCH_SYMBOL(LIBXSMM_CONCATENATE(ITYPE, OTYPE))
106 #define LIBXSMM_TPREFIX_NAME2(ITYPE, OTYPE)       LIBXSMM_TPREFIX_NAME(LIBXSMM_CONCATENATE(ITYPE, OTYPE))
107 #define LIBXSMM_TPREFIX2(ITYPE, OTYPE, FUNCTION)  LIBXSMM_TPREFIX(LIBXSMM_CONCATENATE(ITYPE, OTYPE), FUNCTION)
108 
109 /** Helper macro for comparing selected types. */
110 #define LIBXSMM_EQUAL(T1, T2) LIBXSMM_CONCATENATE3(LIBXSMM_EQUAL_, T1, T2)
111 #define LIBXSMM_EQUAL_floatfloat 1
112 #define LIBXSMM_EQUAL_doubledouble 1
113 #define LIBXSMM_EQUAL_floatdouble 0
114 #define LIBXSMM_EQUAL_doublefloat 0
115 #define LIBXSMM_EQUAL_shortdouble 0
116 #define LIBXSMM_EQUAL_shortfloat 0
117 
118 #if defined(LIBXSMM_BLAS_CONST)
119 # undef LIBXSMM_BLAS_CONST
120 # define LIBXSMM_BLAS_CONST const
121 #elif defined(OPENBLAS_CONST)
122 # define LIBXSMM_BLAS_CONST OPENBLAS_CONST
123 #elif defined(LIBXSMM_BLAS_NONCONST) || defined(__OPENBLAS) || defined(__OPENBLAS77)
124 # define LIBXSMM_BLAS_CONST
125 #else
126 # define LIBXSMM_BLAS_CONST const
127 #endif
128 
129 #if !defined(LIBXSMM_NO_BLAS)
130 # if (!defined(__BLAS) || (0 != __BLAS))
131 #   define LIBXSMM_NO_BLAS 0
132 #   define LIBXSMM_BLAS 1
133 # else
134 #   define LIBXSMM_NO_BLAS 1
135 #   define LIBXSMM_BLAS 0
136 # endif
137 #endif
138 
139 #if defined(__BLAS) && (1 == __BLAS)
140 # if defined(__OPENBLAS)
141     LIBXSMM_EXTERN void openblas_set_num_threads(int num_threads);
142 #   define LIBXSMM_BLAS_INIT openblas_set_num_threads(1);
143 # endif
144 #endif
145 #if !defined(LIBXSMM_BLAS_INIT)
146 # define LIBXSMM_BLAS_INIT
147 #endif
148 
149 #if defined(LIBXSMM_BUILD)
150 # if defined(LIBXSMM_BUILD_EXT) && !defined(__STATIC)
151 #   define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_APIEXT
152 # elif defined(LIBXSMM_NO_BLAS) && (1 == LIBXSMM_NO_BLAS)
153 #   define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_API
154 # endif
155 #endif
156 #if !defined(LIBXSMM_BLAS_SYMBOL_VISIBILITY)
157 # define LIBXSMM_BLAS_SYMBOL_VISIBILITY LIBXSMM_EXTERN LIBXSMM_VISIBILITY_IMPORT LIBXSMM_RETARGETABLE
158 #endif
159 
160 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemm_batch(CONST_STAR, STAR, TYPE) char CONST_STAR, char CONST_STAR, \
161   libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, \
162   TYPE CONST_STAR, TYPE CONST_STAR STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR STAR, libxsmm_blasint CONST_STAR, \
163   TYPE CONST_STAR, TYPE STAR STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR
164 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemm(CONST_STAR, STAR, TYPE) char CONST_STAR, char CONST_STAR, \
165   libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, \
166   TYPE CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, TYPE STAR, libxsmm_blasint CONST_STAR
167 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE_gemv(CONST_STAR, STAR, TYPE) char CONST_STAR, libxsmm_blasint CONST_STAR, libxsmm_blasint CONST_STAR, \
168   TYPE CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, TYPE CONST_STAR, libxsmm_blasint CONST_STAR, \
169   TYPE CONST_STAR, TYPE STAR, libxsmm_blasint CONST_STAR
170 #define LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_SYMBOL_SIGNATURE_, KIND)(CONST_STAR, STAR, TYPE)
171 #define LIBXSMM_BLAS_SYMBOL_FDECL(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_BLAS_SYMBOL_VISIBILITY \
172   void LIBXSMM_BLAS_SYMBOL(TYPE, KIND)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND))
173 #define LIBXSMM_BLAS_SYMBOL_CDECL(CONST_STAR, STAR, TYPE, KIND) LIBXSMM_BLAS_SYMBOL_VISIBILITY \
174   void LIBXSMM_CBLAS_SYMBOL(TYPE, KIND)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(CONST_STAR, STAR, TYPE, KIND))
175 
176 #if (0 != LIBXSMM_BLAS) /* BLAS available */
177 # define LIBXSMM_BLAS_SYMBOL_DECL(TYPE, KIND) LIBXSMM_BLAS_DECL(TYPE, KIND, LIBXSMM_BLAS_SYMBOL_FDECL(LIBXSMM_BLAS_CONST*, *, TYPE, KIND))
178 #else
179 # define LIBXSMM_BLAS_SYMBOL_DECL(TYPE, KIND)
180 #endif
181 
182 /** Helper macro consolidating the transpose requests into a set of flags. */
183 #define LIBXSMM_GEMM_FLAGS(TRANSA, TRANSB) /* check for N/n rather than T/t since C/c is also valid! */ \
184    ((('n' == (TRANSA) || *"N" == (TRANSA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \
185   | (('n' == (TRANSB) || *"N" == (TRANSB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B))
186 
187 /** Helper macro consolidating CBLAS transpose requests into a set of flags. */
188 #define LIBXSMM_GEMM_CFLAGS(TRANSA, TRANSB) /* check for N/n rather than T/t since C/c is also valid! */ \
189    ((CblasNoTrans == (TRANSA) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \
190   | (CblasNoTrans == (TRANSB) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B))
191 
192 /** Helper macro consolidating the transpose requests into a set of flags. */
193 #define LIBXSMM_GEMM_VNNI_FLAGS(TRANSA, TRANSB, VNNIA, VNNIB) /* check for N/n rather than T/t since C/c is also valid! */ \
194    ((('n' == (TRANSA) || *"N" == (TRANSA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_A) \
195   | (('n' == (TRANSB) || *"N" == (TRANSB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_TRANS_B) \
196   | (('n' == (VNNIA) || *"N" == (VNNIA)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_VNNI_A) \
197   | (('n' == (VNNIB) || *"N" == (VNNIB)) ? LIBXSMM_GEMM_FLAG_NONE : LIBXSMM_GEMM_FLAG_VNNI_B))
198 
199 /** Helper macro allowing NULL-requests (transposes) supplied by some default. */
200 #define LIBXSMM_GEMM_PFLAGS(TRANSA, TRANSB, DEFAULT) LIBXSMM_GEMM_FLAGS( \
201   NULL != ((const void*)(TRANSA)) ? (*(const char*)(TRANSA)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & (DEFAULT)) ? 'n' : 't'), \
202   NULL != ((const void*)(TRANSB)) ? (*(const char*)(TRANSB)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & (DEFAULT)) ? 'n' : 't')) \
203   | (~(LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B) & (DEFAULT))
204 
205 /** Inlinable GEMM exercising the compiler's code generation (macro template). TODO: only NN is supported and SP/DP matrices. */
206 #define LIBXSMM_INLINE_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \
207   /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \
208   const char libxsmm_inline_xgemm_transa_ = (char)(NULL != ((void*)(TRANSA)) ? (*(const char*)(TRANSA)) : \
209     (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & LIBXSMM_FLAGS) ? 'n' : 't')); \
210   const char libxsmm_inline_xgemm_transb_ = (char)(NULL != ((void*)(TRANSB)) ? (*(const char*)(TRANSB)) : \
211     (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & LIBXSMM_FLAGS) ? 'n' : 't')); \
212   const libxsmm_blasint libxsmm_inline_xgemm_m_ = *(const libxsmm_blasint*)(M); /* must be specified */ \
213   const libxsmm_blasint libxsmm_inline_xgemm_k_ = (NULL != ((void*)(K)) ? (*(const libxsmm_blasint*)(K)) : libxsmm_inline_xgemm_m_); \
214   const libxsmm_blasint libxsmm_inline_xgemm_n_ = (NULL != ((void*)(N)) ? (*(const libxsmm_blasint*)(N)) : libxsmm_inline_xgemm_k_); \
215   const libxsmm_blasint libxsmm_inline_xgemm_lda_ = (NULL != ((void*)(LDA)) ? (*(const libxsmm_blasint*)(LDA)) : \
216     (('n' == libxsmm_inline_xgemm_transa_ || *"N" == libxsmm_inline_xgemm_transa_) ? libxsmm_inline_xgemm_m_ : libxsmm_inline_xgemm_k_)); \
217   const libxsmm_blasint libxsmm_inline_xgemm_ldb_ = (NULL != ((void*)(LDB)) ? (*(const libxsmm_blasint*)(LDB)) : \
218     (('n' == libxsmm_inline_xgemm_transb_ || *"N" == libxsmm_inline_xgemm_transb_) ? libxsmm_inline_xgemm_k_ : libxsmm_inline_xgemm_n_)); \
219   const libxsmm_blasint libxsmm_inline_xgemm_ldc_ = (NULL != ((void*)(LDC)) ? (*(const libxsmm_blasint*)(LDC)) : libxsmm_inline_xgemm_m_); \
220   const OTYPE libxsmm_inline_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \
221   const OTYPE libxsmm_inline_xgemm_beta_  = (NULL != ((void*)(BETA))  ? (*(const OTYPE*)(BETA))  : ((OTYPE)LIBXSMM_BETA)); \
222   libxsmm_blasint libxsmm_inline_xgemm_ni_, libxsmm_inline_xgemm_mi_ = 0, libxsmm_inline_xgemm_ki_; /* loop induction variables */ \
223   LIBXSMM_ASSERT('n' == libxsmm_inline_xgemm_transa_ || *"N" == libxsmm_inline_xgemm_transa_); \
224   LIBXSMM_ASSERT('n' == libxsmm_inline_xgemm_transb_ || *"N" == libxsmm_inline_xgemm_transb_); \
225   LIBXSMM_PRAGMA_SIMD \
226   for (libxsmm_inline_xgemm_mi_ = 0; libxsmm_inline_xgemm_mi_ < libxsmm_inline_xgemm_m_; ++libxsmm_inline_xgemm_mi_) { \
227     LIBXSMM_PRAGMA_LOOP_COUNT(1, LIBXSMM_CONFIG_MAX_DIM, LIBXSMM_CONFIG_AVG_DIM) \
228     for (libxsmm_inline_xgemm_ki_ = 0; libxsmm_inline_xgemm_ki_ < libxsmm_inline_xgemm_k_; ++libxsmm_inline_xgemm_ki_) { \
229       LIBXSMM_PRAGMA_UNROLL \
230       for (libxsmm_inline_xgemm_ni_ = 0; libxsmm_inline_xgemm_ni_ < libxsmm_inline_xgemm_n_; ++libxsmm_inline_xgemm_ni_) { \
231         ((OTYPE*)(C))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldc_+libxsmm_inline_xgemm_mi_] \
232           = ((const ITYPE*)(B))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldb_+libxsmm_inline_xgemm_ki_] * \
233            (((const ITYPE*)(A))[libxsmm_inline_xgemm_ki_*libxsmm_inline_xgemm_lda_+libxsmm_inline_xgemm_mi_] * libxsmm_inline_xgemm_alpha_) \
234           + ((const OTYPE*)(C))[libxsmm_inline_xgemm_ni_*libxsmm_inline_xgemm_ldc_+libxsmm_inline_xgemm_mi_] * libxsmm_inline_xgemm_beta_; \
235       } \
236     } \
237   } \
238 }
239 
240 #if (defined(LIBXSMM_INIT) || defined(LIBXSMM_CTOR))
241 # undef LIBXSMM_INIT
242 # define LIBXSMM_INIT LIBXSMM_ASSERT_MSG(1 < libxsmm_ninit, "LIBXSMM is not initialized");
243 # define LIBXSMM_INIT_COMPLETED
244 #else
245 # define LIBXSMM_INIT if (2 > libxsmm_ninit) libxsmm_init();
246 #endif
247 
248 /** Map to appropriate BLAS function (or fall-back). The mapping is used, e.g., inside of LIBXSMM_BLAS_XGEMM. */
249 #define LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, FUNCTION) LIBXSMM_CONCATENATE(LIBXSMM_BLAS_FUNCTION_, LIBXSMM_TPREFIX2(ITYPE, OTYPE, FUNCTION))
250 #if (0 != LIBXSMM_BLAS) /* Helper macro to eventually (if defined) call libxsmm_init */
251 # if defined(LIBXSMM_INIT_COMPLETED)
252 #   define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_original_dgemm_batch_function
253 #   define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_original_sgemm_batch_function
254 #   define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_original_dgemm_function
255 #   define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_original_sgemm_function
256 #   define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_original_dgemv_function
257 #   define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_original_sgemv_function
258 # else
259 #   define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_original_dgemm_batch()
260 #   define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_original_sgemm_batch()
261 #   define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_original_dgemm()
262 #   define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_original_sgemm()
263 #   define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_original_dgemv()
264 #   define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_original_sgemv()
265 # endif
266 #else /* no BLAS */
267 # define LIBXSMM_BLAS_FUNCTION_dgemm_batch libxsmm_blas_error("dgemm_batch")
268 # define LIBXSMM_BLAS_FUNCTION_sgemm_batch libxsmm_blas_error("sgemm_batch")
269 # define LIBXSMM_BLAS_FUNCTION_dgemm libxsmm_blas_error("dgemm")
270 # define LIBXSMM_BLAS_FUNCTION_sgemm libxsmm_blas_error("sgemm")
271 # define LIBXSMM_BLAS_FUNCTION_dgemv libxsmm_blas_error("dgemv")
272 # define LIBXSMM_BLAS_FUNCTION_sgemv libxsmm_blas_error("sgemv")
273 #endif
274 /** Low-precision (BLAS-like) function symbols. */
275 #define LIBXSMM_BLAS_FUNCTION_wigemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
276   LIBXSMM_INLINE_XGEMM(short, int, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
277 #define LIBXSMM_BLAS_FUNCTION_bsgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
278   LIBXSMM_INLINE_XGEMM(libxsmm_bfloat16, float, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
279 
280 /** Short-cut macros to construct desired BLAS function symbol. */
281 #define LIBXSMM_BLAS_FUNCTION1(TYPE, FUNCTION) LIBXSMM_BLAS_FUNCTION(TYPE, TYPE, FUNCTION)
282 #define LIBXSMM_GEMM_BATCH_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemm_batch)
283 #define LIBXSMM_GEMM_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemm)
284 #define LIBXSMM_GEMV_SYMBOL(TYPE) LIBXSMM_BLAS_FUNCTION1(TYPE, gemv)
285 
286 /** BLAS-based GEMM supplied by the linked LAPACK/BLAS library (macro template). */
287 #define LIBXSMM_BLAS_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \
288   /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \
289   const char libxsmm_blas_xgemm_transa_ = (char)(NULL != ((void*)(TRANSA)) ? (*(const char*)(TRANSA)) : \
290     (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & LIBXSMM_FLAGS) ? 'n' : 't')); \
291   const char libxsmm_blas_xgemm_transb_ = (char)(NULL != ((void*)(TRANSB)) ? (*(const char*)(TRANSB)) : \
292     (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & LIBXSMM_FLAGS) ? 'n' : 't')); \
293   const libxsmm_blasint *const libxsmm_blas_xgemm_k_ = (NULL != ((void*)(K)) ? (K) : (M)); \
294   const libxsmm_blasint *const libxsmm_blas_xgemm_n_ = (NULL != ((void*)(N)) ? (N) : libxsmm_blas_xgemm_k_); \
295   const libxsmm_blasint libxsmm_blas_xgemm_lda_ = LIBXSMM_MAX(NULL != ((void*)(LDA)) ? *(LDA) : \
296     *(('n' == libxsmm_blas_xgemm_transa_ || *"N" == libxsmm_blas_xgemm_transa_) ? (M) : libxsmm_blas_xgemm_k_), 1); \
297   const libxsmm_blasint libxsmm_blas_xgemm_ldb_ = LIBXSMM_MAX(NULL != ((void*)(LDB)) ? *(LDB) : \
298     *(('n' == libxsmm_blas_xgemm_transb_ || *"N" == libxsmm_blas_xgemm_transb_) ? libxsmm_blas_xgemm_k_ : libxsmm_blas_xgemm_n_), 1); \
299   const libxsmm_blasint libxsmm_blas_xgemm_ldc_ = LIBXSMM_MAX(NULL != ((void*)(LDC)) ? *(LDC) : *(M), 1); \
300   const OTYPE libxsmm_blas_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \
301   const OTYPE libxsmm_blas_xgemm_beta_  = (NULL != ((void*)(BETA))  ? (*(const OTYPE*)(BETA))  : ((OTYPE)LIBXSMM_BETA)); \
302   LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(&libxsmm_blas_xgemm_transa_, &libxsmm_blas_xgemm_transb_, \
303     M, libxsmm_blas_xgemm_n_, libxsmm_blas_xgemm_k_, \
304     &libxsmm_blas_xgemm_alpha_, (const ITYPE*)(A), &libxsmm_blas_xgemm_lda_, \
305                                 (const ITYPE*)(B), &libxsmm_blas_xgemm_ldb_, \
306      &libxsmm_blas_xgemm_beta_,       (ITYPE*)(C), &libxsmm_blas_xgemm_ldc_); \
307 }
308 
309 /** Helper macros for calling a dispatched function in a row/column-major aware fashion. */
310 #define LIBXSMM_MMCALL_ABC(FN, A, B, C) \
311   LIBXSMM_ASSERT(FN); FN(A, B, C)
312 #define LIBXSMM_MMCALL_PRF(FN, A, B, C, PA, PB, PC) { \
313   LIBXSMM_NOPREFETCH_A(LIBXSMM_UNUSED(PA)); \
314   LIBXSMM_NOPREFETCH_B(LIBXSMM_UNUSED(PB)); \
315   LIBXSMM_NOPREFETCH_C(LIBXSMM_UNUSED(PC)); \
316   LIBXSMM_ASSERT(FN); FN(A, B, C, \
317     LIBXSMM_GEMM_PREFETCH_A(PA), \
318     LIBXSMM_GEMM_PREFETCH_B(PB), \
319     LIBXSMM_GEMM_PREFETCH_C(PC)); \
320 }
321 
322 #if (0/*LIBXSMM_GEMM_PREFETCH_NONE*/ == LIBXSMM_PREFETCH)
323 # define LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, LDA, LDB, LDC) \
324   LIBXSMM_MMCALL_ABC(FN, A, B, C)
325 #else
326 # define LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, LDA, LDB, LDC) \
327   LIBXSMM_MMCALL_PRF(FN, A, B, C, (A) + ((size_t)LDA) * (K), (B) + ((size_t)LDB) * (N), (C) + ((size_t)LDC) * (N))
328 #endif
329 #define LIBXSMM_MMCALL(FN, A, B, C, M, N, K) LIBXSMM_MMCALL_LDX(FN, A, B, C, M, N, K, M, K, M)
330 
331 /** Calculate problem size from M, N, and K using the correct integer type in order to cover the general case. */
332 #define LIBXSMM_MNK_SIZE(M, N, K) (((size_t)(M)) * ((size_t)(N)) * ((size_t)(K)))
333 /** Calculate total number of matrix-elements; matrices A, B, C are given per M, N, K, and emphasize (S) the C-size. */
334 #define LIBXSMM_SIZE(M, N, K, S) \
335     (((size_t)(M) * (size_t)(K)) + ((size_t)(K) * (size_t)(N)) + \
336     (((size_t)(S) * (size_t)(M) * (size_t)(N))))
337 /** Condition based on arithmetic intensity (AI) */
338 #define LIBXSMM_SMM_AI(M, N, K, S, TYPESIZE) \
339     ((LIBXSMM_MNK_SIZE(M, N, K) * 2) <= ((size_t)(TYPESIZE) * 4/*AI*/ * LIBXSMM_SIZE(M, N, K, S)))
340 /** Determine whether an SMM is suitable, i.e., small enough. */
341 #if !defined(LIBXSMM_THRESHOLD_AI) /* traditional MNK-threshold */
342 # define LIBXSMM_SMM(M, N, K, S, TYPESIZE) (LIBXSMM_MNK_SIZE(M, N, K) <= (LIBXSMM_MAX_MNK))
343 #else /* threshold based on arithmetic intensity */
344 # define LIBXSMM_SMM LIBXSMM_SMM_AI
345 #endif
346 
347 /** Fall-back code paths: LIBXSMM_XGEMM_FALLBACK0, and LIBXSMM_XGEMM_FALLBACK1 (macro template). */
348 #if !defined(LIBXSMM_XGEMM_FALLBACK0)
349 # define LIBXSMM_XGEMM_FALLBACK0(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
350      LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
351 #endif
352 #if !defined(LIBXSMM_XGEMM_FALLBACK1)
353 # define LIBXSMM_XGEMM_FALLBACK1(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
354      LIBXSMM_BLAS_FUNCTION(ITYPE, OTYPE, gemm)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
355 #endif
356 
357 /**
358  * Execute a specialized function, or use a fall-back code path depending on threshold (macro template).
359  * LIBXSMM_XGEMM_FALLBACK0 or specialized function: below LIBXSMM_MAX_MNK
360  * LIBXSMM_XGEMM_FALLBACK1: above LIBXSMM_MAX_MNK
361  */
362 #define LIBXSMM_XGEMM(ITYPE, OTYPE, TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) { \
363   const int libxsmm_xgemm_flags_ = LIBXSMM_GEMM_PFLAGS(TRANSA, TRANSB, LIBXSMM_FLAGS); \
364   const libxsmm_blasint *const libxsmm_xgemm_k_ = (NULL != (K) ? (K) : (M)); \
365   const libxsmm_blasint *const libxsmm_xgemm_n_ = (NULL != (N) ? (N) : libxsmm_xgemm_k_); \
366   const libxsmm_blasint libxsmm_xgemm_lda_ = LIBXSMM_MAX(NULL != ((void*)(LDA)) ? *(LDA) : \
367     *(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? (M) : libxsmm_xgemm_k_), 1); \
368   const libxsmm_blasint libxsmm_xgemm_ldb_ = LIBXSMM_MAX(NULL != ((void*)(LDB)) ? *(LDB) : \
369     *(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? libxsmm_xgemm_k_ : libxsmm_xgemm_n_), 1); \
370   const libxsmm_blasint libxsmm_xgemm_ldc_ = LIBXSMM_MAX(NULL != (LDC) ? *(LDC) : *(M), 1); \
371   if (LIBXSMM_SMM(*(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, 2/*RFO*/, sizeof(OTYPE))) { \
372     const LIBXSMM_MMFUNCTION_TYPE2(ITYPE, OTYPE) libxsmm_mmfunction_ = LIBXSMM_MMDISPATCH_SYMBOL2(ITYPE, OTYPE)( \
373       *(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, &libxsmm_xgemm_lda_, &libxsmm_xgemm_ldb_, &libxsmm_xgemm_ldc_, \
374       (const OTYPE*)(ALPHA), (const OTYPE*)(BETA), &libxsmm_xgemm_flags_, NULL); \
375     if (NULL != libxsmm_mmfunction_) { \
376       LIBXSMM_MMCALL_LDX(libxsmm_mmfunction_, (const ITYPE*)(A), (const ITYPE*)(B), (OTYPE*)(C), \
377         *(M), *libxsmm_xgemm_n_, *libxsmm_xgemm_k_, libxsmm_xgemm_lda_, libxsmm_xgemm_ldb_, libxsmm_xgemm_ldc_); \
378     } \
379     else { \
380       const char libxsmm_xgemm_transa_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? 'n' : 't'); \
381       const char libxsmm_xgemm_transb_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? 'n' : 't'); \
382       const OTYPE libxsmm_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \
383       const OTYPE libxsmm_xgemm_beta_  = (NULL != ((void*)(BETA))  ? (*(const OTYPE*)(BETA))  : ((OTYPE)LIBXSMM_BETA)); \
384       LIBXSMM_XGEMM_FALLBACK0(ITYPE, OTYPE, &libxsmm_xgemm_transa_, &libxsmm_xgemm_transb_, \
385         M, libxsmm_xgemm_n_, libxsmm_xgemm_k_, \
386         &libxsmm_xgemm_alpha_, A, &libxsmm_xgemm_lda_, \
387                                B, &libxsmm_xgemm_ldb_, \
388          &libxsmm_xgemm_beta_, C, &libxsmm_xgemm_ldc_); \
389     } \
390   } \
391   else { \
392     const char libxsmm_xgemm_transa_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & libxsmm_xgemm_flags_) ? 'n' : 't'); \
393     const char libxsmm_xgemm_transb_ = (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & libxsmm_xgemm_flags_) ? 'n' : 't'); \
394     const OTYPE libxsmm_xgemm_alpha_ = (NULL != ((void*)(ALPHA)) ? (*(const OTYPE*)(ALPHA)) : ((OTYPE)LIBXSMM_ALPHA)); \
395     const OTYPE libxsmm_xgemm_beta_  = (NULL != ((void*)(BETA))  ? (*(const OTYPE*)(BETA))  : ((OTYPE)LIBXSMM_BETA)); \
396     LIBXSMM_XGEMM_FALLBACK1(ITYPE, OTYPE, &libxsmm_xgemm_transa_, &libxsmm_xgemm_transb_, \
397       M, libxsmm_xgemm_n_, libxsmm_xgemm_k_, \
398       &libxsmm_xgemm_alpha_, A, &libxsmm_xgemm_lda_, \
399                              B, &libxsmm_xgemm_ldb_, \
400        &libxsmm_xgemm_beta_, C, &libxsmm_xgemm_ldc_); \
401   } \
402 }
403 
404 /** Helper macro to setup a matrix with some initial values. */
405 #define LIBXSMM_MATINIT_AUX(OMP, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) { \
406   /*const*/ double libxsmm_matinit_seed_ = (double)SEED; /* avoid constant conditional */ \
407   const double libxsmm_matinit_scale_ = (SCALE) * libxsmm_matinit_seed_ + (SCALE); \
408   const libxsmm_blasint libxsmm_matinit_ld_ = (libxsmm_blasint)LD; \
409   libxsmm_blasint libxsmm_matinit_i_ = 0, libxsmm_matinit_j_; \
410   LIBXSMM_OMP_VAR(libxsmm_matinit_i_); LIBXSMM_OMP_VAR(libxsmm_matinit_j_); \
411   if (0 != libxsmm_matinit_seed_) { \
412     OMP(parallel for private(libxsmm_matinit_i_, libxsmm_matinit_j_)) \
413     for (libxsmm_matinit_i_ = 0; libxsmm_matinit_i_ < ((libxsmm_blasint)NCOLS); ++libxsmm_matinit_i_) { \
414       for (libxsmm_matinit_j_ = 0; libxsmm_matinit_j_ < ((libxsmm_blasint)NROWS); ++libxsmm_matinit_j_) { \
415         const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \
416         (DST)[libxsmm_matinit_k_] = (TYPE)(libxsmm_matinit_scale_ / (1.0 + libxsmm_matinit_k_)); \
417       } \
418       for (; libxsmm_matinit_j_ < libxsmm_matinit_ld_; ++libxsmm_matinit_j_) { \
419         const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \
420         (DST)[libxsmm_matinit_k_] = (TYPE)SEED; \
421       } \
422     } \
423   } \
424   else { /* shuffle based initialization */ \
425     const unsigned int libxsmm_matinit_maxval_ = ((unsigned int)NCOLS) * ((unsigned int)libxsmm_matinit_ld_); \
426     const TYPE libxsmm_matinit_maxval2_ = (TYPE)(libxsmm_matinit_maxval_ / 2), libxsmm_matinit_inv_ = (TYPE)((SCALE) / libxsmm_matinit_maxval2_); \
427     const size_t libxsmm_matinit_shuffle_ = libxsmm_shuffle(libxsmm_matinit_maxval_); \
428     OMP(parallel for private(libxsmm_matinit_i_, libxsmm_matinit_j_)) \
429     for (libxsmm_matinit_i_ = 0; libxsmm_matinit_i_ < ((libxsmm_blasint)NCOLS); ++libxsmm_matinit_i_) { \
430       for (libxsmm_matinit_j_ = 0; libxsmm_matinit_j_ < libxsmm_matinit_ld_; ++libxsmm_matinit_j_) { \
431         const libxsmm_blasint libxsmm_matinit_k_ = libxsmm_matinit_i_ * libxsmm_matinit_ld_ + libxsmm_matinit_j_; \
432         (DST)[libxsmm_matinit_k_] = libxsmm_matinit_inv_ * /* normalize values to an interval of [-1, +1] */ \
433           ((TYPE)(libxsmm_matinit_shuffle_ * libxsmm_matinit_k_ % libxsmm_matinit_maxval_) - libxsmm_matinit_maxval2_); \
434       } \
435     } \
436   } \
437 }
438 
439 #define LIBXSMM_MATINIT(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \
440   LIBXSMM_MATINIT_AUX(LIBXSMM_ELIDE, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE)
441 #define LIBXSMM_MATINIT_SEQ(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \
442   LIBXSMM_MATINIT(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE)
443 #define LIBXSMM_MATINIT_OMP(TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE) \
444   LIBXSMM_MATINIT_AUX(LIBXSMM_PRAGMA_OMP, TYPE, SEED, DST, NROWS, NCOLS, LD, SCALE)
445 
446 /** Call libxsmm_gemm_print using LIBXSMM's GEMM-flags. */
447 #define LIBXSMM_GEMM_PRINT(OSTREAM, PRECISION, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) \
448   LIBXSMM_GEMM_PRINT2(OSTREAM, PRECISION, PRECISION, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC)
449 #define LIBXSMM_GEMM_PRINT2(OSTREAM, IPREC, OPREC, FLAGS, M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC) \
450   libxsmm_gemm_dprint2(OSTREAM, (libxsmm_gemm_precision)(IPREC), (libxsmm_gemm_precision)(OPREC), \
451     /* Use 'n' (instead of 'N') avoids warning about "no macro replacement within a character constant". */ \
452     (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & (FLAGS)) ? 'n' : 't'), \
453     (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & (FLAGS)) ? 'n' : 't'), \
454     M, N, K, DALPHA, A, LDA, B, LDB, DBETA, C, LDC)
455 
456 /**
457  * Utility function, which either prints information about the GEMM call
458  * or dumps (FILE/ostream=0) all input and output data into MHD files.
459  * The Meta Image Format (MHD) is suitable for visual inspection using,
460  * e.g., ITK-SNAP or ParaView.
461  */
462 LIBXSMM_API void libxsmm_gemm_print(void* ostream,
463   libxsmm_gemm_precision precision, const char* transa, const char* transb,
464   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
465   const void* alpha, const void* a, const libxsmm_blasint* lda,
466   const void* b, const libxsmm_blasint* ldb,
467   const void* beta, void* c, const libxsmm_blasint* ldc);
468 LIBXSMM_API void libxsmm_gemm_print2(void* ostream,
469   libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb,
470   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
471   const void* alpha, const void* a, const libxsmm_blasint* lda,
472   const void* b, const libxsmm_blasint* ldb,
473   const void* beta, void* c, const libxsmm_blasint* ldc);
474 LIBXSMM_API void libxsmm_gemm_dprint(void* ostream,
475   libxsmm_gemm_precision precision, char transa, char transb,
476   libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
477   double dalpha, const void* a, libxsmm_blasint lda,
478   const void* b, libxsmm_blasint ldb,
479   double dbeta, void* c, libxsmm_blasint ldc);
480 LIBXSMM_API void libxsmm_gemm_dprint2(void* ostream,
481   libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, char transa, char transb,
482   libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
483   double dalpha, const void* a, libxsmm_blasint lda,
484   const void* b, libxsmm_blasint ldb,
485   double dbeta, void* c, libxsmm_blasint ldc);
486 LIBXSMM_API void libxsmm_gemm_xprint(void* ostream,
487   libxsmm_xmmfunction kernel, const void* a, const void* b, void* c);
488 
489 /** GEMM_BATCH: fall-back prototype functions served by any compliant LAPACK/BLAS. */
490 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemm_batch_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm_batch));
491 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemm_batch_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float, gemm_batch));
492 /** GEMM: fall-back prototype functions served by any compliant LAPACK/BLAS. */
493 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemm_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemm));
494 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemm_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float,  gemm));
495 /** GEMV: fall-back prototype functions served by any compliant LAPACK/BLAS. */
496 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dgemv_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, double, gemv));
497 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sgemv_function)(LIBXSMM_BLAS_SYMBOL_SIGNATURE(const*, *, float,  gemv));
498 /** Helper function to consume arguments when called. */
499 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sink_function)(LIBXSMM_VARIADIC);
500 
501 /** The original BLAS functions. */
502 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch_function);
503 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch_function);
504 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemm_function libxsmm_original_dgemm_function);
505 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemm_function libxsmm_original_sgemm_function);
506 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_dgemv_function libxsmm_original_dgemv_function);
507 LIBXSMM_APIVAR_PUBLIC(/*volatile*/libxsmm_sgemv_function libxsmm_original_sgemv_function);
508 LIBXSMM_API libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void);
509 LIBXSMM_API libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void);
510 LIBXSMM_API libxsmm_dgemm_function libxsmm_original_dgemm(void);
511 LIBXSMM_API libxsmm_sgemm_function libxsmm_original_sgemm(void);
512 LIBXSMM_API libxsmm_dgemv_function libxsmm_original_dgemv(void);
513 LIBXSMM_API libxsmm_sgemv_function libxsmm_original_sgemv(void);
514 LIBXSMM_API libxsmm_sink_function libxsmm_blas_error(const char* symbol);
515 LIBXSMM_API void libxsmm_sink(LIBXSMM_VARIADIC);
516 
517 /**
518  * General dense matrix multiplication, which re-exposes LAPACK/BLAS
519  * but allows to rely on LIBXSMM's defaults (libxsmm_config.h)
520  * when supplying NULL-arguments in certain places.
521  */
522 LIBXSMM_API void libxsmm_blas_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
523   const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
524   const void* alpha, const void* a, const libxsmm_blasint* lda,
525   const void* b, const libxsmm_blasint* ldb,
526   const void* beta, void* c, const libxsmm_blasint* ldc);
527 
528 #define libxsmm_blas_dgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
529   libxsmm_blas_xgemm(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, \
530     TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
531 #define libxsmm_blas_sgemm(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
532   libxsmm_blas_xgemm(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, \
533     TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
534 
535 #define libxsmm_dgemm_omp(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
536   libxsmm_xgemm_omp(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, \
537     TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
538 #define libxsmm_sgemm_omp(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
539   libxsmm_xgemm_omp(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, \
540     TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
541 
542 /** Translates GEMM prefetch request into prefetch-enumeration (incl. FE's auto-prefetch). */
543 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_xprefetch(const int* prefetch);
544 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_prefetch(int prefetch);
545 
546 #endif /*LIBXSMM_FRONTEND_H*/
547 
548