1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <memory>
18 
19 #include "dnnl_test_common.hpp"
20 #include "gtest/gtest.h"
21 
22 #include "oneapi/dnnl/dnnl.hpp"
23 
24 namespace dnnl {
25 
26 template <typename T, typename A>
relu_fwd(T s,A alpha)27 inline T relu_fwd(T s, A alpha) {
28     return s > 0 ? s : static_cast<T>(s * alpha);
29 }
30 template <typename T, typename A>
relu_bwd(T dd,T s,A alpha)31 inline T relu_bwd(T dd, T s, A alpha) {
32     return s > 0 ? dd : static_cast<T>(dd * alpha);
33 }
34 template <typename T>
tanh_fwd(T s)35 T tanh_fwd(T s) {
36     return static_cast<T>(::tanhf((float)s));
37 }
38 template <typename T>
tanh_bwd(T dd,T s)39 T tanh_bwd(T dd, T s) {
40     const float th = ::tanhf((float)s);
41     return static_cast<T>(dd * (1 - th) * (1 + th));
42 }
43 
44 template <typename T, typename A>
45 T elu_fwd(T s, A alpha) {
46     return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
47 }
48 template <typename T, typename A>
49 T elu_bwd(T dd, T s, A alpha) {
50     return static_cast<T>(dd * (s > 0 ? 1 : alpha * ::expf(s)));
51 }
52 
53 template <typename T>
gelu_tanh_fwd(T s)54 T gelu_tanh_fwd(T s) {
55     const float a = 0.797884;
56     const float b = 0.044715;
57     const float g = a * s * (1 + b * s * s);
58     return static_cast<T>(0.5 * s * (1 + tanh_fwd(g)));
59 }
60 
61 template <typename T>
gelu_tanh_bwd(T dd,T s)62 T gelu_tanh_bwd(T dd, T s) {
63     const float a = 0.797884;
64     const float b = 0.044715;
65     const float g = a * s * (1 + b * s * s);
66     const float dg = a * (1 + 3 * b * s * s);
67     return static_cast<T>(
68             dd * (0.5 * (1 + tanh_fwd(g)) * (1 + s * (1 - tanh_fwd(g)) * dg)));
69 }
70 
71 template <typename T>
square_fwd(T s)72 T square_fwd(T s) {
73     return s * s;
74 }
75 
76 template <typename T>
square_bwd(T dd,T s)77 T square_bwd(T dd, T s) {
78     return dd * 2 * s;
79 }
80 
81 template <typename T>
abs_fwd(T s)82 T abs_fwd(T s) {
83     return s > 0 ? s : T(-s);
84 }
85 
86 template <typename T>
abs_bwd(T dd,T s)87 T abs_bwd(T dd, T s) {
88     return dd * (s > 0 ? 1 : s < 0 ? -1 : 0);
89 }
90 
91 template <typename T, typename A>
92 T linear_fwd(T s, A alpha, A beta) {
93     return alpha * s + beta;
94 }
95 
96 template <typename T, typename A>
97 T linear_bwd(T dd, T s, A alpha, A beta) {
98     (void)s;
99     (void)beta;
100     return dd * alpha;
101 }
102 
103 template <typename T, typename A>
104 T bounded_relu_fwd(T s, A alpha) {
105     s = s > 0 ? s : T(0);
106     return s > alpha ? T(alpha) : s;
107 }
108 
109 template <typename T, typename A>
110 T bounded_relu_bwd(T dd, T s, A alpha) {
111     return dd * ((0 < s && s < alpha) ? 1 : 0);
112 }
113 
114 template <typename T>
soft_relu_fwd(T s)115 T soft_relu_fwd(T s) {
116     return s < (T)logf(FLT_MAX) ? T(log1pf(::expf(s))) : s;
117 }
118 
119 template <typename T>
soft_relu_bwd(T dd,T s)120 T soft_relu_bwd(T dd, T s) {
121     return dd / (1 + ::expf(-s));
122 }
123 
124 template <typename T>
logistic_fwd(T s)125 T logistic_fwd(T s) {
126     float v = ::expf((float)-s);
127     return (T)(1 / (1 + v));
128 }
129 
130 template <typename T>
logistic_bwd(T dd,T s)131 T logistic_bwd(T dd, T s) {
132     float v = logistic_fwd<float>(s);
133     return (T)(dd * v * (1 - v));
134 }
135 
136 template <typename T>
exp_fwd(T s)137 T exp_fwd(T s) {
138     return (T)(::expf((float)s));
139 }
140 
141 template <typename T>
exp_bwd(T dd,T s)142 T exp_bwd(T dd, T s) {
143     return dd * (::expf((float)s));
144 }
145 
146 template <typename T, typename A>
147 T swish_fwd(T s, A alpha) {
148     return (T)(s / (1.0f + ::expf(-alpha * (float)s)));
149 }
150 
151 template <typename T, typename A>
152 T swish_bwd(T dd, T s, A alpha) {
153     float v = logistic_fwd<float>(alpha * s);
154     return dd * (v + s * alpha * v * (1 - v));
155 }
156 
157 template <typename T>
gelu_erf_fwd(T s)158 T gelu_erf_fwd(T s) {
159     const float sqrt_2_over_2 = 0.707106;
160     float v = s * sqrt_2_over_2;
161     return (T)(sqrt_2_over_2 * v * (1.f + ::erff(v)));
162 }
163 
164 template <typename T>
gelu_erf_bwd(T dd,T s)165 T gelu_erf_bwd(T dd, T s) {
166     const float two_over_sqrt_pi = 1.128379;
167     const float sqrt_2_over_2 = 0.707106;
168     float v = s * sqrt_2_over_2;
169     return (T)(dd * 0.5f
170             * (1.f + ::erff(v) + v * two_over_sqrt_pi * ::expf(-v * v)));
171 }
172 
173 struct eltwise_test_params_t {
174     algorithm alg_kind;
175     memory::format_tag data_format;
176     memory::format_tag diff_format;
177     float alpha, beta;
178     memory::dims dims;
179     bool expect_to_fail;
180     dnnl_status_t expected_status;
181 };
182 
n_elems(const memory::desc & md)183 memory::dim n_elems(const memory::desc &md) {
184     memory::dim p = 1;
185     const auto *pdims = md.data.padded_dims;
186     for (int i = 0; i < md.data.ndims; ++i)
187         p *= pdims[i];
188     return p;
189 }
190 
191 template <typename data_t>
check_eltwise_fwd(const eltwise_test_params_t & p,const memory::desc & md,const memory & src,const memory & dst)192 void check_eltwise_fwd(const eltwise_test_params_t &p, const memory::desc &md,
193         const memory &src, const memory &dst) {
194     auto src_data = map_memory<data_t>(src);
195     auto dst_data = map_memory<data_t>(dst);
196 
197     memory::dim n = n_elems(md);
198     for (memory::dim i = 0; i < n; ++i) {
199         data_t s = src_data[i];
200         data_t ref_d = 0;
201         switch (p.alg_kind) {
202             case algorithm::eltwise_relu: ref_d = relu_fwd(s, p.alpha); break;
203             case algorithm::eltwise_tanh: ref_d = tanh_fwd(s); break;
204             case algorithm::eltwise_elu: ref_d = elu_fwd(s, p.alpha); break;
205             case algorithm::eltwise_square: ref_d = square_fwd(s); break;
206             case algorithm::eltwise_abs: ref_d = abs_fwd(s); break;
207             case algorithm::eltwise_linear:
208                 ref_d = linear_fwd(s, p.alpha, p.beta);
209                 break;
210             case algorithm::eltwise_bounded_relu:
211                 ref_d = bounded_relu_fwd(s, p.alpha);
212                 break;
213             case algorithm::eltwise_soft_relu: ref_d = soft_relu_fwd(s); break;
214             case algorithm::eltwise_logistic: ref_d = logistic_fwd(s); break;
215             case algorithm::eltwise_exp: ref_d = exp_fwd(s); break;
216             case algorithm::eltwise_gelu_tanh: ref_d = gelu_tanh_fwd(s); break;
217             case algorithm::eltwise_swish: ref_d = swish_fwd(s, p.alpha); break;
218             case algorithm::eltwise_gelu_erf: ref_d = gelu_erf_fwd(s); break;
219             default: assert(!"unknown alg_kind");
220         }
221         dst_data[i] = ref_d;
222     }
223 }
224 
225 template <typename data_t>
compare_eltwise_fwd(const eltwise_test_params_t & p,const memory::desc & md,const memory & dst,const memory & ref_dst)226 void compare_eltwise_fwd(const eltwise_test_params_t &p, const memory::desc &md,
227         const memory &dst, const memory &ref_dst) {
228     data_t eps;
229     if (data_traits<data_t>::data_type == memory::data_type::s8
230             || data_traits<data_t>::data_type == memory::data_type::s32)
231         eps = 0;
232     else
233         eps = static_cast<data_t>(
234                 (data_traits<data_t>::data_type == memory::data_type::f16
235                         || data_traits<data_t>::data_type
236                                 == memory::data_type::bf16)
237                         ? 5e-2
238                         : (p.alg_kind == algorithm::eltwise_elu
239                                   || p.alg_kind == algorithm::eltwise_gelu_tanh
240                                   || p.alg_kind == algorithm::eltwise_gelu_erf)
241                                 ? 2e-5
242                                 : p.alg_kind == algorithm::eltwise_soft_relu
243                                         ? 3e-5
244                                         : 1e-6);
245     compare_data(ref_dst, dst, eps);
246 }
247 
248 template <typename data_t>
check_eltwise_bwd(const eltwise_test_params_t & p,const memory::desc & md,const memory & src,const memory & diff_dst,const memory & diff_src)249 void check_eltwise_bwd(const eltwise_test_params_t &p, const memory::desc &md,
250         const memory &src, const memory &diff_dst, const memory &diff_src) {
251     auto src_data = map_memory<data_t>(src);
252     auto diff_dst_data = map_memory<data_t>(diff_dst);
253     auto diff_src_data = map_memory<data_t>(diff_src);
254 
255     const memory::desc data_d = src.get_desc();
256     const memory::desc diff_data_d = diff_src.get_desc();
257     const dnnl::impl::memory_desc_wrapper data_mdw(data_d.data);
258     const dnnl::impl::memory_desc_wrapper diff_data_mdw(diff_data_d.data);
259 
260     float eps_f = 0;
261     if (p.alg_kind == algorithm::eltwise_soft_relu) {
262         eps_f = 2e-6f;
263     } else if (p.alg_kind == algorithm::eltwise_tanh) {
264         eps_f = (get_test_engine_kind() == engine::kind::gpu) ? 2e-5f : 2e-6f;
265     } else if (p.alg_kind == algorithm::eltwise_gelu_tanh
266             || p.alg_kind == algorithm::eltwise_gelu_erf) {
267         eps_f = 1e-5f;
268     } else {
269         eps_f = 1e-6f;
270     }
271     data_t eps = static_cast<data_t>(eps_f);
272 
273     memory::dim n = n_elems(md);
274     for (memory::dim i = 0; i < n; ++i) {
275         data_t ref_s = src_data[data_mdw.off_l(i)];
276         data_t ref_dd = diff_dst_data[diff_data_mdw.off_l(i)];
277         data_t ref_ds = 0;
278         switch (p.alg_kind) {
279             case algorithm::eltwise_relu:
280                 ref_ds = relu_bwd(ref_dd, ref_s, p.alpha);
281                 break;
282             case algorithm::eltwise_tanh:
283                 ref_ds = tanh_bwd(ref_dd, ref_s);
284                 break;
285             case algorithm::eltwise_elu:
286                 ref_ds = elu_bwd(ref_dd, ref_s, p.alpha);
287                 break;
288             case algorithm::eltwise_square:
289                 ref_ds = square_bwd(ref_dd, ref_s);
290                 break;
291             case algorithm::eltwise_abs: ref_ds = abs_bwd(ref_dd, ref_s); break;
292             case algorithm::eltwise_linear:
293                 ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
294                 break;
295             case algorithm::eltwise_bounded_relu:
296                 ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
297                 break;
298             case algorithm::eltwise_soft_relu:
299                 ref_ds = soft_relu_bwd(ref_dd, ref_s);
300                 break;
301             case algorithm::eltwise_logistic:
302                 ref_ds = logistic_bwd(ref_dd, ref_s);
303                 break;
304             case algorithm::eltwise_exp: ref_ds = exp_bwd(ref_dd, ref_s); break;
305             case algorithm::eltwise_gelu_tanh:
306                 ref_ds = gelu_tanh_bwd(ref_dd, ref_s);
307                 break;
308             case algorithm::eltwise_swish:
309                 ref_ds = swish_bwd(ref_dd, ref_s, p.alpha);
310                 break;
311             case algorithm::eltwise_gelu_erf:
312                 ref_ds = gelu_erf_bwd(ref_dd, ref_s);
313                 break;
314             default: assert(!"unknown alg_kind");
315         }
316 
317         data_t tgt = diff_src_data[diff_data_mdw.off_l(i)];
318         const data_t diff = tgt == ref_ds ? 0 : tgt - ref_ds;
319         data_t error = (std::abs(ref_ds) > eps)
320                 ? static_cast<data_t>(diff / ref_ds)
321                 : diff;
322         if (p.alg_kind == algorithm::eltwise_logistic
323                 && (tgt < 1e-3)) { // check for cancellation
324             error = diff;
325         }
326         ASSERT_NEAR(error, 0.0, eps);
327     }
328 }
329 
330 template <typename data_t>
331 class eltwise_test_t : public ::testing::TestWithParam<eltwise_test_params_t> {
332 private:
333     memory src;
334     std::shared_ptr<memory::desc> data_desc;
335     eltwise_forward::primitive_desc eltwise_prim_desc;
336     eltwise_test_params_t p;
337     engine eng;
338     stream strm;
339     memory::data_type data_type;
340 
341 protected:
SetUp()342     void SetUp() override {
343         data_type = data_traits<data_t>::data_type;
344         SKIP_IF(unsupported_data_type(data_type),
345                 "Engine does not support this data type.");
346         p = ::testing::TestWithParam<decltype(p)>::GetParam();
347         SKIP_IF((p.alg_kind != algorithm::eltwise_relu
348                         || (p.alg_kind == algorithm::eltwise_relu
349                                 && p.alpha != 0.0))
350                         && (data_type == memory::data_type::s32
351                                 || data_type == memory::data_type::s8),
352                 "oneDNN only supports relu w/ slope=0 for integers");
353         SKIP_IF_CUDA(p.alg_kind != algorithm::eltwise_relu
354                         && p.alg_kind != algorithm::eltwise_bounded_relu
355                         && p.alg_kind != algorithm::eltwise_tanh
356                         && p.alg_kind != algorithm::eltwise_elu
357                         && p.alg_kind != algorithm::eltwise_logistic,
358                 "Unsupported algorithm type for CUDA");
359         SKIP_IF_CUDA(p.alg_kind == algorithm::eltwise_relu && p.alpha != 0.0,
360                 "DNNL only supports relu w/ slope=0 for integers");
361         SKIP_IF_CUDA(!cuda_check_format_tag(p.data_format),
362                 "Unsupported format tag");
363         SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_format),
364                 "Unsupported format tag");
365         catch_expected_failures(
366                 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
367     }
368 
cuda_check_format_tag(memory::format_tag tag)369     bool cuda_check_format_tag(memory::format_tag tag) {
370         // Blocking is not supported by cuDNN
371         return (tag != memory::format_tag::aBcd8b
372                 && tag != memory::format_tag::aBcd16b
373                 && tag != memory::format_tag::aBcde8b
374                 && tag != memory::format_tag::aBcde16b);
375     }
376 
Test()377     void Test() {
378         p = ::testing::TestWithParam<eltwise_test_params_t>::GetParam();
379 
380         eng = get_test_engine();
381         strm = make_stream(eng);
382 
383         Forward();
384         if (data_type == memory::data_type::f32
385                 || data_type == memory::data_type::bf16) {
386             Backward();
387         }
388     }
389 
Forward()390     void Forward() {
391         data_desc = std::make_shared<memory::desc>(
392                 p.dims, data_type, p.data_format);
393         src = test::make_memory(*data_desc, eng);
394         auto dst = test::make_memory(*data_desc, eng);
395         auto ref_dst = test::make_memory(*data_desc, eng);
396 
397         data_t data_median = data_t(0);
398         data_t data_deviation = (p.alg_kind == algorithm::eltwise_elu
399                                         || p.alg_kind == algorithm::eltwise_exp)
400                         || (p.alg_kind == algorithm::eltwise_swish)
401                 ? data_t(1.0)
402                 : p.alg_kind == algorithm::eltwise_square ? data_t(6.0)
403                                                           : data_t(100.0);
404         fill_data<data_t>(
405                 n_elems(*data_desc), src, data_median, data_deviation);
406         check_zero_tail<data_t>(1, src);
407 
408         auto eltwise_desc = eltwise_forward::desc(prop_kind::forward_training,
409                 p.alg_kind, *data_desc, p.alpha, p.beta);
410         eltwise_prim_desc = eltwise_forward::primitive_desc(eltwise_desc, eng);
411         eltwise_prim_desc = eltwise_forward::primitive_desc(
412                 eltwise_prim_desc.get()); // test construction from a C pd
413 
414         ASSERT_TRUE(eltwise_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
415                 == eltwise_prim_desc.src_desc());
416         ASSERT_TRUE(eltwise_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
417                 == eltwise_prim_desc.dst_desc());
418 
419         eltwise_forward(eltwise_prim_desc)
420                 .execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
421         strm.wait();
422 
423         check_zero_tail<data_t>(0, dst);
424         check_eltwise_fwd<data_t>(p, *data_desc, src, ref_dst);
425         check_zero_tail<data_t>(1, ref_dst);
426         compare_eltwise_fwd<data_t>(p, *data_desc, dst, ref_dst);
427     }
428 
Backward()429     void Backward() {
430         SKIP_IF_CUDA(p.alg_kind != algorithm::eltwise_relu
431                         && p.alg_kind != algorithm::eltwise_bounded_relu,
432                 "Unsupported algorithm");
433         SKIP_IF_CUDA(p.diff_format != p.data_format,
434                 "CUDA does not support different data formats for data and "
435                 "diff vectors");
436         memory::desc diff_data_desc(p.dims, data_type, p.diff_format);
437         auto diff_src = test::make_memory(diff_data_desc, eng);
438         auto diff_dst = test::make_memory(diff_data_desc, eng);
439 
440         data_t data_median = data_t(0);
441         data_t data_deviation = p.alg_kind == algorithm::eltwise_elu
442                 ? data_t(1.0)
443                 : p.alg_kind == algorithm::eltwise_square ? data_t(6.0)
444                                                           : data_t(100.0);
445         fill_data<data_t>(
446                 n_elems(diff_data_desc), diff_dst, data_median, data_deviation);
447         check_zero_tail<data_t>(1, diff_dst);
448 
449         auto eltwise_bwd_desc = eltwise_backward::desc(
450                 p.alg_kind, diff_data_desc, *data_desc, p.alpha, p.beta);
451         auto eltwise_bwd_prim_desc = eltwise_backward::primitive_desc(
452                 eltwise_bwd_desc, eng, eltwise_prim_desc);
453         eltwise_bwd_prim_desc
454                 = eltwise_backward::primitive_desc(eltwise_bwd_prim_desc.get());
455 
456         ASSERT_TRUE(
457                 eltwise_bwd_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
458                 == eltwise_bwd_prim_desc.src_desc());
459         ASSERT_TRUE(
460                 eltwise_bwd_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
461                 == eltwise_bwd_prim_desc.dst_desc());
462 
463         eltwise_backward(eltwise_bwd_prim_desc)
464                 .execute(strm,
465                         {{DNNL_ARG_SRC, src}, {DNNL_ARG_DIFF_DST, diff_dst},
466                                 {DNNL_ARG_DIFF_SRC, diff_src}});
467         strm.wait();
468 
469         check_zero_tail<data_t>(0, diff_src);
470         check_eltwise_bwd<data_t>(p, *data_desc, src, diff_dst, diff_src);
471     }
472 };
473 
474 using eltwise_test_f16 = eltwise_test_t<float16_t>;
475 using eltwise_test_bf16 = eltwise_test_t<bfloat16_t>;
476 using eltwise_test_f32 = eltwise_test_t<float>;
477 using eltwise_test_s32 = eltwise_test_t<int>;
478 using eltwise_test_s8 = eltwise_test_t<int8_t>;
479 
TEST_P(eltwise_test_f16,TestsEltwise)480 TEST_P(eltwise_test_f16, TestsEltwise) {}
481 
TEST_P(eltwise_test_bf16,TestsEltwise)482 TEST_P(eltwise_test_bf16, TestsEltwise) {}
483 
TEST_P(eltwise_test_f32,TestsEltwise)484 TEST_P(eltwise_test_f32, TestsEltwise) {}
485 
TEST_P(eltwise_test_s32,TestsEltwise)486 TEST_P(eltwise_test_s32, TestsEltwise) {}
487 
TEST_P(eltwise_test_s8,TestsEltwise)488 TEST_P(eltwise_test_s8, TestsEltwise) {}
489 
490 #define EXPAND(args) args
491 
492 #define EXPAND_FORMATS(data) memory::format_tag::data
493 #define EXPAND_DIMS(...) \
494     { __VA_ARGS__ }
495 
496 #define PARAMS(alg, data, diff_data, alpha, beta, ...) \
497     eltwise_test_params_t { \
498         algorithm::alg, EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
499                 alpha, beta, EXPAND_DIMS(__VA_ARGS__) \
500     }
501 
502 #define PARAMS_ALL_ALG(...) \
503     EXPAND(PARAMS(eltwise_gelu_tanh, __VA_ARGS__)), \
504             EXPAND(PARAMS(eltwise_relu, __VA_ARGS__)), \
505             EXPAND(PARAMS(eltwise_tanh, __VA_ARGS__)), \
506             EXPAND(PARAMS(eltwise_elu, __VA_ARGS__)), \
507             EXPAND(PARAMS(eltwise_square, __VA_ARGS__)), \
508             EXPAND(PARAMS(eltwise_abs, __VA_ARGS__)), \
509             EXPAND(PARAMS(eltwise_exp, __VA_ARGS__)), \
510             EXPAND(PARAMS(eltwise_swish, __VA_ARGS__)), \
511             EXPAND(PARAMS(eltwise_gelu_erf, __VA_ARGS__))
512 
513 #define PARAMS_ALL_ALG_SDPART(...) \
514     EXPAND(PARAMS(eltwise_linear, __VA_ARGS__)), \
515             EXPAND(PARAMS(eltwise_soft_relu, __VA_ARGS__)), \
516             EXPAND(PARAMS(eltwise_bounded_relu, __VA_ARGS__)), \
517             EXPAND(PARAMS(eltwise_logistic, __VA_ARGS__))
518 
519 #define _CPU_INST_TEST_CASE(str, data_t, ...) \
520     CPU_INSTANTIATE_TEST_SUITE_P(str##_##data_t, eltwise_test_##data_t, \
521             ::testing::Values(__VA_ARGS__))
522 
523 #define _INST_TEST_CASE(str, data_t, ...) \
524     INSTANTIATE_TEST_SUITE_P_(str##_##data_t, eltwise_test_##data_t, \
525             ::testing::Values(__VA_ARGS__))
526 
527 #define CPU_INST_TEST_CASE_BF16(str, ...) \
528     _CPU_INST_TEST_CASE(str, bf16, __VA_ARGS__);
529 #define INST_TEST_CASE_BF16(str, ...) _INST_TEST_CASE(str, bf16, __VA_ARGS__);
530 #define GPU_INST_TEST_CASE_F16(str, ...) \
531     GPU_INSTANTIATE_TEST_SUITE_P_(TEST_CONCAT(str, _f16), eltwise_test_f16, \
532             ::testing::Values(__VA_ARGS__));
533 #define CPU_INST_TEST_CASE_F32(str, ...) \
534     _CPU_INST_TEST_CASE(str, f32, __VA_ARGS__);
535 #define INST_TEST_CASE_F32(str, ...) _INST_TEST_CASE(str, f32, __VA_ARGS__);
536 #define CPU_INST_TEST_CASE_S32(str, ...) \
537     _CPU_INST_TEST_CASE(str, s32, __VA_ARGS__);
538 #define INST_TEST_CASE_S32(str, ...) _INST_TEST_CASE(str, s32, __VA_ARGS__);
539 #define CPU_INST_TEST_CASE_S8(str, ...) \
540     _CPU_INST_TEST_CASE(str, s8, __VA_ARGS__);
541 #define INST_TEST_CASE_S8(str, ...) _INST_TEST_CASE(str, s8, __VA_ARGS__);
542 
543 #define CPU_INST_TEST_CASE(str, ...) \
544     CPU_INST_TEST_CASE_F32(str, __VA_ARGS__) \
545     CPU_INST_TEST_CASE_BF16(str, __VA_ARGS__) \
546     CPU_INST_TEST_CASE_S32(str, __VA_ARGS__) \
547     CPU_INST_TEST_CASE_S8(str, __VA_ARGS__)
548 
549 #define INST_TEST_CASE(str, ...) \
550     GPU_INST_TEST_CASE_F16(str, __VA_ARGS__) \
551     INST_TEST_CASE_BF16(str, __VA_ARGS__) \
552     INST_TEST_CASE_F32(str, __VA_ARGS__) \
553     INST_TEST_CASE_S32(str, __VA_ARGS__) \
554     INST_TEST_CASE_S8(str, __VA_ARGS__)
555 
556 INST_TEST_CASE(SimpleZeroDim,
557         PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 0, 2, 4, 4, 4),
558         PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 0, 4, 4, 4),
559         PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 0, 4, 2, 2, 2),
560         PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 0, 2, 2, 2));
561 
562 #define CASE_EF(alg, d0, d1, d2, d3) \
563     eltwise_test_params_t { \
564         algorithm::eltwise_##alg, EXPAND_FORMATS(nchw), EXPAND_FORMATS(nchw), \
565                 0.f, 0.f, {d0, d1, d2, d3}, true, dnnl_invalid_arguments \
566     }
567 INST_TEST_CASE(SimpleExpectedFails, CASE_EF(relu, -1, 2, 4, 4),
568         CASE_EF(logistic, -1, 2, 4, 4), CASE_EF(relu, 1, -2, 4, 4),
569         CASE_EF(logistic, 1, -2, 4, 4));
570 
571 INST_TEST_CASE(Simple_3D,
572         PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 8, 4, 4, 4),
573         PARAMS_ALL_ALG(nCdhw8c, ncdhw, 0.1f, 0.f, 2, 16, 4, 4, 4),
574         PARAMS_ALL_ALG(ncdhw, ncdhw, 0.1f, 0.f, 2, 16, 8, 8, 8),
575         PARAMS_ALL_ALG(nCdhw8c, nCdhw8c, 0.1f, 0.f, 2, 16, 16, 8, 6),
576         PARAMS_ALL_ALG(ndhwc, ncdhw, 0.1f, 0.f, 2, 16, 10, 8, 6),
577         PARAMS_ALL_ALG(ncdhw, ndhwc, 0.1f, 0.f, 10, 10, 10, 10, 10));
578 
579 INST_TEST_CASE(Simple_blocked_3d_padded,
580         PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 15, 2, 2, 2),
581         PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 27, 2, 2, 2),
582         PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 2, 2, 2),
583         PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 7, 7, 7));
584 
585 INST_TEST_CASE(Simple_blocked_padded,
586         PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 15, 2, 2),
587         PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 27, 2, 2),
588         PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 23, 2, 2),
589         PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 17, 7, 7),
590         PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 15, 2, 2),
591         PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 27, 2, 2),
592         PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 23, 2, 2),
593         PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 17, 7, 7));
594 
595 CPU_INST_TEST_CASE(Simple_NCDHW,
596         PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 0.f, 2, 32, 28, 28, 28),
597         PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 0.f, 2, 64, 13, 13, 13),
598         PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 1.f, 1, 64, 27, 27, 27),
599         PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 1.f, 1, 128, 11, 11, 11));
600 
601 CPU_INST_TEST_CASE(SimpleZeroNegativeSlope,
602         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
603         PARAMS_ALL_ALG(nChw16c, nChw16c, 0.f, 0.f, 2, 16, 4, 4),
604         PARAMS_ALL_ALG(nChw8c, nChw8c, 0.f, 0.f, 2, 16, 8, 8),
605         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
606         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
607         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
608         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11));
609 
610 INST_TEST_CASE(Simple_NCHW, PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
611         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
612         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
613         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
614         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
615         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
616         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
617         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
618         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11));
619 
620 INST_TEST_CASE(Simple_NCHW_SDPART,
621         PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16));
622 
623 CPU_INST_TEST_CASE(Simple, PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
624         PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
625         PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
626         PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
627         PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
628         PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10));
629 
630 CPU_INST_TEST_CASE(Simple_SDPART,
631         PARAMS_ALL_ALG_SDPART(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
632         PARAMS_ALL_ALG_SDPART(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
633         PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
634         PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
635         PARAMS_ALL_ALG_SDPART(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
636         PARAMS_ALL_ALG_SDPART(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10));
637 
638 INST_TEST_CASE(AlexNet_NCHW,
639         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
640         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
641         PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13),
642         PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
643         PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
644         PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13));
645 
646 INST_TEST_CASE(Simple_X, PARAMS_ALL_ALG(x, x, 0.f, 0.f, 55));
647 
648 } // namespace dnnl
649