1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 
22 #include "oneapi/dnnl/dnnl.h"
23 
24 #include "tests/test_thread.hpp"
25 
26 #include "dnnl_common.hpp"
27 #include "dnnl_memory.hpp"
28 
29 #include "norm.hpp"
30 
31 #include "binary/binary.hpp"
32 #include "conv/deconv.hpp"
33 using namespace conv;
34 
35 namespace deconv {
36 
swap(int64_t & a,int64_t & b)37 inline static void swap(int64_t &a, int64_t &b) {
38     int64_t temp = a;
39     a = b;
40     b = temp;
41 }
42 
transpose_data_wei(const prb_t * prb,const dnn_mem_t & wei,const dnn_mem_t & wei_tr)43 int transpose_data_wei(
44         const prb_t *prb, const dnn_mem_t &wei, const dnn_mem_t &wei_tr) {
45     dnnl::impl::parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd,
46             prb->kh, prb->kw,
47             [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh,
48                     int64_t kw) {
49                 int64_t ch_idx
50                         = (g * prb->ic / prb->g + ic) * prb->oc / prb->g + oc;
51                 int64_t idx = ((ch_idx * prb->kd + kd) * prb->kh + kh) * prb->kw
52                         + kw;
53                 ((float *)wei_tr)[idx]
54                         = ((float *)wei)[wei_off_f(prb, g, oc, ic, kd, kh, kw)];
55             });
56 
57     return OK;
58 }
59 
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & dpd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)60 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
61         dnnl_primitive_desc_t &dpd, res_t *res, dir_t dir,
62         const_dnnl_primitive_desc_t hint) {
63     dnnl_deconvolution_desc_t cd;
64     dnnl_memory_desc_t src_d, wei_d, bia_d, dst_d;
65 
66     dnnl_dims_t src_1d_dims = {prb->mb, prb->ic, prb->iw};
67     dnnl_dims_t src_2d_dims = {prb->mb, prb->ic, prb->ih, prb->iw};
68     dnnl_dims_t src_3d_dims = {prb->mb, prb->ic, prb->id, prb->ih, prb->iw};
69     dnnl_dim_t *src_dims = prb->ndims == 5
70             ? src_3d_dims
71             : prb->ndims == 4 ? src_2d_dims : src_1d_dims;
72 
73     dnnl_dims_t wei_1d_dims
74             = {prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kw};
75     dnnl_dims_t wei_2d_dims
76             = {prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kh, prb->kw};
77     dnnl_dims_t wei_3d_dims = {prb->g, prb->oc / prb->g, prb->ic / prb->g,
78             prb->kd, prb->kh, prb->kw};
79     dnnl_dim_t *wei_dims = prb->ndims == 5
80             ? &wei_3d_dims[!prb->has_groups]
81             : prb->ndims == 4 ? &wei_2d_dims[!prb->has_groups]
82                               : &wei_1d_dims[!prb->has_groups];
83 
84     dnnl_dims_t bia_dims = {prb->oc};
85 
86     dnnl_dims_t dst_1d_dims = {prb->mb, prb->oc, prb->ow};
87     dnnl_dims_t dst_2d_dims = {prb->mb, prb->oc, prb->oh, prb->ow};
88     dnnl_dims_t dst_3d_dims = {prb->mb, prb->oc, prb->od, prb->oh, prb->ow};
89     dnnl_dim_t *dst_dims = prb->ndims == 5
90             ? dst_3d_dims
91             : prb->ndims == 4 ? dst_2d_dims : dst_1d_dims;
92 
93     SAFE(init_md(&src_d, prb->ndims, src_dims, prb->cfg[SRC].dt, prb->stag),
94             CRIT);
95     SAFE(init_md(&wei_d, prb->ndims + prb->has_groups, wei_dims,
96                  prb->cfg[WEI].dt, prb->wtag),
97             CRIT);
98     DNN_SAFE(dnnl_memory_desc_init_by_tag(&bia_d, 1, bia_dims, prb->cfg[BIA].dt,
99                      dnnl_format_tag_any),
100             WARN);
101     SAFE(init_md(&dst_d, prb->ndims, dst_dims, prb->cfg[DST].dt, prb->dtag),
102             CRIT);
103 
104     dnnl_dim_t strides_nd[] = {prb->sd, prb->sh, prb->sw};
105     dnnl_dim_t dilates_nd[] = {prb->dd, prb->dh, prb->dw};
106     dnnl_dim_t padding_nd[] = {prb->pd, prb->ph, prb->pw};
107     dnnl_dim_t padding_r_nd[] = {prb->pd_r, prb->ph_r, prb->pw_r};
108 
109     dnnl_dim_t *strides = strides_nd + (5 - prb->ndims);
110     dnnl_dim_t *dilates = dilates_nd + (5 - prb->ndims);
111     dnnl_dim_t *padding = padding_nd + (5 - prb->ndims);
112     dnnl_dim_t *padding_r = padding_r_nd + (5 - prb->ndims);
113 
114     dnnl_alg_kind_t alg = dnnl_deconvolution_direct;
115     if (prb->alg == WINO) alg = dnnl_deconvolution_winograd;
116 
117     switch (prb->dir) {
118         case FWD_D:
119         case FWD_B:
120         case FWD_I:
121             DNN_SAFE(dnnl_dilated_deconvolution_forward_desc_init(&cd,
122                              prb->dir == FWD_I ? dnnl_forward_inference
123                                                : dnnl_forward_training,
124                              alg, &src_d, &wei_d,
125                              prb->dir == FWD_B ? &bia_d : nullptr, &dst_d,
126                              strides, dilates, padding, padding_r),
127                     WARN);
128             break;
129         case BWD_D:
130             DNN_SAFE(dnnl_dilated_deconvolution_backward_data_desc_init(&cd,
131                              alg, &src_d, &wei_d, &dst_d, strides, dilates,
132                              padding, padding_r),
133                     WARN);
134             break;
135         case BWD_W:
136         case BWD_WB:
137             DNN_SAFE(dnnl_dilated_deconvolution_backward_weights_desc_init(&cd,
138                              alg, &src_d, &wei_d,
139                              prb->dir == BWD_W ? nullptr : &bia_d, &dst_d,
140                              strides, dilates, padding, padding_r),
141                     WARN);
142             break;
143         default: DNN_SAFE(dnnl_invalid_arguments, CRIT);
144     }
145 
146     DNN_SAFE(cd.accum_data_type == prb->cfg[ACC].dt ? dnnl_success
147                                                     : dnnl_unimplemented,
148             CRIT);
149 
150     attr_args_t attr_args;
151     attr_args.prepare_output_scales(prb->attr, prb->scales, prb->oc);
152     attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, dst_dims);
153     auto dnnl_attr = make_benchdnn_dnnl_wrapper(
154             create_dnnl_attr(prb->attr, attr_args));
155 
156     dnnl_status_t init_status
157             = dnnl_primitive_desc_create(&dpd, &cd, dnnl_attr, engine, nullptr);
158 
159     if (!res) return OK;
160 
161     if (init_status == dnnl_unimplemented) {
162         return res->state = UNIMPLEMENTED, OK;
163     }
164     SAFE(init_status, WARN);
165 
166     res->impl_name = query_impl_info(dpd);
167     if (maybe_skip(res->impl_name)) {
168         BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
169                 res->impl_name.c_str());
170         return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
171     } else {
172         BENCHDNN_PRINT(
173                 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
174     }
175 
176     SAFE(check_pd_w_and_wo_attr(res, prb->attr, cd), WARN);
177 
178     return OK;
179 }
180 
init_prim_ref(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> & prim_ref,const prb_t * prb)181 int init_prim_ref(
182         benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) {
183     if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK;
184 
185     // Create a new copy of prb to avoid potentially corrupting the test by
186     // modifying prb in place.
187     // DIRECT algorithm is used to prevent fallback  to the slow benchdnn
188     // reference implementation.
189     auto cpu_attr = prb->attr;
190     update_cpu_ref_attrs(cpu_attr);
191     prb_t prb_cpu {*prb, prb->dir, conf_f32, tag::abx, tag::abx, tag::abx,
192             DIRECT, cpu_attr, prb->mb, prb->is_deconv};
193     dnnl_primitive_desc_t pd_ref_ {};
194     SAFE(init_pd(get_cpu_engine(), &prb_cpu, pd_ref_, nullptr, prb->dir,
195                  nullptr),
196             WARN);
197     auto pd_ref = make_benchdnn_dnnl_wrapper(pd_ref_);
198 
199     dnnl_primitive_t prim_ref_ {};
200     if (pd_ref) {
201         DNN_SAFE(dnnl_primitive_create(&prim_ref_, pd_ref), WARN);
202         BENCHDNN_PRINT(
203                 5, "%s\n", "benchdnn: use CPU primitive as the reference");
204     }
205     prim_ref.reset(prim_ref_);
206     return OK;
207 }
208 
check_known_skipped_case(const prb_t * prb,res_t * res)209 void check_known_skipped_case(const prb_t *prb, res_t *res) {
210     check_known_skipped_case_common(
211             {prb->cfg[SRC].dt, prb->cfg[WEI].dt, prb->cfg[DST].dt}, prb->dir,
212             res);
213     if (res->state == SKIPPED) return;
214 
215     if (is_nvidia_gpu()) {
216         const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
217         const int64_t OD = prb->od, OH = prb->oh, OW = prb->ow;
218         const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
219         const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
220         const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
221         const int64_t PD_R = prb->pd_r, PH_R = prb->ph_r, PW_R = prb->pw_r;
222         const bool pad_ok = PD >= PD_R && PH >= PH_R && PW >= PW_R;
223         // copy-pasted from str2desc, dilation is not supported for Nvidia
224         const auto compute_out
225                 = [](int64_t i, int64_t k, int64_t s, int64_t p) {
226                       return (i - 1) * s + k - 2 * p;
227                   };
228         const bool out_ok = OD == compute_out(ID, KD, SD, PD)
229                 && OH == compute_out(IH, KH, SH, PH)
230                 && OW == compute_out(IW, KW, SW, PW);
231 
232         bool post_ops_ok = prb->attr.post_ops.is_def();
233 
234         const auto stag = normalize_tag(prb->stag, prb->ndims);
235         const bool stag_is_axb = stag == normalize_tag(tag::axb, prb->ndims);
236         const bool fwd_tag_ok = !((prb->dir & FLAG_FWD) && stag_is_axb);
237         const bool bwd_tag_ok
238                 = !((prb->dir == BWD_W || prb->dir == BWD_WB) && stag_is_axb);
239         const bool tag_ok = fwd_tag_ok && bwd_tag_ok;
240         // TODO: specified wtag (even for supported formats) is not working?
241         if (!pad_ok || !out_ok || !post_ops_ok || !tag_ok) {
242             res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
243             return;
244         }
245 
246         // FIXME: there's a bug in the library resulting in
247         // memory_tracking.hpp:458: Assertion `registry_.size() == 0' failed.
248         // For any spatial case, when both BWD_W and BWD_WB are run.
249         // It must be cache interaction, but not clear which side is
250         // guilty. Likely Nvidia implementation. Switch it off until further
251         // investigation.
252         if (prb->dir == BWD_WB) {
253             res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
254             return;
255         }
256     }
257 
258     // GPU:
259     //     * BWD: doesn't support any attributes
260     //     * FWD: support only post ops and all but x8s8bf16 cfg
261     if (is_gpu()) {
262         const bool only_non_default_post_ops = prb->attr.oscale.is_def()
263                 && prb->attr.scales.is_def() && prb->attr.zero_points.is_def();
264         const bool is_x8s8bf16_cfg
265                 = prb->cfg[WEI].dt == dnnl_s8 && prb->cfg[DST].dt == dnnl_bf16;
266         const bool fwd_ok = !is_x8s8bf16_cfg
267                 && IMPLICATION(
268                         (prb->dir & FLAG_FWD), only_non_default_post_ops);
269         const bool bwd_ok
270                 = IMPLICATION((prb->dir & FLAG_BWD), prb->attr.is_def());
271         if (!fwd_ok || !bwd_ok) {
272             res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
273             return;
274         }
275     }
276 }
277 
doit(const prb_t * prb,res_t * res)278 int doit(const prb_t *prb, res_t *res) {
279     if (bench_mode == LIST) return res->state = LISTED, OK;
280 
281     check_known_skipped_case(prb, res);
282     check_sum_post_ops(prb->attr, res);
283     if (res->state == SKIPPED) return OK;
284 
285     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
286     SAFE(init_prim(prim, init_pd, prb, res), WARN);
287     if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
288 
289     const_dnnl_primitive_desc_t const_pd;
290     DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
291 
292     if (check_mem_size(const_pd) != OK) {
293         return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
294     }
295 
296     const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
297         return *dnnl_primitive_desc_query_md(
298                 const_pd, dnnl_query_exec_arg_md, index);
299     };
300 
301     const auto &src_md
302             = prb->dir == BWD_D ? q(DNNL_ARG_DIFF_SRC) : q(DNNL_ARG_SRC);
303     const auto &wei_md = prb->dir & FLAG_WEI ? q(DNNL_ARG_DIFF_WEIGHTS)
304                                              : q(DNNL_ARG_WEIGHTS);
305     const auto &bia_md
306             = prb->dir & FLAG_WEI ? q(DNNL_ARG_DIFF_BIAS) : q(DNNL_ARG_BIAS);
307     const auto &dst_md
308             = prb->dir & FLAG_BWD ? q(DNNL_ARG_DIFF_DST) : q(DNNL_ARG_DST);
309     const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
310     auto wei_tr_md = wei_md;
311 
312     const bool with_groups = true;
313     swap(wei_tr_md.dims[with_groups + 0], wei_tr_md.dims[with_groups + 1]);
314 
315     const auto fp = dnnl_f32;
316     const auto src_tag = tag::abx;
317     const auto wei_tag = tag::abx;
318 
319     // Use CPU prim as the reference in GPU testing to reduce testing time.
320     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref;
321     SAFE(init_prim_ref(prim_ref, prb), WARN);
322 
323     const auto &test_engine = get_test_engine();
324     const auto &ref_engine = prim_ref ? get_cpu_engine() : get_test_engine();
325 
326     dnn_mem_t src_dt(src_md, test_engine);
327     dnn_mem_t wei_dt(wei_md, test_engine);
328     dnn_mem_t dst_dt(dst_md, test_engine);
329     dnn_mem_t bia_dt(bia_md, test_engine);
330     dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
331     std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
332     std::vector<int> binary_po_args;
333     SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
334                  binary_po_fp, ref_engine),
335             WARN);
336 
337     dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine);
338     dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine);
339     dnn_mem_t dst_fp(dst_md, fp, src_tag, ref_engine);
340     dnn_mem_t wei_tr_fp(wei_tr_md, fp, wei_tag, ref_engine);
341     dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine);
342     dnn_mem_t scratchpad_fp(scratchpad_md, ref_engine);
343     dnn_mem_t src_zero_points_m;
344     dnn_mem_t dst_zero_points_m;
345 
346     /* fill memory + reorders <-> */
347     if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
348     if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
349     if (need_wei_init(prb)) {
350         SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN);
351         SAFE(transpose_data_wei(prb, wei_fp, wei_tr_fp), WARN);
352     }
353     if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN);
354 
355     args_t args, ref_args;
356 
357     // Update prb descriptor to re-use convolution reference.
358     prb_t p_tr((desc_t)*prb, prb->dir, prb->cfg, prb->stag, prb->wtag,
359             prb->dtag, prb->alg, prb->attr, prb->mb, true);
360     swap(p_tr.ic, p_tr.oc);
361     swap(p_tr.ih, p_tr.oh);
362     swap(p_tr.id, p_tr.od);
363     swap(p_tr.iw, p_tr.ow);
364 
365     if (prb->dir & FLAG_FWD) {
366         maybe_prepare_runtime_zero_points(src_zero_points_m, prb->attr,
367                 DNNL_ARG_SRC, prb->ic, prb->src_zp);
368         maybe_prepare_runtime_zero_points(dst_zero_points_m, prb->attr,
369                 DNNL_ARG_DST, prb->oc, prb->dst_zp);
370 
371         args.set(DNNL_ARG_SRC, src_dt);
372         args.set(DNNL_ARG_WEIGHTS, wei_dt);
373         args.set(DNNL_ARG_BIAS, bia_dt);
374         args.set(DNNL_ARG_DST, dst_dt);
375         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
376         args.set(binary_po_args, binary_po_dt);
377         args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zero_points_m);
378         args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zero_points_m);
379 
380         SAFE(execute_and_wait(prim, args), WARN);
381 
382         if (is_bench_mode(CORR)) {
383             ref_args.set(DNNL_ARG_SRC, src_fp);
384             ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
385             ref_args.set(DNNL_ARG_BIAS, bia_fp);
386             ref_args.set(DNNL_ARG_DST, dst_fp);
387             ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_tr_fp); // Hack. See ref.
388             ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
389             ref_args.set(binary_po_args, binary_po_fp);
390 
391             TIME_REF(deconv::compute_ref_fwd(&p_tr, prim_ref, ref_args));
392             SAFE(compare_data(prb, DST, dst_dt, dst_fp, res), WARN);
393         }
394     } else if (prb->dir == BWD_D) {
395         args.set(DNNL_ARG_DIFF_DST, dst_dt);
396         args.set(DNNL_ARG_WEIGHTS, wei_dt);
397         args.set(DNNL_ARG_DIFF_SRC, src_dt);
398         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
399 
400         SAFE(execute_and_wait(prim, args), WARN);
401 
402         if (is_bench_mode(CORR)) {
403             ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
404             ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
405             ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
406             ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_tr_fp); // Hack. See ref.
407             ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
408 
409             TIME_REF(deconv::compute_ref_bwd_d(&p_tr, prim_ref, ref_args));
410             SAFE(compare_data(prb, SRC, src_dt, src_fp, res), WARN);
411         }
412     } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) {
413         args.set(DNNL_ARG_SRC, src_dt);
414         args.set(DNNL_ARG_DIFF_DST, dst_dt);
415         args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt);
416         args.set(DNNL_ARG_DIFF_BIAS, bia_dt);
417         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
418 
419         SAFE(execute_and_wait(prim, args), WARN);
420 
421         if (is_bench_mode(CORR)) {
422             ref_args.set(DNNL_ARG_SRC, src_fp);
423             ref_args.set(DNNL_ARG_WEIGHTS, wei_tr_fp); // Hack. See ref.
424             ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
425             ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp);
426             ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp);
427             ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
428 
429             TIME_REF(deconv::compute_ref_bwd_w(&p_tr, prim_ref, ref_args));
430             SAFE(compare_data(&p_tr, WEI, wei_dt, wei_fp, res), WARN);
431             if (prb->dir & FLAG_BIA)
432                 SAFE(compare_data(prb, BIA, bia_dt, bia_fp, res), WARN);
433         }
434     } else {
435         SAFE(FAIL, CRIT);
436     }
437 
438     return measure_perf(res, prim, args);
439 }
440 
441 } // namespace deconv
442