1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19 
20 #include "oneapi/dnnl/dnnl.hpp"
21 
22 #include "tests/test_isa_common.hpp"
23 
24 namespace dnnl {
25 
26 using tag = memory::format_tag;
27 
28 enum class data_fmt_t { flat, blocked_cX };
29 
30 #define FLT data_fmt_t::flat
31 #define BLK data_fmt_t::blocked_cX
32 
33 struct conv_any_fmt_test_params_t {
34     prop_kind aprop_kind;
35     algorithm aalgorithm;
36     data_fmt_t expected_src_fmt;
37     data_fmt_t expected_dst_fmt;
38     test_convolution_sizes_t test_cd;
39 };
40 
41 template <typename data_t>
42 class convolution_any_fmt_test_t
43     : public ::testing::TestWithParam<conv_any_fmt_test_params_t> {
44 protected:
SetUp()45     void SetUp() override {
46 #if DNNL_X64
47         // Skip this test if the library cannot select blocked format a priori.
48         // Currently blocking is supported only for sse41 and later CPUs.
49         bool implementation_supports_blocking = dnnl::mayiuse(cpu_isa::sse41);
50         if (!implementation_supports_blocking) return;
51 #else
52         return;
53 #endif
54 
55         auto p = ::testing::TestWithParam<
56                 conv_any_fmt_test_params_t>::GetParam();
57 
58         ASSERT_EQ(p.aprop_kind, prop_kind::forward);
59         ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
60         auto eng = get_test_engine();
61         memory::data_type data_type = data_traits<data_t>::data_type;
62         SKIP_IF_CUDA((p.expected_src_fmt == BLK || p.expected_dst_fmt == BLK),
63                 "unsupported format");
64         ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
65 
66         test_convolution_sizes_t cd = p.test_cd;
67 
68         auto c_src_desc
69                 = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type, tag::any);
70         auto c_weights_desc = cd.ng > 1
71                 ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw},
72                         data_type, tag::any)
73                 : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type, tag::any);
74         auto c_dst_desc
75                 = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, data_type, tag::any);
76 
77         auto conv_desc = convolution_forward::desc(p.aprop_kind, p.aalgorithm,
78                 c_src_desc, c_weights_desc, c_dst_desc, {cd.strh, cd.strw},
79                 {cd.padh, cd.padw}, {cd.padh, cd.padw});
80 
81         auto conv_prim_desc
82                 = convolution_forward::primitive_desc(conv_desc, eng);
83 
84         auto check_fmt = [&](const dnnl_memory_desc_t &md,
85                                  data_fmt_t expected) {
86             bool ok = false;
87             if (expected == FLT) {
88                 ok = true && md.format_kind == dnnl_blocked
89                         && md.format_desc.blocking.inner_nblks == 0;
90             } else if (expected == BLK) {
91                 ok = true && md.format_kind == dnnl_blocked
92                         && md.format_desc.blocking.inner_nblks == 1
93                         && md.format_desc.blocking.inner_idxs[0] == 1
94                         && (false || md.format_desc.blocking.inner_blks[0] == 8
95                                 || md.format_desc.blocking.inner_blks[0] == 16);
96             }
97             return ok;
98         };
99 
100         ASSERT_TRUE(
101                 check_fmt(conv_prim_desc.src_desc().data, p.expected_src_fmt));
102         ASSERT_TRUE(
103                 check_fmt(conv_prim_desc.dst_desc().data, p.expected_dst_fmt));
104     }
105 };
106 
107 using conv_any_fmt_test_float = convolution_any_fmt_test_t<float>;
108 
TEST_P(conv_any_fmt_test_float,TestsConvolutionAnyFmt)109 TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt) {}
110 
111 #define CPARAMS prop_kind::forward, algorithm::convolution_direct
112 
113 using tf32 = conv_any_fmt_test_params_t;
114 
115 #define ALEXNET_SUITE(EFMT) \
116     tf32 {CPARAMS, FLT, EFMT, \
117             {2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4}}, \
118             tf32 {CPARAMS, EFMT, EFMT, \
119                     {2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1}}, \
120             tf32 {CPARAMS, EFMT, EFMT, \
121                     {2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1}}, \
122             tf32 {CPARAMS, EFMT, EFMT, \
123                     {2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1}}, \
124             tf32 { \
125         CPARAMS, EFMT, EFMT, { \
126             2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 \
127         } \
128     }
129 
130 #if DNNL_X64
131 CPU_INSTANTIATE_TEST_SUITE_P(TestConvolutionAlexnetAnyFmtForward,
132         conv_any_fmt_test_float, ::testing::Values(ALEXNET_SUITE(BLK)));
133 #endif
134 } // namespace dnnl
135