1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cstdint>
18 #if defined(_MSC_VER)
19 #include <malloc.h>
20 #endif
21
22 #include "oneapi/dnnl/dnnl_types.h"
23
24 #include "common/bfloat16.hpp"
25 #include "common/dnnl_traits.hpp"
26 #include "common/nstl.hpp"
27 #include "common/utils.hpp"
28
29 #include "cpu/platform.hpp"
30
31 #include "cpu/gemm/f32/gemm_utils_f32.hpp"
32 #include "cpu/gemm/gemm_msan_unpoison.hpp"
33
34 #include "cpu/x64/jit_generator.hpp"
35
36 #include "cpu/x64/gemm/gemm_driver.hpp"
37 #include "cpu/x64/gemm/gemm_info.hpp"
38 #include "cpu/x64/gemm/gemm_partition.hpp"
39 #include "cpu/x64/gemm/gemm_threading.hpp"
40 #include "cpu/x64/gemm/gemm_utils.hpp"
41 #include "cpu/x64/gemm/gemv_driver.hpp"
42
43 #include "cpu/x64/gemm/f32/jit_avx512_common_gemm_f32.hpp"
44 #include "cpu/x64/gemm/f32/jit_avx512_core_gemm_smalln_tn_f32_kern.hpp"
45 #include "cpu/x64/gemm/f32/jit_avx_gemm_f32.hpp"
46
47 #include "cpu/x64/gemm/s8x8s32/jit_avx512_core_gemv_s8x8s32.hpp"
48
49 namespace dnnl {
50 namespace impl {
51 namespace cpu {
52 namespace x64 {
53
54 template <typename c_type>
55 struct alignas(64) gemm_per_thread_t {
56 volatile int32_t result;
57 volatile int32_t compute_done;
58 int32_t thr_k_stride;
59 int32_t nthr_k;
60 dim_t ldc_local;
61 dim_t ldc_global;
62 c_type *c_local;
63 c_type *volatile c_global;
64 gemm_slice_t slice;
65 };
66
67 template <typename T>
get_vector_length()68 int get_vector_length() {
69 int v_bytes;
70
71 if (mayiuse(avx512_core))
72 v_bytes = cpu_isa_traits<avx512_core>::vlen;
73 else if (mayiuse(avx))
74 v_bytes = cpu_isa_traits<avx>::vlen;
75 else
76 v_bytes = cpu_isa_traits<sse41>::vlen;
77
78 return v_bytes / sizeof(T);
79 }
80
81 template <typename c_type>
round_to_nearest(c_type * rounded_val,double fp_val)82 static inline void round_to_nearest(c_type *rounded_val, double fp_val) {
83 if (fp_val >= 0.) {
84 fp_val += 0.5;
85 if (fp_val > INT32_MAX) { fp_val = INT32_MAX; }
86 } else {
87 fp_val -= 0.5;
88 if (fp_val < INT32_MIN) { fp_val = INT32_MIN; }
89 }
90 *rounded_val = (c_type)fp_val;
91 }
92
93 template <typename c_type>
add_results(const dim_t m,const dim_t n,const float alpha,const float beta,const c_type * c_partial_sum,const dim_t ldcp,c_type * c_data,const dim_t ldc,const c_type * co,offset_type offsetc)94 static inline void add_results(const dim_t m, const dim_t n, const float alpha,
95 const float beta, const c_type *c_partial_sum, const dim_t ldcp,
96 c_type *c_data, const dim_t ldc, const c_type *co,
97 offset_type offsetc) {
98
99 constexpr bool is_int8 = data_traits<c_type>::data_type == data_type::s32;
100
101 for (dim_t j = 0; j < n; ++j) {
102 for (dim_t i = 0; i < m; ++i) {
103 c_type ctemp = c_partial_sum[i + j * ldcp];
104
105 if (alpha == 1.0f) {
106 if (beta == 0.0f) {
107 c_data[i + j * ldc] = ctemp;
108 } else {
109 if (is_int8) {
110 double c_float
111 = (double)beta * (double)c_data[i + j * ldc];
112 c_float += (double)ctemp;
113 round_to_nearest(&c_data[i + j * ldc], c_float);
114 } else {
115 c_data[i + j * ldc] *= beta;
116 c_data[i + j * ldc] += ctemp;
117 }
118 }
119 } else if (alpha == -1.0f) {
120 if (beta == 0.0f) {
121 c_data[i + j * ldc] = -ctemp;
122 } else {
123 if (is_int8) {
124 double c_float
125 = (double)beta * (double)c_data[i + j * ldc];
126 c_float -= (double)ctemp;
127 round_to_nearest(&c_data[i + j * ldc], c_float);
128 } else {
129 c_data[i + j * ldc] *= beta;
130 c_data[i + j * ldc] -= ctemp;
131 }
132 }
133 } else {
134 if (beta == 0.0f) {
135 if (is_int8) {
136 double c_float = alpha * (double)ctemp;
137 round_to_nearest(&c_data[i + j * ldc], c_float);
138 } else {
139 c_data[i + j * ldc] = alpha * ctemp;
140 }
141
142 } else {
143 if (is_int8) {
144 double c_float = alpha * (double)ctemp
145 + beta * (double)c_data[i + j * ldc];
146 round_to_nearest(&c_data[i + j * ldc], c_float);
147 } else {
148 c_data[i + j * ldc] *= beta;
149 c_data[i + j * ldc] += alpha * ctemp;
150 }
151 }
152 }
153
154 if (offsetc == offset_type::fixed) {
155 c_data[i + j * ldc] += co[0];
156 } else if (offsetc == offset_type::row) {
157 c_data[i + j * ldc] += co[j];
158 } else if (offsetc == offset_type::column) {
159 c_data[i + j * ldc] += co[i];
160 }
161 }
162 }
163 }
164
165 template <typename a_type, typename b_type, typename c_type>
get_k_padd(int ithr,dim_t k,const gemm_info_t<a_type,b_type,c_type> * arg)166 static inline dim_t get_k_padd(
167 int ithr, dim_t k, const gemm_info_t<a_type, b_type, c_type> *arg) {
168 if (arg->a_packed) {
169 dim_t block_m, block_k;
170 arg->a_packed->get_blocking(ithr, block_m, block_k);
171 return block_k;
172 } else if (arg->b_packed) {
173 dim_t block_n, block_k;
174 arg->b_packed->get_blocking(ithr, block_k, block_n);
175 return block_k;
176 } else {
177 dim_t k_padd = 0;
178
179 if (k <= arg->bk_traditional) {
180 k_padd = utils::rnd_up(k, arg->uk);
181 k_padd = nstl::max(dim_t(128), k_padd);
182 } else if (k < 2 * arg->bk)
183 k_padd = utils::rnd_up((k + 1) / 2, arg->uk);
184 else
185 k_padd = arg->bk;
186
187 return k_padd;
188 }
189 }
190
191 template <typename a_type, typename b_type, typename c_type>
get_m_padd(int ithr,dim_t m,const gemm_info_t<a_type,b_type,c_type> * arg)192 static inline dim_t get_m_padd(
193 int ithr, dim_t m, const gemm_info_t<a_type, b_type, c_type> *arg) {
194 if (arg->a_packed) {
195 dim_t block_m, block_k;
196 arg->a_packed->get_blocking(ithr, block_m, block_k);
197 return block_m;
198 } else
199 return utils::rnd_up(
200 nstl::min(nstl::max(m, arg->um), arg->bm), arg->um);
201 }
202
203 template <typename a_type, typename b_type, typename c_type>
get_m_padd_parallel_a(int ithr,dim_t m,const gemm_info_t<a_type,b_type,c_type> * arg,int nthrs)204 static inline dim_t get_m_padd_parallel_a(int ithr, dim_t m,
205 const gemm_info_t<a_type, b_type, c_type> *arg, int nthrs) {
206 auto m_padd = get_m_padd(ithr, m, arg);
207
208 if (!arg->a_packed) {
209 constexpr auto multiplier = 10;
210
211 m_padd *= nstl::max(nthrs, multiplier);
212 if (m_padd > m) m_padd = utils::rnd_up(m, arg->um);
213 }
214
215 return m_padd;
216 }
217
218 template <typename a_type, typename b_type, typename c_type>
get_n_padd(int ithr,dim_t n,dim_t k,const gemm_info_t<a_type,b_type,c_type> * arg)219 static inline dim_t get_n_padd(int ithr, dim_t n, dim_t k,
220 const gemm_info_t<a_type, b_type, c_type> *arg) {
221 if (arg->b_packed) {
222 dim_t block_n, block_k;
223 arg->b_packed->get_blocking(ithr, block_k, block_n);
224 return block_n;
225 } else {
226 auto bn = (k < arg->blocking_small_k) ? arg->bn_small_k : arg->bn;
227 return utils::rnd_up(nstl::min(nstl::max(n, arg->un), bn), arg->un);
228 }
229 }
230
align(void * ptr,size_t alignment)231 static inline void *align(void *ptr, size_t alignment) {
232 return (void *)utils::rnd_up((uintptr_t)ptr, alignment);
233 }
234
235 template <typename scale_t, typename mat_t>
scale_matrix(dim_t m,dim_t n,scale_t alpha,mat_t * __restrict p_mat,dim_t ld)236 void scale_matrix(
237 dim_t m, dim_t n, scale_t alpha, mat_t *__restrict p_mat, dim_t ld) {
238 if (data_traits<mat_t>::data_type == data_type::f32) {
239 for (dim_t j = 0; j < n; j++) {
240 for (dim_t i = 0; i < m; i++) {
241 p_mat[i + j * ld] = (mat_t)((scale_t)p_mat[i + j * ld] * alpha);
242 }
243 }
244 }
245 }
246
247 template <typename mat_t>
sum_matrices(dim_t m,dim_t n,mat_t * __restrict dst,dim_t ld_dst,mat_t * __restrict src,dim_t ld_src)248 static void sum_matrices(dim_t m, dim_t n, mat_t *__restrict dst, dim_t ld_dst,
249 mat_t *__restrict src, dim_t ld_src) {
250
251 for (dim_t j = 0; j < n; j++) {
252 PRAGMA_OMP_SIMD()
253 for (int i = 0; i < m; i++)
254 dst[i + j * ld_dst] += src[i + j * ld_src];
255 }
256 }
257
258 template <typename c_type>
sum_k_blocks(int ithr,gemm_per_thread_t<c_type> * thread_arg,bool wait)259 static void sum_k_blocks(
260 int ithr, gemm_per_thread_t<c_type> *thread_arg, bool wait) {
261
262 auto m = thread_arg[ithr].slice.m;
263 auto n = thread_arg[ithr].slice.n;
264 auto ithr_k = thread_arg[ithr].slice.ithr_k;
265 auto nthr_k = thread_arg[ithr].nthr_k;
266 auto stride = thread_arg[ithr].thr_k_stride;
267 dim_t n0, nn;
268
269 partition_1d(ithr_k, nthr_k, n, n0, nn);
270
271 auto get_thread_arg = [&](int thr_k) -> gemm_per_thread_t<c_type> & {
272 return thread_arg[ithr + (thr_k - ithr_k) * stride];
273 };
274
275 auto wait_thread = [&](int thr_k) {
276 if (wait) {
277 auto &tk_arg = get_thread_arg(thr_k);
278 while (!tk_arg.compute_done) {}
279 }
280 };
281
282 auto add_thread_results = [&](int thr_k) {
283 auto &tk_arg = get_thread_arg(thr_k);
284
285 sum_matrices(m, nn, tk_arg.c_global + n0 * tk_arg.ldc_global,
286 tk_arg.ldc_global, tk_arg.c_local + n0 * tk_arg.ldc_local,
287 tk_arg.ldc_local);
288 };
289
290 // First accumulate this thread's results while they are in cache.
291 if (ithr_k > 0) {
292 wait_thread(0);
293 add_thread_results(ithr_k);
294 }
295
296 // Then accumulate the others.
297 for (int thr_k = 1; thr_k < nthr_k; thr_k++) {
298 if (thr_k != ithr_k) {
299 wait_thread(thr_k);
300 add_thread_results(thr_k);
301 }
302 }
303 }
304
305 template <typename a_type, typename b_type, typename c_type>
pack_no_copy(gemm_info_t<a_type,b_type,c_type> * arg)306 static dnnl_status_t pack_no_copy(gemm_info_t<a_type, b_type, c_type> *arg) {
307
308 if (arg->packing == pack_type::pack_a) {
309 return gemm_utils::pack_no_copy(arg->a, arg->lda, arg->m, arg->k,
310 arg->transa, arg->alpha, arg->pack_dst);
311 } else {
312 return gemm_utils::pack_no_copy(arg->b, arg->ldb, arg->k, arg->n,
313 arg->transb, arg->alpha, arg->pack_dst);
314 }
315 }
316
317 template <typename a_type, typename b_type, typename c_type>
gemm_packing_driver(int ithr,dim_t m,dim_t n,dim_t k,const a_type * a,const b_type * b,const gemm_info_t<a_type,b_type,c_type> * arg)318 static dnnl_status_t gemm_packing_driver(int ithr, dim_t m, dim_t n, dim_t k,
319 const a_type *a, const b_type *b,
320 const gemm_info_t<a_type, b_type, c_type> *arg) {
321
322 if (m <= 0 || n <= 0) return dnnl_success;
323
324 gemm_pack_storage_t *pack_dst = arg->pack_dst;
325
326 if (!pack_dst->is_first_thread_in_slice(ithr)) return dnnl_success;
327
328 dim_t block_r, block_c;
329 pack_dst->get_blocking(ithr, block_r, block_c);
330
331 auto do_a = (arg->packing == pack_type::pack_a);
332 auto mn = do_a ? m : n;
333 auto mn_padd = do_a ? block_r : block_c;
334 auto k_padd = do_a ? block_c : block_r;
335 dim_t mn_stride, k_stride;
336
337 if (do_a) {
338 mn_stride = (arg->transa == no_trans) ? 1 : arg->lda;
339 k_stride = (arg->transa == no_trans) ? arg->lda : 1;
340 } else {
341 mn_stride = (arg->transb == no_trans) ? arg->ldb : 1;
342 k_stride = (arg->transb == no_trans) ? 1 : arg->ldb;
343 }
344
345 dim_t blk_k = 0;
346 for (dim_t Bk = 0; Bk < k; Bk += k_padd, blk_k++) {
347 dim_t nk = nstl::min(k - Bk, k_padd);
348
349 for (dim_t Bmn = 0; Bmn < mn; Bmn += mn_padd) {
350 dim_t nmn = nstl::min(mn - Bmn, mn_padd);
351
352 if (do_a) {
353 auto a_src = a + mn_stride * Bmn + k_stride * Bk;
354 auto a_dst = pack_dst->matrix<a_type>(ithr, Bmn, Bk);
355 auto a_row_sum = pack_dst->row_sums<c_type>(ithr, Bmn, blk_k);
356
357 arg->copyA(&nk, &nmn, a_src, &arg->lda, &arg->alpha, a_dst,
358 nullptr, nullptr, a_row_sum);
359 } else {
360 auto b_src = b + mn_stride * Bmn + k_stride * Bk;
361 auto b_dst = pack_dst->matrix<b_type>(ithr, Bk, Bmn);
362 auto b_col_sum = pack_dst->col_sums<c_type>(ithr, blk_k, Bmn);
363
364 arg->copyB(&nk, &nmn, b_src, &arg->ldb, &arg->alpha, b_dst,
365 nullptr, nullptr, b_col_sum);
366 }
367 }
368 }
369
370 return dnnl_success;
371 }
372
373 template <typename a_type, typename b_type, typename c_type>
gemm_kernel(dim_t m,dim_t n,const dim_t k,const float alpha,const a_type * a,const b_type * b,float beta,c_type * c,const dim_t ldc,const c_type * a_row_sum,const c_type * b_col_sum,const c_type * co,offset_type offsetc,const gemm_info_t<a_type,b_type,c_type> * arg)374 void gemm_kernel(dim_t m, dim_t n, const dim_t k, const float alpha,
375 const a_type *a, const b_type *b, float beta, c_type *c,
376 const dim_t ldc, const c_type *a_row_sum, const c_type *b_col_sum,
377 const c_type *co, offset_type offsetc,
378 const gemm_info_t<a_type, b_type, c_type> *arg) {
379
380 #ifdef DNNL_WITH_SYCL
381 std::vector<c_type> col_offset_vec(m);
382 std::vector<c_type> row_offset_vec(n);
383 c_type *col_offset = col_offset_vec.data();
384 c_type *row_offset = row_offset_vec.data();
385 #else
386 // Since m and n are limited by blocking, stack overflow may not happen;
387 // it's up to 32kB
388 #if !defined(_MSC_VER)
389 c_type col_offset[m];
390 c_type row_offset[n];
391 #else
392 c_type *col_offset = (c_type *)_alloca(sizeof(*col_offset) * m);
393 c_type *row_offset = (c_type *)_alloca(sizeof(*row_offset) * n);
394 #endif
395 #endif
396
397 bool col_req = false;
398 bool row_req = false;
399
400 constexpr bool is_int8 = utils::one_of(
401 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
402 constexpr bool is_f32 = data_traits<a_type>::data_type == data_type::f32;
403 bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
404
405 // Unconditionally zero initialize these arrays
406 for (dim_t i = 0; i < m; i++)
407 col_offset[i] = 0;
408 for (dim_t i = 0; i < n; i++)
409 row_offset[i] = 0;
410
411 if (is_int8) {
412 c_type ao = arg->ao;
413 c_type bo = arg->bo;
414 c_type co_0 = offsetc == offset_type::none ? 0 : co[0];
415
416 if (bo != 0 || offsetc == offset_type::column) col_req = true;
417 if (ao != 0 || offsetc == offset_type::row) row_req = true;
418
419 // It needs one of column or row offsets, but it doesn't need both
420 if ((ao != 0 && bo != 0)
421 || (offsetc == offset_type::fixed && co_0 != 0)) {
422 if (!col_req && !row_req) {
423 if (m <= n) {
424 col_req = true;
425 } else {
426 row_req = true;
427 }
428 }
429 }
430
431 if (col_req) {
432 if (offsetc == offset_type::column) {
433 for (dim_t i = 0; i < m; i++)
434 col_offset[i] += co[i];
435 }
436
437 if (bo != 0 && a_row_sum) {
438 for (dim_t i = 0; i < m; i++)
439 col_offset[i] -= bo * a_row_sum[i];
440 }
441 }
442
443 if (row_req) {
444 if (offsetc == offset_type::row) {
445 for (dim_t i = 0; i < n; i++)
446 row_offset[i] += co[i];
447 }
448
449 if (ao != 0 && b_col_sum) {
450 for (dim_t i = 0; i < n; i++)
451 row_offset[i] -= ao * b_col_sum[i];
452 }
453 }
454
455 if (offsetc == offset_type::fixed && co_0 != 0) {
456 if (col_req) {
457 for (dim_t i = 0; i < m; i++)
458 col_offset[i] += co_0;
459 } else {
460 for (dim_t i = 0; i < n; i++)
461 row_offset[i] += co_0;
462 }
463 }
464
465 if (ao != 0 && bo != 0) {
466 if (col_req) {
467 for (dim_t i = 0; i < m; i++)
468 col_offset[i] += (c_type)k * ao * bo;
469 } else {
470 for (dim_t i = 0; i < n; i++)
471 row_offset[i] += (c_type)k * ao * bo;
472 }
473 }
474 }
475
476 bool isBeta0 = beta == 0.0f;
477
478 /* Column and row offsets are ignored by non-integer compute kernels.
479 * Scaling is done only for bfloat16 kernels.
480 */
481 if (m > 0 && n > 0)
482 arg->kernel[isBeta0][col_req][row_req](
483 &m, &n, &k, &alpha, a, b, c, ldc, col_offset, row_offset);
484
485 msan_unpoison_matrix(c, m, n, ldc, sizeof(*c));
486
487 // sgemm kernels don't support bias yet.
488 if (is_f32) {
489 if (co && offsetc == offset_type::column) {
490 for (dim_t j = 0; j < n; j++) {
491 for (dim_t i = 0; i < m; i++) {
492 c[i + j * ldc] += co[i];
493 }
494 }
495 }
496 }
497
498 // AMX igemm kernels don't support row & col sums yet.
499 if (is_int8_amx) {
500 for (dim_t j = 0; j < n; j++) {
501 for (dim_t i = 0; i < m; i++) {
502 if (row_req) c[i + j * ldc] += row_offset[j];
503 if (col_req) c[i + j * ldc] += col_offset[i];
504 }
505 }
506 }
507 }
508
509 template <typename a_type, typename b_type, typename c_type>
gemm_kernel_driver(int ithr,dim_t m,dim_t n,dim_t k,const a_type * a,const b_type * b,float beta,c_type * c,dim_t ldc,offset_type offsetc,const c_type * co,const gemm_info_t<a_type,b_type,c_type> * arg)510 static dnnl_status_t gemm_kernel_driver(int ithr, dim_t m, dim_t n, dim_t k,
511 const a_type *a, const b_type *b, float beta, c_type *c, dim_t ldc,
512 offset_type offsetc, const c_type *co,
513 const gemm_info_t<a_type, b_type, c_type> *arg) {
514
515 if (arg->packing != pack_type::none)
516 return gemm_packing_driver(ithr, m, n, k, a, b, arg);
517
518 if (m <= 0 || n <= 0) return dnnl_success;
519
520 dim_t lda = arg->lda;
521 dim_t ldb = arg->ldb;
522
523 float alpha = arg->alpha;
524
525 constexpr bool is_int8 = utils::one_of(
526 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
527 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
528 bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
529 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
530 bool is_amx = is_int8_amx || is_bf16_amx;
531
532 const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed;
533 const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed;
534
535 // Scaling C matrix.
536 if (!is_int8 && beta != 1.0f && beta != 0.0f) {
537 scale_matrix(m, n, beta, c, ldc);
538 beta = 1.0f;
539 }
540
541 // Quick exit for C = beta * C
542 if (!is_int8 && alpha == 0.0f) {
543 if (beta == 0.0f) scale_matrix(m, n, beta, c, ldc);
544
545 return dnnl_success;
546 }
547
548 // Get block sizes.
549 dim_t k_padd = get_k_padd(ithr, k, arg);
550 dim_t m_padd = get_m_padd(ithr, m, arg);
551 dim_t n_padd = get_n_padd(ithr, n, k, arg);
552
553 // Padding for temporary buffer for C
554 dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m_padd);
555
556 dim_t strideAm = (arg->transa == no_trans) ? 1 : lda;
557 dim_t strideAn = (arg->transa != no_trans) ? 1 : lda;
558 dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb;
559 dim_t strideBn = (arg->transb != no_trans) ? 1 : ldb;
560
561 size_t a_buf_nelems = m_padd * k_padd;
562 size_t b_buf_nelems = k_padd * n_padd;
563 // A and B buffers need more space due to zero-padding.
564 if (is_amx) {
565 a_buf_nelems = utils::rnd_up(m_padd, arg->um)
566 * utils::rnd_up(k_padd, arg->uk);
567 b_buf_nelems = utils::rnd_up(k_padd, arg->uk)
568 * utils::rnd_up(n_padd, arg->un);
569 }
570 size_t a_row_sum_nelems = m_padd;
571 size_t b_col_sum_nelems = n_padd;
572
573 if (a_packed) a_buf_nelems = a_row_sum_nelems = 0;
574 if (b_packed) b_buf_nelems = b_col_sum_nelems = 0;
575
576 size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
577 + b_buf_nelems * sizeof(*b) + PAGE_4K;
578
579 if (is_int8) {
580 mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K
581 + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
582 }
583
584 bool need_c_buffer
585 = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f)))
586 // AMX bfloat16 kernels don't support alpha scaling yet,
587 // so we need to use accumulation buffer even if beta == 0.
588 || (is_bf16_amx && alpha != 1.0f);
589
590 if (need_c_buffer) {
591 size_t c_buf_nelems = ldc_buf * n_padd;
592 mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
593 }
594
595 char *mem = nullptr;
596
597 if (mem_size > 0) {
598 mem = (char *)malloc(mem_size, 128);
599 if (!mem) return dnnl_out_of_memory;
600 }
601
602 a_type *bufferA = (a_type *)align(mem, PAGE_4K);
603 b_type *bufferB = (b_type *)align(bufferA + a_buf_nelems, PAGE_4K);
604
605 c_type *a_row_sum = nullptr;
606 c_type *b_col_sum = nullptr;
607 if (is_int8) {
608 a_row_sum = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
609 b_col_sum = (c_type *)align(a_row_sum + a_row_sum_nelems, PAGE_4K);
610 }
611
612 c_type *bufferC = nullptr;
613 if (need_c_buffer) {
614 if (is_int8)
615 bufferC = (c_type *)align(b_col_sum + b_col_sum_nelems, PAGE_4K);
616 else
617 bufferC = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
618 }
619
620 int a_block_copied = 0;
621 dim_t sizeM = 0;
622 for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
623 sizeM = m - Bm;
624 if (sizeM > m_padd) sizeM = m_padd;
625
626 dim_t sizeK = 0;
627 dim_t blk_k = 0;
628 for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) {
629 sizeK = k - Bk;
630 if (sizeK > k_padd) sizeK = k_padd;
631
632 // Scale C blocks by beta only for the first time
633 auto beta_eff = (Bk == 0) ? beta : 1.0f;
634
635 // Apply C offset when to the last k-block of the partial sum.
636 auto offsetc_eff = offset_type::none;
637 if (Bk + sizeK == k) offsetc_eff = offsetc;
638
639 dim_t sizeN = 0;
640 for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
641 sizeN = n - Bn;
642 if (sizeN > n_padd) sizeN = n_padd;
643
644 if (b_packed) {
645 bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn);
646 if (is_int8)
647 b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn);
648 } else {
649 const b_type *b_block = b + Bk * strideBm + Bn * strideBn;
650 const float one = 1.0f;
651
652 /* Column sum argument is ignored for non-integer kernels
653 * and scaling factor is ignored by 8-bit and 16-bit copy
654 * kernels.
655 */
656 arg->copyB(&sizeK, &sizeN, b_block, &ldb, &one, bufferB,
657 nullptr, nullptr, b_col_sum);
658 }
659
660 dim_t sizeUM = 0;
661 for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
662 sizeUM = sizeM - Um;
663 if (sizeUM > arg->um) sizeUM = arg->um;
664
665 /* Use the whole A buffer only if we have multiple B
666 * blocks for k-dimension, otherwise we are wasting cache
667 * to store B and C blocks.
668 */
669 dim_t Um_forA = 0;
670 if (sizeN < n) Um_forA = Um;
671
672 a_type *bufferA_eff = nullptr;
673 c_type *a_row_sum_eff = nullptr;
674
675 if (a_packed) {
676 Um_forA = Um;
677
678 // TODO Can we simplify this!
679 dim_t buf_shift = 0;
680 if (is_amx)
681 buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk);
682 else
683 buf_shift = Um_forA * sizeK;
684
685 bufferA_eff = a_packed->matrix<a_type>(ithr, Bm, Bk)
686 + buf_shift;
687
688 if (is_int8)
689 a_row_sum_eff = a_packed->row_sums<c_type>(
690 ithr, Bm, blk_k)
691 + Um_forA;
692 } else {
693 // TODO Can we simplify this!
694 dim_t buf_shift = 0;
695 if (is_amx)
696 buf_shift = Um_forA * utils::rnd_up(sizeK, arg->uk);
697 else
698 buf_shift = Um_forA * sizeK;
699
700 bufferA_eff = bufferA + buf_shift;
701 a_row_sum_eff
702 = a_row_sum ? a_row_sum + Um_forA : nullptr;
703
704 if (!a_block_copied) {
705 const a_type *a_block
706 = a + (Bm + Um) * strideAm + Bk * strideAn;
707
708 /* Row sum argument is ignored for non-integer
709 * kernels and scaling factor is ignored by 8-bit
710 * and 16-bit copy kernels.
711 */
712 arg->copyA(&sizeK, &sizeUM, a_block, &lda, &alpha,
713 bufferA_eff, nullptr, nullptr,
714 a_row_sum_eff);
715 }
716 }
717
718 c_type *c_block = c + (Bm + Um) + Bn * ldc;
719
720 dim_t co_stride = 0;
721 if (offsetc_eff == offset_type::row)
722 co_stride = Bn;
723 else if (offsetc_eff == offset_type::column)
724 co_stride = Bm + Um;
725
726 if (need_c_buffer) {
727 gemm_kernel(sizeUM, sizeN, sizeK, 1.0f, bufferA_eff,
728 bufferB, 0.0f, bufferC + Um, ldc_buf,
729 a_row_sum_eff, b_col_sum, (c_type *)nullptr,
730 offset_type::none, arg);
731
732 /* Finish the block adding the necessary alpha, beta
733 * and offsets.
734 */
735 add_results(sizeUM, sizeN, alpha, beta_eff,
736 bufferC + Um, ldc_buf, c_block, ldc,
737 co + co_stride, offsetc_eff);
738 } else {
739 gemm_kernel(sizeUM, sizeN, sizeK, alpha, bufferA_eff,
740 bufferB, beta_eff, c_block, ldc, a_row_sum_eff,
741 b_col_sum, co + co_stride, offsetc_eff, arg);
742 }
743 }
744 a_block_copied = 1;
745 }
746 a_block_copied = 0;
747 }
748 }
749
750 free(mem);
751
752 return dnnl_success;
753 }
754
755 template <typename a_type, typename b_type, typename c_type>
kernel_driver_parallel_acopiedbcopy(int ithr,dim_t m,dim_t n,dim_t k,dim_t blk_k,dim_t Bk,const a_type * bufferA,const b_type * b,float beta,c_type * c,offset_type offsetc,const c_type * co,const c_type * a_row_sum,const gemm_info_t<a_type,b_type,c_type> * arg)756 static dnnl_status_t kernel_driver_parallel_acopiedbcopy(int ithr, dim_t m,
757 dim_t n, dim_t k, dim_t blk_k, dim_t Bk, const a_type *bufferA,
758 const b_type *b, float beta, c_type *c, offset_type offsetc,
759 const c_type *co, const c_type *a_row_sum,
760 const gemm_info_t<a_type, b_type, c_type> *arg) {
761
762 dim_t ldb = arg->ldb;
763 dim_t ldc = arg->ldc;
764
765 float alpha = arg->alpha;
766
767 const std::shared_ptr<const gemm_pack_storage_t> &b_packed = arg->b_packed;
768
769 if (m <= 0 || n <= 0) { return dnnl_success; }
770
771 // Padding along N dimension.
772 dim_t n_padd = get_n_padd(ithr, n, k, arg);
773
774 // Padding for temporary buffer for C
775 dim_t ldc_buf = gemm_utils::get_ld_padd<c_type>(m);
776
777 dim_t strideBn = (arg->transb != 0) ? 1 : ldb;
778
779 size_t b_buf_nelems = k * n_padd;
780 size_t b_col_sum_nelems = n_padd;
781
782 constexpr bool is_int8 = utils::one_of(
783 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
784 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
785 bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
786 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
787 bool is_amx = is_int8_amx || is_bf16_amx;
788
789 // B buffer needs to large due to zero-padding.
790 if (is_amx)
791 b_buf_nelems
792 = utils::rnd_up(k, arg->uk) * utils::rnd_up(n_padd, arg->un);
793
794 if (b_packed) b_buf_nelems = b_col_sum_nelems = 0;
795
796 size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K;
797
798 if (is_int8) { mem_size += b_col_sum_nelems * sizeof(*c) + PAGE_4K; }
799
800 bool need_c_buffer
801 = (is_int8 && (alpha != 1.0f || (beta != 1.0f && beta != 0.0f)))
802 // AMX bfloat16 kernels don't support alpha scaling yet,
803 // so we need to use accumulation buffer even if beta == 0.
804 || (is_bf16_amx && alpha != 1.0f);
805
806 if (need_c_buffer) {
807 size_t c_buf_nelems = ldc_buf * n_padd;
808 mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
809 }
810
811 char *mem = nullptr;
812
813 if (mem_size > 0) {
814 mem = (char *)malloc(mem_size, 128);
815 if (!mem) return dnnl_out_of_memory;
816 }
817
818 b_type *bufferB = (b_type *)align(mem, PAGE_4K);
819
820 c_type *b_col_sum = nullptr;
821 if (is_int8) {
822 b_col_sum = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
823 }
824
825 c_type *bufferC = nullptr;
826 if (need_c_buffer) {
827 if (is_int8)
828 bufferC = (c_type *)align(b_col_sum + b_col_sum_nelems, PAGE_4K);
829 else
830 bufferC = (c_type *)align(bufferB + b_buf_nelems, PAGE_4K);
831 }
832
833 dim_t sizeN = 0;
834 for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
835 sizeN = n - Bn;
836 if (sizeN > n_padd) sizeN = n_padd;
837
838 if (b_packed) {
839 bufferB = b_packed->matrix<b_type>(ithr, Bk, Bn);
840 if (is_int8)
841 b_col_sum = b_packed->col_sums<c_type>(ithr, blk_k, Bn);
842 } else {
843 const b_type *b_block = b + Bn * strideBn;
844 const float one = 1.0f;
845
846 /* Column sum argument is ignored for non-integer kernels and
847 * scaling factor is ignored by 8-bit and 16-bit copy kernels.
848 */
849 arg->copyB(&k, &sizeN, b_block, &ldb, &one, bufferB, nullptr,
850 nullptr, b_col_sum);
851 }
852
853 dim_t co_stride = 0;
854 if (offsetc == offset_type::fixed) {
855 co_stride = 0;
856 } else if (offsetc == offset_type::row) {
857 co_stride = Bn;
858 } else if (offsetc == offset_type::column) {
859 co_stride = 0;
860 }
861
862 c_type *c_block = c + Bn * ldc;
863 if (need_c_buffer) {
864 gemm_kernel(m, sizeN, k, 1.0f, bufferA, bufferB, 0.0f, bufferC,
865 ldc_buf, a_row_sum, b_col_sum, (c_type *)nullptr,
866 offset_type::none, arg);
867
868 // Finish the block adding the necessary alpha, beta and offsets.
869 add_results(m, sizeN, alpha, beta, bufferC, ldc_buf, c_block, ldc,
870 co + co_stride, offsetc);
871 } else {
872 gemm_kernel(m, sizeN, k, alpha, bufferA, bufferB, beta, c_block,
873 ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
874 }
875 }
876
877 free(mem);
878
879 return dnnl_success;
880 }
881
nocopy_checker_avx2(const int nthr,const int transa,const int transb,const dim_t m,const dim_t n,const dim_t k,const dim_t lda,const dim_t ldb,const dim_t ldc)882 static inline bool nocopy_checker_avx2(const int nthr, const int transa,
883 const int transb, const dim_t m, const dim_t n, const dim_t k,
884 const dim_t lda, const dim_t ldb, const dim_t ldc) {
885 static const dim_t BM_NOCOPY_AVX2 = 64;
886 static const dim_t MN_NOCOPY_AVX2 = 128;
887 static const dim_t N_TRANSB_PER_THR = 1;
888 static const dim_t K_TRANSB_PER_THR = 1;
889 static const dim_t N_NOTRANSB_PER_THR = 16;
890 static const dim_t K_NOTRANSB_PER_THR = 2;
891 static const double FORCE_NOCOPY_THRESH = 0.0038;
892
893 // Crude threshold to nocopy kernels if copy overhead is significant.
894 if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH) { return true; }
895
896 if (m <= 378 && n <= 378 && k >= nthr * 378) return false;
897
898 if (m >= nthr * 378 && k >= nthr * 378) return false;
899
900 if (transb == no_trans) {
901 if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true;
902 if (n <= nthr * N_NOTRANSB_PER_THR) return true;
903 if (k <= nthr * K_NOTRANSB_PER_THR) return true;
904 if (m <= BM_NOCOPY_AVX2 && n >= nthr * N_NOTRANSB_PER_THR) return true;
905 } else {
906 if (m <= MN_NOCOPY_AVX2 && n <= MN_NOCOPY_AVX2) return true;
907 if (n <= nthr * N_TRANSB_PER_THR) return true;
908 if (k <= nthr * K_TRANSB_PER_THR) return true;
909 }
910
911 return false;
912 }
913
nocopy_checker_avx512(int nthr,const int transa,const int transb,const dim_t m,const dim_t n,const dim_t k,const dim_t lda,const dim_t ldb,const dim_t ldc)914 static inline bool nocopy_checker_avx512(int nthr, const int transa,
915 const int transb, const dim_t m, const dim_t n, const dim_t k,
916 const dim_t lda, const dim_t ldb, const dim_t ldc) {
917 // Constants definition
918 static const dim_t BAD_LD_MULT = 256;
919 static const dim_t VERYBAD_LD_MULT = 1024;
920 static const dim_t M_TRANSB_PER_THR = 28;
921 static const dim_t N_TRANSB_PER_THR = 28;
922 static const dim_t K_TRANSB_PER_THR = 1;
923 static const dim_t MN_NOTRANSB_PER_THR = 28;
924 static const dim_t K_NOTRANSB_PER_THR = 1;
925 static const double FORCE_NOCOPY_THRESH = 0.00196;
926
927 bool is_NN = transa == no_trans && transb == no_trans;
928 bool is_NT = transa == no_trans && transb == do_trans;
929 bool is_TN = transa == do_trans && transb == no_trans;
930
931 bool is_lda_bad = lda % BAD_LD_MULT == 0;
932 bool is_ldb_bad = ldb % BAD_LD_MULT == 0;
933 bool is_ldc_bad = ldc % BAD_LD_MULT == 0;
934 bool is_ld_bad = is_lda_bad || is_ldb_bad || is_ldc_bad;
935
936 bool is_lda_verybad = lda % VERYBAD_LD_MULT == 0;
937
938 // Copy-based performs better for TN case with small N in sequential case.
939 if (nthr == 1 && is_TN && m > 100
940 && ((m < 1200 && n < 200 && k < 1200)
941 || (is_lda_bad && is_ldb_bad)))
942 return false;
943
944 // Copy-based performs better for NN case on very bad leading dimension if
945 // each thread has enough work.
946 if (nthr <= 8 && is_NN && is_lda_verybad && k > 500 && n > 100)
947 return false;
948
949 // Crude threshold for nocopy kernels if copy overhead is significant.
950 if (1.0 / m + 1.0 / n >= FORCE_NOCOPY_THRESH
951 && !(is_lda_verybad && is_NT)) {
952 return true;
953 }
954
955 // Copy strategy usually performs better than nocopy on "bad" leading
956 // dimensions.
957 if (is_ld_bad) {
958 bool use_copy_based = false;
959
960 if (m >= 32 && n > 16) use_copy_based = true;
961
962 // Nocopy outperforms copy-based in certain conditions.
963 if (m >= 32 && n == 16
964 && (k >= 6400 || transa == do_trans || m == 4096))
965 use_copy_based = true;
966
967 if (use_copy_based) return false;
968 }
969
970 if (m <= 378 && n <= 378 && k >= nthr * 378) return false;
971
972 if (m >= nthr * 378 && k >= nthr * 378) return false;
973
974 if (transb == no_trans) {
975 if (m <= nthr * MN_NOTRANSB_PER_THR) return true;
976 if (n <= nthr * MN_NOTRANSB_PER_THR) return true;
977 if (k <= nthr * K_NOTRANSB_PER_THR) return true;
978 } else {
979 if (m <= nthr * M_TRANSB_PER_THR && m >= n) return true;
980 if (n <= nthr * N_TRANSB_PER_THR) return true;
981 if (k <= nthr * K_TRANSB_PER_THR) return true;
982 }
983 return false;
984 }
985
986 template <typename a_type, typename b_type, typename c_type>
nocopy_checker(int nthr,const gemm_info_t<a_type,b_type,c_type> * arg)987 static inline bool nocopy_checker(
988 int nthr, const gemm_info_t<a_type, b_type, c_type> *arg) {
989
990 if (data_traits<a_type>::data_type != data_type::f32) return false;
991
992 if (!mayiuse(avx)) return false;
993
994 if (arg->force_nocopy) return true;
995
996 auto m = arg->m, n = arg->n, k = arg->k;
997 auto lda = arg->lda, ldb = arg->ldb, ldc = arg->ldc;
998 auto transa = arg->transa, transb = arg->transb;
999 auto packing = arg->packing;
1000
1001 if (packing != pack_type::none) ldc = 64;
1002
1003 if (arg->a_packed || arg->b_packed)
1004 return false;
1005 else if (mayiuse(avx512_core))
1006 return nocopy_checker_avx512(
1007 nthr, transa, transb, m, n, k, lda, ldb, ldc);
1008 else
1009 return nocopy_checker_avx2(
1010 nthr, transa, transb, m, n, k, lda, ldb, ldc);
1011 }
1012
1013 template <typename a_type, typename b_type, typename c_type>
set_thread_opts_nopack(int nthrs,int nthrs_spawn,gemm_threading_t & thread_info,const gemm_info_t<a_type,b_type,c_type> * arg)1014 static inline void set_thread_opts_nopack(int nthrs, int nthrs_spawn,
1015 gemm_threading_t &thread_info,
1016 const gemm_info_t<a_type, b_type, c_type> *arg) {
1017
1018 static constexpr dim_t N2D_MAX = 384;
1019 static constexpr dim_t M2D_MIN = 384;
1020
1021 constexpr bool is_int8 = utils::one_of(
1022 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1023 bool isSgemm = data_traits<a_type>::data_type == data_type::f32;
1024
1025 dim_t m = arg->m;
1026 dim_t n = arg->n;
1027 dim_t k = arg->k;
1028
1029 thread_info.nthrs_m = 0;
1030 thread_info.nthrs_n = 0;
1031 thread_info.nthrs_k = 0;
1032 thread_info.copy = copy_type::nonshared;
1033 thread_info.partition = partition_type::row_1d;
1034
1035 // TODO Check if we can use dynamic scheduling for sgemm.
1036 // TODO Check if we should use 3D blocking.
1037 thread_info.nthrs_k = 1;
1038 thread_info.thread_k = k;
1039
1040 bool condition_2D_bsrc = false;
1041 if (isSgemm) {
1042 // If m is large and n is small then do 1D partitioning for AVX2.
1043 if (!mayiuse(avx512_core) && n <= N2D_MAX && (m >= nthrs * M2D_MIN))
1044 condition_2D_bsrc = false;
1045 else
1046 condition_2D_bsrc
1047 = ((n > nthrs * N2D_MAX) || (n <= nthrs * N2D_MAX / 2))
1048 && (m >= 2 * M2D_MIN);
1049 } else {
1050 int scale = mayiuse(avx512_core) ? nthrs : 20;
1051 condition_2D_bsrc = (256 * m > scale * n) && (scale * m < 256 * n);
1052 }
1053
1054 // TODO Check if we should use k-partitioning.
1055
1056 int condition_1D_copya = false;
1057 if (mayiuse(avx512_core)) {
1058 const dim_t thresh = isSgemm ? N2D_MAX / 4 : 68;
1059 if (m >= 1000 && (n >= nthrs * thresh)) {
1060 condition_2D_bsrc = false;
1061 condition_1D_copya = true;
1062 }
1063 } else {
1064 if (m >= 1000 && n >= 4000) {
1065 condition_2D_bsrc = false;
1066 condition_1D_copya = true;
1067 }
1068 }
1069
1070 // If A or B offset is non-zero, we need to keep 1D_copya to reduce update
1071 // overhead.
1072 // TODO: the reasons seems to be in copy_sum_bx routines. At least,
1073 // after simple optimization of copy_sum_ax for avx512, similar
1074 // restriction on offset B became unnecessary. Revisit.
1075 if (is_int8 && arg->ao != 0 && (arg->bo != 0 || mayiuse(avx512_core))) {
1076 condition_2D_bsrc = false;
1077 condition_1D_copya = true;
1078 }
1079
1080 if (condition_2D_bsrc) {
1081 int nthrs_m = 1;
1082 int nthrs_n = nthrs;
1083
1084 if (isSgemm) {
1085 while ((nthrs_n % 2 == 0)
1086 && (n / nthrs > N2D_MAX || n / nthrs_n <= N2D_MAX / 2)
1087 && (m / nthrs_m >= 2 * M2D_MIN) && (nthrs_m < 4)) {
1088 nthrs_m *= 2;
1089 nthrs_n /= 2;
1090 }
1091
1092 thread_info.nthrs_m = nthrs_m;
1093 thread_info.nthrs_n = nthrs_n;
1094 thread_info.partition = partition_type::col_major_2d;
1095 } else {
1096 if (m == 800 && n == 300) {
1097 // TODO: Expand this branch to other problem sizes.
1098
1099 auto &thread_m = thread_info.thread_m;
1100 auto &thread_n = thread_info.thread_n;
1101
1102 const dim_t block_m = arg->um * 4;
1103 constexpr dim_t block_n = 64;
1104 constexpr dim_t small_m = 16;
1105 constexpr dim_t small_n = 2;
1106
1107 std::tie(nthrs_m, nthrs_n)
1108 = gemm_utils::calc_nthr_2d(nthrs, m, n, block_m,
1109 block_n, small_m, small_n, thread_m, thread_n);
1110
1111 thread_info.nthrs_m = nthrs_m;
1112 thread_info.nthrs_n = nthrs_n;
1113 thread_info.partition = partition_type::mnk_3d;
1114
1115 } else if ((n <= 64 || n >= 256)) {
1116 while (((nthrs_n > 1) && (n / nthrs_n < arg->un)
1117 && (m / nthrs_m >= 2 * arg->um)
1118 && mayiuse(avx512_core))
1119 || ((nthrs_n % 2 == 0)
1120 && (n / nthrs > N2D_MAX
1121 || n / nthrs_n <= N2D_MAX / 2)
1122 && (m / nthrs_m >= 2 * M2D_MIN)
1123 && (nthrs_m < 4))) {
1124 nthrs_m *= 2;
1125 nthrs_n /= 2;
1126 }
1127
1128 thread_info.nthrs_m = nthrs_m;
1129 thread_info.nthrs_n = nthrs_n;
1130 thread_info.partition = partition_type::col_major_2d;
1131 } else {
1132 // Use 3D decomposition from pack api without k-partitioning.
1133 set_thread_opts_pack(nthrs, thread_info, arg, false);
1134 }
1135 }
1136
1137 } else if (condition_1D_copya && dnnl_thr_syncable()) {
1138 // Use parallel copy A algorithm
1139 thread_info.copy = copy_type::shared_a;
1140 thread_info.partition = partition_type::col_1d;
1141 thread_info.nthrs_m = 1;
1142 thread_info.nthrs_n = nthrs_spawn; // Using all spawned threads.
1143 } else {
1144 auto veclen = get_vector_length<c_type>();
1145
1146 if (m > n && (m >= nthrs * veclen || n < nthrs)) {
1147 if (n <= 20 && is_int8) {
1148 // Use 3D decomposition forcing m-blocking only.
1149 set_thread_opts_pack(
1150 nthrs, thread_info, arg, false, true, false);
1151 } else {
1152 thread_info.partition = partition_type::row_1d;
1153 thread_info.nthrs_m = nthrs;
1154 thread_info.nthrs_n = 1;
1155 }
1156 } else {
1157 thread_info.partition = partition_type::col_1d;
1158 thread_info.nthrs_m = 1;
1159 thread_info.nthrs_n = nthrs;
1160 }
1161 }
1162 }
1163
1164 template <typename a_type, typename b_type, typename c_type>
set_thread_opts_pack(int nthrs,gemm_threading_t & thread_info,const gemm_info_t<a_type,b_type,c_type> * arg,bool do_k_blocking=true,bool do_m_blocking=true,bool do_n_blocking=true)1165 static inline void set_thread_opts_pack(int nthrs,
1166 gemm_threading_t &thread_info,
1167 const gemm_info_t<a_type, b_type, c_type> *arg,
1168 bool do_k_blocking = true, bool do_m_blocking = true,
1169 bool do_n_blocking = true) {
1170
1171 constexpr bool is_int8 = utils::one_of(
1172 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1173 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1174
1175 bool do_m_blocking_only = do_m_blocking && !do_n_blocking;
1176
1177 auto m = arg->m, n = arg->n, k = arg->k;
1178
1179 auto &nthr_m = thread_info.nthrs_m;
1180 auto &nthr_n = thread_info.nthrs_n;
1181 auto &nthr_k = thread_info.nthrs_k;
1182 auto &thread_m = thread_info.thread_m;
1183 auto &thread_n = thread_info.thread_n;
1184 auto &thread_k = thread_info.thread_k;
1185 auto &block_m = thread_info.block_m;
1186 auto &block_n = thread_info.block_n;
1187 auto &block_k = thread_info.block_k;
1188
1189 constexpr auto MBLK = 64;
1190 constexpr auto NBLK = 64;
1191 auto KBLK = is_int8 ? 3072 : 256;
1192 KBLK = do_m_blocking_only && is_int8 ? 384 : KBLK;
1193
1194 nthr_m = nthr_n = nthr_k = 1;
1195 thread_info.copy = copy_type::nonshared;
1196 thread_info.partition = partition_type::mnk_3d;
1197
1198 auto choose_blocking
1199 = [](dim_t size_z, dim_t &thread_z, int &nthr_z, dim_t block_z_init,
1200 dim_t &block_z, dim_t block_align) {
1201 thread_z = utils::div_up(size_z, nthr_z);
1202 auto num_blk = utils::div_up(thread_z, block_z_init);
1203 block_z = utils::div_up(thread_z, num_blk);
1204 block_z = utils::rnd_up(block_z, block_align);
1205 thread_z = num_blk * block_z;
1206 if (thread_z * nthr_z > size_z)
1207 nthr_z = utils::div_up(size_z, thread_z);
1208 };
1209
1210 auto choose_m_blocking = [&]() {
1211 auto align = get_vector_length<c_type>();
1212 align = do_m_blocking_only ? arg->um : align;
1213 choose_blocking(m, thread_m, nthr_m, arg->bm, block_m, align);
1214 };
1215 auto choose_n_blocking = [&]() {
1216 choose_blocking(n, thread_n, nthr_n, arg->bn, block_n, arg->un);
1217 };
1218 auto choose_k_blocking = [&]() {
1219 auto align = nstl::max(arg->uk, dim_t(4));
1220 choose_blocking(k, thread_k, nthr_k, arg->bk, block_k, align);
1221 };
1222
1223 // Choose k blocking.
1224 if ((m / MBLK + n / NBLK) < nthrs && do_k_blocking) {
1225 for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1226 if (nthrs % nk == 0) nthr_k = nk;
1227
1228 // Sacrifice one thread and try again if parallelism is too small in
1229 // n-dimension.
1230 if (nthr_k == 1 && nthrs > 1 && do_m_blocking_only) {
1231 nthrs--;
1232 for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1233 if (nthrs % nk == 0) nthr_k = nk;
1234 }
1235
1236 // Allow up to 2 threads to be sacrificed for large k >> m, n.
1237 if (nthr_k < 4 && k >= m * 4 && k >= n * 4 && nthrs > 10 && is_bf16) {
1238 for (int nk = 1; nk <= 4 && k >= ((KBLK + 1) * nk); nk++)
1239 if (nthrs % nk <= 2) nthr_k = nk;
1240 }
1241 }
1242
1243 choose_k_blocking();
1244
1245 // Choose m/n blocking.
1246 auto min_mblk = mayiuse(avx512_core) ? (MBLK / 2) : arg->um;
1247 min_mblk = do_m_blocking ? min_mblk : m;
1248 min_mblk = do_m_blocking_only ? arg->um : min_mblk;
1249 auto min_nblk = do_n_blocking ? NBLK / 2 : n;
1250
1251 std::tie(nthr_m, nthr_n) = partition_2d_minblk(m, n, MBLK, NBLK, min_mblk,
1252 min_nblk, arg->um, arg->un, nthrs / nthr_k,
1253 do_m_blocking && do_n_blocking && do_k_blocking);
1254
1255 auto nthr_m_init = nthr_m, nthr_n_init = nthr_n;
1256
1257 choose_m_blocking();
1258 choose_n_blocking();
1259
1260 if (is_int8 && do_m_blocking && do_n_blocking) {
1261 // If we lost a thread in one dimension because we padded the blocking
1262 // size, try to rebalance the other dimensions.
1263 if ((nthr_n != nthr_n_init)
1264 && ((nthr_m + 1) * nthr_n * nthr_k <= nthrs)) {
1265 nthr_m++;
1266 choose_m_blocking();
1267 }
1268
1269 if ((nthr_m != nthr_m_init)
1270 && (nthr_m * (nthr_n + 1) * nthr_k <= nthrs)) {
1271 nthr_n++;
1272 choose_n_blocking();
1273 }
1274 }
1275 }
1276
1277 template <typename a_type, typename b_type, typename c_type>
set_thread_opts(int nthrs,int nthrs_spawn,gemm_threading_t & thread_info,const gemm_info_t<a_type,b_type,c_type> * arg)1278 static inline int set_thread_opts(int nthrs, int nthrs_spawn,
1279 gemm_threading_t &thread_info,
1280 const gemm_info_t<a_type, b_type, c_type> *arg) {
1281
1282 thread_info.block_m = thread_info.block_n = thread_info.block_k = -1;
1283 thread_info.thread_m = thread_info.thread_n = thread_info.thread_k = -1;
1284
1285 constexpr bool is_int8 = utils::one_of(
1286 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1287 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1288
1289 if (nocopy_checker(nthrs, arg)) {
1290 thread_info.copy = copy_type::no_copy;
1291 thread_info.partition = partition_type::mnk_3d;
1292 int nthrs_m = 0;
1293 int nthrs_n = 0;
1294 int nthrs_k = 0;
1295 dim_t BM = 0;
1296 dim_t BN = 0;
1297 dim_t BK = 0;
1298 auto m = arg->m, n = arg->n, k = arg->k;
1299
1300 if (mayiuse(avx512_core)) {
1301 cpu::gemm_utils::calc_nthr_nocopy_avx512_common(m, n, k, nthrs,
1302 &nthrs_m, &nthrs_n, &nthrs_k, &BM, &BN, &BK);
1303 } else {
1304 cpu::gemm_utils::calc_nthr_nocopy_avx(m, n, k, nthrs, &nthrs_m,
1305 &nthrs_n, &nthrs_k, &BM, &BN, &BK);
1306 }
1307
1308 // Block information is being ignored. We will create partitioning
1309 // later.
1310 thread_info.nthrs_m = nthrs_m;
1311 thread_info.nthrs_n = nthrs_n;
1312 thread_info.nthrs_k = nthrs_k;
1313 } else {
1314 if (arg->packing != pack_type::none && (is_int8 || is_bf16))
1315 set_thread_opts_pack(nthrs, thread_info, arg);
1316 else
1317 set_thread_opts_nopack(nthrs, nthrs_spawn, thread_info, arg);
1318 }
1319
1320 return thread_info.nthrs_m * thread_info.nthrs_n * thread_info.nthrs_k;
1321 }
1322
1323 template <typename a_type, typename b_type, typename c_type>
1324 static inline std::tuple<const a_type *, const b_type *, c_type *,
1325 const c_type *>
decompose_matrices(const gemm_slice_t & slice,const gemm_info_t<a_type,b_type,c_type> * arg)1326 decompose_matrices(const gemm_slice_t &slice,
1327 const gemm_info_t<a_type, b_type, c_type> *arg) {
1328
1329 dim_t stride_am = (arg->transa == no_trans) ? 1 : arg->lda;
1330 dim_t stride_ak = (arg->transa != no_trans) ? 1 : arg->lda;
1331 dim_t stride_bn = (arg->transb != no_trans) ? 1 : arg->ldb;
1332 dim_t stride_bk = (arg->transb == no_trans) ? 1 : arg->ldb;
1333
1334 auto a = arg->a;
1335 auto b = arg->b;
1336 auto c = arg->c;
1337 if (a) a += slice.off_m * stride_am + slice.off_k * stride_ak;
1338 if (b) b += slice.off_n * stride_bn + slice.off_k * stride_bk;
1339 if (c) c += slice.off_m + slice.off_n * arg->ldc;
1340
1341 dim_t co_stride;
1342 switch (arg->offsetc) {
1343 case offset_type::row: co_stride = slice.off_n; break;
1344 case offset_type::column: co_stride = slice.off_m; break;
1345 default: co_stride = 0; break;
1346 }
1347 auto co = arg->co;
1348 if (co) co += co_stride;
1349
1350 return std::make_tuple(a, b, c, co);
1351 }
1352
1353 template <typename a_type, typename b_type, typename c_type>
parallel_a_copy(const int ithr,const int nthrs,const dim_t m,const dim_t n,const dim_t k,const a_type * a,const b_type * b,float beta,c_type * c,dim_t ldc,offset_type offsetc,const c_type * co,const gemm_info_t<a_type,b_type,c_type> * arg,char ** p_shared_mem)1354 static dnnl_status_t parallel_a_copy(const int ithr, const int nthrs,
1355 const dim_t m, const dim_t n, const dim_t k, const a_type *a,
1356 const b_type *b, float beta, c_type *c, dim_t ldc, offset_type offsetc,
1357 const c_type *co, const gemm_info_t<a_type, b_type, c_type> *arg,
1358 char **p_shared_mem) {
1359
1360 if (arg->packing != pack_type::none)
1361 return gemm_packing_driver(ithr, m, n, k, a, b, arg);
1362
1363 const dim_t lda = arg->lda;
1364 const dim_t ldb = arg->ldb;
1365 const dim_t strideAm = (arg->transa == no_trans) ? 1 : lda;
1366 const dim_t strideAn = (arg->transa != no_trans) ? 1 : lda;
1367 const dim_t strideBm = (arg->transb == no_trans) ? 1 : ldb;
1368
1369 float alpha = arg->alpha;
1370
1371 constexpr bool is_int8 = utils::one_of(
1372 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1373 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1374 bool is_int8_amx = is_int8 && mayiuse(avx512_core_bf16_amx_int8);
1375 bool is_bf16_amx = is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
1376 bool is_amx = is_int8_amx || is_bf16_amx;
1377
1378 const std::shared_ptr<const gemm_pack_storage_t> &a_packed = arg->a_packed;
1379
1380 // Scaling C matrix.
1381 if (!is_int8 && beta != 1.0f && beta != 0.0f) {
1382 scale_matrix(m, n, beta, c, ldc);
1383 beta = 1.0f;
1384 }
1385
1386 // Padding along M, K dimensions.
1387 dim_t m_padd = get_m_padd_parallel_a(ithr, m, arg, nthrs);
1388 dim_t k_padd = get_k_padd(ithr, k, arg);
1389
1390 size_t a_buf_nelems = m_padd * k_padd;
1391
1392 // A buffer needs more space due to zero-padding.
1393 if (is_amx)
1394 a_buf_nelems = utils::rnd_up(m_padd, arg->um)
1395 * utils::rnd_up(k_padd, arg->uk);
1396
1397 // Allocate shared memory for A and its row sum buffers in master thread.
1398 char *mem = nullptr;
1399 a_type *bufferA = nullptr;
1400 c_type *a_row_sum = nullptr;
1401
1402 if (!a_packed) {
1403 if (ithr == 0) { // If thread master
1404 size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K);
1405
1406 if (is_int8) {
1407 size_t a_row_sum_nelems = m_padd;
1408 mem_size += a_row_sum_nelems * sizeof(*c) + PAGE_4K;
1409 }
1410
1411 *p_shared_mem = (char *)malloc(mem_size, 128);
1412 }
1413
1414 dnnl_thr_barrier();
1415
1416 mem = *p_shared_mem;
1417 bufferA = (a_type *)align(mem, PAGE_4K);
1418
1419 if (is_int8)
1420 a_row_sum = (c_type *)align(bufferA + a_buf_nelems, PAGE_4K);
1421
1422 if (!mem) return dnnl_out_of_memory;
1423 }
1424
1425 dnnl_status_t result = dnnl_success; // Return status
1426
1427 dim_t sizeK = 0;
1428 dim_t blk_k = 0;
1429 for (dim_t Bk = 0; Bk < k; Bk += sizeK, blk_k++) {
1430 sizeK = k - Bk;
1431 if (sizeK > k_padd) sizeK = k_padd;
1432
1433 // Scale C blocks by beta only for the first term of partial sum.
1434 auto beta_eff = (Bk == 0) ? beta : 1.0f;
1435
1436 // Apply C offset for the last k-block of the partial sum.
1437 auto offsetc_eff = offset_type::none;
1438 if (Bk + sizeK == k) offsetc_eff = offsetc;
1439
1440 dim_t sizeM = 0;
1441 for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
1442 sizeM = m - Bm;
1443 if (sizeM > m_padd) sizeM = m_padd;
1444
1445 if ((ithr < nthrs) && !a_packed) {
1446 dim_t band = (sizeM + nthrs - 1) / nthrs;
1447 band = utils::rnd_up(band, arg->um);
1448
1449 dim_t offset = band * ithr;
1450
1451 // If offset is too large don't use that thread for copying.
1452 if (offset >= sizeM) {
1453 offset = 0;
1454 band = 0;
1455 }
1456
1457 // Handle the tail of the copy.
1458 if (offset + band > sizeM) { band = sizeM - offset; }
1459
1460 if (band > 0) {
1461 const a_type *a_block
1462 = a + (Bm + offset) * strideAm + Bk * strideAn;
1463
1464 dim_t buf_shift = 0;
1465 if (is_amx)
1466 buf_shift = offset * utils::rnd_up(sizeK, arg->uk);
1467 else
1468 buf_shift = offset * sizeK;
1469
1470 /* Row sum argument is ignored for non-integer kernels and
1471 * scaling factor is ignored by 8-bit and 16-bit copy
1472 * kernels.
1473 */
1474 c_type *a_row_sum_eff
1475 = a_row_sum ? a_row_sum + offset : nullptr;
1476 arg->copyA(&sizeK, &band, a_block, &lda, &alpha,
1477 bufferA + buf_shift, nullptr, nullptr,
1478 a_row_sum_eff);
1479 }
1480 }
1481 if (!a_packed)
1482 dnnl_thr_barrier(); // Wait for finishing parallel copy.
1483
1484 const b_type *b_block = b + Bk * strideBm;
1485 c_type *c_block = c + Bm;
1486
1487 dim_t co_stride = 0;
1488 if (offsetc_eff == offset_type::fixed) {
1489 co_stride = 0;
1490 } else if (offsetc_eff == offset_type::row) {
1491 co_stride = 0;
1492 } else if (offsetc_eff == offset_type::column) {
1493 co_stride = Bm;
1494 }
1495
1496 auto bufferA_eff
1497 = a_packed ? a_packed->matrix<a_type>(0, Bm, Bk) : bufferA;
1498 auto a_row_sum_eff = a_packed
1499 ? a_packed->row_sums<c_type>(0, Bm, blk_k)
1500 : a_row_sum;
1501
1502 auto this_result = kernel_driver_parallel_acopiedbcopy(ithr, sizeM,
1503 n, sizeK, blk_k, Bk, bufferA_eff, b_block, beta_eff,
1504 c_block, offsetc_eff, co + co_stride, a_row_sum_eff, arg);
1505
1506 if (this_result != dnnl_success) result = this_result;
1507
1508 if (!a_packed)
1509 dnnl_thr_barrier(); // Wait for kernel computations to finish.
1510 }
1511 }
1512
1513 // Free memory allocated in master thread
1514 if (ithr == 0 && !a_packed) free(mem);
1515
1516 return result;
1517 }
1518
1519 template <typename T>
adjust_thread_count(dim_t m,dim_t n,dim_t k,int * nthrs)1520 static inline void adjust_thread_count(dim_t m, dim_t n, dim_t k, int *nthrs) {
1521
1522 const double omp_overhead_small_core = 3.0e+3;
1523 const double omp_intercept_big_core = 4.0e+3;
1524 const double omp_slope_big_core = 5.0e+2;
1525
1526 auto veclen = get_vector_length<T>();
1527 const double fp_per_cycle = 2.0 * 2.0 * veclen;
1528
1529 const bool is_f32 = data_traits<T>::data_type == data_type::f32;
1530
1531 const bool is_avx512_mic = mayiuse(avx512_mic);
1532 const bool is_avx512 = mayiuse(avx512_core);
1533 const bool is_avx = mayiuse(avx);
1534 const bool is_only_avx2 = mayiuse(avx2) && !is_avx512;
1535
1536 if (is_avx512_mic) return;
1537
1538 // Some sgemm cases still benefit from using all threads.
1539 const bool use_all_threads = is_f32 && n > 50
1540 && ((is_avx && m <= 3) || (is_avx512 && m <= 10));
1541 if (use_all_threads) return;
1542
1543 if (is_only_avx2)
1544 if (m > 10 * n && n < *nthrs)
1545 if (m / *nthrs < veclen * 3)
1546 *nthrs = nstl::max(m / veclen / 3, dim_t(1));
1547
1548 double gemm_cycles = m * n * k / fp_per_cycle;
1549 gemm_cycles *= is_f32 ? 2.0 : 8.0;
1550
1551 int i = *nthrs;
1552
1553 // Use a different model for omp overheads if nthrs is <= 4
1554 if (*nthrs <= 4 && omp_overhead_small_core > 0) {
1555 double omp_cycles = omp_overhead_small_core;
1556 if (gemm_cycles < omp_cycles) {
1557 *nthrs = 1;
1558 return;
1559 } else {
1560 while (i > 1) {
1561 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
1562 --i;
1563 }
1564 }
1565 } else {
1566 if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
1567 *nthrs = 1;
1568 return;
1569 }
1570
1571 // adaptive decrement to march faster·
1572 while (i > 1) {
1573 double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
1574 if (omp_cycles * i < gemm_cycles * (i - 1)) break;
1575
1576 if (i < 10)
1577 i -= 2;
1578 else if (i < 30)
1579 i -= 4;
1580 else
1581 i -= 8;
1582 }
1583 }
1584
1585 if (i < 1) i = 1;
1586
1587 *nthrs = i;
1588 }
1589
1590 template <typename a_type, typename b_type, typename c_type>
call_no_copy_sgemm(int nthrs,gemm_info_t<a_type,b_type,c_type> * arg)1591 static dnnl_status_t call_no_copy_sgemm(
1592 int nthrs, gemm_info_t<a_type, b_type, c_type> *arg) {
1593
1594 if (arg->packing == pack_type::none) {
1595 auto transa_char = (arg->transa != do_trans) ? "N" : "T";
1596 auto transb_char = (arg->transb != do_trans) ? "N" : "T";
1597
1598 if (mayiuse(avx512_core))
1599 return jit_avx512_common_gemm_f32(nthrs, transa_char, transb_char,
1600 &arg->m, &arg->n, &arg->k, &arg->alpha, (float *)arg->a,
1601 &arg->lda, (float *)arg->b, &arg->ldb, &arg->beta,
1602 (float *)arg->c, &arg->ldc, (float *)arg->co);
1603 else
1604 return jit_avx_gemm_f32(nthrs, transa_char, transb_char, &arg->m,
1605 &arg->n, &arg->k, &arg->alpha, (float *)arg->a, &arg->lda,
1606 (float *)arg->b, &arg->ldb, &arg->beta, (float *)arg->c,
1607 &arg->ldc, (float *)arg->co);
1608 } else
1609 return pack_no_copy(arg);
1610 }
1611
1612 template <typename a_type, typename b_type, typename c_type>
gemm_threading_driver(gemm_info_t<a_type,b_type,c_type> * arg)1613 static dnnl_status_t gemm_threading_driver(
1614 gemm_info_t<a_type, b_type, c_type> *arg) {
1615
1616 auto packing = (arg->packing != pack_type::none);
1617 auto is_a_packed = (arg->transa == packed);
1618 auto is_b_packed = (arg->transb == packed);
1619 constexpr bool is_int8 = utils::one_of(
1620 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1621 constexpr bool is_bf16 = data_traits<a_type>::data_type == data_type::bf16;
1622
1623 if ((arg->m <= 0) || (arg->n <= 0)) return dnnl_success;
1624
1625 if (!is_a_packed && !is_b_packed && jump_to_gemv_s8x8s32(arg))
1626 return dnnl_success;
1627
1628 if (!is_a_packed && !is_b_packed
1629 && jump_to_gemm_smalln_tn(arg) == dnnl_success)
1630 return dnnl_success;
1631
1632 if (!is_a_packed && !is_b_packed && jump_to_gemv(arg) == dnnl_success)
1633 return dnnl_success;
1634
1635 if (is_a_packed && arg->bo != 0)
1636 if (!arg->a_packed->has_row_sums()) return dnnl_invalid_arguments;
1637
1638 if (is_b_packed && arg->ao != 0)
1639 if (!arg->b_packed->has_col_sums()) return dnnl_invalid_arguments;
1640
1641 auto nthr_max = dnnl_get_current_num_threads();
1642 int nthr_goal = nthr_max;
1643
1644 adjust_thread_count<c_type>(arg->m, arg->n, arg->k, &nthr_goal);
1645
1646 const gemm_threading_t *force_threading = nullptr;
1647 gemm_threading_t force_k_decomp;
1648
1649 // Initialize per-thread data.
1650 // Note: to support k blocking with non-packed GEMM, threading must be
1651 // chosen now and force_threading set.
1652 if (!packing) {
1653 // Override choice of thread count if data is pre-packed for a particular
1654 // number of threads.
1655 if (is_a_packed && is_b_packed)
1656 if (arg->a_packed->threading() != arg->b_packed->threading())
1657 return dnnl_invalid_arguments;
1658 if (is_a_packed)
1659 force_threading = &arg->a_packed->threading();
1660 else if (is_b_packed)
1661 force_threading = &arg->b_packed->threading();
1662 else if (arg->m <= 768 && arg->n <= 768 && arg->k >= 2048 && is_bf16) {
1663 // Try k-partitioning.
1664 set_thread_opts_pack(nthr_goal, force_k_decomp, arg);
1665
1666 // Decide partition type later if no partitions in k-dimension.
1667 if (force_k_decomp.nthrs_k > 1) force_threading = &force_k_decomp;
1668 } else if (arg->n <= 128 && arg->k >= 3072 && is_int8) {
1669 // Use k-partitioning if necessary.
1670 // Use 3D decomposition from pack api without n-partitioning.
1671 set_thread_opts_pack(
1672 nthr_goal, force_k_decomp, arg, true, true, false);
1673
1674 // Decide partition type later if no partitions in k-dimension.
1675 if (force_k_decomp.nthrs_k > 1 && force_k_decomp.nthrs_m > 1)
1676 force_threading = &force_k_decomp;
1677 }
1678
1679 if (force_threading) {
1680 nthr_goal = force_threading->nthrs();
1681 arg->update_blocking(*force_threading);
1682 }
1683 } else {
1684 // Prepare packed data layout.
1685 gemm_pack_storage_t *pack_dst = arg->pack_dst;
1686 bool do_a = (arg->packing == pack_type::pack_a);
1687
1688 pack_dst->which() = do_a ? matrix_id::a : matrix_id::b;
1689 pack_dst->setup(nthr_goal, do_a && is_int8, !do_a && is_int8);
1690
1691 auto &thread_info = pack_dst->threading();
1692 force_threading = &thread_info;
1693
1694 nthr_goal = set_thread_opts(nthr_goal, nthr_max, thread_info, arg);
1695 arg->update_blocking(thread_info);
1696
1697 if (thread_info.copy != copy_type::no_copy) {
1698 for (int ithr = 0; ithr < nthr_goal; ithr++) {
1699 if (!pack_dst->is_first_thread_in_slice(ithr)) continue;
1700
1701 auto slice = thread_info.get_thread_slice(
1702 ithr, arg->m, arg->n, arg->k);
1703
1704 auto m = slice.m, n = slice.n, k = slice.k;
1705
1706 auto m_padd = (thread_info.copy == copy_type::shared_a)
1707 ? get_m_padd_parallel_a(
1708 ithr, m, arg, thread_info.nthrs())
1709 : get_m_padd(ithr, m, arg);
1710 auto n_padd = get_n_padd(ithr, n, k, arg);
1711 auto k_padd = get_k_padd(ithr, k, arg);
1712
1713 do_a ? pack_dst->set_blocking(ithr, m, k, m_padd, k_padd)
1714 : pack_dst->set_blocking(ithr, k, n, k_padd, n_padd);
1715 }
1716 } else {
1717 auto ld = do_a ? gemm_utils::get_ld_padd<a_type>(arg->m)
1718 : gemm_utils::get_ld_padd<b_type>(arg->k);
1719
1720 pack_dst->set_nocopy(0, no_trans, ld, do_a ? arg->k : arg->n);
1721 }
1722
1723 do_a ? pack_dst->finalize<a_type, c_type>()
1724 : pack_dst->finalize<b_type, c_type>();
1725
1726 if (arg->measure_only) return dnnl_success;
1727 }
1728
1729 if (nocopy_checker(nthr_goal, arg))
1730 return call_no_copy_sgemm(nthr_goal, arg);
1731
1732 if (nthr_goal == 1)
1733 return gemm_kernel_driver(0, arg->m, arg->n, arg->k, arg->a, arg->b,
1734 arg->beta, arg->c, arg->ldc, arg->offsetc, arg->co, arg);
1735
1736 bool k_blocking = force_threading && (force_threading->nthrs_k > 1);
1737 bool k_summing = k_blocking && !packing;
1738
1739 auto *thread_arg = (gemm_per_thread_t<c_type> *)malloc(
1740 sizeof(gemm_per_thread_t<c_type>) * nthr_max, PAGE_4K);
1741
1742 if (!thread_arg) return dnnl_out_of_memory;
1743
1744 dim_t max_mt = 0, max_nt = 0;
1745 for (int ithr = 0; ithr < nthr_max; ithr++) {
1746 thread_arg[ithr].result = dnnl_success;
1747 thread_arg[ithr].compute_done = false;
1748 thread_arg[ithr].c_local = thread_arg[ithr].c_global = nullptr;
1749 thread_arg[ithr].ldc_global = arg->ldc;
1750 thread_arg[ithr].ldc_local = 0;
1751
1752 if (force_threading) {
1753 thread_arg[ithr].slice = force_threading->get_thread_slice(
1754 ithr, arg->m, arg->n, arg->k);
1755 thread_arg[ithr].nthr_k = force_threading->nthrs_k;
1756 thread_arg[ithr].thr_k_stride = force_threading->thr_k_stride();
1757 max_mt = nstl::max(max_mt, thread_arg[ithr].slice.m);
1758 max_nt = nstl::max(max_nt, thread_arg[ithr].slice.n);
1759 } else {
1760 thread_arg[ithr].slice = {0, 0, 0, 0, 0, 0, 0, 0, 0};
1761 thread_arg[ithr].nthr_k = 1;
1762 thread_arg[ithr].thr_k_stride = 0;
1763 }
1764 }
1765
1766 // Create temporary C buffers for k blocking if needed.
1767 c_type *c_local_storage = nullptr;
1768 if (k_summing) {
1769 const dim_t BAD_LD_MULT = 256;
1770 dim_t ldc_local = max_mt % BAD_LD_MULT
1771 ? max_mt
1772 : gemm_utils::get_ld_padd<c_type>(max_mt);
1773 dim_t c_local_stride = ldc_local * max_nt;
1774 c_local_storage = (c_type *)malloc(
1775 sizeof(c_type) * c_local_stride * nthr_goal, PAGE_4K);
1776
1777 if (!c_local_storage) {
1778 free(thread_arg);
1779 return dnnl_out_of_memory;
1780 }
1781
1782 for (int ithr = 0; ithr < nthr_goal; ithr++) {
1783 thread_arg[ithr].c_local = c_local_storage + ithr * c_local_stride;
1784 thread_arg[ithr].ldc_local = ldc_local;
1785 }
1786 }
1787
1788 char *shared_mem = nullptr;
1789
1790 // Always use the maximum number of threads to avoid OMP overhead that can
1791 // occur due to change thread counts.
1792 int nthr_spawn = dnnl_thr_syncable() ? nthr_max : nthr_goal;
1793
1794 parallel(nthr_spawn, [&](int ithr, int nthr) {
1795 int nthr_eff = force_threading ? nthr_goal : nstl::min(nthr_goal, nthr);
1796
1797 if (nthr_eff == 1) {
1798 thread_arg[0].result = gemm_kernel_driver(0, arg->m, arg->n, arg->k,
1799 arg->a, arg->b, arg->beta, arg->c, arg->ldc, arg->offsetc,
1800 arg->co, arg);
1801 } else {
1802 gemm_threading_t thread_info;
1803
1804 if (force_threading)
1805 thread_info = *force_threading;
1806 else {
1807 nthr_eff = set_thread_opts(nthr_eff, nthr, thread_info, arg);
1808 if (ithr < nthr_eff)
1809 thread_arg[ithr].slice = thread_info.get_thread_slice(
1810 ithr, arg->m, arg->n, arg->k);
1811 }
1812
1813 for (; ithr < nthr_eff; ithr += nthr) {
1814 // Get submatrices and parameters for this thread's GEMM.
1815 const a_type *a = nullptr;
1816 const b_type *b = nullptr;
1817 c_type *c = nullptr;
1818 const c_type *co = nullptr;
1819 std::tie(a, b, c, co)
1820 = decompose_matrices(thread_arg[ithr].slice, arg);
1821
1822 auto m = thread_arg[ithr].slice.m;
1823 auto n = thread_arg[ithr].slice.n;
1824 auto k = thread_arg[ithr].slice.k;
1825 thread_arg[ithr].c_global = c;
1826 auto c_eff = c;
1827 auto ldc_eff = arg->ldc;
1828 auto beta_eff = arg->beta;
1829 auto offsetc_eff = arg->offsetc;
1830
1831 // For all but first k block: substitute local C matrix and
1832 // disable postops.
1833 if (k_summing && thread_arg[ithr].slice.ithr_k > 0) {
1834 c_eff = thread_arg[ithr].c_local;
1835 ldc_eff = thread_arg[ithr].ldc_local;
1836 beta_eff = 0;
1837 offsetc_eff = offset_type::none;
1838 }
1839
1840 // Dispatch appropriate GEMM driver.
1841 switch (thread_info.copy) {
1842 case copy_type::shared_a:
1843 thread_arg[ithr].result = parallel_a_copy(ithr,
1844 nthr_eff, m, n, k, a, b, beta_eff, c_eff,
1845 ldc_eff, offsetc_eff, co, arg, &shared_mem);
1846 break;
1847
1848 default:
1849 case copy_type::nonshared:
1850 thread_arg[ithr].result = gemm_kernel_driver(ithr, m, n,
1851 k, a, b, beta_eff, c_eff, ldc_eff, offsetc_eff,
1852 co, arg);
1853 break;
1854
1855 case copy_type::no_copy:
1856 // This route is taken only if we realize we need no-copy
1857 // after launching the parallel section, due to less
1858 // threads being spawned than expected.
1859 assert(data_traits<a_type>::data_type
1860 == data_type::f32);
1861 assert(arg->packing == pack_type::none);
1862
1863 if (mayiuse(avx512_core)) {
1864 avx512_common_gemm_f32::sgemm_nocopy_driver(
1865 arg->transa == no_trans ? "N" : "T",
1866 arg->transb == no_trans ? "N" : "T", m, n,
1867 k, &arg->alpha, (float *)a, arg->lda,
1868 (float *)b, arg->ldb, &beta_eff,
1869 (float *)c_eff, ldc_eff, nullptr, nullptr);
1870 } else {
1871 avx_gemm_f32::sgemm_nocopy_driver(
1872 arg->transa == no_trans ? "N" : "T",
1873 arg->transb == no_trans ? "N" : "T", m, n,
1874 k, &arg->alpha, (float *)a, arg->lda,
1875 (float *)b, arg->ldb, &beta_eff,
1876 (float *)c_eff, ldc_eff, nullptr, nullptr);
1877 }
1878 thread_arg[ithr].result = dnnl_success;
1879 break;
1880 }
1881
1882 // Sum thread results along k dimension, parallelized in the n
1883 // dimension. To avoid deadlocks, results are summed later if
1884 // not all threads are running concurrently. We can only detect
1885 // if this is safe when using OpenMP.
1886 #if DNNL_THR_SYNC == 1
1887 if (k_summing && (nthr >= nthr_eff)) {
1888 thread_arg[ithr].compute_done = true;
1889 sum_k_blocks(ithr, thread_arg, true);
1890 }
1891 #endif
1892 }
1893 }
1894 });
1895
1896 dnnl_status_t result = dnnl_success; // Initialize to success
1897 for (int ithr = 0; ithr < nthr_max; ithr++) {
1898 if (thread_arg[ithr].result != dnnl_success) {
1899 result = static_cast<dnnl_status_t>(thread_arg[ithr].result);
1900 break;
1901 }
1902 }
1903
1904 // Sum thread results along k dimension if this wasn't done earlier.
1905 if (k_summing && !thread_arg[0].compute_done) {
1906 parallel(nthr_goal, [&](int ithr, int nthr) {
1907 for (; ithr < nthr_goal; ithr += nthr)
1908 sum_k_blocks(ithr, thread_arg, false);
1909 });
1910 }
1911
1912 if (c_local_storage) dnnl::impl::free(c_local_storage);
1913 dnnl::impl::free(thread_arg);
1914
1915 return result;
1916 }
1917
1918 template <typename a_type, typename b_type, typename c_type>
gemm_driver(const char * transA,const char * transB,const char * offsetC,const dim_t * m,const dim_t * n,const dim_t * k,const float * alpha,const a_type * a,const dim_t * lda,const a_type * oa,const b_type * b,const dim_t * ldb,const b_type * ob,const float * beta,c_type * c,const dim_t * ldc,const c_type * oc,const bool force_nocopy,pack_type packing,gemm_pack_storage_t * pack_dst,bool measure_only)1919 dnnl_status_t gemm_driver(const char *transA, const char *transB,
1920 const char *offsetC, const dim_t *m, const dim_t *n, const dim_t *k,
1921 const float *alpha, const a_type *a, const dim_t *lda, const a_type *oa,
1922 const b_type *b, const dim_t *ldb, const b_type *ob, const float *beta,
1923 c_type *c, const dim_t *ldc, const c_type *oc, const bool force_nocopy,
1924 pack_type packing, gemm_pack_storage_t *pack_dst, bool measure_only) {
1925
1926 constexpr bool is_int8 = utils::one_of(
1927 data_traits<a_type>::data_type, data_type::s8, data_type::u8);
1928 MAYBE_UNUSED(is_int8);
1929
1930 // gemm_driver supports bfloat16 gemm for Intel AVX512 and
1931 // Intel AVX512 BF16.
1932 assert(IMPLICATION(data_traits<a_type>::data_type == data_type::bf16,
1933 mayiuse(avx512_core) && !force_nocopy));
1934
1935 // gemm_driver supports 8-bit integer Intel AVX512, Intel AVX2, Intel AVX,
1936 // Intel SSE4.1 and Intel DL Boost.
1937 assert(IMPLICATION(is_int8, mayiuse(sse41) && !mayiuse(avx512_mic)));
1938
1939 // gemm_driver supports sgemm for Intel AVX512, Intel AVX2, Intel AVX,
1940 // and Intel SSE4.1
1941 assert(IMPLICATION(
1942 data_traits<a_type>::data_type == data_type::f32, mayiuse(sse41)));
1943
1944 // 8-bit integer gemm doesn't support nocopy kernels.
1945 assert(IMPLICATION(is_int8, !force_nocopy));
1946
1947 // gemm_driver can only dispatch nocopy for avx and above.
1948 assert(IMPLICATION(force_nocopy, mayiuse(avx)));
1949
1950 gemm_info_t<a_type, b_type, c_type> args(transA, transB, offsetC, m, n, k,
1951 alpha, a, lda, oa, b, ldb, ob, beta, c, ldc, oc, force_nocopy,
1952 packing, pack_dst, measure_only);
1953
1954 // Check if copy algorithm kernels were generated on supported ISAs.
1955 if (!args.hasKernels()) return dnnl_unimplemented;
1956
1957 return gemm_threading_driver(&args);
1958 }
1959
1960 template // Instantiate gemm_bf16bf16f32
1961 dnnl_status_t
1962 gemm_driver<bfloat16_t, bfloat16_t, float>(const char *transA,
1963 const char *transB, const char *offsetC, const dim_t *m,
1964 const dim_t *n, const dim_t *k, const float *alpha,
1965 const bfloat16_t *a, const dim_t *lda, const bfloat16_t *oa,
1966 const bfloat16_t *b, const dim_t *ldb, const bfloat16_t *ob,
1967 const float *beta, float *c, const dim_t *ldc, const float *oc,
1968 const bool force_nocopy, pack_type packing,
1969 gemm_pack_storage_t *pack_dst, bool measure_only);
1970
1971 template // Instantiate gemm_s8s8s32
1972 dnnl_status_t
1973 gemm_driver<int8_t, int8_t, int32_t>(const char *transA,
1974 const char *transB, const char *offsetC, const dim_t *m,
1975 const dim_t *n, const dim_t *k, const float *alpha,
1976 const int8_t *a, const dim_t *lda, const int8_t *oa,
1977 const int8_t *b, const dim_t *ldb, const int8_t *ob,
1978 const float *beta, int32_t *c, const dim_t *ldc,
1979 const int32_t *oc, const bool force_nocopy, pack_type packing,
1980 gemm_pack_storage_t *pack_dst, bool measure_only);
1981
1982 template // Instantiate gemm_s8u8s32
1983 dnnl_status_t
1984 gemm_driver<int8_t, uint8_t, int32_t>(const char *transA,
1985 const char *transB, const char *offsetC, const dim_t *m,
1986 const dim_t *n, const dim_t *k, const float *alpha,
1987 const int8_t *a, const dim_t *lda, const int8_t *oa,
1988 const uint8_t *b, const dim_t *ldb, const uint8_t *ob,
1989 const float *beta, int32_t *c, const dim_t *ldc,
1990 const int32_t *oc, const bool force_nocopy, pack_type packing,
1991 gemm_pack_storage_t *pack_dst, bool measure_only);
1992
1993 template // Instantiate sgemm
1994 dnnl_status_t
1995 gemm_driver<float, float, float>(const char *transA, const char *transB,
1996 const char *offsetC, const dim_t *m, const dim_t *n,
1997 const dim_t *k, const float *alpha, const float *a,
1998 const dim_t *lda, const float *oa, const float *b,
1999 const dim_t *ldb, const float *ob, const float *beta, float *c,
2000 const dim_t *ldc, const float *oc, const bool force_nocopy,
2001 pack_type packing, gemm_pack_storage_t *pack_dst,
2002 bool measure_only);
2003
2004 } // namespace x64
2005 } // namespace cpu
2006 } // namespace impl
2007 } // namespace dnnl
2008