1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/memory_tracking.hpp"
20 #include "common/tag_traits.hpp"
21 #include "common/type_helpers.hpp"
22 #include "common/utils.hpp"
23
24 #include "cpu/cpu_primitive.hpp"
25
26 #include "cpu/x64/amx_tile_configure.hpp"
27 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28 #include "cpu/x64/matmul/brgemm_matmul.hpp"
29
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34 namespace matmul {
35
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::utils;
38
39 using namespace nstl;
40
41 using namespace data_type;
42
43 template <cpu_isa_t isa>
init(engine_t * engine)44 status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
45 const auto src_dt = src_md_.data_type;
46 const auto wei_dt = weights_md_.data_type;
47 const auto dst_dt = dst_md_.data_type;
48
49 const bool is_f32 = everyone_is(f32, src_dt, wei_dt, dst_dt);
50 const bool is_int8 = one_of(src_dt, u8, s8) && wei_dt == s8
51 && one_of(dst_dt, u8, s8, s32, f32, bf16);
52 const bool is_bf16
53 = everyone_is(bf16, src_dt, wei_dt) && one_of(dst_dt, bf16, f32);
54
55 auto check_bias = [&]() -> bool {
56 const bool is_bia_dt_correct
57 = (is_int8
58 && one_of(weights_md(1)->data_type, f32, s32, s8, u8,
59 bf16))
60 || (is_bf16 && one_of(weights_md(1)->data_type, f32, bf16))
61 || (is_f32 && weights_md(1)->data_type == f32);
62 return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN());
63 };
64
65 auto check_attr_oscale = [&]() -> bool {
66 const auto &oscale = attr()->output_scales_;
67 return IMPLICATION(
68 oscale.mask_ != 0, oscale.mask_ == (1 << (dst_md_.ndims - 1)));
69 };
70
71 auto check_attr_zero_points
72 = [&]() -> bool { return attr()->zero_points_.common(); };
73
74 const bool problem_dt_correct = is_int8 || is_bf16 || is_f32;
75 bool ok = mayiuse(isa) && problem_dt_correct
76 && !has_runtime_dims_or_strides()
77 && attr()->has_default_values(primitive_attr_t::skip_mask_t::oscale
78 | primitive_attr_t::skip_mask_t::zero_points_runtime
79 | primitive_attr_t::skip_mask_t::post_ops
80 | primitive_attr_t::skip_mask_t::sum_dt,
81 dst_dt)
82 && attr()->post_ops_.check_sum_consistent_dt(dst_dt)
83 && check_attr_oscale() && check_attr_zero_points() && check_bias();
84 if (!ok) return status::unimplemented;
85
86 CHECK(init_brgemm_matmul_conf(isa, bgmmc_, *desc(), src_md_, weights_md_,
87 dst_md_, bias_md_, attr_));
88
89 const float alpha = 1.0;
90 const float beta = 1.0;
91 const float beta_init = 0.0;
92 for_(int i_init = 0; i_init < 2; i_init++)
93 for_(int i_M = 0; i_M < 2; i_M++)
94 for_(int i_N = 0; i_N < 2; i_N++)
95 for (int i_K = 0; i_K < 2; i_K++) {
96 auto vbeta = (i_init) ? beta_init : beta;
97 auto vM = (i_M) ? bgmmc_.M_tail : bgmmc_.M_blk;
98 auto vN = (i_N) ? bgmmc_.N_tail : bgmmc_.N_blk;
99 auto vK = (i_K) ? bgmmc_.K_tail : bgmmc_.K_blk;
100
101 int idx = get_brg_kernel_idx(i_init, i_M, i_N, i_K);
102 if (idx < 0) continue;
103 brgemm_t &brg = brg_descs_[idx];
104 auto LDA = i_K && bgmmc_.use_buffer_a_tail_only
105 ? (dim_t)bgmmc_.wei_k_blk
106 : bgmmc_.LDA;
107 CHECK(brgemm_desc_init(&brg, isa, bgmmc_.brg_type, bgmmc_.src_dt,
108 bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta,
109 LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK));
110
111 auto LDD = bgmmc_.N;
112 CHECK(brgemm_desc_set_postops(
113 &brg, attr(), &dst_md_, LDD, bgmmc_.bia_dt));
114
115 brgemm_attr_t brgattr;
116 constexpr bool is_amx = one_of(
117 isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
118 if (is_amx) {
119 brgattr.max_bs = bgmmc_.brgemm_batch_size;
120 brgattr.wary_tail_read = false;
121
122 // TODO: change expected sizes to local chunks wrt L2 blocking
123 brgattr.hint_expected_A_size = vM * vK * bgmmc_.brgemm_batch_size;
124 brgattr.hint_expected_B_size = vN * vK * bgmmc_.brgemm_batch_size;
125 brgattr.hint_expected_C_size = vM * vN * bgmmc_.brgemm_batch_size;
126 brgattr.hint_innermost_loop = brgemm_ld_loop_innermost;
127 }
128
129 brgattr.generate_skip_accumulation
130 = bgmmc_.post_ops_applicable && bgmmc_.nthr_k > 1;
131
132 CHECK(brgemm_desc_set_attr(&brg, brgattr));
133 }
134
135 auto scratchpad = scratchpad_registry().registrar();
136 init_scratchpad(scratchpad, bgmmc_);
137
138 return status::success;
139 }
140
141 template <cpu_isa_t isa>
init(engine_t * engine)142 status_t brgemm_matmul_t<isa>::init(engine_t *engine) {
143 for_(int i_M = 0; i_M < 2; i_M++)
144 for_(int i_N = 0; i_N < 2; i_N++)
145 for_(int i_K = 0; i_K < 2; i_K++)
146 for (int i_init = 0; i_init < 2; i_init++) {
147 int idx = pd()->get_brg_kernel_idx(i_init, i_M, i_N, i_K);
148 if (idx < 0) continue;
149
150 brgemm_kernel_t *ker = nullptr;
151 CHECK(brgemm_kernel_create(&ker, pd()->get_brg_desc(idx)));
152 CHECK(safe_ptr_assign(brg_kernels_[idx], ker));
153 if (one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16))
154 CHECK(brgemm_init_tiles(
155 pd()->get_brg_desc(idx), &brg_kernel_palettes_[idx][0]));
156 }
157
158 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
159 if (bgmmc.use_buffer_b)
160 CHECK(create_brgemm_matmul_copy_b(copy_B_kernel_, &bgmmc));
161
162 if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only)
163 CHECK(create_brgemm_matmul_copy_a(copy_A_kernel_, &bgmmc));
164
165 if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == f32) {
166 CHECK(safe_ptr_assign(
167 acc_ker_f32_, new cpu_accumulator_1d_t<data_type::f32>()));
168 CHECK(acc_ker_f32_->create_kernel());
169 } else if (bgmmc.nthr_k > 1 && bgmmc.acc_dt == s32) {
170 CHECK(safe_ptr_assign(
171 acc_ker_s32_, new cpu_accumulator_1d_t<data_type::s32>()));
172 CHECK(acc_ker_s32_->create_kernel());
173 }
174
175 return status::success;
176 }
177
178 template <cpu_isa_t isa>
execute_body(const exec_ctx_t & ctx) const179 status_t brgemm_matmul_t<isa>::execute_body(const exec_ctx_t &ctx) const {
180 DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
181 DEFINE_ZERO_POINT_VALUE(wei_zero_point, DNNL_ARG_WEIGHTS);
182 DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
183
184 brg_matmul_exec_ctx_t brgmm_ctx(
185 ctx, pd(), src_zero_point, wei_zero_point, dst_zero_point);
186
187 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
188 const bool use_buffer_a
189 = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only;
190 constexpr bool is_amx
191 = one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
192 const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
193
194 parallel(num_threads, [&](const int ithr, const int nthr) {
195 const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr);
196 const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr);
197 if (ithr_bmn < 0 || ithr_k < 0) return;
198 int start {0}, end {0};
199 balance211(brgmm_ctx.get_parallel_work_amount(),
200 brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, start, end);
201 int kc_start {0}, kc_end {bgmmc.K_chunks};
202 if (brgmm_ctx.parallel_reduction_is_used())
203 balance211((int)bgmmc.K_chunks, brgmm_ctx.get_num_threads_for_k(),
204 ithr_k, kc_start, kc_end);
205
206 if (is_amx) {
207 const auto base_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
208 amx_tile_configure(&brg_kernel_palettes_[base_ker_idx][0]);
209 }
210
211 int b {0}, mc {0}, nc {0};
212 nd_iterator_init(
213 start, b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
214 while (start < end) {
215 auto m_start = mc * bgmmc.M_chunk_size;
216 auto m_end = nstl::min(
217 (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks);
218 auto n_start = nc * bgmmc.N_chunk_size;
219 auto n_end = nstl::min(
220 (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks);
221 for_(int kc = kc_start; kc < kc_end; kc++)
222 for (int nb = n_start; nb < n_end; nb++) {
223 if (bgmmc.use_buffer_b)
224 copy_b_chunk_in_buffer(brgmm_ctx, ithr, b, nb, kc);
225 for (int mb = m_start; mb < m_end; mb++) {
226 if (use_buffer_a && nb == n_start)
227 copy_a_chunk_in_buffer(brgmm_ctx, ithr, b, mb, kc);
228 compute_kernel(
229 brgmm_ctx, ithr, b, mb, nb, kc, kc == kc_start);
230 }
231 }
232 ++start;
233 nd_iterator_step(
234 b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
235 }
236 if (is_amx) { amx_tile_release(); }
237 });
238
239 maybe_reduce_partial_results_and_apply_postops(brgmm_ctx);
240
241 return status::success;
242 }
243
244 template <cpu_isa_t isa>
compute_kernel(const brg_matmul_exec_ctx_t & brgmm_ctx,int ithr,int b_idx,int m_blk_idx,int n_blk_idx,int k_chunk_idx,bool do_init) const245 void brgemm_matmul_t<isa>::compute_kernel(
246 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
247 int m_blk_idx, int n_blk_idx, int k_chunk_idx, bool do_init) const {
248 constexpr bool is_amx
249 = one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
250 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
251 const auto addr_batch = brgmm_ctx.get_batch_elem_ptr(ithr);
252 const int base_brg_ker_idx = brgmm_ctx.get_base_brgemm_kernel_idx();
253
254 const auto wsp_tile = brgmm_ctx.get_tile_workspace(ithr);
255 const int m = m_blk_idx * bgmmc.M_blk;
256 const int n = n_blk_idx * bgmmc.N_blk;
257 const int k_blk_idx = k_chunk_idx * bgmmc.brgemm_batch_size;
258
259 const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk);
260 const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
261 const bool is_last_K_chunk = brgmm_ctx.is_last_K_chunk(k_chunk_idx);
262 const bool is_K_tail = is_last_K_chunk && bgmmc.K_tail > 0;
263
264 const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
265 const int brg_ker_idx
266 = pd()->get_brg_kernel_idx(do_init, is_M_tail, is_N_tail, false);
267 const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
268 const auto ptr_bias = brgmm_ctx.get_bias_ptr(n);
269 auto ptr_D = brgmm_ctx.get_data_C_ptr(b_idx, m, n);
270 auto ptr_C = (bgmmc.use_buffer_c)
271 ? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx)
272 : ptr_D;
273
274 const auto zp_comp_a = brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
275 const auto zp_comp_b
276 = brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx);
277 const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
278 const auto &post_ops_binary_rhs_arg_vec
279 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
280 const bool post_ops_applicable = bgmmc.post_ops_applicable
281 && (bgmmc.nthr_k <= 1 || bgmmc.K_chunks == 1);
282
283 if (gemm_batch > 0 && brg_kernel != nullptr) {
284 const bool is_tile_reconf_required = is_amx && (is_M_tail || is_N_tail);
285 if (is_tile_reconf_required)
286 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
287
288 brgmm_ctx.init_brgemm_batch_elements_values(
289 ithr, 0, gemm_batch, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
290
291 if (post_ops_applicable && is_last_K_chunk && !is_K_tail) {
292 void *scratch = is_amx
293 ? static_cast<void *>(wsp_tile)
294 : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr(
295 ithr, b_idx, n_blk_idx));
296
297 const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk;
298 const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
299 ? b_idx / bgmmc.batch_without_first_dim
300 : 0;
301 const size_t first_mb_matrix_addr_off
302 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
303 + (m * bgmmc.N + n);
304 const brgemm_post_ops_data_t post_ops_data {
305 static_cast<const void *>(ptr_bias),
306 brgmm_ctx.get_oscales_ptr(n),
307 post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n),
308 dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0),
309 first_mb_matrix_addr_off,
310 static_cast<const void *>(zp_comp_a),
311 static_cast<const void *>(zp_comp_b),
312 static_cast<const void *>(zp_c_val_ptr)};
313
314 brgemm_kernel_execute_postops(brg_kernel, gemm_batch, addr_batch,
315 (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
316 } else {
317 brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
318 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
319 }
320
321 if (is_tile_reconf_required)
322 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
323 }
324 if (is_K_tail) {
325 brgmm_ctx.init_brgemm_batch_elements_values(
326 ithr, gemm_batch, 1, b_idx, m_blk_idx, k_blk_idx, n_blk_idx);
327
328 const bool use_init_ker = (do_init && gemm_batch == 0);
329 const int brg_ker_idx = pd()->get_brg_kernel_idx(
330 use_init_ker, is_M_tail, is_N_tail, true);
331 const auto brg_kernel_k_tail = brg_kernels_[brg_ker_idx].get();
332 const bool is_tile_reconf_required
333 = is_amx && bgmmc.K_tail != bgmmc.K_blk;
334 if (is_tile_reconf_required)
335 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
336 if (post_ops_applicable) {
337 void *scratch = is_amx
338 ? static_cast<void *>(wsp_tile)
339 : static_cast<void *>(brgmm_ctx.get_s8s8_comp_ptr(
340 ithr, b_idx, n_blk_idx));
341
342 const size_t dst_row_logical_off = m_blk_idx * bgmmc.M_blk;
343 const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
344 ? b_idx / bgmmc.batch_without_first_dim
345 : 0;
346 const size_t first_mb_matrix_addr_off
347 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
348 + (m * bgmmc.N + n);
349 const brgemm_post_ops_data_t post_ops_data {
350 static_cast<const void *>(ptr_bias),
351 brgmm_ctx.get_oscales_ptr(n),
352 post_ops_binary_rhs_arg_vec.data(), static_cast<size_t>(n),
353 dst_row_logical_off, brgmm_ctx.get_data_C_ptr(0, 0, 0),
354 first_mb_matrix_addr_off,
355 static_cast<const void *>(zp_comp_a),
356 static_cast<const void *>(zp_comp_b),
357 static_cast<const void *>(zp_c_val_ptr)};
358
359 brgemm_kernel_execute_postops(brg_kernel_k_tail, 1, addr_batch,
360 (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
361 } else {
362 brgemm_kernel_execute(brg_kernel_k_tail, 1, addr_batch,
363 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
364 }
365 if (is_tile_reconf_required)
366 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
367 }
368 }
369
370 template <cpu_isa_t isa>
maybe_reduce_partial_results_and_apply_postops(const brg_matmul_exec_ctx_t & brgmm_ctx) const371 void brgemm_matmul_t<isa>::maybe_reduce_partial_results_and_apply_postops(
372 const brg_matmul_exec_ctx_t &brgmm_ctx) const {
373 if (!brgmm_ctx.parallel_reduction_is_used()) return;
374
375 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
376 const int num_threads = brgmm_ctx.get_num_threads_for_parallelization();
377
378 parallel(num_threads, [&](const int ithr, const int nthr) {
379 const int nthr_k = brgmm_ctx.get_num_threads_for_k();
380 const int ithr_bmn = brgmm_ctx.get_thread_idx_for_bmn(ithr);
381 const int ithr_k = brgmm_ctx.get_thread_idx_for_k(ithr);
382 if (ithr_bmn < 0 || ithr_k < 0) return;
383
384 const int num_reduction_buffers = nstl::min(nthr_k, bgmmc.K_chunks);
385
386 int bmn_start {0}, bmn_end {0};
387 int start {0}, end {0};
388 balance211(brgmm_ctx.get_parallel_work_amount(),
389 brgmm_ctx.get_num_threads_for_bmn(), ithr_bmn, bmn_start,
390 bmn_end);
391 balance211(bmn_end - bmn_start, nthr_k, ithr_k, start, end);
392
393 int b {0}, mc {0}, nc {0};
394
395 assert(bgmmc.batch == 1);
396 nd_iterator_init(bmn_start + start, b, bgmmc.batch, mc, bgmmc.M_chunks,
397 nc, bgmmc.N_chunks);
398 while (start < end) {
399 auto mb_start = mc * bgmmc.M_chunk_size;
400 auto mb_end = nstl::min(
401 (mc + 1) * bgmmc.M_chunk_size, bgmmc.num_M_blocks);
402 auto nb_start = nc * bgmmc.N_chunk_size;
403 auto nb_end = nstl::min(
404 (nc + 1) * bgmmc.N_chunk_size, bgmmc.num_N_blocks);
405 for (int mb = mb_start; mb < mb_end; mb++) {
406 const int curr_M_blk
407 = nstl::min(bgmmc.M - mb * bgmmc.M_blk, bgmmc.M_blk);
408 const bool is_M_tail = curr_M_blk < bgmmc.M_blk;
409 const int curr_N_chunk_size
410 = nstl::min(bgmmc.N, nb_end * bgmmc.N_blk)
411 - nb_start * bgmmc.N_blk;
412 char *buf_reduced_base = brgmm_ctx.get_buf_C_par_reduction_ptr(
413 0, mb, nb_start);
414 const size_t m_offset = bgmmc.LDC * bgmmc.acc_dt_sz;
415 for (int r = 1; r < num_reduction_buffers; r++) {
416 const char *buf_to_reduce_base
417 = brgmm_ctx.get_buf_C_par_reduction_ptr(
418 r, mb, nb_start);
419 for (int m = 0; m < curr_M_blk; m++) {
420 accumulate(buf_reduced_base + m * m_offset,
421 buf_to_reduce_base + m * m_offset,
422 curr_N_chunk_size);
423 }
424 }
425 if (bgmmc.post_ops_applicable) {
426 for (int nb = nb_start; nb < nb_end; nb++) {
427 const bool is_N_tail
428 = (bgmmc.N - nb * bgmmc.N_blk < bgmmc.N_blk);
429 const int brg_ker_idx = pd()->get_brg_kernel_idx(
430 false, is_M_tail, is_N_tail, false);
431 const auto brg_kernel = brg_kernels_[brg_ker_idx].get();
432 const int m = mb * bgmmc.M_blk;
433 const int n = nb * bgmmc.N_blk;
434 const auto ptr_bias = brgmm_ctx.get_bias_ptr(n);
435 auto ptr_D = brgmm_ctx.get_data_C_ptr(b, m, n);
436 auto ptr_C = brgmm_ctx.get_buf_C_par_reduction_ptr(
437 0, mb, nb);
438
439 // TODO: support reduction for zp/s8s8 compensations
440 // computed in copy routines
441 const auto zp_comp_a
442 = brgmm_ctx.get_zp_a_compensation_ptr(ithr, nb);
443 const auto zp_comp_b
444 = brgmm_ctx.get_zp_b_compensation_result_ptr(
445 ithr, mb);
446 const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr();
447 const auto &post_ops_binary_rhs_arg_vec
448 = brgmm_ctx.get_post_ops_binary_rhs_arg_vec();
449
450 const size_t dst_row_logical_off = mb * bgmmc.M_blk;
451 const size_t batch_first_dim_idx = bgmmc.batch_ndims > 1
452 ? b / bgmmc.batch_without_first_dim
453 : 0;
454 const size_t first_mb_matrix_addr_off
455 = batch_first_dim_idx * (bgmmc.M * bgmmc.N)
456 + (m * bgmmc.N + n);
457 // apply post-ops and convert to dst data type only
458 constexpr bool skip_accumulation = true;
459 const brgemm_post_ops_data_t post_ops_data {
460 static_cast<const void *>(ptr_bias),
461 brgmm_ctx.get_oscales_ptr(n),
462 post_ops_binary_rhs_arg_vec.data(),
463 static_cast<size_t>(n), dst_row_logical_off,
464 brgmm_ctx.get_data_C_ptr(0, 0, 0),
465 first_mb_matrix_addr_off,
466 static_cast<const void *>(zp_comp_a),
467 static_cast<const void *>(zp_comp_b),
468 static_cast<const void *>(zp_c_val_ptr),
469 skip_accumulation};
470
471 brgemm_kernel_execute_postops(brg_kernel, 0, nullptr,
472 (void *)ptr_C, (void *)ptr_D, post_ops_data,
473 nullptr);
474 }
475 }
476 }
477 ++start;
478 nd_iterator_step(
479 b, bgmmc.batch, mc, bgmmc.M_chunks, nc, bgmmc.N_chunks);
480 }
481 });
482 }
483
484 template <cpu_isa_t isa>
copy_a_chunk_in_buffer(const brg_matmul_exec_ctx_t & brgmm_ctx,int ithr,int b_idx,int m_blk_idx,int k_chunk_idx) const485 void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
486 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
487 int m_blk_idx, int k_chunk_idx) const {
488 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
489
490 auto ctx = jit_brgemm_matmul_copy_a_t::ctx_t();
491 const int k_start = k_chunk_idx * bgmmc.K_chunk_elems;
492 const bool is_K_tail
493 = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0;
494 const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
495 const int gemm_batch_iters = bgmmc.use_buffer_a_tail_only ? 0 : gemm_batch;
496
497 const int m = m_blk_idx * bgmmc.M_blk;
498 const bool is_M_tail = (bgmmc.M - m < bgmmc.M_blk);
499 ctx.current_M_blk = is_M_tail ? bgmmc.M_tail : bgmmc.M_blk;
500 ctx.zp_b_compensation_buffer_ptr
501 = (void *)brgmm_ctx.get_zp_b_compensation_buffer_ptr(
502 ithr, m_blk_idx);
503 ctx.zp_a_compensation_result_ptr
504 = (void *)brgmm_ctx.get_zp_b_compensation_result_ptr(
505 ithr, m_blk_idx);
506 ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr();
507 ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr();
508
509 for (int gb = 0; gb < gemm_batch_iters; gb++) {
510 const int k = k_start + gb * bgmmc.K_blk;
511 ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k);
512 ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(ithr, m_blk_idx, gb);
513 ctx.current_K_blk = nstl::min(bgmmc.K_blk, bgmmc.K);
514 ctx.current_K_start = k;
515
516 (*copy_A_kernel_)(&ctx);
517 }
518 if (is_K_tail) {
519 const auto K_tail = bgmmc.K % bgmmc.K_blk;
520 const int k = k_start + gemm_batch * bgmmc.K_blk;
521 ctx.src = (void *)brgmm_ctx.get_data_A_ptr(b_idx, m, k);
522 ctx.tr_src = (void *)brgmm_ctx.get_buf_A_ptr(
523 ithr, m_blk_idx, gemm_batch_iters);
524 ctx.current_K_blk = K_tail;
525 ctx.current_K_start = k;
526
527 (*copy_A_kernel_)(&ctx);
528 }
529 }
530
531 template <cpu_isa_t isa>
copy_b_chunk_in_buffer(const brg_matmul_exec_ctx_t & brgmm_ctx,int ithr,int b_idx,int n_blk_idx,int k_chunk_idx) const532 void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
533 const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, int b_idx,
534 int n_blk_idx, int k_chunk_idx) const {
535 const auto &bgmmc = pd()->get_brgemm_matmul_conf();
536
537 const int k_start = k_chunk_idx * bgmmc.K_chunk_elems;
538 const bool is_K_tail
539 = brgmm_ctx.is_last_K_chunk(k_chunk_idx) && bgmmc.K_tail > 0;
540 const int gemm_batch = brgmm_ctx.get_brgemm_batch_size(k_chunk_idx);
541 auto ctx = jit_brgemm_matmul_copy_b_t::ctx_t();
542
543 const int n = n_blk_idx * bgmmc.N_blk;
544 const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk);
545 ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk;
546 ctx.zp_a_compensation_ptr
547 = (void *)brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx);
548 ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
549
550 int gb = 0;
551 for (; gb < gemm_batch; gb++) {
552 const int k = k_start + gb * bgmmc.K_blk;
553 ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n);
554 ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx);
555 ctx.compensation_ptr
556 = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
557 ctx.current_K_start = k;
558 ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
559
560 (*copy_B_kernel_)(&ctx);
561 }
562
563 if (is_K_tail) {
564 const int k = k_start + gb * bgmmc.K_blk;
565 ctx.src = (void *)brgmm_ctx.get_data_B_ptr(b_idx, k, n);
566 ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(ithr, gb, n_blk_idx);
567 ctx.compensation_ptr
568 = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
569 ctx.current_K_start = k;
570 ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
571
572 (*copy_B_kernel_)(&ctx);
573 }
574 }
575
576 template <cpu_isa_t isa>
accumulate(char * result_ptr,const char * reduce_ptr,size_t size) const577 void brgemm_matmul_t<isa>::accumulate(
578 char *result_ptr, const char *reduce_ptr, size_t size) const {
579 if (pd()->get_brgemm_matmul_conf().acc_dt == f32)
580 acc_ker_f32_->accumulate(
581 (float *)result_ptr, (const float *)reduce_ptr, size);
582 else if (pd()->get_brgemm_matmul_conf().acc_dt == s32)
583 acc_ker_s32_->accumulate(
584 (int32_t *)result_ptr, (const int32_t *)reduce_ptr, size);
585 else
586 assert(!"unsupported accumulation data type");
587 }
588
589 template <cpu_isa_t isa>
590 struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
brg_matmul_exec_ctx_tdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t591 brg_matmul_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd, int32_t src_zp,
592 int32_t wei_zp, int32_t dst_zp)
593 : bgmmc_(pd->get_brgemm_matmul_conf()) {
594
595 data_A_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
596 data_B_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
597 data_C_ptr_ = CTX_OUT_MEM(char *, DNNL_ARG_DST);
598
599 bias_ptr_ = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
600 oscales_ptr_ = pd->attr()->output_scales_.scales_;
601 memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
602 const auto &bgmmc = pd->get_brgemm_matmul_conf();
603
604 batch_element_ptr_ = scratchpad.template get<brgemm_batch_element_t>(
605 key_brgemm_primitive_batch);
606
607 const bool use_buffer_a
608 = bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only;
609 buf_A_ptr_ = (use_buffer_a)
610 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a)
611 : nullptr;
612
613 buf_B_ptr_ = (bgmmc.use_buffer_b)
614 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b)
615 : nullptr;
616
617 buf_C_ptr_ = (bgmmc.use_buffer_c)
618 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
619 : nullptr;
620
621 is_amx_ = one_of(
622 isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
623 wsp_tile_ptr_ = is_amx_
624 ? ctx.get_scratchpad_grantor().template get<char>(
625 key_conv_amx_tile_buffer)
626 : nullptr;
627
628 const memory_desc_wrapper weights_d(pd->weights_md(0));
629 const dim_t comp_offset = bgmmc_.b_dt_sz
630 * (weights_d.size() - weights_d.additional_buffer_size());
631 s8s8_compensation_ptr_ = (bgmmc.s8s8_compensation_required)
632 ? ((bgmmc.use_buffer_b)
633 ? scratchpad.template get<int32_t>(
634 key_brgemm_primitive_buffer_comp)
635 : const_cast<int32_t *>(
636 reinterpret_cast<const int32_t *>(
637 &data_B_ptr_[comp_offset])))
638 : nullptr;
639
640 zero_point_a_compensations_ptr_ = bgmmc.has_zero_point_a
641 ? scratchpad.template get<int32_t>(
642 key_brgemm_primitive_zp_comp_a)
643 : nullptr;
644 zero_point_b_compensations_ptr_ = bgmmc.has_zero_point_b
645 ? scratchpad.template get<int32_t>(
646 key_brgemm_primitive_zp_comp_b)
647 : nullptr;
648
649 zero_point_a_negative_val_ = -src_zp;
650 zero_point_b_negative_val_ = -wei_zp;
651 zero_point_mixed_ab_compensation_component_
652 = bgmmc.K * zero_point_a_negative_val_;
653
654 zero_point_c_val_ = dst_zp;
655
656 post_ops_binary_rhs_arg_vec_ = binary_injector::prepare_binary_args(
657 pd->attr()->post_ops_, ctx);
658 base_brg_ker_idx_ = pd->get_brg_kernel_idx(true, false, false, false);
659 vnni_factor = isa == avx512_core_bf16_amx_int8
660 ? 4
661 : isa == avx512_core_bf16_amx_bf16 ? 2 : 1;
662
663 reorder_zp_a_comp_ptr_ = nullptr;
664 if (bgmmc_.has_zero_point_a && bgmmc_.blocked_B) {
665 // Store the pointer to computed in reorder compensation values to
666 // scale them locally by zp_a value just before usage in post-ops.
667 // Using the single global scaling before parallel section might
668 // produce significant overhead for small problems running in
669 // multitreaded execution mode
670 const size_t reorder_zp_a_comp_offset
671 = weights_d.size() - weights_d.additional_buffer_size();
672 const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required
673 ? bgmmc.s8s8_comp_b_str * sizeof(int32_t)
674 : 0;
675 reorder_zp_a_comp_ptr_
676 = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(
677 &data_B_ptr_[reorder_zp_a_comp_offset
678 + s8s8_buffer_sz]));
679 }
680
681 // parallelization
682 parallel_work_amount_ = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks;
683
684 // The number of threads available during primitive execution may
685 // increase (ex. Eigen threadpool implementation) or decrease
686 // (ex. nested parallelism) compared to the
687 // number of threads available during primitive creation.
688 // So we limit the total number of threads to the
689 // minimum of these two values to prevent potential OOM issues.
690 nthr_ = nstl::min(dnnl_get_current_num_threads(), bgmmc.nthr);
691
692 nthr_k_ = bgmmc.nthr_k > 0 && bgmmc.nthr_k <= nthr_ ? bgmmc.nthr_k : 1;
693 nthr_bmn_ = nthr_ / nthr_k_;
694 num_threads_used_ = nthr_k_ * nthr_bmn_;
695
696 // If parallel_work_amount_ == 1 and parallel reduction is not used, we
697 // limit num threads to 1 as parallel(1, ...) does not create parallel
698 // section at all. We do not limit number of threads for case
699 // 1 < parallel_work_amount_ < dnnl_get_max_threads() to avoid potential
700 // overhead on spawning different number of OMP threads from layer to
701 // layer.
702 if (parallel_work_amount_ == 1 && !parallel_reduction_is_used())
703 nthr_ = nthr_bmn_ = nthr_k_ = 1;
704
705 const bool need_to_calculate_compensation_for_a
706 = bgmmc.has_zero_point_b;
707 const bool need_to_calculate_compensation_for_b = !IMPLICATION(
708 (bgmmc.has_zero_point_a || bgmmc.s8s8_compensation_required),
709 bgmmc.blocked_B);
710 const bool calculate_compensations_in_copy_routines
711 = need_to_calculate_compensation_for_a
712 || need_to_calculate_compensation_for_b;
713 // currently parallel reduction is supported only for case of
714 // non-batched problems without computation of any compensations in
715 // copy routines
716 assert(IMPLICATION(parallel_reduction_is_used(),
717 bgmmc.batch == 1 && !calculate_compensations_in_copy_routines));
718 MAYBE_UNUSED(need_to_calculate_compensation_for_a);
719 MAYBE_UNUSED(need_to_calculate_compensation_for_b);
720 MAYBE_UNUSED(calculate_compensations_in_copy_routines);
721 }
722
723 // NOTE: gb --> generalized batch, bb --> broadcast batch
get_bb_idxdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t724 int get_bb_idx(int gb_idx, const brgemm_matmul_bcast_desc_t &bd) const {
725 if (!bd.bcast_mask) // no broadcast
726 return gb_idx;
727
728 int gb_off_before_bcast = utils::rnd_dn(
729 gb_idx, bd.first_bcast_dim_to_last_batch_dim_prod);
730 int bb_idx = gb_off_before_bcast / (bd.bcast_dims_prod);
731
732 dim_t cur_bcast_dims_prod = bd.bcast_dims_prod;
733 int mask = 1 << (bgmmc_.batch_ndims - bd.first_bcast_dim - 1);
734 for (int d = bd.first_bcast_dim; d < bd.last_bcast_dim; ++d) {
735 if (bd.bcast_mask & mask) // broadcast
736 cur_bcast_dims_prod /= bd.batch_dims[d];
737 else {
738 int cur_b = (gb_idx / bd.gb_off[d]) % bd.batch_dims[d];
739 bb_idx += cur_b * (bd.gb_off[d] / cur_bcast_dims_prod);
740 }
741 mask >>= 1;
742 }
743 bb_idx += gb_idx % bd.gb_off[bd.last_bcast_dim];
744 return bb_idx;
745 }
746
get_data_A_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t747 const char *get_data_A_ptr(int b, int m, int k) const {
748 int cur_b = get_bb_idx(b, bgmmc_.bcast_A_desc);
749 return data_A_ptr_ + get_data_A_off(cur_b, m, k);
750 }
751
get_data_B_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t752 const char *get_data_B_ptr(int b, int k, int n) const {
753 int cur_b = get_bb_idx(b, bgmmc_.bcast_B_desc);
754 return data_B_ptr_ + get_data_B_off(cur_b, k, n);
755 }
756
get_data_C_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t757 char *get_data_C_ptr(int b, int m, int n) const {
758 return data_C_ptr_ + get_data_C_off(b, m, n);
759 }
760
get_batch_elem_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t761 brgemm_batch_element_t *get_batch_elem_ptr(int ithr) const {
762 return batch_element_ptr_
763 + ithr * bgmmc_.brgemm_batch_element_per_thr_sz;
764 }
765
init_brgemm_batch_elements_valuesdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t766 void init_brgemm_batch_elements_values(int ithr, int brg_batch_start,
767 int brg_batch_iters, int b_idx, int m_blk_idx, int k_blk_idx,
768 int n_blk_idx) const {
769 auto addr_batch = get_batch_elem_ptr(ithr);
770
771 const int m = m_blk_idx * bgmmc_.M_blk;
772 const int n = n_blk_idx * bgmmc_.N_blk;
773
774 for (int b_iter = 0; b_iter < brg_batch_iters; b_iter++) {
775 const int brg_batch_idx = brg_batch_start + b_iter;
776 const int k = (k_blk_idx + brg_batch_idx) * bgmmc_.K_blk;
777 addr_batch[b_iter].ptr.A = bgmmc_.use_buffer_a
778 ? get_buf_A_ptr(ithr, m_blk_idx, brg_batch_idx)
779 : get_data_A_ptr(b_idx, m, k);
780 addr_batch[b_iter].ptr.B = (bgmmc_.use_buffer_b)
781 ? get_buf_B_ptr(ithr, brg_batch_idx, n_blk_idx)
782 : get_data_B_ptr(b_idx, k, n);
783 }
784 }
785
get_buf_A_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t786 char *get_buf_A_ptr(int ithr, int m_blk_idx, int k_blk_idx) const {
787 if (!bgmmc_.use_buffer_a && !bgmmc_.use_buffer_a_tail_only)
788 return nullptr;
789
790 const int k_blk_local = bgmmc_.use_buffer_a_tail_only ? 0 : k_blk_idx;
791 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
792 return buf_A_ptr_ + ithr * bgmmc_.buffer_a_per_thread_sz
793 + m_blk_local * bgmmc_.buffer_a_chunk_shift_along_m
794 + k_blk_local * bgmmc_.buffer_a_chunk_sz;
795 }
796
get_buf_B_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t797 char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx) const {
798 UNUSED(n_blk_idx);
799 if (!bgmmc_.use_buffer_b) return nullptr;
800
801 return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz
802 + k_blk_idx * bgmmc_.buffer_b_chunk_sz;
803 }
804
get_buf_C_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t805 char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const {
806 if (!bgmmc_.use_buffer_c) return nullptr;
807
808 if (bgmmc_.nthr_k > 1) {
809 const int nthr_k = bgmmc_.nthr_k <= nthr_ ? bgmmc_.nthr_k : 1;
810 const int nthr_bmn = nthr_ / nthr_k;
811 const int ithr_k = ithr / nthr_bmn;
812 return get_buf_C_par_reduction_ptr(ithr_k, m_blk_idx, n_blk_idx);
813 }
814
815 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
816 const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
817 const int buf_idx = bgmmc_.N_chunk_size * m_blk_local + n_blk_local;
818
819 return buf_C_ptr_ + ithr * bgmmc_.buffer_c_per_thread_sz
820 + buf_idx * bgmmc_.buffer_c_chunk_sz;
821 }
822
get_buf_C_par_reduction_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t823 char *get_buf_C_par_reduction_ptr(
824 int ithr_k, int m_blk_idx, int n_blk_idx) const {
825 if (bgmmc_.nthr_k <= 1) return nullptr;
826
827 const int m = m_blk_idx * bgmmc_.M_blk;
828 const int n = n_blk_idx * bgmmc_.N_blk;
829
830 if (!bgmmc_.post_ops_applicable && ithr_k == 0)
831 return get_data_C_ptr(0, m, n);
832
833 int k_buf_idx = ithr_k - (!bgmmc_.post_ops_applicable ? 1 : 0);
834 return buf_C_ptr_ + k_buf_idx * bgmmc_.buffer_c_per_thread_sz
835 + get_data_C_off(0, m, n) * bgmmc_.acc_dt_sz / bgmmc_.c_dt_sz;
836 }
837
838 // Auxiliary functions for getting offsets with pre-calculated memory
839 // strides for each tensor to get general sulution for all possible
840 // dimension without significant overhead
get_data_A_offdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t841 dim_t get_data_A_off(int b, int m, int k) const {
842 return bgmmc_.A_strides[2] * b + bgmmc_.A_strides[1] * m
843 + bgmmc_.A_strides[0] * k;
844 }
get_data_B_offdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t845 dim_t get_data_B_off(int b, int k, int n) const {
846 int k_idx = bgmmc_.blocked_B ? k / bgmmc_.wei_k_blk : k;
847 int n_idx = bgmmc_.blocked_B ? n / bgmmc_.wei_n_blk : n;
848
849 return bgmmc_.B_strides[2] * b + bgmmc_.B_strides[1] * k_idx
850 + bgmmc_.B_strides[0] * n_idx
851 + get_data_B_off_within_block(k, n);
852 }
853
get_data_B_off_within_blockdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t854 dim_t get_data_B_off_within_block(int k, int n) const {
855 using namespace format_tag;
856
857 if (!bgmmc_.blocked_B) return 0;
858
859 int x0 = k % bgmmc_.wei_k_blk;
860 int x1 = n % bgmmc_.wei_n_blk;
861 dim_t offset = (x0 / vnni_factor) * vnni_factor * bgmmc_.wei_n_blk
862 + x1 * vnni_factor + x0 % vnni_factor;
863 return bgmmc_.b_dt_sz * offset;
864 }
865
get_data_C_offdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t866 dim_t get_data_C_off(int b, int m, int n) const {
867 return bgmmc_.C_strides[2] * b + bgmmc_.C_strides[1] * m
868 + bgmmc_.C_strides[0] * n;
869 }
870
get_bias_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t871 const char *get_bias_ptr(int n) const {
872 if (!bgmmc_.with_bias) return nullptr;
873
874 return bias_ptr_ + n * bgmmc_.bias_dt_sz;
875 }
876
get_s8s8_comp_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t877 int32_t *get_s8s8_comp_ptr(int ithr, int b, int n_blk_idx) const {
878 if (!bgmmc_.s8s8_compensation_required) return nullptr;
879
880 const int n_blk_local = bgmmc_.use_buffer_b
881 ? n_blk_idx % bgmmc_.N_chunk_size
882 : n_blk_idx;
883 return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str
884 + b * bgmmc_.s8s8_comp_b_str
885 + n_blk_local * bgmmc_.s8s8_comp_n_str;
886 }
887
get_oscales_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t888 const float *get_oscales_ptr(int n) const {
889 return oscales_ptr_ + bgmmc_.is_oscale_per_n * n;
890 }
891
get_zp_a_neg_val_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t892 const int32_t *get_zp_a_neg_val_ptr() const {
893 return &zero_point_a_negative_val_;
894 }
895
get_zp_b_neg_val_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t896 const int32_t *get_zp_b_neg_val_ptr() const {
897 return &zero_point_b_negative_val_;
898 }
899
get_zp_ab_mixed_comp_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t900 const int32_t *get_zp_ab_mixed_comp_ptr() const {
901 return &zero_point_mixed_ab_compensation_component_;
902 }
903
get_zp_c_val_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t904 const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; }
905
get_zp_a_compensation_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t906 int32_t *get_zp_a_compensation_ptr(int ithr, int n_blk_idx) const {
907 if (!bgmmc_.has_zero_point_a) return nullptr;
908
909 const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size;
910 int32_t *zp_comp = zero_point_a_compensations_ptr_
911 + ithr * bgmmc_.zp_a_comp_elems_per_thr
912 + n_blk_local * bgmmc_.zp_a_comp_shift_n;
913
914 if (bgmmc_.blocked_B) {
915 // Scale computed in reorder compensation values by zp_a value
916 // locally just before usage. Using the single global scaling before
917 // parallel section might produce significant overhead for small
918 // problems running in multitreaded execution mode
919 const int base_offset = n_blk_idx * bgmmc_.wei_n_blk;
920 PRAGMA_OMP_SIMD()
921 for (int b = 0; b < bgmmc_.wei_n_blk; b++)
922 zp_comp[b] = -zero_point_a_negative_val_
923 * reorder_zp_a_comp_ptr_[base_offset + b];
924 }
925 return zp_comp;
926 }
927
get_zp_b_compensation_result_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t928 int32_t *get_zp_b_compensation_result_ptr(int ithr, int m_blk_idx) const {
929 if (!bgmmc_.has_zero_point_b) return nullptr;
930
931 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
932 return zero_point_b_compensations_ptr_
933 + ithr * bgmmc_.zp_b_comp_elems_per_thr
934 + m_blk_local * bgmmc_.zp_b_comp_result_shift_m;
935 }
936
get_zp_b_compensation_buffer_ptrdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t937 int32_t *get_zp_b_compensation_buffer_ptr(int ithr, int m_blk_idx) const {
938 if (!bgmmc_.has_zero_point_b) return nullptr;
939
940 const int m_blk_local = m_blk_idx % bgmmc_.M_chunk_size;
941 return get_zp_b_compensation_result_ptr(ithr, 0)
942 + bgmmc_.zp_b_comp_buffer_start
943 + m_blk_local * bgmmc_.zp_b_comp_buffer_shift_m;
944 }
945
get_tile_workspacednnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t946 char *get_tile_workspace(int ithr) const {
947 return is_amx_ ? wsp_tile_ptr_ + ithr * bgmmc_.wsp_tile_per_thr_bytes
948 : nullptr;
949 }
950
get_post_ops_binary_rhs_arg_vecdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t951 const std::vector<const void *> &get_post_ops_binary_rhs_arg_vec() const {
952 return post_ops_binary_rhs_arg_vec_;
953 }
954
get_base_brgemm_kernel_idxdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t955 int get_base_brgemm_kernel_idx() const { return base_brg_ker_idx_; }
956
is_last_K_chunkdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t957 bool is_last_K_chunk(int k_chunk_idx) const {
958 return k_chunk_idx == bgmmc_.K_chunks - 1;
959 }
960
get_brgemm_batch_sizednnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t961 int get_brgemm_batch_size(int k_chunk_idx) const {
962 const int last_brgemm_batch_size
963 = (nstl::max(bgmmc_.K, bgmmc_.K_blk)
964 - k_chunk_idx * bgmmc_.K_chunk_elems)
965 / bgmmc_.K_blk;
966 return is_last_K_chunk(k_chunk_idx) ? last_brgemm_batch_size
967 : bgmmc_.brgemm_batch_size;
968 }
969
get_parallel_work_amountdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t970 int get_parallel_work_amount() const { return parallel_work_amount_; }
get_num_threads_for_kdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t971 int get_num_threads_for_k() const { return nthr_k_; }
parallel_reduction_is_useddnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t972 bool parallel_reduction_is_used() const {
973 return nthr_k_ > 1 && bgmmc_.K_chunks > 1;
974 }
get_num_threads_for_bmndnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t975 int get_num_threads_for_bmn() const { return nthr_bmn_; }
976 // ithr = ithr_k * nthr_bmn + ithr_bmn
get_thread_idx_for_kdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t977 int get_thread_idx_for_k(int ithr) const {
978 if (ithr >= num_threads_used_) return -1;
979 const int ithr_k = ithr / nthr_bmn_;
980 return ithr_k < bgmmc_.K_chunks ? ithr_k : -1;
981 }
get_thread_idx_for_bmndnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t982 int get_thread_idx_for_bmn(int ithr) const {
983 if (ithr >= num_threads_used_) return -1;
984 const int ithr_bmn = ithr % nthr_bmn_;
985 return ithr_bmn < parallel_work_amount_ ? ithr_bmn : -1;
986 }
get_num_threads_for_parallelizationdnnl::impl::cpu::x64::matmul::brgemm_matmul_t::brg_matmul_exec_ctx_t987 int get_num_threads_for_parallelization() const { return nthr_; }
988
989 private:
990 bool is_amx_;
991 const brgemm_matmul_conf_t &bgmmc_;
992 const char *data_A_ptr_;
993 const char *data_B_ptr_;
994 char *data_C_ptr_;
995 brgemm_batch_element_t *batch_element_ptr_;
996
997 char *buf_A_ptr_;
998 char *buf_B_ptr_;
999 char *buf_C_ptr_;
1000
1001 char *wsp_tile_ptr_;
1002 const char *bias_ptr_;
1003 const float *oscales_ptr_;
1004 int32_t *s8s8_compensation_ptr_;
1005
1006 int32_t *zero_point_a_compensations_ptr_;
1007 int32_t *zero_point_b_compensations_ptr_;
1008 int32_t *reorder_zp_a_comp_ptr_;
1009
1010 int32_t zero_point_a_negative_val_;
1011 int32_t zero_point_b_negative_val_;
1012 int32_t zero_point_mixed_ab_compensation_component_;
1013 int32_t zero_point_c_val_;
1014 std::vector<const void *> post_ops_binary_rhs_arg_vec_;
1015
1016 int base_brg_ker_idx_;
1017 int vnni_factor;
1018
1019 // parallelization parameters
1020 int parallel_work_amount_;
1021 int nthr_, nthr_k_, nthr_bmn_, num_threads_used_;
1022 };
1023
1024 template struct brgemm_matmul_t<avx512_core_bf16_amx_int8>;
1025 template struct brgemm_matmul_t<avx512_core_bf16_amx_bf16>;
1026 template struct brgemm_matmul_t<avx512_core_bf16>;
1027 template struct brgemm_matmul_t<avx512_core_vnni>;
1028 template struct brgemm_matmul_t<avx512_core>;
1029
1030 } // namespace matmul
1031 } // namespace x64
1032 } // namespace cpu
1033 } // namespace impl
1034 } // namespace dnnl
1035