1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <float.h>
18 #include <math.h>
19 #include <stdio.h>
20 #include <stdlib.h>
21 
22 #include <random>
23 
24 #include "oneapi/dnnl/dnnl.h"
25 
26 #include "tests/test_thread.hpp"
27 
28 #include "compare.hpp"
29 #include "dnnl_common.hpp"
30 #include "dnnl_memory.hpp"
31 
32 #include "sum/sum.hpp"
33 
34 namespace sum {
35 
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & spd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)36 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
37         dnnl_primitive_desc_t &spd, res_t *res, dir_t dir,
38         const_dnnl_primitive_desc_t hint) {
39     std::vector<dnnl_memory_desc_t> src_d;
40     src_d.resize(prb->n_inputs());
41 
42     dnnl_memory_desc_t dst_d;
43 
44     for (int i_input = 0; i_input < prb->n_inputs(); ++i_input)
45         SAFE(init_md(&src_d[i_input], prb->ndims, prb->dims.data(),
46                      prb->sdt[i_input], prb->stag[i_input]),
47                 CRIT);
48 
49     if (prb->dtag != tag::undef) {
50         SAFE(init_md(&dst_d, prb->ndims, prb->dims.data(), prb->ddt, prb->dtag),
51                 CRIT);
52     }
53 
54     auto dnnl_attr = make_benchdnn_dnnl_wrapper(
55             create_dnnl_attr(prb->attr, attr_args_t()));
56 
57     dnnl_status_t init_status = dnnl_sum_primitive_desc_create(&spd,
58             prb->dtag != tag::undef ? &dst_d : nullptr, prb->n_inputs(),
59             prb->scales.data(), src_d.data(), dnnl_attr, engine);
60 
61     if (init_status == dnnl_unimplemented)
62         return res->state = UNIMPLEMENTED, OK;
63     else
64         SAFE(init_status, WARN);
65 
66     res->impl_name = query_impl_info(spd);
67     BENCHDNN_PRINT(5, "oneDNN implementation: %s\n", res->impl_name.c_str());
68 
69     return OK;
70 }
71 
fill_src(const prb_t * prb,int input_idx,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp)72 int fill_src(
73         const prb_t *prb, int input_idx, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp) {
74 
75     const auto nelems = mem_fp.nelems();
76     const auto dt = prb->sdt[input_idx];
77     const int range = 16;
78     const int f_min = dt == dnnl_u8 ? 0 : -range / 2;
79 
80     dnnl::impl::parallel_nd(nelems, [&](int64_t i) {
81         const float gen = ((97 * i) - 17 * input_idx + 101) % range;
82         const float value = (dt == dnnl_bf16 || dt == dnnl_f16)
83                 ? (f_min + gen) / range
84                 : (f_min + gen) * (1.0f + 4.0f / range);
85         mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
86     });
87 
88     SAFE(mem_dt.reorder(mem_fp), WARN);
89 
90     return OK;
91 }
92 
check_known_skipped_case(const prb_t * prb,res_t * res)93 void check_known_skipped_case(const prb_t *prb, res_t *res) {
94     std::vector<dnnl_data_type_t> dts = prb->sdt;
95     dts.push_back(prb->ddt);
96     check_known_skipped_case_common(dts, FWD_D, res);
97 }
98 
doit(const prb_t * prb,res_t * res)99 int doit(const prb_t *prb, res_t *res) {
100     if (bench_mode == LIST) return res->state = LISTED, OK;
101 
102     check_known_skipped_case(prb, res);
103     if (res->state == SKIPPED) return OK;
104 
105     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
106     SAFE(init_prim(prim, init_pd, prb, res), WARN);
107     if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
108 
109     const_dnnl_primitive_desc_t const_pd;
110     DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
111 
112     if (check_mem_size(const_pd) != OK) {
113         return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
114     }
115 
116     const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
117         return *dnnl_primitive_desc_query_md(
118                 const_pd, dnnl_query_exec_arg_md, index);
119     };
120 
121     const auto &test_engine = get_test_engine();
122     const auto &dst_md = q(DNNL_ARG_DST);
123     const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
124 
125     dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, test_engine);
126     dnn_mem_t dst_dt(dst_md, test_engine);
127     dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
128 
129     args_t args;
130     args.set(DNNL_ARG_DST, dst_dt);
131     args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
132 
133     std::vector<dnn_mem_t> src_fp, src_dt;
134     src_fp.reserve(prb->n_inputs());
135     src_dt.reserve(prb->n_inputs());
136 
137     for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
138         const auto &src_md = q(DNNL_ARG_MULTIPLE_SRC + i_input);
139         src_fp.emplace_back(src_md, dnnl_f32, tag::abx, test_engine);
140         src_dt.emplace_back(src_md, test_engine);
141         SAFE(fill_src(prb, i_input, src_dt[i_input], src_fp[i_input]), WARN);
142         args.set(DNNL_ARG_MULTIPLE_SRC + i_input, src_dt[i_input]);
143     }
144 
145     SAFE(execute_and_wait(prim, args), WARN);
146 
147     if (is_bench_mode(CORR)) {
148         compute_ref(prb, src_fp, dst_fp);
149         compare::compare_t cmp;
150         cmp.set_threshold(epsilon_dt(dst_md.data_type) * prb->n_inputs());
151         SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
152     }
153 
154     return measure_perf(res->timer, prim, args);
155 }
156 
157 } // namespace sum
158