1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <memory>
18
19 #include "dnnl_test_common.hpp"
20 #include "gtest/gtest.h"
21
22 #include "oneapi/dnnl/dnnl.hpp"
23
24 namespace dnnl {
25
26 template <typename T, typename A>
relu_fwd(T s,A alpha)27 inline T relu_fwd(T s, A alpha) {
28 return s > 0 ? s : static_cast<T>(s * alpha);
29 }
30 template <typename T, typename A>
relu_bwd(T dd,T s,A alpha)31 inline T relu_bwd(T dd, T s, A alpha) {
32 return s > 0 ? dd : static_cast<T>(dd * alpha);
33 }
34 template <typename T>
tanh_fwd(T s)35 T tanh_fwd(T s) {
36 return static_cast<T>(::tanhf((float)s));
37 }
38 template <typename T>
tanh_bwd(T dd,T s)39 T tanh_bwd(T dd, T s) {
40 const float th = ::tanhf((float)s);
41 return static_cast<T>(dd * (1 - th) * (1 + th));
42 }
43
44 template <typename T, typename A>
45 T elu_fwd(T s, A alpha) {
46 return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
47 }
48 template <typename T, typename A>
49 T elu_bwd(T dd, T s, A alpha) {
50 return static_cast<T>(dd * (s > 0 ? 1 : alpha * ::expf(s)));
51 }
52
53 template <typename T>
gelu_tanh_fwd(T s)54 T gelu_tanh_fwd(T s) {
55 const float a = 0.797884;
56 const float b = 0.044715;
57 const float g = a * s * (1 + b * s * s);
58 return static_cast<T>(0.5 * s * (1 + tanh_fwd(g)));
59 }
60
61 template <typename T>
gelu_tanh_bwd(T dd,T s)62 T gelu_tanh_bwd(T dd, T s) {
63 const float a = 0.797884;
64 const float b = 0.044715;
65 const float g = a * s * (1 + b * s * s);
66 const float dg = a * (1 + 3 * b * s * s);
67 return static_cast<T>(
68 dd * (0.5 * (1 + tanh_fwd(g)) * (1 + s * (1 - tanh_fwd(g)) * dg)));
69 }
70
71 template <typename T>
square_fwd(T s)72 T square_fwd(T s) {
73 return s * s;
74 }
75
76 template <typename T>
square_bwd(T dd,T s)77 T square_bwd(T dd, T s) {
78 return dd * 2 * s;
79 }
80
81 template <typename T>
abs_fwd(T s)82 T abs_fwd(T s) {
83 return s > 0 ? s : T(-s);
84 }
85
86 template <typename T>
abs_bwd(T dd,T s)87 T abs_bwd(T dd, T s) {
88 return dd * (s > 0 ? 1 : s < 0 ? -1 : 0);
89 }
90
91 template <typename T, typename A>
92 T linear_fwd(T s, A alpha, A beta) {
93 return alpha * s + beta;
94 }
95
96 template <typename T, typename A>
97 T linear_bwd(T dd, T s, A alpha, A beta) {
98 (void)s;
99 (void)beta;
100 return dd * alpha;
101 }
102
103 template <typename T, typename A>
104 T bounded_relu_fwd(T s, A alpha) {
105 s = s > 0 ? s : T(0);
106 return s > alpha ? T(alpha) : s;
107 }
108
109 template <typename T, typename A>
110 T bounded_relu_bwd(T dd, T s, A alpha) {
111 return dd * ((0 < s && s < alpha) ? 1 : 0);
112 }
113
114 template <typename T>
soft_relu_fwd(T s)115 T soft_relu_fwd(T s) {
116 return s < (T)logf(FLT_MAX) ? T(log1pf(::expf(s))) : s;
117 }
118
119 template <typename T>
soft_relu_bwd(T dd,T s)120 T soft_relu_bwd(T dd, T s) {
121 return dd / (1 + ::expf(-s));
122 }
123
124 template <typename T>
logistic_fwd(T s)125 T logistic_fwd(T s) {
126 float v = ::expf((float)-s);
127 return (T)(1 / (1 + v));
128 }
129
130 template <typename T>
logistic_bwd(T dd,T s)131 T logistic_bwd(T dd, T s) {
132 float v = logistic_fwd<float>(s);
133 return (T)(dd * v * (1 - v));
134 }
135
136 template <typename T>
exp_fwd(T s)137 T exp_fwd(T s) {
138 return (T)(::expf((float)s));
139 }
140
141 template <typename T>
exp_bwd(T dd,T s)142 T exp_bwd(T dd, T s) {
143 return dd * (::expf((float)s));
144 }
145
146 template <typename T, typename A>
147 T swish_fwd(T s, A alpha) {
148 return (T)(s / (1.0f + ::expf(-alpha * (float)s)));
149 }
150
151 template <typename T, typename A>
152 T swish_bwd(T dd, T s, A alpha) {
153 float v = logistic_fwd<float>(alpha * s);
154 return dd * (v + s * alpha * v * (1 - v));
155 }
156
157 template <typename T>
gelu_erf_fwd(T s)158 T gelu_erf_fwd(T s) {
159 const float sqrt_2_over_2 = 0.707106;
160 float v = s * sqrt_2_over_2;
161 return (T)(sqrt_2_over_2 * v * (1.f + ::erff(v)));
162 }
163
164 template <typename T>
gelu_erf_bwd(T dd,T s)165 T gelu_erf_bwd(T dd, T s) {
166 const float two_over_sqrt_pi = 1.128379;
167 const float sqrt_2_over_2 = 0.707106;
168 float v = s * sqrt_2_over_2;
169 return (T)(dd * 0.5f
170 * (1.f + ::erff(v) + v * two_over_sqrt_pi * ::expf(-v * v)));
171 }
172
173 struct eltwise_test_params_t {
174 algorithm alg_kind;
175 memory::format_tag data_format;
176 memory::format_tag diff_format;
177 float alpha, beta;
178 memory::dims dims;
179 bool expect_to_fail;
180 dnnl_status_t expected_status;
181 };
182
n_elems(const memory::desc & md)183 memory::dim n_elems(const memory::desc &md) {
184 memory::dim p = 1;
185 const auto *pdims = md.data.padded_dims;
186 for (int i = 0; i < md.data.ndims; ++i)
187 p *= pdims[i];
188 return p;
189 }
190
191 template <typename data_t>
check_eltwise_fwd(const eltwise_test_params_t & p,const memory::desc & md,const memory & src,const memory & dst)192 void check_eltwise_fwd(const eltwise_test_params_t &p, const memory::desc &md,
193 const memory &src, const memory &dst) {
194 auto src_data = map_memory<data_t>(src);
195 auto dst_data = map_memory<data_t>(dst);
196
197 memory::dim n = n_elems(md);
198 for (memory::dim i = 0; i < n; ++i) {
199 data_t s = src_data[i];
200 data_t ref_d = 0;
201 switch (p.alg_kind) {
202 case algorithm::eltwise_relu: ref_d = relu_fwd(s, p.alpha); break;
203 case algorithm::eltwise_tanh: ref_d = tanh_fwd(s); break;
204 case algorithm::eltwise_elu: ref_d = elu_fwd(s, p.alpha); break;
205 case algorithm::eltwise_square: ref_d = square_fwd(s); break;
206 case algorithm::eltwise_abs: ref_d = abs_fwd(s); break;
207 case algorithm::eltwise_linear:
208 ref_d = linear_fwd(s, p.alpha, p.beta);
209 break;
210 case algorithm::eltwise_bounded_relu:
211 ref_d = bounded_relu_fwd(s, p.alpha);
212 break;
213 case algorithm::eltwise_soft_relu: ref_d = soft_relu_fwd(s); break;
214 case algorithm::eltwise_logistic: ref_d = logistic_fwd(s); break;
215 case algorithm::eltwise_exp: ref_d = exp_fwd(s); break;
216 case algorithm::eltwise_gelu_tanh: ref_d = gelu_tanh_fwd(s); break;
217 case algorithm::eltwise_swish: ref_d = swish_fwd(s, p.alpha); break;
218 case algorithm::eltwise_gelu_erf: ref_d = gelu_erf_fwd(s); break;
219 default: assert(!"unknown alg_kind");
220 }
221 dst_data[i] = ref_d;
222 }
223 }
224
225 template <typename data_t>
compare_eltwise_fwd(const eltwise_test_params_t & p,const memory::desc & md,const memory & dst,const memory & ref_dst)226 void compare_eltwise_fwd(const eltwise_test_params_t &p, const memory::desc &md,
227 const memory &dst, const memory &ref_dst) {
228 data_t eps;
229 if (data_traits<data_t>::data_type == memory::data_type::s8
230 || data_traits<data_t>::data_type == memory::data_type::s32)
231 eps = 0;
232 else
233 eps = static_cast<data_t>(
234 (data_traits<data_t>::data_type == memory::data_type::f16
235 || data_traits<data_t>::data_type
236 == memory::data_type::bf16)
237 ? 5e-2
238 : (p.alg_kind == algorithm::eltwise_elu
239 || p.alg_kind == algorithm::eltwise_gelu_tanh
240 || p.alg_kind == algorithm::eltwise_gelu_erf)
241 ? 2e-5
242 : p.alg_kind == algorithm::eltwise_soft_relu
243 ? 3e-5
244 : 1e-6);
245 compare_data(ref_dst, dst, eps);
246 }
247
248 template <typename data_t>
check_eltwise_bwd(const eltwise_test_params_t & p,const memory::desc & md,const memory & src,const memory & diff_dst,const memory & diff_src)249 void check_eltwise_bwd(const eltwise_test_params_t &p, const memory::desc &md,
250 const memory &src, const memory &diff_dst, const memory &diff_src) {
251 auto src_data = map_memory<data_t>(src);
252 auto diff_dst_data = map_memory<data_t>(diff_dst);
253 auto diff_src_data = map_memory<data_t>(diff_src);
254
255 const memory::desc data_d = src.get_desc();
256 const memory::desc diff_data_d = diff_src.get_desc();
257 const dnnl::impl::memory_desc_wrapper data_mdw(data_d.data);
258 const dnnl::impl::memory_desc_wrapper diff_data_mdw(diff_data_d.data);
259
260 float eps_f = 0;
261 if (p.alg_kind == algorithm::eltwise_soft_relu) {
262 eps_f = 2e-6f;
263 } else if (p.alg_kind == algorithm::eltwise_tanh) {
264 eps_f = (get_test_engine_kind() == engine::kind::gpu) ? 2e-5f : 2e-6f;
265 } else if (p.alg_kind == algorithm::eltwise_gelu_tanh
266 || p.alg_kind == algorithm::eltwise_gelu_erf) {
267 eps_f = 1e-5f;
268 } else {
269 eps_f = 1e-6f;
270 }
271 data_t eps = static_cast<data_t>(eps_f);
272
273 memory::dim n = n_elems(md);
274 for (memory::dim i = 0; i < n; ++i) {
275 data_t ref_s = src_data[data_mdw.off_l(i)];
276 data_t ref_dd = diff_dst_data[diff_data_mdw.off_l(i)];
277 data_t ref_ds = 0;
278 switch (p.alg_kind) {
279 case algorithm::eltwise_relu:
280 ref_ds = relu_bwd(ref_dd, ref_s, p.alpha);
281 break;
282 case algorithm::eltwise_tanh:
283 ref_ds = tanh_bwd(ref_dd, ref_s);
284 break;
285 case algorithm::eltwise_elu:
286 ref_ds = elu_bwd(ref_dd, ref_s, p.alpha);
287 break;
288 case algorithm::eltwise_square:
289 ref_ds = square_bwd(ref_dd, ref_s);
290 break;
291 case algorithm::eltwise_abs: ref_ds = abs_bwd(ref_dd, ref_s); break;
292 case algorithm::eltwise_linear:
293 ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
294 break;
295 case algorithm::eltwise_bounded_relu:
296 ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
297 break;
298 case algorithm::eltwise_soft_relu:
299 ref_ds = soft_relu_bwd(ref_dd, ref_s);
300 break;
301 case algorithm::eltwise_logistic:
302 ref_ds = logistic_bwd(ref_dd, ref_s);
303 break;
304 case algorithm::eltwise_exp: ref_ds = exp_bwd(ref_dd, ref_s); break;
305 case algorithm::eltwise_gelu_tanh:
306 ref_ds = gelu_tanh_bwd(ref_dd, ref_s);
307 break;
308 case algorithm::eltwise_swish:
309 ref_ds = swish_bwd(ref_dd, ref_s, p.alpha);
310 break;
311 case algorithm::eltwise_gelu_erf:
312 ref_ds = gelu_erf_bwd(ref_dd, ref_s);
313 break;
314 default: assert(!"unknown alg_kind");
315 }
316
317 data_t tgt = diff_src_data[diff_data_mdw.off_l(i)];
318 const data_t diff = tgt == ref_ds ? 0 : tgt - ref_ds;
319 data_t error = (std::abs(ref_ds) > eps)
320 ? static_cast<data_t>(diff / ref_ds)
321 : diff;
322 if (p.alg_kind == algorithm::eltwise_logistic
323 && (tgt < 1e-3)) { // check for cancellation
324 error = diff;
325 }
326 ASSERT_NEAR(error, 0.0, eps);
327 }
328 }
329
330 template <typename data_t>
331 class eltwise_test_t : public ::testing::TestWithParam<eltwise_test_params_t> {
332 private:
333 memory src;
334 std::shared_ptr<memory::desc> data_desc;
335 eltwise_forward::primitive_desc eltwise_prim_desc;
336 eltwise_test_params_t p;
337 engine eng;
338 stream strm;
339 memory::data_type data_type;
340
341 protected:
SetUp()342 void SetUp() override {
343 data_type = data_traits<data_t>::data_type;
344 SKIP_IF(unsupported_data_type(data_type),
345 "Engine does not support this data type.");
346 p = ::testing::TestWithParam<decltype(p)>::GetParam();
347 SKIP_IF((p.alg_kind != algorithm::eltwise_relu
348 || (p.alg_kind == algorithm::eltwise_relu
349 && p.alpha != 0.0))
350 && (data_type == memory::data_type::s32
351 || data_type == memory::data_type::s8),
352 "oneDNN only supports relu w/ slope=0 for integers");
353 SKIP_IF_CUDA(p.alg_kind != algorithm::eltwise_relu
354 && p.alg_kind != algorithm::eltwise_bounded_relu
355 && p.alg_kind != algorithm::eltwise_tanh
356 && p.alg_kind != algorithm::eltwise_elu
357 && p.alg_kind != algorithm::eltwise_logistic,
358 "Unsupported algorithm type for CUDA");
359 SKIP_IF_CUDA(p.alg_kind == algorithm::eltwise_relu && p.alpha != 0.0,
360 "DNNL only supports relu w/ slope=0 for integers");
361 SKIP_IF_CUDA(!cuda_check_format_tag(p.data_format),
362 "Unsupported format tag");
363 SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_format),
364 "Unsupported format tag");
365 catch_expected_failures(
366 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
367 }
368
cuda_check_format_tag(memory::format_tag tag)369 bool cuda_check_format_tag(memory::format_tag tag) {
370 // Blocking is not supported by cuDNN
371 return (tag != memory::format_tag::aBcd8b
372 && tag != memory::format_tag::aBcd16b
373 && tag != memory::format_tag::aBcde8b
374 && tag != memory::format_tag::aBcde16b);
375 }
376
Test()377 void Test() {
378 p = ::testing::TestWithParam<eltwise_test_params_t>::GetParam();
379
380 eng = get_test_engine();
381 strm = make_stream(eng);
382
383 Forward();
384 if (data_type == memory::data_type::f32
385 || data_type == memory::data_type::bf16) {
386 Backward();
387 }
388 }
389
Forward()390 void Forward() {
391 data_desc = std::make_shared<memory::desc>(
392 p.dims, data_type, p.data_format);
393 src = test::make_memory(*data_desc, eng);
394 auto dst = test::make_memory(*data_desc, eng);
395 auto ref_dst = test::make_memory(*data_desc, eng);
396
397 data_t data_median = data_t(0);
398 data_t data_deviation = (p.alg_kind == algorithm::eltwise_elu
399 || p.alg_kind == algorithm::eltwise_exp)
400 || (p.alg_kind == algorithm::eltwise_swish)
401 ? data_t(1.0)
402 : p.alg_kind == algorithm::eltwise_square ? data_t(6.0)
403 : data_t(100.0);
404 fill_data<data_t>(
405 n_elems(*data_desc), src, data_median, data_deviation);
406 check_zero_tail<data_t>(1, src);
407
408 auto eltwise_desc = eltwise_forward::desc(prop_kind::forward_training,
409 p.alg_kind, *data_desc, p.alpha, p.beta);
410 eltwise_prim_desc = eltwise_forward::primitive_desc(eltwise_desc, eng);
411 eltwise_prim_desc = eltwise_forward::primitive_desc(
412 eltwise_prim_desc.get()); // test construction from a C pd
413
414 ASSERT_TRUE(eltwise_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
415 == eltwise_prim_desc.src_desc());
416 ASSERT_TRUE(eltwise_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
417 == eltwise_prim_desc.dst_desc());
418
419 eltwise_forward(eltwise_prim_desc)
420 .execute(strm, {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst}});
421 strm.wait();
422
423 check_zero_tail<data_t>(0, dst);
424 check_eltwise_fwd<data_t>(p, *data_desc, src, ref_dst);
425 check_zero_tail<data_t>(1, ref_dst);
426 compare_eltwise_fwd<data_t>(p, *data_desc, dst, ref_dst);
427 }
428
Backward()429 void Backward() {
430 SKIP_IF_CUDA(p.alg_kind != algorithm::eltwise_relu
431 && p.alg_kind != algorithm::eltwise_bounded_relu,
432 "Unsupported algorithm");
433 SKIP_IF_CUDA(p.diff_format != p.data_format,
434 "CUDA does not support different data formats for data and "
435 "diff vectors");
436 memory::desc diff_data_desc(p.dims, data_type, p.diff_format);
437 auto diff_src = test::make_memory(diff_data_desc, eng);
438 auto diff_dst = test::make_memory(diff_data_desc, eng);
439
440 data_t data_median = data_t(0);
441 data_t data_deviation = p.alg_kind == algorithm::eltwise_elu
442 ? data_t(1.0)
443 : p.alg_kind == algorithm::eltwise_square ? data_t(6.0)
444 : data_t(100.0);
445 fill_data<data_t>(
446 n_elems(diff_data_desc), diff_dst, data_median, data_deviation);
447 check_zero_tail<data_t>(1, diff_dst);
448
449 auto eltwise_bwd_desc = eltwise_backward::desc(
450 p.alg_kind, diff_data_desc, *data_desc, p.alpha, p.beta);
451 auto eltwise_bwd_prim_desc = eltwise_backward::primitive_desc(
452 eltwise_bwd_desc, eng, eltwise_prim_desc);
453 eltwise_bwd_prim_desc
454 = eltwise_backward::primitive_desc(eltwise_bwd_prim_desc.get());
455
456 ASSERT_TRUE(
457 eltwise_bwd_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
458 == eltwise_bwd_prim_desc.src_desc());
459 ASSERT_TRUE(
460 eltwise_bwd_prim_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
461 == eltwise_bwd_prim_desc.dst_desc());
462
463 eltwise_backward(eltwise_bwd_prim_desc)
464 .execute(strm,
465 {{DNNL_ARG_SRC, src}, {DNNL_ARG_DIFF_DST, diff_dst},
466 {DNNL_ARG_DIFF_SRC, diff_src}});
467 strm.wait();
468
469 check_zero_tail<data_t>(0, diff_src);
470 check_eltwise_bwd<data_t>(p, *data_desc, src, diff_dst, diff_src);
471 }
472 };
473
474 using eltwise_test_f16 = eltwise_test_t<float16_t>;
475 using eltwise_test_bf16 = eltwise_test_t<bfloat16_t>;
476 using eltwise_test_f32 = eltwise_test_t<float>;
477 using eltwise_test_s32 = eltwise_test_t<int>;
478 using eltwise_test_s8 = eltwise_test_t<int8_t>;
479
TEST_P(eltwise_test_f16,TestsEltwise)480 TEST_P(eltwise_test_f16, TestsEltwise) {}
481
TEST_P(eltwise_test_bf16,TestsEltwise)482 TEST_P(eltwise_test_bf16, TestsEltwise) {}
483
TEST_P(eltwise_test_f32,TestsEltwise)484 TEST_P(eltwise_test_f32, TestsEltwise) {}
485
TEST_P(eltwise_test_s32,TestsEltwise)486 TEST_P(eltwise_test_s32, TestsEltwise) {}
487
TEST_P(eltwise_test_s8,TestsEltwise)488 TEST_P(eltwise_test_s8, TestsEltwise) {}
489
490 #define EXPAND(args) args
491
492 #define EXPAND_FORMATS(data) memory::format_tag::data
493 #define EXPAND_DIMS(...) \
494 { __VA_ARGS__ }
495
496 #define PARAMS(alg, data, diff_data, alpha, beta, ...) \
497 eltwise_test_params_t { \
498 algorithm::alg, EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
499 alpha, beta, EXPAND_DIMS(__VA_ARGS__) \
500 }
501
502 #define PARAMS_ALL_ALG(...) \
503 EXPAND(PARAMS(eltwise_gelu_tanh, __VA_ARGS__)), \
504 EXPAND(PARAMS(eltwise_relu, __VA_ARGS__)), \
505 EXPAND(PARAMS(eltwise_tanh, __VA_ARGS__)), \
506 EXPAND(PARAMS(eltwise_elu, __VA_ARGS__)), \
507 EXPAND(PARAMS(eltwise_square, __VA_ARGS__)), \
508 EXPAND(PARAMS(eltwise_abs, __VA_ARGS__)), \
509 EXPAND(PARAMS(eltwise_exp, __VA_ARGS__)), \
510 EXPAND(PARAMS(eltwise_swish, __VA_ARGS__)), \
511 EXPAND(PARAMS(eltwise_gelu_erf, __VA_ARGS__))
512
513 #define PARAMS_ALL_ALG_SDPART(...) \
514 EXPAND(PARAMS(eltwise_linear, __VA_ARGS__)), \
515 EXPAND(PARAMS(eltwise_soft_relu, __VA_ARGS__)), \
516 EXPAND(PARAMS(eltwise_bounded_relu, __VA_ARGS__)), \
517 EXPAND(PARAMS(eltwise_logistic, __VA_ARGS__))
518
519 #define _CPU_INST_TEST_CASE(str, data_t, ...) \
520 CPU_INSTANTIATE_TEST_SUITE_P(str##_##data_t, eltwise_test_##data_t, \
521 ::testing::Values(__VA_ARGS__))
522
523 #define _INST_TEST_CASE(str, data_t, ...) \
524 INSTANTIATE_TEST_SUITE_P_(str##_##data_t, eltwise_test_##data_t, \
525 ::testing::Values(__VA_ARGS__))
526
527 #define CPU_INST_TEST_CASE_BF16(str, ...) \
528 _CPU_INST_TEST_CASE(str, bf16, __VA_ARGS__);
529 #define INST_TEST_CASE_BF16(str, ...) _INST_TEST_CASE(str, bf16, __VA_ARGS__);
530 #define GPU_INST_TEST_CASE_F16(str, ...) \
531 GPU_INSTANTIATE_TEST_SUITE_P_(TEST_CONCAT(str, _f16), eltwise_test_f16, \
532 ::testing::Values(__VA_ARGS__));
533 #define CPU_INST_TEST_CASE_F32(str, ...) \
534 _CPU_INST_TEST_CASE(str, f32, __VA_ARGS__);
535 #define INST_TEST_CASE_F32(str, ...) _INST_TEST_CASE(str, f32, __VA_ARGS__);
536 #define CPU_INST_TEST_CASE_S32(str, ...) \
537 _CPU_INST_TEST_CASE(str, s32, __VA_ARGS__);
538 #define INST_TEST_CASE_S32(str, ...) _INST_TEST_CASE(str, s32, __VA_ARGS__);
539 #define CPU_INST_TEST_CASE_S8(str, ...) \
540 _CPU_INST_TEST_CASE(str, s8, __VA_ARGS__);
541 #define INST_TEST_CASE_S8(str, ...) _INST_TEST_CASE(str, s8, __VA_ARGS__);
542
543 #define CPU_INST_TEST_CASE(str, ...) \
544 CPU_INST_TEST_CASE_F32(str, __VA_ARGS__) \
545 CPU_INST_TEST_CASE_BF16(str, __VA_ARGS__) \
546 CPU_INST_TEST_CASE_S32(str, __VA_ARGS__) \
547 CPU_INST_TEST_CASE_S8(str, __VA_ARGS__)
548
549 #define INST_TEST_CASE(str, ...) \
550 GPU_INST_TEST_CASE_F16(str, __VA_ARGS__) \
551 INST_TEST_CASE_BF16(str, __VA_ARGS__) \
552 INST_TEST_CASE_F32(str, __VA_ARGS__) \
553 INST_TEST_CASE_S32(str, __VA_ARGS__) \
554 INST_TEST_CASE_S8(str, __VA_ARGS__)
555
556 INST_TEST_CASE(SimpleZeroDim,
557 PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 0, 2, 4, 4, 4),
558 PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 0, 4, 4, 4),
559 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 0, 4, 2, 2, 2),
560 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 0, 2, 2, 2));
561
562 #define CASE_EF(alg, d0, d1, d2, d3) \
563 eltwise_test_params_t { \
564 algorithm::eltwise_##alg, EXPAND_FORMATS(nchw), EXPAND_FORMATS(nchw), \
565 0.f, 0.f, {d0, d1, d2, d3}, true, dnnl_invalid_arguments \
566 }
567 INST_TEST_CASE(SimpleExpectedFails, CASE_EF(relu, -1, 2, 4, 4),
568 CASE_EF(logistic, -1, 2, 4, 4), CASE_EF(relu, 1, -2, 4, 4),
569 CASE_EF(logistic, 1, -2, 4, 4));
570
571 INST_TEST_CASE(Simple_3D,
572 PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 8, 4, 4, 4),
573 PARAMS_ALL_ALG(nCdhw8c, ncdhw, 0.1f, 0.f, 2, 16, 4, 4, 4),
574 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.1f, 0.f, 2, 16, 8, 8, 8),
575 PARAMS_ALL_ALG(nCdhw8c, nCdhw8c, 0.1f, 0.f, 2, 16, 16, 8, 6),
576 PARAMS_ALL_ALG(ndhwc, ncdhw, 0.1f, 0.f, 2, 16, 10, 8, 6),
577 PARAMS_ALL_ALG(ncdhw, ndhwc, 0.1f, 0.f, 10, 10, 10, 10, 10));
578
579 INST_TEST_CASE(Simple_blocked_3d_padded,
580 PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 15, 2, 2, 2),
581 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 27, 2, 2, 2),
582 PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 2, 2, 2),
583 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 7, 7, 7));
584
585 INST_TEST_CASE(Simple_blocked_padded,
586 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 15, 2, 2),
587 PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 27, 2, 2),
588 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 23, 2, 2),
589 PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 17, 7, 7),
590 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 15, 2, 2),
591 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 27, 2, 2),
592 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 23, 2, 2),
593 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 17, 7, 7));
594
595 CPU_INST_TEST_CASE(Simple_NCDHW,
596 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 0.f, 2, 32, 28, 28, 28),
597 PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 0.f, 2, 64, 13, 13, 13),
598 PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 1.f, 1, 64, 27, 27, 27),
599 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 1.f, 1, 128, 11, 11, 11));
600
601 CPU_INST_TEST_CASE(SimpleZeroNegativeSlope,
602 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
603 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.f, 0.f, 2, 16, 4, 4),
604 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.f, 0.f, 2, 16, 8, 8),
605 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
606 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
607 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
608 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11));
609
610 INST_TEST_CASE(Simple_NCHW, PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
611 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
612 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
613 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
614 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
615 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
616 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
617 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
618 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11));
619
620 INST_TEST_CASE(Simple_NCHW_SDPART,
621 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16));
622
623 CPU_INST_TEST_CASE(Simple, PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
624 PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
625 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
626 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
627 PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
628 PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10));
629
630 CPU_INST_TEST_CASE(Simple_SDPART,
631 PARAMS_ALL_ALG_SDPART(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
632 PARAMS_ALL_ALG_SDPART(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
633 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
634 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
635 PARAMS_ALL_ALG_SDPART(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
636 PARAMS_ALL_ALG_SDPART(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10));
637
638 INST_TEST_CASE(AlexNet_NCHW,
639 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
640 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
641 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13),
642 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
643 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
644 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13));
645
646 INST_TEST_CASE(Simple_X, PARAMS_ALL_ALG(x, x, 0.f, 0.f, 55));
647
648 } // namespace dnnl
649