1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "cpu/x64/brgemm/brgemm.hpp"
18
19 #include "common/c_types_map.hpp"
20 #include "common/nstl.hpp"
21 #include "common/type_helpers.hpp"
22 #include "common/utils.hpp"
23
24 #include "cpu/platform.hpp"
25 #include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp"
26 #include "cpu/x64/cpu_barrier.hpp"
27 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace x64 {
33
34 using namespace dnnl::impl::status;
35 using namespace dnnl::impl::utils;
36
37 using namespace prop_kind;
38 using namespace data_type;
39
40 enum {
41 decomposition_2x2 = 101,
42 decomposition_3x1_3,
43 decomposition_3x1_2,
44 not_definded,
45 };
46
brgemm_kernel_execute(const brgemm_kernel_t * brg_kernel,int bs,const brgemm_batch_element_t * batch,void * ptr_C,void * scratch)47 void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
48 const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) {
49 brgemm_kernel_params_t brgemm_p;
50
51 brgemm_p.batch = batch;
52 brgemm_p.ptr_A = nullptr;
53 brgemm_p.ptr_B = nullptr;
54 brgemm_p.ptr_C = ptr_C;
55 brgemm_p.ptr_D = ptr_C;
56 brgemm_p.ptr_buf = scratch;
57 brgemm_p.ptr_bias = nullptr;
58 brgemm_p.do_post_ops = 0;
59 brgemm_p.skip_accm = 0;
60 brgemm_p.BS = bs;
61
62 assert(brg_kernel);
63
64 (*brg_kernel)(&brgemm_p);
65 }
66
brgemm_kernel_execute(const brgemm_kernel_t * brg_kernel,int bs,const void * addr_A,const void * addr_B,const brgemm_batch_element_t * batch,void * ptr_C,void * scratch)67 void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
68 const void *addr_A, const void *addr_B,
69 const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) {
70 brgemm_kernel_params_t brgemm_p;
71
72 brgemm_p.batch = batch;
73 brgemm_p.ptr_A = addr_A;
74 brgemm_p.ptr_B = addr_B;
75 brgemm_p.ptr_C = ptr_C;
76 brgemm_p.ptr_D = ptr_C;
77 brgemm_p.ptr_buf = scratch;
78 brgemm_p.ptr_bias = nullptr;
79 brgemm_p.do_post_ops = 0;
80 brgemm_p.skip_accm = 0;
81 brgemm_p.BS = bs;
82 (*brg_kernel)(&brgemm_p);
83 }
84
brgemm_kernel_execute_postops(const brgemm_kernel_t * brg_kernel,int bs,const brgemm_batch_element_t * batch,void * ptr_C,void * ptr_D,const brgemm_post_ops_data_t & post_ops_data,void * scratch)85 void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
86 const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
87 const brgemm_post_ops_data_t &post_ops_data, void *scratch) {
88 brgemm_kernel_params_t brgemm_p;
89
90 brgemm_p.batch = batch;
91 brgemm_p.ptr_A = nullptr;
92 brgemm_p.ptr_B = nullptr;
93 brgemm_p.ptr_C = ptr_C;
94 brgemm_p.ptr_D = ptr_D;
95 brgemm_p.ptr_buf = scratch;
96 brgemm_p.ptr_bias = post_ops_data.bias;
97 brgemm_p.ptr_scales = post_ops_data.scales;
98 brgemm_p.do_post_ops = 1;
99 brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0;
100 brgemm_p.BS = bs;
101 brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs;
102 brgemm_p.oc_logical_off = post_ops_data.oc_logical_off;
103 brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off;
104 brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_;
105 brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
106 brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
107 brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
108 brgemm_p.c_zp_values = post_ops_data.c_zp_values;
109 (*brg_kernel)(&brgemm_p);
110 }
111
brgemm_kernel_execute_postops(const brgemm_kernel_t * brg_kernel,int bs,const void * addr_A,const void * addr_B,const brgemm_batch_element_t * batch,void * ptr_C,void * ptr_D,const brgemm_post_ops_data_t & post_ops_data,void * scratch)112 void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
113 const void *addr_A, const void *addr_B,
114 const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
115 const brgemm_post_ops_data_t &post_ops_data, void *scratch) {
116 brgemm_kernel_params_t brgemm_p;
117
118 brgemm_p.batch = batch;
119 brgemm_p.ptr_A = addr_A;
120 brgemm_p.ptr_B = addr_B;
121 brgemm_p.ptr_C = ptr_C;
122 brgemm_p.ptr_D = ptr_D;
123 brgemm_p.ptr_buf = scratch;
124 brgemm_p.ptr_bias = post_ops_data.bias;
125 brgemm_p.ptr_scales = post_ops_data.scales;
126 brgemm_p.do_post_ops = 1;
127 brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0;
128 brgemm_p.BS = bs;
129 brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs;
130 brgemm_p.oc_logical_off = post_ops_data.oc_logical_off;
131 brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_;
132 brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off;
133 brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
134
135 (*brg_kernel)(&brgemm_p);
136 }
137
138 namespace {
brgemm_blocking(brgemm_t * brg)139 status_t brgemm_blocking(brgemm_t *brg) {
140 if (!brg->is_int8_amx && !brg->is_bf16_amx) {
141 brg->ld_block = 16;
142 brg->ldb = brg->load_dim / brg->ld_block;
143 brg->ldb_tail = brg->load_dim % brg->ld_block;
144
145 brg->ld_block2 = 4; // (M < 9) ? 2 : 4 | TODO - fix this for INT8
146 brg->ldb2 = brg->ldb / brg->ld_block2;
147 brg->ldb2_tail = brg->ldb % brg->ld_block2;
148
149 if (brg->ldb2 == 0) brg->ld_block2 = nstl::max(1, brg->ldb2_tail);
150 brg->embd_bcst = !brg->is_int8 && !brg->is_bf16
151 && (brg->ldb2_tail <= 1 && brg->ldb2 == 0);
152
153 int ld_block = (brg->ldb2 != 0) ? brg->ld_block2 : brg->ldb2_tail;
154 int adj_ld_block = (ld_block == 0) ? (ld_block + 1) : ld_block;
155
156 const int max_avx512_regs = 32;
157 const int max_bcst_regs = 1;
158 int max_regs = max_avx512_regs - (adj_ld_block + max_bcst_regs);
159 int max_block
160 = (brg->embd_bcst ? 28
161 : ((brg->beta == 1.f || brg->beta == 0.f)
162 ? max_regs
163 : max_regs - 1));
164 max_block -= brg->req_s8s8_compensation;
165 max_block /= adj_ld_block;
166 int min_block = 1;
167 float best_bd_block_eff = 0.f;
168 brg->bd_block = 1;
169 for (int bd_block = max_block; bd_block >= min_block; bd_block--) {
170 const auto bd_block_disb = static_cast<float>(brg->bcast_dim)
171 / rnd_up(brg->bcast_dim, bd_block);
172 const auto brgemm_microkernel_eff
173 = (static_cast<float>(adj_ld_block) * bd_block)
174 / (((adj_ld_block) + bd_block) * max_block);
175 const auto bd_block_eff = bd_block_disb * brgemm_microkernel_eff;
176
177 float block_foot_print = static_cast<float>(brg->typesize_A)
178 * (bd_block * brg->reduce_dim);
179 if (block_foot_print <= static_cast<float>(
180 platform::get_per_core_cache_size(1))
181 && (bd_block_eff > best_bd_block_eff)) {
182 brg->bd_block = bd_block;
183 best_bd_block_eff = bd_block_eff;
184 }
185 }
186 brg->bdb = brg->bcast_dim / brg->bd_block;
187 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
188
189 brg->rd_block = 16 / brg->typesize_A;
190 brg->rdb = brg->reduce_dim / brg->rd_block;
191 brg->rdb_tail = brg->reduce_dim % brg->rd_block;
192
193 brg->is_M_tail = false;
194 } else {
195 // Blocking configuration for AMX
196 const int max_width = 16, min_width = 1;
197 brg->ld_block = 16;
198 brg->ldb = brg->load_dim / brg->ld_block;
199 brg->ldb_tail = brg->load_dim % brg->ld_block;
200
201 auto find_bd_block_for_bd_mask = [&]() {
202 const auto bd_mask_size = brg->bcast_dim;
203 if (brg->brgattr.bd_mask_level != 2 || bd_mask_size == 0)
204 return false;
205
206 const auto sm_buffer = brg->brgattr.bd_mask;
207 auto min_bdb = INT_MAX;
208 const auto start_bd_block = nstl::min(max_width, brg->bcast_dim);
209 auto best_bd_block = start_bd_block;
210 for (auto bd_block = start_bd_block; bd_block > 0; bd_block--) {
211 auto bdb = 0;
212 for (int i = 0; i < bd_mask_size;) {
213 if (brg->brgattr.bd_mask_level == 2 && sm_buffer[i] == 0) {
214 i++;
215 } else {
216 i += bd_block;
217 if (i > brg->bcast_dim) {
218 // bcast_dim not divided by bd_block
219 bdb = INT_MAX;
220 } else
221 bdb++;
222 }
223 }
224 if (bdb < min_bdb) {
225 min_bdb = bdb;
226 best_bd_block = bd_block;
227 }
228 }
229 brg->bd_block = best_bd_block;
230 brg->bdb_tail = 0;
231 brg->bdb = min_bdb;
232 return true;
233 };
234
235 auto set_decomposition_by_ld = [&]() {
236 if (brg->bd_block2 == 1 && brg->ldb > 0 && brg->ldb_tail == 0) {
237 if (brg->ldb % 3 == 0)
238 brg->ld_block2 = 3;
239 else if (brg->ldb % 2 == 0)
240 brg->ld_block2 = 2;
241 else
242 brg->ld_block2 = 1;
243 } else {
244 brg->ld_block2
245 = (brg->ldb > 0 && brg->ldb % 2 == 0
246 && brg->ldb_tail == 0 && brg->bd_block2 < 3)
247 ? 2
248 : 1;
249 }
250 brg->ldb2 = brg->ldb / brg->ld_block2;
251 brg->ldb2_tail = brg->ldb % brg->ld_block2;
252
253 // Re-adjust the bd_block2 if possible
254 if (brg->ld_block2 == 1 && !brg->is_M_tail && brg->ldb_tail == 0) {
255 brg->bd_block2 = (brg->bdb >= 3) ? 3 : (brg->bdb >= 2) ? 2 : 1;
256 brg->bdb2 = brg->bdb / brg->bd_block2;
257 brg->bdb2_tail = (brg->bd_block2 == 1)
258 ? brg->bdb
259 : brg->bdb % brg->bd_block2;
260 }
261 };
262
263 auto try_3x1_decomposition = [&](int width_step) {
264 brg->is_M_tail = false;
265 if (brg->bcast_dim > (width_step - 1) * max_width
266 && brg->bcast_dim < width_step * max_width
267 && brg->ldb_tail == 0) {
268 if (!find_bd_block_for_bd_mask()) {
269 brg->bd_block = max_width;
270 brg->bdb = div_up(brg->bcast_dim, brg->bd_block);
271 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
272 brg->is_M_tail = true;
273 }
274 brg->bd_block2 = width_step;
275 brg->bdb2 = brg->bdb / brg->bd_block2;
276 brg->bdb2_tail = brg->bdb % brg->bd_block2;
277 set_decomposition_by_ld();
278 return true;
279 }
280 return false;
281 };
282
283 auto try_2x2_decomposition = [&]() {
284 if (!find_bd_block_for_bd_mask()) {
285 for (int m_block = max_width; m_block >= min_width; m_block--) {
286 if (brg->bcast_dim % m_block == 0) {
287 brg->bd_block = m_block;
288 break;
289 }
290 }
291 if (brg->bd_block == 1) {
292 brg->bd_block = nstl::min(max_width, brg->bcast_dim);
293 brg->bdb_tail = brg->bcast_dim % max_width;
294 for (int i = max_width; i >= min_width; i--) {
295 int i_tail = brg->bcast_dim % i;
296 if (i_tail > brg->bdb_tail || i_tail == 0) {
297 brg->bd_block = i;
298 brg->bdb_tail = i_tail;
299 if (i_tail == 0) break;
300 }
301 }
302 }
303 brg->bdb = brg->bcast_dim / brg->bd_block;
304 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
305 }
306
307 brg->bd_block2 = (brg->bdb >= 2) ? 2 : 1;
308 brg->bdb2 = brg->bdb / brg->bd_block2;
309 brg->bdb2_tail = (brg->bd_block2 == 1) ? brg->bdb
310 : brg->bdb % brg->bd_block2;
311
312 brg->is_M_tail = false;
313
314 set_decomposition_by_ld();
315
316 return !(brg->ld_block2 == 1 || brg->bd_block2 == 1
317 || brg->bd_block < 8);
318 };
319
320 bool is_decomposition_defined = false;
321 for (int i = decomposition_2x2; i != not_definded; i++) {
322 switch (i) {
323 case decomposition_2x2:
324 is_decomposition_defined = try_2x2_decomposition();
325 break;
326 case decomposition_3x1_3:
327 is_decomposition_defined = try_3x1_decomposition(3);
328 break;
329 case decomposition_3x1_2:
330 is_decomposition_defined = try_3x1_decomposition(2);
331 break;
332 default: assert(!"invalid value"); break;
333 };
334 if (is_decomposition_defined) break;
335 }
336 if (!is_decomposition_defined) try_2x2_decomposition();
337
338 brg->rd_block = brg->is_bf16_amx ? 32 : 64;
339 brg->rdb = brg->reduce_dim / brg->rd_block;
340 brg->rdb_tail = brg->reduce_dim % brg->rd_block;
341
342 // Remove these guard in the future (add tail processing by reduction dimension)
343 if (brg->rdb > 0 && brg->rdb_tail) return status::unimplemented;
344 if (brg->rdb_tail % ((brg->is_bf16_amx) ? 2 : 4))
345 return status::unimplemented;
346 }
347
348 return status::success;
349 }
350 } // namespace
351
brgemm_desc_init(brgemm_t * brg,cpu_isa_t isa,brgemm_batch_kind_t type,impl::data_type_t dt_a,impl::data_type_t dt_b,bool transA,bool transB,brgemm_layout_t layout,float alpha,float beta,dim_t LDA,dim_t LDB,dim_t LDC,dim_t M,dim_t N,dim_t K,const brgemm_strides_t * strides)352 status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
353 brgemm_batch_kind_t type, impl::data_type_t dt_a,
354 impl::data_type_t dt_b, bool transA, bool transB,
355 brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
356 dim_t LDC, dim_t M, dim_t N, dim_t K, const brgemm_strides_t *strides) {
357 /*
358 m - number of rows of the matrix op(A) and number of rows of the matrix C
359 n - number of columns of the matrix op(B) and number of columns of the matrix C
360 k - number of columns of the matrix op(A) and number of rows of the matrix op(B)
361
362 Matrices are in row-major layouts:
363 A: lda * m, LDA - lda must be at least max(1, k)
364 B: ldb * k, LDB - ldb must be at least max(1, n)
365 C: ldc * m, LDC - ldc must be at least max(1, n)
366
367 Matrices are in column-major layouts:
368 A: lda * k, LDA - lda must be at least max(1, m)
369 B: ldb * n, LDB - ldb must be at least max(1, k)
370 C: ldc * n, LDC - ldc must be at least max(1, m)
371 */
372 if (brg == nullptr) return status::invalid_arguments;
373 if (transA || transB) return status::unimplemented;
374
375 brg->layout = layout;
376 auto is_row_major = [&]() { return brg->layout == brgemm_row_major; };
377 if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
378 bool ldx_check = (is_row_major()) ? (LDA < K || LDB < N || LDC < N)
379 : (LDA < M || LDB < K || LDC < M);
380 if (ldx_check) return status::invalid_arguments;
381
382 brg->dt_a = (is_row_major()) ? dt_a : dt_b;
383 brg->dt_b = (is_row_major()) ? dt_b : dt_a;
384
385 brg->is_int8 = (one_of(brg->dt_a, data_type::u8, data_type::s8)
386 && brg->dt_b == data_type::s8);
387 brg->is_bf16
388 = (brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16);
389 brg->is_f32 = (brg->dt_a == data_type::f32 && brg->dt_b == data_type::f32);
390 if (!brg->is_int8 && !brg->is_bf16 && !brg->is_f32)
391 return status::unimplemented;
392 brg->dt_c = (brg->is_int8) ? data_type::s32 : data_type::f32;
393 brg->dt_d = brg->dt_c;
394 brg->dt_bias = brg->dt_c;
395
396 if (!IMPLICATION(brg->is_f32, mayiuse(avx512_core)))
397 return status::unimplemented;
398 if (!IMPLICATION(brg->is_bf16, mayiuse(avx512_core_bf16)))
399 return status::unimplemented;
400 if (!IMPLICATION(brg->is_int8, mayiuse(avx512_core_vnni)))
401 return status::unimplemented;
402
403 if (isa != isa_any) {
404 if (!one_of(isa, avx512_core, avx512_core_bf16, avx512_core_vnni,
405 avx512_core_bf16_amx_bf16, avx512_core_bf16_amx_int8)) {
406 return status::invalid_arguments;
407 }
408 brg->is_int8_amx = brg->is_bf16_amx = false;
409 if (brg->is_int8 && isa == avx512_core_bf16_amx_int8) {
410 if (!mayiuse(avx512_core_bf16_amx_int8))
411 return status::invalid_arguments;
412 brg->is_int8_amx = true;
413 }
414 if (brg->is_bf16 && isa == avx512_core_bf16_amx_bf16) {
415 if (!mayiuse(avx512_core_bf16_amx_bf16))
416 return status::invalid_arguments;
417 brg->is_bf16_amx = true;
418 }
419 } else {
420 brg->is_int8_amx = brg->is_int8 && mayiuse(avx512_core_bf16_amx_int8);
421 brg->is_bf16_amx = brg->is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
422 }
423 brg->is_amx = (brg->is_int8_amx || brg->is_bf16_amx);
424 brg->req_s8s8_compensation
425 = brg->is_int8 && !brg->is_int8_amx && brg->dt_a == data_type::s8;
426 brg->LDA = (is_row_major()) ? static_cast<int>(LDA) : static_cast<int>(LDB);
427 brg->LDB = (is_row_major()) ? static_cast<int>(LDB) : static_cast<int>(LDA);
428
429 brg->LDC = static_cast<int>(LDC);
430 brg->LDD = static_cast<int>(LDC);
431
432 brg->bcast_dim
433 = (is_row_major()) ? static_cast<int>(M) : static_cast<int>(N);
434 brg->load_dim
435 = (is_row_major()) ? static_cast<int>(N) : static_cast<int>(M);
436 brg->reduce_dim = static_cast<int>(K);
437
438 brg->with_bias = false;
439 brg->with_eltwise = false;
440 brg->with_sum = false;
441 brg->sum_scale = 0;
442 brg->sum_zp = 0;
443 brg->with_scales = false;
444
445 brg->beta = beta;
446 brg->alpha = alpha;
447
448 brg->typesize_A = types::data_type_size(brg->dt_a);
449 brg->typesize_B = types::data_type_size(brg->dt_b);
450 brg->typesize_C = types::data_type_size(brg->dt_c);
451 brg->typesize_D = types::data_type_size(brg->dt_d);
452 brg->type = type;
453
454 brg->bd_block2 = 0;
455 brg->bdb2 = 0;
456 brg->bdb2_tail = 0;
457
458 brg->ld_step = brg->rd_step = 4 / brg->typesize_A;
459
460 if (strides != nullptr) {
461 brg->stride_a = strides->stride_a;
462 brg->stride_b = strides->stride_b;
463 } else {
464 brg->stride_a = brg->stride_b = 0;
465 }
466
467 CHECK(brgemm_blocking(brg));
468
469 return status::success;
470 }
471
brdgmm_desc_init(brgemm_t * brg,cpu_isa_t isa,brgemm_batch_kind_t type,impl::data_type_t dt_a,impl::data_type_t dt_b,bool transA,brgemm_layout_t layout,float alpha,float beta,dim_t LDA,dim_t LDC,dim_t M,dim_t N,const brgemm_strides_t * strides)472 status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
473 brgemm_batch_kind_t type, impl::data_type_t dt_a,
474 impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
475 float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
476 const brgemm_strides_t *strides) {
477
478 if (brg == nullptr) return status::invalid_arguments;
479 if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f)
480 return status::unimplemented;
481
482 const bool ldx_check = (LDA < N || LDC < N);
483 if (ldx_check) return status::invalid_arguments;
484
485 brg->dt_a = dt_a;
486 brg->dt_b = dt_b;
487
488 brg->is_int8 = one_of(brg->dt_a, data_type::u8, data_type::s8)
489 && (brg->dt_b == data_type::s8);
490 brg->is_bf16
491 = (brg->dt_a == data_type::bf16) && (brg->dt_b == data_type::bf16);
492 brg->is_f32
493 = (brg->dt_a == data_type::f32) && (brg->dt_b == data_type::f32);
494 if (!brg->is_int8 && !brg->is_bf16 && !brg->is_f32)
495 return status::unimplemented;
496 brg->dt_c = (brg->is_int8) ? data_type::s32 : data_type::f32;
497 brg->dt_d = brg->dt_c;
498 brg->dt_bias = brg->dt_c;
499
500 const cpu_isa_t req_isa = brg->is_f32
501 ? avx512_core
502 : (brg->is_int8 ? avx512_core_vnni : avx512_core_bf16);
503 if (!(is_superset(isa, req_isa) && mayiuse(req_isa)))
504 return status::unimplemented;
505
506 brg->is_bf16_amx = brg->is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
507 brg->is_dgmm = true;
508 brg->type = type;
509 brg->layout = layout;
510 brg->alpha = alpha;
511 brg->beta = beta;
512
513 brg->LDA = static_cast<int>(LDA);
514 brg->LDC = static_cast<int>(LDC);
515 brg->LDD = static_cast<int>(LDC);
516
517 brg->typesize_A = types::data_type_size(brg->dt_a);
518 brg->typesize_B = types::data_type_size(brg->dt_b);
519 brg->typesize_C = types::data_type_size(brg->dt_c);
520 brg->typesize_D = types::data_type_size(brg->dt_d);
521
522 // In current implementation of dgmm, there is no reduce dim.
523 // Also, bcast and load dimensions refer to M and N.
524
525 // auto &M = brg->bcast_dim;
526 // auto &N = brg->load_dim;
527 auto &m_vlen_blk = brg->bd_block;
528 auto &nb_m_vlen_blk = brg->bdb;
529 auto &m_vlen_tail = brg->bdb_tail;
530 auto &m_blocking = brg->bd_block2;
531 auto &nb_m_blocking = brg->bdb2;
532 auto &m_blocking_tail = brg->bdb2_tail;
533
534 auto &n_vlen_blk = brg->ld_block;
535 auto &nb_n_vlen_blk = brg->ldb;
536 auto &n_vlen_tail = brg->ldb_tail;
537 auto &n_blocking = brg->ld_block2;
538 auto &nb_n_blocking = brg->ldb2;
539 auto &n_blocking_tail = brg->ldb2_tail;
540
541 brg->bcast_dim = M;
542 brg->load_dim = N;
543 const int simd_w = 16;
544
545 // begin blocking
546 n_vlen_blk = simd_w;
547 nb_n_vlen_blk = div_up(N, n_vlen_blk);
548 n_vlen_tail = N % n_vlen_blk;
549 n_blocking = nstl::min(4, nb_n_vlen_blk);
550 nb_n_blocking = div_up(nb_n_vlen_blk, n_blocking);
551 n_blocking_tail = nb_n_vlen_blk % n_blocking;
552
553 const int max_acc_zmms = 32 - 2 /*zmma, zmmb, post-ops, saturation*/
554 - jit_brdgmm_kernel_base_t::is_fast_vnni_int8(*brg) /*perm dst*/;
555 m_vlen_blk = 1;
556 nb_m_vlen_blk = M / m_vlen_blk;
557 m_vlen_tail = M % m_vlen_blk;
558 m_blocking = nstl::min(nb_m_vlen_blk, max_acc_zmms / n_blocking);
559 nb_m_blocking = div_up(nb_m_vlen_blk, m_blocking);
560 m_blocking_tail = nb_m_vlen_blk % m_blocking;
561
562 if (strides != nullptr) {
563 brg->stride_a = strides->stride_a;
564 brg->stride_b = strides->stride_b;
565 } else {
566 brg->stride_a = brg->stride_b = 0;
567 }
568
569 return status::success;
570 }
571
brgemm_desc_set_postops(brgemm_t * brg,const primitive_attr_t * attr,const memory_desc_t * dst_md,int LDD,impl::data_type_t dt_bias)572 status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
573 const memory_desc_t *dst_md, int LDD, impl::data_type_t dt_bias) {
574 if (!brg || !dst_md) return status::invalid_arguments;
575
576 brg->attr = attr;
577 brg->dst_md = dst_md;
578
579 brg->with_bias = (dt_bias == data_type::undef) ? false : true;
580 brg->dt_bias = dt_bias;
581 brg->typesize_bias = (dt_bias == data_type::undef)
582 ? 0
583 : types::data_type_size(brg->dt_bias);
584
585 brg->LDD = LDD;
586 const auto dt_d = dst_md->data_type;
587
588 if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
589 && (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
590 data_type::f32))
591 && (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8,
592 data_type::s32, data_type::f32, data_type::bf16)))
593 return status::unimplemented;
594 if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16)
595 && (!one_of(dt_d, data_type::bf16, data_type::f32))
596 && (!one_of(dt_bias, data_type::undef, data_type::bf16,
597 data_type::f32)))
598 return status::unimplemented;
599 if ((brg->dt_a == data_type::f32 && brg->dt_b == data_type::f32)
600 && (!one_of(dt_d, data_type::f32))
601 && (!one_of(dt_bias, data_type::undef, data_type::f32)))
602 return status::unimplemented;
603
604 brg->dt_d = dt_d;
605 brg->typesize_D = types::data_type_size(brg->dt_d);
606
607 if (!IMPLICATION(
608 brg->is_int8 && brg->dt_d == bf16, mayiuse(avx512_core_bf16)))
609 return status::unimplemented;
610
611 if (!brg->attr) return status::success;
612
613 using namespace injector;
614
615 const auto &post_ops = brg->attr->post_ops_;
616 const memory_desc_wrapper dst_d(dst_md);
617
618 const int binary_ind = post_ops.find(primitive_kind::binary);
619 brg->with_binary = binary_ind != -1;
620 const cpu_isa_t isa = get_max_cpu_isa();
621
622 if ((brg->with_binary && !dst_md)
623 || !injector::post_ops_ok(
624 post_ops_ok_args_t(isa, {sum, eltwise, binary}, post_ops,
625 &dst_d, false /*sum_at_pos_0_only*/,
626 false /*sum_requires_scale_one*/,
627 false /*sum_requires_zp_zero*/,
628 {broadcasting_strategy_t::per_oc,
629 broadcasting_strategy_t::scalar,
630 broadcasting_strategy_t::per_mb_spatial,
631 broadcasting_strategy_t::per_mb_w,
632 broadcasting_strategy_t::no_broadcast})))
633 return status::unimplemented;
634
635 const int sum_idx = post_ops.find(primitive_kind::sum);
636 const bool with_sum = sum_idx != -1;
637 brg->with_sum = with_sum;
638 brg->sum_scale = with_sum ? post_ops.entry_[sum_idx].sum.scale : 0;
639 brg->sum_zp = with_sum ? post_ops.entry_[sum_idx].sum.zero_point : 0;
640 const auto sum_dt
641 = with_sum ? post_ops.entry_[sum_idx].sum.dt : data_type::undef;
642 brg->sum_dt = sum_dt != data_type::undef ? sum_dt : dt_d;
643
644 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
645 brg->with_eltwise = eltwise_ind != -1;
646
647 brg->with_scales = !attr->output_scales_.has_default_values();
648 if (brg->with_scales) {
649 const auto &oscales = brg->attr->output_scales_;
650 // Note. the current version supports only two different output scale
651 // types:
652 // 1) common (mask_ = 0)
653 // 2) per_n_dim_scale - broadcast across n dimension;
654 // for convolution and inner product promitives it corresponds
655 // to "per_oc" mask_ = 1 << 1; for matmul - to
656 // mask_ = (1 << (ndims - 1))), where ndims is number of
657 // dimensions for original matmul problem
658 // So if oscales.mask_ != 0 (not common) it's assumed here that scale
659 // type is per_n_dim_scale and driver which calls brgemm kernel checked
660 // that mask has correct value for this case
661 brg->is_oc_scale = oscales.mask_ != 0;
662 }
663
664 auto init_zp_type
665 = [&](brgemm_broadcast_t &zp_type, int mem_arg) -> status_t {
666 auto zero_points = attr->zero_points_;
667
668 // common zero point type is supported for now
669 if (!zero_points.common(mem_arg)) return status::unimplemented;
670
671 zp_type = zero_points.has_default_values(mem_arg)
672 ? brgemm_broadcast_t::none
673 : brgemm_broadcast_t::per_tensor;
674 return status::success;
675 };
676
677 init_zp_type(brg->zp_type_a, DNNL_ARG_SRC);
678 init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS);
679 init_zp_type(brg->zp_type_c, DNNL_ARG_DST);
680
681 return status::success;
682 }
683
brgemm_desc_set_attr(brgemm_t * brg,const brgemm_attr_t & brgattr)684 status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
685 if (brg == nullptr) return status::invalid_arguments;
686
687 // negative padding is not supported
688 if (brgattr.max_top_vpad < 0 || brgattr.max_bottom_vpad < 0)
689 return status::unimplemented;
690
691 // virtual padding is not supported for "amx"
692 if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0)
693 && (brg->is_amx))
694 return status::unimplemented;
695
696 if (!brg->is_dgmm) {
697 // virtual padding size is restricted by MAX_VPAD value
698 if (brgattr.max_top_vpad > brgemm_t::MAX_VPAD
699 || brgattr.max_bottom_vpad > brgemm_t::MAX_VPAD)
700 return status::unimplemented;
701
702 // virtual padding is restricted by bd_block size due to
703 // brgemm_kernel implementation. TODO: remove this restriction
704 if (brgattr.max_top_vpad > brg->bd_block
705 || brgattr.max_bottom_vpad > brg->bd_block)
706 return status::unimplemented;
707 }
708
709 // virtual padding is supported for "brgemm_row_major" layout
710 // TODO: remove this restriction
711 if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0)
712 && brg->layout != brgemm_row_major)
713 return status::unimplemented;
714
715 brg->brgattr = brgattr;
716
717 if (brgattr.bd_mask_level) brgemm_blocking(brg);
718
719 return status::success;
720 }
721
brgemm_kernel_create(brgemm_kernel_t ** brg_kernel,const brgemm_t & brg)722 status_t brgemm_kernel_create(
723 brgemm_kernel_t **brg_kernel, const brgemm_t &brg) {
724 if (brg.is_dgmm) {
725 CHECK(safe_ptr_assign<brgemm_kernel_t>(
726 *brg_kernel, new brdgmm_kernel_t(brg)));
727 return (*brg_kernel)->create_kernel();
728 } else if (brg.is_amx && brg.type == brgemm_addr && brg.brgattr.max_bs >= 1
729 && brg.brgattr.use_uker) {
730 CHECK(safe_ptr_assign<brgemm_kernel_t>(
731 *brg_kernel, new brgemm_amx_uker_t(brg)));
732 return (*brg_kernel)->create_kernel();
733 } else {
734 CHECK(safe_ptr_assign<brgemm_kernel_t>(
735 *brg_kernel, new brgemm_kernel_common_t(brg)));
736 return (*brg_kernel)->create_kernel();
737 }
738 }
739
brgemm_kernel_destroy(brgemm_kernel_t * brg_kernel)740 void brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel) {
741 delete brg_kernel;
742 }
743
brgemm_init_tiles(const brgemm_t & brg,char palette[64])744 status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
745 constexpr int max_palette_size_in_bytes = 64;
746
747 if (!brg.is_amx) return status::unimplemented;
748
749 //TODO: Add support of tail processing by reduction dimension
750 int rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block;
751
752 palette_config_t *buff = (palette_config_t *)(palette);
753
754 char *_tc = (char *)(buff);
755 for (int i = 0; i < max_palette_size_in_bytes; i++)
756 _tc[i] = 0;
757
758 int rd_step = 4 / brg.typesize_A;
759
760 int Ac = brg.typesize_A * rd_block;
761
762 int Bc = brg.ld_block * brg.typesize_B * rd_step;
763 int Bc_t = brg.ldb_tail * brg.typesize_B * rd_step;
764
765 int Cc = brg.ld_block * brg.typesize_C;
766 int Cc_t = brg.ldb_tail * brg.typesize_C;
767
768 int Br = (brg.typesize_C != 0) ? Ac / brg.typesize_C : 0;
769
770 if (brg.ldb_tail && (brg.ld_block2 > 1)) return status::unimplemented;
771
772 for (int m = 0; m < brg.bd_block2; m++) {
773 int Ar = (brg.is_M_tail && m == brg.bd_block2 - 1) ? brg.bdb_tail
774 : brg.bd_block;
775 tc_configure_tile(buff, brg.get_A_tensor(m), Ar, Ac);
776 }
777
778 for (int n = 0; n < brg.ld_block2; n++)
779 tc_configure_tile(buff, brg.get_B_tensor(n), Br, Bc);
780 if (brg.ldb_tail)
781 tc_configure_tile(buff, brg.get_B_tensor(brg.ld_block2), Br, Bc_t);
782
783 for (int m = 0; m < brg.bd_block2; m++) {
784 int Cr = (brg.is_M_tail && m == brg.bd_block2 - 1) ? brg.bdb_tail
785 : brg.bd_block;
786 for (int n = 0; n < brg.ld_block2; n++)
787 tc_configure_tile(buff, brg.get_C_tensor(m, n), Cr, Cc);
788 if (brg.ldb_tail)
789 tc_configure_tile(
790 buff, brg.get_C_tensor(m, brg.ld_block2), Cr, Cc_t);
791 }
792 buff->palette_id = amx::get_max_palette();
793
794 return status::success;
795 }
796
797 } // namespace x64
798 } // namespace cpu
799 } // namespace impl
800 } // namespace dnnl
801
802 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
803