1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_types.h"
18 
19 #include "common/bfloat16.hpp"
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 
26 #include "cpu/platform.hpp"
27 #include "cpu/x64/cpu_isa_traits.hpp"
28 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
29 #include "cpu/x64/jit_brgemm_conv_utils.hpp"
30 #include "cpu/x64/jit_generator.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36 
37 using namespace dnnl::impl::status;
38 using namespace dnnl::impl::format_tag;
39 using namespace dnnl::impl::memory_tracking::names;
40 using namespace dnnl::impl::utils;
41 
42 using namespace prop_kind;
43 using namespace data_type;
44 
45 namespace brgemm_convolution_utils {
46 
init_tag(format_tag_t & tag,memory_desc_t & md,const memory_desc_wrapper & mdw,const format_tag_t tag_value,bool any_eligible)47 inline status_t init_tag(format_tag_t &tag, memory_desc_t &md,
48         const memory_desc_wrapper &mdw, const format_tag_t tag_value,
49         bool any_eligible) {
50 
51     if (mdw.format_kind() == format_kind::any) {
52         if (any_eligible) {
53             CHECK(memory_desc_init_by_tag(md, tag_value));
54             tag = tag_value;
55         } else {
56             tag = format_tag::undef;
57         }
58     } else {
59         tag = mdw.matches_one_of_tag(tag_value);
60     }
61 
62     if (tag != tag_value) return status::unimplemented;
63 
64     return status::success;
65 }
66 
is_amx(cpu_isa_t isa)67 bool is_amx(cpu_isa_t isa) {
68     return one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
69 }
70 
post_ops_ok(jit_brgemm_conv_conf_t & jcp,primitive_attr_t & attr,const memory_desc_wrapper & dst_d)71 bool post_ops_ok(jit_brgemm_conv_conf_t &jcp, primitive_attr_t &attr,
72         const memory_desc_wrapper &dst_d) {
73     using namespace injector;
74 
75     const auto &post_ops = attr.post_ops_;
76 
77     return injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(),
78             {sum, eltwise, binary}, post_ops, &dst_d,
79             false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
80             false /*sum_requires_zp_zero*/,
81             {broadcasting_strategy_t::per_oc,
82                     broadcasting_strategy_t::scalar}));
83 }
84 
pick_tags(jit_brgemm_conv_conf_t & jcp,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md)85 status_t pick_tags(jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md,
86         memory_desc_t &weights_md, memory_desc_t &dst_md,
87         memory_desc_t &bias_md) {
88     format_tag_t src_tag, dst_tag, wei_tag;
89     dst_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
90 
91     const memory_desc_wrapper src_d(&src_md);
92     const memory_desc_wrapper weights_d(&weights_md);
93     const memory_desc_wrapper dst_d(&dst_md);
94     const memory_desc_wrapper bias_d(&bias_md);
95     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
96 
97     const bool is_1d = jcp.ndims == 3;
98     const bool is_2d = jcp.ndims == 4;
99     const bool is_3d = jcp.ndims == 5;
100 
101     if (jcp.wei_plain) {
102         jcp.LDB = jcp.oc;
103         if (is_3d) {
104             if (jcp.wei_dt == f32)
105                 wei_tag = with_groups ? gdhwio : dhwio;
106             else if (jcp.wei_dt == s8)
107                 wei_tag = with_groups ? gdhwIo4i : dhwIo4i;
108             else if (jcp.wei_dt == bf16) {
109                 wei_tag = with_groups ? gdhwIo2i : dhwIo2i;
110             } else
111                 return status::unimplemented;
112         } else if (is_1d) {
113             if (jcp.wei_dt == f32)
114                 wei_tag = with_groups ? gwio : wio;
115             else if (jcp.wei_dt == s8)
116                 wei_tag = with_groups ? gwIo4i : wIo4i;
117             else if (jcp.wei_dt == bf16) {
118                 wei_tag = with_groups ? gwIo2i : wIo2i;
119             } else
120                 return status::unimplemented;
121         } else {
122             assert(is_2d);
123             UNUSED(is_2d);
124             if (jcp.wei_dt == f32)
125                 wei_tag = with_groups ? ghwio : hwio;
126             else if (jcp.wei_dt == s8)
127                 wei_tag = with_groups ? ghwIo4i : hwIo4i;
128             else if (jcp.wei_dt == bf16) {
129                 wei_tag = with_groups ? ghwIo2i : hwIo2i;
130             } else
131                 return status::unimplemented;
132         }
133     } else {
134         jcp.LDB = jcp.oc_block;
135         if (jcp.oc_block == 64) {
136             if (is_3d) {
137                 if (jcp.wei_dt == f32)
138                     wei_tag = with_groups ? gOdhwi64o : Odhwi64o;
139                 else if (jcp.wei_dt == s8) {
140                     if (jcp.is_ic_padded)
141                         wei_tag = with_groups ? gOdhwI16i64o4i : OdhwI16i64o4i;
142                     else
143                         wei_tag = with_groups ? gOdhwI64o4i : OdhwI64o4i;
144                 } else if (jcp.wei_dt == bf16) {
145                     if (jcp.is_ic_padded)
146                         wei_tag = with_groups ? gOdhwI16i64o2i : OdhwI16i64o2i;
147                     else
148                         wei_tag = with_groups ? gOdhwI64o2i : OdhwI64o2i;
149                 } else
150                     return status::unimplemented;
151             } else if (is_1d) {
152                 if (jcp.wei_dt == f32)
153                     wei_tag = with_groups ? gOwi64o : Owi64o;
154                 else if (jcp.wei_dt == s8) {
155                     if (jcp.is_ic_padded)
156                         wei_tag = with_groups ? gOwI16i64o4i : OwI16i64o4i;
157                     else
158                         wei_tag = with_groups ? gOwI64o4i : OwI64o4i;
159                 } else if (jcp.wei_dt == bf16) {
160                     if (jcp.is_ic_padded)
161                         wei_tag = with_groups ? gOwI16i64o2i : OwI16i64o2i;
162                     else
163                         wei_tag = with_groups ? gOwI64o2i : OwI64o2i;
164                 } else
165                     return status::unimplemented;
166             } else {
167                 assert(is_2d);
168                 UNUSED(is_2d);
169                 if (jcp.wei_dt == f32)
170                     wei_tag = with_groups ? gOhwi64o : Ohwi64o;
171                 else if (jcp.wei_dt == s8) {
172                     if (jcp.is_ic_padded)
173                         wei_tag = with_groups ? gOhwI16i64o4i : OhwI16i64o4i;
174                     else
175                         wei_tag = with_groups ? gOhwI64o4i : OhwI64o4i;
176                 } else if (jcp.wei_dt == bf16) {
177                     if (jcp.is_ic_padded)
178                         wei_tag = with_groups ? gOhwI16i64o2i : OhwI16i64o2i;
179                     else
180                         wei_tag = with_groups ? gOhwI64o2i : OhwI64o2i;
181                 } else
182                     return status::unimplemented;
183             }
184         } else if (jcp.oc_block == 48) {
185             if (is_3d) {
186                 if (jcp.wei_dt == f32)
187                     wei_tag = with_groups ? gOdhwi48o : Odhwi48o;
188                 else if (jcp.wei_dt == s8) {
189                     if (jcp.is_ic_padded)
190                         wei_tag = with_groups ? gOdhwI16i48o4i : OdhwI16i48o4i;
191                     else
192                         wei_tag = with_groups ? gOdhwI48o4i : OdhwI48o4i;
193                 } else if (jcp.wei_dt == bf16) {
194                     if (jcp.is_ic_padded)
195                         wei_tag = with_groups ? gOdhwI16i48o2i : OdhwI16i48o2i;
196                     else
197                         wei_tag = with_groups ? gOdhwI48o2i : OdhwI48o2i;
198                 } else
199                     return status::unimplemented;
200             } else if (is_1d) {
201                 if (jcp.wei_dt == f32)
202                     wei_tag = with_groups ? gOwi48o : Owi48o;
203                 else if (jcp.wei_dt == s8) {
204                     if (jcp.is_ic_padded)
205                         wei_tag = with_groups ? gOwI16i48o4i : OwI16i48o4i;
206                     else
207                         wei_tag = with_groups ? gOwI48o4i : OwI48o4i;
208                 } else if (jcp.wei_dt == bf16) {
209                     if (jcp.is_ic_padded)
210                         wei_tag = with_groups ? gOwI16i48o2i : OwI16i48o2i;
211                     else
212                         wei_tag = with_groups ? gOwI48o2i : OwI48o2i;
213                 } else
214                     return status::unimplemented;
215             } else {
216                 assert(is_2d);
217                 UNUSED(is_2d);
218                 if (jcp.wei_dt == f32)
219                     wei_tag = with_groups ? gOhwi48o : Ohwi48o;
220                 else if (jcp.wei_dt == s8) {
221                     if (jcp.is_ic_padded)
222                         wei_tag = with_groups ? gOhwI16i48o4i : OhwI16i48o4i;
223                     else
224                         wei_tag = with_groups ? gOhwI48o4i : OhwI48o4i;
225                 } else if (jcp.wei_dt == bf16) {
226                     if (jcp.is_ic_padded)
227                         wei_tag = with_groups ? gOhwI16i48o2i : OhwI16i48o2i;
228                     else
229                         wei_tag = with_groups ? gOhwI48o2i : OhwI48o2i;
230                 } else
231                     return status::unimplemented;
232             }
233         } else if (jcp.oc_block == 32) {
234             if (is_3d) {
235                 if (jcp.wei_dt == f32)
236                     wei_tag = with_groups ? gOdhwi32o : Odhwi32o;
237                 else if (jcp.wei_dt == s8) {
238                     if (jcp.is_ic_padded)
239                         wei_tag = with_groups ? gOdhwI16i32o4i : OdhwI16i32o4i;
240                     else
241                         wei_tag = with_groups ? gOdhwI32o4i : OdhwI32o4i;
242                 } else if (jcp.wei_dt == bf16) {
243                     if (jcp.is_ic_padded)
244                         wei_tag = with_groups ? gOdhwI16i32o2i : OdhwI16i32o2i;
245                     else
246                         wei_tag = with_groups ? gOdhwI32o2i : OdhwI32o2i;
247                 } else
248                     return status::unimplemented;
249             } else if (is_1d) {
250                 if (jcp.wei_dt == f32)
251                     wei_tag = with_groups ? gOwi32o : Owi32o;
252                 else if (jcp.wei_dt == s8) {
253                     if (jcp.is_ic_padded)
254                         wei_tag = with_groups ? gOwI16i32o4i : OwI16i32o4i;
255                     else
256                         wei_tag = with_groups ? gOwI32o4i : OwI32o4i;
257                 } else if (jcp.wei_dt == bf16) {
258                     if (jcp.is_ic_padded)
259                         wei_tag = with_groups ? gOwI16i32o2i : OwI16i32o2i;
260                     else
261                         wei_tag = with_groups ? gOwI32o2i : OwI32o2i;
262                 } else
263                     return status::unimplemented;
264             } else {
265                 assert(is_2d);
266                 UNUSED(is_2d);
267                 if (jcp.wei_dt == f32)
268                     wei_tag = with_groups ? gOhwi32o : Ohwi32o;
269                 else if (jcp.wei_dt == s8) {
270                     if (jcp.is_ic_padded)
271                         wei_tag = with_groups ? gOhwI16i32o4i : OhwI16i32o4i;
272                     else
273                         wei_tag = with_groups ? gOhwI32o4i : OhwI32o4i;
274                 } else if (jcp.wei_dt == bf16) {
275                     if (jcp.is_ic_padded)
276                         wei_tag = with_groups ? gOhwI16i32o2i : OhwI16i32o2i;
277                     else
278                         wei_tag = with_groups ? gOhwI32o2i : OhwI32o2i;
279                 } else
280                     return status::unimplemented;
281             }
282         } else {
283             if (is_3d) {
284                 if (jcp.wei_dt == f32)
285                     wei_tag = with_groups ? gOdhwi16o : Odhwi16o;
286                 else if (jcp.wei_dt == s8) {
287                     if (jcp.is_ic_padded)
288                         wei_tag = with_groups ? gOdhwI16i16o4i : OdhwI16i16o4i;
289                     else
290                         wei_tag = with_groups ? gOdhwI16o4i : OdhwI16o4i;
291                 } else if (jcp.wei_dt == bf16) {
292                     if (jcp.is_ic_padded)
293                         wei_tag = with_groups ? gOdhwI16i16o2i : OdhwI16i16o2i;
294                     else
295                         wei_tag = with_groups ? gOdhwI16o2i : OdhwI16o2i;
296                 } else
297                     return status::unimplemented;
298             } else if (is_1d) {
299                 if (jcp.wei_dt == f32)
300                     wei_tag = with_groups ? gOwi16o : Owi16o;
301                 else if (jcp.wei_dt == s8) {
302                     if (jcp.is_ic_padded)
303                         wei_tag = with_groups ? gOwI16i16o4i : OwI16i16o4i;
304                     else
305                         wei_tag = with_groups ? gOwI16o4i : OwI16o4i;
306                 } else if (jcp.wei_dt == bf16) {
307                     if (jcp.is_ic_padded)
308                         wei_tag = with_groups ? gOwI16i16o2i : OwI16i16o2i;
309                     else
310                         wei_tag = with_groups ? gOwI16o2i : OwI16o2i;
311                 } else
312                     return status::unimplemented;
313             } else {
314                 assert(is_2d);
315                 UNUSED(is_2d);
316 
317                 if (jcp.wei_dt == f32)
318                     wei_tag = with_groups ? gOhwi16o : Ohwi16o;
319                 else if (jcp.wei_dt == s8) {
320                     if (jcp.is_ic_padded)
321                         wei_tag = with_groups ? gOhwI16i16o4i : OhwI16i16o4i;
322                     else
323                         wei_tag = with_groups ? gOhwI16o4i : OhwI16o4i;
324                 } else if (jcp.wei_dt == bf16) {
325                     if (jcp.is_ic_padded)
326                         wei_tag = with_groups ? gOhwI16i16o2i : OhwI16i16o2i;
327                     else
328                         wei_tag = with_groups ? gOhwI16o2i : OhwI16o2i;
329                 } else
330                     return status::unimplemented;
331             }
332         }
333     }
334 
335     src_tag = dst_tag;
336 
337     const bool any_eligible = (jcp.prop_kind == prop_kind::forward_inference
338             || jcp.wei_dt == data_type::s8 || is_amx(jcp.isa));
339     CHECK(init_tag(jcp.src_tag, src_md, src_d, src_tag, any_eligible));
340     CHECK(init_tag(jcp.dst_tag, dst_md, dst_d, dst_tag, any_eligible));
341     CHECK(init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag, true));
342 
343     return status::success;
344 }
345 
346 struct brg_blocking_t : public jit_brgemm_conv_conf_t {
347     struct array_in_loop_t {
348         dim_t itersize;
349         float repeatn;
350         float overlap;
setdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t::array_in_loop_t351         void set(dim_t iter_s, float rpt, float ovlp = 1.f) {
352             itersize = iter_s;
353             repeatn = rpt;
354             overlap = ovlp;
355         }
356     };
357 
358     struct loop_t {
359         array_in_loop_t src;
360         array_in_loop_t wei;
361         array_in_loop_t dst;
362     };
363 
brg_blocking_tdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t364     brg_blocking_t() : jit_brgemm_conv_conf_t() { init(); }
brg_blocking_tdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t365     brg_blocking_t(const jit_brgemm_conv_conf_t &jcp)
366         : jit_brgemm_conv_conf_t(jcp) {
367         init();
368     }
initdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t369     void init() {
370         ur = 0;
371         ur_block = 0;
372         ur_block_tail = 0;
373         eff = 0.f;
374         nb_kd = 0;
375         nb_kh = 0;
376         nb_kw = 0;
377         sp = 0;
378         sp_block = 0;
379         nb_sp = 0;
380         eff = 0;
381     }
382 
383     int ur, ur_block, ur_block_tail;
384     int nb_kd, nb_kh, nb_kw;
385     float eff;
386     static unsigned L1;
387     static unsigned L2;
388     static unsigned L3;
389     // These are rough estimates of the latency (relative) of access to various
390     // cache levels. This is enough for an estimation of data access cost.
391     // TODO: Improve memory access estimates
392     static constexpr float L1_k = 1.f;
393     static constexpr float L2_k = 3.f;
394     static constexpr float L3_k = 15.f;
395     // TODO: At the moment, we are primarily evaluating the fit of the data into
396     // the L1/L2. Need to take into account the difference between the L3 and
397     // memory.
398     static constexpr float mem_k = 15.f;
399     static constexpr int bench_iterations = 1;
400     static constexpr int max_regs = 32;
401     static constexpr int bcast_simd = 16;
402 
403     int sp, sp_block, nb_sp;
404     static int last_ic_block_size;
405 
get_from_jcpdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t406     void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
save_to_jcpdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t407     void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; }
408 
409     status_t estimate_brgemm_ur(int spb);
410     status_t get_brgemm_ur(
411             const primitive_attr_t *attr, const memory_desc_t &dst_md);
412 
413     float io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
414             bool is_broadcast, bool is_shared) const;
415 
416     float io_k(const loop_t loop, const array_in_loop_t arr, float pk,
417             bool is_broadcast, bool is_shared) const;
418 
419     void select_ic_block();
420 
421     void update_blocks();
422     bool fast_check_oc_block() const;
423     float est_eff();
424     void iterate_ker_block(brg_blocking_t &best_brgb, int kd_block,
425             int kh_block, bool maybe_use_buffer, int max_ow_block_thr);
426     status_t calc_blocks();
427 
428     bool fast_check_oc_block_1x1() const;
429     float est_eff_1x1();
430     void calc_blocks_1x1();
431 
432     // utils
get_inp_sizednnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t433     static int get_inp_size(
434             int max_src_size, int dst_size, int k, int stride, int dilate) {
435         auto adj_str = nstl::min(k, stride);
436         const auto res = nstl::min(max_src_size,
437                 calculate_end_padding(0, dst_size, 0, adj_str,
438                         calculate_extended_filter_size(k, dilate)));
439         return res;
440     }
441 
squeeze_valdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t442     static float squeeze_val(float eff, float koeff) {
443         if (koeff <= 0) return 1;
444         if (koeff == 1) return eff;
445         const auto k = 1.f / koeff;
446         return (k > 1.f) ? (k - 1 + eff) / k : eff * koeff;
447     }
448 
estimate_urdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t449     static int estimate_ur(int oc_block) {
450         const auto est_ur = (oc_block == 64)
451                 ? 6
452                 : ((oc_block == 48) ? 9 : ((oc_block == 32) ? 14 : 28));
453         return est_ur;
454     }
455 
inp_wdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t456     int inp_w(int out_w, int ker_w) const {
457         return get_inp_size(iw, out_w, ker_w, stride_w, dilate_w);
458     }
459 
rnd_simddnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t460     int rnd_simd(int val) const { return rnd_up(val, simd_w); }
461 
rnd_inp_simddnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t462     int rnd_inp_simd(int out_w, int ker_w, int vic) const {
463         const auto vsp = inp_w(out_w, ker_w);
464         return ((stride_w == 1 && vic >= ic) ? rnd_up(vsp * vic, simd_w)
465                                              : vsp * rnd_up(vic, simd_w));
466     }
467 
468     static constexpr int MAXNLOOPS = 32;
469     loop_t loop[MAXNLOOPS];
470 };
471 
472 unsigned brg_blocking_t::L1;
473 unsigned brg_blocking_t::L2;
474 unsigned brg_blocking_t::L3;
475 int brg_blocking_t::last_ic_block_size;
476 
io_k(dim_t src,dim_t wei,dim_t dst,float n,float pk,bool is_broadcast,bool is_shared) const477 float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
478         bool is_broadcast, bool is_shared) const {
479     if (n < 1) return 0;
480     if (n == 1) return pk;
481     const auto amount = src * src_dsz + wei * wei_dsz + dst * dst_dsz
482             + (use_buffer ? dst * acc_dsz : 0);
483     const auto amount_L1 = is_broadcast ? src * src_dsz : amount;
484     const auto k = is_broadcast
485             ? ((amount_L1 < L1) ? L1_k
486                                 : ((amount < L2) ? L2_k
487                                                  : (is_shared ? L3_k : mem_k)))
488             : ((amount < L2) ? L2_k : (is_shared ? L3_k : mem_k));
489     const auto cost = pk + k * (n - 1);
490     return cost / n;
491 }
492 
io_k(const loop_t loop,const array_in_loop_t arr,float pk,bool is_broadcast,bool is_shared) const493 float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
494         float pk, bool is_broadcast, bool is_shared) const {
495     return io_k(loop.src.itersize, loop.wei.itersize, loop.dst.itersize,
496             arr.repeatn * arr.overlap, pk, is_broadcast, is_shared);
497 }
498 
select_ic_block()499 void brg_blocking_t::select_ic_block() {
500     if (is_1x1 && is_amx(isa)) {
501         // TODO: merge with non-1x1 code block below
502         const int ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
503         assert(ic < ic_padded_block || ic % ic_padded_block == 0);
504         MAYBE_UNUSED(ic_padded_block);
505         ic_block = ic;
506         nb_ic = utils::div_up(ic, ic_block); // trivially 1 for now
507         return;
508     }
509     auto nb_simd = utils::div_up(ic, simd_w);
510     auto max_simd_blocks = nstl::min(5 * simd_w, nb_simd);
511     const auto nb_icb_eff_threshold = 0.5f;
512     const auto padded_ic = last_ic_block_size * (is_ic_padded ? 16 : 1);
513     if (is_amx(isa)) {
514         if (ic * kw_sets < simd_w) {
515             // this is current requirement from brgemm kernel
516             ic_block = rnd_up(ic, last_ic_block_size);
517         } else {
518             if (exec_type == exec_trans) {
519                 auto simd_blocks = 1;
520                 for (int nb_icb = max_simd_blocks; nb_icb >= 1; nb_icb--) {
521                     auto nb_icb_eff = static_cast<float>(nb_simd)
522                             / rnd_up(nb_simd, nb_icb);
523                     if (nb_icb_eff >= nb_icb_eff_threshold) {
524                         simd_blocks = nb_icb;
525                         break;
526                     }
527                 }
528                 ic_block = simd_blocks * simd_w;
529             } else
530                 ic_block = simd_w;
531         }
532     } else {
533         const auto est_ur = nstl::min(sp_block, estimate_ur(oc_block));
534         const auto inp_ur = is_os_blocking ? est_ur : inp_w(est_ur, kw_block);
535 
536         if (kw_block > 1) {
537             // try to fit src into L1
538             const auto inp_per_ic = static_cast<unsigned int>(inp_ur) * src_dsz;
539             max_simd_blocks = saturate(1, max_simd_blocks,
540                     static_cast<int>(L1 / (inp_per_ic * simd_w)));
541         }
542         // try to fit all batch for ur into L2
543         const auto wei_per_ic = static_cast<unsigned int>(kd_block) * kh_block
544                 * kw_block * oc_block * wei_dsz;
545         const auto inp_per_ic = static_cast<unsigned int>(kd_block) * kh_block
546                 * inp_ur * src_dsz;
547         const auto out_size
548                 = static_cast<unsigned int>(ur) * oc_block * dst_dsz;
549 
550         max_simd_blocks = saturate(1, max_simd_blocks,
551                 static_cast<int>((L2 - out_size)
552                         / ((wei_per_ic + inp_per_ic) * simd_w)));
553 
554         auto simd_blocks = 1;
555         for (int nb_icb = nstl::min(max_simd_blocks, nb_simd); nb_icb >= 1;
556                 nb_icb--) {
557             auto nb_icb_eff
558                     = static_cast<float>(nb_simd) / rnd_up(nb_simd, nb_icb);
559             if (nb_icb_eff >= nb_icb_eff_threshold) {
560                 simd_blocks = nb_icb;
561                 break;
562             }
563         }
564 
565         ic_block = nstl::min(
566                 (exec_type == exec_trans) ? rnd_up(ic, padded_ic) : ic,
567                 simd_blocks * simd_w);
568     }
569     nb_ic = utils::div_up(ic, ic_block);
570 }
571 
estimate_brgemm_ur(int spb)572 status_t brg_blocking_t::estimate_brgemm_ur(int spb) {
573     // Simple simulation of brgemm_desc init
574     if (sp_block <= 0) return status::invalid_arguments;
575     LDA = is_rtus
576             ? (ic_block)
577             : (kh_sets > 1 ? kh_sets : 1) * (kw_sets > 1 ? kw_sets : stride_w)
578                     * (exec_type == exec_trans ? ic_block
579                                                : ngroups * ic_without_padding);
580     LDB = oc_block;
581     LDC = use_buffer ? oc_block : oc_without_padding;
582 
583     // Configure matrix sizes
584     // for amx if ic_block != ic then we use exec_trans so K is ic_block
585     const auto padded_ic = last_ic_block_size * (is_ic_padded ? 16 : 1);
586 
587     icp = rnd_up(ic, padded_ic);
588     M = brgM = sp >= sp_block ? sp_block : 0;
589     M_tail = brgM_tail = sp % sp_block;
590     if (is_os_blocking) {
591         if (!is_1x1) M_tail = brgM_tail = (oh * ow) % sp_block;
592         oskip = ((ext_kw - 1) / stride_w) * stride_h + (stride_h - 1) * ow;
593 
594         brgM = sp_block + oskip * (div_up(M, ow) - 1);
595 
596         // round up brgM to help brgemm kernel use max amx_h as brgemm bd_block
597         if (use_M_mask == 2) {
598             int ibrgM = 0;
599             const auto adj_ow = ow_block + oskip;
600             while (ibrgM < brgM) {
601                 if (ibrgM % adj_ow < ow_block)
602                     ibrgM += amx_h;
603                 else
604                     ibrgM++;
605             }
606             brgM = ibrgM;
607         } else
608             brgM = rnd_up(brgM, amx_h);
609 
610         brgM_tail = brgM;
611     }
612 
613     N = oc >= oc_block ? oc_block : 0;
614     N_tail = oc % oc_block;
615     K = kh_sets * kw_sets * (ic >= ic_block ? ic_block : 0);
616     K_tail = kh_sets * kw_sets
617             * (exec_type == exec_trans
618                             ? ic_block
619                             : rnd_up(ic % ic_block, last_ic_block_size));
620 
621     const auto vK = K > 0 ? K : K_tail;
622     const auto vM = M > 0 ? M : M_tail;
623     const auto vN = N > 0 ? N : N_tail;
624 
625     const float alpha = 1.0;
626     const float beta = 0.0;
627     brgemm_t brg;
628     CHECK(brgemm_desc_init(&brg, isa, brgemm_addr, src_dt, wei_dt, false, false,
629             brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK));
630     ur = brg.bd_block * (is_amx(isa) ? brg.bd_block2 : 1);
631     ur_block = brg.bd_block;
632     if (is_1x1 && is_amx(isa) && M > 0 && M_tail > 0) {
633         brgemm_t brg_sp_tail;
634         CHECK(brgemm_desc_init(&brg_sp_tail, isa, brgemm_addr, src_dt, wei_dt,
635                 false, false, brgemm_row_major, alpha, beta, LDA, LDB, LDC,
636                 M_tail, vN, vK));
637         ur_block_tail = brg_sp_tail.bd_block;
638     } else {
639         ur_block_tail = 0;
640     }
641     return status::success;
642 }
643 
get_brgemm_ur(const primitive_attr_t * attr,const memory_desc_t & dst_md)644 status_t brg_blocking_t::get_brgemm_ur(
645         const primitive_attr_t *attr, const memory_desc_t &dst_md) {
646     // Detailed simulation of brgemm convolution init
647     if (sp_block <= 0 || ic_block <= 0 || sp_block <= 0 || oc_block <= 0)
648         return status::invalid_arguments;
649     CHECK(estimate_brgemm_ur(is_os_blocking ? os_block : ow_block));
650 
651     LDD = oc_without_padding;
652 
653     const float alpha = 1.0;
654     const float beta = 1.0;
655     const float beta_init = 0.0;
656 
657     for (int i = 0; i < M; i++) {
658         auto vM = i + 1;
659         // init only needed brgemm descriptors
660         if ((utils::one_of(exec_type, exec_trans, exec_vpad) || is_1x1)
661                 && vM != M && vM != M_tail)
662             continue;
663         for (int i_init = 0; i_init < 2; i_init++) {
664             for (int i_N = 0; i_N < 2; i_N++) {
665                 for (int i_K = 0; i_K < 2; i_K++) {
666                     auto vbeta = (i_init) ? beta_init : beta;
667                     auto vN = (i_N) ? N_tail : N;
668                     auto vK = (i_K) ? K_tail : K;
669                     if (vN == 0 || vK == 0) continue;
670                     brgemm_t brg;
671                     brgemm_strides_t brg_strides;
672                     brg_strides.stride_a = ngroups * ic_without_padding
673                             * (dilate_w + 1) * src_dsz;
674                     //weights are padded by oc_block and last_ic_block
675                     brg_strides.stride_b = rnd_up(ic, last_ic_block_size)
676                             * rnd_up(oc, oc_block) * wei_dsz;
677                     const auto strides_ptr = (brg_type == brgemm_strd)
678                             ? &brg_strides
679                             : nullptr;
680                     CHECK(brgemm_desc_init(&brg, isa, brg_type, src_dt, wei_dt,
681                             false, false, brgemm_row_major, alpha, vbeta, LDA,
682                             LDB, LDC, vM, vN, vK, strides_ptr));
683 
684                     brgemm_attr_t brgattr;
685                     brgattr.max_bs = max_batch;
686                     const auto max_vpad = (exec_type == exec_vpad)
687                             ? nstl::max(l_pad, r_pad)
688                             : 0;
689                     brgattr.max_top_vpad = max_vpad;
690                     brgattr.max_bottom_vpad = max_vpad;
691                     CHECK(brgemm_desc_set_attr(&brg, brgattr));
692 
693                     brg.with_sum = with_sum;
694                     CHECK(brgemm_desc_set_postops(
695                             &brg, attr, &dst_md, LDD, bia_dt));
696                 }
697             }
698         }
699     }
700 
701     return status::success;
702 }
703 
update_blocks()704 void brg_blocking_t::update_blocks() {
705     if (sp_block <= 0
706             || utils::one_of(0, od_block, oh_block, ic_block, oc_block,
707                     kd_block, kh_block, kw_block, os_block, ow_block))
708         return;
709 
710     nb_od = div_up(od, od_block);
711     nb_oh = div_up(oh, oh_block);
712     nb_ic = div_up(ic, ic_block);
713     nb_oc = div_up(oc, oc_block);
714     nb_kd = div_up(kd, kd_block);
715     nb_kh = div_up(kh, kh_block);
716     nb_kw = div_up(kw, kw_block);
717     nb_ow = div_up(ow, ow_block);
718     if (is_os_blocking) {
719         nb_os = div_up(os, os_block);
720         sp = os;
721         sp_block = os_block;
722         nb_sp = nb_os;
723     } else {
724         sp = ow;
725         sp_block = ow_block;
726         nb_sp = nb_ow;
727         iw_block = get_inp_size(iwp, ow_block, kw, stride_w, dilate_w);
728     }
729 }
730 
fast_check_oc_block() const731 bool brg_blocking_t::fast_check_oc_block() const {
732     // This function for reducing the number of blocking variants
733     // TODO: eliminate heuristic in this function
734     const auto rnd_oc = rnd_up(oc, 16);
735     auto res = false;
736     if (oc_block == 64) {
737         res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz < 192 * 4);
738     } else if (oc_block == 48) {
739         const bool big_spatial
740                 = id * ih * iw > 81 * stride_d * stride_h * stride_w;
741         res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz <= 384 * 4
742                 && big_spatial);
743     } else
744         res = true;
745 
746     return res;
747 }
748 
est_eff()749 float brg_blocking_t::est_eff() {
750     const auto ocblock = oc_block / 16;
751 
752     const auto brgemm_microkernel_eff
753             = (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs);
754 
755     const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
756     const auto brgemm_eff = squeeze_val(ur
757                     * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
758                     / 64,
759             0.5f);
760 
761     const auto sp_amount = nb_od * nb_oh * nb_sp;
762     const auto work_amount = mb * ngroups * nb_oc * sp_amount;
763     const auto sp_eff = (static_cast<float>(sp) / rnd_up(sp, sp_block));
764 
765     const auto thr_eff = static_cast<float>(work_amount)
766             / utils::rnd_up(work_amount, nthr);
767 
768     const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
769 
770     const auto job = div_up(work_amount, nthr);
771 
772     auto job_eff = 1.f;
773     if (job < nthr) {
774         std::vector<dim_t> thr_jobs(nthr);
775 
776         for (int ithr = 0; ithr < nthr; ithr++) {
777             thr_jobs[ithr] = 0;
778             if (ithr >= work_amount) continue;
779             dim_t thr_job = 0;
780             int start {0}, end {0};
781             balance211(work_amount, nthr, ithr, start, end);
782             int n {0}, g {0}, ocb {0}, odp {0}, ohp {0}, spb {0};
783             if (loop_order == loop_ndhwgc)
784                 nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp, g,
785                         ngroups, ocb, nb_oc);
786             else if (loop_order == loop_ngcdhw)
787                 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp, od,
788                         ohp, oh, spb, nb_sp);
789 
790             for (auto work = start; work < end; work++) {
791                 const int ocp = ocb * oc_block;
792                 const auto oc_sz = nstl::min(oc - ocp, oc_block);
793                 int sp_sz = 0;
794                 const int spp = spb * sp_block;
795                 sp_sz = nstl::min(sp - spp, sp_block);
796                 thr_job += sp_sz * oc_sz;
797 
798                 if (loop_order == loop_ndhwgc)
799                     nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g,
800                             ngroups, ocb, nb_oc);
801                 else if (loop_order == loop_ngcdhw)
802                     nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od,
803                             ohp, oh, spb, nb_sp);
804             }
805             thr_jobs[ithr] = thr_job;
806         }
807 
808         dim_t max_job = 0;
809         dim_t sum_job = 0;
810         for (int ithr = 0; ithr < nthr; ithr++) {
811             if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
812             sum_job += thr_jobs[ithr];
813         }
814         job_eff = max_job == 0 ? 1
815                                : static_cast<float>(sum_job) / (max_job * nthr);
816 
817     } else {
818         job_eff = thr_eff;
819     }
820 
821     const auto ic_blocking_size = ic_block * nb_ic_blocking;
822     const auto oc_blocking_size = oc_block * ic_blocking_size;
823 
824     int l = -1;
825 
826     // -- brgemm kernel: loop by simd_w  --
827     l++;
828     const auto inp_ur = inp_w(ur, kw_block);
829     loop[l].src.set(inp_ur * simd_w, 1, bcast_simd);
830     loop[l].dst.set(0, 1);
831     loop[l].wei.set(oc_block, 1);
832 
833     // -- brgemm kernel: loop by kw in kw_block  --
834     l++;
835     auto src_is = rnd_inp_simd(ur, kw_block, ic_blocking_size);
836     loop[l].src.set(src_is, 1, kw_block);
837     loop[l].dst.set(0, 1);
838     loop[l].wei.set(oc_blocking_size, 1);
839 
840     // -- brgemm kernel: loop by batch (grouped by kw_block) in ur  --
841     l++;
842     loop[l].src.set(src_is, 1);
843     loop[l].dst.set(0, 1);
844     auto wei_is = kw_block * oc_blocking_size;
845     loop[l].wei.set(wei_is, 1);
846     // -- brgemm kernel: loop by ur in sp_block --
847     l++;
848     const auto nb_ur = div_up(sp_block, ur);
849     loop[l].src.set(kd_block * kh_block * src_is, 1);
850     loop[l].dst.set(ur * oc_block, 1);
851     wei_is = kd_block * kh_block * kw_block * oc_blocking_size;
852     loop[l].wei.set(wei_is, nb_ur);
853 
854     // -- harness: loop by k_blocks in ks --
855     l++;
856     loop[l].src.set(kd_block * kh_block
857                     * rnd_inp_simd(sp_block, kw_block, ic_blocking_size),
858             1);
859     loop[l].dst.set(sp_block * oc_block, nb_kd * nb_kh * nb_kw);
860     loop[l].wei.set(wei_is, 1);
861 
862     // -- brgemm kernel: loop by ic_chunks --
863     l++;
864     const auto ic_chunks = div_up(nb_ic, nb_ic_blocking);
865     loop[l].src.set(kd * kh * rnd_inp_simd(sp_block, kw, ic_blocking_size), 1);
866     loop[l].dst.set(sp_block * oc_block, ic_chunks);
867     wei_is = kd * kh * kw * oc_blocking_size;
868     loop[l].wei.set(wei_is, 1);
869 
870     const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount;
871     const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc));
872     const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block);
873     const auto nsimd_oc_thr = div_up(oc_thr, simd_w);
874 
875     const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1;
876     const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
877     const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
878 
879     const auto dim_oh = nb_sp * dim_sp;
880     const auto nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh));
881     const auto oh_thr = nstl::min(oh, nb_oh_thr * oh_block);
882 
883     const auto dim_od = nb_oh * dim_oh;
884     const auto nb_od_thr = nstl::min(nb_od, div_up(job, dim_od));
885     const auto od_thr = nstl::min(od, nb_od_thr * od_block);
886 
887     src_is = kd * kh * rnd_inp_simd(sp_block, kw, ic);
888 
889     auto wei_op = kd * kh * kw * ocblock * ic;
890     if (loop_order == loop_ndhwgc) {
891         // -- harness: loop by oc_block --
892         l++;
893         loop[l].src.set(src_is, nb_oc_thr);
894         loop[l].dst.set(sp_block * oc_block, 1);
895         wei_is = kd * kh * kw * oc_block * ic;
896         wei_op = kd * kh * kw * nsimd_oc_thr * ic;
897         loop[l].wei.set(wei_is, 1);
898     }
899 
900     // -- harness: loop by sp_blocks --
901     l++;
902     loop[l].src.set(src_is, 1);
903     const auto rnd_oc_for_sp
904             = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock);
905     loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
906     loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
907     // oh_block almost all is 1. TODO: manage oh_block != 1
908     // -- harness: loop by oh_blocks --
909     l++;
910     src_is = kd * kh * rnd_inp_simd(sp_thr, kw, ic);
911     loop[l].src.set(oh_block * src_is, 1);
912     loop[l].dst.set(sp_thr * rnd_oc_for_sp, 1);
913     loop[l].wei.set(wei_op * simd_w, nb_oh_thr);
914     // od_block almost all is 1. TODO: manage oh_block != 1
915     // -- harness: loop by od_blocks --
916     l++;
917     loop[l].src.set(od_block * oh_thr * src_is, 1);
918     loop[l].dst.set(oh_thr * sp_thr * rnd_oc_for_sp, 1);
919     loop[l].wei.set(wei_op * simd_w, nb_od_thr);
920 
921     if (loop_order != loop_ndhwgc) {
922         // -- harness: loop by oc_block --
923         l++;
924         loop[l].src.set(od_thr * oh_thr * src_is, nb_oc_thr);
925         loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1);
926         loop[l].wei.set(kd * kh * kw * oc_block * ic, 1);
927     }
928 
929     // -- harness: loop by mb --
930     l++;
931     const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc));
932     loop[l].src.set(od_thr * oh_thr * src_is, 1);
933     loop[l].dst.set(od_thr * oh_thr * sp_thr * nsimd_oc_thr * simd_w, 1);
934     loop[l].wei.set(kd * kh * kw * nsimd_oc_thr * simd_w * ic, mb_thr);
935 
936     const auto src_op = static_cast<dim_t>(mb_thr) * od_thr
937             * (is_os_blocking ? 1 : oh_thr) * sp_thr * kd * kh * kw * ic;
938     const auto dst_op = static_cast<dim_t>(mb_thr) * od_thr
939             * (is_os_blocking ? 1 : oh_thr) * sp_thr * nsimd_oc_thr;
940     wei_op = kd * kh * kw * nsimd_oc_thr * ic;
941 
942     // for "real" application set bench_iterations to 1
943     const auto iterations = bench_iterations;
944     l++;
945     loop[l].src.set(src_op, iterations);
946     loop[l].dst.set(dst_op * simd_w, iterations);
947     loop[l].wei.set(wei_op * simd_w, iterations);
948 
949     auto src_mem_k = mem_k;
950     auto dst_mem_k = mem_k;
951     auto wei_mem_k = mem_k;
952     float src_rp = 1;
953     float dst_rp = 1;
954     float wei_rp = 1;
955 
956     for (auto il = l; il >= 0; il--) {
957         src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true,
958                 loop_order == loop_ndhwgc ? false : true);
959         dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
960         wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false,
961                 loop_order == loop_ndhwgc ? true : false);
962         src_rp *= loop[il].src.repeatn;
963         dst_rp *= loop[il].dst.repeatn;
964         wei_rp *= loop[il].wei.repeatn;
965     }
966     const auto src_ops = (src_op * src_rp) / iterations;
967     const auto dst_ops = (dst_op * dst_rp) / iterations;
968     const auto wei_ops = (wei_op * wei_rp) / iterations;
969 
970     const auto src_cost = src_mem_k * src_ops;
971     const auto dst_cost = dst_mem_k * dst_ops;
972     const auto wei_cost = wei_mem_k * wei_ops;
973     const auto call_kernel_cost
974             = 1000.f * job * ic_chunks * nb_kd * nb_kh * nb_kw;
975 
976     const auto cache_eff = (static_cast<dim_t>(mb) * od * oh * sp * ic * oc)
977             / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
978     const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff
979             * job_eff * ur_eff * cache_eff * brgemm_eff;
980     return res_eff;
981 }
982 
iterate_ker_block(brg_blocking_t & best_brgb,int kd_block_,int kh_block_,bool maybe_use_buffer,int max_ow_block_thr)983 void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_,
984         int kh_block_, bool maybe_use_buffer, int max_ow_block_thr) {
985 
986     unsigned est_k_amount = ic * oc_block * wei_dsz;
987 
988     kd_block = kd_block_;
989     kh_block = kh_block_;
990     if (one_of(exec_type, exec_vpad, exec_trans)) {
991         kw_block = kw;
992         kd_block_pad = kd_block;
993         kh_block_pad = kh_block;
994         kw_block_pad = kw_block;
995     } else {
996         kw_block = (est_k_amount * kw < L2) ? kw : 1;
997         kd_block_pad = kh_block >= kd ? kd : 1;
998         kh_block_pad = kw_block >= kh ? kh : 1;
999         kw_block_pad = kw;
1000     }
1001 
1002     if (exec_type == exec_vpad) {
1003         od_block = 1;
1004         oh_block = 1;
1005     } else if (exec_type == exec_trans) {
1006         const auto w_block_size
1007                 = 2 * src_dsz * ic * iwp + dst_dsz * ow * oc_block;
1008         const auto other_size = wei_dsz * kd * kh * kw * ic * oc_block
1009                 + acc_dsz * 2 * amx_h * oc_block;
1010         const auto L2_available = nstl::min(static_cast<size_t>(div_up(L2, 2)),
1011                 other_size > L2 ? 0 : L2 - other_size);
1012         if (idp * ihp * w_block_size > L2_available) {
1013             od_block = utils::saturate(
1014                     1, od, int(L2_available / (ihp * w_block_size)));
1015             if (od_block == 1)
1016                 oh_block = utils::saturate(
1017                         1, oh, int(L2_available / (w_block_size)));
1018             else
1019                 oh_block = oh;
1020         } else {
1021             od_block = 1;
1022             oh_block = oh;
1023         }
1024         if (is_amx(isa)) {
1025             // try to fit into L1
1026             bool L1_fit_res = false;
1027             auto cur_od_block = od_block;
1028             auto cur_oh_block = oh_block;
1029             const auto src_w_block_size
1030                     = src_dsz * ic * iwp + dst_dsz * ow * oc_block;
1031             if (src_w_block_size < L1) {
1032                 cur_od_block = utils::saturate(
1033                         1, od, int(L1 / (ihp * src_w_block_size)));
1034                 if (cur_od_block == 1)
1035                     cur_oh_block = utils::saturate(
1036                             1, oh, int(L1 / (src_w_block_size)));
1037             }
1038             for (; cur_od_block > 1; cur_od_block--) {
1039                 const auto sp_size = cur_od_block * cur_oh_block * iwp;
1040                 if ((static_cast<float>(od) / rnd_up(od, cur_od_block)) > 0.9f
1041                         && static_cast<float>(sp_size) / rnd_up(sp, amx_h)
1042                                 > 0.8f) {
1043                     L1_fit_res = true;
1044                     break;
1045                 }
1046             }
1047             if (cur_od_block == 1) {
1048                 for (; cur_oh_block > 1; cur_oh_block--) {
1049                     const auto sp_size = cur_oh_block * iwp;
1050                     if ((static_cast<float>(oh) / rnd_up(oh, cur_oh_block))
1051                                     > 0.9f
1052                             && sp_size > 128) {
1053                         L1_fit_res = true;
1054                         break;
1055                     }
1056                 }
1057             }
1058             if (L1_fit_res) {
1059                 od_block = cur_od_block;
1060                 oh_block = cur_oh_block;
1061             }
1062         }
1063 
1064         // limit oh_block to have good threading
1065         auto thr_od_block = div_up(od, div_up(nthr, mb * div_up(oc, oc_block)));
1066         auto thr_oh_block = div_up(oh,
1067                 div_up(nthr,
1068                         mb * div_up(oc, oc_block) * div_up(od, thr_od_block)));
1069         od_block = nstl::min(od_block, thr_od_block);
1070         oh_block = nstl::min(oh_block, thr_oh_block);
1071     } else {
1072         od_block = 1;
1073         oh_block = 1;
1074     }
1075 
1076     // --- Select ow_block ----
1077     const auto max_ow_block_L2 = ow;
1078     auto start_ow_block = nstl::min(max_ow_block_thr, max_ow_block_L2);
1079 
1080     sp = ow;
1081     const auto start_sp_block = is_os_blocking ? ow : start_ow_block;
1082     auto prev_spb = 0;
1083     for (auto ns = 1; ns <= sp; ns++) {
1084         const auto spb = div_up(sp, ns);
1085         if (spb == prev_spb || spb > start_sp_block) continue;
1086         if (is_os_blocking && spb != ow) continue;
1087         prev_spb = spb;
1088         ow_block = spb;
1089         sp_block = ow_block;
1090 
1091         select_ic_block();
1092 
1093         use_buffer = maybe_use_buffer
1094                 && (ic_block * nb_ic_blocking < ic || kd_block != kd
1095                         || kh_block != kh || kw_block != kw
1096                         || kd_block_pad != kd || kh_block_pad != kh
1097                         || kw_block_pad != kw);
1098         if (exec_type == exec_base)
1099             use_buffer = use_buffer || (maybe_use_buffer && iwp != iw);
1100         if (is_amx(isa)) use_buffer = use_buffer || (use_M_mask > 0);
1101 
1102         const status_t st = estimate_brgemm_ur(ow_block);
1103         if (st != status::success) continue;
1104         os_block = sp_block = ow_block;
1105         update_blocks();
1106 
1107         eff = est_eff();
1108 
1109         if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
1110     }
1111 }
1112 
calc_blocks()1113 status_t brg_blocking_t::calc_blocks() {
1114     sp = ow;
1115 
1116     nb_ic_blocking = 1;
1117     // --- Select kernel blocking ---
1118     // if dst_dt != acc_dt and we need to store intermediate
1119     // results then we need the out buffer
1120     const auto maybe_use_buffer = (dst_dt != acc_dt || with_sum);
1121 
1122     std::vector<int> kd_blocks(1), kh_blocks(1);
1123     kd_blocks[0] = kd;
1124     kh_blocks[0] = kh;
1125     if (kd != 1) {
1126         kd_blocks.resize(2);
1127         kd_blocks[1] = 1;
1128     }
1129     if (kh != 1) {
1130         kh_blocks.resize(2);
1131         kh_blocks[1] = 1;
1132     }
1133 
1134     const auto thr_eff_threshold = 0.9f;
1135     const auto max_ow_block_thr = utils::saturate(1, ow,
1136             static_cast<int>(div_up(
1137                     mb * ngroups * nb_oc * os, thr_eff_threshold * nthr)));
1138 
1139     ow_block = os_block = sp_block = -1;
1140     brg_blocking_t best_brgb = *this;
1141     for (const auto &kd_block : kd_blocks) {
1142         for (const auto &kh_block : kh_blocks) {
1143             iterate_ker_block(best_brgb, kd_block, kh_block, maybe_use_buffer,
1144                     max_ow_block_thr);
1145         }
1146     }
1147     *this = best_brgb;
1148     if (!IMPLICATION(!is_os_blocking, sp_block > 0))
1149         return status::unimplemented;
1150 
1151     if (is_os_blocking) {
1152         ow_block = ow;
1153         os_block = ow * oh_block;
1154         sp_block = os_block;
1155         ow_tail = 0;
1156     } else {
1157         ow_block = os_block = sp_block;
1158         ow_tail = ow % ow_block;
1159     }
1160     update_blocks();
1161     return status::success;
1162 }
1163 
fast_check_oc_block_1x1() const1164 bool brg_blocking_t::fast_check_oc_block_1x1() const {
1165     // This function for reducing the number of blocking variants
1166     // TODO: eliminate heuristic in this function
1167     if (is_1x1 && is_amx(isa)) return true;
1168     const auto rnd_oc = rnd_up(oc, 16);
1169     auto res = false;
1170     if (oc_block == 64) {
1171         const auto big_spatial
1172                 = od * oh * ow >= 64 * stride_d * stride_h * stride_w;
1173         res = (rnd_oc % oc_block == 0 && big_spatial);
1174     } else if (oc_block == 48) {
1175         const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
1176         res = (oc_block_eff >= 0.95);
1177     } else
1178         res = true;
1179 
1180     return res;
1181 }
1182 
est_eff_1x1()1183 float brg_blocking_t::est_eff_1x1() {
1184     const auto ocblock = oc_block / 16;
1185 
1186     auto calc_ave_blk = [&](int dim, int block, bool use_ave) -> float {
1187         const int nb = dim / block;
1188         constexpr int max_nb = 2; // only consider 2x2 tile blocking
1189         const int block2 = nstl::min(max_nb, nb);
1190         const int nb2 = nb / block2;
1191         const int nb2_tail = nb % block2;
1192         if (!use_ave) return block2;
1193         return (float(nb2) * block2 + nb2_tail) / div_up(nb, block2);
1194     };
1195     const bool use_ocb_ave = true;
1196     const auto ocb_ave = calc_ave_blk(oc_block, 16, use_ocb_ave);
1197     const bool use_spb_ave = false;
1198     const auto spb_ave = calc_ave_blk(sp_block, ur_block, use_spb_ave);
1199     const auto M_n_sp_blks = ur_block > 0 ? nstl::max(M, M_tail) / ur_block : 0;
1200     const auto M_tail_n_sp_blks
1201             = ur_block_tail > 0 ? M_tail / ur_block_tail : 0;
1202     const auto amx_fac
1203             = div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks);
1204 
1205     const auto brgemm_microkernel_eff = is_amx(isa)
1206             ? amx_fac * (static_cast<float>(ocb_ave) * spb_ave)
1207                     / (ocb_ave + spb_ave)
1208             : (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs);
1209     const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
1210     const auto brgemm_eff = squeeze_val(ur
1211                     * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
1212                     / 64,
1213             0.5f);
1214 
1215     const auto sp_amount = is_os_blocking ? div_up(nb_os, nb_os_blocking)
1216                                           : nb_od * nb_oh * nb_sp;
1217     const auto work_amount = mb * ngroups * nb_oc * sp_amount;
1218 
1219     const auto sp_eff = static_cast<float>(sp) / rnd_up(sp, sp_block);
1220     const auto thr_eff = static_cast<float>(work_amount)
1221             / utils::rnd_up(work_amount, nthr);
1222     const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
1223 
1224     const auto job = div_up(work_amount, nthr);
1225 
1226     const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount;
1227     const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc));
1228     const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block);
1229     const auto nsimd_oc_thr = div_up(oc_thr, simd_w);
1230 
1231     const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1;
1232     const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
1233     const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
1234 
1235     const auto dim_oh = nb_sp * dim_sp;
1236     const auto nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh));
1237     const auto oh_thr
1238             = is_os_blocking ? 1 : nstl::min(oh, nb_oh_thr * oh_block);
1239 
1240     const auto dim_od = nb_oh * dim_oh;
1241     const auto nb_od_thr = nstl::min(nb_od, div_up(job, dim_od));
1242     const auto od_thr
1243             = is_os_blocking ? 1 : nstl::min(od, nb_od_thr * od_block);
1244 
1245     auto job_eff = 1.f;
1246     if (job < nthr) {
1247         std::vector<dim_t> thr_jobs(nthr);
1248         for (int ithr = 0; ithr < nthr; ithr++) {
1249             thr_jobs[ithr] = 0;
1250             if (ithr >= work_amount) continue;
1251             dim_t thr_job = 0;
1252             int start {0}, end {0};
1253             balance211(work_amount, nthr, ithr, start, end);
1254             int n {0}, g {0}, ocb {0}, oss {0}, odp {0}, ohp {0}, spb {0};
1255             if (loop_order == loop_ndhwgc) {
1256                 if (is_os_blocking)
1257                     nd_iterator_init(start, n, mb, oss, sp_amount, g, ngroups,
1258                             ocb, nb_oc);
1259                 else
1260                     nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp,
1261                             g, ngroups, ocb, nb_oc);
1262             } else if (loop_order == loop_ngcdhw) {
1263                 if (is_os_blocking)
1264                     nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, oss,
1265                             sp_amount);
1266                 else
1267                     nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp,
1268                             od, ohp, oh, spb, nb_sp);
1269             }
1270 
1271             for (auto work = start; work < end; work++) {
1272                 const int ocp = ocb * oc_block;
1273                 const auto oc_sz = nstl::min(oc - ocp, oc_block);
1274                 int sp_sz = 0;
1275                 if (is_os_blocking) {
1276                     const auto osb_start = oss * nb_os_blocking;
1277                     const auto osb_range
1278                             = nstl::min(nb_os - osb_start, nb_os_blocking);
1279                     for (int osb = 0; osb < osb_range; osb++) {
1280                         const int osp = (osb_start + osb) * sp_block;
1281                         sp_sz = nstl::min(os - osp, sp_block);
1282                     }
1283                 } else {
1284                     const int spp = spb * sp_block;
1285                     sp_sz = nstl::min(sp - spp, sp_block);
1286                 }
1287                 thr_job += sp_sz * oc_sz;
1288 
1289                 if (loop_order == loop_ndhwgc) {
1290                     if (is_os_blocking)
1291                         nd_iterator_step(
1292                                 n, mb, oss, sp_amount, g, ngroups, ocb, nb_oc);
1293                     else
1294                         nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g,
1295                                 ngroups, ocb, nb_oc);
1296                 } else if (loop_order == loop_ngcdhw) {
1297                     if (is_os_blocking)
1298                         nd_iterator_step(
1299                                 n, mb, g, ngroups, ocb, nb_oc, oss, sp_amount);
1300                     else
1301                         nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od,
1302                                 ohp, oh, spb, nb_sp);
1303                 }
1304             }
1305             thr_jobs[ithr] = thr_job;
1306         }
1307 
1308         dim_t max_job = 0;
1309         dim_t sum_job = 0;
1310         for (int ithr = 0; ithr < nthr; ithr++) {
1311             if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
1312             sum_job += thr_jobs[ithr];
1313         }
1314 
1315         job_eff = max_job == 0 ? 1
1316                                : static_cast<float>(sum_job) / (max_job * nthr);
1317     } else {
1318         job_eff = thr_eff;
1319     }
1320 
1321     const auto ic_blocking_size = ic_block * nb_ic_blocking;
1322     const auto oc_blocking_size = oc_block * ic_blocking_size;
1323 
1324     int l = -1;
1325     // -- brgemm kernel: loop by simd_w  --
1326     l++;
1327     loop[l].src.set(ur * simd_w, 1, bcast_simd);
1328     loop[l].dst.set(0, 1);
1329     loop[l].wei.set(oc_block, 1);
1330 
1331     // -- brgemm kernel: loop by ur in sp_block --
1332     l++;
1333     const auto nb_ur = div_up(sp_block, ur);
1334     const auto nb_sp_no_tail = sp / sp_block;
1335     const auto sp_block_tail = sp % sp_block;
1336     const auto nb_ur_average
1337             = (nb_sp_no_tail * nb_ur + div_up(sp_block_tail, ur)) / nb_sp;
1338     loop[l].src.set(ur * rnd_simd(ic_blocking_size), 1);
1339     loop[l].dst.set(ur * oc_block, 1);
1340     loop[l].wei.set(oc_blocking_size, is_amx(isa) ? nb_ur_average : nb_ur);
1341     // -- brgemm kernel: loop by ic_chunks --
1342     l++;
1343     const auto ic_chunks = div_up(nb_ic, nb_ic_blocking);
1344     loop[l].src.set(sp_block * ic_blocking_size, 1);
1345     loop[l].dst.set(sp_block * oc_block, ic_chunks);
1346     auto wei_is = oc_blocking_size;
1347     auto wei_op = ocblock * ic;
1348     loop[l].wei.set(wei_is, 1);
1349 
1350     if (loop_order == loop_ndhwgc) {
1351         // -- harness: loop by oc_block --
1352         l++;
1353         loop[l].src.set(sp_block * rnd_simd(ic), nb_oc_thr);
1354         loop[l].dst.set(sp_block * oc_block, 1);
1355         wei_is = oc_block * ic;
1356         wei_op = nsimd_oc_thr * ic;
1357         loop[l].wei.set(wei_is, 1);
1358     }
1359 
1360     const auto rnd_oc_for_sp
1361             = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock);
1362     if (is_os_blocking) {
1363         // -- harness: loop by os_blocks --
1364         l++;
1365         loop[l].src.set(sp_block * ic_blocking_size, 1);
1366         loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
1367         loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1368     } else {
1369         // -- harness: loop by sp_blocks --
1370         l++;
1371         loop[l].src.set(sp_block * ic_blocking_size, 1);
1372         loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
1373         loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1374         // -- harness: loop by oh_blocks --
1375         l++;
1376         loop[l].src.set(oh_block * sp_thr * rnd_simd(ic_blocking_size), 1);
1377         loop[l].dst.set(oh_block * sp_thr * rnd_oc_for_sp, 1);
1378         loop[l].wei.set(wei_op * simd_w, nb_oh_thr);
1379         // -- harness: loop by od_blocks --
1380         l++;
1381         loop[l].src.set(
1382                 od_block * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1);
1383         loop[l].dst.set(od_block * oh_thr * sp_thr * rnd_oc_for_sp, 1);
1384         loop[l].wei.set(wei_op * simd_w, nb_od_thr);
1385     }
1386 
1387     if (loop_order != loop_ndhwgc) {
1388         // -- harness: loop by oc_block --
1389         l++;
1390         loop[l].src.set(od_thr * oh_thr * rnd_simd(sp_thr * ic_blocking_size),
1391                 nb_oc_thr);
1392         loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1);
1393         loop[l].wei.set(oc_block * ic, 1);
1394     }
1395 
1396     // -- harness: loop by mb --
1397     l++;
1398     const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc));
1399     loop[l].src.set(od_thr * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1);
1400     loop[l].dst.set(nsimd_oc_thr * simd_w * od_thr * oh_thr * sp_thr, 1);
1401     loop[l].wei.set(nsimd_oc_thr * ic * simd_w, mb_thr);
1402 
1403     const auto src_op = static_cast<dim_t>(mb_thr) * od_thr
1404             * (is_os_blocking ? 1 : oh_thr) * sp_thr * ic_blocking_size;
1405     const auto dst_op = static_cast<dim_t>(mb_thr) * nsimd_oc_thr * od_thr
1406             * (is_os_blocking ? 1 : oh_thr) * sp_thr;
1407     wei_op = nsimd_oc_thr * ic;
1408 
1409     // for "real" application set bench_iterations to 1
1410     const auto iterations = bench_iterations;
1411     l++;
1412     loop[l].src.set(src_op, iterations);
1413     loop[l].dst.set(dst_op * simd_w, iterations);
1414     loop[l].wei.set(wei_op * simd_w, iterations);
1415 
1416     auto src_mem_k = mem_k;
1417     auto dst_mem_k = mem_k;
1418     auto wei_mem_k = mem_k;
1419     float src_rp = 1;
1420     float dst_rp = 1;
1421     float wei_rp = 1;
1422 
1423     for (auto il = l; il >= 0; il--) {
1424         src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false);
1425         dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
1426         wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true);
1427         src_rp *= loop[il].src.repeatn;
1428         dst_rp *= loop[il].dst.repeatn;
1429         wei_rp *= loop[il].wei.repeatn;
1430     }
1431     const auto src_ops = (src_op * src_rp) / iterations;
1432     const auto dst_ops = (dst_op * dst_rp) / iterations;
1433     const auto wei_ops = (wei_op * wei_rp) / iterations;
1434 
1435     const auto src_cost = src_mem_k * src_ops;
1436     const auto dst_cost = dst_mem_k * dst_ops;
1437     const auto wei_cost = wei_mem_k * wei_ops;
1438     const auto call_kernel_cost = 1000.f * job * ic_chunks;
1439 
1440     const auto up_sp_size = is_os_blocking ? 1 : od * oh;
1441 
1442     const auto cache_eff = (static_cast<dim_t>(mb) * up_sp_size * sp * ic * oc)
1443             / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
1444 
1445     const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff
1446             * job_eff * ur_eff * cache_eff * brgemm_eff;
1447     return res_eff;
1448 }
1449 
calc_blocks_1x1()1450 void brg_blocking_t::calc_blocks_1x1() {
1451     const bool is_os_blocking_ok
1452             = utils::everyone_is(1, stride_d, stride_h) && iw % stride_w == 0;
1453     const bool is_ic_zero_padded = ic != ic_without_padding;
1454     is_rtus = is_ic_zero_padded || (!is_os_blocking_ok && is_amx(isa));
1455     if (is_os_blocking_ok || is_rtus) {
1456         sp = os;
1457         is_os_blocking = true;
1458     } else {
1459         sp = ow;
1460         is_os_blocking = false;
1461     }
1462 
1463     od_block = 1;
1464     oh_block = 1;
1465     kd_block = kh_block = kw_block = 1;
1466     kd_block_pad = kh_block_pad = kw_block_pad = 1;
1467     nb_ic_blocking = 1;
1468 
1469     const auto thr_eff_threshold = 0.9f;
1470 
1471     const auto max_sp_block_L2 = os;
1472     // TODO: nb_os_blocking always is 1 for now. Update this code
1473     nb_os_blocking = 1;
1474     int start_sp_block = 0;
1475 
1476     if (is_os_blocking) {
1477         ow_block = 0;
1478 
1479         const auto max_os_block_thr = nstl::max(div_up(2048, oc_block),
1480                 static_cast<int>(div_up(mb * ngroups * os, nthr)));
1481         const auto max_os_block_L2 = max_sp_block_L2;
1482 
1483         auto max_os_block_aliasing = 1000000 / nthr;
1484         if ((oc_without_padding * os * dst_dsz) % 4096 == 0) {
1485             max_os_block_aliasing /= 1;
1486             for (auto cur_oc = oc_without_padding;
1487                     max_os_block_aliasing * dst_dsz > 400 && cur_oc % 2 == 0
1488                     && cur_oc * os * dst_dsz >= 4096;
1489                     cur_oc /= 2) {
1490                 max_os_block_aliasing /= 2;
1491             }
1492             max_os_block_aliasing += max_os_block_aliasing % 2 ? 0 : 1;
1493         }
1494         max_os_block_aliasing
1495                 = nstl::min(div_up(1001, dst_dsz), max_os_block_aliasing);
1496 
1497         start_sp_block = utils::saturate(1, os,
1498                 nstl::min(nstl::min(max_os_block_thr, max_os_block_L2),
1499                         max_os_block_aliasing));
1500 
1501     } else {
1502         os_block = 0;
1503 
1504         const auto max_ow_block_thr = utils::saturate(1, ow,
1505                 static_cast<int>(div_up(
1506                         mb * ngroups * nb_oc * os, thr_eff_threshold * nthr)));
1507         const auto max_ow_block_L2 = max_sp_block_L2;
1508 
1509         start_sp_block = utils::saturate(
1510                 1, ow, nstl::min(max_ow_block_thr, max_ow_block_L2));
1511     }
1512     os_block = ow_block = sp_block = -1;
1513     brg_blocking_t best_brgb = *this;
1514 
1515     auto prev_spb = 0;
1516     for (auto ns = 1; ns <= sp; ns++) {
1517         auto spb = div_up(sp, ns);
1518         if (is_amx(isa)) {
1519             auto min_dis = 16;
1520             auto best_w = 16;
1521             const auto max_tile_width = nstl::min(16, sp);
1522             const auto min_tile_width = utils::saturate(1, 11, sp / 2);
1523             if (spb < min_tile_width) break;
1524             for (auto w = max_tile_width; w >= min_tile_width; w--) {
1525                 const auto dis = nstl::additive_inverse_modulo(spb, w);
1526                 if (dis < min_dis) {
1527                     min_dis = dis;
1528                     best_w = w;
1529                 }
1530             }
1531             spb = nstl::min(sp, rnd_dn(spb, best_w));
1532             if (spb == prev_spb) continue;
1533         }
1534         if (spb == prev_spb || spb > start_sp_block) continue;
1535         prev_spb = spb;
1536         os_block = ow_block = sp_block = spb;
1537         select_ic_block();
1538         const status_t st = estimate_brgemm_ur(spb);
1539         if (st != status::success) continue;
1540         update_blocks();
1541 
1542         use_buffer = (dst_dt != acc_dt || with_sum)
1543                 && (ic_block * nb_ic_blocking < ic);
1544 
1545         eff = est_eff_1x1();
1546         if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
1547     }
1548     *this = best_brgb;
1549     os_block = ow_block = sp_block;
1550     update_blocks();
1551 }
1552 
init_jcp(jit_brgemm_conv_conf_t & jcp,cpu_isa_t isa,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)1553 status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1554         const convolution_desc_t &cd, memory_desc_t &src_md,
1555         memory_desc_t &weights_md, memory_desc_t &dst_md,
1556         memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1557     using namespace prop_kind;
1558 
1559     brg_blocking_t::L1 = platform::get_per_core_cache_size(1);
1560     brg_blocking_t::L2 = platform::get_per_core_cache_size(2);
1561     brg_blocking_t::L3 = platform::get_per_core_cache_size(2);
1562 
1563     if (!mayiuse(avx512_core)) return status::unimplemented;
1564 
1565     const memory_desc_wrapper src_d(&src_md);
1566     const memory_desc_wrapper weights_d(&weights_md);
1567     const memory_desc_wrapper dst_d(&dst_md);
1568     const memory_desc_wrapper bias_d(&bias_md);
1569 
1570     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1571     int ndims = src_d.ndims();
1572 
1573     jcp = zero<decltype(jcp)>();
1574     jcp.isa = isa;
1575     jcp.ndims = ndims;
1576     jcp.prop_kind = cd.prop_kind;
1577     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1578     jcp.mb = src_d.dims()[0];
1579     jcp.oc_without_padding = dst_d.dims()[1];
1580     jcp.oc = jcp.oc_without_padding / jcp.ngroups;
1581     jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
1582     jcp.ic = jcp.ic_without_padding;
1583     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1584     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
1585     jcp.iw = src_d.dims()[ndims - 1];
1586     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1587     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
1588     jcp.ow = dst_d.dims()[ndims - 1];
1589     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1590     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1591     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1592     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1593     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1594     jcp.l_pad = cd.padding[0][ndims - 3];
1595     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1596     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1597     jcp.stride_w = cd.strides[ndims - 3];
1598 
1599     jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1600     jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1601     jcp.dilate_w = cd.dilates[ndims - 3];
1602 
1603     jcp.os = jcp.od * jcp.oh * jcp.ow;
1604 
1605     jcp.ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1606     jcp.ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1607     jcp.ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1608 
1609     jcp.back_pad = calculate_end_padding(
1610             jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, jcp.ext_kd);
1611     jcp.b_pad = calculate_end_padding(
1612             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.ext_kh);
1613     jcp.r_pad = calculate_end_padding(
1614             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.ext_kw);
1615 
1616     jcp.is_1x1 = jcp.f_pad <= 0 && jcp.back_pad <= 0 && jcp.t_pad <= 0
1617             && jcp.b_pad <= 0 && jcp.l_pad <= 0 && jcp.r_pad <= 0
1618             && utils::everyone_is(1, jcp.kd, jcp.kh, jcp.kw);
1619 
1620     jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1621 
1622     jcp.src_dt = cd.src_desc.data_type;
1623     jcp.dst_dt = cd.dst_desc.data_type;
1624     jcp.wei_dt = cd.weights_desc.data_type;
1625     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1626 
1627     brg_blocking_t::last_ic_block_size
1628             = (jcp.wei_dt == f32) ? 1 : ((jcp.wei_dt == bf16) ? 2 : 4);
1629 
1630     // TODO: optimize depthwise convolutions (for now direct approach is faster)
1631     const bool is_depthwise
1632             = with_groups && jcp.ngroups > 1 && everyone_is(1, jcp.ic, jcp.oc);
1633     if (is_depthwise) return status::unimplemented;
1634 
1635     // TODO: optimize grouped convolutions with small ic
1636     const bool is_grouped_small_ic = with_groups && jcp.ngroups > 1
1637             && jcp.ic <= 16
1638             // already optimized for amx 1x1 convs
1639             && IMPLICATION(is_amx(jcp.isa), !jcp.is_1x1);
1640     if (is_grouped_small_ic) return status::unimplemented;
1641 
1642     // TODO: support s8 in non-amx brgemm convolutions
1643     if (!IMPLICATION(jcp.src_dt == s8, is_amx(jcp.isa)))
1644         return status::unimplemented;
1645 
1646     if (!IMPLICATION(jcp.wei_dt == s8, mayiuse(avx512_core_vnni)))
1647         return status::unimplemented;
1648     if (!IMPLICATION(jcp.wei_dt == bf16, mayiuse(avx512_core_bf16)))
1649         return status::unimplemented;
1650 
1651     if (one_of(jcp.src_dt, u8, s8)) {
1652         jcp.acc_dt = s32;
1653     } else if (one_of(jcp.src_dt, f32, bf16)) {
1654         jcp.acc_dt = f32;
1655     } else
1656         return status::unimplemented;
1657 
1658     jcp.src_dsz = types::data_type_size(jcp.src_dt);
1659     jcp.wei_dsz = types::data_type_size(jcp.wei_dt);
1660     jcp.dst_dsz = types::data_type_size(jcp.dst_dt);
1661     jcp.acc_dsz = types::data_type_size(jcp.acc_dt);
1662     jcp.bia_dsz = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
1663 
1664     if (!post_ops_ok(jcp, attr, dst_d)) return status::unimplemented;
1665 
1666     jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / jcp.src_dsz;
1667     jcp.amx_h = 16;
1668     jcp.amx_w = 64 / jcp.src_dsz;
1669 
1670     const auto &p = attr.post_ops_;
1671     jcp.with_sum = p.find(primitive_kind::sum) != -1;
1672     const int eltwise_ind = p.find(primitive_kind::eltwise);
1673     jcp.with_eltwise = eltwise_ind != -1;
1674 
1675     const int binary_ind = p.find(primitive_kind::binary);
1676     jcp.with_binary = binary_ind != -1;
1677 
1678     if (jcp.with_bias) {
1679         if (bias_d.format_kind() == format_kind::any)
1680             CHECK(memory_desc_init_by_tag(bias_md, x));
1681     }
1682 
1683     jcp.nthr = nthreads;
1684     jcp.kh_sets = 1;
1685     jcp.kw_sets = 1;
1686     jcp.copy_block_only = false;
1687     jcp.amx_tile_load_xx = false;
1688     jcp.use_M_mask = 0;
1689     jcp.is_os_blocking = false;
1690     jcp.oskip = 0;
1691     jcp.use_uker = false;
1692     jcp.use_interleave_stores = false;
1693     jcp.brgemm_bd_loop_innermost = false;
1694 
1695     // fast check data layout before spending time for blocking selection
1696     format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
1697     const bool any_eligible = (jcp.prop_kind == prop_kind::forward_inference
1698             || jcp.wei_dt == data_type::s8 || is_amx(jcp.isa));
1699     CHECK(init_tag(jcp.src_tag, src_md, src_d, src_tag, any_eligible));
1700 
1701     const auto ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
1702     jcp.is_ic_padded = !jcp.is_1x1 && one_of(jcp.wei_dt, bf16, s8)
1703             && jcp.ic * jcp.kw_sets > ic_padded_block && is_amx(isa);
1704 
1705     return status::success;
1706 }
1707 
init_conf(jit_brgemm_conv_conf_t & jcp,cpu_isa_t isa,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)1708 status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1709         const convolution_desc_t &cd, memory_desc_t &src_md,
1710         memory_desc_t &weights_md, memory_desc_t &dst_md,
1711         memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1712 
1713     using namespace prop_kind;
1714     if (!mayiuse(isa)) return status::unimplemented;
1715 
1716     CHECK(init_jcp(
1717             jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads));
1718 
1719     if (jcp.is_1x1) return status::unimplemented;
1720     // TODO: check these restrictions
1721     if (is_amx(isa)) {
1722         // disabled for first convolutions excepting 3d
1723         const bool is_3d = jcp.ndims == 5;
1724         if (jcp.ic <= 4 && !is_3d) return status::unimplemented;
1725 
1726         if (jcp.f_pad >= jcp.kd || jcp.t_pad >= jcp.kh || jcp.r_pad >= jcp.kw)
1727             return status::unimplemented;
1728         if (jcp.dilate_d > 0 || jcp.dilate_h > 0 || jcp.dilate_w > 0)
1729             return status::unimplemented;
1730     }
1731 
1732     const memory_desc_wrapper src_d(&src_md);
1733     const memory_desc_wrapper weights_d(&weights_md);
1734     const memory_desc_wrapper dst_d(&dst_md);
1735     const memory_desc_wrapper bias_d(&bias_md);
1736 
1737     jcp.idp = jcp.id + jcp.f_pad + jcp.back_pad;
1738     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1739     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1740 
1741     using namespace data_type;
1742     // ======================= blocking =================================
1743 
1744     auto bcast_amount
1745             = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz;
1746     auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.kd * jcp.kh * jcp.kw
1747             * jcp.wei_dsz;
1748 
1749     jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
1750 
1751     const int min_oc_block = 16;
1752 
1753     int selected_ur = 0;
1754     MAYBE_UNUSED(selected_ur);
1755 
1756     auto try_exec_type = [&]() {
1757         brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
1758         best_brgb.oc_block = min_oc_block;
1759         brg_blocking_t cur_brgb = zero<decltype(best_brgb)>();
1760         cur_brgb.get_from_jcp(jcp);
1761         auto start_ocb = (is_amx(isa) && jcp.is_os_blocking) ? 2 : 4;
1762         if (jcp.wei_plain)
1763             start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32,
1764                     div_up(jcp.oc, 16));
1765         start_ocb = nstl::min(div_up(jcp.oc, 16), start_ocb);
1766 
1767         auto finish_ocb = 1;
1768         for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) {
1769             cur_brgb.oc_block = ocb * 16;
1770             cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block);
1771             if (!cur_brgb.fast_check_oc_block()) continue;
1772 
1773             const status_t blocking_ok = cur_brgb.calc_blocks();
1774             if (blocking_ok != status::success) continue;
1775 
1776             const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md);
1777             if (st != status::success) continue;
1778             cur_brgb.eff = cur_brgb.est_eff();
1779             if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
1780         }
1781         if (best_brgb.oc_block == 0 || best_brgb.ic_block == 0
1782                 || best_brgb.ow_block == 0)
1783             return false;
1784         best_brgb.save_to_jcp(jcp);
1785         selected_ur = best_brgb.ur;
1786         return true;
1787     };
1788 
1789     //-----------------------------------------------------------------------
1790 
1791     jcp.exec_type = exec_base;
1792     jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
1793 
1794     bool try_exec_vpad = false;
1795     bool try_exec_trans = false;
1796     bool try_exec_base = true;
1797 
1798     if (!is_amx(isa) && div_up(jcp.l_pad, jcp.stride_w) < jcp.kw
1799             && div_up(jcp.r_pad, jcp.stride_w) < jcp.kw) {
1800         try_exec_vpad = true;
1801     }
1802 
1803     const auto ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
1804     // TODO: remove this restriction
1805     const auto w_padding = jcp.l_pad > 0 || jcp.r_pad > 0;
1806     if (is_amx(isa)) {
1807         try_exec_base = !w_padding
1808                 && IMPLICATION(jcp.ic <= ic_padded_block,
1809                         jcp.ic % brg_blocking_t::last_ic_block_size == 0)
1810                 && IMPLICATION(
1811                         jcp.ic > ic_padded_block, jcp.ic % ic_padded_block == 0)
1812                 && jcp.ow > 50 /*TODO: reinvestigate this heuristic */;
1813         try_exec_trans = !try_exec_base;
1814     }
1815 
1816     bool must_exec_vpad = false;
1817 
1818     // TODO: in future use (kd/kh/kw) and (kd/kh/kw)_pad blocks for more
1819     // precise calculation of jcp.max_batch
1820     jcp.max_batch = jcp.kd * jcp.kh * jcp.kw;
1821 
1822     //TODO: check wei plain
1823     jcp.wei_plain = false;
1824     jcp.wei_plain = jcp.exec_type == exec_vpad ? jcp.wei_plain : false;
1825 
1826     bool try_exec_type_res = false;
1827 
1828     if (try_exec_vpad) {
1829         jcp.exec_type = exec_vpad;
1830         try_exec_type_res = try_exec_type();
1831         // to avoid case when both top and bottom virtual padding are non-zero
1832         // TODO: remove this restriction
1833         const auto iw_block = (jcp.ow_block - 1) * jcp.stride_w + 1;
1834         if (!must_exec_vpad && (iw_block > jcp.iw)) try_exec_type_res = false;
1835     }
1836     if (try_exec_type_res == false && try_exec_trans) {
1837         jcp.exec_type = exec_trans;
1838 
1839         // try loop_ndhwgc always for exec_trans
1840         jcp.loop_order = loop_ndhwgc;
1841 
1842         // we read input block only once for loop_ndhwgc, so we don't need to
1843         // keep it memory
1844         if (jcp.loop_order == loop_ndhwgc) { jcp.copy_block_only = true; }
1845 
1846         jcp.is_ic_padded = one_of(jcp.wei_dt, bf16, s8)
1847                 && jcp.ic * jcp.kw_sets > ic_padded_block;
1848 
1849         if (is_amx(isa) && (/* heuristic*/ jcp.kw_sets == 1 && jcp.ow < 256)) {
1850             jcp.is_os_blocking = jcp.f_pad < jcp.kd && jcp.back_pad < jcp.kd
1851                     && jcp.t_pad < jcp.kh && jcp.b_pad < jcp.kh
1852                     && jcp.r_pad < jcp.kw && jcp.l_pad < jcp.kw;
1853             jcp.use_M_mask = jcp.is_os_blocking ? 2 : 0;
1854             jcp.use_uker = true;
1855             jcp.use_interleave_stores = true;
1856             // assuming 2x2 decomposition in amx brgemm kernel
1857             // and overlap of input by kw
1858             const auto bd_blocking = 2 * jcp.amx_h;
1859             const auto ld_blocking = 2 * 16;
1860             const auto A_ds
1861                     = jcp.src_dsz * bd_blocking * jcp.ic * jcp.kd * jcp.kh;
1862             const auto B_ds = jcp.wei_dsz * ld_blocking * jcp.ic * jcp.kd
1863                     * jcp.kh * jcp.kw;
1864             const auto C_ds = jcp.acc_dsz * bd_blocking * ld_blocking;
1865             if (A_ds + B_ds + C_ds > brg_blocking_t::L1)
1866                 jcp.amx_tile_load_xx = true;
1867         }
1868 
1869         try_exec_type_res = try_exec_type();
1870     }
1871     if (try_exec_base && try_exec_type_res == false) {
1872         jcp.exec_type = exec_base;
1873         try_exec_type_res = try_exec_type();
1874     }
1875 
1876     if (try_exec_type_res == false) return status::unimplemented;
1877 
1878     // ============ end blocking ===========================================
1879     if (jcp.exec_type == exec_vpad)
1880         jcp.max_vpad = nstl::max(jcp.l_pad, jcp.r_pad);
1881     else
1882         jcp.max_vpad = 0;
1883 
1884     if (jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0)
1885         return status::unimplemented;
1886 
1887     jcp.gemm_batch_size = jcp.nb_ic_blocking
1888             * nstl::max(jcp.kd_block * jcp.kh_block * jcp.kw_block,
1889                     jcp.kd_block_pad * jcp.kh_block_pad * jcp.kw_block_pad);
1890     // to avoid cache concurrent write access from different threads
1891     size_t sc_size = sizeof(brgemm_batch_element_t);
1892     jcp.adjusted_batch_size
1893             = div_up(rnd_up(jcp.gemm_batch_size * sc_size, 4096), sc_size);
1894 
1895     CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
1896     CHECK(attr.set_default_formats(&dst_md));
1897 
1898     const auto &oscales = attr.output_scales_;
1899     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1900 
1901     // only common and per-oc-channel scales are supported
1902     const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
1903     if (!oscales_ok) return status::unimplemented;
1904 
1905     jcp.buffer_size = jcp.LDC * jcp.M;
1906 
1907     jcp.nb_od = div_up(jcp.od, jcp.od_block);
1908     jcp.nb_oh = div_up(jcp.oh, jcp.oh_block);
1909 
1910     if (jcp.exec_type == exec_trans) {
1911         // TODO: this is rough estimation of buffer for transpose input
1912         dim_t ds = jcp.copy_block_only
1913                 ? (brg_blocking_t::get_inp_size(jcp.idp, jcp.od_block, jcp.kd,
1914                            jcp.stride_d, jcp.dilate_d)
1915                         + nstl::max(0, jcp.f_pad) + nstl::max(0, jcp.back_pad))
1916                 : jcp.idp;
1917         dim_t hs = jcp.copy_block_only
1918                 ? (brg_blocking_t::get_inp_size(jcp.ihp, jcp.oh_block, jcp.kh,
1919                            jcp.stride_h, jcp.dilate_h)
1920                         + nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad))
1921                 : jcp.ihp;
1922         if (jcp.is_os_blocking)
1923             hs = div_up(rnd_up(hs * jcp.iwp, jcp.brgM), jcp.iwp);
1924 
1925         jcp.inp_buffer_size = rnd_up(ds * hs * jcp.iwp * jcp.ngroups * jcp.nb_ic
1926                         * jcp.ic_block * jcp.kh_sets * jcp.kw_sets,
1927                 4096);
1928         jcp.inp_buffer_mask_size = rnd_up(static_cast<dim_t>(jcp.nb_od)
1929                         * jcp.nb_oh * jcp.nb_ow * jcp.ngroups * jcp.nb_ic,
1930                 4096);
1931     }
1932 
1933     return status::success;
1934 }
1935 
init_1x1_conf(jit_brgemm_conv_conf_t & jcp,cpu_isa_t isa,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)1936 status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1937         const convolution_desc_t &cd, memory_desc_t &src_md,
1938         memory_desc_t &weights_md, memory_desc_t &dst_md,
1939         memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1940 
1941     using namespace prop_kind;
1942     if (!mayiuse(isa)) return status::unimplemented;
1943 
1944     CHECK(init_jcp(
1945             jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads));
1946 
1947     // Maybe fall back to direct jit impl for small batch sizes
1948     // TODO: eliminate performance degradation and remove this constraint
1949     if (is_amx(isa) && jcp.mb == 1) return status::unimplemented;
1950 
1951     const memory_desc_wrapper src_d(&src_md);
1952     const memory_desc_wrapper weights_d(&weights_md);
1953     const memory_desc_wrapper dst_d(&dst_md);
1954     const memory_desc_wrapper bias_d(&bias_md);
1955 
1956     if (!jcp.is_1x1) return status::unimplemented;
1957 
1958     using namespace data_type;
1959     // ===================== blocking =================================
1960 
1961     auto bcast_amount
1962             = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz;
1963     auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.wei_dsz;
1964 
1965     jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
1966 
1967     if (is_amx(isa)) {
1968         // round up ic if needed
1969         const int vnni_width = brg_blocking_t::last_ic_block_size;
1970         const int n_vnni_blocks = utils::div_up(jcp.ic, vnni_width);
1971         const int ic_block = nstl::min(16, n_vnni_blocks) * vnni_width;
1972         const bool do_zeropad = jcp.ic % vnni_width != 0 || jcp.ic > ic_block;
1973         if (do_zeropad) jcp.ic = utils::rnd_up(jcp.ic, ic_block);
1974         const auto ic_padded_block = 16 * vnni_width;
1975         jcp.is_ic_padded = jcp.ic > ic_padded_block;
1976 
1977         // try to choose optimal loop order
1978         auto wei_size = (size_t)jcp.oc * jcp.ic * jcp.wei_dsz;
1979         auto max_size = 0.75f * brg_blocking_t::L2;
1980         jcp.loop_order = max_size < wei_size ? loop_ngcdhw : loop_ndhwgc;
1981     }
1982 
1983     const auto min_oc_block = 16;
1984 
1985     jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
1986 
1987     // max_batch is 1 and max_vpad is 0 for 1x1 convolutions
1988     jcp.max_batch = 1;
1989     jcp.max_vpad = 0;
1990 
1991     jcp.wei_plain = false;
1992 
1993     brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
1994     best_brgb.oc_block = min_oc_block;
1995     brg_blocking_t cur_brgb = zero<decltype(cur_brgb)>();
1996     cur_brgb.get_from_jcp(jcp);
1997     auto start_ocb = 4;
1998     if (jcp.wei_plain)
1999         start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32,
2000                 div_up(jcp.oc, 16));
2001     start_ocb = nstl::min(div_up(jcp.oc, 16), start_ocb);
2002 
2003     auto finish_ocb = 1;
2004     for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) {
2005         cur_brgb.oc_block = ocb * min_oc_block;
2006         cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block);
2007 
2008         if (!cur_brgb.fast_check_oc_block_1x1()) continue;
2009 
2010         cur_brgb.calc_blocks_1x1();
2011         const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md);
2012         if (st != status::success) continue;
2013         cur_brgb.eff = cur_brgb.est_eff_1x1();
2014         if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
2015     }
2016     best_brgb.save_to_jcp(jcp);
2017 
2018     // =============== end blocking =================================
2019     jcp.brg_stride_a = jcp.ic_block * jcp.src_dsz;
2020     jcp.brg_stride_b = jcp.ic_block * jcp.oc * jcp.wei_dsz;
2021 
2022     if (jcp.ic_block == 0 || jcp.oc_block == 0) return status::unimplemented;
2023 
2024     // Configure matrix sizes
2025 
2026     if (best_brgb.is_os_blocking) {
2027         if (jcp.os_block == 0) return status::unimplemented;
2028         jcp.M = jcp.brgM = jcp.os_block;
2029         jcp.M_tail = jcp.brgM_tail = jcp.os % jcp.os_block;
2030     } else {
2031         if (jcp.ow_block == 0) return status::unimplemented;
2032         jcp.M = jcp.brgM = jcp.ow_block;
2033         jcp.M_tail = jcp.brgM_tail = jcp.ow % jcp.ow_block;
2034     }
2035 
2036     jcp.K = jcp.ic >= jcp.ic_block ? jcp.ic_block : 0;
2037     jcp.N = jcp.oc >= jcp.oc_block ? jcp.oc_block : 0;
2038     jcp.N_tail = jcp.oc % jcp.oc_block;
2039     jcp.K_tail = jcp.ic % jcp.ic_block;
2040 
2041     jcp.gemm_batch_size = jcp.nb_ic_blocking;
2042     // to avoid cache concurrent access from different threads
2043     size_t sc_size = sizeof(brgemm_batch_element_t);
2044     jcp.adjusted_batch_size
2045             = div_up(rnd_up(jcp.gemm_batch_size * sc_size, 4096), sc_size);
2046 
2047     CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
2048     CHECK(attr.set_default_formats(&dst_md));
2049 
2050     const auto &oscales = attr.output_scales_;
2051     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
2052 
2053     // only common and per-oc-channel scales are supported
2054     const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
2055     if (!oscales_ok) return status::unimplemented;
2056 
2057     // no inp buffer or brgemm_vpad for 1x1
2058     constexpr int align_size = platform::get_cache_line_size();
2059     jcp.exec_type = jcp.is_rtus ? exec_trans : exec_base;
2060     jcp.inp_buffer_size
2061             = jcp.is_rtus ? rnd_up(jcp.LDA * jcp.os, align_size) : 0;
2062     jcp.inp_buffer_mask_size = jcp.is_rtus
2063             ? rnd_up(div_up(jcp.nb_ic, jcp.nb_ic_blocking) * jcp.nb_os,
2064                     align_size)
2065             : 0;
2066     jcp.buffer_size = jcp.LDC * jcp.M;
2067 
2068 #if 0
2069     printf("@@@ debug: nthreads = %d, IC = %d, OC = %d, ID = %d, IH = %d, IW = "
2070            "%d, OD = %d, OH = %d, OW = %d, KD = %d, "
2071            "KH = %d, KW = %d\n",
2072             nthreads, jcp.ic, jcp.oc, jcp.id, jcp.ih, jcp.iw, jcp.od, jcp.oh,
2073             jcp.ow, jcp.kd, jcp.kh, jcp.kw);
2074 
2075     printf("@@@ debug: blocking: ic_block = %d, nb_ic_blocking = %d, oc_block "
2076            "= %d, os_block = %d, ow_block = %d, nb_os_blocking = %d, "
2077            "loop_order = %d, "
2078            "wei_plain = %d, wei_tag = %d \n",
2079             jcp.ic_block, jcp.nb_ic_blocking, jcp.oc_block, jcp.os_block,
2080             jcp.ow_block, jcp.nb_os_blocking, jcp.loop_order, jcp.wei_plain,
2081             jcp.wei_tag);
2082 
2083     printf("@@@ debug: Matrix configuration: M = %d, N = %d, K = "
2084            "%d, M_tail = %d, N_tail = %d, K_tail = %d, LDA = %d, LDB = %d, LDC "
2085            "= %d ur = %d\n",
2086             jcp.M, jcp.N, jcp.K, jcp.M_tail, jcp.N_tail, jcp.K_tail, jcp.LDA,
2087             jcp.LDB, jcp.LDC, best_brgb.ur);
2088     printf("@@@ debug: brg_type = %d use_buffer = %d \n", jcp.brg_type,
2089             jcp.use_buffer);
2090     fflush(nullptr);
2091 #endif
2092     return status::success;
2093 }
2094 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_brgemm_conv_conf_t & jcp)2095 void init_scratchpad(memory_tracking::registrar_t &scratchpad,
2096         const jit_brgemm_conv_conf_t &jcp) {
2097     if (jcp.brg_type == brgemm_addr || jcp.brg_type == brgemm_offs
2098             || (jcp.brg_type == brgemm_strd && jcp.exec_type == exec_vpad))
2099         scratchpad.book(key_brgemm_primitive_batch,
2100                 static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
2101                 sizeof(brgemm_batch_element_t), 64);
2102     if (jcp.exec_type == exec_trans) {
2103         size_t inp_buffer_size
2104                 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_size;
2105         scratchpad.book(
2106                 key_conv_brgemm_inp_buffer, inp_buffer_size, jcp.src_dsz);
2107         size_t inp_buffer_mask_size
2108                 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_mask_size;
2109         scratchpad.book(key_conv_brgemm_inp_buffer_mask, inp_buffer_mask_size,
2110                 sizeof(uint8_t));
2111     }
2112     if (jcp.use_buffer) {
2113         scratchpad.book(key_brgemm_primitive_buffer, jcp.nthr * jcp.buffer_size,
2114                 jcp.acc_dsz);
2115     }
2116     if (is_amx(jcp.isa)) {
2117         scratchpad.book(
2118                 key_conv_amx_tile_buffer, jcp.nthr * 4 * 1024, sizeof(char));
2119     }
2120 }
2121 
2122 } // namespace brgemm_convolution_utils
2123 
2124 } // namespace x64
2125 } // namespace cpu
2126 } // namespace impl
2127 } // namespace dnnl
2128