1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_gemm.h"
12 #include "libxsmm_xcopy.h"
13 #include "libxsmm_hash.h"
14 #include <libxsmm_mhd.h>
15
16 #if defined(LIBXSMM_OFFLOAD_TARGET)
17 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
18 #endif
19 #if !defined(LIBXSMM_NO_LIBM)
20 # include <math.h>
21 #endif
22 #if defined(LIBXSMM_OFFLOAD_TARGET)
23 # pragma offload_attribute(pop)
24 #endif
25
26 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS) && \
27 /* TODO: fully support calling convention */ \
28 (defined(_WIN32) || defined(__CYGWIN__))
29 # define LIBXSMM_GEMM_NOJIT_TRANS
30 #endif
31 #if !defined(LIBXSMM_GEMM_KPARALLEL) && 0
32 # define LIBXSMM_GEMM_KPARALLEL
33 #endif
34 #if !defined(LIBXSMM_GEMM_BATCHSIZE)
35 # define LIBXSMM_GEMM_BATCHSIZE 1024
36 #endif
37 #if !defined(LIBXSMM_GEMM_TASKGRAIN)
38 # define LIBXSMM_GEMM_TASKGRAIN 128
39 #endif
40 #if !defined(LIBXSMM_GEMM_BATCHREDUCE) && !defined(_WIN32) && !defined(__CYGWIN__) /* not supported */
41 # define LIBXSMM_GEMM_BATCHREDUCE
42 #endif
43 #if !defined(LIBXSMM_GEMM_BATCHSCALE) && (defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP))
44 #define LIBXSMM_GEMM_BATCHSCALE ((unsigned int)LIBXSMM_ROUND(sizeof(libxsmm_mmbatch_item) * (LIBXSMM_GEMM_MMBATCH_SCALE)))
45 #endif
46 #if defined(LIBXSMM_BUILD)
47 # define LIBXSMM_GEMM_WEAK LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK
48 #else
49 # define LIBXSMM_GEMM_WEAK LIBXSMM_API
50 #endif
51
52 #if (0 != LIBXSMM_SYNC) /** Locks for the batch interface (duplicated C indexes). */
53 # define LIBXSMM_GEMM_LOCKIDX(IDX, NPOT) LIBXSMM_MOD2(LIBXSMM_CRC32U(LIBXSMM_BLASINT_NBITS)(2507/*seed*/, &(IDX)), NPOT)
54 # define LIBXSMM_GEMM_LOCKPTR(PTR, NPOT) LIBXSMM_MOD2(LIBXSMM_CRC32U(LIBXSMM_BITS)(1975/*seed*/, &(PTR)), NPOT)
55 # if !defined(LIBXSMM_GEMM_MAXNLOCKS)
56 # define LIBXSMM_GEMM_MAXNLOCKS 1024
57 # endif
58 # if !defined(LIBXSMM_GEMM_LOCKFWD)
59 # define LIBXSMM_GEMM_LOCKFWD
60 # endif
61 # if LIBXSMM_LOCK_TYPE_ISPOD(LIBXSMM_GEMM_LOCK)
62 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_gemm_locktype {
63 char pad[LIBXSMM_CACHELINE];
64 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) state;
65 } internal_gemm_locktype;
66 # else
67 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE internal_gemm_locktype {
68 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) state;
69 } internal_gemm_locktype;
70 # endif
71 LIBXSMM_APIVAR_DEFINE(internal_gemm_locktype internal_gemm_lock[LIBXSMM_GEMM_MAXNLOCKS]);
72 LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_nlocks); /* populated number of locks */
73 #endif
74
75 /* definition of corresponding variables */
76 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch_function);
77 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch_function);
78 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemm_function libxsmm_original_dgemm_function);
79 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemm_function libxsmm_original_sgemm_function);
80 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_dgemv_function libxsmm_original_dgemv_function);
81 LIBXSMM_APIVAR_PUBLIC_DEF(/*volatile*/libxsmm_sgemv_function libxsmm_original_sgemv_function);
82 /* definition of corresponding variables */
83 LIBXSMM_APIVAR_PUBLIC_DEF(libxsmm_gemm_descriptor libxsmm_mmbatch_desc);
84 LIBXSMM_APIVAR_PUBLIC_DEF(void* libxsmm_mmbatch_array);
85 LIBXSMM_APIVAR_PUBLIC_DEF(LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) libxsmm_mmbatch_lock);
86 LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_mmbatch_size);
87 LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_gemm_npargroups);
88 LIBXSMM_APIVAR_PUBLIC_DEF(unsigned int libxsmm_gemm_taskgrain);
89 LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_gemm_tasks);
90 LIBXSMM_APIVAR_PUBLIC_DEF(int libxsmm_gemm_wrap);
91
92 LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch_default);
93 /** Determines the prefetch strategy, which is used in case of LIBXSMM_PREFETCH_AUTO. */
94 LIBXSMM_APIVAR_PRIVATE_DEF(libxsmm_gemm_prefetch_type libxsmm_gemm_auto_prefetch);
95
96 /** Prefetch strategy for tiled GEMM. */
97 LIBXSMM_APIVAR_DEFINE(libxsmm_gemm_prefetch_type internal_gemm_tiled_prefetch);
98 /** Vector width used for GEMM. */
99 LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_vwidth);
100 /** Limit the M-extent of the tile. */
101 LIBXSMM_APIVAR_DEFINE(unsigned int internal_gemm_mlimit);
102 /** Table of M-extents per type-size (tile shape). */
103 LIBXSMM_APIVAR_DEFINE(float internal_gemm_nstretch);
104 /** Table of M-extents per type-size (tile shape). */
105 LIBXSMM_APIVAR_DEFINE(float internal_gemm_kstretch);
106 /** Determines if batch-reduce is enabled */
107 LIBXSMM_APIVAR_DEFINE(int internal_gemm_batchreduce);
108
109
110 #if defined(LIBXSMM_BUILD)
LIBXSMM_FSYMBOL(__real_dgemm_batch)111 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemm_batch)(
112 const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
113 const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
114 const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
115 {
116 #if (0 != LIBXSMM_BLAS)
117 # if defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
118 if (0 > libxsmm_gemm_wrap) {
119 LIBXSMM_FSYMBOL(dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
120 alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
121 group_count, group_size);
122 }
123 else
124 # endif
125 {
126 const libxsmm_blasint ptrsize = sizeof(void*);
127 libxsmm_blasint i, j = 0;
128 LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != group_count && NULL != group_size);
129 LIBXSMM_ASSERT(NULL != m_array && NULL != n_array && NULL != k_array && NULL != lda_array && NULL != ldb_array && NULL != ldc_array);
130 LIBXSMM_ASSERT(NULL != a_array && NULL != b_array && NULL != c_array && NULL != alpha_array && NULL != beta_array);
131 for (i = 0; i < *group_count; ++i) {
132 const libxsmm_blasint size = group_size[i];
133 libxsmm_dmmbatch_blas(transa_array + i, transb_array + i, m_array[i], n_array[i], k_array[i], alpha_array + i,
134 a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i,
135 c_array + j, ldc_array + i, 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
136 j += size;
137 }
138 }
139 #else
140 libxsmm_blas_error("dgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
141 alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
142 group_count, group_size);
143 #endif
144 }
145
146
LIBXSMM_FSYMBOL(__real_sgemm_batch)147 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemm_batch)(
148 const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
149 const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
150 const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
151 {
152 #if (0 != LIBXSMM_BLAS)
153 # if defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
154 if (0 > libxsmm_gemm_wrap) {
155 LIBXSMM_FSYMBOL(sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
156 alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
157 group_count, group_size);
158 }
159 else
160 # endif
161 {
162 const libxsmm_blasint ptrsize = sizeof(void*);
163 libxsmm_blasint i, j = 0;
164 LIBXSMM_ASSERT(NULL != transa_array && NULL != transb_array && NULL != group_count && NULL != group_size);
165 LIBXSMM_ASSERT(NULL != m_array && NULL != n_array && NULL != k_array && NULL != lda_array && NULL != ldb_array && NULL != ldc_array);
166 LIBXSMM_ASSERT(NULL != a_array && NULL != b_array && NULL != c_array && NULL != alpha_array && NULL != beta_array);
167 for (i = 0; i < *group_count; ++i) {
168 const libxsmm_blasint size = group_size[i];
169 libxsmm_smmbatch_blas(transa_array + i, transb_array + i, m_array[i], n_array[i], k_array[i], alpha_array + i,
170 a_array + i, lda_array + i, b_array + i, ldb_array + i, beta_array + i,
171 c_array + i, ldc_array + i, 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
172 j += size;
173 }
174 }
175 #else
176 libxsmm_blas_error("sgemm_batch")(transa_array, transb_array, m_array, n_array, k_array,
177 alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
178 group_count, group_size);
179 #endif
180 }
181
182
LIBXSMM_FSYMBOL(__real_dgemm)183 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemm)(const char* transa, const char* transb,
184 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
185 const double* alpha, const double* a, const libxsmm_blasint* lda,
186 const double* b, const libxsmm_blasint* ldb,
187 const double* beta, double* c, const libxsmm_blasint* ldc)
188 {
189 #if (0 != LIBXSMM_BLAS)
190 LIBXSMM_FSYMBOL(dgemm)((LIBXSMM_BLAS_CONST char*)transa, (LIBXSMM_BLAS_CONST char*)transb,
191 (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, (LIBXSMM_BLAS_CONST libxsmm_blasint*)k,
192 (LIBXSMM_BLAS_CONST double*)alpha, (LIBXSMM_BLAS_CONST double*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
193 (LIBXSMM_BLAS_CONST double*)b, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldb,
194 (LIBXSMM_BLAS_CONST double*) beta, c, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldc);
195 #else
196 libxsmm_blas_error("dgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
197 #endif
198 }
199
200
LIBXSMM_FSYMBOL(__real_sgemm)201 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemm)(const char* transa, const char* transb,
202 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
203 const float* alpha, const float* a, const libxsmm_blasint* lda,
204 const float* b, const libxsmm_blasint* ldb,
205 const float* beta, float* c, const libxsmm_blasint* ldc)
206 {
207 #if (0 != LIBXSMM_BLAS)
208 LIBXSMM_FSYMBOL(sgemm)((LIBXSMM_BLAS_CONST char*)transa, (LIBXSMM_BLAS_CONST char*)transb,
209 (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n, (LIBXSMM_BLAS_CONST libxsmm_blasint*)k,
210 (LIBXSMM_BLAS_CONST float*)alpha, (LIBXSMM_BLAS_CONST float*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
211 (LIBXSMM_BLAS_CONST float*)b, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldb,
212 (LIBXSMM_BLAS_CONST float*) beta, c, (LIBXSMM_BLAS_CONST libxsmm_blasint*)ldc);
213 #else
214 libxsmm_blas_error("sgemm")(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
215 #endif
216 }
217
218
LIBXSMM_FSYMBOL(__real_dgemv)219 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_dgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
220 const double* alpha, const double* a, const libxsmm_blasint* lda, const double* x, const libxsmm_blasint* incx,
221 const double* beta, double* y, const libxsmm_blasint* incy)
222 {
223 #if (0 != LIBXSMM_BLAS)
224 LIBXSMM_FSYMBOL(dgemv)((LIBXSMM_BLAS_CONST char*)trans, (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n,
225 (LIBXSMM_BLAS_CONST double*)alpha, (LIBXSMM_BLAS_CONST double*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
226 (LIBXSMM_BLAS_CONST double*)x, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incx,
227 (LIBXSMM_BLAS_CONST double*) beta, y, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incy);
228 #else
229 libxsmm_blas_error("dgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
230 #endif
231 }
232
233
LIBXSMM_FSYMBOL(__real_sgemv)234 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void LIBXSMM_FSYMBOL(__real_sgemv)(const char* trans, const libxsmm_blasint* m, const libxsmm_blasint* n,
235 const float* alpha, const float* a, const libxsmm_blasint* lda, const float* x, const libxsmm_blasint* incx,
236 const float* beta, float* y, const libxsmm_blasint* incy)
237 {
238 #if (0 != LIBXSMM_BLAS)
239 LIBXSMM_FSYMBOL(sgemv)((LIBXSMM_BLAS_CONST char*)trans, (LIBXSMM_BLAS_CONST libxsmm_blasint*)m, (LIBXSMM_BLAS_CONST libxsmm_blasint*)n,
240 (LIBXSMM_BLAS_CONST float*)alpha, (LIBXSMM_BLAS_CONST float*)a, (LIBXSMM_BLAS_CONST libxsmm_blasint*)lda,
241 (LIBXSMM_BLAS_CONST float*)x, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incx,
242 (LIBXSMM_BLAS_CONST float*) beta, y, (LIBXSMM_BLAS_CONST libxsmm_blasint*)incy);
243 #else
244 libxsmm_blas_error("sgemv")(trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
245 #endif
246 }
247
248
__real_dgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const double alpha_array[],const double * a_array[],const libxsmm_blasint lda_array[],const double * b_array[],const libxsmm_blasint ldb_array[],const double beta_array[],double * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])249 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void __real_dgemm_batch(
250 const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
251 const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
252 const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
253 {
254 LIBXSMM_FSYMBOL(__real_dgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
255 alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
256 group_count, group_size);
257 }
258
259
__real_sgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const float alpha_array[],const float * a_array[],const libxsmm_blasint lda_array[],const float * b_array[],const libxsmm_blasint ldb_array[],const float beta_array[],float * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])260 LIBXSMM_API LIBXSMM_ATTRIBUTE_WEAK void __real_sgemm_batch(
261 const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
262 const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
263 const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
264 {
265 LIBXSMM_FSYMBOL(__real_sgemm_batch)(transa_array, transb_array, m_array, n_array, k_array,
266 alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array,
267 group_count, group_size);
268 }
269 #endif /*defined(LIBXSMM_BUILD)*/
270
271
libxsmm_original_dgemm_batch(void)272 LIBXSMM_GEMM_WEAK libxsmm_dgemm_batch_function libxsmm_original_dgemm_batch(void)
273 {
274 #if (0 != LIBXSMM_BLAS) && defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
275 LIBXSMM_BLAS_WRAPPER(1, double, gemm_batch, libxsmm_original_dgemm_batch_function, NULL/*unknown*/);
276 /*LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_batch_function);*/
277 #else
278 LIBXSMM_BLAS_WRAPPER(0, double, gemm_batch, libxsmm_original_dgemm_batch_function, NULL/*unknown*/);
279 #endif
280 return libxsmm_original_dgemm_batch_function;
281 }
282
283
libxsmm_original_sgemm_batch(void)284 LIBXSMM_GEMM_WEAK libxsmm_sgemm_batch_function libxsmm_original_sgemm_batch(void)
285 {
286 #if (0 != LIBXSMM_BLAS) && defined(LIBXSMM_WRAP) && (0 > LIBXSMM_WRAP)
287 LIBXSMM_BLAS_WRAPPER(1, float, gemm_batch, libxsmm_original_sgemm_batch_function, NULL/*unknown*/);
288 /*LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_batch_function);*/
289 #else
290 LIBXSMM_BLAS_WRAPPER(0, float, gemm_batch, libxsmm_original_sgemm_batch_function, NULL/*unknown*/);
291 #endif
292 return libxsmm_original_sgemm_batch_function;
293 }
294
295
libxsmm_original_dgemm(void)296 LIBXSMM_GEMM_WEAK libxsmm_dgemm_function libxsmm_original_dgemm(void)
297 {
298 #if (0 != LIBXSMM_BLAS)
299 LIBXSMM_BLAS_WRAPPER(1, double, gemm, libxsmm_original_dgemm_function, NULL/*unknown*/);
300 LIBXSMM_ASSERT(NULL != libxsmm_original_dgemm_function);
301 #else
302 LIBXSMM_BLAS_WRAPPER(0, double, gemm, libxsmm_original_dgemm_function, NULL/*unknown*/);
303 #endif
304 return libxsmm_original_dgemm_function;
305 }
306
307
libxsmm_original_sgemm(void)308 LIBXSMM_GEMM_WEAK libxsmm_sgemm_function libxsmm_original_sgemm(void)
309 {
310 #if (0 != LIBXSMM_BLAS)
311 LIBXSMM_BLAS_WRAPPER(1, float, gemm, libxsmm_original_sgemm_function, NULL/*unknown*/);
312 LIBXSMM_ASSERT(NULL != libxsmm_original_sgemm_function);
313 #else
314 LIBXSMM_BLAS_WRAPPER(0, float, gemm, libxsmm_original_sgemm_function, NULL/*unknown*/);
315 #endif
316 return libxsmm_original_sgemm_function;
317 }
318
319
libxsmm_original_dgemv(void)320 LIBXSMM_GEMM_WEAK libxsmm_dgemv_function libxsmm_original_dgemv(void)
321 {
322 #if (0 != LIBXSMM_BLAS)
323 LIBXSMM_BLAS_WRAPPER(1, double, gemv, libxsmm_original_dgemv_function, NULL/*unknown*/);
324 LIBXSMM_ASSERT(NULL != libxsmm_original_dgemv_function);
325 #else
326 LIBXSMM_BLAS_WRAPPER(0, double, gemv, libxsmm_original_dgemv_function, NULL/*unknown*/);
327 #endif
328 return libxsmm_original_dgemv_function;
329 }
330
331
libxsmm_original_sgemv(void)332 LIBXSMM_GEMM_WEAK libxsmm_sgemv_function libxsmm_original_sgemv(void)
333 {
334 #if (0 != LIBXSMM_BLAS)
335 LIBXSMM_BLAS_WRAPPER(1, float, gemv, libxsmm_original_sgemv_function, NULL/*unknown*/);
336 LIBXSMM_ASSERT(NULL != libxsmm_original_sgemv_function);
337 #else
338 LIBXSMM_BLAS_WRAPPER(0, float, gemv, libxsmm_original_sgemv_function, NULL/*unknown*/);
339 #endif
340 return libxsmm_original_sgemv_function;
341 }
342
343
libxsmm_blas_error(const char * symbol)344 LIBXSMM_API libxsmm_sink_function libxsmm_blas_error(const char* symbol)
345 {
346 static int error_once = 0;
347 LIBXSMM_BLAS_ERROR(symbol, &error_once);
348 return libxsmm_sink;
349 }
350
351
libxsmm_gemm_init(int archid)352 LIBXSMM_API_INTERN void libxsmm_gemm_init(int archid)
353 {
354 const char* env_w = getenv("LIBXSMM_GEMM_WRAP");
355 LIBXSMM_LOCK_ATTR_TYPE(LIBXSMM_GEMM_LOCK) attr;
356 LIBXSMM_LOCK_ATTR_INIT(LIBXSMM_GEMM_LOCK, &attr);
357 #if defined(LIBXSMM_WRAP) /* determines if wrap is considered */
358 { /* intercepted GEMMs (1: sequential and non-tiled, 2: parallelized and tiled) */
359 # if defined(__STATIC) /* with static library the user controls interceptor already */
360 libxsmm_gemm_wrap = ((NULL == env_w || 0 == *env_w) /* LIBXSMM_WRAP=0: no promotion */
361 ? (0 < (LIBXSMM_WRAP) ? (LIBXSMM_WRAP + 2) : (LIBXSMM_WRAP - 2)) : atoi(env_w));
362 # else
363 libxsmm_gemm_wrap = ((NULL == env_w || 0 == *env_w) ? (LIBXSMM_WRAP) : atoi(env_w));
364 # endif
365 }
366 #endif
367 { /* setup prefetch strategy for tiled GEMMs */
368 const char *const env_p = getenv("LIBXSMM_TGEMM_PREFETCH");
369 const libxsmm_gemm_prefetch_type tiled_prefetch_default = LIBXSMM_GEMM_PREFETCH_AL2_AHEAD;
370 const int uid = ((NULL == env_p || 0 == *env_p) ? LIBXSMM_PREFETCH_AUTO/*default*/ : atoi(env_p));
371 internal_gemm_tiled_prefetch = (0 <= uid ? libxsmm_gemm_uid2prefetch(uid) : tiled_prefetch_default);
372 }
373 #if (0 != LIBXSMM_SYNC)
374 { /* initialize locks for the batch interface */
375 const char *const env_locks = getenv("LIBXSMM_GEMM_NLOCKS");
376 const int nlocks = ((NULL == env_locks || 0 == *env_locks) ? -1/*default*/ : atoi(env_locks));
377 unsigned int i;
378 internal_gemm_nlocks = LIBXSMM_UP2POT(0 > nlocks ? (LIBXSMM_GEMM_MAXNLOCKS) : LIBXSMM_MIN(nlocks, LIBXSMM_GEMM_MAXNLOCKS));
379 for (i = 0; i < internal_gemm_nlocks; ++i) LIBXSMM_LOCK_INIT(LIBXSMM_GEMM_LOCK, &internal_gemm_lock[i].state, &attr);
380 }
381 #endif
382 #if defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP)
383 { /* determines if batch-reduce kernel or batch-wrap is considered */
384 const char *const env_r = getenv("LIBXSMM_GEMM_BATCHREDUCE");
385 internal_gemm_batchreduce = (NULL == env_r || 0 == *env_r) ? 0 : atoi(env_r);
386 if ((NULL == env_w || 0 == *env_w) && ((LIBXSMM_GEMM_MMBATCH_VERBOSITY <= libxsmm_verbosity && INT_MAX != libxsmm_verbosity) || 0 > libxsmm_verbosity)) {
387 libxsmm_mmbatch_desc.flags = LIBXSMM_MMBATCH_FLAG_STATISTIC; /* enable auto-batch statistic */
388 internal_gemm_batchreduce = 0;
389 }
390 if (0 != internal_gemm_batchreduce || 0 != libxsmm_gemm_wrap) {
391 const char *const env_b = getenv("LIBXSMM_GEMM_BATCHSIZE");
392 const int env_bi = (NULL == env_b || 0 == *env_b) ? -1/*auto*/ : atoi(env_b);
393 const unsigned int env_bu = (unsigned int)(0 >= env_bi ? (LIBXSMM_GEMM_BATCHSIZE) : env_bi);
394 const unsigned int batchscale = LIBXSMM_ABS(internal_gemm_batchreduce) * 2048/*arbitrary*/ * 2/*A and B-matrices*/ * sizeof(void*);
395 const unsigned int minsize = LIBXSMM_UPDIV(batchscale * env_bu, LIBXSMM_GEMM_BATCHSCALE);
396 const unsigned int batchsize = LIBXSMM_MAX(env_bu, minsize);
397 const void *const extra = NULL;
398 LIBXSMM_ASSERT(1 < (LIBXSMM_GEMM_MMBATCH_SCALE) && NULL == libxsmm_mmbatch_array);
399 if (EXIT_SUCCESS == libxsmm_xmalloc(&libxsmm_mmbatch_array, (size_t)batchsize * (LIBXSMM_GEMM_BATCHSCALE), 0/*auto-alignment*/,
400 LIBXSMM_MALLOC_FLAG_PRIVATE /*| LIBXSMM_MALLOC_FLAG_SCRATCH*/, &extra, sizeof(extra)))
401 {
402 LIBXSMM_LOCK_INIT(LIBXSMM_GEMM_LOCK, &libxsmm_mmbatch_lock, &attr);
403 LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array);
404 libxsmm_mmbatch_size = batchsize;
405 }
406 }
407 }
408 #else
409 LIBXSMM_UNUSED(env_w);
410 #endif
411 { /* determines grain-size of tasks (when available) */
412 const char *const env_s = getenv("LIBXSMM_GEMM_NPARGROUPS");
413 libxsmm_gemm_npargroups = ((NULL == env_s || 0 == *env_s || 0 >= atoi(env_s))
414 ? (LIBXSMM_GEMM_NPARGROUPS) : atoi(env_s));
415 }
416 if (LIBXSMM_X86_AVX512_CORE <= archid) {
417 internal_gemm_vwidth = 64;
418 internal_gemm_mlimit = 48;
419 internal_gemm_nstretch = 3.0f;
420 internal_gemm_kstretch = 2.0f;
421 }
422 else if (LIBXSMM_X86_AVX512_MIC <= archid) {
423 internal_gemm_vwidth = 64;
424 internal_gemm_mlimit = 64;
425 internal_gemm_nstretch = 1.0f;
426 internal_gemm_kstretch = 1.0f;
427 }
428 else if (LIBXSMM_X86_AVX2 <= archid) {
429 internal_gemm_vwidth = 32;
430 internal_gemm_mlimit = 48;
431 internal_gemm_nstretch = 3.0f;
432 internal_gemm_kstretch = 2.0f;
433 }
434 else if (LIBXSMM_X86_AVX <= archid) {
435 internal_gemm_vwidth = 32;
436 internal_gemm_mlimit = 48;
437 internal_gemm_nstretch = 5.0f;
438 internal_gemm_kstretch = 1.0f;
439 }
440 else {
441 internal_gemm_vwidth = 16;
442 internal_gemm_mlimit = 48;
443 internal_gemm_nstretch = 7.0f;
444 internal_gemm_kstretch = 5.0f;
445 }
446 { /* setup tile sizes according to environment (LIBXSMM_TGEMM_M, LIBXSMM_TGEMM_N, LIBXSMM_TGEMM_K) */
447 const char *const env_m = getenv("LIBXSMM_TGEMM_M"), *const env_n = getenv("LIBXSMM_TGEMM_N"), *const env_k = getenv("LIBXSMM_TGEMM_K");
448 const int m = ((NULL == env_m || 0 == *env_m) ? 0 : atoi(env_m));
449 const int n = ((NULL == env_n || 0 == *env_n) ? 0 : atoi(env_n));
450 const int k = ((NULL == env_k || 0 == *env_k) ? 0 : atoi(env_k));
451 if (0 < m) {
452 if (0 < n) internal_gemm_nstretch = ((float)n) / m;
453 if (0 < k) internal_gemm_kstretch = ((float)k) / m;
454 }
455 }
456 { /* setup tile sizes according to environment (LIBXSMM_TGEMM_NS, LIBXSMM_TGEMM_KS) */
457 const char *const env_ns = getenv("LIBXSMM_TGEMM_NS"), *const env_ks = getenv("LIBXSMM_TGEMM_KS");
458 const double ns = ((NULL == env_ns || 0 == *env_ns) ? 0 : atof(env_ns));
459 const double ks = ((NULL == env_ks || 0 == *env_ks) ? 0 : atof(env_ks));
460 if (0 < ns) internal_gemm_nstretch = (float)LIBXSMM_MIN(24, ns);
461 if (0 < ks) internal_gemm_kstretch = (float)LIBXSMM_MIN(24, ks);
462 }
463 { /* determines if OpenMP tasks are used (when available) */
464 const char *const env_t = getenv("LIBXSMM_GEMM_TASKS");
465 const int gemm_tasks = ((NULL == env_t || 0 == *env_t) ? 0/*disabled*/ : atoi(env_t));
466 libxsmm_gemm_tasks = (0 <= gemm_tasks ? LIBXSMM_ABS(gemm_tasks) : 1/*enabled*/);
467 }
468 { /* determines grain-size of tasks (when available) */
469 const char *const env_g = getenv("LIBXSMM_GEMM_TASKGRAIN");
470 const int gemm_taskgrain = ((NULL == env_g || 0 == *env_g || 0 >= atoi(env_g))
471 ? (LIBXSMM_GEMM_TASKGRAIN) : atoi(env_g));
472 /* adjust grain-size or scale beyond the number of threads */
473 libxsmm_gemm_taskgrain = LIBXSMM_MAX(0 < libxsmm_gemm_tasks ? (gemm_taskgrain / libxsmm_gemm_tasks) : gemm_taskgrain, 1);
474 }
475 LIBXSMM_LOCK_ATTR_DESTROY(LIBXSMM_GEMM_LOCK, &attr);
476 /* determine BLAS function-pointers */
477 libxsmm_original_dgemm_batch();
478 libxsmm_original_sgemm_batch();
479 libxsmm_original_dgemm();
480 libxsmm_original_sgemm();
481 libxsmm_original_dgemv();
482 libxsmm_original_sgemv();
483 }
484
485
libxsmm_gemm_finalize(void)486 LIBXSMM_API_INTERN void libxsmm_gemm_finalize(void)
487 {
488 #if (0 != LIBXSMM_SYNC)
489 unsigned int i; for (i = 0; i < internal_gemm_nlocks; ++i) LIBXSMM_LOCK_DESTROY(LIBXSMM_GEMM_LOCK, &internal_gemm_lock[i].state);
490 #endif
491 #if defined(LIBXSMM_GEMM_BATCHREDUCE) || defined(LIBXSMM_WRAP)
492 if (NULL != libxsmm_mmbatch_array) {
493 void *extra = NULL, *const mmbatch_array = libxsmm_mmbatch_array;
494 if (EXIT_SUCCESS == libxsmm_get_malloc_xinfo(mmbatch_array, NULL/*size*/, NULL/*flags*/, &extra) && NULL != extra) {
495 const libxsmm_mmbatch_flush_function flush = *(libxsmm_mmbatch_flush_function*)extra;
496 if (NULL != flush) flush();
497 }
498 #if !defined(NDEBUG)
499 libxsmm_mmbatch_array = NULL;
500 #endif
501 libxsmm_xfree(mmbatch_array, 0/*no check*/);
502 LIBXSMM_LOCK_DESTROY(LIBXSMM_GEMM_LOCK, &libxsmm_mmbatch_lock);
503 }
504 #endif
505 }
506
507
libxsmm_get_gemm_xprefetch(const int * prefetch)508 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_xprefetch(const int* prefetch)
509 {
510 LIBXSMM_INIT /* load configuration */
511 return libxsmm_get_gemm_prefetch(NULL == prefetch ? ((int)libxsmm_gemm_auto_prefetch) : *prefetch);
512 }
513
514
libxsmm_get_gemm_prefetch(int prefetch)515 LIBXSMM_API libxsmm_gemm_prefetch_type libxsmm_get_gemm_prefetch(int prefetch)
516 {
517 libxsmm_gemm_prefetch_type result;
518 #if !defined(_WIN32) && !defined(__CYGWIN__) && !defined(__MINGW32__)
519 if (0 > prefetch) {
520 LIBXSMM_INIT /* load configuration */
521 result = libxsmm_gemm_auto_prefetch_default;
522 }
523 else {
524 result = (libxsmm_gemm_prefetch_type)prefetch;
525 }
526 #else /* TODO: full support for Windows calling convention */
527 result = LIBXSMM_GEMM_PREFETCH_NONE;
528 LIBXSMM_UNUSED(prefetch);
529 #endif
530 return result;
531 }
532
533
libxsmm_gemm_prefetch2uid(libxsmm_gemm_prefetch_type prefetch)534 LIBXSMM_API_INTERN int libxsmm_gemm_prefetch2uid(libxsmm_gemm_prefetch_type prefetch)
535 {
536 switch (prefetch) {
537 case LIBXSMM_GEMM_PREFETCH_SIGONLY: return 2;
538 case LIBXSMM_GEMM_PREFETCH_BL2_VIA_C: return 3;
539 case LIBXSMM_GEMM_PREFETCH_AL2_AHEAD: return 4;
540 case LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD: return 5;
541 case LIBXSMM_GEMM_PREFETCH_AL2: return 6;
542 case LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C: return 7;
543 default: {
544 LIBXSMM_ASSERT(LIBXSMM_GEMM_PREFETCH_NONE == prefetch);
545 return 0;
546 }
547 }
548 }
549
550
libxsmm_gemm_uid2prefetch(int uid)551 LIBXSMM_API_INTERN libxsmm_gemm_prefetch_type libxsmm_gemm_uid2prefetch(int uid)
552 {
553 switch (uid) {
554 case 1: return LIBXSMM_GEMM_PREFETCH_NONE; /* nopf */
555 case 2: return LIBXSMM_GEMM_PREFETCH_SIGONLY; /* pfsigonly */
556 case 3: return LIBXSMM_GEMM_PREFETCH_BL2_VIA_C; /* BL2viaC */
557 case 4: return LIBXSMM_GEMM_PREFETCH_AL2_AHEAD; /* curAL2 */
558 case 5: return LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD; /* curAL2_BL2viaC */
559 case 6: return LIBXSMM_GEMM_PREFETCH_AL2; /* AL2 */
560 case 7: return LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C; /* AL2_BL2viaC */
561 default: {
562 if (0 != libxsmm_verbosity) { /* library code is expected to be mute */
563 static int error_once = 0;
564 if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) {
565 fprintf(stderr, "LIBXSMM WARNING: invalid prefetch strategy requested!\n");
566 }
567 }
568 return LIBXSMM_GEMM_PREFETCH_NONE;
569 }
570 }
571 }
572
573
libxsmm_gemm_print(void * ostream,libxsmm_gemm_precision precision,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)574 LIBXSMM_API void libxsmm_gemm_print(void* ostream,
575 libxsmm_gemm_precision precision, const char* transa, const char* transb,
576 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
577 const void* alpha, const void* a, const libxsmm_blasint* lda,
578 const void* b, const libxsmm_blasint* ldb,
579 const void* beta, void* c, const libxsmm_blasint* ldc)
580 {
581 libxsmm_gemm_print2(ostream, precision, precision, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
582 }
583
584
libxsmm_gemm_print2(void * ostream,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)585 LIBXSMM_API void libxsmm_gemm_print2(void* ostream,
586 libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb,
587 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
588 const void* alpha, const void* a, const libxsmm_blasint* lda,
589 const void* b, const libxsmm_blasint* ldb,
590 const void* beta, void* c, const libxsmm_blasint* ldc)
591 {
592 const libxsmm_blasint nn = *(n ? n : m), kk = *(k ? k : m);
593 const char ctransa = (char)(NULL != transa ? (*transa) : (0 == (LIBXSMM_FLAGS & LIBXSMM_GEMM_FLAG_TRANS_A) ? 'n' : 't'));
594 const char ctransb = (char)(NULL != transb ? (*transb) : (0 == (LIBXSMM_FLAGS & LIBXSMM_GEMM_FLAG_TRANS_B) ? 'n' : 't'));
595 const libxsmm_blasint ilda = (NULL != lda ? *lda : (('n' == ctransa || 'N' == ctransa) ? *m : kk));
596 const libxsmm_blasint ildb = (NULL != ldb ? *ldb : (('n' == ctransb || 'N' == ctransb) ? kk : nn));
597 const libxsmm_blasint ildc = *(NULL != ldc ? ldc : m);
598 libxsmm_mhd_elemtype mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_UNKNOWN;
599 char string_a[128], string_b[128], typeprefix = 0;
600
601 switch (iprec | oprec) {
602 case LIBXSMM_GEMM_PRECISION_F64: {
603 LIBXSMM_ASSERT(iprec == oprec);
604 LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "%g", NULL != alpha ? *((const double*)alpha) : LIBXSMM_ALPHA);
605 LIBXSMM_SNPRINTF(string_b, sizeof(string_b), "%g", NULL != beta ? *((const double*)beta) : LIBXSMM_BETA);
606 mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_F64;
607 typeprefix = 'd';
608 } break;
609 case LIBXSMM_GEMM_PRECISION_F32: {
610 LIBXSMM_ASSERT(iprec == oprec);
611 LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "%g", NULL != alpha ? *((const float*)alpha) : LIBXSMM_ALPHA);
612 LIBXSMM_SNPRINTF(string_b, sizeof(string_b), "%g", NULL != beta ? *((const float*)beta) : LIBXSMM_BETA);
613 mhd_elemtype = LIBXSMM_MHD_ELEMTYPE_F32;
614 typeprefix = 's';
615 } break;
616 default: if (0 != libxsmm_verbosity) { /* library code is expected to be mute */
617 static int error_once = 0;
618 if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { /* TODO: support I16, etc. */
619 fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n");
620 }
621 }
622 }
623
624 if (0 != typeprefix) {
625 if (NULL != ostream) { /* print information about GEMM call */
626 if (NULL != a && NULL != b && NULL != c) {
627 fprintf((FILE*)ostream, "%cgemm('%c', '%c', %" PRIuPTR "/*m*/, %" PRIuPTR "/*n*/, %" PRIuPTR "/*k*/,\n"
628 " %s/*alpha*/, %p/*a*/, %" PRIuPTR "/*lda*/,\n"
629 " %p/*b*/, %" PRIuPTR "/*ldb*/,\n"
630 " %s/*beta*/, %p/*c*/, %" PRIuPTR "/*ldc*/)",
631 typeprefix, ctransa, ctransb, (uintptr_t)*m, (uintptr_t)nn, (uintptr_t)kk,
632 string_a, a, (uintptr_t)ilda, b, (uintptr_t)ildb, string_b, c, (uintptr_t)ildc);
633 }
634 else {
635 fprintf((FILE*)ostream, "%cgemm(trans=%c%c mnk=%" PRIuPTR ",%" PRIuPTR ",%" PRIuPTR
636 " ldx=%" PRIuPTR ",%" PRIuPTR ",%" PRIuPTR " a,b=%s,%s)",
637 typeprefix, ctransa, ctransb, (uintptr_t)*m, (uintptr_t)nn, (uintptr_t)kk,
638 (uintptr_t)ilda, (uintptr_t)ildb, (uintptr_t)ildc, string_a, string_b);
639 }
640 }
641 else { /* dump A, B, and C matrices into MHD files */
642 char extension_header[256];
643 size_t data_size[2], size[2];
644
645 if (NULL != a) {
646 LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "TRANS = %c\nALPHA = %s", ctransa, string_a);
647 LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_a_%p.mhd", a);
648 data_size[0] = (size_t)ilda; data_size[1] = (size_t)kk; size[0] = (size_t)(*m); size[1] = (size_t)kk;
649 libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype,
650 NULL/*conversion*/, a, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/);
651 }
652 if (NULL != b) {
653 LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "\nTRANS = %c", ctransb);
654 LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_b_%p.mhd", b);
655 data_size[0] = (size_t)ildb; data_size[1] = (size_t)nn; size[0] = (size_t)kk; size[1] = (size_t)nn;
656 libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype,
657 NULL/*conversion*/, b, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/);
658 }
659 if (NULL != c) {
660 LIBXSMM_SNPRINTF(extension_header, sizeof(extension_header), "BETA = %s", string_b);
661 LIBXSMM_SNPRINTF(string_a, sizeof(string_a), "libxsmm_c_%p.mhd", c);
662 data_size[0] = (size_t)ildc; data_size[1] = (size_t)nn; size[0] = (size_t)(*m); size[1] = (size_t)nn;
663 libxsmm_mhd_write(string_a, NULL/*offset*/, size, data_size, 2/*ndims*/, 1/*ncomponents*/, mhd_elemtype,
664 NULL/*conversion*/, c, NULL/*header_size*/, extension_header, NULL/*extension*/, 0/*extension_size*/);
665 }
666 }
667 }
668 }
669
670
libxsmm_gemm_dprint(void * ostream,libxsmm_gemm_precision precision,char transa,char transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,double dalpha,const void * a,libxsmm_blasint lda,const void * b,libxsmm_blasint ldb,double dbeta,void * c,libxsmm_blasint ldc)671 LIBXSMM_API void libxsmm_gemm_dprint(
672 void* ostream, libxsmm_gemm_precision precision, char transa, char transb,
673 libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, double dalpha, const void* a, libxsmm_blasint lda,
674 const void* b, libxsmm_blasint ldb, double dbeta, void* c, libxsmm_blasint ldc)
675 {
676 libxsmm_gemm_dprint2(ostream, precision, precision, transa, transb, m, n, k, dalpha, a, lda, b, ldb, dbeta, c, ldc);
677 }
678
679
libxsmm_gemm_dprint2(void * ostream,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,char transa,char transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,double dalpha,const void * a,libxsmm_blasint lda,const void * b,libxsmm_blasint ldb,double dbeta,void * c,libxsmm_blasint ldc)680 LIBXSMM_API void libxsmm_gemm_dprint2(
681 void* ostream, libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, char transa, char transb,
682 libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k, double dalpha, const void* a, libxsmm_blasint lda,
683 const void* b, libxsmm_blasint ldb, double dbeta, void* c, libxsmm_blasint ldc)
684 {
685 switch (iprec) {
686 case LIBXSMM_GEMM_PRECISION_F64: {
687 libxsmm_gemm_print2(ostream, LIBXSMM_GEMM_PRECISION_F64, oprec, &transa, &transb,
688 &m, &n, &k, &dalpha, a, &lda, b, &ldb, &dbeta, c, &ldc);
689 } break;
690 case LIBXSMM_GEMM_PRECISION_F32: {
691 const float alpha = (float)dalpha, beta = (float)dbeta;
692 libxsmm_gemm_print2(ostream, LIBXSMM_GEMM_PRECISION_F32, oprec, &transa, &transb,
693 &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
694 } break;
695 default: {
696 libxsmm_gemm_print2(ostream, iprec, oprec, &transa, &transb,
697 &m, &n, &k, &dalpha, a, &lda, b, &ldb, &dbeta, c, &ldc);
698 }
699 }
700 }
701
702
libxsmm_gemm_xprint(void * ostream,libxsmm_xmmfunction kernel,const void * a,const void * b,void * c)703 LIBXSMM_API void libxsmm_gemm_xprint(void* ostream,
704 libxsmm_xmmfunction kernel, const void* a, const void* b, void* c)
705 {
706 const libxsmm_descriptor* desc;
707 libxsmm_code_pointer code;
708 size_t code_size;
709 code.xgemm = kernel;
710 if (NULL != libxsmm_get_kernel_xinfo(code, &desc, &code_size) &&
711 NULL != desc && LIBXSMM_KERNEL_KIND_MATMUL == desc->kind)
712 {
713 libxsmm_gemm_dprint2(ostream,
714 (libxsmm_gemm_precision)LIBXSMM_GETENUM_INP(desc->gemm.desc.datatype),
715 (libxsmm_gemm_precision)LIBXSMM_GETENUM_OUT(desc->gemm.desc.datatype),
716 (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_A & desc->gemm.desc.flags) ? 'N' : 'T'),
717 (char)(0 == (LIBXSMM_GEMM_FLAG_TRANS_B & desc->gemm.desc.flags) ? 'N' : 'T'),
718 (libxsmm_blasint)desc->gemm.desc.m, (libxsmm_blasint)desc->gemm.desc.n, (libxsmm_blasint)desc->gemm.desc.k,
719 /*0 != (LIBXSMM_GEMM_FLAG_ALPHA_0 & libxsmm_mmbatch_desc.flags) ? 0 : */1, a,
720 (libxsmm_blasint)desc->gemm.desc.lda, b, (libxsmm_blasint)desc->gemm.desc.ldb,
721 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & libxsmm_mmbatch_desc.flags) ? 0 : 1, c, (libxsmm_blasint)desc->gemm.desc.ldc);
722 fprintf((FILE*)ostream, " = %p+%u", code.ptr_const, (unsigned int)code_size);
723 }
724 }
725
726
libxsmm_blas_xgemm(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)727 LIBXSMM_API void libxsmm_blas_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
728 const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
729 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
730 const void* beta, void* c, const libxsmm_blasint* ldc)
731 {
732 LIBXSMM_INIT
733 switch (iprec) {
734 case LIBXSMM_GEMM_PRECISION_F64: {
735 LIBXSMM_ASSERT(iprec == oprec);
736 LIBXSMM_BLAS_XGEMM(double, double, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
737 } break;
738 case LIBXSMM_GEMM_PRECISION_F32: {
739 LIBXSMM_ASSERT(iprec == oprec);
740 LIBXSMM_BLAS_XGEMM(float, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
741 } break;
742 default: if (0 != libxsmm_verbosity) { /* library code is expected to be mute */
743 static int error_once = 0;
744 LIBXSMM_UNUSED(oprec);
745 if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) { /* TODO: support I16, etc. */
746 fprintf(stderr, "LIBXSMM ERROR: unsupported data-type requested!\n");
747 }
748 }
749 }
750 }
751
752
libxsmm_gemm_plan_internal(unsigned int ntasks,unsigned int m,unsigned int n,unsigned int k,unsigned int tm,unsigned int tn,unsigned int tk,unsigned int * nmt,unsigned int * nnt,unsigned int * nkt,unsigned int * mt,unsigned int * nt,unsigned int * kt)753 LIBXSMM_API_INLINE int libxsmm_gemm_plan_internal(unsigned int ntasks,
754 unsigned int m, unsigned int n, unsigned int k, /* whole problem size */
755 unsigned int tm, unsigned int tn, unsigned int tk, /* tile size (kernel) */
756 unsigned int* nmt, unsigned int* nnt, unsigned int* nkt, /* number of tiles */
757 unsigned int* mt, unsigned int* nt, unsigned int* kt) /* number of tasks */
758 {
759 unsigned int result = EXIT_SUCCESS, replan = 0;
760 LIBXSMM_ASSERT(NULL != nmt && NULL != nnt && NULL != nkt);
761 LIBXSMM_ASSERT(NULL != mt && NULL != nt && NULL != kt);
762 LIBXSMM_ASSERT(0 < ntasks);
763 *nmt = (m + tm - 1) / LIBXSMM_MAX(tm, 1);
764 *nnt = (n + tn - 1) / LIBXSMM_MAX(tn, 1);
765 *nkt = (k + tk - 1) / LIBXSMM_MAX(tk, 1);
766 #if !defined(NDEBUG)
767 *mt = *nt = *kt = 0;
768 #endif
769 do {
770 if (1 >= replan) *mt = libxsmm_product_limit(*nmt, ntasks, 0);
771 if (1 == replan || ntasks <= *mt) { /* M-parallelism */
772 *nt = 1;
773 *kt = 1;
774 replan = 0;
775 }
776 else {
777 const unsigned int mntasks = libxsmm_product_limit((*nmt) * (*nnt), ntasks, 0);
778 if (0 == replan && *mt >= mntasks) replan = 1;
779 if (2 == replan || (0 == replan && ntasks <= mntasks)) { /* MN-parallelism */
780 *nt = mntasks / *mt;
781 *kt = 1;
782 replan = 0;
783 }
784 else { /* MNK-parallelism */
785 const unsigned int mnktasks = libxsmm_product_limit((*nmt) * (*nnt) * (*nkt), ntasks, 0);
786 if (mntasks < mnktasks) {
787 #if defined(LIBXSMM_GEMM_KPARALLEL)
788 *nt = mntasks / *mt;
789 *kt = mnktasks / mntasks;
790 replan = 0;
791 #else
792 static int error_once = 0;
793 if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */
794 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
795 {
796 fprintf(stderr, "LIBXSMM WARNING (XGEMM): K-parallelism triggered!\n");
797 }
798 #endif
799 }
800 #if defined(LIBXSMM_GEMM_KPARALLEL)
801 else
802 #endif
803 if (0 == replan) replan = 2;
804 }
805 }
806 } while (0 != replan);
807 if (0 == *mt || 0 == *nt || 0 == *kt) {
808 result = EXIT_FAILURE;
809 }
810 return result;
811 }
812
813
libxsmm_gemm_handle_init(libxsmm_gemm_blob * blob,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const libxsmm_blasint * lda,const libxsmm_blasint * ldb,const libxsmm_blasint * ldc,const void * alpha,const void * beta,int flags,int ntasks)814 LIBXSMM_API libxsmm_gemm_handle* libxsmm_gemm_handle_init(libxsmm_gemm_blob* blob,
815 libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb,
816 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
817 const libxsmm_blasint* lda, const libxsmm_blasint* ldb, const libxsmm_blasint* ldc,
818 const void* alpha, const void* beta, int flags, /*unsigned*/int ntasks)
819 {
820 unsigned int ulda, uldb, um, un, uk, tm = 0, tn = 0, tk = 0, max_ntasks = 0;
821 libxsmm_descriptor_blob desc_blob;
822 union {
823 libxsmm_gemm_handle* ptr;
824 libxsmm_gemm_blob* blob;
825 } result;
826 LIBXSMM_ASSERT(sizeof(libxsmm_gemm_handle) <= sizeof(libxsmm_gemm_blob));
827 if (NULL != blob && NULL != m && 0 < ntasks) {
828 unsigned int ntm = 0, ntn = 0, ntk = 0, mt = 1, nt = 1, kt = 1;
829 const char *const env_tm = getenv("LIBXSMM_TGEMM_M");
830 libxsmm_blasint klda, kldb, kldc, km, kn;
831 libxsmm_gemm_descriptor* desc;
832 const int prf_copy = 0;
833 double dbeta;
834 LIBXSMM_INIT
835 result.blob = blob;
836 #if defined(NDEBUG)
837 result.ptr->copy_a.ptr = result.ptr->copy_b.ptr = result.ptr->copy_i.ptr = result.ptr->copy_o.ptr = NULL;
838 #else
839 memset(blob, 0, sizeof(libxsmm_gemm_blob));
840 #endif
841 if (EXIT_SUCCESS != libxsmm_dvalue((libxsmm_datatype)oprec, beta, &dbeta)) dbeta = LIBXSMM_BETA; /* fuse beta into flags */
842 result.ptr->gemm_flags = LIBXSMM_GEMM_PFLAGS(transa, transb, LIBXSMM_FLAGS) | (LIBXSMM_NEQ(0, dbeta) ? 0 : LIBXSMM_GEMM_FLAG_BETA_0);
843 /* TODO: check that arguments fit into handle (unsigned int vs. libxsmm_blasint) */
844 um = (unsigned int)(*m); uk = (NULL != k ? ((unsigned int)(*k)) : um); un = (NULL != n ? ((unsigned int)(*n)) : uk);
845 result.ptr->otypesize = libxsmm_typesize((libxsmm_datatype)oprec);
846 if (NULL == env_tm || 0 >= atoi(env_tm)) {
847 const unsigned int vwidth = LIBXSMM_MAX(internal_gemm_vwidth / result.ptr->otypesize, 1);
848 const double s2 = (double)internal_gemm_nstretch * internal_gemm_kstretch; /* LIBXSMM_INIT! */
849 unsigned int tmi = libxsmm_product_limit(um, internal_gemm_mlimit, 0); /* LIBXSMM_INIT! */
850 for (; vwidth <= tmi; tmi = libxsmm_product_limit(um, tmi - 1, 0)) {
851 const double si = (double)(LIBXSMM_CONFIG_MAX_MNK) / ((double)tmi * tmi * tmi), s = (s2 <= si ? 1 : (s2 / si));
852 unsigned int tni = libxsmm_product_limit(un, LIBXSMM_MAX((unsigned int)(tmi * (s * internal_gemm_nstretch)), 1), 0);
853 unsigned int tki = libxsmm_product_limit(uk, LIBXSMM_MAX((unsigned int)(tmi * (s * internal_gemm_kstretch)), 1), 0);
854 unsigned int ntmi, ntni, ntki, mti = 1, nti = 1, kti = 1;
855 LIBXSMM_ASSERT(tmi <= um && tni <= un && tki <= uk);
856 if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) {
857 const unsigned int ttm = (unsigned int)libxsmm_product_limit(tmi, (unsigned int)ntasks, 0);
858 const unsigned int ttn = (unsigned int)libxsmm_product_limit(tni, (unsigned int)ntasks, 0);
859 tmi = tni = LIBXSMM_MIN(ttm, ttn); /* prefer threads over larger tile */
860 }
861 if (EXIT_SUCCESS == libxsmm_gemm_plan_internal((unsigned int)ntasks, um, un, uk, tmi, tni, tki,
862 &ntmi, &ntni, &ntki, &mti, &nti, &kti))
863 {
864 const int exit_plan = ((tmi < um && tni < un && tki < uk && (tm != tmi || tn != tni || tk != tki)) ? 0 : 1);
865 const unsigned itasks = mti * nti * kti;
866 LIBXSMM_ASSERT(1 <= itasks);
867 if (max_ntasks < itasks) {
868 ntm = ntmi; ntn = ntni; ntk = ntki;
869 mt = mti; nt = nti; kt = kti;
870 tm = tmi; tn = tni; tk = tki;
871 max_ntasks = itasks;
872 }
873 if (itasks == (unsigned int)ntasks || 0 != exit_plan) break;
874 }
875 }
876 }
877 else {
878 const unsigned int tmi = atoi(env_tm);
879 const double s2 = (double)internal_gemm_nstretch * internal_gemm_kstretch; /* LIBXSMM_INIT! */
880 double si, s;
881 tm = libxsmm_product_limit(um, LIBXSMM_MIN(tmi, internal_gemm_mlimit), 0); /* LIBXSMM_INIT! */
882 si = (double)(LIBXSMM_CONFIG_MAX_MNK) / ((double)tm * tm * tm); s = (s2 <= si ? 1 : (s2 / si));
883 tn = libxsmm_product_limit(un, LIBXSMM_MAX((unsigned int)(tm * (s * internal_gemm_nstretch)), 1), 0);
884 tk = libxsmm_product_limit(uk, LIBXSMM_MAX((unsigned int)(tm * (s * internal_gemm_kstretch)), 1), 0);
885 if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) {
886 const unsigned int ttm = (unsigned int)libxsmm_product_limit(tm, (unsigned int)ntasks, 0);
887 const unsigned int ttn = (unsigned int)libxsmm_product_limit(tn, (unsigned int)ntasks, 0);
888 tm = tn = LIBXSMM_MIN(ttm, ttn); /* prefer threads over larger tile */
889 }
890 if (EXIT_SUCCESS == libxsmm_gemm_plan_internal((unsigned int)ntasks, um, un, uk, tm, tn, tk,
891 &ntm, &ntn, &ntk, &mt, &nt, &kt))
892 {
893 #if defined(NDEBUG)
894 max_ntasks = 2; /* only need something unequal to zero to pass below condition */
895 #else
896 max_ntasks = mt * nt * kt;
897 #endif
898 }
899 }
900 LIBXSMM_ASSERT(LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags) || tm == tn);
901 /* check for non-conforming GEMM parameters (error), and conforming GEMM parameters (fast-path, fall-back) */
902 if (0 == max_ntasks || 0 == tm || 0 == tn || 0 == tk || 0 != (um % tm) || 0 != (un % tn) || 0 != (uk % tk)) {
903 return NULL;
904 }
905 result.ptr->flags = flags;
906 if (LIBXSMM_GEMM_HANDLE_FLAG_AUTO == flags && 0 == LIBXSMM_SMM_AI(um, un, uk,
907 0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0) ? 1 : 2/*RFO*/, result.ptr->otypesize))
908 {
909 if (um == LIBXSMM_UP2POT(um) || un == LIBXSMM_UP2POT(un)) { /* power-of-two (POT) extent(s) */
910 result.ptr->flags |= LIBXSMM_GEMM_HANDLE_FLAG_COPY_C;
911 if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) {
912 result.ptr->flags |= LIBXSMM_GEMM_HANDLE_FLAG_COPY_A;
913 }
914 }
915 }
916 result.ptr->itypesize = libxsmm_typesize((libxsmm_datatype)iprec);
917 result.ptr->ldc = (unsigned int)(NULL != ldc ? *ldc : *m);
918 ulda = (NULL != lda ? ((unsigned int)(*lda)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & result.ptr->gemm_flags) ? ((unsigned int)(*m)) : uk));
919 uldb = (NULL != ldb ? ((unsigned int)(*ldb)) : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & result.ptr->gemm_flags) ? uk : un));
920 if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & result.ptr->gemm_flags)) { /* NN, NT, or TN */
921 kldc = (libxsmm_blasint)result.ptr->ldc;
922 klda = (libxsmm_blasint)ulda;
923 kldb = (libxsmm_blasint)uldb;
924 if (0 != (LIBXSMM_GEMM_FLAG_TRANS_A & result.ptr->gemm_flags)) { /* TN */
925 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS)
926 result.ptr->copy_a.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
927 result.ptr->itypesize, tk, tm, tm/*ldo*/));
928 #endif
929 klda = (libxsmm_blasint)tm;
930 }
931 else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & result.ptr->flags)) {
932 result.ptr->copy_a.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
933 result.ptr->itypesize, tm, tk, tm/*ldo*/, ulda/*ldi*/,
934 0/*flags*/, prf_copy, NULL/*unroll*/));
935 klda = (libxsmm_blasint)tm;
936 }
937 if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & result.ptr->gemm_flags)) { /* NT */
938 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS)
939 result.ptr->copy_b.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
940 result.ptr->itypesize, tn, tk, tk/*ldo*/));
941 #endif
942 kldb = (libxsmm_blasint)tk;
943 }
944 else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & result.ptr->flags)) {
945 result.ptr->copy_b.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
946 result.ptr->itypesize, tk, tn, tk/*ldo*/, uldb/*ldi*/,
947 0/*flags*/, prf_copy, NULL/*unroll*/));
948 kldb = (libxsmm_blasint)tk;
949 }
950 if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & result.ptr->flags)) {
951 result.ptr->copy_o.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
952 result.ptr->otypesize, tm, tn, result.ptr->ldc/*ldo*/, tm/*ldi*/,
953 0/*flags*/, prf_copy, NULL/*unroll*/));
954 if (0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */
955 result.ptr->copy_i.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
956 result.ptr->otypesize, tm, tn, tm/*ldo*/, result.ptr->ldc/*ldi*/,
957 0/*flags*/, prf_copy, NULL/*unroll*/));
958 }
959 kldc = (libxsmm_blasint)tm;
960 }
961 result.ptr->lda = ulda; result.ptr->ldb = uldb;
962 result.ptr->km = tm; result.ptr->kn = tn;
963 result.ptr->mt = mt; result.ptr->nt = nt;
964 result.ptr->m = um; result.ptr->n = un;
965 result.ptr->dm = LIBXSMM_UPDIV(ntm, mt) * tm;
966 result.ptr->dn = LIBXSMM_UPDIV(ntn, nt) * tn;
967 km = tm; kn = tn;
968 }
969 else { /* TT */
970 const unsigned int tt = tm;
971 klda = (libxsmm_blasint)uldb;
972 kldb = (libxsmm_blasint)ulda;
973 kldc = (libxsmm_blasint)tt;
974 LIBXSMM_ASSERT(tt == tn);
975 #if !defined(LIBXSMM_GEMM_NOJIT_TRANS)
976 result.ptr->copy_o.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
977 result.ptr->otypesize, tt, tt, result.ptr->ldc/*ldo*/));
978 if (0 == (result.ptr->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */
979 result.ptr->copy_i.xtrans = libxsmm_dispatch_trans(libxsmm_trans_descriptor_init(&desc_blob,
980 result.ptr->otypesize, tt, tt, tt/*ldo*/));
981 }
982 #endif
983 if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & result.ptr->flags)) {
984 result.ptr->copy_a.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
985 result.ptr->itypesize, tt, tk, tk/*ldo*/, uldb/*ldi*/,
986 0/*flags*/, prf_copy, NULL/*unroll*/));
987 klda = (libxsmm_blasint)tt;
988 }
989 if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & result.ptr->flags)) {
990 result.ptr->copy_b.xmatcopy = libxsmm_dispatch_mcopy(libxsmm_mcopy_descriptor_init(&desc_blob,
991 result.ptr->itypesize, tk, tn, tk/*ldo*/, ulda/*ldi*/,
992 0/*flags*/, prf_copy, NULL/*unroll*/));
993 kldb = (libxsmm_blasint)tk;
994 }
995 result.ptr->lda = uldb; result.ptr->ldb = ulda;
996 result.ptr->km = tn; result.ptr->kn = tm;
997 result.ptr->mt = nt; result.ptr->nt = mt;
998 result.ptr->m = un; result.ptr->n = um;
999 result.ptr->dm = LIBXSMM_UPDIV(ntn, nt) * tn;
1000 result.ptr->dn = LIBXSMM_UPDIV(ntm, mt) * tm;
1001 km = kn = tt;
1002 }
1003 result.ptr->dk = ntk / kt * tk;
1004 result.ptr->kk = tk;
1005 result.ptr->kt = kt;
1006 result.ptr->k = uk;
1007 desc = libxsmm_gemm_descriptor_init2( /* remove transpose flags from kernel request */
1008 &desc_blob, iprec, oprec, km, kn, result.ptr->kk, klda, kldb, kldc,
1009 alpha, beta, result.ptr->gemm_flags & ~LIBXSMM_GEMM_FLAG_TRANS_AB, internal_gemm_tiled_prefetch);
1010 result.ptr->kernel[0] = libxsmm_xmmdispatch(desc);
1011 if (NULL != result.ptr->kernel[0].xmm) {
1012 if (0 == (desc->flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* beta!=0 */
1013 result.ptr->kernel[1] = result.ptr->kernel[0];
1014 }
1015 else { /* generate kernel with beta=1 */
1016 desc->flags &= ~LIBXSMM_GEMM_FLAG_BETA_0;
1017 result.ptr->kernel[1] = libxsmm_xmmdispatch(desc);
1018 if (NULL == result.ptr->kernel[1].xmm) result.ptr = NULL;
1019 }
1020 }
1021 else result.ptr = NULL;
1022 }
1023 else {
1024 result.ptr = NULL;
1025 }
1026 return result.ptr;
1027 }
1028
1029
libxsmm_gemm_handle_get_scratch_size_a(const libxsmm_gemm_handle * handle)1030 LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_a(const libxsmm_gemm_handle* handle)
1031 {
1032 size_t result;
1033 if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_A)
1034 && (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ||
1035 (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags) == 0)))
1036 {
1037 result = 0;
1038 }
1039 else {
1040 const size_t size = (size_t)handle->km * handle->kk * handle->itypesize;
1041 result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE);
1042 }
1043 return result;
1044 }
1045
1046
libxsmm_gemm_handle_get_scratch_size_b(const libxsmm_gemm_handle * handle)1047 LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_b(const libxsmm_gemm_handle* handle)
1048 {
1049 size_t result;
1050 if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_B)
1051 && (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ||
1052 (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags) == 0)))
1053 {
1054 result = 0;
1055 }
1056 else {
1057 const size_t size = (size_t)handle->kk * handle->kn * handle->itypesize;
1058 result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE);
1059 }
1060 return result;
1061 }
1062
1063
libxsmm_gemm_handle_get_scratch_size_c(const libxsmm_gemm_handle * handle)1064 LIBXSMM_API_INLINE size_t libxsmm_gemm_handle_get_scratch_size_c(const libxsmm_gemm_handle* handle)
1065 {
1066 size_t result;
1067 if (NULL == handle || (0 == (handle->flags & LIBXSMM_GEMM_HANDLE_FLAG_COPY_C)
1068 && LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)))
1069 {
1070 result = 0;
1071 }
1072 else {
1073 const size_t size = (size_t)handle->km * handle->kn * handle->otypesize;
1074 result = LIBXSMM_UP2(size, LIBXSMM_CACHELINE);
1075 }
1076 return result;
1077 }
1078
1079
libxsmm_gemm_handle_get_scratch_size(const libxsmm_gemm_handle * handle)1080 LIBXSMM_API size_t libxsmm_gemm_handle_get_scratch_size(const libxsmm_gemm_handle* handle)
1081 {
1082 size_t result;
1083 if (NULL != handle) { /* thread-local scratch buffer for GEMM */
1084 const size_t size_a = libxsmm_gemm_handle_get_scratch_size_a(handle);
1085 const size_t size_b = libxsmm_gemm_handle_get_scratch_size_b(handle);
1086 const size_t size_c = libxsmm_gemm_handle_get_scratch_size_c(handle);
1087 result = (size_a + size_b + size_c) * handle->mt * handle->nt * handle->kt;
1088 }
1089 else {
1090 result = 0;
1091 }
1092 return result;
1093 }
1094
1095
libxsmm_gemm_thread(const libxsmm_gemm_handle * handle,void * scratch,const void * a,const void * b,void * c,int tid,int nthreads)1096 LIBXSMM_API void libxsmm_gemm_thread(const libxsmm_gemm_handle* handle, void* scratch,
1097 const void* a, const void* b, void* c, /*unsigned*/int tid, /*unsigned*/int nthreads)
1098 {
1099 #if !defined(NDEBUG)
1100 if (NULL != handle && 0 <= tid && tid < nthreads)
1101 #endif
1102 {
1103 const unsigned int uthreads = (unsigned int)nthreads;
1104 const unsigned int ntasks = handle->mt * handle->nt * handle->kt;
1105 const unsigned int spread = (ntasks <= uthreads ? (uthreads / ntasks) : 1);
1106 const unsigned int utid = (unsigned int)tid, vtid = utid / spread;
1107 if (utid < (spread * ntasks) && 0 == (utid - vtid * spread)) {
1108 const int excess = (uthreads << 1) <= (vtid + ntasks);
1109 const unsigned int rtid = vtid / handle->mt, mtid = vtid - rtid * handle->mt, ntid = rtid % handle->nt, ktid = vtid / (handle->mt * handle->nt);
1110 const unsigned int m0 = mtid * handle->dm, m1 = (0 == excess ? LIBXSMM_MIN(m0 + handle->dm, handle->m) : handle->m);
1111 const unsigned int n0 = ntid * handle->dn, n1 = (0 == excess ? LIBXSMM_MIN(n0 + handle->dn, handle->n) : handle->n);
1112 const unsigned int k0 = ktid * handle->dk, k1 = (0 == excess ? LIBXSMM_MIN(k0 + handle->dk, handle->k) : handle->k);
1113 const unsigned int ldo = (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) ? handle->km : handle->kk);
1114 /* calculate increments to simplify address calculations */
1115 const unsigned int dom = handle->km * handle->otypesize;
1116 const unsigned int don = handle->kn * handle->otypesize;
1117 const unsigned int dik = handle->kk * handle->itypesize;
1118 const unsigned int on = handle->otypesize * n0;
1119 /* calculate base address of thread-local storage */
1120 const size_t size_a = libxsmm_gemm_handle_get_scratch_size_a(handle);
1121 const size_t size_b = libxsmm_gemm_handle_get_scratch_size_b(handle);
1122 const size_t size_c = libxsmm_gemm_handle_get_scratch_size_c(handle);
1123 char *const at = (char*)scratch + (size_a + size_b + size_c) * vtid;
1124 char *const bt = at + size_a, *const ct = bt + size_b;
1125 const libxsmm_xcopykernel kernel = { NULL };
1126 /* loop induction variables and other variables */
1127 unsigned int om = handle->otypesize * m0, im = m0, in = n0, ik = k0, im1, in1, ik1;
1128 LIBXSMM_ASSERT_MSG(mtid < handle->mt && ntid < handle->nt && ktid < handle->kt, "Invalid task-ID");
1129 LIBXSMM_ASSERT_MSG(m1 <= handle->m && n1 <= handle->n && k1 <= handle->k, "Invalid task size");
1130 for (im1 = im + handle->km; (im1 - 1) < m1; im = im1, im1 += handle->km, om += dom) {
1131 unsigned int dn = don, dka = dik, dkb = dik;
1132 char *c0 = (char*)c, *ci;
1133 const char *aa;
1134 if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1135 if (0 != (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags)) { /* TN */
1136 aa = (const char*)a + ((size_t)im * handle->lda + k0) * handle->itypesize;
1137 }
1138 else if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags)) { /* NT */
1139 aa = (const char*)a + ((size_t)k0 * handle->lda + im) * handle->itypesize;
1140 dka *= handle->lda; dkb *= handle->ldb;
1141 }
1142 else { /* NN */
1143 aa = (const char*)a + ((size_t)k0 * handle->lda + im) * handle->itypesize;
1144 dka *= handle->lda;
1145 }
1146 c0 += (size_t)on * handle->ldc + om;
1147 dn *= handle->ldc;
1148 }
1149 else { /* TT */
1150 aa = (const char*)b + ((size_t)k0 * handle->lda + im) * handle->itypesize;
1151 c0 += (size_t)on + handle->ldc * (size_t)om;
1152 dka *= handle->lda;
1153 }
1154 for (in = n0, in1 = in + handle->kn; (in1 - 1) < n1; in = in1, in1 += handle->kn, c0 += dn) {
1155 const char *a0 = aa, *b0 = (const char*)b;
1156 if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1157 if (0 != (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags)) { /* NT */
1158 b0 += ((size_t)k0 * handle->ldb + in) * handle->itypesize;
1159 }
1160 else { /* NN or TN */
1161 b0 += ((size_t)in * handle->ldb + k0) * handle->itypesize;
1162 }
1163 }
1164 else { /* TT */
1165 b0 = (const char*)a + ((size_t)in * handle->ldb + k0) * handle->itypesize;
1166 }
1167 if (NULL == handle->copy_i.ptr_const) {
1168 ci = (NULL == handle->copy_o.ptr_const ? c0 : ct);
1169 if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1170 const unsigned int km = handle->kn, kn = handle->km;
1171 libxsmm_otrans_internal(ct/*out*/, c0/*in*/, handle->otypesize, handle->ldc/*ldi*/, kn/*ldo*/,
1172 0, km, 0, kn, km/*tile*/, kn/*tile*/, kernel);
1173 ci = ct;
1174 }
1175 else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & handle->flags)) {
1176 if (0 == (handle->gemm_flags & LIBXSMM_GEMM_FLAG_BETA_0)) { /* copy-in only if beta!=0 */
1177 libxsmm_matcopy_internal(ct/*out*/, c0/*in*/, handle->otypesize, handle->ldc/*ldi*/, handle->km/*ldo*/,
1178 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel);
1179 }
1180 ci = ct;
1181 }
1182 }
1183 else { /* MCOPY/TCOPY kernel */
1184 handle->copy_i.xmatcopy(c0, &handle->ldc, ct, &handle->km);
1185 ci = ct;
1186 }
1187 for (ik = k0, ik1 = ik + handle->kk; (ik1 - 1) < k1; ik = ik1, ik1 += handle->kk) {
1188 const char *const a1 = a0 + dka, *const b1 = b0 + dkb, *ai = a0, *bi = b0;
1189 if (NULL == handle->copy_a.ptr_const) {
1190 if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) &&
1191 (LIBXSMM_GEMM_FLAG_TRANS_A & handle->gemm_flags) != 0) /* pure A-transpose */
1192 {
1193 LIBXSMM_ASSERT(ldo == handle->km);
1194 libxsmm_otrans_internal(at/*out*/, a0/*in*/, handle->itypesize, handle->lda/*ldi*/, ldo,
1195 0, handle->kk, 0, handle->km, handle->kk/*tile*/, handle->km/*tile*/, kernel);
1196 ai = at;
1197 }
1198 else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_A & handle->flags)) {
1199 libxsmm_matcopy_internal(at/*out*/, a0/*in*/, handle->itypesize, handle->lda/*ldi*/, ldo,
1200 0, handle->km, 0, handle->kk, handle->km/*tile*/, handle->kk/*tile*/, kernel);
1201 ai = at;
1202 }
1203 }
1204 else { /* MCOPY/TCOPY kernel */
1205 handle->copy_a.xmatcopy(a0, &handle->lda, at, &ldo);
1206 ai = at;
1207 }
1208 if (NULL == handle->copy_b.ptr_const) {
1209 if (LIBXSMM_GEMM_FLAG_TRANS_AB != (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags) &&
1210 (LIBXSMM_GEMM_FLAG_TRANS_B & handle->gemm_flags) != 0) /* pure B-transpose */
1211 {
1212 libxsmm_otrans_internal(bt/*out*/, b0/*in*/, handle->itypesize, handle->ldb/*ldi*/, handle->kk/*ldo*/,
1213 0, handle->kn, 0, handle->kk, handle->kn/*tile*/, handle->kk/*tile*/, kernel);
1214 bi = bt;
1215 }
1216 else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_B & handle->flags)) {
1217 libxsmm_matcopy_internal(bt/*out*/, b0/*in*/, handle->itypesize, handle->ldb/*ldi*/, handle->kk/*ldo*/,
1218 0, handle->kk, 0, handle->kn, handle->kk/*tile*/, handle->kn/*tile*/, kernel);
1219 bi = bt;
1220 }
1221 }
1222 else { /* MCOPY/TCOPY kernel */
1223 handle->copy_b.xmatcopy(b0, &handle->ldb, bt, &handle->kk);
1224 bi = bt;
1225 }
1226 /* beta0-kernel on first-touch, beta1-kernel otherwise (beta0/beta1 are identical if beta=1) */
1227 LIBXSMM_MMCALL_PRF(handle->kernel[k0!=ik?1:0].xmm, ai, bi, ci, a1, b1, c0);
1228 a0 = a1;
1229 b0 = b1;
1230 }
1231 /* TODO: synchronize */
1232 if (NULL == handle->copy_o.ptr_const) {
1233 if (LIBXSMM_GEMM_FLAG_TRANS_AB == (LIBXSMM_GEMM_FLAG_TRANS_AB & handle->gemm_flags)) {
1234 libxsmm_otrans_internal(c0/*out*/, ct/*in*/, handle->otypesize, handle->km/*ldi*/, handle->ldc/*ldo*/,
1235 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel);
1236 }
1237 else if (0 != (LIBXSMM_GEMM_HANDLE_FLAG_COPY_C & handle->flags)) {
1238 libxsmm_matcopy_internal(c0/*out*/, ct/*in*/, handle->otypesize, handle->km/*ldi*/, handle->ldc/*ldo*/,
1239 0, handle->km, 0, handle->kn, handle->km/*tile*/, handle->kn/*tile*/, kernel);
1240 }
1241 }
1242 else { /* MCOPY/TCOPY kernel */
1243 handle->copy_o.xmatcopy(ct, &handle->km, c0, &handle->ldc);
1244 }
1245 }
1246 }
1247 }
1248 }
1249 #if !defined(NDEBUG)
1250 else if (/*implies LIBXSMM_INIT*/0 != libxsmm_get_verbosity()) { /* library code is expected to be mute */
1251 static int error_once = 0;
1252 if (1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED)) {
1253 fprintf(stderr, "LIBXSMM ERROR: libxsmm_gemm_thread - invalid handle!\n");
1254 }
1255 }
1256 #endif
1257 }
1258
1259
libxsmm_xgemm(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc)1260 LIBXSMM_API void libxsmm_xgemm(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
1261 const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1262 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
1263 const void* beta, void* c, const libxsmm_blasint* ldc)
1264 {
1265 libxsmm_gemm_blob blob;
1266 const libxsmm_gemm_handle *const handle = libxsmm_gemm_handle_init(&blob, iprec, oprec, transa, transb,
1267 m, n, k, lda, ldb, ldc, alpha, beta, LIBXSMM_GEMM_HANDLE_FLAG_AUTO, 1/*ntasks*/);
1268 const size_t scratch_size = libxsmm_gemm_handle_get_scratch_size(handle);
1269 void* scratch = NULL;
1270 if (NULL != handle && (0 == scratch_size ||
1271 NULL != (scratch = libxsmm_scratch_malloc(scratch_size, LIBXSMM_CACHELINE, LIBXSMM_MALLOC_INTERNAL_CALLER))))
1272 {
1273 libxsmm_gemm_thread(handle, scratch, a, b, c, 0/*tid*/, 1/*ntasks*/);
1274 libxsmm_free(scratch);
1275 }
1276 else { /* fall-back or error */
1277 static int error_once = 0;
1278 if (NULL == handle) { /* fall-back */
1279 if ((LIBXSMM_VERBOSITY_HIGH <= libxsmm_verbosity || 0 > libxsmm_verbosity) /* library code is expected to be mute */
1280 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1281 {
1282 fprintf(stderr, "LIBXSMM WARNING (XGEMM): fall-back code path triggered!\n");
1283 }
1284 }
1285 else if (0 != libxsmm_verbosity && /* library code is expected to be mute */
1286 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1287 {
1288 fprintf(stderr, "LIBXSMM ERROR: failed to allocate GEMM-scratch memory!\n");
1289 }
1290 libxsmm_blas_xgemm(iprec, oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1291 }
1292 }
1293
1294
libxsmm_dgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const double alpha_array[],const double * a_array[],const libxsmm_blasint lda_array[],const double * b_array[],const libxsmm_blasint ldb_array[],const double beta_array[],double * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])1295 LIBXSMM_API void libxsmm_dgemm_batch(
1296 const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
1297 const double alpha_array[], const double* a_array[], const libxsmm_blasint lda_array[], const double* b_array[], const libxsmm_blasint ldb_array[],
1298 const double beta_array[], double* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
1299 {
1300 const libxsmm_blasint ngroups = LIBXSMM_ABS(*group_count), ptrsize = sizeof(void*);
1301 libxsmm_blasint i, j = 0;
1302 for (i = 0; i < ngroups; ++i) {
1303 const libxsmm_blasint size = group_size[i];
1304 libxsmm_gemm_batch(LIBXSMM_GEMM_PRECISION_F64, LIBXSMM_GEMM_PRECISION_F64, transa_array + i, transb_array + i,
1305 m_array[i], n_array[i], k_array[i], alpha_array + i, a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, c_array + j, ldc_array + i,
1306 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
1307 j += LIBXSMM_ABS(size);
1308 }
1309 }
1310
1311
libxsmm_sgemm_batch(const char transa_array[],const char transb_array[],const libxsmm_blasint m_array[],const libxsmm_blasint n_array[],const libxsmm_blasint k_array[],const float alpha_array[],const float * a_array[],const libxsmm_blasint lda_array[],const float * b_array[],const libxsmm_blasint ldb_array[],const float beta_array[],float * c_array[],const libxsmm_blasint ldc_array[],const libxsmm_blasint * group_count,const libxsmm_blasint group_size[])1312 LIBXSMM_API void libxsmm_sgemm_batch(
1313 const char transa_array[], const char transb_array[], const libxsmm_blasint m_array[], const libxsmm_blasint n_array[], const libxsmm_blasint k_array[],
1314 const float alpha_array[], const float* a_array[], const libxsmm_blasint lda_array[], const float* b_array[], const libxsmm_blasint ldb_array[],
1315 const float beta_array[], float* c_array[], const libxsmm_blasint ldc_array[], const libxsmm_blasint* group_count, const libxsmm_blasint group_size[])
1316 {
1317 const libxsmm_blasint ngroups = LIBXSMM_ABS(*group_count), ptrsize = sizeof(void*);
1318 libxsmm_blasint i, j = 0;
1319 for (i = 0; i < ngroups; ++i) {
1320 const libxsmm_blasint size = group_size[i];
1321 libxsmm_gemm_batch(LIBXSMM_GEMM_PRECISION_F32, LIBXSMM_GEMM_PRECISION_F32, transa_array + i, transb_array + i,
1322 m_array[i], n_array[i], k_array[i], alpha_array + i, a_array + j, lda_array + i, b_array + j, ldb_array + i, beta_array + i, c_array + j, ldc_array + i,
1323 0/*index_base*/, 0/*index_stride*/, &ptrsize, &ptrsize, &ptrsize, size);
1324 j += LIBXSMM_ABS(size);
1325 }
1326 }
1327
1328
libxsmm_dgemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const double * alpha,const double * a,const libxsmm_blasint * lda,const double * b,const libxsmm_blasint * ldb,const double * beta,double * c,const libxsmm_blasint * ldc)1329 LIBXSMM_API void libxsmm_dgemm(const char* transa, const char* transb,
1330 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1331 const double* alpha, const double* a, const libxsmm_blasint* lda,
1332 const double* b, const libxsmm_blasint* ldb,
1333 const double* beta, double* c, const libxsmm_blasint* ldc)
1334 {
1335 LIBXSMM_XGEMM(double, double, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1336 }
1337
1338
libxsmm_sgemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const float * alpha,const float * a,const libxsmm_blasint * lda,const float * b,const libxsmm_blasint * ldb,const float * beta,float * c,const libxsmm_blasint * ldc)1339 LIBXSMM_API void libxsmm_sgemm(const char* transa, const char* transb,
1340 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1341 const float* alpha, const float* a, const libxsmm_blasint* lda,
1342 const float* b, const libxsmm_blasint* ldb,
1343 const float* beta, float* c, const libxsmm_blasint* ldc)
1344 {
1345 LIBXSMM_XGEMM(float, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1346 }
1347
1348
libxsmm_wigemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const int * alpha,const short * a,const libxsmm_blasint * lda,const short * b,const libxsmm_blasint * ldb,const int * beta,int * c,const libxsmm_blasint * ldc)1349 LIBXSMM_API void libxsmm_wigemm(const char* transa, const char* transb,
1350 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1351 const int* alpha, const short* a, const libxsmm_blasint* lda,
1352 const short* b, const libxsmm_blasint* ldb,
1353 const int* beta, int* c, const libxsmm_blasint* ldc)
1354 {
1355 LIBXSMM_XGEMM(short, int, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1356 }
1357
1358
libxsmm_bsgemm(const char * transa,const char * transb,const libxsmm_blasint * m,const libxsmm_blasint * n,const libxsmm_blasint * k,const float * alpha,const libxsmm_bfloat16 * a,const libxsmm_blasint * lda,const libxsmm_bfloat16 * b,const libxsmm_blasint * ldb,const float * beta,float * c,const libxsmm_blasint * ldc)1359 LIBXSMM_API void libxsmm_bsgemm(const char* transa, const char* transb,
1360 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1361 const float* alpha, const libxsmm_bfloat16* a, const libxsmm_blasint* lda,
1362 const libxsmm_bfloat16* b, const libxsmm_blasint* ldb,
1363 const float* beta, float* c, const libxsmm_blasint* ldc)
1364 {
1365 LIBXSMM_XGEMM(libxsmm_bfloat16, float, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1366 }
1367
1368
libxsmm_mmbatch_kernel(libxsmm_xmmfunction kernel,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],const void * a,const void * b,void * c,libxsmm_blasint batchsize,int tid,int ntasks,unsigned char itypesize,unsigned char otypesize,int flags)1369 LIBXSMM_API int libxsmm_mmbatch_kernel(libxsmm_xmmfunction kernel, libxsmm_blasint index_base,
1370 libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1371 const void* a, const void* b, void* c, libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int ntasks,
1372 unsigned char itypesize, unsigned char otypesize, int flags)
1373 {
1374 int result = EXIT_SUCCESS;
1375 const libxsmm_blasint size = LIBXSMM_ABS(batchsize);
1376 const libxsmm_blasint tasksize = LIBXSMM_UPDIV(size, ntasks);
1377 const libxsmm_blasint begin = tid * tasksize, span = begin + tasksize;
1378 const libxsmm_blasint end = LIBXSMM_MIN(span, size);
1379
1380 LIBXSMM_ASSERT(NULL != kernel.xmm);
1381 if (begin < end) {
1382 const char *const a0 = (const char*)a, *const b0 = (const char*)b;
1383 char *const c0 = (char*)c;
1384
1385 LIBXSMM_ASSERT(0 < itypesize && 0 < otypesize);
1386 if (0 == (LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS & flags)) {
1387 if (0 != index_stride) { /* stride arrays contain indexes */
1388 libxsmm_blasint i = begin * index_stride, ic = (NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0);
1389 const char* ai = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1390 const char* bi = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1391 char* ci = &c0[ic * otypesize];
1392 const libxsmm_blasint end1 = (end != size ? end : (end - 1)) * index_stride;
1393 #if (0 != LIBXSMM_SYNC)
1394 if (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize || 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & flags))
1395 #endif
1396 { /* no locking */
1397 if (NULL != stride_a && NULL != stride_b && NULL != stride_c) {
1398 const unsigned char ibits = (unsigned char)LIBXSMM_INTRINSICS_BITSCANBWD32(itypesize);
1399 const unsigned char obits = (unsigned char)LIBXSMM_INTRINSICS_BITSCANBWD32(otypesize);
1400
1401 if (itypesize == (1 << ibits) && otypesize == (1 << obits)) {
1402 for (i += index_stride; i <= end1; i += index_stride) {
1403 const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) << ibits];
1404 const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) << ibits];
1405 char *const cn = &c0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) << obits];
1406 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1407 ai = an; bi = bn; ci = cn;
1408 }
1409 }
1410 else { /* non-pot type sizes */
1411 for (i += index_stride; i <= end1; i += index_stride) {
1412 const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize];
1413 const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize];
1414 char *const cn = &c0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize];
1415 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1416 ai = an; bi = bn; ci = cn;
1417 }
1418 }
1419 }
1420 else { /* mixed specification of strides */
1421 for (i += index_stride; i <= end1; i += index_stride) {
1422 const char *const an = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1423 const char *const bn = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1424 char *const cn = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize) : 0];
1425 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1426 ai = an; bi = bn; ci = cn;
1427 }
1428 }
1429 if (end == size) { /* remainder multiplication */
1430 kernel.xmm(ai, bi, ci, ai, bi, ci); /* pseudo-prefetch */
1431 }
1432 }
1433 #if (0 != LIBXSMM_SYNC)
1434 else { /* synchronize among C-indexes */
1435 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state;
1436 # if defined(LIBXSMM_GEMM_LOCKFWD)
1437 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock0 = NULL;
1438 # endif
1439 LIBXSMM_ASSERT(NULL != lock);
1440 if (NULL != stride_a && NULL != stride_b && NULL != stride_c) {
1441 for (i += index_stride; i <= end1; i += index_stride) {
1442 ic = LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base;
1443 {
1444 const char *const an = &a0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize];
1445 const char *const bn = &b0[(LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize];
1446 char *const cn = &c0[ic * otypesize];
1447 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state;
1448 # if defined(LIBXSMM_GEMM_LOCKFWD)
1449 if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); }
1450 # else
1451 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1452 # endif
1453 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1454 # if defined(LIBXSMM_GEMM_LOCKFWD)
1455 if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; }
1456 # else
1457 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1;
1458 # endif
1459 ai = an; bi = bn; ci = cn; /* next */
1460 }
1461 }
1462 }
1463 else {
1464 for (i += index_stride; i <= end1; i += index_stride) {
1465 ic = (NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0);
1466 {
1467 const char *const an = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1468 const char *const bn = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1469 char *const cn = &c0[ic * otypesize];
1470 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKIDX(ic, internal_gemm_nlocks)].state;
1471 # if defined(LIBXSMM_GEMM_LOCKFWD)
1472 if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); }
1473 # else
1474 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1475 # endif
1476 kernel.xmm(ai, bi, ci, an, bn, cn); /* with prefetch */
1477 # if defined(LIBXSMM_GEMM_LOCKFWD)
1478 if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; }
1479 # else
1480 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1;
1481 # endif
1482 ai = an; bi = bn; ci = cn; /* next */
1483 }
1484 }
1485 }
1486 if (end == size) { /* remainder multiplication */
1487 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1488 kernel.xmm(ai, bi, ci, ai, bi, ci); /* pseudo-prefetch */
1489 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock);
1490 }
1491 }
1492 #endif /*(0 != LIBXSMM_SYNC)*/
1493 }
1494 else { /* singular strides are measured in Bytes */
1495 const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1496 const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1497 const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1498 libxsmm_blasint i;
1499 const libxsmm_blasint end1 = (end != size ? end : (end - 1));
1500 const char *ai = a0 + (size_t)da * begin, *bi = b0 + (size_t)db * begin;
1501 char* ci = c0 + (size_t)dc * begin;
1502 #if (0 != LIBXSMM_SYNC)
1503 if (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize || 0 != (LIBXSMM_GEMM_FLAG_BETA_0 & flags))
1504 #endif
1505 { /* no locking */
1506 for (i = begin; i < end1; ++i) {
1507 const char *const an = ai + da, *const bn = bi + db;
1508 char *const cn = ci + dc;
1509 #if defined(LIBXSMM_GEMM_CHECK)
1510 if (NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != *((const void**)ci))
1511 #endif
1512 {
1513 kernel.xmm( /* with prefetch */
1514 *((const void**)ai), *((const void**)bi), *((void**)ci),
1515 *((const void**)an), *((const void**)bn), *((const void**)cn));
1516 }
1517 ai = an; bi = bn; ci = cn; /* next */
1518 }
1519 if ( /* remainder multiplication */
1520 #if defined(LIBXSMM_GEMM_CHECK)
1521 NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != *((const void**)ci) &&
1522 #endif
1523 end == size)
1524 {
1525 kernel.xmm( /* pseudo-prefetch */
1526 *((const void**)ai), *((const void**)bi), *((void**)ci),
1527 *((const void**)ai), *((const void**)bi), *((const void**)ci));
1528 }
1529 }
1530 #if (0 != LIBXSMM_SYNC)
1531 else { /* synchronize among C-indexes */
1532 void* cc = *((void**)ci);
1533 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock = &internal_gemm_lock[LIBXSMM_GEMM_LOCKPTR(cc, internal_gemm_nlocks)].state;
1534 # if defined(LIBXSMM_GEMM_LOCKFWD)
1535 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK)* lock0 = NULL;
1536 # endif
1537 LIBXSMM_ASSERT(NULL != lock);
1538 for (i = begin + 1; i <= end1; ++i) {
1539 const char *const an = ai + da, *const bn = bi + db;
1540 char *const cn = ci + dc;
1541 void *const nc = *((void**)cn);
1542 # if defined(LIBXSMM_GEMM_CHECK)
1543 if (NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != cc)
1544 # endif
1545 {
1546 LIBXSMM_LOCK_TYPE(LIBXSMM_GEMM_LOCK) *const lock1 = &internal_gemm_lock[LIBXSMM_GEMM_LOCKPTR(nc, internal_gemm_nlocks)].state;
1547 # if defined(LIBXSMM_GEMM_LOCKFWD)
1548 if (lock != lock0) { lock0 = lock; LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock); }
1549 # else
1550 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1551 # endif
1552 kernel.xmm( /* with prefetch */
1553 *((const void**)ai), *((const void**)bi), cc,
1554 *((const void**)an), *((const void**)bn), *((const void**)cn));
1555 # if defined(LIBXSMM_GEMM_LOCKFWD)
1556 if (lock != lock1 || i == end1) { LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1; }
1557 # else
1558 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock); lock = lock1;
1559 # endif
1560 }
1561 ai = an; bi = bn; ci = cn; cc = nc; /* next */
1562 }
1563 if ( /* remainder multiplication */
1564 # if defined(LIBXSMM_GEMM_CHECK)
1565 NULL != *((const void**)ai) && NULL != *((const void**)bi) && NULL != cc &&
1566 # endif
1567 end == size)
1568 {
1569 LIBXSMM_LOCK_ACQUIRE(LIBXSMM_GEMM_LOCK, lock);
1570 kernel.xmm( /* pseudo-prefetch */
1571 *((const void**)ai), *((const void**)bi), cc,
1572 *((const void**)ai), *((const void**)bi), cc);
1573 LIBXSMM_LOCK_RELEASE(LIBXSMM_GEMM_LOCK, lock);
1574 }
1575 }
1576 #endif /*(0 != LIBXSMM_SYNC)*/
1577 }
1578 }
1579 #if defined(LIBXSMM_GEMM_BATCHREDUCE)
1580 else /* LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS */
1581 # if defined(LIBXSMM_GEMM_CHECK)
1582 if (
1583 # if (0 != LIBXSMM_SYNC)
1584 (1 == ntasks || 0 == internal_gemm_nlocks || 0 > batchsize) &&
1585 # endif
1586 (0 == (LIBXSMM_GEMM_FLAG_BETA_0 & flags)) &&
1587 (0 != internal_gemm_batchreduce))
1588 # endif
1589 {
1590 const unsigned int n = libxsmm_mmbatch_size * (LIBXSMM_GEMM_BATCHSCALE) / ((unsigned int)sizeof(void*));
1591 LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array && 0 != libxsmm_mmbatch_size);
1592 if ((2U/*A and B matrices*/ * tasksize) <= n) {
1593 const void **ai = (const void**)libxsmm_mmbatch_array + begin, **bi = ai + size;
1594 unsigned long long count;
1595 if (0 != index_stride) { /* stride arrays contain indexes */
1596 const size_t end_stride = (size_t)end * index_stride;
1597 size_t i = (size_t)begin * index_stride;
1598 char *ci = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) * otypesize) : 0], *cn = ci;
1599 do {
1600 for (count = 0; i < end_stride && ci == cn; ++count) {
1601 const size_t j = i + index_stride;
1602 *ai++ = &a0[NULL != stride_a ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) * itypesize) : 0];
1603 *bi++ = &b0[NULL != stride_b ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) * itypesize) : 0];
1604 cn = &c0[NULL != stride_c ? ((LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, j) - index_base) * otypesize) : 0];
1605 i = j;
1606 }
1607 ai = (const void**)libxsmm_mmbatch_array + begin; bi = ai + size;
1608 kernel.xbm(ai, bi, ci, &count);
1609 ci = cn;
1610 } while (i < end_stride);
1611 }
1612 else { /* singular strides are measured in Bytes */
1613 const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1614 const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1615 const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1616 const char *ia = a0 + (size_t)da * begin, *ib = b0 + (size_t)db * begin;
1617 char* ic = c0 + (size_t)dc * begin;
1618 if (
1619 # if defined(LIBXSMM_GEMM_CHECK)
1620 NULL != *((const void**)ia) && NULL != *((const void**)ib) && NULL != *((const void**)ic) &&
1621 # endif
1622 sizeof(void*) == da && sizeof(void*) == db) /* fast path */
1623 {
1624 if (0 != dc) {
1625 libxsmm_blasint i = begin;
1626 char* jc = ic;
1627 do {
1628 for (count = 0; i < end && *((const void**)ic) == *((const void**)jc); ++i) {
1629 # if defined(LIBXSMM_GEMM_CHECK)
1630 if (NULL != *((const void**)jc))
1631 # endif
1632 ++count;
1633 jc += dc; /* next */
1634 }
1635 memcpy((void*)ai, ia, count * sizeof(void*));
1636 memcpy((void*)bi, ib, count * sizeof(void*));
1637 kernel.xbm(ai, bi, *((void**)ic), &count);
1638 ic = jc;
1639 } while (i < end);
1640 }
1641 else { /* fastest path */
1642 count = (unsigned long long)end - begin;
1643 memcpy((void*)ai, ia, count * sizeof(void*));
1644 memcpy((void*)bi, ib, count * sizeof(void*));
1645 kernel.xbm(ai, bi, *((void**)ic), &count);
1646 }
1647 }
1648 else { /* custom-copy required */
1649 libxsmm_blasint i = begin;
1650 char* jc = ic;
1651 do {
1652 for (count = 0; i < end && *((const void**)ic) == *((const void**)jc); ++i) {
1653 # if defined(LIBXSMM_GEMM_CHECK)
1654 if (NULL != *((const void**)ia) && NULL != *((const void**)ib) && NULL != *((const void**)jc))
1655 # endif
1656 {
1657 *ai++ = *((const void**)ia); *bi++ = *((const void**)ib);
1658 ++count;
1659 }
1660 ia += da; ib += db; jc += dc; /* next */
1661 }
1662 ai = (const void**)libxsmm_mmbatch_array + begin; bi = ai + size;
1663 kernel.xbm(ai, bi, *((void**)ic), &count);
1664 ic = jc;
1665 } while (i < end);
1666 }
1667 }
1668 }
1669 else { /* fall-back */
1670 result = EXIT_FAILURE;
1671 }
1672 }
1673 #endif /*defined(LIBXSMM_GEMM_BATCHREDUCE)*/
1674 }
1675 /* coverity[missing_unlock] */
1676 return result;
1677 }
1678
1679
libxsmm_gemm_internal_set_batchflag(libxsmm_gemm_descriptor * descriptor,void * c,libxsmm_blasint index_stride,libxsmm_blasint batchsize,int multithreaded)1680 LIBXSMM_API void libxsmm_gemm_internal_set_batchflag(libxsmm_gemm_descriptor* descriptor, void* c, libxsmm_blasint index_stride,
1681 libxsmm_blasint batchsize, int multithreaded)
1682 {
1683 LIBXSMM_ASSERT(NULL != descriptor);
1684 if (0 != (LIBXSMM_GEMM_FLAG_BETA_0 & descriptor->flags)) {
1685 const uintptr_t vw = (LIBXSMM_X86_AVX512 <= libxsmm_target_archid ? 64 : 32);
1686 /* assume that all C-matrices are aligned eventually */
1687 if (0 == LIBXSMM_MOD2((uintptr_t)c, vw)
1688 #if 0 /* should fall-back in BE */
1689 && LIBXSMM_X86_AVX <= libxsmm_target_archid
1690 #endif
1691 && 0 != index_stride)
1692 {
1693 const int oprec = LIBXSMM_GETENUM_OUT(descriptor->datatype);
1694 const libxsmm_blasint typesize = LIBXSMM_TYPESIZE(oprec);
1695 const libxsmm_blasint csize = (libxsmm_blasint)descriptor->ldc * descriptor->n * typesize;
1696 /* finalize assumption if matrix-size is a multiple of the vector-width */
1697 descriptor->flags |= (unsigned short)(0 == LIBXSMM_MOD2(csize, vw) ? LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT : 0);
1698 }
1699 }
1700 #if defined(LIBXSMM_GEMM_BATCHREDUCE)
1701 else if (0 != internal_gemm_batchreduce) { /* check if reduce-batch kernel can be used */
1702 static int error_once = 0;
1703 LIBXSMM_ASSERT(NULL != libxsmm_mmbatch_array);
1704 # if (0 != LIBXSMM_SYNC)
1705 if (0 == multithreaded || 0 == internal_gemm_nlocks || 0 > batchsize)
1706 # endif
1707 {
1708 int result = EXIT_FAILURE;
1709 switch (LIBXSMM_GETENUM_INP(descriptor->datatype)) {
1710 case LIBXSMM_GEMM_PRECISION_F64: {
1711 if (LIBXSMM_GEMM_PRECISION_F64 == LIBXSMM_GETENUM_OUT(descriptor->datatype)) {
1712 result = EXIT_SUCCESS;
1713 }
1714 } break;
1715 case LIBXSMM_GEMM_PRECISION_F32: {
1716 if (LIBXSMM_GEMM_PRECISION_F32 == LIBXSMM_GETENUM_OUT(descriptor->datatype)) {
1717 result = EXIT_SUCCESS;
1718 }
1719 } break;
1720 }
1721 if (EXIT_SUCCESS == result) {
1722 descriptor->flags |= LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS;
1723 descriptor->prefetch = 0; /* omit decision */
1724 }
1725 else {
1726 if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) && /* library code is expected to be mute */
1727 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1728 {
1729 fprintf(stderr, "LIBXSMM WARNING: data type not supported in batch-reduce!\n");
1730 }
1731 }
1732 }
1733 # if (0 != LIBXSMM_SYNC)
1734 else if ((LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) && /* library code is expected to be mute */
1735 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1736 {
1737 fprintf(stderr, "LIBXSMM: potential data races prevent batch-reduce.\n");
1738 }
1739 # endif
1740 }
1741 #endif /*defined(LIBXSMM_GEMM_BATCHREDUCE)*/
1742 #if !defined(LIBXSMM_GEMM_BATCHREDUCE) || (0 == LIBXSMM_SYNC)
1743 LIBXSMM_UNUSED(batchsize); LIBXSMM_UNUSED(multithreaded);
1744 #endif
1745 }
1746
1747
libxsmm_dmmbatch_blas(const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const double * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const double * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1748 LIBXSMM_API_INTERN void libxsmm_dmmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1749 const double* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const double* beta, void* c, const libxsmm_blasint* ldc,
1750 libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1751 libxsmm_blasint batchsize)
1752 {
1753 const libxsmm_blasint end = LIBXSMM_ABS(batchsize);
1754 libxsmm_blasint i;
1755
1756 if (0 != index_stride) { /* stride arrays contain indexes */
1757 const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base) : 0);
1758 const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base) : 0);
1759 const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base) : 0);
1760 const libxsmm_blasint end1 = end * index_stride;
1761 const double *const a0 = (const double*)a, *const b0 = (const double*)b, *ai = a0 + da, *bi = b0 + db;
1762 double *const c0 = (double*)c, *ci = c0 + dc;
1763 for (i = index_stride; i <= end1; i += index_stride) {
1764 const double *const an = &a0[NULL != stride_a ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) : 0];
1765 const double *const bn = &b0[NULL != stride_b ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) : 0];
1766 double *const cn = &c0[NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0];
1767 #if defined(LIBXSMM_GEMM_CHECK)
1768 if (NULL != ai && NULL != bi && NULL != ci)
1769 #endif
1770 {
1771 libxsmm_blas_dgemm(transa, transb, &m, &n, &k, alpha, ai, lda, bi, ldb, beta, ci, ldc);
1772 }
1773 ai = an; bi = bn; ci = cn;
1774 }
1775 }
1776 else { /* singular strides are measured in Bytes */
1777 const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1778 const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1779 const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1780 const char *const a0 = (const char*)a, *const b0 = (const char*)b, *ai = a0, *bi = b0;
1781 char *const c0 = (char*)c, *ci = c0;
1782 for (i = 0; i < end; ++i) {
1783 const char *const an = ai + da, *const bn = bi + db;
1784 char *const cn = ci + dc;
1785 #if defined(LIBXSMM_GEMM_CHECK)
1786 if (NULL != *((const double**)ai) && NULL != *((const double**)bi) && NULL != *((const double**)ci))
1787 #endif
1788 {
1789 libxsmm_blas_dgemm(transa, transb, &m, &n, &k, alpha, *((const double**)ai), lda, *((const double**)bi), ldb, beta, *((double**)ci), ldc);
1790 }
1791 ai = an; bi = bn; ci = cn; /* next */
1792 }
1793 }
1794 }
1795
1796
libxsmm_smmbatch_blas(const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const float * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const float * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1797 LIBXSMM_API_INTERN void libxsmm_smmbatch_blas(const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1798 const float* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const float* beta, void* c, const libxsmm_blasint* ldc,
1799 libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1800 libxsmm_blasint batchsize)
1801 {
1802 const libxsmm_blasint end = LIBXSMM_ABS(batchsize);
1803 libxsmm_blasint i;
1804
1805 if (0 != index_stride) { /* stride arrays contain indexes */
1806 const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base) : 0);
1807 const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base) : 0);
1808 const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base) : 0);
1809 const libxsmm_blasint end1 = end * index_stride;
1810 const float *a0 = (const float*)a, *b0 = (const float*)b, *ai = a0 + da, *bi = b0 + db;
1811 float *c0 = (float*)c, *ci = c0 + dc;
1812 for (i = index_stride; i <= end1; i += index_stride) {
1813 const float *const an = &a0[NULL != stride_a ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_a, i) - index_base) : 0];
1814 const float *const bn = &b0[NULL != stride_b ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_b, i) - index_base) : 0];
1815 float *const cn = &c0[NULL != stride_c ? (LIBXSMM_ACCESS(const libxsmm_blasint, stride_c, i) - index_base) : 0];
1816 #if defined(LIBXSMM_GEMM_CHECK)
1817 if (NULL != ai && NULL != bi && NULL != ci)
1818 #endif
1819 {
1820 libxsmm_blas_sgemm(transa, transb, &m, &n, &k, alpha, ai, lda, bi, ldb, beta, ci, ldc);
1821 }
1822 ai = an; bi = bn; ci = cn;
1823 }
1824 }
1825 else { /* singular strides are measured in Bytes */
1826 const libxsmm_blasint da = (NULL != stride_a ? (*stride_a - index_base * sizeof(void*)) : 0);
1827 const libxsmm_blasint db = (NULL != stride_b ? (*stride_b - index_base * sizeof(void*)) : 0);
1828 const libxsmm_blasint dc = (NULL != stride_c ? (*stride_c - index_base * sizeof(void*)) : 0);
1829 const char *a0 = (const char*)a, *b0 = (const char*)b, *ai = a0, *bi = b0;
1830 char *c0 = (char*)c, *ci = c0;
1831 for (i = 0; i < end; ++i) {
1832 const char *const an = ai + da;
1833 const char *const bn = bi + db;
1834 char *const cn = ci + dc;
1835 #if defined(LIBXSMM_GEMM_CHECK)
1836 if (NULL != *((const float**)ai) && NULL != *((const float**)bi) && NULL != *((const float**)ci))
1837 #endif
1838 {
1839 libxsmm_blas_sgemm(transa, transb, &m, &n, &k, alpha, *((const float**)ai), lda, *((const float**)bi), ldb, beta, *((float**)ci), ldc);
1840 }
1841 ai = an; bi = bn; ci = cn; /* next */
1842 }
1843 }
1844 }
1845
1846
libxsmm_mmbatch_blas(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1847 LIBXSMM_API int libxsmm_mmbatch_blas(
1848 libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1849 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb, const void* beta, void* c, const libxsmm_blasint* ldc,
1850 libxsmm_blasint index_base, libxsmm_blasint index_stride, const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1851 libxsmm_blasint batchsize)
1852 {
1853 int result;
1854 if (NULL != a && NULL != b && NULL != c) {
1855 switch (LIBXSMM_GETENUM(iprec, oprec)) {
1856 case LIBXSMM_GEMM_PRECISION_F64: {
1857 libxsmm_dmmbatch_blas(transa, transb, m, n, k,
1858 (const double*)alpha, a, lda, b, ldb, (const double*)beta, c, ldc,
1859 index_base, index_stride, stride_a, stride_b, stride_c, batchsize);
1860 result = EXIT_SUCCESS;
1861 } break;
1862 case LIBXSMM_GEMM_PRECISION_F32: {
1863 libxsmm_smmbatch_blas(transa, transb, m, n, k,
1864 (const float*)alpha, a, lda, b, ldb, (const float*)beta, c, ldc,
1865 index_base, index_stride, stride_a, stride_b, stride_c, batchsize);
1866 result = EXIT_SUCCESS;
1867 } break;
1868 default: result = EXIT_FAILURE;
1869 }
1870 }
1871 else {
1872 result = EXIT_FAILURE;
1873 }
1874 return result;
1875 }
1876
1877
libxsmm_mmbatch(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize,int tid,int nthreads)1878 LIBXSMM_API void libxsmm_mmbatch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
1879 const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1880 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
1881 const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride,
1882 const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1883 libxsmm_blasint batchsize, /*unsigned*/int tid, /*unsigned*/int nthreads)
1884 {
1885 static int error_once = 0;
1886 #if defined(LIBXSMM_GEMM_CHECK)
1887 if (NULL != a && NULL != b && NULL != c && 0 <= tid && tid < nthreads)
1888 #endif
1889 {
1890 const unsigned char otypesize = libxsmm_typesize((libxsmm_datatype)oprec);
1891 int result = EXIT_FAILURE;
1892 LIBXSMM_INIT
1893 if (LIBXSMM_SMM_AI(m, n, k, 2/*RFO*/, otypesize)) { /* check if an SMM is suitable */
1894 const int gemm_flags = LIBXSMM_GEMM_PFLAGS(transa, transb, LIBXSMM_FLAGS);
1895 libxsmm_descriptor_blob blob;
1896 libxsmm_gemm_descriptor *const desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, m, n, k,
1897 NULL != lda ? *lda : (0 == (LIBXSMM_GEMM_FLAG_TRANS_A & gemm_flags) ? m : k),
1898 NULL != ldb ? *ldb : (0 == (LIBXSMM_GEMM_FLAG_TRANS_B & gemm_flags) ? k : n),
1899 NULL != ldc ? *ldc : m, alpha, beta, gemm_flags, libxsmm_get_gemm_prefetch(LIBXSMM_PREFETCH_AUTO));
1900 if (NULL != desc) {
1901 libxsmm_xmmfunction kernel;
1902 libxsmm_gemm_internal_set_batchflag(desc, c, index_stride, batchsize, 0/*multi-threaded*/);
1903 kernel = libxsmm_xmmdispatch(desc);
1904 if (NULL != kernel.xmm) {
1905 result = libxsmm_mmbatch_kernel(kernel, index_base, index_stride,
1906 stride_a, stride_b, stride_c, a, b, c, batchsize, tid, nthreads,
1907 libxsmm_typesize((libxsmm_datatype)iprec), otypesize, desc->flags);
1908 }
1909 }
1910 }
1911 if (EXIT_SUCCESS != result) { /* quiet fall-back */
1912 if (EXIT_SUCCESS == libxsmm_mmbatch_blas(iprec, oprec,
1913 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1914 index_base, index_stride, stride_a, stride_b, stride_c, batchsize))
1915 {
1916 if (LIBXSMM_VERBOSITY_WARN <= libxsmm_verbosity || 0 > libxsmm_verbosity) {
1917 const size_t threshold = LIBXSMM_MNK_SIZE(m, n, m);
1918 static size_t threshold_max = 0;
1919 if (threshold_max < threshold) {
1920 LIBXSMM_STDIO_ACQUIRE();
1921 fprintf(stderr, "LIBXSMM WARNING: ");
1922 libxsmm_gemm_print2(stderr, iprec, oprec, transa, transb, &m, &n, &k,
1923 alpha, NULL/*a*/, lda, NULL/*b*/, ldb, beta, NULL/*c*/, ldc);
1924 fprintf(stderr, " => batched GEMM was falling back to BLAS!\n");
1925 LIBXSMM_STDIO_RELEASE();
1926 threshold_max = threshold;
1927 }
1928 }
1929 }
1930 else if (0 != libxsmm_verbosity /* library code is expected to be mute */
1931 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1932 {
1933 fprintf(stderr, "LIBXSMM ERROR: libxsmm_mmbatch failed!\n");
1934 }
1935 }
1936 }
1937 #if defined(LIBXSMM_GEMM_CHECK)
1938 else if (0 != libxsmm_verbosity /* library code is expected to be mute */
1939 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
1940 {
1941 fprintf(stderr, "LIBXSMM ERROR: incorrect arguments (libxsmm_mmbatch)!\n");
1942 }
1943 #endif
1944 }
1945
1946
libxsmm_gemm_batch(libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,const char * transa,const char * transb,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const void * alpha,const void * a,const libxsmm_blasint * lda,const void * b,const libxsmm_blasint * ldb,const void * beta,void * c,const libxsmm_blasint * ldc,libxsmm_blasint index_base,libxsmm_blasint index_stride,const libxsmm_blasint stride_a[],const libxsmm_blasint stride_b[],const libxsmm_blasint stride_c[],libxsmm_blasint batchsize)1947 LIBXSMM_API void libxsmm_gemm_batch(libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec,
1948 const char* transa, const char* transb, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
1949 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
1950 const void* beta, void* c, const libxsmm_blasint* ldc, libxsmm_blasint index_base, libxsmm_blasint index_stride,
1951 const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
1952 libxsmm_blasint batchsize)
1953 {
1954 libxsmm_mmbatch(iprec, oprec, transa, transb, m, n, k,
1955 alpha,a, lda, b, ldb, beta, c, ldc, index_base, index_stride,
1956 stride_a, stride_b, stride_c, batchsize, 0/*tid*/, 1/*nthreads*/);
1957 }
1958
1959
1960 #if defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))
1961
1962 /* implementation provided for Fortran 77 compatibility */
1963 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dgemm)(const char*, const char*,
1964 const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
1965 const double*, const double*, const libxsmm_blasint*,
1966 const double*, const libxsmm_blasint*,
1967 const double*, double*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_dgemm)1968 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_dgemm)(const char* transa, const char* transb,
1969 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1970 const double* alpha, const double* a, const libxsmm_blasint* lda,
1971 const double* b, const libxsmm_blasint* ldb,
1972 const double* beta, double* c, const libxsmm_blasint* ldc)
1973 {
1974 libxsmm_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1975 }
1976
1977 /* implementation provided for Fortran 77 compatibility */
1978 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_sgemm)(const char*, const char*,
1979 const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
1980 const float*, const float*, const libxsmm_blasint*,
1981 const float*, const libxsmm_blasint*,
1982 const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_sgemm)1983 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_sgemm)(const char* transa, const char* transb,
1984 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
1985 const float* alpha, const float* a, const libxsmm_blasint* lda,
1986 const float* b, const libxsmm_blasint* ldb,
1987 const float* beta, float* c, const libxsmm_blasint* ldc)
1988 {
1989 libxsmm_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
1990 }
1991
1992
1993 /* implementation provided for Fortran 77 compatibility */
1994 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_wigemm)(const char*, const char*,
1995 const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
1996 const int*, const short*, const libxsmm_blasint*,
1997 const short*, const libxsmm_blasint*,
1998 const int*, int*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_wigemm)1999 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_wigemm)(const char* transa, const char* transb,
2000 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2001 const int* alpha, const short* a, const libxsmm_blasint* lda,
2002 const short* b, const libxsmm_blasint* ldb,
2003 const int* beta, int* c, const libxsmm_blasint* ldc)
2004 {
2005 libxsmm_wigemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2006 }
2007
2008
2009 /* implementation provided for Fortran 77 compatibility */
2010 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_bsgemm)(const char*, const char*,
2011 const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2012 const float*, const libxsmm_bfloat16*, const libxsmm_blasint*,
2013 const libxsmm_bfloat16*, const libxsmm_blasint*,
2014 const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_bsgemm)2015 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_bsgemm)(const char* transa, const char* transb,
2016 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2017 const float* alpha, const libxsmm_bfloat16* a, const libxsmm_blasint* lda,
2018 const libxsmm_bfloat16* b, const libxsmm_blasint* ldb,
2019 const float* beta, float* c, const libxsmm_blasint* ldc)
2020 {
2021 libxsmm_bsgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2022 }
2023
2024
2025 /* implementation provided for Fortran 77 compatibility */
2026 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
2027 const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2028 const float*, const float*, const libxsmm_blasint*,
2029 const float*, const libxsmm_blasint*,
2030 const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)2031 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_xgemm)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
2032 const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2033 const float* alpha, const float* a, const libxsmm_blasint* lda,
2034 const float* b, const libxsmm_blasint* ldb,
2035 const float* beta, float* c, const libxsmm_blasint* ldc)
2036 {
2037 LIBXSMM_ASSERT(NULL != iprec && NULL != oprec);
2038 libxsmm_blas_xgemm(*iprec, *oprec, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2039 }
2040
2041
2042 /* implementation provided for Fortran 77 compatibility */
2043 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)(const char*, const char*,
2044 const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2045 const double*, const double*, const libxsmm_blasint*,
2046 const double*, const libxsmm_blasint*,
2047 const double*, double*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)2048 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_dgemm)(const char* transa, const char* transb,
2049 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2050 const double* alpha, const double* a, const libxsmm_blasint* lda,
2051 const double* b, const libxsmm_blasint* ldb,
2052 const double* beta, double* c, const libxsmm_blasint* ldc)
2053 {
2054 libxsmm_blas_dgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2055 }
2056
2057
2058 /* implementation provided for Fortran 77 compatibility */
2059 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)(const char*, const char*,
2060 const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2061 const float*, const float*, const libxsmm_blasint*,
2062 const float*, const libxsmm_blasint*,
2063 const float*, float*, const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)2064 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_blas_sgemm)(const char* transa, const char* transb,
2065 const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2066 const float* alpha, const float* a, const libxsmm_blasint* lda,
2067 const float* b, const libxsmm_blasint* ldb,
2068 const float* beta, float* c, const libxsmm_blasint* ldc)
2069 {
2070 libxsmm_blas_sgemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2071 }
2072
2073
2074 /* implementation provided for Fortran 77 compatibility */
2075 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_mmbatch)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
2076 const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2077 const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*,
2078 const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2079 const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[],
2080 const libxsmm_blasint*, const /*unsigned*/int*, const /*unsigned*/int*);
LIBXSMM_FSYMBOL(libxsmm_mmbatch)2081 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_mmbatch)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
2082 const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2083 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
2084 const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride,
2085 const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
2086 const libxsmm_blasint* batchsize, const /*unsigned*/int* tid, const /*unsigned*/int* nthreads)
2087 {
2088 LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k);
2089 LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize);
2090 LIBXSMM_ASSERT(NULL != tid && NULL != nthreads);
2091 libxsmm_mmbatch(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc,
2092 *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize, *tid, *nthreads);
2093 }
2094
2095
2096 /* implementation provided for Fortran 77 compatibility */
2097 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_gemm_batch)(const libxsmm_gemm_precision*, const libxsmm_gemm_precision*,
2098 const char*, const char*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2099 const void*, const void*, const libxsmm_blasint*, const void*, const libxsmm_blasint*,
2100 const void*, void*, const libxsmm_blasint*, const libxsmm_blasint*, const libxsmm_blasint*,
2101 const libxsmm_blasint[], const libxsmm_blasint[], const libxsmm_blasint[],
2102 const libxsmm_blasint*);
LIBXSMM_FSYMBOL(libxsmm_gemm_batch)2103 LIBXSMM_API void LIBXSMM_FSYMBOL(libxsmm_gemm_batch)(const libxsmm_gemm_precision* iprec, const libxsmm_gemm_precision* oprec,
2104 const char* transa, const char* transb, const libxsmm_blasint* m, const libxsmm_blasint* n, const libxsmm_blasint* k,
2105 const void* alpha, const void* a, const libxsmm_blasint* lda, const void* b, const libxsmm_blasint* ldb,
2106 const void* beta, void* c, const libxsmm_blasint* ldc, const libxsmm_blasint* index_base, const libxsmm_blasint* index_stride,
2107 const libxsmm_blasint stride_a[], const libxsmm_blasint stride_b[], const libxsmm_blasint stride_c[],
2108 const libxsmm_blasint* batchsize)
2109 {
2110 LIBXSMM_ASSERT(NULL != iprec && NULL != oprec && NULL != m && NULL != n && NULL != k);
2111 LIBXSMM_ASSERT(NULL != index_base && NULL != index_stride && NULL != batchsize);
2112 libxsmm_gemm_batch(*iprec, *oprec, transa, transb, *m, *n, *k, alpha, a, lda, b, ldb, beta, c, ldc,
2113 *index_base, *index_stride, stride_a, stride_b, stride_c, *batchsize);
2114 }
2115
2116 #endif /*defined(LIBXSMM_BUILD) && (!defined(LIBXSMM_NOFORTRAN) || defined(__clang_analyzer__))*/
2117
2118