1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <math.h>
18 #include <random>
19 #include <stdio.h>
20 #include <stdlib.h>
21 
22 #include "oneapi/dnnl/dnnl.h"
23 
24 #include "tests/test_thread.hpp"
25 
26 #include "dnnl_common.hpp"
27 #include "dnnl_memory.hpp"
28 #include "utils/compare.hpp"
29 
30 #include "binary/binary.hpp"
31 #include "eltwise/eltwise.hpp"
32 
33 namespace eltwise {
34 
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & epd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)35 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
36         dnnl_primitive_desc_t &epd, res_t *res, dir_t dir,
37         const_dnnl_primitive_desc_t hint) {
38     dnnl_eltwise_desc_t ed;
39     dnnl_memory_desc_t data_d;
40 
41     SAFE(init_md(&data_d, prb->ndims, prb->dims.data(), prb->dt, prb->tag),
42             CRIT);
43 
44     dnnl_alg_kind_t alg = attr_t::post_ops_t::kind2dnnl_kind(prb->alg);
45 
46     if (prb->dir & FLAG_FWD) {
47         auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
48                                         : dnnl_forward_training;
49 
50         DNN_SAFE(dnnl_eltwise_forward_desc_init(
51                          &ed, prop, alg, &data_d, prb->alpha, prb->beta),
52                 WARN);
53     } else {
54         dnnl_memory_desc_t diff_data_d;
55         DNN_SAFE(dnnl_memory_desc_init_by_tag(&diff_data_d, prb->ndims,
56                          prb->dims.data(), prb->dt, dnnl_format_tag_any),
57                 WARN);
58         DNN_SAFE(dnnl_eltwise_backward_desc_init(&ed, alg, &diff_data_d,
59                          &data_d, prb->alpha, prb->beta),
60                 WARN);
61     }
62 
63     attr_args_t attr_args;
64     attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dims.data());
65     auto dnnl_attr = make_benchdnn_dnnl_wrapper(
66             create_dnnl_attr(prb->attr, attr_args));
67 
68     dnnl_status_t init_status
69             = dnnl_primitive_desc_create(&epd, &ed, dnnl_attr, engine, nullptr);
70 
71     if (init_status == dnnl_unimplemented)
72         return res->state = UNIMPLEMENTED, OK;
73     else
74         SAFE(init_status, WARN);
75 
76     res->impl_name = query_impl_info(epd);
77     if (maybe_skip(res->impl_name)) {
78         BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
79                 res->impl_name.c_str());
80         return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
81     } else {
82         BENCHDNN_PRINT(
83                 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
84     }
85 
86     SAFE(check_pd_w_and_wo_attr(res, prb->attr, ed), WARN);
87 
88     return OK;
89 }
90 
check_abs_err(const prb_t * prb,const float & s,const float & trh)91 static bool check_abs_err(const prb_t *prb, const float &s, const float &trh) {
92     const float approx_machine_eps = 2 * epsilon_dt(dnnl_f32);
93     const float comp_err = approx_machine_eps / trh;
94 
95     switch (prb->alg) {
96         case alg_t::ELU:
97         case alg_t::ELU_DST:
98             // catch catastrophic cancellation when (exp(s) - 1), s < 0 and
99             // s is close to zero.
100             return (prb->dir & FLAG_FWD) && std::signbit(s)
101                     && (fabsf(expf(s) - 1.f) <= comp_err);
102         case alg_t::GELU_TANH: {
103             // catch catastrophic cancellation
104             // (4.f is magic scale for f32)
105             const float sqrt_2_over_pi = 0.797884;
106             const float fitting_const = 0.044715;
107             float v = tanhf(sqrt_2_over_pi * s * (1 + fitting_const * s * s));
108             float dg = sqrt_2_over_pi * (1 + 3 * fitting_const * s * s);
109             if (fabsf(1.f + v) <= comp_err) return true;
110             return (prb->dir & FLAG_BWD) && std::signbit(s)
111                     && fabsf(1.f + s * (1.f - v) * dg) <= 4.f * comp_err;
112         }
113         case alg_t::GELU_ERF: {
114             // Catch catastrophic cancellation
115             // which occurs at large negative s.
116             // Factor 2 (in bwd) is to account for the fact that error is
117             // accumulated for each summand (except the 1) when they
118             // are of the same order of magnitude.
119             const float sqrt_2_over_2 = 0.707106769084930419921875f;
120             const float two_over_sqrt_pi = 1.12837922573089599609375f;
121             float v = s * sqrt_2_over_2;
122             if (prb->dir & FLAG_FWD)
123                 return fabsf(1.f + erff(v)) <= comp_err;
124             else
125                 return fabsf(1.f + erff(v)
126                                + v * two_over_sqrt_pi * expf(-v * v))
127                         <= comp_err * 2;
128         }
129         case alg_t::TANH:
130             // catch catastrophic cancellation, which occurs when err in tanh(s)
131             // is high and tanh(s) is close to 1.
132             return (prb->dir & FLAG_BWD) && (1.f - tanhf(fabsf(s))) <= comp_err;
133         case alg_t::TANH_DST: // sse41 can't do fma
134             // catch catastrophic cancellation, which occurs when err in tanh(s)
135             // is high and tanh(s) is close to 1.
136             return (prb->dir & FLAG_BWD) && (1.f - s * s) <= comp_err;
137         case alg_t::SRELU:
138             // when s is negative, expf(s) -> 0 rapidly
139             // which leads to log1pf(expf(s)) -> 0
140             // which leads to high relative error,
141             // while abs error is still low.
142             // (10.f is magic scale for bf16)
143             return (prb->dir & FLAG_FWD) && std::signbit(s)
144                     && log1pf(expf(s)) <= 10.f * comp_err;
145         case alg_t::LOGSIGMOID:
146             // same situation like in SRELU
147             // in logsigmoid when s is positive
148             // results -> 0
149             return (prb->dir & FLAG_FWD) && !std::signbit(s)
150                     && log1pf(expf(-s)) <= 10.f * comp_err;
151         case alg_t::MISH:
152             // same situation like in SRELU
153             return (prb->dir & FLAG_FWD) && std::signbit(s)
154                     && s * tanh(log1pf(expf(s))) <= 10.f * comp_err;
155         case alg_t::LOGISTIC:
156             // when s >= 4, logistic(s) -> 0 rapidly, which leads to high
157             // relative error of logistic(s) * (1 - logistic(s)) due to
158             // catastrohic cancellation.
159             return (prb->dir & FLAG_BWD) && !std::signbit(s)
160                     && (1.f / (1.f + expf(s))) <= comp_err;
161         case alg_t::SWISH: {
162             // catch cancellation happening when W(s) ~~ -1 in (1 + W(s))
163             // formula part on backward.
164             const float alpha_s = prb->alpha * s;
165             return (prb->dir & FLAG_BWD)
166                     && (alpha_s * (1.f - 1.f / (1.f + expf(-alpha_s)))
167                             <= comp_err);
168         }
169         default: return false;
170     }
171 }
172 
get_eltwise_threshold(dnnl_data_type_t dt,alg_t alg,bool is_fwd)173 float get_eltwise_threshold(dnnl_data_type_t dt, alg_t alg, bool is_fwd) {
174     // Tolerate only rounding error (1 ulp) for other than fp32 precisions.
175     float trh = dt == dnnl_f32 ? 4e-6 : epsilon_dt(dt);
176     // Tolerate bigger compute errors for complex algorithms.
177     const bool alg_has_higher_tolerance = alg == alg_t::GELU_TANH
178             || alg == alg_t::ELU || alg == alg_t::SWISH || alg == alg_t::TANH
179             || alg == alg_t::SRELU || alg == alg_t::LOGSIGMOID
180             || alg == alg_t::MISH || alg == alg_t::LOG
181             || ((alg == alg_t::ELU_DST || alg == alg_t::TANH_DST) && is_fwd);
182     if (dt == dnnl_f32 && alg_has_higher_tolerance) trh = 4e-5;
183     return trh;
184 }
185 
get_eltwise_zero_trust_percent(const prb_t * prb)186 static float get_eltwise_zero_trust_percent(const prb_t *prb) {
187     float ztp = 60.f; // default for eltwise due to filling.
188     switch (prb->alg) {
189         case alg_t::LINEAR:
190             if (prb->alpha == 0) ztp = 100.f;
191             break;
192         case alg_t::BRELU:
193             if ((prb->alpha == 0) || (prb->dir & FLAG_BWD)) ztp = 100.f;
194             break;
195         case alg_t::CLIP:
196         case alg_t::CLIP_V2:
197         case alg_t::CLIP_V2_DST:
198             if ((prb->alpha == 0 && prb->beta == 0) || (prb->dir & FLAG_BWD))
199                 ztp = 100.f;
200             break;
201         case alg_t::POW:
202             if (prb->alpha == 0 || ((prb->dir & FLAG_BWD) && prb->beta == 0))
203                 ztp = 100.f;
204             break;
205         default: break;
206     }
207     // Integral data types with small float values will produce most zeros.
208     // u8 with negative alpha will produce only zeros.
209     if (is_integral_dt(prb->dt)) ztp = 100.f;
210     return ztp;
211 }
212 
fill_data(const prb_t * prb,data_kind_t kind,dnn_mem_t & mem_dt,dnn_mem_t & mem_fp)213 int fill_data(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
214         dnn_mem_t &mem_fp) {
215     const auto nelems = mem_fp.nelems();
216     if (nelems == 0) return OK;
217 
218     /* Do fixed partitioning to have same filling for any number of threads */
219     const int64_t n_chunks = 16;
220     const int64_t chunk_size = div_up(nelems, n_chunks);
221 
222     dnnl::impl::parallel_nd(n_chunks, [&](int64_t idx_chunk) {
223         int64_t idx_start = idx_chunk * chunk_size;
224         int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
225         // Note 1: we use a different seed for each chunk to avoid
226         // repeating patterns. We could use discard(idx_start) too but
227         // we avoid it for two reasons:
228         //   a. it has a complexity in O(idx_start).
229         //   b. igen and fgen below might require more than 1 sample
230         //   per idx, so the we cannot deterministically compute the
231         //   number of states we need to discard
232         // Note 2: We also advance the state to avoid having only
233         // small values as first chunk input.  The +1 is necessary to
234         // avoid generating zeros in first chunk.
235         // Note 3: we multiply by kind + 1 to have different values in
236         // src/dst and diff_dst. The +1 is to avoid 0 again.
237         std::minstd_rand msr((idx_start + 1) * (kind + 1));
238         msr.discard(1);
239         std::uniform_int_distribution<> igen(0, 10);
240         // TODO: 0.09 due to log impl doesn't give good accuracy in 0.99 points
241         std::uniform_real_distribution<> fgen(0.f, 0.09f);
242 
243         for (int64_t idx = idx_start; idx < idx_end; ++idx) {
244             static constexpr int64_t num_of_generation_variants = 13;
245             float value = FLT_MAX;
246             switch (idx % num_of_generation_variants) {
247                 case 0: value = (float)igen(msr); break; // [0-10] pos
248                 case 1: value = -(float)igen(msr); break; // [0-10] neg
249                 case 2: value = fgen(msr); break; // [0.-0.1) pos
250                 case 3: value = -fgen(msr); break; // [0.-0.1) neg
251                 case 4: value = 10 * (float)igen(msr); break; // [0-100] pos
252                 case 5: value = -10 * (float)igen(msr); break; // [0-100] neg
253                 case 6: value = 10.f * fgen(msr); break; // [0.-1.) pos
254                 case 7: value = -10.f * fgen(msr); break; // [0.-1.) neg
255                 case 8:
256                     value = 88.f + 10.f * fgen(msr);
257                     break; // values close to logf(FLT_MAX) for exp alg testing
258                 case 9:
259                     value = 22.f + 10.f * fgen(msr);
260                     break; // values close to logf(FLT_MAX)/4.0 for bwd mish alg testing
261                 case 10:
262                     value = 44.f + 10.f * fgen(msr);
263                     break; // values close to logf(FLT_MAX)/2.0 for fwd mish alg testing
264                 case 11: value = prb->alpha; break; // `x = alpha` corner cases
265                 case 12: value = prb->beta; break; // `x = beta` corner cases
266             }
267             value = round_to_nearest_representable(prb->dt, value);
268 
269             // Hack: -0 may lead to different sign in the answer since input
270             // passes through simple reorder which converts -0 into +0.
271             if (value == -0.f) value = 0.f;
272 
273             mem_fp.set_elem(idx, value);
274         }
275     });
276 
277     SAFE(mem_dt.reorder(mem_fp), WARN);
278 
279     return OK;
280 }
281 
check_known_skipped_case(const prb_t * prb,res_t * res)282 void check_known_skipped_case(const prb_t *prb, res_t *res) {
283     check_known_skipped_case_common({prb->dt}, prb->dir, res);
284     if (res->state == SKIPPED) return;
285 
286     bool is_invalid = false;
287     switch (prb->alg) {
288         case alg_t::CLIP:
289         case alg_t::CLIP_V2:
290         case alg_t::CLIP_V2_DST: is_invalid = prb->beta < prb->alpha; break;
291         case alg_t::BRELU:
292         case alg_t::ELU_DST:
293         case alg_t::RELU_DST: is_invalid = prb->alpha < 0; break;
294         case alg_t::ROUND:
295             is_invalid = prb->dt != dnnl_f32 || prb->dir & FLAG_BWD;
296             break;
297         default: break;
298     };
299     if (is_invalid) {
300         res->state = SKIPPED, res->reason = INVALID_CASE;
301         return;
302     }
303 
304     if (is_nvidia_gpu()) {
305         if (!is_nvidia_eltwise_ok(prb->dir, prb->alg, prb->alpha)
306                 || !prb->attr.post_ops.is_def()) {
307             res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
308             return;
309         }
310     }
311 }
312 
doit(const prb_t * prb,res_t * res)313 int doit(const prb_t *prb, res_t *res) {
314     if (bench_mode == LIST) return res->state = LISTED, OK;
315 
316     check_known_skipped_case(prb, res);
317     check_sum_post_ops(prb->attr, res);
318     if (res->state == SKIPPED) return OK;
319 
320     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
321     SAFE(init_prim(prim, init_pd, prb, res), WARN);
322     if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
323 
324     const_dnnl_primitive_desc_t const_pd;
325     DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
326 
327     if (check_mem_size(const_pd) != OK) {
328         return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
329     }
330 
331     const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
332         return *dnnl_primitive_desc_query_md(
333                 const_pd, dnnl_query_exec_arg_md, index);
334     };
335 
336     const bool is_fwd = prb->dir & FLAG_FWD;
337     const auto &src_md = q(DNNL_ARG_SRC);
338     const auto &dst_md = q(DNNL_ARG_DST);
339     const auto &data_md = !is_fwd && prb->use_dst() ? dst_md : src_md;
340     const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
341     const auto &test_engine = get_test_engine();
342 
343     dnn_mem_t src_fp(data_md, dnnl_f32, tag::abx, test_engine);
344     dnn_mem_t src_dt(data_md, test_engine);
345 
346     // we need src_fp for proper comparison, => no in-place reference
347     dnn_mem_t dst_fp(data_md, dnnl_f32, tag::abx, test_engine);
348     dnn_mem_t placeholder_dst_dt;
349     if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(data_md, test_engine); }
350     dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt;
351 
352     dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
353     std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
354     std::vector<int> binary_po_args;
355     SAFE(binary::setup_binary_po(
356                  const_pd, binary_po_args, binary_po_dt, binary_po_fp),
357             WARN);
358 
359     dnn_mem_t d_dst_dt, placeholder_d_src_dt;
360 
361     SAFE(fill_data(prb, SRC, src_dt, src_fp), WARN);
362 
363     args_t args;
364 
365     dnn_mem_t &arg_fp = !is_fwd && prb->use_dst() ? dst_fp : src_fp;
366 
367     // Shouldn't be defined inside since not available when `eltwise_add_check`
368     // is invoked due to removed from stack.
369     const float trh = get_eltwise_threshold(prb->dt, prb->alg, is_fwd);
370     compare::compare_t cmp;
371     if (is_bench_mode(CORR)) {
372         cmp.set_threshold(trh);
373         cmp.set_zero_trust_percent(get_eltwise_zero_trust_percent(prb));
374 
375         const auto eltwise_add_check =
376                 [&](const compare::compare_t::driver_check_func_args_t &args) {
377                     // Some algorithms require absolute value comparison for inputs
378                     // where catastrophic cancellation may happen.
379                     const float src = arg_fp.get_elem(args.idx);
380                     if (check_abs_err(prb, src, trh)) return args.diff <= trh;
381                     if (prb->attr.post_ops.binary_index() != -1)
382                         return args.diff <= trh;
383                     return false;
384                 };
385         cmp.set_driver_check_function(eltwise_add_check);
386     }
387 
388     if (prb->dir & FLAG_FWD) {
389         args.set(DNNL_ARG_SRC, src_dt);
390         args.set(DNNL_ARG_DST, dst_dt);
391         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
392         args.set(binary_po_args, binary_po_dt);
393 
394         SAFE(execute_and_wait(prim, args), WARN);
395 
396         if (is_bench_mode(CORR)) {
397             TIME_REF(compute_ref_fwd(prb, src_fp, binary_po_fp, dst_fp));
398             SAFE(cmp.compare(dst_fp, dst_dt, prb->attr, res), WARN);
399         }
400     } else {
401         const auto &d_data_md = q(DNNL_ARG_DIFF_DST);
402 
403         dnn_mem_t d_dst_fp
404                 = dnn_mem_t(d_data_md, dnnl_f32, tag::abx, test_engine);
405         d_dst_dt = dnn_mem_t(d_data_md, test_engine);
406 
407         dnn_mem_t &d_src_fp = d_dst_fp; // in-place reference
408         if (!prb->inplace) {
409             placeholder_d_src_dt = dnn_mem_t(d_data_md, test_engine);
410         }
411         dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
412 
413         SAFE(fill_data(prb, DST, d_dst_dt, d_dst_fp), WARN);
414 
415         args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
416         args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
417         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
418 
419         if (prb->use_dst()) {
420             if (is_bench_mode(CORR))
421                 TIME_REF(compute_ref_fwd(prb, src_fp, binary_po_fp, dst_fp));
422             SAFE(dst_dt.reorder(dst_fp), WARN);
423             // make dst_fp of same values as for bf16, otherwise there are high
424             // relative and absolute errors due to initial difference in source
425             // values which become worse particularly when (1 - x) is used.
426             if (dst_dt.dt() != dst_fp.dt()) SAFE(dst_fp.reorder(dst_dt), WARN);
427             args.set(DNNL_ARG_DST, dst_dt);
428         } else {
429             args.set(DNNL_ARG_SRC, src_dt);
430         }
431         SAFE(execute_and_wait(prim, args), WARN);
432 
433         if (is_bench_mode(CORR)) {
434             TIME_REF(compute_ref_bwd(prb, arg_fp, d_dst_fp, d_src_fp));
435             SAFE(cmp.compare(d_src_fp, d_src_dt, prb->attr, res), WARN);
436         }
437     }
438 
439     return measure_perf(res, prim, args);
440 }
441 
442 } // namespace eltwise
443