1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain src copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19 
20 #include "oneapi/dnnl/dnnl.hpp"
21 
22 #include <vector>
23 
24 namespace dnnl {
25 
26 namespace P {
27 // Common
28 unsigned NONE = 0u;
29 
30 unsigned RUNTIME = 1u << 31;
31 
32 unsigned SCALES = 1u << 30;
33 unsigned ZERO_POINTS = 1u << 29;
34 
35 unsigned LEADING_DIM = 1u << 28;
36 
37 // matrices indices: 1 .. 7
38 // bits reserved: 20 .. 22
39 unsigned MATRIX_MASK = 7u << 20;
40 unsigned SRC = 1u << 20;
41 unsigned WEIGHTS = 2u << 20;
42 unsigned DST = 3u << 20;
43 
44 // scales and zero points: 1 .. 3
45 // bits reserved: 0 .. 1
46 unsigned MASK_MASK = 3u << 0;
47 unsigned COMMON = 1u << 0;
48 unsigned PER_N = 1u << 1;
49 } // namespace P
50 
51 struct matmul_base_t {
52     struct md_t {
53         memory::dims dims;
54         memory::data_type dt;
55         memory::format_tag tag;
56         unsigned flags;
57     } src, weights, dst;
58     memory::data_type bia_dt;
59 };
60 
61 // TODO: src way to generalize?
62 struct matmul_attr_t {
63     // ctor {P::SCALE, {P::SRC, P::WEIGHTS, P::DST}, {P::POST_OPS, ...}}
64 
65     unsigned scale_flags;
66 
67     struct zero_points_t {
68         unsigned src, weights, dst;
69     } zero_points;
70 
71     struct post_op_t {
72         primitive::kind kind;
73         algorithm alg;
74     };
75 
76     std::vector<post_op_t> post_ops;
77 };
78 
79 struct matmul_test_params_t {
80     matmul_base_t base;
81     matmul_attr_t attr;
82 
83     bool expect_to_fail;
84     dnnl_status_t expected_status;
85 };
86 
87 using tag = memory::format_tag;
88 
89 class matmul_iface_test_t
90     : public ::testing::TestWithParam<matmul_test_params_t> {
91 protected:
SetUp()92     void SetUp() override {
93         matmul_test_params_t p
94                 = ::testing::TestWithParam<decltype(p)>::GetParam();
95 
96         SKIP_IF(unsupported_data_type(p.base.src.dt),
97                 "Engine does not support this data type.");
98         SKIP_IF(unsupported_data_type(p.base.weights.dt),
99                 "Engine does not support this data type.");
100         SKIP_IF(unsupported_data_type(p.base.dst.dt),
101                 "Engine does not support this data type.");
102         SKIP_IF(unsupported_data_type(p.base.bia_dt),
103                 "Engine does not support this data type.");
104         SKIP_IF(get_test_engine_kind() == engine::kind::gpu
105                         && ((p.attr.zero_points.src & P::PER_N)
106                                 || (p.attr.zero_points.dst & P::PER_N)),
107                 "Per dimensional zero points are not supported on GPU");
108         SKIP_IF(get_test_engine_kind() == engine::kind::cpu
109                         && p.base.src.tag == impl::format_tag::AB8a4b,
110                 "Don't test blocked formats on CPU");
111 
112         SKIP_IF_CUDA((p.attr.zero_points.src != 0 || p.attr.zero_points.dst != 0
113                              || p.attr.zero_points.weights != 0),
114                 "Zero points not supported for CUDA");
115 
116         SKIP_IF_CUDA((p.attr.scale_flags & P::MASK_MASK) == P::PER_N,
117                 "Per dimensional scaling is not supported for CUDA");
118 
119         catch_expected_failures(
120                 [=]() { Test(); }, p.expect_to_fail, p.expected_status, false);
121     }
122 
123     // use `force_no_rt = true` when create final memory
init_md(const matmul_base_t::md_t & desc,bool force_no_rt=false)124     static memory::desc init_md(
125             const matmul_base_t::md_t &desc, bool force_no_rt = false) {
126         const bool runtime = force_no_rt ? false : (desc.flags & P::RUNTIME);
127         const bool use_ld = (desc.flags & P::LEADING_DIM);
128 
129         memory::dims dims = desc.dims;
130         if (runtime)
131             dims = memory::dims(desc.dims.size(), DNNL_RUNTIME_DIM_VAL);
132 
133         if (runtime || use_ld == false)
134             return memory::desc(dims, desc.dt, desc.tag);
135 
136         memory::dims strides;
137         switch (desc.tag) {
138             case tag::ab: strides = {dims[1] + 1, 1}; break;
139             case tag::ba: strides = {1, dims[0] + 1}; break;
140             case tag::abc:
141                 strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1};
142                 break;
143             case tag::acb:
144                 strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1};
145                 break;
146             default:
147                 throw std::invalid_argument("tag doesn't support custom ld");
148         }
149 
150         return memory::desc(dims, desc.dt, strides);
151     }
152 
create_attr(const matmul_test_params_t & p,primitive_attr & attr,memory & scales_m,memory & zero_points_src_m,memory & zero_points_weights_m,memory & zero_points_dst_m,engine & eng)153     static void create_attr(const matmul_test_params_t &p, primitive_attr &attr,
154             memory &scales_m, memory &zero_points_src_m,
155             memory &zero_points_weights_m, memory &zero_points_dst_m,
156             engine &eng) {
157         const int ndims = (int)p.base.dst.dims.size();
158 
159         // output scales
160         if (p.attr.scale_flags != P::NONE) {
161             ASSERT_TRUE(p.attr.scale_flags & P::SCALES);
162 
163             unsigned scales_mask = p.attr.scale_flags & P::MASK_MASK;
164             ASSERT_TRUE(scales_mask == P::COMMON || scales_mask == P::PER_N);
165 
166             int mask = scales_mask == P::PER_N ? 1 << (ndims - 1) : 0;
167             memory::dim scale_size = mask ? p.base.dst.dims[ndims - 1] : 1;
168 
169             if (p.attr.scale_flags & P::RUNTIME) {
170                 attr.set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
171 
172                 scales_m = test::make_memory(
173                         {{scale_size}, memory::data_type::f32, {1}}, eng);
174                 auto s = map_memory<float>(scales_m);
175                 GTEST_EXPECT_NE(s, nullptr);
176                 for (memory::dim i = 0; i < scale_size; ++i)
177                     s[i] = 2.f;
178             } else {
179                 std::vector<float> scales(scale_size, 2.f);
180                 attr.set_output_scales(mask, scales);
181             }
182         }
183 
184         // zero points
185         auto handle_zero_points = [&](int arg, unsigned flags,
186                                           const matmul_base_t::md_t &md,
187                                           memory &zero_points_m) {
188             if (flags == P::NONE) return;
189 
190             ASSERT_TRUE(flags & P::ZERO_POINTS);
191             ASSERT_TRUE(flags & P::MATRIX_MASK);
192 
193             // sanity check
194             switch (arg) {
195                 case DNNL_ARG_SRC:
196                     ASSERT_TRUE((flags & P::MATRIX_MASK) == P::SRC);
197                     break;
198                 case DNNL_ARG_WEIGHTS:
199                     ASSERT_TRUE((flags & P::MATRIX_MASK) == P::WEIGHTS);
200                     break;
201                 case DNNL_ARG_DST:
202                     ASSERT_TRUE((flags & P::MATRIX_MASK) == P::DST);
203                     break;
204                 default: ASSERT_TRUE(!"unreachable");
205             }
206 
207             unsigned zero_points_mask = flags & P::MASK_MASK;
208             ASSERT_TRUE(zero_points_mask == P::COMMON
209                     || zero_points_mask == P::PER_N);
210             int mask = zero_points_mask == P::PER_N ? 1 << (ndims - 1) : 0;
211             memory::dim zero_points_size = mask ? md.dims[ndims - 1] : 1;
212 
213             if (flags & P::RUNTIME) {
214                 attr.set_zero_points(arg, mask, {DNNL_RUNTIME_S32_VAL});
215                 zero_points_m = test::make_memory(
216                         {{zero_points_size}, memory::data_type::s32, {1}}, eng);
217                 auto z = map_memory<int32_t>(zero_points_m);
218                 GTEST_EXPECT_NE(z, nullptr);
219                 for (memory::dim i = 0; i < zero_points_size; ++i)
220                     z[i] = (arg % 7) - 3;
221             } else {
222                 std::vector<int32_t> zero_points(
223                         zero_points_size, (arg % 7) - 3);
224                 attr.set_zero_points(arg, mask, zero_points);
225             }
226         };
227 
228         handle_zero_points(DNNL_ARG_SRC, p.attr.zero_points.src, p.base.src,
229                 zero_points_src_m);
230         handle_zero_points(DNNL_ARG_WEIGHTS, p.attr.zero_points.weights,
231                 p.base.weights, zero_points_weights_m);
232         handle_zero_points(DNNL_ARG_DST, p.attr.zero_points.dst, p.base.dst,
233                 zero_points_dst_m);
234 
235         // post ops
236         post_ops po;
237         for (auto post_op : p.attr.post_ops) {
238             switch (post_op.kind) {
239                 case primitive::kind::sum: po.append_sum(); break;
240                 case primitive::kind::eltwise:
241                     po.append_eltwise(1.f, post_op.alg, 0.f, 0.f);
242                     break;
243                 default: ASSERT_TRUE(!"unknown post op kind");
244             }
245         }
246         attr.set_post_ops(po);
247     }
248 
Test()249     void Test() {
250         matmul_test_params_t p
251                 = ::testing::TestWithParam<matmul_test_params_t>::GetParam();
252 
253         auto eng = get_test_engine();
254         auto strm = make_stream(eng);
255 
256         auto check_matrix_flags = [](unsigned flags, unsigned matrix) {
257             if (flags) { ASSERT_EQ(flags & P::MATRIX_MASK, matrix); }
258         };
259         check_matrix_flags(p.base.src.flags, P::SRC);
260         check_matrix_flags(p.base.weights.flags, P::WEIGHTS);
261         check_matrix_flags(p.base.dst.flags, P::DST);
262 
263         auto src_md = init_md(p.base.src);
264         auto weights_md = init_md(p.base.weights);
265         auto dst_md = init_md(p.base.dst);
266 
267         auto bia_md = memory::desc();
268         memory bia_m;
269         if (p.base.bia_dt != memory::data_type::undef) {
270             memory::dims bia_dims(p.base.dst.dims.size() - 1, 1);
271             bia_dims.push_back(p.base.dst.dims.back());
272             tag bia_tag = bia_dims.size() == 2 ? tag::ab : tag::abc;
273             bia_md = init_md({bia_dims, p.base.bia_dt, bia_tag,
274                     p.base.dst.flags & P::RUNTIME});
275             bia_m = test::make_memory(
276                     init_md({bia_dims, p.base.bia_dt, bia_tag}), eng);
277         }
278 
279         auto matmul_d = matmul::desc(src_md, weights_md, bia_md, dst_md);
280 
281         primitive_attr attr;
282         memory scales_m, zero_points_src_m, zero_points_weights_m,
283                 zero_points_dst_m;
284         create_attr(p, attr, scales_m, zero_points_src_m, zero_points_weights_m,
285                 zero_points_dst_m, eng);
286 
287         auto matmul_pd = matmul::primitive_desc(matmul_d, attr, eng);
288 
289         ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
290                 == matmul_pd.src_desc());
291         ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS)
292                 == matmul_pd.weights_desc());
293         ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_BIAS)
294                 == matmul_pd.bias_desc());
295         ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
296                 == matmul_pd.dst_desc());
297 
298         auto matmul_p = matmul(matmul_pd);
299 
300         auto src_m = test::make_memory(init_md(p.base.src, true), eng);
301         auto weights_m = test::make_memory(init_md(p.base.weights, true), eng);
302         auto dst_m = test::make_memory(init_md(p.base.dst, true), eng);
303 
304         // Initialize memory to make sanitizers happy
305         auto set_to_zero = [](memory &m) {
306             if (m) {
307                 auto p = map_memory<char>(m);
308                 GTEST_EXPECT_NE(p, nullptr);
309                 memset(p, 0, m.get_desc().get_size());
310             }
311         };
312         set_to_zero(src_m);
313         set_to_zero(weights_m);
314         set_to_zero(dst_m);
315         set_to_zero(bia_m);
316 
317         matmul_p.execute(strm,
318                 {
319                         {DNNL_ARG_SRC, src_m},
320                         {DNNL_ARG_WEIGHTS, weights_m},
321                         {DNNL_ARG_BIAS, bia_m},
322                         {DNNL_ARG_DST, dst_m},
323                         {DNNL_ARG_ATTR_OUTPUT_SCALES, scales_m},
324                         {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC,
325                                 zero_points_src_m},
326                         {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS,
327                                 zero_points_weights_m},
328                         {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST,
329                                 zero_points_dst_m},
330                 });
331         strm.wait();
332     }
333 };
334 
335 struct attr_test_t
336     : public ::testing::TestWithParam<std::tuple<memory::dims, memory::dims,
337               memory::format_tag, memory::data_type, int>> {};
338 
HANDLE_EXCEPTIONS_FOR_TEST_P(attr_test_t,TestMatmulShouldCallSameImplementationWithAttributes)339 HANDLE_EXCEPTIONS_FOR_TEST_P(
340         attr_test_t, TestMatmulShouldCallSameImplementationWithAttributes) {
341     auto engine_kind = get_test_engine_kind();
342     SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu,
343             "Binary impl_info_str should be same only on x64 CPU");
344     engine e {engine_kind, 0};
345 
346     const auto &tensor_dims = std::get<0>(GetParam());
347     const auto format_tag = std::get<2>(GetParam());
348 
349     auto src_md = memory::desc(tensor_dims, memory::data_type::u8, format_tag);
350     auto weights_md
351             = memory::desc(tensor_dims, memory::data_type::s8, format_tag);
352     auto dst_md = memory::desc(tensor_dims, memory::data_type::s8, format_tag);
353     auto bia_md = memory::desc();
354 
355     auto matmul_d = matmul::desc(src_md, weights_md, bia_md, dst_md);
356 
357     std::string impl_info_no_postops;
358     auto matmul_pd = matmul::primitive_desc(matmul_d, e);
359     ASSERT_NO_THROW(impl_info_no_postops = matmul_pd.impl_info_str(););
360 
361     dnnl::primitive_attr attr;
362     const float scale = 1.f;
363     const float alpha = 1.f;
364     const float beta = 1.f;
365     const float oscale = 1.5f;
366 
367     const int ndims = std::get<4>(GetParam());
368     // per-channel output scales
369     std::vector<float> oscales(tensor_dims[1], oscale);
370     attr.set_output_scales(1 << (ndims - 1), oscales);
371 
372     dnnl::post_ops ops;
373     ops.append_sum(1.0);
374     ops.append_eltwise(scale, algorithm::eltwise_relu, alpha, beta);
375 
376     const auto &binary_po_tensor_dims = std::get<1>(GetParam());
377     const auto &binary_po_mem_dt = std::get<3>(GetParam());
378     SKIP_IF(unsupported_data_type(binary_po_mem_dt),
379             "Engine does not support this data type.");
380     memory::desc src1_po_md(
381             binary_po_tensor_dims, binary_po_mem_dt, format_tag);
382     ops.append_binary(algorithm::binary_add, src1_po_md);
383 
384     attr.set_post_ops(ops);
385 
386     std::string impl_info_with_postops;
387 
388     matmul_pd = matmul::primitive_desc(matmul_d, attr, e);
389     ASSERT_NO_THROW(impl_info_with_postops = matmul_pd.impl_info_str(););
390     ASSERT_EQ(impl_info_no_postops, impl_info_with_postops);
391 }
392 
393 /********************************* TEST CASES *********************************/
394 
395 using iface = matmul_iface_test_t;
396 
397 using data_type = memory::data_type;
398 
TEST_P(iface,TestsMatMul)399 TEST_P(iface, TestsMatMul) {}
400 
__anone0aad9510502() 401 static auto cases_ef = []() {
402     std::vector<matmul_test_params_t> cases;
403 
404     // inconsistent dims
405     cases.push_back(
406             {{{{10, 1}, data_type::f32, tag::ab},
407                      {{2, 20}, data_type::f32, tag::ab},
408                      {{10, 20}, data_type::f32, tag::ab}, data_type::undef},
409                     {}, true, dnnl_invalid_arguments});
410     cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
411                              {{1, 20}, data_type::f32, tag::ab},
412                              {{10, 21}, data_type::f32, tag::ab}},
413             {}, true, dnnl_invalid_arguments});
414     cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
415                              {{1, 1, 20}, data_type::f32, tag::abc},
416                              {{10, 20}, data_type::f32, tag::ab}},
417             {}, true, dnnl_invalid_arguments});
418     cases.push_back({{{{1, 10, 1}, data_type::u8, tag::abc},
419                              {{1, 1, 2}, data_type::s8, tag::abc},
420                              {{1, 11, 2}, data_type::s8, tag::abc}},
421             {}, true, dnnl_invalid_arguments});
422 
423     // inconsistent wrt runtime dim vals
424     cases.push_back(
425             {{{{3, 10, 10}, data_type::f32, tag::abc},
426                      {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32, tag::abc},
427                      {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32,
428                              tag::abc}},
429                     {}, true, dnnl_invalid_arguments});
430 
431     // inconsistent wrt broadcasting
432     cases.push_back({{{{3, 10, 10}, data_type::f32, tag::abc},
433                              {{1, 10, 10}, data_type::f32, tag::abc},
434                              {{1, 10, 10}, data_type::f32, tag::abc}},
435             {}, true, dnnl_invalid_arguments});
436 
437     // no broadcasting on m/k/n dims
438     cases.push_back({{{{10, 10}, data_type::f32, tag::ab},
439                              {{1, 1}, data_type::f32, tag::ab},
440                              {{10, 10}, data_type::f32, tag::ab}},
441             {}, true, dnnl_invalid_arguments});
442 
443     // f32 data and zero-points
444     cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
445                              {{1, 20}, data_type::f32, tag::ab},
446                              {{10, 20}, data_type::f32, tag::ab}},
447             {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true,
448             dnnl_unimplemented});
449 
450     // bf16 data and zero-points
451     cases.push_back({{{{10, 1}, data_type::bf16, tag::ab},
452                              {{1, 20}, data_type::bf16, tag::ab},
453                              {{10, 20}, data_type::bf16, tag::ab}},
454             {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true,
455             dnnl_unimplemented});
456     // unimplemented data types
457     if (get_test_engine_kind() == engine::kind::cpu) {
458         cases.push_back(
459                 {{{{10, 1}, data_type::f32, tag::ab},
460                          {{1, 20}, data_type::f32, tag::ab},
461                          {{10, 20}, data_type::f32, tag::ab}, data_type::u8},
462                         {}, true, dnnl_unimplemented});
463     }
464     // XXX: disable assert in type_helpers.hpp: default_accum_data_type(...)
465     // cases.push_back({{{{10, 1}, data_type::u8, tag::ab}, {{1, 20},
466     // data_type::u8, tag::ab},
467     //                           {{10, 20}, data_type::u8, tag::ab}},
468     //        {}, true, dnnl_unimplemented});
469 
470     // unimplemented formats (GPU only)
471     cases.push_back({{{{16, 16}, data_type::f32, tag::AB8a4b},
472                              {{16, 16}, data_type::f32, tag::AB8a4b},
473                              {{16, 16}, data_type::f32, tag::AB8a4b}},
474             {}, true, dnnl_unimplemented});
475 
476     return ::testing::ValuesIn(cases);
477 };
478 INSTANTIATE_TEST_SUITE_P(EF, iface, cases_ef());
479 
__anone0aad9510602(memory::data_type dt) 480 static auto cases_f = [](memory::data_type dt) {
481     std::vector<matmul_test_params_t> cases;
482 
483     // simple case
484     cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
485                              {{10, 20}, dt, tag::ab}, data_type::undef},
486             {}});
487     // simple case + leading dimensions
488     cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::LEADING_DIM},
489                              {{1, 3}, dt, tag::ba},
490                              {{10, 3}, dt, tag::ab, P::DST | P::LEADING_DIM},
491                              data_type::f32},
492             {}});
493     // simple case + leading dimensions + runtime dims
494     cases.push_back(
495             {{{{1, 10}, dt, tag::ab, P::SRC | P::LEADING_DIM | P::RUNTIME},
496                      {{10, 2}, dt, tag::ba, P::WEIGHTS | P::RUNTIME},
497                      {{1, 2}, dt, tag::ab,
498                              P::DST | P::LEADING_DIM | P::RUNTIME},
499                      data_type::f32},
500                     {}});
501 
502     // output scales
503     cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
504                              {{10, 20}, dt, tag::ab}, data_type::undef},
505             {P::SCALES | P::COMMON}});
506     // output scales + per_n + runtime
507     cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
508                              {{10, 20}, dt, tag::ab}, data_type::undef},
509             {P::SCALES | P::PER_N | P::RUNTIME}});
510 
511     // post-ops
512     cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab},
513                              {{10, 20}, dt, tag::ab}},
514             {P::NONE, {},
515                     {{primitive::kind::eltwise, algorithm::eltwise_relu}}}});
516     // multiple post-ops
517     cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
518                              {{10, 20}, dt, tag::ab}},
519             {P::SCALES | P::COMMON, {},
520                     {{primitive::kind::sum},
521                             {primitive::kind::eltwise,
522                                     algorithm::eltwise_relu}}}});
523 
524     // gemm like: output scale + post-ops(sum)
525     cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab},
526                              {{10, 20}, dt, tag::ab}, data_type::f32},
527             {P::SCALES | P::COMMON, {}, {{primitive::kind::sum}}}});
528     // gemm like: output scale + post-ops(sum) + all runtime
529     cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::RUNTIME},
530                              {{1, 20}, dt, tag::ab, P::WEIGHTS | P::RUNTIME},
531                              {{10, 20}, dt, tag::ab, P::DST | P::RUNTIME},
532                              data_type::f32},
533             {P::SCALES | P::COMMON | P::RUNTIME, {},
534                     {{primitive::kind::sum}}}});
535 
536     return ::testing::ValuesIn(cases);
537 };
538 
539 GPU_INSTANTIATE_TEST_SUITE_P(Generic_f16, iface, cases_f(data_type::f16));
540 GPU_INSTANTIATE_TEST_SUITE_P(Generic_bf16, iface, cases_f(data_type::bf16));
541 INSTANTIATE_TEST_SUITE_P(Generic_f32, iface, cases_f(data_type::f32));
542 
__anone0aad9510702(memory::data_type src_dt, memory::data_type dst_dt) 543 static auto cases_x8 = [](memory::data_type src_dt, memory::data_type dst_dt) {
544     std::vector<matmul_test_params_t> cases;
545 
546     // simple case
547     cases.push_back(
548             {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
549                      {{10, 20}, dst_dt, tag::ab}, data_type::undef},
550                     {}});
551     // simple case + leading dimensions
552     cases.push_back(
553             {{{{10, 1}, src_dt, tag::ba, P::SRC | P::LEADING_DIM},
554                      {{1, 3}, data_type::s8, tag::ba},
555                      {{10, 3}, dst_dt, tag::ab, P::DST | P::LEADING_DIM},
556                      data_type::s8},
557                     {}});
558     // simple case + leading dimensions + runtime dims
559     cases.push_back(
560             {{{{1, 10}, src_dt, tag::ba, P::SRC | P::LEADING_DIM | P::RUNTIME},
561                      {{10, 2}, data_type::s8, tag::ba, P::WEIGHTS | P::RUNTIME},
562                      {{1, 2}, dst_dt, tag::ab,
563                              P::DST | P::LEADING_DIM | P::RUNTIME},
564                      data_type::u8},
565                     {}});
566 
567     // output scales
568     cases.push_back(
569             {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
570                      {{10, 20}, dst_dt, tag::ab}, data_type::undef},
571                     {P::SCALES | P::COMMON}});
572     // output scales + per_n + runtime
573     cases.push_back(
574             {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
575                      {{10, 20}, dst_dt, tag::ab}, data_type::undef},
576                     {P::SCALES | P::PER_N | P::RUNTIME}});
577 
578     // zero points
579     cases.push_back(
580             {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
581                      {{10, 20}, dst_dt, tag::ab}, data_type::f32},
582                     {P::SCALES | P::COMMON,
583                             {P::ZERO_POINTS | P::SRC | P::COMMON,
584                                     P::ZERO_POINTS | P::WEIGHTS | P::COMMON,
585                                     P::ZERO_POINTS | P::DST | P::COMMON}}});
586 
587     // zero points + runtime
588     cases.push_back(
589             {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
590                      {{10, 20}, dst_dt, tag::ab}, data_type::f32},
591                     {P::SCALES | P::COMMON | P::RUNTIME,
592                             {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE,
593                                     P::ZERO_POINTS | P::DST | P::COMMON
594                                             | P::RUNTIME}}});
595 
596     // per_dim_1 zero points + runtime
597     cases.push_back({{{{10, 2}, src_dt, tag::ba},
598                              {{2, 20}, data_type::s8, tag::ab},
599                              {{10, 20}, dst_dt, tag::ab}, data_type::f32},
600             {P::SCALES | P::COMMON | P::RUNTIME,
601                     {P::ZERO_POINTS | P::SRC | P::PER_N | P::RUNTIME, P::NONE,
602                             P::ZERO_POINTS | P::DST | P::PER_N | P::RUNTIME}}});
603     // post-ops
604     cases.push_back({{{{10, 1}, src_dt, tag::ab},
605                              {{1, 20}, data_type::s8, tag::ab},
606                              {{10, 20}, dst_dt, tag::ab}},
607             {P::NONE, {},
608                     {{primitive::kind::eltwise, algorithm::eltwise_relu}}}});
609     // multiple post-ops
610     cases.push_back(
611             {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
612                      {{10, 20}, dst_dt, tag::ab}, data_type::f32},
613                     {P::SCALES | P::COMMON, {},
614                             {{primitive::kind::sum},
615                                     {primitive::kind::eltwise,
616                                             algorithm::eltwise_relu}}}});
617 
618     // igemm like: output scale + post-ops(sum)
619     cases.push_back({{{{10, 1}, src_dt, tag::ab},
620                              {{1, 20}, data_type::s8, tag::ab},
621                              {{10, 20}, dst_dt, tag::ab}, data_type::s8},
622             {P::SCALES | P::COMMON,
623                     {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE,
624                             P::ZERO_POINTS | P::DST | P::COMMON | P::RUNTIME},
625                     {{primitive::kind::sum}}}});
626     // igemm like: output scale + post-ops(sum) + all runtime
627     cases.push_back({{{{10, 2}, src_dt, tag::ba},
628                              {{2, 20}, data_type::s8, tag::ba},
629                              {{10, 20}, dst_dt, tag::ab}, data_type::s8},
630             {P::SCALES | P::PER_N | P::RUNTIME,
631                     {P::ZERO_POINTS | P::SRC | P::COMMON | P::RUNTIME,
632                             P::ZERO_POINTS | P::WEIGHTS | P::COMMON
633                                     | P::RUNTIME,
634                             P::ZERO_POINTS | P::DST | P::COMMON | P::RUNTIME},
635                     {{primitive::kind::sum}}}});
636 
637     return ::testing::ValuesIn(cases);
638 };
639 INSTANTIATE_TEST_SUITE_P(
640         Generic_s8s8s32, iface, cases_x8(data_type::s8, data_type::s32));
641 INSTANTIATE_TEST_SUITE_P(
642         Generic_u8s8u8, iface, cases_x8(data_type::u8, data_type::u8));
643 
644 INSTANTIATE_TEST_SUITE_P(TensorDims, attr_test_t,
645         ::testing::Values(
646                 // {{src0, src1, dst same_dim}, { binary post-op dim }},
647                 // format_tag, post-op data type, ndims
648                 std::make_tuple(memory::dims {3, 2, 16, 16},
649                         memory::dims {3, 1, 16, 16}, tag::abcd,
650                         memory::data_type::f32, 4),
651                 std::make_tuple(memory::dims {9, 9, 64, 64},
652                         memory::dims {9, 1, 64, 64}, tag::abcd,
653                         memory::data_type::f32, 4),
654                 std::make_tuple(memory::dims {3, 2, 16, 16},
655                         memory::dims {3, 2, 16, 16}, tag::abcd,
656                         memory::data_type::f32, 4),
657                 std::make_tuple(memory::dims {2, 10, 10, 10},
658                         memory::dims {2, 10, 10, 10}, tag::abcd,
659                         memory::data_type::bf16, 4)));
660 
661 } // namespace dnnl
662