1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_gemm.h"
12 #include "libxsmm_xcopy.h"
13 #include "libxsmm_hash.h"
14 #include <libxsmm_mhd.h>
15 
16 #if defined(LIBXSMM_OFFLOAD_TARGET)
17 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
18 #endif
19 #if !defined(LIBXSMM_NO_LIBM)
20 # include <math.h>
21 #endif
22 #if defined(LIBXSMM_OFFLOAD_TARGET)
23 # pragma offload_attribute(pop)
24 #endif
25 
26 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS) && \
27   /* TODO: fully support calling convention */ \
28   (defined(_WIN32) || defined(__CYGWIN__))
29 # define LIBXSMM_GEMM_NOJIT_TRANS
30 #endif
31 #if !defined(LIBXSMM_GEMM_KPARALLEL) && 0
32 # define LIBXSMM_GEMM_KPARALLEL
33 #endif
34 #if !defined(LIBXSMM_GEMM_BATCHSIZE)
35 # define LIBXSMM_GEMM_BATCHSIZE 1024
36 #endif
37 #if !defined(LIBXSMM_GEMM_TASKGRAIN)
38 # define LIBXSMM_GEMM_TASKGRAIN 128
39 #endif
40 #if !defined(LIBXSMM_GEMM_BATCHREDUCE) && !defined(_WIN32) && !defined(__CYGWIN__) /* not supported */
41 # define LIBXSMM_GEMM_BATCHREDUCE
42 #endif
43 #if !defined(LIBXSMM_GEMM_BATCHSCALE) && (defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP))
44 #define LIBXSMM_GEMM_BATCHSCALE ((unsigned int)LIBXSMM_ROUND(sizeof(libxsmm_mmbatch_item) * (LIBXSMM_GEMM_MMBATCH_SCALE)))
45 #endif
46 #if defined(LIBXSMM_BUILD)
47 # define LIBXSMM_GEMM_WEAK LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK
48 #else
49 # define LIBXSMM_GEMM_WEAK LIBXSMM_API
50 #endif
51 
52 #if (0 != LIBXSMM_SYNC) /** Locks for the batch interface (duplicated C indexes). */
53 # define LIBXSMM_GEMM_LOCKIDX(IDX, NPOT) LIBXSMM_MOD2(LIBXSMM_CRC32U(LIBXSMM_BLASINT_NBITS)(2507/*seed*/, &(IDX)), NPOT)
54 # define LIBXSMM_GEMM_LOCKPTR(PTR, NPOT) LIBXSMM_MOD2(LIBXSMM_CRC32U(LIBXSMM_BITS)(1975/*seed*/, &(PTR)), NPOT)
55 # if !defined(LIBXSMM_GEMM_MAXNLOCKS)
56 #   define LIBXSMM_GEMM_MAXNLOCKS 1024
57 # endif
58 # if !defined(LIBXSMM_GEMM_LOCKFWD)
59 #   define LIBXSMM_GEMM_LOCKFWD
60 # endif
61 # if LIBXSMM_LOCK_TYPE_ISPOD(LIBXSMM_GEMM_LOCK)
62 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_gemm_locktype {
63   char pad[LIBXSMM_CACHELINE];
64   LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) state;
65 } internal_gemm_locktype;
66 # else
67 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_gemm_locktype {
68   LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) state;
69 } internal_gemm_locktype;
70 # endif
71 LIBXSMM_APIVAR_DEFINE(internal_gemm_locktype internal_gemm_lock[LIBXSMM_GEMM_MAXNLOCKS]);
72 LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_nlocks); /* populated number of locks */
73 #endif
74 
75 /* definition of corresponding variables */
76 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch_function);
77 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch_function);
78 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemm_function libxsmm_original_dgemm_function);
79 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemm_function libxsmm_original_sgemm_function);
80 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemv_function libxsmm_original_dgemv_function);
81 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemv_function libxsmm_original_sgemv_function);
82 /* definition of corresponding variables */
83 LIBXSMM_APIVAR_PUBLIC_DEF(libxsmm_gemm_descriptor libxsmm_mmbatch_desc);
84 LIBXSMM_APIVAR_PUBLIC_DEF(void* libxsmm_mmbatch_array);
85 LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) libxsmm_mmbatch_lock);
86 LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_mmbatch_size);
87 LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_gemm_npargroups);
88 LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_gemm_taskgrain);
89 LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_gemm_tasks);
90 LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_gemm_wrap);
91 
92 LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch_default);
93 /** Determines the prefetch strategy, which is used in case of LIBXSMM_PREFETCH_AUTO. */
94 LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch);
95 
96 /** Prefetch strategy for tiled GEMM. */
97 LIBXSMM_APIVAR_DEFINE(libxsmm_gemm_prefetch_type internal_gemm_tiled_prefetch);
98 /** Vector width used for GEMM. */
99 LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_vwidth);
100 /** Limit the M-extent of the tile. */
101 LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_mlimit);
102 /** Table of M-extents per type-size (tile shape). */
103 LIBXSMM_APIVAR_DEFINE(float internal_gemm_nstretch);
104 /** Table of M-extents per type-size (tile shape). */
105 LIBXSMM_APIVAR_DEFINE(float internal_gemm_kstretch);
106 /** Determines if batch-reduce is enabled */
107 LIBXSMM_APIVAR_DEFINE(int internal_gemm_batchreduce);
108 
109 
110 #if defined(LIBXSMM_BUILD)
LIBXSMM_FSYMBOL(__real_dgemm_batch)111 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemm_batch)(
112   const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
113   const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
114   const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
115 {
116 #if (0 != LIBXSMM_BLAS)
117 # if defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
118   if (0 > libxsmm_gemm_wrap) {
119     LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
120       alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
121       group_count, group_size);
122   }
123   else
124 # endif
125   {
126     const libxsmm_blasint ptrsize = sizeof(void*);
127     libxsmm_blasint i, j = 0;
128     LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != group_count && NULL != group_size);
129     LIBXSMM_ASSERT(NULL != m_array && NULL != n_array && NULL != k_array && NULL != lda_array && NULL != ldb_array && NULL != ldc_array);
130     LIBXSMM_ASSERT(NULL != a_array && NULL != b_array && NULL != c_array && NULL != alpha_array && NULL != beta_array);
131     for (i = 0; i < *group_count; ++i) {
132       const libxsmm_blasint size = group_size[i];
133       libxsmm_dmmbatch_blas(transa_array + i, transb_array + i, m_array[i], n_array[i], k_array[i], alpha_array + i,
134         a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i,
135         c_array + j, ldc_array + i, 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
136       j += size;
137     }
138   }
139 #else
140   libxsmm_blas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
141     alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
142     group_count, group_size);
143 #endif
144 }
145 
146 
LIBXSMM_FSYMBOL(__real_sgemm_batch)147 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemm_batch)(
148   const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
149   const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
150   const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
151 {
152 #if (0 != LIBXSMM_BLAS)
153 # if defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
154   if (0 > libxsmm_gemm_wrap) {
155     LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
156       alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
157       group_count, group_size);
158   }
159   else
160 # endif
161   {
162     const libxsmm_blasint ptrsize = sizeof(void*);
163     libxsmm_blasint i, j = 0;
164     LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != group_count && NULL != group_size);
165     LIBXSMM_ASSERT(NULL != m_array && NULL != n_array && NULL != k_array && NULL != lda_array && NULL != ldb_array && NULL != ldc_array);
166     LIBXSMM_ASSERT(NULL != a_array && NULL != b_array && NULL != c_array && NULL != alpha_array && NULL != beta_array);
167     for (i = 0; i < *group_count; ++i) {
168       const libxsmm_blasint size = group_size[i];
169       libxsmm_smmbatch_blas(transa_array + i, transb_array + i, m_array[i], n_array[i], k_array[i], alpha_array + i,
170         a_array + i, lda_array + i, b_array + i, ldb_array + i, beta_array + i,
171         c_array + i, ldc_array + i, 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
172       j += size;
173     }
174   }
175 #else
176   libxsmm_blas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
177     alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
178     group_count, group_size);
179 #endif
180 }
181 
182 
LIBXSMM_FSYMBOL(__real_dgemm)183 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemm)(const char* transa, const char* transb,
184   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
185   const double* alpha, const double* a, const libxsmm_blasint* lda,
186   const double* b, const libxsmm_blasint* ldb,
187   const double* beta, double* c, const libxsmm_blasint* ldc)
188 {
189 #if (0 != LIBXSMM_BLAS)
190   LIBXSMM_FSYMBOL(dgemm)((LIBXSMM_BLAS_CONST char*)transa, (LIBXSMM_BLAS_CONST char*)transb,
191     (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, (LIBXSMM_BLAS_CONST libxsmm_blasint*)k,
192     (LIBXSMM_BLAS_CONST double*)alpha, (LIBXSMM_BLAS_CONST double*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
193                                        (LIBXSMM_BLAS_CONST double*)b, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldb,
194     (LIBXSMM_BLAS_CONST double*) beta,                             c, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldc);
195 #else
196   libxsmm_blas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
197 #endif
198 }
199 
200 
LIBXSMM_FSYMBOL(__real_sgemm)201 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemm)(const char* transa, const char* transb,
202   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
203   const float* alpha, const float* a, const libxsmm_blasint* lda,
204   const float* b, const libxsmm_blasint* ldb,
205   const float* beta, float* c, const libxsmm_blasint* ldc)
206 {
207 #if (0 != LIBXSMM_BLAS)
208   LIBXSMM_FSYMBOL(sgemm)((LIBXSMM_BLAS_CONST char*)transa, (LIBXSMM_BLAS_CONST char*)transb,
209     (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, (LIBXSMM_BLAS_CONST libxsmm_blasint*)k,
210     (LIBXSMM_BLAS_CONST float*)alpha, (LIBXSMM_BLAS_CONST float*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
211                                       (LIBXSMM_BLAS_CONST float*)b, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldb,
212     (LIBXSMM_BLAS_CONST float*) beta,                            c, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldc);
213 #else
214   libxsmm_blas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
215 #endif
216 }
217 
218 
LIBXSMM_FSYMBOL(__real_dgemv)219 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
220   const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx,
221   const double* beta, double* y, const libxsmm_blasint* incy)
222 {
223 #if (0 != LIBXSMM_BLAS)
224   LIBXSMM_FSYMBOL(dgemv)((LIBXSMM_BLAS_CONST char*)trans, (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n,
225     (LIBXSMM_BLAS_CONST double*)alpha, (LIBXSMM_BLAS_CONST double*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
226                                        (LIBXSMM_BLAS_CONST double*)x, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incx,
227     (LIBXSMM_BLAS_CONST double*) beta,                             y, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incy);
228 #else
229   libxsmm_blas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
230 #endif
231 }
232 
233 
LIBXSMM_FSYMBOL(__real_sgemv)234 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
235   const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx,
236   const float* beta, float* y, const libxsmm_blasint* incy)
237 {
238 #if (0 != LIBXSMM_BLAS)
239   LIBXSMM_FSYMBOL(sgemv)((LIBXSMM_BLAS_CONST char*)trans, (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n,
240     (LIBXSMM_BLAS_CONST float*)alpha, (LIBXSMM_BLAS_CONST float*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
241                                       (LIBXSMM_BLAS_CONST float*)x, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incx,
242     (LIBXSMM_BLAS_CONST float*) beta,                            y, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incy);
243 #else
244   libxsmm_blas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
245 #endif
246 }
247 
248 
__real_dgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const double alpha_array[],const double * a_array[],const libxsmm_blasint lda_array[],const double * b_array[],const libxsmm_blasint ldb_array[],const double beta_array[],double * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])249 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void __real_dgemm_batch(
250   const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
251   const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
252   const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
253 {
254   LIBXSMM_FSYMBOL(__real_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
255     alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
256     group_count, group_size);
257 }
258 
259 
__real_sgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const float alpha_array[],const float * a_array[],const libxsmm_blasint lda_array[],const float * b_array[],const libxsmm_blasint ldb_array[],const float beta_array[],float * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])260 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void __real_sgemm_batch(
261   const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
262   const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
263   const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
264 {
265   LIBXSMM_FSYMBOL(__real_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
266     alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
267     group_count, group_size);
268 }
269 #endif /*defined(LIBXSMM_BUILD)*/
270 
271 
libxsmm_original_dgemm_batch(void)272 LIBXSMM_GEMM_WEAK libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void)
273 {
274 #if (0 != LIBXSMM_BLAS) && defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
275   LIBXSMM_BLAS_WRAPPER(1, double, gemm_batch, libxsmm_original_dgemm_batch_function, NULL/*unknown*/);
276   /*LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_batch_function);*/
277 #else
278   LIBXSMM_BLAS_WRAPPER(0, double, gemm_batch, libxsmm_original_dgemm_batch_function, NULL/*unknown*/);
279 #endif
280   return libxsmm_original_dgemm_batch_function;
281 }
282 
283 
libxsmm_original_sgemm_batch(void)284 LIBXSMM_GEMM_WEAK libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void)
285 {
286 #if (0 != LIBXSMM_BLAS) && defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
287   LIBXSMM_BLAS_WRAPPER(1, float, gemm_batch, libxsmm_original_sgemm_batch_function, NULL/*unknown*/);
288   /*LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_batch_function);*/
289 #else
290   LIBXSMM_BLAS_WRAPPER(0, float, gemm_batch, libxsmm_original_sgemm_batch_function, NULL/*unknown*/);
291 #endif
292   return libxsmm_original_sgemm_batch_function;
293 }
294 
295 
libxsmm_original_dgemm(void)296 LIBXSMM_GEMM_WEAK libxsmm_dgemm_function libxsmm_original_dgemm(void)
297 {
298 #if (0 != LIBXSMM_BLAS)
299   LIBXSMM_BLAS_WRAPPER(1, double, gemm, libxsmm_original_dgemm_function, NULL/*unknown*/);
300   LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_function);
301 #else
302   LIBXSMM_BLAS_WRAPPER(0, double, gemm, libxsmm_original_dgemm_function, NULL/*unknown*/);
303 #endif
304   return libxsmm_original_dgemm_function;
305 }
306 
307 
libxsmm_original_sgemm(void)308 LIBXSMM_GEMM_WEAK libxsmm_sgemm_function libxsmm_original_sgemm(void)
309 {
310 #if (0 != LIBXSMM_BLAS)
311   LIBXSMM_BLAS_WRAPPER(1, float, gemm, libxsmm_original_sgemm_function, NULL/*unknown*/);
312   LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_function);
313 #else
314   LIBXSMM_BLAS_WRAPPER(0, float, gemm, libxsmm_original_sgemm_function, NULL/*unknown*/);
315 #endif
316   return libxsmm_original_sgemm_function;
317 }
318 
319 
libxsmm_original_dgemv(void)320 LIBXSMM_GEMM_WEAK libxsmm_dgemv_function libxsmm_original_dgemv(void)
321 {
322 #if (0 != LIBXSMM_BLAS)
323   LIBXSMM_BLAS_WRAPPER(1, double, gemv, libxsmm_original_dgemv_function, NULL/*unknown*/);
324   LIBXSMM_ASSERT(NULL != libxsmm_original_dgemv_function);
325 #else
326   LIBXSMM_BLAS_WRAPPER(0, double, gemv, libxsmm_original_dgemv_function, NULL/*unknown*/);
327 #endif
328   return libxsmm_original_dgemv_function;
329 }
330 
331 
libxsmm_original_sgemv(void)332 LIBXSMM_GEMM_WEAK libxsmm_sgemv_function libxsmm_original_sgemv(void)
333 {
334 #if (0 != LIBXSMM_BLAS)
335   LIBXSMM_BLAS_WRAPPER(1, float, gemv, libxsmm_original_sgemv_function, NULL/*unknown*/);
336   LIBXSMM_ASSERT(NULL != libxsmm_original_sgemv_function);
337 #else
338   LIBXSMM_BLAS_WRAPPER(0, float, gemv, libxsmm_original_sgemv_function, NULL/*unknown*/);
339 #endif
340   return libxsmm_original_sgemv_function;
341 }
342 
343 
libxsmm_blas_error(const char * symbol)344 LIBXSMM_API libxsmm_sink_function libxsmm_blas_error(const char* symbol)
345 {
346   static int error_once = 0;
347   LIBXSMM_BLAS_ERROR(symbol, &error_once);
348   return libxsmm_sink;
349 }
350 
351 
libxsmm_gemm_init(int archid)352 LIBXSMM_API_INTERN void libxsmm_gemm_init(int archid)
353 {
354   const char* env_w = getenv("LIBXSMM_GEMM_WRAP");
355   LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_GEMM_LOCK) attr;
356   LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_GEMM_LOCK, &attr);
357 #if defined(LIBXSMM_WRAP) /* determines if wrap is considered */
358   { /* intercepted GEMMs (1: sequential and non-tiled, 2: parallelized and tiled) */
359 # if defined(__STATIC) /* with static library the user controls interceptor already */
360     libxsmm_gemm_wrap = ((NULL == env_w || 0 == *env_w) /* LIBXSMM_WRAP=0: no promotion */
361       ? (0 < (LIBXSMM_WRAP) ? (LIBXSMM_WRAP + 2) : (LIBXSMM_WRAP - 2)) : atoi(env_w));
362 # else
363     libxsmm_gemm_wrap = ((NULL == env_w || 0 == *env_w) ? (LIBXSMM_WRAP) : atoi(env_w));
364 # endif
365   }
366 #endif
367   { /* setup prefetch strategy for tiled GEMMs */
368     const char *const env_p = getenv("LIBXSMM_TGEMM_PREFETCH");
369     const libxsmm_gemm_prefetch_type tiled_prefetch_default = LIBXSMM_GEMM_PREFETCH_AL2_AHEAD;
370     const int uid = ((NULL == env_p || 0 == *env_p) ? LIBXSMM_PREFETCH_AUTO/*default*/ : atoi(env_p));
371     internal_gemm_tiled_prefetch = (0 <= uid ? libxsmm_gemm_uid2prefetch(uid) : tiled_prefetch_default);
372   }
373 #if (0 != LIBXSMM_SYNC)
374   { /* initialize locks for the batch interface */
375     const char *const env_locks = getenv("LIBXSMM_GEMM_NLOCKS");
376     const int nlocks = ((NULL == env_locks || 0 == *env_locks) ? -1/*default*/ : atoi(env_locks));
377     unsigned int i;
378     internal_gemm_nlocks = LIBXSMM_UP2POT(0 > nlocks ? (LIBXSMM_GEMM_MAXNLOCKS) : LIBXSMM_MIN(nlocks, LIBXSMM_GEMM_MAXNLOCKS));
379     for (i = 0; i < internal_gemm_nlocks; ++i) LIBXSMM_LOCK_INIT(LIBXSMM_GEMM_LOCK, &internal_gemm_lock[i].state, &attr);
380   }
381 #endif
382 #if defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP)
383   { /* determines if batch-reduce kernel or batch-wrap is considered */
384     const char *const env_r = getenv("LIBXSMM_GEMM_BATCHREDUCE");
385     internal_gemm_batchreduce = (NULL == env_r || 0 == *env_r) ? 0 : atoi(env_r);
386     if ((NULL == env_w || 0 == *env_w) && ((LIBXSMM_GEMM_MMBATCH_VERBOSITY <= libxsmm_verbosity && INT_MAX != libxsmm_verbosity) || 0 > libxsmm_verbosity)) {
387       libxsmm_mmbatch_desc.flags = LIBXSMM_MMBATCH_FLAG_STATISTIC; /* enable auto-batch statistic */
388       internal_gemm_batchreduce = 0;
389     }
390     if (0 != internal_gemm_batchreduce || 0 != libxsmm_gemm_wrap) {
391       const char *const env_b = getenv("LIBXSMM_GEMM_BATCHSIZE");
392       const int env_bi = (NULL == env_b || 0 == *env_b) ? -1/*auto*/ : atoi(env_b);
393       const unsigned int env_bu = (unsigned int)(0 >= env_bi ? (LIBXSMM_GEMM_BATCHSIZE) : env_bi);
394       const unsigned int batchscale = LIBXSMM_ABS(internal_gemm_batchreduce) * 2048/*arbitrary*/ * 2/*A and B-matrices*/ * sizeof(void*);
395       const unsigned int minsize = LIBXSMM_UPDIV(batchscale * env_bu, LIBXSMM_GEMM_BATCHSCALE);
396       const unsigned int batchsize = LIBXSMM_MAX(env_bu, minsize);
397       const void *const extra = NULL;
398       LIBXSMM_ASSERT(1 < (LIBXSMM_GEMM_MMBATCH_SCALE) && NULL == libxsmm_mmbatch_array);
399       if (EXIT_SUCCESS == libxsmm_xmalloc(&libxsmm_mmbatch_array, (size_t)batchsize * (LIBXSMM_GEMM_BATCHSCALE), 0/*auto-alignment*/,
400         LIBXSMM_MALLOC_FLAG_PRIVATE /*| LIBXSMM_MALLOC_FLAG_SCRATCH*/, &extra, sizeof(extra)))
401       {
402         LIBXSMM_LOCK_INIT(LIBXSMM_GEMM_LOCK, &libxsmm_mmbatch_lock, &attr);
403         LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array);
404         libxsmm_mmbatch_size = batchsize;
405       }
406     }
407   }
408 #else
409   LIBXSMM_UNUSED(env_w);
410 #endif
411   { /* determines grain-size of tasks (when available) */
412     const char *const env_s = getenv("LIBXSMM_GEMM_NPARGROUPS");
413     libxsmm_gemm_npargroups = ((NULL == env_s || 0 == *env_s || 0 >= atoi(env_s))
414       ? (LIBXSMM_GEMM_NPARGROUPS) : atoi(env_s));
415   }
416   if (LIBXSMM_X86_AVX512_CORE <= archid) {
417     internal_gemm_vwidth = 64;
418     internal_gemm_mlimit = 48;
419     internal_gemm_nstretch = 3.0f;
420     internal_gemm_kstretch = 2.0f;
421   }
422   else if (LIBXSMM_X86_AVX512_MIC <= archid) {
423     internal_gemm_vwidth = 64;
424     internal_gemm_mlimit = 64;
425     internal_gemm_nstretch = 1.0f;
426     internal_gemm_kstretch = 1.0f;
427   }
428   else if (LIBXSMM_X86_AVX2 <= archid) {
429     internal_gemm_vwidth = 32;
430     internal_gemm_mlimit = 48;
431     internal_gemm_nstretch = 3.0f;
432     internal_gemm_kstretch = 2.0f;
433   }
434   else if (LIBXSMM_X86_AVX <= archid) {
435     internal_gemm_vwidth = 32;
436     internal_gemm_mlimit = 48;
437     internal_gemm_nstretch = 5.0f;
438     internal_gemm_kstretch = 1.0f;
439   }
440   else {
441     internal_gemm_vwidth = 16;
442     internal_gemm_mlimit = 48;
443     internal_gemm_nstretch = 7.0f;
444     internal_gemm_kstretch = 5.0f;
445   }
446   { /* setup tile sizes according to environment (LIBXSMM_TGEMM_M, LIBXSMM_TGEMM_N, LIBXSMM_TGEMM_K) */
447     const char *const env_m = getenv("LIBXSMM_TGEMM_M"), *const env_n = getenv("LIBXSMM_TGEMM_N"), *const env_k = getenv("LIBXSMM_TGEMM_K");
448     const int m = ((NULL == env_m || 0 == *env_m) ? 0 : atoi(env_m));
449     const int n = ((NULL == env_n || 0 == *env_n) ? 0 : atoi(env_n));
450     const int k = ((NULL == env_k || 0 == *env_k) ? 0 : atoi(env_k));
451     if (0 < m) {
452       if (0 < n) internal_gemm_nstretch = ((float)n) / m;
453       if (0 < k) internal_gemm_kstretch = ((float)k) / m;
454     }
455   }
456   { /* setup tile sizes according to environment (LIBXSMM_TGEMM_NS, LIBXSMM_TGEMM_KS) */
457     const char *const env_ns = getenv("LIBXSMM_TGEMM_NS"), *const env_ks = getenv("LIBXSMM_TGEMM_KS");
458     const double ns = ((NULL == env_ns || 0 == *env_ns) ? 0 : atof(env_ns));
459     const double ks = ((NULL == env_ks || 0 == *env_ks) ? 0 : atof(env_ks));
460     if (0 < ns) internal_gemm_nstretch = (float)LIBXSMM_MIN(24, ns);
461     if (0 < ks) internal_gemm_kstretch = (float)LIBXSMM_MIN(24, ks);
462   }
463   { /* determines if OpenMP tasks are used (when available) */
464     const char *const env_t = getenv("LIBXSMM_GEMM_TASKS");
465     const int gemm_tasks = ((NULL == env_t || 0 == *env_t) ? 0/*disabled*/ : atoi(env_t));
466     libxsmm_gemm_tasks = (0 <= gemm_tasks ? LIBXSMM_ABS(gemm_tasks) : 1/*enabled*/);
467   }
468   { /* determines grain-size of tasks (when available) */
469     const char *const env_g = getenv("LIBXSMM_GEMM_TASKGRAIN");
470     const int gemm_taskgrain = ((NULL == env_g || 0 == *env_g || 0 >= atoi(env_g))
471       ? (LIBXSMM_GEMM_TASKGRAIN) : atoi(env_g));
472     /* adjust grain-size or scale beyond the number of threads */
473     libxsmm_gemm_taskgrain = LIBXSMM_MAX(0 < libxsmm_gemm_tasks ? (gemm_taskgrain / libxsmm_gemm_tasks) : gemm_taskgrain, 1);
474   }
475   LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_GEMM_LOCK, &attr);
476   /* determine BLAS function-pointers */
477   libxsmm_original_dgemm_batch();
478   libxsmm_original_sgemm_batch();
479   libxsmm_original_dgemm();
480   libxsmm_original_sgemm();
481   libxsmm_original_dgemv();
482   libxsmm_original_sgemv();
483 }
484 
485 
libxsmm_gemm_finalize(void)486 LIBXSMM_API_INTERN void libxsmm_gemm_finalize(void)
487 {
488 #if (0 != LIBXSMM_SYNC)
489   unsigned int i; for (i = 0; i < internal_gemm_nlocks; ++i) LIBXSMM_LOCK_DESTROY(LIBXSMM_GEMM_LOCK, &internal_gemm_lock[i].state);
490 #endif
491 #if defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP)
492   if (NULL != libxsmm_mmbatch_array) {
493     void *extra = NULL, *const mmbatch_array = libxsmm_mmbatch_array;
494     if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra) && NULL != extra) {
495       const libxsmm_mmbatch_flush_function flush = *(libxsmm_mmbatch_flush_function*)extra;
496       if (NULL != flush) flush();
497     }
498 #if !defined(NDEBUG)
499     libxsmm_mmbatch_array = NULL;
500 #endif
501     libxsmm_xfree(mmbatch_array, 0/*no check*/);
502     LIBXSMM_LOCK_DESTROY(LIBXSMM_GEMM_LOCK, &libxsmm_mmbatch_lock);
503   }
504 #endif
505 }
506 
507 
libxsmm_get_gemm_xprefetch(const int * prefetch)508 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_xprefetch(const int* prefetch)
509 {
510   LIBXSMM_INIT /* load configuration */
511   return libxsmm_get_gemm_prefetch(NULL == prefetch ? ((int)libxsmm_gemm_auto_prefetch) : *prefetch);
512 }
513 
514 
libxsmm_get_gemm_prefetch(int prefetch)515 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_prefetch(int prefetch)
516 {
517   libxsmm_gemm_prefetch_type result;
518 #if !defined(_WIN32) && !defined(__CYGWIN__) && !defined(__MINGW32__)
519   if (0 > prefetch) {
520     LIBXSMM_INIT /* load configuration */
521     result = libxsmm_gemm_auto_prefetch_default;
522   }
523   else {
524     result = (libxsmm_gemm_prefetch_type)prefetch;
525   }
526 #else /* TODO: full support for Windows calling convention */
527   result = LIBXSMM_GEMM_PREFETCH_NONE;
528   LIBXSMM_UNUSED(prefetch);
529 #endif
530   return result;
531 }
532 
533 
libxsmm_gemm_prefetch2uid(libxsmm_gemm_prefetch_type prefetch)534 LIBXSMM_API_INTERN int libxsmm_gemm_prefetch2uid(libxsmm_gemm_prefetch_type prefetch)
535 {
536   switch (prefetch) {
537     case LIBXSMM_GEMM_PREFETCH_SIGONLY:            return 2;
538     case LIBXSMM_GEMM_PREFETCH_BL2_VIA_C:          return 3;
539     case LIBXSMM_GEMM_PREFETCH_AL2_AHEAD:          return 4;
540     case LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD: return 5;
541     case LIBXSMM_GEMM_PREFETCH_AL2:                return 6;
542     case LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C:       return 7;
543     default: {
544       LIBXSMM_ASSERT(LIBXSMM_GEMM_PREFETCH_NONE == prefetch);
545       return 0;
546     }
547   }
548 }
549 
550 
libxsmm_gemm_uid2prefetch(int uid)551 LIBXSMM_API_INTERN libxsmm_gemm_prefetch_type libxsmm_gemm_uid2prefetch(int uid)
552 {
553   switch (uid) {
554     case 1: return LIBXSMM_GEMM_PREFETCH_NONE;               /* nopf */
555     case 2: return LIBXSMM_GEMM_PREFETCH_SIGONLY;            /* pfsigonly */
556     case 3: return LIBXSMM_GEMM_PREFETCH_BL2_VIA_C;          /* BL2viaC */
557     case 4: return LIBXSMM_GEMM_PREFETCH_AL2_AHEAD;          /* curAL2 */
558     case 5: return LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD; /* curAL2_BL2viaC */
559     case 6: return LIBXSMM_GEMM_PREFETCH_AL2;                /* AL2 */
560     case 7: return LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C;       /* AL2_BL2viaC */
561     default: {
562       if (0 != libxsmm_verbosity) { /* library code is expected to be mute */
563         static int error_once = 0;
564         if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) {
565           fprintf(stderr, "LIBXSMM WARNING: invalid prefetch strategy requested!\n");
566         }
567       }
568       return LIBXSMM_GEMM_PREFETCH_NONE;
569     }
570   }
571 }
572 
573 
libxsmm_gemm_print(void * ostream,libxsmm_gemm_precision precision,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)574 LIBXSMM_API void libxsmm_gemm_print(void* ostream,
575   libxsmm_gemm_precision precision, const char* transa, const char* transb,
576   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
577   const void* alpha, const void* a, const libxsmm_blasint* lda,
578   const void* b, const libxsmm_blasint* ldb,
579   const void* beta, void* c, const libxsmm_blasint* ldc)
580 {
581   libxsmm_gemm_print2(ostream, precision, precision, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
582 }
583 
584 
libxsmm_gemm_print2(void * ostream,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)585 LIBXSMM_API void libxsmm_gemm_print2(void* ostream,
586   libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb,
587   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
588   const void* alpha, const void* a, const libxsmm_blasint* lda,
589   const void* b, const libxsmm_blasint* ldb,
590   const void* beta, void* c, const libxsmm_blasint* ldc)
591 {
592   const libxsmm_blasint nn = *(n ? n : m), kk = *(k ? k : m);
593   const char ctransa = (char)(NULL != transa ? (*transa) : (0 == (LIBXSMM_FLAGS & LIBXSMM_GEMM_FLAG_TRANS_A) ? 'n' : 't'));
594   const char ctransb = (char)(NULL != transb ? (*transb) : (0 == (LIBXSMM_FLAGS & LIBXSMM_GEMM_FLAG_TRANS_B) ? 'n' : 't'));
595   const libxsmm_blasint ilda = (NULL != lda ? *lda : (('n' == ctransa || 'N' == ctransa) ? *m : kk));
596   const libxsmm_blasint ildb = (NULL != ldb ? *ldb : (('n' == ctransb || 'N' == ctransb) ? kk : nn));
597   const libxsmm_blasint ildc = *(NULL != ldc ? ldc : m);
598   libxsmm_mhd_elemtype mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_UNKNOWN;
599   char string_a[128], string_b[128], typeprefix = 0;
600 
601   switch (iprec | oprec) {
602     case LIBXSMM_GEMM_PRECISION_F64: {
603       LIBXSMM_ASSERT(iprec == oprec);
604       LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "%g", NULL != alpha ? *((const double*)alpha) : LIBXSMM_ALPHA);
605       LIBXSMM_SNPRINTF(string_b, sizeof(string_b), "%g", NULL != beta  ? *((const double*)beta)  : LIBXSMM_BETA);
606       mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_F64;
607       typeprefix = 'd';
608     } break;
609     case LIBXSMM_GEMM_PRECISION_F32: {
610       LIBXSMM_ASSERT(iprec == oprec);
611       LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "%g", NULL != alpha ? *((const float*)alpha) : LIBXSMM_ALPHA);
612       LIBXSMM_SNPRINTF(string_b, sizeof(string_b), "%g", NULL != beta  ? *((const float*)beta)  : LIBXSMM_BETA);
613       mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_F32;
614       typeprefix = 's';
615     } break;
616     default: if (0 != libxsmm_verbosity) { /* library code is expected to be mute */
617       static int error_once = 0;
618       if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { /* TODO: support I16, etc. */
619         fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n");
620       }
621     }
622   }
623 
624   if (0 != typeprefix) {
625     if (NULL != ostream) { /* print information about GEMM call */
626       if (NULL != a && NULL != b && NULL != c) {
627         fprintf((FILE*)ostream, "%cgemm('%c', '%c', %" PRIuPTR "/*m*/, %" PRIuPTR "/*n*/, %" PRIuPTR "/*k*/,\n"
628                                 "  %s/*alpha*/, %p/*a*/, %" PRIuPTR "/*lda*/,\n"
629                                 "              %p/*b*/, %" PRIuPTR "/*ldb*/,\n"
630                                 "   %s/*beta*/, %p/*c*/, %" PRIuPTR "/*ldc*/)",
631           typeprefix, ctransa, ctransb, (uintptr_t)*m, (uintptr_t)nn, (uintptr_t)kk,
632           string_a, a, (uintptr_t)ilda, b, (uintptr_t)ildb, string_b, c, (uintptr_t)ildc);
633       }
634       else {
635         fprintf((FILE*)ostream, "%cgemm(trans=%c%c mnk=%" PRIuPTR ",%" PRIuPTR ",%" PRIuPTR
636                                                  " ldx=%" PRIuPTR ",%" PRIuPTR ",%" PRIuPTR " a,b=%s,%s)",
637           typeprefix, ctransa, ctransb, (uintptr_t)*m, (uintptr_t)nn, (uintptr_t)kk,
638           (uintptr_t)ilda, (uintptr_t)ildb, (uintptr_t)ildc, string_a, string_b);
639       }
640     }
641     else { /* dump A, B, and C matrices into MHD files */
642       char extension_header[256];
643       size_t data_size[2], size[2];
644 
645       if (NULL != a) {
646         LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "TRANS = %c\nALPHA = %s", ctransa, string_a);
647         LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_a_%p.mhd", a);
648         data_size[0] = (size_t)ilda; data_size[1] = (size_t)kk; size[0] = (size_t)(*m); size[1] = (size_t)kk;
649         libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype,
650           NULL/*conversion*/, a, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/);
651       }
652       if (NULL != b) {
653         LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "\nTRANS = %c", ctransb);
654         LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_b_%p.mhd", b);
655         data_size[0] = (size_t)ildb; data_size[1] = (size_t)nn; size[0] = (size_t)kk; size[1] = (size_t)nn;
656         libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype,
657           NULL/*conversion*/, b, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/);
658       }
659       if (NULL != c) {
660         LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "BETA = %s", string_b);
661         LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_c_%p.mhd", c);
662         data_size[0] = (size_t)ildc; data_size[1] = (size_t)nn; size[0] = (size_t)(*m); size[1] = (size_t)nn;
663         libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype,
664           NULL/*conversion*/, c, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/);
665       }
666     }
667   }
668 }
669 
670 
libxsmm_gemm_dprint(void * ostream,libxsmm_gemm_precision precision,char transa,char transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,double dalpha,const void * a,libxsmm_blasint lda,const void * b,libxsmm_blasint ldb,double dbeta,void * c,libxsmm_blasint ldc)671 LIBXSMM_API void libxsmm_gemm_dprint(
672   void* ostream, libxsmm_gemm_precision precision, char transa, char transb,
673   libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, double dalpha, const void* a, libxsmm_blasint lda,
674   const void* b, libxsmm_blasint ldb, double dbeta, void* c, libxsmm_blasint ldc)
675 {
676   libxsmm_gemm_dprint2(ostream, precision, precision, transa, transb, m, n, k, dalpha, a, lda, b, ldb, dbeta, c, ldc);
677 }
678 
679 
libxsmm_gemm_dprint2(void * ostream,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,char transa,char transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,double dalpha,const void * a,libxsmm_blasint lda,const void * b,libxsmm_blasint ldb,double dbeta,void * c,libxsmm_blasint ldc)680 LIBXSMM_API void libxsmm_gemm_dprint2(
681   void* ostream, libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, char transa, char transb,
682   libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, double dalpha, const void* a, libxsmm_blasint lda,
683   const void* b, libxsmm_blasint ldb, double dbeta, void* c, libxsmm_blasint ldc)
684 {
685   switch (iprec) {
686     case LIBXSMM_GEMM_PRECISION_F64: {
687       libxsmm_gemm_print2(ostream, LIBXSMM_GEMM_PRECISION_F64, oprec, &transa, &transb,
688         &m, &n, &k, &dalpha, a, &lda, b, &ldb, &dbeta, c, &ldc);
689     } break;
690     case LIBXSMM_GEMM_PRECISION_F32: {
691       const float alpha = (float)dalpha, beta = (float)dbeta;
692       libxsmm_gemm_print2(ostream, LIBXSMM_GEMM_PRECISION_F32, oprec, &transa, &transb,
693         &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
694     } break;
695     default: {
696       libxsmm_gemm_print2(ostream, iprec, oprec, &transa, &transb,
697         &m, &n, &k, &dalpha, a, &lda, b, &ldb, &dbeta, c, &ldc);
698     }
699   }
700 }
701 
702 
libxsmm_gemm_xprint(void * ostream,libxsmm_xmmfunction kernel,const void * a,const void * b,void * c)703 LIBXSMM_API void libxsmm_gemm_xprint(void* ostream,
704   libxsmm_xmmfunction kernel, const void* a, const void* b, void* c)
705 {
706   const libxsmm_descriptor* desc;
707   libxsmm_code_pointer code;
708   size_t code_size;
709   code.xgemm = kernel;
710   if (NULL != libxsmm_get_kernel_xinfo(code, &desc, &code_size) &&
711       NULL != desc && LIBXSMM_KERNEL_KIND_MATMUL == desc->kind)
712   {
713     libxsmm_gemm_dprint2(ostream,
714       (libxsmm_gemm_precision)LIBXSMM_GETENUM_INP(desc->gemm.desc.datatype),
715       (libxsmm_gemm_precision)LIBXSMM_GETENUM_OUT(desc->gemm.desc.datatype),
716       (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & desc->gemm.desc.flags) ? 'N' : 'T'),
717       (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & desc->gemm.desc.flags) ? 'N' : 'T'),
718       (libxsmm_blasint)desc->gemm.desc.m, (libxsmm_blasint)desc->gemm.desc.n, (libxsmm_blasint)desc->gemm.desc.k,
719       /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, a,
720       (libxsmm_blasint)desc->gemm.desc.lda, b, (libxsmm_blasint)desc->gemm.desc.ldb,
721       0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, c, (libxsmm_blasint)desc->gemm.desc.ldc);
722     fprintf((FILE*)ostream, " = %p+%u", code.ptr_const, (unsigned int)code_size);
723   }
724 }
725 
726 
libxsmm_blas_xgemm(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)727 LIBXSMM_API void libxsmm_blas_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
728   const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
729   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
730   const void* beta, void* c, const libxsmm_blasint* ldc)
731 {
732   LIBXSMM_INIT
733   switch (iprec) {
734     case LIBXSMM_GEMM_PRECISION_F64: {
735       LIBXSMM_ASSERT(iprec == oprec);
736       LIBXSMM_BLAS_XGEMM(double, double, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
737     } break;
738     case LIBXSMM_GEMM_PRECISION_F32: {
739       LIBXSMM_ASSERT(iprec == oprec);
740       LIBXSMM_BLAS_XGEMM(float, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
741     } break;
742     default: if (0 != libxsmm_verbosity) { /* library code is expected to be mute */
743       static int error_once = 0;
744       LIBXSMM_UNUSED(oprec);
745       if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { /* TODO: support I16, etc. */
746         fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n");
747       }
748     }
749   }
750 }
751 
752 
libxsmm_gemm_plan_internal(unsigned int ntasks,unsigned int m,unsigned int n,unsigned int k,unsigned int tm,unsigned int tn,unsigned int tk,unsigned int * nmt,unsigned int * nnt,unsigned int * nkt,unsigned int * mt,unsigned int * nt,unsigned int * kt)753 LIBXSMM_API_INLINE int libxsmm_gemm_plan_internal(unsigned int ntasks,
754   unsigned int m, unsigned int n, unsigned int k,           /* whole problem size */
755   unsigned int tm, unsigned int tn, unsigned int tk,        /* tile size (kernel) */
756   unsigned int* nmt, unsigned int* nnt, unsigned int* nkt,  /* number of tiles */
757   unsigned int* mt, unsigned int* nt, unsigned int* kt)     /* number of tasks */
758 {
759   unsigned int result = EXIT_SUCCESS, replan = 0;
760   LIBXSMM_ASSERT(NULL != nmt && NULL != nnt && NULL != nkt);
761   LIBXSMM_ASSERT(NULL != mt && NULL != nt && NULL != kt);
762   LIBXSMM_ASSERT(0 < ntasks);
763   *nmt = (m + tm - 1) / LIBXSMM_MAX(tm, 1);
764   *nnt = (n + tn - 1) / LIBXSMM_MAX(tn, 1);
765   *nkt = (k + tk - 1) / LIBXSMM_MAX(tk, 1);
766 #if !defined(NDEBUG)
767   *mt = *nt = *kt = 0;
768 #endif
769   do {
770     if (1 >= replan) *mt = libxsmm_product_limit(*nmt, ntasks, 0);
771     if (1 == replan || ntasks <= *mt) { /* M-parallelism */
772       *nt = 1;
773       *kt = 1;
774       replan = 0;
775     }
776     else {
777       const unsigned int mntasks = libxsmm_product_limit((*nmt) * (*nnt), ntasks, 0);
778       if (0 == replan && *mt >= mntasks) replan = 1;
779       if (2 == replan || (0 == replan && ntasks <= mntasks)) { /* MN-parallelism */
780         *nt = mntasks / *mt;
781         *kt = 1;
782         replan = 0;
783       }
784       else { /* MNK-parallelism */
785         const unsigned int mnktasks = libxsmm_product_limit((*nmt) * (*nnt) * (*nkt), ntasks, 0);
786         if (mntasks < mnktasks) {
787 #if defined(LIBXSMM_GEMM_KPARALLEL)
788           *nt = mntasks / *mt;
789           *kt = mnktasks / mntasks;
790           replan = 0;
791 #else
792           static int error_once = 0;
793           if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */
794             && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
795           {
796             fprintf(stderr, "LIBXSMM WARNING (XGEMM): K-parallelism triggered!\n");
797           }
798 #endif
799         }
800 #if defined(LIBXSMM_GEMM_KPARALLEL)
801         else
802 #endif
803         if (0 == replan) replan = 2;
804       }
805     }
806   } while (0 != replan);
807   if (0 == *mt || 0 == *nt || 0 == *kt) {
808     result = EXIT_FAILURE;
809   }
810   return result;
811 }
812 
813 
libxsmm_gemm_handle_init(libxsmm_gemm_blob * blob,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const libxsmm_blasint * lda,const libxsmm_blasint * ldb,const libxsmm_blasint * ldc,const void * alpha,const void * beta,int flags,int ntasks)814 LIBXSMM_API libxsmm_gemm_handle* libxsmm_gemm_handle_init(libxsmm_gemm_blob* blob,
815   libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb,
816   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
817   const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc,
818   const void* alpha, const void* beta, int flags, /*unsigned*/int ntasks)
819 {
820   unsigned int ulda, uldb, um, un, uk, tm = 0, tn = 0, tk = 0, max_ntasks = 0;
821   libxsmm_descriptor_blob desc_blob;
822   union {
823     libxsmm_gemm_handle* ptr;
824     libxsmm_gemm_blob* blob;
825   } result;
826   LIBXSMM_ASSERT(sizeof(libxsmm_gemm_handle) <= sizeof(libxsmm_gemm_blob));
827   if (NULL != blob && NULL != m && 0 < ntasks) {
828     unsigned int ntm = 0, ntn = 0, ntk = 0, mt = 1, nt = 1, kt = 1;
829     const char *const env_tm = getenv("LIBXSMM_TGEMM_M");
830     libxsmm_blasint klda, kldb, kldc, km, kn;
831     libxsmm_gemm_descriptor* desc;
832     const int prf_copy = 0;
833     double dbeta;
834     LIBXSMM_INIT
835     result.blob = blob;
836 #if defined(NDEBUG)
837     result.ptr->copy_a.ptr = result.ptr->copy_b.ptr = result.ptr->copy_i.ptr = result.ptr->copy_o.ptr = NULL;
838 #else
839     memset(blob, 0, sizeof(libxsmm_gemm_blob));
840 #endif
841     if (EXIT_SUCCESS != libxsmm_dvalue((libxsmm_datatype)oprec, beta, &dbeta)) dbeta = LIBXSMM_BETA; /* fuse beta into flags */
842     result.ptr->gemm_flags = LIBXSMM_GEMM_PFLAGS(transa, transb, LIBXSMM_FLAGS) | (LIBXSMM_NEQ(0, dbeta) ? 0 : LIBXSMM_GEMM_FLAG_BETA_0);
843     /* TODO: check that arguments fit into handle (unsigned int vs. libxsmm_blasint) */
844     um = (unsigned int)(*m); uk = (NULL != k ? ((unsigned int)(*k)) : um); un = (NULL != n ? ((unsigned int)(*n)) : uk);
845     result.ptr->otypesize = libxsmm_typesize((libxsmm_datatype)oprec);
846     if (NULL == env_tm || 0 >= atoi(env_tm)) {
847       const unsigned int vwidth = LIBXSMM_MAX(internal_gemm_vwidth / result.ptr->otypesize, 1);
848       const double s2 = (double)internal_gemm_nstretch * internal_gemm_kstretch; /* LIBXSMM_INIT! */
849       unsigned int tmi = libxsmm_product_limit(um, internal_gemm_mlimit, 0); /* LIBXSMM_INIT! */
850       for (; vwidth <= tmi; tmi = libxsmm_product_limit(um, tmi - 1, 0)) {
851         const double si = (double)(LIBXSMM_CONFIG_MAX_MNK) / ((double)tmi * tmi * tmi), s = (s2 <= si ? 1 : (s2 / si));
852         unsigned int tni = libxsmm_product_limit(un, LIBXSMM_MAX((unsigned int)(tmi * (s * internal_gemm_nstretch)), 1), 0);
853         unsigned int tki = libxsmm_product_limit(uk, LIBXSMM_MAX((unsigned int)(tmi * (s * internal_gemm_kstretch)), 1), 0);
854         unsigned int ntmi, ntni, ntki, mti = 1, nti = 1, kti = 1;
855         LIBXSMM_ASSERT(tmi <= um && tni <= un && tki <= uk);
856         if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) {
857           const unsigned int ttm = (unsigned int)libxsmm_product_limit(tmi, (unsigned int)ntasks, 0);
858           const unsigned int ttn = (unsigned int)libxsmm_product_limit(tni, (unsigned int)ntasks, 0);
859           tmi = tni = LIBXSMM_MIN(ttm, ttn); /* prefer threads over larger tile */
860         }
861         if (EXIT_SUCCESS == libxsmm_gemm_plan_internal((unsigned int)ntasks, um, un, uk, tmi, tni, tki,
862           &ntmi, &ntni, &ntki, &mti, &nti, &kti))
863         {
864           const int exit_plan = ((tmi < um && tni < un && tki < uk && (tm != tmi || tn != tni || tk != tki)) ? 0 : 1);
865           const unsigned itasks = mti * nti * kti;
866           LIBXSMM_ASSERT(1 <= itasks);
867           if (max_ntasks < itasks) {
868             ntm = ntmi; ntn = ntni; ntk = ntki;
869             mt = mti; nt = nti; kt = kti;
870             tm = tmi; tn = tni; tk = tki;
871             max_ntasks = itasks;
872           }
873           if (itasks == (unsigned int)ntasks || 0 != exit_plan) break;
874         }
875       }
876     }
877     else {
878       const unsigned int tmi = atoi(env_tm);
879       const double s2 = (double)internal_gemm_nstretch * internal_gemm_kstretch; /* LIBXSMM_INIT! */
880       double si, s;
881       tm = libxsmm_product_limit(um, LIBXSMM_MIN(tmi, internal_gemm_mlimit), 0); /* LIBXSMM_INIT! */
882       si = (double)(LIBXSMM_CONFIG_MAX_MNK) / ((double)tm * tm * tm); s = (s2 <= si ? 1 : (s2 / si));
883       tn = libxsmm_product_limit(un, LIBXSMM_MAX((unsigned int)(tm * (s * internal_gemm_nstretch)), 1), 0);
884       tk = libxsmm_product_limit(uk, LIBXSMM_MAX((unsigned int)(tm * (s * internal_gemm_kstretch)), 1), 0);
885       if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) {
886         const unsigned int ttm = (unsigned int)libxsmm_product_limit(tm, (unsigned int)ntasks, 0);
887         const unsigned int ttn = (unsigned int)libxsmm_product_limit(tn, (unsigned int)ntasks, 0);
888         tm = tn = LIBXSMM_MIN(ttm, ttn); /* prefer threads over larger tile */
889       }
890       if (EXIT_SUCCESS == libxsmm_gemm_plan_internal((unsigned int)ntasks, um, un, uk, tm, tn, tk,
891         &ntm, &ntn, &ntk, &mt, &nt, &kt))
892       {
893 #if defined(NDEBUG)
894         max_ntasks = 2; /* only need something unequal to zero to pass below condition */
895 #else
896         max_ntasks = mt * nt * kt;
897 #endif
898       }
899     }
900     LIBXSMM_ASSERT(LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags) || tm == tn);
901     /* check for non-conforming GEMM parameters (error), and conforming GEMM parameters (fast-path, fall-back) */
902     if (0 == max_ntasks || 0 == tm || 0 == tn || 0 == tk || 0 != (um % tm) || 0 != (un % tn) || 0 != (uk % tk)) {
903       return NULL;
904     }
905     result.ptr->flags = flags;
906     if (LIBXSMM_GEMM_HANDLE_FLAG_AUTO == flags && 0 == LIBXSMM_SMM_AI(um, un, uk,
907       0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0) ? 1 : 2/*RFO*/, result.ptr->otypesize))
908     {
909       if (um == LIBXSMM_UP2POT(um) || un == LIBXSMM_UP2POT(un)) { /* power-of-two (POT) extent(s) */
910         result.ptr->flags |= LIBXSMM_GEMM_HANDLE_FLAG_COPY_C;
911         if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) {
912           result.ptr->flags |= LIBXSMM_GEMM_HANDLE_FLAG_COPY_A;
913         }
914       }
915     }
916     result.ptr->itypesize = libxsmm_typesize((libxsmm_datatype)iprec);
917     result.ptr->ldc = (unsigned int)(NULL != ldc ? *ldc : *m);
918     ulda = (NULL != lda ? ((unsigned int)(*lda)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & result.ptr->gemm_flags) ? ((unsigned int)(*m)) : uk));
919     uldb = (NULL != ldb ? ((unsigned int)(*ldb)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & result.ptr->gemm_flags) ? uk : un));
920     if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) { /* NN, NT, or TN */
921       kldc = (libxsmm_blasint)result.ptr->ldc;
922       klda = (libxsmm_blasint)ulda;
923       kldb = (libxsmm_blasint)uldb;
924       if (0 != (LIBXSMM_GEMM_FLAG_TRANS_A & result.ptr->gemm_flags)) { /* TN */
925 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS)
926         result.ptr->copy_a.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
927           result.ptr->itypesize, tk, tm, tm/*ldo*/));
928 #endif
929         klda = (libxsmm_blasint)tm;
930       }
931       else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & result.ptr->flags)) {
932         result.ptr->copy_a.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
933           result.ptr->itypesize, tm, tk, tm/*ldo*/, ulda/*ldi*/,
934           0/*flags*/, prf_copy, NULL/*unroll*/));
935         klda = (libxsmm_blasint)tm;
936       }
937       if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & result.ptr->gemm_flags)) { /* NT */
938 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS)
939         result.ptr->copy_b.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
940           result.ptr->itypesize, tn, tk, tk/*ldo*/));
941 #endif
942         kldb = (libxsmm_blasint)tk;
943       }
944       else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & result.ptr->flags)) {
945         result.ptr->copy_b.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
946           result.ptr->itypesize, tk, tn, tk/*ldo*/, uldb/*ldi*/,
947           0/*flags*/, prf_copy, NULL/*unroll*/));
948         kldb = (libxsmm_blasint)tk;
949       }
950       if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & result.ptr->flags)) {
951         result.ptr->copy_o.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
952           result.ptr->otypesize, tm, tn, result.ptr->ldc/*ldo*/, tm/*ldi*/,
953           0/*flags*/, prf_copy, NULL/*unroll*/));
954         if (0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */
955           result.ptr->copy_i.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
956             result.ptr->otypesize, tm, tn, tm/*ldo*/, result.ptr->ldc/*ldi*/,
957             0/*flags*/, prf_copy, NULL/*unroll*/));
958         }
959         kldc = (libxsmm_blasint)tm;
960       }
961       result.ptr->lda = ulda; result.ptr->ldb = uldb;
962       result.ptr->km = tm; result.ptr->kn = tn;
963       result.ptr->mt = mt; result.ptr->nt = nt;
964       result.ptr->m = um; result.ptr->n = un;
965       result.ptr->dm = LIBXSMM_UPDIV(ntm, mt) * tm;
966       result.ptr->dn = LIBXSMM_UPDIV(ntn, nt) * tn;
967       km = tm; kn = tn;
968     }
969     else { /* TT */
970       const unsigned int tt = tm;
971       klda = (libxsmm_blasint)uldb;
972       kldb = (libxsmm_blasint)ulda;
973       kldc = (libxsmm_blasint)tt;
974       LIBXSMM_ASSERT(tt == tn);
975 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS)
976       result.ptr->copy_o.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
977         result.ptr->otypesize, tt, tt, result.ptr->ldc/*ldo*/));
978       if (0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */
979         result.ptr->copy_i.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
980           result.ptr->otypesize, tt, tt, tt/*ldo*/));
981       }
982 #endif
983       if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & result.ptr->flags)) {
984         result.ptr->copy_a.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
985           result.ptr->itypesize, tt, tk, tk/*ldo*/, uldb/*ldi*/,
986           0/*flags*/, prf_copy, NULL/*unroll*/));
987         klda = (libxsmm_blasint)tt;
988       }
989       if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & result.ptr->flags)) {
990         result.ptr->copy_b.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
991           result.ptr->itypesize, tk, tn, tk/*ldo*/, ulda/*ldi*/,
992           0/*flags*/, prf_copy, NULL/*unroll*/));
993         kldb = (libxsmm_blasint)tk;
994       }
995       result.ptr->lda = uldb; result.ptr->ldb = ulda;
996       result.ptr->km = tn; result.ptr->kn = tm;
997       result.ptr->mt = nt; result.ptr->nt = mt;
998       result.ptr->m = un; result.ptr->n = um;
999       result.ptr->dm = LIBXSMM_UPDIV(ntn, nt) * tn;
1000       result.ptr->dn = LIBXSMM_UPDIV(ntm, mt) * tm;
1001       km = kn = tt;
1002     }
1003     result.ptr->dk = ntk / kt * tk;
1004     result.ptr->kk = tk;
1005     result.ptr->kt = kt;
1006     result.ptr->k = uk;
1007     desc = libxsmm_gemm_descriptor_init2( /* remove transpose flags from kernel request */
1008       &desc_blob, iprec, oprec, km, kn, result.ptr->kk, klda, kldb, kldc,
1009       alpha, beta, result.ptr->gemm_flags & ~LIBXSMM_GEMM_FLAG_TRANS_AB, internal_gemm_tiled_prefetch);
1010     result.ptr->kernel[0] = libxsmm_xmmdispatch(desc);
1011     if (NULL != result.ptr->kernel[0].xmm) {
1012       if (0 == (desc->flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* beta!=0 */
1013         result.ptr->kernel[1] = result.ptr->kernel[0];
1014       }
1015       else { /* generate kernel with beta=1 */
1016         desc->flags &= ~LIBXSMM_GEMM_FLAG_BETA_0;
1017         result.ptr->kernel[1] = libxsmm_xmmdispatch(desc);
1018         if (NULL == result.ptr->kernel[1].xmm) result.ptr = NULL;
1019       }
1020     }
1021     else result.ptr = NULL;
1022   }
1023   else {
1024     result.ptr = NULL;
1025   }
1026   return result.ptr;
1027 }
1028 
1029 
libxsmm_gemm_handle_get_scratch_size_a(const libxsmm_gemm_handle * handle)1030 LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_a(const libxsmm_gemm_handle* handle)
1031 {
1032   size_t result;
1033   if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_A)
1034     && (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ||
1035        (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags) == 0)))
1036   {
1037     result = 0;
1038   }
1039   else {
1040     const size_t size = (size_t)handle->km * handle->kk * handle->itypesize;
1041     result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE);
1042   }
1043   return result;
1044 }
1045 
1046 
libxsmm_gemm_handle_get_scratch_size_b(const libxsmm_gemm_handle * handle)1047 LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_b(const libxsmm_gemm_handle* handle)
1048 {
1049   size_t result;
1050   if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_B)
1051     && (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ||
1052        (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags) == 0)))
1053   {
1054     result = 0;
1055   }
1056   else {
1057     const size_t size = (size_t)handle->kk * handle->kn * handle->itypesize;
1058     result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE);
1059   }
1060   return result;
1061 }
1062 
1063 
libxsmm_gemm_handle_get_scratch_size_c(const libxsmm_gemm_handle * handle)1064 LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_c(const libxsmm_gemm_handle* handle)
1065 {
1066   size_t result;
1067   if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_C)
1068     && LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)))
1069   {
1070     result = 0;
1071   }
1072   else {
1073     const size_t size = (size_t)handle->km * handle->kn * handle->otypesize;
1074     result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE);
1075   }
1076   return result;
1077 }
1078 
1079 
libxsmm_gemm_handle_get_scratch_size(const libxsmm_gemm_handle * handle)1080 LIBXSMM_API size_t libxsmm_gemm_handle_get_scratch_size(const libxsmm_gemm_handle* handle)
1081 {
1082   size_t result;
1083   if (NULL != handle) { /* thread-local scratch buffer for GEMM */
1084     const size_t size_a = libxsmm_gemm_handle_get_scratch_size_a(handle);
1085     const size_t size_b = libxsmm_gemm_handle_get_scratch_size_b(handle);
1086     const size_t size_c = libxsmm_gemm_handle_get_scratch_size_c(handle);
1087     result = (size_a + size_b + size_c) * handle->mt * handle->nt * handle->kt;
1088   }
1089   else {
1090     result = 0;
1091   }
1092   return result;
1093 }
1094 
1095 
libxsmm_gemm_thread(const libxsmm_gemm_handle * handle,void * scratch,const void * a,const void * b,void * c,int tid,int nthreads)1096 LIBXSMM_API void libxsmm_gemm_thread(const libxsmm_gemm_handle* handle, void* scratch,
1097   const void* a, const void* b, void* c, /*unsigned*/int tid, /*unsigned*/int nthreads)
1098 {
1099 #if !defined(NDEBUG)
1100   if (NULL != handle && 0 <= tid && tid < nthreads)
1101 #endif
1102   {
1103     const unsigned int uthreads = (unsigned int)nthreads;
1104     const unsigned int ntasks = handle->mt * handle->nt * handle->kt;
1105     const unsigned int spread = (ntasks <= uthreads ? (uthreads / ntasks) : 1);
1106     const unsigned int utid = (unsigned int)tid, vtid = utid / spread;
1107     if (utid < (spread * ntasks) && 0 == (utid - vtid * spread)) {
1108       const int excess = (uthreads << 1) <= (vtid + ntasks);
1109       const unsigned int rtid = vtid / handle->mt, mtid = vtid - rtid * handle->mt, ntid = rtid % handle->nt, ktid = vtid / (handle->mt * handle->nt);
1110       const unsigned int m0 = mtid * handle->dm, m1 = (0 == excess ? LIBXSMM_MIN(m0 + handle->dm, handle->m) : handle->m);
1111       const unsigned int n0 = ntid * handle->dn, n1 = (0 == excess ? LIBXSMM_MIN(n0 + handle->dn, handle->n) : handle->n);
1112       const unsigned int k0 = ktid * handle->dk, k1 = (0 == excess ? LIBXSMM_MIN(k0 + handle->dk, handle->k) : handle->k);
1113       const unsigned int ldo = (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ? handle->km : handle->kk);
1114       /* calculate increments to simplify address calculations */
1115       const unsigned int dom = handle->km * handle->otypesize;
1116       const unsigned int don = handle->kn * handle->otypesize;
1117       const unsigned int dik = handle->kk * handle->itypesize;
1118       const unsigned int on = handle->otypesize * n0;
1119       /* calculate base address of thread-local storage */
1120       const size_t size_a = libxsmm_gemm_handle_get_scratch_size_a(handle);
1121       const size_t size_b = libxsmm_gemm_handle_get_scratch_size_b(handle);
1122       const size_t size_c = libxsmm_gemm_handle_get_scratch_size_c(handle);
1123       char *const at = (char*)scratch + (size_a + size_b + size_c) * vtid;
1124       char *const bt = at + size_a, *const ct = bt + size_b;
1125       const libxsmm_xcopykernel kernel = { NULL };
1126       /* loop induction variables and other variables */
1127       unsigned int om = handle->otypesize * m0, im = m0, in = n0, ik = k0, im1, in1, ik1;
1128       LIBXSMM_ASSERT_MSG(mtid < handle->mt && ntid < handle->nt && ktid < handle->kt, "Invalid task-ID");
1129       LIBXSMM_ASSERT_MSG(m1 <= handle->m && n1 <= handle->n && k1 <= handle->k, "Invalid task size");
1130       for (im1 = im + handle->km; (im1 - 1) < m1; im = im1, im1 += handle->km, om += dom) {
1131         unsigned int dn = don, dka = dik, dkb = dik;
1132         char *c0 = (char*)c, *ci;
1133         const char *aa;
1134         if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1135           if (0 != (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags)) { /* TN */
1136             aa = (const char*)a + ((size_t)im * handle->lda + k0) * handle->itypesize;
1137           }
1138           else if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags)) { /* NT */
1139             aa = (const char*)a + ((size_t)k0 * handle->lda + im) * handle->itypesize;
1140             dka *= handle->lda; dkb *= handle->ldb;
1141           }
1142           else { /* NN */
1143             aa = (const char*)a + ((size_t)k0 * handle->lda + im) * handle->itypesize;
1144             dka *= handle->lda;
1145           }
1146           c0 += (size_t)on * handle->ldc + om;
1147           dn *= handle->ldc;
1148         }
1149         else { /* TT */
1150           aa = (const char*)b + ((size_t)k0 * handle->lda + im) * handle->itypesize;
1151           c0 += (size_t)on + handle->ldc * (size_t)om;
1152           dka *= handle->lda;
1153         }
1154         for (in = n0, in1 = in + handle->kn; (in1 - 1) < n1; in = in1, in1 += handle->kn, c0 += dn) {
1155           const char *a0 = aa, *b0 = (const char*)b;
1156           if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1157             if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags)) { /* NT */
1158               b0 += ((size_t)k0 * handle->ldb + in) * handle->itypesize;
1159             }
1160             else { /* NN or TN */
1161               b0 += ((size_t)in * handle->ldb + k0) * handle->itypesize;
1162             }
1163           }
1164           else { /* TT */
1165             b0 = (const char*)a + ((size_t)in * handle->ldb + k0) * handle->itypesize;
1166           }
1167           if (NULL == handle->copy_i.ptr_const) {
1168             ci = (NULL == handle->copy_o.ptr_const ? c0 : ct);
1169             if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1170               const unsigned int km = handle->kn, kn = handle->km;
1171               libxsmm_otrans_internal(ct/*out*/, c0/*in*/, handle->otypesize, handle->ldc/*ldi*/, kn/*ldo*/,
1172                 0, km, 0, kn, km/*tile*/, kn/*tile*/, kernel);
1173               ci = ct;
1174             }
1175             else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & handle->flags)) {
1176               if (0 == (handle->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */
1177                 libxsmm_matcopy_internal(ct/*out*/, c0/*in*/, handle->otypesize, handle->ldc/*ldi*/, handle->km/*ldo*/,
1178                   0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel);
1179               }
1180               ci = ct;
1181             }
1182           }
1183           else { /* MCOPY/TCOPY kernel */
1184             handle->copy_i.xmatcopy(c0, &handle->ldc, ct, &handle->km);
1185             ci = ct;
1186           }
1187           for (ik = k0, ik1 = ik + handle->kk; (ik1 - 1) < k1; ik = ik1, ik1 += handle->kk) {
1188             const char *const a1 = a0 + dka, *const b1 = b0 + dkb, *ai = a0, *bi = b0;
1189             if (NULL == handle->copy_a.ptr_const) {
1190               if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) &&
1191                  (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags) != 0) /* pure A-transpose */
1192               {
1193                 LIBXSMM_ASSERT(ldo == handle->km);
1194                 libxsmm_otrans_internal(at/*out*/, a0/*in*/, handle->itypesize, handle->lda/*ldi*/, ldo,
1195                   0, handle->kk, 0, handle->km, handle->kk/*tile*/, handle->km/*tile*/, kernel);
1196                 ai = at;
1197               }
1198               else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & handle->flags)) {
1199                 libxsmm_matcopy_internal(at/*out*/, a0/*in*/, handle->itypesize, handle->lda/*ldi*/, ldo,
1200                   0, handle->km, 0, handle->kk, handle->km/*tile*/, handle->kk/*tile*/, kernel);
1201                 ai = at;
1202               }
1203             }
1204             else { /* MCOPY/TCOPY kernel */
1205               handle->copy_a.xmatcopy(a0, &handle->lda, at, &ldo);
1206               ai = at;
1207             }
1208             if (NULL == handle->copy_b.ptr_const) {
1209               if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) &&
1210                  (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags) != 0) /* pure B-transpose */
1211               {
1212                 libxsmm_otrans_internal(bt/*out*/, b0/*in*/, handle->itypesize, handle->ldb/*ldi*/, handle->kk/*ldo*/,
1213                   0, handle->kn, 0, handle->kk, handle->kn/*tile*/, handle->kk/*tile*/, kernel);
1214                 bi = bt;
1215               }
1216               else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & handle->flags)) {
1217                 libxsmm_matcopy_internal(bt/*out*/, b0/*in*/, handle->itypesize, handle->ldb/*ldi*/, handle->kk/*ldo*/,
1218                   0, handle->kk, 0, handle->kn, handle->kk/*tile*/, handle->kn/*tile*/, kernel);
1219                 bi = bt;
1220               }
1221             }
1222             else { /* MCOPY/TCOPY kernel */
1223               handle->copy_b.xmatcopy(b0, &handle->ldb, bt, &handle->kk);
1224               bi = bt;
1225             }
1226             /* beta0-kernel on first-touch, beta1-kernel otherwise (beta0/beta1 are identical if beta=1) */
1227             LIBXSMM_MMCALL_PRF(handle->kernel[k0!=ik?1:0].xmm, ai, bi, ci, a1, b1, c0);
1228             a0 = a1;
1229             b0 = b1;
1230           }
1231           /* TODO: synchronize */
1232           if (NULL == handle->copy_o.ptr_const) {
1233             if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1234               libxsmm_otrans_internal(c0/*out*/, ct/*in*/, handle->otypesize, handle->km/*ldi*/, handle->ldc/*ldo*/,
1235                 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel);
1236             }
1237             else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & handle->flags)) {
1238               libxsmm_matcopy_internal(c0/*out*/, ct/*in*/, handle->otypesize, handle->km/*ldi*/, handle->ldc/*ldo*/,
1239                 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel);
1240             }
1241           }
1242           else { /* MCOPY/TCOPY kernel */
1243             handle->copy_o.xmatcopy(ct, &handle->km, c0, &handle->ldc);
1244           }
1245         }
1246       }
1247     }
1248   }
1249 #if !defined(NDEBUG)
1250   else if (/*implies LIBXSMM_INIT*/0 != libxsmm_get_verbosity()) { /* library code is expected to be mute */
1251     static int error_once = 0;
1252     if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) {
1253       fprintf(stderr, "LIBXSMM ERROR: libxsmm_gemm_thread - invalid handle!\n");
1254     }
1255   }
1256 #endif
1257 }
1258 
1259 
libxsmm_xgemm(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)1260 LIBXSMM_API void libxsmm_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
1261   const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1262   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
1263   const void* beta, void* c, const libxsmm_blasint* ldc)
1264 {
1265   libxsmm_gemm_blob blob;
1266   const libxsmm_gemm_handle *const handle = libxsmm_gemm_handle_init(&blob, iprec, oprec, transa, transb,
1267     m, n, k, lda, ldb, ldc, alpha, beta, LIBXSMM_GEMM_HANDLE_FLAG_AUTO, 1/*ntasks*/);
1268   const size_t scratch_size = libxsmm_gemm_handle_get_scratch_size(handle);
1269   void* scratch = NULL;
1270   if (NULL != handle && (0 == scratch_size ||
1271       NULL != (scratch = libxsmm_scratch_malloc(scratch_size, LIBXSMM_CACHELINE, LIBXSMM_MALLOC_INTERNAL_CALLER))))
1272   {
1273     libxsmm_gemm_thread(handle, scratch, a, b, c, 0/*tid*/, 1/*ntasks*/);
1274     libxsmm_free(scratch);
1275   }
1276   else { /* fall-back or error */
1277     static int error_once = 0;
1278     if (NULL == handle) { /* fall-back */
1279       if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */
1280         && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1281       {
1282         fprintf(stderr, "LIBXSMM WARNING (XGEMM): fall-back code path triggered!\n");
1283       }
1284     }
1285     else if (0 != libxsmm_verbosity && /* library code is expected to be mute */
1286       1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1287     {
1288       fprintf(stderr, "LIBXSMM ERROR: failed to allocate GEMM-scratch memory!\n");
1289     }
1290     libxsmm_blas_xgemm(iprec, oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1291   }
1292 }
1293 
1294 
libxsmm_dgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const double alpha_array[],const double * a_array[],const libxsmm_blasint lda_array[],const double * b_array[],const libxsmm_blasint ldb_array[],const double beta_array[],double * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])1295 LIBXSMM_API void libxsmm_dgemm_batch(
1296   const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
1297   const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
1298   const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
1299 {
1300   const libxsmm_blasint ngroups = LIBXSMM_ABS(*group_count), ptrsize = sizeof(void*);
1301   libxsmm_blasint i, j = 0;
1302   for (i = 0; i < ngroups; ++i) {
1303     const libxsmm_blasint size = group_size[i];
1304     libxsmm_gemm_batch(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, transa_array + i, transb_array + i,
1305       m_array[i], n_array[i], k_array[i], alpha_array + i, a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, c_array + j, ldc_array + i,
1306       0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
1307     j += LIBXSMM_ABS(size);
1308   }
1309 }
1310 
1311 
libxsmm_sgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const float alpha_array[],const float * a_array[],const libxsmm_blasint lda_array[],const float * b_array[],const libxsmm_blasint ldb_array[],const float beta_array[],float * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])1312 LIBXSMM_API void libxsmm_sgemm_batch(
1313   const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
1314   const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
1315   const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
1316 {
1317   const libxsmm_blasint ngroups = LIBXSMM_ABS(*group_count), ptrsize = sizeof(void*);
1318   libxsmm_blasint i, j = 0;
1319   for (i = 0; i < ngroups; ++i) {
1320     const libxsmm_blasint size = group_size[i];
1321     libxsmm_gemm_batch(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, transa_array + i, transb_array + i,
1322       m_array[i], n_array[i], k_array[i], alpha_array + i, a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, c_array + j, ldc_array + i,
1323       0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
1324     j += LIBXSMM_ABS(size);
1325   }
1326 }
1327 
1328 
libxsmm_dgemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const double * alpha,const double * a,const libxsmm_blasint * lda,const double * b,const libxsmm_blasint * ldb,const double * beta,double * c,const libxsmm_blasint * ldc)1329 LIBXSMM_API void libxsmm_dgemm(const char* transa, const char* transb,
1330   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1331   const double* alpha, const double* a, const libxsmm_blasint* lda,
1332   const double* b, const libxsmm_blasint* ldb,
1333   const double* beta, double* c, const libxsmm_blasint* ldc)
1334 {
1335   LIBXSMM_XGEMM(double, double, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1336 }
1337 
1338 
libxsmm_sgemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const float * alpha,const float * a,const libxsmm_blasint * lda,const float * b,const libxsmm_blasint * ldb,const float * beta,float * c,const libxsmm_blasint * ldc)1339 LIBXSMM_API void libxsmm_sgemm(const char* transa, const char* transb,
1340   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1341   const float* alpha, const float* a, const libxsmm_blasint* lda,
1342   const float* b, const libxsmm_blasint* ldb,
1343   const float* beta, float* c, const libxsmm_blasint* ldc)
1344 {
1345   LIBXSMM_XGEMM(float, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1346 }
1347 
1348 
libxsmm_wigemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const int * alpha,const short * a,const libxsmm_blasint * lda,const short * b,const libxsmm_blasint * ldb,const int * beta,int * c,const libxsmm_blasint * ldc)1349 LIBXSMM_API void libxsmm_wigemm(const char* transa, const char* transb,
1350   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1351   const int* alpha, const short* a, const libxsmm_blasint* lda,
1352   const short* b, const libxsmm_blasint* ldb,
1353   const int* beta, int* c, const libxsmm_blasint* ldc)
1354 {
1355   LIBXSMM_XGEMM(short, int, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1356 }
1357 
1358 
libxsmm_bsgemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const float * alpha,const libxsmm_bfloat16 * a,const libxsmm_blasint * lda,const libxsmm_bfloat16 * b,const libxsmm_blasint * ldb,const float * beta,float * c,const libxsmm_blasint * ldc)1359 LIBXSMM_API void libxsmm_bsgemm(const char* transa, const char* transb,
1360   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1361   const float* alpha, const libxsmm_bfloat16* a, const libxsmm_blasint* lda,
1362   const libxsmm_bfloat16* b, const libxsmm_blasint* ldb,
1363   const float* beta, float* c, const libxsmm_blasint* ldc)
1364 {
1365   LIBXSMM_XGEMM(libxsmm_bfloat16, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1366 }
1367 
1368 
libxsmm_mmbatch_kernel(libxsmm_xmmfunction kernel,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],const void * a,const void * b,void * c,libxsmm_blasint batchsize,int tid,int ntasks,unsigned char itypesize,unsigned char otypesize,int flags)1369 LIBXSMM_API int libxsmm_mmbatch_kernel(libxsmm_xmmfunction kernel, libxsmm_blasint index_base,
1370   libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1371   const void* a, const void* b, void* c, libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int ntasks,
1372   unsigned char itypesize, unsigned char otypesize, int flags)
1373 {
1374   int result = EXIT_SUCCESS;
1375   const libxsmm_blasint size = LIBXSMM_ABS(batchsize);
1376   const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks);
1377   const libxsmm_blasint begin = tid * tasksize, span = begin + tasksize;
1378   const libxsmm_blasint end = LIBXSMM_MIN(span, size);
1379 
1380   LIBXSMM_ASSERT(NULL != kernel.xmm);
1381   if (begin < end) {
1382     const char *const a0 = (const char*)a, *const b0 = (const char*)b;
1383     char *const c0 = (char*)c;
1384 
1385     LIBXSMM_ASSERT(0 < itypesize && 0 < otypesize);
1386     if (0 == (LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS & flags)) {
1387       if (0 != index_stride) { /* stride arrays contain indexes */
1388         libxsmm_blasint i = begin * index_stride, ic = (NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0);
1389         const char* ai = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1390         const char* bi = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1391         char*       ci = &c0[ic * otypesize];
1392         const libxsmm_blasint end1 = (end != size ? end : (end - 1)) * index_stride;
1393 #if (0 != LIBXSMM_SYNC)
1394         if (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize || 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & flags))
1395 #endif
1396         { /* no locking */
1397           if (NULL != stride_a && NULL != stride_b && NULL != stride_c) {
1398             const unsigned char ibits = (unsigned char)LIBXSMM_INTRINSICS_BITSCANBWD32(itypesize);
1399             const unsigned char obits = (unsigned char)LIBXSMM_INTRINSICS_BITSCANBWD32(otypesize);
1400 
1401             if (itypesize == (1 << ibits) && otypesize == (1 << obits)) {
1402               for (i += index_stride; i <= end1; i += index_stride) {
1403                 const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) << ibits];
1404                 const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) << ibits];
1405                 char       *const cn = &c0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) << obits];
1406                 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1407                 ai = an; bi = bn; ci = cn;
1408               }
1409             }
1410             else { /* non-pot type sizes */
1411               for (i += index_stride; i <= end1; i += index_stride) {
1412                 const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize];
1413                 const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize];
1414                 char       *const cn = &c0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize];
1415                 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1416                 ai = an; bi = bn; ci = cn;
1417               }
1418             }
1419           }
1420           else { /* mixed specification of strides */
1421             for (i += index_stride; i <= end1; i += index_stride) {
1422               const char *const an = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1423               const char *const bn = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1424               char       *const cn = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize) : 0];
1425               kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1426               ai = an; bi = bn; ci = cn;
1427             }
1428           }
1429           if (end == size) { /* remainder multiplication */
1430             kernel.xmm(ai, bi, ci, ai, bi, ci); /* pseudo-prefetch */
1431           }
1432         }
1433 #if (0 != LIBXSMM_SYNC)
1434         else { /* synchronize among C-indexes */
1435           LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state;
1436 # if defined(LIBXSMM_GEMM_LOCKFWD)
1437           LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock0 = NULL;
1438 # endif
1439           LIBXSMM_ASSERT(NULL != lock);
1440           if (NULL != stride_a && NULL != stride_b && NULL != stride_c) {
1441             for (i += index_stride; i <= end1; i += index_stride) {
1442               ic = LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base;
1443               {
1444                 const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize];
1445                 const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize];
1446                 char       *const cn = &c0[ic * otypesize];
1447                 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state;
1448 # if defined(LIBXSMM_GEMM_LOCKFWD)
1449                 if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); }
1450 # else
1451                 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1452 # endif
1453                 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1454 # if defined(LIBXSMM_GEMM_LOCKFWD)
1455                 if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; }
1456 # else
1457                 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1;
1458 # endif
1459                 ai = an; bi = bn; ci = cn; /* next */
1460               }
1461             }
1462           }
1463           else {
1464             for (i += index_stride; i <= end1; i += index_stride) {
1465               ic = (NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0);
1466               {
1467                 const char *const an = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1468                 const char *const bn = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1469                 char       *const cn = &c0[ic * otypesize];
1470                 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state;
1471 # if defined(LIBXSMM_GEMM_LOCKFWD)
1472                 if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); }
1473 # else
1474                 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1475 # endif
1476                 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1477 # if defined(LIBXSMM_GEMM_LOCKFWD)
1478                 if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; }
1479 # else
1480                 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1;
1481 # endif
1482                 ai = an; bi = bn; ci = cn; /* next */
1483               }
1484             }
1485           }
1486           if (end == size) { /* remainder multiplication */
1487             LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1488             kernel.xmm(ai, bi, ci, ai, bi, ci); /* pseudo-prefetch */
1489             LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock);
1490           }
1491         }
1492 #endif /*(0 != LIBXSMM_SYNC)*/
1493       }
1494       else { /* singular strides are measured in Bytes */
1495         const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1496         const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1497         const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1498         libxsmm_blasint i;
1499         const libxsmm_blasint end1 = (end != size ? end : (end - 1));
1500         const char *ai = a0 + (size_t)da * begin, *bi = b0 + (size_t)db * begin;
1501         char* ci = c0 + (size_t)dc * begin;
1502 #if (0 != LIBXSMM_SYNC)
1503         if (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize || 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & flags))
1504 #endif
1505         { /* no locking */
1506           for (i = begin; i < end1; ++i) {
1507             const char *const an = ai + da, *const bn = bi + db;
1508             char *const cn = ci + dc;
1509 #if defined(LIBXSMM_GEMM_CHECK)
1510             if (NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != *((const void**)ci))
1511 #endif
1512             {
1513               kernel.xmm( /* with prefetch */
1514                 *((const void**)ai), *((const void**)bi), *((void**)ci),
1515                 *((const void**)an), *((const void**)bn), *((const void**)cn));
1516             }
1517             ai = an; bi = bn; ci = cn; /* next */
1518           }
1519           if ( /* remainder multiplication */
1520 #if defined(LIBXSMM_GEMM_CHECK)
1521             NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != *((const void**)ci) &&
1522 #endif
1523             end == size)
1524           {
1525             kernel.xmm( /* pseudo-prefetch */
1526               *((const void**)ai), *((const void**)bi), *((void**)ci),
1527               *((const void**)ai), *((const void**)bi), *((const void**)ci));
1528           }
1529         }
1530 #if (0 != LIBXSMM_SYNC)
1531         else { /* synchronize among C-indexes */
1532           void* cc = *((void**)ci);
1533           LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock = &internal_gemm_lock[LIBXSMM_GEMM_LOCKPTR(cc, internal_gemm_nlocks)].state;
1534 # if defined(LIBXSMM_GEMM_LOCKFWD)
1535           LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock0 = NULL;
1536 # endif
1537           LIBXSMM_ASSERT(NULL != lock);
1538           for (i = begin + 1; i <= end1; ++i) {
1539             const char *const an = ai + da, *const bn = bi + db;
1540             char *const cn = ci + dc;
1541             void *const nc = *((void**)cn);
1542 # if defined(LIBXSMM_GEMM_CHECK)
1543             if (NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != cc)
1544 # endif
1545             {
1546               LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKPTR(nc, internal_gemm_nlocks)].state;
1547 # if defined(LIBXSMM_GEMM_LOCKFWD)
1548               if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); }
1549 # else
1550               LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1551 # endif
1552               kernel.xmm( /* with prefetch */
1553                 *((const void**)ai), *((const void**)bi), cc,
1554                 *((const void**)an), *((const void**)bn), *((const void**)cn));
1555 # if defined(LIBXSMM_GEMM_LOCKFWD)
1556               if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; }
1557 # else
1558               LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1;
1559 # endif
1560             }
1561             ai = an; bi = bn; ci = cn; cc = nc; /* next */
1562           }
1563           if ( /* remainder multiplication */
1564 # if defined(LIBXSMM_GEMM_CHECK)
1565             NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != cc &&
1566 # endif
1567             end == size)
1568           {
1569             LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1570             kernel.xmm( /* pseudo-prefetch */
1571               *((const void**)ai), *((const void**)bi), cc,
1572               *((const void**)ai), *((const void**)bi), cc);
1573             LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock);
1574           }
1575         }
1576 #endif /*(0 != LIBXSMM_SYNC)*/
1577       }
1578     }
1579 #if defined(LIBXSMM_GEMM_BATCHREDUCE)
1580     else /* LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS */
1581 # if defined(LIBXSMM_GEMM_CHECK)
1582     if (
1583 #   if (0 != LIBXSMM_SYNC)
1584       (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize) &&
1585 #   endif
1586       (0 == (LIBXSMM_GEMM_FLAG_BETA_0 & flags)) &&
1587       (0 != internal_gemm_batchreduce))
1588 # endif
1589     {
1590       const unsigned int n = libxsmm_mmbatch_size * (LIBXSMM_GEMM_BATCHSCALE) / ((unsigned int)sizeof(void*));
1591       LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array && 0 != libxsmm_mmbatch_size);
1592       if ((2U/*A and B matrices*/ * tasksize) <= n) {
1593         const void **ai = (const void**)libxsmm_mmbatch_array + begin, **bi = ai + size;
1594         unsigned long long count;
1595         if (0 != index_stride) { /* stride arrays contain indexes */
1596           const size_t end_stride = (size_t)end * index_stride;
1597           size_t i = (size_t)begin * index_stride;
1598           char *ci = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize) : 0], *cn = ci;
1599           do {
1600             for (count = 0; i < end_stride && ci == cn; ++count) {
1601               const size_t j = i + index_stride;
1602               *ai++ = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1603               *bi++ = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1604                  cn = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, j) - index_base) * otypesize) : 0];
1605               i = j;
1606             }
1607             ai = (const void**)libxsmm_mmbatch_array + begin; bi = ai + size;
1608             kernel.xbm(ai, bi, ci, &count);
1609             ci = cn;
1610           } while (i < end_stride);
1611         }
1612         else { /* singular strides are measured in Bytes */
1613           const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1614           const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1615           const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1616           const char *ia = a0 + (size_t)da * begin, *ib = b0 + (size_t)db * begin;
1617           char* ic = c0 + (size_t)dc * begin;
1618           if (
1619 # if defined(LIBXSMM_GEMM_CHECK)
1620             NULL != *((const void**)ia) && NULL != *((const void**)ib) && NULL != *((const void**)ic) &&
1621 # endif
1622             sizeof(void*) == da && sizeof(void*) == db) /* fast path */
1623           {
1624             if (0 != dc) {
1625               libxsmm_blasint i = begin;
1626               char* jc = ic;
1627               do {
1628                 for (count = 0; i < end && *((const void**)ic) == *((const void**)jc); ++i) {
1629 # if defined(LIBXSMM_GEMM_CHECK)
1630                   if (NULL != *((const void**)jc))
1631 # endif
1632                   ++count;
1633                   jc += dc; /* next */
1634                 }
1635                 memcpy((void*)ai, ia, count * sizeof(void*));
1636                 memcpy((void*)bi, ib, count * sizeof(void*));
1637                 kernel.xbm(ai, bi, *((void**)ic), &count);
1638                 ic = jc;
1639               } while (i < end);
1640             }
1641             else { /* fastest path */
1642               count = (unsigned long long)end - begin;
1643               memcpy((void*)ai, ia, count * sizeof(void*));
1644               memcpy((void*)bi, ib, count * sizeof(void*));
1645               kernel.xbm(ai, bi, *((void**)ic), &count);
1646             }
1647           }
1648           else { /* custom-copy required */
1649             libxsmm_blasint i = begin;
1650             char* jc = ic;
1651             do {
1652               for (count = 0; i < end && *((const void**)ic) == *((const void**)jc); ++i) {
1653 # if defined(LIBXSMM_GEMM_CHECK)
1654                 if (NULL != *((const void**)ia) && NULL != *((const void**)ib) && NULL != *((const void**)jc))
1655 # endif
1656                 {
1657                   *ai++ = *((const void**)ia); *bi++ = *((const void**)ib);
1658                   ++count;
1659                 }
1660                 ia += da; ib += db; jc += dc; /* next */
1661               }
1662               ai = (const void**)libxsmm_mmbatch_array + begin; bi = ai + size;
1663               kernel.xbm(ai, bi, *((void**)ic), &count);
1664               ic = jc;
1665             } while (i < end);
1666           }
1667         }
1668       }
1669       else { /* fall-back */
1670         result = EXIT_FAILURE;
1671       }
1672     }
1673 #endif /*defined(LIBXSMM_GEMM_BATCHREDUCE)*/
1674   }
1675   /* coverity[missing_unlock] */
1676   return result;
1677 }
1678 
1679 
libxsmm_gemm_internal_set_batchflag(libxsmm_gemm_descriptor * descriptor,void * c,libxsmm_blasint index_stride,libxsmm_blasint batchsize,int multithreaded)1680 LIBXSMM_API void libxsmm_gemm_internal_set_batchflag(libxsmm_gemm_descriptor* descriptor, void* c, libxsmm_blasint index_stride,
1681   libxsmm_blasint batchsize, int multithreaded)
1682 {
1683   LIBXSMM_ASSERT(NULL != descriptor);
1684   if (0 != (LIBXSMM_GEMM_FLAG_BETA_0 & descriptor->flags)) {
1685     const uintptr_t vw = (LIBXSMM_X86_AVX512 <= libxsmm_target_archid ? 64 : 32);
1686     /* assume that all C-matrices are aligned eventually */
1687     if (0 == LIBXSMM_MOD2((uintptr_t)c, vw)
1688 #if 0 /* should fall-back in BE */
1689       && LIBXSMM_X86_AVX <= libxsmm_target_archid
1690 #endif
1691       && 0 != index_stride)
1692     {
1693       const int oprec = LIBXSMM_GETENUM_OUT(descriptor->datatype);
1694       const libxsmm_blasint typesize = LIBXSMM_TYPESIZE(oprec);
1695       const libxsmm_blasint csize = (libxsmm_blasint)descriptor->ldc * descriptor->n * typesize;
1696       /* finalize assumption if matrix-size is a multiple of the vector-width */
1697       descriptor->flags |= (unsigned short)(0 == LIBXSMM_MOD2(csize, vw) ? LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT : 0);
1698     }
1699   }
1700 #if defined(LIBXSMM_GEMM_BATCHREDUCE)
1701   else if (0 != internal_gemm_batchreduce) { /* check if reduce-batch kernel can be used */
1702     static int error_once = 0;
1703     LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array);
1704 # if (0 != LIBXSMM_SYNC)
1705     if (0 == multithreaded || 0 == internal_gemm_nlocks || 0 > batchsize)
1706 # endif
1707     {
1708       int result = EXIT_FAILURE;
1709       switch (LIBXSMM_GETENUM_INP(descriptor->datatype)) {
1710         case LIBXSMM_GEMM_PRECISION_F64: {
1711           if (LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_OUT(descriptor->datatype)) {
1712             result = EXIT_SUCCESS;
1713           }
1714         } break;
1715         case LIBXSMM_GEMM_PRECISION_F32: {
1716           if (LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_OUT(descriptor->datatype)) {
1717             result = EXIT_SUCCESS;
1718           }
1719         } break;
1720       }
1721       if (EXIT_SUCCESS == result) {
1722         descriptor->flags |= LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS;
1723         descriptor->prefetch = 0; /* omit decision */
1724       }
1725       else {
1726         if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) && /* library code is expected to be mute */
1727           1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1728         {
1729           fprintf(stderr, "LIBXSMM WARNING: data type not supported in batch-reduce!\n");
1730         }
1731       }
1732     }
1733 # if (0 != LIBXSMM_SYNC)
1734     else if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) && /* library code is expected to be mute */
1735       1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1736     {
1737       fprintf(stderr, "LIBXSMM: potential data races prevent batch-reduce.\n");
1738     }
1739 # endif
1740   }
1741 #endif /*defined(LIBXSMM_GEMM_BATCHREDUCE)*/
1742 #if !defined(LIBXSMM_GEMM_BATCHREDUCE) || (0 == LIBXSMM_SYNC)
1743   LIBXSMM_UNUSED(batchsize); LIBXSMM_UNUSED(multithreaded);
1744 #endif
1745 }
1746 
1747 
libxsmm_dmmbatch_blas(const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const double * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const double * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1748 LIBXSMM_API_INTERN void libxsmm_dmmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1749   const double* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const double* beta, void* c, const libxsmm_blasint* ldc,
1750   libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1751   libxsmm_blasint batchsize)
1752 {
1753   const libxsmm_blasint end = LIBXSMM_ABS(batchsize);
1754   libxsmm_blasint i;
1755 
1756   if (0 != index_stride) { /* stride arrays contain indexes */
1757     const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base) : 0);
1758     const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base) : 0);
1759     const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base) : 0);
1760     const libxsmm_blasint end1 = end * index_stride;
1761     const double *const a0 = (const double*)a, *const b0 = (const double*)b, *ai = a0 + da, *bi = b0 + db;
1762     double *const c0 = (double*)c, *ci = c0 + dc;
1763     for (i = index_stride; i <= end1; i += index_stride) {
1764       const double *const an = &a0[NULL != stride_a ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) : 0];
1765       const double *const bn = &b0[NULL != stride_b ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) : 0];
1766       double       *const cn = &c0[NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0];
1767 #if defined(LIBXSMM_GEMM_CHECK)
1768       if (NULL != ai && NULL != bi && NULL != ci)
1769 #endif
1770       {
1771         libxsmm_blas_dgemm(transa, transb, &m, &n, &k, alpha, ai, lda, bi, ldb, beta, ci, ldc);
1772       }
1773       ai = an; bi = bn; ci = cn;
1774     }
1775   }
1776   else { /* singular strides are measured in Bytes */
1777     const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1778     const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1779     const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1780     const char *const a0 = (const char*)a, *const b0 = (const char*)b, *ai = a0, *bi = b0;
1781     char *const c0 = (char*)c, *ci = c0;
1782     for (i = 0; i < end; ++i) {
1783       const char *const an = ai + da, *const bn = bi + db;
1784       char *const cn = ci + dc;
1785 #if defined(LIBXSMM_GEMM_CHECK)
1786       if (NULL != *((const double**)ai) && NULL != *((const double**)bi) && NULL != *((const double**)ci))
1787 #endif
1788       {
1789         libxsmm_blas_dgemm(transa, transb, &m, &n, &k, alpha, *((const double**)ai), lda, *((const double**)bi), ldb, beta, *((double**)ci), ldc);
1790       }
1791       ai = an; bi = bn; ci = cn; /* next */
1792     }
1793   }
1794 }
1795 
1796 
libxsmm_smmbatch_blas(const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const float * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const float * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1797 LIBXSMM_API_INTERN void libxsmm_smmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1798   const float* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const float* beta, void* c, const libxsmm_blasint* ldc,
1799   libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1800   libxsmm_blasint batchsize)
1801 {
1802   const libxsmm_blasint end = LIBXSMM_ABS(batchsize);
1803   libxsmm_blasint i;
1804 
1805   if (0 != index_stride) { /* stride arrays contain indexes */
1806     const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base) : 0);
1807     const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base) : 0);
1808     const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base) : 0);
1809     const libxsmm_blasint end1 = end * index_stride;
1810     const float *a0 = (const float*)a, *b0 = (const float*)b, *ai = a0 + da, *bi = b0 + db;
1811     float *c0 = (float*)c, *ci = c0 + dc;
1812     for (i = index_stride; i <= end1; i += index_stride) {
1813       const float *const an = &a0[NULL != stride_a ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) : 0];
1814       const float *const bn = &b0[NULL != stride_b ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) : 0];
1815       float       *const cn = &c0[NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0];
1816 #if defined(LIBXSMM_GEMM_CHECK)
1817       if (NULL != ai && NULL != bi && NULL != ci)
1818 #endif
1819       {
1820         libxsmm_blas_sgemm(transa, transb, &m, &n, &k, alpha, ai, lda, bi, ldb, beta, ci, ldc);
1821       }
1822       ai = an; bi = bn; ci = cn;
1823     }
1824   }
1825   else { /* singular strides are measured in Bytes */
1826     const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1827     const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1828     const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1829     const char *a0 = (const char*)a, *b0 = (const char*)b, *ai = a0, *bi = b0;
1830     char *c0 = (char*)c, *ci = c0;
1831     for (i = 0; i < end; ++i) {
1832       const char *const an = ai + da;
1833       const char *const bn = bi + db;
1834       char *const cn = ci + dc;
1835 #if defined(LIBXSMM_GEMM_CHECK)
1836       if (NULL != *((const float**)ai) && NULL != *((const float**)bi) && NULL != *((const float**)ci))
1837 #endif
1838       {
1839         libxsmm_blas_sgemm(transa, transb, &m, &n, &k, alpha, *((const float**)ai), lda, *((const float**)bi), ldb, beta, *((float**)ci), ldc);
1840       }
1841       ai = an; bi = bn; ci = cn; /* next */
1842     }
1843   }
1844 }
1845 
1846 
libxsmm_mmbatch_blas(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1847 LIBXSMM_API int libxsmm_mmbatch_blas(
1848   libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1849   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const void* beta, void* c, const libxsmm_blasint* ldc,
1850   libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1851   libxsmm_blasint batchsize)
1852 {
1853   int result;
1854   if (NULL != a && NULL != b && NULL != c) {
1855     switch (LIBXSMM_GETENUM(iprec, oprec)) {
1856       case LIBXSMM_GEMM_PRECISION_F64: {
1857         libxsmm_dmmbatch_blas(transa, transb, m, n, k,
1858           (const double*)alpha, a, lda, b, ldb, (const double*)beta, c, ldc,
1859           index_base, index_stride, stride_a, stride_b, stride_c, batchsize);
1860         result = EXIT_SUCCESS;
1861       } break;
1862       case LIBXSMM_GEMM_PRECISION_F32: {
1863         libxsmm_smmbatch_blas(transa, transb, m, n, k,
1864           (const float*)alpha, a, lda, b, ldb, (const float*)beta, c, ldc,
1865           index_base, index_stride, stride_a, stride_b, stride_c, batchsize);
1866         result = EXIT_SUCCESS;
1867       } break;
1868       default: result = EXIT_FAILURE;
1869     }
1870   }
1871   else {
1872     result = EXIT_FAILURE;
1873   }
1874   return result;
1875 }
1876 
1877 
libxsmm_mmbatch(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize,int tid,int nthreads)1878 LIBXSMM_API void libxsmm_mmbatch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
1879   const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1880   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
1881   const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride,
1882   const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1883   libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int nthreads)
1884 {
1885   static int error_once = 0;
1886 #if defined(LIBXSMM_GEMM_CHECK)
1887   if (NULL != a && NULL != b && NULL != c && 0 <= tid && tid < nthreads)
1888 #endif
1889   {
1890     const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)oprec);
1891     int result = EXIT_FAILURE;
1892     LIBXSMM_INIT
1893     if (LIBXSMM_SMM_AI(m, n, k, 2/*RFO*/, otypesize)) { /* check if an SMM is suitable */
1894       const int gemm_flags = LIBXSMM_GEMM_PFLAGS(transa, transb, LIBXSMM_FLAGS);
1895       libxsmm_descriptor_blob blob;
1896       libxsmm_gemm_descriptor *const desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, m, n, k,
1897         NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k),
1898         NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n),
1899         NULL != ldc ? *ldc : m, alpha, beta, gemm_flags, libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO));
1900       if (NULL != desc) {
1901         libxsmm_xmmfunction kernel;
1902         libxsmm_gemm_internal_set_batchflag(desc, c, index_stride, batchsize, 0/*multi-threaded*/);
1903         kernel = libxsmm_xmmdispatch(desc);
1904         if (NULL != kernel.xmm) {
1905           result = libxsmm_mmbatch_kernel(kernel, index_base, index_stride,
1906             stride_a, stride_b, stride_c, a, b, c, batchsize, tid, nthreads,
1907             libxsmm_typesize((libxsmm_datatype)iprec), otypesize, desc->flags);
1908         }
1909       }
1910     }
1911     if (EXIT_SUCCESS != result) { /* quiet fall-back */
1912       if (EXIT_SUCCESS == libxsmm_mmbatch_blas(iprec, oprec,
1913         transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1914         index_base, index_stride, stride_a, stride_b, stride_c, batchsize))
1915       {
1916         if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) {
1917           const size_t threshold = LIBXSMM_MNK_SIZE(m, n, m);
1918           static size_t threshold_max = 0;
1919           if (threshold_max < threshold) {
1920             LIBXSMM_STDIO_ACQUIRE();
1921             fprintf(stderr, "LIBXSMM WARNING: ");
1922             libxsmm_gemm_print2(stderr, iprec, oprec, transa, transb, &m, &n, &k,
1923               alpha, NULL/*a*/, lda, NULL/*b*/, ldb, beta, NULL/*c*/, ldc);
1924             fprintf(stderr, " => batched GEMM was falling back to BLAS!\n");
1925             LIBXSMM_STDIO_RELEASE();
1926             threshold_max = threshold;
1927           }
1928         }
1929       }
1930       else if (0 != libxsmm_verbosity /* library code is expected to be mute */
1931         && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1932       {
1933         fprintf(stderr, "LIBXSMM ERROR: libxsmm_mmbatch failed!\n");
1934       }
1935     }
1936   }
1937 #if defined(LIBXSMM_GEMM_CHECK)
1938   else if (0 != libxsmm_verbosity /* library code is expected to be mute */
1939     && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1940   {
1941     fprintf(stderr, "LIBXSMM ERROR: incorrect arguments (libxsmm_mmbatch)!\n");
1942   }
1943 #endif
1944 }
1945 
1946 
libxsmm_gemm_batch(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1947 LIBXSMM_API void libxsmm_gemm_batch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
1948   const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1949   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
1950   const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride,
1951   const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1952   libxsmm_blasint batchsize)
1953 {
1954   libxsmm_mmbatch(iprec, oprec, transa, transb, m, n, k,
1955     alpha,a, lda, b, ldb, beta, c, ldc, index_base, index_stride,
1956     stride_a, stride_b, stride_c, batchsize, 0/*tid*/, 1/*nthreads*/);
1957 }
1958 
1959 
1960 #if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))
1961 
1962 /* implementation provided for Fortran 77 compatibility */
1963 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dgemm)(const char*, const char*,
1964   const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
1965   const double*, const double*, const libxsmm_blasint*,
1966   const double*, const libxsmm_blasint*,
1967   const double*, double*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_dgemm)1968 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dgemm)(const char* transa, const char* transb,
1969   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1970   const double* alpha, const double* a, const libxsmm_blasint* lda,
1971   const double* b, const libxsmm_blasint* ldb,
1972   const double* beta, double* c, const libxsmm_blasint* ldc)
1973 {
1974   libxsmm_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1975 }
1976 
1977 /* implementation provided for Fortran 77 compatibility */
1978 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_sgemm)(const char*, const char*,
1979   const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
1980   const float*, const float*, const libxsmm_blasint*,
1981   const float*, const libxsmm_blasint*,
1982   const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_sgemm)1983 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_sgemm)(const char* transa, const char* transb,
1984   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1985   const float* alpha, const float* a, const libxsmm_blasint* lda,
1986   const float* b, const libxsmm_blasint* ldb,
1987   const float* beta, float* c, const libxsmm_blasint* ldc)
1988 {
1989   libxsmm_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1990 }
1991 
1992 
1993 /* implementation provided for Fortran 77 compatibility */
1994 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_wigemm)(const char*, const char*,
1995   const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
1996   const int*, const short*, const libxsmm_blasint*,
1997   const short*, const libxsmm_blasint*,
1998   const int*, int*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_wigemm)1999 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_wigemm)(const char* transa, const char* transb,
2000   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2001   const int* alpha, const short* a, const libxsmm_blasint* lda,
2002   const short* b, const libxsmm_blasint* ldb,
2003   const int* beta, int* c, const libxsmm_blasint* ldc)
2004 {
2005   libxsmm_wigemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2006 }
2007 
2008 
2009 /* implementation provided for Fortran 77 compatibility */
2010 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_bsgemm)(const char*, const char*,
2011   const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2012   const float*, const libxsmm_bfloat16*, const libxsmm_blasint*,
2013   const libxsmm_bfloat16*, const libxsmm_blasint*,
2014   const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_bsgemm)2015 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_bsgemm)(const char* transa, const char* transb,
2016   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2017   const float* alpha, const libxsmm_bfloat16* a, const libxsmm_blasint* lda,
2018   const libxsmm_bfloat16* b, const libxsmm_blasint* ldb,
2019   const float* beta, float* c, const libxsmm_blasint* ldc)
2020 {
2021   libxsmm_bsgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2022 }
2023 
2024 
2025 /* implementation provided for Fortran 77 compatibility */
2026 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
2027   const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2028   const float*, const float*, const libxsmm_blasint*,
2029   const float*, const libxsmm_blasint*,
2030   const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)2031 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
2032   const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2033   const float* alpha, const float* a, const libxsmm_blasint* lda,
2034   const float* b, const libxsmm_blasint* ldb,
2035   const float* beta, float* c, const libxsmm_blasint* ldc)
2036 {
2037   LIBXSMM_ASSERT(NULL != iprec && NULL != oprec);
2038   libxsmm_blas_xgemm(*iprec, *oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2039 }
2040 
2041 
2042 /* implementation provided for Fortran 77 compatibility */
2043 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)(const char*, const char*,
2044   const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2045   const double*, const double*, const libxsmm_blasint*,
2046   const double*, const libxsmm_blasint*,
2047   const double*, double*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)2048 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)(const char* transa, const char* transb,
2049   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2050   const double* alpha, const double* a, const libxsmm_blasint* lda,
2051   const double* b, const libxsmm_blasint* ldb,
2052   const double* beta, double* c, const libxsmm_blasint* ldc)
2053 {
2054   libxsmm_blas_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2055 }
2056 
2057 
2058 /* implementation provided for Fortran 77 compatibility */
2059 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)(const char*, const char*,
2060   const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2061   const float*, const float*, const libxsmm_blasint*,
2062   const float*, const libxsmm_blasint*,
2063   const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)2064 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)(const char* transa, const char* transb,
2065   const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2066   const float* alpha, const float* a, const libxsmm_blasint* lda,
2067   const float* b, const libxsmm_blasint* ldb,
2068   const float* beta, float* c, const libxsmm_blasint* ldc)
2069 {
2070   libxsmm_blas_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2071 }
2072 
2073 
2074 /* implementation provided for Fortran 77 compatibility */
2075 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_mmbatch)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
2076   const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2077   const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*,
2078   const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2079   const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[],
2080   const libxsmm_blasint*, const /*unsigned*/int*, const /*unsigned*/int*);
LIBXSMM_FSYMBOL(libxsmm_mmbatch)2081 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_mmbatch)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
2082   const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2083   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
2084   const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride,
2085   const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
2086   const libxsmm_blasint* batchsize, const /*unsigned*/int* tid, const /*unsigned*/int* nthreads)
2087 {
2088   LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k);
2089   LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize);
2090   LIBXSMM_ASSERT(NULL != tid && NULL != nthreads);
2091   libxsmm_mmbatch(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc,
2092     *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize, *tid, *nthreads);
2093 }
2094 
2095 
2096 /* implementation provided for Fortran 77 compatibility */
2097 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_gemm_batch)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
2098   const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2099   const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*,
2100   const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2101   const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[],
2102   const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_gemm_batch)2103 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_gemm_batch)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
2104   const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2105   const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
2106   const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride,
2107   const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
2108   const libxsmm_blasint* batchsize)
2109 {
2110   LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k);
2111   LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize);
2112   libxsmm_gemm_batch(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc,
2113     *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize);
2114 }
2115 
2116 #endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/
2117 
2118