1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19 
20 #include "oneapi/dnnl/dnnl.hpp"
21 
22 namespace dnnl {
23 
24 using tag = memory::format_tag;
25 
26 /* iface tests */
27 
28 class iface_sum_test_t : public ::testing::Test {
29 protected:
30     engine eng;
31     stream strm;
32 
SetUp()33     void SetUp() override {
34         eng = get_test_engine();
35         strm = make_stream(eng);
36     }
37 };
38 
TEST_F(iface_sum_test_t,SumTestDstDataTypeCompliance)39 TEST_F(iface_sum_test_t, SumTestDstDataTypeCompliance) {
40     using dt = memory::data_type;
41 
42     const dt src_dt = dt::s8;
43 
44     memory::dims shape = {10, 10, 10, 10};
45     auto src_md = memory::desc(shape, src_dt, tag::abcd);
46 
47     for_(tag dst_tag : {tag::any, tag::abcd, tag::acdb})
48     for (dt dst_dt : {dt::undef, dt::s8, dt::s32, dt::f32}) {
49         sum::primitive_desc sum_pd;
50         SKIP_FOR_LOOP_CUDA(dst_dt == dt::s32, "Unsupported data_type");
51         if (dst_dt != dt::undef) {
52             memory::desc dst_md(shape, dst_dt, dst_tag);
53             sum_pd = sum::primitive_desc(
54                     dst_md, {2., 2.}, {src_md, src_md}, eng);
55         } else {
56             sum_pd = sum::primitive_desc({2., 2.}, {src_md, src_md}, eng);
57         }
58 
59         dt expect_dst_dt = dst_dt == dt::undef ? src_dt : dst_dt;
60         ASSERT_EQ(sum_pd.dst_desc().data.data_type, expect_dst_dt);
61     }
62 }
63 
64 /* correctness tests */
65 
66 struct sum_test_params {
67     std::vector<tag> srcs_format;
68     tag dst_format;
69     memory::dims dims;
70     std::vector<float> scale;
71     bool is_output_omitted;
72     bool expect_to_fail;
73     dnnl_status_t expected_status;
74 };
75 
76 template <typename src_data_t, typename acc_t, typename dst_data_t = src_data_t>
77 class sum_test_t : public ::testing::TestWithParam<sum_test_params> {
78 private:
79     memory::data_type src_data_type;
80     memory::data_type dst_data_type;
81 
check_data(const std::vector<memory> & srcs,const std::vector<float> & scale,const memory & dst)82     void check_data(const std::vector<memory> &srcs,
83             const std::vector<float> &scale, const memory &dst) {
84         auto dst_data = map_memory<const dst_data_t>(dst);
85         const auto &dst_d = dst.get_desc();
86         const auto dst_dims = dst_d.data.dims;
87         const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.data);
88 
89         std::vector<mapped_ptr_t<const src_data_t>> mapped_srcs;
90         mapped_srcs.reserve(srcs.size());
91         for (auto &src : srcs)
92             mapped_srcs.emplace_back(map_memory<const src_data_t>(src));
93 
94         dnnl::impl::parallel_nd(dst_dims[0], dst_dims[1], dst_dims[2],
95                 dst_dims[3],
96                 [&](memory::dim n, memory::dim c, memory::dim h,
97                         memory::dim w) {
98                     if (is_current_test_failed()) return;
99 
100                     acc_t src_sum = 0.0;
101                     for (size_t num = 0; num < srcs.size(); num++) {
102                         auto &src_data = mapped_srcs[num];
103                         const auto &src_d = srcs[num].get_desc();
104                         const auto src_dims = src_d.data.dims;
105                         const dnnl::impl::memory_desc_wrapper src_mdw(
106                                 src_d.data);
107 
108                         auto src_idx = w + src_dims[3] * h
109                                 + src_dims[2] * src_dims[3] * c
110                                 + src_dims[1] * src_dims[2] * src_dims[3] * n;
111                         if (num == 0) {
112                             src_sum = acc_t(scale[num])
113                                     * src_data[src_mdw.off_l(src_idx, false)];
114                         } else {
115                             src_sum += acc_t(scale[num])
116                                     * src_data[src_mdw.off_l(src_idx, false)];
117                         }
118 
119                         src_sum = (std::max)(
120                                 (std::min)(src_sum,
121                                         (std::numeric_limits<acc_t>::max)()),
122                                 std::numeric_limits<acc_t>::lowest());
123                     }
124 
125                     auto dst_idx = w + dst_dims[3] * h
126                             + dst_dims[2] * dst_dims[3] * c
127                             + dst_dims[1] * dst_dims[2] * dst_dims[3] * n;
128 
129                     acc_t dst_val = dst_data[dst_mdw.off_l(dst_idx, false)];
130                     ASSERT_EQ(src_sum, dst_val);
131                 });
132     }
133 
134 protected:
cuda_supported_format_tag(memory::format_tag tag)135     bool cuda_supported_format_tag(memory::format_tag tag) {
136         return impl::utils::one_of(tag, dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd,
137                 dnnl_abcde, dnnl_abcdef, dnnl_abdec, dnnl_acb, dnnl_acbde,
138                 dnnl_acbdef, dnnl_acdb, dnnl_acdeb, dnnl_ba, dnnl_bac,
139                 dnnl_bacd, dnnl_bca, dnnl_bcda, dnnl_bcdea, dnnl_cba, dnnl_cdba,
140                 dnnl_cdeba, dnnl_decab, dnnl_defcab, dnnl_aBc4b, dnnl_aBcd4b,
141                 dnnl_aBcde4b);
142     }
SetUp()143     void SetUp() override {
144         src_data_type = data_traits<src_data_t>::data_type;
145         dst_data_type = data_traits<dst_data_t>::data_type;
146         sum_test_params p
147                 = ::testing::TestWithParam<sum_test_params>::GetParam();
148         SKIP_IF(get_test_engine_kind() == engine::kind::gpu
149                         && src_data_type == memory::data_type::bf16,
150                 "GPU does not support bfloat16 data type.");
151         SKIP_IF(unsupported_data_type(src_data_type),
152                 "Engine does not support this data type.");
153         SKIP_IF(unsupported_data_type(dst_data_type),
154                 "Engine does not support this data type.");
155 
156         SKIP_IF_CUDA(!cuda_supported_format_tag(p.dst_format),
157                 "Unsupported format tag");
158         for (size_t i = 0; i < p.srcs_format.size(); i++) {
159             SKIP_IF_CUDA(!cuda_supported_format_tag(p.srcs_format[i]),
160                     "Unsupported format tag");
161         }
162         catch_expected_failures(
163                 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
164     }
165 
Test()166     void Test() {
167         sum_test_params p
168                 = ::testing::TestWithParam<sum_test_params>::GetParam();
169 
170         const auto num_srcs = p.srcs_format.size();
171 
172         auto eng = get_test_engine();
173         auto strm = make_stream(eng);
174 
175         std::vector<memory::desc> srcs_md;
176         std::vector<memory> srcs;
177 
178         for (size_t i = 0; i < num_srcs; i++) {
179             auto desc = memory::desc(p.dims, src_data_type, p.srcs_format[i]);
180             auto src_memory = test::make_memory(desc, eng);
181             const size_t sz
182                     = src_memory.get_desc().get_size() / sizeof(src_data_t);
183             fill_data<src_data_t>(sz, src_memory);
184 
185             // Keep few mantissa digits for fp types to avoid round-off errors
186             // With proper scalars the computations give exact results
187             if (!std::is_integral<src_data_t>::value) {
188                 using uint_type = typename data_traits<src_data_t>::uint_type;
189                 int mant_digits
190                         = dnnl::impl::nstl::numeric_limits<src_data_t>::digits;
191                 int want_mant_digits = 3;
192                 auto src_ptr = map_memory<src_data_t>(src_memory);
193                 for (size_t i = 0; i < sz; i++) {
194                     uint_type mask = (uint_type)-1
195                             << (mant_digits - want_mant_digits);
196                     *((uint_type *)&src_ptr[i]) &= mask;
197                 }
198             }
199             srcs_md.push_back(desc);
200             srcs.push_back(src_memory);
201         }
202 
203         memory dst;
204         sum::primitive_desc sum_pd;
205 
206         if (p.is_output_omitted) {
207             ASSERT_NO_THROW(
208                     sum_pd = sum::primitive_desc(p.scale, srcs_md, eng));
209         } else {
210             auto dst_desc = memory::desc(p.dims, dst_data_type, p.dst_format);
211             sum_pd = sum::primitive_desc(dst_desc, p.scale, srcs_md, eng);
212 
213             ASSERT_EQ(sum_pd.dst_desc().data.ndims, dst_desc.data.ndims);
214         }
215         dst = test::make_memory(sum_pd.dst_desc(), eng);
216         // test construction from a C pd
217         sum_pd = sum::primitive_desc(sum_pd.get());
218 
219         ASSERT_TRUE(sum_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
220                 == sum_pd.dst_desc());
221         for (int i = 0; i < (int)srcs.size(); i++)
222             ASSERT_TRUE(sum_pd.query_md(
223                                 query::exec_arg_md, DNNL_ARG_MULTIPLE_SRC + i)
224                     == sum_pd.src_desc(i));
225 
226         {
227             auto dst_data = map_memory<dst_data_t>(dst);
228             const size_t sz = dst.get_desc().get_size() / sizeof(dst_data_t);
229             // overwriting dst to prevent false positives for test cases.
230             dnnl::impl::parallel_nd(
231                     (ptrdiff_t)sz, [&](ptrdiff_t i) { dst_data[i] = -32; });
232         }
233         sum c(sum_pd);
234         std::unordered_map<int, memory> args = {{DNNL_ARG_DST, dst}};
235         for (int i = 0; i < (int)num_srcs; i++) {
236             args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs[i]});
237         }
238         c.execute(strm, args);
239         strm.wait();
240 
241         check_data(srcs, p.scale, dst);
242     }
243 };
244 
__anonec0f339d0402(bool omit_output) 245 static auto simple_test_cases = [](bool omit_output) {
246     return ::testing::Values(
247             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {0, 7, 4, 4},
248                     {1.0f, 1.0f}, omit_output},
249             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 0, 4, 4},
250                     {1.0f, 1.0f}, omit_output},
251             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 8, 0, 4},
252                     {1.0f, 1.0f}, omit_output},
253             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {-1, 8, 4, 4},
254                     {1.0f, 1.0f}, omit_output, true, dnnl_invalid_arguments},
255 
256             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw,
257                     {1, 1024, 38, 50}, {1.0f, 1.0f}, omit_output},
258             sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 8, 2, 2},
259                     {1.0f, 1.0f}, omit_output},
260             sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nChw8c,
261                     {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output},
262             sum_test_params {{tag::nchw, tag::nchw}, tag::nChw8c, {2, 16, 2, 2},
263                     {1.0f, 1.0f}, omit_output},
264             sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nchw,
265                     {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output},
266             sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 8, 2, 2},
267                     {2.0f, 3.0f}, omit_output},
268             sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nChw8c,
269                     {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output},
270             sum_test_params {{tag::nchw, tag::nchw}, tag::nChw8c, {2, 16, 2, 2},
271                     {2.0f, 3.0f}, omit_output},
272             sum_test_params {{tag::nChw8c, tag::nChw8c}, tag::nchw,
273                     {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output},
274             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {5, 8, 3, 3},
275                     {2.0f, 3.0f}, omit_output},
276             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw,
277                     {32, 32, 13, 14}, {2.0f, 3.0f}, omit_output},
278             sum_test_params {{tag::nChw16c, tag::nChw8c}, tag::nChw16c,
279                     {2, 16, 3, 3}, {2.0f, 3.0f}, omit_output});
280 };
281 
__anonec0f339d0502(bool omit_output) 282 static auto simple_test_cases_bf16 = [](bool omit_output) {
283     return ::testing::Values(
284             sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c,
285                     {1, 16, 1, 1}, {2.0f, 3.0f}, omit_output},
286             sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {1, 16, 1, 1},
287                     {2.0f, 3.0f}, omit_output},
288             sum_test_params {{tag::nchw, tag::nchw}, tag::nchw, {2, 16, 13, 7},
289                     {2.0f, 3.0f}, omit_output},
290             sum_test_params {{tag::nchw, tag::nchw, tag::nchw, tag::nchw},
291                     tag::nchw, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f},
292                     omit_output},
293             sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw,
294                     {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output},
295             sum_test_params {
296                     {tag::nchw, tag::nchw, tag::nchw, tag::nchw, tag::nchw},
297                     tag::nchw, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
298                     omit_output},
299             sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw,
300                     {2, 37, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output},
301             sum_test_params {{tag::nchw, tag::nchw, tag::nchw}, tag::nchw,
302                     {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output},
303             sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c,
304                     {2, 16, 13, 7}, {2.0f, 3.0f}, omit_output},
305             sum_test_params {{tag::nChw16c, tag::nChw16c, tag::nChw16c},
306                     tag::nChw16c, {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f},
307                     omit_output},
308             sum_test_params {{tag::nChw16c, tag::nChw16c, tag::nChw16c,
309                                      tag::nChw16c, tag::nChw16c},
310                     tag::nChw16c, {2, 16, 13, 7},
311                     {2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, omit_output},
312             sum_test_params {{tag::nChw16c, tag::nChw16c}, tag::nChw16c,
313                     {2, 128, 23, 15}, {2.5f, 0.125f}, omit_output});
314 };
315 
__anonec0f339d0602() 316 static auto special_test_cases = []() {
317     return ::testing::Values(
318             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {1, 8, 4, 4},
319                     {1.0f}, false, true, dnnl_invalid_arguments},
320             sum_test_params {{tag::nchw, tag::nChw8c}, tag::nchw, {2, 8, 4, 4},
321                     {0.1f}, false, true, dnnl_invalid_arguments});
322 };
323 
324 /* corner cases */
325 #define CASE_CC(itag0, itag1, otag, dims_, ef, st) \
326     sum_test_params { \
327         {tag::itag0, tag::itag1}, tag::otag, memory::dims dims_, {1.0f, 1.0f}, \
328                 0, ef, st \
329     }
__anonec0f339d0702() 330 static auto corner_test_cases = []() {
331     return ::testing::Values(
332             CASE_CC(nchw, nChw8c, nchw, ({0, 7, 4, 4}), false, dnnl_success),
333             CASE_CC(nchw, nChw8c, nchw, ({1, 0, 4, 4}), false, dnnl_success),
334             CASE_CC(nchw, nChw8c, nchw, ({1, 8, 0, 4}), false, dnnl_success),
335             CASE_CC(nchw, nChw8c, nchw, ({-1, 8, 4, 4}), true,
336                     dnnl_invalid_arguments));
337 };
338 #undef CASE_CC
339 
340 #define CPU_INST_TEST_CASE(test, omit_output) \
341     CPU_TEST_P(test, TestsSum) {} \
342     CPU_INSTANTIATE_TEST_SUITE_P( \
343             TestSum, test, simple_test_cases(omit_output)); \
344     CPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases());
345 
346 #define INST_TEST_CASE_BF16(test, omit_output) \
347     CPU_TEST_P(test, TestsSum) {} \
348     CPU_INSTANTIATE_TEST_SUITE_P( \
349             TestSum, test, simple_test_cases(omit_output)); \
350     CPU_INSTANTIATE_TEST_SUITE_P( \
351             TestSumBf16, test, simple_test_cases_bf16(omit_output)); \
352     CPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases());
353 
354 #define GPU_INST_TEST_CASE(test, omit_output) \
355     GPU_TEST_P(test, TestsSum) {} \
356     GPU_INSTANTIATE_TEST_SUITE_P( \
357             TestSum, test, simple_test_cases(omit_output)); \
358     GPU_INSTANTIATE_TEST_SUITE_P(TestSumEF, test, special_test_cases());
359 
360 #define INST_TEST_CASE(test, omit_output) \
361     CPU_INST_TEST_CASE(test, omit_output) \
362     GPU_INST_TEST_CASE(test, omit_output)
363 
364 using sum_test_float_omit_output = sum_test_t<float, float>;
365 using sum_test_u8_omit_output = sum_test_t<uint8_t, int32_t>;
366 using sum_test_s8_omit_output = sum_test_t<int8_t, int32_t>;
367 using sum_test_s32_omit_output = sum_test_t<int32_t, float>;
368 using sum_test_f16_omit_output = sum_test_t<float16_t, float>;
369 using sum_test_bf16bf16_omit_output = sum_test_t<bfloat16_t, float>;
370 using sum_test_bf16f32_omit_output = sum_test_t<bfloat16_t, float, float>;
371 
372 using sum_test_float = sum_test_t<float, float>;
373 using sum_test_u8 = sum_test_t<uint8_t, int32_t>;
374 using sum_test_s8 = sum_test_t<int8_t, int32_t>;
375 using sum_test_s32 = sum_test_t<int32_t, float>;
376 using sum_test_f16 = sum_test_t<float16_t, float>;
377 using sum_test_bf16bf16 = sum_test_t<bfloat16_t, float>;
378 using sum_test_bf16f32 = sum_test_t<bfloat16_t, float, float>;
379 
380 using sum_cc_f32 = sum_test_t<float, float>;
381 
TEST_P(sum_cc_f32,TestSumCornerCases)382 TEST_P(sum_cc_f32, TestSumCornerCases) {}
383 INSTANTIATE_TEST_SUITE_P(TestSumCornerCases, sum_cc_f32, corner_test_cases());
384 
385 INST_TEST_CASE(sum_test_float_omit_output, 1)
386 INST_TEST_CASE(sum_test_u8_omit_output, 1)
387 INST_TEST_CASE(sum_test_s8_omit_output, 1)
388 INST_TEST_CASE(sum_test_s32_omit_output, 1)
389 INST_TEST_CASE_BF16(sum_test_bf16bf16_omit_output, 1)
390 // Automatically created dst descriptor has bf16 data type so this test is not
391 // valid: INST_TEST_CASE(sum_test_bf16f32_omit_output, 1)
392 GPU_INST_TEST_CASE(sum_test_f16_omit_output, 1)
393 
394 INST_TEST_CASE(sum_test_float, 0)
395 INST_TEST_CASE(sum_test_u8, 0)
396 INST_TEST_CASE(sum_test_s8, 0)
397 INST_TEST_CASE(sum_test_s32, 0)
398 INST_TEST_CASE_BF16(sum_test_bf16bf16, 0)
399 INST_TEST_CASE_BF16(sum_test_bf16f32, 0)
400 GPU_INST_TEST_CASE(sum_test_f16, 0)
401 
402 #undef CPU_INST_TEST_CASE
403 #undef GPU_INST_TEST_CASE
404 } // namespace dnnl
405