1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.), Dheevatsa Mudigere (Intel Corp.)
10    Alexander Heinecke (Intel Corp.), Hans Pabst (Intel Corp.)
11 ******************************************************************************/
12 #include "libxsmm_blocked_gemm_types.h"
13 #include <libxsmm.h>
14 
15 
libxsmm_blocked_gemm_handle_create(int nthreads,libxsmm_gemm_precision iprec,libxsmm_gemm_precision oprec,libxsmm_blasint m,libxsmm_blasint n,libxsmm_blasint k,const libxsmm_blasint * bm,const libxsmm_blasint * bn,const libxsmm_blasint * bk,const libxsmm_blasint * b_m1,const libxsmm_blasint * b_n1,const libxsmm_blasint * b_k1,const libxsmm_blasint * b_k2,const void * alpha,const void * beta,const int * gemm_flags,const libxsmm_gemm_prefetch_type * prefetch,const libxsmm_blocked_gemm_order * order)16 LIBXSMM_API libxsmm_blocked_gemm_handle* libxsmm_blocked_gemm_handle_create(/*unsigned*/ int nthreads,
17   libxsmm_gemm_precision iprec, libxsmm_gemm_precision oprec, libxsmm_blasint m, libxsmm_blasint n, libxsmm_blasint k,
18   const libxsmm_blasint* bm, const libxsmm_blasint* bn, const libxsmm_blasint* bk,
19   const libxsmm_blasint* b_m1, const libxsmm_blasint* b_n1, const libxsmm_blasint* b_k1, const libxsmm_blasint* b_k2,
20   const void* alpha, const void* beta, const int* gemm_flags,
21   const libxsmm_gemm_prefetch_type* prefetch,
22   const libxsmm_blocked_gemm_order* order)
23 {
24   const char *const env_m = getenv("LIBXSMM_BLOCKED_GEMM_M"), *const env_n = getenv("LIBXSMM_BLOCKED_GEMM_N"), *const env_k = getenv("LIBXSMM_BLOCKED_GEMM_K");
25   const libxsmm_blasint mm = LIBXSMM_MIN(0 == bm ? ((NULL == env_m || 0 == *env_m) ? 32 : atoi(env_m)) : *bm, m);
26   const libxsmm_blasint kk = LIBXSMM_MIN(0 == bk ? ((NULL == env_k || 0 == *env_k) ? mm : atoi(env_k)) : *bk, k);
27   const libxsmm_blasint nn = LIBXSMM_MIN(0 == bn ? ((NULL == env_n || 0 == *env_n) ? kk : atoi(env_n)) : *bn, n);
28   libxsmm_blocked_gemm_handle* result = 0;
29   static int error_once = 0;
30 
31   if (0 < m && 0 < n && 0 < k && 0 < mm && 0 < nn && 0 < kk && 0 < nthreads) {
32     libxsmm_blocked_gemm_handle handle;
33     memset(&handle, 0, sizeof(handle));
34     if (0 == (m % mm) && 0 == (n % nn) && 0 == (k % kk) &&
35         0 == (m % *b_m1) && 0 == (n % *b_n1) && 0 == (k % *b_k1) &&
36         0 == ((k / *b_k1 / *b_k2) % kk) && 0 == ((n / *b_n1) % nn) && 0 == ((m / *b_m1) % mm))
37     { /* check for valid block-size */
38       libxsmm_gemm_descriptor* desc;
39       libxsmm_descriptor_blob blob;
40       if (0 == prefetch) { /* auto-prefetch */
41         /* TODO: more sophisticated strategy perhaps according to CPUID */
42         const libxsmm_gemm_prefetch_type prefetch_default = LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C;
43         const char *const env_p = getenv("LIBXSMM_BLOCKED_GEMM_PREFETCH");
44         desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, mm, nn, kk, mm/*lda*/, kk/*ldb*/, mm/*ldc*/,
45           alpha, beta, 0 == gemm_flags ? LIBXSMM_GEMM_FLAG_NONE : *gemm_flags,
46           (NULL == env_p || 0 == *env_p) ? prefetch_default : libxsmm_gemm_uid2prefetch(atoi(env_p)));
47       }
48       else { /* user-defined */
49         desc = libxsmm_gemm_descriptor_init2(&blob, iprec, oprec, mm, nn, kk, mm/*lda*/, kk/*ldb*/, mm/*ldc*/,
50           alpha, beta, 0 == gemm_flags ? LIBXSMM_GEMM_FLAG_NONE : *gemm_flags, *prefetch);
51       }
52       if (0 != desc) {
53         handle.mb = m / mm; handle.nb = n / nn; handle.kb = k / kk;
54         if (LIBXSMM_GEMM_PREFETCH_NONE != desc->prefetch) {
55           handle.kernel_pf = libxsmm_xmmdispatch(desc);
56           desc->prefetch = LIBXSMM_GEMM_PREFETCH_NONE;
57           handle.kernel = libxsmm_xmmdispatch(desc);
58         }
59         else { /* no prefetch */
60           handle.kernel = libxsmm_xmmdispatch(desc);
61           handle.kernel_pf.xmm = 0;
62         }
63       }
64       if (0 != handle.kernel.xmm) {
65         const size_t tls_size = LIBXSMM_UP2((size_t)mm * nn * LIBXSMM_TYPESIZE(oprec), LIBXSMM_CACHELINE) * nthreads;
66         const size_t size_locks = (size_t)handle.mb * (size_t)handle.nb * sizeof(libxsmm_blocked_gemm_lock);
67         handle.locks = (libxsmm_blocked_gemm_lock*)libxsmm_aligned_malloc(size_locks, LIBXSMM_CACHELINE);
68         handle.buffer = libxsmm_aligned_malloc(tls_size, LIBXSMM_CACHELINE);
69         result = (libxsmm_blocked_gemm_handle*)malloc(sizeof(libxsmm_blocked_gemm_handle));
70 
71         if (224 <= nthreads
72 #if !defined(__MIC__)
73           && LIBXSMM_X86_AVX512_MIC <= libxsmm_target_archid
74           && LIBXSMM_X86_AVX512_CORE > libxsmm_target_archid
75 #endif
76           )
77         {
78           handle.barrier = libxsmm_barrier_create(nthreads / 4, 4);
79         }
80         else {
81           handle.barrier = libxsmm_barrier_create(nthreads / 2, 2);
82         }
83         if (0 != result && 0 != handle.barrier && 0 != handle.buffer && 0 != handle.locks) {
84           handle.m = m; handle.n = n; handle.k = k; handle.bm = mm; handle.bn = nn; handle.bk = kk;
85           handle.b_m1 = *b_m1; handle.b_n1 = *b_n1; handle.b_k1 = *b_k1; handle.b_k2 = *b_k2;
86           handle.iprec = iprec; handle.oprec = oprec;
87           memset(handle.locks, 0, size_locks);
88           handle.order = (0 == order ? LIBXSMM_BLOCKED_GEMM_ORDER_JIK : *order);
89           handle.nthreads = nthreads;
90           *result = handle;
91         }
92         else {
93           if (0 != libxsmm_verbosity /* library code is expected to be mute */
94             && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
95           {
96             fprintf(stderr, "LIBXSMM ERROR: BGEMM handle allocation failed!\n");
97           }
98           libxsmm_barrier_release(handle.barrier);
99           libxsmm_free(handle.buffer);
100           libxsmm_free(handle.locks);
101           free(result);
102           result = 0;
103         }
104       }
105       else if (0 != libxsmm_verbosity /* library code is expected to be mute */
106         && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
107       {
108         fprintf(stderr, "LIBXSMM ERROR: unsupported BGEMM kernel requested!\n");
109       }
110     }
111     else if (0 != libxsmm_verbosity /* library code is expected to be mute */
112       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
113     {
114       fprintf(stderr, "LIBXSMM ERROR: BGEMM block-size is invalid!\n");
115     }
116   }
117   else if (0 != libxsmm_verbosity /* library code is expected to be mute */
118     && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
119   {
120     fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_blocked_gemm_handle_create!\n");
121   }
122 
123   return result;
124 }
125 
126 
libxsmm_blocked_gemm_handle_destroy(const libxsmm_blocked_gemm_handle * handle)127 LIBXSMM_API void libxsmm_blocked_gemm_handle_destroy(const libxsmm_blocked_gemm_handle* handle)
128 {
129   if (0 != handle) {
130     libxsmm_barrier_release(handle->barrier);
131     libxsmm_free(handle->buffer);
132     libxsmm_free(handle->locks);
133     free((libxsmm_blocked_gemm_handle*)handle);
134   }
135 }
136 
137 
libxsmm_blocked_gemm_copyin_a(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)138 LIBXSMM_API int libxsmm_blocked_gemm_copyin_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
139 {
140   int result = EXIT_SUCCESS;
141   static int error_once = 0;
142 
143   if (0 != handle) {
144 #if 0 /* TODO: support leading dimension for the source buffer */
145     const libxsmm_blasint ild = (0 == ld ? handle->m : *ld);
146     assert(ild >= handle->m);
147 #else
148     LIBXSMM_UNUSED(ld);
149 #endif
150     switch (handle->iprec) {
151       case LIBXSMM_GEMM_PRECISION_F64: {
152 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
153 #       include "template/libxsmm_blocked_gemm_copyin_a.tpl.c"
154 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
155       } break;
156       case LIBXSMM_GEMM_PRECISION_F32: {
157 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
158 #       include "template/libxsmm_blocked_gemm_copyin_a.tpl.c"
159 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
160       } break;
161       case LIBXSMM_GEMM_PRECISION_I16: {
162 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
163 #       include "template/libxsmm_blocked_gemm_copyin_a.tpl.c"
164 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
165       } break;
166       default: {
167         if (0 != libxsmm_verbosity /* library code is expected to be mute */
168           && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
169         {
170           fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix A is not supported!\n");
171         }
172         result = EXIT_FAILURE;
173       }
174     }
175   }
176   else {
177     if (0 != libxsmm_verbosity /* library code is expected to be mute */
178       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
179     {
180       fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
181     }
182     result = EXIT_FAILURE;
183   }
184   return result;
185 }
186 
187 
libxsmm_blocked_gemm_copyin_b(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)188 LIBXSMM_API int libxsmm_blocked_gemm_copyin_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
189 {
190   int result = EXIT_SUCCESS;
191   static int error_once = 0;
192 
193   if (0 != handle) {
194 #if 0 /* TODO: support leading dimension for the source buffer */
195     const libxsmm_blasint ild = (0 == ld ? handle->k : *ld);
196     assert(ild >= handle->k);
197 #else
198     LIBXSMM_UNUSED(ld);
199 #endif
200     switch (handle->iprec) {
201       case LIBXSMM_GEMM_PRECISION_F64: {
202 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
203 #       include "template/libxsmm_blocked_gemm_copyin_b.tpl.c"
204 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
205       } break;
206       case LIBXSMM_GEMM_PRECISION_F32: {
207 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
208 #       include "template/libxsmm_blocked_gemm_copyin_b.tpl.c"
209 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
210       } break;
211       case LIBXSMM_GEMM_PRECISION_I16: {
212 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
213 #       include "template/libxsmm_blocked_gemm_copyin_b.tpl.c"
214 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
215       } break;
216       default: {
217         if (0 != libxsmm_verbosity /* library code is expected to be mute */
218           && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
219         {
220           fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix B is not supported!\n");
221         }
222         result = EXIT_FAILURE;
223       }
224     }
225   }
226   else {
227     if (0 != libxsmm_verbosity /* library code is expected to be mute */
228       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
229     {
230       fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
231     }
232     result = EXIT_FAILURE;
233   }
234   return result;
235 }
236 
237 
libxsmm_blocked_gemm_copyin_c(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)238 LIBXSMM_API int libxsmm_blocked_gemm_copyin_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
239 {
240   int result = EXIT_SUCCESS;
241   static int error_once = 0;
242 
243   if (0 != handle) {
244 #if 0 /* TODO: support leading dimension for the source buffer */
245     const libxsmm_blasint ild = (0 == ld ? handle->m : *ld);
246     assert(ild >= handle->m);
247 #else
248     LIBXSMM_UNUSED(ld);
249 #endif
250     switch (handle->oprec) {
251       case LIBXSMM_GEMM_PRECISION_F64: {
252 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
253 #       include "template/libxsmm_blocked_gemm_copyin_c.tpl.c"
254 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
255       } break;
256       case LIBXSMM_GEMM_PRECISION_F32: {
257 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
258 #       include "template/libxsmm_blocked_gemm_copyin_c.tpl.c"
259 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
260       } break;
261       case LIBXSMM_GEMM_PRECISION_I16: {
262 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE int
263 #       include "template/libxsmm_blocked_gemm_copyin_c.tpl.c"
264 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
265       } break;
266       default: {
267         if (0 != libxsmm_verbosity /* library code is expected to be mute */
268           && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
269         {
270           fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix A is not supported!\n");
271         }
272         result = EXIT_FAILURE;
273       }
274     }
275   }
276   else {
277     if (0 != libxsmm_verbosity /* library code is expected to be mute */
278       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
279     {
280       fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
281     }
282     result = EXIT_FAILURE;
283   }
284   return result;
285 }
286 
287 
libxsmm_blocked_gemm_copyout_c(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)288 LIBXSMM_API int libxsmm_blocked_gemm_copyout_c(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
289 {
290   int result = EXIT_SUCCESS;
291   static int error_once = 0;
292 
293   if (0 != handle) {
294 #if 0 /* TODO: support leading dimension for the source buffer */
295     const libxsmm_blasint ild = (0 == ld ? handle->m : *ld);
296     assert(ild >= handle->m);
297 #else
298     LIBXSMM_UNUSED(ld);
299 #endif
300     switch (handle->oprec) {
301       case LIBXSMM_GEMM_PRECISION_F64: {
302 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
303 #       include "template/libxsmm_blocked_gemm_copyout_c.tpl.c"
304 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
305       } break;
306       case LIBXSMM_GEMM_PRECISION_F32: {
307 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
308 #       include "template/libxsmm_blocked_gemm_copyout_c.tpl.c"
309 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
310       } break;
311       case LIBXSMM_GEMM_PRECISION_I16: {
312 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE int
313 #       include "template/libxsmm_blocked_gemm_copyout_c.tpl.c"
314 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
315       } break;
316       default: {
317         if (0 != libxsmm_verbosity /* library code is expected to be mute */
318           && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
319         {
320           fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix A is not supported!\n");
321         }
322         result = EXIT_FAILURE;
323       }
324     }
325   }
326   else {
327     if (0 != libxsmm_verbosity /* library code is expected to be mute */
328       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
329     {
330       fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
331     }
332     result = EXIT_FAILURE;
333   }
334   return result;
335 }
336 
337 
libxsmm_blocked_gemm_convert_b_to_a(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)338 LIBXSMM_API int libxsmm_blocked_gemm_convert_b_to_a(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
339 {
340   int result = EXIT_SUCCESS;
341   static int error_once = 0;
342 
343   if (0 != handle) {
344 #if 0 /* TODO: support leading dimension for the source buffer */
345     const libxsmm_blasint ild = (0 == ld ? handle->k : *ld);
346     assert(ild >= handle->k);
347 #else
348     LIBXSMM_UNUSED(ld);
349 #endif
350     switch (handle->iprec) {
351       case LIBXSMM_GEMM_PRECISION_F64: {
352 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
353 #       include "template/libxsmm_blocked_gemm_convert_b_to_a.tpl.c"
354 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
355       } break;
356       case LIBXSMM_GEMM_PRECISION_F32: {
357 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
358 #       include "template/libxsmm_blocked_gemm_convert_b_to_a.tpl.c"
359 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
360       } break;
361       case LIBXSMM_GEMM_PRECISION_I16: {
362 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
363 #       include "template/libxsmm_blocked_gemm_convert_b_to_a.tpl.c"
364 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
365       } break;
366       default: {
367         if (0 != libxsmm_verbosity /* library code is expected to be mute */
368           && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
369         {
370           fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix B is not supported!\n");
371         }
372         result = EXIT_FAILURE;
373       }
374     }
375   }
376   else {
377     if (0 != libxsmm_verbosity /* library code is expected to be mute */
378       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
379     {
380       fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
381     }
382     result = EXIT_FAILURE;
383   }
384   return result;
385 }
386 
387 
libxsmm_blocked_gemm_transpose_b(const libxsmm_blocked_gemm_handle * handle,const void * src,const libxsmm_blasint * ld,void * dst)388 LIBXSMM_API int libxsmm_blocked_gemm_transpose_b(const libxsmm_blocked_gemm_handle* handle, const void* src, const libxsmm_blasint* ld, void* dst)
389 {
390   int result = EXIT_SUCCESS;
391   static int error_once = 0;
392 
393   if (0 != handle) {
394 #if 0 /* TODO: support leading dimension for the source buffer */
395     const libxsmm_blasint ild = (0 == ld ? handle->k : *ld);
396     assert(ild >= handle->k);
397 #else
398     LIBXSMM_UNUSED(ld);
399 #endif
400     switch (handle->iprec) {
401       case LIBXSMM_GEMM_PRECISION_F64: {
402 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE double
403 #       include "template/libxsmm_blocked_gemm_transpose_b.tpl.c"
404 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
405       } break;
406       case LIBXSMM_GEMM_PRECISION_F32: {
407 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE float
408 #       include "template/libxsmm_blocked_gemm_transpose_b.tpl.c"
409 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
410       } break;
411       case LIBXSMM_GEMM_PRECISION_I16: {
412 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE short
413 #       include "template/libxsmm_blocked_gemm_transpose_b.tpl.c"
414 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE
415       } break;
416       default: {
417         if (0 != libxsmm_verbosity /* library code is expected to be mute */
418           && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
419         {
420           fprintf(stderr, "LIBXSMM ERROR: BGEMM precision of matrix B is not supported!\n");
421         }
422         result = EXIT_FAILURE;
423       }
424     }
425   }
426   else {
427     if (0 != libxsmm_verbosity /* library code is expected to be mute */
428       && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
429     {
430       fprintf(stderr, "LIBXSMM ERROR: BGEMM-handle cannot be NULL!\n");
431     }
432     result = EXIT_FAILURE;
433   }
434   return result;
435 }
436 
437 
internal_bgemm_order(libxsmm_blocked_gemm_order order,libxsmm_blasint w_i,libxsmm_blasint nw_i,libxsmm_blasint nw_j,libxsmm_blasint nw_k,libxsmm_blasint * i2,libxsmm_blasint * j2,libxsmm_blasint * k2)438 LIBXSMM_API_INLINE void internal_bgemm_order(libxsmm_blocked_gemm_order order,
439   libxsmm_blasint w_i, libxsmm_blasint nw_i, libxsmm_blasint nw_j, libxsmm_blasint nw_k,
440   libxsmm_blasint* i2, libxsmm_blasint* j2, libxsmm_blasint* k2)
441 {
442   switch (order) {
443     case LIBXSMM_BLOCKED_GEMM_ORDER_JIK: {
444       *j2 = (w_i / (nw_i * nw_k));
445       *i2 = (w_i - (*j2) * (nw_i * nw_k)) / nw_k;
446       *k2 = (w_i % nw_k);
447     } break;
448     case LIBXSMM_BLOCKED_GEMM_ORDER_IJK: {
449       *i2 = (w_i / (nw_j * nw_k));
450       *j2 = (w_i - (*i2) * (nw_j * nw_k)) / nw_k;
451       *k2 = (w_i % nw_k);
452     } break;
453     case LIBXSMM_BLOCKED_GEMM_ORDER_JKI: {
454       *j2 = (w_i / (nw_k * nw_i));
455       *k2 = (w_i - (*j2) * (nw_k * nw_i)) / nw_i;
456       *i2 = (w_i % nw_i);
457     } break;
458     case LIBXSMM_BLOCKED_GEMM_ORDER_IKJ: {
459       *i2 = (w_i / (nw_k * nw_j));
460       *k2 = (w_i - (*i2) * (nw_k * nw_j)) / nw_j;
461       *j2 = (w_i % nw_j);
462     } break;
463     case LIBXSMM_BLOCKED_GEMM_ORDER_KJI: {
464       *k2 = (w_i / (nw_j * nw_i));
465       *j2 = (w_i - (*k2) * (nw_j * nw_i)) / nw_i;
466       *i2 = (w_i % nw_i);
467     } break;
468     case LIBXSMM_BLOCKED_GEMM_ORDER_KIJ: {
469       *k2 = (w_i / (nw_i * nw_j));
470       *i2 = (w_i - (*k2) * (nw_i * nw_j)) / nw_j;
471       *j2 = (w_i % nw_j);
472     } break;
473     default: assert(0/*should never happen*/);
474   }
475 }
476 
libxsmm_blocked_gemm_st(const libxsmm_blocked_gemm_handle * handle,const void * a,const void * b,void * c,int start_thread,int tid)477 LIBXSMM_API void libxsmm_blocked_gemm_st(const libxsmm_blocked_gemm_handle* handle, const void* a, const void* b, void* c,
478   /*unsigned*/int start_thread, /*unsigned*/int tid)
479 {
480   static int error_once = 0;
481 #if defined(LIBXSMM_BLOCKED_GEMM_CHECKS)
482   if (0 != handle && 0 != a && 0 != b && 0 != c && start_thread <= tid && 0 <= tid)
483 #endif
484   {
485     const int ltid = tid - start_thread;
486     if (handle->nthreads > 1) {
487       libxsmm_barrier_init(handle->barrier, ltid);
488     }
489     switch (handle->iprec) {
490       case LIBXSMM_GEMM_PRECISION_F64: {
491 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB double
492 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C  double
493 #       include "template/libxsmm_blocked_gemm.tpl.c"
494 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB
495 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C
496       } break;
497       case LIBXSMM_GEMM_PRECISION_F32: {
498 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB float
499 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C  float
500 #       include "template/libxsmm_blocked_gemm.tpl.c"
501 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB
502 #       undef  LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C
503       } break;
504       case LIBXSMM_GEMM_PRECISION_I16: {
505 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB short
506 #       define LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C  int
507 #       include "template/libxsmm_blocked_gemm.tpl.c"
508 #       undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_C
509 #       undef LIBXSMM_BLOCKED_GEMM_TEMPLATE_TYPE_AB
510       } break;
511       default: if (0 != libxsmm_verbosity /* library code is expected to be mute */
512         && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
513       {
514         fprintf(stderr, "LIBXSMM ERROR: BGEMM precision is not supported!\n");
515       }
516     }
517     if (handle->nthreads > 1) {
518       libxsmm_barrier_wait(handle->barrier, ltid);
519     }
520   }
521 #if defined(LIBXSMM_BLOCKED_GEMM_CHECKS)
522   else if (0 != libxsmm_verbosity /* library code is expected to be mute */
523     && 1 == LIBXSMM_ATOMIC_ADD_FETCH(&error_once, 1, LIBXSMM_ATOMIC_RELAXED))
524   {
525     fprintf(stderr, "LIBXSMM ERROR: invalid arguments for libxsmm_blocked_gemm!\n");
526   }
527 #endif
528 }
529 
530