1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/nstl.hpp"
20 #include "common/type_helpers.hpp"
21 #include "common/utils.hpp"
22 
23 #include "cpu/x64/amx_tile_configure.hpp"
24 #include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
25 #include "cpu/x64/jit_brgemm_1x1_conv.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31 
32 using namespace dnnl::impl::status;
33 using namespace dnnl::impl::memory_tracking::names;
34 using namespace dnnl::impl::utils;
35 
36 using namespace nstl;
37 using namespace data_type;
38 
39 #define ndims_pick(v5, v4, v3) \
40     ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)
41 
42 template <cpu_isa_t isa>
init(engine_t * engine)43 status_t brgemm_1x1_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) {
44     using namespace data_type;
45     using namespace utils;
46 
47     const auto src_type = src_md(0)->data_type;
48     const auto wei_type = weights_md(0)->data_type;
49     const auto dst_type = dst_md(0)->data_type;
50 
51     using skip_mask_t = primitive_attr_t::skip_mask_t;
52     auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt;
53     if (one_of(src_type, u8, s8)) skip_mask |= skip_mask_t::oscale;
54 
55     bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct)
56             && expect_data_types(src_type, wei_type, data_type::undef, dst_type,
57                     data_type::undef)
58             && IMPLICATION(with_bias(),
59                     ((one_of(src_type, u8, s8)
60                              && one_of(bias_md_.data_type, f32, s32, s8, u8))
61                             || (one_of(src_type, bf16)
62                                     && one_of(bias_md_.data_type, f32, bf16))
63                             || (one_of(src_type, f32)
64                                     && one_of(bias_md_.data_type, f32))))
65             && attr()->has_default_values(skip_mask, dst_type)
66             && attr()->post_ops_.check_sum_consistent_dt(dst_type)
67             && !has_zero_dim_memory();
68     if (!ok) return status::unimplemented;
69 
70     CHECK(brgemm_convolution_utils::init_1x1_conf(jcp_, isa, *desc(), src_md_,
71             weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads()));
72 
73     for (int i = 0; i < 16; i++)
74         brgs_[i].bcast_dim = brgs_[i].load_dim = brgs_[i].reduce_dim = 0;
75 
76     const float alpha = 1.0;
77     const float beta = 1.0;
78     const auto &p = attr()->post_ops_;
79     const int sum_idx = p.find(primitive_kind::sum);
80     with_sum = (sum_idx != -1);
81     sum_scale = with_sum ? p.entry_[sum_idx].sum.scale : 0.0;
82 
83     for_(int i_init = 0; i_init < 2; i_init++)
84     for_(int i_M = 0; i_M < 2; i_M++)
85     for_(int i_N = 0; i_N < 2; i_N++)
86     for (int i_K = 0; i_K < 2; i_K++) {
87         auto vbeta = (i_init) ? 0 : beta;
88         auto vM = (i_M) ? jcp_.M_tail : jcp_.M;
89         auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
90         auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
91         brgemm_t &brg = brgs_[get_brg_idx(i_init, i_M, i_N, i_K)];
92         if (vM == 0 || vN == 0 || vK == 0) continue;
93         brgemm_strides_t brg_strides;
94         brg_strides.stride_a = jcp_.brg_stride_a;
95         brg_strides.stride_b = jcp_.brg_stride_b;
96         const auto strides_ptr
97                 = (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr;
98         CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, src_type, wei_type,
99                 false, false, brgemm_row_major, alpha, vbeta, jcp_.LDA,
100                 jcp_.LDB, jcp_.LDC, vM, vN, vK, strides_ptr));
101 
102         brgemm_attr_t brgattr;
103         brgattr.max_bs = jcp_.gemm_batch_size;
104         brgattr.max_top_vpad = jcp_.max_vpad;
105         brgattr.max_bottom_vpad = jcp_.max_vpad;
106         brgattr.hint_expected_A_size = 0;
107         brgattr.hint_expected_B_size = brgattr.max_bs * vK * vN;
108         brgattr.hint_expected_C_size = 0;
109         brgattr.wary_tail_read = false;
110         const bool is_amx = brgemm_convolution_utils::is_amx(isa);
111         brgattr.use_uker = is_amx && brg.rdb > 1;
112         brgattr.use_interleave_stores = brgattr.use_uker;
113         CHECK(brgemm_desc_set_attr(&brg, brgattr));
114         auto LDD = jcp_.oc_without_padding;
115         brg.with_sum = with_sum;
116         CHECK(brgemm_desc_set_postops(
117                 &brg, attr(), &dst_md_, LDD, jcp_.bia_dt));
118     }
119 
120     auto scratchpad = scratchpad_registry().registrar();
121     brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_);
122 
123     return status::success;
124 }
125 
126 template <cpu_isa_t isa>
init(engine_t * engine)127 status_t brgemm_1x1_convolution_fwd_t<isa>::init(engine_t *engine) {
128     auto ndims = pd()->ndims();
129     if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!");
130 
131     const auto &jcp = pd()->jcp_;
132 
133     ID = ndims_pick(jcp.id, 1, 1);
134     IH = ndims_pick(jcp.ih, jcp.ih, 1);
135     IW = jcp.iw;
136 
137     OD = ndims_pick(jcp.od, 1, 1);
138     OH = ndims_pick(jcp.oh, jcp.oh, 1);
139     OW = jcp.ow;
140 
141     SD = ndims_pick(jcp.stride_d, 1, 1);
142     SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1);
143     SW = jcp.stride_w;
144 
145     bia_dsz = jcp.bia_dsz;
146     acc_dsz = jcp.acc_dsz;
147     src_dsz = jcp.src_dsz;
148     wei_dsz = jcp.wei_dsz;
149 
150     ic_chunks = div_up(jcp.nb_ic, jcp.nb_ic_blocking);
151 
152     // const variables used for address calculations
153     src_w_sz = (dim_t)IW * jcp.ngroups * jcp.ic_without_padding;
154     src_h_sz = IH * src_w_sz;
155     src_d_sz = ID * src_h_sz;
156     dst_w_sz = (dim_t)OW * jcp.oc_without_padding;
157     dst_h_sz = OH * dst_w_sz;
158     dst_d_sz = OD * dst_h_sz;
159 
160     const auto src_type = pd()->src_md(0)->data_type;
161     const auto wei_type = pd()->weights_md(0)->data_type;
162 
163     const auto last_ic_block
164             = (src_type == f32) ? 1 : ((src_type == bf16) ? 2 : 4);
165 
166     wei_oc_sz = jcp.wei_plain ? jcp.oc : jcp.oc_block;
167     wei_ic_sz = jcp.wei_plain
168             ? (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc
169             : (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc_block;
170     wei_ocb_sz = jcp.wei_plain ? jcp.oc_block * last_ic_block
171                                : jcp.nb_oc * wei_ic_sz;
172 
173     need_postwork = jcp.with_bias || jcp.with_eltwise || jcp.with_binary
174             || (one_of(src_type, u8, s8) && wei_type == s8) // oscales needed
175             || (jcp.dst_dt != jcp.acc_dt) || jcp.with_sum;
176 
177     for (int i = 0; i < 16; i++)
178         brg_kernels_[i] = nullptr;
179 
180     if (jcp.is_rtus) {
181         CHECK(safe_ptr_assign(rtus_kernel_,
182                 new jit_avx512_core_brgemm_conv_trans_kernel::
183                         jit_avx512_core_brgemm_conv_rtus_kernel_t(jcp)));
184         CHECK(rtus_kernel_->create_kernel());
185     }
186 
187     const bool is_amx = brgemm_convolution_utils::is_amx(isa);
188     for_(int i_M = 0; i_M < 2; i_M++)
189     for_(int i_N = 0; i_N < 2; i_N++)
190     for_(int i_K = 0; i_K < 2; i_K++)
191     for (int i_init = 0; i_init < 2; i_init++) {
192         auto brg_idx = get_brg_idx(i_init, i_M, i_N, i_K);
193         auto &brg = pd()->brgs_[brg_idx];
194         if (brg.bcast_dim > 0 && brg.load_dim > 0 && brg.reduce_dim > 0
195                 && !brg_kernels_[brg_idx]) {
196             brgemm_kernel_t *brg_kernel = nullptr;
197             CHECK(brgemm_kernel_create(&brg_kernel, brg));
198             CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel));
199             if (is_amx) {
200                 amx_palette_t tmp;
201                 int &palette_idx = brg_kernel_palette_idx_[brg_idx];
202                 palette_idx = -1;
203                 CHECK(brgemm_init_tiles(brg, tmp.p));
204                 // check if it's in set of tile configs
205                 for (size_t i = 0; i < brg_kernel_palette_.size(); i++) {
206                     const bool is_match = 0
207                             == std::memcmp(brg_kernel_palette_[i].p, tmp.p,
208                                     AMX_PALETTE_SIZE);
209                     if (is_match) {
210                         palette_idx = i;
211                         break;
212                     }
213                 }
214                 // add to set of tile configs if needed
215                 if (palette_idx == -1) {
216                     palette_idx = brg_kernel_palette_.size();
217                     brg_kernel_palette_.push_back(tmp);
218                 }
219             }
220         }
221     }
222     return status::success;
223 }
224 
225 template <cpu_isa_t isa>
maybe_rtus(int ithr,const char * __restrict src,char * __restrict inp_buffer,uint8_t * __restrict inp_buffer_mask,int g,int n,int icc,int od,int oh,int ow) const226 void brgemm_1x1_convolution_fwd_t<isa>::maybe_rtus(int ithr,
227         const char *__restrict src, char *__restrict inp_buffer,
228         uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int od,
229         int oh, int ow) const {
230     const auto &jcp = pd()->jcp_;
231     if (!jcp.is_rtus) return;
232     assert(jcp.is_os_blocking);
233     const size_t src_dt_size = jcp.src_dsz;
234 
235     const auto os = (od * OH + oh) * OW + ow;
236     const auto osb = os / jcp.os_block;
237 
238     uint8_t *bmask = &inp_buffer_mask[icc * jcp.nb_os + osb];
239     if (bmask && *bmask) return; // skip if already masked
240     if (bmask) *bmask = 1; // set mask to skip next time
241 
242     const auto g_ic = g * jcp.ic_without_padding
243             + icc * jcp.nb_ic_blocking * jcp.ic_block;
244 
245     auto call_kernel = [&](int nh, int nw, int od, int oh, int ow) {
246         assert(nh == 0 || (nw == 0 && ow == 0));
247         if (utils::everyone_is(0, nh, nw)) return;
248         const int id = od * jcp.stride_d;
249         const int ih = oh * jcp.stride_h;
250         const int iw = ow * jcp.stride_w;
251         const auto inp_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz
252                 + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
253         auto p = jit_avx512_core_brgemm_conv_trans_kernel::
254                 jit_brgemm_conv_trans_kernel_call_s();
255         p.h_count = nh;
256         p.owb = nw;
257         p.src = src + src_dt_size * inp_offset;
258         p.dst = inp_buffer;
259         (*rtus_kernel_)(&p);
260         inp_buffer += src_dt_size * (nh * jcp.ow + nw) * jcp.LDA;
261     };
262 
263     const bool is_os_tail = jcp.os - os < jcp.os_block;
264     int count = is_os_tail ? jcp.M_tail : jcp.M;
265 
266     if (count < OW || ow > 0) {
267         // copy to end of row
268         const auto nw = nstl::min(count, OW - ow);
269         call_kernel(0, nw, od, oh, ow);
270         count -= nw;
271         if (count == 0) return;
272         ow = 0;
273         oh = (oh + 1) % OH;
274         if (oh == 0) od++;
275     }
276 
277     while (od < OD) {
278         // copy to end of column
279         const auto nh = nstl::min(count / OW, OH - oh);
280         call_kernel(nh, 0, od, oh, ow);
281         count -= nh * OW;
282         if (count == 0) return;
283         oh = (oh + nh) % OH;
284         if (oh == 0) od++;
285         if (count < OW) {
286             // copy partial row
287             const auto nw = count;
288             call_kernel(0, nw, od, oh, ow);
289             return;
290         }
291     }
292 }
293 
294 template <cpu_isa_t isa>
exec_ker(const brgemm_exec_ctx_t & brgemm_ctx,int ithr,brgemm_batch_element_t * const __restrict brg_batch,char * const c_buffer,const char * inp_buffer,int g,int n,int ocb,int od,int oh,int ow,int icc,int * last_palette_idx) const295 void brgemm_1x1_convolution_fwd_t<isa>::exec_ker(
296         const brgemm_exec_ctx_t &brgemm_ctx, int ithr,
297         brgemm_batch_element_t *const __restrict brg_batch,
298         char *const c_buffer, const char *inp_buffer, int g, int n, int ocb,
299         int od, int oh, int ow, int icc, int *last_palette_idx) const {
300 
301     const memory_desc_wrapper src_d(pd()->src_md());
302     const memory_desc_wrapper weights_d(pd()->weights_md());
303     const memory_desc_wrapper dst_d(pd()->dst_md());
304     const size_t src_dt_size = types::data_type_size(src_d.data_type());
305     const size_t wei_dt_size = types::data_type_size(weights_d.data_type());
306     const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
307 
308     const char *const __restrict src = brgemm_ctx.src;
309     const char *const __restrict weights = brgemm_ctx.weights;
310     const char *const __restrict bias = brgemm_ctx.bias;
311     char *const __restrict dst = brgemm_ctx.dst;
312     const std::vector<const void *> &post_ops_binary_rhs_arg_vec
313             = brgemm_ctx.post_ops_binary_rhs_arg_vec;
314 
315     const float *oscales = pd()->attr()->output_scales_.scales_;
316 
317     const auto &jcp = pd()->jcp_;
318     auto ndims = pd()->ndims();
319 
320     const bool is_amx = brgemm_convolution_utils::is_amx(isa);
321     char *const wsp_tile
322             = is_amx ? brgemm_ctx.wsp_tile + ithr * 4 * 1024 : nullptr;
323 
324     const int id = ndims_pick(od * SD, 0, 0);
325     const int ih = ndims_pick(oh * SH, oh * SH, 0);
326     const int iw = ow * SW;
327 
328     const int oc = ocb * jcp.oc_block;
329     const int g_oc = g * jcp.oc + oc;
330 
331     const int icb = icc * jcp.nb_ic_blocking;
332     const int ic = icb * jcp.ic_block;
333     const int g_ic = g * jcp.ic + ic;
334 
335     const bool kernel_init = (icc == 0);
336 
337     const auto os = (od * OH + oh) * OW + ow;
338 
339     const bool is_os_tail = jcp.is_os_blocking ? (jcp.os - os < jcp.os_block)
340                                                : (OW - ow < jcp.ow_block);
341     const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block);
342     const bool is_ic_tail
343             = (icc == ic_chunks - 1 && ((jcp.ic - ic) % jcp.ic_block != 0));
344 
345     const auto src_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz
346             + iw * jcp.ngroups * jcp.ic_without_padding + g_ic;
347     const auto src_base
348             = jcp.is_rtus ? inp_buffer : src + src_dt_size * src_offset;
349     const auto wei_offset = jcp.wei_plain ? g * wei_ic_sz + ocb * wei_ocb_sz
350                                           : g * wei_ocb_sz + ocb * wei_ic_sz;
351     const auto wei_base = weights + wei_dt_size * wei_offset;
352     const auto ptr_D = dst
353             + dst_dt_size
354                     * (n * dst_d_sz + od * dst_h_sz + oh * dst_w_sz
355                             + ow * jcp.oc_without_padding + g_oc);
356     char *const ptr_C = (jcp.use_buffer) ? c_buffer : (char *)ptr_D;
357 
358     const auto bias_w
359             = bias ? bias + (bias_d.blk_off(g_oc) * bia_dsz) : nullptr;
360     const auto nb_ic_b = nstl::min(jcp.nb_ic_blocking, jcp.nb_ic - icb)
361             - (is_ic_tail ? 1 : 0);
362 
363     const auto call_brgemm = [=](int brg_idx, int ic_block_s, int n_ic_blocks,
364                                      bool do_postops) {
365         for (int k = 0; k < n_ic_blocks; k++) {
366             const auto ic_off = (ic_block_s + k) * jcp.ic_block;
367             const auto src_ic = ic_off;
368             const auto wei_ic = ic + ic_off;
369             const auto ptr_A = src_base + src_dt_size * src_ic;
370             const auto ptr_B = wei_base + wei_dt_size * wei_ic * wei_oc_sz;
371             brg_batch[k].ptr.A = ptr_A;
372             brg_batch[k].ptr.B = ptr_B;
373             brg_batch[k].vvpad.top = 0;
374             brg_batch[k].vvpad.bottom = 0;
375         }
376 
377         // NOTE: avoid some costly tile reconfigurations here by keeping track
378         //       of the previous brg kernel tile configuration palette
379         // TODO: adjust harness to avoid even more tile reconfigurations
380         if (is_amx) {
381             const int curr_palette_idx = brg_kernel_palette_idx_[brg_idx];
382             if (curr_palette_idx != *last_palette_idx) {
383                 amx_tile_configure(brg_kernel_palette_[curr_palette_idx].p);
384                 *last_palette_idx = curr_palette_idx;
385             }
386         }
387 
388         const brgemm_kernel_t *brg_ker = brg_kernels_[brg_idx].get();
389         if (do_postops) {
390             const brgemm_post_ops_data_t post_ops_data {
391                     static_cast<const void *>(bias_w),
392                     &oscales[jcp.is_oc_scale * g_oc],
393                     post_ops_binary_rhs_arg_vec.data(),
394                     static_cast<size_t>(g_oc), 0, dst};
395 
396             brgemm_kernel_execute_postops(brg_ker, n_ic_blocks, brg_batch,
397                     (void *)ptr_C, (void *)ptr_D, post_ops_data,
398                     (void *)wsp_tile);
399         } else {
400             brgemm_kernel_execute(brg_ker, n_ic_blocks, brg_batch,
401                     (void *)ptr_C, (void *)wsp_tile);
402         }
403     };
404 
405     const auto do_post_work
406             = (need_postwork || jcp.use_buffer) && icc == ic_chunks - 1;
407 
408     if (nb_ic_b > 0) {
409         const auto brg_idx
410                 = get_brg_idx(kernel_init, is_os_tail, is_oc_tail, false);
411         call_brgemm(brg_idx, 0, nb_ic_b, do_post_work && !is_ic_tail);
412     }
413     if (is_ic_tail) {
414         const auto use_init_ker = (kernel_init && nb_ic_b == 0);
415         const auto brg_idx
416                 = get_brg_idx(use_init_ker, is_os_tail, is_oc_tail, true);
417 
418         call_brgemm(brg_idx, nb_ic_b, 1, do_post_work);
419     }
420 }
421 
422 template <cpu_isa_t isa>
execute_forward_all(const exec_ctx_t & ctx) const423 void brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
424         const exec_ctx_t &ctx) const {
425 
426     brgemm_exec_ctx_t brgemm_ctx(ctx, pd());
427 
428     const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
429 
430     const auto &jcp = pd()->jcp_;
431     const bool is_amx = brgemm_convolution_utils::is_amx(isa);
432 
433     brgemm_batch_element_t *const brg_batch_global
434             = (jcp.brg_type != brgemm_strd)
435             ? scratchpad.template get<brgemm_batch_element_t>(
436                     key_brgemm_primitive_batch)
437             : nullptr;
438     char *const c_buffer_global = (jcp.use_buffer)
439             ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
440             : nullptr;
441     char *inp_buffer_base = (jcp.is_rtus)
442             ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer)
443             : nullptr;
444     uint8_t *inp_buffer_mask_base = (jcp.is_rtus)
445             ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask)
446             : nullptr;
447 
448     if (jcp.is_os_blocking) {
449         const int os_chunks = div_up(jcp.nb_os, jcp.nb_os_blocking);
450         const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_oc * os_chunks;
451 
452 #define BRGC_WO(...) \
453     parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \
454         if (ithr >= work_amount) return; \
455         brgemm_batch_element_t *const brg_batch \
456                 = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \
457         char *const c_buffer = (jcp.use_buffer) \
458                 ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \
459                 : nullptr; \
460         char *inp_buffer = (jcp.is_rtus) \
461                 ? inp_buffer_base + ithr * src_dsz * jcp.inp_buffer_size \
462                 : nullptr; \
463         uint8_t *__restrict inp_buffer_mask = (jcp.is_rtus) \
464                 ? inp_buffer_mask_base + ithr * jcp.inp_buffer_mask_size \
465                 : nullptr; \
466         int last_n = -1; \
467         int last_g = -1; \
468         int last_palette_idx = -1; \
469         int start {0}, end {0}; \
470         balance211(work_amount, nthr, ithr, start, end); \
471         int n {0}, g {0}, ocb {0}, oss {0}; \
472         nd_iterator_init(start, __VA_ARGS__); \
473         for (auto work = start; work < end; work++) { \
474             if (jcp.is_rtus && (last_n != n || last_g != g)) \
475                 std::memset(inp_buffer_mask, 0, jcp.inp_buffer_mask_size); \
476             const auto osb_start = oss * jcp.nb_os_blocking; \
477             const auto osb_range \
478                     = nstl::min(jcp.nb_os - osb_start, jcp.nb_os_blocking); \
479             for (int osb = 0; osb < osb_range; osb++) { \
480                 const int os = (osb_start + osb) * jcp.os_block; \
481                 const int od = os / (OH * OW); \
482                 const int oh = (os % (OH * OW)) / OW; \
483                 const int ow = os % OW; \
484                 char *inp_buffer_sp = (jcp.is_rtus) \
485                         ? inp_buffer + src_dsz * os * jcp.LDA \
486                         : nullptr; \
487                 for (int icc = 0; icc < ic_chunks; icc++) { \
488                     if (jcp.is_rtus) \
489                         maybe_rtus(ithr, brgemm_ctx.src, inp_buffer_sp, \
490                                 inp_buffer_mask, g, n, icc, od, oh, ow); \
491                     exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, \
492                             inp_buffer_sp, g, n, ocb, od, oh, ow, icc, \
493                             &last_palette_idx); \
494                 } \
495             } \
496             last_n = n; \
497             last_g = g; \
498             nd_iterator_step(__VA_ARGS__); \
499         } \
500         if (is_amx) amx_tile_release(); \
501     });
502 
503         if (jcp.loop_order == loop_ndhwgc)
504             BRGC_WO(n, jcp.mb, oss, os_chunks, g, jcp.ngroups, ocb, jcp.nb_oc)
505         else if (jcp.loop_order == loop_ngcdhw)
506             BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, oss, os_chunks)
507         else
508             assert(!"Unknown loop order");
509 
510 #undef BRGC_WO
511 
512     } else {
513         const int work_amount
514                 = jcp.mb * jcp.ngroups * jcp.nb_oc * OD * OH * jcp.nb_ow;
515 
516 #define BRGC_WO(...) \
517     parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \
518         if (ithr >= work_amount) return; \
519         brgemm_batch_element_t *const brg_batch \
520                 = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \
521         char *const c_buffer = (jcp.use_buffer) \
522                 ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \
523                 : nullptr; \
524         int last_palette_idx = -1; \
525         int start {0}, end {0}; \
526         balance211(work_amount, nthr, ithr, start, end); \
527         int n {0}, g {0}, ocb {0}, od {0}, oh {0}, owb {0}; \
528         nd_iterator_init(start, __VA_ARGS__); \
529         for (auto work = start; work < end; work++) { \
530             for (int icc = 0; icc < ic_chunks; icc++) { \
531                 const int ow = owb * jcp.ow_block; \
532                 exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, nullptr, g, n, \
533                         ocb, od, oh, ow, icc, &last_palette_idx); \
534             } \
535             nd_iterator_step(__VA_ARGS__); \
536         } \
537         if (is_amx) amx_tile_release(); \
538     });
539 
540         if (jcp.loop_order == loop_ndhwgc)
541             BRGC_WO(n, jcp.mb, od, OD, oh, OH, owb, jcp.nb_ow, g, jcp.ngroups,
542                     ocb, jcp.nb_oc)
543         else if (jcp.loop_order == loop_ngcdhw)
544             BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, od, OD, oh, OH,
545                     owb, jcp.nb_ow)
546         else
547             assert(!"Unknown loop order");
548 
549 #undef BRGC_WO
550     }
551 }
552 
553 template struct brgemm_1x1_convolution_fwd_t<avx512_core>;
554 template struct brgemm_1x1_convolution_fwd_t<avx512_core_vnni>;
555 template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16>;
556 template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16_amx_int8>;
557 template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16_amx_bf16>;
558 
559 } // namespace x64
560 } // namespace cpu
561 } // namespace impl
562 } // namespace dnnl
563 
564 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
565