1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "dnnl_test_common.hpp"
18 #include "gtest/gtest.h"
19 
20 #include "oneapi/dnnl/dnnl.hpp"
21 
22 namespace dnnl {
23 
24 using tag = memory::format_tag;
25 
26 template <typename data_t>
27 struct softmax_test_params_t {
28     prop_kind aprop_kind;
29     tag memory_format;
30     tag diff_memory_format;
31     memory::dims dims;
32     int axis;
33     bool expect_to_fail;
34     dnnl_status_t expected_status;
35 };
36 
37 template <typename data_t>
38 class softmax_test_t
39     : public ::testing::TestWithParam<softmax_test_params_t<data_t>> {
40 private:
41     softmax_test_params_t<data_t> p;
42     memory dst, workspace;
43     std::shared_ptr<softmax_forward::primitive_desc> pd_fwd_hint;
44 
45 protected:
SetUp()46     void SetUp() override {
47         p = ::testing::TestWithParam<softmax_test_params_t<data_t>>::GetParam();
48 
49         SKIP_IF_CUDA(!cuda_check_format_tag(p.memory_format),
50                 "Unsupported format tag");
51         SKIP_IF_CUDA(!cuda_check_format_tag(p.diff_memory_format),
52                 "Unsupported format tag");
53         SKIP_IF_CUDA(data_traits<data_t>::data_type == memory::data_type::bf16,
54                 "Unsupported datatype for CUDA");
55 
56         catch_expected_failures(
57                 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
58     }
cuda_check_format_tag(memory::format_tag tag)59     bool cuda_check_format_tag(memory::format_tag tag) {
60         return (tag != memory::format_tag::aBcd8b
61                 && tag != memory::format_tag::aBc16b);
62     }
63 
Forward()64     void Forward() {
65         // softmax specific types and values
66         using op_desc_t = softmax_forward::desc;
67         using pd_t = softmax_forward::primitive_desc;
68         allows_attr_t aa {false}; // doesn't support anything
69 
70         auto eng = get_test_engine();
71         auto strm = make_stream(eng);
72         prop_kind pk = p.aprop_kind == prop_kind::backward_data
73                 ? prop_kind::forward_training
74                 : p.aprop_kind;
75         auto prec = data_traits<data_t>::data_type;
76         auto mem_desc = memory::desc(p.dims, prec, p.memory_format);
77 
78         // default op desc ctor
79         auto op_desc = op_desc_t();
80         // regular op desc ctor
81         op_desc = op_desc_t(pk, mem_desc, p.axis);
82 
83         // default pd ctor
84         auto pd = pd_t();
85         // regular pd ctor
86         ASSERT_NO_THROW(pd = pd_t(op_desc, eng));
87         // test all pd ctors
88         test_fwd_pd_constructors<op_desc_t, pd_t>(op_desc, pd, aa);
89         pd_fwd_hint = std::make_shared<pd_t>(pd);
90 
91         // default primitive ctor
92         auto softmax = softmax_forward();
93         // regular primitive ctor
94         softmax = softmax_forward(pd);
95 
96         // query for data_desc from pd via src
97         const auto data_desc = pd.src_desc();
98         // query for data_desc from pd via dst
99         ASSERT_TRUE(pd.dst_desc() == data_desc);
100         // query for data_desc via exec arg number of src
101         ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) == data_desc);
102         // query for data_desc via exec arg number of dst
103         ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DST) == data_desc);
104 
105         // query for workspace
106         const auto workspace_desc = pd.workspace_desc();
107 
108         // check primitive returns zero_md for all rest md
109         ASSERT_TRUE(pd.weights_desc().is_zero());
110         ASSERT_TRUE(pd.diff_src_desc().is_zero());
111         ASSERT_TRUE(pd.diff_dst_desc().is_zero());
112         ASSERT_TRUE(pd.diff_weights_desc().is_zero());
113 
114         auto src = test::make_memory(data_desc, eng);
115         dst = test::make_memory(data_desc, eng);
116         workspace = test::make_memory(workspace_desc, eng);
117 
118         auto test_with_given_fill = [&](data_t mean, data_t var) {
119             fill_data<data_t>(
120                     data_desc.get_size() / sizeof(data_t), src, mean, var);
121             check_zero_tail<data_t>(1, src);
122 
123             // test out-place mode
124             softmax.execute(strm,
125                     {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, dst},
126                             {DNNL_ARG_WORKSPACE, workspace}});
127             strm.wait();
128             check_zero_tail<data_t>(0, dst);
129 
130             // test in-place mode
131             if (p.aprop_kind != prop_kind::backward_data) {
132                 softmax.execute(strm,
133                         {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, src},
134                                 {DNNL_ARG_WORKSPACE, workspace}});
135                 strm.wait();
136                 check_zero_tail<data_t>(0, src);
137             }
138         };
139 
140         test_with_given_fill(200, 1);
141     }
142 
Backward()143     void Backward() {
144         // softmax specific types and values
145         using op_desc_t = softmax_backward::desc;
146         using pd_t = softmax_backward::primitive_desc;
147         using hint_pd_t = softmax_forward::primitive_desc;
148         allows_attr_t aa {false}; // doesn't support anything
149 
150         auto eng = get_test_engine();
151         auto strm = make_stream(eng);
152         auto prec = data_traits<data_t>::data_type;
153         auto mem_desc = memory::desc(p.dims, prec, p.memory_format);
154         auto diff_mem_desc = memory::desc(p.dims, prec, p.diff_memory_format);
155 
156         // default op desc ctor
157         auto op_desc = op_desc_t();
158         // regular op desc ctor
159         op_desc = op_desc_t(diff_mem_desc, mem_desc, p.axis);
160 
161         // default pd ctor
162         auto pd = pd_t();
163         // regular pd ctor
164         ASSERT_NO_THROW(pd = pd_t(op_desc, eng, *pd_fwd_hint));
165         // test all pd ctors
166         test_bwd_pd_constructors<op_desc_t, pd_t, hint_pd_t>(
167                 op_desc, pd, *pd_fwd_hint, aa);
168 
169         // default primitive ctor
170         auto softmax = softmax_backward();
171         // regular primitive ctor
172         softmax = softmax_backward(pd);
173 
174         // query for diff_data_desc from pd via diff_src
175         const auto diff_data_desc = pd.diff_src_desc();
176         // query for diff_data_desc from pd via diff_dst
177         ASSERT_TRUE(pd.diff_dst_desc() == diff_data_desc);
178         // query for diff_data_desc via exec arg number of src
179         ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
180                 == diff_data_desc);
181         // query for diff_data_desc via exec arg number of dst
182         ASSERT_TRUE(pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
183                 == diff_data_desc);
184 
185         // check primitive returns zero_md for all rest md
186         ASSERT_TRUE(pd.src_desc().is_zero());
187         ASSERT_TRUE(pd.weights_desc().is_zero());
188         ASSERT_TRUE(pd.diff_weights_desc().is_zero());
189 
190         auto diff_src = test::make_memory(diff_data_desc, eng);
191         auto diff_dst = test::make_memory(diff_data_desc, eng);
192 
193         auto test_with_given_fill = [&](data_t mean, data_t var) {
194             // Fill the softmax backward diffs
195             fill_data<data_t>(diff_data_desc.get_size() / sizeof(data_t),
196                     diff_dst, data_t(0), data_t(1));
197             check_zero_tail<data_t>(1, diff_dst);
198 
199             softmax.execute(strm,
200                     {{DNNL_ARG_DST, dst}, {DNNL_ARG_DIFF_DST, diff_dst},
201                             {DNNL_ARG_DIFF_SRC, diff_src},
202                             {DNNL_ARG_WORKSPACE, workspace}});
203             strm.wait();
204 
205             check_zero_tail<data_t>(0, diff_src);
206         };
207 
208         test_with_given_fill(0, 1);
209     }
210 
Test()211     void Test() {
212         Forward();
213         if (p.aprop_kind == prop_kind::backward_data) Backward();
214     }
215 };
216 
217 using softmax_forward_test_float = softmax_test_t<float>;
218 using softmax_forward_test_half = softmax_test_t<float16_t>;
219 using softmax_forward_test_bfloat16 = softmax_test_t<bfloat16_t>;
220 
221 using softmax_backward_test_float = softmax_test_t<float>;
222 
223 template <typename dt>
224 using test_params = softmax_test_params_t<dt>;
225 
TEST_P(softmax_forward_test_float,TestsSoftmax)226 TEST_P(softmax_forward_test_float, TestsSoftmax) {}
227 INSTANTIATE_TEST_SUITE_P(TestSoftmaxForwardFloat, softmax_forward_test_float,
228         ::testing::Values(test_params<float> {prop_kind::forward_training,
229                                   tag::nchw, tag::undef, {2, -2, 128, 256}, 0,
230                                   true, dnnl_invalid_arguments},
231                 test_params<float> {prop_kind::forward_training, tag::nchw,
232                         tag::undef, {2, 2, 128, 256}, 5, true,
233                         dnnl_invalid_arguments},
234                 test_params<float> {prop_kind::forward_training, tag::nchw,
235                         tag::undef, {2, 0, 5, 5}, 0},
236                 test_params<float> {prop_kind::forward_training, tag::nchw,
237                         tag::undef, {2, 0, 5, 5}, 1},
238                 test_params<float> {prop_kind::forward_training, tag::nchw,
239                         tag::undef, {2, 19, 16, 64}, 1},
240                 test_params<float> {prop_kind::forward_training, tag::nchw,
241                         tag::undef, {1, 8, 128, 1024}, 3},
242                 test_params<float> {prop_kind::forward_inference, tag::nc,
243                         tag::undef, {2, 1000}, 0},
244                 test_params<float> {prop_kind::forward_inference, tag::nc,
245                         tag::undef, {2, 1000}, 1},
246                 test_params<float> {prop_kind::forward_inference, tag::nc,
247                         tag::undef, {1, 13}, 1},
248                 test_params<float> {prop_kind::forward_inference, tag::ncw,
249                         tag::undef, {16, 257, 32}, 1},
250                 test_params<float> {prop_kind::forward_inference, tag::ncw,
251                         tag::undef, {16, 257, 32}, 2},
252                 test_params<float> {prop_kind::forward_inference, tag::nChw8c,
253                         tag::undef, {64, 1011, 1, 1}, 1},
254                 test_params<float> {prop_kind::forward_inference, tag::nChw8c,
255                         tag::undef, {2, 1011, 32, 1}, 2}));
256 
TEST_P(softmax_forward_test_bfloat16,TestsSoftmax)257 TEST_P(softmax_forward_test_bfloat16, TestsSoftmax) {}
258 GPU_INSTANTIATE_TEST_SUITE_P(TestSoftmaxForwardBfloat16,
259         softmax_forward_test_bfloat16,
260         ::testing::Values(test_params<bfloat16_t> {prop_kind::forward_training,
261                                   tag::nchw, tag::undef, {2, -2, 128, 256}, 0,
262                                   true, dnnl_invalid_arguments},
263                 test_params<bfloat16_t> {prop_kind::forward_training, tag::nchw,
264                         tag::undef, {2, 2, 128, 256}, 5, true,
265                         dnnl_invalid_arguments},
266                 test_params<bfloat16_t> {prop_kind::forward_training, tag::nchw,
267                         tag::undef, {2, 0, 5, 5}, 0},
268                 test_params<bfloat16_t> {prop_kind::forward_training, tag::nchw,
269                         tag::undef, {2, 0, 5, 5}, 1},
270                 test_params<bfloat16_t> {prop_kind::forward_training, tag::nchw,
271                         tag::undef, {2, 19, 16, 64}, 1},
272                 test_params<bfloat16_t> {prop_kind::forward_training, tag::nchw,
273                         tag::undef, {1, 8, 128, 1024}, 3},
274                 test_params<bfloat16_t> {prop_kind::forward_inference, tag::nc,
275                         tag::undef, {2, 1000}, 0},
276                 test_params<bfloat16_t> {prop_kind::forward_inference, tag::nc,
277                         tag::undef, {2, 1000}, 1},
278                 test_params<bfloat16_t> {prop_kind::forward_inference, tag::nc,
279                         tag::undef, {1, 13}, 1},
280                 test_params<bfloat16_t> {prop_kind::forward_inference, tag::ncw,
281                         tag::undef, {16, 257, 32}, 1},
282                 test_params<bfloat16_t> {prop_kind::forward_inference, tag::ncw,
283                         tag::undef, {16, 257, 32}, 2},
284                 test_params<bfloat16_t> {prop_kind::forward_inference,
285                         tag::nChw8c, tag::undef, {64, 1011, 1, 1}, 1},
286                 test_params<bfloat16_t> {prop_kind::forward_inference,
287                         tag::nChw8c, tag::undef, {2, 1011, 32, 1}, 2}));
288 
TEST_P(softmax_forward_test_half,TestsSoftmax)289 TEST_P(softmax_forward_test_half, TestsSoftmax) {}
290 GPU_INSTANTIATE_TEST_SUITE_P(TestSoftmaxForwardHalf, softmax_forward_test_half,
291         ::testing::Values(test_params<float16_t> {prop_kind::forward_training,
292                                   tag::nchw, tag::undef, {2, -2, 128, 256}, 0,
293                                   true, dnnl_invalid_arguments},
294                 test_params<float16_t> {prop_kind::forward_training, tag::nchw,
295                         tag::undef, {2, 2, 128, 256}, 5, true,
296                         dnnl_invalid_arguments},
297                 test_params<float16_t> {prop_kind::forward_training, tag::nchw,
298                         tag::undef, {2, 0, 5, 5}, 0},
299                 test_params<float16_t> {prop_kind::forward_training, tag::nchw,
300                         tag::undef, {2, 0, 5, 5}, 1},
301                 test_params<float16_t> {prop_kind::forward_training, tag::nchw,
302                         tag::undef, {2, 19, 16, 64}, 1},
303                 test_params<float16_t> {prop_kind::forward_training, tag::nchw,
304                         tag::undef, {1, 8, 128, 1024}, 3},
305                 test_params<float16_t> {prop_kind::forward_inference, tag::nc,
306                         tag::undef, {2, 1000}, 0},
307                 test_params<float16_t> {prop_kind::forward_inference, tag::nc,
308                         tag::undef, {2, 1000}, 1},
309                 test_params<float16_t> {prop_kind::forward_inference, tag::nc,
310                         tag::undef, {1, 13}, 1},
311                 test_params<float16_t> {prop_kind::forward_inference, tag::ncw,
312                         tag::undef, {16, 257, 32}, 1},
313                 test_params<float16_t> {prop_kind::forward_inference, tag::ncw,
314                         tag::undef, {16, 257, 32}, 2},
315                 test_params<float16_t> {prop_kind::forward_inference,
316                         tag::nChw8c, tag::undef, {64, 1011, 1, 1}, 1},
317                 test_params<float16_t> {prop_kind::forward_inference,
318                         tag::nChw8c, tag::undef, {2, 1011, 32, 1}, 2}));
319 
TEST_P(softmax_backward_test_float,TestsSoftmax)320 TEST_P(softmax_backward_test_float, TestsSoftmax) {}
321 INSTANTIATE_TEST_SUITE_P(TestSoftmaxBackward, softmax_backward_test_float,
322         ::testing::Values(test_params<float> {prop_kind::backward_data,
323                                   tag::nchw, tag::nchw, {2, -2, 128, 256}, 0,
324                                   true, dnnl_invalid_arguments},
325                 test_params<float> {prop_kind::backward_data, tag::nchw,
326                         tag::nchw, {2, 19, 128, 256}, 5, true,
327                         dnnl_invalid_arguments},
328                 test_params<float> {prop_kind::backward_data, tag::nchw,
329                         tag::nchw, {2, 0, 5, 5}, 0},
330                 test_params<float> {prop_kind::backward_data, tag::nhwc,
331                         tag::nchw, {2, 0, 5, 5}, 1},
332                 test_params<float> {prop_kind::backward_data, tag::nchw,
333                         tag::nchw, {2, 19, 16, 64}, 1},
334                 test_params<float> {prop_kind::backward_data, tag::nhwc,
335                         tag::nchw, {1, 8, 128, 1024}, 3},
336                 test_params<float> {prop_kind::backward_data, tag::cn, tag::nc,
337                         {2, 1000}, 0},
338                 test_params<float> {prop_kind::backward_data, tag::nc, tag::nc,
339                         {2, 1000}, 1},
340                 test_params<float> {
341                         prop_kind::backward_data, tag::nc, tag::cn, {1, 13}, 1},
342                 test_params<float> {prop_kind::backward_data, tag::ncw,
343                         tag::ncw, {16, 257, 32}, 1},
344                 test_params<float> {prop_kind::backward_data, tag::nCw16c,
345                         tag::ncw, {16, 257, 32}, 2},
346                 test_params<float> {prop_kind::backward_data, tag::nChw8c,
347                         tag::nChw8c, {64, 1011, 1, 1}, 1},
348                 test_params<float> {prop_kind::backward_data, tag::nchw,
349                         tag::nChw8c, {2, 1011, 32, 1}, 2}));
350 } // namespace dnnl
351