1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stddef.h>
18 #include <stdio.h>
19 #include <stdlib.h>
20
21 #include "oneapi/dnnl/dnnl.h"
22
23 #include "tests/test_thread.hpp"
24
25 #include "dnnl_common.hpp"
26 #include "dnnl_memory.hpp"
27 #include "utils/compare.hpp"
28
29 #include "shuffle/shuffle.hpp"
30
31 namespace shuffle {
32
fill_src(const prb_t * prb,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp)33 int fill_src(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
34 auto get_range = [](const dnnl_data_type_t dt) {
35 if (dt == dnnl_s8 || dt == dnnl_u8)
36 return 256;
37 else if (dt == dnnl_bf16 || dt == dnnl_f16)
38 return 128;
39 return 1024;
40 };
41
42 const auto nelems = mem_fp.nelems();
43 const int range = get_range(prb->dt);
44 const int f_min = prb->dt == dnnl_u8 ? 0 : -range / 2;
45
46 dnnl::impl::parallel_nd(nelems, [&](int64_t i) {
47 const float gen = ((97 * i) + 101) % range;
48 const float value = (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
49 ? (f_min + gen) / range
50 : (f_min + gen) * (1.0f + 4.0f / range);
51 mem_fp.set_elem(i, round_to_nearest_representable(prb->dt, value));
52 });
53
54 SAFE(mem_dt.reorder(mem_fp), WARN);
55
56 return OK;
57 }
58
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & spd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)59 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
60 dnnl_primitive_desc_t &spd, res_t *res, dir_t dir,
61 const_dnnl_primitive_desc_t hint) {
62 dnnl_shuffle_desc_t sd;
63
64 dnnl_memory_desc_t data_d;
65 SAFE(init_md(&data_d, prb->ndims, prb->dims.data(), prb->dt, prb->tag),
66 WARN);
67
68 if (prb->dir & FLAG_FWD) {
69 auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
70 : dnnl_forward_training;
71
72 DNN_SAFE(dnnl_shuffle_forward_desc_init(
73 &sd, prop_kind, &data_d, prb->axis, prb->group),
74 WARN);
75 } else {
76 dnnl_memory_desc_t diff_data_d;
77 DNN_SAFE(dnnl_memory_desc_init_by_tag(&diff_data_d, prb->ndims,
78 prb->dims.data(), prb->dt, dnnl_format_tag_any),
79 WARN);
80
81 DNN_SAFE(dnnl_shuffle_backward_desc_init(
82 &sd, &diff_data_d, prb->axis, prb->group),
83 WARN);
84 }
85
86 dnnl_primitive_desc_t hint_fwd_pd_ {};
87 dnnl_status_t status = dnnl_success;
88 if (prb->dir & FLAG_BWD) {
89 dnnl_shuffle_desc_t sd_fwd;
90 DNN_SAFE(dnnl_shuffle_forward_desc_init(&sd_fwd, dnnl_forward_training,
91 &data_d, prb->axis, prb->group),
92 WARN);
93
94 status = dnnl_primitive_desc_create(
95 &hint_fwd_pd_, &sd_fwd, nullptr, engine, nullptr);
96 if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
97 }
98 auto hint_fwd_pd = make_benchdnn_dnnl_wrapper(hint_fwd_pd_);
99 SAFE(status, WARN);
100
101 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
102 create_dnnl_attr(prb->attr, attr_args_t()));
103
104 status = dnnl_primitive_desc_create(
105 &spd, &sd, dnnl_attr, engine, hint_fwd_pd);
106
107 if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
108 SAFE(status, WARN);
109
110 res->impl_name = query_impl_info(spd);
111 BENCHDNN_PRINT(5, "oneDNN implementation: %s\n", res->impl_name.c_str());
112
113 SAFE(check_pd_w_and_wo_attr(res, prb->attr, sd), WARN);
114
115 return OK;
116 }
117
check_known_skipped_case(const prb_t * prb,res_t * res)118 void check_known_skipped_case(const prb_t *prb, res_t *res) {
119 check_known_skipped_case_common({prb->dt}, prb->dir, res);
120 if (res->state == SKIPPED) return;
121
122 if (is_nvidia_gpu()) {
123 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
124 return;
125 }
126 }
127
doit(const prb_t * prb,res_t * res)128 int doit(const prb_t *prb, res_t *res) {
129 if (bench_mode == LIST) return res->state = LISTED, OK;
130
131 check_known_skipped_case(prb, res);
132 check_sum_post_ops(prb->attr, res);
133 if (res->state == SKIPPED) return OK;
134
135 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
136 SAFE(init_prim(prim, init_pd, prb, res), WARN);
137 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
138
139 const_dnnl_primitive_desc_t const_pd;
140 DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
141
142 if (check_mem_size(const_pd) != OK) {
143 return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
144 }
145
146 const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
147 return *dnnl_primitive_desc_query_md(
148 const_pd, dnnl_query_exec_arg_md, index);
149 };
150
151 const auto &data_md
152 = prb->dir & FLAG_FWD ? q(DNNL_ARG_SRC) : q(DNNL_ARG_DIFF_SRC);
153 const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
154 const auto &test_engine = get_test_engine();
155
156 dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, test_engine);
157 dnn_mem_t src_dt(data_md, test_engine);
158
159 dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, test_engine);
160 dnn_mem_t dst_dt(data_md, test_engine);
161
162 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
163
164 SAFE(fill_src(prb, src_dt, src_fp), WARN);
165
166 const int i_arg = prb->dir == FWD_D ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST;
167 const int o_arg = prb->dir == FWD_D ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
168
169 args_t args;
170
171 args.set(i_arg, src_dt);
172 args.set(o_arg, dst_dt);
173 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
174
175 SAFE(execute_and_wait(prim, args), WARN);
176
177 if (is_bench_mode(CORR)) {
178 TIME_REF(compute_ref(prb, src_fp, dst_fp));
179 compare::compare_t cmp;
180 SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
181 }
182
183 return measure_perf(res, prim, args);
184 }
185
186 } // namespace shuffle
187