1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <stddef.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20 
21 #include "oneapi/dnnl/dnnl.h"
22 
23 #include "tests/test_thread.hpp"
24 
25 #include "dnnl_common.hpp"
26 #include "dnnl_memory.hpp"
27 #include "utils/compare.hpp"
28 
29 #include "shuffle/shuffle.hpp"
30 
31 namespace shuffle {
32 
fill_src(const prb_t * prb,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp)33 int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
34     auto get_range = [](const dnnl_data_type_t dt) {
35         if (dt == dnnl_s8 || dt == dnnl_u8)
36             return 256;
37         else if (dt == dnnl_bf16 || dt == dnnl_f16)
38             return 128;
39         return 1024;
40     };
41 
42     const auto nelems = mem_fp.nelems();
43     const int range = get_range(prb->dt);
44     const int f_min = prb->dt == dnnl_u8 ? 0 : -range / 2;
45 
46     dnnl::impl::parallel_nd(nelems, [&](int64_t i) {
47         const float gen = ((97 * i) + 101) % range;
48         const float value = (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
49                 ? (f_min + gen) / range
50                 : (f_min + gen) * (1.0f + 4.0f / range);
51         mem_fp.set_elem(i, round_to_nearest_representable(prb->dt, value));
52     });
53 
54     SAFE(mem_dt.reorder(mem_fp), WARN);
55 
56     return OK;
57 }
58 
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & spd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)59 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
60         dnnl_primitive_desc_t &spd, res_t *res, dir_t dir,
61         const_dnnl_primitive_desc_t hint) {
62     dnnl_shuffle_desc_t sd;
63 
64     dnnl_memory_desc_t data_d;
65     SAFE(init_md(&data_d, prb->ndims, prb->dims.data(), prb->dt, prb->tag),
66             WARN);
67 
68     if (prb->dir & FLAG_FWD) {
69         auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
70                                              : dnnl_forward_training;
71 
72         DNN_SAFE(dnnl_shuffle_forward_desc_init(
73                          &sd, prop_kind, &data_d, prb->axis, prb->group),
74                 WARN);
75     } else {
76         dnnl_memory_desc_t diff_data_d;
77         DNN_SAFE(dnnl_memory_desc_init_by_tag(&diff_data_d, prb->ndims,
78                          prb->dims.data(), prb->dt, dnnl_format_tag_any),
79                 WARN);
80 
81         DNN_SAFE(dnnl_shuffle_backward_desc_init(
82                          &sd, &diff_data_d, prb->axis, prb->group),
83                 WARN);
84     }
85 
86     dnnl_primitive_desc_t hint_fwd_pd_ {};
87     dnnl_status_t status = dnnl_success;
88     if (prb->dir & FLAG_BWD) {
89         dnnl_shuffle_desc_t sd_fwd;
90         DNN_SAFE(dnnl_shuffle_forward_desc_init(&sd_fwd, dnnl_forward_training,
91                          &data_d, prb->axis, prb->group),
92                 WARN);
93 
94         status = dnnl_primitive_desc_create(
95                 &hint_fwd_pd_, &sd_fwd, nullptr, engine, nullptr);
96         if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
97     }
98     auto hint_fwd_pd = make_benchdnn_dnnl_wrapper(hint_fwd_pd_);
99     SAFE(status, WARN);
100 
101     auto dnnl_attr = make_benchdnn_dnnl_wrapper(
102             create_dnnl_attr(prb->attr, attr_args_t()));
103 
104     status = dnnl_primitive_desc_create(
105             &spd, &sd, dnnl_attr, engine, hint_fwd_pd);
106 
107     if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
108     SAFE(status, WARN);
109 
110     res->impl_name = query_impl_info(spd);
111     BENCHDNN_PRINT(5, "oneDNN implementation: %s\n", res->impl_name.c_str());
112 
113     SAFE(check_pd_w_and_wo_attr(res, prb->attr, sd), WARN);
114 
115     return OK;
116 }
117 
check_known_skipped_case(const prb_t * prb,res_t * res)118 void check_known_skipped_case(const prb_t *prb, res_t *res) {
119     check_known_skipped_case_common({prb->dt}, prb->dir, res);
120     if (res->state == SKIPPED) return;
121 
122     if (is_nvidia_gpu()) {
123         res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
124         return;
125     }
126 }
127 
doit(const prb_t * prb,res_t * res)128 int doit(const prb_t *prb, res_t *res) {
129     if (bench_mode == LIST) return res->state = LISTED, OK;
130 
131     check_known_skipped_case(prb, res);
132     check_sum_post_ops(prb->attr, res);
133     if (res->state == SKIPPED) return OK;
134 
135     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
136     SAFE(init_prim(prim, init_pd, prb, res), WARN);
137     if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
138 
139     const_dnnl_primitive_desc_t const_pd;
140     DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
141 
142     if (check_mem_size(const_pd) != OK) {
143         return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
144     }
145 
146     const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
147         return *dnnl_primitive_desc_query_md(
148                 const_pd, dnnl_query_exec_arg_md, index);
149     };
150 
151     const auto &data_md
152             = prb->dir & FLAG_FWD ? q(DNNL_ARG_SRC) : q(DNNL_ARG_DIFF_SRC);
153     const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
154     const auto &test_engine = get_test_engine();
155 
156     dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, test_engine);
157     dnn_mem_t src_dt(data_md, test_engine);
158 
159     dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, test_engine);
160     dnn_mem_t dst_dt(data_md, test_engine);
161 
162     dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
163 
164     SAFE(fill_src(prb, src_dt, src_fp), WARN);
165 
166     const int i_arg = prb->dir == FWD_D ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST;
167     const int o_arg = prb->dir == FWD_D ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
168 
169     args_t args;
170 
171     args.set(i_arg, src_dt);
172     args.set(o_arg, dst_dt);
173     args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
174 
175     SAFE(execute_and_wait(prim, args), WARN);
176 
177     if (is_bench_mode(CORR)) {
178         TIME_REF(compute_ref(prb, src_fp, dst_fp));
179         compare::compare_t cmp;
180         SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
181     }
182 
183     return measure_perf(res, prim, args);
184 }
185 
186 } // namespace shuffle
187