1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_JIT_CONV_CONFIG_HPP
18 #define GPU_JIT_CONV_CONFIG_HPP
19
20 #include <iostream>
21 #include <sstream>
22
23 #include "common/c_types_map.hpp"
24 #include "common/convolution_pd.hpp"
25 #include "common/math_utils.hpp"
26 #include "common/memory_desc_wrapper.hpp"
27 #include "common/type_helpers.hpp"
28 #include "gpu/compute/compute.hpp"
29 #include "gpu/compute/compute_engine.hpp"
30 #include "gpu/jit/conv/fma_support.hpp"
31 #include "gpu/jit/conv/tensor.hpp"
32 #include "gpu/jit/conv/utils.hpp"
33 #include "gpu/jit/jit_eltwise_injector.hpp"
34
35 namespace dnnl {
36 namespace impl {
37 namespace gpu {
38 namespace jit {
39
40 // Description of the convolution problem.
41 class conv_problem_t {
42 public:
43 conv_problem_t() = default;
44
init(convolution_pd_t * conv_pd)45 status_t init(convolution_pd_t *conv_pd) {
46 if (conv_pd->has_zero_dim_memory()) return status::unimplemented;
47
48 is_fwd = conv_pd->is_fwd();
49 is_bwd_d = conv_pd->is_bwd_d();
50 is_bwd_w = conv_pd->is_bwd_w();
51 with_bias = conv_pd->with_bias();
52 with_groups = conv_pd->with_groups();
53
54 orig_src_md = *conv_pd->invariant_src_md();
55 orig_wei_md = *conv_pd->invariant_wei_md();
56 orig_dst_md = *conv_pd->invariant_dst_md();
57 orig_bia_md = *conv_pd->invariant_bia_md();
58
59 src_data_type = orig_src_md.data_type;
60 wei_data_type = orig_wei_md.data_type;
61 dst_data_type = orig_dst_md.data_type;
62 bia_data_type = orig_bia_md.data_type;
63
64 if (with_bias)
65 bia_layout = layout_t(orig_bia_md, "a", /*do_normalize=*/false);
66
67 ndims = conv_pd->ndims();
68
69 mb = conv_pd->MB();
70 g = conv_pd->G();
71 ic = ir_utils::safe_divide(conv_pd->IC(), g);
72 oc = ir_utils::safe_divide(conv_pd->OC(), g);
73
74 // Input spatial.
75 id = conv_pd->ID();
76 ih = conv_pd->IH();
77 iw = conv_pd->IW();
78
79 // Output spatial.
80 od = conv_pd->OD();
81 oh = conv_pd->OH();
82 ow = conv_pd->OW();
83
84 // Kernel sizes.
85 kd = conv_pd->KD();
86 kh = conv_pd->KH();
87 kw = conv_pd->KW();
88
89 // Strides.
90 sd = conv_pd->KSD();
91 sh = conv_pd->KSH();
92 sw = conv_pd->KSW();
93
94 // Padding.
95 pd = conv_pd->padFront();
96 ph = conv_pd->padT();
97 pw = conv_pd->padL();
98
99 // Dilation.
100 dd = conv_pd->KDD();
101 dh = conv_pd->KDH();
102 dw = conv_pd->KDW();
103
104 try_reduce_to_1d();
105
106 is_dw = with_groups && (g > 1) && (oc == 1) && (ic == 1);
107
108 return status::success;
109 }
110
111 // Reduces dimensions for 1x1 kernel.
try_reduce_to_1d()112 void try_reduce_to_1d() {
113 bool is_1x1 = (kd * kh * kw == 1);
114 bool is_stride1 = (sd == 1 && sh == 1 && sw == 1);
115 bool is_eq_oi = (od == id && oh == ih && ow == iw);
116 if (is_1x1 && is_stride1 && is_eq_oi) {
117 ir_assert(pd == 0 && ph == 0 && pw == 0);
118 ow = od * oh * ow;
119 iw = id * ih * iw;
120 od = id = kd = 1;
121 oh = ih = kh = 1;
122 reduced_to_1d = true;
123 }
124 }
125
orig_src_mdw() const126 memory_desc_wrapper orig_src_mdw() const {
127 return memory_desc_wrapper(orig_src_md);
128 }
orig_wei_mdw() const129 memory_desc_wrapper orig_wei_mdw() const {
130 return memory_desc_wrapper(orig_wei_md);
131 }
orig_dst_mdw() const132 memory_desc_wrapper orig_dst_mdw() const {
133 return memory_desc_wrapper(orig_dst_md);
134 }
135
desc_str() const136 std::string desc_str() const {
137 std::ostringstream oss;
138 oss << "mb" << mb;
139 oss << "g" << g;
140 oss << "ic" << ic;
141 oss << "id" << id;
142 oss << "ih" << ih;
143 oss << "iw" << iw;
144 oss << "oc" << oc;
145 oss << "od" << od;
146 oss << "oh" << oh;
147 oss << "ow" << ow;
148 oss << "kd" << kd;
149 oss << "kh" << kh;
150 oss << "kw" << kw;
151 oss << "pd" << pd;
152 oss << "ph" << ph;
153 oss << "pw" << pw;
154 return oss.str();
155 }
156
157 memory_desc_t orig_src_md;
158 memory_desc_t orig_wei_md;
159 memory_desc_t orig_dst_md;
160 memory_desc_t orig_bia_md;
161
162 layout_t src_layout;
163 layout_t wei_layout;
164 layout_t dst_layout;
165 layout_t bia_layout;
166
167 data_type_t src_data_type;
168 data_type_t wei_data_type;
169 data_type_t dst_data_type;
170 data_type_t bia_data_type;
171
172 bool is_fwd;
173 bool is_bwd_d;
174 bool is_bwd_w;
175 bool with_bias;
176 bool with_groups;
177 bool is_dw;
178
179 int ndims;
180 int mb; // Batch size.
181 int g; // Groups.
182 int ic, oc; // Input and output channels.
183 int id, ih, iw; // Input spatial sizes.
184 int od, oh, ow; // Output spatial sizes.
185 int kd, kh, kw; // Kernel sizes.
186 int sd, sh, sw; // Strides.
187 int pd, ph, pw; // Padding in the beginning.
188 int dd, dh, dw; // Dilation.
189 bool reduced_to_1d; // Whether the problem spatial was reduced to 1D.
190 };
191
192 // Parameters for kernel generation.
193 class conv_config_t : public conv_problem_t {
194 public:
195 conv_config_t() = default;
196
init(convolution_pd_t * conv_pd,primitive_attr_t * attr,engine_t * engine)197 status_t init(convolution_pd_t *conv_pd, primitive_attr_t *attr,
198 engine_t *engine) {
199 // These functions have implicit dependencies between them. They cannot be
200 // reordered with verifying these dependencies are satisfied.
201 CHECK(conv_problem_t::init(conv_pd));
202 CHECK(init_hw(engine));
203 CHECK(init_abc_data_types());
204 CHECK(init_acc_data_type());
205 CHECK(init_fma_kind());
206 CHECK(init_data_layouts(conv_pd));
207
208 if (!data_types_ok()) return status::unimplemented;
209
210 // Group convolution is not supported.
211 // Depthwise convolution is supported for forward.
212 if (with_groups && g > 1 && !(is_dw && is_fwd))
213 return status::unimplemented;
214
215 CHECK(init_common_config());
216
217 const memory_desc_t *output_md = nullptr;
218 if (is_fwd) {
219 CHECK(init_fwd(conv_pd, engine));
220 output_md = conv_pd->dst_md();
221 } else if (is_bwd_d) {
222 CHECK(init_bwd_d(conv_pd));
223 output_md = conv_pd->diff_src_md();
224 } else if (is_bwd_w) {
225 CHECK(init_bwd_w());
226 output_md = conv_pd->diff_weights_md();
227 } else {
228 ir_error_not_expected();
229 }
230
231 CHECK(attr->set_default_formats(output_md));
232
233 if (!post_ops_ok(conv_pd)) return status::unimplemented;
234 if (!hw_ok(engine)) return status::unimplemented;
235
236 return status::success;
237 }
238
init_fwd(convolution_pd_t * conv_pd,engine_t * engine)239 status_t init_fwd(convolution_pd_t *conv_pd, engine_t *engine) {
240 using namespace ir_utils;
241
242 if (ic < 16 && !is_dpas_fma() && !is_dw) return status::unimplemented;
243
244 auto &src_md = *conv_pd->invariant_src_md();
245 auto &dst_md = *conv_pd->invariant_dst_md();
246 bool is_src_nhwc = (orig_src_mdw().is_plain()
247 && src_layout == make_layout(src_md, "axb"));
248 bool is_dst_nhwc = (orig_dst_mdw().is_plain()
249 && dst_layout == make_layout(dst_md, "axb"));
250 // Set dispatch and kernel parameters.
251 if (is_dw) {
252 g_tg_blk = (is_int8_dst() ? 32 : 16);
253 mb_thr_blk
254 = (mb < 16 || is_src_nhwc ? 1
255 : hw <= ngen::HW::XeLP ? 8 : 16);
256 mb_thr_dim = (mb_thr_blk == 1 ? 1 : 2);
257 ow_thr_blk = (mb_thr_blk == 1 ? 8 : 1);
258 ow_thr_dim = 1;
259 oc_thr_blk = 1;
260 oc_thr_dim = 1;
261 ic_thr_dim = 1;
262 ic_blk = 1;
263
264 int iw_load_blk = (ow_thr_blk - 1) * sw + (kw - 1) + 1;
265 bool do_kw_buf = (kw > 1 && mb_thr_blk == 1 && iw_load_blk <= 32);
266 kw_blk = (do_kw_buf ? kw : 1);
267 } else if (fma_kind == fma_kind_t::mad
268 && src_data_type == data_type::f32) {
269 const int max_tg_size = 16;
270 g_tg_blk = 1;
271 mb_thr_blk = (mb < 16 ? 1 : 8);
272 mb_thr_dim = std::min((mb_thr_blk != 1) ? (32 / mb_thr_blk) : 1,
273 utils::div_up(mb, mb_thr_blk));
274 #ifdef GEN_CONV_DEBUG
275 mb_thr_blk = getenv_int("mb_thr_blk", mb_thr_blk);
276 #endif
277 oc_thr_blk = 16;
278 oc_thr_dim = std::min(4, utils::div_up(oc, oc_thr_blk));
279 oc_thr_dim = (1 << math::ilog2q(oc_thr_dim));
280
281 if (mb_thr_dim > 1) {
282 ow_thr_blk = 1;
283 ow_thr_dim = 1;
284 } else {
285 const int pref_ow_thr_dim
286 = max_tg_size / (oc_thr_dim * mb_thr_dim);
287 const int pref_ow_block
288 = (mb_thr_blk == 1) ? 8 : kw > 1 ? 4 : 1;
289 ow_thr_blk = ow < pref_ow_block * pref_ow_thr_dim
290 ? (1 << math::ilog2q(
291 utils::div_up(ow, pref_ow_thr_dim)))
292 : pref_ow_block;
293 ow_thr_dim = pref_ow_thr_dim;
294 }
295 ic_thr_dim = 1;
296 kw_blk = 1;
297 ic_blk = (is_small_ic() ? ic : 16);
298 } else if (is_dpas_fma()) {
299 g_tg_blk = 1;
300 mb_thr_blk = is_small_ic() ? 8 : (mb < 16 ? 1 : 32);
301 mb_thr_dim = (is_small_ic())
302 ? (mb < 16 ? std::min(utils::div_up(mb, mb_thr_blk), 4) : 4)
303 : 1;
304 oc_thr_blk = 32;
305 if (hw >= ngen::HW::XeHPC && !is_small_ic()) oc_thr_blk = 64;
306 oc_thr_dim = std::min(4, utils::div_up(oc, oc_thr_blk));
307 oc_thr_dim = (1 << math::ilog2q(oc_thr_dim));
308 if (is_small_ic()) {
309 ow_thr_blk = 4;
310 } else {
311 ow_thr_blk = (mb < 16 ? 16 : 1);
312 if (ow < ow_thr_blk) ow_thr_blk = 8;
313 }
314 ow_thr_dim = is_small_ic()
315 ? 1
316 : std::min(4, utils::div_up(ow, ow_thr_blk));
317 if (is_small_ic()) {
318 kw_blk = 8;
319 ic_blk = (is_s32_accumulator() ? 4 : 2);
320 } else {
321 kw_blk = 1;
322 ic_blk = (is_s32_accumulator() ? 32 : 16);
323 }
324
325 ic_thr_dim = init_fwd_ic_thr_dim(
326 engine, mb_thr_blk, oc_thr_blk, ow_thr_blk, ic_blk);
327
328 // Disable M/N thread group blocking when K thread group blocking
329 // is enabled. For some reason combining them results in lower
330 // performance.
331 if (ic_thr_dim > 1) {
332 ow_thr_dim = 1;
333 oc_thr_dim = 1;
334 }
335 } else {
336 ir_error_not_expected();
337 }
338 g_thr_blk = g_tg_blk;
339
340 int ic_padded = utils::rnd_up(ic, ic_blk);
341 ic_thr_blk = ir_utils::safe_divide(ic_padded, ic_thr_dim);
342
343 ow_thr_dim = (1 << math::ilog2q(ow_thr_dim));
344
345 #ifdef GEN_CONV_DEBUG
346 mb_thr_blk = getenv_int("mb_thr_blk", mb_thr_blk);
347 mb_thr_dim = getenv_int("mb_thr_dim", mb_thr_dim);
348 oc_thr_blk = getenv_int("oc_thr_blk", oc_thr_blk);
349 oc_thr_dim = getenv_int("oc_thr_dim", oc_thr_dim);
350 ow_thr_blk = getenv_int("ow_thr_blk", ow_thr_blk);
351 ow_thr_dim = getenv_int("ow_thr_dim", ow_thr_dim);
352 #endif
353
354 tg_grid_dim[0] = oc_thr_dim;
355 tg_grid_dim[1] = mb_thr_dim * ow_thr_dim;
356 tg_grid_dim[2] = ic_thr_dim;
357
358 // Round down to a power of 2.
359 tg_grid_dim[0] = (1 << math::ilog2q(tg_grid_dim[0]));
360 tg_grid_dim[1] = (1 << math::ilog2q(tg_grid_dim[1]));
361 tg_grid_dim[2] = (1 << math::ilog2q(tg_grid_dim[2]));
362
363 #ifdef GEN_CONV_DEBUG
364 tg_grid_dim[0] = getenv_int("tg0", tg_grid_dim[0]);
365 tg_grid_dim[1] = getenv_int("tg1", tg_grid_dim[1]);
366 #endif
367
368 mb_tg_blk = mb_thr_dim * mb_thr_blk;
369 oc_tg_blk = oc_thr_dim * oc_thr_blk;
370 ow_tg_blk = ow_thr_dim * ow_thr_blk;
371
372 #ifdef GEN_CONV_DEBUG
373 mb_tg_blk = getenv_int("mb_tg_blk", mb_tg_blk);
374 oc_tg_blk = getenv_int("oc_tg_blk", oc_tg_blk);
375 ow_tg_blk = getenv_int("ow_tg_blk", ow_tg_blk);
376 #endif
377
378 // TODO: Update estimate_register_count.
379 b_blk = g_tg_blk;
380 m_tg_blk = mb_tg_blk * ow_tg_blk;
381 n_tg_blk = oc_tg_blk;
382 k_blk = ic_blk * kw_blk;
383
384 int g_tg_padded = utils::rnd_up(g, g_tg_blk);
385 int mb_tg_padded = utils::rnd_up(mb, mb_tg_blk);
386 int oc_tg_padded = utils::rnd_up(oc, oc_tg_blk);
387 int ow_tg_padded = utils::rnd_up(ow, ow_tg_blk);
388
389 g_tg_dim = g_tg_padded / g_tg_blk;
390 mb_tg_dim = mb_tg_padded / mb_tg_blk;
391 oc_tg_dim = oc_tg_padded / oc_tg_blk;
392
393 ow_tg_dim = ow_tg_padded / ow_tg_blk;
394
395 kernel_grid_dim[0] = oc_tg_dim;
396 kernel_grid_dim[1] = g_tg_dim * od * oh * ow_tg_dim;
397 kernel_grid_dim[2] = mb_tg_dim;
398
399 allow_grf_reorder = is_small_ic() || is_dw;
400
401 if (kd * kh * kw > 9) do_loop_unroll = false;
402 if (is_dw) {
403 use_preload = false;
404 do_loop_unroll = false;
405 }
406 if (is_small_ic()) {
407 reuse_headers = true;
408 do_loop_unroll = false;
409 }
410
411 regs = hw <= ngen::HW::XeLP ? 128 : 256;
412
413 // XXX: in case of nhwc or small mb allow reorders on XeHPC
414 // since A/B tile loads may be strided
415 if (hw >= ngen::HW::XeHPC
416 && (mb_thr_blk == 1 || is_src_nhwc || is_dst_nhwc))
417 allow_grf_reorder = true;
418
419 if (mb >= 16) {
420 // Large batch performance is slightly behind for some cases.
421 bool large_batch_ok = false;
422 if (hw >= ngen::HW::XeHPC) large_batch_ok = true;
423 if (is_src_nhwc) large_batch_ok = true;
424 // TODO: Fix issues with mb zero padding
425 if (is_small_ic() && mb % 16 == 0) large_batch_ok = true;
426 if (!large_batch_ok) return status::unimplemented;
427 }
428
429 fixup_inference_consistency();
430 if (!try_reduce_grf_usage()) return status::unimplemented;
431
432 return status::success;
433 }
434
init_bwd_d(convolution_pd_t * conv_pd)435 status_t init_bwd_d(convolution_pd_t *conv_pd) {
436 using namespace ir_utils;
437
438 // Set dispatch and kernel parameters.
439 mb_thr_blk = (mb < 16 ? 1 : 32);
440 ic_thr_blk = 32;
441 if (hw >= ngen::HW::XeHPC) ic_thr_blk = 64;
442 iw_thr_blk = (mb < 16 ? 16 : 1);
443 if (iw < iw_thr_blk) iw_thr_blk = 8;
444
445 #ifdef GEN_CONV_DEBUG
446 mb_thr_blk = getenv_int("mb_thr_blk", mb_thr_blk);
447 ic_thr_blk = getenv_int("ic_thr_blk", ic_thr_blk);
448 iw_thr_blk = getenv_int("iw_thr_blk", iw_thr_blk);
449 #endif
450
451 regs = 256;
452
453 tg_grid_dim[0] = std::min(4, utils::div_up(ic, ic_thr_blk));
454 tg_grid_dim[1] = std::min(4, utils::div_up(iw, iw_thr_blk));
455 tg_grid_dim[2] = 1;
456
457 // Round down to a power of 2.
458 tg_grid_dim[0] = (1 << math::ilog2q(tg_grid_dim[0]));
459 tg_grid_dim[1] = (1 << math::ilog2q(tg_grid_dim[1]));
460 tg_grid_dim[2] = (1 << math::ilog2q(tg_grid_dim[2]));
461
462 #ifdef GEN_CONV_DEBUG
463 tg_grid_dim[0] = getenv_int("tg0", tg_grid_dim[0]);
464 tg_grid_dim[1] = getenv_int("tg1", tg_grid_dim[1]);
465 #endif
466
467 mb_tg_blk = mb_thr_blk;
468 ic_tg_blk = tg_grid_dim[0] * ic_thr_blk;
469 iw_tg_blk = tg_grid_dim[1] * iw_thr_blk;
470 oc_blk = (is_s32_accumulator() ? 32 : 16);
471
472 #ifdef GEN_CONV_DEBUG
473 mb_tg_blk = getenv_int("mb_tg_blk", mb_tg_blk);
474 ic_tg_blk = getenv_int("ic_tg_blk", ic_tg_blk);
475 iw_tg_blk = getenv_int("iw_tg_blk", iw_tg_blk);
476 #endif
477
478 m_tg_blk = mb_tg_blk * iw_tg_blk;
479 n_tg_blk = ic_tg_blk;
480 k_blk = oc_blk;
481
482 int mb_tg_padded = utils::rnd_up(mb, mb_tg_blk);
483 int ic_tg_padded = utils::rnd_up(ic, ic_tg_blk);
484 int iw_tg_padded = utils::rnd_up(iw, iw_tg_blk);
485
486 mb_tg_dim = mb_tg_padded / mb_tg_blk;
487 ic_tg_dim = ic_tg_padded / ic_tg_blk;
488
489 iw_tg_dim = iw_tg_padded / iw_tg_blk;
490
491 kernel_grid_dim[0] = ic_tg_dim;
492 kernel_grid_dim[1] = id * ih * iw_tg_dim;
493 kernel_grid_dim[2] = mb_tg_dim;
494
495 allow_grf_reorder = false;
496
497 // Do not perform full unrolling when there are too many inner
498 // iterations.
499 int kernel_limit = is_f32_conv() ? 4 : 9;
500 if (kd * kh * kw > kernel_limit) do_loop_unroll = false;
501
502 // Do not perform full unrolling with non-unit stride. These cases have
503 // non-trivial post-increment updates which result in unrolling all
504 // reduction loops and exceeding the instruction cache.
505 if (sd * sh * sw != 1) do_loop_unroll = false;
506
507 fixup_inference_consistency();
508 if (!try_reduce_grf_usage()) return status::unimplemented;
509
510 auto &src_md = *conv_pd->invariant_src_md();
511 auto &dst_md = *conv_pd->invariant_dst_md();
512
513 // Validate layouts.
514 bool is_src_nhwc = (orig_src_mdw().is_plain()
515 && src_layout == make_layout(src_md, "axb"));
516 bool is_dst_nhwc = (orig_dst_mdw().is_plain()
517 && dst_layout == make_layout(dst_md, "axb"));
518 // XXX: in case of nhwc or small mb allow reorders on XeHPC
519 // since A/B tile loads may be strided
520 if (hw >= ngen::HW::XeHPC
521 && (mb_thr_blk == 1 || is_src_nhwc || is_dst_nhwc))
522 allow_grf_reorder = true;
523
524 if (hw < ngen::HW::XeHPC)
525 // Blocked large batch performance is slightly behind.
526 if (!is_src_nhwc && mb >= 16) return status::unimplemented;
527
528 return status::success;
529 }
530
init_bwd_w()531 status_t init_bwd_w() {
532 using namespace ir_utils;
533
534 if (fma_kind == fma_kind_t::mad) {
535 // Performance for small ic and small mb is worse than ocl:ncsp
536 // implementation
537 if (is_small_ic() && mb < 16) return status::unimplemented;
538
539 oc_thr_blk = simd_size;
540 ic_thr_blk = (ic < simd_size ? utils::rnd_up_pow2(ic) : simd_size);
541 kw_blk = utils::rnd_up_pow2(
542 std::min(utils::div_up(simd_size, ic_thr_blk), kw));
543 mb_blk = mb < 16 ? 1 : 16;
544 mb_tg_blk = mb_blk;
545 ow_thr_blk = mb < 16 ? std::min(16, utils::rnd_up_pow2(ow)) : 1;
546 } else if (is_dpas_fma()) {
547 oc_thr_blk = (oc <= 16 ? 16 : 32);
548 if (hw >= ngen::HW::XeHPC) oc_thr_blk = (oc <= 16 ? 16 : 64);
549 // Value required due to blocking in dpas data format
550 int min_ic_thr_blk = is_s32_accumulator() ? 4 : 2;
551 ic_thr_blk = (ic <= 16
552 ? std::max(utils::rnd_up_pow2(ic), min_ic_thr_blk)
553 : mb < 16 ? 16 : 32);
554 kw_blk = utils::rnd_up_pow2(
555 std::min(utils::div_up(16, ic_thr_blk), kw));
556
557 mb_blk = mb < 16 ? 1 : 16;
558 mb_tg_blk = (mb < 16 || mb <= mb_blk) ? mb_blk : 2 * mb_blk;
559 ow_thr_blk = mb < 16 ? 16 : 1;
560 // TODO: Investigate why insufficient registers even though m_tg_blk is
561 // the same
562 if (mb < 16 && kw > 8 && kw_blk >= 8) kw_blk = 4;
563 } else {
564 ir_error_not_expected();
565 }
566
567 #ifdef GEN_CONV_DEBUG
568 oc_thr_blk = getenv_int("oc_thr_blk", oc_thr_blk);
569 ic_thr_blk = getenv_int("ic_thr_blk", ic_thr_blk);
570 kw_blk = getenv_int("kw_blk", kw_blk);
571 ow_thr_blk = getenv_int("ow_thr_blk", ow_thr_blk);
572 mb_blk = getenv_int("mb_blk", mb_blk);
573 mb_tg_blk = getenv_int("mb_tg_blk", mb_tg_blk);
574 #endif
575
576 kw_tg_dim = utils::div_up(kw, kw_blk);
577
578 int max_oc_thr_dim = 4;
579 int max_ic_thr_dim = 4;
580
581 // Prefer larger thread groups when possible on XeHPC.
582 if (hw >= ngen::HW::XeHPC) {
583 if (oc / oc_thr_blk >= 8) {
584 max_oc_thr_dim = 8;
585 } else {
586 max_ic_thr_dim = 8;
587 }
588 }
589
590 regs = 256;
591 tg_grid_dim[0]
592 = std::min(max_oc_thr_dim, utils::div_up(oc, oc_thr_blk));
593 tg_grid_dim[1]
594 = std::min(max_ic_thr_dim, utils::div_up(ic, ic_thr_blk));
595 tg_grid_dim[2] = 1;
596
597 // Round down to a power of 2.
598 tg_grid_dim[0] = (1 << math::ilog2q(tg_grid_dim[0]));
599 tg_grid_dim[1] = (1 << math::ilog2q(tg_grid_dim[1]));
600 tg_grid_dim[2] = (1 << math::ilog2q(tg_grid_dim[2]));
601
602 #ifdef GEN_CONV_DEBUG
603 tg_grid_dim[0] = getenv_int("tg0", tg_grid_dim[0]);
604 tg_grid_dim[1] = getenv_int("tg1", tg_grid_dim[1]);
605 #endif
606
607 oc_tg_blk = tg_grid_dim[0] * oc_thr_blk;
608 ic_tg_blk = tg_grid_dim[1] * ic_thr_blk;
609 kw_tg_blk = kw_blk;
610
611 init_bwd_w_spatial_blocks();
612
613 mb_unroll = mb_tg_blk / mb_blk;
614 ow_unroll = mb < 16 && is_dpas_fma() ? ow_tg_blk / ow_thr_blk : 1;
615
616 m_tg_blk = ic_tg_blk * kw_tg_blk;
617 n_tg_blk = oc_tg_blk;
618 k_blk = mb_blk * ow_thr_blk;
619
620 int oc_tg_padded = utils::rnd_up(oc, oc_tg_blk);
621 int ic_tg_padded = utils::rnd_up(ic, ic_tg_blk);
622 int mb_tg_padded = utils::rnd_up(mb, mb_tg_blk);
623 int od_tg_padded = utils::rnd_up(od, od_tg_blk);
624 int oh_tg_padded = utils::rnd_up(oh, oh_tg_blk);
625 int ow_tg_padded = utils::rnd_up(ow, ow_tg_blk);
626
627 oc_tg_dim = oc_tg_padded / oc_tg_blk;
628 ic_tg_dim = ic_tg_padded / ic_tg_blk;
629
630 mb_tg_dim = mb_tg_padded / mb_tg_blk;
631 od_tg_dim = od_tg_padded / od_tg_blk;
632 oh_tg_dim = oh_tg_padded / oh_tg_blk;
633 ow_tg_dim = ow_tg_padded / ow_tg_blk;
634
635 kernel_grid_dim[0] = oc_tg_dim;
636 kernel_grid_dim[1] = ic_tg_dim * kd * kh * kw_tg_dim * od_tg_dim
637 * oh_tg_dim * ow_tg_dim;
638 kernel_grid_dim[2] = mb_tg_dim;
639
640 // Set BWD_W-specific settings.
641 do_b_reduction = with_bias;
642 do_loop_unroll = (hw >= ngen::HW::XeHPC && is_dpas_fma() && mb_blk > 1);
643 allow_grf_reorder = is_dpas_fma();
644 zero_out_output = true;
645 do_atomic_update = true;
646 do_post_wei_reorder = (wei_data_type == data_type::bf16);
647 do_post_bia_reorder = (with_bias && bia_data_type == data_type::bf16);
648
649 #ifdef GEN_CONV_DEBUG
650 do_loop_unroll = getenv_bool("do_loop_unroll", do_loop_unroll);
651 allow_grf_reorder = getenv_bool("allow_grf_reorder", allow_grf_reorder);
652 #endif
653
654 fixup_inference_consistency();
655 if (!try_reduce_grf_usage()) return status::unimplemented;
656
657 if (do_post_wei_reorder) {
658 wei_layout = wei_layout.retype(type_t::f32());
659 orig_wei_md.data_type = data_type::f32;
660 }
661 if (do_post_bia_reorder) {
662 bia_layout = bia_layout.retype(type_t::f32());
663 orig_bia_md.data_type = data_type::f32;
664 }
665
666 // XXX: disable f32 bwd_w due to hang
667 if (hw == ngen::HW::XeHP || hw == ngen::HW::XeHPG)
668 if (src_data_type == data_type::f32
669 && dst_data_type == data_type::f32)
670 return status::unimplemented;
671
672 return status::success;
673 }
674
init_bwd_w_spatial_blocks()675 void init_bwd_w_spatial_blocks() {
676 od_tg_blk = 1;
677 oh_tg_blk = 1;
678 ow_tg_blk = ow_thr_blk;
679 bool are_small_large_channels
680 = (std::min(ic, oc) <= 64 && std::max(ic, oc) >= 256);
681 int sp_min_blk = 24;
682 int sp_max_blk = (are_small_large_channels ? 100 : 50);
683
684 auto get_score = [&](int oh_blk, int ow_blk) {
685 int sp_blk = oh_blk * ow_blk;
686 int oh_padded = utils::rnd_up(oh, oh_blk);
687 int ow_padded = utils::rnd_up(ow, ow_blk);
688
689 double extra_work
690 = (oh_padded * ow_padded - oh * ow) / double(oh * ow);
691 // ohw_eff == 0: no useful computation
692 // ohw_eff == 1: all computation is useful
693 double ohw_eff = 1 - std::min(extra_work, 1.0);
694 int score = int(ohw_eff * 10000);
695
696 // Prefer [sp_min_blk; sp_max_blk] range for the total spatial size.
697 bool sp_size_ok = (sp_blk >= sp_min_blk && sp_blk <= sp_max_blk);
698
699 if (hw >= ngen::HW::XeHPC) {
700 bool sp_block_ok = false;
701 // Avoid OH blocking when OW blocking is enabled and big enough (to
702 // avoid code explosion due after mandatory unrolling of inner
703 // iterations). Exception: when OH/OW are fully blocked - even with
704 // code explosion such blocks may give the best performance.
705 sp_block_ok |= (oh_blk == 1 || ow_blk <= 2);
706 sp_block_ok |= (oh_blk == oh && ow_blk == ow);
707 if (sp_size_ok && sp_block_ok) {
708 double sp_range = sp_max_blk - sp_min_blk;
709 double sp_score = (sp_blk - sp_min_blk) / sp_range * 100;
710 score += sp_score;
711 }
712 } else if (sp_size_ok) {
713 score += 100;
714 }
715 return score;
716 };
717
718 int max_score = 0;
719 for (int oh_blk = 1; oh_blk <= sp_max_blk; oh_blk++) {
720 for (int ow_blk = ow_thr_blk; ow_blk <= sp_max_blk;
721 ow_blk += ow_thr_blk) {
722 int score = get_score(oh_blk, ow_blk);
723 if (score > max_score) {
724 oh_tg_blk = oh_blk;
725 ow_tg_blk = ow_blk;
726 max_score = score;
727 }
728 }
729 }
730
731 #ifdef GEN_CONV_DEBUG
732 od_tg_blk = getenv_int("od_tg_blk", od_tg_blk);
733 oh_tg_blk = getenv_int("oh_tg_blk", oh_tg_blk);
734 ow_tg_blk = getenv_int("ow_tg_blk", ow_tg_blk);
735 #endif
736 }
737
init_common_config()738 status_t init_common_config() {
739 using namespace ir_utils;
740
741 use_preload = true;
742
743 if (hw <= ngen::HW::XeLP) use_preload = false;
744
745 // No SLM buffering by default (will be enabled later).
746 disable_slm_buffering();
747
748 // No prefetch by default (will be enabled later).
749 disable_prefetch();
750
751 do_b_reduction = false;
752 pad_slm = true;
753 assign_sbids = is_dpas_fma();
754 do_loop_unroll = hw > ngen::HW::XeLP;
755 reduce_grf_usage = true;
756 zero_out_output = false;
757 do_atomic_update = false;
758 reuse_headers = hw <= ngen::HW::XeLP;
759 do_post_wei_reorder = false;
760 do_post_bia_reorder = false;
761 a_sub_tiles = 1;
762 b_sub_tiles = 1;
763
764 #ifdef GEN_CONV_DEBUG
765 use_preload = getenv_bool("use_preload", use_preload);
766 pad_slm = getenv_bool("pad_slm", pad_slm);
767 assign_sbids = getenv_bool("assign_sbids", assign_sbids);
768 do_loop_unroll = getenv_bool("do_loop_unroll", do_loop_unroll);
769 reduce_grf_usage = getenv_bool("reduce_grf_usage", reduce_grf_usage);
770 allow_grf_reorder = getenv_bool("allow_grf_reorder", allow_grf_reorder);
771 reuse_headers = getenv_bool("reuse_headers", reuse_headers);
772 a_sub_tiles = getenv_int("a_sub_tiles", a_sub_tiles);
773 b_sub_tiles = getenv_int("b_sub_tiles", b_sub_tiles);
774 #endif
775
776 return status::success;
777 }
778
post_ops_ok(const convolution_pd_t * pd) const779 bool post_ops_ok(const convolution_pd_t *pd) const {
780 auto *attr = pd->attr();
781
782 if (is_fwd || is_bwd_d) {
783 auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops
784 | primitive_attr_t::skip_mask_t::oscale_runtime
785 | primitive_attr_t::skip_mask_t::sum_dt;
786 if (!attr->has_default_values(attr_skip_mask)) return false;
787 } else {
788 if (!attr->has_default_values()) return false;
789 }
790
791 if (!attr->output_scales_.has_default_values()) {
792 // Only common and per_oc output scales were tested.
793 if (!utils::one_of(attr->output_scales_.mask_, 0, (1 << 1)))
794 return false;
795 }
796 for (int i = 0; i < attr->post_ops_.len(); i++) {
797 auto &po = attr->post_ops_.entry_[i];
798 if (po.is_eltwise()) {
799 if (!jit_eltwise_injector_f32_is_supported(po.eltwise.alg))
800 return false;
801 } else if (po.is_binary()) {
802 int mask = utils::get_dims_mask(pd->invariant_dst_md()->dims,
803 po.binary.src1_desc.dims, ndims);
804 // per_oc broadcast is always supported.
805 if ((mask & (1 << 1)) == 0) continue;
806 auto rhs_layout = layout_t(po.binary.src1_desc);
807 // No blocks means it's a scalar, can be always loaded.
808 if (rhs_layout.blocks().empty()) return true;
809
810 auto rhs0 = rhs_layout.blocks()[0];
811 int block_bytes = rhs0.block * rhs_layout.type().size();
812 // Innermost block must:
813 // - be across output channels
814 // - be dense
815 // - aligned to 32 bytes (for HWord loads)
816 if (rhs0.dim_idx != 1 || dim_t(rhs0.stride) != 1
817 || block_bytes % 32 != 0)
818 return false;
819 }
820 }
821 return true;
822 }
823
hw_ok(const engine_t * engine) const824 bool hw_ok(const engine_t *engine) const {
825 auto *compute_engine
826 = utils::downcast<const compute::compute_engine_t *>(engine);
827 if (regs == 256 && !compute_engine->mayiuse_large_grf_mode())
828 return false;
829 return true;
830 }
831
data_types_ok() const832 bool data_types_ok() const {
833 bool is_bf16 = utils::one_of(data_type::bf16, src_data_type,
834 wei_data_type, dst_data_type, bia_data_type);
835 if (is_bf16 && hw <= ngen::HW::XeLP) return false;
836
837 if (is_fwd) return true;
838 if (is_bwd_d) return true;
839 if (is_bwd_w) {
840 bool ok = true;
841 ok &= (src_data_type == data_type::bf16
842 || src_data_type == data_type::f32);
843 ok &= (dst_data_type == src_data_type);
844 ok &= utils::one_of(wei_data_type, src_data_type, data_type::f32);
845
846 if (with_bias) {
847 ok &= utils::one_of(
848 bia_data_type, src_data_type, data_type::f32);
849 }
850 return ok;
851 }
852 return false;
853 }
854
is_s32_accumulator() const855 bool is_s32_accumulator() const { return acc_data_type == data_type::s32; }
is_f32_conv() const856 bool is_f32_conv() const {
857 return utils::everyone_is(src_data_type, wei_data_type, data_type::f32);
858 }
is_int8_dst() const859 bool is_int8_dst() const {
860 return utils::one_of(dst_data_type, data_type::s8, data_type::u8);
861 }
is_small_ic() const862 bool is_small_ic() const { return ic < simd_size; }
is_dpas_fma() const863 bool is_dpas_fma() const {
864 return utils::one_of(fma_kind, fma_kind_t::dpas, fma_kind_t::dpasw);
865 }
866
grf_size() const867 int grf_size() const { return ngen::GRF::bytes(hw); }
868
nd_range() const869 compute::nd_range_t nd_range() const {
870 size_t gws[3];
871 size_t lws[3];
872 for (int i = 0; i < 3; i++) {
873 lws[i] = tg_grid_dim[i] * (i == 0 ? simd_size : 1);
874 gws[i] = kernel_grid_dim[i] * lws[i];
875 }
876 return compute::nd_range_t(gws, lws);
877 }
878
str() const879 std::string str() const {
880 using namespace ir_utils;
881
882 std::ostringstream oss;
883 // clang-format off
884 oss << " Problem: " << desc_str() << std::endl;
885 oss << " Source layout: " << src_layout << std::endl;
886 oss << " Weights layout: " << wei_layout << std::endl;
887 oss << " Destination layout: " << dst_layout << std::endl;
888 oss << " MB TG block: " << mb_tg_blk << std::endl;
889 oss << " OD TG block: " << od_tg_blk << std::endl;
890 oss << " OH TG block: " << oh_tg_blk << std::endl;
891 oss << " OW TG block: " << ow_tg_blk << std::endl;
892 oss << " OC TG block: " << oc_tg_blk << std::endl;
893 oss << " Kernel grid: " << make_seq_print_helper(kernel_grid_dim, " x ") << std::endl;
894 oss << " Thread group: " << make_seq_print_helper(tg_grid_dim, " x ") << std::endl;
895 oss << " FMA kind: " << fma_kind::to_string(fma_kind) << std::endl;
896 oss << " Use SLM for A: " << to_string(use_a_slm) << std::endl;
897 oss << " Use SLM for B: " << to_string(use_b_slm) << std::endl;
898 oss << " SLM buffers: " << slm_bufs << std::endl;
899 oss << " GMEM to SLM, GRF buffers: " << gmem_bufs << std::endl;
900 oss << " Pad SLM: " << to_string(pad_slm) << std::endl;
901 oss << " Use prefetch: " << to_string(use_prefetch) << std::endl;
902 oss << " Prefetch buffers: " << prefetch_bufs << std::endl;
903 oss << " Do loop unroll: " << to_string(do_loop_unroll) << std::endl;
904 oss << " Assign SBIDs: " << to_string(assign_sbids) << std::endl;
905 oss << " Reduce GRF usage: " << to_string(reduce_grf_usage) << std::endl;
906 oss << " Reuse headers: " << to_string(reuse_headers) << std::endl;
907 oss << " Allow GRF reorder: " << to_string(allow_grf_reorder) << std::endl;
908 oss << " A sub-tiles: " << a_sub_tiles << std::endl;
909 oss << " B sub-tiles: " << b_sub_tiles << std::endl;
910 // clang-format on
911 return oss.str();
912 }
913
914 data_type_t a_data_type;
915 data_type_t b_data_type;
916 data_type_t c_data_type;
917 data_type_t acc_data_type;
918
919 ngen::HW hw = ngen::HW::Unknown;
920 int simd_size; // SIMD width.
921 int regs; // Number of registers.
922
923 // Thread group dimensions (thread group grid).
924 std::array<int, 3> tg_grid_dim;
925
926 // Number of thread groups across dimensions (kernel grid).
927 std::array<int, 3> kernel_grid_dim;
928
929 // Number of thread group blocks across problem dimensions.
930 int g_tg_dim;
931 int ic_tg_dim;
932 int iw_tg_dim;
933 int kw_tg_dim;
934 int mb_tg_dim;
935 int oc_tg_dim;
936 int od_tg_dim;
937 int oh_tg_dim;
938 int ow_tg_dim;
939
940 // Block sizes per thread group.
941 int g_tg_blk;
942 int ic_tg_blk;
943 int iw_tg_blk;
944 int kw_tg_blk;
945 int mb_tg_blk;
946 int oc_tg_blk;
947 int od_tg_blk;
948 int oh_tg_blk;
949 int ow_tg_blk;
950
951 // Number of thread blocks across problem dimensions.
952 int ic_thr_dim;
953 int mb_thr_dim;
954 int oc_thr_dim;
955 int ow_thr_dim;
956
957 // Block sizes per thread.
958 int g_thr_blk;
959 int ic_thr_blk;
960 int iw_thr_blk;
961 int mb_thr_blk;
962 int oc_thr_blk;
963 int ow_thr_blk;
964
965 // Block sizes per iteration.
966 int ic_blk;
967 int kw_blk;
968 int mb_blk;
969 int oc_blk;
970
971 // Block sizes in GEMM notation.
972 int b_blk;
973 int m_tg_blk;
974 int n_tg_blk;
975 int k_blk;
976
977 // Unroll sizes.
978 int mb_unroll;
979 int ow_unroll;
980
981 bool do_b_reduction;
982
983 fma_kind_t fma_kind; // Which instruction backend to use.
984
985 bool use_preload; // Whether to use SLM or prefetch.
986 bool use_a_slm; // Whether to use SLM for A.
987 bool use_b_slm; // Whether to use SLM for B.
988 bool use_prefetch; // Whether to use prefetch for A and B.
989 bool pad_slm; // Whether to pad SLM to avoid write conflicts.
990 bool assign_sbids; // Whether to manually assign SBID tokens.
991 int slm_bufs; // Number of SLM buffers to use.
992 int gmem_bufs; // Number of GRF buffers to use for GMEM -> SLM copy.
993 int prefetch_bufs; // Number of prefetch buffers for A and B.
994 bool do_loop_unroll; // Whether to fully unroll inner loops.
995 bool reduce_grf_usage; // Whether to try to reduce GRF usage based on heuristics.
996 bool allow_grf_reorder; // Whether to allow GRF reorders to FMA-friendly layouts.
997 bool zero_out_output; // Whether to zero out outputs before the main kernel.
998 bool do_atomic_update; // Whether to use atomics during C update.
999 bool reuse_headers; // Whether to reuse header messages to reduce GRF usage.
1000
1001 // Specific to BWD_W.
1002 bool do_post_bia_reorder; // Whether to perform extra reorder for weights.
1003 bool do_post_wei_reorder; // Whether to perform extra reorder for bias.
1004
1005 // Sub-tiles to split into for the inner A x B multiplication:
1006 // for i in range(0, a_sub_tiles):
1007 // A_i = load(...)
1008 // for j in range(0, b_sub_tiles):
1009 // B_j = load(...)
1010 // C_i_j += A_i * B_j
1011 //
1012 // GRF buffers for A_i and B_j are reused. Factors greater than one help to
1013 // reduce GRF usage.
1014 int a_sub_tiles;
1015 int b_sub_tiles;
1016
1017 private:
init_fwd_ic_thr_dim(engine_t * engine,int mb_thr_blk,int oc_thr_blk,int ow_thr_blk,int ic_blk) const1018 int init_fwd_ic_thr_dim(engine_t *engine, int mb_thr_blk, int oc_thr_blk,
1019 int ow_thr_blk, int ic_blk) const {
1020 if (mb_thr_blk > 1) return 1;
1021
1022 int ic_blocks = utils::div_up(ic, ic_blk);
1023 int reduction_blocks = ic_blocks * kd * kh * kw;
1024
1025 int oc_nthr = utils::div_up(oc, oc_thr_blk);
1026 int ow_nthr = utils::div_up(ow, ow_thr_blk);
1027 int mb_nthr = utils::div_up(mb, mb_thr_blk);
1028 int nthr = mb_nthr * oc_nthr * od * oh * ow_nthr;
1029
1030 auto *compute_engine
1031 = utils::downcast<const compute::compute_engine_t *>(engine);
1032 int eus = compute_engine->device_info()->eu_count();
1033
1034 int ret_ic_thr_dim = 1;
1035 if (!is_small_ic() && reduction_blocks >= 16 && (nthr < eus)) {
1036 ret_ic_thr_dim = ir_utils::max_divisor(ic_blocks, {1, 2, 4, 8});
1037
1038 // If reduction is too small, limit k-slicing.
1039 int reduction_threshold = 32;
1040 if (reduction_blocks < reduction_threshold) {
1041 int max_ic_thr_dim = utils::div_up(eus, nthr);
1042 max_ic_thr_dim = (1 << math::ilog2q(max_ic_thr_dim));
1043 ret_ic_thr_dim = std::min(ret_ic_thr_dim, max_ic_thr_dim);
1044 }
1045 }
1046 return ret_ic_thr_dim;
1047 }
1048
init_hw(engine_t * engine)1049 status_t init_hw(engine_t *engine) {
1050 using namespace compute;
1051
1052 auto compute_engine
1053 = utils::downcast<compute::compute_engine_t *>(engine);
1054 auto device_info = compute_engine->device_info();
1055
1056 switch (device_info->gpu_arch()) {
1057 case gpu_arch_t::gen9: hw = ngen::HW::Gen9; break;
1058 case gpu_arch_t::xe_lp: hw = ngen::HW::XeLP; break;
1059 case gpu_arch_t::xe_hp: hw = ngen::HW::XeHP; break;
1060 case gpu_arch_t::xe_hpg: hw = ngen::HW::XeHPG; break;
1061 case gpu_arch_t::xe_hpc: hw = ngen::HW::XeHPC; break;
1062 default: return status::unimplemented;
1063 }
1064 return status::success;
1065 }
1066
1067 // Initializes A/B/C data types (GEMM notation: C += A * B) according to
1068 // the following convention:
1069 // FWD: src -> A, wei -> B, dst -> C
1070 // BWD_D: diff_dst -> A, wei -> B, diff_src -> C
1071 // BWD_W: src -> A, diff_dst -> B, diff_wei -> C
init_abc_data_types()1072 status_t init_abc_data_types() {
1073 if (is_fwd) {
1074 a_data_type = src_data_type;
1075 b_data_type = wei_data_type;
1076 c_data_type = dst_data_type;
1077 } else if (is_bwd_d) {
1078 a_data_type = dst_data_type;
1079 b_data_type = wei_data_type;
1080 c_data_type = src_data_type;
1081 } else if (is_bwd_w) {
1082 a_data_type = src_data_type;
1083 b_data_type = dst_data_type;
1084 // Always use f32 for accumulation/storing in the main kernel.
1085 c_data_type = data_type::f32;
1086 } else {
1087 ir_error_not_expected();
1088 }
1089 return status::success;
1090 }
1091
init_acc_data_type()1092 status_t init_acc_data_type() {
1093 auto a = a_data_type;
1094 auto b = b_data_type;
1095 auto c = c_data_type;
1096 if (utils::one_of(a, data_type::s8, data_type::u8)
1097 && utils::one_of(b, data_type::s8, data_type::u8)) {
1098 acc_data_type = data_type::s32;
1099 return status::success;
1100 }
1101 if (utils::everyone_is(data_type::f16, a, b)
1102 || utils::everyone_is(data_type::bf16, a, b)) {
1103 acc_data_type = data_type::f32;
1104 return status::success;
1105 }
1106 if (utils::everyone_is(data_type::f32, a, b, c)) {
1107 acc_data_type = data_type::f32;
1108 return status::success;
1109 }
1110 return status::unimplemented;
1111 }
1112
init_fma_kind()1113 status_t init_fma_kind() {
1114 fma_kind = fma_kind::get_supported_kind(
1115 hw, a_data_type, b_data_type, acc_data_type);
1116
1117 simd_size = fma_kind::get_simd_size(
1118 hw, fma_kind, a_data_type, b_data_type, acc_data_type);
1119
1120 bool use_mad = false;
1121 if (is_small_ic() && !is_dw) {
1122 if (is_fwd && (kw != 7 || mb % 8 != 0))
1123 use_mad = true;
1124 else if (is_bwd_d)
1125 use_mad = true;
1126 } else if (is_dw) {
1127 use_mad = true;
1128 }
1129
1130 if (use_mad) {
1131 fma_kind = fma_kind_t::mad;
1132 simd_size = fma_kind::get_simd_size(
1133 hw, fma_kind, a_data_type, b_data_type, acc_data_type);
1134 }
1135
1136 #ifdef GEN_CONV_DEBUG
1137 fma_kind = fma_kind::from_string(ir_utils::getenv_str(
1138 "fma_kind", fma_kind::to_string(fma_kind)));
1139 simd_size = fma_kind::get_simd_size(
1140 hw, fma_kind, a_data_type, b_data_type, acc_data_type);
1141
1142 #endif
1143 if (fma_kind == fma_kind_t::unknown) return status::unimplemented;
1144
1145 // Disable using mad instruction backend until performance parity is
1146 // reached with OpenCL kernels.
1147 if (fma_kind == fma_kind_t::mad) {
1148 if (hw < ngen::HW::XeHP) return status::unimplemented;
1149 if (is_bwd_d) {
1150 if (!is_f32_conv()) return status::unimplemented;
1151 if (is_small_ic()) return status::unimplemented;
1152 return status::success;
1153 }
1154 }
1155
1156 return status::success;
1157 }
1158
init_data_layouts(convolution_pd_t * conv_pd)1159 status_t init_data_layouts(convolution_pd_t *conv_pd) {
1160 std::string src_tag;
1161 std::string wei_tag;
1162 std::string dst_tag;
1163
1164 const bool is_wei16aXb = hw >= ngen::HW::XeHPC;
1165 assert(hw != ngen::HW::Unknown);
1166 bool is_mb_block = mb >= 16;
1167
1168 // Src/Dst buffers should generally be the same format to avoid reorders
1169 // between FWD, BWD_D, and BWD_W.
1170 if (is_small_ic() && !is_dw) {
1171 src_tag = is_s32_accumulator() ? "ABx8a4b" : "ABx8a2b";
1172 } else if (fma_kind == fma_kind_t::mad) {
1173 if (is_s32_accumulator()) {
1174 src_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1175 } else {
1176 src_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1177 }
1178 if (is_fwd) {
1179 int max_simd_size = 16;
1180 if (simd_size > max_simd_size) simd_size = max_simd_size;
1181 }
1182 } else if (is_s32_accumulator()) {
1183 src_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1184 } else {
1185 src_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1186 }
1187
1188 if (fma_kind == fma_kind_t::mad) {
1189 if (is_dw) {
1190 if (is_int8_dst()) {
1191 dst_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1192 } else {
1193 dst_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1194 }
1195 } else {
1196 dst_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1197 }
1198 if (is_bwd_d) {
1199 int max_simd_size = 16;
1200 if (simd_size > max_simd_size) simd_size = max_simd_size;
1201 }
1202 } else if (is_int8_dst()) {
1203 dst_tag = (!is_mb_block ? "aBx32b" : "ABx32a32b");
1204 } else {
1205 dst_tag = (!is_mb_block ? "aBx16b" : "ABx32a16b");
1206 }
1207
1208 // Weight reorders are generally small, so reordering weights between
1209 // FWD and BWD_D/BWD_W implementations for optimization purposes makes
1210 // sense.
1211 if (is_fwd) {
1212 if (is_small_ic() && !is_dw) {
1213 if (fma_kind == fma_kind_t::mad)
1214 wei_tag = "bAx16a";
1215 else if (is_s32_accumulator())
1216 wei_tag = is_wei16aXb ? "ABx16a4b" : "ABx8a4b";
1217 else
1218 wei_tag = is_wei16aXb ? "ABx16a2b" : "ABx8a2b";
1219 } else {
1220 if (is_dw) {
1221 if (is_s32_accumulator())
1222 wei_tag = "Abcx32a";
1223 else
1224 wei_tag = "Abcx16a";
1225 } else if (fma_kind == fma_kind_t::mad) {
1226 wei_tag = "BAx16b16a";
1227 } else if (is_s32_accumulator()) {
1228 wei_tag = is_wei16aXb ? "ABx2a8b16a4b" : "ABx4a8b8a4b";
1229 } else {
1230 wei_tag = is_wei16aXb ? "ABx2a8b16a2b" : "ABx4a8b8a2b";
1231 }
1232 }
1233 } else if (is_bwd_d) {
1234 if (fma_kind == fma_kind_t::mad)
1235 wei_tag = "ABx16a16b";
1236 else if (is_s32_accumulator())
1237 wei_tag = is_wei16aXb ? "BAx2b8a16b4a" : "BAx4b8a8b4a";
1238 else
1239 wei_tag = is_wei16aXb ? "BAx2b8a16b2a" : "BAx4b8a8b2a";
1240 } else if (is_bwd_w) {
1241 if (is_small_ic()) {
1242 wei_tag = "Axb16a";
1243 } else {
1244 wei_tag = "ABx16b16a";
1245 }
1246 }
1247
1248 if (with_groups && !is_dw) wei_tag = prepend_groups_to_tag(wei_tag);
1249
1250 #ifdef GEN_CONV_DEBUG
1251 src_tag = ir_utils::getenv_str("stag", src_tag);
1252 wei_tag = ir_utils::getenv_str("wtag", wei_tag);
1253 dst_tag = ir_utils::getenv_str("dtag", dst_tag);
1254 #endif
1255
1256 auto &src_md = *conv_pd->invariant_src_md();
1257 auto &wei_md = *conv_pd->invariant_wei_md();
1258 auto &dst_md = *conv_pd->invariant_dst_md();
1259 auto &bia_md = *conv_pd->invariant_bia_md();
1260
1261 // Select layouts.
1262 src_layout = init_layout(src_md, src_tag);
1263 wei_layout = init_layout(wei_md, wei_tag);
1264 dst_layout = init_layout(dst_md, dst_tag);
1265 if (with_bias) bia_layout = init_layout(bia_md, "a");
1266
1267 // Validate layouts.
1268 bool is_src_nhwc = false;
1269 bool is_dst_nhwc = false;
1270
1271 if (is_fwd || is_bwd_d) {
1272 is_src_nhwc = (orig_src_mdw().is_plain()
1273 && src_layout == layout_t(src_md, "axb"));
1274 is_dst_nhwc = (orig_dst_mdw().is_plain()
1275 && dst_layout == layout_t(dst_md, "axb"));
1276 if (is_src_nhwc != is_dst_nhwc) return status::unimplemented;
1277
1278 // HWord loads require 32 byte alignment. For NHWC layout it means
1279 // input/output channels must be multiples of 32 bytes.
1280 size_t ic_bytes = ic * types::data_type_size(src_data_type);
1281 size_t oc_bytes = oc * types::data_type_size(dst_data_type);
1282 if (is_dst_nhwc && (ic_bytes % 32 != 0 || oc_bytes % 32 != 0))
1283 return status::unimplemented;
1284 }
1285 if (!is_src_nhwc
1286 && !src_layout.is_strictly_equal(make_layout(src_md, src_tag)))
1287 return status::unimplemented;
1288 if (!is_dst_nhwc
1289 && !dst_layout.is_strictly_equal(make_layout(dst_md, dst_tag)))
1290 return status::unimplemented;
1291 if (!wei_layout.is_strictly_equal(make_layout(wei_md, wei_tag)))
1292 return status::unimplemented;
1293 return status::success;
1294 }
1295
enable_slm_buffering()1296 void enable_slm_buffering() {
1297 using namespace ir_utils;
1298
1299 use_a_slm = (tg_grid_dim[0] > 1);
1300 use_b_slm = (tg_grid_dim[1] > 1);
1301 if (use_a_slm || use_b_slm) {
1302 int pref_slm_bufs = (tg_grid_dim[0] * tg_grid_dim[1] <= 8 ? 2 : 3);
1303 if (do_loop_unroll) {
1304 slm_bufs = pref_slm_bufs;
1305 gmem_bufs = (is_dpas_fma() ? 2 : 1);
1306 } else {
1307 // Double/triple SLM buffering is not supported when only one
1308 // matrix is SLM-buffered.
1309 slm_bufs = (use_a_slm == use_b_slm ? pref_slm_bufs : 1);
1310 gmem_bufs = 1;
1311 }
1312 } else {
1313 slm_bufs = 0;
1314 gmem_bufs = 0;
1315 }
1316 #ifdef GEN_CONV_DEBUG
1317 use_a_slm = getenv_bool("use_a_slm", use_a_slm);
1318 use_b_slm = getenv_bool("use_b_slm", use_b_slm);
1319 slm_bufs = getenv_int("slm_bufs", slm_bufs);
1320 gmem_bufs = getenv_int("gmem_bufs", gmem_bufs);
1321 #endif
1322 }
1323
enable_prefetch()1324 void enable_prefetch() {
1325 using namespace ir_utils;
1326
1327 use_prefetch = true;
1328 prefetch_bufs = is_bwd_w ? 2 : 3;
1329 #ifdef GEN_CONV_DEBUG
1330 use_prefetch = getenv_bool("use_prefetch", use_prefetch);
1331 prefetch_bufs = getenv_int("prefetch_bufs", prefetch_bufs);
1332 #endif
1333 }
1334
disable_slm_buffering()1335 void disable_slm_buffering() {
1336 use_a_slm = false;
1337 use_b_slm = false;
1338 slm_bufs = 0;
1339 gmem_bufs = 0;
1340 }
1341
disable_prefetch()1342 void disable_prefetch() {
1343 use_prefetch = false;
1344 prefetch_bufs = 0;
1345 }
1346
1347 // Overwrites parameters that are implied by other parameters.
fixup_inference_consistency()1348 void fixup_inference_consistency() {
1349 // Can't reuse headers with loop unroll and post-increment offset updates.
1350 if (reuse_headers) do_loop_unroll = false;
1351
1352 bool prefer_prefetch = false;
1353 if (hw >= ngen::HW::XeHPC) prefer_prefetch = true;
1354
1355 if (use_preload) {
1356 // Prefetches are only supported with loop unrolling.
1357 if (prefer_prefetch && do_loop_unroll) {
1358 enable_prefetch();
1359 } else {
1360 enable_slm_buffering();
1361 }
1362 }
1363 // Downgrade dpasw -> dpas for some cases.
1364 if (fma_kind == fma_kind_t::dpasw) {
1365 // dpasw is executed by fused EUs (across X thread group
1366 // dimension). Do not use dpasw if X is uneven.
1367 if (tg_grid_dim[0] % 2 != 0) fma_kind = fma_kind_t::dpas;
1368 // dpasw can't be generated in case of direct load from GMEM and reorder.
1369 if (is_bwd_w && allow_grf_reorder && (!use_a_slm || !use_b_slm))
1370 fma_kind = fma_kind_t::dpas;
1371 }
1372 }
1373
try_reduce_grf_usage()1374 bool try_reduce_grf_usage() {
1375 if (!reduce_grf_usage) return true;
1376
1377 // TODO: improve estimate register count, it fails to account for tmp
1378 // values like mask_registers among other things.
1379 double reg_factor = is_bwd_w ? 0.875 : 0.95;
1380 int max_regs = int(regs * reg_factor);
1381 int regs = estimate_register_count();
1382 if (regs <= max_regs) return true;
1383
1384 // Try to disable GRF buffering.
1385 if (gmem_bufs > 1) {
1386 gmem_bufs = 1;
1387 int regs = estimate_register_count();
1388 if (regs <= max_regs) return true;
1389 }
1390
1391 // Try to use sub-tiles for B.
1392 int n_thr_blk = utils::div_up(n_tg_blk, tg_grid_dim[0]);
1393 int max_b_sub_tiles
1394 = std::min((use_b_slm ? 4 : 2), n_thr_blk / simd_size);
1395 // XXX: avoid layout mismatch for B loads
1396 if (hw >= ngen::HW::XeHPC && is_bwd_w) max_b_sub_tiles = 2;
1397 while (b_sub_tiles < max_b_sub_tiles) {
1398 b_sub_tiles *= 2;
1399 int regs = estimate_register_count();
1400 if (regs <= max_regs) return true;
1401 }
1402
1403 // Try to use double SLM buffering.
1404 if (slm_bufs == 3) {
1405 slm_bufs = 2;
1406 int regs = estimate_register_count();
1407 if (regs <= max_regs) return true;
1408 }
1409
1410 // Try to use single SLM buffering.
1411 if (slm_bufs == 2) {
1412 slm_bufs = 1;
1413 int regs = estimate_register_count();
1414 if (regs <= max_regs) return true;
1415 }
1416
1417 // Last resort settings to reduce GRF usage.
1418 reuse_headers = true;
1419 do_loop_unroll = false;
1420
1421 return estimate_register_count() <= max_regs;
1422 }
1423
estimate_register_count() const1424 int estimate_register_count() const {
1425 int reg_bytes = ngen::GRF::bytes(hw);
1426
1427 // Assume 8 HWord per GMEM load for double-blocked layouts and 1 HWord
1428 // otherwise.
1429 int hword_bytes = 32;
1430 int a_gmem_msg_bytes
1431 = (a_layout().is_n_blocked(2) ? 8 : 1) * hword_bytes;
1432 int b_gmem_msg_bytes
1433 = (b_layout().is_n_blocked(2) ? 8 : 1) * hword_bytes;
1434
1435 // Assume 8 HWords per SLM load/store.
1436 int slm_msg_bytes = 256;
1437
1438 int nthr = tg_grid_dim[0] * tg_grid_dim[1];
1439 int m_thr_blk = utils::div_up(m_tg_blk, tg_grid_dim[1]);
1440 int n_thr_blk = utils::div_up(n_tg_blk, tg_grid_dim[0]);
1441 int k_thr_blk = k_blk;
1442
1443 int a_size = int(types::data_type_size(a_data_type));
1444 int b_size = int(types::data_type_size(b_data_type));
1445 int acc_size = int(types::data_type_size(acc_data_type));
1446
1447 // Registers for C += A * B operation.
1448 int a_tile_bytes = m_thr_blk * k_thr_blk * a_size;
1449 int b_tile_bytes = k_thr_blk * n_thr_blk * b_size;
1450 int a_bytes = utils::div_up(a_tile_bytes, a_sub_tiles);
1451 int b_bytes = utils::div_up(b_tile_bytes, b_sub_tiles);
1452 int acc_bytes = m_thr_blk * n_thr_blk * acc_size;
1453
1454 int a_regs = utils::div_up(a_bytes, reg_bytes);
1455 int b_regs = utils::div_up(b_bytes, reg_bytes);
1456 int acc_regs = utils::div_up(acc_bytes, reg_bytes);
1457
1458 int a_headers = utils::div_up(
1459 a_tile_bytes, use_a_slm ? slm_msg_bytes : a_gmem_msg_bytes);
1460 int b_headers = utils::div_up(
1461 b_tile_bytes, use_b_slm ? slm_msg_bytes : b_gmem_msg_bytes);
1462
1463 if (fma_kind == fma_kind_t::dpasw) {
1464 // dpasw reuses registers between fused threads across tg0. M is
1465 // split across tg1, N is split across tg0 so dpasw allows to share
1466 // matrix A which is is (M x K).
1467 a_regs = utils::div_up(a_regs, 2);
1468 a_headers = utils::div_up(a_headers, 2);
1469 }
1470
1471 // Size of A/B thread blocks when split full A/B TG blocks across all
1472 // threads in TG.
1473 int a_tg_per_thr_bytes = utils::div_up(m_tg_blk * k_blk * a_size, nthr);
1474 int b_tg_per_thr_bytes = utils::div_up(k_blk * n_tg_blk * b_size, nthr);
1475
1476 // Temporary registers for GMEM -> SLM load.
1477 int a_g2s_bytes = (use_a_slm ? a_tg_per_thr_bytes : 0);
1478 int b_g2s_bytes = (use_b_slm ? b_tg_per_thr_bytes : 0);
1479
1480 // Account for dedicated headers for prefetches.
1481 if (use_prefetch) {
1482 a_headers += utils::div_up(a_tg_per_thr_bytes, a_gmem_msg_bytes);
1483 b_headers += utils::div_up(b_tg_per_thr_bytes, b_gmem_msg_bytes);
1484 }
1485
1486 int a_g2s_regs = utils::div_up(a_g2s_bytes, reg_bytes);
1487 int b_g2s_regs = utils::div_up(b_g2s_bytes, reg_bytes);
1488
1489 // Two sets of headers for GMEM -> GRF and GRF -> SLM.
1490 int a_g2s_headers = utils::div_up(a_g2s_bytes, a_gmem_msg_bytes)
1491 + utils::div_up(a_g2s_bytes, slm_msg_bytes);
1492 int b_g2s_headers = utils::div_up(b_g2s_bytes, b_gmem_msg_bytes)
1493 + utils::div_up(b_g2s_bytes, slm_msg_bytes);
1494
1495 // Extra registers for GRF <-> GRF reorders.
1496 int reorder_regs = 0;
1497
1498 // Assume A/B need reorders to temporary buffers.
1499 if (allow_grf_reorder) {
1500 if (is_bwd_w) {
1501 // Hardcode for now, this is the upper bound for the temporary
1502 // buffer size for BWD_W.
1503 int bwd_w_reorder_regs = 16;
1504 reorder_regs += bwd_w_reorder_regs;
1505 }
1506
1507 int ab_reorder_regs = 0;
1508
1509 if (use_a_slm) {
1510 ab_reorder_regs = std::max(ab_reorder_regs, a_g2s_regs);
1511 } else {
1512 int a_reorder_regs = a_regs;
1513 // Loads must be aligned to a GRF boundary, account for cases
1514 // when the load size is less than the register size.
1515 if (a_gmem_msg_bytes < reg_bytes) {
1516 a_reorder_regs
1517 *= utils::div_up(reg_bytes, a_gmem_msg_bytes);
1518 }
1519 ab_reorder_regs = std::max(ab_reorder_regs, a_reorder_regs);
1520 }
1521 if (use_b_slm) {
1522 ab_reorder_regs = std::max(ab_reorder_regs, b_g2s_regs);
1523 } else {
1524 int b_reorder_regs = b_regs;
1525 // Loads must be aligned to a GRF boundary, account for cases
1526 // when the load size is less than the register size.
1527 if (b_gmem_msg_bytes < reg_bytes) {
1528 b_reorder_regs
1529 *= utils::div_up(reg_bytes, b_gmem_msg_bytes);
1530 }
1531 ab_reorder_regs = std::max(ab_reorder_regs, b_reorder_regs);
1532 }
1533 reorder_regs += ab_reorder_regs;
1534 }
1535
1536 int g2s_regs = gmem_bufs * (a_g2s_regs + b_g2s_regs);
1537 int g2s_headers = a_g2s_headers + b_g2s_headers;
1538
1539 int data_regs = a_regs + b_regs + acc_regs + g2s_regs;
1540 int header_regs = a_headers + b_headers + g2s_headers;
1541 if (reuse_headers) header_regs = 1;
1542
1543 int estimated_regs = data_regs + reorder_regs + header_regs;
1544
1545 return estimated_regs;
1546 }
1547
a_layout() const1548 const layout_t &a_layout() const {
1549 const layout_t *ret = nullptr;
1550 if (is_fwd) {
1551 ret = &src_layout;
1552 } else if (is_bwd_d) {
1553 ret = &dst_layout;
1554 } else if (is_bwd_w) {
1555 ret = &src_layout;
1556 }
1557 ir_assert(ret && !ret->is_empty()) << "Layout is not initialized.";
1558 return *ret;
1559 }
1560
b_layout() const1561 const layout_t &b_layout() const {
1562 const layout_t *ret = nullptr;
1563 if (is_fwd) {
1564 ret = &wei_layout;
1565 } else if (is_bwd_d) {
1566 ret = &wei_layout;
1567 } else if (is_bwd_w) {
1568 ret = &dst_layout;
1569 }
1570 ir_assert(ret && !ret->is_empty()) << "Layout is not initialized.";
1571 return *ret;
1572 }
1573
prepend_groups_to_tag(const std::string & tag)1574 static std::string prepend_groups_to_tag(const std::string &tag) {
1575 auto ret = tag;
1576 for (auto &c : ret) {
1577 bool is_lower_dim = ('a' <= c && c < 'a' + DNNL_MAX_NDIMS);
1578 bool is_upper_dim = ('A' <= c && c < 'A' + DNNL_MAX_NDIMS);
1579 if (!is_lower_dim && !is_upper_dim) continue;
1580 c += 1;
1581 }
1582 return "a" + ret;
1583 }
1584
init_layout(memory_desc_t & user_md,const std::string & optimal_tag)1585 static layout_t init_layout(
1586 memory_desc_t &user_md, const std::string &optimal_tag) {
1587 auto optimal = make_layout(user_md, optimal_tag);
1588 if (user_md.format_kind != format_kind::any) {
1589 auto user = make_layout(user_md);
1590 // If layouts are physically different return the layout passed by
1591 // the user and return unimplemented later.
1592 if (user != optimal) return user;
1593 } else {
1594 user_md = optimal.to_dnnl(user_md.dims);
1595 }
1596 return optimal;
1597 }
1598
make_layout(const memory_desc_t & md)1599 static layout_t make_layout(const memory_desc_t &md) {
1600 return layout_t(md, /*do_normalize=*/false);
1601 }
1602
make_layout(const memory_desc_t & md,const std::string & tag)1603 static layout_t make_layout(
1604 const memory_desc_t &md, const std::string &tag) {
1605 return layout_t(md, tag, /*do_normalize=*/false);
1606 }
1607 };
1608
operator <<(std::ostream & out,const conv_config_t & cfg)1609 inline std::ostream &operator<<(std::ostream &out, const conv_config_t &cfg) {
1610 out << cfg.str();
1611 return out;
1612 }
1613
1614 } // namespace jit
1615 } // namespace gpu
1616 } // namespace impl
1617 } // namespace dnnl
1618
1619 #endif
1620