1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/memory_tracking.hpp"
20 #include "common/tag_traits.hpp"
21 #include "common/type_helpers.hpp"
22 #include "common/utils.hpp"
23 
24 #include "cpu/cpu_primitive.hpp"
25 
26 #include "cpu/x64/amx_tile_configure.hpp"
27 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28 #include "cpu/x64/matmul/brgemm_matmul.hpp"
29 
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34 namespace matmul {
35 
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::utils;
38 
39 using namespace nstl;
40 
41 using namespace data_type;
42 
43 template <cpu_isa_t isa>
init(engine_t * engine)44 status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
45     const auto src_dt = src_md_.data_type;
46     const auto wei_dt = weights_md_.data_type;
47     const auto dst_dt = dst_md_.data_type;
48 
49     const bool is_f32 = everyone_is(f32, src_dt, wei_dt, dst_dt);
50     const bool is_int8 = one_of(src_dt, u8, s8) && wei_dt == s8
51             && one_of(dst_dt, u8, s8, s32, f32, bf16);
52     const bool is_bf16
53             = everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
54 
55     auto check_bias = [&]() -> bool {
56         const bool is_bia_dt_correct
57                 = (is_int8
58                           && one_of(weights_md(1)->data_type, f32, s32, s8, u8,
59                                   bf16))
60                 || (is_bf16 && one_of(weights_md(1)->data_type, f32, bf16))
61                 || (is_f32 && weights_md(1)->data_type == f32);
62         return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN());
63     };
64 
65     auto check_attr_oscale = [&]() -> bool {
66         const auto &oscale = attr()->output_scales_;
67         return IMPLICATION(
68                 oscale.mask_ != 0, oscale.mask_ == (1 << (dst_md_.ndims - 1)));
69     };
70 
71     auto check_attr_zero_points
72             = [&]() -> bool { return attr()->zero_points_.common(); };
73 
74     const bool problem_dt_correct = is_int8 || is_bf16 || is_f32;
75     bool ok = mayiuse(isa) && problem_dt_correct
76             && !has_runtime_dims_or_strides()
77             && attr()->has_default_values(primitive_attr_t::skip_mask_t::oscale
78                             | primitive_attr_t::skip_mask_t::zero_points_runtime
79                             | primitive_attr_t::skip_mask_t::post_ops
80                             | primitive_attr_t::skip_mask_t::sum_dt,
81                     dst_dt)
82             && attr()->post_ops_.check_sum_consistent_dt(dst_dt)
83             && check_attr_oscale() && check_attr_zero_points() && check_bias();
84     if (!ok) return status::unimplemented;
85 
86     CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_,
87             dst_md_, bias_md_, attr_));
88 
89     const float alpha = 1.0;
90     const float beta = 1.0;
91     const float beta_init = 0.0;
92     for_(int i_init = 0; i_init < 2; i_init++)
93     for_(int i_M = 0; i_M < 2; i_M++)
94     for_(int i_N = 0; i_N < 2; i_N++)
95     for (int i_K = 0; i_K < 2; i_K++) {
96         auto vbeta = (i_init) ? beta_init : beta;
97         auto vM = (i_M) ? bgmmc_.M_tail : bgmmc_.M_blk;
98         auto vN = (i_N) ? bgmmc_.N_tail : bgmmc_.N_blk;
99         auto vK = (i_K) ? bgmmc_.K_tail : bgmmc_.K_blk;
100 
101         int idx = get_brg_kernel_idx(i_init, i_M, i_N, i_K);
102         if (idx < 0) continue;
103         brgemm_t &brg = brg_descs_[idx];
104         auto LDA = i_K && bgmmc_.use_buffer_a_tail_only
105                 ? (dim_t)bgmmc_.wei_k_blk
106                 : bgmmc_.LDA;
107         CHECK(brgemm_desc_init(&brg, isa, bgmmc_.brg_type, bgmmc_.src_dt,
108                 bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta,
109                 LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK));
110 
111         auto LDD = bgmmc_.N;
112         CHECK(brgemm_desc_set_postops(
113                 &brg, attr(), &dst_md_, LDD, bgmmc_.bia_dt));
114 
115         brgemm_attr_t brgattr;
116         constexpr bool is_amx = one_of(
117                 isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
118         if (is_amx) {
119             brgattr.max_bs = bgmmc_.brgemm_batch_size;
120             brgattr.wary_tail_read = false;
121 
122             // TODO: change expected sizes to local chunks wrt L2 blocking
123             brgattr.hint_expected_A_size = vM * vK * bgmmc_.brgemm_batch_size;
124             brgattr.hint_expected_B_size = vN * vK * bgmmc_.brgemm_batch_size;
125             brgattr.hint_expected_C_size = vM * vN * bgmmc_.brgemm_batch_size;
126             brgattr.hint_innermost_loop = brgemm_ld_loop_innermost;
127         }
128 
129         brgattr.generate_skip_accumulation
130                 = bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1;
131 
132         CHECK(brgemm_desc_set_attr(&brg, brgattr));
133     }
134 
135     auto scratchpad = scratchpad_registry().registrar();
136     init_scratchpad(scratchpad, bgmmc_);
137 
138     return status::success;
139 }
140 
141 template <cpu_isa_t isa>
init(engine_t * engine)142 status_t brgemm_matmul_t<isa>::init(engine_t *engine) {
143     for_(int i_M = 0; i_M < 2; i_M++)
144     for_(int i_N = 0; i_N < 2; i_N++)
145     for_(int i_K = 0; i_K < 2; i_K++)
146     for (int i_init = 0; i_init < 2; i_init++) {
147         int idx = pd()->get_brg_kernel_idx(i_init, i_M, i_N, i_K);
148         if (idx < 0) continue;
149 
150         brgemm_kernel_t *ker = nullptr;
151         CHECK(brgemm_kernel_create(&ker, pd()->get_brg_desc(idx)));
152         CHECK(safe_ptr_assign(brg_kernels_[idx], ker));
153         if (one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16))
154             CHECK(brgemm_init_tiles(
155                     pd()->get_brg_desc(idx), &brg_kernel_palettes_[idx][0]));
156     }
157 
158     const auto &bgmmc = pd()->get_brgemm_matmul_conf();
159     if (bgmmc.use_buffer_b)
160         CHECK(create_brgemm_matmul_copy_b(copy_B_kernel_, &bgmmc));
161 
162     if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only)
163         CHECK(create_brgemm_matmul_copy_a(copy_A_kernel_, &bgmmc));
164 
165     if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == f32) {
166         CHECK(safe_ptr_assign(
167                 acc_ker_f32_, new cpu_accumulator_1d_t<data_type::f32>()));
168         CHECK(acc_ker_f32_->create_kernel());
169     } else if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == s32) {
170         CHECK(safe_ptr_assign(
171                 acc_ker_s32_, new cpu_accumulator_1d_t<data_type::s32>()));
172         CHECK(acc_ker_s32_->create_kernel());
173     }
174 
175     return status::success;
176 }
177 
178 template <cpu_isa_t isa>
execute_body(const exec_ctx_t & ctx) const179 status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
180     DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
181     DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS);
182     DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
183 
184     brg_matmul_exec_ctx_t brgmm_ctx(
185             ctx, pd(), src_zero_point, wei_zero_point, dst_zero_point);
186 
187     const auto &bgmmc = pd()->get_brgemm_matmul_conf();
188     const bool use_buffer_a
189             = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only;
190     constexpr bool is_amx
191             = one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
192     const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
193 
194     parallel(num_threads, [&](const int ithr, const int nthr) {
195         const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr);
196         const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr);
197         if (ithr_bmn < 0 || ithr_k < 0) return;
198         int start {0}, end {0};
199         balance211(brgmm_ctx.get_parallel_work_amount(),
200                 brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, start, end);
201         int kc_start {0}, kc_end {bgmmc.K_chunks};
202         if (brgmm_ctx.parallel_reduction_is_used())
203             balance211((int)bgmmc.K_chunks, brgmm_ctx.get_num_threads_for_k(),
204                     ithr_k, kc_start, kc_end);
205 
206         if (is_amx) {
207             const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
208             amx_tile_configure(&brg_kernel_palettes_[base_ker_idx][0]);
209         }
210 
211         int b {0}, mc {0}, nc {0};
212         nd_iterator_init(
213                 start, b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
214         while (start < end) {
215             auto m_start = mc * bgmmc.M_chunk_size;
216             auto m_end = nstl::min(
217                     (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks);
218             auto n_start = nc * bgmmc.N_chunk_size;
219             auto n_end = nstl::min(
220                     (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks);
221             for_(int kc = kc_start; kc < kc_end; kc++)
222             for (int nb = n_start; nb < n_end; nb++) {
223                 if (bgmmc.use_buffer_b)
224                     copy_b_chunk_in_buffer(brgmm_ctx, ithr, b, nb, kc);
225                 for (int mb = m_start; mb < m_end; mb++) {
226                     if (use_buffer_a && nb == n_start)
227                         copy_a_chunk_in_buffer(brgmm_ctx, ithr, b, mb, kc);
228                     compute_kernel(
229                             brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start);
230                 }
231             }
232             ++start;
233             nd_iterator_step(
234                     b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
235         }
236         if (is_amx) { amx_tile_release(); }
237     });
238 
239     maybe_reduce_partial_results_and_apply_postops(brgmm_ctx);
240 
241     return status::success;
242 }
243 
244 template <cpu_isa_t isa>
compute_kernel(const brg_matmul_exec_ctx_t & brgmm_ctx,int ithr,int b_idx,int m_blk_idx,int n_blk_idx,int k_chunk_idx,bool do_init) const245 void brgemm_matmul_t<isa>::compute_kernel(
246         const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
247         int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const {
248     constexpr bool is_amx
249             = one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
250     const auto &bgmmc = pd()->get_brgemm_matmul_conf();
251     const auto addr_batch = brgmm_ctx.get_batch_elem_ptr(ithr);
252     const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
253 
254     const auto wsp_tile = brgmm_ctx.get_tile_workspace(ithr);
255     const int m = m_blk_idx * bgmmc.M_blk;
256     const int n = n_blk_idx * bgmmc.N_blk;
257     const int k_blk_idx = k_chunk_idx * bgmmc.brgemm_batch_size;
258 
259     const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk);
260     const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
261     const bool is_last_K_chunk = brgmm_ctx.is_last_K_chunk(k_chunk_idx);
262     const bool is_K_tail = is_last_K_chunk && bgmmc.K_tail > 0;
263 
264     const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
265     const int brg_ker_idx
266             = pd()->get_brg_kernel_idx(do_init, is_M_tail, is_N_tail, false);
267     const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
268     const auto ptr_bias = brgmm_ctx.get_bias_ptr(n);
269     auto ptr_D = brgmm_ctx.get_data_C_ptr(b_idx, m, n);
270     auto ptr_C = (bgmmc.use_buffer_c)
271             ? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx)
272             : ptr_D;
273 
274     const auto zp_comp_a = brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
275     const auto zp_comp_b
276             = brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx);
277     const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
278     const auto &post_ops_binary_rhs_arg_vec
279             = brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
280     const bool post_ops_applicable = bgmmc.post_ops_applicable
281             && (bgmmc.nthr_k <= 1 || bgmmc.K_chunks == 1);
282 
283     if (gemm_batch > 0 && brg_kernel != nullptr) {
284         const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail);
285         if (is_tile_reconf_required)
286             amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
287 
288         brgmm_ctx.init_brgemm_batch_elements_values(
289                 ithr, 0, gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
290 
291         if (post_ops_applicable && is_last_K_chunk && !is_K_tail) {
292             void *scratch = is_amx
293                     ? static_cast<void *>(wsp_tile)
294                     : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr(
295                             ithr, b_idx, n_blk_idx));
296 
297             const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk;
298             const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
299                     ? b_idx / bgmmc.batch_without_first_dim
300                     : 0;
301             const size_t first_mb_matrix_addr_off
302                     = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
303                     + (m * bgmmc.N + n);
304             const brgemm_post_ops_data_t post_ops_data {
305                     static_cast<const void *>(ptr_bias),
306                     brgmm_ctx.get_oscales_ptr(n),
307                     post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n),
308                     dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0),
309                     first_mb_matrix_addr_off,
310                     static_cast<const void *>(zp_comp_a),
311                     static_cast<const void *>(zp_comp_b),
312                     static_cast<const void *>(zp_c_val_ptr)};
313 
314             brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch,
315                     (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
316         } else {
317             brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
318                     (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
319         }
320 
321         if (is_tile_reconf_required)
322             amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
323     }
324     if (is_K_tail) {
325         brgmm_ctx.init_brgemm_batch_elements_values(
326                 ithr, gemm_batch, 1, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
327 
328         const bool use_init_ker = (do_init && gemm_batch == 0);
329         const int brg_ker_idx = pd()->get_brg_kernel_idx(
330                 use_init_ker, is_M_tail, is_N_tail, true);
331         const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get();
332         const bool is_tile_reconf_required
333                 = is_amx && bgmmc.K_tail != bgmmc.K_blk;
334         if (is_tile_reconf_required)
335             amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
336         if (post_ops_applicable) {
337             void *scratch = is_amx
338                     ? static_cast<void *>(wsp_tile)
339                     : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr(
340                             ithr, b_idx, n_blk_idx));
341 
342             const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk;
343             const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
344                     ? b_idx / bgmmc.batch_without_first_dim
345                     : 0;
346             const size_t first_mb_matrix_addr_off
347                     = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
348                     + (m * bgmmc.N + n);
349             const brgemm_post_ops_data_t post_ops_data {
350                     static_cast<const void *>(ptr_bias),
351                     brgmm_ctx.get_oscales_ptr(n),
352                     post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n),
353                     dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0),
354                     first_mb_matrix_addr_off,
355                     static_cast<const void *>(zp_comp_a),
356                     static_cast<const void *>(zp_comp_b),
357                     static_cast<const void *>(zp_c_val_ptr)};
358 
359             brgemm_kernel_execute_postops(brg_kernel_k_tail, 1, addr_batch,
360                     (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
361         } else {
362             brgemm_kernel_execute(brg_kernel_k_tail, 1, addr_batch,
363                     (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
364         }
365         if (is_tile_reconf_required)
366             amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
367     }
368 }
369 
370 template <cpu_isa_t isa>
maybe_reduce_partial_results_and_apply_postops(const brg_matmul_exec_ctx_t & brgmm_ctx) const371 void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
372         const brg_matmul_exec_ctx_t &brgmm_ctx) const {
373     if (!brgmm_ctx.parallel_reduction_is_used()) return;
374 
375     const auto &bgmmc = pd()->get_brgemm_matmul_conf();
376     const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
377 
378     parallel(num_threads, [&](const int ithr, const int nthr) {
379         const int nthr_k = brgmm_ctx.get_num_threads_for_k();
380         const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr);
381         const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr);
382         if (ithr_bmn < 0 || ithr_k < 0) return;
383 
384         const int num_reduction_buffers = nstl::min(nthr_k, bgmmc.K_chunks);
385 
386         int bmn_start {0}, bmn_end {0};
387         int start {0}, end {0};
388         balance211(brgmm_ctx.get_parallel_work_amount(),
389                 brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, bmn_start,
390                 bmn_end);
391         balance211(bmn_end - bmn_start, nthr_k, ithr_k, start, end);
392 
393         int b {0}, mc {0}, nc {0};
394 
395         assert(bgmmc.batch == 1);
396         nd_iterator_init(bmn_start + start, b, bgmmc.batch, mc, bgmmc.M_chunks,
397                 nc, bgmmc.N_chunks);
398         while (start < end) {
399             auto mb_start = mc * bgmmc.M_chunk_size;
400             auto mb_end = nstl::min(
401                     (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks);
402             auto nb_start = nc * bgmmc.N_chunk_size;
403             auto nb_end = nstl::min(
404                     (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks);
405             for (int mb = mb_start; mb < mb_end; mb++) {
406                 const int curr_M_blk
407                         = nstl::min(bgmmc.M - mb * bgmmc.M_blk, bgmmc.M_blk);
408                 const bool is_M_tail = curr_M_blk < bgmmc.M_blk;
409                 const int curr_N_chunk_size
410                         = nstl::min(bgmmc.N, nb_end * bgmmc.N_blk)
411                         - nb_start * bgmmc.N_blk;
412                 char *buf_reduced_base = brgmm_ctx.get_buf_C_par_reduction_ptr(
413                         0, mb, nb_start);
414                 const size_t m_offset = bgmmc.LDC * bgmmc.acc_dt_sz;
415                 for (int r = 1; r < num_reduction_buffers; r++) {
416                     const char *buf_to_reduce_base
417                             = brgmm_ctx.get_buf_C_par_reduction_ptr(
418                                     r, mb, nb_start);
419                     for (int m = 0; m < curr_M_blk; m++) {
420                         accumulate(buf_reduced_base + m * m_offset,
421                                 buf_to_reduce_base + m * m_offset,
422                                 curr_N_chunk_size);
423                     }
424                 }
425                 if (bgmmc.post_ops_applicable) {
426                     for (int nb = nb_start; nb < nb_end; nb++) {
427                         const bool is_N_tail
428                                 = (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk);
429                         const int brg_ker_idx = pd()->get_brg_kernel_idx(
430                                 false, is_M_tail, is_N_tail, false);
431                         const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
432                         const int m = mb * bgmmc.M_blk;
433                         const int n = nb * bgmmc.N_blk;
434                         const auto ptr_bias = brgmm_ctx.get_bias_ptr(n);
435                         auto ptr_D = brgmm_ctx.get_data_C_ptr(b, m, n);
436                         auto ptr_C = brgmm_ctx.get_buf_C_par_reduction_ptr(
437                                 0, mb, nb);
438 
439                         // TODO: support reduction for zp/s8s8 compensations
440                         // computed in copy routines
441                         const auto zp_comp_a
442                                 = brgmm_ctx.get_zp_a_compensation_ptr(ithr, nb);
443                         const auto zp_comp_b
444                                 = brgmm_ctx.get_zp_b_compensation_result_ptr(
445                                         ithr, mb);
446                         const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
447                         const auto &post_ops_binary_rhs_arg_vec
448                                 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
449 
450                         const size_t dst_row_logical_off = mb * bgmmc.M_blk;
451                         const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
452                                 ? b / bgmmc.batch_without_first_dim
453                                 : 0;
454                         const size_t first_mb_matrix_addr_off
455                                 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
456                                 + (m * bgmmc.N + n);
457                         // apply post-ops and convert to dst data type only
458                         constexpr bool skip_accumulation = true;
459                         const brgemm_post_ops_data_t post_ops_data {
460                                 static_cast<const void *>(ptr_bias),
461                                 brgmm_ctx.get_oscales_ptr(n),
462                                 post_ops_binary_rhs_arg_vec.data(),
463                                 static_cast<size_t>(n), dst_row_logical_off,
464                                 brgmm_ctx.get_data_C_ptr(0, 0, 0),
465                                 first_mb_matrix_addr_off,
466                                 static_cast<const void *>(zp_comp_a),
467                                 static_cast<const void *>(zp_comp_b),
468                                 static_cast<const void *>(zp_c_val_ptr),
469                                 skip_accumulation};
470 
471                         brgemm_kernel_execute_postops(brg_kernel, 0, nullptr,
472                                 (void *)ptr_C, (void *)ptr_D, post_ops_data,
473                                 nullptr);
474                     }
475                 }
476             }
477             ++start;
478             nd_iterator_step(
479                     b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
480         }
481     });
482 }
483 
484 template <cpu_isa_t isa>
copy_a_chunk_in_buffer(const brg_matmul_exec_ctx_t & brgmm_ctx,int ithr,int b_idx,int m_blk_idx,int k_chunk_idx) const485 void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
486         const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
487         int m_blk_idx, int k_chunk_idx) const {
488     const auto &bgmmc = pd()->get_brgemm_matmul_conf();
489 
490     auto ctx = jit_brgemm_matmul_copy_a_t::ctx_t();
491     const int k_start = k_chunk_idx * bgmmc.K_chunk_elems;
492     const bool is_K_tail
493             = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0;
494     const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
495     const int gemm_batch_iters = bgmmc.use_buffer_a_tail_only ? 0 : gemm_batch;
496 
497     const int m = m_blk_idx * bgmmc.M_blk;
498     const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk);
499     ctx.current_M_blk = is_M_tail ? bgmmc.M_tail : bgmmc.M_blk;
500     ctx.zp_b_compensation_buffer_ptr
501             = (void *)brgmm_ctx.get_zp_b_compensation_buffer_ptr(
502                     ithr, m_blk_idx);
503     ctx.zp_a_compensation_result_ptr
504             = (void *)brgmm_ctx.get_zp_b_compensation_result_ptr(
505                     ithr, m_blk_idx);
506     ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr();
507     ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr();
508 
509     for (int gb = 0; gb < gemm_batch_iters; gb++) {
510         const int k = k_start + gb * bgmmc.K_blk;
511         ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k);
512         ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(ithr, m_blk_idx, gb);
513         ctx.current_K_blk = nstl::min(bgmmc.K_blk, bgmmc.K);
514         ctx.current_K_start = k;
515 
516         (*copy_A_kernel_)(&ctx);
517     }
518     if (is_K_tail) {
519         const auto K_tail = bgmmc.K % bgmmc.K_blk;
520         const int k = k_start + gemm_batch * bgmmc.K_blk;
521         ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k);
522         ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(
523                 ithr, m_blk_idx, gemm_batch_iters);
524         ctx.current_K_blk = K_tail;
525         ctx.current_K_start = k;
526 
527         (*copy_A_kernel_)(&ctx);
528     }
529 }
530 
531 template <cpu_isa_t isa>
copy_b_chunk_in_buffer(const brg_matmul_exec_ctx_t & brgmm_ctx,int ithr,int b_idx,int n_blk_idx,int k_chunk_idx) const532 void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
533         const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
534         int n_blk_idx, int k_chunk_idx) const {
535     const auto &bgmmc = pd()->get_brgemm_matmul_conf();
536 
537     const int k_start = k_chunk_idx * bgmmc.K_chunk_elems;
538     const bool is_K_tail
539             = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0;
540     const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
541     auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t();
542 
543     const int n = n_blk_idx * bgmmc.N_blk;
544     const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
545     ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk;
546     ctx.zp_a_compensation_ptr
547             = (void *)brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
548     ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
549 
550     int gb = 0;
551     for (; gb < gemm_batch; gb++) {
552         const int k = k_start + gb * bgmmc.K_blk;
553         ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n);
554         ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx);
555         ctx.compensation_ptr
556                 = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
557         ctx.current_K_start = k;
558         ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
559 
560         (*copy_B_kernel_)(&ctx);
561     }
562 
563     if (is_K_tail) {
564         const int k = k_start + gb * bgmmc.K_blk;
565         ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n);
566         ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx);
567         ctx.compensation_ptr
568                 = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
569         ctx.current_K_start = k;
570         ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
571 
572         (*copy_B_kernel_)(&ctx);
573     }
574 }
575 
576 template <cpu_isa_t isa>
accumulate(char * result_ptr,const char * reduce_ptr,size_t size) const577 void brgemm_matmul_t<isa>::accumulate(
578         char *result_ptr, const char *reduce_ptr, size_t size) const {
579     if (pd()->get_brgemm_matmul_conf().acc_dt == f32)
580         acc_ker_f32_->accumulate(
581                 (float *)result_ptr, (const float *)reduce_ptr, size);
582     else if (pd()->get_brgemm_matmul_conf().acc_dt == s32)
583         acc_ker_s32_->accumulate(
584                 (int32_t *)result_ptr, (const int32_t *)reduce_ptr, size);
585     else
586         assert(!"unsupported accumulation data type");
587 }
588 
589 template <cpu_isa_t isa>
590 struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
brg_matmul_exec_ctx_tdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t591     brg_matmul_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd, int32_t src_zp,
592             int32_t wei_zp, int32_t dst_zp)
593         : bgmmc_(pd->get_brgemm_matmul_conf()) {
594 
595         data_A_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
596         data_B_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
597         data_C_ptr_ = CTX_OUT_MEM(char *, DNNL_ARG_DST);
598 
599         bias_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
600         oscales_ptr_ = pd->attr()->output_scales_.scales_;
601         memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
602         const auto &bgmmc = pd->get_brgemm_matmul_conf();
603 
604         batch_element_ptr_ = scratchpad.template get<brgemm_batch_element_t>(
605                 key_brgemm_primitive_batch);
606 
607         const bool use_buffer_a
608                 = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only;
609         buf_A_ptr_ = (use_buffer_a)
610                 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a)
611                 : nullptr;
612 
613         buf_B_ptr_ = (bgmmc.use_buffer_b)
614                 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b)
615                 : nullptr;
616 
617         buf_C_ptr_ = (bgmmc.use_buffer_c)
618                 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
619                 : nullptr;
620 
621         is_amx_ = one_of(
622                 isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
623         wsp_tile_ptr_ = is_amx_
624                 ? ctx.get_scratchpad_grantor().template get<char>(
625                         key_conv_amx_tile_buffer)
626                 : nullptr;
627 
628         const memory_desc_wrapper weights_d(pd->weights_md(0));
629         const dim_t comp_offset = bgmmc_.b_dt_sz
630                 * (weights_d.size() - weights_d.additional_buffer_size());
631         s8s8_compensation_ptr_ = (bgmmc.s8s8_compensation_required)
632                 ? ((bgmmc.use_buffer_b)
633                                 ? scratchpad.template get<int32_t>(
634                                         key_brgemm_primitive_buffer_comp)
635                                 : const_cast<int32_t *>(
636                                         reinterpret_cast<const int32_t *>(
637                                                 &data_B_ptr_[comp_offset])))
638                 : nullptr;
639 
640         zero_point_a_compensations_ptr_ = bgmmc.has_zero_point_a
641                 ? scratchpad.template get<int32_t>(
642                         key_brgemm_primitive_zp_comp_a)
643                 : nullptr;
644         zero_point_b_compensations_ptr_ = bgmmc.has_zero_point_b
645                 ? scratchpad.template get<int32_t>(
646                         key_brgemm_primitive_zp_comp_b)
647                 : nullptr;
648 
649         zero_point_a_negative_val_ = -src_zp;
650         zero_point_b_negative_val_ = -wei_zp;
651         zero_point_mixed_ab_compensation_component_
652                 = bgmmc.K * zero_point_a_negative_val_;
653 
654         zero_point_c_val_ = dst_zp;
655 
656         post_ops_binary_rhs_arg_vec_ = binary_injector::prepare_binary_args(
657                 pd->attr()->post_ops_, ctx);
658         base_brg_ker_idx_ = pd->get_brg_kernel_idx(true, false, false, false);
659         vnni_factor = isa == avx512_core_bf16_amx_int8
660                 ? 4
661                 : isa == avx512_core_bf16_amx_bf16 ? 2 : 1;
662 
663         reorder_zp_a_comp_ptr_ = nullptr;
664         if (bgmmc_.has_zero_point_a && bgmmc_.blocked_B) {
665             // Store the pointer to computed in reorder compensation values to
666             // scale them locally by zp_a value just before usage in post-ops.
667             // Using the single global scaling before parallel section might
668             // produce significant overhead for small problems running in
669             // multitreaded execution mode
670             const size_t reorder_zp_a_comp_offset
671                     = weights_d.size() - weights_d.additional_buffer_size();
672             const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required
673                     ? bgmmc.s8s8_comp_b_str * sizeof(int32_t)
674                     : 0;
675             reorder_zp_a_comp_ptr_
676                     = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(
677                             &data_B_ptr_[reorder_zp_a_comp_offset
678                                     + s8s8_buffer_sz]));
679         }
680 
681         // parallelization
682         parallel_work_amount_ = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks;
683 
684         // The number of threads available during primitive execution may
685         // increase (ex. Eigen threadpool implementation) or decrease
686         // (ex. nested parallelism) compared to the
687         // number of threads available during primitive creation.
688         // So we limit the total number of threads to the
689         // minimum of these two values to prevent potential OOM issues.
690         nthr_ = nstl::min(dnnl_get_current_num_threads(), bgmmc.nthr);
691 
692         nthr_k_ = bgmmc.nthr_k > 0 && bgmmc.nthr_k <= nthr_ ? bgmmc.nthr_k : 1;
693         nthr_bmn_ = nthr_ / nthr_k_;
694         num_threads_used_ = nthr_k_ * nthr_bmn_;
695 
696         // If parallel_work_amount_ == 1 and parallel reduction is not used, we
697         // limit num threads to 1 as parallel(1, ...) does not create parallel
698         // section at all. We do not limit number of threads for case
699         // 1 < parallel_work_amount_ < dnnl_get_max_threads() to avoid potential
700         // overhead on spawning different number of OMP threads from layer to
701         // layer.
702         if (parallel_work_amount_ == 1 && !parallel_reduction_is_used())
703             nthr_ = nthr_bmn_ = nthr_k_ = 1;
704 
705         const bool need_to_calculate_compensation_for_a
706                 = bgmmc.has_zero_point_b;
707         const bool need_to_calculate_compensation_for_b = !IMPLICATION(
708                 (bgmmc.has_zero_point_a || bgmmc.s8s8_compensation_required),
709                 bgmmc.blocked_B);
710         const bool calculate_compensations_in_copy_routines
711                 = need_to_calculate_compensation_for_a
712                 || need_to_calculate_compensation_for_b;
713         // currently parallel reduction is supported only for case of
714         // non-batched problems without computation of any compensations in
715         // copy routines
716         assert(IMPLICATION(parallel_reduction_is_used(),
717                 bgmmc.batch == 1 && !calculate_compensations_in_copy_routines));
718         MAYBE_UNUSED(need_to_calculate_compensation_for_a);
719         MAYBE_UNUSED(need_to_calculate_compensation_for_b);
720         MAYBE_UNUSED(calculate_compensations_in_copy_routines);
721     }
722 
723     // NOTE: gb --> generalized batch, bb --> broadcast batch
get_bb_idxdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t724     int get_bb_idx(int gb_idx, const brgemm_matmul_bcast_desc_t &bd) const {
725         if (!bd.bcast_mask) // no broadcast
726             return gb_idx;
727 
728         int gb_off_before_bcast = utils::rnd_dn(
729                 gb_idx, bd.first_bcast_dim_to_last_batch_dim_prod);
730         int bb_idx = gb_off_before_bcast / (bd.bcast_dims_prod);
731 
732         dim_t cur_bcast_dims_prod = bd.bcast_dims_prod;
733         int mask = 1 << (bgmmc_.batch_ndims - bd.first_bcast_dim - 1);
734         for (int d = bd.first_bcast_dim; d < bd.last_bcast_dim; ++d) {
735             if (bd.bcast_mask & mask) // broadcast
736                 cur_bcast_dims_prod /= bd.batch_dims[d];
737             else {
738                 int cur_b = (gb_idx / bd.gb_off[d]) % bd.batch_dims[d];
739                 bb_idx += cur_b * (bd.gb_off[d] / cur_bcast_dims_prod);
740             }
741             mask >>= 1;
742         }
743         bb_idx += gb_idx % bd.gb_off[bd.last_bcast_dim];
744         return bb_idx;
745     }
746 
get_data_A_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t747     const char *get_data_A_ptr(int b, int m, int k) const {
748         int cur_b = get_bb_idx(b, bgmmc_.bcast_A_desc);
749         return data_A_ptr_ + get_data_A_off(cur_b, m, k);
750     }
751 
get_data_B_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t752     const char *get_data_B_ptr(int b, int k, int n) const {
753         int cur_b = get_bb_idx(b, bgmmc_.bcast_B_desc);
754         return data_B_ptr_ + get_data_B_off(cur_b, k, n);
755     }
756 
get_data_C_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t757     char *get_data_C_ptr(int b, int m, int n) const {
758         return data_C_ptr_ + get_data_C_off(b, m, n);
759     }
760 
get_batch_elem_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t761     brgemm_batch_element_t *get_batch_elem_ptr(int ithr) const {
762         return batch_element_ptr_
763                 + ithr * bgmmc_.brgemm_batch_element_per_thr_sz;
764     }
765 
init_brgemm_batch_elements_valuesdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t766     void init_brgemm_batch_elements_values(int ithr, int brg_batch_start,
767             int brg_batch_iters, int b_idx, int m_blk_idx, int k_blk_idx,
768             int n_blk_idx) const {
769         auto addr_batch = get_batch_elem_ptr(ithr);
770 
771         const int m = m_blk_idx * bgmmc_.M_blk;
772         const int n = n_blk_idx * bgmmc_.N_blk;
773 
774         for (int b_iter = 0; b_iter < brg_batch_iters; b_iter++) {
775             const int brg_batch_idx = brg_batch_start + b_iter;
776             const int k = (k_blk_idx + brg_batch_idx) * bgmmc_.K_blk;
777             addr_batch[b_iter].ptr.A = bgmmc_.use_buffer_a
778                     ? get_buf_A_ptr(ithr, m_blk_idx, brg_batch_idx)
779                     : get_data_A_ptr(b_idx, m, k);
780             addr_batch[b_iter].ptr.B = (bgmmc_.use_buffer_b)
781                     ? get_buf_B_ptr(ithr, brg_batch_idx, n_blk_idx)
782                     : get_data_B_ptr(b_idx, k, n);
783         }
784     }
785 
get_buf_A_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t786     char *get_buf_A_ptr(int ithr, int m_blk_idx, int k_blk_idx) const {
787         if (!bgmmc_.use_buffer_a && !bgmmc_.use_buffer_a_tail_only)
788             return nullptr;
789 
790         const int k_blk_local = bgmmc_.use_buffer_a_tail_only ? 0 : k_blk_idx;
791         const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
792         return buf_A_ptr_ + ithr * bgmmc_.buffer_a_per_thread_sz
793                 + m_blk_local * bgmmc_.buffer_a_chunk_shift_along_m
794                 + k_blk_local * bgmmc_.buffer_a_chunk_sz;
795     }
796 
get_buf_B_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t797     char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx) const {
798         UNUSED(n_blk_idx);
799         if (!bgmmc_.use_buffer_b) return nullptr;
800 
801         return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz
802                 + k_blk_idx * bgmmc_.buffer_b_chunk_sz;
803     }
804 
get_buf_C_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t805     char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const {
806         if (!bgmmc_.use_buffer_c) return nullptr;
807 
808         if (bgmmc_.nthr_k > 1) {
809             const int nthr_k = bgmmc_.nthr_k <= nthr_ ? bgmmc_.nthr_k : 1;
810             const int nthr_bmn = nthr_ / nthr_k;
811             const int ithr_k = ithr / nthr_bmn;
812             return get_buf_C_par_reduction_ptr(ithr_k, m_blk_idx, n_blk_idx);
813         }
814 
815         const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
816         const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
817         const int buf_idx = bgmmc_.N_chunk_size * m_blk_local + n_blk_local;
818 
819         return buf_C_ptr_ + ithr * bgmmc_.buffer_c_per_thread_sz
820                 + buf_idx * bgmmc_.buffer_c_chunk_sz;
821     }
822 
get_buf_C_par_reduction_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t823     char *get_buf_C_par_reduction_ptr(
824             int ithr_k, int m_blk_idx, int n_blk_idx) const {
825         if (bgmmc_.nthr_k <= 1) return nullptr;
826 
827         const int m = m_blk_idx * bgmmc_.M_blk;
828         const int n = n_blk_idx * bgmmc_.N_blk;
829 
830         if (!bgmmc_.post_ops_applicable && ithr_k == 0)
831             return get_data_C_ptr(0, m, n);
832 
833         int k_buf_idx = ithr_k - (!bgmmc_.post_ops_applicable ? 1 : 0);
834         return buf_C_ptr_ + k_buf_idx * bgmmc_.buffer_c_per_thread_sz
835                 + get_data_C_off(0, m, n) * bgmmc_.acc_dt_sz / bgmmc_.c_dt_sz;
836     }
837 
838     // Auxiliary functions for getting offsets with pre-calculated memory
839     // strides for each tensor to get general sulution for all possible
840     // dimension without significant overhead
get_data_A_offdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t841     dim_t get_data_A_off(int b, int m, int k) const {
842         return bgmmc_.A_strides[2] * b + bgmmc_.A_strides[1] * m
843                 + bgmmc_.A_strides[0] * k;
844     }
get_data_B_offdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t845     dim_t get_data_B_off(int b, int k, int n) const {
846         int k_idx = bgmmc_.blocked_B ? k / bgmmc_.wei_k_blk : k;
847         int n_idx = bgmmc_.blocked_B ? n / bgmmc_.wei_n_blk : n;
848 
849         return bgmmc_.B_strides[2] * b + bgmmc_.B_strides[1] * k_idx
850                 + bgmmc_.B_strides[0] * n_idx
851                 + get_data_B_off_within_block(k, n);
852     }
853 
get_data_B_off_within_blockdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t854     dim_t get_data_B_off_within_block(int k, int n) const {
855         using namespace format_tag;
856 
857         if (!bgmmc_.blocked_B) return 0;
858 
859         int x0 = k % bgmmc_.wei_k_blk;
860         int x1 = n % bgmmc_.wei_n_blk;
861         dim_t offset = (x0 / vnni_factor) * vnni_factor * bgmmc_.wei_n_blk
862                 + x1 * vnni_factor + x0 % vnni_factor;
863         return bgmmc_.b_dt_sz * offset;
864     }
865 
get_data_C_offdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t866     dim_t get_data_C_off(int b, int m, int n) const {
867         return bgmmc_.C_strides[2] * b + bgmmc_.C_strides[1] * m
868                 + bgmmc_.C_strides[0] * n;
869     }
870 
get_bias_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t871     const char *get_bias_ptr(int n) const {
872         if (!bgmmc_.with_bias) return nullptr;
873 
874         return bias_ptr_ + n * bgmmc_.bias_dt_sz;
875     }
876 
get_s8s8_comp_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t877     int32_t *get_s8s8_comp_ptr(int ithr, int b, int n_blk_idx) const {
878         if (!bgmmc_.s8s8_compensation_required) return nullptr;
879 
880         const int n_blk_local = bgmmc_.use_buffer_b
881                 ? n_blk_idx % bgmmc_.N_chunk_size
882                 : n_blk_idx;
883         return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str
884                 + b * bgmmc_.s8s8_comp_b_str
885                 + n_blk_local * bgmmc_.s8s8_comp_n_str;
886     }
887 
get_oscales_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t888     const float *get_oscales_ptr(int n) const {
889         return oscales_ptr_ + bgmmc_.is_oscale_per_n * n;
890     }
891 
get_zp_a_neg_val_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t892     const int32_t *get_zp_a_neg_val_ptr() const {
893         return &zero_point_a_negative_val_;
894     }
895 
get_zp_b_neg_val_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t896     const int32_t *get_zp_b_neg_val_ptr() const {
897         return &zero_point_b_negative_val_;
898     }
899 
get_zp_ab_mixed_comp_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t900     const int32_t *get_zp_ab_mixed_comp_ptr() const {
901         return &zero_point_mixed_ab_compensation_component_;
902     }
903 
get_zp_c_val_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t904     const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; }
905 
get_zp_a_compensation_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t906     int32_t *get_zp_a_compensation_ptr(int ithr, int n_blk_idx) const {
907         if (!bgmmc_.has_zero_point_a) return nullptr;
908 
909         const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
910         int32_t *zp_comp = zero_point_a_compensations_ptr_
911                 + ithr * bgmmc_.zp_a_comp_elems_per_thr
912                 + n_blk_local * bgmmc_.zp_a_comp_shift_n;
913 
914         if (bgmmc_.blocked_B) {
915             // Scale computed in reorder compensation values by zp_a value
916             // locally just before usage. Using the single global scaling before
917             // parallel section might produce significant overhead for small
918             // problems running in multitreaded execution mode
919             const int base_offset = n_blk_idx * bgmmc_.wei_n_blk;
920             PRAGMA_OMP_SIMD()
921             for (int b = 0; b < bgmmc_.wei_n_blk; b++)
922                 zp_comp[b] = -zero_point_a_negative_val_
923                         * reorder_zp_a_comp_ptr_[base_offset + b];
924         }
925         return zp_comp;
926     }
927 
get_zp_b_compensation_result_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t928     int32_t *get_zp_b_compensation_result_ptr(int ithr, int m_blk_idx) const {
929         if (!bgmmc_.has_zero_point_b) return nullptr;
930 
931         const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
932         return zero_point_b_compensations_ptr_
933                 + ithr * bgmmc_.zp_b_comp_elems_per_thr
934                 + m_blk_local * bgmmc_.zp_b_comp_result_shift_m;
935     }
936 
get_zp_b_compensation_buffer_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t937     int32_t *get_zp_b_compensation_buffer_ptr(int ithr, int m_blk_idx) const {
938         if (!bgmmc_.has_zero_point_b) return nullptr;
939 
940         const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
941         return get_zp_b_compensation_result_ptr(ithr, 0)
942                 + bgmmc_.zp_b_comp_buffer_start
943                 + m_blk_local * bgmmc_.zp_b_comp_buffer_shift_m;
944     }
945 
get_tile_workspacednnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t946     char *get_tile_workspace(int ithr) const {
947         return is_amx_ ? wsp_tile_ptr_ + ithr * bgmmc_.wsp_tile_per_thr_bytes
948                        : nullptr;
949     }
950 
get_post_ops_binary_rhs_arg_vecdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t951     const std::vector<const void *> &get_post_ops_binary_rhs_arg_vec() const {
952         return post_ops_binary_rhs_arg_vec_;
953     }
954 
get_base_brgemm_kernel_idxdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t955     int get_base_brgemm_kernel_idx() const { return base_brg_ker_idx_; }
956 
is_last_K_chunkdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t957     bool is_last_K_chunk(int k_chunk_idx) const {
958         return k_chunk_idx == bgmmc_.K_chunks - 1;
959     }
960 
get_brgemm_batch_sizednnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t961     int get_brgemm_batch_size(int k_chunk_idx) const {
962         const int last_brgemm_batch_size
963                 = (nstl::max(bgmmc_.K, bgmmc_.K_blk)
964                           - k_chunk_idx * bgmmc_.K_chunk_elems)
965                 / bgmmc_.K_blk;
966         return is_last_K_chunk(k_chunk_idx) ? last_brgemm_batch_size
967                                             : bgmmc_.brgemm_batch_size;
968     }
969 
get_parallel_work_amountdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t970     int get_parallel_work_amount() const { return parallel_work_amount_; }
get_num_threads_for_kdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t971     int get_num_threads_for_k() const { return nthr_k_; }
parallel_reduction_is_useddnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t972     bool parallel_reduction_is_used() const {
973         return nthr_k_ > 1 && bgmmc_.K_chunks > 1;
974     }
get_num_threads_for_bmndnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t975     int get_num_threads_for_bmn() const { return nthr_bmn_; }
976     // ithr = ithr_k * nthr_bmn + ithr_bmn
get_thread_idx_for_kdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t977     int get_thread_idx_for_k(int ithr) const {
978         if (ithr >= num_threads_used_) return -1;
979         const int ithr_k = ithr / nthr_bmn_;
980         return ithr_k < bgmmc_.K_chunks ? ithr_k : -1;
981     }
get_thread_idx_for_bmndnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t982     int get_thread_idx_for_bmn(int ithr) const {
983         if (ithr >= num_threads_used_) return -1;
984         const int ithr_bmn = ithr % nthr_bmn_;
985         return ithr_bmn < parallel_work_amount_ ? ithr_bmn : -1;
986     }
get_num_threads_for_parallelizationdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t987     int get_num_threads_for_parallelization() const { return nthr_; }
988 
989 private:
990     bool is_amx_;
991     const brgemm_matmul_conf_t &bgmmc_;
992     const char *data_A_ptr_;
993     const char *data_B_ptr_;
994     char *data_C_ptr_;
995     brgemm_batch_element_t *batch_element_ptr_;
996 
997     char *buf_A_ptr_;
998     char *buf_B_ptr_;
999     char *buf_C_ptr_;
1000 
1001     char *wsp_tile_ptr_;
1002     const char *bias_ptr_;
1003     const float *oscales_ptr_;
1004     int32_t *s8s8_compensation_ptr_;
1005 
1006     int32_t *zero_point_a_compensations_ptr_;
1007     int32_t *zero_point_b_compensations_ptr_;
1008     int32_t *reorder_zp_a_comp_ptr_;
1009 
1010     int32_t zero_point_a_negative_val_;
1011     int32_t zero_point_b_negative_val_;
1012     int32_t zero_point_mixed_ab_compensation_component_;
1013     int32_t zero_point_c_val_;
1014     std::vector<const void *> post_ops_binary_rhs_arg_vec_;
1015 
1016     int base_brg_ker_idx_;
1017     int vnni_factor;
1018 
1019     // parallelization parameters
1020     int parallel_work_amount_;
1021     int nthr_, nthr_k_, nthr_bmn_, num_threads_used_;
1022 };
1023 
1024 template struct brgemm_matmul_t<avx512_core_bf16_amx_int8>;
1025 template struct brgemm_matmul_t<avx512_core_bf16_amx_bf16>;
1026 template struct brgemm_matmul_t<avx512_core_bf16>;
1027 template struct brgemm_matmul_t<avx512_core_vnni>;
1028 template struct brgemm_matmul_t<avx512_core>;
1029 
1030 } // namespace matmul
1031 } // namespace x64
1032 } // namespace cpu
1033 } // namespace impl
1034 } // namespace dnnl
1035