1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21
22 #include "oneapi/dnnl/dnnl.h"
23
24 #include "tests/test_thread.hpp"
25
26 #include "dnnl_common.hpp"
27 #include "dnnl_memory.hpp"
28
29 #include "norm.hpp"
30
31 #include "binary/binary.hpp"
32 #include "conv/deconv.hpp"
33 using namespace conv;
34
35 namespace deconv {
36
swap(int64_t & a,int64_t & b)37 inline static void swap(int64_t &a, int64_t &b) {
38 int64_t temp = a;
39 a = b;
40 b = temp;
41 }
42
transpose_data_wei(const prb_t * prb,const dnn_mem_t & wei,const dnn_mem_t & wei_tr)43 int transpose_data_wei(
44 const prb_t *prb, const dnn_mem_t &wei, const dnn_mem_t &wei_tr) {
45 dnnl::impl::parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd,
46 prb->kh, prb->kw,
47 [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
48 int64_t kw) {
49 int64_t ch_idx
50 = (g * prb->ic / prb->g + ic) * prb->oc / prb->g + oc;
51 int64_t idx = ((ch_idx * prb->kd + kd) * prb->kh + kh) * prb->kw
52 + kw;
53 ((float *)wei_tr)[idx]
54 = ((float *)wei)[wei_off_f(prb, g, oc, ic, kd, kh, kw)];
55 });
56
57 return OK;
58 }
59
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & dpd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)60 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
61 dnnl_primitive_desc_t &dpd, res_t *res, dir_t dir,
62 const_dnnl_primitive_desc_t hint) {
63 dnnl_deconvolution_desc_t cd;
64 dnnl_memory_desc_t src_d, wei_d, bia_d, dst_d;
65
66 dnnl_dims_t src_1d_dims = {prb->mb, prb->ic, prb->iw};
67 dnnl_dims_t src_2d_dims = {prb->mb, prb->ic, prb->ih, prb->iw};
68 dnnl_dims_t src_3d_dims = {prb->mb, prb->ic, prb->id, prb->ih, prb->iw};
69 dnnl_dim_t *src_dims = prb->ndims == 5
70 ? src_3d_dims
71 : prb->ndims == 4 ? src_2d_dims : src_1d_dims;
72
73 dnnl_dims_t wei_1d_dims
74 = {prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kw};
75 dnnl_dims_t wei_2d_dims
76 = {prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kh, prb->kw};
77 dnnl_dims_t wei_3d_dims = {prb->g, prb->oc / prb->g, prb->ic / prb->g,
78 prb->kd, prb->kh, prb->kw};
79 dnnl_dim_t *wei_dims = prb->ndims == 5
80 ? &wei_3d_dims[!prb->has_groups]
81 : prb->ndims == 4 ? &wei_2d_dims[!prb->has_groups]
82 : &wei_1d_dims[!prb->has_groups];
83
84 dnnl_dims_t bia_dims = {prb->oc};
85
86 dnnl_dims_t dst_1d_dims = {prb->mb, prb->oc, prb->ow};
87 dnnl_dims_t dst_2d_dims = {prb->mb, prb->oc, prb->oh, prb->ow};
88 dnnl_dims_t dst_3d_dims = {prb->mb, prb->oc, prb->od, prb->oh, prb->ow};
89 dnnl_dim_t *dst_dims = prb->ndims == 5
90 ? dst_3d_dims
91 : prb->ndims == 4 ? dst_2d_dims : dst_1d_dims;
92
93 SAFE(init_md(&src_d, prb->ndims, src_dims, prb->cfg[SRC].dt, prb->stag),
94 CRIT);
95 SAFE(init_md(&wei_d, prb->ndims + prb->has_groups, wei_dims,
96 prb->cfg[WEI].dt, prb->wtag),
97 CRIT);
98 DNN_SAFE(dnnl_memory_desc_init_by_tag(&bia_d, 1, bia_dims, prb->cfg[BIA].dt,
99 dnnl_format_tag_any),
100 WARN);
101 SAFE(init_md(&dst_d, prb->ndims, dst_dims, prb->cfg[DST].dt, prb->dtag),
102 CRIT);
103
104 dnnl_dim_t strides_nd[] = {prb->sd, prb->sh, prb->sw};
105 dnnl_dim_t dilates_nd[] = {prb->dd, prb->dh, prb->dw};
106 dnnl_dim_t padding_nd[] = {prb->pd, prb->ph, prb->pw};
107 dnnl_dim_t padding_r_nd[] = {prb->pd_r, prb->ph_r, prb->pw_r};
108
109 dnnl_dim_t *strides = strides_nd + (5 - prb->ndims);
110 dnnl_dim_t *dilates = dilates_nd + (5 - prb->ndims);
111 dnnl_dim_t *padding = padding_nd + (5 - prb->ndims);
112 dnnl_dim_t *padding_r = padding_r_nd + (5 - prb->ndims);
113
114 dnnl_alg_kind_t alg = dnnl_deconvolution_direct;
115 if (prb->alg == WINO) alg = dnnl_deconvolution_winograd;
116
117 switch (prb->dir) {
118 case FWD_D:
119 case FWD_B:
120 case FWD_I:
121 DNN_SAFE(dnnl_dilated_deconvolution_forward_desc_init(&cd,
122 prb->dir == FWD_I ? dnnl_forward_inference
123 : dnnl_forward_training,
124 alg, &src_d, &wei_d,
125 prb->dir == FWD_B ? &bia_d : nullptr, &dst_d,
126 strides, dilates, padding, padding_r),
127 WARN);
128 break;
129 case BWD_D:
130 DNN_SAFE(dnnl_dilated_deconvolution_backward_data_desc_init(&cd,
131 alg, &src_d, &wei_d, &dst_d, strides, dilates,
132 padding, padding_r),
133 WARN);
134 break;
135 case BWD_W:
136 case BWD_WB:
137 DNN_SAFE(dnnl_dilated_deconvolution_backward_weights_desc_init(&cd,
138 alg, &src_d, &wei_d,
139 prb->dir == BWD_W ? nullptr : &bia_d, &dst_d,
140 strides, dilates, padding, padding_r),
141 WARN);
142 break;
143 default: DNN_SAFE(dnnl_invalid_arguments, CRIT);
144 }
145
146 DNN_SAFE(cd.accum_data_type == prb->cfg[ACC].dt ? dnnl_success
147 : dnnl_unimplemented,
148 CRIT);
149
150 attr_args_t attr_args;
151 attr_args.prepare_output_scales(prb->attr, prb->scales, prb->oc);
152 attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, dst_dims);
153 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
154 create_dnnl_attr(prb->attr, attr_args));
155
156 dnnl_status_t init_status
157 = dnnl_primitive_desc_create(&dpd, &cd, dnnl_attr, engine, nullptr);
158
159 if (!res) return OK;
160
161 if (init_status == dnnl_unimplemented) {
162 return res->state = UNIMPLEMENTED, OK;
163 }
164 SAFE(init_status, WARN);
165
166 res->impl_name = query_impl_info(dpd);
167 if (maybe_skip(res->impl_name)) {
168 BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
169 res->impl_name.c_str());
170 return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
171 } else {
172 BENCHDNN_PRINT(
173 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
174 }
175
176 SAFE(check_pd_w_and_wo_attr(res, prb->attr, cd), WARN);
177
178 return OK;
179 }
180
init_prim_ref(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> & prim_ref,const prb_t * prb)181 int init_prim_ref(
182 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) {
183 if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK;
184
185 // Create a new copy of prb to avoid potentially corrupting the test by
186 // modifying prb in place.
187 // DIRECT algorithm is used to prevent fallback to the slow benchdnn
188 // reference implementation.
189 auto cpu_attr = prb->attr;
190 update_cpu_ref_attrs(cpu_attr);
191 prb_t prb_cpu {*prb, prb->dir, conf_f32, tag::abx, tag::abx, tag::abx,
192 DIRECT, cpu_attr, prb->mb, prb->is_deconv};
193 dnnl_primitive_desc_t pd_ref_ {};
194 SAFE(init_pd(get_cpu_engine(), &prb_cpu, pd_ref_, nullptr, prb->dir,
195 nullptr),
196 WARN);
197 auto pd_ref = make_benchdnn_dnnl_wrapper(pd_ref_);
198
199 dnnl_primitive_t prim_ref_ {};
200 if (pd_ref) {
201 DNN_SAFE(dnnl_primitive_create(&prim_ref_, pd_ref), WARN);
202 BENCHDNN_PRINT(
203 5, "%s\n", "benchdnn: use CPU primitive as the reference");
204 }
205 prim_ref.reset(prim_ref_);
206 return OK;
207 }
208
check_known_skipped_case(const prb_t * prb,res_t * res)209 void check_known_skipped_case(const prb_t *prb, res_t *res) {
210 check_known_skipped_case_common(
211 {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir,
212 res);
213 if (res->state == SKIPPED) return;
214
215 if (is_nvidia_gpu()) {
216 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
217 const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
218 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
219 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
220 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
221 const int64_t PD_R = prb->pd_r, PH_R = prb->ph_r, PW_R = prb->pw_r;
222 const bool pad_ok = PD >= PD_R && PH >= PH_R && PW >= PW_R;
223 // copy-pasted from str2desc, dilation is not supported for Nvidia
224 const auto compute_out
225 = [](int64_t i, int64_t k, int64_t s, int64_t p) {
226 return (i - 1) * s + k - 2 * p;
227 };
228 const bool out_ok = OD == compute_out(ID, KD, SD, PD)
229 && OH == compute_out(IH, KH, SH, PH)
230 && OW == compute_out(IW, KW, SW, PW);
231
232 bool post_ops_ok = prb->attr.post_ops.is_def();
233
234 const auto stag = normalize_tag(prb->stag, prb->ndims);
235 const bool stag_is_axb = stag == normalize_tag(tag::axb, prb->ndims);
236 const bool fwd_tag_ok = !((prb->dir & FLAG_FWD) && stag_is_axb);
237 const bool bwd_tag_ok
238 = !((prb->dir == BWD_W || prb->dir == BWD_WB) && stag_is_axb);
239 const bool tag_ok = fwd_tag_ok && bwd_tag_ok;
240 // TODO: specified wtag (even for supported formats) is not working?
241 if (!pad_ok || !out_ok || !post_ops_ok || !tag_ok) {
242 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
243 return;
244 }
245
246 // FIXME: there's a bug in the library resulting in
247 // memory_tracking.hpp:458: Assertion `registry_.size() == 0' failed.
248 // For any spatial case, when both BWD_W and BWD_WB are run.
249 // It must be cache interaction, but not clear which side is
250 // guilty. Likely Nvidia implementation. Switch it off until further
251 // investigation.
252 if (prb->dir == BWD_WB) {
253 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
254 return;
255 }
256 }
257
258 // GPU:
259 // * BWD: doesn't support any attributes
260 // * FWD: support only post ops and all but x8s8bf16 cfg
261 if (is_gpu()) {
262 const bool only_non_default_post_ops = prb->attr.oscale.is_def()
263 && prb->attr.scales.is_def() && prb->attr.zero_points.is_def();
264 const bool is_x8s8bf16_cfg
265 = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[DST].dt == dnnl_bf16;
266 const bool fwd_ok = !is_x8s8bf16_cfg
267 && IMPLICATION(
268 (prb->dir & FLAG_FWD), only_non_default_post_ops);
269 const bool bwd_ok
270 = IMPLICATION((prb->dir & FLAG_BWD), prb->attr.is_def());
271 if (!fwd_ok || !bwd_ok) {
272 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
273 return;
274 }
275 }
276 }
277
doit(const prb_t * prb,res_t * res)278 int doit(const prb_t *prb, res_t *res) {
279 if (bench_mode == LIST) return res->state = LISTED, OK;
280
281 check_known_skipped_case(prb, res);
282 check_sum_post_ops(prb->attr, res);
283 if (res->state == SKIPPED) return OK;
284
285 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
286 SAFE(init_prim(prim, init_pd, prb, res), WARN);
287 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
288
289 const_dnnl_primitive_desc_t const_pd;
290 DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
291
292 if (check_mem_size(const_pd) != OK) {
293 return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
294 }
295
296 const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
297 return *dnnl_primitive_desc_query_md(
298 const_pd, dnnl_query_exec_arg_md, index);
299 };
300
301 const auto &src_md
302 = prb->dir == BWD_D ? q(DNNL_ARG_DIFF_SRC) : q(DNNL_ARG_SRC);
303 const auto &wei_md = prb->dir & FLAG_WEI ? q(DNNL_ARG_DIFF_WEIGHTS)
304 : q(DNNL_ARG_WEIGHTS);
305 const auto &bia_md
306 = prb->dir & FLAG_WEI ? q(DNNL_ARG_DIFF_BIAS) : q(DNNL_ARG_BIAS);
307 const auto &dst_md
308 = prb->dir & FLAG_BWD ? q(DNNL_ARG_DIFF_DST) : q(DNNL_ARG_DST);
309 const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
310 auto wei_tr_md = wei_md;
311
312 const bool with_groups = true;
313 swap(wei_tr_md.dims[with_groups + 0], wei_tr_md.dims[with_groups + 1]);
314
315 const auto fp = dnnl_f32;
316 const auto src_tag = tag::abx;
317 const auto wei_tag = tag::abx;
318
319 // Use CPU prim as the reference in GPU testing to reduce testing time.
320 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref;
321 SAFE(init_prim_ref(prim_ref, prb), WARN);
322
323 const auto &test_engine = get_test_engine();
324 const auto &ref_engine = prim_ref ? get_cpu_engine() : get_test_engine();
325
326 dnn_mem_t src_dt(src_md, test_engine);
327 dnn_mem_t wei_dt(wei_md, test_engine);
328 dnn_mem_t dst_dt(dst_md, test_engine);
329 dnn_mem_t bia_dt(bia_md, test_engine);
330 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
331 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
332 std::vector<int> binary_po_args;
333 SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
334 binary_po_fp, ref_engine),
335 WARN);
336
337 dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine);
338 dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine);
339 dnn_mem_t dst_fp(dst_md, fp, src_tag, ref_engine);
340 dnn_mem_t wei_tr_fp(wei_tr_md, fp, wei_tag, ref_engine);
341 dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine);
342 dnn_mem_t scratchpad_fp(scratchpad_md, ref_engine);
343 dnn_mem_t src_zero_points_m;
344 dnn_mem_t dst_zero_points_m;
345
346 /* fill memory + reorders <-> */
347 if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
348 if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
349 if (need_wei_init(prb)) {
350 SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN);
351 SAFE(transpose_data_wei(prb, wei_fp, wei_tr_fp), WARN);
352 }
353 if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN);
354
355 args_t args, ref_args;
356
357 // Update prb descriptor to re-use convolution reference.
358 prb_t p_tr((desc_t)*prb, prb->dir, prb->cfg, prb->stag, prb->wtag,
359 prb->dtag, prb->alg, prb->attr, prb->mb, true);
360 swap(p_tr.ic, p_tr.oc);
361 swap(p_tr.ih, p_tr.oh);
362 swap(p_tr.id, p_tr.od);
363 swap(p_tr.iw, p_tr.ow);
364
365 if (prb->dir & FLAG_FWD) {
366 maybe_prepare_runtime_zero_points(src_zero_points_m, prb->attr,
367 DNNL_ARG_SRC, prb->ic, prb->src_zp);
368 maybe_prepare_runtime_zero_points(dst_zero_points_m, prb->attr,
369 DNNL_ARG_DST, prb->oc, prb->dst_zp);
370
371 args.set(DNNL_ARG_SRC, src_dt);
372 args.set(DNNL_ARG_WEIGHTS, wei_dt);
373 args.set(DNNL_ARG_BIAS, bia_dt);
374 args.set(DNNL_ARG_DST, dst_dt);
375 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
376 args.set(binary_po_args, binary_po_dt);
377 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_points_m);
378 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_points_m);
379
380 SAFE(execute_and_wait(prim, args), WARN);
381
382 if (is_bench_mode(CORR)) {
383 ref_args.set(DNNL_ARG_SRC, src_fp);
384 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
385 ref_args.set(DNNL_ARG_BIAS, bia_fp);
386 ref_args.set(DNNL_ARG_DST, dst_fp);
387 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_tr_fp); // Hack. See ref.
388 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
389 ref_args.set(binary_po_args, binary_po_fp);
390
391 TIME_REF(deconv::compute_ref_fwd(&p_tr, prim_ref, ref_args));
392 SAFE(compare_data(prb, DST, dst_dt, dst_fp, res), WARN);
393 }
394 } else if (prb->dir == BWD_D) {
395 args.set(DNNL_ARG_DIFF_DST, dst_dt);
396 args.set(DNNL_ARG_WEIGHTS, wei_dt);
397 args.set(DNNL_ARG_DIFF_SRC, src_dt);
398 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
399
400 SAFE(execute_and_wait(prim, args), WARN);
401
402 if (is_bench_mode(CORR)) {
403 ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
404 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
405 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
406 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_tr_fp); // Hack. See ref.
407 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
408
409 TIME_REF(deconv::compute_ref_bwd_d(&p_tr, prim_ref, ref_args));
410 SAFE(compare_data(prb, SRC, src_dt, src_fp, res), WARN);
411 }
412 } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
413 args.set(DNNL_ARG_SRC, src_dt);
414 args.set(DNNL_ARG_DIFF_DST, dst_dt);
415 args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt);
416 args.set(DNNL_ARG_DIFF_BIAS, bia_dt);
417 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
418
419 SAFE(execute_and_wait(prim, args), WARN);
420
421 if (is_bench_mode(CORR)) {
422 ref_args.set(DNNL_ARG_SRC, src_fp);
423 ref_args.set(DNNL_ARG_WEIGHTS, wei_tr_fp); // Hack. See ref.
424 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
425 ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp);
426 ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp);
427 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
428
429 TIME_REF(deconv::compute_ref_bwd_w(&p_tr, prim_ref, ref_args));
430 SAFE(compare_data(&p_tr, WEI, wei_dt, wei_fp, res), WARN);
431 if (prb->dir & FLAG_BIA)
432 SAFE(compare_data(prb, BIA, bia_dt, bia_fp, res), WARN);
433 }
434 } else {
435 SAFE(FAIL, CRIT);
436 }
437
438 return measure_perf(res, prim, args);
439 }
440
441 } // namespace deconv
442