1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "dnnl_types.h"
18
19 #include "common/bfloat16.hpp"
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25
26 #include "cpu/platform.hpp"
27 #include "cpu/x64/cpu_isa_traits.hpp"
28 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
29 #include "cpu/x64/jit_brgemm_conv_utils.hpp"
30 #include "cpu/x64/jit_generator.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36
37 using namespace dnnl::impl::status;
38 using namespace dnnl::impl::format_tag;
39 using namespace dnnl::impl::memory_tracking::names;
40 using namespace dnnl::impl::utils;
41
42 using namespace prop_kind;
43 using namespace data_type;
44
45 namespace brgemm_convolution_utils {
46
init_tag(format_tag_t & tag,memory_desc_t & md,const memory_desc_wrapper & mdw,const format_tag_t tag_value,bool any_eligible)47 inline status_t init_tag(format_tag_t &tag, memory_desc_t &md,
48 const memory_desc_wrapper &mdw, const format_tag_t tag_value,
49 bool any_eligible) {
50
51 if (mdw.format_kind() == format_kind::any) {
52 if (any_eligible) {
53 CHECK(memory_desc_init_by_tag(md, tag_value));
54 tag = tag_value;
55 } else {
56 tag = format_tag::undef;
57 }
58 } else {
59 tag = mdw.matches_one_of_tag(tag_value);
60 }
61
62 if (tag != tag_value) return status::unimplemented;
63
64 return status::success;
65 }
66
is_amx(cpu_isa_t isa)67 bool is_amx(cpu_isa_t isa) {
68 return one_of(isa, avx512_core_bf16_amx_int8, avx512_core_bf16_amx_bf16);
69 }
70
post_ops_ok(jit_brgemm_conv_conf_t & jcp,primitive_attr_t & attr,const memory_desc_wrapper & dst_d)71 bool post_ops_ok(jit_brgemm_conv_conf_t &jcp, primitive_attr_t &attr,
72 const memory_desc_wrapper &dst_d) {
73 using namespace injector;
74
75 const auto &post_ops = attr.post_ops_;
76
77 return injector::post_ops_ok(post_ops_ok_args_t(get_max_cpu_isa(),
78 {sum, eltwise, binary}, post_ops, &dst_d,
79 false /*sum_at_pos_0_only*/, false /*sum_requires_scale_one*/,
80 false /*sum_requires_zp_zero*/,
81 {broadcasting_strategy_t::per_oc,
82 broadcasting_strategy_t::scalar}));
83 }
84
pick_tags(jit_brgemm_conv_conf_t & jcp,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md)85 status_t pick_tags(jit_brgemm_conv_conf_t &jcp, memory_desc_t &src_md,
86 memory_desc_t &weights_md, memory_desc_t &dst_md,
87 memory_desc_t &bias_md) {
88 format_tag_t src_tag, dst_tag, wei_tag;
89 dst_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
90
91 const memory_desc_wrapper src_d(&src_md);
92 const memory_desc_wrapper weights_d(&weights_md);
93 const memory_desc_wrapper dst_d(&dst_md);
94 const memory_desc_wrapper bias_d(&bias_md);
95 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
96
97 const bool is_1d = jcp.ndims == 3;
98 const bool is_2d = jcp.ndims == 4;
99 const bool is_3d = jcp.ndims == 5;
100
101 if (jcp.wei_plain) {
102 jcp.LDB = jcp.oc;
103 if (is_3d) {
104 if (jcp.wei_dt == f32)
105 wei_tag = with_groups ? gdhwio : dhwio;
106 else if (jcp.wei_dt == s8)
107 wei_tag = with_groups ? gdhwIo4i : dhwIo4i;
108 else if (jcp.wei_dt == bf16) {
109 wei_tag = with_groups ? gdhwIo2i : dhwIo2i;
110 } else
111 return status::unimplemented;
112 } else if (is_1d) {
113 if (jcp.wei_dt == f32)
114 wei_tag = with_groups ? gwio : wio;
115 else if (jcp.wei_dt == s8)
116 wei_tag = with_groups ? gwIo4i : wIo4i;
117 else if (jcp.wei_dt == bf16) {
118 wei_tag = with_groups ? gwIo2i : wIo2i;
119 } else
120 return status::unimplemented;
121 } else {
122 assert(is_2d);
123 UNUSED(is_2d);
124 if (jcp.wei_dt == f32)
125 wei_tag = with_groups ? ghwio : hwio;
126 else if (jcp.wei_dt == s8)
127 wei_tag = with_groups ? ghwIo4i : hwIo4i;
128 else if (jcp.wei_dt == bf16) {
129 wei_tag = with_groups ? ghwIo2i : hwIo2i;
130 } else
131 return status::unimplemented;
132 }
133 } else {
134 jcp.LDB = jcp.oc_block;
135 if (jcp.oc_block == 64) {
136 if (is_3d) {
137 if (jcp.wei_dt == f32)
138 wei_tag = with_groups ? gOdhwi64o : Odhwi64o;
139 else if (jcp.wei_dt == s8) {
140 if (jcp.is_ic_padded)
141 wei_tag = with_groups ? gOdhwI16i64o4i : OdhwI16i64o4i;
142 else
143 wei_tag = with_groups ? gOdhwI64o4i : OdhwI64o4i;
144 } else if (jcp.wei_dt == bf16) {
145 if (jcp.is_ic_padded)
146 wei_tag = with_groups ? gOdhwI16i64o2i : OdhwI16i64o2i;
147 else
148 wei_tag = with_groups ? gOdhwI64o2i : OdhwI64o2i;
149 } else
150 return status::unimplemented;
151 } else if (is_1d) {
152 if (jcp.wei_dt == f32)
153 wei_tag = with_groups ? gOwi64o : Owi64o;
154 else if (jcp.wei_dt == s8) {
155 if (jcp.is_ic_padded)
156 wei_tag = with_groups ? gOwI16i64o4i : OwI16i64o4i;
157 else
158 wei_tag = with_groups ? gOwI64o4i : OwI64o4i;
159 } else if (jcp.wei_dt == bf16) {
160 if (jcp.is_ic_padded)
161 wei_tag = with_groups ? gOwI16i64o2i : OwI16i64o2i;
162 else
163 wei_tag = with_groups ? gOwI64o2i : OwI64o2i;
164 } else
165 return status::unimplemented;
166 } else {
167 assert(is_2d);
168 UNUSED(is_2d);
169 if (jcp.wei_dt == f32)
170 wei_tag = with_groups ? gOhwi64o : Ohwi64o;
171 else if (jcp.wei_dt == s8) {
172 if (jcp.is_ic_padded)
173 wei_tag = with_groups ? gOhwI16i64o4i : OhwI16i64o4i;
174 else
175 wei_tag = with_groups ? gOhwI64o4i : OhwI64o4i;
176 } else if (jcp.wei_dt == bf16) {
177 if (jcp.is_ic_padded)
178 wei_tag = with_groups ? gOhwI16i64o2i : OhwI16i64o2i;
179 else
180 wei_tag = with_groups ? gOhwI64o2i : OhwI64o2i;
181 } else
182 return status::unimplemented;
183 }
184 } else if (jcp.oc_block == 48) {
185 if (is_3d) {
186 if (jcp.wei_dt == f32)
187 wei_tag = with_groups ? gOdhwi48o : Odhwi48o;
188 else if (jcp.wei_dt == s8) {
189 if (jcp.is_ic_padded)
190 wei_tag = with_groups ? gOdhwI16i48o4i : OdhwI16i48o4i;
191 else
192 wei_tag = with_groups ? gOdhwI48o4i : OdhwI48o4i;
193 } else if (jcp.wei_dt == bf16) {
194 if (jcp.is_ic_padded)
195 wei_tag = with_groups ? gOdhwI16i48o2i : OdhwI16i48o2i;
196 else
197 wei_tag = with_groups ? gOdhwI48o2i : OdhwI48o2i;
198 } else
199 return status::unimplemented;
200 } else if (is_1d) {
201 if (jcp.wei_dt == f32)
202 wei_tag = with_groups ? gOwi48o : Owi48o;
203 else if (jcp.wei_dt == s8) {
204 if (jcp.is_ic_padded)
205 wei_tag = with_groups ? gOwI16i48o4i : OwI16i48o4i;
206 else
207 wei_tag = with_groups ? gOwI48o4i : OwI48o4i;
208 } else if (jcp.wei_dt == bf16) {
209 if (jcp.is_ic_padded)
210 wei_tag = with_groups ? gOwI16i48o2i : OwI16i48o2i;
211 else
212 wei_tag = with_groups ? gOwI48o2i : OwI48o2i;
213 } else
214 return status::unimplemented;
215 } else {
216 assert(is_2d);
217 UNUSED(is_2d);
218 if (jcp.wei_dt == f32)
219 wei_tag = with_groups ? gOhwi48o : Ohwi48o;
220 else if (jcp.wei_dt == s8) {
221 if (jcp.is_ic_padded)
222 wei_tag = with_groups ? gOhwI16i48o4i : OhwI16i48o4i;
223 else
224 wei_tag = with_groups ? gOhwI48o4i : OhwI48o4i;
225 } else if (jcp.wei_dt == bf16) {
226 if (jcp.is_ic_padded)
227 wei_tag = with_groups ? gOhwI16i48o2i : OhwI16i48o2i;
228 else
229 wei_tag = with_groups ? gOhwI48o2i : OhwI48o2i;
230 } else
231 return status::unimplemented;
232 }
233 } else if (jcp.oc_block == 32) {
234 if (is_3d) {
235 if (jcp.wei_dt == f32)
236 wei_tag = with_groups ? gOdhwi32o : Odhwi32o;
237 else if (jcp.wei_dt == s8) {
238 if (jcp.is_ic_padded)
239 wei_tag = with_groups ? gOdhwI16i32o4i : OdhwI16i32o4i;
240 else
241 wei_tag = with_groups ? gOdhwI32o4i : OdhwI32o4i;
242 } else if (jcp.wei_dt == bf16) {
243 if (jcp.is_ic_padded)
244 wei_tag = with_groups ? gOdhwI16i32o2i : OdhwI16i32o2i;
245 else
246 wei_tag = with_groups ? gOdhwI32o2i : OdhwI32o2i;
247 } else
248 return status::unimplemented;
249 } else if (is_1d) {
250 if (jcp.wei_dt == f32)
251 wei_tag = with_groups ? gOwi32o : Owi32o;
252 else if (jcp.wei_dt == s8) {
253 if (jcp.is_ic_padded)
254 wei_tag = with_groups ? gOwI16i32o4i : OwI16i32o4i;
255 else
256 wei_tag = with_groups ? gOwI32o4i : OwI32o4i;
257 } else if (jcp.wei_dt == bf16) {
258 if (jcp.is_ic_padded)
259 wei_tag = with_groups ? gOwI16i32o2i : OwI16i32o2i;
260 else
261 wei_tag = with_groups ? gOwI32o2i : OwI32o2i;
262 } else
263 return status::unimplemented;
264 } else {
265 assert(is_2d);
266 UNUSED(is_2d);
267 if (jcp.wei_dt == f32)
268 wei_tag = with_groups ? gOhwi32o : Ohwi32o;
269 else if (jcp.wei_dt == s8) {
270 if (jcp.is_ic_padded)
271 wei_tag = with_groups ? gOhwI16i32o4i : OhwI16i32o4i;
272 else
273 wei_tag = with_groups ? gOhwI32o4i : OhwI32o4i;
274 } else if (jcp.wei_dt == bf16) {
275 if (jcp.is_ic_padded)
276 wei_tag = with_groups ? gOhwI16i32o2i : OhwI16i32o2i;
277 else
278 wei_tag = with_groups ? gOhwI32o2i : OhwI32o2i;
279 } else
280 return status::unimplemented;
281 }
282 } else {
283 if (is_3d) {
284 if (jcp.wei_dt == f32)
285 wei_tag = with_groups ? gOdhwi16o : Odhwi16o;
286 else if (jcp.wei_dt == s8) {
287 if (jcp.is_ic_padded)
288 wei_tag = with_groups ? gOdhwI16i16o4i : OdhwI16i16o4i;
289 else
290 wei_tag = with_groups ? gOdhwI16o4i : OdhwI16o4i;
291 } else if (jcp.wei_dt == bf16) {
292 if (jcp.is_ic_padded)
293 wei_tag = with_groups ? gOdhwI16i16o2i : OdhwI16i16o2i;
294 else
295 wei_tag = with_groups ? gOdhwI16o2i : OdhwI16o2i;
296 } else
297 return status::unimplemented;
298 } else if (is_1d) {
299 if (jcp.wei_dt == f32)
300 wei_tag = with_groups ? gOwi16o : Owi16o;
301 else if (jcp.wei_dt == s8) {
302 if (jcp.is_ic_padded)
303 wei_tag = with_groups ? gOwI16i16o4i : OwI16i16o4i;
304 else
305 wei_tag = with_groups ? gOwI16o4i : OwI16o4i;
306 } else if (jcp.wei_dt == bf16) {
307 if (jcp.is_ic_padded)
308 wei_tag = with_groups ? gOwI16i16o2i : OwI16i16o2i;
309 else
310 wei_tag = with_groups ? gOwI16o2i : OwI16o2i;
311 } else
312 return status::unimplemented;
313 } else {
314 assert(is_2d);
315 UNUSED(is_2d);
316
317 if (jcp.wei_dt == f32)
318 wei_tag = with_groups ? gOhwi16o : Ohwi16o;
319 else if (jcp.wei_dt == s8) {
320 if (jcp.is_ic_padded)
321 wei_tag = with_groups ? gOhwI16i16o4i : OhwI16i16o4i;
322 else
323 wei_tag = with_groups ? gOhwI16o4i : OhwI16o4i;
324 } else if (jcp.wei_dt == bf16) {
325 if (jcp.is_ic_padded)
326 wei_tag = with_groups ? gOhwI16i16o2i : OhwI16i16o2i;
327 else
328 wei_tag = with_groups ? gOhwI16o2i : OhwI16o2i;
329 } else
330 return status::unimplemented;
331 }
332 }
333 }
334
335 src_tag = dst_tag;
336
337 const bool any_eligible = (jcp.prop_kind == prop_kind::forward_inference
338 || jcp.wei_dt == data_type::s8 || is_amx(jcp.isa));
339 CHECK(init_tag(jcp.src_tag, src_md, src_d, src_tag, any_eligible));
340 CHECK(init_tag(jcp.dst_tag, dst_md, dst_d, dst_tag, any_eligible));
341 CHECK(init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag, true));
342
343 return status::success;
344 }
345
346 struct brg_blocking_t : public jit_brgemm_conv_conf_t {
347 struct array_in_loop_t {
348 dim_t itersize;
349 float repeatn;
350 float overlap;
setdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t::array_in_loop_t351 void set(dim_t iter_s, float rpt, float ovlp = 1.f) {
352 itersize = iter_s;
353 repeatn = rpt;
354 overlap = ovlp;
355 }
356 };
357
358 struct loop_t {
359 array_in_loop_t src;
360 array_in_loop_t wei;
361 array_in_loop_t dst;
362 };
363
brg_blocking_tdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t364 brg_blocking_t() : jit_brgemm_conv_conf_t() { init(); }
brg_blocking_tdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t365 brg_blocking_t(const jit_brgemm_conv_conf_t &jcp)
366 : jit_brgemm_conv_conf_t(jcp) {
367 init();
368 }
initdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t369 void init() {
370 ur = 0;
371 ur_block = 0;
372 ur_block_tail = 0;
373 eff = 0.f;
374 nb_kd = 0;
375 nb_kh = 0;
376 nb_kw = 0;
377 sp = 0;
378 sp_block = 0;
379 nb_sp = 0;
380 eff = 0;
381 }
382
383 int ur, ur_block, ur_block_tail;
384 int nb_kd, nb_kh, nb_kw;
385 float eff;
386 static unsigned L1;
387 static unsigned L2;
388 static unsigned L3;
389 // These are rough estimates of the latency (relative) of access to various
390 // cache levels. This is enough for an estimation of data access cost.
391 // TODO: Improve memory access estimates
392 static constexpr float L1_k = 1.f;
393 static constexpr float L2_k = 3.f;
394 static constexpr float L3_k = 15.f;
395 // TODO: At the moment, we are primarily evaluating the fit of the data into
396 // the L1/L2. Need to take into account the difference between the L3 and
397 // memory.
398 static constexpr float mem_k = 15.f;
399 static constexpr int bench_iterations = 1;
400 static constexpr int max_regs = 32;
401 static constexpr int bcast_simd = 16;
402
403 int sp, sp_block, nb_sp;
404 static int last_ic_block_size;
405
get_from_jcpdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t406 void get_from_jcp(const jit_brgemm_conv_conf_t &jcp) { *this = jcp; }
save_to_jcpdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t407 void save_to_jcp(jit_brgemm_conv_conf_t &jcp) const { jcp = *this; }
408
409 status_t estimate_brgemm_ur(int spb);
410 status_t get_brgemm_ur(
411 const primitive_attr_t *attr, const memory_desc_t &dst_md);
412
413 float io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
414 bool is_broadcast, bool is_shared) const;
415
416 float io_k(const loop_t loop, const array_in_loop_t arr, float pk,
417 bool is_broadcast, bool is_shared) const;
418
419 void select_ic_block();
420
421 void update_blocks();
422 bool fast_check_oc_block() const;
423 float est_eff();
424 void iterate_ker_block(brg_blocking_t &best_brgb, int kd_block,
425 int kh_block, bool maybe_use_buffer, int max_ow_block_thr);
426 status_t calc_blocks();
427
428 bool fast_check_oc_block_1x1() const;
429 float est_eff_1x1();
430 void calc_blocks_1x1();
431
432 // utils
get_inp_sizednnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t433 static int get_inp_size(
434 int max_src_size, int dst_size, int k, int stride, int dilate) {
435 auto adj_str = nstl::min(k, stride);
436 const auto res = nstl::min(max_src_size,
437 calculate_end_padding(0, dst_size, 0, adj_str,
438 calculate_extended_filter_size(k, dilate)));
439 return res;
440 }
441
squeeze_valdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t442 static float squeeze_val(float eff, float koeff) {
443 if (koeff <= 0) return 1;
444 if (koeff == 1) return eff;
445 const auto k = 1.f / koeff;
446 return (k > 1.f) ? (k - 1 + eff) / k : eff * koeff;
447 }
448
estimate_urdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t449 static int estimate_ur(int oc_block) {
450 const auto est_ur = (oc_block == 64)
451 ? 6
452 : ((oc_block == 48) ? 9 : ((oc_block == 32) ? 14 : 28));
453 return est_ur;
454 }
455
inp_wdnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t456 int inp_w(int out_w, int ker_w) const {
457 return get_inp_size(iw, out_w, ker_w, stride_w, dilate_w);
458 }
459
rnd_simddnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t460 int rnd_simd(int val) const { return rnd_up(val, simd_w); }
461
rnd_inp_simddnnl::impl::cpu::x64::brgemm_convolution_utils::brg_blocking_t462 int rnd_inp_simd(int out_w, int ker_w, int vic) const {
463 const auto vsp = inp_w(out_w, ker_w);
464 return ((stride_w == 1 && vic >= ic) ? rnd_up(vsp * vic, simd_w)
465 : vsp * rnd_up(vic, simd_w));
466 }
467
468 static constexpr int MAXNLOOPS = 32;
469 loop_t loop[MAXNLOOPS];
470 };
471
472 unsigned brg_blocking_t::L1;
473 unsigned brg_blocking_t::L2;
474 unsigned brg_blocking_t::L3;
475 int brg_blocking_t::last_ic_block_size;
476
io_k(dim_t src,dim_t wei,dim_t dst,float n,float pk,bool is_broadcast,bool is_shared) const477 float brg_blocking_t::io_k(dim_t src, dim_t wei, dim_t dst, float n, float pk,
478 bool is_broadcast, bool is_shared) const {
479 if (n < 1) return 0;
480 if (n == 1) return pk;
481 const auto amount = src * src_dsz + wei * wei_dsz + dst * dst_dsz
482 + (use_buffer ? dst * acc_dsz : 0);
483 const auto amount_L1 = is_broadcast ? src * src_dsz : amount;
484 const auto k = is_broadcast
485 ? ((amount_L1 < L1) ? L1_k
486 : ((amount < L2) ? L2_k
487 : (is_shared ? L3_k : mem_k)))
488 : ((amount < L2) ? L2_k : (is_shared ? L3_k : mem_k));
489 const auto cost = pk + k * (n - 1);
490 return cost / n;
491 }
492
io_k(const loop_t loop,const array_in_loop_t arr,float pk,bool is_broadcast,bool is_shared) const493 float brg_blocking_t::io_k(const loop_t loop, const array_in_loop_t arr,
494 float pk, bool is_broadcast, bool is_shared) const {
495 return io_k(loop.src.itersize, loop.wei.itersize, loop.dst.itersize,
496 arr.repeatn * arr.overlap, pk, is_broadcast, is_shared);
497 }
498
select_ic_block()499 void brg_blocking_t::select_ic_block() {
500 if (is_1x1 && is_amx(isa)) {
501 // TODO: merge with non-1x1 code block below
502 const int ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
503 assert(ic < ic_padded_block || ic % ic_padded_block == 0);
504 MAYBE_UNUSED(ic_padded_block);
505 ic_block = ic;
506 nb_ic = utils::div_up(ic, ic_block); // trivially 1 for now
507 return;
508 }
509 auto nb_simd = utils::div_up(ic, simd_w);
510 auto max_simd_blocks = nstl::min(5 * simd_w, nb_simd);
511 const auto nb_icb_eff_threshold = 0.5f;
512 const auto padded_ic = last_ic_block_size * (is_ic_padded ? 16 : 1);
513 if (is_amx(isa)) {
514 if (ic * kw_sets < simd_w) {
515 // this is current requirement from brgemm kernel
516 ic_block = rnd_up(ic, last_ic_block_size);
517 } else {
518 if (exec_type == exec_trans) {
519 auto simd_blocks = 1;
520 for (int nb_icb = max_simd_blocks; nb_icb >= 1; nb_icb--) {
521 auto nb_icb_eff = static_cast<float>(nb_simd)
522 / rnd_up(nb_simd, nb_icb);
523 if (nb_icb_eff >= nb_icb_eff_threshold) {
524 simd_blocks = nb_icb;
525 break;
526 }
527 }
528 ic_block = simd_blocks * simd_w;
529 } else
530 ic_block = simd_w;
531 }
532 } else {
533 const auto est_ur = nstl::min(sp_block, estimate_ur(oc_block));
534 const auto inp_ur = is_os_blocking ? est_ur : inp_w(est_ur, kw_block);
535
536 if (kw_block > 1) {
537 // try to fit src into L1
538 const auto inp_per_ic = static_cast<unsigned int>(inp_ur) * src_dsz;
539 max_simd_blocks = saturate(1, max_simd_blocks,
540 static_cast<int>(L1 / (inp_per_ic * simd_w)));
541 }
542 // try to fit all batch for ur into L2
543 const auto wei_per_ic = static_cast<unsigned int>(kd_block) * kh_block
544 * kw_block * oc_block * wei_dsz;
545 const auto inp_per_ic = static_cast<unsigned int>(kd_block) * kh_block
546 * inp_ur * src_dsz;
547 const auto out_size
548 = static_cast<unsigned int>(ur) * oc_block * dst_dsz;
549
550 max_simd_blocks = saturate(1, max_simd_blocks,
551 static_cast<int>((L2 - out_size)
552 / ((wei_per_ic + inp_per_ic) * simd_w)));
553
554 auto simd_blocks = 1;
555 for (int nb_icb = nstl::min(max_simd_blocks, nb_simd); nb_icb >= 1;
556 nb_icb--) {
557 auto nb_icb_eff
558 = static_cast<float>(nb_simd) / rnd_up(nb_simd, nb_icb);
559 if (nb_icb_eff >= nb_icb_eff_threshold) {
560 simd_blocks = nb_icb;
561 break;
562 }
563 }
564
565 ic_block = nstl::min(
566 (exec_type == exec_trans) ? rnd_up(ic, padded_ic) : ic,
567 simd_blocks * simd_w);
568 }
569 nb_ic = utils::div_up(ic, ic_block);
570 }
571
estimate_brgemm_ur(int spb)572 status_t brg_blocking_t::estimate_brgemm_ur(int spb) {
573 // Simple simulation of brgemm_desc init
574 if (sp_block <= 0) return status::invalid_arguments;
575 LDA = is_rtus
576 ? (ic_block)
577 : (kh_sets > 1 ? kh_sets : 1) * (kw_sets > 1 ? kw_sets : stride_w)
578 * (exec_type == exec_trans ? ic_block
579 : ngroups * ic_without_padding);
580 LDB = oc_block;
581 LDC = use_buffer ? oc_block : oc_without_padding;
582
583 // Configure matrix sizes
584 // for amx if ic_block != ic then we use exec_trans so K is ic_block
585 const auto padded_ic = last_ic_block_size * (is_ic_padded ? 16 : 1);
586
587 icp = rnd_up(ic, padded_ic);
588 M = brgM = sp >= sp_block ? sp_block : 0;
589 M_tail = brgM_tail = sp % sp_block;
590 if (is_os_blocking) {
591 if (!is_1x1) M_tail = brgM_tail = (oh * ow) % sp_block;
592 oskip = ((ext_kw - 1) / stride_w) * stride_h + (stride_h - 1) * ow;
593
594 brgM = sp_block + oskip * (div_up(M, ow) - 1);
595
596 // round up brgM to help brgemm kernel use max amx_h as brgemm bd_block
597 if (use_M_mask == 2) {
598 int ibrgM = 0;
599 const auto adj_ow = ow_block + oskip;
600 while (ibrgM < brgM) {
601 if (ibrgM % adj_ow < ow_block)
602 ibrgM += amx_h;
603 else
604 ibrgM++;
605 }
606 brgM = ibrgM;
607 } else
608 brgM = rnd_up(brgM, amx_h);
609
610 brgM_tail = brgM;
611 }
612
613 N = oc >= oc_block ? oc_block : 0;
614 N_tail = oc % oc_block;
615 K = kh_sets * kw_sets * (ic >= ic_block ? ic_block : 0);
616 K_tail = kh_sets * kw_sets
617 * (exec_type == exec_trans
618 ? ic_block
619 : rnd_up(ic % ic_block, last_ic_block_size));
620
621 const auto vK = K > 0 ? K : K_tail;
622 const auto vM = M > 0 ? M : M_tail;
623 const auto vN = N > 0 ? N : N_tail;
624
625 const float alpha = 1.0;
626 const float beta = 0.0;
627 brgemm_t brg;
628 CHECK(brgemm_desc_init(&brg, isa, brgemm_addr, src_dt, wei_dt, false, false,
629 brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK));
630 ur = brg.bd_block * (is_amx(isa) ? brg.bd_block2 : 1);
631 ur_block = brg.bd_block;
632 if (is_1x1 && is_amx(isa) && M > 0 && M_tail > 0) {
633 brgemm_t brg_sp_tail;
634 CHECK(brgemm_desc_init(&brg_sp_tail, isa, brgemm_addr, src_dt, wei_dt,
635 false, false, brgemm_row_major, alpha, beta, LDA, LDB, LDC,
636 M_tail, vN, vK));
637 ur_block_tail = brg_sp_tail.bd_block;
638 } else {
639 ur_block_tail = 0;
640 }
641 return status::success;
642 }
643
get_brgemm_ur(const primitive_attr_t * attr,const memory_desc_t & dst_md)644 status_t brg_blocking_t::get_brgemm_ur(
645 const primitive_attr_t *attr, const memory_desc_t &dst_md) {
646 // Detailed simulation of brgemm convolution init
647 if (sp_block <= 0 || ic_block <= 0 || sp_block <= 0 || oc_block <= 0)
648 return status::invalid_arguments;
649 CHECK(estimate_brgemm_ur(is_os_blocking ? os_block : ow_block));
650
651 LDD = oc_without_padding;
652
653 const float alpha = 1.0;
654 const float beta = 1.0;
655 const float beta_init = 0.0;
656
657 for (int i = 0; i < M; i++) {
658 auto vM = i + 1;
659 // init only needed brgemm descriptors
660 if ((utils::one_of(exec_type, exec_trans, exec_vpad) || is_1x1)
661 && vM != M && vM != M_tail)
662 continue;
663 for (int i_init = 0; i_init < 2; i_init++) {
664 for (int i_N = 0; i_N < 2; i_N++) {
665 for (int i_K = 0; i_K < 2; i_K++) {
666 auto vbeta = (i_init) ? beta_init : beta;
667 auto vN = (i_N) ? N_tail : N;
668 auto vK = (i_K) ? K_tail : K;
669 if (vN == 0 || vK == 0) continue;
670 brgemm_t brg;
671 brgemm_strides_t brg_strides;
672 brg_strides.stride_a = ngroups * ic_without_padding
673 * (dilate_w + 1) * src_dsz;
674 //weights are padded by oc_block and last_ic_block
675 brg_strides.stride_b = rnd_up(ic, last_ic_block_size)
676 * rnd_up(oc, oc_block) * wei_dsz;
677 const auto strides_ptr = (brg_type == brgemm_strd)
678 ? &brg_strides
679 : nullptr;
680 CHECK(brgemm_desc_init(&brg, isa, brg_type, src_dt, wei_dt,
681 false, false, brgemm_row_major, alpha, vbeta, LDA,
682 LDB, LDC, vM, vN, vK, strides_ptr));
683
684 brgemm_attr_t brgattr;
685 brgattr.max_bs = max_batch;
686 const auto max_vpad = (exec_type == exec_vpad)
687 ? nstl::max(l_pad, r_pad)
688 : 0;
689 brgattr.max_top_vpad = max_vpad;
690 brgattr.max_bottom_vpad = max_vpad;
691 CHECK(brgemm_desc_set_attr(&brg, brgattr));
692
693 brg.with_sum = with_sum;
694 CHECK(brgemm_desc_set_postops(
695 &brg, attr, &dst_md, LDD, bia_dt));
696 }
697 }
698 }
699 }
700
701 return status::success;
702 }
703
update_blocks()704 void brg_blocking_t::update_blocks() {
705 if (sp_block <= 0
706 || utils::one_of(0, od_block, oh_block, ic_block, oc_block,
707 kd_block, kh_block, kw_block, os_block, ow_block))
708 return;
709
710 nb_od = div_up(od, od_block);
711 nb_oh = div_up(oh, oh_block);
712 nb_ic = div_up(ic, ic_block);
713 nb_oc = div_up(oc, oc_block);
714 nb_kd = div_up(kd, kd_block);
715 nb_kh = div_up(kh, kh_block);
716 nb_kw = div_up(kw, kw_block);
717 nb_ow = div_up(ow, ow_block);
718 if (is_os_blocking) {
719 nb_os = div_up(os, os_block);
720 sp = os;
721 sp_block = os_block;
722 nb_sp = nb_os;
723 } else {
724 sp = ow;
725 sp_block = ow_block;
726 nb_sp = nb_ow;
727 iw_block = get_inp_size(iwp, ow_block, kw, stride_w, dilate_w);
728 }
729 }
730
fast_check_oc_block() const731 bool brg_blocking_t::fast_check_oc_block() const {
732 // This function for reducing the number of blocking variants
733 // TODO: eliminate heuristic in this function
734 const auto rnd_oc = rnd_up(oc, 16);
735 auto res = false;
736 if (oc_block == 64) {
737 res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz < 192 * 4);
738 } else if (oc_block == 48) {
739 const bool big_spatial
740 = id * ih * iw > 81 * stride_d * stride_h * stride_w;
741 res = (rnd_oc % oc_block == 0 && rnd_oc * wei_dsz <= 384 * 4
742 && big_spatial);
743 } else
744 res = true;
745
746 return res;
747 }
748
est_eff()749 float brg_blocking_t::est_eff() {
750 const auto ocblock = oc_block / 16;
751
752 const auto brgemm_microkernel_eff
753 = (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs);
754
755 const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
756 const auto brgemm_eff = squeeze_val(ur
757 * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
758 / 64,
759 0.5f);
760
761 const auto sp_amount = nb_od * nb_oh * nb_sp;
762 const auto work_amount = mb * ngroups * nb_oc * sp_amount;
763 const auto sp_eff = (static_cast<float>(sp) / rnd_up(sp, sp_block));
764
765 const auto thr_eff = static_cast<float>(work_amount)
766 / utils::rnd_up(work_amount, nthr);
767
768 const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
769
770 const auto job = div_up(work_amount, nthr);
771
772 auto job_eff = 1.f;
773 if (job < nthr) {
774 std::vector<dim_t> thr_jobs(nthr);
775
776 for (int ithr = 0; ithr < nthr; ithr++) {
777 thr_jobs[ithr] = 0;
778 if (ithr >= work_amount) continue;
779 dim_t thr_job = 0;
780 int start {0}, end {0};
781 balance211(work_amount, nthr, ithr, start, end);
782 int n {0}, g {0}, ocb {0}, odp {0}, ohp {0}, spb {0};
783 if (loop_order == loop_ndhwgc)
784 nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp, g,
785 ngroups, ocb, nb_oc);
786 else if (loop_order == loop_ngcdhw)
787 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp, od,
788 ohp, oh, spb, nb_sp);
789
790 for (auto work = start; work < end; work++) {
791 const int ocp = ocb * oc_block;
792 const auto oc_sz = nstl::min(oc - ocp, oc_block);
793 int sp_sz = 0;
794 const int spp = spb * sp_block;
795 sp_sz = nstl::min(sp - spp, sp_block);
796 thr_job += sp_sz * oc_sz;
797
798 if (loop_order == loop_ndhwgc)
799 nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g,
800 ngroups, ocb, nb_oc);
801 else if (loop_order == loop_ngcdhw)
802 nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od,
803 ohp, oh, spb, nb_sp);
804 }
805 thr_jobs[ithr] = thr_job;
806 }
807
808 dim_t max_job = 0;
809 dim_t sum_job = 0;
810 for (int ithr = 0; ithr < nthr; ithr++) {
811 if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
812 sum_job += thr_jobs[ithr];
813 }
814 job_eff = max_job == 0 ? 1
815 : static_cast<float>(sum_job) / (max_job * nthr);
816
817 } else {
818 job_eff = thr_eff;
819 }
820
821 const auto ic_blocking_size = ic_block * nb_ic_blocking;
822 const auto oc_blocking_size = oc_block * ic_blocking_size;
823
824 int l = -1;
825
826 // -- brgemm kernel: loop by simd_w --
827 l++;
828 const auto inp_ur = inp_w(ur, kw_block);
829 loop[l].src.set(inp_ur * simd_w, 1, bcast_simd);
830 loop[l].dst.set(0, 1);
831 loop[l].wei.set(oc_block, 1);
832
833 // -- brgemm kernel: loop by kw in kw_block --
834 l++;
835 auto src_is = rnd_inp_simd(ur, kw_block, ic_blocking_size);
836 loop[l].src.set(src_is, 1, kw_block);
837 loop[l].dst.set(0, 1);
838 loop[l].wei.set(oc_blocking_size, 1);
839
840 // -- brgemm kernel: loop by batch (grouped by kw_block) in ur --
841 l++;
842 loop[l].src.set(src_is, 1);
843 loop[l].dst.set(0, 1);
844 auto wei_is = kw_block * oc_blocking_size;
845 loop[l].wei.set(wei_is, 1);
846 // -- brgemm kernel: loop by ur in sp_block --
847 l++;
848 const auto nb_ur = div_up(sp_block, ur);
849 loop[l].src.set(kd_block * kh_block * src_is, 1);
850 loop[l].dst.set(ur * oc_block, 1);
851 wei_is = kd_block * kh_block * kw_block * oc_blocking_size;
852 loop[l].wei.set(wei_is, nb_ur);
853
854 // -- harness: loop by k_blocks in ks --
855 l++;
856 loop[l].src.set(kd_block * kh_block
857 * rnd_inp_simd(sp_block, kw_block, ic_blocking_size),
858 1);
859 loop[l].dst.set(sp_block * oc_block, nb_kd * nb_kh * nb_kw);
860 loop[l].wei.set(wei_is, 1);
861
862 // -- brgemm kernel: loop by ic_chunks --
863 l++;
864 const auto ic_chunks = div_up(nb_ic, nb_ic_blocking);
865 loop[l].src.set(kd * kh * rnd_inp_simd(sp_block, kw, ic_blocking_size), 1);
866 loop[l].dst.set(sp_block * oc_block, ic_chunks);
867 wei_is = kd * kh * kw * oc_blocking_size;
868 loop[l].wei.set(wei_is, 1);
869
870 const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount;
871 const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc));
872 const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block);
873 const auto nsimd_oc_thr = div_up(oc_thr, simd_w);
874
875 const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1;
876 const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
877 const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
878
879 const auto dim_oh = nb_sp * dim_sp;
880 const auto nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh));
881 const auto oh_thr = nstl::min(oh, nb_oh_thr * oh_block);
882
883 const auto dim_od = nb_oh * dim_oh;
884 const auto nb_od_thr = nstl::min(nb_od, div_up(job, dim_od));
885 const auto od_thr = nstl::min(od, nb_od_thr * od_block);
886
887 src_is = kd * kh * rnd_inp_simd(sp_block, kw, ic);
888
889 auto wei_op = kd * kh * kw * ocblock * ic;
890 if (loop_order == loop_ndhwgc) {
891 // -- harness: loop by oc_block --
892 l++;
893 loop[l].src.set(src_is, nb_oc_thr);
894 loop[l].dst.set(sp_block * oc_block, 1);
895 wei_is = kd * kh * kw * oc_block * ic;
896 wei_op = kd * kh * kw * nsimd_oc_thr * ic;
897 loop[l].wei.set(wei_is, 1);
898 }
899
900 // -- harness: loop by sp_blocks --
901 l++;
902 loop[l].src.set(src_is, 1);
903 const auto rnd_oc_for_sp
904 = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock);
905 loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
906 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
907 // oh_block almost all is 1. TODO: manage oh_block != 1
908 // -- harness: loop by oh_blocks --
909 l++;
910 src_is = kd * kh * rnd_inp_simd(sp_thr, kw, ic);
911 loop[l].src.set(oh_block * src_is, 1);
912 loop[l].dst.set(sp_thr * rnd_oc_for_sp, 1);
913 loop[l].wei.set(wei_op * simd_w, nb_oh_thr);
914 // od_block almost all is 1. TODO: manage oh_block != 1
915 // -- harness: loop by od_blocks --
916 l++;
917 loop[l].src.set(od_block * oh_thr * src_is, 1);
918 loop[l].dst.set(oh_thr * sp_thr * rnd_oc_for_sp, 1);
919 loop[l].wei.set(wei_op * simd_w, nb_od_thr);
920
921 if (loop_order != loop_ndhwgc) {
922 // -- harness: loop by oc_block --
923 l++;
924 loop[l].src.set(od_thr * oh_thr * src_is, nb_oc_thr);
925 loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1);
926 loop[l].wei.set(kd * kh * kw * oc_block * ic, 1);
927 }
928
929 // -- harness: loop by mb --
930 l++;
931 const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc));
932 loop[l].src.set(od_thr * oh_thr * src_is, 1);
933 loop[l].dst.set(od_thr * oh_thr * sp_thr * nsimd_oc_thr * simd_w, 1);
934 loop[l].wei.set(kd * kh * kw * nsimd_oc_thr * simd_w * ic, mb_thr);
935
936 const auto src_op = static_cast<dim_t>(mb_thr) * od_thr
937 * (is_os_blocking ? 1 : oh_thr) * sp_thr * kd * kh * kw * ic;
938 const auto dst_op = static_cast<dim_t>(mb_thr) * od_thr
939 * (is_os_blocking ? 1 : oh_thr) * sp_thr * nsimd_oc_thr;
940 wei_op = kd * kh * kw * nsimd_oc_thr * ic;
941
942 // for "real" application set bench_iterations to 1
943 const auto iterations = bench_iterations;
944 l++;
945 loop[l].src.set(src_op, iterations);
946 loop[l].dst.set(dst_op * simd_w, iterations);
947 loop[l].wei.set(wei_op * simd_w, iterations);
948
949 auto src_mem_k = mem_k;
950 auto dst_mem_k = mem_k;
951 auto wei_mem_k = mem_k;
952 float src_rp = 1;
953 float dst_rp = 1;
954 float wei_rp = 1;
955
956 for (auto il = l; il >= 0; il--) {
957 src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true,
958 loop_order == loop_ndhwgc ? false : true);
959 dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
960 wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false,
961 loop_order == loop_ndhwgc ? true : false);
962 src_rp *= loop[il].src.repeatn;
963 dst_rp *= loop[il].dst.repeatn;
964 wei_rp *= loop[il].wei.repeatn;
965 }
966 const auto src_ops = (src_op * src_rp) / iterations;
967 const auto dst_ops = (dst_op * dst_rp) / iterations;
968 const auto wei_ops = (wei_op * wei_rp) / iterations;
969
970 const auto src_cost = src_mem_k * src_ops;
971 const auto dst_cost = dst_mem_k * dst_ops;
972 const auto wei_cost = wei_mem_k * wei_ops;
973 const auto call_kernel_cost
974 = 1000.f * job * ic_chunks * nb_kd * nb_kh * nb_kw;
975
976 const auto cache_eff = (static_cast<dim_t>(mb) * od * oh * sp * ic * oc)
977 / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
978 const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff
979 * job_eff * ur_eff * cache_eff * brgemm_eff;
980 return res_eff;
981 }
982
iterate_ker_block(brg_blocking_t & best_brgb,int kd_block_,int kh_block_,bool maybe_use_buffer,int max_ow_block_thr)983 void brg_blocking_t::iterate_ker_block(brg_blocking_t &best_brgb, int kd_block_,
984 int kh_block_, bool maybe_use_buffer, int max_ow_block_thr) {
985
986 unsigned est_k_amount = ic * oc_block * wei_dsz;
987
988 kd_block = kd_block_;
989 kh_block = kh_block_;
990 if (one_of(exec_type, exec_vpad, exec_trans)) {
991 kw_block = kw;
992 kd_block_pad = kd_block;
993 kh_block_pad = kh_block;
994 kw_block_pad = kw_block;
995 } else {
996 kw_block = (est_k_amount * kw < L2) ? kw : 1;
997 kd_block_pad = kh_block >= kd ? kd : 1;
998 kh_block_pad = kw_block >= kh ? kh : 1;
999 kw_block_pad = kw;
1000 }
1001
1002 if (exec_type == exec_vpad) {
1003 od_block = 1;
1004 oh_block = 1;
1005 } else if (exec_type == exec_trans) {
1006 const auto w_block_size
1007 = 2 * src_dsz * ic * iwp + dst_dsz * ow * oc_block;
1008 const auto other_size = wei_dsz * kd * kh * kw * ic * oc_block
1009 + acc_dsz * 2 * amx_h * oc_block;
1010 const auto L2_available = nstl::min(static_cast<size_t>(div_up(L2, 2)),
1011 other_size > L2 ? 0 : L2 - other_size);
1012 if (idp * ihp * w_block_size > L2_available) {
1013 od_block = utils::saturate(
1014 1, od, int(L2_available / (ihp * w_block_size)));
1015 if (od_block == 1)
1016 oh_block = utils::saturate(
1017 1, oh, int(L2_available / (w_block_size)));
1018 else
1019 oh_block = oh;
1020 } else {
1021 od_block = 1;
1022 oh_block = oh;
1023 }
1024 if (is_amx(isa)) {
1025 // try to fit into L1
1026 bool L1_fit_res = false;
1027 auto cur_od_block = od_block;
1028 auto cur_oh_block = oh_block;
1029 const auto src_w_block_size
1030 = src_dsz * ic * iwp + dst_dsz * ow * oc_block;
1031 if (src_w_block_size < L1) {
1032 cur_od_block = utils::saturate(
1033 1, od, int(L1 / (ihp * src_w_block_size)));
1034 if (cur_od_block == 1)
1035 cur_oh_block = utils::saturate(
1036 1, oh, int(L1 / (src_w_block_size)));
1037 }
1038 for (; cur_od_block > 1; cur_od_block--) {
1039 const auto sp_size = cur_od_block * cur_oh_block * iwp;
1040 if ((static_cast<float>(od) / rnd_up(od, cur_od_block)) > 0.9f
1041 && static_cast<float>(sp_size) / rnd_up(sp, amx_h)
1042 > 0.8f) {
1043 L1_fit_res = true;
1044 break;
1045 }
1046 }
1047 if (cur_od_block == 1) {
1048 for (; cur_oh_block > 1; cur_oh_block--) {
1049 const auto sp_size = cur_oh_block * iwp;
1050 if ((static_cast<float>(oh) / rnd_up(oh, cur_oh_block))
1051 > 0.9f
1052 && sp_size > 128) {
1053 L1_fit_res = true;
1054 break;
1055 }
1056 }
1057 }
1058 if (L1_fit_res) {
1059 od_block = cur_od_block;
1060 oh_block = cur_oh_block;
1061 }
1062 }
1063
1064 // limit oh_block to have good threading
1065 auto thr_od_block = div_up(od, div_up(nthr, mb * div_up(oc, oc_block)));
1066 auto thr_oh_block = div_up(oh,
1067 div_up(nthr,
1068 mb * div_up(oc, oc_block) * div_up(od, thr_od_block)));
1069 od_block = nstl::min(od_block, thr_od_block);
1070 oh_block = nstl::min(oh_block, thr_oh_block);
1071 } else {
1072 od_block = 1;
1073 oh_block = 1;
1074 }
1075
1076 // --- Select ow_block ----
1077 const auto max_ow_block_L2 = ow;
1078 auto start_ow_block = nstl::min(max_ow_block_thr, max_ow_block_L2);
1079
1080 sp = ow;
1081 const auto start_sp_block = is_os_blocking ? ow : start_ow_block;
1082 auto prev_spb = 0;
1083 for (auto ns = 1; ns <= sp; ns++) {
1084 const auto spb = div_up(sp, ns);
1085 if (spb == prev_spb || spb > start_sp_block) continue;
1086 if (is_os_blocking && spb != ow) continue;
1087 prev_spb = spb;
1088 ow_block = spb;
1089 sp_block = ow_block;
1090
1091 select_ic_block();
1092
1093 use_buffer = maybe_use_buffer
1094 && (ic_block * nb_ic_blocking < ic || kd_block != kd
1095 || kh_block != kh || kw_block != kw
1096 || kd_block_pad != kd || kh_block_pad != kh
1097 || kw_block_pad != kw);
1098 if (exec_type == exec_base)
1099 use_buffer = use_buffer || (maybe_use_buffer && iwp != iw);
1100 if (is_amx(isa)) use_buffer = use_buffer || (use_M_mask > 0);
1101
1102 const status_t st = estimate_brgemm_ur(ow_block);
1103 if (st != status::success) continue;
1104 os_block = sp_block = ow_block;
1105 update_blocks();
1106
1107 eff = est_eff();
1108
1109 if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
1110 }
1111 }
1112
calc_blocks()1113 status_t brg_blocking_t::calc_blocks() {
1114 sp = ow;
1115
1116 nb_ic_blocking = 1;
1117 // --- Select kernel blocking ---
1118 // if dst_dt != acc_dt and we need to store intermediate
1119 // results then we need the out buffer
1120 const auto maybe_use_buffer = (dst_dt != acc_dt || with_sum);
1121
1122 std::vector<int> kd_blocks(1), kh_blocks(1);
1123 kd_blocks[0] = kd;
1124 kh_blocks[0] = kh;
1125 if (kd != 1) {
1126 kd_blocks.resize(2);
1127 kd_blocks[1] = 1;
1128 }
1129 if (kh != 1) {
1130 kh_blocks.resize(2);
1131 kh_blocks[1] = 1;
1132 }
1133
1134 const auto thr_eff_threshold = 0.9f;
1135 const auto max_ow_block_thr = utils::saturate(1, ow,
1136 static_cast<int>(div_up(
1137 mb * ngroups * nb_oc * os, thr_eff_threshold * nthr)));
1138
1139 ow_block = os_block = sp_block = -1;
1140 brg_blocking_t best_brgb = *this;
1141 for (const auto &kd_block : kd_blocks) {
1142 for (const auto &kh_block : kh_blocks) {
1143 iterate_ker_block(best_brgb, kd_block, kh_block, maybe_use_buffer,
1144 max_ow_block_thr);
1145 }
1146 }
1147 *this = best_brgb;
1148 if (!IMPLICATION(!is_os_blocking, sp_block > 0))
1149 return status::unimplemented;
1150
1151 if (is_os_blocking) {
1152 ow_block = ow;
1153 os_block = ow * oh_block;
1154 sp_block = os_block;
1155 ow_tail = 0;
1156 } else {
1157 ow_block = os_block = sp_block;
1158 ow_tail = ow % ow_block;
1159 }
1160 update_blocks();
1161 return status::success;
1162 }
1163
fast_check_oc_block_1x1() const1164 bool brg_blocking_t::fast_check_oc_block_1x1() const {
1165 // This function for reducing the number of blocking variants
1166 // TODO: eliminate heuristic in this function
1167 if (is_1x1 && is_amx(isa)) return true;
1168 const auto rnd_oc = rnd_up(oc, 16);
1169 auto res = false;
1170 if (oc_block == 64) {
1171 const auto big_spatial
1172 = od * oh * ow >= 64 * stride_d * stride_h * stride_w;
1173 res = (rnd_oc % oc_block == 0 && big_spatial);
1174 } else if (oc_block == 48) {
1175 const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
1176 res = (oc_block_eff >= 0.95);
1177 } else
1178 res = true;
1179
1180 return res;
1181 }
1182
est_eff_1x1()1183 float brg_blocking_t::est_eff_1x1() {
1184 const auto ocblock = oc_block / 16;
1185
1186 auto calc_ave_blk = [&](int dim, int block, bool use_ave) -> float {
1187 const int nb = dim / block;
1188 constexpr int max_nb = 2; // only consider 2x2 tile blocking
1189 const int block2 = nstl::min(max_nb, nb);
1190 const int nb2 = nb / block2;
1191 const int nb2_tail = nb % block2;
1192 if (!use_ave) return block2;
1193 return (float(nb2) * block2 + nb2_tail) / div_up(nb, block2);
1194 };
1195 const bool use_ocb_ave = true;
1196 const auto ocb_ave = calc_ave_blk(oc_block, 16, use_ocb_ave);
1197 const bool use_spb_ave = false;
1198 const auto spb_ave = calc_ave_blk(sp_block, ur_block, use_spb_ave);
1199 const auto M_n_sp_blks = ur_block > 0 ? nstl::max(M, M_tail) / ur_block : 0;
1200 const auto M_tail_n_sp_blks
1201 = ur_block_tail > 0 ? M_tail / ur_block_tail : 0;
1202 const auto amx_fac
1203 = div_up(M + M_tail, 16) / (M_n_sp_blks + M_tail_n_sp_blks);
1204
1205 const auto brgemm_microkernel_eff = is_amx(isa)
1206 ? amx_fac * (static_cast<float>(ocb_ave) * spb_ave)
1207 / (ocb_ave + spb_ave)
1208 : (static_cast<float>(ocblock) * ur) / ((ur + ocblock) * max_regs);
1209 const auto ur_eff = static_cast<float>(sp_block) / rnd_up(sp_block, ur);
1210 const auto brgemm_eff = squeeze_val(ur
1211 * (2.f - nstl::min(1.9f, static_cast<float>(ur) / sp_block))
1212 / 64,
1213 0.5f);
1214
1215 const auto sp_amount = is_os_blocking ? div_up(nb_os, nb_os_blocking)
1216 : nb_od * nb_oh * nb_sp;
1217 const auto work_amount = mb * ngroups * nb_oc * sp_amount;
1218
1219 const auto sp_eff = static_cast<float>(sp) / rnd_up(sp, sp_block);
1220 const auto thr_eff = static_cast<float>(work_amount)
1221 / utils::rnd_up(work_amount, nthr);
1222 const auto oc_block_eff = static_cast<float>(oc) / rnd_up(oc, oc_block);
1223
1224 const auto job = div_up(work_amount, nthr);
1225
1226 const auto dim_oc = (loop_order == loop_ndhwgc) ? 1 : sp_amount;
1227 const auto nb_oc_thr = nstl::min(nb_oc, div_up(job, dim_oc));
1228 const auto oc_thr = nstl::min(oc, nb_oc_thr * oc_block);
1229 const auto nsimd_oc_thr = div_up(oc_thr, simd_w);
1230
1231 const auto dim_sp = (loop_order == loop_ndhwgc) ? ngroups * nb_oc : 1;
1232 const auto nb_sp_thr = nstl::min(nb_sp, div_up(job, dim_sp));
1233 const auto sp_thr = nstl::min(sp, nb_sp_thr * sp_block);
1234
1235 const auto dim_oh = nb_sp * dim_sp;
1236 const auto nb_oh_thr = nstl::min(nb_oh, div_up(job, dim_oh));
1237 const auto oh_thr
1238 = is_os_blocking ? 1 : nstl::min(oh, nb_oh_thr * oh_block);
1239
1240 const auto dim_od = nb_oh * dim_oh;
1241 const auto nb_od_thr = nstl::min(nb_od, div_up(job, dim_od));
1242 const auto od_thr
1243 = is_os_blocking ? 1 : nstl::min(od, nb_od_thr * od_block);
1244
1245 auto job_eff = 1.f;
1246 if (job < nthr) {
1247 std::vector<dim_t> thr_jobs(nthr);
1248 for (int ithr = 0; ithr < nthr; ithr++) {
1249 thr_jobs[ithr] = 0;
1250 if (ithr >= work_amount) continue;
1251 dim_t thr_job = 0;
1252 int start {0}, end {0};
1253 balance211(work_amount, nthr, ithr, start, end);
1254 int n {0}, g {0}, ocb {0}, oss {0}, odp {0}, ohp {0}, spb {0};
1255 if (loop_order == loop_ndhwgc) {
1256 if (is_os_blocking)
1257 nd_iterator_init(start, n, mb, oss, sp_amount, g, ngroups,
1258 ocb, nb_oc);
1259 else
1260 nd_iterator_init(start, n, mb, odp, od, ohp, oh, spb, nb_sp,
1261 g, ngroups, ocb, nb_oc);
1262 } else if (loop_order == loop_ngcdhw) {
1263 if (is_os_blocking)
1264 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, oss,
1265 sp_amount);
1266 else
1267 nd_iterator_init(start, n, mb, g, ngroups, ocb, nb_oc, odp,
1268 od, ohp, oh, spb, nb_sp);
1269 }
1270
1271 for (auto work = start; work < end; work++) {
1272 const int ocp = ocb * oc_block;
1273 const auto oc_sz = nstl::min(oc - ocp, oc_block);
1274 int sp_sz = 0;
1275 if (is_os_blocking) {
1276 const auto osb_start = oss * nb_os_blocking;
1277 const auto osb_range
1278 = nstl::min(nb_os - osb_start, nb_os_blocking);
1279 for (int osb = 0; osb < osb_range; osb++) {
1280 const int osp = (osb_start + osb) * sp_block;
1281 sp_sz = nstl::min(os - osp, sp_block);
1282 }
1283 } else {
1284 const int spp = spb * sp_block;
1285 sp_sz = nstl::min(sp - spp, sp_block);
1286 }
1287 thr_job += sp_sz * oc_sz;
1288
1289 if (loop_order == loop_ndhwgc) {
1290 if (is_os_blocking)
1291 nd_iterator_step(
1292 n, mb, oss, sp_amount, g, ngroups, ocb, nb_oc);
1293 else
1294 nd_iterator_step(n, mb, odp, od, ohp, oh, spb, nb_sp, g,
1295 ngroups, ocb, nb_oc);
1296 } else if (loop_order == loop_ngcdhw) {
1297 if (is_os_blocking)
1298 nd_iterator_step(
1299 n, mb, g, ngroups, ocb, nb_oc, oss, sp_amount);
1300 else
1301 nd_iterator_step(n, mb, g, ngroups, ocb, nb_oc, odp, od,
1302 ohp, oh, spb, nb_sp);
1303 }
1304 }
1305 thr_jobs[ithr] = thr_job;
1306 }
1307
1308 dim_t max_job = 0;
1309 dim_t sum_job = 0;
1310 for (int ithr = 0; ithr < nthr; ithr++) {
1311 if (thr_jobs[ithr] > max_job) max_job = thr_jobs[ithr];
1312 sum_job += thr_jobs[ithr];
1313 }
1314
1315 job_eff = max_job == 0 ? 1
1316 : static_cast<float>(sum_job) / (max_job * nthr);
1317 } else {
1318 job_eff = thr_eff;
1319 }
1320
1321 const auto ic_blocking_size = ic_block * nb_ic_blocking;
1322 const auto oc_blocking_size = oc_block * ic_blocking_size;
1323
1324 int l = -1;
1325 // -- brgemm kernel: loop by simd_w --
1326 l++;
1327 loop[l].src.set(ur * simd_w, 1, bcast_simd);
1328 loop[l].dst.set(0, 1);
1329 loop[l].wei.set(oc_block, 1);
1330
1331 // -- brgemm kernel: loop by ur in sp_block --
1332 l++;
1333 const auto nb_ur = div_up(sp_block, ur);
1334 const auto nb_sp_no_tail = sp / sp_block;
1335 const auto sp_block_tail = sp % sp_block;
1336 const auto nb_ur_average
1337 = (nb_sp_no_tail * nb_ur + div_up(sp_block_tail, ur)) / nb_sp;
1338 loop[l].src.set(ur * rnd_simd(ic_blocking_size), 1);
1339 loop[l].dst.set(ur * oc_block, 1);
1340 loop[l].wei.set(oc_blocking_size, is_amx(isa) ? nb_ur_average : nb_ur);
1341 // -- brgemm kernel: loop by ic_chunks --
1342 l++;
1343 const auto ic_chunks = div_up(nb_ic, nb_ic_blocking);
1344 loop[l].src.set(sp_block * ic_blocking_size, 1);
1345 loop[l].dst.set(sp_block * oc_block, ic_chunks);
1346 auto wei_is = oc_blocking_size;
1347 auto wei_op = ocblock * ic;
1348 loop[l].wei.set(wei_is, 1);
1349
1350 if (loop_order == loop_ndhwgc) {
1351 // -- harness: loop by oc_block --
1352 l++;
1353 loop[l].src.set(sp_block * rnd_simd(ic), nb_oc_thr);
1354 loop[l].dst.set(sp_block * oc_block, 1);
1355 wei_is = oc_block * ic;
1356 wei_op = nsimd_oc_thr * ic;
1357 loop[l].wei.set(wei_is, 1);
1358 }
1359
1360 const auto rnd_oc_for_sp
1361 = simd_w * ((loop_order == loop_ndhwgc) ? nsimd_oc_thr : ocblock);
1362 if (is_os_blocking) {
1363 // -- harness: loop by os_blocks --
1364 l++;
1365 loop[l].src.set(sp_block * ic_blocking_size, 1);
1366 loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
1367 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1368 } else {
1369 // -- harness: loop by sp_blocks --
1370 l++;
1371 loop[l].src.set(sp_block * ic_blocking_size, 1);
1372 loop[l].dst.set(sp_block * rnd_oc_for_sp, 1);
1373 loop[l].wei.set(wei_op * simd_w, nb_sp_thr);
1374 // -- harness: loop by oh_blocks --
1375 l++;
1376 loop[l].src.set(oh_block * sp_thr * rnd_simd(ic_blocking_size), 1);
1377 loop[l].dst.set(oh_block * sp_thr * rnd_oc_for_sp, 1);
1378 loop[l].wei.set(wei_op * simd_w, nb_oh_thr);
1379 // -- harness: loop by od_blocks --
1380 l++;
1381 loop[l].src.set(
1382 od_block * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1);
1383 loop[l].dst.set(od_block * oh_thr * sp_thr * rnd_oc_for_sp, 1);
1384 loop[l].wei.set(wei_op * simd_w, nb_od_thr);
1385 }
1386
1387 if (loop_order != loop_ndhwgc) {
1388 // -- harness: loop by oc_block --
1389 l++;
1390 loop[l].src.set(od_thr * oh_thr * rnd_simd(sp_thr * ic_blocking_size),
1391 nb_oc_thr);
1392 loop[l].dst.set(oc_block * od_thr * oh_thr * sp_thr, 1);
1393 loop[l].wei.set(oc_block * ic, 1);
1394 }
1395
1396 // -- harness: loop by mb --
1397 l++;
1398 const auto mb_thr = nstl::min(mb, div_up(job, sp_amount * ngroups * nb_oc));
1399 loop[l].src.set(od_thr * oh_thr * sp_thr * rnd_simd(ic_blocking_size), 1);
1400 loop[l].dst.set(nsimd_oc_thr * simd_w * od_thr * oh_thr * sp_thr, 1);
1401 loop[l].wei.set(nsimd_oc_thr * ic * simd_w, mb_thr);
1402
1403 const auto src_op = static_cast<dim_t>(mb_thr) * od_thr
1404 * (is_os_blocking ? 1 : oh_thr) * sp_thr * ic_blocking_size;
1405 const auto dst_op = static_cast<dim_t>(mb_thr) * nsimd_oc_thr * od_thr
1406 * (is_os_blocking ? 1 : oh_thr) * sp_thr;
1407 wei_op = nsimd_oc_thr * ic;
1408
1409 // for "real" application set bench_iterations to 1
1410 const auto iterations = bench_iterations;
1411 l++;
1412 loop[l].src.set(src_op, iterations);
1413 loop[l].dst.set(dst_op * simd_w, iterations);
1414 loop[l].wei.set(wei_op * simd_w, iterations);
1415
1416 auto src_mem_k = mem_k;
1417 auto dst_mem_k = mem_k;
1418 auto wei_mem_k = mem_k;
1419 float src_rp = 1;
1420 float dst_rp = 1;
1421 float wei_rp = 1;
1422
1423 for (auto il = l; il >= 0; il--) {
1424 src_mem_k = io_k(loop[il], loop[il].src, src_mem_k, true, false);
1425 dst_mem_k = io_k(loop[il], loop[il].dst, dst_mem_k, false, false);
1426 wei_mem_k = io_k(loop[il], loop[il].wei, wei_mem_k, false, true);
1427 src_rp *= loop[il].src.repeatn;
1428 dst_rp *= loop[il].dst.repeatn;
1429 wei_rp *= loop[il].wei.repeatn;
1430 }
1431 const auto src_ops = (src_op * src_rp) / iterations;
1432 const auto dst_ops = (dst_op * dst_rp) / iterations;
1433 const auto wei_ops = (wei_op * wei_rp) / iterations;
1434
1435 const auto src_cost = src_mem_k * src_ops;
1436 const auto dst_cost = dst_mem_k * dst_ops;
1437 const auto wei_cost = wei_mem_k * wei_ops;
1438 const auto call_kernel_cost = 1000.f * job * ic_chunks;
1439
1440 const auto up_sp_size = is_os_blocking ? 1 : od * oh;
1441
1442 const auto cache_eff = (static_cast<dim_t>(mb) * up_sp_size * sp * ic * oc)
1443 / (nthr * (src_cost + dst_cost + wei_cost + call_kernel_cost));
1444
1445 const auto res_eff = oc_block_eff * brgemm_microkernel_eff * sp_eff
1446 * job_eff * ur_eff * cache_eff * brgemm_eff;
1447 return res_eff;
1448 }
1449
calc_blocks_1x1()1450 void brg_blocking_t::calc_blocks_1x1() {
1451 const bool is_os_blocking_ok
1452 = utils::everyone_is(1, stride_d, stride_h) && iw % stride_w == 0;
1453 const bool is_ic_zero_padded = ic != ic_without_padding;
1454 is_rtus = is_ic_zero_padded || (!is_os_blocking_ok && is_amx(isa));
1455 if (is_os_blocking_ok || is_rtus) {
1456 sp = os;
1457 is_os_blocking = true;
1458 } else {
1459 sp = ow;
1460 is_os_blocking = false;
1461 }
1462
1463 od_block = 1;
1464 oh_block = 1;
1465 kd_block = kh_block = kw_block = 1;
1466 kd_block_pad = kh_block_pad = kw_block_pad = 1;
1467 nb_ic_blocking = 1;
1468
1469 const auto thr_eff_threshold = 0.9f;
1470
1471 const auto max_sp_block_L2 = os;
1472 // TODO: nb_os_blocking always is 1 for now. Update this code
1473 nb_os_blocking = 1;
1474 int start_sp_block = 0;
1475
1476 if (is_os_blocking) {
1477 ow_block = 0;
1478
1479 const auto max_os_block_thr = nstl::max(div_up(2048, oc_block),
1480 static_cast<int>(div_up(mb * ngroups * os, nthr)));
1481 const auto max_os_block_L2 = max_sp_block_L2;
1482
1483 auto max_os_block_aliasing = 1000000 / nthr;
1484 if ((oc_without_padding * os * dst_dsz) % 4096 == 0) {
1485 max_os_block_aliasing /= 1;
1486 for (auto cur_oc = oc_without_padding;
1487 max_os_block_aliasing * dst_dsz > 400 && cur_oc % 2 == 0
1488 && cur_oc * os * dst_dsz >= 4096;
1489 cur_oc /= 2) {
1490 max_os_block_aliasing /= 2;
1491 }
1492 max_os_block_aliasing += max_os_block_aliasing % 2 ? 0 : 1;
1493 }
1494 max_os_block_aliasing
1495 = nstl::min(div_up(1001, dst_dsz), max_os_block_aliasing);
1496
1497 start_sp_block = utils::saturate(1, os,
1498 nstl::min(nstl::min(max_os_block_thr, max_os_block_L2),
1499 max_os_block_aliasing));
1500
1501 } else {
1502 os_block = 0;
1503
1504 const auto max_ow_block_thr = utils::saturate(1, ow,
1505 static_cast<int>(div_up(
1506 mb * ngroups * nb_oc * os, thr_eff_threshold * nthr)));
1507 const auto max_ow_block_L2 = max_sp_block_L2;
1508
1509 start_sp_block = utils::saturate(
1510 1, ow, nstl::min(max_ow_block_thr, max_ow_block_L2));
1511 }
1512 os_block = ow_block = sp_block = -1;
1513 brg_blocking_t best_brgb = *this;
1514
1515 auto prev_spb = 0;
1516 for (auto ns = 1; ns <= sp; ns++) {
1517 auto spb = div_up(sp, ns);
1518 if (is_amx(isa)) {
1519 auto min_dis = 16;
1520 auto best_w = 16;
1521 const auto max_tile_width = nstl::min(16, sp);
1522 const auto min_tile_width = utils::saturate(1, 11, sp / 2);
1523 if (spb < min_tile_width) break;
1524 for (auto w = max_tile_width; w >= min_tile_width; w--) {
1525 const auto dis = nstl::additive_inverse_modulo(spb, w);
1526 if (dis < min_dis) {
1527 min_dis = dis;
1528 best_w = w;
1529 }
1530 }
1531 spb = nstl::min(sp, rnd_dn(spb, best_w));
1532 if (spb == prev_spb) continue;
1533 }
1534 if (spb == prev_spb || spb > start_sp_block) continue;
1535 prev_spb = spb;
1536 os_block = ow_block = sp_block = spb;
1537 select_ic_block();
1538 const status_t st = estimate_brgemm_ur(spb);
1539 if (st != status::success) continue;
1540 update_blocks();
1541
1542 use_buffer = (dst_dt != acc_dt || with_sum)
1543 && (ic_block * nb_ic_blocking < ic);
1544
1545 eff = est_eff_1x1();
1546 if (eff > best_brgb.eff || best_brgb.eff == 0) best_brgb = *this;
1547 }
1548 *this = best_brgb;
1549 os_block = ow_block = sp_block;
1550 update_blocks();
1551 }
1552
init_jcp(jit_brgemm_conv_conf_t & jcp,cpu_isa_t isa,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)1553 status_t init_jcp(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1554 const convolution_desc_t &cd, memory_desc_t &src_md,
1555 memory_desc_t &weights_md, memory_desc_t &dst_md,
1556 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1557 using namespace prop_kind;
1558
1559 brg_blocking_t::L1 = platform::get_per_core_cache_size(1);
1560 brg_blocking_t::L2 = platform::get_per_core_cache_size(2);
1561 brg_blocking_t::L3 = platform::get_per_core_cache_size(2);
1562
1563 if (!mayiuse(avx512_core)) return status::unimplemented;
1564
1565 const memory_desc_wrapper src_d(&src_md);
1566 const memory_desc_wrapper weights_d(&weights_md);
1567 const memory_desc_wrapper dst_d(&dst_md);
1568 const memory_desc_wrapper bias_d(&bias_md);
1569
1570 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1571 int ndims = src_d.ndims();
1572
1573 jcp = zero<decltype(jcp)>();
1574 jcp.isa = isa;
1575 jcp.ndims = ndims;
1576 jcp.prop_kind = cd.prop_kind;
1577 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1578 jcp.mb = src_d.dims()[0];
1579 jcp.oc_without_padding = dst_d.dims()[1];
1580 jcp.oc = jcp.oc_without_padding / jcp.ngroups;
1581 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
1582 jcp.ic = jcp.ic_without_padding;
1583 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
1584 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
1585 jcp.iw = src_d.dims()[ndims - 1];
1586 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
1587 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
1588 jcp.ow = dst_d.dims()[ndims - 1];
1589 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
1590 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
1591 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1592 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
1593 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
1594 jcp.l_pad = cd.padding[0][ndims - 3];
1595 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
1596 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
1597 jcp.stride_w = cd.strides[ndims - 3];
1598
1599 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
1600 jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4];
1601 jcp.dilate_w = cd.dilates[ndims - 3];
1602
1603 jcp.os = jcp.od * jcp.oh * jcp.ow;
1604
1605 jcp.ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1606 jcp.ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1607 jcp.ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1608
1609 jcp.back_pad = calculate_end_padding(
1610 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, jcp.ext_kd);
1611 jcp.b_pad = calculate_end_padding(
1612 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, jcp.ext_kh);
1613 jcp.r_pad = calculate_end_padding(
1614 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, jcp.ext_kw);
1615
1616 jcp.is_1x1 = jcp.f_pad <= 0 && jcp.back_pad <= 0 && jcp.t_pad <= 0
1617 && jcp.b_pad <= 0 && jcp.l_pad <= 0 && jcp.r_pad <= 0
1618 && utils::everyone_is(1, jcp.kd, jcp.kh, jcp.kw);
1619
1620 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1621
1622 jcp.src_dt = cd.src_desc.data_type;
1623 jcp.dst_dt = cd.dst_desc.data_type;
1624 jcp.wei_dt = cd.weights_desc.data_type;
1625 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1626
1627 brg_blocking_t::last_ic_block_size
1628 = (jcp.wei_dt == f32) ? 1 : ((jcp.wei_dt == bf16) ? 2 : 4);
1629
1630 // TODO: optimize depthwise convolutions (for now direct approach is faster)
1631 const bool is_depthwise
1632 = with_groups && jcp.ngroups > 1 && everyone_is(1, jcp.ic, jcp.oc);
1633 if (is_depthwise) return status::unimplemented;
1634
1635 // TODO: optimize grouped convolutions with small ic
1636 const bool is_grouped_small_ic = with_groups && jcp.ngroups > 1
1637 && jcp.ic <= 16
1638 // already optimized for amx 1x1 convs
1639 && IMPLICATION(is_amx(jcp.isa), !jcp.is_1x1);
1640 if (is_grouped_small_ic) return status::unimplemented;
1641
1642 // TODO: support s8 in non-amx brgemm convolutions
1643 if (!IMPLICATION(jcp.src_dt == s8, is_amx(jcp.isa)))
1644 return status::unimplemented;
1645
1646 if (!IMPLICATION(jcp.wei_dt == s8, mayiuse(avx512_core_vnni)))
1647 return status::unimplemented;
1648 if (!IMPLICATION(jcp.wei_dt == bf16, mayiuse(avx512_core_bf16)))
1649 return status::unimplemented;
1650
1651 if (one_of(jcp.src_dt, u8, s8)) {
1652 jcp.acc_dt = s32;
1653 } else if (one_of(jcp.src_dt, f32, bf16)) {
1654 jcp.acc_dt = f32;
1655 } else
1656 return status::unimplemented;
1657
1658 jcp.src_dsz = types::data_type_size(jcp.src_dt);
1659 jcp.wei_dsz = types::data_type_size(jcp.wei_dt);
1660 jcp.dst_dsz = types::data_type_size(jcp.dst_dt);
1661 jcp.acc_dsz = types::data_type_size(jcp.acc_dt);
1662 jcp.bia_dsz = jcp.with_bias ? types::data_type_size(jcp.bia_dt) : 0;
1663
1664 if (!post_ops_ok(jcp, attr, dst_d)) return status::unimplemented;
1665
1666 jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / jcp.src_dsz;
1667 jcp.amx_h = 16;
1668 jcp.amx_w = 64 / jcp.src_dsz;
1669
1670 const auto &p = attr.post_ops_;
1671 jcp.with_sum = p.find(primitive_kind::sum) != -1;
1672 const int eltwise_ind = p.find(primitive_kind::eltwise);
1673 jcp.with_eltwise = eltwise_ind != -1;
1674
1675 const int binary_ind = p.find(primitive_kind::binary);
1676 jcp.with_binary = binary_ind != -1;
1677
1678 if (jcp.with_bias) {
1679 if (bias_d.format_kind() == format_kind::any)
1680 CHECK(memory_desc_init_by_tag(bias_md, x));
1681 }
1682
1683 jcp.nthr = nthreads;
1684 jcp.kh_sets = 1;
1685 jcp.kw_sets = 1;
1686 jcp.copy_block_only = false;
1687 jcp.amx_tile_load_xx = false;
1688 jcp.use_M_mask = 0;
1689 jcp.is_os_blocking = false;
1690 jcp.oskip = 0;
1691 jcp.use_uker = false;
1692 jcp.use_interleave_stores = false;
1693 jcp.brgemm_bd_loop_innermost = false;
1694
1695 // fast check data layout before spending time for blocking selection
1696 format_tag_t src_tag = pick(jcp.ndims - 3, nwc, nhwc, ndhwc);
1697 const bool any_eligible = (jcp.prop_kind == prop_kind::forward_inference
1698 || jcp.wei_dt == data_type::s8 || is_amx(jcp.isa));
1699 CHECK(init_tag(jcp.src_tag, src_md, src_d, src_tag, any_eligible));
1700
1701 const auto ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
1702 jcp.is_ic_padded = !jcp.is_1x1 && one_of(jcp.wei_dt, bf16, s8)
1703 && jcp.ic * jcp.kw_sets > ic_padded_block && is_amx(isa);
1704
1705 return status::success;
1706 }
1707
init_conf(jit_brgemm_conv_conf_t & jcp,cpu_isa_t isa,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)1708 status_t init_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1709 const convolution_desc_t &cd, memory_desc_t &src_md,
1710 memory_desc_t &weights_md, memory_desc_t &dst_md,
1711 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1712
1713 using namespace prop_kind;
1714 if (!mayiuse(isa)) return status::unimplemented;
1715
1716 CHECK(init_jcp(
1717 jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads));
1718
1719 if (jcp.is_1x1) return status::unimplemented;
1720 // TODO: check these restrictions
1721 if (is_amx(isa)) {
1722 // disabled for first convolutions excepting 3d
1723 const bool is_3d = jcp.ndims == 5;
1724 if (jcp.ic <= 4 && !is_3d) return status::unimplemented;
1725
1726 if (jcp.f_pad >= jcp.kd || jcp.t_pad >= jcp.kh || jcp.r_pad >= jcp.kw)
1727 return status::unimplemented;
1728 if (jcp.dilate_d > 0 || jcp.dilate_h > 0 || jcp.dilate_w > 0)
1729 return status::unimplemented;
1730 }
1731
1732 const memory_desc_wrapper src_d(&src_md);
1733 const memory_desc_wrapper weights_d(&weights_md);
1734 const memory_desc_wrapper dst_d(&dst_md);
1735 const memory_desc_wrapper bias_d(&bias_md);
1736
1737 jcp.idp = jcp.id + jcp.f_pad + jcp.back_pad;
1738 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1739 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1740
1741 using namespace data_type;
1742 // ======================= blocking =================================
1743
1744 auto bcast_amount
1745 = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz;
1746 auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.kd * jcp.kh * jcp.kw
1747 * jcp.wei_dsz;
1748
1749 jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
1750
1751 const int min_oc_block = 16;
1752
1753 int selected_ur = 0;
1754 MAYBE_UNUSED(selected_ur);
1755
1756 auto try_exec_type = [&]() {
1757 brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
1758 best_brgb.oc_block = min_oc_block;
1759 brg_blocking_t cur_brgb = zero<decltype(best_brgb)>();
1760 cur_brgb.get_from_jcp(jcp);
1761 auto start_ocb = (is_amx(isa) && jcp.is_os_blocking) ? 2 : 4;
1762 if (jcp.wei_plain)
1763 start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32,
1764 div_up(jcp.oc, 16));
1765 start_ocb = nstl::min(div_up(jcp.oc, 16), start_ocb);
1766
1767 auto finish_ocb = 1;
1768 for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) {
1769 cur_brgb.oc_block = ocb * 16;
1770 cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block);
1771 if (!cur_brgb.fast_check_oc_block()) continue;
1772
1773 const status_t blocking_ok = cur_brgb.calc_blocks();
1774 if (blocking_ok != status::success) continue;
1775
1776 const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md);
1777 if (st != status::success) continue;
1778 cur_brgb.eff = cur_brgb.est_eff();
1779 if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
1780 }
1781 if (best_brgb.oc_block == 0 || best_brgb.ic_block == 0
1782 || best_brgb.ow_block == 0)
1783 return false;
1784 best_brgb.save_to_jcp(jcp);
1785 selected_ur = best_brgb.ur;
1786 return true;
1787 };
1788
1789 //-----------------------------------------------------------------------
1790
1791 jcp.exec_type = exec_base;
1792 jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
1793
1794 bool try_exec_vpad = false;
1795 bool try_exec_trans = false;
1796 bool try_exec_base = true;
1797
1798 if (!is_amx(isa) && div_up(jcp.l_pad, jcp.stride_w) < jcp.kw
1799 && div_up(jcp.r_pad, jcp.stride_w) < jcp.kw) {
1800 try_exec_vpad = true;
1801 }
1802
1803 const auto ic_padded_block = 16 * brg_blocking_t::last_ic_block_size;
1804 // TODO: remove this restriction
1805 const auto w_padding = jcp.l_pad > 0 || jcp.r_pad > 0;
1806 if (is_amx(isa)) {
1807 try_exec_base = !w_padding
1808 && IMPLICATION(jcp.ic <= ic_padded_block,
1809 jcp.ic % brg_blocking_t::last_ic_block_size == 0)
1810 && IMPLICATION(
1811 jcp.ic > ic_padded_block, jcp.ic % ic_padded_block == 0)
1812 && jcp.ow > 50 /*TODO: reinvestigate this heuristic */;
1813 try_exec_trans = !try_exec_base;
1814 }
1815
1816 bool must_exec_vpad = false;
1817
1818 // TODO: in future use (kd/kh/kw) and (kd/kh/kw)_pad blocks for more
1819 // precise calculation of jcp.max_batch
1820 jcp.max_batch = jcp.kd * jcp.kh * jcp.kw;
1821
1822 //TODO: check wei plain
1823 jcp.wei_plain = false;
1824 jcp.wei_plain = jcp.exec_type == exec_vpad ? jcp.wei_plain : false;
1825
1826 bool try_exec_type_res = false;
1827
1828 if (try_exec_vpad) {
1829 jcp.exec_type = exec_vpad;
1830 try_exec_type_res = try_exec_type();
1831 // to avoid case when both top and bottom virtual padding are non-zero
1832 // TODO: remove this restriction
1833 const auto iw_block = (jcp.ow_block - 1) * jcp.stride_w + 1;
1834 if (!must_exec_vpad && (iw_block > jcp.iw)) try_exec_type_res = false;
1835 }
1836 if (try_exec_type_res == false && try_exec_trans) {
1837 jcp.exec_type = exec_trans;
1838
1839 // try loop_ndhwgc always for exec_trans
1840 jcp.loop_order = loop_ndhwgc;
1841
1842 // we read input block only once for loop_ndhwgc, so we don't need to
1843 // keep it memory
1844 if (jcp.loop_order == loop_ndhwgc) { jcp.copy_block_only = true; }
1845
1846 jcp.is_ic_padded = one_of(jcp.wei_dt, bf16, s8)
1847 && jcp.ic * jcp.kw_sets > ic_padded_block;
1848
1849 if (is_amx(isa) && (/* heuristic*/ jcp.kw_sets == 1 && jcp.ow < 256)) {
1850 jcp.is_os_blocking = jcp.f_pad < jcp.kd && jcp.back_pad < jcp.kd
1851 && jcp.t_pad < jcp.kh && jcp.b_pad < jcp.kh
1852 && jcp.r_pad < jcp.kw && jcp.l_pad < jcp.kw;
1853 jcp.use_M_mask = jcp.is_os_blocking ? 2 : 0;
1854 jcp.use_uker = true;
1855 jcp.use_interleave_stores = true;
1856 // assuming 2x2 decomposition in amx brgemm kernel
1857 // and overlap of input by kw
1858 const auto bd_blocking = 2 * jcp.amx_h;
1859 const auto ld_blocking = 2 * 16;
1860 const auto A_ds
1861 = jcp.src_dsz * bd_blocking * jcp.ic * jcp.kd * jcp.kh;
1862 const auto B_ds = jcp.wei_dsz * ld_blocking * jcp.ic * jcp.kd
1863 * jcp.kh * jcp.kw;
1864 const auto C_ds = jcp.acc_dsz * bd_blocking * ld_blocking;
1865 if (A_ds + B_ds + C_ds > brg_blocking_t::L1)
1866 jcp.amx_tile_load_xx = true;
1867 }
1868
1869 try_exec_type_res = try_exec_type();
1870 }
1871 if (try_exec_base && try_exec_type_res == false) {
1872 jcp.exec_type = exec_base;
1873 try_exec_type_res = try_exec_type();
1874 }
1875
1876 if (try_exec_type_res == false) return status::unimplemented;
1877
1878 // ============ end blocking ===========================================
1879 if (jcp.exec_type == exec_vpad)
1880 jcp.max_vpad = nstl::max(jcp.l_pad, jcp.r_pad);
1881 else
1882 jcp.max_vpad = 0;
1883
1884 if (jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0)
1885 return status::unimplemented;
1886
1887 jcp.gemm_batch_size = jcp.nb_ic_blocking
1888 * nstl::max(jcp.kd_block * jcp.kh_block * jcp.kw_block,
1889 jcp.kd_block_pad * jcp.kh_block_pad * jcp.kw_block_pad);
1890 // to avoid cache concurrent write access from different threads
1891 size_t sc_size = sizeof(brgemm_batch_element_t);
1892 jcp.adjusted_batch_size
1893 = div_up(rnd_up(jcp.gemm_batch_size * sc_size, 4096), sc_size);
1894
1895 CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
1896 CHECK(attr.set_default_formats(&dst_md));
1897
1898 const auto &oscales = attr.output_scales_;
1899 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1900
1901 // only common and per-oc-channel scales are supported
1902 const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
1903 if (!oscales_ok) return status::unimplemented;
1904
1905 jcp.buffer_size = jcp.LDC * jcp.M;
1906
1907 jcp.nb_od = div_up(jcp.od, jcp.od_block);
1908 jcp.nb_oh = div_up(jcp.oh, jcp.oh_block);
1909
1910 if (jcp.exec_type == exec_trans) {
1911 // TODO: this is rough estimation of buffer for transpose input
1912 dim_t ds = jcp.copy_block_only
1913 ? (brg_blocking_t::get_inp_size(jcp.idp, jcp.od_block, jcp.kd,
1914 jcp.stride_d, jcp.dilate_d)
1915 + nstl::max(0, jcp.f_pad) + nstl::max(0, jcp.back_pad))
1916 : jcp.idp;
1917 dim_t hs = jcp.copy_block_only
1918 ? (brg_blocking_t::get_inp_size(jcp.ihp, jcp.oh_block, jcp.kh,
1919 jcp.stride_h, jcp.dilate_h)
1920 + nstl::max(0, jcp.t_pad) + nstl::max(0, jcp.b_pad))
1921 : jcp.ihp;
1922 if (jcp.is_os_blocking)
1923 hs = div_up(rnd_up(hs * jcp.iwp, jcp.brgM), jcp.iwp);
1924
1925 jcp.inp_buffer_size = rnd_up(ds * hs * jcp.iwp * jcp.ngroups * jcp.nb_ic
1926 * jcp.ic_block * jcp.kh_sets * jcp.kw_sets,
1927 4096);
1928 jcp.inp_buffer_mask_size = rnd_up(static_cast<dim_t>(jcp.nb_od)
1929 * jcp.nb_oh * jcp.nb_ow * jcp.ngroups * jcp.nb_ic,
1930 4096);
1931 }
1932
1933 return status::success;
1934 }
1935
init_1x1_conf(jit_brgemm_conv_conf_t & jcp,cpu_isa_t isa,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)1936 status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa,
1937 const convolution_desc_t &cd, memory_desc_t &src_md,
1938 memory_desc_t &weights_md, memory_desc_t &dst_md,
1939 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1940
1941 using namespace prop_kind;
1942 if (!mayiuse(isa)) return status::unimplemented;
1943
1944 CHECK(init_jcp(
1945 jcp, isa, cd, src_md, weights_md, dst_md, bias_md, attr, nthreads));
1946
1947 // Maybe fall back to direct jit impl for small batch sizes
1948 // TODO: eliminate performance degradation and remove this constraint
1949 if (is_amx(isa) && jcp.mb == 1) return status::unimplemented;
1950
1951 const memory_desc_wrapper src_d(&src_md);
1952 const memory_desc_wrapper weights_d(&weights_md);
1953 const memory_desc_wrapper dst_d(&dst_md);
1954 const memory_desc_wrapper bias_d(&bias_md);
1955
1956 if (!jcp.is_1x1) return status::unimplemented;
1957
1958 using namespace data_type;
1959 // ===================== blocking =================================
1960
1961 auto bcast_amount
1962 = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw * jcp.src_dsz;
1963 auto wei_amount = static_cast<size_t>(jcp.oc) * jcp.wei_dsz;
1964
1965 jcp.loop_order = (bcast_amount < wei_amount) ? loop_ngcdhw : loop_ndhwgc;
1966
1967 if (is_amx(isa)) {
1968 // round up ic if needed
1969 const int vnni_width = brg_blocking_t::last_ic_block_size;
1970 const int n_vnni_blocks = utils::div_up(jcp.ic, vnni_width);
1971 const int ic_block = nstl::min(16, n_vnni_blocks) * vnni_width;
1972 const bool do_zeropad = jcp.ic % vnni_width != 0 || jcp.ic > ic_block;
1973 if (do_zeropad) jcp.ic = utils::rnd_up(jcp.ic, ic_block);
1974 const auto ic_padded_block = 16 * vnni_width;
1975 jcp.is_ic_padded = jcp.ic > ic_padded_block;
1976
1977 // try to choose optimal loop order
1978 auto wei_size = (size_t)jcp.oc * jcp.ic * jcp.wei_dsz;
1979 auto max_size = 0.75f * brg_blocking_t::L2;
1980 jcp.loop_order = max_size < wei_size ? loop_ngcdhw : loop_ndhwgc;
1981 }
1982
1983 const auto min_oc_block = 16;
1984
1985 jcp.brg_type = brgemm_addr; // TODO: Choose right type of BRGEMM
1986
1987 // max_batch is 1 and max_vpad is 0 for 1x1 convolutions
1988 jcp.max_batch = 1;
1989 jcp.max_vpad = 0;
1990
1991 jcp.wei_plain = false;
1992
1993 brg_blocking_t best_brgb = zero<decltype(best_brgb)>();
1994 best_brgb.oc_block = min_oc_block;
1995 brg_blocking_t cur_brgb = zero<decltype(cur_brgb)>();
1996 cur_brgb.get_from_jcp(jcp);
1997 auto start_ocb = 4;
1998 if (jcp.wei_plain)
1999 start_ocb = nstl::min(jcp.ic > 128 ? (jcp.ic > 256 ? 8 : 16) : 32,
2000 div_up(jcp.oc, 16));
2001 start_ocb = nstl::min(div_up(jcp.oc, 16), start_ocb);
2002
2003 auto finish_ocb = 1;
2004 for (auto ocb = start_ocb; ocb >= finish_ocb; ocb--) {
2005 cur_brgb.oc_block = ocb * min_oc_block;
2006 cur_brgb.nb_oc = utils::div_up(jcp.oc, cur_brgb.oc_block);
2007
2008 if (!cur_brgb.fast_check_oc_block_1x1()) continue;
2009
2010 cur_brgb.calc_blocks_1x1();
2011 const status_t st = cur_brgb.get_brgemm_ur(&attr, dst_md);
2012 if (st != status::success) continue;
2013 cur_brgb.eff = cur_brgb.est_eff_1x1();
2014 if (cur_brgb.eff > best_brgb.eff) best_brgb = cur_brgb;
2015 }
2016 best_brgb.save_to_jcp(jcp);
2017
2018 // =============== end blocking =================================
2019 jcp.brg_stride_a = jcp.ic_block * jcp.src_dsz;
2020 jcp.brg_stride_b = jcp.ic_block * jcp.oc * jcp.wei_dsz;
2021
2022 if (jcp.ic_block == 0 || jcp.oc_block == 0) return status::unimplemented;
2023
2024 // Configure matrix sizes
2025
2026 if (best_brgb.is_os_blocking) {
2027 if (jcp.os_block == 0) return status::unimplemented;
2028 jcp.M = jcp.brgM = jcp.os_block;
2029 jcp.M_tail = jcp.brgM_tail = jcp.os % jcp.os_block;
2030 } else {
2031 if (jcp.ow_block == 0) return status::unimplemented;
2032 jcp.M = jcp.brgM = jcp.ow_block;
2033 jcp.M_tail = jcp.brgM_tail = jcp.ow % jcp.ow_block;
2034 }
2035
2036 jcp.K = jcp.ic >= jcp.ic_block ? jcp.ic_block : 0;
2037 jcp.N = jcp.oc >= jcp.oc_block ? jcp.oc_block : 0;
2038 jcp.N_tail = jcp.oc % jcp.oc_block;
2039 jcp.K_tail = jcp.ic % jcp.ic_block;
2040
2041 jcp.gemm_batch_size = jcp.nb_ic_blocking;
2042 // to avoid cache concurrent access from different threads
2043 size_t sc_size = sizeof(brgemm_batch_element_t);
2044 jcp.adjusted_batch_size
2045 = div_up(rnd_up(jcp.gemm_batch_size * sc_size, 4096), sc_size);
2046
2047 CHECK(pick_tags(jcp, src_md, weights_md, dst_md, bias_md));
2048 CHECK(attr.set_default_formats(&dst_md));
2049
2050 const auto &oscales = attr.output_scales_;
2051 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
2052
2053 // only common and per-oc-channel scales are supported
2054 const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
2055 if (!oscales_ok) return status::unimplemented;
2056
2057 // no inp buffer or brgemm_vpad for 1x1
2058 constexpr int align_size = platform::get_cache_line_size();
2059 jcp.exec_type = jcp.is_rtus ? exec_trans : exec_base;
2060 jcp.inp_buffer_size
2061 = jcp.is_rtus ? rnd_up(jcp.LDA * jcp.os, align_size) : 0;
2062 jcp.inp_buffer_mask_size = jcp.is_rtus
2063 ? rnd_up(div_up(jcp.nb_ic, jcp.nb_ic_blocking) * jcp.nb_os,
2064 align_size)
2065 : 0;
2066 jcp.buffer_size = jcp.LDC * jcp.M;
2067
2068 #if 0
2069 printf("@@@ debug: nthreads = %d, IC = %d, OC = %d, ID = %d, IH = %d, IW = "
2070 "%d, OD = %d, OH = %d, OW = %d, KD = %d, "
2071 "KH = %d, KW = %d\n",
2072 nthreads, jcp.ic, jcp.oc, jcp.id, jcp.ih, jcp.iw, jcp.od, jcp.oh,
2073 jcp.ow, jcp.kd, jcp.kh, jcp.kw);
2074
2075 printf("@@@ debug: blocking: ic_block = %d, nb_ic_blocking = %d, oc_block "
2076 "= %d, os_block = %d, ow_block = %d, nb_os_blocking = %d, "
2077 "loop_order = %d, "
2078 "wei_plain = %d, wei_tag = %d \n",
2079 jcp.ic_block, jcp.nb_ic_blocking, jcp.oc_block, jcp.os_block,
2080 jcp.ow_block, jcp.nb_os_blocking, jcp.loop_order, jcp.wei_plain,
2081 jcp.wei_tag);
2082
2083 printf("@@@ debug: Matrix configuration: M = %d, N = %d, K = "
2084 "%d, M_tail = %d, N_tail = %d, K_tail = %d, LDA = %d, LDB = %d, LDC "
2085 "= %d ur = %d\n",
2086 jcp.M, jcp.N, jcp.K, jcp.M_tail, jcp.N_tail, jcp.K_tail, jcp.LDA,
2087 jcp.LDB, jcp.LDC, best_brgb.ur);
2088 printf("@@@ debug: brg_type = %d use_buffer = %d \n", jcp.brg_type,
2089 jcp.use_buffer);
2090 fflush(nullptr);
2091 #endif
2092 return status::success;
2093 }
2094
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_brgemm_conv_conf_t & jcp)2095 void init_scratchpad(memory_tracking::registrar_t &scratchpad,
2096 const jit_brgemm_conv_conf_t &jcp) {
2097 if (jcp.brg_type == brgemm_addr || jcp.brg_type == brgemm_offs
2098 || (jcp.brg_type == brgemm_strd && jcp.exec_type == exec_vpad))
2099 scratchpad.book(key_brgemm_primitive_batch,
2100 static_cast<size_t>(jcp.nthr) * jcp.adjusted_batch_size,
2101 sizeof(brgemm_batch_element_t), 64);
2102 if (jcp.exec_type == exec_trans) {
2103 size_t inp_buffer_size
2104 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_size;
2105 scratchpad.book(
2106 key_conv_brgemm_inp_buffer, inp_buffer_size, jcp.src_dsz);
2107 size_t inp_buffer_mask_size
2108 = static_cast<size_t>(jcp.nthr) * jcp.inp_buffer_mask_size;
2109 scratchpad.book(key_conv_brgemm_inp_buffer_mask, inp_buffer_mask_size,
2110 sizeof(uint8_t));
2111 }
2112 if (jcp.use_buffer) {
2113 scratchpad.book(key_brgemm_primitive_buffer, jcp.nthr * jcp.buffer_size,
2114 jcp.acc_dsz);
2115 }
2116 if (is_amx(jcp.isa)) {
2117 scratchpad.book(
2118 key_conv_amx_tile_buffer, jcp.nthr * 4 * 1024, sizeof(char));
2119 }
2120 }
2121
2122 } // namespace brgemm_convolution_utils
2123
2124 } // namespace x64
2125 } // namespace cpu
2126 } // namespace impl
2127 } // namespace dnnl
2128