1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "oneapi/dnnl/dnnl.hpp"
21
22 namespace dnnl {
23
24 struct test_inner_product_descr_t {
25 memory::dim mb;
26 memory::dim ic;
27 memory::dim oc;
28 memory::dim kd, kh, kw;
29 };
30
31 template <typename data_t>
compute_ref_inner_product_bwd_bias(const test_inner_product_descr_t & ipd,const memory & diff_dst,const memory & diff_bias)32 void compute_ref_inner_product_bwd_bias(const test_inner_product_descr_t &ipd,
33 const memory &diff_dst, const memory &diff_bias) {
34 auto diff_bias_data = map_memory<data_t>(diff_bias);
35 auto diff_dst_data = map_memory<data_t>(diff_dst);
36
37 const memory::desc diff_bias_d = diff_bias.get_desc();
38 const memory::desc diff_dst_d = diff_dst.get_desc();
39 const dnnl::impl::memory_desc_wrapper diff_bias_mdw(diff_bias_d.data);
40 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.data);
41
42 dnnl::impl::parallel_nd(ipd.oc, [&](memory::dim oc) {
43 data_t *db = &diff_bias_data[diff_bias_mdw.off_l(oc, true)];
44 *db = data_t(0);
45 for (memory::dim n = 0; n < ipd.mb; ++n) {
46 *db += diff_dst_data[diff_dst_mdw.off_l(n * ipd.oc + oc, true)];
47 }
48 });
49 }
50
51 template <typename data_t>
compute_ref_inner_product_bwd_weights(int ndims,const test_inner_product_descr_t & ipd,const memory & src,const memory & diff_dst,const memory & diff_weights)52 void compute_ref_inner_product_bwd_weights(int ndims,
53 const test_inner_product_descr_t &ipd, const memory &src,
54 const memory &diff_dst, const memory &diff_weights) {
55 auto src_data = map_memory<data_t>(src);
56 auto diff_weights_data = map_memory<data_t>(diff_weights);
57 auto diff_dst_data = map_memory<data_t>(diff_dst);
58
59 const memory::desc src_d = src.get_desc();
60 const memory::desc diff_weights_d = diff_weights.get_desc();
61 const memory::desc diff_dst_d = diff_dst.get_desc();
62 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
63 const dnnl::impl::memory_desc_wrapper diff_weights_mdw(diff_weights_d.data);
64 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.data);
65
66 auto padded_ic = src_d.data.padded_dims[1];
67
68 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
69 if (ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
70 dnnl::impl::parallel_nd(
71 ipd.oc, ipd.ic, [&](memory::dim oc, memory::dim ic) {
72 if (has_spatial) {
73 for_(memory::dim kd = 0; kd < ipd.kd; ++kd)
74 for_(memory::dim kh = 0; kh < ipd.kh; ++kh)
75 for (memory::dim kw = 0; kw < ipd.kw; ++kw) {
76 memory::dim dwidx
77 = oc * padded_ic * ipd.kd * ipd.kh * ipd.kw
78 + ic * ipd.kd * ipd.kh * ipd.kw
79 + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
80 data_t *dw = &diff_weights_data[diff_weights_mdw.off_l(
81 dwidx, true)];
82 *dw = data_t(0);
83 for (memory::dim n = 0; n < ipd.mb; ++n) {
84 memory::dim ddidx = n * ipd.oc + oc;
85 memory::dim sidx
86 = n * padded_ic * ipd.kd * ipd.kh * ipd.kw
87 + ic * ipd.kd * ipd.kh * ipd.kw
88 + kd * ipd.kh * ipd.kw + kh * ipd.kw + kw;
89 *dw += diff_dst_data[diff_dst_mdw.off_l(
90 ddidx, true)]
91 * src_data[src_mdw.off_l(sidx, true)];
92 }
93 }
94 } else {
95 memory::dim dwidx = oc * ipd.ic + ic;
96 data_t *dw = &diff_weights_data[diff_weights_mdw.off_l(
97 dwidx, true)];
98 *dw = data_t(0);
99 for (memory::dim n = 0; n < ipd.mb; ++n) {
100 memory::dim ddidx = n * ipd.oc + oc;
101 memory::dim sidx = n * ipd.ic + ic;
102 *dw += diff_dst_data[diff_dst_mdw.off_l(ddidx, true)]
103 * src_data[src_mdw.off_l(sidx, true)];
104 }
105 }
106 });
107 }
108
109 struct inprod_test_params_t {
110 memory::format_tag src_format;
111 memory::format_tag diff_weights_format;
112 memory::format_tag diff_bias_format;
113 memory::format_tag diff_dst_format;
114 int ndims;
115 test_inner_product_descr_t test_ipd;
116 bool expect_to_fail;
117 dnnl_status_t expected_status;
118 };
119
120 template <typename data_t>
121 class inner_product_test_bwd_weights_t
122 : public ::testing::TestWithParam<inprod_test_params_t> {
123 protected:
SetUp()124 void SetUp() override {
125 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
126 SKIP_IF_CUDA(
127 !cuda_check_format_tags(p.src_format, p.diff_weights_format,
128 p.diff_bias_format, p.diff_dst_format),
129 "Unsupported format tag");
130 SKIP_IF_CUDA(p.ndims > 5, "Unsupported number of dimensions");
131 catch_expected_failures(
132 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
133 }
134
cuda_check_format_tags(memory::format_tag src_format,memory::format_tag diff_wei_format,memory::format_tag diff_bia_format,memory::format_tag diff_dst_format)135 bool cuda_check_format_tags(memory::format_tag src_format,
136 memory::format_tag diff_wei_format,
137 memory::format_tag diff_bia_format,
138 memory::format_tag diff_dst_format) {
139 bool src_ok = src_format == memory::format_tag::ncdhw
140 || src_format == memory::format_tag::ndhwc
141 || src_format == memory::format_tag::nchw
142 || src_format == memory::format_tag::nhwc
143 || src_format == memory::format_tag::ncw
144 || src_format == memory::format_tag::nwc
145 || src_format == memory::format_tag::nc
146 || src_format == memory::format_tag::any;
147 bool diff_wei_ok = diff_wei_format == memory::format_tag::oidhw
148 || diff_wei_format == memory::format_tag::odhwi
149 || diff_wei_format == memory::format_tag::dhwio
150 || diff_wei_format == memory::format_tag::oihw
151 || diff_wei_format == memory::format_tag::ohwi
152 || diff_wei_format == memory::format_tag::hwio
153 || diff_wei_format == memory::format_tag::oiw
154 || diff_wei_format == memory::format_tag::owi
155 || diff_wei_format == memory::format_tag::wio
156 || diff_wei_format == memory::format_tag::io
157 || diff_wei_format == memory::format_tag::oi
158 || diff_wei_format == memory::format_tag::any;
159 bool diff_bia_ok = diff_bia_format == memory::format_tag::undef
160 || diff_bia_format == memory::format_tag::any
161 || diff_bia_format == memory::format_tag::a
162 || diff_bia_format == memory::format_tag::x;
163 bool diff_dst_ok = diff_dst_format == memory::format_tag::any
164 || diff_dst_format == memory::format_tag::nc;
165
166 return src_ok && diff_wei_ok && diff_bia_ok && diff_dst_ok;
167 }
168
Test()169 void Test() {
170 auto p = ::testing::TestWithParam<inprod_test_params_t>::GetParam();
171 test_inner_product_descr_t ipd = p.test_ipd;
172
173 bool has_spatial = ipd.kh > 1 || ipd.kw > 1;
174 if (p.ndims == 5) has_spatial = has_spatial || ipd.kd > 1;
175
176 bool with_bias = p.diff_bias_format != memory::format_tag::undef;
177
178 auto eng = get_test_engine();
179 auto strm = make_stream(eng);
180 memory::data_type data_type = data_traits<data_t>::data_type;
181 ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
182
183 memory::dims src_dims = {ipd.mb, ipd.ic},
184 diff_wei_dims = {ipd.oc, ipd.ic};
185 if (has_spatial) {
186 if (p.ndims == 5) {
187 src_dims.push_back(ipd.kd);
188 diff_wei_dims.push_back(ipd.kd);
189 }
190 if (p.ndims >= 4) {
191 src_dims.push_back(ipd.kh);
192 diff_wei_dims.push_back(ipd.kh);
193 }
194 if (p.ndims >= 3) {
195 src_dims.push_back(ipd.kw);
196 diff_wei_dims.push_back(ipd.kw);
197 }
198 }
199 auto ip_src_desc = create_md(src_dims, data_type, p.src_format);
200 auto ip_diff_weights_desc
201 = create_md(diff_wei_dims, data_type, p.diff_weights_format);
202 auto ip_diff_dst_desc
203 = create_md({ipd.mb, ipd.oc}, data_type, p.diff_dst_format);
204 auto ip_diff_bias_desc = with_bias
205 ? create_md({ipd.oc}, data_type, p.diff_bias_format)
206 : create_md({}, data_type, p.diff_bias_format);
207
208 // Create inner product forward (hint for backward)
209 auto ip_fwd_desc = inner_product_forward::desc(prop_kind::forward,
210 ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc);
211 auto ip_fwd_pdesc
212 = inner_product_forward::primitive_desc(ip_fwd_desc, eng);
213
214 // Create inner product backward
215 auto ip_desc = with_bias
216 ? inner_product_backward_weights::desc(ip_src_desc,
217 ip_diff_weights_desc, ip_diff_bias_desc,
218 ip_diff_dst_desc)
219 : inner_product_backward_weights::desc(
220 ip_src_desc, ip_diff_weights_desc, ip_diff_dst_desc);
221
222 auto ip_primitive_desc = inner_product_backward_weights::primitive_desc(
223 ip_desc, eng, ip_fwd_pdesc);
224 ip_primitive_desc = inner_product_backward_weights::primitive_desc(
225 ip_primitive_desc.get()); // test construction from a C pd
226
227 auto ip_src = test::make_memory(ip_primitive_desc.src_desc(), eng);
228 auto ip_diff_dst
229 = test::make_memory(ip_primitive_desc.diff_dst_desc(), eng);
230 auto ip_diff_weights
231 = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng);
232 auto diff_weights_ref
233 = test::make_memory(ip_primitive_desc.diff_weights_desc(), eng);
234 auto ip_diff_bias
235 = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng);
236 auto diff_bias_ref
237 = test::make_memory(ip_primitive_desc.diff_bias_desc(), eng);
238
239 fill_data<data_t>(
240 ip_src.get_desc().get_size() / sizeof(data_t), ip_src);
241 fill_data<data_t>(ip_diff_dst.get_desc().get_size() / sizeof(data_t),
242 ip_diff_dst);
243
244 check_zero_tail<data_t>(1, ip_src);
245 check_zero_tail<data_t>(1, ip_diff_dst);
246
247 ASSERT_TRUE(ip_primitive_desc.query_md(query::exec_arg_md, DNNL_ARG_SRC)
248 == ip_primitive_desc.src_desc());
249 ASSERT_TRUE(ip_primitive_desc.query_md(
250 query::exec_arg_md, DNNL_ARG_DIFF_DST)
251 == ip_primitive_desc.diff_dst_desc());
252 ASSERT_TRUE(ip_primitive_desc.query_md(
253 query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS)
254 == ip_primitive_desc.diff_weights_desc());
255 ASSERT_TRUE(ip_primitive_desc.query_md(
256 query::exec_arg_md, DNNL_ARG_DIFF_BIAS)
257 == ip_primitive_desc.diff_bias_desc());
258
259 inner_product_backward_weights(ip_primitive_desc)
260 .execute(strm,
261 {{DNNL_ARG_DIFF_DST, ip_diff_dst},
262 {DNNL_ARG_SRC, ip_src},
263 {DNNL_ARG_DIFF_WEIGHTS, ip_diff_weights},
264 {DNNL_ARG_DIFF_BIAS, ip_diff_bias}});
265 strm.wait();
266
267 compute_ref_inner_product_bwd_weights<data_t>(
268 p.ndims, ipd, ip_src, ip_diff_dst, diff_weights_ref);
269 check_zero_tail<data_t>(1, diff_weights_ref);
270
271 compare_data<data_t>(diff_weights_ref, ip_diff_weights);
272
273 check_zero_tail<data_t>(0, ip_diff_weights);
274
275 if (with_bias) {
276 compute_ref_inner_product_bwd_bias<data_t>(
277 ipd, ip_diff_dst, diff_bias_ref);
278 compare_data<data_t>(diff_bias_ref, ip_diff_bias);
279 }
280 }
281 };
282
283 using inner_product_test_float = inner_product_test_bwd_weights_t<float>;
284 using inprod_test_params_float = inprod_test_params_t;
285
286 #define EXPAND_SIZES_3D(...) \
287 5, { __VA_ARGS__ }
288 #define EXPAND_SIZES_2D(mb, ic, oc, kh, kw) \
289 4, { mb, ic, oc, 1, kh, kw }
290 #define EXPAND_SIZES_1D(mb, ic, oc, kw) \
291 3, { mb, ic, oc, 1, 1, kw }
292
TEST_P(inner_product_test_float,TestsInnerProduct)293 TEST_P(inner_product_test_float, TestsInnerProduct) {}
294
295 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsZeroDim,
296 inner_product_test_float,
297 ::testing::Values(inprod_test_params_float {memory::format_tag::any,
298 memory::format_tag::any, memory::format_tag::any,
299 memory::format_tag::any, EXPAND_SIZES_2D(0, 32, 48, 6, 6)}));
300
301 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsEF,
302 inner_product_test_float,
303 ::testing::Values(
304 inprod_test_params_float {memory::format_tag::any,
305 memory::format_tag::any, memory::format_tag::any,
306 memory::format_tag::any,
307 EXPAND_SIZES_2D(2, 0, 48, 6, 6), true,
308 dnnl_invalid_arguments},
309 inprod_test_params_float {memory::format_tag::any,
310 memory::format_tag::any, memory::format_tag::any,
311 memory::format_tag::any,
312 EXPAND_SIZES_2D(-1, 32, 48, 6, 6), true,
313 dnnl_invalid_arguments},
314 inprod_test_params_float {memory::format_tag::any,
315 memory::format_tag::any, memory::format_tag::any,
316 memory::format_tag::any,
317 EXPAND_SIZES_2D(2, -1, 48, 6, 6), true,
318 dnnl_invalid_arguments}));
319
320 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias_padded,
321 inner_product_test_float,
322 ::testing::Values(
323 inprod_test_params_float {memory::format_tag::nChw16c,
324 memory::format_tag::aBcd16b, memory::format_tag::undef,
325 memory::format_tag::nc,
326 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
327 inprod_test_params_float {memory::format_tag::nChw16c,
328 memory::format_tag::aBcd16b, memory::format_tag::undef,
329 memory::format_tag::nc,
330 EXPAND_SIZES_2D(2, 10, 5, 3, 3)},
331 inprod_test_params_float {memory::format_tag::nChw8c,
332 memory::format_tag::aBcd8b, memory::format_tag::undef,
333 memory::format_tag::nc,
334 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
335 inprod_test_params_float {memory::format_tag::nChw8c,
336 memory::format_tag::aBcd8b, memory::format_tag::undef,
337 memory::format_tag::nc,
338 EXPAND_SIZES_2D(2, 5, 15, 3, 3)}));
339
340 GPU_INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights_padded,
341 inner_product_test_float,
342 ::testing::Values(inprod_test_params_float {memory::format_tag::nChw16c,
343 memory::format_tag::aBcd16b,
344 memory::format_tag::x, memory::format_tag::nc,
345 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
346 inprod_test_params_float {memory::format_tag::nChw16c,
347 memory::format_tag::aBcd16b, memory::format_tag::x,
348 memory::format_tag::nc,
349 EXPAND_SIZES_2D(2, 10, 5, 3, 3)},
350 inprod_test_params_float {memory::format_tag::nChw8c,
351 memory::format_tag::aBcd8b, memory::format_tag::x,
352 memory::format_tag::nc,
353 EXPAND_SIZES_2D(2, 17, 5, 3, 3)},
354 inprod_test_params_float {memory::format_tag::nChw8c,
355 memory::format_tag::aBcd8b, memory::format_tag::x,
356 memory::format_tag::nc,
357 EXPAND_SIZES_2D(2, 5, 15, 3, 3)}));
358
359 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeightsNoBias,
360 inner_product_test_float,
361 ::testing::Values(
362 inprod_test_params_float {memory::format_tag::any,
363 memory::format_tag::any, memory::format_tag::undef,
364 memory::format_tag::any,
365 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
366 inprod_test_params_float {memory::format_tag::any,
367 memory::format_tag::any, memory::format_tag::undef,
368 memory::format_tag::any,
369 EXPAND_SIZES_2D(2, 1024, 48, 2, 2)},
370 inprod_test_params_float {memory::format_tag::nwc,
371 memory::format_tag::owi, memory::format_tag::undef,
372 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
373 inprod_test_params_float {memory::format_tag::nwc,
374 memory::format_tag::wio, memory::format_tag::undef,
375 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
376 inprod_test_params_float {memory::format_tag::nwc,
377 memory::format_tag::oiw, memory::format_tag::undef,
378 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
379 inprod_test_params_float {memory::format_tag::ncw,
380 memory::format_tag::oiw, memory::format_tag::undef,
381 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
382 inprod_test_params_float {memory::format_tag::ncw,
383 memory::format_tag::wio, memory::format_tag::undef,
384 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
385 inprod_test_params_float {memory::format_tag::ncw,
386 memory::format_tag::owi, memory::format_tag::undef,
387 memory::format_tag::nc, EXPAND_SIZES_1D(2, 32, 48, 6)},
388 inprod_test_params_float {memory::format_tag::nhwc,
389 memory::format_tag::hwio, memory::format_tag::undef,
390 memory::format_tag::nc,
391 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
392 inprod_test_params_float {memory::format_tag::nhwc,
393 memory::format_tag::oihw, memory::format_tag::undef,
394 memory::format_tag::nc,
395 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
396 inprod_test_params_float {memory::format_tag::nhwc,
397 memory::format_tag::ohwi, memory::format_tag::undef,
398 memory::format_tag::nc,
399 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
400 inprod_test_params_float {memory::format_tag::nchw,
401 memory::format_tag::oihw, memory::format_tag::undef,
402 memory::format_tag::nc,
403 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
404 inprod_test_params_float {memory::format_tag::nchw,
405 memory::format_tag::ohwi, memory::format_tag::undef,
406 memory::format_tag::nc,
407 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
408 inprod_test_params_float {memory::format_tag::nchw,
409 memory::format_tag::hwio, memory::format_tag::undef,
410 memory::format_tag::nc,
411 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
412 inprod_test_params_float {memory::format_tag::nChw8c,
413 memory::format_tag::aBcd8b, memory::format_tag::undef,
414 memory::format_tag::nc,
415 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
416 inprod_test_params_float {memory::format_tag::nChw16c,
417 memory::format_tag::aBcd16b, memory::format_tag::undef,
418 memory::format_tag::nc,
419 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
420 inprod_test_params_float {memory::format_tag::any,
421 memory::format_tag::aBcd16b, memory::format_tag::undef,
422 memory::format_tag::nc,
423 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
424 inprod_test_params_float {memory::format_tag::nChw16c,
425 memory::format_tag::any, memory::format_tag::undef,
426 memory::format_tag::nc,
427 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
428 inprod_test_params_float {memory::format_tag::nChw16c,
429 memory::format_tag::aBcd16b, memory::format_tag::undef,
430 memory::format_tag::nc,
431 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
432 inprod_test_params_float {memory::format_tag::nc,
433 memory::format_tag::oi, memory::format_tag::undef,
434 memory::format_tag::nc,
435 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
436 inprod_test_params_float {memory::format_tag::nc,
437 memory::format_tag::oi, memory::format_tag::undef,
438 memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
439 inprod_test_params_float {memory::format_tag::nc,
440 memory::format_tag::io, memory::format_tag::undef,
441 memory::format_tag::nc,
442 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
443
444 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights,
445 inner_product_test_float,
446 ::testing::Values(
447 inprod_test_params_float {memory::format_tag::any,
448 memory::format_tag::any, memory::format_tag::any,
449 memory::format_tag::any,
450 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
451 inprod_test_params_float {memory::format_tag::any,
452 memory::format_tag::any, memory::format_tag::any,
453 memory::format_tag::any,
454 EXPAND_SIZES_2D(2, 32, 1024, 2, 2)},
455 inprod_test_params_float {memory::format_tag::nhwc,
456 memory::format_tag::hwio, memory::format_tag::x,
457 memory::format_tag::nc,
458 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
459 inprod_test_params_float {memory::format_tag::nhwc,
460 memory::format_tag::oihw, memory::format_tag::x,
461 memory::format_tag::nc,
462 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
463 inprod_test_params_float {memory::format_tag::nchw,
464 memory::format_tag::oihw, memory::format_tag::x,
465 memory::format_tag::nc,
466 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
467 inprod_test_params_float {memory::format_tag::nChw8c,
468 memory::format_tag::aBcd8b, memory::format_tag::x,
469 memory::format_tag::nc,
470 EXPAND_SIZES_2D(2, 32, 48, 6, 6)},
471 inprod_test_params_float {memory::format_tag::nChw16c,
472 memory::format_tag::aBcd16b, memory::format_tag::x,
473 memory::format_tag::nc,
474 EXPAND_SIZES_2D(2, 32, 1000, 6, 6)},
475 inprod_test_params_float {memory::format_tag::nc,
476 memory::format_tag::oi, memory::format_tag::x,
477 memory::format_tag::nc,
478 EXPAND_SIZES_2D(2, 32, 1152, 1, 1)},
479 inprod_test_params_float {memory::format_tag::nc,
480 memory::format_tag::oi, memory::format_tag::x,
481 memory::format_tag::nc, EXPAND_SIZES_2D(2, 2, 4, 1, 1)},
482 inprod_test_params_float {memory::format_tag::nc,
483 memory::format_tag::io, memory::format_tag::x,
484 memory::format_tag::nc,
485 EXPAND_SIZES_2D(2, 8, 16, 1, 1)}));
486
487 INSTANTIATE_TEST_SUITE_P(TestInnerProductBackwardWeights3D,
488 inner_product_test_float,
489 ::testing::Values(
490 inprod_test_params_float {memory::format_tag::any,
491 memory::format_tag::any, memory::format_tag::any,
492 memory::format_tag::any,
493 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
494 inprod_test_params_float {memory::format_tag::any,
495 memory::format_tag::any, memory::format_tag::any,
496 memory::format_tag::any,
497 EXPAND_SIZES_3D(2, 32, 1024, 2, 2, 2)},
498 inprod_test_params_float {memory::format_tag::ncdhw,
499 memory::format_tag::oidhw, memory::format_tag::x,
500 memory::format_tag::nc,
501 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
502 inprod_test_params_float {memory::format_tag::ncdhw,
503 memory::format_tag::dhwio, memory::format_tag::x,
504 memory::format_tag::nc,
505 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
506 inprod_test_params_float {memory::format_tag::ncdhw,
507 memory::format_tag::odhwi, memory::format_tag::x,
508 memory::format_tag::nc,
509 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
510 inprod_test_params_float {memory::format_tag::nCdhw8c,
511 memory::format_tag::aBcde8b, memory::format_tag::x,
512 memory::format_tag::nc,
513 EXPAND_SIZES_3D(2, 32, 48, 6, 6, 6)},
514 inprod_test_params_float {memory::format_tag::nCdhw16c,
515 memory::format_tag::aBcde16b, memory::format_tag::x,
516 memory::format_tag::nc,
517 EXPAND_SIZES_3D(2, 32, 1000, 6, 6, 6)},
518 inprod_test_params_float {memory::format_tag::ndhwc,
519 memory::format_tag::dhwio, memory::format_tag::x,
520 memory::format_tag::nc,
521 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
522 inprod_test_params_float {memory::format_tag::ndhwc,
523 memory::format_tag::odhwi, memory::format_tag::x,
524 memory::format_tag::nc,
525 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)},
526 inprod_test_params_float {memory::format_tag::ndhwc,
527 memory::format_tag::oidhw, memory::format_tag::x,
528 memory::format_tag::nc,
529 EXPAND_SIZES_3D(2, 16, 48, 3, 3, 3)}));
530 } // namespace dnnl
531