1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <math.h>
18 #include <random>
19 #include <stdio.h>
20 #include <stdlib.h>
21
22 #include "oneapi/dnnl/dnnl.h"
23
24 #include "tests/test_thread.hpp"
25
26 #include "dnnl_common.hpp"
27 #include "dnnl_memory.hpp"
28 #include "utils/compare.hpp"
29
30 #include "binary/binary.hpp"
31 #include "eltwise/eltwise.hpp"
32
33 namespace eltwise {
34
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & epd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)35 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
36 dnnl_primitive_desc_t &epd, res_t *res, dir_t dir,
37 const_dnnl_primitive_desc_t hint) {
38 dnnl_eltwise_desc_t ed;
39 dnnl_memory_desc_t data_d;
40
41 SAFE(init_md(&data_d, prb->ndims, prb->dims.data(), prb->dt, prb->tag),
42 CRIT);
43
44 dnnl_alg_kind_t alg = attr_t::post_ops_t::kind2dnnl_kind(prb->alg);
45
46 if (prb->dir & FLAG_FWD) {
47 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
48 : dnnl_forward_training;
49
50 DNN_SAFE(dnnl_eltwise_forward_desc_init(
51 &ed, prop, alg, &data_d, prb->alpha, prb->beta),
52 WARN);
53 } else {
54 dnnl_memory_desc_t diff_data_d;
55 DNN_SAFE(dnnl_memory_desc_init_by_tag(&diff_data_d, prb->ndims,
56 prb->dims.data(), prb->dt, dnnl_format_tag_any),
57 WARN);
58 DNN_SAFE(dnnl_eltwise_backward_desc_init(&ed, alg, &diff_data_d,
59 &data_d, prb->alpha, prb->beta),
60 WARN);
61 }
62
63 attr_args_t attr_args;
64 attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dims.data());
65 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
66 create_dnnl_attr(prb->attr, attr_args));
67
68 dnnl_status_t init_status
69 = dnnl_primitive_desc_create(&epd, &ed, dnnl_attr, engine, nullptr);
70
71 if (init_status == dnnl_unimplemented)
72 return res->state = UNIMPLEMENTED, OK;
73 else
74 SAFE(init_status, WARN);
75
76 res->impl_name = query_impl_info(epd);
77 if (maybe_skip(res->impl_name)) {
78 BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
79 res->impl_name.c_str());
80 return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
81 } else {
82 BENCHDNN_PRINT(
83 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
84 }
85
86 SAFE(check_pd_w_and_wo_attr(res, prb->attr, ed), WARN);
87
88 return OK;
89 }
90
check_abs_err(const prb_t * prb,const float & s,const float & trh)91 static bool check_abs_err(const prb_t *prb, const float &s, const float &trh) {
92 const float approx_machine_eps = 2 * epsilon_dt(dnnl_f32);
93 const float comp_err = approx_machine_eps / trh;
94
95 switch (prb->alg) {
96 case alg_t::ELU:
97 case alg_t::ELU_DST:
98 // catch catastrophic cancellation when (exp(s) - 1), s < 0 and
99 // s is close to zero.
100 return (prb->dir & FLAG_FWD) && std::signbit(s)
101 && (fabsf(expf(s) - 1.f) <= comp_err);
102 case alg_t::GELU_TANH: {
103 // catch catastrophic cancellation
104 // (4.f is magic scale for f32)
105 const float sqrt_2_over_pi = 0.797884;
106 const float fitting_const = 0.044715;
107 float v = tanhf(sqrt_2_over_pi * s * (1 + fitting_const * s * s));
108 float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s);
109 if (fabsf(1.f + v) <= comp_err) return true;
110 return (prb->dir & FLAG_BWD) && std::signbit(s)
111 && fabsf(1.f + s * (1.f - v) * dg) <= 4.f * comp_err;
112 }
113 case alg_t::GELU_ERF: {
114 // Catch catastrophic cancellation
115 // which occurs at large negative s.
116 // Factor 2 (in bwd) is to account for the fact that error is
117 // accumulated for each summand (except the 1) when they
118 // are of the same order of magnitude.
119 const float sqrt_2_over_2 = 0.707106769084930419921875f;
120 const float two_over_sqrt_pi = 1.12837922573089599609375f;
121 float v = s * sqrt_2_over_2;
122 if (prb->dir & FLAG_FWD)
123 return fabsf(1.f + erff(v)) <= comp_err;
124 else
125 return fabsf(1.f + erff(v)
126 + v * two_over_sqrt_pi * expf(-v * v))
127 <= comp_err * 2;
128 }
129 case alg_t::TANH:
130 // catch catastrophic cancellation, which occurs when err in tanh(s)
131 // is high and tanh(s) is close to 1.
132 return (prb->dir & FLAG_BWD) && (1.f - tanhf(fabsf(s))) <= comp_err;
133 case alg_t::TANH_DST: // sse41 can't do fma
134 // catch catastrophic cancellation, which occurs when err in tanh(s)
135 // is high and tanh(s) is close to 1.
136 return (prb->dir & FLAG_BWD) && (1.f - s * s) <= comp_err;
137 case alg_t::SRELU:
138 // when s is negative, expf(s) -> 0 rapidly
139 // which leads to log1pf(expf(s)) -> 0
140 // which leads to high relative error,
141 // while abs error is still low.
142 // (10.f is magic scale for bf16)
143 return (prb->dir & FLAG_FWD) && std::signbit(s)
144 && log1pf(expf(s)) <= 10.f * comp_err;
145 case alg_t::LOGSIGMOID:
146 // same situation like in SRELU
147 // in logsigmoid when s is positive
148 // results -> 0
149 return (prb->dir & FLAG_FWD) && !std::signbit(s)
150 && log1pf(expf(-s)) <= 10.f * comp_err;
151 case alg_t::MISH:
152 // same situation like in SRELU
153 return (prb->dir & FLAG_FWD) && std::signbit(s)
154 && s * tanh(log1pf(expf(s))) <= 10.f * comp_err;
155 case alg_t::LOGISTIC:
156 // when s >= 4, logistic(s) -> 0 rapidly, which leads to high
157 // relative error of logistic(s) * (1 - logistic(s)) due to
158 // catastrohic cancellation.
159 return (prb->dir & FLAG_BWD) && !std::signbit(s)
160 && (1.f / (1.f + expf(s))) <= comp_err;
161 case alg_t::SWISH: {
162 // catch cancellation happening when W(s) ~~ -1 in (1 + W(s))
163 // formula part on backward.
164 const float alpha_s = prb->alpha * s;
165 return (prb->dir & FLAG_BWD)
166 && (alpha_s * (1.f - 1.f / (1.f + expf(-alpha_s)))
167 <= comp_err);
168 }
169 default: return false;
170 }
171 }
172
get_eltwise_threshold(dnnl_data_type_t dt,alg_t alg,bool is_fwd)173 float get_eltwise_threshold(dnnl_data_type_t dt, alg_t alg, bool is_fwd) {
174 // Tolerate only rounding error (1 ulp) for other than fp32 precisions.
175 float trh = dt == dnnl_f32 ? 4e-6 : epsilon_dt(dt);
176 // Tolerate bigger compute errors for complex algorithms.
177 const bool alg_has_higher_tolerance = alg == alg_t::GELU_TANH
178 || alg == alg_t::ELU || alg == alg_t::SWISH || alg == alg_t::TANH
179 || alg == alg_t::SRELU || alg == alg_t::LOGSIGMOID
180 || alg == alg_t::MISH || alg == alg_t::LOG
181 || ((alg == alg_t::ELU_DST || alg == alg_t::TANH_DST) && is_fwd);
182 if (dt == dnnl_f32 && alg_has_higher_tolerance) trh = 4e-5;
183 return trh;
184 }
185
get_eltwise_zero_trust_percent(const prb_t * prb)186 static float get_eltwise_zero_trust_percent(const prb_t *prb) {
187 float ztp = 60.f; // default for eltwise due to filling.
188 switch (prb->alg) {
189 case alg_t::LINEAR:
190 if (prb->alpha == 0) ztp = 100.f;
191 break;
192 case alg_t::BRELU:
193 if ((prb->alpha == 0) || (prb->dir & FLAG_BWD)) ztp = 100.f;
194 break;
195 case alg_t::CLIP:
196 case alg_t::CLIP_V2:
197 case alg_t::CLIP_V2_DST:
198 if ((prb->alpha == 0 && prb->beta == 0) || (prb->dir & FLAG_BWD))
199 ztp = 100.f;
200 break;
201 case alg_t::POW:
202 if (prb->alpha == 0 || ((prb->dir & FLAG_BWD) && prb->beta == 0))
203 ztp = 100.f;
204 break;
205 default: break;
206 }
207 // Integral data types with small float values will produce most zeros.
208 // u8 with negative alpha will produce only zeros.
209 if (is_integral_dt(prb->dt)) ztp = 100.f;
210 return ztp;
211 }
212
fill_data(const prb_t * prb,data_kind_t kind,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp)213 int fill_data(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
214 dnn_mem_t &mem_fp) {
215 const auto nelems = mem_fp.nelems();
216 if (nelems == 0) return OK;
217
218 /* Do fixed partitioning to have same filling for any number of threads */
219 const int64_t n_chunks = 16;
220 const int64_t chunk_size = div_up(nelems, n_chunks);
221
222 dnnl::impl::parallel_nd(n_chunks, [&](int64_t idx_chunk) {
223 int64_t idx_start = idx_chunk * chunk_size;
224 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
225 // Note 1: we use a different seed for each chunk to avoid
226 // repeating patterns. We could use discard(idx_start) too but
227 // we avoid it for two reasons:
228 // a. it has a complexity in O(idx_start).
229 // b. igen and fgen below might require more than 1 sample
230 // per idx, so the we cannot deterministically compute the
231 // number of states we need to discard
232 // Note 2: We also advance the state to avoid having only
233 // small values as first chunk input. The +1 is necessary to
234 // avoid generating zeros in first chunk.
235 // Note 3: we multiply by kind + 1 to have different values in
236 // src/dst and diff_dst. The +1 is to avoid 0 again.
237 std::minstd_rand msr((idx_start + 1) * (kind + 1));
238 msr.discard(1);
239 std::uniform_int_distribution<> igen(0, 10);
240 // TODO: 0.09 due to log impl doesn't give good accuracy in 0.99 points
241 std::uniform_real_distribution<> fgen(0.f, 0.09f);
242
243 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
244 static constexpr int64_t num_of_generation_variants = 13;
245 float value = FLT_MAX;
246 switch (idx % num_of_generation_variants) {
247 case 0: value = (float)igen(msr); break; // [0-10] pos
248 case 1: value = -(float)igen(msr); break; // [0-10] neg
249 case 2: value = fgen(msr); break; // [0.-0.1) pos
250 case 3: value = -fgen(msr); break; // [0.-0.1) neg
251 case 4: value = 10 * (float)igen(msr); break; // [0-100] pos
252 case 5: value = -10 * (float)igen(msr); break; // [0-100] neg
253 case 6: value = 10.f * fgen(msr); break; // [0.-1.) pos
254 case 7: value = -10.f * fgen(msr); break; // [0.-1.) neg
255 case 8:
256 value = 88.f + 10.f * fgen(msr);
257 break; // values close to logf(FLT_MAX) for exp alg testing
258 case 9:
259 value = 22.f + 10.f * fgen(msr);
260 break; // values close to logf(FLT_MAX)/4.0 for bwd mish alg testing
261 case 10:
262 value = 44.f + 10.f * fgen(msr);
263 break; // values close to logf(FLT_MAX)/2.0 for fwd mish alg testing
264 case 11: value = prb->alpha; break; // `x = alpha` corner cases
265 case 12: value = prb->beta; break; // `x = beta` corner cases
266 }
267 value = round_to_nearest_representable(prb->dt, value);
268
269 // Hack: -0 may lead to different sign in the answer since input
270 // passes through simple reorder which converts -0 into +0.
271 if (value == -0.f) value = 0.f;
272
273 mem_fp.set_elem(idx, value);
274 }
275 });
276
277 SAFE(mem_dt.reorder(mem_fp), WARN);
278
279 return OK;
280 }
281
check_known_skipped_case(const prb_t * prb,res_t * res)282 void check_known_skipped_case(const prb_t *prb, res_t *res) {
283 check_known_skipped_case_common({prb->dt}, prb->dir, res);
284 if (res->state == SKIPPED) return;
285
286 bool is_invalid = false;
287 switch (prb->alg) {
288 case alg_t::CLIP:
289 case alg_t::CLIP_V2:
290 case alg_t::CLIP_V2_DST: is_invalid = prb->beta < prb->alpha; break;
291 case alg_t::BRELU:
292 case alg_t::ELU_DST:
293 case alg_t::RELU_DST: is_invalid = prb->alpha < 0; break;
294 case alg_t::ROUND:
295 is_invalid = prb->dt != dnnl_f32 || prb->dir & FLAG_BWD;
296 break;
297 default: break;
298 };
299 if (is_invalid) {
300 res->state = SKIPPED, res->reason = INVALID_CASE;
301 return;
302 }
303
304 if (is_nvidia_gpu()) {
305 if (!is_nvidia_eltwise_ok(prb->dir, prb->alg, prb->alpha)
306 || !prb->attr.post_ops.is_def()) {
307 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
308 return;
309 }
310 }
311 }
312
doit(const prb_t * prb,res_t * res)313 int doit(const prb_t *prb, res_t *res) {
314 if (bench_mode == LIST) return res->state = LISTED, OK;
315
316 check_known_skipped_case(prb, res);
317 check_sum_post_ops(prb->attr, res);
318 if (res->state == SKIPPED) return OK;
319
320 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
321 SAFE(init_prim(prim, init_pd, prb, res), WARN);
322 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
323
324 const_dnnl_primitive_desc_t const_pd;
325 DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
326
327 if (check_mem_size(const_pd) != OK) {
328 return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
329 }
330
331 const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
332 return *dnnl_primitive_desc_query_md(
333 const_pd, dnnl_query_exec_arg_md, index);
334 };
335
336 const bool is_fwd = prb->dir & FLAG_FWD;
337 const auto &src_md = q(DNNL_ARG_SRC);
338 const auto &dst_md = q(DNNL_ARG_DST);
339 const auto &data_md = !is_fwd && prb->use_dst() ? dst_md : src_md;
340 const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
341 const auto &test_engine = get_test_engine();
342
343 dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, test_engine);
344 dnn_mem_t src_dt(data_md, test_engine);
345
346 // we need src_fp for proper comparison, => no in-place reference
347 dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, test_engine);
348 dnn_mem_t placeholder_dst_dt;
349 if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(data_md, test_engine); }
350 dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt;
351
352 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
353 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
354 std::vector<int> binary_po_args;
355 SAFE(binary::setup_binary_po(
356 const_pd, binary_po_args, binary_po_dt, binary_po_fp),
357 WARN);
358
359 dnn_mem_t d_dst_dt, placeholder_d_src_dt;
360
361 SAFE(fill_data(prb, SRC, src_dt, src_fp), WARN);
362
363 args_t args;
364
365 dnn_mem_t &arg_fp = !is_fwd && prb->use_dst() ? dst_fp : src_fp;
366
367 // Shouldn't be defined inside since not available when `eltwise_add_check`
368 // is invoked due to removed from stack.
369 const float trh = get_eltwise_threshold(prb->dt, prb->alg, is_fwd);
370 compare::compare_t cmp;
371 if (is_bench_mode(CORR)) {
372 cmp.set_threshold(trh);
373 cmp.set_zero_trust_percent(get_eltwise_zero_trust_percent(prb));
374
375 const auto eltwise_add_check =
376 [&](const compare::compare_t::driver_check_func_args_t &args) {
377 // Some algorithms require absolute value comparison for inputs
378 // where catastrophic cancellation may happen.
379 const float src = arg_fp.get_elem(args.idx);
380 if (check_abs_err(prb, src, trh)) return args.diff <= trh;
381 if (prb->attr.post_ops.binary_index() != -1)
382 return args.diff <= trh;
383 return false;
384 };
385 cmp.set_driver_check_function(eltwise_add_check);
386 }
387
388 if (prb->dir & FLAG_FWD) {
389 args.set(DNNL_ARG_SRC, src_dt);
390 args.set(DNNL_ARG_DST, dst_dt);
391 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
392 args.set(binary_po_args, binary_po_dt);
393
394 SAFE(execute_and_wait(prim, args), WARN);
395
396 if (is_bench_mode(CORR)) {
397 TIME_REF(compute_ref_fwd(prb, src_fp, binary_po_fp, dst_fp));
398 SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
399 }
400 } else {
401 const auto &d_data_md = q(DNNL_ARG_DIFF_DST);
402
403 dnn_mem_t d_dst_fp
404 = dnn_mem_t(d_data_md, dnnl_f32, tag::abx, test_engine);
405 d_dst_dt = dnn_mem_t(d_data_md, test_engine);
406
407 dnn_mem_t &d_src_fp = d_dst_fp; // in-place reference
408 if (!prb->inplace) {
409 placeholder_d_src_dt = dnn_mem_t(d_data_md, test_engine);
410 }
411 dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
412
413 SAFE(fill_data(prb, DST, d_dst_dt, d_dst_fp), WARN);
414
415 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
416 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
417 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
418
419 if (prb->use_dst()) {
420 if (is_bench_mode(CORR))
421 TIME_REF(compute_ref_fwd(prb, src_fp, binary_po_fp, dst_fp));
422 SAFE(dst_dt.reorder(dst_fp), WARN);
423 // make dst_fp of same values as for bf16, otherwise there are high
424 // relative and absolute errors due to initial difference in source
425 // values which become worse particularly when (1 - x) is used.
426 if (dst_dt.dt() != dst_fp.dt()) SAFE(dst_fp.reorder(dst_dt), WARN);
427 args.set(DNNL_ARG_DST, dst_dt);
428 } else {
429 args.set(DNNL_ARG_SRC, src_dt);
430 }
431 SAFE(execute_and_wait(prim, args), WARN);
432
433 if (is_bench_mode(CORR)) {
434 TIME_REF(compute_ref_bwd(prb, arg_fp, d_dst_fp, d_src_fp));
435 SAFE(cmp.compare(d_src_fp, d_src_dt, prb->attr, res), WARN);
436 }
437 }
438
439 return measure_perf(res, prim, args);
440 }
441
442 } // namespace eltwise
443