1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_CONV_CONFIG_HPP
18 #define GPU_JIT_CONV_CONFIG_HPP
19 
20 #include <iostream>
21 #include <sstream>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/convolution_pd.hpp"
25 #include "common/math_utils.hpp"
26 #include "common/memory_desc_wrapper.hpp"
27 #include "common/type_helpers.hpp"
28 #include "gpu/compute/compute.hpp"
29 #include "gpu/compute/compute_engine.hpp"
30 #include "gpu/jit/conv/fma_support.hpp"
31 #include "gpu/jit/conv/tensor.hpp"
32 #include "gpu/jit/conv/utils.hpp"
33 #include "gpu/jit/jit_eltwise_injector.hpp"
34 
35 namespace dnnl {
36 namespace impl {
37 namespace gpu {
38 namespace jit {
39 
40 // Description of the convolution problem.
41 class conv_problem_t {
42 public:
43     conv_problem_t() = default;
44 
init(convolution_pd_t * conv_pd)45     status_t init(convolution_pd_t *conv_pd) {
46         if (conv_pd->has_zero_dim_memory()) return status::unimplemented;
47 
48         is_fwd = conv_pd->is_fwd();
49         is_bwd_d = conv_pd->is_bwd_d();
50         is_bwd_w = conv_pd->is_bwd_w();
51         with_bias = conv_pd->with_bias();
52         with_groups = conv_pd->with_groups();
53 
54         orig_src_md = *conv_pd->invariant_src_md();
55         orig_wei_md = *conv_pd->invariant_wei_md();
56         orig_dst_md = *conv_pd->invariant_dst_md();
57         orig_bia_md = *conv_pd->invariant_bia_md();
58 
59         src_data_type = orig_src_md.data_type;
60         wei_data_type = orig_wei_md.data_type;
61         dst_data_type = orig_dst_md.data_type;
62         bia_data_type = orig_bia_md.data_type;
63 
64         if (with_bias)
65             bia_layout = layout_t(orig_bia_md, "a", /*do_normalize=*/false);
66 
67         ndims = conv_pd->ndims();
68 
69         mb = conv_pd->MB();
70         g = conv_pd->G();
71         ic = ir_utils::safe_divide(conv_pd->IC(), g);
72         oc = ir_utils::safe_divide(conv_pd->OC(), g);
73 
74         // Input spatial.
75         id = conv_pd->ID();
76         ih = conv_pd->IH();
77         iw = conv_pd->IW();
78 
79         // Output spatial.
80         od = conv_pd->OD();
81         oh = conv_pd->OH();
82         ow = conv_pd->OW();
83 
84         // Kernel sizes.
85         kd = conv_pd->KD();
86         kh = conv_pd->KH();
87         kw = conv_pd->KW();
88 
89         // Strides.
90         sd = conv_pd->KSD();
91         sh = conv_pd->KSH();
92         sw = conv_pd->KSW();
93 
94         // Padding.
95         pd = conv_pd->padFront();
96         ph = conv_pd->padT();
97         pw = conv_pd->padL();
98 
99         // Dilation.
100         dd = conv_pd->KDD();
101         dh = conv_pd->KDH();
102         dw = conv_pd->KDW();
103 
104         try_reduce_to_1d();
105 
106         is_dw = with_groups && (g > 1) && (oc == 1) && (ic == 1);
107 
108         return status::success;
109     }
110 
111     // Reduces dimensions for 1x1 kernel.
try_reduce_to_1d()112     void try_reduce_to_1d() {
113         bool is_1x1 = (kd * kh * kw == 1);
114         bool is_stride1 = (sd == 1 && sh == 1 && sw == 1);
115         bool is_eq_oi = (od == id && oh == ih && ow == iw);
116         if (is_1x1 && is_stride1 && is_eq_oi) {
117             ir_assert(pd == 0 && ph == 0 && pw == 0);
118             ow = od * oh * ow;
119             iw = id * ih * iw;
120             od = id = kd = 1;
121             oh = ih = kh = 1;
122             reduced_to_1d = true;
123         }
124     }
125 
orig_src_mdw() const126     memory_desc_wrapper orig_src_mdw() const {
127         return memory_desc_wrapper(orig_src_md);
128     }
orig_wei_mdw() const129     memory_desc_wrapper orig_wei_mdw() const {
130         return memory_desc_wrapper(orig_wei_md);
131     }
orig_dst_mdw() const132     memory_desc_wrapper orig_dst_mdw() const {
133         return memory_desc_wrapper(orig_dst_md);
134     }
135 
desc_str() const136     std::string desc_str() const {
137         std::ostringstream oss;
138         oss << "mb" << mb;
139         oss << "g" << g;
140         oss << "ic" << ic;
141         oss << "id" << id;
142         oss << "ih" << ih;
143         oss << "iw" << iw;
144         oss << "oc" << oc;
145         oss << "od" << od;
146         oss << "oh" << oh;
147         oss << "ow" << ow;
148         oss << "kd" << kd;
149         oss << "kh" << kh;
150         oss << "kw" << kw;
151         oss << "pd" << pd;
152         oss << "ph" << ph;
153         oss << "pw" << pw;
154         return oss.str();
155     }
156 
157     memory_desc_t orig_src_md;
158     memory_desc_t orig_wei_md;
159     memory_desc_t orig_dst_md;
160     memory_desc_t orig_bia_md;
161 
162     layout_t src_layout;
163     layout_t wei_layout;
164     layout_t dst_layout;
165     layout_t bia_layout;
166 
167     data_type_t src_data_type;
168     data_type_t wei_data_type;
169     data_type_t dst_data_type;
170     data_type_t bia_data_type;
171 
172     bool is_fwd;
173     bool is_bwd_d;
174     bool is_bwd_w;
175     bool with_bias;
176     bool with_groups;
177     bool is_dw;
178 
179     int ndims;
180     int mb; // Batch size.
181     int g; // Groups.
182     int ic, oc; // Input and output channels.
183     int id, ih, iw; // Input spatial sizes.
184     int od, oh, ow; // Output spatial sizes.
185     int kd, kh, kw; // Kernel sizes.
186     int sd, sh, sw; // Strides.
187     int pd, ph, pw; // Padding in the beginning.
188     int dd, dh, dw; // Dilation.
189     bool reduced_to_1d; // Whether the problem spatial was reduced to 1D.
190 };
191 
192 // Parameters for kernel generation.
193 class conv_config_t : public conv_problem_t {
194 public:
195     conv_config_t() = default;
196 
init(convolution_pd_t * conv_pd,primitive_attr_t * attr,engine_t * engine)197     status_t init(convolution_pd_t *conv_pd, primitive_attr_t *attr,
198             engine_t *engine) {
199         // These functions have implicit dependencies between them. They cannot be
200         // reordered with verifying these dependencies are satisfied.
201         CHECK(conv_problem_t::init(conv_pd));
202         CHECK(init_hw(engine));
203         CHECK(init_abc_data_types());
204         CHECK(init_acc_data_type());
205         CHECK(init_fma_kind());
206         CHECK(init_data_layouts(conv_pd));
207 
208         if (!data_types_ok()) return status::unimplemented;
209 
210         // Group convolution is not supported.
211         // Depthwise convolution is supported for forward.
212         if (with_groups && g > 1 && !(is_dw && is_fwd))
213             return status::unimplemented;
214 
215         CHECK(init_common_config());
216 
217         const memory_desc_t *output_md = nullptr;
218         if (is_fwd) {
219             CHECK(init_fwd(conv_pd, engine));
220             output_md = conv_pd->dst_md();
221         } else if (is_bwd_d) {
222             CHECK(init_bwd_d(conv_pd));
223             output_md = conv_pd->diff_src_md();
224         } else if (is_bwd_w) {
225             CHECK(init_bwd_w());
226             output_md = conv_pd->diff_weights_md();
227         } else {
228             ir_error_not_expected();
229         }
230 
231         CHECK(attr->set_default_formats(output_md));
232 
233         if (!post_ops_ok(conv_pd)) return status::unimplemented;
234         if (!hw_ok(engine)) return status::unimplemented;
235 
236         return status::success;
237     }
238 
init_fwd(convolution_pd_t * conv_pd,engine_t * engine)239     status_t init_fwd(convolution_pd_t *conv_pd, engine_t *engine) {
240         using namespace ir_utils;
241 
242         if (ic < 16 && !is_dpas_fma() && !is_dw) return status::unimplemented;
243 
244         auto &src_md = *conv_pd->invariant_src_md();
245         auto &dst_md = *conv_pd->invariant_dst_md();
246         bool is_src_nhwc = (orig_src_mdw().is_plain()
247                 && src_layout == make_layout(src_md, "axb"));
248         bool is_dst_nhwc = (orig_dst_mdw().is_plain()
249                 && dst_layout == make_layout(dst_md, "axb"));
250         // Set dispatch and kernel parameters.
251         if (is_dw) {
252             g_tg_blk = (is_int8_dst() ? 32 : 16);
253             mb_thr_blk
254                     = (mb < 16 || is_src_nhwc ? 1
255                                               : hw <= ngen::HW::XeLP ? 8 : 16);
256             mb_thr_dim = (mb_thr_blk == 1 ? 1 : 2);
257             ow_thr_blk = (mb_thr_blk == 1 ? 8 : 1);
258             ow_thr_dim = 1;
259             oc_thr_blk = 1;
260             oc_thr_dim = 1;
261             ic_thr_dim = 1;
262             ic_blk = 1;
263 
264             int iw_load_blk = (ow_thr_blk - 1) * sw + (kw - 1) + 1;
265             bool do_kw_buf = (kw > 1 && mb_thr_blk == 1 && iw_load_blk <= 32);
266             kw_blk = (do_kw_buf ? kw : 1);
267         } else if (fma_kind == fma_kind_t::mad
268                 && src_data_type == data_type::f32) {
269             const int max_tg_size = 16;
270             g_tg_blk = 1;
271             mb_thr_blk = (mb < 16 ? 1 : 8);
272             mb_thr_dim = std::min((mb_thr_blk != 1) ? (32 / mb_thr_blk) : 1,
273                     utils::div_up(mb, mb_thr_blk));
274 #ifdef GEN_CONV_DEBUG
275             mb_thr_blk = getenv_int("mb_thr_blk", mb_thr_blk);
276 #endif
277             oc_thr_blk = 16;
278             oc_thr_dim = std::min(4, utils::div_up(oc, oc_thr_blk));
279             oc_thr_dim = (1 << math::ilog2q(oc_thr_dim));
280 
281             if (mb_thr_dim > 1) {
282                 ow_thr_blk = 1;
283                 ow_thr_dim = 1;
284             } else {
285                 const int pref_ow_thr_dim
286                         = max_tg_size / (oc_thr_dim * mb_thr_dim);
287                 const int pref_ow_block
288                         = (mb_thr_blk == 1) ? 8 : kw > 1 ? 4 : 1;
289                 ow_thr_blk = ow < pref_ow_block * pref_ow_thr_dim
290                         ? (1 << math::ilog2q(
291                                    utils::div_up(ow, pref_ow_thr_dim)))
292                         : pref_ow_block;
293                 ow_thr_dim = pref_ow_thr_dim;
294             }
295             ic_thr_dim = 1;
296             kw_blk = 1;
297             ic_blk = (is_small_ic() ? ic : 16);
298         } else if (is_dpas_fma()) {
299             g_tg_blk = 1;
300             mb_thr_blk = is_small_ic() ? 8 : (mb < 16 ? 1 : 32);
301             mb_thr_dim = (is_small_ic())
302                     ? (mb < 16 ? std::min(utils::div_up(mb, mb_thr_blk), 4) : 4)
303                     : 1;
304             oc_thr_blk = 32;
305             if (hw >= ngen::HW::XeHPC && !is_small_ic()) oc_thr_blk = 64;
306             oc_thr_dim = std::min(4, utils::div_up(oc, oc_thr_blk));
307             oc_thr_dim = (1 << math::ilog2q(oc_thr_dim));
308             if (is_small_ic()) {
309                 ow_thr_blk = 4;
310             } else {
311                 ow_thr_blk = (mb < 16 ? 16 : 1);
312                 if (ow < ow_thr_blk) ow_thr_blk = 8;
313             }
314             ow_thr_dim = is_small_ic()
315                     ? 1
316                     : std::min(4, utils::div_up(ow, ow_thr_blk));
317             if (is_small_ic()) {
318                 kw_blk = 8;
319                 ic_blk = (is_s32_accumulator() ? 4 : 2);
320             } else {
321                 kw_blk = 1;
322                 ic_blk = (is_s32_accumulator() ? 32 : 16);
323             }
324 
325             ic_thr_dim = init_fwd_ic_thr_dim(
326                     engine, mb_thr_blk, oc_thr_blk, ow_thr_blk, ic_blk);
327 
328             // Disable M/N thread group blocking when K thread group blocking
329             // is enabled. For some reason combining them results in lower
330             // performance.
331             if (ic_thr_dim > 1) {
332                 ow_thr_dim = 1;
333                 oc_thr_dim = 1;
334             }
335         } else {
336             ir_error_not_expected();
337         }
338         g_thr_blk = g_tg_blk;
339 
340         int ic_padded = utils::rnd_up(ic, ic_blk);
341         ic_thr_blk = ir_utils::safe_divide(ic_padded, ic_thr_dim);
342 
343         ow_thr_dim = (1 << math::ilog2q(ow_thr_dim));
344 
345 #ifdef GEN_CONV_DEBUG
346         mb_thr_blk = getenv_int("mb_thr_blk", mb_thr_blk);
347         mb_thr_dim = getenv_int("mb_thr_dim", mb_thr_dim);
348         oc_thr_blk = getenv_int("oc_thr_blk", oc_thr_blk);
349         oc_thr_dim = getenv_int("oc_thr_dim", oc_thr_dim);
350         ow_thr_blk = getenv_int("ow_thr_blk", ow_thr_blk);
351         ow_thr_dim = getenv_int("ow_thr_dim", ow_thr_dim);
352 #endif
353 
354         tg_grid_dim[0] = oc_thr_dim;
355         tg_grid_dim[1] = mb_thr_dim * ow_thr_dim;
356         tg_grid_dim[2] = ic_thr_dim;
357 
358         // Round down to a power of 2.
359         tg_grid_dim[0] = (1 << math::ilog2q(tg_grid_dim[0]));
360         tg_grid_dim[1] = (1 << math::ilog2q(tg_grid_dim[1]));
361         tg_grid_dim[2] = (1 << math::ilog2q(tg_grid_dim[2]));
362 
363 #ifdef GEN_CONV_DEBUG
364         tg_grid_dim[0] = getenv_int("tg0", tg_grid_dim[0]);
365         tg_grid_dim[1] = getenv_int("tg1", tg_grid_dim[1]);
366 #endif
367 
368         mb_tg_blk = mb_thr_dim * mb_thr_blk;
369         oc_tg_blk = oc_thr_dim * oc_thr_blk;
370         ow_tg_blk = ow_thr_dim * ow_thr_blk;
371 
372 #ifdef GEN_CONV_DEBUG
373         mb_tg_blk = getenv_int("mb_tg_blk", mb_tg_blk);
374         oc_tg_blk = getenv_int("oc_tg_blk", oc_tg_blk);
375         ow_tg_blk = getenv_int("ow_tg_blk", ow_tg_blk);
376 #endif
377 
378         // TODO: Update estimate_register_count.
379         b_blk = g_tg_blk;
380         m_tg_blk = mb_tg_blk * ow_tg_blk;
381         n_tg_blk = oc_tg_blk;
382         k_blk = ic_blk * kw_blk;
383 
384         int g_tg_padded = utils::rnd_up(g, g_tg_blk);
385         int mb_tg_padded = utils::rnd_up(mb, mb_tg_blk);
386         int oc_tg_padded = utils::rnd_up(oc, oc_tg_blk);
387         int ow_tg_padded = utils::rnd_up(ow, ow_tg_blk);
388 
389         g_tg_dim = g_tg_padded / g_tg_blk;
390         mb_tg_dim = mb_tg_padded / mb_tg_blk;
391         oc_tg_dim = oc_tg_padded / oc_tg_blk;
392 
393         ow_tg_dim = ow_tg_padded / ow_tg_blk;
394 
395         kernel_grid_dim[0] = oc_tg_dim;
396         kernel_grid_dim[1] = g_tg_dim * od * oh * ow_tg_dim;
397         kernel_grid_dim[2] = mb_tg_dim;
398 
399         allow_grf_reorder = is_small_ic() || is_dw;
400 
401         if (kd * kh * kw > 9) do_loop_unroll = false;
402         if (is_dw) {
403             use_preload = false;
404             do_loop_unroll = false;
405         }
406         if (is_small_ic()) {
407             reuse_headers = true;
408             do_loop_unroll = false;
409         }
410 
411         regs = hw <= ngen::HW::XeLP ? 128 : 256;
412 
413         // XXX: in case of nhwc or small mb allow reorders on XeHPC
414         // since A/B tile loads may be strided
415         if (hw >= ngen::HW::XeHPC
416                 && (mb_thr_blk == 1 || is_src_nhwc || is_dst_nhwc))
417             allow_grf_reorder = true;
418 
419         if (mb >= 16) {
420             // Large batch performance is slightly behind for some cases.
421             bool large_batch_ok = false;
422             if (hw >= ngen::HW::XeHPC) large_batch_ok = true;
423             if (is_src_nhwc) large_batch_ok = true;
424             // TODO: Fix issues with mb zero padding
425             if (is_small_ic() && mb % 16 == 0) large_batch_ok = true;
426             if (!large_batch_ok) return status::unimplemented;
427         }
428 
429         fixup_inference_consistency();
430         if (!try_reduce_grf_usage()) return status::unimplemented;
431 
432         return status::success;
433     }
434 
init_bwd_d(convolution_pd_t * conv_pd)435     status_t init_bwd_d(convolution_pd_t *conv_pd) {
436         using namespace ir_utils;
437 
438         // Set dispatch and kernel parameters.
439         mb_thr_blk = (mb < 16 ? 1 : 32);
440         ic_thr_blk = 32;
441         if (hw >= ngen::HW::XeHPC) ic_thr_blk = 64;
442         iw_thr_blk = (mb < 16 ? 16 : 1);
443         if (iw < iw_thr_blk) iw_thr_blk = 8;
444 
445 #ifdef GEN_CONV_DEBUG
446         mb_thr_blk = getenv_int("mb_thr_blk", mb_thr_blk);
447         ic_thr_blk = getenv_int("ic_thr_blk", ic_thr_blk);
448         iw_thr_blk = getenv_int("iw_thr_blk", iw_thr_blk);
449 #endif
450 
451         regs = 256;
452 
453         tg_grid_dim[0] = std::min(4, utils::div_up(ic, ic_thr_blk));
454         tg_grid_dim[1] = std::min(4, utils::div_up(iw, iw_thr_blk));
455         tg_grid_dim[2] = 1;
456 
457         // Round down to a power of 2.
458         tg_grid_dim[0] = (1 << math::ilog2q(tg_grid_dim[0]));
459         tg_grid_dim[1] = (1 << math::ilog2q(tg_grid_dim[1]));
460         tg_grid_dim[2] = (1 << math::ilog2q(tg_grid_dim[2]));
461 
462 #ifdef GEN_CONV_DEBUG
463         tg_grid_dim[0] = getenv_int("tg0", tg_grid_dim[0]);
464         tg_grid_dim[1] = getenv_int("tg1", tg_grid_dim[1]);
465 #endif
466 
467         mb_tg_blk = mb_thr_blk;
468         ic_tg_blk = tg_grid_dim[0] * ic_thr_blk;
469         iw_tg_blk = tg_grid_dim[1] * iw_thr_blk;
470         oc_blk = (is_s32_accumulator() ? 32 : 16);
471 
472 #ifdef GEN_CONV_DEBUG
473         mb_tg_blk = getenv_int("mb_tg_blk", mb_tg_blk);
474         ic_tg_blk = getenv_int("ic_tg_blk", ic_tg_blk);
475         iw_tg_blk = getenv_int("iw_tg_blk", iw_tg_blk);
476 #endif
477 
478         m_tg_blk = mb_tg_blk * iw_tg_blk;
479         n_tg_blk = ic_tg_blk;
480         k_blk = oc_blk;
481 
482         int mb_tg_padded = utils::rnd_up(mb, mb_tg_blk);
483         int ic_tg_padded = utils::rnd_up(ic, ic_tg_blk);
484         int iw_tg_padded = utils::rnd_up(iw, iw_tg_blk);
485 
486         mb_tg_dim = mb_tg_padded / mb_tg_blk;
487         ic_tg_dim = ic_tg_padded / ic_tg_blk;
488 
489         iw_tg_dim = iw_tg_padded / iw_tg_blk;
490 
491         kernel_grid_dim[0] = ic_tg_dim;
492         kernel_grid_dim[1] = id * ih * iw_tg_dim;
493         kernel_grid_dim[2] = mb_tg_dim;
494 
495         allow_grf_reorder = false;
496 
497         // Do not perform full unrolling when there are too many inner
498         // iterations.
499         int kernel_limit = is_f32_conv() ? 4 : 9;
500         if (kd * kh * kw > kernel_limit) do_loop_unroll = false;
501 
502         // Do not perform full unrolling with non-unit stride. These cases have
503         // non-trivial post-increment updates which result in unrolling all
504         // reduction loops and exceeding the instruction cache.
505         if (sd * sh * sw != 1) do_loop_unroll = false;
506 
507         fixup_inference_consistency();
508         if (!try_reduce_grf_usage()) return status::unimplemented;
509 
510         auto &src_md = *conv_pd->invariant_src_md();
511         auto &dst_md = *conv_pd->invariant_dst_md();
512 
513         // Validate layouts.
514         bool is_src_nhwc = (orig_src_mdw().is_plain()
515                 && src_layout == make_layout(src_md, "axb"));
516         bool is_dst_nhwc = (orig_dst_mdw().is_plain()
517                 && dst_layout == make_layout(dst_md, "axb"));
518         // XXX: in case of nhwc or small mb allow reorders on XeHPC
519         // since A/B tile loads may be strided
520         if (hw >= ngen::HW::XeHPC
521                 && (mb_thr_blk == 1 || is_src_nhwc || is_dst_nhwc))
522             allow_grf_reorder = true;
523 
524         if (hw < ngen::HW::XeHPC)
525             // Blocked large batch performance is slightly behind.
526             if (!is_src_nhwc && mb >= 16) return status::unimplemented;
527 
528         return status::success;
529     }
530 
init_bwd_w()531     status_t init_bwd_w() {
532         using namespace ir_utils;
533 
534         if (fma_kind == fma_kind_t::mad) {
535             // Performance for small ic and small mb is worse than ocl:ncsp
536             // implementation
537             if (is_small_ic() && mb < 16) return status::unimplemented;
538 
539             oc_thr_blk = simd_size;
540             ic_thr_blk = (ic < simd_size ? utils::rnd_up_pow2(ic) : simd_size);
541             kw_blk = utils::rnd_up_pow2(
542                     std::min(utils::div_up(simd_size, ic_thr_blk), kw));
543             mb_blk = mb < 16 ? 1 : 16;
544             mb_tg_blk = mb_blk;
545             ow_thr_blk = mb < 16 ? std::min(16, utils::rnd_up_pow2(ow)) : 1;
546         } else if (is_dpas_fma()) {
547             oc_thr_blk = (oc <= 16 ? 16 : 32);
548             if (hw >= ngen::HW::XeHPC) oc_thr_blk = (oc <= 16 ? 16 : 64);
549             // Value required due to blocking in dpas data format
550             int min_ic_thr_blk = is_s32_accumulator() ? 4 : 2;
551             ic_thr_blk = (ic <= 16
552                             ? std::max(utils::rnd_up_pow2(ic), min_ic_thr_blk)
553                             : mb < 16 ? 16 : 32);
554             kw_blk = utils::rnd_up_pow2(
555                     std::min(utils::div_up(16, ic_thr_blk), kw));
556 
557             mb_blk = mb < 16 ? 1 : 16;
558             mb_tg_blk = (mb < 16 || mb <= mb_blk) ? mb_blk : 2 * mb_blk;
559             ow_thr_blk = mb < 16 ? 16 : 1;
560             // TODO: Investigate why insufficient registers even though m_tg_blk is
561             // the same
562             if (mb < 16 && kw > 8 && kw_blk >= 8) kw_blk = 4;
563         } else {
564             ir_error_not_expected();
565         }
566 
567 #ifdef GEN_CONV_DEBUG
568         oc_thr_blk = getenv_int("oc_thr_blk", oc_thr_blk);
569         ic_thr_blk = getenv_int("ic_thr_blk", ic_thr_blk);
570         kw_blk = getenv_int("kw_blk", kw_blk);
571         ow_thr_blk = getenv_int("ow_thr_blk", ow_thr_blk);
572         mb_blk = getenv_int("mb_blk", mb_blk);
573         mb_tg_blk = getenv_int("mb_tg_blk", mb_tg_blk);
574 #endif
575 
576         kw_tg_dim = utils::div_up(kw, kw_blk);
577 
578         int max_oc_thr_dim = 4;
579         int max_ic_thr_dim = 4;
580 
581         // Prefer larger thread groups when possible on XeHPC.
582         if (hw >= ngen::HW::XeHPC) {
583             if (oc / oc_thr_blk >= 8) {
584                 max_oc_thr_dim = 8;
585             } else {
586                 max_ic_thr_dim = 8;
587             }
588         }
589 
590         regs = 256;
591         tg_grid_dim[0]
592                 = std::min(max_oc_thr_dim, utils::div_up(oc, oc_thr_blk));
593         tg_grid_dim[1]
594                 = std::min(max_ic_thr_dim, utils::div_up(ic, ic_thr_blk));
595         tg_grid_dim[2] = 1;
596 
597         // Round down to a power of 2.
598         tg_grid_dim[0] = (1 << math::ilog2q(tg_grid_dim[0]));
599         tg_grid_dim[1] = (1 << math::ilog2q(tg_grid_dim[1]));
600         tg_grid_dim[2] = (1 << math::ilog2q(tg_grid_dim[2]));
601 
602 #ifdef GEN_CONV_DEBUG
603         tg_grid_dim[0] = getenv_int("tg0", tg_grid_dim[0]);
604         tg_grid_dim[1] = getenv_int("tg1", tg_grid_dim[1]);
605 #endif
606 
607         oc_tg_blk = tg_grid_dim[0] * oc_thr_blk;
608         ic_tg_blk = tg_grid_dim[1] * ic_thr_blk;
609         kw_tg_blk = kw_blk;
610 
611         init_bwd_w_spatial_blocks();
612 
613         mb_unroll = mb_tg_blk / mb_blk;
614         ow_unroll = mb < 16 && is_dpas_fma() ? ow_tg_blk / ow_thr_blk : 1;
615 
616         m_tg_blk = ic_tg_blk * kw_tg_blk;
617         n_tg_blk = oc_tg_blk;
618         k_blk = mb_blk * ow_thr_blk;
619 
620         int oc_tg_padded = utils::rnd_up(oc, oc_tg_blk);
621         int ic_tg_padded = utils::rnd_up(ic, ic_tg_blk);
622         int mb_tg_padded = utils::rnd_up(mb, mb_tg_blk);
623         int od_tg_padded = utils::rnd_up(od, od_tg_blk);
624         int oh_tg_padded = utils::rnd_up(oh, oh_tg_blk);
625         int ow_tg_padded = utils::rnd_up(ow, ow_tg_blk);
626 
627         oc_tg_dim = oc_tg_padded / oc_tg_blk;
628         ic_tg_dim = ic_tg_padded / ic_tg_blk;
629 
630         mb_tg_dim = mb_tg_padded / mb_tg_blk;
631         od_tg_dim = od_tg_padded / od_tg_blk;
632         oh_tg_dim = oh_tg_padded / oh_tg_blk;
633         ow_tg_dim = ow_tg_padded / ow_tg_blk;
634 
635         kernel_grid_dim[0] = oc_tg_dim;
636         kernel_grid_dim[1] = ic_tg_dim * kd * kh * kw_tg_dim * od_tg_dim
637                 * oh_tg_dim * ow_tg_dim;
638         kernel_grid_dim[2] = mb_tg_dim;
639 
640         // Set BWD_W-specific settings.
641         do_b_reduction = with_bias;
642         do_loop_unroll = (hw >= ngen::HW::XeHPC && is_dpas_fma() && mb_blk > 1);
643         allow_grf_reorder = is_dpas_fma();
644         zero_out_output = true;
645         do_atomic_update = true;
646         do_post_wei_reorder = (wei_data_type == data_type::bf16);
647         do_post_bia_reorder = (with_bias && bia_data_type == data_type::bf16);
648 
649 #ifdef GEN_CONV_DEBUG
650         do_loop_unroll = getenv_bool("do_loop_unroll", do_loop_unroll);
651         allow_grf_reorder = getenv_bool("allow_grf_reorder", allow_grf_reorder);
652 #endif
653 
654         fixup_inference_consistency();
655         if (!try_reduce_grf_usage()) return status::unimplemented;
656 
657         if (do_post_wei_reorder) {
658             wei_layout = wei_layout.retype(type_t::f32());
659             orig_wei_md.data_type = data_type::f32;
660         }
661         if (do_post_bia_reorder) {
662             bia_layout = bia_layout.retype(type_t::f32());
663             orig_bia_md.data_type = data_type::f32;
664         }
665 
666         // XXX: disable f32 bwd_w due to hang
667         if (hw == ngen::HW::XeHP || hw == ngen::HW::XeHPG)
668             if (src_data_type == data_type::f32
669                     && dst_data_type == data_type::f32)
670                 return status::unimplemented;
671 
672         return status::success;
673     }
674 
init_bwd_w_spatial_blocks()675     void init_bwd_w_spatial_blocks() {
676         od_tg_blk = 1;
677         oh_tg_blk = 1;
678         ow_tg_blk = ow_thr_blk;
679         bool are_small_large_channels
680                 = (std::min(ic, oc) <= 64 && std::max(ic, oc) >= 256);
681         int sp_min_blk = 24;
682         int sp_max_blk = (are_small_large_channels ? 100 : 50);
683 
684         auto get_score = [&](int oh_blk, int ow_blk) {
685             int sp_blk = oh_blk * ow_blk;
686             int oh_padded = utils::rnd_up(oh, oh_blk);
687             int ow_padded = utils::rnd_up(ow, ow_blk);
688 
689             double extra_work
690                     = (oh_padded * ow_padded - oh * ow) / double(oh * ow);
691             // ohw_eff == 0: no useful computation
692             // ohw_eff == 1: all computation is useful
693             double ohw_eff = 1 - std::min(extra_work, 1.0);
694             int score = int(ohw_eff * 10000);
695 
696             // Prefer [sp_min_blk; sp_max_blk] range for the total spatial size.
697             bool sp_size_ok = (sp_blk >= sp_min_blk && sp_blk <= sp_max_blk);
698 
699             if (hw >= ngen::HW::XeHPC) {
700                 bool sp_block_ok = false;
701                 // Avoid OH blocking when OW blocking is enabled and big enough (to
702                 // avoid code explosion due after mandatory unrolling of inner
703                 // iterations). Exception: when OH/OW are fully blocked - even with
704                 // code explosion such blocks may give the best performance.
705                 sp_block_ok |= (oh_blk == 1 || ow_blk <= 2);
706                 sp_block_ok |= (oh_blk == oh && ow_blk == ow);
707                 if (sp_size_ok && sp_block_ok) {
708                     double sp_range = sp_max_blk - sp_min_blk;
709                     double sp_score = (sp_blk - sp_min_blk) / sp_range * 100;
710                     score += sp_score;
711                 }
712             } else if (sp_size_ok) {
713                 score += 100;
714             }
715             return score;
716         };
717 
718         int max_score = 0;
719         for (int oh_blk = 1; oh_blk <= sp_max_blk; oh_blk++) {
720             for (int ow_blk = ow_thr_blk; ow_blk <= sp_max_blk;
721                     ow_blk += ow_thr_blk) {
722                 int score = get_score(oh_blk, ow_blk);
723                 if (score > max_score) {
724                     oh_tg_blk = oh_blk;
725                     ow_tg_blk = ow_blk;
726                     max_score = score;
727                 }
728             }
729         }
730 
731 #ifdef GEN_CONV_DEBUG
732         od_tg_blk = getenv_int("od_tg_blk", od_tg_blk);
733         oh_tg_blk = getenv_int("oh_tg_blk", oh_tg_blk);
734         ow_tg_blk = getenv_int("ow_tg_blk", ow_tg_blk);
735 #endif
736     }
737 
init_common_config()738     status_t init_common_config() {
739         using namespace ir_utils;
740 
741         use_preload = true;
742 
743         if (hw <= ngen::HW::XeLP) use_preload = false;
744 
745         // No SLM buffering by default (will be enabled later).
746         disable_slm_buffering();
747 
748         // No prefetch by default (will be enabled later).
749         disable_prefetch();
750 
751         do_b_reduction = false;
752         pad_slm = true;
753         assign_sbids = is_dpas_fma();
754         do_loop_unroll = hw > ngen::HW::XeLP;
755         reduce_grf_usage = true;
756         zero_out_output = false;
757         do_atomic_update = false;
758         reuse_headers = hw <= ngen::HW::XeLP;
759         do_post_wei_reorder = false;
760         do_post_bia_reorder = false;
761         a_sub_tiles = 1;
762         b_sub_tiles = 1;
763 
764 #ifdef GEN_CONV_DEBUG
765         use_preload = getenv_bool("use_preload", use_preload);
766         pad_slm = getenv_bool("pad_slm", pad_slm);
767         assign_sbids = getenv_bool("assign_sbids", assign_sbids);
768         do_loop_unroll = getenv_bool("do_loop_unroll", do_loop_unroll);
769         reduce_grf_usage = getenv_bool("reduce_grf_usage", reduce_grf_usage);
770         allow_grf_reorder = getenv_bool("allow_grf_reorder", allow_grf_reorder);
771         reuse_headers = getenv_bool("reuse_headers", reuse_headers);
772         a_sub_tiles = getenv_int("a_sub_tiles", a_sub_tiles);
773         b_sub_tiles = getenv_int("b_sub_tiles", b_sub_tiles);
774 #endif
775 
776         return status::success;
777     }
778 
post_ops_ok(const convolution_pd_t * pd) const779     bool post_ops_ok(const convolution_pd_t *pd) const {
780         auto *attr = pd->attr();
781 
782         if (is_fwd || is_bwd_d) {
783             auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops
784                     | primitive_attr_t::skip_mask_t::oscale_runtime
785                     | primitive_attr_t::skip_mask_t::sum_dt;
786             if (!attr->has_default_values(attr_skip_mask)) return false;
787         } else {
788             if (!attr->has_default_values()) return false;
789         }
790 
791         if (!attr->output_scales_.has_default_values()) {
792             // Only common and per_oc output scales were tested.
793             if (!utils::one_of(attr->output_scales_.mask_, 0, (1 << 1)))
794                 return false;
795         }
796         for (int i = 0; i < attr->post_ops_.len(); i++) {
797             auto &po = attr->post_ops_.entry_[i];
798             if (po.is_eltwise()) {
799                 if (!jit_eltwise_injector_f32_is_supported(po.eltwise.alg))
800                     return false;
801             } else if (po.is_binary()) {
802                 int mask = utils::get_dims_mask(pd->invariant_dst_md()->dims,
803                         po.binary.src1_desc.dims, ndims);
804                 // per_oc broadcast is always supported.
805                 if ((mask & (1 << 1)) == 0) continue;
806                 auto rhs_layout = layout_t(po.binary.src1_desc);
807                 // No blocks means it's a scalar, can be always loaded.
808                 if (rhs_layout.blocks().empty()) return true;
809 
810                 auto rhs0 = rhs_layout.blocks()[0];
811                 int block_bytes = rhs0.block * rhs_layout.type().size();
812                 // Innermost block must:
813                 // - be across output channels
814                 // - be dense
815                 // - aligned to 32 bytes (for HWord loads)
816                 if (rhs0.dim_idx != 1 || dim_t(rhs0.stride) != 1
817                         || block_bytes % 32 != 0)
818                     return false;
819             }
820         }
821         return true;
822     }
823 
hw_ok(const engine_t * engine) const824     bool hw_ok(const engine_t *engine) const {
825         auto *compute_engine
826                 = utils::downcast<const compute::compute_engine_t *>(engine);
827         if (regs == 256 && !compute_engine->mayiuse_large_grf_mode())
828             return false;
829         return true;
830     }
831 
data_types_ok() const832     bool data_types_ok() const {
833         bool is_bf16 = utils::one_of(data_type::bf16, src_data_type,
834                 wei_data_type, dst_data_type, bia_data_type);
835         if (is_bf16 && hw <= ngen::HW::XeLP) return false;
836 
837         if (is_fwd) return true;
838         if (is_bwd_d) return true;
839         if (is_bwd_w) {
840             bool ok = true;
841             ok &= (src_data_type == data_type::bf16
842                     || src_data_type == data_type::f32);
843             ok &= (dst_data_type == src_data_type);
844             ok &= utils::one_of(wei_data_type, src_data_type, data_type::f32);
845 
846             if (with_bias) {
847                 ok &= utils::one_of(
848                         bia_data_type, src_data_type, data_type::f32);
849             }
850             return ok;
851         }
852         return false;
853     }
854 
is_s32_accumulator() const855     bool is_s32_accumulator() const { return acc_data_type == data_type::s32; }
is_f32_conv() const856     bool is_f32_conv() const {
857         return utils::everyone_is(src_data_type, wei_data_type, data_type::f32);
858     }
is_int8_dst() const859     bool is_int8_dst() const {
860         return utils::one_of(dst_data_type, data_type::s8, data_type::u8);
861     }
is_small_ic() const862     bool is_small_ic() const { return ic < simd_size; }
is_dpas_fma() const863     bool is_dpas_fma() const {
864         return utils::one_of(fma_kind, fma_kind_t::dpas, fma_kind_t::dpasw);
865     }
866 
grf_size() const867     int grf_size() const { return ngen::GRF::bytes(hw); }
868 
nd_range() const869     compute::nd_range_t nd_range() const {
870         size_t gws[3];
871         size_t lws[3];
872         for (int i = 0; i < 3; i++) {
873             lws[i] = tg_grid_dim[i] * (i == 0 ? simd_size : 1);
874             gws[i] = kernel_grid_dim[i] * lws[i];
875         }
876         return compute::nd_range_t(gws, lws);
877     }
878 
str() const879     std::string str() const {
880         using namespace ir_utils;
881 
882         std::ostringstream oss;
883         // clang-format off
884         oss << "  Problem:                    " << desc_str() << std::endl;
885         oss << "  Source layout:              " << src_layout << std::endl;
886         oss << "  Weights layout:             " << wei_layout << std::endl;
887         oss << "  Destination layout:         " << dst_layout << std::endl;
888         oss << "  MB TG block:                " << mb_tg_blk << std::endl;
889         oss << "  OD TG block:                " << od_tg_blk << std::endl;
890         oss << "  OH TG block:                " << oh_tg_blk << std::endl;
891         oss << "  OW TG block:                " << ow_tg_blk << std::endl;
892         oss << "  OC TG block:                " << oc_tg_blk << std::endl;
893         oss << "  Kernel grid:                " << make_seq_print_helper(kernel_grid_dim, " x ") << std::endl;
894         oss << "  Thread group:               " << make_seq_print_helper(tg_grid_dim, " x ") << std::endl;
895         oss << "  FMA kind:                   " << fma_kind::to_string(fma_kind) << std::endl;
896         oss << "  Use SLM for A:              " << to_string(use_a_slm) << std::endl;
897         oss << "  Use SLM for B:              " << to_string(use_b_slm) << std::endl;
898         oss << "  SLM buffers:                " << slm_bufs << std::endl;
899         oss << "  GMEM to SLM, GRF buffers:   " << gmem_bufs << std::endl;
900         oss << "  Pad SLM:                    " << to_string(pad_slm) << std::endl;
901         oss << "  Use prefetch:               " << to_string(use_prefetch) << std::endl;
902         oss << "  Prefetch buffers:           " << prefetch_bufs << std::endl;
903         oss << "  Do loop unroll:             " << to_string(do_loop_unroll) << std::endl;
904         oss << "  Assign SBIDs:               " << to_string(assign_sbids) << std::endl;
905         oss << "  Reduce GRF usage:           " << to_string(reduce_grf_usage) << std::endl;
906         oss << "  Reuse headers:              " << to_string(reuse_headers) << std::endl;
907         oss << "  Allow GRF reorder:          " << to_string(allow_grf_reorder) << std::endl;
908         oss << "  A sub-tiles:                " << a_sub_tiles << std::endl;
909         oss << "  B sub-tiles:                " << b_sub_tiles << std::endl;
910         // clang-format on
911         return oss.str();
912     }
913 
914     data_type_t a_data_type;
915     data_type_t b_data_type;
916     data_type_t c_data_type;
917     data_type_t acc_data_type;
918 
919     ngen::HW hw = ngen::HW::Unknown;
920     int simd_size; // SIMD width.
921     int regs; // Number of registers.
922 
923     // Thread group dimensions (thread group grid).
924     std::array<int, 3> tg_grid_dim;
925 
926     // Number of thread groups across dimensions (kernel grid).
927     std::array<int, 3> kernel_grid_dim;
928 
929     // Number of thread group blocks across problem dimensions.
930     int g_tg_dim;
931     int ic_tg_dim;
932     int iw_tg_dim;
933     int kw_tg_dim;
934     int mb_tg_dim;
935     int oc_tg_dim;
936     int od_tg_dim;
937     int oh_tg_dim;
938     int ow_tg_dim;
939 
940     // Block sizes per thread group.
941     int g_tg_blk;
942     int ic_tg_blk;
943     int iw_tg_blk;
944     int kw_tg_blk;
945     int mb_tg_blk;
946     int oc_tg_blk;
947     int od_tg_blk;
948     int oh_tg_blk;
949     int ow_tg_blk;
950 
951     // Number of thread blocks across problem dimensions.
952     int ic_thr_dim;
953     int mb_thr_dim;
954     int oc_thr_dim;
955     int ow_thr_dim;
956 
957     // Block sizes per thread.
958     int g_thr_blk;
959     int ic_thr_blk;
960     int iw_thr_blk;
961     int mb_thr_blk;
962     int oc_thr_blk;
963     int ow_thr_blk;
964 
965     // Block sizes per iteration.
966     int ic_blk;
967     int kw_blk;
968     int mb_blk;
969     int oc_blk;
970 
971     // Block sizes in GEMM notation.
972     int b_blk;
973     int m_tg_blk;
974     int n_tg_blk;
975     int k_blk;
976 
977     // Unroll sizes.
978     int mb_unroll;
979     int ow_unroll;
980 
981     bool do_b_reduction;
982 
983     fma_kind_t fma_kind; // Which instruction backend to use.
984 
985     bool use_preload; // Whether to use SLM or prefetch.
986     bool use_a_slm; // Whether to use SLM for A.
987     bool use_b_slm; // Whether to use SLM for B.
988     bool use_prefetch; // Whether to use prefetch for A and B.
989     bool pad_slm; // Whether to pad SLM to avoid write conflicts.
990     bool assign_sbids; // Whether to manually assign SBID tokens.
991     int slm_bufs; // Number of SLM buffers to use.
992     int gmem_bufs; // Number of GRF buffers to use for GMEM -> SLM copy.
993     int prefetch_bufs; // Number of prefetch buffers for A and B.
994     bool do_loop_unroll; // Whether to fully unroll inner loops.
995     bool reduce_grf_usage; // Whether to try to reduce GRF usage based on heuristics.
996     bool allow_grf_reorder; // Whether to allow GRF reorders to FMA-friendly layouts.
997     bool zero_out_output; // Whether to zero out outputs before the main kernel.
998     bool do_atomic_update; // Whether to use atomics during C update.
999     bool reuse_headers; // Whether to reuse header messages to reduce GRF usage.
1000 
1001     // Specific to BWD_W.
1002     bool do_post_bia_reorder; // Whether to perform extra reorder for weights.
1003     bool do_post_wei_reorder; // Whether to perform extra reorder for bias.
1004 
1005     // Sub-tiles to split into for the inner A x B multiplication:
1006     // for i in range(0, a_sub_tiles):
1007     //     A_i = load(...)
1008     //     for j in range(0, b_sub_tiles):
1009     //         B_j = load(...)
1010     //         C_i_j += A_i * B_j
1011     //
1012     // GRF buffers for A_i and B_j are reused. Factors greater than one help to
1013     // reduce GRF usage.
1014     int a_sub_tiles;
1015     int b_sub_tiles;
1016 
1017 private:
init_fwd_ic_thr_dim(engine_t * engine,int mb_thr_blk,int oc_thr_blk,int ow_thr_blk,int ic_blk) const1018     int init_fwd_ic_thr_dim(engine_t *engine, int mb_thr_blk, int oc_thr_blk,
1019             int ow_thr_blk, int ic_blk) const {
1020         if (mb_thr_blk > 1) return 1;
1021 
1022         int ic_blocks = utils::div_up(ic, ic_blk);
1023         int reduction_blocks = ic_blocks * kd * kh * kw;
1024 
1025         int oc_nthr = utils::div_up(oc, oc_thr_blk);
1026         int ow_nthr = utils::div_up(ow, ow_thr_blk);
1027         int mb_nthr = utils::div_up(mb, mb_thr_blk);
1028         int nthr = mb_nthr * oc_nthr * od * oh * ow_nthr;
1029 
1030         auto *compute_engine
1031                 = utils::downcast<const compute::compute_engine_t *>(engine);
1032         int eus = compute_engine->device_info()->eu_count();
1033 
1034         int ret_ic_thr_dim = 1;
1035         if (!is_small_ic() && reduction_blocks >= 16 && (nthr < eus)) {
1036             ret_ic_thr_dim = ir_utils::max_divisor(ic_blocks, {1, 2, 4, 8});
1037 
1038             // If reduction is too small, limit k-slicing.
1039             int reduction_threshold = 32;
1040             if (reduction_blocks < reduction_threshold) {
1041                 int max_ic_thr_dim = utils::div_up(eus, nthr);
1042                 max_ic_thr_dim = (1 << math::ilog2q(max_ic_thr_dim));
1043                 ret_ic_thr_dim = std::min(ret_ic_thr_dim, max_ic_thr_dim);
1044             }
1045         }
1046         return ret_ic_thr_dim;
1047     }
1048 
init_hw(engine_t * engine)1049     status_t init_hw(engine_t *engine) {
1050         using namespace compute;
1051 
1052         auto compute_engine
1053                 = utils::downcast<compute::compute_engine_t *>(engine);
1054         auto device_info = compute_engine->device_info();
1055 
1056         switch (device_info->gpu_arch()) {
1057             case gpu_arch_t::gen9: hw = ngen::HW::Gen9; break;
1058             case gpu_arch_t::xe_lp: hw = ngen::HW::XeLP; break;
1059             case gpu_arch_t::xe_hp: hw = ngen::HW::XeHP; break;
1060             case gpu_arch_t::xe_hpg: hw = ngen::HW::XeHPG; break;
1061             case gpu_arch_t::xe_hpc: hw = ngen::HW::XeHPC; break;
1062             default: return status::unimplemented;
1063         }
1064         return status::success;
1065     }
1066 
1067     // Initializes A/B/C data types (GEMM notation: C += A * B) according to
1068     // the following convention:
1069     // FWD:        src -> A,      wei -> B,      dst -> C
1070     // BWD_D: diff_dst -> A,      wei -> B, diff_src -> C
1071     // BWD_W:      src -> A, diff_dst -> B, diff_wei -> C
init_abc_data_types()1072     status_t init_abc_data_types() {
1073         if (is_fwd) {
1074             a_data_type = src_data_type;
1075             b_data_type = wei_data_type;
1076             c_data_type = dst_data_type;
1077         } else if (is_bwd_d) {
1078             a_data_type = dst_data_type;
1079             b_data_type = wei_data_type;
1080             c_data_type = src_data_type;
1081         } else if (is_bwd_w) {
1082             a_data_type = src_data_type;
1083             b_data_type = dst_data_type;
1084             // Always use f32 for accumulation/storing in the main kernel.
1085             c_data_type = data_type::f32;
1086         } else {
1087             ir_error_not_expected();
1088         }
1089         return status::success;
1090     }
1091 
init_acc_data_type()1092     status_t init_acc_data_type() {
1093         auto a = a_data_type;
1094         auto b = b_data_type;
1095         auto c = c_data_type;
1096         if (utils::one_of(a, data_type::s8, data_type::u8)
1097                 && utils::one_of(b, data_type::s8, data_type::u8)) {
1098             acc_data_type = data_type::s32;
1099             return status::success;
1100         }
1101         if (utils::everyone_is(data_type::f16, a, b)
1102                 || utils::everyone_is(data_type::bf16, a, b)) {
1103             acc_data_type = data_type::f32;
1104             return status::success;
1105         }
1106         if (utils::everyone_is(data_type::f32, a, b, c)) {
1107             acc_data_type = data_type::f32;
1108             return status::success;
1109         }
1110         return status::unimplemented;
1111     }
1112 
init_fma_kind()1113     status_t init_fma_kind() {
1114         fma_kind = fma_kind::get_supported_kind(
1115                 hw, a_data_type, b_data_type, acc_data_type);
1116 
1117         simd_size = fma_kind::get_simd_size(
1118                 hw, fma_kind, a_data_type, b_data_type, acc_data_type);
1119 
1120         bool use_mad = false;
1121         if (is_small_ic() && !is_dw) {
1122             if (is_fwd && (kw != 7 || mb % 8 != 0))
1123                 use_mad = true;
1124             else if (is_bwd_d)
1125                 use_mad = true;
1126         } else if (is_dw) {
1127             use_mad = true;
1128         }
1129 
1130         if (use_mad) {
1131             fma_kind = fma_kind_t::mad;
1132             simd_size = fma_kind::get_simd_size(
1133                     hw, fma_kind, a_data_type, b_data_type, acc_data_type);
1134         }
1135 
1136 #ifdef GEN_CONV_DEBUG
1137         fma_kind = fma_kind::from_string(ir_utils::getenv_str(
1138                 "fma_kind", fma_kind::to_string(fma_kind)));
1139         simd_size = fma_kind::get_simd_size(
1140                 hw, fma_kind, a_data_type, b_data_type, acc_data_type);
1141 
1142 #endif
1143         if (fma_kind == fma_kind_t::unknown) return status::unimplemented;
1144 
1145         // Disable using mad instruction backend until performance parity is
1146         // reached with OpenCL kernels.
1147         if (fma_kind == fma_kind_t::mad) {
1148             if (hw < ngen::HW::XeHP) return status::unimplemented;
1149             if (is_bwd_d) {
1150                 if (!is_f32_conv()) return status::unimplemented;
1151                 if (is_small_ic()) return status::unimplemented;
1152                 return status::success;
1153             }
1154         }
1155 
1156         return status::success;
1157     }
1158 
init_data_layouts(convolution_pd_t * conv_pd)1159     status_t init_data_layouts(convolution_pd_t *conv_pd) {
1160         std::string src_tag;
1161         std::string wei_tag;
1162         std::string dst_tag;
1163 
1164         const bool is_wei16aXb = hw >= ngen::HW::XeHPC;
1165         assert(hw != ngen::HW::Unknown);
1166         bool is_mb_block = mb >= 16;
1167 
1168         // Src/Dst buffers should generally be the same format to avoid reorders
1169         // between FWD, BWD_D, and BWD_W.
1170         if (is_small_ic() && !is_dw) {
1171             src_tag = is_s32_accumulator() ? "ABx8a4b" : "ABx8a2b";
1172         } else if (fma_kind == fma_kind_t::mad) {
1173             if (is_s32_accumulator()) {
1174                 src_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1175             } else {
1176                 src_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1177             }
1178             if (is_fwd) {
1179                 int max_simd_size = 16;
1180                 if (simd_size > max_simd_size) simd_size = max_simd_size;
1181             }
1182         } else if (is_s32_accumulator()) {
1183             src_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1184         } else {
1185             src_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1186         }
1187 
1188         if (fma_kind == fma_kind_t::mad) {
1189             if (is_dw) {
1190                 if (is_int8_dst()) {
1191                     dst_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1192                 } else {
1193                     dst_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1194                 }
1195             } else {
1196                 dst_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1197             }
1198             if (is_bwd_d) {
1199                 int max_simd_size = 16;
1200                 if (simd_size > max_simd_size) simd_size = max_simd_size;
1201             }
1202         } else if (is_int8_dst()) {
1203             dst_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1204         } else {
1205             dst_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1206         }
1207 
1208         // Weight reorders are generally small, so reordering weights between
1209         // FWD and BWD_D/BWD_W implementations for optimization purposes makes
1210         // sense.
1211         if (is_fwd) {
1212             if (is_small_ic() && !is_dw) {
1213                 if (fma_kind == fma_kind_t::mad)
1214                     wei_tag = "bAx16a";
1215                 else if (is_s32_accumulator())
1216                     wei_tag = is_wei16aXb ? "ABx16a4b" : "ABx8a4b";
1217                 else
1218                     wei_tag = is_wei16aXb ? "ABx16a2b" : "ABx8a2b";
1219             } else {
1220                 if (is_dw) {
1221                     if (is_s32_accumulator())
1222                         wei_tag = "Abcx32a";
1223                     else
1224                         wei_tag = "Abcx16a";
1225                 } else if (fma_kind == fma_kind_t::mad) {
1226                     wei_tag = "BAx16b16a";
1227                 } else if (is_s32_accumulator()) {
1228                     wei_tag = is_wei16aXb ? "ABx2a8b16a4b" : "ABx4a8b8a4b";
1229                 } else {
1230                     wei_tag = is_wei16aXb ? "ABx2a8b16a2b" : "ABx4a8b8a2b";
1231                 }
1232             }
1233         } else if (is_bwd_d) {
1234             if (fma_kind == fma_kind_t::mad)
1235                 wei_tag = "ABx16a16b";
1236             else if (is_s32_accumulator())
1237                 wei_tag = is_wei16aXb ? "BAx2b8a16b4a" : "BAx4b8a8b4a";
1238             else
1239                 wei_tag = is_wei16aXb ? "BAx2b8a16b2a" : "BAx4b8a8b2a";
1240         } else if (is_bwd_w) {
1241             if (is_small_ic()) {
1242                 wei_tag = "Axb16a";
1243             } else {
1244                 wei_tag = "ABx16b16a";
1245             }
1246         }
1247 
1248         if (with_groups && !is_dw) wei_tag = prepend_groups_to_tag(wei_tag);
1249 
1250 #ifdef GEN_CONV_DEBUG
1251         src_tag = ir_utils::getenv_str("stag", src_tag);
1252         wei_tag = ir_utils::getenv_str("wtag", wei_tag);
1253         dst_tag = ir_utils::getenv_str("dtag", dst_tag);
1254 #endif
1255 
1256         auto &src_md = *conv_pd->invariant_src_md();
1257         auto &wei_md = *conv_pd->invariant_wei_md();
1258         auto &dst_md = *conv_pd->invariant_dst_md();
1259         auto &bia_md = *conv_pd->invariant_bia_md();
1260 
1261         // Select layouts.
1262         src_layout = init_layout(src_md, src_tag);
1263         wei_layout = init_layout(wei_md, wei_tag);
1264         dst_layout = init_layout(dst_md, dst_tag);
1265         if (with_bias) bia_layout = init_layout(bia_md, "a");
1266 
1267         // Validate layouts.
1268         bool is_src_nhwc = false;
1269         bool is_dst_nhwc = false;
1270 
1271         if (is_fwd || is_bwd_d) {
1272             is_src_nhwc = (orig_src_mdw().is_plain()
1273                     && src_layout == layout_t(src_md, "axb"));
1274             is_dst_nhwc = (orig_dst_mdw().is_plain()
1275                     && dst_layout == layout_t(dst_md, "axb"));
1276             if (is_src_nhwc != is_dst_nhwc) return status::unimplemented;
1277 
1278             // HWord loads require 32 byte alignment. For NHWC layout it means
1279             // input/output channels must be multiples of 32 bytes.
1280             size_t ic_bytes = ic * types::data_type_size(src_data_type);
1281             size_t oc_bytes = oc * types::data_type_size(dst_data_type);
1282             if (is_dst_nhwc && (ic_bytes % 32 != 0 || oc_bytes % 32 != 0))
1283                 return status::unimplemented;
1284         }
1285         if (!is_src_nhwc
1286                 && !src_layout.is_strictly_equal(make_layout(src_md, src_tag)))
1287             return status::unimplemented;
1288         if (!is_dst_nhwc
1289                 && !dst_layout.is_strictly_equal(make_layout(dst_md, dst_tag)))
1290             return status::unimplemented;
1291         if (!wei_layout.is_strictly_equal(make_layout(wei_md, wei_tag)))
1292             return status::unimplemented;
1293         return status::success;
1294     }
1295 
enable_slm_buffering()1296     void enable_slm_buffering() {
1297         using namespace ir_utils;
1298 
1299         use_a_slm = (tg_grid_dim[0] > 1);
1300         use_b_slm = (tg_grid_dim[1] > 1);
1301         if (use_a_slm || use_b_slm) {
1302             int pref_slm_bufs = (tg_grid_dim[0] * tg_grid_dim[1] <= 8 ? 2 : 3);
1303             if (do_loop_unroll) {
1304                 slm_bufs = pref_slm_bufs;
1305                 gmem_bufs = (is_dpas_fma() ? 2 : 1);
1306             } else {
1307                 // Double/triple SLM buffering is not supported when only one
1308                 // matrix is SLM-buffered.
1309                 slm_bufs = (use_a_slm == use_b_slm ? pref_slm_bufs : 1);
1310                 gmem_bufs = 1;
1311             }
1312         } else {
1313             slm_bufs = 0;
1314             gmem_bufs = 0;
1315         }
1316 #ifdef GEN_CONV_DEBUG
1317         use_a_slm = getenv_bool("use_a_slm", use_a_slm);
1318         use_b_slm = getenv_bool("use_b_slm", use_b_slm);
1319         slm_bufs = getenv_int("slm_bufs", slm_bufs);
1320         gmem_bufs = getenv_int("gmem_bufs", gmem_bufs);
1321 #endif
1322     }
1323 
enable_prefetch()1324     void enable_prefetch() {
1325         using namespace ir_utils;
1326 
1327         use_prefetch = true;
1328         prefetch_bufs = is_bwd_w ? 2 : 3;
1329 #ifdef GEN_CONV_DEBUG
1330         use_prefetch = getenv_bool("use_prefetch", use_prefetch);
1331         prefetch_bufs = getenv_int("prefetch_bufs", prefetch_bufs);
1332 #endif
1333     }
1334 
disable_slm_buffering()1335     void disable_slm_buffering() {
1336         use_a_slm = false;
1337         use_b_slm = false;
1338         slm_bufs = 0;
1339         gmem_bufs = 0;
1340     }
1341 
disable_prefetch()1342     void disable_prefetch() {
1343         use_prefetch = false;
1344         prefetch_bufs = 0;
1345     }
1346 
1347     // Overwrites parameters that are implied by other parameters.
fixup_inference_consistency()1348     void fixup_inference_consistency() {
1349         // Can't reuse headers with loop unroll and post-increment offset updates.
1350         if (reuse_headers) do_loop_unroll = false;
1351 
1352         bool prefer_prefetch = false;
1353         if (hw >= ngen::HW::XeHPC) prefer_prefetch = true;
1354 
1355         if (use_preload) {
1356             // Prefetches are only supported with loop unrolling.
1357             if (prefer_prefetch && do_loop_unroll) {
1358                 enable_prefetch();
1359             } else {
1360                 enable_slm_buffering();
1361             }
1362         }
1363         // Downgrade dpasw -> dpas for some cases.
1364         if (fma_kind == fma_kind_t::dpasw) {
1365             // dpasw is executed by fused EUs (across X thread group
1366             // dimension). Do not use dpasw if X is uneven.
1367             if (tg_grid_dim[0] % 2 != 0) fma_kind = fma_kind_t::dpas;
1368             // dpasw can't be generated in case of direct load from GMEM and reorder.
1369             if (is_bwd_w && allow_grf_reorder && (!use_a_slm || !use_b_slm))
1370                 fma_kind = fma_kind_t::dpas;
1371         }
1372     }
1373 
try_reduce_grf_usage()1374     bool try_reduce_grf_usage() {
1375         if (!reduce_grf_usage) return true;
1376 
1377         // TODO: improve estimate register count, it fails to account for tmp
1378         // values like mask_registers among other things.
1379         double reg_factor = is_bwd_w ? 0.875 : 0.95;
1380         int max_regs = int(regs * reg_factor);
1381         int regs = estimate_register_count();
1382         if (regs <= max_regs) return true;
1383 
1384         // Try to disable GRF buffering.
1385         if (gmem_bufs > 1) {
1386             gmem_bufs = 1;
1387             int regs = estimate_register_count();
1388             if (regs <= max_regs) return true;
1389         }
1390 
1391         // Try to use sub-tiles for B.
1392         int n_thr_blk = utils::div_up(n_tg_blk, tg_grid_dim[0]);
1393         int max_b_sub_tiles
1394                 = std::min((use_b_slm ? 4 : 2), n_thr_blk / simd_size);
1395         // XXX: avoid layout mismatch for B loads
1396         if (hw >= ngen::HW::XeHPC && is_bwd_w) max_b_sub_tiles = 2;
1397         while (b_sub_tiles < max_b_sub_tiles) {
1398             b_sub_tiles *= 2;
1399             int regs = estimate_register_count();
1400             if (regs <= max_regs) return true;
1401         }
1402 
1403         // Try to use double SLM buffering.
1404         if (slm_bufs == 3) {
1405             slm_bufs = 2;
1406             int regs = estimate_register_count();
1407             if (regs <= max_regs) return true;
1408         }
1409 
1410         // Try to use single SLM buffering.
1411         if (slm_bufs == 2) {
1412             slm_bufs = 1;
1413             int regs = estimate_register_count();
1414             if (regs <= max_regs) return true;
1415         }
1416 
1417         // Last resort settings to reduce GRF usage.
1418         reuse_headers = true;
1419         do_loop_unroll = false;
1420 
1421         return estimate_register_count() <= max_regs;
1422     }
1423 
estimate_register_count() const1424     int estimate_register_count() const {
1425         int reg_bytes = ngen::GRF::bytes(hw);
1426 
1427         // Assume 8 HWord per GMEM load for double-blocked layouts and 1 HWord
1428         // otherwise.
1429         int hword_bytes = 32;
1430         int a_gmem_msg_bytes
1431                 = (a_layout().is_n_blocked(2) ? 8 : 1) * hword_bytes;
1432         int b_gmem_msg_bytes
1433                 = (b_layout().is_n_blocked(2) ? 8 : 1) * hword_bytes;
1434 
1435         // Assume 8 HWords per SLM load/store.
1436         int slm_msg_bytes = 256;
1437 
1438         int nthr = tg_grid_dim[0] * tg_grid_dim[1];
1439         int m_thr_blk = utils::div_up(m_tg_blk, tg_grid_dim[1]);
1440         int n_thr_blk = utils::div_up(n_tg_blk, tg_grid_dim[0]);
1441         int k_thr_blk = k_blk;
1442 
1443         int a_size = int(types::data_type_size(a_data_type));
1444         int b_size = int(types::data_type_size(b_data_type));
1445         int acc_size = int(types::data_type_size(acc_data_type));
1446 
1447         // Registers for C += A * B operation.
1448         int a_tile_bytes = m_thr_blk * k_thr_blk * a_size;
1449         int b_tile_bytes = k_thr_blk * n_thr_blk * b_size;
1450         int a_bytes = utils::div_up(a_tile_bytes, a_sub_tiles);
1451         int b_bytes = utils::div_up(b_tile_bytes, b_sub_tiles);
1452         int acc_bytes = m_thr_blk * n_thr_blk * acc_size;
1453 
1454         int a_regs = utils::div_up(a_bytes, reg_bytes);
1455         int b_regs = utils::div_up(b_bytes, reg_bytes);
1456         int acc_regs = utils::div_up(acc_bytes, reg_bytes);
1457 
1458         int a_headers = utils::div_up(
1459                 a_tile_bytes, use_a_slm ? slm_msg_bytes : a_gmem_msg_bytes);
1460         int b_headers = utils::div_up(
1461                 b_tile_bytes, use_b_slm ? slm_msg_bytes : b_gmem_msg_bytes);
1462 
1463         if (fma_kind == fma_kind_t::dpasw) {
1464             // dpasw reuses registers between fused threads across tg0. M is
1465             // split across tg1, N is split across tg0 so dpasw allows to share
1466             // matrix A which is is (M x K).
1467             a_regs = utils::div_up(a_regs, 2);
1468             a_headers = utils::div_up(a_headers, 2);
1469         }
1470 
1471         // Size of A/B thread blocks when split full A/B TG blocks across all
1472         // threads in TG.
1473         int a_tg_per_thr_bytes = utils::div_up(m_tg_blk * k_blk * a_size, nthr);
1474         int b_tg_per_thr_bytes = utils::div_up(k_blk * n_tg_blk * b_size, nthr);
1475 
1476         // Temporary registers for GMEM -> SLM load.
1477         int a_g2s_bytes = (use_a_slm ? a_tg_per_thr_bytes : 0);
1478         int b_g2s_bytes = (use_b_slm ? b_tg_per_thr_bytes : 0);
1479 
1480         // Account for dedicated headers for prefetches.
1481         if (use_prefetch) {
1482             a_headers += utils::div_up(a_tg_per_thr_bytes, a_gmem_msg_bytes);
1483             b_headers += utils::div_up(b_tg_per_thr_bytes, b_gmem_msg_bytes);
1484         }
1485 
1486         int a_g2s_regs = utils::div_up(a_g2s_bytes, reg_bytes);
1487         int b_g2s_regs = utils::div_up(b_g2s_bytes, reg_bytes);
1488 
1489         // Two sets of headers for GMEM -> GRF and GRF -> SLM.
1490         int a_g2s_headers = utils::div_up(a_g2s_bytes, a_gmem_msg_bytes)
1491                 + utils::div_up(a_g2s_bytes, slm_msg_bytes);
1492         int b_g2s_headers = utils::div_up(b_g2s_bytes, b_gmem_msg_bytes)
1493                 + utils::div_up(b_g2s_bytes, slm_msg_bytes);
1494 
1495         // Extra registers for GRF <-> GRF reorders.
1496         int reorder_regs = 0;
1497 
1498         // Assume A/B need reorders to temporary buffers.
1499         if (allow_grf_reorder) {
1500             if (is_bwd_w) {
1501                 // Hardcode for now, this is the upper bound for the temporary
1502                 // buffer size for BWD_W.
1503                 int bwd_w_reorder_regs = 16;
1504                 reorder_regs += bwd_w_reorder_regs;
1505             }
1506 
1507             int ab_reorder_regs = 0;
1508 
1509             if (use_a_slm) {
1510                 ab_reorder_regs = std::max(ab_reorder_regs, a_g2s_regs);
1511             } else {
1512                 int a_reorder_regs = a_regs;
1513                 // Loads must be aligned to a GRF boundary, account for cases
1514                 // when the load size is less than the register size.
1515                 if (a_gmem_msg_bytes < reg_bytes) {
1516                     a_reorder_regs
1517                             *= utils::div_up(reg_bytes, a_gmem_msg_bytes);
1518                 }
1519                 ab_reorder_regs = std::max(ab_reorder_regs, a_reorder_regs);
1520             }
1521             if (use_b_slm) {
1522                 ab_reorder_regs = std::max(ab_reorder_regs, b_g2s_regs);
1523             } else {
1524                 int b_reorder_regs = b_regs;
1525                 // Loads must be aligned to a GRF boundary, account for cases
1526                 // when the load size is less than the register size.
1527                 if (b_gmem_msg_bytes < reg_bytes) {
1528                     b_reorder_regs
1529                             *= utils::div_up(reg_bytes, b_gmem_msg_bytes);
1530                 }
1531                 ab_reorder_regs = std::max(ab_reorder_regs, b_reorder_regs);
1532             }
1533             reorder_regs += ab_reorder_regs;
1534         }
1535 
1536         int g2s_regs = gmem_bufs * (a_g2s_regs + b_g2s_regs);
1537         int g2s_headers = a_g2s_headers + b_g2s_headers;
1538 
1539         int data_regs = a_regs + b_regs + acc_regs + g2s_regs;
1540         int header_regs = a_headers + b_headers + g2s_headers;
1541         if (reuse_headers) header_regs = 1;
1542 
1543         int estimated_regs = data_regs + reorder_regs + header_regs;
1544 
1545         return estimated_regs;
1546     }
1547 
a_layout() const1548     const layout_t &a_layout() const {
1549         const layout_t *ret = nullptr;
1550         if (is_fwd) {
1551             ret = &src_layout;
1552         } else if (is_bwd_d) {
1553             ret = &dst_layout;
1554         } else if (is_bwd_w) {
1555             ret = &src_layout;
1556         }
1557         ir_assert(ret && !ret->is_empty()) << "Layout is not initialized.";
1558         return *ret;
1559     }
1560 
b_layout() const1561     const layout_t &b_layout() const {
1562         const layout_t *ret = nullptr;
1563         if (is_fwd) {
1564             ret = &wei_layout;
1565         } else if (is_bwd_d) {
1566             ret = &wei_layout;
1567         } else if (is_bwd_w) {
1568             ret = &dst_layout;
1569         }
1570         ir_assert(ret && !ret->is_empty()) << "Layout is not initialized.";
1571         return *ret;
1572     }
1573 
prepend_groups_to_tag(const std::string & tag)1574     static std::string prepend_groups_to_tag(const std::string &tag) {
1575         auto ret = tag;
1576         for (auto &c : ret) {
1577             bool is_lower_dim = ('a' <= c && c < 'a' + DNNL_MAX_NDIMS);
1578             bool is_upper_dim = ('A' <= c && c < 'A' + DNNL_MAX_NDIMS);
1579             if (!is_lower_dim && !is_upper_dim) continue;
1580             c += 1;
1581         }
1582         return "a" + ret;
1583     }
1584 
init_layout(memory_desc_t & user_md,const std::string & optimal_tag)1585     static layout_t init_layout(
1586             memory_desc_t &user_md, const std::string &optimal_tag) {
1587         auto optimal = make_layout(user_md, optimal_tag);
1588         if (user_md.format_kind != format_kind::any) {
1589             auto user = make_layout(user_md);
1590             // If layouts are physically different return the layout passed by
1591             // the user and return unimplemented later.
1592             if (user != optimal) return user;
1593         } else {
1594             user_md = optimal.to_dnnl(user_md.dims);
1595         }
1596         return optimal;
1597     }
1598 
make_layout(const memory_desc_t & md)1599     static layout_t make_layout(const memory_desc_t &md) {
1600         return layout_t(md, /*do_normalize=*/false);
1601     }
1602 
make_layout(const memory_desc_t & md,const std::string & tag)1603     static layout_t make_layout(
1604             const memory_desc_t &md, const std::string &tag) {
1605         return layout_t(md, tag, /*do_normalize=*/false);
1606     }
1607 };
1608 
operator <<(std::ostream & out,const conv_config_t & cfg)1609 inline std::ostream &operator<<(std::ostream &out, const conv_config_t &cfg) {
1610     out << cfg.str();
1611     return out;
1612 }
1613 
1614 } // namespace jit
1615 } // namespace gpu
1616 } // namespace impl
1617 } // namespace dnnl
1618 
1619 #endif
1620