1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 
20 #include <sstream>
21 
22 #include "dnnl_common.hpp"
23 #include "dnnl_memory.hpp"
24 #include "utils/parser.hpp"
25 
26 #include "matmul/matmul.hpp"
27 
28 namespace matmul {
29 
check_correctness(const settings_t & s,const settings_t & def)30 void check_correctness(const settings_t &s, const settings_t &def) {
31     std::vector<std::pair<dnnl_data_type_t, int>> bia_cfg;
32     for (const auto &i_bia_dt : s.bia_dt) {
33         if (i_bia_dt == dnnl_data_type_undef) {
34             bia_cfg.emplace_back(i_bia_dt, 0);
35             continue;
36         }
37         for (const auto &i_bia_mask : s.bia_mask)
38             bia_cfg.emplace_back(i_bia_dt, i_bia_mask);
39     }
40 
41     for_(const auto &i_cfg : s.cfg)
42     for_(const auto &i_stag : s.stag)
43     for_(const auto &i_wtag : s.wtag)
44     for_(const auto &i_dtag : s.dtag)
45     for_(const auto &i_strides : s.strides)
46     for_(const auto &i_rt_dims_masks : s.rt_dims_masks)
47     for_(const auto &i_oscale : s.oscale)
48     for_(const auto &i_zero_points : s.zero_points)
49     for_(const auto &i_post_ops : s.post_ops)
50     for_(const auto &i_scratchpad_mode : s.scratchpad_mode)
51     for (const auto &i_bia_cfg : bia_cfg) {
52         attr_t attr;
53         attr.insert(i_oscale);
54         attr.insert(i_zero_points);
55         attr.insert(i_post_ops);
56         attr.insert(i_scratchpad_mode);
57         handle_legacy_attr(attr, s.attr);
58 
59         const bool strided_input = !i_strides[STRIDES_SRC].empty()
60                 || !i_strides[STRIDES_WEI].empty()
61                 || !i_strides[STRIDES_DST].empty();
62         if (strided_input) {
63             const bool no_stride_with_tag
64                     = IMPLICATION(i_stag != def.stag[0],
65                               i_strides[STRIDES_SRC].empty())
66                     && IMPLICATION(i_wtag != def.wtag[0],
67                             i_strides[STRIDES_WEI].empty())
68                     && IMPLICATION(i_dtag != def.dtag[0],
69                             i_strides[STRIDES_DST].empty());
70 
71             if (!no_stride_with_tag) {
72                 fprintf(stderr,
73                         "ERROR: matmul driver: both `strides` and `tag` knobs "
74                         "can not be used with either of `src`, `wei`, and `dst`"
75                         " tensors.\n"),
76                         fflush(stderr);
77                 SAFE_V(FAIL);
78             }
79         }
80 
81         const prb_t prb(s.prb_vdims, i_cfg, i_stag, i_wtag, i_dtag, i_strides,
82                 i_bia_cfg.first, i_bia_cfg.second, i_rt_dims_masks, attr);
83         std::stringstream ss;
84         ss << prb;
85         const std::string cpp_pstr = ss.str();
86         const char *pstr = cpp_pstr.c_str();
87         BENCHDNN_PRINT(1, "run: %s\n", pstr);
88 
89         res_t res {};
90         const int status = doit(&prb, &res);
91 
92         bool want_perf_report = false;
93         parse_result(res, want_perf_report, status, pstr);
94 
95         if (want_perf_report && is_bench_mode(PERF)) {
96             perf_report_t pr(&prb, s.perf_template);
97             pr.report(&res, pstr);
98         }
99 
100         benchdnn_stat.tests++;
101     }
102 }
103 
bench(int argc,char ** argv)104 int bench(int argc, char **argv) {
105     driver_name = "matmul";
106     using namespace parser;
107     static settings_t s;
108     static const settings_t def {};
109     for (; argc > 0; --argc, ++argv) {
110         const bool parsed_options = parse_bench_settings(argv[0])
111                 || parse_batch(bench, argv[0])
112                 || parse_cfg(s.cfg, def.cfg, str2cfg, argv[0])
113                 || parse_tag(s.stag, def.stag, argv[0], "stag")
114                 || parse_tag(s.wtag, def.wtag, argv[0], "wtag")
115                 || parse_tag(s.dtag, def.dtag, argv[0], "dtag")
116                 || parse_strides(s.strides, def.strides, argv[0], "strides")
117                 || parse_dt(s.bia_dt, def.bia_dt, argv[0], "bia_dt")
118                 || parse_vector_option(
119                         s.bia_mask, def.bia_mask, atoi, argv[0], "bia_mask")
120                 || parse_multivector_option(s.rt_dims_masks, def.rt_dims_masks,
121                         atoi, argv[0], "runtime_dims_masks")
122                 || parse_attr(s.attr, argv[0])
123                 || parse_attr_oscale(s.oscale, argv[0])
124                 || parse_attr_zero_points(s.zero_points, argv[0])
125                 || parse_attr_post_ops(s.post_ops, argv[0])
126                 || parse_attr_scratchpad_mode(
127                         s.scratchpad_mode, def.scratchpad_mode, argv[0])
128                 || parse_perf_template(s.perf_template, s.perf_template_def,
129                         s.perf_template_csv, argv[0])
130                 || parse_reset(s, argv[0]);
131         if (!parsed_options) {
132             catch_unknown_options(argv[0]);
133 
134             parse_prb_vdims(s.prb_vdims, argv[0]);
135             check_correctness(s, def);
136         }
137     }
138 
139     return parse_last_argument();
140 }
141 
142 } // namespace matmul
143