1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19 
20 #include "oneapi/dnnl/dnnl.hpp"
21 
22 namespace dnnl {
23 
24 struct test_inner_product_descr_t {
25     memory::dim mb;
26     memory::dim ic;
27     memory::dim oc;
28     memory::dim kd, kh, kw;
29 };
30 
31 template <typename data_t>
compute_ref_inner_product_bwd_bias(const test_inner_product_descr_t & ipd,const memory & diff_dst,const memory & diff_bias)32 void compute_ref_inner_product_bwd_bias(const test_inner_product_descr_t &ipd,
33         const memory &diff_dst, const memory &diff_bias) {
34     auto diff_bias_data = map_memory<data_t>(diff_bias);
35     auto diff_dst_data = map_memory<data_t>(diff_dst);
36 
37     const memory::desc diff_bias_d = diff_bias.get_desc();
38     const memory::desc diff_dst_d = diff_dst.get_desc();
39     const dnnl::impl::memory_desc_wrapper diff_bias_mdw(diff_bias_d.data);
40     const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.data);
41 
42     dnnl::impl::parallel_nd(ipd.oc, [&](memory::dim oc) {
43         data_t *db = &diff_bias_data[diff_bias_mdw.off_l(oc, true)];
44         *db = data_t(0);
45         for (memory::dim n = 0; n < ipd.mb; ++n) {
46             *db += diff_dst_data[diff_dst_mdw.off_l(n * ipd.oc + oc, true)];
47         }
48     });
49 }
50 
51 template <typename data_t>
compute_ref_inner_product_bwd_weights(int ndims,const test_inner_product_descr_t & ipd,const memory & src,const memory & diff_dst,const memory & diff_weights)52 void compute_ref_inner_product_bwd_weights(int ndims,
53         const test_inner_product_descr_t &ipd, const memory &src,
54         const memory &diff_dst, const memory &diff_weights) {
55     auto src_data = map_memory<data_t>(src);
56     auto diff_weights_data = map_memory<data_t>(diff_weights);
57     auto diff_dst_data = map_memory<data_t>(diff_dst);
58 
59     const memory::desc src_d = src.get_desc();
60     const memory::desc diff_weights_d = diff_weights.get_desc();
61     const memory::desc diff_dst_d = diff_dst.get_desc();
62     const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
63     const dnnl::impl::memory_desc_wrapper diff_weights_mdw(diff_weights_d.data);
64     const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.data);
65 
66     auto padded_ic = src_d.data.padded_dims[1];
67 
68     bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
69     if (ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
70     dnnl::impl::parallel_nd(
71             ipd.oc, ipd.ic, [&](memory::dim oc, memory::dim ic) {
72                 if (has_spatial) {
73                     for_(memory::dim kd = 0; kd < ipd.kd; ++kd)
74                     for_(memory::dim kh = 0; kh < ipd.kh; ++kh)
75                     for (memory::dim kw = 0; kw < ipd.kw; ++kw) {
76                         memory::dim dwidx
77                                 = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw
78                                 + ic * ipd.kd * ipd.kh * ipd.kw
79                                 + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
80                         data_t *dw = &diff_weights_data[diff_weights_mdw.off_l(
81                                 dwidx, true)];
82                         *dw = data_t(0);
83                         for (memory::dim n = 0; n < ipd.mb; ++n) {
84                             memory::dim ddidx = n * ipd.oc + oc;
85                             memory::dim sidx
86                                     = n * padded_ic * ipd.kd * ipd.kh * ipd.kw
87                                     + ic * ipd.kd * ipd.kh * ipd.kw
88                                     + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
89                             *dw += diff_dst_data[diff_dst_mdw.off_l(
90                                            ddidx, true)]
91                                     * src_data[src_mdw.off_l(sidx, true)];
92                         }
93                     }
94                 } else {
95                     memory::dim dwidx = oc * ipd.ic + ic;
96                     data_t *dw = &diff_weights_data[diff_weights_mdw.off_l(
97                             dwidx, true)];
98                     *dw = data_t(0);
99                     for (memory::dim n = 0; n < ipd.mb; ++n) {
100                         memory::dim ddidx = n * ipd.oc + oc;
101                         memory::dim sidx = n * ipd.ic + ic;
102                         *dw += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)]
103                                 * src_data[src_mdw.off_l(sidx, true)];
104                     }
105                 }
106             });
107 }
108 
109 struct inprod_test_params_t {
110     memory::format_tag src_format;
111     memory::format_tag diff_weights_format;
112     memory::format_tag diff_bias_format;
113     memory::format_tag diff_dst_format;
114     int ndims;
115     test_inner_product_descr_t test_ipd;
116     bool expect_to_fail;
117     dnnl_status_t expected_status;
118 };
119 
120 template <typename data_t>
121 class inner_product_test_bwd_weights_t
122     : public ::testing::TestWithParam<inprod_test_params_t> {
123 protected:
SetUp()124     void SetUp() override {
125         auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
126         SKIP_IF_CUDA(
127                 !cuda_check_format_tags(p.src_format, p.diff_weights_format,
128                         p.diff_bias_format, p.diff_dst_format),
129                 "Unsupported format tag");
130         SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
131         catch_expected_failures(
132                 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
133     }
134 
cuda_check_format_tags(memory::format_tag src_format,memory::format_tag diff_wei_format,memory::format_tag diff_bia_format,memory::format_tag diff_dst_format)135     bool cuda_check_format_tags(memory::format_tag src_format,
136             memory::format_tag diff_wei_format,
137             memory::format_tag diff_bia_format,
138             memory::format_tag diff_dst_format) {
139         bool src_ok = src_format == memory::format_tag::ncdhw
140                 || src_format == memory::format_tag::ndhwc
141                 || src_format == memory::format_tag::nchw
142                 || src_format == memory::format_tag::nhwc
143                 || src_format == memory::format_tag::ncw
144                 || src_format == memory::format_tag::nwc
145                 || src_format == memory::format_tag::nc
146                 || src_format == memory::format_tag::any;
147         bool diff_wei_ok = diff_wei_format == memory::format_tag::oidhw
148                 || diff_wei_format == memory::format_tag::odhwi
149                 || diff_wei_format == memory::format_tag::dhwio
150                 || diff_wei_format == memory::format_tag::oihw
151                 || diff_wei_format == memory::format_tag::ohwi
152                 || diff_wei_format == memory::format_tag::hwio
153                 || diff_wei_format == memory::format_tag::oiw
154                 || diff_wei_format == memory::format_tag::owi
155                 || diff_wei_format == memory::format_tag::wio
156                 || diff_wei_format == memory::format_tag::io
157                 || diff_wei_format == memory::format_tag::oi
158                 || diff_wei_format == memory::format_tag::any;
159         bool diff_bia_ok = diff_bia_format == memory::format_tag::undef
160                 || diff_bia_format == memory::format_tag::any
161                 || diff_bia_format == memory::format_tag::a
162                 || diff_bia_format == memory::format_tag::x;
163         bool diff_dst_ok = diff_dst_format == memory::format_tag::any
164                 || diff_dst_format == memory::format_tag::nc;
165 
166         return src_ok && diff_wei_ok && diff_bia_ok && diff_dst_ok;
167     }
168 
Test()169     void Test() {
170         auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
171         test_inner_product_descr_t ipd = p.test_ipd;
172 
173         bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
174         if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
175 
176         bool with_bias = p.diff_bias_format != memory::format_tag::undef;
177 
178         auto eng = get_test_engine();
179         auto strm = make_stream(eng);
180         memory::data_type data_type = data_traits<data_t>::data_type;
181         ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
182 
183         memory::dims src_dims = {ipd.mb, ipd.ic},
184                      diff_wei_dims = {ipd.oc, ipd.ic};
185         if (has_spatial) {
186             if (p.ndims == 5) {
187                 src_dims.push_back(ipd.kd);
188                 diff_wei_dims.push_back(ipd.kd);
189             }
190             if (p.ndims >= 4) {
191                 src_dims.push_back(ipd.kh);
192                 diff_wei_dims.push_back(ipd.kh);
193             }
194             if (p.ndims >= 3) {
195                 src_dims.push_back(ipd.kw);
196                 diff_wei_dims.push_back(ipd.kw);
197             }
198         }
199         auto ip_src_desc = create_md(src_dims, data_type, p.src_format);
200         auto ip_diff_weights_desc
201                 = create_md(diff_wei_dims, data_type, p.diff_weights_format);
202         auto ip_diff_dst_desc
203                 = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format);
204         auto ip_diff_bias_desc = with_bias
205                 ? create_md({ipd.oc}, data_type, p.diff_bias_format)
206                 : create_md({}, data_type, p.diff_bias_format);
207 
208         // Create inner product forward (hint for backward)
209         auto ip_fwd_desc = inner_product_forward::desc(prop_kind::forward,
210                 ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc);
211         auto ip_fwd_pdesc
212                 = inner_product_forward::primitive_desc(ip_fwd_desc, eng);
213 
214         // Create inner product backward
215         auto ip_desc = with_bias
216                 ? inner_product_backward_weights::desc(ip_src_desc,
217                         ip_diff_weights_desc, ip_diff_bias_desc,
218                         ip_diff_dst_desc)
219                 : inner_product_backward_weights::desc(
220                         ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc);
221 
222         auto ip_primitive_desc = inner_product_backward_weights::primitive_desc(
223                 ip_desc, eng, ip_fwd_pdesc);
224         ip_primitive_desc = inner_product_backward_weights::primitive_desc(
225                 ip_primitive_desc.get()); // test construction from a C pd
226 
227         auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng);
228         auto ip_diff_dst
229                 = test::make_memory(ip_primitive_desc.diff_dst_desc(), eng);
230         auto ip_diff_weights
231                 = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng);
232         auto diff_weights_ref
233                 = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng);
234         auto ip_diff_bias
235                 = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng);
236         auto diff_bias_ref
237                 = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng);
238 
239         fill_data<data_t>(
240                 ip_src.get_desc().get_size() / sizeof(data_t), ip_src);
241         fill_data<data_t>(ip_diff_dst.get_desc().get_size() / sizeof(data_t),
242                 ip_diff_dst);
243 
244         check_zero_tail<data_t>(1, ip_src);
245         check_zero_tail<data_t>(1, ip_diff_dst);
246 
247         ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
248                 == ip_primitive_desc.src_desc());
249         ASSERT_TRUE(ip_primitive_desc.query_md(
250                             query::exec_arg_md, DNNL_ARG_DIFF_DST)
251                 == ip_primitive_desc.diff_dst_desc());
252         ASSERT_TRUE(ip_primitive_desc.query_md(
253                             query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS)
254                 == ip_primitive_desc.diff_weights_desc());
255         ASSERT_TRUE(ip_primitive_desc.query_md(
256                             query::exec_arg_md, DNNL_ARG_DIFF_BIAS)
257                 == ip_primitive_desc.diff_bias_desc());
258 
259         inner_product_backward_weights(ip_primitive_desc)
260                 .execute(strm,
261                         {{DNNL_ARG_DIFF_DST, ip_diff_dst},
262                                 {DNNL_ARG_SRC, ip_src},
263                                 {DNNL_ARG_DIFF_WEIGHTS, ip_diff_weights},
264                                 {DNNL_ARG_DIFF_BIAS, ip_diff_bias}});
265         strm.wait();
266 
267         compute_ref_inner_product_bwd_weights<data_t>(
268                 p.ndims, ipd, ip_src, ip_diff_dst, diff_weights_ref);
269         check_zero_tail<data_t>(1, diff_weights_ref);
270 
271         compare_data<data_t>(diff_weights_ref, ip_diff_weights);
272 
273         check_zero_tail<data_t>(0, ip_diff_weights);
274 
275         if (with_bias) {
276             compute_ref_inner_product_bwd_bias<data_t>(
277                     ipd, ip_diff_dst, diff_bias_ref);
278             compare_data<data_t>(diff_bias_ref, ip_diff_bias);
279         }
280     }
281 };
282 
283 using inner_product_test_float = inner_product_test_bwd_weights_t<float>;
284 using inprod_test_params_float = inprod_test_params_t;
285 
286 #define EXPAND_SIZES_3D(...) \
287     5, { __VA_ARGS__ }
288 #define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \
289     4, { mb, ic, oc, 1, kh, kw }
290 #define EXPAND_SIZES_1D(mb, ic, oc, kw) \
291     3, { mb, ic, oc, 1, 1, kw }
292 
TEST_P(inner_product_test_float,TestsInnerProduct)293 TEST_P(inner_product_test_float, TestsInnerProduct) {}
294 
295 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsZeroDim,
296         inner_product_test_float,
297         ::testing::Values(inprod_test_params_float {memory::format_tag::any,
298                 memory::format_tag::any, memory::format_tag::any,
299                 memory::format_tag::any, EXPAND_SIZES_2D(0, 32, 48, 6, 6)}));
300 
301 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsEF,
302         inner_product_test_float,
303         ::testing::Values(
304                 inprod_test_params_float {memory::format_tag::any,
305                         memory::format_tag::any, memory::format_tag::any,
306                         memory::format_tag::any,
307                         EXPAND_SIZES_2D(2, 0, 48, 6, 6), true,
308                         dnnl_invalid_arguments},
309                 inprod_test_params_float {memory::format_tag::any,
310                         memory::format_tag::any, memory::format_tag::any,
311                         memory::format_tag::any,
312                         EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true,
313                         dnnl_invalid_arguments},
314                 inprod_test_params_float {memory::format_tag::any,
315                         memory::format_tag::any, memory::format_tag::any,
316                         memory::format_tag::any,
317                         EXPAND_SIZES_2D(2, -1, 48, 6, 6), true,
318                         dnnl_invalid_arguments}));
319 
320 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias_padded,
321         inner_product_test_float,
322         ::testing::Values(
323                 inprod_test_params_float {memory::format_tag::nChw16c,
324                         memory::format_tag::aBcd16b, memory::format_tag::undef,
325                         memory::format_tag::nc,
326                         EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
327                 inprod_test_params_float {memory::format_tag::nChw16c,
328                         memory::format_tag::aBcd16b, memory::format_tag::undef,
329                         memory::format_tag::nc,
330                         EXPAND_SIZES_2D(2, 10, 5, 3, 3)},
331                 inprod_test_params_float {memory::format_tag::nChw8c,
332                         memory::format_tag::aBcd8b, memory::format_tag::undef,
333                         memory::format_tag::nc,
334                         EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
335                 inprod_test_params_float {memory::format_tag::nChw8c,
336                         memory::format_tag::aBcd8b, memory::format_tag::undef,
337                         memory::format_tag::nc,
338                         EXPAND_SIZES_2D(2, 5, 15, 3, 3)}));
339 
340 GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights_padded,
341         inner_product_test_float,
342         ::testing::Values(inprod_test_params_float {memory::format_tag::nChw16c,
343                                   memory::format_tag::aBcd16b,
344                                   memory::format_tag::x, memory::format_tag::nc,
345                                   EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
346                 inprod_test_params_float {memory::format_tag::nChw16c,
347                         memory::format_tag::aBcd16b, memory::format_tag::x,
348                         memory::format_tag::nc,
349                         EXPAND_SIZES_2D(2, 10, 5, 3, 3)},
350                 inprod_test_params_float {memory::format_tag::nChw8c,
351                         memory::format_tag::aBcd8b, memory::format_tag::x,
352                         memory::format_tag::nc,
353                         EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
354                 inprod_test_params_float {memory::format_tag::nChw8c,
355                         memory::format_tag::aBcd8b, memory::format_tag::x,
356                         memory::format_tag::nc,
357                         EXPAND_SIZES_2D(2, 5, 15, 3, 3)}));
358 
359 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias,
360         inner_product_test_float,
361         ::testing::Values(
362                 inprod_test_params_float {memory::format_tag::any,
363                         memory::format_tag::any, memory::format_tag::undef,
364                         memory::format_tag::any,
365                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
366                 inprod_test_params_float {memory::format_tag::any,
367                         memory::format_tag::any, memory::format_tag::undef,
368                         memory::format_tag::any,
369                         EXPAND_SIZES_2D(2, 1024, 48, 2, 2)},
370                 inprod_test_params_float {memory::format_tag::nwc,
371                         memory::format_tag::owi, memory::format_tag::undef,
372                         memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
373                 inprod_test_params_float {memory::format_tag::nwc,
374                         memory::format_tag::wio, memory::format_tag::undef,
375                         memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
376                 inprod_test_params_float {memory::format_tag::nwc,
377                         memory::format_tag::oiw, memory::format_tag::undef,
378                         memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
379                 inprod_test_params_float {memory::format_tag::ncw,
380                         memory::format_tag::oiw, memory::format_tag::undef,
381                         memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
382                 inprod_test_params_float {memory::format_tag::ncw,
383                         memory::format_tag::wio, memory::format_tag::undef,
384                         memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
385                 inprod_test_params_float {memory::format_tag::ncw,
386                         memory::format_tag::owi, memory::format_tag::undef,
387                         memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
388                 inprod_test_params_float {memory::format_tag::nhwc,
389                         memory::format_tag::hwio, memory::format_tag::undef,
390                         memory::format_tag::nc,
391                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
392                 inprod_test_params_float {memory::format_tag::nhwc,
393                         memory::format_tag::oihw, memory::format_tag::undef,
394                         memory::format_tag::nc,
395                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
396                 inprod_test_params_float {memory::format_tag::nhwc,
397                         memory::format_tag::ohwi, memory::format_tag::undef,
398                         memory::format_tag::nc,
399                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
400                 inprod_test_params_float {memory::format_tag::nchw,
401                         memory::format_tag::oihw, memory::format_tag::undef,
402                         memory::format_tag::nc,
403                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
404                 inprod_test_params_float {memory::format_tag::nchw,
405                         memory::format_tag::ohwi, memory::format_tag::undef,
406                         memory::format_tag::nc,
407                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
408                 inprod_test_params_float {memory::format_tag::nchw,
409                         memory::format_tag::hwio, memory::format_tag::undef,
410                         memory::format_tag::nc,
411                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
412                 inprod_test_params_float {memory::format_tag::nChw8c,
413                         memory::format_tag::aBcd8b, memory::format_tag::undef,
414                         memory::format_tag::nc,
415                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
416                 inprod_test_params_float {memory::format_tag::nChw16c,
417                         memory::format_tag::aBcd16b, memory::format_tag::undef,
418                         memory::format_tag::nc,
419                         EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
420                 inprod_test_params_float {memory::format_tag::any,
421                         memory::format_tag::aBcd16b, memory::format_tag::undef,
422                         memory::format_tag::nc,
423                         EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
424                 inprod_test_params_float {memory::format_tag::nChw16c,
425                         memory::format_tag::any, memory::format_tag::undef,
426                         memory::format_tag::nc,
427                         EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
428                 inprod_test_params_float {memory::format_tag::nChw16c,
429                         memory::format_tag::aBcd16b, memory::format_tag::undef,
430                         memory::format_tag::nc,
431                         EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
432                 inprod_test_params_float {memory::format_tag::nc,
433                         memory::format_tag::oi, memory::format_tag::undef,
434                         memory::format_tag::nc,
435                         EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
436                 inprod_test_params_float {memory::format_tag::nc,
437                         memory::format_tag::oi, memory::format_tag::undef,
438                         memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
439                 inprod_test_params_float {memory::format_tag::nc,
440                         memory::format_tag::io, memory::format_tag::undef,
441                         memory::format_tag::nc,
442                         EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
443 
444 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights,
445         inner_product_test_float,
446         ::testing::Values(
447                 inprod_test_params_float {memory::format_tag::any,
448                         memory::format_tag::any, memory::format_tag::any,
449                         memory::format_tag::any,
450                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
451                 inprod_test_params_float {memory::format_tag::any,
452                         memory::format_tag::any, memory::format_tag::any,
453                         memory::format_tag::any,
454                         EXPAND_SIZES_2D(2, 32, 1024, 2, 2)},
455                 inprod_test_params_float {memory::format_tag::nhwc,
456                         memory::format_tag::hwio, memory::format_tag::x,
457                         memory::format_tag::nc,
458                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
459                 inprod_test_params_float {memory::format_tag::nhwc,
460                         memory::format_tag::oihw, memory::format_tag::x,
461                         memory::format_tag::nc,
462                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
463                 inprod_test_params_float {memory::format_tag::nchw,
464                         memory::format_tag::oihw, memory::format_tag::x,
465                         memory::format_tag::nc,
466                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
467                 inprod_test_params_float {memory::format_tag::nChw8c,
468                         memory::format_tag::aBcd8b, memory::format_tag::x,
469                         memory::format_tag::nc,
470                         EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
471                 inprod_test_params_float {memory::format_tag::nChw16c,
472                         memory::format_tag::aBcd16b, memory::format_tag::x,
473                         memory::format_tag::nc,
474                         EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
475                 inprod_test_params_float {memory::format_tag::nc,
476                         memory::format_tag::oi, memory::format_tag::x,
477                         memory::format_tag::nc,
478                         EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
479                 inprod_test_params_float {memory::format_tag::nc,
480                         memory::format_tag::oi, memory::format_tag::x,
481                         memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
482                 inprod_test_params_float {memory::format_tag::nc,
483                         memory::format_tag::io, memory::format_tag::x,
484                         memory::format_tag::nc,
485                         EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
486 
487 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights3D,
488         inner_product_test_float,
489         ::testing::Values(
490                 inprod_test_params_float {memory::format_tag::any,
491                         memory::format_tag::any, memory::format_tag::any,
492                         memory::format_tag::any,
493                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
494                 inprod_test_params_float {memory::format_tag::any,
495                         memory::format_tag::any, memory::format_tag::any,
496                         memory::format_tag::any,
497                         EXPAND_SIZES_3D(2, 32, 1024, 2, 2, 2)},
498                 inprod_test_params_float {memory::format_tag::ncdhw,
499                         memory::format_tag::oidhw, memory::format_tag::x,
500                         memory::format_tag::nc,
501                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
502                 inprod_test_params_float {memory::format_tag::ncdhw,
503                         memory::format_tag::dhwio, memory::format_tag::x,
504                         memory::format_tag::nc,
505                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
506                 inprod_test_params_float {memory::format_tag::ncdhw,
507                         memory::format_tag::odhwi, memory::format_tag::x,
508                         memory::format_tag::nc,
509                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
510                 inprod_test_params_float {memory::format_tag::nCdhw8c,
511                         memory::format_tag::aBcde8b, memory::format_tag::x,
512                         memory::format_tag::nc,
513                         EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
514                 inprod_test_params_float {memory::format_tag::nCdhw16c,
515                         memory::format_tag::aBcde16b, memory::format_tag::x,
516                         memory::format_tag::nc,
517                         EXPAND_SIZES_3D(2, 32, 1000, 6, 6, 6)},
518                 inprod_test_params_float {memory::format_tag::ndhwc,
519                         memory::format_tag::dhwio, memory::format_tag::x,
520                         memory::format_tag::nc,
521                         EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
522                 inprod_test_params_float {memory::format_tag::ndhwc,
523                         memory::format_tag::odhwi, memory::format_tag::x,
524                         memory::format_tag::nc,
525                         EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
526                 inprod_test_params_float {memory::format_tag::ndhwc,
527                         memory::format_tag::oidhw, memory::format_tag::x,
528                         memory::format_tag::nc,
529                         EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}));
530 } // namespace dnnl
531