1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/jit/gemm/xe_hp_systolic_gemm.hpp"
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_traits.hpp"
21 #include "common/float16.hpp"
22 #include "common/type_helpers.hpp"
23 #include "gpu/jit/gemm/gemm_walk_orders.hpp"
24 #include "gpu/jit/ngen_type_bridge.hpp"
25 #include "gpu/ocl/gemm/xe_systolic_gemm_copy_kernel.hpp"
26
27 namespace dnnl {
28 namespace impl {
29 namespace gpu {
30 namespace jit {
31
init(engine_t * engine)32 status_t xe_hp_systolic_gemm_t::pd_t::init(engine_t *engine) {
33 using namespace prop_kind;
34 using namespace data_type;
35 using namespace primitive_kind;
36 using smask_t = primitive_attr_t::skip_mask_t;
37 using arch_t = compute::gpu_arch_t;
38
39 assert(engine->kind() == engine_kind::gpu);
40 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
41
42 if (!compute_engine->mayiuse_ngen_kernels()) return status::unimplemented;
43 if (!compute_engine->mayiuse_large_grf_mode()) return status::unimplemented;
44
45 dev_info_ = compute_engine->device_info();
46 auto arch = dev_info_->gpu_arch();
47
48 const auto &d = desc();
49
50 bool dt_float_ok = (d->a_type() == d->b_type()
51 && utils::one_of(d->a_type(), bf16, f16)
52 && utils::one_of(d->c_type(), f32, d->a_type()));
53
54 bool dt_int_ok = (utils::one_of(d->a_type(), u8, s8)
55 && utils::one_of(d->b_type(), u8, s8) && (d->c_type() == s32));
56
57 if (dt_int_ok) {
58 if (attr()->zero_points_.defined(DNNL_ARG_SRC)) {
59 const int *ao_i32 = nullptr;
60 attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, nullptr, &ao_i32);
61 a_zp_ = (*ao_i32 != 0);
62 } else if (!attr()->zero_points_.has_default_values(DNNL_ARG_SRC))
63 return status::unimplemented;
64
65 if (attr()->zero_points_.defined(DNNL_ARG_WEIGHTS)) {
66 const int *bo_i32 = nullptr;
67 attr()->zero_points_.get(
68 DNNL_ARG_WEIGHTS, nullptr, nullptr, &bo_i32);
69 b_zp_ = (*bo_i32 != 0);
70 } else if (!attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS))
71 return status::unimplemented;
72
73 c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
74 }
75
76 bool ok = set_default_formats(d->a_type());
77 if (!ok) return status::unimplemented;
78
79 CHECK(attr_.set_default_formats(dst_md(0)));
80
81 if (use_fma()) return status::unimplemented;
82
83 // LIMITATIONS:
84 // - batch is not supported for unpacked inputs.
85 // - runtime dims are not supported
86 bool limits_ok
87 = !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k());
88 if (!packed_a())
89 limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL)
90 && (d->batch() == 1);
91 if (!packed_b())
92 limits_ok = limits_ok && (d->ldb() != DNNL_RUNTIME_DIM_VAL)
93 && (d->batch() == 1);
94 if (!packed_c())
95 limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL);
96
97 auto attr_skip_mask = smask_t::oscale | smask_t::post_ops;
98
99 if (dt_int_ok) attr_skip_mask |= smask_t::zero_points_runtime;
100
101 bool arch_ok = (arch == arch_t::xe_hp);
102 arch_ok |= (arch == arch_t::xe_hpg);
103 arch_ok |= (arch == arch_t::xe_hpc);
104
105 ok = true && limits_ok && (dt_float_ok || dt_int_ok) && arch_ok
106 && compute_engine->mayiuse(compute::device_ext_t::
107 intel_subgroup_split_matrix_multiply_accumulate)
108 && attr()->has_default_values(attr_skip_mask)
109 && attr()->output_scales_.mask_ == 0 && attr()->post_ops_.len() <= 2
110 && IMPLICATION(with_bias(),
111 dt_float_ok
112 && utils::one_of(d->bias_type(), d->a_type(), f32)
113 && utils::one_of(bias_cmask(), 0, 1 << 0, 1 << 1));
114
115 // check if there is sum post op and only at first place
116 ok &= IMPLICATION(attr()->post_ops_.find(sum) != -1,
117 attr()->post_ops_.find(sum) == 0
118 && attr()->post_ops_.find(sum, 1) == -1);
119
120 // check if post ops are supported
121 ok &= IMPLICATION(attr()->post_ops_.len() > 0,
122 jit_post_op_injector_is_supported(attr()->post_ops_, true));
123
124 if (dt_int_ok) {
125 ok &= IMPLICATION(a_zp_, !packed_b()) && IMPLICATION(b_zp_, !packed_a())
126 && IMPLICATION(
127 c_zp_, !attr()->zero_points_.defined(DNNL_ARG_DST));
128
129 int cmask_a = 0, cmask_b = 0, cmask_c = 0;
130 attr()->zero_points_.get(DNNL_ARG_WEIGHTS, nullptr, &cmask_b, nullptr);
131 attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, &cmask_a, nullptr);
132 attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &cmask_c, nullptr);
133 ok &= (cmask_a == 0) && (cmask_b == 0)
134 && utils::one_of(cmask_c, 0, 1 << 0, 1 << 1);
135 }
136
137 if (!ok) return status::unimplemented;
138
139 return status::success;
140 }
141
142 namespace {
143 struct nocopy_table_t {
144 int mn_limit[2][2]; // Use no-copy if m*n < mn_limit * mn_limit and
145 int k_limit[2][2]; // Use no-copy if k < k_limit
146 };
147
148 const nocopy_table_t xe_hp_f16_nocopy_table[] = {
149 // NN NT TN TT
150 {{{1280, 768}, {512, 384}}, {{512, 768}, {1024, 512}}}};
151
152 const nocopy_table_t xe_hp_bf16_nocopy_table[] = {
153 // NN NT TN TT
154 {{{512, 256}, {512, 512}}, {{512, 256}, {384, 384}}}};
155
156 const nocopy_table_t xe_hp_x8x8s32_nocopy_table[] = {
157 // NN NT TN TT
158 {{{384, 384}, {384, 384}}, {{384, 512}, {384, 256}}}};
159 } // namespace
160
use_fma()161 bool xe_hp_systolic_gemm_t::pd_t::use_fma() {
162 using namespace data_type;
163
164 const auto &d = desc();
165
166 if (any_prepacked_) return false;
167
168 // Use FMA implementation if one matrix is very small.
169 if (d->m() < 32 && d->n() < 32) return true;
170 if (d->m() < 32 && d->k() < 32) return true;
171 if (d->n() < 32 && d->k() < 32) return true;
172
173 // Use FMA for small/medium sizes.
174 if (utils::one_of(d->c_type(), bf16, f16, s32)) {
175 const nocopy_table_t *all_tables[3] = {xe_hp_f16_nocopy_table,
176 xe_hp_bf16_nocopy_table, xe_hp_x8x8s32_nocopy_table};
177 const int type_idx
178 = (d->c_type() == f16) ? 0 : (d->c_type() == bf16) ? 1 : 2;
179 const nocopy_table_t *table = all_tables[type_idx];
180 const long mnl = table->mn_limit[d->transa()][d->transb()];
181 const long kl = table->k_limit[d->transa()][d->transb()];
182
183 if ((d->m() * d->n() < mnl * mnl) && (d->k() < kl)) return true;
184 }
185
186 return false;
187 }
188
set_default_formats(data_type_t dt)189 bool xe_hp_systolic_gemm_t::pd_t::set_default_formats(data_type_t dt) {
190 using namespace format_tag;
191 using new_kernel_t = gen_gemm_xe_systolic_kernel_t;
192
193 auto sz = types::data_type_size(dt);
194 const auto &d = desc();
195 auto arch = dev_info_->gpu_arch();
196
197 memory_desc_wrapper a_mdw(&desc_.b_desc);
198 memory_desc_wrapper b_mdw(&desc_.a_desc);
199 memory_desc_wrapper c_mdw(&desc_.c_desc);
200
201 bool a_any = a_mdw.format_any();
202 bool b_any = b_mdw.format_any();
203 bool c_any = c_mdw.format_any();
204 bool batch = d->is_batched();
205
206 format_tag_t a_packed_tag = batch ? ((sz == 2) ? aCB4c8b8c2b : aCB4c8b8c4b)
207 : ((sz == 2) ? BA4b8a8b2a : BA4b8a8b4a);
208 format_tag_t b_packed_tag_48 = batch ? ((sz == 2) ? aBC48b16c : aBC48b32c)
209 : ((sz == 2) ? AB48a16b : AB48a32b);
210 format_tag_t b_packed_tag_32 = batch ? ((sz == 2) ? aBC32b16c : aBC32b32c)
211 : ((sz == 2) ? AB32a16b : AB32a32b);
212 format_tag_t unpacked_tag = batch ? abc : ab;
213
214 if (arch == compute::gpu_arch_t::xe_hpc) {
215 a_packed_tag = batch ? ((sz == 2) ? aCB4c8b16c2b : aCB4c8b16c4b)
216 : ((sz == 2) ? BA4b8a16b2a : BA4b8a16b4a);
217 }
218
219 bool a_prepacked = a_mdw.matches_tag(a_packed_tag);
220 bool bc_prepacked_32 = b_mdw.matches_tag(b_packed_tag_32)
221 || c_mdw.matches_tag(b_packed_tag_32);
222 bool bc_prepacked_48 = b_mdw.matches_tag(b_packed_tag_48)
223 || c_mdw.matches_tag(b_packed_tag_48);
224 bool c_prepacked = c_mdw.matches_tag(b_packed_tag_32)
225 || c_mdw.matches_tag(b_packed_tag_48);
226
227 any_prepacked_ = a_prepacked || bc_prepacked_32 || bc_prepacked_48;
228
229 unroll_m_ = 32;
230 unroll_n_ = 0;
231 kernel_tag_ = 0;
232 if (bc_prepacked_32)
233 unroll_n_ = 32;
234 else if (bc_prepacked_48)
235 unroll_n_ = 48;
236
237 use_new_kernels_ = !c_prepacked && !with_ab_zero_points() && (d->k() >= 64);
238 use_new_kernels_ |= (arch >= compute::gpu_arch_t::xe_hpc);
239
240 new_kernel_t::choose_unrolls(arch, dev_info_->eu_count(), d->a_type(),
241 d->b_type(), d->c_type(), d->m(), d->n(), d->k(), d->batch(),
242 unroll_m_, unroll_n_, kernel_tag_);
243
244 format_tag_t b_packed_tag
245 = (unroll_n_ == 48) ? b_packed_tag_48 : b_packed_tag_32;
246 format_tag_t c_packed_tag = use_new_kernels_ ? unpacked_tag : b_packed_tag;
247
248 packed_a_ = packed_b_ = packed_c_ = false;
249
250 if (a_any) {
251 CHECK(memory_desc_init_by_tag(
252 desc_.b_desc, b_zp_ ? unpacked_tag : a_packed_tag));
253 auto &ld = desc_.b_desc.padded_dims[batch ? 1 : 0];
254 ld = nice_ld(ld, int(sz));
255 desc_.b_desc.format_desc.blocking.strides[batch ? 2 : 1]
256 = unroll_m_ * ld;
257 packed_a_ = true;
258 } else if (a_mdw.matches_one_of_tag(a_packed_tag, ab, ba, abc, acb)
259 == undef)
260 return false;
261
262 if (b_any) {
263 CHECK(memory_desc_init_by_tag(
264 desc_.a_desc, a_zp_ ? unpacked_tag : b_packed_tag));
265 auto &ld = desc_.a_desc.padded_dims[batch ? 2 : 1];
266 ld = nice_ld(ld, int(sz));
267 desc_.a_desc.format_desc.blocking.strides[batch ? 1 : 0]
268 = unroll_n_ * ld;
269 packed_b_ = true;
270 } else if (b_mdw.matches_one_of_tag(b_packed_tag, ab, ba, abc, acb)
271 == undef)
272 return false;
273
274 if (c_any)
275 CHECK(memory_desc_init_by_tag(desc_.c_desc, c_packed_tag));
276 else if (c_mdw.matches_one_of_tag(c_packed_tag, ab, abc) == undef)
277 return false;
278
279 packed_a_ = packed_a_ || a_mdw.matches_tag(a_packed_tag);
280 packed_b_ = packed_b_ || b_mdw.matches_tag(b_packed_tag);
281 packed_c_ = c_mdw.matches_tag(b_packed_tag);
282
283 return gpu_gemm_pd_t::set_default_formats();
284 }
285
init(engine_t * engine)286 status_t xe_hp_systolic_gemm_t::init(engine_t *engine) {
287 arch_ = pd()->dev_info_->gpu_arch();
288 eu_count_ = pd()->dev_info_->eu_count();
289
290 auto a_type = pd()->desc()->a_type();
291 auto b_type = pd()->desc()->b_type();
292
293 int cmask = -1;
294
295 if (pd()->with_c_zero_points())
296 pd()->attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &cmask, nullptr);
297 else if (pd()->with_bias())
298 cmask = pd()->bias_cmask();
299
300 switch (cmask) {
301 case 0: co_kind_ = 'F'; break;
302 case (1 << 1): co_kind_ = 'R'; break;
303 case (1 << 0): co_kind_ = 'C'; break;
304 case -1:
305 default: co_kind_ = 'N'; break;
306 }
307
308 if (get_verbose() >= 2) {
309 char tag_s[2] = {pd()->kernel_tag(), 0};
310 printf("onednn_verbose,info,gpu,gemm,kernel:%dx%d,%s,new:%c\n",
311 pd()->unroll_m(), pd()->unroll_n(), tag_s,
312 pd()->use_new_kernels() ? 'Y' : 'N');
313 }
314
315 // Initialize compute kernels (assembly)
316 {
317 auto status = pd()->use_new_kernels() ? init_compute_new(engine)
318 : init_compute_old(engine);
319 if (status != status::success) return status;
320 }
321
322 // Initialize copy kernels (OpenCL)
323 for (bool copy_b : {false, true}) {
324 for (bool clear_sum : {false, true}) {
325 if (clear_sum && !pd()->with_ab_zero_points()) continue;
326 if (!copy_b ? pd()->packed_a() : pd()->packed_b()) continue;
327
328 using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t;
329 compute::kernel_ctx_t kernel_ctx;
330
331 auto trans
332 = !copy_b ? pd()->desc()->transa() : pd()->desc()->transb();
333 auto status = copy_kernel_t::init_kernel_ctx(kernel_ctx, arch_,
334 !copy_b ? a_type : b_type, pd()->unroll_n(), copy_b, trans,
335 pd()->with_ab_zero_points(), clear_sum);
336 if (status != status::success) return status;
337
338 create_kernel(engine, ©_kernel_[copy_b][clear_sum],
339 copy_kernel_t::name(arch_), kernel_ctx);
340 if (!copy_kernel_[copy_b][clear_sum]) return status::runtime_error;
341 }
342 }
343
344 return status::success;
345 }
346
init_compute_old(engine_t * engine)347 status_t xe_hp_systolic_gemm_t::init_compute_old(engine_t *engine) {
348 using kernel_t = xehp_systolic_gemm_kernel_t<gpu_xe_hp>;
349 using arch_t = compute::gpu_arch_t;
350
351 kernel_t::config_t cfg;
352
353 auto a_type = pd()->desc()->a_type();
354 auto b_type = pd()->desc()->b_type();
355 auto c_type = pd()->desc()->c_type();
356 auto acc_type = pd()->impl_acc_type();
357
358 cfg.a_type = convert_dnnl_type_to_ngen(a_type);
359 cfg.b_type = convert_dnnl_type_to_ngen(b_type);
360 cfg.c_type = convert_dnnl_type_to_ngen(c_type);
361 cfg.acc_type = convert_dnnl_type_to_ngen(acc_type);
362 cfg.alpha1 = (pd()->alpha() == 1.0f);
363 cfg.beta0 = (pd()->beta() == 0.0f);
364 cfg.beta1 = (pd()->beta() == 1.0f);
365 cfg.post_ops = pd()->attr()->post_ops_;
366 cfg.a_bias = cfg.b_bias = pd()->with_ab_zero_points();
367 cfg.c_packed = pd()->packed_c();
368 cfg.batch = pd()->with_batch();
369 walk_n_first_ = cfg.walk_n_first
370 = (pd()->desc()->m() >= 2 * pd()->desc()->n());
371 cfg.tile_m = pd()->unroll_m();
372 cfg.tile_n = pd()->unroll_n();
373 cfg.global_3x_buf = (cfg.tile_n == 32);
374
375 if (pd()->with_c_zero_points())
376 cfg.co_type = cfg.c_type;
377 else if (pd()->with_bias()) {
378 cfg.early_c_bias = true;
379 cfg.co_type = convert_dnnl_type_to_ngen(pd()->desc()->bias_type());
380 }
381
382 switch (co_kind_) {
383 case 'F': cfg.c_bias = kernel_t::bias_t::fixed; break;
384 case 'R': cfg.c_bias = kernel_t::bias_t::row; break;
385 case 'C': cfg.c_bias = kernel_t::bias_t::column; break;
386 case 'N':
387 default: cfg.c_bias = kernel_t::bias_t::none; break;
388 }
389
390 bool may_k_block = (pd()->desc()->k() > kernel_t::min_block_k(a_type));
391 bool got_info = false;
392
393 for (bool first_k_block : {false, true}) {
394 for (bool last_k_block : {false, true}) {
395 if ((!first_k_block || !last_k_block) && !may_k_block) continue;
396 if (may_k_block && last_k_block
397 && (cfg.c_bias == kernel_t::bias_t::none)
398 && !cfg.have_post_op())
399 kernel_[first_k_block][last_k_block]
400 = kernel_[first_k_block][false];
401 else if (may_k_block && first_k_block && cfg.beta1)
402 kernel_[first_k_block][last_k_block]
403 = kernel_[false][last_k_block];
404 else {
405 auto cfg_copy = cfg;
406 if (!first_k_block) {
407 cfg_copy.beta0 = false;
408 cfg_copy.beta1 = true;
409 }
410 if (!last_k_block) {
411 cfg_copy.c_bias = kernel_t::bias_t::none;
412 cfg_copy.post_ops = post_ops_t();
413 }
414
415 switch (arch_) {
416 case arch_t::xe_hp: {
417 auto kernel = kernel_t(cfg_copy);
418
419 create_kernel(engine,
420 &kernel_[first_k_block][last_k_block], kernel);
421
422 if (!got_info) {
423 compute_info_ = kernel.driver_info(eu_count_);
424 got_info = true;
425 }
426 break;
427 }
428 case arch_t::xe_hpg: {
429 using kernel_xe_hpg_t
430 = xehp_systolic_gemm_kernel_t<gpu_xe_hpg>;
431 cfg_copy.emulate64 = true;
432 auto kernel = kernel_xe_hpg_t(
433 cfg_copy.cast<kernel_xe_hpg_t::config_t>());
434
435 create_kernel(engine,
436 &kernel_[first_k_block][last_k_block], kernel);
437
438 if (!got_info) {
439 compute_info_ = kernel.driver_info(eu_count_);
440 got_info = true;
441 }
442 break;
443 }
444 default:
445 assert(!"Unsupported GPU architecture.");
446 return status::unimplemented;
447 break;
448 }
449
450 if (!kernel_[first_k_block][last_k_block])
451 return status::runtime_error;
452 }
453 }
454 }
455
456 return status::success;
457 }
458
init_compute_new(engine_t * engine)459 status_t xe_hp_systolic_gemm_t::init_compute_new(engine_t *engine) {
460 using kernel_t = gen_gemm_xe_systolic_kernel_t;
461 using offset_t = kernel_t::offset_t;
462
463 auto a_type = pd()->desc()->a_type();
464 auto b_type = pd()->desc()->b_type();
465 auto c_type = pd()->desc()->c_type();
466 auto co_type = pd()->with_bias() ? pd()->desc()->bias_type() : c_type;
467 auto acc_type = pd()->impl_acc_type();
468
469 offset_t ab_offset
470 = pd()->with_ab_zero_points() ? offset_t::fixed : offset_t::none;
471 offset_t c_offset
472 = pd()->with_c_zero_points() ? offset_t::runtime : offset_t::none;
473 offset_t bias_offset
474 = pd()->with_bias() ? offset_t::runtime : offset_t::none;
475
476 bool may_k_block = (pd()->desc()->k() > kernel_t::min_block_k(a_type));
477 bool got_info = false;
478
479 bool with_eltwise
480 = (pd()->attr()->post_ops_.find(primitive_kind::eltwise) != -1);
481
482 for (bool first_k_block : {false, true}) {
483 for (bool last_k_block : {false, true}) {
484 if ((!first_k_block || !last_k_block) && !may_k_block) continue;
485 if (may_k_block && last_k_block && (c_offset == offset_t::none)
486 && !with_eltwise)
487 kernel_[first_k_block][last_k_block]
488 = kernel_[first_k_block][false];
489 else if (may_k_block && first_k_block && pd()->beta() == 1.0f)
490 kernel_[first_k_block][last_k_block]
491 = kernel_[false][last_k_block];
492 else {
493 auto this_beta = pd()->beta();
494 auto this_c_offset = c_offset;
495 auto *this_post_ops = &pd()->attr()->post_ops_;
496 post_ops_t no_post_ops;
497
498 if (!first_k_block) this_beta = 1.0f;
499 if (!last_k_block) {
500 this_c_offset = offset_t::none;
501 this_post_ops = &no_post_ops;
502 }
503
504 kernel_t kernel;
505
506 auto status = kernel.init(arch_, pd()->with_batch(), ab_offset,
507 ab_offset, this_c_offset, bias_offset, pd()->alpha(),
508 this_beta, *this_post_ops, a_type, b_type, c_type,
509 co_type, acc_type, pd()->unroll_m(), pd()->unroll_n(),
510 pd()->kernel_tag());
511
512 if (status != status::success) return status;
513
514 if (!got_info) {
515 compute_info_ = kernel.driver_info();
516 got_info = true;
517 }
518
519 create_kernel(
520 engine, &kernel_[first_k_block][last_k_block], kernel);
521
522 if (!kernel_[first_k_block][last_k_block])
523 return status::runtime_error;
524 }
525 }
526 }
527
528 return status::success;
529 }
530
init_res_storage(engine_t * engine,gpu_resource_t * r) const531 status_t xe_hp_systolic_gemm_t::init_res_storage(
532 engine_t *engine, gpu_resource_t *r) const {
533 auto a_type = pd()->desc()->a_type();
534 auto b_type = pd()->desc()->b_type();
535
536 auto m = pd()->desc()->m();
537 auto n = pd()->desc()->n();
538 auto k = pd()->desc()->k();
539
540 int64_t align_m = compute_info_.wgTile(LoopM);
541 int64_t align_n = compute_info_.wgTile(LoopN);
542
543 auto m_aligned = utils::rnd_up(m, align_m);
544 auto n_aligned = utils::rnd_up(n, align_n);
545
546 auto max_ldab_packed = max_ld_packed(k);
547
548 if (!pd()->packed_a()) {
549 memory_storage_t *a_packed_ptr;
550 engine->create_memory_storage(&a_packed_ptr,
551 m_aligned * max_ldab_packed * types::data_type_size(a_type));
552 if (!a_packed_ptr) return status::runtime_error;
553
554 std::unique_ptr<memory_storage_t> a_packed(a_packed_ptr);
555 r->add_memory_storage(A_PACKED_, std::move(a_packed));
556 }
557
558 if (!pd()->packed_b()) {
559 memory_storage_t *b_packed_ptr;
560 engine->create_memory_storage(&b_packed_ptr,
561 n_aligned * max_ldab_packed * types::data_type_size(b_type));
562 if (!b_packed_ptr) return status::runtime_error;
563
564 std::unique_ptr<memory_storage_t> b_packed(b_packed_ptr);
565 r->add_memory_storage(B_PACKED_, std::move(b_packed));
566 }
567
568 return status::success;
569 }
570
enable_mn_blocking() const571 bool xe_hp_systolic_gemm_t::enable_mn_blocking() const {
572 return (pd()->desc()->m() >= 8192) && (pd()->desc()->n() >= 8192);
573 }
574
575 std::tuple<int64_t, int64_t, int64_t>
get_blocking() const576 xe_hp_systolic_gemm_t::get_blocking() const {
577 int64_t m = pd()->desc()->m();
578 int64_t n = pd()->desc()->n();
579 int64_t k = pd()->desc()->k();
580
581 int64_t unroll_k = compute_info_.unroll[LoopK];
582
583 int64_t align_m = compute_info_.wgTile(LoopM);
584 int64_t align_n = compute_info_.wgTile(LoopN);
585
586 m = utils::rnd_up(m, align_m);
587 n = utils::rnd_up(n, align_n);
588
589 // Decide on m/n blocking.
590 int64_t block_m = compute_info_.blocking[LoopM];
591 int64_t block_n = compute_info_.blocking[LoopN];
592 int64_t max_block_m = utils::rnd_up(m, align_m);
593 int64_t max_block_n = utils::rnd_up(n, align_n);
594
595 if (enable_mn_blocking()) {
596 if (n <= block_n)
597 block_m = (block_m * block_n) / n;
598 else if (m <= block_m)
599 block_n = (block_m * block_n) / m;
600 else if (n < 2 * block_n) {
601 block_n = utils::rnd_up(n / 2, align_n);
602 block_m = (2 * block_m * block_n) / n;
603 } else if (m < 2 * block_m) {
604 block_m = utils::rnd_up(m / 2, align_m);
605 block_n = (2 * block_m * block_n) / m;
606 }
607
608 block_m = utils::rnd_dn(nstl::min(block_m, max_block_m), align_m);
609 block_n = utils::rnd_dn(nstl::min(block_n, max_block_n), align_n);
610 } else {
611 block_m = m;
612 block_n = n;
613 }
614
615 // Decide on k blocking.
616 int64_t block_k = compute_info_.blocking[LoopK];
617 int64_t nblock_k = utils::div_up(k, block_k);
618 block_k = utils::div_up(k, nblock_k);
619 block_k = utils::rnd_up(
620 (pd()->desc()->acc_type != pd()->desc()->c_type()) ? k : block_k,
621 unroll_k);
622
623 return std::make_tuple(block_m, block_n, block_k);
624 }
625
launch_copy(const gemm_exec_ctx_t & ctx,int64_t r,int64_t c,const memory_storage_t & src,int64_t offset_src,int64_t ld_src,const memory_storage_t & dst,int32_t offset_dst,int32_t ld_dst,bool copyb) const626 status_t xe_hp_systolic_gemm_t::launch_copy(const gemm_exec_ctx_t &ctx,
627 int64_t r, int64_t c, const memory_storage_t &src, int64_t offset_src,
628 int64_t ld_src, const memory_storage_t &dst, int32_t offset_dst,
629 int32_t ld_dst, bool copyb) const {
630
631 using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t;
632
633 if (pd()->with_ab_zero_points()) {
634 auto status
635 = launch_clear_sum(ctx, r, c, dst, offset_dst, ld_dst, copyb);
636 if (status) return status;
637 }
638
639 int64_t unroll_k = compute_info_.unroll[LoopK];
640
641 int64_t align_r = 0, align_c = 0;
642
643 if (!copyb) {
644 align_r = compute_info_.wgTile(LoopM);
645 align_c = unroll_k;
646 } else {
647 align_r = unroll_k;
648 align_c = compute_info_.wgTile(LoopN);
649 }
650
651 bool transa = (pd()->desc()->transa() == dnnl_trans);
652 bool transb = (pd()->desc()->transb() == dnnl_trans);
653 bool trans = !copyb ? transa : transb;
654
655 auto &kernel = copy_kernel_[copyb][false];
656
657 assert(kernel);
658 compute::kernel_arg_list_t arg_list;
659 arg_list.set(0, r);
660 arg_list.set(1, c);
661 arg_list.set(2, src);
662 arg_list.set(3, offset_src);
663 arg_list.set(4, ld_src);
664 arg_list.set(5, dst);
665 arg_list.set(6, offset_dst);
666 arg_list.set(7, ld_dst);
667
668 auto elt_size = types::data_type_size(pd()->desc()->a_type());
669 size_t r_threads = utils::div_up(utils::rnd_up(r, align_r),
670 copy_kernel_t::unroll_r(
671 arch_, elt_size, pd()->unroll_n(), copyb, trans));
672 size_t c_threads = utils::div_up(utils::rnd_up(c, align_c),
673 copy_kernel_t::unroll_c(
674 arch_, elt_size, pd()->unroll_n(), copyb, trans));
675 size_t sg = copy_kernel_t::subgroup_size(arch_, elt_size, copyb, trans);
676
677 size_t r_lsz = trans ? 1 : 16;
678 size_t c_lsz = trans ? 16 : 1;
679
680 if (r_threads > r_lsz)
681 r_threads = utils::rnd_up(r_threads, r_lsz);
682 else
683 r_lsz = r_threads;
684
685 if (c_threads > c_lsz)
686 c_threads = utils::rnd_up(c_threads, c_lsz);
687 else
688 c_lsz = c_threads;
689
690 size_t gws[3] = {r_threads * sg, c_threads, 1};
691 size_t lws[3] = {r_lsz * sg, c_lsz, 1};
692
693 auto nd_range = compute::nd_range_t(gws, lws);
694
695 return parallel_for(ctx, nd_range, kernel, arg_list);
696 }
697
launch_clear_sum(const gemm_exec_ctx_t & ctx,int64_t r,int64_t c,const memory_storage_t & dst,int32_t offset_dst,int32_t ld_dst,bool copyb) const698 status_t xe_hp_systolic_gemm_t::launch_clear_sum(const gemm_exec_ctx_t &ctx,
699 int64_t r, int64_t c, const memory_storage_t &dst, int32_t offset_dst,
700 int32_t ld_dst, bool copyb) const {
701
702 auto &kernel = copy_kernel_[copyb][true];
703
704 assert(kernel);
705 compute::kernel_arg_list_t arg_list;
706 arg_list.set(0, r);
707 arg_list.set(1, c);
708 arg_list.set(2, dst);
709 arg_list.set(3, offset_dst);
710 arg_list.set(4, ld_dst);
711
712 auto elt_size = types::data_type_size(pd()->desc()->a_type());
713 size_t threads = !copyb ? utils::div_up(r, pd()->unroll_m())
714 : utils::div_up(c, pd()->unroll_n());
715 size_t sg = ocl::xe_systolic_gemm_copy_kernel_t::subgroup_size_clear_sum(
716 arch_, elt_size, copyb);
717
718 size_t gws[3] = {threads * sg, 1, 1};
719 size_t lws[3] = {sg, 1, 1};
720
721 auto nd_range = compute::nd_range_t(gws, lws);
722
723 return parallel_for(ctx, nd_range, kernel, arg_list);
724 }
725
launch_compute(const gemm_exec_ctx_t & ctx,int32_t m,int32_t n,int32_t k,const memory_storage_t & ap,int64_t offset_a,int32_t lda,const memory_storage_t & bp,int64_t offset_b,int32_t ldb,const memory_storage_t & c,int64_t offset_c,int32_t ldc,float alpha,float beta,int16_t ao,int16_t bo,const memory_storage_t & co,int32_t offset_co,bool first_k_block,bool last_k_block,int32_t batch,int32_t stride_a,int32_t stride_b,int32_t stride_c) const726 status_t xe_hp_systolic_gemm_t::launch_compute(const gemm_exec_ctx_t &ctx,
727 int32_t m, int32_t n, int32_t k, const memory_storage_t &ap,
728 int64_t offset_a, int32_t lda, const memory_storage_t &bp,
729 int64_t offset_b, int32_t ldb, const memory_storage_t &c,
730 int64_t offset_c, int32_t ldc, float alpha, float beta, int16_t ao,
731 int16_t bo, const memory_storage_t &co, int32_t offset_co,
732 bool first_k_block, bool last_k_block, int32_t batch, int32_t stride_a,
733 int32_t stride_b, int32_t stride_c) const {
734
735 auto tg_m = compute_info_.wg[LoopM];
736 auto tg_n = compute_info_.wg[LoopN];
737
738 auto &kernel = kernel_[first_k_block][last_k_block];
739
740 // kernel void gemm_kernel(global char *Ap, global uchar *Bp, global int *C,
741 // int k, int ldc,
742 // long offsetA, long offsetB, long offsetC,
743 // int m, int n,
744 // float alpha, float beta,
745 // int lda, int ldb)
746
747 assert(kernel);
748
749 compute::kernel_arg_list_t arg_list;
750 int argn = 0;
751 arg_list.set(argn++, ap);
752 arg_list.set(argn++, bp);
753 arg_list.set(argn++, c);
754 arg_list.set(argn++, offset_a);
755 arg_list.set(argn++, offset_b);
756 arg_list.set(argn++, offset_c);
757 arg_list.set(argn++, lda);
758 arg_list.set(argn++, ldb);
759 arg_list.set(argn++, ldc);
760 arg_list.set(argn++, m);
761 arg_list.set(argn++, n);
762 arg_list.set(argn++, k);
763 arg_list.set(argn++, alpha);
764 arg_list.set(argn++, beta);
765 if (pd()->with_ab_zero_points()) {
766 uint32_t abo = (uint16_t(ao) | (uint16_t(bo) << 16));
767 arg_list.set(argn++, abo);
768 }
769 if ((pd()->with_bias() || pd()->with_c_zero_points())) {
770 arg_list.set(argn++, co);
771 arg_list.set(argn++, offset_co);
772 }
773 if (pd()->use_new_kernels()) {
774 uint32_t flags = 0;
775 if (co_kind_ == 'R') flags |= FlagCORow;
776 if (co_kind_ == 'C') flags |= FlagCOColumn;
777 if (!first_k_block) flags |= FlagNoninitialKBlock;
778 if (!last_k_block) flags |= FlagNonfinalKBlock;
779 arg_list.set(argn++, flags);
780 }
781 if (pd()->with_batch()) {
782 arg_list.set(argn++, stride_a);
783 arg_list.set(argn++, stride_b);
784 arg_list.set(argn++, stride_c);
785 }
786
787 auto thread_m = utils::div_up(m, pd()->unroll_m() * tg_m) * tg_m;
788 auto thread_n = utils::div_up(n, pd()->unroll_n() * tg_n) * tg_n;
789
790 if (walk_n_first_) std::swap(thread_m, thread_n);
791
792 size_t gws[3] = {size_t(thread_m), size_t(thread_n), 1};
793 size_t lws[3] = {size_t(tg_m), size_t(tg_n), 1};
794 if (pd()->with_batch()) gws[2] = batch;
795
796 lws[1] *= compute_info_.wgExpand;
797 gws[1] *= compute_info_.wgExpand;
798
799 gemm_linear_order_args(arg_list, argn, lws, gws, m, n, false, compute_info_,
800 pd()->dev_info_);
801
802 lws[0] *= compute_info_.subgroupSize;
803 gws[0] *= compute_info_.subgroupSize;
804
805 auto nd_range = compute::nd_range_t(gws, lws);
806
807 return parallel_for(ctx, nd_range, kernel, arg_list);
808 }
809
execute(const gemm_exec_ctx_t & ctx) const810 status_t xe_hp_systolic_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
811 auto a_type = pd()->desc()->a_type();
812 auto b_type = pd()->desc()->b_type();
813 auto c_type = pd()->desc()->c_type();
814 auto bias_type = pd()->desc()->bias_type();
815
816 auto m = pd()->desc()->m();
817 auto n = pd()->desc()->n();
818 auto k = pd()->desc()->k();
819 auto batch = pd()->desc()->batch();
820
821 bool packed_a = pd()->packed_a();
822 bool packed_b = pd()->packed_b();
823 bool packed_c = pd()->packed_c();
824
825 auto lda = packed_a ? 0 : pd()->desc()->lda();
826 auto ldb = packed_b ? 0 : pd()->desc()->ldb();
827 auto ldc = packed_c ? pd()->ldc_packed() : pd()->desc()->ldc();
828
829 auto stride_a = pd()->desc()->stride_a();
830 auto stride_b = pd()->desc()->stride_b();
831 auto stride_c = pd()->desc()->stride_c();
832
833 auto alpha = pd()->alpha();
834 auto beta = pd()->beta();
835
836 auto &a = GEMM_CTX_ARG_STORAGE(b);
837 auto &b = GEMM_CTX_ARG_STORAGE(a);
838 auto &c = GEMM_CTX_ARG_STORAGE(c);
839 auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point);
840 auto &bias = GEMM_CTX_ARG_STORAGE(bias);
841 auto *co = &c_zp;
842
843 auto &a_packed = packed_a ? a : CTX_GPU_RES_STORAGE(A_PACKED_);
844 auto &b_packed = packed_b ? b : CTX_GPU_RES_STORAGE(B_PACKED_);
845
846 int32_t ao = 0, bo = 0;
847
848 size_t off_a0
849 = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a;
850 size_t off_b0
851 = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b;
852 size_t off_c0
853 = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c;
854 size_t off_co0 = 0;
855
856 if (pd()->with_ab_zero_points()) {
857 const int *ao_i32 = nullptr;
858 const int *bo_i32 = nullptr;
859 pd()->attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, nullptr, &ao_i32);
860 pd()->attr()->zero_points_.get(
861 DNNL_ARG_WEIGHTS, nullptr, nullptr, &bo_i32);
862 ao = -*ao_i32;
863 bo = -*bo_i32;
864 }
865
866 if (pd()->with_bias()) {
867 off_co0 = bias.offset() / types::data_type_size(bias_type);
868 co = &bias;
869 }
870
871 int64_t block_m = 0, block_n = 0, block_k = 0;
872 std::tie(block_m, block_n, block_k) = get_blocking();
873
874 auto ld_packed = get_ld_packed(k);
875 auto lda_packed = packed_a ? pd()->lda_packed() : ld_packed;
876 auto ldb_packed = packed_b ? pd()->ldb_packed() : ld_packed;
877
878 status_t status;
879
880 if (!packed_a) {
881 assert(batch == 1);
882 status = launch_copy(
883 ctx, m, k, a, off_a0, lda, a_packed, 0, lda_packed, false);
884 if (status) return status;
885 }
886
887 if (!packed_b) {
888 assert(batch == 1);
889 status = launch_copy(
890 ctx, k, n, b, off_b0, ldb, b_packed, 0, ldb_packed, true);
891 if (status) return status;
892 }
893
894 for (int64_t Bk = 0; Bk < k; Bk += block_k) {
895 int64_t size_k = k - Bk;
896 bool first_k_block = (Bk == 0);
897 bool last_k_block = (size_k <= block_k);
898 if (!last_k_block) size_k = block_k;
899
900 for (int64_t Bm = 0; Bm < m; Bm += block_m) {
901 int64_t size_m = m - Bm;
902 if (size_m > block_m) size_m = block_m;
903
904 auto off_a_packed = Bm * lda_packed + Bk * pd()->unroll_m();
905 if (packed_a) off_a_packed += off_a0;
906
907 for (int64_t Bn = 0; Bn < n; Bn += block_n) {
908 int64_t size_n = n - Bn;
909 if (size_n > block_n) size_n = block_n;
910
911 auto off_b_packed = Bn * ldb_packed + Bk * pd()->unroll_n();
912 if (packed_b) off_b_packed += off_b0;
913
914 auto off_c = off_c0 + Bm + Bn * ldc;
915 auto off_co = int32_t(off_co0);
916 switch (co_kind_) {
917 case 'R': off_co += Bm; break;
918 case 'C': off_co += Bn; break;
919 default: break;
920 }
921
922 float this_beta = first_k_block ? beta : 1.0f;
923 status = launch_compute(ctx, size_m, size_n, size_k, a_packed,
924 off_a_packed, lda_packed, b_packed, off_b_packed,
925 ldb_packed, c, off_c, ldc, alpha, this_beta, ao, bo,
926 *co, off_co, first_k_block, last_k_block, batch,
927 stride_a, stride_b, stride_c);
928 if (status) return status;
929 }
930 }
931 }
932
933 return status::success;
934 }
935
936 } // namespace jit
937 } // namespace gpu
938 } // namespace impl
939 } // namespace dnnl
940
941 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
942