1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 
20 #include <sstream>
21 
22 #include "oneapi/dnnl/dnnl.h"
23 
24 #include "tests/test_thread.hpp"
25 
26 #include "compare.hpp"
27 #include "dnnl_common.hpp"
28 #include "dnnl_memory.hpp"
29 
30 #include "binary/binary.hpp"
31 #include "resampling/resampling.hpp"
32 
33 namespace resampling {
34 
fill_dat(const prb_t * prb,data_kind_t kind,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp,res_t * res)35 int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
36         dnn_mem_t &mem_fp, res_t *res) {
37     const auto nelems = mem_fp.nelems();
38     const auto dt = mem_dt.dt();
39     const int range = 16;
40     const int f_min = 0;
41 
42     dnnl::impl::parallel_nd(nelems, [&](int64_t i) {
43         const float gen = ((97 * i) - 19 * kind + 101) % (range + 1);
44         const float value = dt == dnnl_f32 || is_integral_dt(dt)
45                 ? (f_min + gen) * (1.0f + 4.0f / range)
46                 : (f_min + gen) / range;
47 
48         mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
49     });
50 
51     SAFE(mem_dt.reorder(mem_fp), WARN);
52 
53     return OK;
54 }
55 
fill_src(const prb_t * prb,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp,res_t * res)56 int fill_src(
57         const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
58     return fill_dat(prb, SRC, mem_dt, mem_fp, res);
59 }
60 
fill_dst(const prb_t * prb,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp,res_t * res)61 int fill_dst(
62         const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
63     return fill_dat(prb, DST, mem_dt, mem_fp, res);
64 }
65 
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & rpd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)66 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
67         dnnl_primitive_desc_t &rpd, res_t *res, dir_t dir,
68         const_dnnl_primitive_desc_t hint) {
69     dnnl_memory_desc_t src_d, dst_d;
70 
71     dnnl_dims_t src_1d_dims = {prb->mb, prb->ic, prb->iw};
72     dnnl_dims_t src_2d_dims = {prb->mb, prb->ic, prb->ih, prb->iw};
73     dnnl_dims_t src_3d_dims = {prb->mb, prb->ic, prb->id, prb->ih, prb->iw};
74     dnnl_dim_t *src_dims = prb->ndims == 5
75             ? src_3d_dims
76             : prb->ndims == 4 ? src_2d_dims : src_1d_dims;
77 
78     dnnl_dims_t dst_1d_dims = {prb->mb, prb->ic, prb->ow};
79     dnnl_dims_t dst_2d_dims = {prb->mb, prb->ic, prb->oh, prb->ow};
80     dnnl_dims_t dst_3d_dims = {prb->mb, prb->ic, prb->od, prb->oh, prb->ow};
81     dnnl_dim_t *dst_dims = prb->ndims == 5
82             ? dst_3d_dims
83             : prb->ndims == 4 ? dst_2d_dims : dst_1d_dims;
84 
85     std::string src_tag = (prb->dir & FLAG_FWD) ? prb->tag : tag::any;
86     std::string dst_tag = (prb->dir & FLAG_BWD) ? prb->tag : tag::any;
87 
88     SAFE(init_md(&src_d, prb->ndims, src_dims, prb->sdt, src_tag), CRIT);
89 
90     SAFE(init_md(&dst_d, prb->ndims, dst_dims, prb->ddt, dst_tag), CRIT);
91 
92     dnnl_alg_kind_t alg = alg2alg_kind(prb->alg);
93     dnnl_resampling_desc_t pd;
94 
95     if (prb->dir & FLAG_FWD) {
96         auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
97                                              : dnnl_forward_training;
98         DNN_SAFE(dnnl_resampling_forward_desc_init(
99                          &pd, prop_kind, alg, nullptr, &src_d, &dst_d),
100                 WARN);
101     } else {
102         DNN_SAFE(dnnl_resampling_backward_desc_init(
103                          &pd, alg, nullptr, &src_d, &dst_d),
104                 WARN);
105     }
106 
107     dnnl_primitive_desc_t hint_fwd_pd_ {};
108     dnnl_status_t status = dnnl_success;
109     if (prb->dir & FLAG_BWD) {
110         dnnl_memory_desc_t fwd_src_d, fwd_dst_d;
111         SAFE(init_md(&fwd_src_d, prb->ndims, src_dims, prb->sdt, prb->tag),
112                 CRIT);
113         SAFE(init_md(&fwd_dst_d, prb->ndims, dst_dims, prb->ddt, tag::any),
114                 CRIT);
115 
116         dnnl_resampling_desc_t rd_fwd;
117         DNN_SAFE(dnnl_resampling_forward_desc_init(&rd_fwd,
118                          dnnl_forward_training, alg, nullptr, &fwd_src_d,
119                          &fwd_dst_d),
120                 WARN);
121 
122         status = dnnl_primitive_desc_create(
123                 &hint_fwd_pd_, &rd_fwd, nullptr, engine, nullptr);
124         if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
125     }
126     auto hint_fwd_pd = make_benchdnn_dnnl_wrapper(hint_fwd_pd_);
127     SAFE(status, WARN);
128 
129     attr_args_t attr_args;
130     attr_args.prepare_binary_post_op_mds(prb->attr, prb->ndims, dst_dims);
131     const auto dnnl_attr = make_benchdnn_dnnl_wrapper(
132             create_dnnl_attr(prb->attr, attr_args));
133 
134     status = dnnl_primitive_desc_create(
135             &rpd, &pd, dnnl_attr, engine, hint_fwd_pd);
136 
137     if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
138     SAFE(status, WARN);
139 
140     res->impl_name = query_impl_info(rpd);
141     if (maybe_skip(res->impl_name)) {
142         BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
143                 res->impl_name.c_str());
144         return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
145     } else {
146         BENCHDNN_PRINT(
147                 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
148     }
149 
150     return OK;
151 }
152 
check_known_skipped_case(const prb_t * prb,res_t * res)153 void check_known_skipped_case(const prb_t *prb, res_t *res) {
154     check_known_skipped_case_common({prb->sdt, prb->ddt}, prb->dir, res);
155 
156     if (res->state == SKIPPED) return;
157 
158     if (is_nvidia_gpu()) {
159         const bool dt_ok = prb->sdt != dnnl_s8 && prb->ddt != dnnl_s8;
160         if (prb->ndims == 5 || prb->alg == nearest || !prb->attr.is_def()
161                 || !dt_ok) {
162             res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
163             return;
164         }
165     }
166 }
167 
168 /* The following issue takes place for integer data types:
169  * Sometimes there are differences in the order of operations between
170  * the version of the algorithm implemented in the kernel and the reference
171  * algorithm. Therefore, this function is especially important if the
172  * destination data type is an integer, because when the floating-point
173  * type is used to compute the algorithm and if the returned value is very
174  * close to x.5, there may be a difference between the output value of
175  * reference and the kernel, as one version may round up and the other down.
176  * Therefore, we can assume that two values are equal to each other when:
177  * - there is a difference in the order of operations,
178  * - and the output value of the algorithm is very close to x.5,
179  * - and the difference between the output value of reference and expected is 1,
180  * - and the output type is an integer type */
add_additional_check_to_compare(compare::compare_t & cmp)181 void add_additional_check_to_compare(compare::compare_t &cmp) {
182     using cmp_args_t = compare::compare_t::driver_check_func_args_t;
183     cmp.set_driver_check_function([&](const cmp_args_t &args) -> bool {
184         if (!is_integral_dt(args.dt)) return false;
185         // Check that original value is close to x.5f
186         static constexpr float small_eps = 9e-6;
187         if (fabsf((floorf(args.exp_f32) + 0.5f) - args.exp_f32) >= small_eps)
188             return false;
189         // If it was, check that exp and got values reside on opposite sides of it.
190         if (args.exp == floorf(args.exp_f32))
191             return args.got == ceilf(args.exp_f32);
192         else if (args.exp == ceilf(args.exp_f32))
193             return args.got == floorf(args.exp_f32);
194         else {
195             assert(!"unexpected scenario");
196             return false;
197         }
198     });
199 }
200 
doit(const prb_t * prb,res_t * res)201 int doit(const prb_t *prb, res_t *res) {
202     if (bench_mode == LIST) return res->state = LISTED, OK;
203 
204     check_known_skipped_case(prb, res);
205     if (res->state == SKIPPED) return OK;
206 
207     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
208     SAFE(init_prim(prim, init_pd, prb, res), WARN);
209     if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
210 
211     const_dnnl_primitive_desc_t const_pd;
212     DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
213 
214     if (check_mem_size(const_pd) != OK) {
215         return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
216     }
217 
218     const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
219         return *dnnl_primitive_desc_query_md(
220                 const_pd, dnnl_query_exec_arg_md, index);
221     };
222 
223     const auto &src_md
224             = prb->dir == BWD_D ? q(DNNL_ARG_DIFF_SRC) : q(DNNL_ARG_SRC);
225     const auto &dst_md
226             = prb->dir == BWD_D ? q(DNNL_ARG_DIFF_DST) : q(DNNL_ARG_DST);
227     const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
228 
229     const auto fp = dnnl_f32;
230     const auto tag = tag::abx;
231 
232     const auto &test_engine = get_test_engine();
233 
234     dnn_mem_t src_fp(src_md, fp, tag, test_engine);
235     dnn_mem_t src_dt(src_md, test_engine);
236 
237     dnn_mem_t dst_fp(dst_md, fp, tag, test_engine);
238     dnn_mem_t dst_dt(dst_md, test_engine);
239     if (prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM) >= 0)
240         SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
241 
242     std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
243     std::vector<int> binary_po_args;
244     // When post-ops occur, the relative difference can change
245     // between the output from reference and the kernel. The compare
246     // function usually uses to compare a relative difference.
247     // Therefore, we should not lead to a situation where the
248     // relative difference is very small after executing a
249     // post-ops operation. Therefore, all values for binary post_ops
250     // are positive when the linear algorithm is present. This is
251     // important because there may be small differences in the result
252     // between the expected value and the gotten value with this algorithm.
253     const bool only_positive_values = prb->alg == linear;
254     SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
255                  binary_po_fp, only_positive_values),
256             WARN);
257 
258     dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
259 
260     args_t args;
261 
262     compare::compare_t cmp;
263     const bool operations_order_can_be_different = prb->alg == linear;
264     if (operations_order_can_be_different) add_additional_check_to_compare(cmp);
265 
266     if (prb->dir & FLAG_FWD) {
267         SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
268         args.set(DNNL_ARG_SRC, src_dt);
269         args.set(DNNL_ARG_DST, dst_dt);
270         args.set(binary_po_args, binary_po_dt);
271 
272         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
273 
274         SAFE(execute_and_wait(prim, args), WARN);
275 
276         if (is_bench_mode(CORR)) {
277             compute_ref_fwd(prb, src_fp, dst_fp, binary_po_fp);
278             float trh = prb->alg == nearest ? 0.f : 3 * epsilon_dt(prb->ddt);
279 
280             if (is_nvidia_gpu()) {
281                 // cuDNN precision is different from ref one due to different
282                 // computation algorithm used for resampling.
283                 trh = prb->ddt == dnnl_f16 ? 4e-2 : 2e-5;
284             }
285 
286             cmp.set_threshold(trh);
287             // No sense to test zero trust for upsampling since it produces
288             // valid zeros.
289             // TODO: validate this once again.
290             cmp.set_zero_trust_percent(100.f);
291             SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
292         }
293     } else {
294         SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
295         args.set(DNNL_ARG_DIFF_DST, dst_dt);
296         args.set(DNNL_ARG_DIFF_SRC, src_dt);
297         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
298 
299         SAFE(execute_and_wait(prim, args), WARN);
300 
301         if (is_bench_mode(CORR)) {
302             compute_ref_bwd(prb, src_fp, dst_fp);
303             float trh = prb->alg == nearest ? 0.f : 6 * epsilon_dt(prb->sdt);
304             // cuDNN precision is different from ref one due to different
305             // computation algorithm used for resampling.
306             if (is_nvidia_gpu()) trh = 2e-5;
307 
308             cmp.set_threshold(trh);
309             // No sense to test zero trust for upsampling since it produces
310             // valid zeros.
311             // TODO: validate this once again.
312             cmp.set_zero_trust_percent(100.f);
313             SAFE(cmp.compare(src_fp, src_dt, prb->attr, res), WARN);
314         }
315     }
316 
317     return measure_perf(res->timer, prim, args);
318 }
319 
320 } // namespace resampling
321