1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.), Dheevatsa Mudigere (Intel Corp.)
10 Alexander Heinecke (Intel Corp.), Hans Pabst (Intel Corp.)
11 ******************************************************************************/
12 #include "libxsmm_blocked_gemm_types.h"
13 #include <libxsmm.h>
14
15
libxsmm_blocked_gemm_handle_create(int nthreads,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const libxsmm_blasint * bm,const libxsmm_blasint * bn,const libxsmm_blasint * bk,const libxsmm_blasint * b_m1,const libxsmm_blasint * b_n1,const libxsmm_blasint * b_k1,const libxsmm_blasint * b_k2,const void * alpha,const void * beta,const int * gemm_flags,const libxsmm_gemm_prefetch_type * prefetch,const libxsmm_blocked_gemm_order * order)16 LIBXSMM_API libxsmm_blocked_gemm_handle* libxsmm_blocked_gemm_handle_create(/*unsigned*/ int nthreads,
17 libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
18 const libxsmm_blasint* bm, const libxsmm_blasint* bn, const libxsmm_blasint* bk,
19 const libxsmm_blasint* b_m1, const libxsmm_blasint* b_n1, const libxsmm_blasint* b_k1, const libxsmm_blasint* b_k2,
20 const void* alpha, const void* beta, const int* gemm_flags,
21 const libxsmm_gemm_prefetch_type* prefetch,
22 const libxsmm_blocked_gemm_order* order)
23 {
24 const char *const env_m = getenv("LIBXSMM_BLOCKED_GEMM_M"), *const env_n = getenv("LIBXSMM_BLOCKED_GEMM_N"), *const env_k = getenv("LIBXSMM_BLOCKED_GEMM_K");
25 const libxsmm_blasint mm = LIBXSMM_MIN(0 == bm ? ((NULL == env_m || 0 == *env_m) ? 32 : atoi(env_m)) : *bm, m);
26 const libxsmm_blasint kk = LIBXSMM_MIN(0 == bk ? ((NULL == env_k || 0 == *env_k) ? mm : atoi(env_k)) : *bk, k);
27 const libxsmm_blasint nn = LIBXSMM_MIN(0 == bn ? ((NULL == env_n || 0 == *env_n) ? kk : atoi(env_n)) : *bn, n);
28 libxsmm_blocked_gemm_handle* result = 0;
29 static int error_once = 0;
30
31 if (0 < m && 0 < n && 0 < k && 0 < mm && 0 < nn && 0 < kk && 0 < nthreads) {
32 libxsmm_blocked_gemm_handle handle;
33 memset(&handle, 0, sizeof(handle));
34 if (0 == (m % mm) && 0 == (n % nn) && 0 == (k % kk) &&
35 0 == (m % *b_m1) && 0 == (n % *b_n1) && 0 == (k % *b_k1) &&
36 0 == ((k / *b_k1 / *b_k2) % kk) && 0 == ((n / *b_n1) % nn) && 0 == ((m / *b_m1) % mm))
37 { /* check for valid block-size */
38 libxsmm_gemm_descriptor* desc;
39 libxsmm_descriptor_blob blob;
40 if (0 == prefetch) { /* auto-prefetch */
41 /* TODO: more sophisticated strategy perhaps according to CPUID */
42 const libxsmm_gemm_prefetch_type prefetch_default = LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C;
43 const char *const env_p = getenv("LIBXSMM_BLOCKED_GEMM_PREFETCH");
44 desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, mm, nn, kk, mm/*lda*/, kk/*ldb*/, mm/*ldc*/,
45 alpha, beta, 0 == gemm_flags ? LIBXSMM_GEMM_FLAG_NONE : *gemm_flags,
46 (NULL == env_p || 0 == *env_p) ? prefetch_default : libxsmm_gemm_uid2prefetch(atoi(env_p)));
47 }
48 else { /* user-defined */
49 desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, mm, nn, kk, mm/*lda*/, kk/*ldb*/, mm/*ldc*/,
50 alpha, beta, 0 == gemm_flags ? LIBXSMM_GEMM_FLAG_NONE : *gemm_flags, *prefetch);
51 }
52 if (0 != desc) {
53 handle.mb = m / mm; handle.nb = n / nn; handle.kb = k / kk;
54 if (LIBXSMM_GEMM_PREFETCH_NONE != desc->prefetch) {
55 handle.kernel_pf = libxsmm_xmmdispatch(desc);
56 desc->prefetch = LIBXSMM_GEMM_PREFETCH_NONE;
57 handle.kernel = libxsmm_xmmdispatch(desc);
58 }
59 else { /* no prefetch */
60 handle.kernel = libxsmm_xmmdispatch(desc);
61 handle.kernel_pf.xmm = 0;
62 }
63 }
64 if (0 != handle.kernel.xmm) {
65 const size_t tls_size = LIBXSMM_UP2((size_t)mm * nn * LIBXSMM_TYPESIZE(oprec), LIBXSMM_CACHELINE) * nthreads;
66 const size_t size_locks = (size_t)handle.mb * (size_t)handle.nb * sizeof(libxsmm_blocked_gemm_lock);
67 handle.locks = (libxsmm_blocked_gemm_lock*)libxsmm_aligned_malloc(size_locks, LIBXSMM_CACHELINE);
68 handle.buffer = libxsmm_aligned_malloc(tls_size, LIBXSMM_CACHELINE);
69 result = (libxsmm_blocked_gemm_handle*)malloc(sizeof(libxsmm_blocked_gemm_handle));
70
71 if (224 <= nthreads
72 #if !defined(__MIC__)
73 && LIBXSMM_X86_AVX512_MIC <= libxsmm_target_archid
74 && LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid
75 #endif
76 )
77 {
78 handle.barrier = libxsmm_barrier_create(nthreads / 4, 4);
79 }
80 else {
81 handle.barrier = libxsmm_barrier_create(nthreads / 2, 2);
82 }
83 if (0 != result && 0 != handle.barrier && 0 != handle.buffer && 0 != handle.locks) {
84 handle.m = m; handle.n = n; handle.k = k; handle.bm = mm; handle.bn = nn; handle.bk = kk;
85 handle.b_m1 = *b_m1; handle.b_n1 = *b_n1; handle.b_k1 = *b_k1; handle.b_k2 = *b_k2;
86 handle.iprec = iprec; handle.oprec = oprec;
87 memset(handle.locks, 0, size_locks);
88 handle.order = (0 == order ? LIBXSMM_BLOCKED_GEMM_ORDER_JIK : *order);
89 handle.nthreads = nthreads;
90 *result = handle;
91 }
92 else {
93 if (0 != libxsmm_verbosity /* library code is expected to be mute */
94 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
95 {
96 fprintf(stderr, "LIBXSMM ERROR: BGEMM handle allocation failed!\n");
97 }
98 libxsmm_barrier_release(handle.barrier);
99 libxsmm_free(handle.buffer);
100 libxsmm_free(handle.locks);
101 free(result);
102 result = 0;
103 }
104 }
105 else if (0 != libxsmm_verbosity /* library code is expected to be mute */
106 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
107 {
108 fprintf(stderr, "LIBXSMM ERROR: unsupported BGEMM kernel requested!\n");
109 }
110 }
111 else if (0 != libxsmm_verbosity /* library code is expected to be mute */
112 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
113 {
114 fprintf(stderr, "LIBXSMM ERROR: BGEMM block-size is invalid!\n");
115 }
116 }
117 else if (0 != libxsmm_verbosity /* library code is expected to be mute */
118 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
119 {
120 fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_blocked_gemm_handle_create!\n");
121 }
122
123 return result;
124 }
125
126
libxsmm_blocked_gemm_handle_destroy(const libxsmm_blocked_gemm_handle * handle)127 LIBXSMM_API void libxsmm_blocked_gemm_handle_destroy(const libxsmm_blocked_gemm_handle* handle)
128 {
129 if (0 != handle) {
130 libxsmm_barrier_release(handle->barrier);
131 libxsmm_free(handle->buffer);
132 libxsmm_free(handle->locks);
133 free((libxsmm_blocked_gemm_handle*)handle);
134 }
135 }
136
137
libxsmm_blocked_gemm_copyin_a(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)138 LIBXSMM_API int libxsmm_blocked_gemm_copyin_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
139 {
140 int result = EXIT_SUCCESS;
141 static int error_once = 0;
142
143 if (0 != handle) {
144 #if 0 /* TODO: support leading dimension for the source buffer */
145 const libxsmm_blasint ild = (0 == ld ? handle->m : *ld);
146 assert(ild >= handle->m);
147 #else
148 LIBXSMM_UNUSED(ld);
149 #endif
150 switch (handle->iprec) {
151 case LIBXSMM_GEMM_PRECISION_F64: {
152 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
153 # include "template/libxsmm_blocked_gemm_copyin_a.tpl.c"
154 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
155 } break;
156 case LIBXSMM_GEMM_PRECISION_F32: {
157 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
158 # include "template/libxsmm_blocked_gemm_copyin_a.tpl.c"
159 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
160 } break;
161 case LIBXSMM_GEMM_PRECISION_I16: {
162 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
163 # include "template/libxsmm_blocked_gemm_copyin_a.tpl.c"
164 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
165 } break;
166 default: {
167 if (0 != libxsmm_verbosity /* library code is expected to be mute */
168 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
169 {
170 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix A is not supported!\n");
171 }
172 result = EXIT_FAILURE;
173 }
174 }
175 }
176 else {
177 if (0 != libxsmm_verbosity /* library code is expected to be mute */
178 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
179 {
180 fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
181 }
182 result = EXIT_FAILURE;
183 }
184 return result;
185 }
186
187
libxsmm_blocked_gemm_copyin_b(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)188 LIBXSMM_API int libxsmm_blocked_gemm_copyin_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
189 {
190 int result = EXIT_SUCCESS;
191 static int error_once = 0;
192
193 if (0 != handle) {
194 #if 0 /* TODO: support leading dimension for the source buffer */
195 const libxsmm_blasint ild = (0 == ld ? handle->k : *ld);
196 assert(ild >= handle->k);
197 #else
198 LIBXSMM_UNUSED(ld);
199 #endif
200 switch (handle->iprec) {
201 case LIBXSMM_GEMM_PRECISION_F64: {
202 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
203 # include "template/libxsmm_blocked_gemm_copyin_b.tpl.c"
204 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
205 } break;
206 case LIBXSMM_GEMM_PRECISION_F32: {
207 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
208 # include "template/libxsmm_blocked_gemm_copyin_b.tpl.c"
209 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
210 } break;
211 case LIBXSMM_GEMM_PRECISION_I16: {
212 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
213 # include "template/libxsmm_blocked_gemm_copyin_b.tpl.c"
214 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
215 } break;
216 default: {
217 if (0 != libxsmm_verbosity /* library code is expected to be mute */
218 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
219 {
220 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix B is not supported!\n");
221 }
222 result = EXIT_FAILURE;
223 }
224 }
225 }
226 else {
227 if (0 != libxsmm_verbosity /* library code is expected to be mute */
228 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
229 {
230 fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
231 }
232 result = EXIT_FAILURE;
233 }
234 return result;
235 }
236
237
libxsmm_blocked_gemm_copyin_c(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)238 LIBXSMM_API int libxsmm_blocked_gemm_copyin_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
239 {
240 int result = EXIT_SUCCESS;
241 static int error_once = 0;
242
243 if (0 != handle) {
244 #if 0 /* TODO: support leading dimension for the source buffer */
245 const libxsmm_blasint ild = (0 == ld ? handle->m : *ld);
246 assert(ild >= handle->m);
247 #else
248 LIBXSMM_UNUSED(ld);
249 #endif
250 switch (handle->oprec) {
251 case LIBXSMM_GEMM_PRECISION_F64: {
252 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
253 # include "template/libxsmm_blocked_gemm_copyin_c.tpl.c"
254 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
255 } break;
256 case LIBXSMM_GEMM_PRECISION_F32: {
257 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
258 # include "template/libxsmm_blocked_gemm_copyin_c.tpl.c"
259 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
260 } break;
261 case LIBXSMM_GEMM_PRECISION_I16: {
262 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE int
263 # include "template/libxsmm_blocked_gemm_copyin_c.tpl.c"
264 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
265 } break;
266 default: {
267 if (0 != libxsmm_verbosity /* library code is expected to be mute */
268 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
269 {
270 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix A is not supported!\n");
271 }
272 result = EXIT_FAILURE;
273 }
274 }
275 }
276 else {
277 if (0 != libxsmm_verbosity /* library code is expected to be mute */
278 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
279 {
280 fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
281 }
282 result = EXIT_FAILURE;
283 }
284 return result;
285 }
286
287
libxsmm_blocked_gemm_copyout_c(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)288 LIBXSMM_API int libxsmm_blocked_gemm_copyout_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
289 {
290 int result = EXIT_SUCCESS;
291 static int error_once = 0;
292
293 if (0 != handle) {
294 #if 0 /* TODO: support leading dimension for the source buffer */
295 const libxsmm_blasint ild = (0 == ld ? handle->m : *ld);
296 assert(ild >= handle->m);
297 #else
298 LIBXSMM_UNUSED(ld);
299 #endif
300 switch (handle->oprec) {
301 case LIBXSMM_GEMM_PRECISION_F64: {
302 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
303 # include "template/libxsmm_blocked_gemm_copyout_c.tpl.c"
304 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
305 } break;
306 case LIBXSMM_GEMM_PRECISION_F32: {
307 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
308 # include "template/libxsmm_blocked_gemm_copyout_c.tpl.c"
309 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
310 } break;
311 case LIBXSMM_GEMM_PRECISION_I16: {
312 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE int
313 # include "template/libxsmm_blocked_gemm_copyout_c.tpl.c"
314 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
315 } break;
316 default: {
317 if (0 != libxsmm_verbosity /* library code is expected to be mute */
318 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
319 {
320 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix A is not supported!\n");
321 }
322 result = EXIT_FAILURE;
323 }
324 }
325 }
326 else {
327 if (0 != libxsmm_verbosity /* library code is expected to be mute */
328 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
329 {
330 fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
331 }
332 result = EXIT_FAILURE;
333 }
334 return result;
335 }
336
337
libxsmm_blocked_gemm_convert_b_to_a(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)338 LIBXSMM_API int libxsmm_blocked_gemm_convert_b_to_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
339 {
340 int result = EXIT_SUCCESS;
341 static int error_once = 0;
342
343 if (0 != handle) {
344 #if 0 /* TODO: support leading dimension for the source buffer */
345 const libxsmm_blasint ild = (0 == ld ? handle->k : *ld);
346 assert(ild >= handle->k);
347 #else
348 LIBXSMM_UNUSED(ld);
349 #endif
350 switch (handle->iprec) {
351 case LIBXSMM_GEMM_PRECISION_F64: {
352 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
353 # include "template/libxsmm_blocked_gemm_convert_b_to_a.tpl.c"
354 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
355 } break;
356 case LIBXSMM_GEMM_PRECISION_F32: {
357 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
358 # include "template/libxsmm_blocked_gemm_convert_b_to_a.tpl.c"
359 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
360 } break;
361 case LIBXSMM_GEMM_PRECISION_I16: {
362 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
363 # include "template/libxsmm_blocked_gemm_convert_b_to_a.tpl.c"
364 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
365 } break;
366 default: {
367 if (0 != libxsmm_verbosity /* library code is expected to be mute */
368 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
369 {
370 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix B is not supported!\n");
371 }
372 result = EXIT_FAILURE;
373 }
374 }
375 }
376 else {
377 if (0 != libxsmm_verbosity /* library code is expected to be mute */
378 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
379 {
380 fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
381 }
382 result = EXIT_FAILURE;
383 }
384 return result;
385 }
386
387
libxsmm_blocked_gemm_transpose_b(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)388 LIBXSMM_API int libxsmm_blocked_gemm_transpose_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
389 {
390 int result = EXIT_SUCCESS;
391 static int error_once = 0;
392
393 if (0 != handle) {
394 #if 0 /* TODO: support leading dimension for the source buffer */
395 const libxsmm_blasint ild = (0 == ld ? handle->k : *ld);
396 assert(ild >= handle->k);
397 #else
398 LIBXSMM_UNUSED(ld);
399 #endif
400 switch (handle->iprec) {
401 case LIBXSMM_GEMM_PRECISION_F64: {
402 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
403 # include "template/libxsmm_blocked_gemm_transpose_b.tpl.c"
404 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
405 } break;
406 case LIBXSMM_GEMM_PRECISION_F32: {
407 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
408 # include "template/libxsmm_blocked_gemm_transpose_b.tpl.c"
409 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
410 } break;
411 case LIBXSMM_GEMM_PRECISION_I16: {
412 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
413 # include "template/libxsmm_blocked_gemm_transpose_b.tpl.c"
414 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
415 } break;
416 default: {
417 if (0 != libxsmm_verbosity /* library code is expected to be mute */
418 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
419 {
420 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix B is not supported!\n");
421 }
422 result = EXIT_FAILURE;
423 }
424 }
425 }
426 else {
427 if (0 != libxsmm_verbosity /* library code is expected to be mute */
428 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
429 {
430 fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
431 }
432 result = EXIT_FAILURE;
433 }
434 return result;
435 }
436
437
internal_bgemm_order(libxsmm_blocked_gemm_order order,libxsmm_blasint w_i,libxsmm_blasint nw_i,libxsmm_blasint nw_j,libxsmm_blasint nw_k,libxsmm_blasint * i2,libxsmm_blasint * j2,libxsmm_blasint * k2)438 LIBXSMM_API_INLINE void internal_bgemm_order(libxsmm_blocked_gemm_order order,
439 libxsmm_blasint w_i, libxsmm_blasint nw_i, libxsmm_blasint nw_j, libxsmm_blasint nw_k,
440 libxsmm_blasint* i2, libxsmm_blasint* j2, libxsmm_blasint* k2)
441 {
442 switch (order) {
443 case LIBXSMM_BLOCKED_GEMM_ORDER_JIK: {
444 *j2 = (w_i / (nw_i * nw_k));
445 *i2 = (w_i - (*j2) * (nw_i * nw_k)) / nw_k;
446 *k2 = (w_i % nw_k);
447 } break;
448 case LIBXSMM_BLOCKED_GEMM_ORDER_IJK: {
449 *i2 = (w_i / (nw_j * nw_k));
450 *j2 = (w_i - (*i2) * (nw_j * nw_k)) / nw_k;
451 *k2 = (w_i % nw_k);
452 } break;
453 case LIBXSMM_BLOCKED_GEMM_ORDER_JKI: {
454 *j2 = (w_i / (nw_k * nw_i));
455 *k2 = (w_i - (*j2) * (nw_k * nw_i)) / nw_i;
456 *i2 = (w_i % nw_i);
457 } break;
458 case LIBXSMM_BLOCKED_GEMM_ORDER_IKJ: {
459 *i2 = (w_i / (nw_k * nw_j));
460 *k2 = (w_i - (*i2) * (nw_k * nw_j)) / nw_j;
461 *j2 = (w_i % nw_j);
462 } break;
463 case LIBXSMM_BLOCKED_GEMM_ORDER_KJI: {
464 *k2 = (w_i / (nw_j * nw_i));
465 *j2 = (w_i - (*k2) * (nw_j * nw_i)) / nw_i;
466 *i2 = (w_i % nw_i);
467 } break;
468 case LIBXSMM_BLOCKED_GEMM_ORDER_KIJ: {
469 *k2 = (w_i / (nw_i * nw_j));
470 *i2 = (w_i - (*k2) * (nw_i * nw_j)) / nw_j;
471 *j2 = (w_i % nw_j);
472 } break;
473 default: assert(0/*should never happen*/);
474 }
475 }
476
libxsmm_blocked_gemm_st(const libxsmm_blocked_gemm_handle * handle,const void * a,const void * b,void * c,int start_thread,int tid)477 LIBXSMM_API void libxsmm_blocked_gemm_st(const libxsmm_blocked_gemm_handle* handle, const void* a, const void* b, void* c,
478 /*unsigned*/int start_thread, /*unsigned*/int tid)
479 {
480 static int error_once = 0;
481 #if defined(LIBXSMM_BLOCKED_GEMM_CHECKS)
482 if (0 != handle && 0 != a && 0 != b && 0 != c && start_thread <= tid && 0 <= tid)
483 #endif
484 {
485 const int ltid = tid - start_thread;
486 if (handle->nthreads > 1) {
487 libxsmm_barrier_init(handle->barrier, ltid);
488 }
489 switch (handle->iprec) {
490 case LIBXSMM_GEMM_PRECISION_F64: {
491 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB double
492 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C double
493 # include "template/libxsmm_blocked_gemm.tpl.c"
494 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB
495 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C
496 } break;
497 case LIBXSMM_GEMM_PRECISION_F32: {
498 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB float
499 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C float
500 # include "template/libxsmm_blocked_gemm.tpl.c"
501 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB
502 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C
503 } break;
504 case LIBXSMM_GEMM_PRECISION_I16: {
505 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB short
506 # define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C int
507 # include "template/libxsmm_blocked_gemm.tpl.c"
508 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C
509 # undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB
510 } break;
511 default: if (0 != libxsmm_verbosity /* library code is expected to be mute */
512 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
513 {
514 fprintf(stderr, "LIBXSMM ERROR: BGEMM precision is not supported!\n");
515 }
516 }
517 if (handle->nthreads > 1) {
518 libxsmm_barrier_wait(handle->barrier, ltid);
519 }
520 }
521 #if defined(LIBXSMM_BLOCKED_GEMM_CHECKS)
522 else if (0 != libxsmm_verbosity /* library code is expected to be mute */
523 && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
524 {
525 fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_blocked_gemm!\n");
526 }
527 #endif
528 }
529
530