1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef BINARY_HPP 18 #define BINARY_HPP 19 20 #include <iostream> 21 22 #include "oneapi/dnnl/dnnl.h" 23 24 #include "common.hpp" 25 #include "dnn_types.hpp" 26 #include "dnnl_common.hpp" 27 #include "dnnl_memory.hpp" 28 #include "perf_report.hpp" 29 30 namespace binary { 31 32 using alg_t = attr_t::post_ops_t::kind_t; 33 34 struct settings_t { 35 settings_t() = default; 36 37 // ctor to save certain fields from resetting settings_tbinary::settings_t38 settings_t(const char *perf_template) : settings_t() { 39 this->perf_template = perf_template; 40 } 41 42 std::vector<dims_t> sdims; 43 44 std::vector<std::vector<dnnl_data_type_t>> sdt {{dnnl_f32, dnnl_f32}}; 45 std::vector<dnnl_data_type_t> ddt {dnnl_f32}; 46 std::vector<std::vector<std::string>> stag {{tag::abx, tag::abx}}; 47 std::vector<std::string> dtag {tag::any}; 48 std::vector<alg_t> alg {alg_t::ADD}; 49 std::vector<bool> inplace {false}; 50 std::vector<attr_t::arg_scales_t> scales {attr_t::arg_scales_t()}; 51 std::vector<attr_t::post_ops_t> post_ops {attr_t::post_ops_t()}; 52 std::vector<dnnl_scratchpad_mode_t> scratchpad_mode { 53 dnnl_scratchpad_mode_library}; 54 attr_t attr = {}; 55 56 const char *perf_template_csv 57 = "perf,%engine%,%impl%,%sdt%,%ddt%,%stag%,%dtag%,%alg%,%attr%," 58 "%DESC%,%-time%,%0time%"; 59 const char *perf_template_def 60 = "perf,%engine%,%impl%,%prb%,%-time%,%0time%"; 61 const char *perf_template = perf_template_def; 62 resetbinary::settings_t63 void reset() { *this = settings_t(perf_template); } 64 }; 65 66 struct prb_t { prb_tbinary::prb_t67 prb_t(const std::vector<dims_t> &sdims, 68 const std::vector<dnnl_data_type_t> &sdt, dnnl_data_type_t ddt, 69 const std::vector<std::string> &stag, std::string dtag, alg_t alg, 70 bool inplace, const attr_t &attr) 71 : sdims(sdims) 72 , sdt(sdt) 73 , ddt(ddt) 74 , stag(stag) 75 , dtag(dtag) 76 , alg(alg) 77 , inplace(inplace) 78 , attr(attr) 79 , ndims({(int)sdims[0].size(), (int)sdims[1].size()}) {} ~prb_tbinary::prb_t80 ~prb_t() {} 81 82 std::vector<dims_t> sdims; 83 std::vector<dnnl_data_type_t> sdt; 84 dnnl_data_type_t ddt; 85 std::vector<std::string> stag; 86 std::string dtag; 87 alg_t alg; 88 bool inplace; 89 attr_t attr; 90 std::vector<int> ndims; 91 n_inputsbinary::prb_t92 int n_inputs() const { return 2; } 93 get_broadcast_maskbinary::prb_t94 int get_broadcast_mask(const dims_t &dims_B, int source_num) const { 95 const dims_t &dims_A = this->sdims[source_num]; 96 97 int broadcast_mask = 0; 98 for (int d = 0; d < ndims[source_num]; ++d) 99 broadcast_mask += dims_A[d] == dims_B[d] ? (1 << d) : 0; 100 return broadcast_mask; 101 } 102 }; 103 std::ostream &operator<<(std::ostream &s, const prb_t &prb); 104 105 struct perf_report_t : public base_perf_report_t { 106 using base_perf_report_t::base_perf_report_t; 107 reportbinary::perf_report_t108 void report(const prb_t *prb, const res_t *res, const char *prb_str) { 109 p_ = prb; 110 for (size_t d = 0; d < p_->stag.size(); d++) 111 stag_.push_back(normalize_tag(p_->stag[d], p_->ndims[d])); 112 dtag_ = normalize_tag(p_->dtag, p_->ndims[0]); 113 base_report(res, prb_str); 114 } 115 dump_algbinary::perf_report_t116 void dump_alg(std::ostream &s) const override { s << p_->alg; } 117 dump_descbinary::perf_report_t118 void dump_desc(std::ostream &s) const override { s << p_->sdims; } 119 dump_desc_csvbinary::perf_report_t120 void dump_desc_csv(std::ostream &s) const override { s << p_->sdims; } 121 sdtbinary::perf_report_t122 const std::vector<dnnl_data_type_t> *sdt() const override { 123 return &p_->sdt; 124 } attrbinary::perf_report_t125 const attr_t *attr() const override { return &p_->attr; } ddtbinary::perf_report_t126 const dnnl_data_type_t *ddt() const override { return &p_->ddt; } stagbinary::perf_report_t127 const std::vector<std::string> *stag() const override { return &stag_; } dtagbinary::perf_report_t128 const std::string *dtag() const override { return &dtag_; } 129 130 private: 131 const prb_t *p_ = NULL; 132 std::vector<std::string> stag_; 133 std::string dtag_; 134 }; 135 136 int setup_binary_po(const_dnnl_primitive_desc_t pd, std::vector<int> &args, 137 std::vector<dnn_mem_t> &mem_dt, std::vector<dnn_mem_t> &mem_fp, 138 bool only_positive_values = false); 139 140 void compute_ref(const prb_t *prb, const dnn_mem_t &src0, const dnn_mem_t &src1, 141 const std::vector<dnn_mem_t> &binary_po, dnn_mem_t &dst); 142 143 int doit(const prb_t *prb, res_t *res); 144 int bench(int argc, char **argv); 145 146 } // namespace binary 147 148 #endif 149