1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "cpu/x64/brgemm/brgemm.hpp"
18 
19 #include "common/c_types_map.hpp"
20 #include "common/nstl.hpp"
21 #include "common/type_helpers.hpp"
22 #include "common/utils.hpp"
23 
24 #include "cpu/platform.hpp"
25 #include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp"
26 #include "cpu/x64/cpu_barrier.hpp"
27 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace x64 {
33 
34 using namespace dnnl::impl::status;
35 using namespace dnnl::impl::utils;
36 
37 using namespace prop_kind;
38 using namespace data_type;
39 
40 enum {
41     decomposition_2x2 = 101,
42     decomposition_3x1_3,
43     decomposition_3x1_2,
44     not_definded,
45 };
46 
brgemm_kernel_execute(const brgemm_kernel_t * brg_kernel,int bs,const brgemm_batch_element_t * batch,void * ptr_C,void * scratch)47 void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
48         const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) {
49     brgemm_kernel_params_t brgemm_p;
50 
51     brgemm_p.batch = batch;
52     brgemm_p.ptr_A = nullptr;
53     brgemm_p.ptr_B = nullptr;
54     brgemm_p.ptr_C = ptr_C;
55     brgemm_p.ptr_D = ptr_C;
56     brgemm_p.ptr_buf = scratch;
57     brgemm_p.ptr_bias = nullptr;
58     brgemm_p.do_post_ops = 0;
59     brgemm_p.skip_accm = 0;
60     brgemm_p.BS = bs;
61 
62     assert(brg_kernel);
63 
64     (*brg_kernel)(&brgemm_p);
65 }
66 
brgemm_kernel_execute(const brgemm_kernel_t * brg_kernel,int bs,const void * addr_A,const void * addr_B,const brgemm_batch_element_t * batch,void * ptr_C,void * scratch)67 void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs,
68         const void *addr_A, const void *addr_B,
69         const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) {
70     brgemm_kernel_params_t brgemm_p;
71 
72     brgemm_p.batch = batch;
73     brgemm_p.ptr_A = addr_A;
74     brgemm_p.ptr_B = addr_B;
75     brgemm_p.ptr_C = ptr_C;
76     brgemm_p.ptr_D = ptr_C;
77     brgemm_p.ptr_buf = scratch;
78     brgemm_p.ptr_bias = nullptr;
79     brgemm_p.do_post_ops = 0;
80     brgemm_p.skip_accm = 0;
81     brgemm_p.BS = bs;
82     (*brg_kernel)(&brgemm_p);
83 }
84 
brgemm_kernel_execute_postops(const brgemm_kernel_t * brg_kernel,int bs,const brgemm_batch_element_t * batch,void * ptr_C,void * ptr_D,const brgemm_post_ops_data_t & post_ops_data,void * scratch)85 void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
86         const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
87         const brgemm_post_ops_data_t &post_ops_data, void *scratch) {
88     brgemm_kernel_params_t brgemm_p;
89 
90     brgemm_p.batch = batch;
91     brgemm_p.ptr_A = nullptr;
92     brgemm_p.ptr_B = nullptr;
93     brgemm_p.ptr_C = ptr_C;
94     brgemm_p.ptr_D = ptr_D;
95     brgemm_p.ptr_buf = scratch;
96     brgemm_p.ptr_bias = post_ops_data.bias;
97     brgemm_p.ptr_scales = post_ops_data.scales;
98     brgemm_p.do_post_ops = 1;
99     brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0;
100     brgemm_p.BS = bs;
101     brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs;
102     brgemm_p.oc_logical_off = post_ops_data.oc_logical_off;
103     brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off;
104     brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_;
105     brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
106     brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations;
107     brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations;
108     brgemm_p.c_zp_values = post_ops_data.c_zp_values;
109     (*brg_kernel)(&brgemm_p);
110 }
111 
brgemm_kernel_execute_postops(const brgemm_kernel_t * brg_kernel,int bs,const void * addr_A,const void * addr_B,const brgemm_batch_element_t * batch,void * ptr_C,void * ptr_D,const brgemm_post_ops_data_t & post_ops_data,void * scratch)112 void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs,
113         const void *addr_A, const void *addr_B,
114         const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D,
115         const brgemm_post_ops_data_t &post_ops_data, void *scratch) {
116     brgemm_kernel_params_t brgemm_p;
117 
118     brgemm_p.batch = batch;
119     brgemm_p.ptr_A = addr_A;
120     brgemm_p.ptr_B = addr_B;
121     brgemm_p.ptr_C = ptr_C;
122     brgemm_p.ptr_D = ptr_D;
123     brgemm_p.ptr_buf = scratch;
124     brgemm_p.ptr_bias = post_ops_data.bias;
125     brgemm_p.ptr_scales = post_ops_data.scales;
126     brgemm_p.do_post_ops = 1;
127     brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0;
128     brgemm_p.BS = bs;
129     brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs;
130     brgemm_p.oc_logical_off = post_ops_data.oc_logical_off;
131     brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_;
132     brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off;
133     brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off;
134 
135     (*brg_kernel)(&brgemm_p);
136 }
137 
138 namespace {
brgemm_blocking(brgemm_t * brg)139 status_t brgemm_blocking(brgemm_t *brg) {
140     if (!brg->is_int8_amx && !brg->is_bf16_amx) {
141         brg->ld_block = 16;
142         brg->ldb = brg->load_dim / brg->ld_block;
143         brg->ldb_tail = brg->load_dim % brg->ld_block;
144 
145         brg->ld_block2 = 4; // (M < 9) ? 2 : 4 | TODO - fix this for INT8
146         brg->ldb2 = brg->ldb / brg->ld_block2;
147         brg->ldb2_tail = brg->ldb % brg->ld_block2;
148 
149         if (brg->ldb2 == 0) brg->ld_block2 = nstl::max(1, brg->ldb2_tail);
150         brg->embd_bcst = !brg->is_int8 && !brg->is_bf16
151                 && (brg->ldb2_tail <= 1 && brg->ldb2 == 0);
152 
153         int ld_block = (brg->ldb2 != 0) ? brg->ld_block2 : brg->ldb2_tail;
154         int adj_ld_block = (ld_block == 0) ? (ld_block + 1) : ld_block;
155 
156         const int max_avx512_regs = 32;
157         const int max_bcst_regs = 1;
158         int max_regs = max_avx512_regs - (adj_ld_block + max_bcst_regs);
159         int max_block
160                 = (brg->embd_bcst ? 28
161                                   : ((brg->beta == 1.f || brg->beta == 0.f)
162                                                   ? max_regs
163                                                   : max_regs - 1));
164         max_block -= brg->req_s8s8_compensation;
165         max_block /= adj_ld_block;
166         int min_block = 1;
167         float best_bd_block_eff = 0.f;
168         brg->bd_block = 1;
169         for (int bd_block = max_block; bd_block >= min_block; bd_block--) {
170             const auto bd_block_disb = static_cast<float>(brg->bcast_dim)
171                     / rnd_up(brg->bcast_dim, bd_block);
172             const auto brgemm_microkernel_eff
173                     = (static_cast<float>(adj_ld_block) * bd_block)
174                     / (((adj_ld_block) + bd_block) * max_block);
175             const auto bd_block_eff = bd_block_disb * brgemm_microkernel_eff;
176 
177             float block_foot_print = static_cast<float>(brg->typesize_A)
178                     * (bd_block * brg->reduce_dim);
179             if (block_foot_print <= static_cast<float>(
180                         platform::get_per_core_cache_size(1))
181                     && (bd_block_eff > best_bd_block_eff)) {
182                 brg->bd_block = bd_block;
183                 best_bd_block_eff = bd_block_eff;
184             }
185         }
186         brg->bdb = brg->bcast_dim / brg->bd_block;
187         brg->bdb_tail = brg->bcast_dim % brg->bd_block;
188 
189         brg->rd_block = 16 / brg->typesize_A;
190         brg->rdb = brg->reduce_dim / brg->rd_block;
191         brg->rdb_tail = brg->reduce_dim % brg->rd_block;
192 
193         brg->is_M_tail = false;
194     } else {
195         // Blocking configuration for AMX
196         const int max_width = 16, min_width = 1;
197         brg->ld_block = 16;
198         brg->ldb = brg->load_dim / brg->ld_block;
199         brg->ldb_tail = brg->load_dim % brg->ld_block;
200 
201         auto find_bd_block_for_bd_mask = [&]() {
202             const auto bd_mask_size = brg->bcast_dim;
203             if (brg->brgattr.bd_mask_level != 2 || bd_mask_size == 0)
204                 return false;
205 
206             const auto sm_buffer = brg->brgattr.bd_mask;
207             auto min_bdb = INT_MAX;
208             const auto start_bd_block = nstl::min(max_width, brg->bcast_dim);
209             auto best_bd_block = start_bd_block;
210             for (auto bd_block = start_bd_block; bd_block > 0; bd_block--) {
211                 auto bdb = 0;
212                 for (int i = 0; i < bd_mask_size;) {
213                     if (brg->brgattr.bd_mask_level == 2 && sm_buffer[i] == 0) {
214                         i++;
215                     } else {
216                         i += bd_block;
217                         if (i > brg->bcast_dim) {
218                             // bcast_dim not divided by bd_block
219                             bdb = INT_MAX;
220                         } else
221                             bdb++;
222                     }
223                 }
224                 if (bdb < min_bdb) {
225                     min_bdb = bdb;
226                     best_bd_block = bd_block;
227                 }
228             }
229             brg->bd_block = best_bd_block;
230             brg->bdb_tail = 0;
231             brg->bdb = min_bdb;
232             return true;
233         };
234 
235         auto set_decomposition_by_ld = [&]() {
236             if (brg->bd_block2 == 1 && brg->ldb > 0 && brg->ldb_tail == 0) {
237                 if (brg->ldb % 3 == 0)
238                     brg->ld_block2 = 3;
239                 else if (brg->ldb % 2 == 0)
240                     brg->ld_block2 = 2;
241                 else
242                     brg->ld_block2 = 1;
243             } else {
244                 brg->ld_block2
245                         = (brg->ldb > 0 && brg->ldb % 2 == 0
246                                   && brg->ldb_tail == 0 && brg->bd_block2 < 3)
247                         ? 2
248                         : 1;
249             }
250             brg->ldb2 = brg->ldb / brg->ld_block2;
251             brg->ldb2_tail = brg->ldb % brg->ld_block2;
252 
253             // Re-adjust the bd_block2 if possible
254             if (brg->ld_block2 == 1 && !brg->is_M_tail && brg->ldb_tail == 0) {
255                 brg->bd_block2 = (brg->bdb >= 3) ? 3 : (brg->bdb >= 2) ? 2 : 1;
256                 brg->bdb2 = brg->bdb / brg->bd_block2;
257                 brg->bdb2_tail = (brg->bd_block2 == 1)
258                         ? brg->bdb
259                         : brg->bdb % brg->bd_block2;
260             }
261         };
262 
263         auto try_3x1_decomposition = [&](int width_step) {
264             brg->is_M_tail = false;
265             if (brg->bcast_dim > (width_step - 1) * max_width
266                     && brg->bcast_dim < width_step * max_width
267                     && brg->ldb_tail == 0) {
268                 if (!find_bd_block_for_bd_mask()) {
269                     brg->bd_block = max_width;
270                     brg->bdb = div_up(brg->bcast_dim, brg->bd_block);
271                     brg->bdb_tail = brg->bcast_dim % brg->bd_block;
272                     brg->is_M_tail = true;
273                 }
274                 brg->bd_block2 = width_step;
275                 brg->bdb2 = brg->bdb / brg->bd_block2;
276                 brg->bdb2_tail = brg->bdb % brg->bd_block2;
277                 set_decomposition_by_ld();
278                 return true;
279             }
280             return false;
281         };
282 
283         auto try_2x2_decomposition = [&]() {
284             if (!find_bd_block_for_bd_mask()) {
285                 for (int m_block = max_width; m_block >= min_width; m_block--) {
286                     if (brg->bcast_dim % m_block == 0) {
287                         brg->bd_block = m_block;
288                         break;
289                     }
290                 }
291                 if (brg->bd_block == 1) {
292                     brg->bd_block = nstl::min(max_width, brg->bcast_dim);
293                     brg->bdb_tail = brg->bcast_dim % max_width;
294                     for (int i = max_width; i >= min_width; i--) {
295                         int i_tail = brg->bcast_dim % i;
296                         if (i_tail > brg->bdb_tail || i_tail == 0) {
297                             brg->bd_block = i;
298                             brg->bdb_tail = i_tail;
299                             if (i_tail == 0) break;
300                         }
301                     }
302                 }
303                 brg->bdb = brg->bcast_dim / brg->bd_block;
304                 brg->bdb_tail = brg->bcast_dim % brg->bd_block;
305             }
306 
307             brg->bd_block2 = (brg->bdb >= 2) ? 2 : 1;
308             brg->bdb2 = brg->bdb / brg->bd_block2;
309             brg->bdb2_tail = (brg->bd_block2 == 1) ? brg->bdb
310                                                    : brg->bdb % brg->bd_block2;
311 
312             brg->is_M_tail = false;
313 
314             set_decomposition_by_ld();
315 
316             return !(brg->ld_block2 == 1 || brg->bd_block2 == 1
317                     || brg->bd_block < 8);
318         };
319 
320         bool is_decomposition_defined = false;
321         for (int i = decomposition_2x2; i != not_definded; i++) {
322             switch (i) {
323                 case decomposition_2x2:
324                     is_decomposition_defined = try_2x2_decomposition();
325                     break;
326                 case decomposition_3x1_3:
327                     is_decomposition_defined = try_3x1_decomposition(3);
328                     break;
329                 case decomposition_3x1_2:
330                     is_decomposition_defined = try_3x1_decomposition(2);
331                     break;
332                 default: assert(!"invalid value"); break;
333             };
334             if (is_decomposition_defined) break;
335         }
336         if (!is_decomposition_defined) try_2x2_decomposition();
337 
338         brg->rd_block = brg->is_bf16_amx ? 32 : 64;
339         brg->rdb = brg->reduce_dim / brg->rd_block;
340         brg->rdb_tail = brg->reduce_dim % brg->rd_block;
341 
342         // Remove these guard in the future (add tail processing by reduction dimension)
343         if (brg->rdb > 0 && brg->rdb_tail) return status::unimplemented;
344         if (brg->rdb_tail % ((brg->is_bf16_amx) ? 2 : 4))
345             return status::unimplemented;
346     }
347 
348     return status::success;
349 }
350 } // namespace
351 
brgemm_desc_init(brgemm_t * brg,cpu_isa_t isa,brgemm_batch_kind_t type,impl::data_type_t dt_a,impl::data_type_t dt_b,bool transA,bool transB,brgemm_layout_t layout,float alpha,float beta,dim_t LDA,dim_t LDB,dim_t LDC,dim_t M,dim_t N,dim_t K,const brgemm_strides_t * strides)352 status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
353         brgemm_batch_kind_t type, impl::data_type_t dt_a,
354         impl::data_type_t dt_b, bool transA, bool transB,
355         brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB,
356         dim_t LDC, dim_t M, dim_t N, dim_t K, const brgemm_strides_t *strides) {
357     /*
358     m - number of rows of the matrix op(A) and number of rows of the matrix C
359     n - number of columns of the matrix op(B) and number of columns of the matrix C
360     k - number of columns of the matrix op(A) and number of rows of the matrix op(B)
361 
362     Matrices are in row-major layouts:
363         A: lda * m, LDA - lda must be at least max(1, k)
364         B: ldb * k, LDB - ldb must be at least max(1, n)
365         C: ldc * m, LDC - ldc must be at least max(1, n)
366 
367     Matrices are in column-major layouts:
368         A: lda * k, LDA - lda must be at least max(1, m)
369         B: ldb * n, LDB - ldb must be at least max(1, k)
370         C: ldc * n, LDC - ldc must be at least max(1, m)
371     */
372     if (brg == nullptr) return status::invalid_arguments;
373     if (transA || transB) return status::unimplemented;
374 
375     brg->layout = layout;
376     auto is_row_major = [&]() { return brg->layout == brgemm_row_major; };
377     if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
378     bool ldx_check = (is_row_major()) ? (LDA < K || LDB < N || LDC < N)
379                                       : (LDA < M || LDB < K || LDC < M);
380     if (ldx_check) return status::invalid_arguments;
381 
382     brg->dt_a = (is_row_major()) ? dt_a : dt_b;
383     brg->dt_b = (is_row_major()) ? dt_b : dt_a;
384 
385     brg->is_int8 = (one_of(brg->dt_a, data_type::u8, data_type::s8)
386             && brg->dt_b == data_type::s8);
387     brg->is_bf16
388             = (brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16);
389     brg->is_f32 = (brg->dt_a == data_type::f32 && brg->dt_b == data_type::f32);
390     if (!brg->is_int8 && !brg->is_bf16 && !brg->is_f32)
391         return status::unimplemented;
392     brg->dt_c = (brg->is_int8) ? data_type::s32 : data_type::f32;
393     brg->dt_d = brg->dt_c;
394     brg->dt_bias = brg->dt_c;
395 
396     if (!IMPLICATION(brg->is_f32, mayiuse(avx512_core)))
397         return status::unimplemented;
398     if (!IMPLICATION(brg->is_bf16, mayiuse(avx512_core_bf16)))
399         return status::unimplemented;
400     if (!IMPLICATION(brg->is_int8, mayiuse(avx512_core_vnni)))
401         return status::unimplemented;
402 
403     if (isa != isa_any) {
404         if (!one_of(isa, avx512_core, avx512_core_bf16, avx512_core_vnni,
405                     avx512_core_bf16_amx_bf16, avx512_core_bf16_amx_int8)) {
406             return status::invalid_arguments;
407         }
408         brg->is_int8_amx = brg->is_bf16_amx = false;
409         if (brg->is_int8 && isa == avx512_core_bf16_amx_int8) {
410             if (!mayiuse(avx512_core_bf16_amx_int8))
411                 return status::invalid_arguments;
412             brg->is_int8_amx = true;
413         }
414         if (brg->is_bf16 && isa == avx512_core_bf16_amx_bf16) {
415             if (!mayiuse(avx512_core_bf16_amx_bf16))
416                 return status::invalid_arguments;
417             brg->is_bf16_amx = true;
418         }
419     } else {
420         brg->is_int8_amx = brg->is_int8 && mayiuse(avx512_core_bf16_amx_int8);
421         brg->is_bf16_amx = brg->is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
422     }
423     brg->is_amx = (brg->is_int8_amx || brg->is_bf16_amx);
424     brg->req_s8s8_compensation
425             = brg->is_int8 && !brg->is_int8_amx && brg->dt_a == data_type::s8;
426     brg->LDA = (is_row_major()) ? static_cast<int>(LDA) : static_cast<int>(LDB);
427     brg->LDB = (is_row_major()) ? static_cast<int>(LDB) : static_cast<int>(LDA);
428 
429     brg->LDC = static_cast<int>(LDC);
430     brg->LDD = static_cast<int>(LDC);
431 
432     brg->bcast_dim
433             = (is_row_major()) ? static_cast<int>(M) : static_cast<int>(N);
434     brg->load_dim
435             = (is_row_major()) ? static_cast<int>(N) : static_cast<int>(M);
436     brg->reduce_dim = static_cast<int>(K);
437 
438     brg->with_bias = false;
439     brg->with_eltwise = false;
440     brg->with_sum = false;
441     brg->sum_scale = 0;
442     brg->sum_zp = 0;
443     brg->with_scales = false;
444 
445     brg->beta = beta;
446     brg->alpha = alpha;
447 
448     brg->typesize_A = types::data_type_size(brg->dt_a);
449     brg->typesize_B = types::data_type_size(brg->dt_b);
450     brg->typesize_C = types::data_type_size(brg->dt_c);
451     brg->typesize_D = types::data_type_size(brg->dt_d);
452     brg->type = type;
453 
454     brg->bd_block2 = 0;
455     brg->bdb2 = 0;
456     brg->bdb2_tail = 0;
457 
458     brg->ld_step = brg->rd_step = 4 / brg->typesize_A;
459 
460     if (strides != nullptr) {
461         brg->stride_a = strides->stride_a;
462         brg->stride_b = strides->stride_b;
463     } else {
464         brg->stride_a = brg->stride_b = 0;
465     }
466 
467     CHECK(brgemm_blocking(brg));
468 
469     return status::success;
470 }
471 
brdgmm_desc_init(brgemm_t * brg,cpu_isa_t isa,brgemm_batch_kind_t type,impl::data_type_t dt_a,impl::data_type_t dt_b,bool transA,brgemm_layout_t layout,float alpha,float beta,dim_t LDA,dim_t LDC,dim_t M,dim_t N,const brgemm_strides_t * strides)472 status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
473         brgemm_batch_kind_t type, impl::data_type_t dt_a,
474         impl::data_type_t dt_b, bool transA, brgemm_layout_t layout,
475         float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
476         const brgemm_strides_t *strides) {
477 
478     if (brg == nullptr) return status::invalid_arguments;
479     if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f)
480         return status::unimplemented;
481 
482     const bool ldx_check = (LDA < N || LDC < N);
483     if (ldx_check) return status::invalid_arguments;
484 
485     brg->dt_a = dt_a;
486     brg->dt_b = dt_b;
487 
488     brg->is_int8 = one_of(brg->dt_a, data_type::u8, data_type::s8)
489             && (brg->dt_b == data_type::s8);
490     brg->is_bf16
491             = (brg->dt_a == data_type::bf16) && (brg->dt_b == data_type::bf16);
492     brg->is_f32
493             = (brg->dt_a == data_type::f32) && (brg->dt_b == data_type::f32);
494     if (!brg->is_int8 && !brg->is_bf16 && !brg->is_f32)
495         return status::unimplemented;
496     brg->dt_c = (brg->is_int8) ? data_type::s32 : data_type::f32;
497     brg->dt_d = brg->dt_c;
498     brg->dt_bias = brg->dt_c;
499 
500     const cpu_isa_t req_isa = brg->is_f32
501             ? avx512_core
502             : (brg->is_int8 ? avx512_core_vnni : avx512_core_bf16);
503     if (!(is_superset(isa, req_isa) && mayiuse(req_isa)))
504         return status::unimplemented;
505 
506     brg->is_bf16_amx = brg->is_bf16 && mayiuse(avx512_core_bf16_amx_bf16);
507     brg->is_dgmm = true;
508     brg->type = type;
509     brg->layout = layout;
510     brg->alpha = alpha;
511     brg->beta = beta;
512 
513     brg->LDA = static_cast<int>(LDA);
514     brg->LDC = static_cast<int>(LDC);
515     brg->LDD = static_cast<int>(LDC);
516 
517     brg->typesize_A = types::data_type_size(brg->dt_a);
518     brg->typesize_B = types::data_type_size(brg->dt_b);
519     brg->typesize_C = types::data_type_size(brg->dt_c);
520     brg->typesize_D = types::data_type_size(brg->dt_d);
521 
522     // In current implementation of dgmm, there is no reduce dim.
523     // Also, bcast and load dimensions refer to M and N.
524 
525     // auto &M = brg->bcast_dim;
526     // auto &N = brg->load_dim;
527     auto &m_vlen_blk = brg->bd_block;
528     auto &nb_m_vlen_blk = brg->bdb;
529     auto &m_vlen_tail = brg->bdb_tail;
530     auto &m_blocking = brg->bd_block2;
531     auto &nb_m_blocking = brg->bdb2;
532     auto &m_blocking_tail = brg->bdb2_tail;
533 
534     auto &n_vlen_blk = brg->ld_block;
535     auto &nb_n_vlen_blk = brg->ldb;
536     auto &n_vlen_tail = brg->ldb_tail;
537     auto &n_blocking = brg->ld_block2;
538     auto &nb_n_blocking = brg->ldb2;
539     auto &n_blocking_tail = brg->ldb2_tail;
540 
541     brg->bcast_dim = M;
542     brg->load_dim = N;
543     const int simd_w = 16;
544 
545     // begin blocking
546     n_vlen_blk = simd_w;
547     nb_n_vlen_blk = div_up(N, n_vlen_blk);
548     n_vlen_tail = N % n_vlen_blk;
549     n_blocking = nstl::min(4, nb_n_vlen_blk);
550     nb_n_blocking = div_up(nb_n_vlen_blk, n_blocking);
551     n_blocking_tail = nb_n_vlen_blk % n_blocking;
552 
553     const int max_acc_zmms = 32 - 2 /*zmma, zmmb, post-ops, saturation*/
554             - jit_brdgmm_kernel_base_t::is_fast_vnni_int8(*brg) /*perm dst*/;
555     m_vlen_blk = 1;
556     nb_m_vlen_blk = M / m_vlen_blk;
557     m_vlen_tail = M % m_vlen_blk;
558     m_blocking = nstl::min(nb_m_vlen_blk, max_acc_zmms / n_blocking);
559     nb_m_blocking = div_up(nb_m_vlen_blk, m_blocking);
560     m_blocking_tail = nb_m_vlen_blk % m_blocking;
561 
562     if (strides != nullptr) {
563         brg->stride_a = strides->stride_a;
564         brg->stride_b = strides->stride_b;
565     } else {
566         brg->stride_a = brg->stride_b = 0;
567     }
568 
569     return status::success;
570 }
571 
brgemm_desc_set_postops(brgemm_t * brg,const primitive_attr_t * attr,const memory_desc_t * dst_md,int LDD,impl::data_type_t dt_bias)572 status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
573         const memory_desc_t *dst_md, int LDD, impl::data_type_t dt_bias) {
574     if (!brg || !dst_md) return status::invalid_arguments;
575 
576     brg->attr = attr;
577     brg->dst_md = dst_md;
578 
579     brg->with_bias = (dt_bias == data_type::undef) ? false : true;
580     brg->dt_bias = dt_bias;
581     brg->typesize_bias = (dt_bias == data_type::undef)
582             ? 0
583             : types::data_type_size(brg->dt_bias);
584 
585     brg->LDD = LDD;
586     const auto dt_d = dst_md->data_type;
587 
588     if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
589             && (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
590                     data_type::f32))
591             && (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8,
592                     data_type::s32, data_type::f32, data_type::bf16)))
593         return status::unimplemented;
594     if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16)
595             && (!one_of(dt_d, data_type::bf16, data_type::f32))
596             && (!one_of(dt_bias, data_type::undef, data_type::bf16,
597                     data_type::f32)))
598         return status::unimplemented;
599     if ((brg->dt_a == data_type::f32 && brg->dt_b == data_type::f32)
600             && (!one_of(dt_d, data_type::f32))
601             && (!one_of(dt_bias, data_type::undef, data_type::f32)))
602         return status::unimplemented;
603 
604     brg->dt_d = dt_d;
605     brg->typesize_D = types::data_type_size(brg->dt_d);
606 
607     if (!IMPLICATION(
608                 brg->is_int8 && brg->dt_d == bf16, mayiuse(avx512_core_bf16)))
609         return status::unimplemented;
610 
611     if (!brg->attr) return status::success;
612 
613     using namespace injector;
614 
615     const auto &post_ops = brg->attr->post_ops_;
616     const memory_desc_wrapper dst_d(dst_md);
617 
618     const int binary_ind = post_ops.find(primitive_kind::binary);
619     brg->with_binary = binary_ind != -1;
620     const cpu_isa_t isa = get_max_cpu_isa();
621 
622     if ((brg->with_binary && !dst_md)
623             || !injector::post_ops_ok(
624                     post_ops_ok_args_t(isa, {sum, eltwise, binary}, post_ops,
625                             &dst_d, false /*sum_at_pos_0_only*/,
626                             false /*sum_requires_scale_one*/,
627                             false /*sum_requires_zp_zero*/,
628                             {broadcasting_strategy_t::per_oc,
629                                     broadcasting_strategy_t::scalar,
630                                     broadcasting_strategy_t::per_mb_spatial,
631                                     broadcasting_strategy_t::per_mb_w,
632                                     broadcasting_strategy_t::no_broadcast})))
633         return status::unimplemented;
634 
635     const int sum_idx = post_ops.find(primitive_kind::sum);
636     const bool with_sum = sum_idx != -1;
637     brg->with_sum = with_sum;
638     brg->sum_scale = with_sum ? post_ops.entry_[sum_idx].sum.scale : 0;
639     brg->sum_zp = with_sum ? post_ops.entry_[sum_idx].sum.zero_point : 0;
640     const auto sum_dt
641             = with_sum ? post_ops.entry_[sum_idx].sum.dt : data_type::undef;
642     brg->sum_dt = sum_dt != data_type::undef ? sum_dt : dt_d;
643 
644     const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
645     brg->with_eltwise = eltwise_ind != -1;
646 
647     brg->with_scales = !attr->output_scales_.has_default_values();
648     if (brg->with_scales) {
649         const auto &oscales = brg->attr->output_scales_;
650         // Note. the current version supports only two different output scale
651         // types:
652         //     1) common (mask_ = 0)
653         //     2) per_n_dim_scale - broadcast across n dimension;
654         //        for convolution and inner product promitives it corresponds
655         //        to "per_oc" mask_ = 1 << 1; for matmul - to
656         //        mask_ = (1 << (ndims - 1))), where ndims is number of
657         //        dimensions for original matmul problem
658         // So if oscales.mask_ != 0 (not common) it's assumed here that scale
659         // type is per_n_dim_scale and driver which calls brgemm kernel checked
660         // that mask has correct value for this case
661         brg->is_oc_scale = oscales.mask_ != 0;
662     }
663 
664     auto init_zp_type
665             = [&](brgemm_broadcast_t &zp_type, int mem_arg) -> status_t {
666         auto zero_points = attr->zero_points_;
667 
668         // common zero point type is supported for now
669         if (!zero_points.common(mem_arg)) return status::unimplemented;
670 
671         zp_type = zero_points.has_default_values(mem_arg)
672                 ? brgemm_broadcast_t::none
673                 : brgemm_broadcast_t::per_tensor;
674         return status::success;
675     };
676 
677     init_zp_type(brg->zp_type_a, DNNL_ARG_SRC);
678     init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS);
679     init_zp_type(brg->zp_type_c, DNNL_ARG_DST);
680 
681     return status::success;
682 }
683 
brgemm_desc_set_attr(brgemm_t * brg,const brgemm_attr_t & brgattr)684 status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
685     if (brg == nullptr) return status::invalid_arguments;
686 
687     // negative padding is not supported
688     if (brgattr.max_top_vpad < 0 || brgattr.max_bottom_vpad < 0)
689         return status::unimplemented;
690 
691     // virtual padding is not supported for "amx"
692     if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0)
693             && (brg->is_amx))
694         return status::unimplemented;
695 
696     if (!brg->is_dgmm) {
697         // virtual padding size is restricted by MAX_VPAD value
698         if (brgattr.max_top_vpad > brgemm_t::MAX_VPAD
699                 || brgattr.max_bottom_vpad > brgemm_t::MAX_VPAD)
700             return status::unimplemented;
701 
702         // virtual padding is restricted by bd_block size due to
703         // brgemm_kernel implementation. TODO: remove this restriction
704         if (brgattr.max_top_vpad > brg->bd_block
705                 || brgattr.max_bottom_vpad > brg->bd_block)
706             return status::unimplemented;
707     }
708 
709     // virtual padding is supported for "brgemm_row_major" layout
710     // TODO: remove this restriction
711     if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0)
712             && brg->layout != brgemm_row_major)
713         return status::unimplemented;
714 
715     brg->brgattr = brgattr;
716 
717     if (brgattr.bd_mask_level) brgemm_blocking(brg);
718 
719     return status::success;
720 }
721 
brgemm_kernel_create(brgemm_kernel_t ** brg_kernel,const brgemm_t & brg)722 status_t brgemm_kernel_create(
723         brgemm_kernel_t **brg_kernel, const brgemm_t &brg) {
724     if (brg.is_dgmm) {
725         CHECK(safe_ptr_assign<brgemm_kernel_t>(
726                 *brg_kernel, new brdgmm_kernel_t(brg)));
727         return (*brg_kernel)->create_kernel();
728     } else if (brg.is_amx && brg.type == brgemm_addr && brg.brgattr.max_bs >= 1
729             && brg.brgattr.use_uker) {
730         CHECK(safe_ptr_assign<brgemm_kernel_t>(
731                 *brg_kernel, new brgemm_amx_uker_t(brg)));
732         return (*brg_kernel)->create_kernel();
733     } else {
734         CHECK(safe_ptr_assign<brgemm_kernel_t>(
735                 *brg_kernel, new brgemm_kernel_common_t(brg)));
736         return (*brg_kernel)->create_kernel();
737     }
738 }
739 
brgemm_kernel_destroy(brgemm_kernel_t * brg_kernel)740 void brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel) {
741     delete brg_kernel;
742 }
743 
brgemm_init_tiles(const brgemm_t & brg,char palette[64])744 status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
745     constexpr int max_palette_size_in_bytes = 64;
746 
747     if (!brg.is_amx) return status::unimplemented;
748 
749     //TODO: Add support of tail processing by reduction dimension
750     int rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block;
751 
752     palette_config_t *buff = (palette_config_t *)(palette);
753 
754     char *_tc = (char *)(buff);
755     for (int i = 0; i < max_palette_size_in_bytes; i++)
756         _tc[i] = 0;
757 
758     int rd_step = 4 / brg.typesize_A;
759 
760     int Ac = brg.typesize_A * rd_block;
761 
762     int Bc = brg.ld_block * brg.typesize_B * rd_step;
763     int Bc_t = brg.ldb_tail * brg.typesize_B * rd_step;
764 
765     int Cc = brg.ld_block * brg.typesize_C;
766     int Cc_t = brg.ldb_tail * brg.typesize_C;
767 
768     int Br = (brg.typesize_C != 0) ? Ac / brg.typesize_C : 0;
769 
770     if (brg.ldb_tail && (brg.ld_block2 > 1)) return status::unimplemented;
771 
772     for (int m = 0; m < brg.bd_block2; m++) {
773         int Ar = (brg.is_M_tail && m == brg.bd_block2 - 1) ? brg.bdb_tail
774                                                            : brg.bd_block;
775         tc_configure_tile(buff, brg.get_A_tensor(m), Ar, Ac);
776     }
777 
778     for (int n = 0; n < brg.ld_block2; n++)
779         tc_configure_tile(buff, brg.get_B_tensor(n), Br, Bc);
780     if (brg.ldb_tail)
781         tc_configure_tile(buff, brg.get_B_tensor(brg.ld_block2), Br, Bc_t);
782 
783     for (int m = 0; m < brg.bd_block2; m++) {
784         int Cr = (brg.is_M_tail && m == brg.bd_block2 - 1) ? brg.bdb_tail
785                                                            : brg.bd_block;
786         for (int n = 0; n < brg.ld_block2; n++)
787             tc_configure_tile(buff, brg.get_C_tensor(m, n), Cr, Cc);
788         if (brg.ldb_tail)
789             tc_configure_tile(
790                     buff, brg.get_C_tensor(m, brg.ld_block2), Cr, Cc_t);
791     }
792     buff->palette_id = amx::get_max_palette();
793 
794     return status::success;
795 }
796 
797 } // namespace x64
798 } // namespace cpu
799 } // namespace impl
800 } // namespace dnnl
801 
802 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
803