1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain src copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "oneapi/dnnl/dnnl.hpp"
21
22 #include <vector>
23
24 namespace dnnl {
25
26 namespace P {
27 // Common
28 unsigned NONE = 0u;
29
30 unsigned RUNTIME = 1u << 31;
31
32 unsigned SCALES = 1u << 30;
33 unsigned ZERO_POINTS = 1u << 29;
34
35 unsigned LEADING_DIM = 1u << 28;
36
37 // matrices indices: 1 .. 7
38 // bits reserved: 20 .. 22
39 unsigned MATRIX_MASK = 7u << 20;
40 unsigned SRC = 1u << 20;
41 unsigned WEIGHTS = 2u << 20;
42 unsigned DST = 3u << 20;
43
44 // scales and zero points: 1 .. 3
45 // bits reserved: 0 .. 1
46 unsigned MASK_MASK = 3u << 0;
47 unsigned COMMON = 1u << 0;
48 unsigned PER_N = 1u << 1;
49 } // namespace P
50
51 struct matmul_base_t {
52 struct md_t {
53 memory::dims dims;
54 memory::data_type dt;
55 memory::format_tag tag;
56 unsigned flags;
57 } src, weights, dst;
58 memory::data_type bia_dt;
59 };
60
61 // TODO: src way to generalize?
62 struct matmul_attr_t {
63 // ctor {P::SCALE, {P::SRC, P::WEIGHTS, P::DST}, {P::POST_OPS, ...}}
64
65 unsigned scale_flags;
66
67 struct zero_points_t {
68 unsigned src, weights, dst;
69 } zero_points;
70
71 struct post_op_t {
72 primitive::kind kind;
73 algorithm alg;
74 };
75
76 std::vector<post_op_t> post_ops;
77 };
78
79 struct matmul_test_params_t {
80 matmul_base_t base;
81 matmul_attr_t attr;
82
83 bool expect_to_fail;
84 dnnl_status_t expected_status;
85 };
86
87 using tag = memory::format_tag;
88
89 class matmul_iface_test_t
90 : public ::testing::TestWithParam<matmul_test_params_t> {
91 protected:
SetUp()92 void SetUp() override {
93 matmul_test_params_t p
94 = ::testing::TestWithParam<decltype(p)>::GetParam();
95
96 SKIP_IF(unsupported_data_type(p.base.src.dt),
97 "Engine does not support this data type.");
98 SKIP_IF(unsupported_data_type(p.base.weights.dt),
99 "Engine does not support this data type.");
100 SKIP_IF(unsupported_data_type(p.base.dst.dt),
101 "Engine does not support this data type.");
102 SKIP_IF(unsupported_data_type(p.base.bia_dt),
103 "Engine does not support this data type.");
104 SKIP_IF(get_test_engine_kind() == engine::kind::gpu
105 && ((p.attr.zero_points.src & P::PER_N)
106 || (p.attr.zero_points.dst & P::PER_N)),
107 "Per dimensional zero points are not supported on GPU");
108 SKIP_IF(get_test_engine_kind() == engine::kind::cpu
109 && p.base.src.tag == impl::format_tag::AB8a4b,
110 "Don't test blocked formats on CPU");
111
112 SKIP_IF_CUDA((p.attr.zero_points.src != 0 || p.attr.zero_points.dst != 0
113 || p.attr.zero_points.weights != 0),
114 "Zero points not supported for CUDA");
115
116 SKIP_IF_CUDA((p.attr.scale_flags & P::MASK_MASK) == P::PER_N,
117 "Per dimensional scaling is not supported for CUDA");
118
119 catch_expected_failures(
120 [=]() { Test(); }, p.expect_to_fail, p.expected_status, false);
121 }
122
123 // use `force_no_rt = true` when create final memory
init_md(const matmul_base_t::md_t & desc,bool force_no_rt=false)124 static memory::desc init_md(
125 const matmul_base_t::md_t &desc, bool force_no_rt = false) {
126 const bool runtime = force_no_rt ? false : (desc.flags & P::RUNTIME);
127 const bool use_ld = (desc.flags & P::LEADING_DIM);
128
129 memory::dims dims = desc.dims;
130 if (runtime)
131 dims = memory::dims(desc.dims.size(), DNNL_RUNTIME_DIM_VAL);
132
133 if (runtime || use_ld == false)
134 return memory::desc(dims, desc.dt, desc.tag);
135
136 memory::dims strides;
137 switch (desc.tag) {
138 case tag::ab: strides = {dims[1] + 1, 1}; break;
139 case tag::ba: strides = {1, dims[0] + 1}; break;
140 case tag::abc:
141 strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1};
142 break;
143 case tag::acb:
144 strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1};
145 break;
146 default:
147 throw std::invalid_argument("tag doesn't support custom ld");
148 }
149
150 return memory::desc(dims, desc.dt, strides);
151 }
152
create_attr(const matmul_test_params_t & p,primitive_attr & attr,memory & scales_m,memory & zero_points_src_m,memory & zero_points_weights_m,memory & zero_points_dst_m,engine & eng)153 static void create_attr(const matmul_test_params_t &p, primitive_attr &attr,
154 memory &scales_m, memory &zero_points_src_m,
155 memory &zero_points_weights_m, memory &zero_points_dst_m,
156 engine &eng) {
157 const int ndims = (int)p.base.dst.dims.size();
158
159 // output scales
160 if (p.attr.scale_flags != P::NONE) {
161 ASSERT_TRUE(p.attr.scale_flags & P::SCALES);
162
163 unsigned scales_mask = p.attr.scale_flags & P::MASK_MASK;
164 ASSERT_TRUE(scales_mask == P::COMMON || scales_mask == P::PER_N);
165
166 int mask = scales_mask == P::PER_N ? 1 << (ndims - 1) : 0;
167 memory::dim scale_size = mask ? p.base.dst.dims[ndims - 1] : 1;
168
169 if (p.attr.scale_flags & P::RUNTIME) {
170 attr.set_output_scales(mask, {DNNL_RUNTIME_F32_VAL});
171
172 scales_m = test::make_memory(
173 {{scale_size}, memory::data_type::f32, {1}}, eng);
174 auto s = map_memory<float>(scales_m);
175 GTEST_EXPECT_NE(s, nullptr);
176 for (memory::dim i = 0; i < scale_size; ++i)
177 s[i] = 2.f;
178 } else {
179 std::vector<float> scales(scale_size, 2.f);
180 attr.set_output_scales(mask, scales);
181 }
182 }
183
184 // zero points
185 auto handle_zero_points = [&](int arg, unsigned flags,
186 const matmul_base_t::md_t &md,
187 memory &zero_points_m) {
188 if (flags == P::NONE) return;
189
190 ASSERT_TRUE(flags & P::ZERO_POINTS);
191 ASSERT_TRUE(flags & P::MATRIX_MASK);
192
193 // sanity check
194 switch (arg) {
195 case DNNL_ARG_SRC:
196 ASSERT_TRUE((flags & P::MATRIX_MASK) == P::SRC);
197 break;
198 case DNNL_ARG_WEIGHTS:
199 ASSERT_TRUE((flags & P::MATRIX_MASK) == P::WEIGHTS);
200 break;
201 case DNNL_ARG_DST:
202 ASSERT_TRUE((flags & P::MATRIX_MASK) == P::DST);
203 break;
204 default: ASSERT_TRUE(!"unreachable");
205 }
206
207 unsigned zero_points_mask = flags & P::MASK_MASK;
208 ASSERT_TRUE(zero_points_mask == P::COMMON
209 || zero_points_mask == P::PER_N);
210 int mask = zero_points_mask == P::PER_N ? 1 << (ndims - 1) : 0;
211 memory::dim zero_points_size = mask ? md.dims[ndims - 1] : 1;
212
213 if (flags & P::RUNTIME) {
214 attr.set_zero_points(arg, mask, {DNNL_RUNTIME_S32_VAL});
215 zero_points_m = test::make_memory(
216 {{zero_points_size}, memory::data_type::s32, {1}}, eng);
217 auto z = map_memory<int32_t>(zero_points_m);
218 GTEST_EXPECT_NE(z, nullptr);
219 for (memory::dim i = 0; i < zero_points_size; ++i)
220 z[i] = (arg % 7) - 3;
221 } else {
222 std::vector<int32_t> zero_points(
223 zero_points_size, (arg % 7) - 3);
224 attr.set_zero_points(arg, mask, zero_points);
225 }
226 };
227
228 handle_zero_points(DNNL_ARG_SRC, p.attr.zero_points.src, p.base.src,
229 zero_points_src_m);
230 handle_zero_points(DNNL_ARG_WEIGHTS, p.attr.zero_points.weights,
231 p.base.weights, zero_points_weights_m);
232 handle_zero_points(DNNL_ARG_DST, p.attr.zero_points.dst, p.base.dst,
233 zero_points_dst_m);
234
235 // post ops
236 post_ops po;
237 for (auto post_op : p.attr.post_ops) {
238 switch (post_op.kind) {
239 case primitive::kind::sum: po.append_sum(); break;
240 case primitive::kind::eltwise:
241 po.append_eltwise(1.f, post_op.alg, 0.f, 0.f);
242 break;
243 default: ASSERT_TRUE(!"unknown post op kind");
244 }
245 }
246 attr.set_post_ops(po);
247 }
248
Test()249 void Test() {
250 matmul_test_params_t p
251 = ::testing::TestWithParam<matmul_test_params_t>::GetParam();
252
253 auto eng = get_test_engine();
254 auto strm = make_stream(eng);
255
256 auto check_matrix_flags = [](unsigned flags, unsigned matrix) {
257 if (flags) { ASSERT_EQ(flags & P::MATRIX_MASK, matrix); }
258 };
259 check_matrix_flags(p.base.src.flags, P::SRC);
260 check_matrix_flags(p.base.weights.flags, P::WEIGHTS);
261 check_matrix_flags(p.base.dst.flags, P::DST);
262
263 auto src_md = init_md(p.base.src);
264 auto weights_md = init_md(p.base.weights);
265 auto dst_md = init_md(p.base.dst);
266
267 auto bia_md = memory::desc();
268 memory bia_m;
269 if (p.base.bia_dt != memory::data_type::undef) {
270 memory::dims bia_dims(p.base.dst.dims.size() - 1, 1);
271 bia_dims.push_back(p.base.dst.dims.back());
272 tag bia_tag = bia_dims.size() == 2 ? tag::ab : tag::abc;
273 bia_md = init_md({bia_dims, p.base.bia_dt, bia_tag,
274 p.base.dst.flags & P::RUNTIME});
275 bia_m = test::make_memory(
276 init_md({bia_dims, p.base.bia_dt, bia_tag}), eng);
277 }
278
279 auto matmul_d = matmul::desc(src_md, weights_md, bia_md, dst_md);
280
281 primitive_attr attr;
282 memory scales_m, zero_points_src_m, zero_points_weights_m,
283 zero_points_dst_m;
284 create_attr(p, attr, scales_m, zero_points_src_m, zero_points_weights_m,
285 zero_points_dst_m, eng);
286
287 auto matmul_pd = matmul::primitive_desc(matmul_d, attr, eng);
288
289 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
290 == matmul_pd.src_desc());
291 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS)
292 == matmul_pd.weights_desc());
293 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_BIAS)
294 == matmul_pd.bias_desc());
295 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
296 == matmul_pd.dst_desc());
297
298 auto matmul_p = matmul(matmul_pd);
299
300 auto src_m = test::make_memory(init_md(p.base.src, true), eng);
301 auto weights_m = test::make_memory(init_md(p.base.weights, true), eng);
302 auto dst_m = test::make_memory(init_md(p.base.dst, true), eng);
303
304 // Initialize memory to make sanitizers happy
305 auto set_to_zero = [](memory &m) {
306 if (m) {
307 auto p = map_memory<char>(m);
308 GTEST_EXPECT_NE(p, nullptr);
309 memset(p, 0, m.get_desc().get_size());
310 }
311 };
312 set_to_zero(src_m);
313 set_to_zero(weights_m);
314 set_to_zero(dst_m);
315 set_to_zero(bia_m);
316
317 matmul_p.execute(strm,
318 {
319 {DNNL_ARG_SRC, src_m},
320 {DNNL_ARG_WEIGHTS, weights_m},
321 {DNNL_ARG_BIAS, bia_m},
322 {DNNL_ARG_DST, dst_m},
323 {DNNL_ARG_ATTR_OUTPUT_SCALES, scales_m},
324 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC,
325 zero_points_src_m},
326 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS,
327 zero_points_weights_m},
328 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST,
329 zero_points_dst_m},
330 });
331 strm.wait();
332 }
333 };
334
335 struct attr_test_t
336 : public ::testing::TestWithParam<std::tuple<memory::dims, memory::dims,
337 memory::format_tag, memory::data_type, int>> {};
338
HANDLE_EXCEPTIONS_FOR_TEST_P(attr_test_t,TestMatmulShouldCallSameImplementationWithAttributes)339 HANDLE_EXCEPTIONS_FOR_TEST_P(
340 attr_test_t, TestMatmulShouldCallSameImplementationWithAttributes) {
341 auto engine_kind = get_test_engine_kind();
342 SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu,
343 "Binary impl_info_str should be same only on x64 CPU");
344 engine e {engine_kind, 0};
345
346 const auto &tensor_dims = std::get<0>(GetParam());
347 const auto format_tag = std::get<2>(GetParam());
348
349 auto src_md = memory::desc(tensor_dims, memory::data_type::u8, format_tag);
350 auto weights_md
351 = memory::desc(tensor_dims, memory::data_type::s8, format_tag);
352 auto dst_md = memory::desc(tensor_dims, memory::data_type::s8, format_tag);
353 auto bia_md = memory::desc();
354
355 auto matmul_d = matmul::desc(src_md, weights_md, bia_md, dst_md);
356
357 std::string impl_info_no_postops;
358 auto matmul_pd = matmul::primitive_desc(matmul_d, e);
359 ASSERT_NO_THROW(impl_info_no_postops = matmul_pd.impl_info_str(););
360
361 dnnl::primitive_attr attr;
362 const float scale = 1.f;
363 const float alpha = 1.f;
364 const float beta = 1.f;
365 const float oscale = 1.5f;
366
367 const int ndims = std::get<4>(GetParam());
368 // per-channel output scales
369 std::vector<float> oscales(tensor_dims[1], oscale);
370 attr.set_output_scales(1 << (ndims - 1), oscales);
371
372 dnnl::post_ops ops;
373 ops.append_sum(1.0);
374 ops.append_eltwise(scale, algorithm::eltwise_relu, alpha, beta);
375
376 const auto &binary_po_tensor_dims = std::get<1>(GetParam());
377 const auto &binary_po_mem_dt = std::get<3>(GetParam());
378 SKIP_IF(unsupported_data_type(binary_po_mem_dt),
379 "Engine does not support this data type.");
380 memory::desc src1_po_md(
381 binary_po_tensor_dims, binary_po_mem_dt, format_tag);
382 ops.append_binary(algorithm::binary_add, src1_po_md);
383
384 attr.set_post_ops(ops);
385
386 std::string impl_info_with_postops;
387
388 matmul_pd = matmul::primitive_desc(matmul_d, attr, e);
389 ASSERT_NO_THROW(impl_info_with_postops = matmul_pd.impl_info_str(););
390 ASSERT_EQ(impl_info_no_postops, impl_info_with_postops);
391 }
392
393 /********************************* TEST CASES *********************************/
394
395 using iface = matmul_iface_test_t;
396
397 using data_type = memory::data_type;
398
TEST_P(iface,TestsMatMul)399 TEST_P(iface, TestsMatMul) {}
400
__anone0aad9510502() 401 static auto cases_ef = []() {
402 std::vector<matmul_test_params_t> cases;
403
404 // inconsistent dims
405 cases.push_back(
406 {{{{10, 1}, data_type::f32, tag::ab},
407 {{2, 20}, data_type::f32, tag::ab},
408 {{10, 20}, data_type::f32, tag::ab}, data_type::undef},
409 {}, true, dnnl_invalid_arguments});
410 cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
411 {{1, 20}, data_type::f32, tag::ab},
412 {{10, 21}, data_type::f32, tag::ab}},
413 {}, true, dnnl_invalid_arguments});
414 cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
415 {{1, 1, 20}, data_type::f32, tag::abc},
416 {{10, 20}, data_type::f32, tag::ab}},
417 {}, true, dnnl_invalid_arguments});
418 cases.push_back({{{{1, 10, 1}, data_type::u8, tag::abc},
419 {{1, 1, 2}, data_type::s8, tag::abc},
420 {{1, 11, 2}, data_type::s8, tag::abc}},
421 {}, true, dnnl_invalid_arguments});
422
423 // inconsistent wrt runtime dim vals
424 cases.push_back(
425 {{{{3, 10, 10}, data_type::f32, tag::abc},
426 {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32, tag::abc},
427 {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32,
428 tag::abc}},
429 {}, true, dnnl_invalid_arguments});
430
431 // inconsistent wrt broadcasting
432 cases.push_back({{{{3, 10, 10}, data_type::f32, tag::abc},
433 {{1, 10, 10}, data_type::f32, tag::abc},
434 {{1, 10, 10}, data_type::f32, tag::abc}},
435 {}, true, dnnl_invalid_arguments});
436
437 // no broadcasting on m/k/n dims
438 cases.push_back({{{{10, 10}, data_type::f32, tag::ab},
439 {{1, 1}, data_type::f32, tag::ab},
440 {{10, 10}, data_type::f32, tag::ab}},
441 {}, true, dnnl_invalid_arguments});
442
443 // f32 data and zero-points
444 cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
445 {{1, 20}, data_type::f32, tag::ab},
446 {{10, 20}, data_type::f32, tag::ab}},
447 {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true,
448 dnnl_unimplemented});
449
450 // bf16 data and zero-points
451 cases.push_back({{{{10, 1}, data_type::bf16, tag::ab},
452 {{1, 20}, data_type::bf16, tag::ab},
453 {{10, 20}, data_type::bf16, tag::ab}},
454 {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true,
455 dnnl_unimplemented});
456 // unimplemented data types
457 if (get_test_engine_kind() == engine::kind::cpu) {
458 cases.push_back(
459 {{{{10, 1}, data_type::f32, tag::ab},
460 {{1, 20}, data_type::f32, tag::ab},
461 {{10, 20}, data_type::f32, tag::ab}, data_type::u8},
462 {}, true, dnnl_unimplemented});
463 }
464 // XXX: disable assert in type_helpers.hpp: default_accum_data_type(...)
465 // cases.push_back({{{{10, 1}, data_type::u8, tag::ab}, {{1, 20},
466 // data_type::u8, tag::ab},
467 // {{10, 20}, data_type::u8, tag::ab}},
468 // {}, true, dnnl_unimplemented});
469
470 // unimplemented formats (GPU only)
471 cases.push_back({{{{16, 16}, data_type::f32, tag::AB8a4b},
472 {{16, 16}, data_type::f32, tag::AB8a4b},
473 {{16, 16}, data_type::f32, tag::AB8a4b}},
474 {}, true, dnnl_unimplemented});
475
476 return ::testing::ValuesIn(cases);
477 };
478 INSTANTIATE_TEST_SUITE_P(EF, iface, cases_ef());
479
__anone0aad9510602(memory::data_type dt) 480 static auto cases_f = [](memory::data_type dt) {
481 std::vector<matmul_test_params_t> cases;
482
483 // simple case
484 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
485 {{10, 20}, dt, tag::ab}, data_type::undef},
486 {}});
487 // simple case + leading dimensions
488 cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::LEADING_DIM},
489 {{1, 3}, dt, tag::ba},
490 {{10, 3}, dt, tag::ab, P::DST | P::LEADING_DIM},
491 data_type::f32},
492 {}});
493 // simple case + leading dimensions + runtime dims
494 cases.push_back(
495 {{{{1, 10}, dt, tag::ab, P::SRC | P::LEADING_DIM | P::RUNTIME},
496 {{10, 2}, dt, tag::ba, P::WEIGHTS | P::RUNTIME},
497 {{1, 2}, dt, tag::ab,
498 P::DST | P::LEADING_DIM | P::RUNTIME},
499 data_type::f32},
500 {}});
501
502 // output scales
503 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
504 {{10, 20}, dt, tag::ab}, data_type::undef},
505 {P::SCALES | P::COMMON}});
506 // output scales + per_n + runtime
507 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
508 {{10, 20}, dt, tag::ab}, data_type::undef},
509 {P::SCALES | P::PER_N | P::RUNTIME}});
510
511 // post-ops
512 cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab},
513 {{10, 20}, dt, tag::ab}},
514 {P::NONE, {},
515 {{primitive::kind::eltwise, algorithm::eltwise_relu}}}});
516 // multiple post-ops
517 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
518 {{10, 20}, dt, tag::ab}},
519 {P::SCALES | P::COMMON, {},
520 {{primitive::kind::sum},
521 {primitive::kind::eltwise,
522 algorithm::eltwise_relu}}}});
523
524 // gemm like: output scale + post-ops(sum)
525 cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab},
526 {{10, 20}, dt, tag::ab}, data_type::f32},
527 {P::SCALES | P::COMMON, {}, {{primitive::kind::sum}}}});
528 // gemm like: output scale + post-ops(sum) + all runtime
529 cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::RUNTIME},
530 {{1, 20}, dt, tag::ab, P::WEIGHTS | P::RUNTIME},
531 {{10, 20}, dt, tag::ab, P::DST | P::RUNTIME},
532 data_type::f32},
533 {P::SCALES | P::COMMON | P::RUNTIME, {},
534 {{primitive::kind::sum}}}});
535
536 return ::testing::ValuesIn(cases);
537 };
538
539 GPU_INSTANTIATE_TEST_SUITE_P(Generic_f16, iface, cases_f(data_type::f16));
540 GPU_INSTANTIATE_TEST_SUITE_P(Generic_bf16, iface, cases_f(data_type::bf16));
541 INSTANTIATE_TEST_SUITE_P(Generic_f32, iface, cases_f(data_type::f32));
542
__anone0aad9510702(memory::data_type src_dt, memory::data_type dst_dt) 543 static auto cases_x8 = [](memory::data_type src_dt, memory::data_type dst_dt) {
544 std::vector<matmul_test_params_t> cases;
545
546 // simple case
547 cases.push_back(
548 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
549 {{10, 20}, dst_dt, tag::ab}, data_type::undef},
550 {}});
551 // simple case + leading dimensions
552 cases.push_back(
553 {{{{10, 1}, src_dt, tag::ba, P::SRC | P::LEADING_DIM},
554 {{1, 3}, data_type::s8, tag::ba},
555 {{10, 3}, dst_dt, tag::ab, P::DST | P::LEADING_DIM},
556 data_type::s8},
557 {}});
558 // simple case + leading dimensions + runtime dims
559 cases.push_back(
560 {{{{1, 10}, src_dt, tag::ba, P::SRC | P::LEADING_DIM | P::RUNTIME},
561 {{10, 2}, data_type::s8, tag::ba, P::WEIGHTS | P::RUNTIME},
562 {{1, 2}, dst_dt, tag::ab,
563 P::DST | P::LEADING_DIM | P::RUNTIME},
564 data_type::u8},
565 {}});
566
567 // output scales
568 cases.push_back(
569 {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
570 {{10, 20}, dst_dt, tag::ab}, data_type::undef},
571 {P::SCALES | P::COMMON}});
572 // output scales + per_n + runtime
573 cases.push_back(
574 {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
575 {{10, 20}, dst_dt, tag::ab}, data_type::undef},
576 {P::SCALES | P::PER_N | P::RUNTIME}});
577
578 // zero points
579 cases.push_back(
580 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
581 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
582 {P::SCALES | P::COMMON,
583 {P::ZERO_POINTS | P::SRC | P::COMMON,
584 P::ZERO_POINTS | P::WEIGHTS | P::COMMON,
585 P::ZERO_POINTS | P::DST | P::COMMON}}});
586
587 // zero points + runtime
588 cases.push_back(
589 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
590 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
591 {P::SCALES | P::COMMON | P::RUNTIME,
592 {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE,
593 P::ZERO_POINTS | P::DST | P::COMMON
594 | P::RUNTIME}}});
595
596 // per_dim_1 zero points + runtime
597 cases.push_back({{{{10, 2}, src_dt, tag::ba},
598 {{2, 20}, data_type::s8, tag::ab},
599 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
600 {P::SCALES | P::COMMON | P::RUNTIME,
601 {P::ZERO_POINTS | P::SRC | P::PER_N | P::RUNTIME, P::NONE,
602 P::ZERO_POINTS | P::DST | P::PER_N | P::RUNTIME}}});
603 // post-ops
604 cases.push_back({{{{10, 1}, src_dt, tag::ab},
605 {{1, 20}, data_type::s8, tag::ab},
606 {{10, 20}, dst_dt, tag::ab}},
607 {P::NONE, {},
608 {{primitive::kind::eltwise, algorithm::eltwise_relu}}}});
609 // multiple post-ops
610 cases.push_back(
611 {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
612 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
613 {P::SCALES | P::COMMON, {},
614 {{primitive::kind::sum},
615 {primitive::kind::eltwise,
616 algorithm::eltwise_relu}}}});
617
618 // igemm like: output scale + post-ops(sum)
619 cases.push_back({{{{10, 1}, src_dt, tag::ab},
620 {{1, 20}, data_type::s8, tag::ab},
621 {{10, 20}, dst_dt, tag::ab}, data_type::s8},
622 {P::SCALES | P::COMMON,
623 {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE,
624 P::ZERO_POINTS | P::DST | P::COMMON | P::RUNTIME},
625 {{primitive::kind::sum}}}});
626 // igemm like: output scale + post-ops(sum) + all runtime
627 cases.push_back({{{{10, 2}, src_dt, tag::ba},
628 {{2, 20}, data_type::s8, tag::ba},
629 {{10, 20}, dst_dt, tag::ab}, data_type::s8},
630 {P::SCALES | P::PER_N | P::RUNTIME,
631 {P::ZERO_POINTS | P::SRC | P::COMMON | P::RUNTIME,
632 P::ZERO_POINTS | P::WEIGHTS | P::COMMON
633 | P::RUNTIME,
634 P::ZERO_POINTS | P::DST | P::COMMON | P::RUNTIME},
635 {{primitive::kind::sum}}}});
636
637 return ::testing::ValuesIn(cases);
638 };
639 INSTANTIATE_TEST_SUITE_P(
640 Generic_s8s8s32, iface, cases_x8(data_type::s8, data_type::s32));
641 INSTANTIATE_TEST_SUITE_P(
642 Generic_u8s8u8, iface, cases_x8(data_type::u8, data_type::u8));
643
644 INSTANTIATE_TEST_SUITE_P(TensorDims, attr_test_t,
645 ::testing::Values(
646 // {{src0, src1, dst same_dim}, { binary post-op dim }},
647 // format_tag, post-op data type, ndims
648 std::make_tuple(memory::dims {3, 2, 16, 16},
649 memory::dims {3, 1, 16, 16}, tag::abcd,
650 memory::data_type::f32, 4),
651 std::make_tuple(memory::dims {9, 9, 64, 64},
652 memory::dims {9, 1, 64, 64}, tag::abcd,
653 memory::data_type::f32, 4),
654 std::make_tuple(memory::dims {3, 2, 16, 16},
655 memory::dims {3, 2, 16, 16}, tag::abcd,
656 memory::data_type::f32, 4),
657 std::make_tuple(memory::dims {2, 10, 10, 10},
658 memory::dims {2, 10, 10, 10}, tag::abcd,
659 memory::data_type::bf16, 4)));
660
661 } // namespace dnnl
662