1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21
22 #include <random>
23
24 #include "oneapi/dnnl/dnnl.h"
25
26 #include "tests/test_thread.hpp"
27
28 #include "compare.hpp"
29 #include "dnnl_common.hpp"
30 #include "dnnl_memory.hpp"
31
32 #include "sum/sum.hpp"
33
34 namespace sum {
35
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & spd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)36 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
37 dnnl_primitive_desc_t &spd, res_t *res, dir_t dir,
38 const_dnnl_primitive_desc_t hint) {
39 std::vector<dnnl_memory_desc_t> src_d;
40 src_d.resize(prb->n_inputs());
41
42 dnnl_memory_desc_t dst_d;
43
44 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input)
45 SAFE(init_md(&src_d[i_input], prb->ndims, prb->dims.data(),
46 prb->sdt[i_input], prb->stag[i_input]),
47 CRIT);
48
49 if (prb->dtag != tag::undef) {
50 SAFE(init_md(&dst_d, prb->ndims, prb->dims.data(), prb->ddt, prb->dtag),
51 CRIT);
52 }
53
54 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
55 create_dnnl_attr(prb->attr, attr_args_t()));
56
57 dnnl_status_t init_status = dnnl_sum_primitive_desc_create(&spd,
58 prb->dtag != tag::undef ? &dst_d : nullptr, prb->n_inputs(),
59 prb->scales.data(), src_d.data(), dnnl_attr, engine);
60
61 if (init_status == dnnl_unimplemented)
62 return res->state = UNIMPLEMENTED, OK;
63 else
64 SAFE(init_status, WARN);
65
66 res->impl_name = query_impl_info(spd);
67 BENCHDNN_PRINT(5, "oneDNN implementation: %s\n", res->impl_name.c_str());
68
69 return OK;
70 }
71
fill_src(const prb_t * prb,int input_idx,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp)72 int fill_src(
73 const prb_t *prb, int input_idx, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
74
75 const auto nelems = mem_fp.nelems();
76 const auto dt = prb->sdt[input_idx];
77 const int range = 16;
78 const int f_min = dt == dnnl_u8 ? 0 : -range / 2;
79
80 dnnl::impl::parallel_nd(nelems, [&](int64_t i) {
81 const float gen = ((97 * i) - 17 * input_idx + 101) % range;
82 const float value = (dt == dnnl_bf16 || dt == dnnl_f16)
83 ? (f_min + gen) / range
84 : (f_min + gen) * (1.0f + 4.0f / range);
85 mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
86 });
87
88 SAFE(mem_dt.reorder(mem_fp), WARN);
89
90 return OK;
91 }
92
check_known_skipped_case(const prb_t * prb,res_t * res)93 void check_known_skipped_case(const prb_t *prb, res_t *res) {
94 std::vector<dnnl_data_type_t> dts = prb->sdt;
95 dts.push_back(prb->ddt);
96 check_known_skipped_case_common(dts, FWD_D, res);
97 }
98
doit(const prb_t * prb,res_t * res)99 int doit(const prb_t *prb, res_t *res) {
100 if (bench_mode == LIST) return res->state = LISTED, OK;
101
102 check_known_skipped_case(prb, res);
103 if (res->state == SKIPPED) return OK;
104
105 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
106 SAFE(init_prim(prim, init_pd, prb, res), WARN);
107 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
108
109 const_dnnl_primitive_desc_t const_pd;
110 DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
111
112 if (check_mem_size(const_pd) != OK) {
113 return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
114 }
115
116 const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
117 return *dnnl_primitive_desc_query_md(
118 const_pd, dnnl_query_exec_arg_md, index);
119 };
120
121 const auto &test_engine = get_test_engine();
122 const auto &dst_md = q(DNNL_ARG_DST);
123 const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
124
125 dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, test_engine);
126 dnn_mem_t dst_dt(dst_md, test_engine);
127 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
128
129 args_t args;
130 args.set(DNNL_ARG_DST, dst_dt);
131 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
132
133 std::vector<dnn_mem_t> src_fp, src_dt;
134 src_fp.reserve(prb->n_inputs());
135 src_dt.reserve(prb->n_inputs());
136
137 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
138 const auto &src_md = q(DNNL_ARG_MULTIPLE_SRC + i_input);
139 src_fp.emplace_back(src_md, dnnl_f32, tag::abx, test_engine);
140 src_dt.emplace_back(src_md, test_engine);
141 SAFE(fill_src(prb, i_input, src_dt[i_input], src_fp[i_input]), WARN);
142 args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_dt[i_input]);
143 }
144
145 SAFE(execute_and_wait(prim, args), WARN);
146
147 if (is_bench_mode(CORR)) {
148 compute_ref(prb, src_fp, dst_fp);
149 compare::compare_t cmp;
150 cmp.set_threshold(epsilon_dt(dst_md.data_type) * prb->n_inputs());
151 SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
152 }
153
154 return measure_perf(res->timer, prim, args);
155 }
156
157 } // namespace sum
158