1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19 
20 #include "oneapi/dnnl/dnnl.hpp"
21 
22 namespace dnnl {
23 
24 struct test_inner_product_descr_t {
25     memory::dim mb;
26     memory::dim ic;
27     memory::dim oc;
28     memory::dim kd, kh, kw;
29 };
30 
31 template <typename data_t>
compute_ref_inner_product_fwd(test_inner_product_descr_t ipd,memory & src,memory & weights,memory & bias,memory & dst)32 void compute_ref_inner_product_fwd(test_inner_product_descr_t ipd, memory &src,
33         memory &weights, memory &bias, memory &dst) {
34     const bool w_bias = bias.get_desc().data.ndims != 0;
35     auto src_data = map_memory<data_t>(src);
36     auto weights_data = map_memory<data_t>(weights);
37     auto bias_data = w_bias ? map_memory<data_t>(bias) : nullptr;
38     auto dst_data = map_memory<data_t>(dst);
39 
40     const memory::desc src_d = src.get_desc();
41     const memory::desc weights_d = weights.get_desc();
42     const memory::desc bias_d = bias.get_desc();
43     const memory::desc dst_d = dst.get_desc();
44     const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
45     const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.data);
46     const dnnl::impl::memory_desc_wrapper bias_mdw(bias_d.data);
47     const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.data);
48 
49     auto padded_ic = src_mdw.padded_dims()[1];
50 
51     dnnl::impl::parallel_nd(ipd.mb, ipd.oc, [&](memory::dim n, memory::dim oc) {
52         memory::dim oidx = n * ipd.oc + oc;
53         dst_data[dst_mdw.off_l(oidx, true)]
54                 = bias_data ? bias_data[bias_mdw.off_l(oc, true)] : data_t {0};
55         for (memory::dim ic = 0; ic < ipd.ic; ic++) {
56             for_(memory::dim kd = 0; kd < ipd.kd; kd++)
57             for_(memory::dim kh = 0; kh < ipd.kh; kh++)
58             for (memory::dim kw = 0; kw < ipd.kw; kw++) {
59                 memory::dim iidx = n * padded_ic * ipd.kd * ipd.kh * ipd.kw
60                         + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw
61                         + kh * ipd.kw + kw;
62                 memory::dim widx = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw
63                         + ic * ipd.kd * ipd.kh * ipd.kw + kd * ipd.kh * ipd.kw
64                         + kh * ipd.kw + kw;
65                 dst_data[dst_mdw.off_l(oidx, true)]
66                         += src_data[src_mdw.off_l(iidx, true)]
67                         * weights_data[weights_mdw.off_l(widx, true)];
68             }
69         }
70     });
71 }
72 
73 struct inprod_test_params_t {
74     prop_kind aprop_kind;
75     memory::format_tag src_format;
76     memory::format_tag weights_format;
77     memory::format_tag bias_format;
78     memory::format_tag dst_format;
79     int ndims;
80     test_inner_product_descr_t test_ipd;
81     bool expect_to_fail;
82     dnnl_status_t expected_status;
83 };
84 
85 template <typename data_t>
86 class inner_product_test_t
87     : public ::testing::TestWithParam<inprod_test_params_t> {
88 protected:
SetUp()89     void SetUp() override {
90         auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
91         SKIP_IF_CUDA(!cuda_check_format_tags(p.src_format, p.weights_format,
92                              p.bias_format, p.dst_format),
93                 "Unsupported format tag");
94         SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
95         catch_expected_failures(
96                 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
97     }
98 
cuda_check_format_tags(memory::format_tag src_format,memory::format_tag wei_format,memory::format_tag bia_format,memory::format_tag dst_format)99     bool cuda_check_format_tags(memory::format_tag src_format,
100             memory::format_tag wei_format, memory::format_tag bia_format,
101             memory::format_tag dst_format) {
102         bool src_ok = src_format == memory::format_tag::ncdhw
103                 || src_format == memory::format_tag::ndhwc
104                 || src_format == memory::format_tag::nchw
105                 || src_format == memory::format_tag::nhwc
106                 || src_format == memory::format_tag::ncw
107                 || src_format == memory::format_tag::nwc
108                 || src_format == memory::format_tag::nc
109                 || src_format == memory::format_tag::any;
110         bool wei_ok = wei_format == memory::format_tag::oidhw
111                 || wei_format == memory::format_tag::odhwi
112                 || wei_format == memory::format_tag::dhwio
113                 || wei_format == memory::format_tag::oihw
114                 || wei_format == memory::format_tag::ohwi
115                 || wei_format == memory::format_tag::hwio
116                 || wei_format == memory::format_tag::oiw
117                 || wei_format == memory::format_tag::owi
118                 || wei_format == memory::format_tag::wio
119                 || wei_format == memory::format_tag::io
120                 || wei_format == memory::format_tag::oi
121                 || wei_format == memory::format_tag::any;
122         bool bia_ok = bia_format == memory::format_tag::undef
123                 || bia_format == memory::format_tag::any
124                 || bia_format == memory::format_tag::a
125                 || bia_format == memory::format_tag::x;
126         bool dst_ok = dst_format == memory::format_tag::any
127                 || dst_format == memory::format_tag::nc;
128 
129         return src_ok && wei_ok && bia_ok && dst_ok;
130     }
131 
Test()132     void Test() {
133         auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
134         test_inner_product_descr_t ipd = p.test_ipd;
135         bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
136         if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
137         bool with_bias = p.bias_format != memory::format_tag::undef;
138 
139         ASSERT_EQ(p.aprop_kind, prop_kind::forward);
140         auto eng = get_test_engine();
141         auto strm = make_stream(eng);
142         memory::data_type data_type = data_traits<data_t>::data_type;
143         ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
144 
145         memory::dims src_dims = {ipd.mb, ipd.ic}, wei_dims = {ipd.oc, ipd.ic};
146         if (has_spatial) {
147             if (p.ndims == 5) {
148                 src_dims.push_back(ipd.kd);
149                 wei_dims.push_back(ipd.kd);
150             }
151             if (p.ndims >= 4) {
152                 src_dims.push_back(ipd.kh);
153                 wei_dims.push_back(ipd.kh);
154             }
155             if (p.ndims >= 3) {
156                 src_dims.push_back(ipd.kw);
157                 wei_dims.push_back(ipd.kw);
158             }
159         }
160         auto ip_src_desc = create_md(src_dims, data_type, p.src_format);
161         auto ip_weights_desc = create_md(wei_dims, data_type, p.weights_format);
162         auto ip_bias_desc = with_bias
163                 ? create_md({ipd.oc}, data_type, p.bias_format)
164                 : create_md({}, data_type, p.bias_format);
165         auto ip_dst_desc = create_md({ipd.mb, ipd.oc}, data_type, p.dst_format);
166 
167         auto ip_desc = with_bias
168                 ? inner_product_forward::desc(p.aprop_kind, ip_src_desc,
169                         ip_weights_desc, ip_bias_desc, ip_dst_desc)
170                 : inner_product_forward::desc(p.aprop_kind, ip_src_desc,
171                         ip_weights_desc, ip_dst_desc);
172 
173         auto ip_primitive_desc
174                 = inner_product_forward::primitive_desc(ip_desc, eng);
175         ip_primitive_desc = inner_product_forward::primitive_desc(
176                 ip_primitive_desc.get()); // test construction from a C pd
177 
178         auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng);
179         auto ip_weights
180                 = test::make_memory(ip_primitive_desc.weights_desc(), eng);
181         auto ip_bias = test::make_memory(ip_primitive_desc.bias_desc(), eng);
182         auto ip_dst = test::make_memory(ip_primitive_desc.dst_desc(), eng);
183         auto dst_ref = test::make_memory(ip_primitive_desc.dst_desc(), eng);
184 
185         fill_data<data_t>(
186                 ip_src.get_desc().get_size() / sizeof(data_t), ip_src);
187         fill_data<data_t>(
188                 ip_weights.get_desc().get_size() / sizeof(data_t), ip_weights);
189         if (with_bias) {
190             fill_data<data_t>(
191                     ip_bias.get_desc().get_size() / sizeof(data_t), ip_bias);
192         }
193         check_zero_tail<data_t>(1, ip_src);
194         check_zero_tail<data_t>(1, ip_weights);
195 
196         ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
197                 == ip_primitive_desc.src_desc());
198         ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_DST)
199                 == ip_primitive_desc.dst_desc());
200         ASSERT_TRUE(
201                 ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS)
202                 == ip_primitive_desc.weights_desc());
203         ASSERT_TRUE(
204                 ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_BIAS)
205                 == ip_primitive_desc.bias_desc());
206 
207         inner_product_forward(ip_primitive_desc)
208                 .execute(strm,
209                         {{DNNL_ARG_SRC, ip_src}, {DNNL_ARG_WEIGHTS, ip_weights},
210                                 {DNNL_ARG_BIAS, ip_bias},
211                                 {DNNL_ARG_DST, ip_dst}});
212         strm.wait();
213 
214         compute_ref_inner_product_fwd<data_t>(
215                 ipd, ip_src, ip_weights, ip_bias, dst_ref);
216         check_zero_tail<data_t>(1, dst_ref);
217         compare_data<data_t>(dst_ref, ip_dst);
218 
219         check_zero_tail<data_t>(0, ip_dst);
220     }
221 };
222 
223 using inner_product_test_float = inner_product_test_t<float>;
224 using inprod_test_params_float = inprod_test_params_t;
225 
226 #define EXPAND_SIZES_3D(...) \
227     5, { __VA_ARGS__ }
228 #define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \
229     4, { mb, ic, oc, 1, kh, kw }
230 #define EXPAND_SIZES_1D(mb, ic, oc, kw) \
231     3, { mb, ic, oc, 1, 1, kw }
232 
TEST_P(inner_product_test_float,TestsInnerProduct)233 TEST_P(inner_product_test_float, TestsInnerProduct) {}
234 
235 INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardZeroDim,
236         inner_product_test_float,
237         ::testing::Values(inprod_test_params_float {prop_kind::forward,
238                 memory::format_tag::any, memory::format_tag::any,
239                 memory::format_tag::any, memory::format_tag::any,
240                 EXPAND_SIZES_2D(0, 32, 48, 6, 6)}));
241 
242 INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardEF, inner_product_test_float,
243         ::testing::Values(
244                 inprod_test_params_float {prop_kind::forward,
245                         memory::format_tag::any, memory::format_tag::any,
246                         memory::format_tag::any, memory::format_tag::any,
247                         EXPAND_SIZES_2D(2, 0, 48, 6, 6), true,
248                         dnnl_invalid_arguments},
249                 inprod_test_params_float {prop_kind::forward,
250                         memory::format_tag::any, memory::format_tag::any,
251                         memory::format_tag::any, memory::format_tag::any,
252                         EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true,
253                         dnnl_invalid_arguments},
254                 inprod_test_params_float {prop_kind::forward,
255                         memory::format_tag::any, memory::format_tag::any,
256                         memory::format_tag::any, memory::format_tag::any,
257                         EXPAND_SIZES_2D(2, -1, 48, 6, 6), true,
258                         dnnl_invalid_arguments}));
259 
260 INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardNoBias_padded,
261         inner_product_test_float,
262         ::testing::Values(
263                 inprod_test_params_float {prop_kind::forward,
264                         memory::format_tag::nChw16c,
265                         memory::format_tag::aBcd16b, memory::format_tag::undef,
266                         memory::format_tag::nc,
267                         EXPAND_SIZES_2D(4, 14, 25, 5, 5)},
268                 inprod_test_params_float {prop_kind::forward,
269                         memory::format_tag::nChw16c,
270                         memory::format_tag::aBcd16b, memory::format_tag::undef,
271                         memory::format_tag::nc,
272                         EXPAND_SIZES_2D(4, 20, 15, 5, 5)},
273                 inprod_test_params_float {prop_kind::forward,
274                         memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
275                         memory::format_tag::undef, memory::format_tag::nc,
276                         EXPAND_SIZES_2D(4, 6, 15, 5, 5)},
277                 inprod_test_params_float {prop_kind::forward,
278                         memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
279                         memory::format_tag::undef, memory::format_tag::nc,
280                         EXPAND_SIZES_2D(4, 10, 5, 5, 5)},
281                 inprod_test_params_float {prop_kind::forward,
282                         memory::format_tag::nChw4c, memory::format_tag::aBcd4b,
283                         memory::format_tag::undef, memory::format_tag::nc,
284                         EXPAND_SIZES_2D(4, 16, 5, 5, 5)}));
285 
286 GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductForward_padded,
287         inner_product_test_float,
288         ::testing::Values(inprod_test_params_float {prop_kind::forward,
289                                   memory::format_tag::nChw16c,
290                                   memory::format_tag::aBcd16b,
291                                   memory::format_tag::x, memory::format_tag::nc,
292                                   EXPAND_SIZES_2D(4, 14, 25, 5, 5)},
293                 inprod_test_params_float {prop_kind::forward,
294                         memory::format_tag::nChw16c,
295                         memory::format_tag::aBcd16b, memory::format_tag::x,
296                         memory::format_tag::nc,
297                         EXPAND_SIZES_2D(4, 20, 15, 5, 5)},
298                 inprod_test_params_float {prop_kind::forward,
299                         memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
300                         memory::format_tag::x, memory::format_tag::nc,
301                         EXPAND_SIZES_2D(4, 6, 15, 5, 5)},
302                 inprod_test_params_float {prop_kind::forward,
303                         memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
304                         memory::format_tag::x, memory::format_tag::nc,
305                         EXPAND_SIZES_2D(4, 10, 5, 5, 5)}));
306 
307 INSTANTIATE_TEST_SUITE_P(TestInnerProductForwardNoBias,
308         inner_product_test_float,
309         ::testing::Values(
310                 inprod_test_params_float {prop_kind::forward,
311                         memory::format_tag::any, memory::format_tag::any,
312                         memory::format_tag::undef, memory::format_tag::any,
313                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
314                 inprod_test_params_float {prop_kind::forward,
315                         memory::format_tag::any, memory::format_tag::any,
316                         memory::format_tag::undef, memory::format_tag::any,
317                         EXPAND_SIZES_2D(2, 512, 48, 2, 2)},
318                 inprod_test_params_float {prop_kind::forward,
319                         memory::format_tag::nwc, memory::format_tag::wio,
320                         memory::format_tag::undef, memory::format_tag::nc,
321                         EXPAND_SIZES_1D(2, 32, 48, 5)},
322                 inprod_test_params_float {prop_kind::forward,
323                         memory::format_tag::nwc, memory::format_tag::owi,
324                         memory::format_tag::undef, memory::format_tag::nc,
325                         EXPAND_SIZES_1D(2, 32, 48, 5)},
326                 inprod_test_params_float {prop_kind::forward,
327                         memory::format_tag::nwc, memory::format_tag::oiw,
328                         memory::format_tag::undef, memory::format_tag::nc,
329                         EXPAND_SIZES_1D(2, 32, 48, 5)},
330                 inprod_test_params_float {prop_kind::forward,
331                         memory::format_tag::ncw, memory::format_tag::oiw,
332                         memory::format_tag::undef, memory::format_tag::nc,
333                         EXPAND_SIZES_1D(2, 32, 48, 5)},
334                 inprod_test_params_float {prop_kind::forward,
335                         memory::format_tag::ncw, memory::format_tag::owi,
336                         memory::format_tag::undef, memory::format_tag::nc,
337                         EXPAND_SIZES_1D(2, 32, 48, 5)},
338                 inprod_test_params_float {prop_kind::forward,
339                         memory::format_tag::ncw, memory::format_tag::wio,
340                         memory::format_tag::undef, memory::format_tag::nc,
341                         EXPAND_SIZES_1D(2, 32, 48, 5)},
342                 inprod_test_params_float {prop_kind::forward,
343                         memory::format_tag::nhwc, memory::format_tag::hwio,
344                         memory::format_tag::undef, memory::format_tag::nc,
345                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
346                 inprod_test_params_float {prop_kind::forward,
347                         memory::format_tag::nhwc, memory::format_tag::ohwi,
348                         memory::format_tag::undef, memory::format_tag::nc,
349                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
350                 inprod_test_params_float {prop_kind::forward,
351                         memory::format_tag::nhwc, memory::format_tag::oihw,
352                         memory::format_tag::undef, memory::format_tag::nc,
353                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
354                 inprod_test_params_float {prop_kind::forward,
355                         memory::format_tag::nchw, memory::format_tag::oihw,
356                         memory::format_tag::undef, memory::format_tag::nc,
357                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
358                 inprod_test_params_float {prop_kind::forward,
359                         memory::format_tag::nchw, memory::format_tag::hwio,
360                         memory::format_tag::undef, memory::format_tag::nc,
361                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
362                 inprod_test_params_float {prop_kind::forward,
363                         memory::format_tag::nchw, memory::format_tag::ohwi,
364                         memory::format_tag::undef, memory::format_tag::nc,
365                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
366                 inprod_test_params_float {prop_kind::forward,
367                         memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
368                         memory::format_tag::undef, memory::format_tag::nc,
369                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
370                 inprod_test_params_float {prop_kind::forward,
371                         memory::format_tag::nChw16c,
372                         memory::format_tag::aBcd16b, memory::format_tag::undef,
373                         memory::format_tag::nc,
374                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
375                 inprod_test_params_float {prop_kind::forward,
376                         memory::format_tag::any, memory::format_tag::aBcd8b,
377                         memory::format_tag::undef, memory::format_tag::nc,
378                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
379                 inprod_test_params_float {prop_kind::forward,
380                         memory::format_tag::nChw8c, memory::format_tag::any,
381                         memory::format_tag::undef, memory::format_tag::nc,
382                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
383                 inprod_test_params_float {prop_kind::forward,
384                         memory::format_tag::nChw16c,
385                         memory::format_tag::aBcd16b, memory::format_tag::undef,
386                         memory::format_tag::nc,
387                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
388                 inprod_test_params_float {prop_kind::forward,
389                         memory::format_tag::nc, memory::format_tag::oi,
390                         memory::format_tag::undef, memory::format_tag::nc,
391                         EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
392                 inprod_test_params_float {prop_kind::forward,
393                         memory::format_tag::nc, memory::format_tag::oi,
394                         memory::format_tag::undef, memory::format_tag::nc,
395                         EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
396                 inprod_test_params_float {prop_kind::forward,
397                         memory::format_tag::nc, memory::format_tag::io,
398                         memory::format_tag::undef, memory::format_tag::nc,
399                         EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
400 
401 INSTANTIATE_TEST_SUITE_P(TestInnerProductForward3D, inner_product_test_float,
402         ::testing::Values(
403                 inprod_test_params_float {prop_kind::forward,
404                         memory::format_tag::any, memory::format_tag::any,
405                         memory::format_tag::undef, memory::format_tag::any,
406                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
407                 inprod_test_params_float {prop_kind::forward,
408                         memory::format_tag::ncdhw, memory::format_tag::dhwio,
409                         memory::format_tag::x, memory::format_tag::nc,
410                         EXPAND_SIZES_3D(2, 32, 48, 3, 5, 7)},
411                 inprod_test_params_float {prop_kind::forward,
412                         memory::format_tag::ncdhw, memory::format_tag::odhwi,
413                         memory::format_tag::undef, memory::format_tag::nc,
414                         EXPAND_SIZES_3D(2, 32, 48, 2, 4, 6)},
415                 inprod_test_params_float {prop_kind::forward,
416                         memory::format_tag::ncdhw, memory::format_tag::oidhw,
417                         memory::format_tag::undef, memory::format_tag::nc,
418                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
419                 inprod_test_params_float {prop_kind::forward,
420                         memory::format_tag::nCdhw8c,
421                         memory::format_tag::aBcde8b, memory::format_tag::x,
422                         memory::format_tag::nc,
423                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
424                 inprod_test_params_float {prop_kind::forward,
425                         memory::format_tag::nCdhw16c,
426                         memory::format_tag::aBcde16b, memory::format_tag::x,
427                         memory::format_tag::nc,
428                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
429                 inprod_test_params_float {prop_kind::forward,
430                         memory::format_tag::ndhwc, memory::format_tag::dhwio,
431                         memory::format_tag::undef, memory::format_tag::nc,
432                         EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
433                 inprod_test_params_float {prop_kind::forward,
434                         memory::format_tag::ndhwc, memory::format_tag::odhwi,
435                         memory::format_tag::undef, memory::format_tag::nc,
436                         EXPAND_SIZES_3D(2, 16, 48, 3, 4, 5)},
437                 inprod_test_params_float {prop_kind::forward,
438                         memory::format_tag::ndhwc, memory::format_tag::oidhw,
439                         memory::format_tag::undef, memory::format_tag::nc,
440                         EXPAND_SIZES_3D(2, 16, 48, 3, 5, 4)}));
441 
442 INSTANTIATE_TEST_SUITE_P(TestInnerProductForward, inner_product_test_float,
443         ::testing::Values(
444                 inprod_test_params_float {prop_kind::forward,
445                         memory::format_tag::any, memory::format_tag::any,
446                         memory::format_tag::any, memory::format_tag::any,
447                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
448                 inprod_test_params_float {prop_kind::forward,
449                         memory::format_tag::any, memory::format_tag::any,
450                         memory::format_tag::any, memory::format_tag::any,
451                         EXPAND_SIZES_2D(2, 512, 48, 2, 2)},
452                 inprod_test_params_float {prop_kind::forward,
453                         memory::format_tag::nhwc, memory::format_tag::oihw,
454                         memory::format_tag::x, memory::format_tag::nc,
455                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
456                 inprod_test_params_float {prop_kind::forward,
457                         memory::format_tag::nhwc, memory::format_tag::hwio,
458                         memory::format_tag::x, memory::format_tag::nc,
459                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
460                 inprod_test_params_float {prop_kind::forward,
461                         memory::format_tag::nchw, memory::format_tag::oihw,
462                         memory::format_tag::x, memory::format_tag::nc,
463                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
464                 inprod_test_params_float {prop_kind::forward,
465                         memory::format_tag::nChw8c, memory::format_tag::aBcd8b,
466                         memory::format_tag::x, memory::format_tag::nc,
467                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
468                 inprod_test_params_float {prop_kind::forward,
469                         memory::format_tag::nChw16c,
470                         memory::format_tag::aBcd16b, memory::format_tag::x,
471                         memory::format_tag::nc,
472                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
473                 inprod_test_params_float {prop_kind::forward,
474                         memory::format_tag::nc, memory::format_tag::oi,
475                         memory::format_tag::x, memory::format_tag::nc,
476                         EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
477                 inprod_test_params_float {prop_kind::forward,
478                         memory::format_tag::nc, memory::format_tag::oi,
479                         memory::format_tag::x, memory::format_tag::nc,
480                         EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
481                 inprod_test_params_float {prop_kind::forward,
482                         memory::format_tag::nc, memory::format_tag::oi,
483                         memory::format_tag::x, memory::format_tag::nc,
484                         EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
485 } // namespace dnnl
486