1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdio.h>
18 #include <stdlib.h>
19
20 #include <sstream>
21
22 #include "dnnl_common.hpp"
23 #include "dnnl_memory.hpp"
24 #include "utils/parser.hpp"
25
26 #include "matmul/matmul.hpp"
27
28 namespace matmul {
29
check_correctness(const settings_t & s,const settings_t & def)30 void check_correctness(const settings_t &s, const settings_t &def) {
31 std::vector<std::pair<dnnl_data_type_t, int>> bia_cfg;
32 for (const auto &i_bia_dt : s.bia_dt) {
33 if (i_bia_dt == dnnl_data_type_undef) {
34 bia_cfg.emplace_back(i_bia_dt, 0);
35 continue;
36 }
37 for (const auto &i_bia_mask : s.bia_mask)
38 bia_cfg.emplace_back(i_bia_dt, i_bia_mask);
39 }
40
41 for_(const auto &i_cfg : s.cfg)
42 for_(const auto &i_stag : s.stag)
43 for_(const auto &i_wtag : s.wtag)
44 for_(const auto &i_dtag : s.dtag)
45 for_(const auto &i_strides : s.strides)
46 for_(const auto &i_rt_dims_masks : s.rt_dims_masks)
47 for_(const auto &i_oscale : s.oscale)
48 for_(const auto &i_zero_points : s.zero_points)
49 for_(const auto &i_post_ops : s.post_ops)
50 for_(const auto &i_scratchpad_mode : s.scratchpad_mode)
51 for (const auto &i_bia_cfg : bia_cfg) {
52 attr_t attr;
53 attr.insert(i_oscale);
54 attr.insert(i_zero_points);
55 attr.insert(i_post_ops);
56 attr.insert(i_scratchpad_mode);
57 handle_legacy_attr(attr, s.attr);
58
59 const bool strided_input = !i_strides[STRIDES_SRC].empty()
60 || !i_strides[STRIDES_WEI].empty()
61 || !i_strides[STRIDES_DST].empty();
62 if (strided_input) {
63 const bool no_stride_with_tag
64 = IMPLICATION(i_stag != def.stag[0],
65 i_strides[STRIDES_SRC].empty())
66 && IMPLICATION(i_wtag != def.wtag[0],
67 i_strides[STRIDES_WEI].empty())
68 && IMPLICATION(i_dtag != def.dtag[0],
69 i_strides[STRIDES_DST].empty());
70
71 if (!no_stride_with_tag) {
72 fprintf(stderr,
73 "ERROR: matmul driver: both `strides` and `tag` knobs "
74 "can not be used with either of `src`, `wei`, and `dst`"
75 " tensors.\n"),
76 fflush(stderr);
77 SAFE_V(FAIL);
78 }
79 }
80
81 const prb_t prb(s.prb_vdims, i_cfg, i_stag, i_wtag, i_dtag, i_strides,
82 i_bia_cfg.first, i_bia_cfg.second, i_rt_dims_masks, attr);
83 std::stringstream ss;
84 ss << prb;
85 const std::string cpp_pstr = ss.str();
86 const char *pstr = cpp_pstr.c_str();
87 BENCHDNN_PRINT(1, "run: %s\n", pstr);
88
89 res_t res {};
90 const int status = doit(&prb, &res);
91
92 bool want_perf_report = false;
93 parse_result(res, want_perf_report, status, pstr);
94
95 if (want_perf_report && is_bench_mode(PERF)) {
96 perf_report_t pr(&prb, s.perf_template);
97 pr.report(&res, pstr);
98 }
99
100 benchdnn_stat.tests++;
101 }
102 }
103
bench(int argc,char ** argv)104 int bench(int argc, char **argv) {
105 driver_name = "matmul";
106 using namespace parser;
107 static settings_t s;
108 static const settings_t def {};
109 for (; argc > 0; --argc, ++argv) {
110 const bool parsed_options = parse_bench_settings(argv[0])
111 || parse_batch(bench, argv[0])
112 || parse_cfg(s.cfg, def.cfg, str2cfg, argv[0])
113 || parse_tag(s.stag, def.stag, argv[0], "stag")
114 || parse_tag(s.wtag, def.wtag, argv[0], "wtag")
115 || parse_tag(s.dtag, def.dtag, argv[0], "dtag")
116 || parse_strides(s.strides, def.strides, argv[0], "strides")
117 || parse_dt(s.bia_dt, def.bia_dt, argv[0], "bia_dt")
118 || parse_vector_option(
119 s.bia_mask, def.bia_mask, atoi, argv[0], "bia_mask")
120 || parse_multivector_option(s.rt_dims_masks, def.rt_dims_masks,
121 atoi, argv[0], "runtime_dims_masks")
122 || parse_attr(s.attr, argv[0])
123 || parse_attr_oscale(s.oscale, argv[0])
124 || parse_attr_zero_points(s.zero_points, argv[0])
125 || parse_attr_post_ops(s.post_ops, argv[0])
126 || parse_attr_scratchpad_mode(
127 s.scratchpad_mode, def.scratchpad_mode, argv[0])
128 || parse_perf_template(s.perf_template, s.perf_template_def,
129 s.perf_template_csv, argv[0])
130 || parse_reset(s, argv[0]);
131 if (!parsed_options) {
132 catch_unknown_options(argv[0]);
133
134 parse_prb_vdims(s.prb_vdims, argv[0]);
135 check_correctness(s, def);
136 }
137 }
138
139 return parse_last_argument();
140 }
141
142 } // namespace matmul
143