1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/nstl.hpp"
20 #include "common/type_helpers.hpp"
21 #include "common/utils.hpp"
22
23 #include "cpu/x64/amx_tile_configure.hpp"
24 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
25 #include "cpu/x64/jit_brgemm_1x1_conv.hpp"
26
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31
32 using namespace dnnl::impl::status;
33 using namespace dnnl::impl::memory_tracking::names;
34 using namespace dnnl::impl::utils;
35
36 using namespace nstl;
37 using namespace data_type;
38
39 #define ndims_pick(v5, v4, v3) \
40 ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)
41
42 template <cpu_isa_t isa>
init(engine_t * engine)43 status_t brgemm_1x1_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {
44 using namespace data_type;
45 using namespace utils;
46
47 const auto src_type = src_md(0)->data_type;
48 const auto wei_type = weights_md(0)->data_type;
49 const auto dst_type = dst_md(0)->data_type;
50
51 using skip_mask_t = primitive_attr_t::skip_mask_t;
52 auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt;
53 if (one_of(src_type, u8, s8)) skip_mask |= skip_mask_t::oscale;
54
55 bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct)
56 && expect_data_types(src_type, wei_type, data_type::undef, dst_type,
57 data_type::undef)
58 && IMPLICATION(with_bias(),
59 ((one_of(src_type, u8, s8)
60 && one_of(bias_md_.data_type, f32, s32, s8, u8))
61 || (one_of(src_type, bf16)
62 && one_of(bias_md_.data_type, f32, bf16))
63 || (one_of(src_type, f32)
64 && one_of(bias_md_.data_type, f32))))
65 && attr()->has_default_values(skip_mask, dst_type)
66 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
67 && !has_zero_dim_memory();
68 if (!ok) return status::unimplemented;
69
70 CHECK(brgemm_convolution_utils::init_1x1_conf(jcp_, isa, *desc(), src_md_,
71 weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads()));
72
73 for (int i = 0; i < 16; i++)
74 brgs_[i].bcast_dim = brgs_[i].load_dim = brgs_[i].reduce_dim = 0;
75
76 const float alpha = 1.0;
77 const float beta = 1.0;
78 const auto &p = attr()->post_ops_;
79 const int sum_idx = p.find(primitive_kind::sum);
80 with_sum = (sum_idx != -1);
81 sum_scale = with_sum ? p.entry_[sum_idx].sum.scale : 0.0;
82
83 for_(int i_init = 0; i_init < 2; i_init++)
84 for_(int i_M = 0; i_M < 2; i_M++)
85 for_(int i_N = 0; i_N < 2; i_N++)
86 for (int i_K = 0; i_K < 2; i_K++) {
87 auto vbeta = (i_init) ? 0 : beta;
88 auto vM = (i_M) ? jcp_.M_tail : jcp_.M;
89 auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
90 auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
91 brgemm_t &brg = brgs_[get_brg_idx(i_init, i_M, i_N, i_K)];
92 if (vM == 0 || vN == 0 || vK == 0) continue;
93 brgemm_strides_t brg_strides;
94 brg_strides.stride_a = jcp_.brg_stride_a;
95 brg_strides.stride_b = jcp_.brg_stride_b;
96 const auto strides_ptr
97 = (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr;
98 CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, src_type, wei_type,
99 false, false, brgemm_row_major, alpha, vbeta, jcp_.LDA,
100 jcp_.LDB, jcp_.LDC, vM, vN, vK, strides_ptr));
101
102 brgemm_attr_t brgattr;
103 brgattr.max_bs = jcp_.gemm_batch_size;
104 brgattr.max_top_vpad = jcp_.max_vpad;
105 brgattr.max_bottom_vpad = jcp_.max_vpad;
106 brgattr.hint_expected_A_size = 0;
107 brgattr.hint_expected_B_size = brgattr.max_bs * vK * vN;
108 brgattr.hint_expected_C_size = 0;
109 brgattr.wary_tail_read = false;
110 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
111 brgattr.use_uker = is_amx && brg.rdb > 1;
112 brgattr.use_interleave_stores = brgattr.use_uker;
113 CHECK(brgemm_desc_set_attr(&brg, brgattr));
114 auto LDD = jcp_.oc_without_padding;
115 brg.with_sum = with_sum;
116 CHECK(brgemm_desc_set_postops(
117 &brg, attr(), &dst_md_, LDD, jcp_.bia_dt));
118 }
119
120 auto scratchpad = scratchpad_registry().registrar();
121 brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_);
122
123 return status::success;
124 }
125
126 template <cpu_isa_t isa>
init(engine_t * engine)127 status_t brgemm_1x1_convolution_fwd_t<isa>::init(engine_t *engine) {
128 auto ndims = pd()->ndims();
129 if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!");
130
131 const auto &jcp = pd()->jcp_;
132
133 ID = ndims_pick(jcp.id, 1, 1);
134 IH = ndims_pick(jcp.ih, jcp.ih, 1);
135 IW = jcp.iw;
136
137 OD = ndims_pick(jcp.od, 1, 1);
138 OH = ndims_pick(jcp.oh, jcp.oh, 1);
139 OW = jcp.ow;
140
141 SD = ndims_pick(jcp.stride_d, 1, 1);
142 SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1);
143 SW = jcp.stride_w;
144
145 bia_dsz = jcp.bia_dsz;
146 acc_dsz = jcp.acc_dsz;
147 src_dsz = jcp.src_dsz;
148 wei_dsz = jcp.wei_dsz;
149
150 ic_chunks = div_up(jcp.nb_ic, jcp.nb_ic_blocking);
151
152 // const variables used for address calculations
153 src_w_sz = (dim_t)IW * jcp.ngroups * jcp.ic_without_padding;
154 src_h_sz = IH * src_w_sz;
155 src_d_sz = ID * src_h_sz;
156 dst_w_sz = (dim_t)OW * jcp.oc_without_padding;
157 dst_h_sz = OH * dst_w_sz;
158 dst_d_sz = OD * dst_h_sz;
159
160 const auto src_type = pd()->src_md(0)->data_type;
161 const auto wei_type = pd()->weights_md(0)->data_type;
162
163 const auto last_ic_block
164 = (src_type == f32) ? 1 : ((src_type == bf16) ? 2 : 4);
165
166 wei_oc_sz = jcp.wei_plain ? jcp.oc : jcp.oc_block;
167 wei_ic_sz = jcp.wei_plain
168 ? (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc
169 : (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc_block;
170 wei_ocb_sz = jcp.wei_plain ? jcp.oc_block * last_ic_block
171 : jcp.nb_oc * wei_ic_sz;
172
173 need_postwork = jcp.with_bias || jcp.with_eltwise || jcp.with_binary
174 || (one_of(src_type, u8, s8) && wei_type == s8) // oscales needed
175 || (jcp.dst_dt != jcp.acc_dt) || jcp.with_sum;
176
177 for (int i = 0; i < 16; i++)
178 brg_kernels_[i] = nullptr;
179
180 if (jcp.is_rtus) {
181 CHECK(safe_ptr_assign(rtus_kernel_,
182 new jit_avx512_core_brgemm_conv_trans_kernel::
183 jit_avx512_core_brgemm_conv_rtus_kernel_t(jcp)));
184 CHECK(rtus_kernel_->create_kernel());
185 }
186
187 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
188 for_(int i_M = 0; i_M < 2; i_M++)
189 for_(int i_N = 0; i_N < 2; i_N++)
190 for_(int i_K = 0; i_K < 2; i_K++)
191 for (int i_init = 0; i_init < 2; i_init++) {
192 auto brg_idx = get_brg_idx(i_init, i_M, i_N, i_K);
193 auto &brg = pd()->brgs_[brg_idx];
194 if (brg.bcast_dim > 0 && brg.load_dim > 0 && brg.reduce_dim > 0
195 && !brg_kernels_[brg_idx]) {
196 brgemm_kernel_t *brg_kernel = nullptr;
197 CHECK(brgemm_kernel_create(&brg_kernel, brg));
198 CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel));
199 if (is_amx) {
200 amx_palette_t tmp;
201 int &palette_idx = brg_kernel_palette_idx_[brg_idx];
202 palette_idx = -1;
203 CHECK(brgemm_init_tiles(brg, tmp.p));
204 // check if it's in set of tile configs
205 for (size_t i = 0; i < brg_kernel_palette_.size(); i++) {
206 const bool is_match = 0
207 == std::memcmp(brg_kernel_palette_[i].p, tmp.p,
208 AMX_PALETTE_SIZE);
209 if (is_match) {
210 palette_idx = i;
211 break;
212 }
213 }
214 // add to set of tile configs if needed
215 if (palette_idx == -1) {
216 palette_idx = brg_kernel_palette_.size();
217 brg_kernel_palette_.push_back(tmp);
218 }
219 }
220 }
221 }
222 return status::success;
223 }
224
225 template <cpu_isa_t isa>
maybe_rtus(int ithr,const char * __restrict src,char * __restrict inp_buffer,uint8_t * __restrict inp_buffer_mask,int g,int n,int icc,int od,int oh,int ow) const226 void brgemm_1x1_convolution_fwd_t<isa>::maybe_rtus(int ithr,
227 const char *__restrict src, char *__restrict inp_buffer,
228 uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int od,
229 int oh, int ow) const {
230 const auto &jcp = pd()->jcp_;
231 if (!jcp.is_rtus) return;
232 assert(jcp.is_os_blocking);
233 const size_t src_dt_size = jcp.src_dsz;
234
235 const auto os = (od * OH + oh) * OW + ow;
236 const auto osb = os / jcp.os_block;
237
238 uint8_t *bmask = &inp_buffer_mask[icc * jcp.nb_os + osb];
239 if (bmask && *bmask) return; // skip if already masked
240 if (bmask) *bmask = 1; // set mask to skip next time
241
242 const auto g_ic = g * jcp.ic_without_padding
243 + icc * jcp.nb_ic_blocking * jcp.ic_block;
244
245 auto call_kernel = [&](int nh, int nw, int od, int oh, int ow) {
246 assert(nh == 0 || (nw == 0 && ow == 0));
247 if (utils::everyone_is(0, nh, nw)) return;
248 const int id = od * jcp.stride_d;
249 const int ih = oh * jcp.stride_h;
250 const int iw = ow * jcp.stride_w;
251 const auto inp_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz
252 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
253 auto p = jit_avx512_core_brgemm_conv_trans_kernel::
254 jit_brgemm_conv_trans_kernel_call_s();
255 p.h_count = nh;
256 p.owb = nw;
257 p.src = src + src_dt_size * inp_offset;
258 p.dst = inp_buffer;
259 (*rtus_kernel_)(&p);
260 inp_buffer += src_dt_size * (nh * jcp.ow + nw) * jcp.LDA;
261 };
262
263 const bool is_os_tail = jcp.os - os < jcp.os_block;
264 int count = is_os_tail ? jcp.M_tail : jcp.M;
265
266 if (count < OW || ow > 0) {
267 // copy to end of row
268 const auto nw = nstl::min(count, OW - ow);
269 call_kernel(0, nw, od, oh, ow);
270 count -= nw;
271 if (count == 0) return;
272 ow = 0;
273 oh = (oh + 1) % OH;
274 if (oh == 0) od++;
275 }
276
277 while (od < OD) {
278 // copy to end of column
279 const auto nh = nstl::min(count / OW, OH - oh);
280 call_kernel(nh, 0, od, oh, ow);
281 count -= nh * OW;
282 if (count == 0) return;
283 oh = (oh + nh) % OH;
284 if (oh == 0) od++;
285 if (count < OW) {
286 // copy partial row
287 const auto nw = count;
288 call_kernel(0, nw, od, oh, ow);
289 return;
290 }
291 }
292 }
293
294 template <cpu_isa_t isa>
exec_ker(const brgemm_exec_ctx_t & brgemm_ctx,int ithr,brgemm_batch_element_t * const __restrict brg_batch,char * const c_buffer,const char * inp_buffer,int g,int n,int ocb,int od,int oh,int ow,int icc,int * last_palette_idx) const295 void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
296 const brgemm_exec_ctx_t &brgemm_ctx, int ithr,
297 brgemm_batch_element_t *const __restrict brg_batch,
298 char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
299 int od, int oh, int ow, int icc, int *last_palette_idx) const {
300
301 const memory_desc_wrapper src_d(pd()->src_md());
302 const memory_desc_wrapper weights_d(pd()->weights_md());
303 const memory_desc_wrapper dst_d(pd()->dst_md());
304 const size_t src_dt_size = types::data_type_size(src_d.data_type());
305 const size_t wei_dt_size = types::data_type_size(weights_d.data_type());
306 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
307
308 const char *const __restrict src = brgemm_ctx.src;
309 const char *const __restrict weights = brgemm_ctx.weights;
310 const char *const __restrict bias = brgemm_ctx.bias;
311 char *const __restrict dst = brgemm_ctx.dst;
312 const std::vector<const void *> &post_ops_binary_rhs_arg_vec
313 = brgemm_ctx.post_ops_binary_rhs_arg_vec;
314
315 const float *oscales = pd()->attr()->output_scales_.scales_;
316
317 const auto &jcp = pd()->jcp_;
318 auto ndims = pd()->ndims();
319
320 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
321 char *const wsp_tile
322 = is_amx ? brgemm_ctx.wsp_tile + ithr * 4 * 1024 : nullptr;
323
324 const int id = ndims_pick(od * SD, 0, 0);
325 const int ih = ndims_pick(oh * SH, oh * SH, 0);
326 const int iw = ow * SW;
327
328 const int oc = ocb * jcp.oc_block;
329 const int g_oc = g * jcp.oc + oc;
330
331 const int icb = icc * jcp.nb_ic_blocking;
332 const int ic = icb * jcp.ic_block;
333 const int g_ic = g * jcp.ic + ic;
334
335 const bool kernel_init = (icc == 0);
336
337 const auto os = (od * OH + oh) * OW + ow;
338
339 const bool is_os_tail = jcp.is_os_blocking ? (jcp.os - os < jcp.os_block)
340 : (OW - ow < jcp.ow_block);
341 const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block);
342 const bool is_ic_tail
343 = (icc == ic_chunks - 1 && ((jcp.ic - ic) % jcp.ic_block != 0));
344
345 const auto src_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz
346 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
347 const auto src_base
348 = jcp.is_rtus ? inp_buffer : src + src_dt_size * src_offset;
349 const auto wei_offset = jcp.wei_plain ? g * wei_ic_sz + ocb * wei_ocb_sz
350 : g * wei_ocb_sz + ocb * wei_ic_sz;
351 const auto wei_base = weights + wei_dt_size * wei_offset;
352 const auto ptr_D = dst
353 + dst_dt_size
354 * (n * dst_d_sz + od * dst_h_sz + oh * dst_w_sz
355 + ow * jcp.oc_without_padding + g_oc);
356 char *const ptr_C = (jcp.use_buffer) ? c_buffer : (char *)ptr_D;
357
358 const auto bias_w
359 = bias ? bias + (bias_d.blk_off(g_oc) * bia_dsz) : nullptr;
360 const auto nb_ic_b = nstl::min(jcp.nb_ic_blocking, jcp.nb_ic - icb)
361 - (is_ic_tail ? 1 : 0);
362
363 const auto call_brgemm = [=](int brg_idx, int ic_block_s, int n_ic_blocks,
364 bool do_postops) {
365 for (int k = 0; k < n_ic_blocks; k++) {
366 const auto ic_off = (ic_block_s + k) * jcp.ic_block;
367 const auto src_ic = ic_off;
368 const auto wei_ic = ic + ic_off;
369 const auto ptr_A = src_base + src_dt_size * src_ic;
370 const auto ptr_B = wei_base + wei_dt_size * wei_ic * wei_oc_sz;
371 brg_batch[k].ptr.A = ptr_A;
372 brg_batch[k].ptr.B = ptr_B;
373 brg_batch[k].vvpad.top = 0;
374 brg_batch[k].vvpad.bottom = 0;
375 }
376
377 // NOTE: avoid some costly tile reconfigurations here by keeping track
378 // of the previous brg kernel tile configuration palette
379 // TODO: adjust harness to avoid even more tile reconfigurations
380 if (is_amx) {
381 const int curr_palette_idx = brg_kernel_palette_idx_[brg_idx];
382 if (curr_palette_idx != *last_palette_idx) {
383 amx_tile_configure(brg_kernel_palette_[curr_palette_idx].p);
384 *last_palette_idx = curr_palette_idx;
385 }
386 }
387
388 const brgemm_kernel_t *brg_ker = brg_kernels_[brg_idx].get();
389 if (do_postops) {
390 const brgemm_post_ops_data_t post_ops_data {
391 static_cast<const void *>(bias_w),
392 &oscales[jcp.is_oc_scale * g_oc],
393 post_ops_binary_rhs_arg_vec.data(),
394 static_cast<size_t>(g_oc), 0, dst};
395
396 brgemm_kernel_execute_postops(brg_ker, n_ic_blocks, brg_batch,
397 (void *)ptr_C, (void *)ptr_D, post_ops_data,
398 (void *)wsp_tile);
399 } else {
400 brgemm_kernel_execute(brg_ker, n_ic_blocks, brg_batch,
401 (void *)ptr_C, (void *)wsp_tile);
402 }
403 };
404
405 const auto do_post_work
406 = (need_postwork || jcp.use_buffer) && icc == ic_chunks - 1;
407
408 if (nb_ic_b > 0) {
409 const auto brg_idx
410 = get_brg_idx(kernel_init, is_os_tail, is_oc_tail, false);
411 call_brgemm(brg_idx, 0, nb_ic_b, do_post_work && !is_ic_tail);
412 }
413 if (is_ic_tail) {
414 const auto use_init_ker = (kernel_init && nb_ic_b == 0);
415 const auto brg_idx
416 = get_brg_idx(use_init_ker, is_os_tail, is_oc_tail, true);
417
418 call_brgemm(brg_idx, nb_ic_b, 1, do_post_work);
419 }
420 }
421
422 template <cpu_isa_t isa>
execute_forward_all(const exec_ctx_t & ctx) const423 void brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
424 const exec_ctx_t &ctx) const {
425
426 brgemm_exec_ctx_t brgemm_ctx(ctx, pd());
427
428 const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
429
430 const auto &jcp = pd()->jcp_;
431 const bool is_amx = brgemm_convolution_utils::is_amx(isa);
432
433 brgemm_batch_element_t *const brg_batch_global
434 = (jcp.brg_type != brgemm_strd)
435 ? scratchpad.template get<brgemm_batch_element_t>(
436 key_brgemm_primitive_batch)
437 : nullptr;
438 char *const c_buffer_global = (jcp.use_buffer)
439 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
440 : nullptr;
441 char *inp_buffer_base = (jcp.is_rtus)
442 ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer)
443 : nullptr;
444 uint8_t *inp_buffer_mask_base = (jcp.is_rtus)
445 ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask)
446 : nullptr;
447
448 if (jcp.is_os_blocking) {
449 const int os_chunks = div_up(jcp.nb_os, jcp.nb_os_blocking);
450 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_oc * os_chunks;
451
452 #define BRGC_WO(...) \
453 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \
454 if (ithr >= work_amount) return; \
455 brgemm_batch_element_t *const brg_batch \
456 = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \
457 char *const c_buffer = (jcp.use_buffer) \
458 ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \
459 : nullptr; \
460 char *inp_buffer = (jcp.is_rtus) \
461 ? inp_buffer_base + ithr * src_dsz * jcp.inp_buffer_size \
462 : nullptr; \
463 uint8_t *__restrict inp_buffer_mask = (jcp.is_rtus) \
464 ? inp_buffer_mask_base + ithr * jcp.inp_buffer_mask_size \
465 : nullptr; \
466 int last_n = -1; \
467 int last_g = -1; \
468 int last_palette_idx = -1; \
469 int start {0}, end {0}; \
470 balance211(work_amount, nthr, ithr, start, end); \
471 int n {0}, g {0}, ocb {0}, oss {0}; \
472 nd_iterator_init(start, __VA_ARGS__); \
473 for (auto work = start; work < end; work++) { \
474 if (jcp.is_rtus && (last_n != n || last_g != g)) \
475 std::memset(inp_buffer_mask, 0, jcp.inp_buffer_mask_size); \
476 const auto osb_start = oss * jcp.nb_os_blocking; \
477 const auto osb_range \
478 = nstl::min(jcp.nb_os - osb_start, jcp.nb_os_blocking); \
479 for (int osb = 0; osb < osb_range; osb++) { \
480 const int os = (osb_start + osb) * jcp.os_block; \
481 const int od = os / (OH * OW); \
482 const int oh = (os % (OH * OW)) / OW; \
483 const int ow = os % OW; \
484 char *inp_buffer_sp = (jcp.is_rtus) \
485 ? inp_buffer + src_dsz * os * jcp.LDA \
486 : nullptr; \
487 for (int icc = 0; icc < ic_chunks; icc++) { \
488 if (jcp.is_rtus) \
489 maybe_rtus(ithr, brgemm_ctx.src, inp_buffer_sp, \
490 inp_buffer_mask, g, n, icc, od, oh, ow); \
491 exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, \
492 inp_buffer_sp, g, n, ocb, od, oh, ow, icc, \
493 &last_palette_idx); \
494 } \
495 } \
496 last_n = n; \
497 last_g = g; \
498 nd_iterator_step(__VA_ARGS__); \
499 } \
500 if (is_amx) amx_tile_release(); \
501 });
502
503 if (jcp.loop_order == loop_ndhwgc)
504 BRGC_WO(n, jcp.mb, oss, os_chunks, g, jcp.ngroups, ocb, jcp.nb_oc)
505 else if (jcp.loop_order == loop_ngcdhw)
506 BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, oss, os_chunks)
507 else
508 assert(!"Unknown loop order");
509
510 #undef BRGC_WO
511
512 } else {
513 const int work_amount
514 = jcp.mb * jcp.ngroups * jcp.nb_oc * OD * OH * jcp.nb_ow;
515
516 #define BRGC_WO(...) \
517 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \
518 if (ithr >= work_amount) return; \
519 brgemm_batch_element_t *const brg_batch \
520 = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \
521 char *const c_buffer = (jcp.use_buffer) \
522 ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \
523 : nullptr; \
524 int last_palette_idx = -1; \
525 int start {0}, end {0}; \
526 balance211(work_amount, nthr, ithr, start, end); \
527 int n {0}, g {0}, ocb {0}, od {0}, oh {0}, owb {0}; \
528 nd_iterator_init(start, __VA_ARGS__); \
529 for (auto work = start; work < end; work++) { \
530 for (int icc = 0; icc < ic_chunks; icc++) { \
531 const int ow = owb * jcp.ow_block; \
532 exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, nullptr, g, n, \
533 ocb, od, oh, ow, icc, &last_palette_idx); \
534 } \
535 nd_iterator_step(__VA_ARGS__); \
536 } \
537 if (is_amx) amx_tile_release(); \
538 });
539
540 if (jcp.loop_order == loop_ndhwgc)
541 BRGC_WO(n, jcp.mb, od, OD, oh, OH, owb, jcp.nb_ow, g, jcp.ngroups,
542 ocb, jcp.nb_oc)
543 else if (jcp.loop_order == loop_ngcdhw)
544 BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, od, OD, oh, OH,
545 owb, jcp.nb_ow)
546 else
547 assert(!"Unknown loop order");
548
549 #undef BRGC_WO
550 }
551 }
552
553 template struct brgemm_1x1_convolution_fwd_t<avx512_core>;
554 template struct brgemm_1x1_convolution_fwd_t<avx512_core_vnni>;
555 template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16>;
556 template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16_amx_int8>;
557 template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16_amx_bf16>;
558
559 } // namespace x64
560 } // namespace cpu
561 } // namespace impl
562 } // namespace dnnl
563
564 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
565