1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef RESAMPLING_HPP
18 #define RESAMPLING_HPP
19
20 #include <assert.h>
21 #include <limits.h>
22 #include <stdint.h>
23
24 #include <iostream>
25
26 #include "common.hpp"
27 #include "dnn_types.hpp"
28 #include "dnnl_common.hpp"
29 #include "dnnl_memory.hpp"
30 #include "utils/perf_report.hpp"
31
32 namespace resampling {
33
34 enum alg_t {
35 undef,
36 nearest,
37 linear,
38 resampling_nearest = nearest,
39 resampling_linear = linear,
40 };
41 alg_t str2alg(const char *str);
42 const char *alg2str(alg_t alg);
43 dnnl_alg_kind_t alg2alg_kind(alg_t alg);
44
45 struct desc_t {
46 int64_t mb, ic;
47 int64_t id, ih, iw;
48 int64_t od, oh, ow;
49 const char *name;
50 int ndims;
51 };
52
53 int str2desc(desc_t *desc, const char *str);
54 std::ostream &operator<<(std::ostream &s, const desc_t &d);
55
56 struct settings_t {
57 settings_t() = default;
58
59 // ctor to save certain fields from resetting
settings_tresampling::settings_t60 settings_t(const char *perf_template) : settings_t() {
61 this->perf_template = perf_template;
62 }
63
64 desc_t desc {};
65
66 std::vector<dir_t> dir {FWD_D};
67 std::vector<dnnl_data_type_t> sdt {dnnl_f32};
68 std::vector<dnnl_data_type_t> ddt {dnnl_f32};
69 std::vector<std::string> tag {tag::abx};
70 std::vector<alg_t> alg {nearest};
71 std::vector<attr_t::post_ops_t> post_ops {attr_t::post_ops_t()};
72 std::vector<dnnl_scratchpad_mode_t> scratchpad_mode {
73 dnnl_scratchpad_mode_library};
74 std::vector<int64_t> mb {0};
75
76 const char *perf_template_csv
77 = "perf,%engine%,%impl%,%name%,%dir%,%sdt%,%ddt%,%tag%,%alg%,%DESC%"
78 ",%-time%,%0time%";
79 const char *perf_template_def
80 = "perf,%engine%,%impl%,%name%,%prb%,%-time%,%0time%";
81 const char *perf_template = perf_template_def;
82
resetresampling::settings_t83 void reset() { *this = settings_t(perf_template); }
84 };
85
86 struct prb_t : public desc_t {
prb_tresampling::prb_t87 prb_t(const desc_t &desc, dir_t dir, dnnl_data_type_t sdt,
88 dnnl_data_type_t ddt, const std::string &tag, alg_t alg,
89 const attr_t &attr, int64_t mb = 0)
90 : desc_t(desc)
91 , dir(dir)
92 , sdt(sdt)
93 , ddt(ddt)
94 , tag(tag)
95 , alg(alg)
96 , attr(attr)
97 , user_mb(mb) {
98 if (mb) this->mb = mb;
99 }
~prb_tresampling::prb_t100 ~prb_t() {}
101
102 dir_t dir;
103 dnnl_data_type_t sdt, ddt;
104 std::string tag;
105 alg_t alg;
106 attr_t attr;
107 int64_t user_mb;
108
109 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
110 };
111 std::ostream &operator<<(std::ostream &s, const prb_t &prb);
112
113 struct perf_report_t : public base_perf_report_t {
perf_report_tresampling::perf_report_t114 perf_report_t(const prb_t *prb, const char *perf_template)
115 : base_perf_report_t(perf_template)
116 , p_(prb)
117 , sdt_({prb->sdt})
118 , tag_(normalize_tag(p_->tag, p_->ndims)) {}
119
dump_algresampling::perf_report_t120 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
121
dump_descresampling::perf_report_t122 void dump_desc(std::ostream &s) const override {
123 s << static_cast<const desc_t &>(*p_);
124 }
125
dump_desc_csvresampling::perf_report_t126 void dump_desc_csv(std::ostream &s) const override {
127 s << p_->mb << ','
128
129 << p_->ic << ',' << p_->id << ',' << p_->ih << ',' << p_->iw << ','
130
131 << p_->od << ',' << p_->oh << ',' << p_->ow;
132 }
133
user_mbresampling::perf_report_t134 const int64_t *user_mb() const override { return &p_->user_mb; }
nameresampling::perf_report_t135 const char *name() const override { return p_->name; }
dirresampling::perf_report_t136 const dir_t *dir() const override { return &p_->dir; }
sdtresampling::perf_report_t137 const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
ddtresampling::perf_report_t138 const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
tagresampling::perf_report_t139 const std::string *tag() const override { return &tag_; }
140
141 private:
142 const prb_t *p_;
143 std::vector<dnnl_data_type_t> sdt_;
144 std::string tag_;
145 };
146
src_off_f(const prb_t * prb,int64_t mb,int64_t ic,int64_t id,int64_t ih,int64_t iw)147 inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id,
148 int64_t ih, int64_t iw) {
149 return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw;
150 }
151
dst_off_f(const prb_t * prb,int64_t mb,int64_t ic,int64_t od,int64_t oh,int64_t ow)152 inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t od,
153 int64_t oh, int64_t ow) {
154 return (((mb * prb->ic + ic) * prb->od + od) * prb->oh + oh) * prb->ow + ow;
155 }
156
157 void compute_ref_fwd(const prb_t *prb, const dnn_mem_t &src, dnn_mem_t &dst,
158 const std::vector<dnn_mem_t> &binary_po);
159 void compute_ref_bwd(
160 const prb_t *prb, dnn_mem_t &diff_src, const dnn_mem_t &diff_dst);
161
162 int compare_src(
163 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
164 int compare_dst(
165 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
166 int fill_dat(
167 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
168
169 int doit(const prb_t *prb, res_t *res);
170 int bench(int argc, char **argv);
171
172 } // namespace resampling
173
174 #endif
175