1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdio.h>
18 #include <stdlib.h>
19
20 #include <sstream>
21
22 #include "oneapi/dnnl/dnnl.h"
23
24 #include "tests/test_thread.hpp"
25
26 #include "compare.hpp"
27 #include "dnnl_common.hpp"
28 #include "dnnl_memory.hpp"
29
30 #include "binary/binary.hpp"
31 #include "resampling/resampling.hpp"
32
33 namespace resampling {
34
fill_dat(const prb_t * prb,data_kind_t kind,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp,res_t * res)35 int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
36 dnn_mem_t &mem_fp, res_t *res) {
37 const auto nelems = mem_fp.nelems();
38 const auto dt = mem_dt.dt();
39 const int range = 16;
40 const int f_min = 0;
41
42 dnnl::impl::parallel_nd(nelems, [&](int64_t i) {
43 const float gen = ((97 * i) - 19 * kind + 101) % (range + 1);
44 const float value = dt == dnnl_f32 || is_integral_dt(dt)
45 ? (f_min + gen) * (1.0f + 4.0f / range)
46 : (f_min + gen) / range;
47
48 mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
49 });
50
51 SAFE(mem_dt.reorder(mem_fp), WARN);
52
53 return OK;
54 }
55
fill_src(const prb_t * prb,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp,res_t * res)56 int fill_src(
57 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
58 return fill_dat(prb, SRC, mem_dt, mem_fp, res);
59 }
60
fill_dst(const prb_t * prb,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp,res_t * res)61 int fill_dst(
62 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
63 return fill_dat(prb, DST, mem_dt, mem_fp, res);
64 }
65
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & rpd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)66 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
67 dnnl_primitive_desc_t &rpd, res_t *res, dir_t dir,
68 const_dnnl_primitive_desc_t hint) {
69 dnnl_memory_desc_t src_d, dst_d;
70
71 dnnl_dims_t src_1d_dims = {prb->mb, prb->ic, prb->iw};
72 dnnl_dims_t src_2d_dims = {prb->mb, prb->ic, prb->ih, prb->iw};
73 dnnl_dims_t src_3d_dims = {prb->mb, prb->ic, prb->id, prb->ih, prb->iw};
74 dnnl_dim_t *src_dims = prb->ndims == 5
75 ? src_3d_dims
76 : prb->ndims == 4 ? src_2d_dims : src_1d_dims;
77
78 dnnl_dims_t dst_1d_dims = {prb->mb, prb->ic, prb->ow};
79 dnnl_dims_t dst_2d_dims = {prb->mb, prb->ic, prb->oh, prb->ow};
80 dnnl_dims_t dst_3d_dims = {prb->mb, prb->ic, prb->od, prb->oh, prb->ow};
81 dnnl_dim_t *dst_dims = prb->ndims == 5
82 ? dst_3d_dims
83 : prb->ndims == 4 ? dst_2d_dims : dst_1d_dims;
84
85 std::string src_tag = (prb->dir & FLAG_FWD) ? prb->tag : tag::any;
86 std::string dst_tag = (prb->dir & FLAG_BWD) ? prb->tag : tag::any;
87
88 SAFE(init_md(&src_d, prb->ndims, src_dims, prb->sdt, src_tag), CRIT);
89
90 SAFE(init_md(&dst_d, prb->ndims, dst_dims, prb->ddt, dst_tag), CRIT);
91
92 dnnl_alg_kind_t alg = alg2alg_kind(prb->alg);
93 dnnl_resampling_desc_t pd;
94
95 if (prb->dir & FLAG_FWD) {
96 auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
97 : dnnl_forward_training;
98 DNN_SAFE(dnnl_resampling_forward_desc_init(
99 &pd, prop_kind, alg, nullptr, &src_d, &dst_d),
100 WARN);
101 } else {
102 DNN_SAFE(dnnl_resampling_backward_desc_init(
103 &pd, alg, nullptr, &src_d, &dst_d),
104 WARN);
105 }
106
107 dnnl_primitive_desc_t hint_fwd_pd_ {};
108 dnnl_status_t status = dnnl_success;
109 if (prb->dir & FLAG_BWD) {
110 dnnl_memory_desc_t fwd_src_d, fwd_dst_d;
111 SAFE(init_md(&fwd_src_d, prb->ndims, src_dims, prb->sdt, prb->tag),
112 CRIT);
113 SAFE(init_md(&fwd_dst_d, prb->ndims, dst_dims, prb->ddt, tag::any),
114 CRIT);
115
116 dnnl_resampling_desc_t rd_fwd;
117 DNN_SAFE(dnnl_resampling_forward_desc_init(&rd_fwd,
118 dnnl_forward_training, alg, nullptr, &fwd_src_d,
119 &fwd_dst_d),
120 WARN);
121
122 status = dnnl_primitive_desc_create(
123 &hint_fwd_pd_, &rd_fwd, nullptr, engine, nullptr);
124 if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
125 }
126 auto hint_fwd_pd = make_benchdnn_dnnl_wrapper(hint_fwd_pd_);
127 SAFE(status, WARN);
128
129 attr_args_t attr_args;
130 attr_args.prepare_binary_post_op_mds(prb->attr, prb->ndims, dst_dims);
131 const auto dnnl_attr = make_benchdnn_dnnl_wrapper(
132 create_dnnl_attr(prb->attr, attr_args));
133
134 status = dnnl_primitive_desc_create(
135 &rpd, &pd, dnnl_attr, engine, hint_fwd_pd);
136
137 if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
138 SAFE(status, WARN);
139
140 res->impl_name = query_impl_info(rpd);
141 if (maybe_skip(res->impl_name)) {
142 BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
143 res->impl_name.c_str());
144 return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
145 } else {
146 BENCHDNN_PRINT(
147 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
148 }
149
150 return OK;
151 }
152
check_known_skipped_case(const prb_t * prb,res_t * res)153 void check_known_skipped_case(const prb_t *prb, res_t *res) {
154 check_known_skipped_case_common({prb->sdt, prb->ddt}, prb->dir, res);
155
156 if (res->state == SKIPPED) return;
157
158 if (is_nvidia_gpu()) {
159 const bool dt_ok = prb->sdt != dnnl_s8 && prb->ddt != dnnl_s8;
160 if (prb->ndims == 5 || prb->alg == nearest || !prb->attr.is_def()
161 || !dt_ok) {
162 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
163 return;
164 }
165 }
166 }
167
168 /* The following issue takes place for integer data types:
169 * Sometimes there are differences in the order of operations between
170 * the version of the algorithm implemented in the kernel and the reference
171 * algorithm. Therefore, this function is especially important if the
172 * destination data type is an integer, because when the floating-point
173 * type is used to compute the algorithm and if the returned value is very
174 * close to x.5, there may be a difference between the output value of
175 * reference and the kernel, as one version may round up and the other down.
176 * Therefore, we can assume that two values are equal to each other when:
177 * - there is a difference in the order of operations,
178 * - and the output value of the algorithm is very close to x.5,
179 * - and the difference between the output value of reference and expected is 1,
180 * - and the output type is an integer type */
add_additional_check_to_compare(compare::compare_t & cmp)181 void add_additional_check_to_compare(compare::compare_t &cmp) {
182 using cmp_args_t = compare::compare_t::driver_check_func_args_t;
183 cmp.set_driver_check_function([&](const cmp_args_t &args) -> bool {
184 if (!is_integral_dt(args.dt)) return false;
185 // Check that original value is close to x.5f
186 static constexpr float small_eps = 9e-6;
187 if (fabsf((floorf(args.exp_f32) + 0.5f) - args.exp_f32) >= small_eps)
188 return false;
189 // If it was, check that exp and got values reside on opposite sides of it.
190 if (args.exp == floorf(args.exp_f32))
191 return args.got == ceilf(args.exp_f32);
192 else if (args.exp == ceilf(args.exp_f32))
193 return args.got == floorf(args.exp_f32);
194 else {
195 assert(!"unexpected scenario");
196 return false;
197 }
198 });
199 }
200
doit(const prb_t * prb,res_t * res)201 int doit(const prb_t *prb, res_t *res) {
202 if (bench_mode == LIST) return res->state = LISTED, OK;
203
204 check_known_skipped_case(prb, res);
205 if (res->state == SKIPPED) return OK;
206
207 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
208 SAFE(init_prim(prim, init_pd, prb, res), WARN);
209 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
210
211 const_dnnl_primitive_desc_t const_pd;
212 DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
213
214 if (check_mem_size(const_pd) != OK) {
215 return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
216 }
217
218 const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
219 return *dnnl_primitive_desc_query_md(
220 const_pd, dnnl_query_exec_arg_md, index);
221 };
222
223 const auto &src_md
224 = prb->dir == BWD_D ? q(DNNL_ARG_DIFF_SRC) : q(DNNL_ARG_SRC);
225 const auto &dst_md
226 = prb->dir == BWD_D ? q(DNNL_ARG_DIFF_DST) : q(DNNL_ARG_DST);
227 const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
228
229 const auto fp = dnnl_f32;
230 const auto tag = tag::abx;
231
232 const auto &test_engine = get_test_engine();
233
234 dnn_mem_t src_fp(src_md, fp, tag, test_engine);
235 dnn_mem_t src_dt(src_md, test_engine);
236
237 dnn_mem_t dst_fp(dst_md, fp, tag, test_engine);
238 dnn_mem_t dst_dt(dst_md, test_engine);
239 if (prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM) >= 0)
240 SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
241
242 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
243 std::vector<int> binary_po_args;
244 // When post-ops occur, the relative difference can change
245 // between the output from reference and the kernel. The compare
246 // function usually uses to compare a relative difference.
247 // Therefore, we should not lead to a situation where the
248 // relative difference is very small after executing a
249 // post-ops operation. Therefore, all values for binary post_ops
250 // are positive when the linear algorithm is present. This is
251 // important because there may be small differences in the result
252 // between the expected value and the gotten value with this algorithm.
253 const bool only_positive_values = prb->alg == linear;
254 SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
255 binary_po_fp, only_positive_values),
256 WARN);
257
258 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
259
260 args_t args;
261
262 compare::compare_t cmp;
263 const bool operations_order_can_be_different = prb->alg == linear;
264 if (operations_order_can_be_different) add_additional_check_to_compare(cmp);
265
266 if (prb->dir & FLAG_FWD) {
267 SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
268 args.set(DNNL_ARG_SRC, src_dt);
269 args.set(DNNL_ARG_DST, dst_dt);
270 args.set(binary_po_args, binary_po_dt);
271
272 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
273
274 SAFE(execute_and_wait(prim, args), WARN);
275
276 if (is_bench_mode(CORR)) {
277 compute_ref_fwd(prb, src_fp, dst_fp, binary_po_fp);
278 float trh = prb->alg == nearest ? 0.f : 3 * epsilon_dt(prb->ddt);
279
280 if (is_nvidia_gpu()) {
281 // cuDNN precision is different from ref one due to different
282 // computation algorithm used for resampling.
283 trh = prb->ddt == dnnl_f16 ? 4e-2 : 2e-5;
284 }
285
286 cmp.set_threshold(trh);
287 // No sense to test zero trust for upsampling since it produces
288 // valid zeros.
289 // TODO: validate this once again.
290 cmp.set_zero_trust_percent(100.f);
291 SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
292 }
293 } else {
294 SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
295 args.set(DNNL_ARG_DIFF_DST, dst_dt);
296 args.set(DNNL_ARG_DIFF_SRC, src_dt);
297 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
298
299 SAFE(execute_and_wait(prim, args), WARN);
300
301 if (is_bench_mode(CORR)) {
302 compute_ref_bwd(prb, src_fp, dst_fp);
303 float trh = prb->alg == nearest ? 0.f : 6 * epsilon_dt(prb->sdt);
304 // cuDNN precision is different from ref one due to different
305 // computation algorithm used for resampling.
306 if (is_nvidia_gpu()) trh = 2e-5;
307
308 cmp.set_threshold(trh);
309 // No sense to test zero trust for upsampling since it produces
310 // valid zeros.
311 // TODO: validate this once again.
312 cmp.set_zero_trust_percent(100.f);
313 SAFE(cmp.compare(src_fp, src_dt, prb->attr, res), WARN);
314 }
315 }
316
317 return measure_perf(res->timer, prim, args);
318 }
319
320 } // namespace resampling
321