1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <cstdint>
18 #if defined(_MSC_VER)
19 #include <malloc.h>
20 #endif
21 
22 #include "oneapi/dnnl/dnnl_types.h"
23 
24 #include "common/bfloat16.hpp"
25 #include "common/dnnl_traits.hpp"
26 #include "common/nstl.hpp"
27 #include "common/utils.hpp"
28 
29 #include "cpu/platform.hpp"
30 
31 #include "cpu/gemm/f32/gemm_utils_f32.hpp"
32 #include "cpu/gemm/gemm_msan_unpoison.hpp"
33 
34 #include "cpu/x64/jit_generator.hpp"
35 
36 #include "cpu/x64/gemm/gemm_driver.hpp"
37 #include "cpu/x64/gemm/gemm_info.hpp"
38 #include "cpu/x64/gemm/gemm_partition.hpp"
39 #include "cpu/x64/gemm/gemm_threading.hpp"
40 #include "cpu/x64/gemm/gemm_utils.hpp"
41 #include "cpu/x64/gemm/gemv_driver.hpp"
42 
43 #include "cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.hpp"
44 #include "cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.hpp"
45 #include "cpu/x64/gemm/f32/jit_avx_gemm_f32.hpp"
46 
47 #include "cpu/x64/gemm/s8x8s32/jit_avx512_core_gemv_s8x8s32.hpp"
48 
49 namespace dnnl {
50 namespace impl {
51 namespace cpu {
52 namespace x64 {
53 
54 template <typename c_type>
55 struct alignas(64) gemm_per_thread_t {
56     volatile int32_t result;
57     volatile int32_t compute_done;
58     int32_t thr_k_stride;
59     int32_t nthr_k;
60     dim_t ldc_local;
61     dim_t ldc_global;
62     c_type *c_local;
63     c_type *volatile c_global;
64     gemm_slice_t slice;
65 };
66 
67 template <typename T>
get_vector_length()68 int get_vector_length() {
69     int v_bytes;
70 
71     if (mayiuse(avx512_core))
72         v_bytes = cpu_isa_traits<avx512_core>::vlen;
73     else if (mayiuse(avx))
74         v_bytes = cpu_isa_traits<avx>::vlen;
75     else
76         v_bytes = cpu_isa_traits<sse41>::vlen;
77 
78     return v_bytes / sizeof(T);
79 }
80 
81 template <typename c_type>
round_to_nearest(c_type * rounded_val,double fp_val)82 static inline void round_to_nearest(c_type *rounded_val, double fp_val) {
83     if (fp_val >= 0.) {
84         fp_val += 0.5;
85         if (fp_val > INT32_MAX) { fp_val = INT32_MAX; }
86     } else {
87         fp_val -= 0.5;
88         if (fp_val < INT32_MIN) { fp_val = INT32_MIN; }
89     }
90     *rounded_val = (c_type)fp_val;
91 }
92 
93 template <typename c_type>
add_results(const dim_t m,const dim_t n,const float alpha,const float beta,const c_type * c_partial_sum,const dim_t ldcp,c_type * c_data,const dim_t ldc,const c_type * co,offset_type offsetc)94 static inline void add_results(const dim_t m, const dim_t n, const float alpha,
95         const float beta, const c_type *c_partial_sum, const dim_t ldcp,
96         c_type *c_data, const dim_t ldc, const c_type *co,
97         offset_type offsetc) {
98 
99     constexpr bool is_int8 = data_traits<c_type>::data_type == data_type::s32;
100 
101     for (dim_t j = 0; j < n; ++j) {
102         for (dim_t i = 0; i < m; ++i) {
103             c_type ctemp = c_partial_sum[i + j * ldcp];
104 
105             if (alpha == 1.0f) {
106                 if (beta == 0.0f) {
107                     c_data[i + j * ldc] = ctemp;
108                 } else {
109                     if (is_int8) {
110                         double c_float
111                                 = (double)beta * (double)c_data[i + j * ldc];
112                         c_float += (double)ctemp;
113                         round_to_nearest(&c_data[i + j * ldc], c_float);
114                     } else {
115                         c_data[i + j * ldc] *= beta;
116                         c_data[i + j * ldc] += ctemp;
117                     }
118                 }
119             } else if (alpha == -1.0f) {
120                 if (beta == 0.0f) {
121                     c_data[i + j * ldc] = -ctemp;
122                 } else {
123                     if (is_int8) {
124                         double c_float
125                                 = (double)beta * (double)c_data[i + j * ldc];
126                         c_float -= (double)ctemp;
127                         round_to_nearest(&c_data[i + j * ldc], c_float);
128                     } else {
129                         c_data[i + j * ldc] *= beta;
130                         c_data[i + j * ldc] -= ctemp;
131                     }
132                 }
133             } else {
134                 if (beta == 0.0f) {
135                     if (is_int8) {
136                         double c_float = alpha * (double)ctemp;
137                         round_to_nearest(&c_data[i + j * ldc], c_float);
138                     } else {
139                         c_data[i + j * ldc] = alpha * ctemp;
140                     }
141 
142                 } else {
143                     if (is_int8) {
144                         double c_float = alpha * (double)ctemp
145                                 + beta * (double)c_data[i + j * ldc];
146                         round_to_nearest(&c_data[i + j * ldc], c_float);
147                     } else {
148                         c_data[i + j * ldc] *= beta;
149                         c_data[i + j * ldc] += alpha * ctemp;
150                     }
151                 }
152             }
153 
154             if (offsetc == offset_type::fixed) {
155                 c_data[i + j * ldc] += co[0];
156             } else if (offsetc == offset_type::row) {
157                 c_data[i + j * ldc] += co[j];
158             } else if (offsetc == offset_type::column) {
159                 c_data[i + j * ldc] += co[i];
160             }
161         }
162     }
163 }
164 
165 template <typename a_type, typename b_type, typename c_type>
get_k_padd(int ithr,dim_t k,const gemm_info_t<a_type,b_type,c_type> * arg)166 static inline dim_t get_k_padd(
167         int ithr, dim_t k, const gemm_info_t<a_type, b_type, c_type> *arg) {
168     if (arg->a_packed) {
169         dim_t block_m, block_k;
170         arg->a_packed->get_blocking(ithr, block_m, block_k);
171         return block_k;
172     } else if (arg->b_packed) {
173         dim_t block_n, block_k;
174         arg->b_packed->get_blocking(ithr, block_k, block_n);
175         return block_k;
176     } else {
177         dim_t k_padd = 0;
178 
179         if (k <= arg->bk_traditional) {
180             k_padd = utils::rnd_up(k, arg->uk);
181             k_padd = nstl::max(dim_t(128), k_padd);
182         } else if (k < 2 * arg->bk)
183             k_padd = utils::rnd_up((k + 1) / 2, arg->uk);
184         else
185             k_padd = arg->bk;
186 
187         return k_padd;
188     }
189 }
190 
191 template <typename a_type, typename b_type, typename c_type>
get_m_padd(int ithr,dim_t m,const gemm_info_t<a_type,b_type,c_type> * arg)192 static inline dim_t get_m_padd(
193         int ithr, dim_t m, const gemm_info_t<a_type, b_type, c_type> *arg) {
194     if (arg->a_packed) {
195         dim_t block_m, block_k;
196         arg->a_packed->get_blocking(ithr, block_m, block_k);
197         return block_m;
198     } else
199         return utils::rnd_up(
200                 nstl::min(nstl::max(m, arg->um), arg->bm), arg->um);
201 }
202 
203 template <typename a_type, typename b_type, typename c_type>
get_m_padd_parallel_a(int ithr,dim_t m,const gemm_info_t<a_type,b_type,c_type> * arg,int nthrs)204 static inline dim_t get_m_padd_parallel_a(int ithr, dim_t m,
205         const gemm_info_t<a_type, b_type, c_type> *arg, int nthrs) {
206     auto m_padd = get_m_padd(ithr, m, arg);
207 
208     if (!arg->a_packed) {
209         constexpr auto multiplier = 10;
210 
211         m_padd *= nstl::max(nthrs, multiplier);
212         if (m_padd > m) m_padd = utils::rnd_up(m, arg->um);
213     }
214 
215     return m_padd;
216 }
217 
218 template <typename a_type, typename b_type, typename c_type>
get_n_padd(int ithr,dim_t n,dim_t k,const gemm_info_t<a_type,b_type,c_type> * arg)219 static inline dim_t get_n_padd(int ithr, dim_t n, dim_t k,
220         const gemm_info_t<a_type, b_type, c_type> *arg) {
221     if (arg->b_packed) {
222         dim_t block_n, block_k;
223         arg->b_packed->get_blocking(ithr, block_k, block_n);
224         return block_n;
225     } else {
226         auto bn = (k < arg->blocking_small_k) ? arg->bn_small_k : arg->bn;
227         return utils::rnd_up(nstl::min(nstl::max(n, arg->un), bn), arg->un);
228     }
229 }
230 
align(void * ptr,size_t alignment)231 static inline void *align(void *ptr, size_t alignment) {
232     return (void *)utils::rnd_up((uintptr_t)ptr, alignment);
233 }
234 
235 template <typename scale_t, typename mat_t>
scale_matrix(dim_t m,dim_t n,scale_t alpha,mat_t * __restrict p_mat,dim_t ld)236 void scale_matrix(
237         dim_t m, dim_t n, scale_t alpha, mat_t *__restrict p_mat, dim_t ld) {
238     if (data_traits<mat_t>::data_type == data_type::f32) {
239         for (dim_t j = 0; j < n; j++) {
240             for (dim_t i = 0; i < m; i++) {
241                 p_mat[i + j * ld] = (mat_t)((scale_t)p_mat[i + j * ld] * alpha);
242             }
243         }
244     }
245 }
246 
247 template <typename mat_t>
sum_matrices(dim_t m,dim_t n,mat_t * __restrict dst,dim_t ld_dst,mat_t * __restrict src,dim_t ld_src)248 static void sum_matrices(dim_t m, dim_t n, mat_t *__restrict dst, dim_t ld_dst,
249         mat_t *__restrict src, dim_t ld_src) {
250 
251     for (dim_t j = 0; j < n; j++) {
252         PRAGMA_OMP_SIMD()
253         for (int i = 0; i < m; i++)
254             dst[i + j * ld_dst] += src[i + j * ld_src];
255     }
256 }
257 
258 template <typename c_type>
sum_k_blocks(int ithr,gemm_per_thread_t<c_type> * thread_arg,bool wait)259 static void sum_k_blocks(
260         int ithr, gemm_per_thread_t<c_type> *thread_arg, bool wait) {
261 
262     auto m = thread_arg[ithr].slice.m;
263     auto n = thread_arg[ithr].slice.n;
264     auto ithr_k = thread_arg[ithr].slice.ithr_k;
265     auto nthr_k = thread_arg[ithr].nthr_k;
266     auto stride = thread_arg[ithr].thr_k_stride;
267     dim_t n0, nn;
268 
269     partition_1d(ithr_k, nthr_k, n, n0, nn);
270 
271     auto get_thread_arg = [&](int thr_k) -> gemm_per_thread_t<c_type> & {
272         return thread_arg[ithr + (thr_k - ithr_k) * stride];
273     };
274 
275     auto wait_thread = [&](int thr_k) {
276         if (wait) {
277             auto &tk_arg = get_thread_arg(thr_k);
278             while (!tk_arg.compute_done) {}
279         }
280     };
281 
282     auto add_thread_results = [&](int thr_k) {
283         auto &tk_arg = get_thread_arg(thr_k);
284 
285         sum_matrices(m, nn, tk_arg.c_global + n0 * tk_arg.ldc_global,
286                 tk_arg.ldc_global, tk_arg.c_local + n0 * tk_arg.ldc_local,
287                 tk_arg.ldc_local);
288     };
289 
290     // First accumulate this thread's results while they are in cache.
291     if (ithr_k > 0) {
292         wait_thread(0);
293         add_thread_results(ithr_k);
294     }
295 
296     // Then accumulate the others.
297     for (int thr_k = 1; thr_k < nthr_k; thr_k++) {
298         if (thr_k != ithr_k) {
299             wait_thread(thr_k);
300             add_thread_results(thr_k);
301         }
302     }
303 }
304 
305 template <typename a_type, typename b_type, typename c_type>
pack_no_copy(gemm_info_t<a_type,b_type,c_type> * arg)306 static dnnl_status_t pack_no_copy(gemm_info_t<a_type, b_type, c_type> *arg) {
307 
308     if (arg->packing == pack_type::pack_a) {
309         return gemm_utils::pack_no_copy(arg->a, arg->lda, arg->m, arg->k,
310                 arg->transa, arg->alpha, arg->pack_dst);
311     } else {
312         return gemm_utils::pack_no_copy(arg->b, arg->ldb, arg->k, arg->n,
313                 arg->transb, arg->alpha, arg->pack_dst);
314     }
315 }
316 
317 template <typename a_type, typename b_type, typename c_type>
gemm_packing_driver(int ithr,dim_t m,dim_t n,dim_t k,const a_type * a,const b_type * b,const gemm_info_t<a_type,b_type,c_type> * arg)318 static dnnl_status_t gemm_packing_driver(int ithr, dim_t m, dim_t n, dim_t k,
319         const a_type *a, const b_type *b,
320         const gemm_info_t<a_type, b_type, c_type> *arg) {
321 
322     if (m <= 0 || n <= 0) return dnnl_success;
323 
324     gemm_pack_storage_t *pack_dst = arg->pack_dst;
325 
326     if (!pack_dst->is_first_thread_in_slice(ithr)) return dnnl_success;
327 
328     dim_t block_r, block_c;
329     pack_dst->get_blocking(ithr, block_r, block_c);
330 
331     auto do_a = (arg->packing == pack_type::pack_a);
332     auto mn = do_a ? m : n;
333     auto mn_padd = do_a ? block_r : block_c;
334     auto k_padd = do_a ? block_c : block_r;
335     dim_t mn_stride, k_stride;
336 
337     if (do_a) {
338         mn_stride = (arg->transa == no_trans) ? 1 : arg->lda;
339         k_stride = (arg->transa == no_trans) ? arg->lda : 1;
340     } else {
341         mn_stride = (arg->transb == no_trans) ? arg->ldb : 1;
342         k_stride = (arg->transb == no_trans) ? 1 : arg->ldb;
343     }
344 
345     dim_t blk_k = 0;
346     for (dim_t Bk = 0; Bk < k; Bk += k_padd, blk_k++) {
347         dim_t nk = nstl::min(k - Bk, k_padd);
348 
349         for (dim_t Bmn = 0; Bmn < mn; Bmn += mn_padd) {
350             dim_t nmn = nstl::min(mn - Bmn, mn_padd);
351 
352             if (do_a) {
353                 auto a_src = a + mn_stride * Bmn + k_stride * Bk;
354                 auto a_dst = pack_dst->matrix<a_type>(ithr, Bmn, Bk);
355                 auto a_row_sum = pack_dst->row_sums<c_type>(ithr, Bmn, blk_k);
356 
357                 arg->copyA(&nk, &nmn, a_src, &arg->lda, &arg->alpha, a_dst,
358                         nullptr, nullptr, a_row_sum);
359             } else {
360                 auto b_src = b + mn_stride * Bmn + k_stride * Bk;
361                 auto b_dst = pack_dst->matrix<b_type>(ithr, Bk, Bmn);
362                 auto b_col_sum = pack_dst->col_sums<c_type>(ithr, blk_k, Bmn);
363 
364                 arg->copyB(&nk, &nmn, b_src, &arg->ldb, &arg->alpha, b_dst,
365                         nullptr, nullptr, b_col_sum);
366             }
367         }
368     }
369 
370     return dnnl_success;
371 }
372 
373 template <typename a_type, typename b_type, typename c_type>
gemm_kernel(dim_t m,dim_t n,const dim_t k,const float alpha,const a_type * a,const b_type * b,float beta,c_type * c,const dim_t ldc,const c_type * a_row_sum,const c_type * b_col_sum,const c_type * co,offset_type offsetc,const gemm_info_t<a_type,b_type,c_type> * arg)374 void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha,
375         const a_type *a, const b_type *b, float beta, c_type *c,
376         const dim_t ldc, const c_type *a_row_sum, const c_type *b_col_sum,
377         const c_type *co, offset_type offsetc,
378         const gemm_info_t<a_type, b_type, c_type> *arg) {
379 
380 #ifdef DNNL_WITH_SYCL
381     std::vector<c_type> col_offset_vec(m);
382     std::vector<c_type> row_offset_vec(n);
383     c_type *col_offset = col_offset_vec.data();
384     c_type *row_offset = row_offset_vec.data();
385 #else
386     // Since m and n are limited by blocking, stack overflow may not happen;
387     // it's up to 32kB
388 #if !defined(_MSC_VER)
389     c_type col_offset[m];
390     c_type row_offset[n];
391 #else
392     c_type *col_offset = (c_type *)_alloca(sizeof(*col_offset) * m);
393     c_type *row_offset = (c_type *)_alloca(sizeof(*row_offset) * n);
394 #endif
395 #endif
396 
397     bool col_req = false;
398     bool row_req = false;
399 
400     constexpr bool is_int8 = utils::one_of(
401             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
402     constexpr bool is_f32 = data_traits<a_type>::data_type == data_type::f32;
403     bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
404 
405     // Unconditionally zero initialize these arrays
406     for (dim_t i = 0; i < m; i++)
407         col_offset[i] = 0;
408     for (dim_t i = 0; i < n; i++)
409         row_offset[i] = 0;
410 
411     if (is_int8) {
412         c_type ao = arg->ao;
413         c_type bo = arg->bo;
414         c_type co_0 = offsetc == offset_type::none ? 0 : co[0];
415 
416         if (bo != 0 || offsetc == offset_type::column) col_req = true;
417         if (ao != 0 || offsetc == offset_type::row) row_req = true;
418 
419         // It needs one of column or row offsets, but it doesn't need both
420         if ((ao != 0 && bo != 0)
421                 || (offsetc == offset_type::fixed && co_0 != 0)) {
422             if (!col_req && !row_req) {
423                 if (m <= n) {
424                     col_req = true;
425                 } else {
426                     row_req = true;
427                 }
428             }
429         }
430 
431         if (col_req) {
432             if (offsetc == offset_type::column) {
433                 for (dim_t i = 0; i < m; i++)
434                     col_offset[i] += co[i];
435             }
436 
437             if (bo != 0 && a_row_sum) {
438                 for (dim_t i = 0; i < m; i++)
439                     col_offset[i] -= bo * a_row_sum[i];
440             }
441         }
442 
443         if (row_req) {
444             if (offsetc == offset_type::row) {
445                 for (dim_t i = 0; i < n; i++)
446                     row_offset[i] += co[i];
447             }
448 
449             if (ao != 0 && b_col_sum) {
450                 for (dim_t i = 0; i < n; i++)
451                     row_offset[i] -= ao * b_col_sum[i];
452             }
453         }
454 
455         if (offsetc == offset_type::fixed && co_0 != 0) {
456             if (col_req) {
457                 for (dim_t i = 0; i < m; i++)
458                     col_offset[i] += co_0;
459             } else {
460                 for (dim_t i = 0; i < n; i++)
461                     row_offset[i] += co_0;
462             }
463         }
464 
465         if (ao != 0 && bo != 0) {
466             if (col_req) {
467                 for (dim_t i = 0; i < m; i++)
468                     col_offset[i] += (c_type)k * ao * bo;
469             } else {
470                 for (dim_t i = 0; i < n; i++)
471                     row_offset[i] += (c_type)k * ao * bo;
472             }
473         }
474     }
475 
476     bool isBeta0 = beta == 0.0f;
477 
478     /* Column and row offsets are ignored by non-integer compute kernels.
479      * Scaling is done only for bfloat16 kernels.
480      */
481     if (m > 0 && n > 0)
482         arg->kernel[isBeta0][col_req][row_req](
483                 &m, &n, &k, &alpha, a, b, c, ldc, col_offset, row_offset);
484 
485     msan_unpoison_matrix(c, m, n, ldc, sizeof(*c));
486 
487     // sgemm kernels don't support bias yet.
488     if (is_f32) {
489         if (co && offsetc == offset_type::column) {
490             for (dim_t j = 0; j < n; j++) {
491                 for (dim_t i = 0; i < m; i++) {
492                     c[i + j * ldc] += co[i];
493                 }
494             }
495         }
496     }
497 
498     // AMX igemm kernels don't support row & col sums yet.
499     if (is_int8_amx) {
500         for (dim_t j = 0; j < n; j++) {
501             for (dim_t i = 0; i < m; i++) {
502                 if (row_req) c[i + j * ldc] += row_offset[j];
503                 if (col_req) c[i + j * ldc] += col_offset[i];
504             }
505         }
506     }
507 }
508 
509 template <typename a_type, typename b_type, typename c_type>
gemm_kernel_driver(int ithr,dim_t m,dim_t n,dim_t k,const a_type * a,const b_type * b,float beta,c_type * c,dim_t ldc,offset_type offsetc,const c_type * co,const gemm_info_t<a_type,b_type,c_type> * arg)510 static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k,
511         const a_type *a, const b_type *b, float beta, c_type *c, dim_t ldc,
512         offset_type offsetc, const c_type *co,
513         const gemm_info_t<a_type, b_type, c_type> *arg) {
514 
515     if (arg->packing != pack_type::none)
516         return gemm_packing_driver(ithr, m, n, k, a, b, arg);
517 
518     if (m <= 0 || n <= 0) return dnnl_success;
519 
520     dim_t lda = arg->lda;
521     dim_t ldb = arg->ldb;
522 
523     float alpha = arg->alpha;
524 
525     constexpr bool is_int8 = utils::one_of(
526             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
527     constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
528     bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
529     bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
530     bool is_amx = is_int8_amx || is_bf16_amx;
531 
532     const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed;
533     const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed;
534 
535     // Scaling C matrix.
536     if (!is_int8 && beta != 1.0f && beta != 0.0f) {
537         scale_matrix(m, n, beta, c, ldc);
538         beta = 1.0f;
539     }
540 
541     // Quick exit for C = beta * C
542     if (!is_int8 && alpha == 0.0f) {
543         if (beta == 0.0f) scale_matrix(m, n, beta, c, ldc);
544 
545         return dnnl_success;
546     }
547 
548     // Get block sizes.
549     dim_t k_padd = get_k_padd(ithr, k, arg);
550     dim_t m_padd = get_m_padd(ithr, m, arg);
551     dim_t n_padd = get_n_padd(ithr, n, k, arg);
552 
553     // Padding for temporary buffer for C
554     dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m_padd);
555 
556     dim_t strideAm = (arg->transa == no_trans) ? 1 : lda;
557     dim_t strideAn = (arg->transa != no_trans) ? 1 : lda;
558     dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb;
559     dim_t strideBn = (arg->transb != no_trans) ? 1 : ldb;
560 
561     size_t a_buf_nelems = m_padd * k_padd;
562     size_t b_buf_nelems = k_padd * n_padd;
563     // A and B buffers need more space due to zero-padding.
564     if (is_amx) {
565         a_buf_nelems = utils::rnd_up(m_padd, arg->um)
566                 * utils::rnd_up(k_padd, arg->uk);
567         b_buf_nelems = utils::rnd_up(k_padd, arg->uk)
568                 * utils::rnd_up(n_padd, arg->un);
569     }
570     size_t a_row_sum_nelems = m_padd;
571     size_t b_col_sum_nelems = n_padd;
572 
573     if (a_packed) a_buf_nelems = a_row_sum_nelems = 0;
574     if (b_packed) b_buf_nelems = b_col_sum_nelems = 0;
575 
576     size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
577             + b_buf_nelems * sizeof(*b) + PAGE_4K;
578 
579     if (is_int8) {
580         mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K
581                 + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
582     }
583 
584     bool need_c_buffer
585             = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f)))
586             // AMX bfloat16 kernels don't support alpha scaling yet,
587             // so we need to use accumulation buffer even if beta == 0.
588             || (is_bf16_amx && alpha != 1.0f);
589 
590     if (need_c_buffer) {
591         size_t c_buf_nelems = ldc_buf * n_padd;
592         mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
593     }
594 
595     char *mem = nullptr;
596 
597     if (mem_size > 0) {
598         mem = (char *)malloc(mem_size, 128);
599         if (!mem) return dnnl_out_of_memory;
600     }
601 
602     a_type *bufferA = (a_type *)align(mem, PAGE_4K);
603     b_type *bufferB = (b_type *)align(bufferA + a_buf_nelems, PAGE_4K);
604 
605     c_type *a_row_sum = nullptr;
606     c_type *b_col_sum = nullptr;
607     if (is_int8) {
608         a_row_sum = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
609         b_col_sum = (c_type *)align(a_row_sum + a_row_sum_nelems, PAGE_4K);
610     }
611 
612     c_type *bufferC = nullptr;
613     if (need_c_buffer) {
614         if (is_int8)
615             bufferC = (c_type *)align(b_col_sum + b_col_sum_nelems, PAGE_4K);
616         else
617             bufferC = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
618     }
619 
620     int a_block_copied = 0;
621     dim_t sizeM = 0;
622     for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
623         sizeM = m - Bm;
624         if (sizeM > m_padd) sizeM = m_padd;
625 
626         dim_t sizeK = 0;
627         dim_t blk_k = 0;
628         for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) {
629             sizeK = k - Bk;
630             if (sizeK > k_padd) sizeK = k_padd;
631 
632             // Scale C blocks by beta only for the first time
633             auto beta_eff = (Bk == 0) ? beta : 1.0f;
634 
635             // Apply C offset when to the last k-block of the partial sum.
636             auto offsetc_eff = offset_type::none;
637             if (Bk + sizeK == k) offsetc_eff = offsetc;
638 
639             dim_t sizeN = 0;
640             for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
641                 sizeN = n - Bn;
642                 if (sizeN > n_padd) sizeN = n_padd;
643 
644                 if (b_packed) {
645                     bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn);
646                     if (is_int8)
647                         b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn);
648                 } else {
649                     const b_type *b_block = b + Bk * strideBm + Bn * strideBn;
650                     const float one = 1.0f;
651 
652                     /* Column sum argument is ignored for non-integer kernels
653                      * and scaling factor is ignored by 8-bit and 16-bit copy
654                      * kernels.
655                      */
656                     arg->copyB(&sizeK, &sizeN, b_block, &ldb, &one, bufferB,
657                             nullptr, nullptr, b_col_sum);
658                 }
659 
660                 dim_t sizeUM = 0;
661                 for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
662                     sizeUM = sizeM - Um;
663                     if (sizeUM > arg->um) sizeUM = arg->um;
664 
665                     /* Use the whole A buffer only if we have multiple B
666                      * blocks for k-dimension, otherwise we are wasting cache
667                      * to store B and C blocks.
668                      */
669                     dim_t Um_forA = 0;
670                     if (sizeN < n) Um_forA = Um;
671 
672                     a_type *bufferA_eff = nullptr;
673                     c_type *a_row_sum_eff = nullptr;
674 
675                     if (a_packed) {
676                         Um_forA = Um;
677 
678                         // TODO Can we simplify this!
679                         dim_t buf_shift = 0;
680                         if (is_amx)
681                             buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk);
682                         else
683                             buf_shift = Um_forA * sizeK;
684 
685                         bufferA_eff = a_packed->matrix<a_type>(ithr, Bm, Bk)
686                                 + buf_shift;
687 
688                         if (is_int8)
689                             a_row_sum_eff = a_packed->row_sums<c_type>(
690                                                     ithr, Bm, blk_k)
691                                     + Um_forA;
692                     } else {
693                         // TODO Can we simplify this!
694                         dim_t buf_shift = 0;
695                         if (is_amx)
696                             buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk);
697                         else
698                             buf_shift = Um_forA * sizeK;
699 
700                         bufferA_eff = bufferA + buf_shift;
701                         a_row_sum_eff
702                                 = a_row_sum ? a_row_sum + Um_forA : nullptr;
703 
704                         if (!a_block_copied) {
705                             const a_type *a_block
706                                     = a + (Bm + Um) * strideAm + Bk * strideAn;
707 
708                             /* Row sum argument is ignored for non-integer
709                              * kernels and scaling factor is ignored by 8-bit
710                              * and 16-bit copy kernels.
711                              */
712                             arg->copyA(&sizeK, &sizeUM, a_block, &lda, &alpha,
713                                     bufferA_eff, nullptr, nullptr,
714                                     a_row_sum_eff);
715                         }
716                     }
717 
718                     c_type *c_block = c + (Bm + Um) + Bn * ldc;
719 
720                     dim_t co_stride = 0;
721                     if (offsetc_eff == offset_type::row)
722                         co_stride = Bn;
723                     else if (offsetc_eff == offset_type::column)
724                         co_stride = Bm + Um;
725 
726                     if (need_c_buffer) {
727                         gemm_kernel(sizeUM, sizeN, sizeK, 1.0f, bufferA_eff,
728                                 bufferB, 0.0f, bufferC + Um, ldc_buf,
729                                 a_row_sum_eff, b_col_sum, (c_type *)nullptr,
730                                 offset_type::none, arg);
731 
732                         /* Finish the block adding the necessary alpha, beta
733                          * and offsets.
734                          */
735                         add_results(sizeUM, sizeN, alpha, beta_eff,
736                                 bufferC + Um, ldc_buf, c_block, ldc,
737                                 co + co_stride, offsetc_eff);
738                     } else {
739                         gemm_kernel(sizeUM, sizeN, sizeK, alpha, bufferA_eff,
740                                 bufferB, beta_eff, c_block, ldc, a_row_sum_eff,
741                                 b_col_sum, co + co_stride, offsetc_eff, arg);
742                     }
743                 }
744                 a_block_copied = 1;
745             }
746             a_block_copied = 0;
747         }
748     }
749 
750     free(mem);
751 
752     return dnnl_success;
753 }
754 
755 template <typename a_type, typename b_type, typename c_type>
kernel_driver_parallel_acopiedbcopy(int ithr,dim_t m,dim_t n,dim_t k,dim_t blk_k,dim_t Bk,const a_type * bufferA,const b_type * b,float beta,c_type * c,offset_type offsetc,const c_type * co,const c_type * a_row_sum,const gemm_info_t<a_type,b_type,c_type> * arg)756 static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m,
757         dim_t n, dim_t k, dim_t blk_k, dim_t Bk, const a_type *bufferA,
758         const b_type *b, float beta, c_type *c, offset_type offsetc,
759         const c_type *co, const c_type *a_row_sum,
760         const gemm_info_t<a_type, b_type, c_type> *arg) {
761 
762     dim_t ldb = arg->ldb;
763     dim_t ldc = arg->ldc;
764 
765     float alpha = arg->alpha;
766 
767     const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed;
768 
769     if (m <= 0 || n <= 0) { return dnnl_success; }
770 
771     // Padding along N dimension.
772     dim_t n_padd = get_n_padd(ithr, n, k, arg);
773 
774     // Padding for temporary buffer for C
775     dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m);
776 
777     dim_t strideBn = (arg->transb != 0) ? 1 : ldb;
778 
779     size_t b_buf_nelems = k * n_padd;
780     size_t b_col_sum_nelems = n_padd;
781 
782     constexpr bool is_int8 = utils::one_of(
783             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
784     constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
785     bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
786     bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
787     bool is_amx = is_int8_amx || is_bf16_amx;
788 
789     // B buffer needs to large due to zero-padding.
790     if (is_amx)
791         b_buf_nelems
792                 = utils::rnd_up(k, arg->uk) * utils::rnd_up(n_padd, arg->un);
793 
794     if (b_packed) b_buf_nelems = b_col_sum_nelems = 0;
795 
796     size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K;
797 
798     if (is_int8) { mem_size += b_col_sum_nelems * sizeof(*c) + PAGE_4K; }
799 
800     bool need_c_buffer
801             = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f)))
802             // AMX bfloat16 kernels don't support alpha scaling yet,
803             // so we need to use accumulation buffer even if beta == 0.
804             || (is_bf16_amx && alpha != 1.0f);
805 
806     if (need_c_buffer) {
807         size_t c_buf_nelems = ldc_buf * n_padd;
808         mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
809     }
810 
811     char *mem = nullptr;
812 
813     if (mem_size > 0) {
814         mem = (char *)malloc(mem_size, 128);
815         if (!mem) return dnnl_out_of_memory;
816     }
817 
818     b_type *bufferB = (b_type *)align(mem, PAGE_4K);
819 
820     c_type *b_col_sum = nullptr;
821     if (is_int8) {
822         b_col_sum = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
823     }
824 
825     c_type *bufferC = nullptr;
826     if (need_c_buffer) {
827         if (is_int8)
828             bufferC = (c_type *)align(b_col_sum + b_col_sum_nelems, PAGE_4K);
829         else
830             bufferC = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
831     }
832 
833     dim_t sizeN = 0;
834     for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
835         sizeN = n - Bn;
836         if (sizeN > n_padd) sizeN = n_padd;
837 
838         if (b_packed) {
839             bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn);
840             if (is_int8)
841                 b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn);
842         } else {
843             const b_type *b_block = b + Bn * strideBn;
844             const float one = 1.0f;
845 
846             /* Column sum argument is ignored for non-integer kernels and
847              * scaling factor is ignored by 8-bit and 16-bit copy kernels.
848              */
849             arg->copyB(&k, &sizeN, b_block, &ldb, &one, bufferB, nullptr,
850                     nullptr, b_col_sum);
851         }
852 
853         dim_t co_stride = 0;
854         if (offsetc == offset_type::fixed) {
855             co_stride = 0;
856         } else if (offsetc == offset_type::row) {
857             co_stride = Bn;
858         } else if (offsetc == offset_type::column) {
859             co_stride = 0;
860         }
861 
862         c_type *c_block = c + Bn * ldc;
863         if (need_c_buffer) {
864             gemm_kernel(m, sizeN, k, 1.0f, bufferA, bufferB, 0.0f, bufferC,
865                     ldc_buf, a_row_sum, b_col_sum, (c_type *)nullptr,
866                     offset_type::none, arg);
867 
868             // Finish the block adding the necessary alpha, beta and offsets.
869             add_results(m, sizeN, alpha, beta, bufferC, ldc_buf, c_block, ldc,
870                     co + co_stride, offsetc);
871         } else {
872             gemm_kernel(m, sizeN, k, alpha, bufferA, bufferB, beta, c_block,
873                     ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
874         }
875     }
876 
877     free(mem);
878 
879     return dnnl_success;
880 }
881 
nocopy_checker_avx2(const int nthr,const int transa,const int transb,const dim_t m,const dim_t n,const dim_t k,const dim_t lda,const dim_t ldb,const dim_t ldc)882 static inline bool nocopy_checker_avx2(const int nthr, const int transa,
883         const int transb, const dim_t m, const dim_t n, const dim_t k,
884         const dim_t lda, const dim_t ldb, const dim_t ldc) {
885     static const dim_t BM_NOCOPY_AVX2 = 64;
886     static const dim_t MN_NOCOPY_AVX2 = 128;
887     static const dim_t N_TRANSB_PER_THR = 1;
888     static const dim_t K_TRANSB_PER_THR = 1;
889     static const dim_t N_NOTRANSB_PER_THR = 16;
890     static const dim_t K_NOTRANSB_PER_THR = 2;
891     static const double FORCE_NOCOPY_THRESH = 0.0038;
892 
893     // Crude threshold to nocopy kernels if copy overhead is significant.
894     if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH) { return true; }
895 
896     if (m <= 378 && n <= 378 && k >= nthr * 378) return false;
897 
898     if (m >= nthr * 378 && k >= nthr * 378) return false;
899 
900     if (transb == no_trans) {
901         if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true;
902         if (n <= nthr * N_NOTRANSB_PER_THR) return true;
903         if (k <= nthr * K_NOTRANSB_PER_THR) return true;
904         if (m <= BM_NOCOPY_AVX2 && n >= nthr * N_NOTRANSB_PER_THR) return true;
905     } else {
906         if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true;
907         if (n <= nthr * N_TRANSB_PER_THR) return true;
908         if (k <= nthr * K_TRANSB_PER_THR) return true;
909     }
910 
911     return false;
912 }
913 
nocopy_checker_avx512(int nthr,const int transa,const int transb,const dim_t m,const dim_t n,const dim_t k,const dim_t lda,const dim_t ldb,const dim_t ldc)914 static inline bool nocopy_checker_avx512(int nthr, const int transa,
915         const int transb, const dim_t m, const dim_t n, const dim_t k,
916         const dim_t lda, const dim_t ldb, const dim_t ldc) {
917     // Constants definition
918     static const dim_t BAD_LD_MULT = 256;
919     static const dim_t VERYBAD_LD_MULT = 1024;
920     static const dim_t M_TRANSB_PER_THR = 28;
921     static const dim_t N_TRANSB_PER_THR = 28;
922     static const dim_t K_TRANSB_PER_THR = 1;
923     static const dim_t MN_NOTRANSB_PER_THR = 28;
924     static const dim_t K_NOTRANSB_PER_THR = 1;
925     static const double FORCE_NOCOPY_THRESH = 0.00196;
926 
927     bool is_NN = transa == no_trans && transb == no_trans;
928     bool is_NT = transa == no_trans && transb == do_trans;
929     bool is_TN = transa == do_trans && transb == no_trans;
930 
931     bool is_lda_bad = lda % BAD_LD_MULT == 0;
932     bool is_ldb_bad = ldb % BAD_LD_MULT == 0;
933     bool is_ldc_bad = ldc % BAD_LD_MULT == 0;
934     bool is_ld_bad = is_lda_bad || is_ldb_bad || is_ldc_bad;
935 
936     bool is_lda_verybad = lda % VERYBAD_LD_MULT == 0;
937 
938     // Copy-based performs better for TN case with small N in sequential case.
939     if (nthr == 1 && is_TN && m > 100
940             && ((m < 1200 && n < 200 && k < 1200)
941                     || (is_lda_bad && is_ldb_bad)))
942         return false;
943 
944     // Copy-based performs better for NN case on very bad leading dimension if
945     // each thread has enough work.
946     if (nthr <= 8 && is_NN && is_lda_verybad && k > 500 && n > 100)
947         return false;
948 
949     // Crude threshold for nocopy kernels if copy overhead is significant.
950     if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH
951             && !(is_lda_verybad && is_NT)) {
952         return true;
953     }
954 
955     // Copy strategy usually performs better than nocopy on "bad" leading
956     // dimensions.
957     if (is_ld_bad) {
958         bool use_copy_based = false;
959 
960         if (m >= 32 && n > 16) use_copy_based = true;
961 
962         // Nocopy outperforms copy-based in certain conditions.
963         if (m >= 32 && n == 16
964                 && (k >= 6400 || transa == do_trans || m == 4096))
965             use_copy_based = true;
966 
967         if (use_copy_based) return false;
968     }
969 
970     if (m <= 378 && n <= 378 && k >= nthr * 378) return false;
971 
972     if (m >= nthr * 378 && k >= nthr * 378) return false;
973 
974     if (transb == no_trans) {
975         if (m <= nthr * MN_NOTRANSB_PER_THR) return true;
976         if (n <= nthr * MN_NOTRANSB_PER_THR) return true;
977         if (k <= nthr * K_NOTRANSB_PER_THR) return true;
978     } else {
979         if (m <= nthr * M_TRANSB_PER_THR && m >= n) return true;
980         if (n <= nthr * N_TRANSB_PER_THR) return true;
981         if (k <= nthr * K_TRANSB_PER_THR) return true;
982     }
983     return false;
984 }
985 
986 template <typename a_type, typename b_type, typename c_type>
nocopy_checker(int nthr,const gemm_info_t<a_type,b_type,c_type> * arg)987 static inline bool nocopy_checker(
988         int nthr, const gemm_info_t<a_type, b_type, c_type> *arg) {
989 
990     if (data_traits<a_type>::data_type != data_type::f32) return false;
991 
992     if (!mayiuse(avx)) return false;
993 
994     if (arg->force_nocopy) return true;
995 
996     auto m = arg->m, n = arg->n, k = arg->k;
997     auto lda = arg->lda, ldb = arg->ldb, ldc = arg->ldc;
998     auto transa = arg->transa, transb = arg->transb;
999     auto packing = arg->packing;
1000 
1001     if (packing != pack_type::none) ldc = 64;
1002 
1003     if (arg->a_packed || arg->b_packed)
1004         return false;
1005     else if (mayiuse(avx512_core))
1006         return nocopy_checker_avx512(
1007                 nthr, transa, transb, m, n, k, lda, ldb, ldc);
1008     else
1009         return nocopy_checker_avx2(
1010                 nthr, transa, transb, m, n, k, lda, ldb, ldc);
1011 }
1012 
1013 template <typename a_type, typename b_type, typename c_type>
set_thread_opts_nopack(int nthrs,int nthrs_spawn,gemm_threading_t & thread_info,const gemm_info_t<a_type,b_type,c_type> * arg)1014 static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn,
1015         gemm_threading_t &thread_info,
1016         const gemm_info_t<a_type, b_type, c_type> *arg) {
1017 
1018     static constexpr dim_t N2D_MAX = 384;
1019     static constexpr dim_t M2D_MIN = 384;
1020 
1021     constexpr bool is_int8 = utils::one_of(
1022             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1023     bool isSgemm = data_traits<a_type>::data_type == data_type::f32;
1024 
1025     dim_t m = arg->m;
1026     dim_t n = arg->n;
1027     dim_t k = arg->k;
1028 
1029     thread_info.nthrs_m = 0;
1030     thread_info.nthrs_n = 0;
1031     thread_info.nthrs_k = 0;
1032     thread_info.copy = copy_type::nonshared;
1033     thread_info.partition = partition_type::row_1d;
1034 
1035     // TODO Check if we can use dynamic scheduling for sgemm.
1036     // TODO Check if we should use 3D blocking.
1037     thread_info.nthrs_k = 1;
1038     thread_info.thread_k = k;
1039 
1040     bool condition_2D_bsrc = false;
1041     if (isSgemm) {
1042         // If m is large and n is small then do 1D partitioning for AVX2.
1043         if (!mayiuse(avx512_core) && n <= N2D_MAX && (m >= nthrs * M2D_MIN))
1044             condition_2D_bsrc = false;
1045         else
1046             condition_2D_bsrc
1047                     = ((n > nthrs * N2D_MAX) || (n <= nthrs * N2D_MAX / 2))
1048                     && (m >= 2 * M2D_MIN);
1049     } else {
1050         int scale = mayiuse(avx512_core) ? nthrs : 20;
1051         condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n);
1052     }
1053 
1054     // TODO Check if we should use k-partitioning.
1055 
1056     int condition_1D_copya = false;
1057     if (mayiuse(avx512_core)) {
1058         const dim_t thresh = isSgemm ? N2D_MAX / 4 : 68;
1059         if (m >= 1000 && (n >= nthrs * thresh)) {
1060             condition_2D_bsrc = false;
1061             condition_1D_copya = true;
1062         }
1063     } else {
1064         if (m >= 1000 && n >= 4000) {
1065             condition_2D_bsrc = false;
1066             condition_1D_copya = true;
1067         }
1068     }
1069 
1070     // If A or B offset is non-zero, we need to keep 1D_copya to reduce update
1071     // overhead.
1072     // TODO: the reasons seems to be in copy_sum_bx routines. At least,
1073     //       after simple optimization of copy_sum_ax for avx512, similar
1074     //       restriction on offset B became unnecessary. Revisit.
1075     if (is_int8 && arg->ao != 0 && (arg->bo != 0 || mayiuse(avx512_core))) {
1076         condition_2D_bsrc = false;
1077         condition_1D_copya = true;
1078     }
1079 
1080     if (condition_2D_bsrc) {
1081         int nthrs_m = 1;
1082         int nthrs_n = nthrs;
1083 
1084         if (isSgemm) {
1085             while ((nthrs_n % 2 == 0)
1086                     && (n / nthrs > N2D_MAX || n / nthrs_n <= N2D_MAX / 2)
1087                     && (m / nthrs_m >= 2 * M2D_MIN) && (nthrs_m < 4)) {
1088                 nthrs_m *= 2;
1089                 nthrs_n /= 2;
1090             }
1091 
1092             thread_info.nthrs_m = nthrs_m;
1093             thread_info.nthrs_n = nthrs_n;
1094             thread_info.partition = partition_type::col_major_2d;
1095         } else {
1096             if (m == 800 && n == 300) {
1097                 // TODO: Expand this branch to other problem sizes.
1098 
1099                 auto &thread_m = thread_info.thread_m;
1100                 auto &thread_n = thread_info.thread_n;
1101 
1102                 const dim_t block_m = arg->um * 4;
1103                 constexpr dim_t block_n = 64;
1104                 constexpr dim_t small_m = 16;
1105                 constexpr dim_t small_n = 2;
1106 
1107                 std::tie(nthrs_m, nthrs_n)
1108                         = gemm_utils::calc_nthr_2d(nthrs, m, n, block_m,
1109                                 block_n, small_m, small_n, thread_m, thread_n);
1110 
1111                 thread_info.nthrs_m = nthrs_m;
1112                 thread_info.nthrs_n = nthrs_n;
1113                 thread_info.partition = partition_type::mnk_3d;
1114 
1115             } else if ((n <= 64 || n >= 256)) {
1116                 while (((nthrs_n > 1) && (n / nthrs_n < arg->un)
1117                                && (m / nthrs_m >= 2 * arg->um)
1118                                && mayiuse(avx512_core))
1119                         || ((nthrs_n % 2 == 0)
1120                                 && (n / nthrs > N2D_MAX
1121                                         || n / nthrs_n <= N2D_MAX / 2)
1122                                 && (m / nthrs_m >= 2 * M2D_MIN)
1123                                 && (nthrs_m < 4))) {
1124                     nthrs_m *= 2;
1125                     nthrs_n /= 2;
1126                 }
1127 
1128                 thread_info.nthrs_m = nthrs_m;
1129                 thread_info.nthrs_n = nthrs_n;
1130                 thread_info.partition = partition_type::col_major_2d;
1131             } else {
1132                 // Use 3D decomposition from pack api without k-partitioning.
1133                 set_thread_opts_pack(nthrs, thread_info, arg, false);
1134             }
1135         }
1136 
1137     } else if (condition_1D_copya && dnnl_thr_syncable()) {
1138         // Use parallel copy A algorithm
1139         thread_info.copy = copy_type::shared_a;
1140         thread_info.partition = partition_type::col_1d;
1141         thread_info.nthrs_m = 1;
1142         thread_info.nthrs_n = nthrs_spawn; // Using all spawned threads.
1143     } else {
1144         auto veclen = get_vector_length<c_type>();
1145 
1146         if (m > n && (m >= nthrs * veclen || n < nthrs)) {
1147             if (n <= 20 && is_int8) {
1148                 // Use 3D decomposition forcing m-blocking only.
1149                 set_thread_opts_pack(
1150                         nthrs, thread_info, arg, false, true, false);
1151             } else {
1152                 thread_info.partition = partition_type::row_1d;
1153                 thread_info.nthrs_m = nthrs;
1154                 thread_info.nthrs_n = 1;
1155             }
1156         } else {
1157             thread_info.partition = partition_type::col_1d;
1158             thread_info.nthrs_m = 1;
1159             thread_info.nthrs_n = nthrs;
1160         }
1161     }
1162 }
1163 
1164 template <typename a_type, typename b_type, typename c_type>
set_thread_opts_pack(int nthrs,gemm_threading_t & thread_info,const gemm_info_t<a_type,b_type,c_type> * arg,bool do_k_blocking=true,bool do_m_blocking=true,bool do_n_blocking=true)1165 static inline void set_thread_opts_pack(int nthrs,
1166         gemm_threading_t &thread_info,
1167         const gemm_info_t<a_type, b_type, c_type> *arg,
1168         bool do_k_blocking = true, bool do_m_blocking = true,
1169         bool do_n_blocking = true) {
1170 
1171     constexpr bool is_int8 = utils::one_of(
1172             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1173     constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1174 
1175     bool do_m_blocking_only = do_m_blocking && !do_n_blocking;
1176 
1177     auto m = arg->m, n = arg->n, k = arg->k;
1178 
1179     auto &nthr_m = thread_info.nthrs_m;
1180     auto &nthr_n = thread_info.nthrs_n;
1181     auto &nthr_k = thread_info.nthrs_k;
1182     auto &thread_m = thread_info.thread_m;
1183     auto &thread_n = thread_info.thread_n;
1184     auto &thread_k = thread_info.thread_k;
1185     auto &block_m = thread_info.block_m;
1186     auto &block_n = thread_info.block_n;
1187     auto &block_k = thread_info.block_k;
1188 
1189     constexpr auto MBLK = 64;
1190     constexpr auto NBLK = 64;
1191     auto KBLK = is_int8 ? 3072 : 256;
1192     KBLK = do_m_blocking_only && is_int8 ? 384 : KBLK;
1193 
1194     nthr_m = nthr_n = nthr_k = 1;
1195     thread_info.copy = copy_type::nonshared;
1196     thread_info.partition = partition_type::mnk_3d;
1197 
1198     auto choose_blocking
1199             = [](dim_t size_z, dim_t &thread_z, int &nthr_z, dim_t block_z_init,
1200                       dim_t &block_z, dim_t block_align) {
1201                   thread_z = utils::div_up(size_z, nthr_z);
1202                   auto num_blk = utils::div_up(thread_z, block_z_init);
1203                   block_z = utils::div_up(thread_z, num_blk);
1204                   block_z = utils::rnd_up(block_z, block_align);
1205                   thread_z = num_blk * block_z;
1206                   if (thread_z * nthr_z > size_z)
1207                       nthr_z = utils::div_up(size_z, thread_z);
1208               };
1209 
1210     auto choose_m_blocking = [&]() {
1211         auto align = get_vector_length<c_type>();
1212         align = do_m_blocking_only ? arg->um : align;
1213         choose_blocking(m, thread_m, nthr_m, arg->bm, block_m, align);
1214     };
1215     auto choose_n_blocking = [&]() {
1216         choose_blocking(n, thread_n, nthr_n, arg->bn, block_n, arg->un);
1217     };
1218     auto choose_k_blocking = [&]() {
1219         auto align = nstl::max(arg->uk, dim_t(4));
1220         choose_blocking(k, thread_k, nthr_k, arg->bk, block_k, align);
1221     };
1222 
1223     // Choose k blocking.
1224     if ((m / MBLK + n / NBLK) < nthrs && do_k_blocking) {
1225         for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1226             if (nthrs % nk == 0) nthr_k = nk;
1227 
1228         // Sacrifice one thread and try again if parallelism is too small in
1229         // n-dimension.
1230         if (nthr_k == 1 && nthrs > 1 && do_m_blocking_only) {
1231             nthrs--;
1232             for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1233                 if (nthrs % nk == 0) nthr_k = nk;
1234         }
1235 
1236         // Allow up to 2 threads to be sacrificed for large k >> m, n.
1237         if (nthr_k < 4 && k >= m * 4 && k >= n * 4 && nthrs > 10 && is_bf16) {
1238             for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1239                 if (nthrs % nk <= 2) nthr_k = nk;
1240         }
1241     }
1242 
1243     choose_k_blocking();
1244 
1245     // Choose m/n blocking.
1246     auto min_mblk = mayiuse(avx512_core) ? (MBLK / 2) : arg->um;
1247     min_mblk = do_m_blocking ? min_mblk : m;
1248     min_mblk = do_m_blocking_only ? arg->um : min_mblk;
1249     auto min_nblk = do_n_blocking ? NBLK / 2 : n;
1250 
1251     std::tie(nthr_m, nthr_n) = partition_2d_minblk(m, n, MBLK, NBLK, min_mblk,
1252             min_nblk, arg->um, arg->un, nthrs / nthr_k,
1253             do_m_blocking && do_n_blocking && do_k_blocking);
1254 
1255     auto nthr_m_init = nthr_m, nthr_n_init = nthr_n;
1256 
1257     choose_m_blocking();
1258     choose_n_blocking();
1259 
1260     if (is_int8 && do_m_blocking && do_n_blocking) {
1261         // If we lost a thread in one dimension because we padded the blocking
1262         // size, try to rebalance the other dimensions.
1263         if ((nthr_n != nthr_n_init)
1264                 && ((nthr_m + 1) * nthr_n * nthr_k <= nthrs)) {
1265             nthr_m++;
1266             choose_m_blocking();
1267         }
1268 
1269         if ((nthr_m != nthr_m_init)
1270                 && (nthr_m * (nthr_n + 1) * nthr_k <= nthrs)) {
1271             nthr_n++;
1272             choose_n_blocking();
1273         }
1274     }
1275 }
1276 
1277 template <typename a_type, typename b_type, typename c_type>
set_thread_opts(int nthrs,int nthrs_spawn,gemm_threading_t & thread_info,const gemm_info_t<a_type,b_type,c_type> * arg)1278 static inline int set_thread_opts(int nthrs, int nthrs_spawn,
1279         gemm_threading_t &thread_info,
1280         const gemm_info_t<a_type, b_type, c_type> *arg) {
1281 
1282     thread_info.block_m = thread_info.block_n = thread_info.block_k = -1;
1283     thread_info.thread_m = thread_info.thread_n = thread_info.thread_k = -1;
1284 
1285     constexpr bool is_int8 = utils::one_of(
1286             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1287     constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1288 
1289     if (nocopy_checker(nthrs, arg)) {
1290         thread_info.copy = copy_type::no_copy;
1291         thread_info.partition = partition_type::mnk_3d;
1292         int nthrs_m = 0;
1293         int nthrs_n = 0;
1294         int nthrs_k = 0;
1295         dim_t BM = 0;
1296         dim_t BN = 0;
1297         dim_t BK = 0;
1298         auto m = arg->m, n = arg->n, k = arg->k;
1299 
1300         if (mayiuse(avx512_core)) {
1301             cpu::gemm_utils::calc_nthr_nocopy_avx512_common(m, n, k, nthrs,
1302                     &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK);
1303         } else {
1304             cpu::gemm_utils::calc_nthr_nocopy_avx(m, n, k, nthrs, &nthrs_m,
1305                     &nthrs_n, &nthrs_k, &BM, &BN, &BK);
1306         }
1307 
1308         // Block information is being ignored. We will create partitioning
1309         // later.
1310         thread_info.nthrs_m = nthrs_m;
1311         thread_info.nthrs_n = nthrs_n;
1312         thread_info.nthrs_k = nthrs_k;
1313     } else {
1314         if (arg->packing != pack_type::none && (is_int8 || is_bf16))
1315             set_thread_opts_pack(nthrs, thread_info, arg);
1316         else
1317             set_thread_opts_nopack(nthrs, nthrs_spawn, thread_info, arg);
1318     }
1319 
1320     return thread_info.nthrs_m * thread_info.nthrs_n * thread_info.nthrs_k;
1321 }
1322 
1323 template <typename a_type, typename b_type, typename c_type>
1324 static inline std::tuple<const a_type *, const b_type *, c_type *,
1325         const c_type *>
decompose_matrices(const gemm_slice_t & slice,const gemm_info_t<a_type,b_type,c_type> * arg)1326 decompose_matrices(const gemm_slice_t &slice,
1327         const gemm_info_t<a_type, b_type, c_type> *arg) {
1328 
1329     dim_t stride_am = (arg->transa == no_trans) ? 1 : arg->lda;
1330     dim_t stride_ak = (arg->transa != no_trans) ? 1 : arg->lda;
1331     dim_t stride_bn = (arg->transb != no_trans) ? 1 : arg->ldb;
1332     dim_t stride_bk = (arg->transb == no_trans) ? 1 : arg->ldb;
1333 
1334     auto a = arg->a;
1335     auto b = arg->b;
1336     auto c = arg->c;
1337     if (a) a += slice.off_m * stride_am + slice.off_k * stride_ak;
1338     if (b) b += slice.off_n * stride_bn + slice.off_k * stride_bk;
1339     if (c) c += slice.off_m + slice.off_n * arg->ldc;
1340 
1341     dim_t co_stride;
1342     switch (arg->offsetc) {
1343         case offset_type::row: co_stride = slice.off_n; break;
1344         case offset_type::column: co_stride = slice.off_m; break;
1345         default: co_stride = 0; break;
1346     }
1347     auto co = arg->co;
1348     if (co) co += co_stride;
1349 
1350     return std::make_tuple(a, b, c, co);
1351 }
1352 
1353 template <typename a_type, typename b_type, typename c_type>
parallel_a_copy(const int ithr,const int nthrs,const dim_t m,const dim_t n,const dim_t k,const a_type * a,const b_type * b,float beta,c_type * c,dim_t ldc,offset_type offsetc,const c_type * co,const gemm_info_t<a_type,b_type,c_type> * arg,char ** p_shared_mem)1354 static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs,
1355         const dim_t m, const dim_t n, const dim_t k, const a_type *a,
1356         const b_type *b, float beta, c_type *c, dim_t ldc, offset_type offsetc,
1357         const c_type *co, const gemm_info_t<a_type, b_type, c_type> *arg,
1358         char **p_shared_mem) {
1359 
1360     if (arg->packing != pack_type::none)
1361         return gemm_packing_driver(ithr, m, n, k, a, b, arg);
1362 
1363     const dim_t lda = arg->lda;
1364     const dim_t ldb = arg->ldb;
1365     const dim_t strideAm = (arg->transa == no_trans) ? 1 : lda;
1366     const dim_t strideAn = (arg->transa != no_trans) ? 1 : lda;
1367     const dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb;
1368 
1369     float alpha = arg->alpha;
1370 
1371     constexpr bool is_int8 = utils::one_of(
1372             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1373     constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1374     bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
1375     bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
1376     bool is_amx = is_int8_amx || is_bf16_amx;
1377 
1378     const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed;
1379 
1380     // Scaling C matrix.
1381     if (!is_int8 && beta != 1.0f && beta != 0.0f) {
1382         scale_matrix(m, n, beta, c, ldc);
1383         beta = 1.0f;
1384     }
1385 
1386     // Padding along M, K dimensions.
1387     dim_t m_padd = get_m_padd_parallel_a(ithr, m, arg, nthrs);
1388     dim_t k_padd = get_k_padd(ithr, k, arg);
1389 
1390     size_t a_buf_nelems = m_padd * k_padd;
1391 
1392     // A buffer needs more space due to zero-padding.
1393     if (is_amx)
1394         a_buf_nelems = utils::rnd_up(m_padd, arg->um)
1395                 * utils::rnd_up(k_padd, arg->uk);
1396 
1397     // Allocate shared memory for A and its row sum buffers in master thread.
1398     char *mem = nullptr;
1399     a_type *bufferA = nullptr;
1400     c_type *a_row_sum = nullptr;
1401 
1402     if (!a_packed) {
1403         if (ithr == 0) { // If thread master
1404             size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K);
1405 
1406             if (is_int8) {
1407                 size_t a_row_sum_nelems = m_padd;
1408                 mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K;
1409             }
1410 
1411             *p_shared_mem = (char *)malloc(mem_size, 128);
1412         }
1413 
1414         dnnl_thr_barrier();
1415 
1416         mem = *p_shared_mem;
1417         bufferA = (a_type *)align(mem, PAGE_4K);
1418 
1419         if (is_int8)
1420             a_row_sum = (c_type *)align(bufferA + a_buf_nelems, PAGE_4K);
1421 
1422         if (!mem) return dnnl_out_of_memory;
1423     }
1424 
1425     dnnl_status_t result = dnnl_success; // Return status
1426 
1427     dim_t sizeK = 0;
1428     dim_t blk_k = 0;
1429     for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) {
1430         sizeK = k - Bk;
1431         if (sizeK > k_padd) sizeK = k_padd;
1432 
1433         // Scale C blocks by beta only for the first term of partial sum.
1434         auto beta_eff = (Bk == 0) ? beta : 1.0f;
1435 
1436         // Apply C offset for the last k-block of the partial sum.
1437         auto offsetc_eff = offset_type::none;
1438         if (Bk + sizeK == k) offsetc_eff = offsetc;
1439 
1440         dim_t sizeM = 0;
1441         for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
1442             sizeM = m - Bm;
1443             if (sizeM > m_padd) sizeM = m_padd;
1444 
1445             if ((ithr < nthrs) && !a_packed) {
1446                 dim_t band = (sizeM + nthrs - 1) / nthrs;
1447                 band = utils::rnd_up(band, arg->um);
1448 
1449                 dim_t offset = band * ithr;
1450 
1451                 // If offset is too large don't use that thread for copying.
1452                 if (offset >= sizeM) {
1453                     offset = 0;
1454                     band = 0;
1455                 }
1456 
1457                 // Handle the tail of the copy.
1458                 if (offset + band > sizeM) { band = sizeM - offset; }
1459 
1460                 if (band > 0) {
1461                     const a_type *a_block
1462                             = a + (Bm + offset) * strideAm + Bk * strideAn;
1463 
1464                     dim_t buf_shift = 0;
1465                     if (is_amx)
1466                         buf_shift = offset * utils::rnd_up(sizeK, arg->uk);
1467                     else
1468                         buf_shift = offset * sizeK;
1469 
1470                     /* Row sum argument is ignored for non-integer kernels and
1471                      * scaling factor is ignored by 8-bit and 16-bit copy
1472                      * kernels.
1473                      */
1474                     c_type *a_row_sum_eff
1475                             = a_row_sum ? a_row_sum + offset : nullptr;
1476                     arg->copyA(&sizeK, &band, a_block, &lda, &alpha,
1477                             bufferA + buf_shift, nullptr, nullptr,
1478                             a_row_sum_eff);
1479                 }
1480             }
1481             if (!a_packed)
1482                 dnnl_thr_barrier(); // Wait for finishing parallel copy.
1483 
1484             const b_type *b_block = b + Bk * strideBm;
1485             c_type *c_block = c + Bm;
1486 
1487             dim_t co_stride = 0;
1488             if (offsetc_eff == offset_type::fixed) {
1489                 co_stride = 0;
1490             } else if (offsetc_eff == offset_type::row) {
1491                 co_stride = 0;
1492             } else if (offsetc_eff == offset_type::column) {
1493                 co_stride = Bm;
1494             }
1495 
1496             auto bufferA_eff
1497                     = a_packed ? a_packed->matrix<a_type>(0, Bm, Bk) : bufferA;
1498             auto a_row_sum_eff = a_packed
1499                     ? a_packed->row_sums<c_type>(0, Bm, blk_k)
1500                     : a_row_sum;
1501 
1502             auto this_result = kernel_driver_parallel_acopiedbcopy(ithr, sizeM,
1503                     n, sizeK, blk_k, Bk, bufferA_eff, b_block, beta_eff,
1504                     c_block, offsetc_eff, co + co_stride, a_row_sum_eff, arg);
1505 
1506             if (this_result != dnnl_success) result = this_result;
1507 
1508             if (!a_packed)
1509                 dnnl_thr_barrier(); // Wait for kernel computations to finish.
1510         }
1511     }
1512 
1513     // Free memory allocated in master thread
1514     if (ithr == 0 && !a_packed) free(mem);
1515 
1516     return result;
1517 }
1518 
1519 template <typename T>
adjust_thread_count(dim_t m,dim_t n,dim_t k,int * nthrs)1520 static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) {
1521 
1522     const double omp_overhead_small_core = 3.0e+3;
1523     const double omp_intercept_big_core = 4.0e+3;
1524     const double omp_slope_big_core = 5.0e+2;
1525 
1526     auto veclen = get_vector_length<T>();
1527     const double fp_per_cycle = 2.0 * 2.0 * veclen;
1528 
1529     const bool is_f32 = data_traits<T>::data_type == data_type::f32;
1530 
1531     const bool is_avx512_mic = mayiuse(avx512_mic);
1532     const bool is_avx512 = mayiuse(avx512_core);
1533     const bool is_avx = mayiuse(avx);
1534     const bool is_only_avx2 = mayiuse(avx2) && !is_avx512;
1535 
1536     if (is_avx512_mic) return;
1537 
1538     // Some sgemm cases still benefit from using all threads.
1539     const bool use_all_threads = is_f32 && n > 50
1540             && ((is_avx && m <= 3) || (is_avx512 && m <= 10));
1541     if (use_all_threads) return;
1542 
1543     if (is_only_avx2)
1544         if (m > 10 * n && n < *nthrs)
1545             if (m / *nthrs < veclen * 3)
1546                 *nthrs = nstl::max(m / veclen / 3, dim_t(1));
1547 
1548     double gemm_cycles = m * n * k / fp_per_cycle;
1549     gemm_cycles *= is_f32 ? 2.0 : 8.0;
1550 
1551     int i = *nthrs;
1552 
1553     // Use a different model for omp overheads if nthrs is <= 4
1554     if (*nthrs <= 4 && omp_overhead_small_core > 0) {
1555         double omp_cycles = omp_overhead_small_core;
1556         if (gemm_cycles < omp_cycles) {
1557             *nthrs = 1;
1558             return;
1559         } else {
1560             while (i > 1) {
1561                 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
1562                 --i;
1563             }
1564         }
1565     } else {
1566         if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
1567             *nthrs = 1;
1568             return;
1569         }
1570 
1571         // adaptive decrement to march faster·
1572         while (i > 1) {
1573             double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
1574             if (omp_cycles * i < gemm_cycles * (i - 1)) break;
1575 
1576             if (i < 10)
1577                 i -= 2;
1578             else if (i < 30)
1579                 i -= 4;
1580             else
1581                 i -= 8;
1582         }
1583     }
1584 
1585     if (i < 1) i = 1;
1586 
1587     *nthrs = i;
1588 }
1589 
1590 template <typename a_type, typename b_type, typename c_type>
call_no_copy_sgemm(int nthrs,gemm_info_t<a_type,b_type,c_type> * arg)1591 static dnnl_status_t call_no_copy_sgemm(
1592         int nthrs, gemm_info_t<a_type, b_type, c_type> *arg) {
1593 
1594     if (arg->packing == pack_type::none) {
1595         auto transa_char = (arg->transa != do_trans) ? "N" : "T";
1596         auto transb_char = (arg->transb != do_trans) ? "N" : "T";
1597 
1598         if (mayiuse(avx512_core))
1599             return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char,
1600                     &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a,
1601                     &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta,
1602                     (float *)arg->c, &arg->ldc, (float *)arg->co);
1603         else
1604             return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m,
1605                     &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda,
1606                     (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c,
1607                     &arg->ldc, (float *)arg->co);
1608     } else
1609         return pack_no_copy(arg);
1610 }
1611 
1612 template <typename a_type, typename b_type, typename c_type>
gemm_threading_driver(gemm_info_t<a_type,b_type,c_type> * arg)1613 static dnnl_status_t gemm_threading_driver(
1614         gemm_info_t<a_type, b_type, c_type> *arg) {
1615 
1616     auto packing = (arg->packing != pack_type::none);
1617     auto is_a_packed = (arg->transa == packed);
1618     auto is_b_packed = (arg->transb == packed);
1619     constexpr bool is_int8 = utils::one_of(
1620             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1621     constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1622 
1623     if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success;
1624 
1625     if (!is_a_packed && !is_b_packed && jump_to_gemv_s8x8s32(arg))
1626         return dnnl_success;
1627 
1628     if (!is_a_packed && !is_b_packed
1629             && jump_to_gemm_smalln_tn(arg) == dnnl_success)
1630         return dnnl_success;
1631 
1632     if (!is_a_packed && !is_b_packed && jump_to_gemv(arg) == dnnl_success)
1633         return dnnl_success;
1634 
1635     if (is_a_packed && arg->bo != 0)
1636         if (!arg->a_packed->has_row_sums()) return dnnl_invalid_arguments;
1637 
1638     if (is_b_packed && arg->ao != 0)
1639         if (!arg->b_packed->has_col_sums()) return dnnl_invalid_arguments;
1640 
1641     auto nthr_max = dnnl_get_current_num_threads();
1642     int nthr_goal = nthr_max;
1643 
1644     adjust_thread_count<c_type>(arg->m, arg->n, arg->k, &nthr_goal);
1645 
1646     const gemm_threading_t *force_threading = nullptr;
1647     gemm_threading_t force_k_decomp;
1648 
1649     // Initialize per-thread data.
1650     // Note: to support k blocking with non-packed GEMM, threading must be
1651     //   chosen now and force_threading set.
1652     if (!packing) {
1653         // Override choice of thread count if data is pre-packed for a particular
1654         //  number of threads.
1655         if (is_a_packed && is_b_packed)
1656             if (arg->a_packed->threading() != arg->b_packed->threading())
1657                 return dnnl_invalid_arguments;
1658         if (is_a_packed)
1659             force_threading = &arg->a_packed->threading();
1660         else if (is_b_packed)
1661             force_threading = &arg->b_packed->threading();
1662         else if (arg->m <= 768 && arg->n <= 768 && arg->k >= 2048 && is_bf16) {
1663             // Try k-partitioning.
1664             set_thread_opts_pack(nthr_goal, force_k_decomp, arg);
1665 
1666             // Decide partition type later if no partitions in k-dimension.
1667             if (force_k_decomp.nthrs_k > 1) force_threading = &force_k_decomp;
1668         } else if (arg->n <= 128 && arg->k >= 3072 && is_int8) {
1669             // Use k-partitioning if necessary.
1670             // Use 3D decomposition from pack api without n-partitioning.
1671             set_thread_opts_pack(
1672                     nthr_goal, force_k_decomp, arg, true, true, false);
1673 
1674             // Decide partition type later if no partitions in k-dimension.
1675             if (force_k_decomp.nthrs_k > 1 && force_k_decomp.nthrs_m > 1)
1676                 force_threading = &force_k_decomp;
1677         }
1678 
1679         if (force_threading) {
1680             nthr_goal = force_threading->nthrs();
1681             arg->update_blocking(*force_threading);
1682         }
1683     } else {
1684         // Prepare packed data layout.
1685         gemm_pack_storage_t *pack_dst = arg->pack_dst;
1686         bool do_a = (arg->packing == pack_type::pack_a);
1687 
1688         pack_dst->which() = do_a ? matrix_id::a : matrix_id::b;
1689         pack_dst->setup(nthr_goal, do_a && is_int8, !do_a && is_int8);
1690 
1691         auto &thread_info = pack_dst->threading();
1692         force_threading = &thread_info;
1693 
1694         nthr_goal = set_thread_opts(nthr_goal, nthr_max, thread_info, arg);
1695         arg->update_blocking(thread_info);
1696 
1697         if (thread_info.copy != copy_type::no_copy) {
1698             for (int ithr = 0; ithr < nthr_goal; ithr++) {
1699                 if (!pack_dst->is_first_thread_in_slice(ithr)) continue;
1700 
1701                 auto slice = thread_info.get_thread_slice(
1702                         ithr, arg->m, arg->n, arg->k);
1703 
1704                 auto m = slice.m, n = slice.n, k = slice.k;
1705 
1706                 auto m_padd = (thread_info.copy == copy_type::shared_a)
1707                         ? get_m_padd_parallel_a(
1708                                 ithr, m, arg, thread_info.nthrs())
1709                         : get_m_padd(ithr, m, arg);
1710                 auto n_padd = get_n_padd(ithr, n, k, arg);
1711                 auto k_padd = get_k_padd(ithr, k, arg);
1712 
1713                 do_a ? pack_dst->set_blocking(ithr, m, k, m_padd, k_padd)
1714                      : pack_dst->set_blocking(ithr, k, n, k_padd, n_padd);
1715             }
1716         } else {
1717             auto ld = do_a ? gemm_utils::get_ld_padd<a_type>(arg->m)
1718                            : gemm_utils::get_ld_padd<b_type>(arg->k);
1719 
1720             pack_dst->set_nocopy(0, no_trans, ld, do_a ? arg->k : arg->n);
1721         }
1722 
1723         do_a ? pack_dst->finalize<a_type, c_type>()
1724              : pack_dst->finalize<b_type, c_type>();
1725 
1726         if (arg->measure_only) return dnnl_success;
1727     }
1728 
1729     if (nocopy_checker(nthr_goal, arg))
1730         return call_no_copy_sgemm(nthr_goal, arg);
1731 
1732     if (nthr_goal == 1)
1733         return gemm_kernel_driver(0, arg->m, arg->n, arg->k, arg->a, arg->b,
1734                 arg->beta, arg->c, arg->ldc, arg->offsetc, arg->co, arg);
1735 
1736     bool k_blocking = force_threading && (force_threading->nthrs_k > 1);
1737     bool k_summing = k_blocking && !packing;
1738 
1739     auto *thread_arg = (gemm_per_thread_t<c_type> *)malloc(
1740             sizeof(gemm_per_thread_t<c_type>) * nthr_max, PAGE_4K);
1741 
1742     if (!thread_arg) return dnnl_out_of_memory;
1743 
1744     dim_t max_mt = 0, max_nt = 0;
1745     for (int ithr = 0; ithr < nthr_max; ithr++) {
1746         thread_arg[ithr].result = dnnl_success;
1747         thread_arg[ithr].compute_done = false;
1748         thread_arg[ithr].c_local = thread_arg[ithr].c_global = nullptr;
1749         thread_arg[ithr].ldc_global = arg->ldc;
1750         thread_arg[ithr].ldc_local = 0;
1751 
1752         if (force_threading) {
1753             thread_arg[ithr].slice = force_threading->get_thread_slice(
1754                     ithr, arg->m, arg->n, arg->k);
1755             thread_arg[ithr].nthr_k = force_threading->nthrs_k;
1756             thread_arg[ithr].thr_k_stride = force_threading->thr_k_stride();
1757             max_mt = nstl::max(max_mt, thread_arg[ithr].slice.m);
1758             max_nt = nstl::max(max_nt, thread_arg[ithr].slice.n);
1759         } else {
1760             thread_arg[ithr].slice = {0, 0, 0, 0, 0, 0, 0, 0, 0};
1761             thread_arg[ithr].nthr_k = 1;
1762             thread_arg[ithr].thr_k_stride = 0;
1763         }
1764     }
1765 
1766     // Create temporary C buffers for k blocking if needed.
1767     c_type *c_local_storage = nullptr;
1768     if (k_summing) {
1769         const dim_t BAD_LD_MULT = 256;
1770         dim_t ldc_local = max_mt % BAD_LD_MULT
1771                 ? max_mt
1772                 : gemm_utils::get_ld_padd<c_type>(max_mt);
1773         dim_t c_local_stride = ldc_local * max_nt;
1774         c_local_storage = (c_type *)malloc(
1775                 sizeof(c_type) * c_local_stride * nthr_goal, PAGE_4K);
1776 
1777         if (!c_local_storage) {
1778             free(thread_arg);
1779             return dnnl_out_of_memory;
1780         }
1781 
1782         for (int ithr = 0; ithr < nthr_goal; ithr++) {
1783             thread_arg[ithr].c_local = c_local_storage + ithr * c_local_stride;
1784             thread_arg[ithr].ldc_local = ldc_local;
1785         }
1786     }
1787 
1788     char *shared_mem = nullptr;
1789 
1790     // Always use the maximum number of threads to avoid OMP overhead that can
1791     // occur due to change thread counts.
1792     int nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal;
1793 
1794     parallel(nthr_spawn, [&](int ithr, int nthr) {
1795         int nthr_eff = force_threading ? nthr_goal : nstl::min(nthr_goal, nthr);
1796 
1797         if (nthr_eff == 1) {
1798             thread_arg[0].result = gemm_kernel_driver(0, arg->m, arg->n, arg->k,
1799                     arg->a, arg->b, arg->beta, arg->c, arg->ldc, arg->offsetc,
1800                     arg->co, arg);
1801         } else {
1802             gemm_threading_t thread_info;
1803 
1804             if (force_threading)
1805                 thread_info = *force_threading;
1806             else {
1807                 nthr_eff = set_thread_opts(nthr_eff, nthr, thread_info, arg);
1808                 if (ithr < nthr_eff)
1809                     thread_arg[ithr].slice = thread_info.get_thread_slice(
1810                             ithr, arg->m, arg->n, arg->k);
1811             }
1812 
1813             for (; ithr < nthr_eff; ithr += nthr) {
1814                 // Get submatrices and parameters for this thread's GEMM.
1815                 const a_type *a = nullptr;
1816                 const b_type *b = nullptr;
1817                 c_type *c = nullptr;
1818                 const c_type *co = nullptr;
1819                 std::tie(a, b, c, co)
1820                         = decompose_matrices(thread_arg[ithr].slice, arg);
1821 
1822                 auto m = thread_arg[ithr].slice.m;
1823                 auto n = thread_arg[ithr].slice.n;
1824                 auto k = thread_arg[ithr].slice.k;
1825                 thread_arg[ithr].c_global = c;
1826                 auto c_eff = c;
1827                 auto ldc_eff = arg->ldc;
1828                 auto beta_eff = arg->beta;
1829                 auto offsetc_eff = arg->offsetc;
1830 
1831                 // For all but first k block: substitute local C matrix and
1832                 // disable postops.
1833                 if (k_summing && thread_arg[ithr].slice.ithr_k > 0) {
1834                     c_eff = thread_arg[ithr].c_local;
1835                     ldc_eff = thread_arg[ithr].ldc_local;
1836                     beta_eff = 0;
1837                     offsetc_eff = offset_type::none;
1838                 }
1839 
1840                 // Dispatch appropriate GEMM driver.
1841                 switch (thread_info.copy) {
1842                     case copy_type::shared_a:
1843                         thread_arg[ithr].result = parallel_a_copy(ithr,
1844                                 nthr_eff, m, n, k, a, b, beta_eff, c_eff,
1845                                 ldc_eff, offsetc_eff, co, arg, &shared_mem);
1846                         break;
1847 
1848                     default:
1849                     case copy_type::nonshared:
1850                         thread_arg[ithr].result = gemm_kernel_driver(ithr, m, n,
1851                                 k, a, b, beta_eff, c_eff, ldc_eff, offsetc_eff,
1852                                 co, arg);
1853                         break;
1854 
1855                     case copy_type::no_copy:
1856                         // This route is taken only if we realize we need no-copy
1857                         //  after launching the parallel section, due to less
1858                         //  threads being spawned than expected.
1859                         assert(data_traits<a_type>::data_type
1860                                 == data_type::f32);
1861                         assert(arg->packing == pack_type::none);
1862 
1863                         if (mayiuse(avx512_core)) {
1864                             avx512_common_gemm_f32::sgemm_nocopy_driver(
1865                                     arg->transa == no_trans ? "N" : "T",
1866                                     arg->transb == no_trans ? "N" : "T", m, n,
1867                                     k, &arg->alpha, (float *)a, arg->lda,
1868                                     (float *)b, arg->ldb, &beta_eff,
1869                                     (float *)c_eff, ldc_eff, nullptr, nullptr);
1870                         } else {
1871                             avx_gemm_f32::sgemm_nocopy_driver(
1872                                     arg->transa == no_trans ? "N" : "T",
1873                                     arg->transb == no_trans ? "N" : "T", m, n,
1874                                     k, &arg->alpha, (float *)a, arg->lda,
1875                                     (float *)b, arg->ldb, &beta_eff,
1876                                     (float *)c_eff, ldc_eff, nullptr, nullptr);
1877                         }
1878                         thread_arg[ithr].result = dnnl_success;
1879                         break;
1880                 }
1881 
1882                     // Sum thread results along k dimension, parallelized in the n
1883                     // dimension. To avoid deadlocks, results are summed later if
1884                     // not all threads are running concurrently. We can only detect
1885                     // if this is safe when using OpenMP.
1886 #if DNNL_THR_SYNC == 1
1887                 if (k_summing && (nthr >= nthr_eff)) {
1888                     thread_arg[ithr].compute_done = true;
1889                     sum_k_blocks(ithr, thread_arg, true);
1890                 }
1891 #endif
1892             }
1893         }
1894     });
1895 
1896     dnnl_status_t result = dnnl_success; // Initialize to success
1897     for (int ithr = 0; ithr < nthr_max; ithr++) {
1898         if (thread_arg[ithr].result != dnnl_success) {
1899             result = static_cast<dnnl_status_t>(thread_arg[ithr].result);
1900             break;
1901         }
1902     }
1903 
1904     // Sum thread results along k dimension if this wasn't done earlier.
1905     if (k_summing && !thread_arg[0].compute_done) {
1906         parallel(nthr_goal, [&](int ithr, int nthr) {
1907             for (; ithr < nthr_goal; ithr += nthr)
1908                 sum_k_blocks(ithr, thread_arg, false);
1909         });
1910     }
1911 
1912     if (c_local_storage) dnnl::impl::free(c_local_storage);
1913     dnnl::impl::free(thread_arg);
1914 
1915     return result;
1916 }
1917 
1918 template <typename a_type, typename b_type, typename c_type>
gemm_driver(const char * transA,const char * transB,const char * offsetC,const dim_t * m,const dim_t * n,const dim_t * k,const float * alpha,const a_type * a,const dim_t * lda,const a_type * oa,const b_type * b,const dim_t * ldb,const b_type * ob,const float * beta,c_type * c,const dim_t * ldc,const c_type * oc,const bool force_nocopy,pack_type packing,gemm_pack_storage_t * pack_dst,bool measure_only)1919 dnnl_status_t gemm_driver(const char *transA, const char *transB,
1920         const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k,
1921         const float *alpha, const a_type *a, const dim_t *lda, const a_type *oa,
1922         const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta,
1923         c_type *c, const dim_t *ldc, const c_type *oc, const bool force_nocopy,
1924         pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) {
1925 
1926     constexpr bool is_int8 = utils::one_of(
1927             data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1928     MAYBE_UNUSED(is_int8);
1929 
1930     // gemm_driver supports bfloat16 gemm for Intel AVX512 and
1931     // Intel AVX512 BF16.
1932     assert(IMPLICATION(data_traits<a_type>::data_type == data_type::bf16,
1933             mayiuse(avx512_core) && !force_nocopy));
1934 
1935     // gemm_driver supports 8-bit integer Intel AVX512, Intel AVX2, Intel AVX,
1936     // Intel SSE4.1 and Intel DL Boost.
1937     assert(IMPLICATION(is_int8, mayiuse(sse41) && !mayiuse(avx512_mic)));
1938 
1939     // gemm_driver supports sgemm for Intel AVX512, Intel AVX2, Intel AVX,
1940     // and Intel SSE4.1
1941     assert(IMPLICATION(
1942             data_traits<a_type>::data_type == data_type::f32, mayiuse(sse41)));
1943 
1944     // 8-bit integer gemm doesn't support nocopy kernels.
1945     assert(IMPLICATION(is_int8, !force_nocopy));
1946 
1947     // gemm_driver can only dispatch nocopy for avx and above.
1948     assert(IMPLICATION(force_nocopy, mayiuse(avx)));
1949 
1950     gemm_info_t<a_type, b_type, c_type> args(transA, transB, offsetC, m, n, k,
1951             alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, force_nocopy,
1952             packing, pack_dst, measure_only);
1953 
1954     // Check if copy algorithm kernels were generated on supported ISAs.
1955     if (!args.hasKernels()) return dnnl_unimplemented;
1956 
1957     return gemm_threading_driver(&args);
1958 }
1959 
1960 template // Instantiate gemm_bf16bf16f32
1961         dnnl_status_t
1962         gemm_driver<bfloat16_t, bfloat16_t, float>(const char *transA,
1963                 const char *transB, const char *offsetC, const dim_t *m,
1964                 const dim_t *n, const dim_t *k, const float *alpha,
1965                 const bfloat16_t *a, const dim_t *lda, const bfloat16_t *oa,
1966                 const bfloat16_t *b, const dim_t *ldb, const bfloat16_t *ob,
1967                 const float *beta, float *c, const dim_t *ldc, const float *oc,
1968                 const bool force_nocopy, pack_type packing,
1969                 gemm_pack_storage_t *pack_dst, bool measure_only);
1970 
1971 template // Instantiate gemm_s8s8s32
1972         dnnl_status_t
1973         gemm_driver<int8_t, int8_t, int32_t>(const char *transA,
1974                 const char *transB, const char *offsetC, const dim_t *m,
1975                 const dim_t *n, const dim_t *k, const float *alpha,
1976                 const int8_t *a, const dim_t *lda, const int8_t *oa,
1977                 const int8_t *b, const dim_t *ldb, const int8_t *ob,
1978                 const float *beta, int32_t *c, const dim_t *ldc,
1979                 const int32_t *oc, const bool force_nocopy, pack_type packing,
1980                 gemm_pack_storage_t *pack_dst, bool measure_only);
1981 
1982 template // Instantiate gemm_s8u8s32
1983         dnnl_status_t
1984         gemm_driver<int8_t, uint8_t, int32_t>(const char *transA,
1985                 const char *transB, const char *offsetC, const dim_t *m,
1986                 const dim_t *n, const dim_t *k, const float *alpha,
1987                 const int8_t *a, const dim_t *lda, const int8_t *oa,
1988                 const uint8_t *b, const dim_t *ldb, const uint8_t *ob,
1989                 const float *beta, int32_t *c, const dim_t *ldc,
1990                 const int32_t *oc, const bool force_nocopy, pack_type packing,
1991                 gemm_pack_storage_t *pack_dst, bool measure_only);
1992 
1993 template // Instantiate sgemm
1994         dnnl_status_t
1995         gemm_driver<float, float, float>(const char *transA, const char *transB,
1996                 const char *offsetC, const dim_t *m, const dim_t *n,
1997                 const dim_t *k, const float *alpha, const float *a,
1998                 const dim_t *lda, const float *oa, const float *b,
1999                 const dim_t *ldb, const float *ob, const float *beta, float *c,
2000                 const dim_t *ldc, const float *oc, const bool force_nocopy,
2001                 pack_type packing, gemm_pack_storage_t *pack_dst,
2002                 bool measure_only);
2003 
2004 } // namespace x64
2005 } // namespace cpu
2006 } // namespace impl
2007 } // namespace dnnl
2008