1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/jit/gemm/xe_hp_systolic_gemm.hpp"
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/float16.hpp"
22 #include "common/type_helpers.hpp"
23 #include "gpu/jit/gemm/gemm_walk_orders.hpp"
24 #include "gpu/jit/ngen_type_bridge.hpp"
25 #include "gpu/ocl/gemm/xe_systolic_gemm_copy_kernel.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace gpu {
30 namespace jit {
31 
init(engine_t * engine)32 status_t xe_hp_systolic_gemm_t::pd_t::init(engine_t *engine) {
33     using namespace prop_kind;
34     using namespace data_type;
35     using namespace primitive_kind;
36     using smask_t = primitive_attr_t::skip_mask_t;
37     using arch_t = compute::gpu_arch_t;
38 
39     assert(engine->kind() == engine_kind::gpu);
40     auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
41 
42     if (!compute_engine->mayiuse_ngen_kernels()) return status::unimplemented;
43     if (!compute_engine->mayiuse_large_grf_mode()) return status::unimplemented;
44 
45     dev_info_ = compute_engine->device_info();
46     auto arch = dev_info_->gpu_arch();
47 
48     const auto &d = desc();
49 
50     bool dt_float_ok = (d->a_type() == d->b_type()
51             && utils::one_of(d->a_type(), bf16, f16)
52             && utils::one_of(d->c_type(), f32, d->a_type()));
53 
54     bool dt_int_ok = (utils::one_of(d->a_type(), u8, s8)
55             && utils::one_of(d->b_type(), u8, s8) && (d->c_type() == s32));
56 
57     if (dt_int_ok) {
58         if (attr()->zero_points_.defined(DNNL_ARG_SRC)) {
59             const int *ao_i32 = nullptr;
60             attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, nullptr, &ao_i32);
61             a_zp_ = (*ao_i32 != 0);
62         } else if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC))
63             return status::unimplemented;
64 
65         if (attr()->zero_points_.defined(DNNL_ARG_WEIGHTS)) {
66             const int *bo_i32 = nullptr;
67             attr()->zero_points_.get(
68                     DNNL_ARG_WEIGHTS, nullptr, nullptr, &bo_i32);
69             b_zp_ = (*bo_i32 != 0);
70         } else if (!attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS))
71             return status::unimplemented;
72 
73         c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
74     }
75 
76     bool ok = set_default_formats(d->a_type());
77     if (!ok) return status::unimplemented;
78 
79     CHECK(attr_.set_default_formats(dst_md(0)));
80 
81     if (use_fma()) return status::unimplemented;
82 
83     // LIMITATIONS:
84     // - batch is not supported for unpacked inputs.
85     // - runtime dims are not supported
86     bool limits_ok
87             = !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k());
88     if (!packed_a())
89         limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL)
90                 && (d->batch() == 1);
91     if (!packed_b())
92         limits_ok = limits_ok && (d->ldb() != DNNL_RUNTIME_DIM_VAL)
93                 && (d->batch() == 1);
94     if (!packed_c())
95         limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL);
96 
97     auto attr_skip_mask = smask_t::oscale | smask_t::post_ops;
98 
99     if (dt_int_ok) attr_skip_mask |= smask_t::zero_points_runtime;
100 
101     bool arch_ok = (arch == arch_t::xe_hp);
102     arch_ok |= (arch == arch_t::xe_hpg);
103     arch_ok |= (arch == arch_t::xe_hpc);
104 
105     ok = true && limits_ok && (dt_float_ok || dt_int_ok) && arch_ok
106             && compute_engine->mayiuse(compute::device_ext_t::
107                             intel_subgroup_split_matrix_multiply_accumulate)
108             && attr()->has_default_values(attr_skip_mask)
109             && attr()->output_scales_.mask_ == 0 && attr()->post_ops_.len() <= 2
110             && IMPLICATION(with_bias(),
111                     dt_float_ok
112                             && utils::one_of(d->bias_type(), d->a_type(), f32)
113                             && utils::one_of(bias_cmask(), 0, 1 << 0, 1 << 1));
114 
115     // check if there is sum post op and only at first place
116     ok &= IMPLICATION(attr()->post_ops_.find(sum) != -1,
117             attr()->post_ops_.find(sum) == 0
118                     && attr()->post_ops_.find(sum, 1) == -1);
119 
120     // check if post ops are supported
121     ok &= IMPLICATION(attr()->post_ops_.len() > 0,
122             jit_post_op_injector_is_supported(attr()->post_ops_, true));
123 
124     if (dt_int_ok) {
125         ok &= IMPLICATION(a_zp_, !packed_b()) && IMPLICATION(b_zp_, !packed_a())
126                 && IMPLICATION(
127                         c_zp_, !attr()->zero_points_.defined(DNNL_ARG_DST));
128 
129         int cmask_a = 0, cmask_b = 0, cmask_c = 0;
130         attr()->zero_points_.get(DNNL_ARG_WEIGHTS, nullptr, &cmask_b, nullptr);
131         attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, &cmask_a, nullptr);
132         attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &cmask_c, nullptr);
133         ok &= (cmask_a == 0) && (cmask_b == 0)
134                 && utils::one_of(cmask_c, 0, 1 << 0, 1 << 1);
135     }
136 
137     if (!ok) return status::unimplemented;
138 
139     return status::success;
140 }
141 
142 namespace {
143 struct nocopy_table_t {
144     int mn_limit[2][2]; // Use no-copy if m*n < mn_limit * mn_limit and
145     int k_limit[2][2]; // Use no-copy if k < k_limit
146 };
147 
148 const nocopy_table_t xe_hp_f16_nocopy_table[] = {
149         // NN     NT     TN    TT
150         {{{1280, 768}, {512, 384}}, {{512, 768}, {1024, 512}}}};
151 
152 const nocopy_table_t xe_hp_bf16_nocopy_table[] = {
153         // NN   NT     TN   TT
154         {{{512, 256}, {512, 512}}, {{512, 256}, {384, 384}}}};
155 
156 const nocopy_table_t xe_hp_x8x8s32_nocopy_table[] = {
157         // NN   NT     TN   TT
158         {{{384, 384}, {384, 384}}, {{384, 512}, {384, 256}}}};
159 } // namespace
160 
use_fma()161 bool xe_hp_systolic_gemm_t::pd_t::use_fma() {
162     using namespace data_type;
163 
164     const auto &d = desc();
165 
166     if (any_prepacked_) return false;
167 
168     // Use FMA implementation if one matrix is very small.
169     if (d->m() < 32 && d->n() < 32) return true;
170     if (d->m() < 32 && d->k() < 32) return true;
171     if (d->n() < 32 && d->k() < 32) return true;
172 
173     // Use FMA for small/medium sizes.
174     if (utils::one_of(d->c_type(), bf16, f16, s32)) {
175         const nocopy_table_t *all_tables[3] = {xe_hp_f16_nocopy_table,
176                 xe_hp_bf16_nocopy_table, xe_hp_x8x8s32_nocopy_table};
177         const int type_idx
178                 = (d->c_type() == f16) ? 0 : (d->c_type() == bf16) ? 1 : 2;
179         const nocopy_table_t *table = all_tables[type_idx];
180         const long mnl = table->mn_limit[d->transa()][d->transb()];
181         const long kl = table->k_limit[d->transa()][d->transb()];
182 
183         if ((d->m() * d->n() < mnl * mnl) && (d->k() < kl)) return true;
184     }
185 
186     return false;
187 }
188 
set_default_formats(data_type_t dt)189 bool xe_hp_systolic_gemm_t::pd_t::set_default_formats(data_type_t dt) {
190     using namespace format_tag;
191     using new_kernel_t = gen_gemm_xe_systolic_kernel_t;
192 
193     auto sz = types::data_type_size(dt);
194     const auto &d = desc();
195     auto arch = dev_info_->gpu_arch();
196 
197     memory_desc_wrapper a_mdw(&desc_.b_desc);
198     memory_desc_wrapper b_mdw(&desc_.a_desc);
199     memory_desc_wrapper c_mdw(&desc_.c_desc);
200 
201     bool a_any = a_mdw.format_any();
202     bool b_any = b_mdw.format_any();
203     bool c_any = c_mdw.format_any();
204     bool batch = d->is_batched();
205 
206     format_tag_t a_packed_tag = batch ? ((sz == 2) ? aCB4c8b8c2b : aCB4c8b8c4b)
207                                       : ((sz == 2) ? BA4b8a8b2a : BA4b8a8b4a);
208     format_tag_t b_packed_tag_48 = batch ? ((sz == 2) ? aBC48b16c : aBC48b32c)
209                                          : ((sz == 2) ? AB48a16b : AB48a32b);
210     format_tag_t b_packed_tag_32 = batch ? ((sz == 2) ? aBC32b16c : aBC32b32c)
211                                          : ((sz == 2) ? AB32a16b : AB32a32b);
212     format_tag_t unpacked_tag = batch ? abc : ab;
213 
214     if (arch == compute::gpu_arch_t::xe_hpc) {
215         a_packed_tag = batch ? ((sz == 2) ? aCB4c8b16c2b : aCB4c8b16c4b)
216                              : ((sz == 2) ? BA4b8a16b2a : BA4b8a16b4a);
217     }
218 
219     bool a_prepacked = a_mdw.matches_tag(a_packed_tag);
220     bool bc_prepacked_32 = b_mdw.matches_tag(b_packed_tag_32)
221             || c_mdw.matches_tag(b_packed_tag_32);
222     bool bc_prepacked_48 = b_mdw.matches_tag(b_packed_tag_48)
223             || c_mdw.matches_tag(b_packed_tag_48);
224     bool c_prepacked = c_mdw.matches_tag(b_packed_tag_32)
225             || c_mdw.matches_tag(b_packed_tag_48);
226 
227     any_prepacked_ = a_prepacked || bc_prepacked_32 || bc_prepacked_48;
228 
229     unroll_m_ = 32;
230     unroll_n_ = 0;
231     kernel_tag_ = 0;
232     if (bc_prepacked_32)
233         unroll_n_ = 32;
234     else if (bc_prepacked_48)
235         unroll_n_ = 48;
236 
237     use_new_kernels_ = !c_prepacked && !with_ab_zero_points() && (d->k() >= 64);
238     use_new_kernels_ |= (arch >= compute::gpu_arch_t::xe_hpc);
239 
240     new_kernel_t::choose_unrolls(arch, dev_info_->eu_count(), d->a_type(),
241             d->b_type(), d->c_type(), d->m(), d->n(), d->k(), d->batch(),
242             unroll_m_, unroll_n_, kernel_tag_);
243 
244     format_tag_t b_packed_tag
245             = (unroll_n_ == 48) ? b_packed_tag_48 : b_packed_tag_32;
246     format_tag_t c_packed_tag = use_new_kernels_ ? unpacked_tag : b_packed_tag;
247 
248     packed_a_ = packed_b_ = packed_c_ = false;
249 
250     if (a_any) {
251         CHECK(memory_desc_init_by_tag(
252                 desc_.b_desc, b_zp_ ? unpacked_tag : a_packed_tag));
253         auto &ld = desc_.b_desc.padded_dims[batch ? 1 : 0];
254         ld = nice_ld(ld, int(sz));
255         desc_.b_desc.format_desc.blocking.strides[batch ? 2 : 1]
256                 = unroll_m_ * ld;
257         packed_a_ = true;
258     } else if (a_mdw.matches_one_of_tag(a_packed_tag, ab, ba, abc, acb)
259             == undef)
260         return false;
261 
262     if (b_any) {
263         CHECK(memory_desc_init_by_tag(
264                 desc_.a_desc, a_zp_ ? unpacked_tag : b_packed_tag));
265         auto &ld = desc_.a_desc.padded_dims[batch ? 2 : 1];
266         ld = nice_ld(ld, int(sz));
267         desc_.a_desc.format_desc.blocking.strides[batch ? 1 : 0]
268                 = unroll_n_ * ld;
269         packed_b_ = true;
270     } else if (b_mdw.matches_one_of_tag(b_packed_tag, ab, ba, abc, acb)
271             == undef)
272         return false;
273 
274     if (c_any)
275         CHECK(memory_desc_init_by_tag(desc_.c_desc, c_packed_tag));
276     else if (c_mdw.matches_one_of_tag(c_packed_tag, ab, abc) == undef)
277         return false;
278 
279     packed_a_ = packed_a_ || a_mdw.matches_tag(a_packed_tag);
280     packed_b_ = packed_b_ || b_mdw.matches_tag(b_packed_tag);
281     packed_c_ = c_mdw.matches_tag(b_packed_tag);
282 
283     return gpu_gemm_pd_t::set_default_formats();
284 }
285 
init(engine_t * engine)286 status_t xe_hp_systolic_gemm_t::init(engine_t *engine) {
287     arch_ = pd()->dev_info_->gpu_arch();
288     eu_count_ = pd()->dev_info_->eu_count();
289 
290     auto a_type = pd()->desc()->a_type();
291     auto b_type = pd()->desc()->b_type();
292 
293     int cmask = -1;
294 
295     if (pd()->with_c_zero_points())
296         pd()->attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &cmask, nullptr);
297     else if (pd()->with_bias())
298         cmask = pd()->bias_cmask();
299 
300     switch (cmask) {
301         case 0: co_kind_ = 'F'; break;
302         case (1 << 1): co_kind_ = 'R'; break;
303         case (1 << 0): co_kind_ = 'C'; break;
304         case -1:
305         default: co_kind_ = 'N'; break;
306     }
307 
308     if (get_verbose() >= 2) {
309         char tag_s[2] = {pd()->kernel_tag(), 0};
310         printf("onednn_verbose,info,gpu,gemm,kernel:%dx%d,%s,new:%c\n",
311                 pd()->unroll_m(), pd()->unroll_n(), tag_s,
312                 pd()->use_new_kernels() ? 'Y' : 'N');
313     }
314 
315     // Initialize compute kernels (assembly)
316     {
317         auto status = pd()->use_new_kernels() ? init_compute_new(engine)
318                                               : init_compute_old(engine);
319         if (status != status::success) return status;
320     }
321 
322     // Initialize copy kernels (OpenCL)
323     for (bool copy_b : {false, true}) {
324         for (bool clear_sum : {false, true}) {
325             if (clear_sum && !pd()->with_ab_zero_points()) continue;
326             if (!copy_b ? pd()->packed_a() : pd()->packed_b()) continue;
327 
328             using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t;
329             compute::kernel_ctx_t kernel_ctx;
330 
331             auto trans
332                     = !copy_b ? pd()->desc()->transa() : pd()->desc()->transb();
333             auto status = copy_kernel_t::init_kernel_ctx(kernel_ctx, arch_,
334                     !copy_b ? a_type : b_type, pd()->unroll_n(), copy_b, trans,
335                     pd()->with_ab_zero_points(), clear_sum);
336             if (status != status::success) return status;
337 
338             create_kernel(engine, &copy_kernel_[copy_b][clear_sum],
339                     copy_kernel_t::name(arch_), kernel_ctx);
340             if (!copy_kernel_[copy_b][clear_sum]) return status::runtime_error;
341         }
342     }
343 
344     return status::success;
345 }
346 
init_compute_old(engine_t * engine)347 status_t xe_hp_systolic_gemm_t::init_compute_old(engine_t *engine) {
348     using kernel_t = xehp_systolic_gemm_kernel_t<gpu_xe_hp>;
349     using arch_t = compute::gpu_arch_t;
350 
351     kernel_t::config_t cfg;
352 
353     auto a_type = pd()->desc()->a_type();
354     auto b_type = pd()->desc()->b_type();
355     auto c_type = pd()->desc()->c_type();
356     auto acc_type = pd()->impl_acc_type();
357 
358     cfg.a_type = convert_dnnl_type_to_ngen(a_type);
359     cfg.b_type = convert_dnnl_type_to_ngen(b_type);
360     cfg.c_type = convert_dnnl_type_to_ngen(c_type);
361     cfg.acc_type = convert_dnnl_type_to_ngen(acc_type);
362     cfg.alpha1 = (pd()->alpha() == 1.0f);
363     cfg.beta0 = (pd()->beta() == 0.0f);
364     cfg.beta1 = (pd()->beta() == 1.0f);
365     cfg.post_ops = pd()->attr()->post_ops_;
366     cfg.a_bias = cfg.b_bias = pd()->with_ab_zero_points();
367     cfg.c_packed = pd()->packed_c();
368     cfg.batch = pd()->with_batch();
369     walk_n_first_ = cfg.walk_n_first
370             = (pd()->desc()->m() >= 2 * pd()->desc()->n());
371     cfg.tile_m = pd()->unroll_m();
372     cfg.tile_n = pd()->unroll_n();
373     cfg.global_3x_buf = (cfg.tile_n == 32);
374 
375     if (pd()->with_c_zero_points())
376         cfg.co_type = cfg.c_type;
377     else if (pd()->with_bias()) {
378         cfg.early_c_bias = true;
379         cfg.co_type = convert_dnnl_type_to_ngen(pd()->desc()->bias_type());
380     }
381 
382     switch (co_kind_) {
383         case 'F': cfg.c_bias = kernel_t::bias_t::fixed; break;
384         case 'R': cfg.c_bias = kernel_t::bias_t::row; break;
385         case 'C': cfg.c_bias = kernel_t::bias_t::column; break;
386         case 'N':
387         default: cfg.c_bias = kernel_t::bias_t::none; break;
388     }
389 
390     bool may_k_block = (pd()->desc()->k() > kernel_t::min_block_k(a_type));
391     bool got_info = false;
392 
393     for (bool first_k_block : {false, true}) {
394         for (bool last_k_block : {false, true}) {
395             if ((!first_k_block || !last_k_block) && !may_k_block) continue;
396             if (may_k_block && last_k_block
397                     && (cfg.c_bias == kernel_t::bias_t::none)
398                     && !cfg.have_post_op())
399                 kernel_[first_k_block][last_k_block]
400                         = kernel_[first_k_block][false];
401             else if (may_k_block && first_k_block && cfg.beta1)
402                 kernel_[first_k_block][last_k_block]
403                         = kernel_[false][last_k_block];
404             else {
405                 auto cfg_copy = cfg;
406                 if (!first_k_block) {
407                     cfg_copy.beta0 = false;
408                     cfg_copy.beta1 = true;
409                 }
410                 if (!last_k_block) {
411                     cfg_copy.c_bias = kernel_t::bias_t::none;
412                     cfg_copy.post_ops = post_ops_t();
413                 }
414 
415                 switch (arch_) {
416                     case arch_t::xe_hp: {
417                         auto kernel = kernel_t(cfg_copy);
418 
419                         create_kernel(engine,
420                                 &kernel_[first_k_block][last_k_block], kernel);
421 
422                         if (!got_info) {
423                             compute_info_ = kernel.driver_info(eu_count_);
424                             got_info = true;
425                         }
426                         break;
427                     }
428                     case arch_t::xe_hpg: {
429                         using kernel_xe_hpg_t
430                                 = xehp_systolic_gemm_kernel_t<gpu_xe_hpg>;
431                         cfg_copy.emulate64 = true;
432                         auto kernel = kernel_xe_hpg_t(
433                                 cfg_copy.cast<kernel_xe_hpg_t::config_t>());
434 
435                         create_kernel(engine,
436                                 &kernel_[first_k_block][last_k_block], kernel);
437 
438                         if (!got_info) {
439                             compute_info_ = kernel.driver_info(eu_count_);
440                             got_info = true;
441                         }
442                         break;
443                     }
444                     default:
445                         assert(!"Unsupported GPU architecture.");
446                         return status::unimplemented;
447                         break;
448                 }
449 
450                 if (!kernel_[first_k_block][last_k_block])
451                     return status::runtime_error;
452             }
453         }
454     }
455 
456     return status::success;
457 }
458 
init_compute_new(engine_t * engine)459 status_t xe_hp_systolic_gemm_t::init_compute_new(engine_t *engine) {
460     using kernel_t = gen_gemm_xe_systolic_kernel_t;
461     using offset_t = kernel_t::offset_t;
462 
463     auto a_type = pd()->desc()->a_type();
464     auto b_type = pd()->desc()->b_type();
465     auto c_type = pd()->desc()->c_type();
466     auto co_type = pd()->with_bias() ? pd()->desc()->bias_type() : c_type;
467     auto acc_type = pd()->impl_acc_type();
468 
469     offset_t ab_offset
470             = pd()->with_ab_zero_points() ? offset_t::fixed : offset_t::none;
471     offset_t c_offset
472             = pd()->with_c_zero_points() ? offset_t::runtime : offset_t::none;
473     offset_t bias_offset
474             = pd()->with_bias() ? offset_t::runtime : offset_t::none;
475 
476     bool may_k_block = (pd()->desc()->k() > kernel_t::min_block_k(a_type));
477     bool got_info = false;
478 
479     bool with_eltwise
480             = (pd()->attr()->post_ops_.find(primitive_kind::eltwise) != -1);
481 
482     for (bool first_k_block : {false, true}) {
483         for (bool last_k_block : {false, true}) {
484             if ((!first_k_block || !last_k_block) && !may_k_block) continue;
485             if (may_k_block && last_k_block && (c_offset == offset_t::none)
486                     && !with_eltwise)
487                 kernel_[first_k_block][last_k_block]
488                         = kernel_[first_k_block][false];
489             else if (may_k_block && first_k_block && pd()->beta() == 1.0f)
490                 kernel_[first_k_block][last_k_block]
491                         = kernel_[false][last_k_block];
492             else {
493                 auto this_beta = pd()->beta();
494                 auto this_c_offset = c_offset;
495                 auto *this_post_ops = &pd()->attr()->post_ops_;
496                 post_ops_t no_post_ops;
497 
498                 if (!first_k_block) this_beta = 1.0f;
499                 if (!last_k_block) {
500                     this_c_offset = offset_t::none;
501                     this_post_ops = &no_post_ops;
502                 }
503 
504                 kernel_t kernel;
505 
506                 auto status = kernel.init(arch_, pd()->with_batch(), ab_offset,
507                         ab_offset, this_c_offset, bias_offset, pd()->alpha(),
508                         this_beta, *this_post_ops, a_type, b_type, c_type,
509                         co_type, acc_type, pd()->unroll_m(), pd()->unroll_n(),
510                         pd()->kernel_tag());
511 
512                 if (status != status::success) return status;
513 
514                 if (!got_info) {
515                     compute_info_ = kernel.driver_info();
516                     got_info = true;
517                 }
518 
519                 create_kernel(
520                         engine, &kernel_[first_k_block][last_k_block], kernel);
521 
522                 if (!kernel_[first_k_block][last_k_block])
523                     return status::runtime_error;
524             }
525         }
526     }
527 
528     return status::success;
529 }
530 
init_res_storage(engine_t * engine,gpu_resource_t * r) const531 status_t xe_hp_systolic_gemm_t::init_res_storage(
532         engine_t *engine, gpu_resource_t *r) const {
533     auto a_type = pd()->desc()->a_type();
534     auto b_type = pd()->desc()->b_type();
535 
536     auto m = pd()->desc()->m();
537     auto n = pd()->desc()->n();
538     auto k = pd()->desc()->k();
539 
540     int64_t align_m = compute_info_.wgTile(LoopM);
541     int64_t align_n = compute_info_.wgTile(LoopN);
542 
543     auto m_aligned = utils::rnd_up(m, align_m);
544     auto n_aligned = utils::rnd_up(n, align_n);
545 
546     auto max_ldab_packed = max_ld_packed(k);
547 
548     if (!pd()->packed_a()) {
549         memory_storage_t *a_packed_ptr;
550         engine->create_memory_storage(&a_packed_ptr,
551                 m_aligned * max_ldab_packed * types::data_type_size(a_type));
552         if (!a_packed_ptr) return status::runtime_error;
553 
554         std::unique_ptr<memory_storage_t> a_packed(a_packed_ptr);
555         r->add_memory_storage(A_PACKED_, std::move(a_packed));
556     }
557 
558     if (!pd()->packed_b()) {
559         memory_storage_t *b_packed_ptr;
560         engine->create_memory_storage(&b_packed_ptr,
561                 n_aligned * max_ldab_packed * types::data_type_size(b_type));
562         if (!b_packed_ptr) return status::runtime_error;
563 
564         std::unique_ptr<memory_storage_t> b_packed(b_packed_ptr);
565         r->add_memory_storage(B_PACKED_, std::move(b_packed));
566     }
567 
568     return status::success;
569 }
570 
enable_mn_blocking() const571 bool xe_hp_systolic_gemm_t::enable_mn_blocking() const {
572     return (pd()->desc()->m() >= 8192) && (pd()->desc()->n() >= 8192);
573 }
574 
575 std::tuple<int64_t, int64_t, int64_t>
get_blocking() const576 xe_hp_systolic_gemm_t::get_blocking() const {
577     int64_t m = pd()->desc()->m();
578     int64_t n = pd()->desc()->n();
579     int64_t k = pd()->desc()->k();
580 
581     int64_t unroll_k = compute_info_.unroll[LoopK];
582 
583     int64_t align_m = compute_info_.wgTile(LoopM);
584     int64_t align_n = compute_info_.wgTile(LoopN);
585 
586     m = utils::rnd_up(m, align_m);
587     n = utils::rnd_up(n, align_n);
588 
589     // Decide on m/n blocking.
590     int64_t block_m = compute_info_.blocking[LoopM];
591     int64_t block_n = compute_info_.blocking[LoopN];
592     int64_t max_block_m = utils::rnd_up(m, align_m);
593     int64_t max_block_n = utils::rnd_up(n, align_n);
594 
595     if (enable_mn_blocking()) {
596         if (n <= block_n)
597             block_m = (block_m * block_n) / n;
598         else if (m <= block_m)
599             block_n = (block_m * block_n) / m;
600         else if (n < 2 * block_n) {
601             block_n = utils::rnd_up(n / 2, align_n);
602             block_m = (2 * block_m * block_n) / n;
603         } else if (m < 2 * block_m) {
604             block_m = utils::rnd_up(m / 2, align_m);
605             block_n = (2 * block_m * block_n) / m;
606         }
607 
608         block_m = utils::rnd_dn(nstl::min(block_m, max_block_m), align_m);
609         block_n = utils::rnd_dn(nstl::min(block_n, max_block_n), align_n);
610     } else {
611         block_m = m;
612         block_n = n;
613     }
614 
615     // Decide on k blocking.
616     int64_t block_k = compute_info_.blocking[LoopK];
617     int64_t nblock_k = utils::div_up(k, block_k);
618     block_k = utils::div_up(k, nblock_k);
619     block_k = utils::rnd_up(
620             (pd()->desc()->acc_type != pd()->desc()->c_type()) ? k : block_k,
621             unroll_k);
622 
623     return std::make_tuple(block_m, block_n, block_k);
624 }
625 
launch_copy(const gemm_exec_ctx_t & ctx,int64_t r,int64_t c,const memory_storage_t & src,int64_t offset_src,int64_t ld_src,const memory_storage_t & dst,int32_t offset_dst,int32_t ld_dst,bool copyb) const626 status_t xe_hp_systolic_gemm_t::launch_copy(const gemm_exec_ctx_t &ctx,
627         int64_t r, int64_t c, const memory_storage_t &src, int64_t offset_src,
628         int64_t ld_src, const memory_storage_t &dst, int32_t offset_dst,
629         int32_t ld_dst, bool copyb) const {
630 
631     using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t;
632 
633     if (pd()->with_ab_zero_points()) {
634         auto status
635                 = launch_clear_sum(ctx, r, c, dst, offset_dst, ld_dst, copyb);
636         if (status) return status;
637     }
638 
639     int64_t unroll_k = compute_info_.unroll[LoopK];
640 
641     int64_t align_r = 0, align_c = 0;
642 
643     if (!copyb) {
644         align_r = compute_info_.wgTile(LoopM);
645         align_c = unroll_k;
646     } else {
647         align_r = unroll_k;
648         align_c = compute_info_.wgTile(LoopN);
649     }
650 
651     bool transa = (pd()->desc()->transa() == dnnl_trans);
652     bool transb = (pd()->desc()->transb() == dnnl_trans);
653     bool trans = !copyb ? transa : transb;
654 
655     auto &kernel = copy_kernel_[copyb][false];
656 
657     assert(kernel);
658     compute::kernel_arg_list_t arg_list;
659     arg_list.set(0, r);
660     arg_list.set(1, c);
661     arg_list.set(2, src);
662     arg_list.set(3, offset_src);
663     arg_list.set(4, ld_src);
664     arg_list.set(5, dst);
665     arg_list.set(6, offset_dst);
666     arg_list.set(7, ld_dst);
667 
668     auto elt_size = types::data_type_size(pd()->desc()->a_type());
669     size_t r_threads = utils::div_up(utils::rnd_up(r, align_r),
670             copy_kernel_t::unroll_r(
671                     arch_, elt_size, pd()->unroll_n(), copyb, trans));
672     size_t c_threads = utils::div_up(utils::rnd_up(c, align_c),
673             copy_kernel_t::unroll_c(
674                     arch_, elt_size, pd()->unroll_n(), copyb, trans));
675     size_t sg = copy_kernel_t::subgroup_size(arch_, elt_size, copyb, trans);
676 
677     size_t r_lsz = trans ? 1 : 16;
678     size_t c_lsz = trans ? 16 : 1;
679 
680     if (r_threads > r_lsz)
681         r_threads = utils::rnd_up(r_threads, r_lsz);
682     else
683         r_lsz = r_threads;
684 
685     if (c_threads > c_lsz)
686         c_threads = utils::rnd_up(c_threads, c_lsz);
687     else
688         c_lsz = c_threads;
689 
690     size_t gws[3] = {r_threads * sg, c_threads, 1};
691     size_t lws[3] = {r_lsz * sg, c_lsz, 1};
692 
693     auto nd_range = compute::nd_range_t(gws, lws);
694 
695     return parallel_for(ctx, nd_range, kernel, arg_list);
696 }
697 
launch_clear_sum(const gemm_exec_ctx_t & ctx,int64_t r,int64_t c,const memory_storage_t & dst,int32_t offset_dst,int32_t ld_dst,bool copyb) const698 status_t xe_hp_systolic_gemm_t::launch_clear_sum(const gemm_exec_ctx_t &ctx,
699         int64_t r, int64_t c, const memory_storage_t &dst, int32_t offset_dst,
700         int32_t ld_dst, bool copyb) const {
701 
702     auto &kernel = copy_kernel_[copyb][true];
703 
704     assert(kernel);
705     compute::kernel_arg_list_t arg_list;
706     arg_list.set(0, r);
707     arg_list.set(1, c);
708     arg_list.set(2, dst);
709     arg_list.set(3, offset_dst);
710     arg_list.set(4, ld_dst);
711 
712     auto elt_size = types::data_type_size(pd()->desc()->a_type());
713     size_t threads = !copyb ? utils::div_up(r, pd()->unroll_m())
714                             : utils::div_up(c, pd()->unroll_n());
715     size_t sg = ocl::xe_systolic_gemm_copy_kernel_t::subgroup_size_clear_sum(
716             arch_, elt_size, copyb);
717 
718     size_t gws[3] = {threads * sg, 1, 1};
719     size_t lws[3] = {sg, 1, 1};
720 
721     auto nd_range = compute::nd_range_t(gws, lws);
722 
723     return parallel_for(ctx, nd_range, kernel, arg_list);
724 }
725 
launch_compute(const gemm_exec_ctx_t & ctx,int32_t m,int32_t n,int32_t k,const memory_storage_t & ap,int64_t offset_a,int32_t lda,const memory_storage_t & bp,int64_t offset_b,int32_t ldb,const memory_storage_t & c,int64_t offset_c,int32_t ldc,float alpha,float beta,int16_t ao,int16_t bo,const memory_storage_t & co,int32_t offset_co,bool first_k_block,bool last_k_block,int32_t batch,int32_t stride_a,int32_t stride_b,int32_t stride_c) const726 status_t xe_hp_systolic_gemm_t::launch_compute(const gemm_exec_ctx_t &ctx,
727         int32_t m, int32_t n, int32_t k, const memory_storage_t &ap,
728         int64_t offset_a, int32_t lda, const memory_storage_t &bp,
729         int64_t offset_b, int32_t ldb, const memory_storage_t &c,
730         int64_t offset_c, int32_t ldc, float alpha, float beta, int16_t ao,
731         int16_t bo, const memory_storage_t &co, int32_t offset_co,
732         bool first_k_block, bool last_k_block, int32_t batch, int32_t stride_a,
733         int32_t stride_b, int32_t stride_c) const {
734 
735     auto tg_m = compute_info_.wg[LoopM];
736     auto tg_n = compute_info_.wg[LoopN];
737 
738     auto &kernel = kernel_[first_k_block][last_k_block];
739 
740     //   kernel void gemm_kernel(global char *Ap, global uchar *Bp, global int *C,
741     //                           int k, int ldc,
742     //                           long offsetA, long offsetB, long offsetC,
743     //                           int m, int n,
744     //                           float alpha, float beta,
745     //                           int lda, int ldb)
746 
747     assert(kernel);
748 
749     compute::kernel_arg_list_t arg_list;
750     int argn = 0;
751     arg_list.set(argn++, ap);
752     arg_list.set(argn++, bp);
753     arg_list.set(argn++, c);
754     arg_list.set(argn++, offset_a);
755     arg_list.set(argn++, offset_b);
756     arg_list.set(argn++, offset_c);
757     arg_list.set(argn++, lda);
758     arg_list.set(argn++, ldb);
759     arg_list.set(argn++, ldc);
760     arg_list.set(argn++, m);
761     arg_list.set(argn++, n);
762     arg_list.set(argn++, k);
763     arg_list.set(argn++, alpha);
764     arg_list.set(argn++, beta);
765     if (pd()->with_ab_zero_points()) {
766         uint32_t abo = (uint16_t(ao) | (uint16_t(bo) << 16));
767         arg_list.set(argn++, abo);
768     }
769     if ((pd()->with_bias() || pd()->with_c_zero_points())) {
770         arg_list.set(argn++, co);
771         arg_list.set(argn++, offset_co);
772     }
773     if (pd()->use_new_kernels()) {
774         uint32_t flags = 0;
775         if (co_kind_ == 'R') flags |= FlagCORow;
776         if (co_kind_ == 'C') flags |= FlagCOColumn;
777         if (!first_k_block) flags |= FlagNoninitialKBlock;
778         if (!last_k_block) flags |= FlagNonfinalKBlock;
779         arg_list.set(argn++, flags);
780     }
781     if (pd()->with_batch()) {
782         arg_list.set(argn++, stride_a);
783         arg_list.set(argn++, stride_b);
784         arg_list.set(argn++, stride_c);
785     }
786 
787     auto thread_m = utils::div_up(m, pd()->unroll_m() * tg_m) * tg_m;
788     auto thread_n = utils::div_up(n, pd()->unroll_n() * tg_n) * tg_n;
789 
790     if (walk_n_first_) std::swap(thread_m, thread_n);
791 
792     size_t gws[3] = {size_t(thread_m), size_t(thread_n), 1};
793     size_t lws[3] = {size_t(tg_m), size_t(tg_n), 1};
794     if (pd()->with_batch()) gws[2] = batch;
795 
796     lws[1] *= compute_info_.wgExpand;
797     gws[1] *= compute_info_.wgExpand;
798 
799     gemm_linear_order_args(arg_list, argn, lws, gws, m, n, false, compute_info_,
800             pd()->dev_info_);
801 
802     lws[0] *= compute_info_.subgroupSize;
803     gws[0] *= compute_info_.subgroupSize;
804 
805     auto nd_range = compute::nd_range_t(gws, lws);
806 
807     return parallel_for(ctx, nd_range, kernel, arg_list);
808 }
809 
execute(const gemm_exec_ctx_t & ctx) const810 status_t xe_hp_systolic_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
811     auto a_type = pd()->desc()->a_type();
812     auto b_type = pd()->desc()->b_type();
813     auto c_type = pd()->desc()->c_type();
814     auto bias_type = pd()->desc()->bias_type();
815 
816     auto m = pd()->desc()->m();
817     auto n = pd()->desc()->n();
818     auto k = pd()->desc()->k();
819     auto batch = pd()->desc()->batch();
820 
821     bool packed_a = pd()->packed_a();
822     bool packed_b = pd()->packed_b();
823     bool packed_c = pd()->packed_c();
824 
825     auto lda = packed_a ? 0 : pd()->desc()->lda();
826     auto ldb = packed_b ? 0 : pd()->desc()->ldb();
827     auto ldc = packed_c ? pd()->ldc_packed() : pd()->desc()->ldc();
828 
829     auto stride_a = pd()->desc()->stride_a();
830     auto stride_b = pd()->desc()->stride_b();
831     auto stride_c = pd()->desc()->stride_c();
832 
833     auto alpha = pd()->alpha();
834     auto beta = pd()->beta();
835 
836     auto &a = GEMM_CTX_ARG_STORAGE(b);
837     auto &b = GEMM_CTX_ARG_STORAGE(a);
838     auto &c = GEMM_CTX_ARG_STORAGE(c);
839     auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point);
840     auto &bias = GEMM_CTX_ARG_STORAGE(bias);
841     auto *co = &c_zp;
842 
843     auto &a_packed = packed_a ? a : CTX_GPU_RES_STORAGE(A_PACKED_);
844     auto &b_packed = packed_b ? b : CTX_GPU_RES_STORAGE(B_PACKED_);
845 
846     int32_t ao = 0, bo = 0;
847 
848     size_t off_a0
849             = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a;
850     size_t off_b0
851             = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b;
852     size_t off_c0
853             = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c;
854     size_t off_co0 = 0;
855 
856     if (pd()->with_ab_zero_points()) {
857         const int *ao_i32 = nullptr;
858         const int *bo_i32 = nullptr;
859         pd()->attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, nullptr, &ao_i32);
860         pd()->attr()->zero_points_.get(
861                 DNNL_ARG_WEIGHTS, nullptr, nullptr, &bo_i32);
862         ao = -*ao_i32;
863         bo = -*bo_i32;
864     }
865 
866     if (pd()->with_bias()) {
867         off_co0 = bias.offset() / types::data_type_size(bias_type);
868         co = &bias;
869     }
870 
871     int64_t block_m = 0, block_n = 0, block_k = 0;
872     std::tie(block_m, block_n, block_k) = get_blocking();
873 
874     auto ld_packed = get_ld_packed(k);
875     auto lda_packed = packed_a ? pd()->lda_packed() : ld_packed;
876     auto ldb_packed = packed_b ? pd()->ldb_packed() : ld_packed;
877 
878     status_t status;
879 
880     if (!packed_a) {
881         assert(batch == 1);
882         status = launch_copy(
883                 ctx, m, k, a, off_a0, lda, a_packed, 0, lda_packed, false);
884         if (status) return status;
885     }
886 
887     if (!packed_b) {
888         assert(batch == 1);
889         status = launch_copy(
890                 ctx, k, n, b, off_b0, ldb, b_packed, 0, ldb_packed, true);
891         if (status) return status;
892     }
893 
894     for (int64_t Bk = 0; Bk < k; Bk += block_k) {
895         int64_t size_k = k - Bk;
896         bool first_k_block = (Bk == 0);
897         bool last_k_block = (size_k <= block_k);
898         if (!last_k_block) size_k = block_k;
899 
900         for (int64_t Bm = 0; Bm < m; Bm += block_m) {
901             int64_t size_m = m - Bm;
902             if (size_m > block_m) size_m = block_m;
903 
904             auto off_a_packed = Bm * lda_packed + Bk * pd()->unroll_m();
905             if (packed_a) off_a_packed += off_a0;
906 
907             for (int64_t Bn = 0; Bn < n; Bn += block_n) {
908                 int64_t size_n = n - Bn;
909                 if (size_n > block_n) size_n = block_n;
910 
911                 auto off_b_packed = Bn * ldb_packed + Bk * pd()->unroll_n();
912                 if (packed_b) off_b_packed += off_b0;
913 
914                 auto off_c = off_c0 + Bm + Bn * ldc;
915                 auto off_co = int32_t(off_co0);
916                 switch (co_kind_) {
917                     case 'R': off_co += Bm; break;
918                     case 'C': off_co += Bn; break;
919                     default: break;
920                 }
921 
922                 float this_beta = first_k_block ? beta : 1.0f;
923                 status = launch_compute(ctx, size_m, size_n, size_k, a_packed,
924                         off_a_packed, lda_packed, b_packed, off_b_packed,
925                         ldb_packed, c, off_c, ldc, alpha, this_beta, ao, bo,
926                         *co, off_co, first_k_block, last_k_block, batch,
927                         stride_a, stride_b, stride_c);
928                 if (status) return status;
929             }
930         }
931     }
932 
933     return status::success;
934 }
935 
936 } // namespace jit
937 } // namespace gpu
938 } // namespace impl
939 } // namespace dnnl
940 
941 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
942