1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef RESAMPLING_HPP
18 #define RESAMPLING_HPP
19 
20 #include <assert.h>
21 #include <limits.h>
22 #include <stdint.h>
23 
24 #include <iostream>
25 
26 #include "common.hpp"
27 #include "dnn_types.hpp"
28 #include "dnnl_common.hpp"
29 #include "dnnl_memory.hpp"
30 #include "utils/perf_report.hpp"
31 
32 namespace resampling {
33 
34 enum alg_t {
35     undef,
36     nearest,
37     linear,
38     resampling_nearest = nearest,
39     resampling_linear = linear,
40 };
41 alg_t str2alg(const char *str);
42 const char *alg2str(alg_t alg);
43 dnnl_alg_kind_t alg2alg_kind(alg_t alg);
44 
45 struct desc_t {
46     int64_t mb, ic;
47     int64_t id, ih, iw;
48     int64_t od, oh, ow;
49     const char *name;
50     int ndims;
51 };
52 
53 int str2desc(desc_t *desc, const char *str);
54 std::ostream &operator<<(std::ostream &s, const desc_t &d);
55 
56 struct settings_t {
57     settings_t() = default;
58 
59     // ctor to save certain fields from resetting
settings_tresampling::settings_t60     settings_t(const char *perf_template) : settings_t() {
61         this->perf_template = perf_template;
62     }
63 
64     desc_t desc {};
65 
66     std::vector<dir_t> dir {FWD_D};
67     std::vector<dnnl_data_type_t> sdt {dnnl_f32};
68     std::vector<dnnl_data_type_t> ddt {dnnl_f32};
69     std::vector<std::string> tag {tag::abx};
70     std::vector<alg_t> alg {nearest};
71     std::vector<attr_t::post_ops_t> post_ops {attr_t::post_ops_t()};
72     std::vector<dnnl_scratchpad_mode_t> scratchpad_mode {
73             dnnl_scratchpad_mode_library};
74     std::vector<int64_t> mb {0};
75 
76     const char *perf_template_csv
77             = "perf,%engine%,%impl%,%name%,%dir%,%sdt%,%ddt%,%tag%,%alg%,%DESC%"
78               ",%-time%,%0time%";
79     const char *perf_template_def
80             = "perf,%engine%,%impl%,%name%,%prb%,%-time%,%0time%";
81     const char *perf_template = perf_template_def;
82 
resetresampling::settings_t83     void reset() { *this = settings_t(perf_template); }
84 };
85 
86 struct prb_t : public desc_t {
prb_tresampling::prb_t87     prb_t(const desc_t &desc, dir_t dir, dnnl_data_type_t sdt,
88             dnnl_data_type_t ddt, const std::string &tag, alg_t alg,
89             const attr_t &attr, int64_t mb = 0)
90         : desc_t(desc)
91         , dir(dir)
92         , sdt(sdt)
93         , ddt(ddt)
94         , tag(tag)
95         , alg(alg)
96         , attr(attr)
97         , user_mb(mb) {
98         if (mb) this->mb = mb;
99     }
~prb_tresampling::prb_t100     ~prb_t() {}
101 
102     dir_t dir;
103     dnnl_data_type_t sdt, ddt;
104     std::string tag;
105     alg_t alg;
106     attr_t attr;
107     int64_t user_mb;
108 
109     BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
110 };
111 std::ostream &operator<<(std::ostream &s, const prb_t &prb);
112 
113 struct perf_report_t : public base_perf_report_t {
perf_report_tresampling::perf_report_t114     perf_report_t(const prb_t *prb, const char *perf_template)
115         : base_perf_report_t(perf_template)
116         , p_(prb)
117         , sdt_({prb->sdt})
118         , tag_(normalize_tag(p_->tag, p_->ndims)) {}
119 
dump_algresampling::perf_report_t120     void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
121 
dump_descresampling::perf_report_t122     void dump_desc(std::ostream &s) const override {
123         s << static_cast<const desc_t &>(*p_);
124     }
125 
dump_desc_csvresampling::perf_report_t126     void dump_desc_csv(std::ostream &s) const override {
127         s << p_->mb << ','
128 
129           << p_->ic << ',' << p_->id << ',' << p_->ih << ',' << p_->iw << ','
130 
131           << p_->od << ',' << p_->oh << ',' << p_->ow;
132     }
133 
user_mbresampling::perf_report_t134     const int64_t *user_mb() const override { return &p_->user_mb; }
nameresampling::perf_report_t135     const char *name() const override { return p_->name; }
dirresampling::perf_report_t136     const dir_t *dir() const override { return &p_->dir; }
sdtresampling::perf_report_t137     const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
ddtresampling::perf_report_t138     const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
tagresampling::perf_report_t139     const std::string *tag() const override { return &tag_; }
140 
141 private:
142     const prb_t *p_;
143     std::vector<dnnl_data_type_t> sdt_;
144     std::string tag_;
145 };
146 
src_off_f(const prb_t * prb,int64_t mb,int64_t ic,int64_t id,int64_t ih,int64_t iw)147 inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id,
148         int64_t ih, int64_t iw) {
149     return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw;
150 }
151 
dst_off_f(const prb_t * prb,int64_t mb,int64_t ic,int64_t od,int64_t oh,int64_t ow)152 inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t od,
153         int64_t oh, int64_t ow) {
154     return (((mb * prb->ic + ic) * prb->od + od) * prb->oh + oh) * prb->ow + ow;
155 }
156 
157 void compute_ref_fwd(const prb_t *prb, const dnn_mem_t &src, dnn_mem_t &dst,
158         const std::vector<dnn_mem_t> &binary_po);
159 void compute_ref_bwd(
160         const prb_t *prb, dnn_mem_t &diff_src, const dnn_mem_t &diff_dst);
161 
162 int compare_src(
163         const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
164 int compare_dst(
165         const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
166 int fill_dat(
167         const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
168 
169 int doit(const prb_t *prb, res_t *res);
170 int bench(int argc, char **argv);
171 
172 } // namespace resampling
173 
174 #endif
175