1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef BINARY_HPP
18 #define BINARY_HPP
19 
20 #include <iostream>
21 
22 #include "oneapi/dnnl/dnnl.h"
23 
24 #include "common.hpp"
25 #include "dnn_types.hpp"
26 #include "dnnl_common.hpp"
27 #include "dnnl_memory.hpp"
28 #include "perf_report.hpp"
29 
30 namespace binary {
31 
32 using alg_t = attr_t::post_ops_t::kind_t;
33 
34 struct settings_t {
35     settings_t() = default;
36 
37     // ctor to save certain fields from resetting
settings_tbinary::settings_t38     settings_t(const char *perf_template) : settings_t() {
39         this->perf_template = perf_template;
40     }
41 
42     std::vector<dims_t> sdims;
43 
44     std::vector<std::vector<dnnl_data_type_t>> sdt {{dnnl_f32, dnnl_f32}};
45     std::vector<dnnl_data_type_t> ddt {dnnl_f32};
46     std::vector<std::vector<std::string>> stag {{tag::abx, tag::abx}};
47     std::vector<std::string> dtag {tag::any};
48     std::vector<alg_t> alg {alg_t::ADD};
49     std::vector<bool> inplace {false};
50     std::vector<attr_t::arg_scales_t> scales {attr_t::arg_scales_t()};
51     std::vector<attr_t::post_ops_t> post_ops {attr_t::post_ops_t()};
52     std::vector<dnnl_scratchpad_mode_t> scratchpad_mode {
53             dnnl_scratchpad_mode_library};
54     attr_t attr = {};
55 
56     const char *perf_template_csv
57             = "perf,%engine%,%impl%,%sdt%,%ddt%,%stag%,%dtag%,%alg%,%attr%,"
58               "%DESC%,%-time%,%0time%";
59     const char *perf_template_def
60             = "perf,%engine%,%impl%,%prb%,%-time%,%0time%";
61     const char *perf_template = perf_template_def;
62 
resetbinary::settings_t63     void reset() { *this = settings_t(perf_template); }
64 };
65 
66 struct prb_t {
prb_tbinary::prb_t67     prb_t(const std::vector<dims_t> &sdims,
68             const std::vector<dnnl_data_type_t> &sdt, dnnl_data_type_t ddt,
69             const std::vector<std::string> &stag, std::string dtag, alg_t alg,
70             bool inplace, const attr_t &attr)
71         : sdims(sdims)
72         , sdt(sdt)
73         , ddt(ddt)
74         , stag(stag)
75         , dtag(dtag)
76         , alg(alg)
77         , inplace(inplace)
78         , attr(attr)
79         , ndims({(int)sdims[0].size(), (int)sdims[1].size()}) {}
~prb_tbinary::prb_t80     ~prb_t() {}
81 
82     std::vector<dims_t> sdims;
83     std::vector<dnnl_data_type_t> sdt;
84     dnnl_data_type_t ddt;
85     std::vector<std::string> stag;
86     std::string dtag;
87     alg_t alg;
88     bool inplace;
89     attr_t attr;
90     std::vector<int> ndims;
91 
n_inputsbinary::prb_t92     int n_inputs() const { return 2; }
93 
get_broadcast_maskbinary::prb_t94     int get_broadcast_mask(const dims_t &dims_B, int source_num) const {
95         const dims_t &dims_A = this->sdims[source_num];
96 
97         int broadcast_mask = 0;
98         for (int d = 0; d < ndims[source_num]; ++d)
99             broadcast_mask += dims_A[d] == dims_B[d] ? (1 << d) : 0;
100         return broadcast_mask;
101     }
102 };
103 std::ostream &operator<<(std::ostream &s, const prb_t &prb);
104 
105 struct perf_report_t : public base_perf_report_t {
106     using base_perf_report_t::base_perf_report_t;
107 
reportbinary::perf_report_t108     void report(const prb_t *prb, const res_t *res, const char *prb_str) {
109         p_ = prb;
110         for (size_t d = 0; d < p_->stag.size(); d++)
111             stag_.push_back(normalize_tag(p_->stag[d], p_->ndims[d]));
112         dtag_ = normalize_tag(p_->dtag, p_->ndims[0]);
113         base_report(res, prb_str);
114     }
115 
dump_algbinary::perf_report_t116     void dump_alg(std::ostream &s) const override { s << p_->alg; }
117 
dump_descbinary::perf_report_t118     void dump_desc(std::ostream &s) const override { s << p_->sdims; }
119 
dump_desc_csvbinary::perf_report_t120     void dump_desc_csv(std::ostream &s) const override { s << p_->sdims; }
121 
sdtbinary::perf_report_t122     const std::vector<dnnl_data_type_t> *sdt() const override {
123         return &p_->sdt;
124     }
attrbinary::perf_report_t125     const attr_t *attr() const override { return &p_->attr; }
ddtbinary::perf_report_t126     const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
stagbinary::perf_report_t127     const std::vector<std::string> *stag() const override { return &stag_; }
dtagbinary::perf_report_t128     const std::string *dtag() const override { return &dtag_; }
129 
130 private:
131     const prb_t *p_ = NULL;
132     std::vector<std::string> stag_;
133     std::string dtag_;
134 };
135 
136 int setup_binary_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args,
137         std::vector<dnn_mem_t> &mem_dt, std::vector<dnn_mem_t> &mem_fp,
138         bool only_positive_values = false);
139 
140 void compute_ref(const prb_t *prb, const dnn_mem_t &src0, const dnn_mem_t &src1,
141         const std::vector<dnn_mem_t> &binary_po, dnn_mem_t &dst);
142 
143 int doit(const prb_t *prb, res_t *res);
144 int bench(int argc, char **argv);
145 
146 } // namespace binary
147 
148 #endif
149