1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <cmath>
18 #include <memory>
19 
20 #include "dnnl_test_common.hpp"
21 #include "gtest/gtest.h"
22 
23 #include "oneapi/dnnl/dnnl.hpp"
24 
25 #define CPU_INST_TEST_CASE(str, ...) \
26     CPU_INSTANTIATE_TEST_SUITE_P( \
27             str, lnorm_test_t, ::testing::Values(__VA_ARGS__));
28 
29 namespace dnnl {
30 
31 struct test_lnorm_params_t {
32     memory::format_tag data_tag;
33     memory::format_tag stat_tag;
34     memory::format_tag diff_tag;
35     memory::dims dims;
36     float epsilon;
37     bool expect_to_fail;
38     dnnl_status_t expected_status;
39 };
40 
41 template <typename T>
fill(const memory & m)42 void fill(const memory &m) {
43     auto numElements = m.get_desc().get_size() / sizeof(T);
44     fill_data<T>(numElements, m);
45 }
46 
47 class lnorm_test_t : public ::testing::TestWithParam<test_lnorm_params_t> {
48 private:
49     std::shared_ptr<test_memory> src, dst, diff_src, diff_dst;
50     memory weights, bias, diff_weights, diff_bias, mean, variance;
51 
52     std::shared_ptr<memory::desc> data_d;
53     std::shared_ptr<memory::desc> stat_d;
54     std::shared_ptr<memory::desc> diff_d;
55 
56     layer_normalization_forward::primitive_desc lnorm_fwd_pd;
57     layer_normalization_backward::primitive_desc lnorm_bwd_pd;
58 
59     test_lnorm_params_t p;
60     engine eng;
61     stream strm;
62 
63 protected:
SetUp()64     void SetUp() override {
65         SKIP_IF_CUDA(true, "Layer normalization not supported by CUDA.");
66         p = ::testing::TestWithParam<decltype(p)>::GetParam();
67         catch_expected_failures(
68                 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
69     }
70 
Test()71     void Test() {
72         eng = get_test_engine();
73         strm = make_stream(eng);
74 
75         data_d = std::make_shared<memory::desc>(
76                 p.dims, memory::data_type::f32, p.data_tag);
77         memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
78         stat_d = std::make_shared<memory::desc>(
79                 stat_dims, memory::data_type::f32, p.stat_tag);
80         diff_d = std::make_shared<memory::desc>(
81                 p.dims, memory::data_type::f32, p.diff_tag);
82 
83         src = std::make_shared<test_memory>(*data_d, eng);
84         dst = std::make_shared<test_memory>(*data_d, eng);
85         diff_src = std::make_shared<test_memory>(*diff_d, eng);
86         diff_dst = std::make_shared<test_memory>(*diff_d, eng);
87 
88         auto training = prop_kind::forward_training;
89         auto inference = prop_kind::forward_inference;
90 
91         using flags = normalization_flags;
92         Forward(training);
93         Forward(training, flags::use_global_stats);
94         Forward(training, flags::use_scale_shift);
95         Forward(training, flags::use_scale);
96         Forward(training, flags::use_shift);
97         Forward(training, flags::use_scale | flags::use_shift);
98         Forward(training, flags::use_scale_shift | flags::use_global_stats);
99         Forward(inference);
100         Forward(inference, flags::use_global_stats);
101         Forward(inference, flags::use_scale_shift);
102 
103         Backward(prop_kind::backward_data);
104         Backward(prop_kind::backward_data, flags::use_global_stats);
105         Backward(prop_kind::backward, flags::use_scale_shift);
106         Backward(prop_kind::backward, flags::use_scale);
107         Backward(prop_kind::backward, flags::use_shift);
108         Backward(prop_kind::backward, flags::use_scale | flags::use_shift);
109         Backward(prop_kind::backward,
110                 flags::use_scale_shift | flags::use_global_stats);
111     }
112 
Forward(prop_kind pk,normalization_flags flags=normalization_flags::none)113     void Forward(prop_kind pk,
114             normalization_flags flags = normalization_flags::none) {
115         fwd_iface_test_stat_any(pk, flags);
116 
117         bool useScaleShift
118                 = (bool)(flags & normalization_flags::use_scale_shift);
119         bool useScale = (bool)(flags & normalization_flags::use_scale);
120         bool useShift = (bool)(flags & normalization_flags::use_shift);
121         bool useGlobalStats
122                 = (bool)(flags & normalization_flags::use_global_stats);
123         bool isTraining = pk == prop_kind::forward_training;
124 
125         auto lnorm_fwd_d = layer_normalization_forward::desc(
126                 pk, *data_d, *stat_d, p.epsilon, flags);
127 
128         lnorm_fwd_pd
129                 = layer_normalization_forward::primitive_desc(lnorm_fwd_d, eng);
130         lnorm_fwd_pd = layer_normalization_forward::primitive_desc(
131                 lnorm_fwd_pd.get()); // test construction from a C pd
132 
133         ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
134                 == lnorm_fwd_pd.src_desc());
135         ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
136                 == lnorm_fwd_pd.dst_desc());
137         ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN)
138                 == lnorm_fwd_pd.mean_desc());
139         ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
140                 == lnorm_fwd_pd.variance_desc());
141         ASSERT_TRUE(
142                 lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE_SHIFT)
143                 == lnorm_fwd_pd.weights_desc());
144 
145         if (useScaleShift || useScale)
146             weights = test::make_memory(lnorm_fwd_pd.weights_desc(), eng);
147         if (useShift)
148             bias = test::make_memory(lnorm_fwd_pd.weights_desc(), eng);
149         if (isTraining || useGlobalStats) {
150             mean = test::make_memory(*stat_d, eng);
151             variance = test::make_memory(*stat_d, eng);
152         }
153 
154         fill<float>(src->get());
155         fill<float>(dst->get());
156         if (useScaleShift || useScale) fill<float>(weights);
157         if (useShift) fill<float>(bias);
158         if (useGlobalStats) {
159             fill<float>(mean);
160             fill<float>(variance);
161         }
162 
163         execlnormFwd(
164                 isTraining, useGlobalStats, useScaleShift, useScale, useShift);
165         check_lnorm_fwd(p, src->get(), mean, variance, weights, bias,
166                 dst->get(), flags, pk);
167     }
168 
Backward(prop_kind pk,normalization_flags flags=normalization_flags::none)169     void Backward(prop_kind pk,
170             normalization_flags flags = normalization_flags::none) {
171         bwd_iface_test_stat_any(pk, flags);
172 
173         bool useScaleShift
174                 = (bool)(flags & normalization_flags::use_scale_shift);
175         bool useScale = (bool)(flags & normalization_flags::use_scale);
176         bool useShift = (bool)(flags & normalization_flags::use_shift);
177 
178         auto lnorm_fwd_d
179                 = layer_normalization_forward::desc(prop_kind::forward_training,
180                         *data_d, *stat_d, p.epsilon, flags);
181         lnorm_fwd_pd
182                 = layer_normalization_forward::primitive_desc(lnorm_fwd_d, eng);
183 
184         auto lnorm_bwd_d = layer_normalization_backward::desc(
185                 pk, *diff_d, *data_d, *stat_d, p.epsilon, flags);
186         lnorm_bwd_pd = layer_normalization_backward::primitive_desc(
187                 lnorm_bwd_d, eng, lnorm_fwd_pd);
188         lnorm_bwd_pd = layer_normalization_backward::primitive_desc(
189                 lnorm_bwd_pd.get()); // test construction from a C pd
190 
191         ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
192                 == lnorm_bwd_pd.src_desc());
193         ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
194                 == lnorm_bwd_pd.diff_src_desc());
195         ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
196                 == lnorm_bwd_pd.diff_dst_desc());
197         ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN)
198                 == lnorm_bwd_pd.mean_desc());
199         ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
200                 == lnorm_bwd_pd.variance_desc());
201         ASSERT_TRUE(
202                 lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE_SHIFT)
203                 == lnorm_bwd_pd.weights_desc());
204         ASSERT_TRUE(lnorm_bwd_pd.query_md(
205                             query::exec_arg_md, DNNL_ARG_DIFF_SCALE_SHIFT)
206                 == lnorm_bwd_pd.diff_weights_desc());
207 
208         if (useScaleShift || useScale)
209             weights = test::make_memory(lnorm_bwd_pd.weights_desc(), eng);
210         if (useShift)
211             bias = test::make_memory(lnorm_bwd_pd.weights_desc(), eng);
212         if (useScaleShift || useScale)
213             diff_weights
214                     = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng);
215         if (useShift)
216             diff_bias
217                     = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng);
218         mean = test::make_memory(*stat_d, eng);
219         variance = test::make_memory(*stat_d, eng);
220 
221         if (useScaleShift || useScale) fill<float>(weights);
222         if (useShift) fill<float>(bias);
223         fill<float>(diff_src->get());
224         fill<float>(diff_dst->get());
225         fill<float>(mean);
226         fill<float>(variance);
227 
228         execlnormBwd(useScaleShift, useScale, useShift, pk);
229         check_lnorm_bwd(p, src->get(), diff_dst->get(), mean, variance, weights,
230                 diff_src->get(), diff_weights, diff_bias, flags, pk);
231     }
232 
execlnormFwd(bool isTraining,bool useGlobalStats,bool useScaleShift,bool useScale,bool useShift)233     void execlnormFwd(bool isTraining, bool useGlobalStats, bool useScaleShift,
234             bool useScale, bool useShift) {
235         std::unordered_map<int, memory> args = {
236                 {DNNL_ARG_SRC, src->get()},
237                 {DNNL_ARG_DST, dst->get()},
238         };
239 
240         if (useScaleShift) args.insert({DNNL_ARG_SCALE_SHIFT, weights});
241         if (useScale) args.insert({DNNL_ARG_SCALE, weights});
242         if (useShift) args.insert({DNNL_ARG_SHIFT, bias});
243 
244         if (isTraining || useGlobalStats) {
245             args.insert({DNNL_ARG_MEAN, mean});
246             args.insert({DNNL_ARG_VARIANCE, variance});
247         }
248 
249         layer_normalization_forward(lnorm_fwd_pd).execute(strm, args);
250         strm.wait();
251     }
252 
execlnormBwd(bool useScaleShift,bool useScale,bool useShift,prop_kind pk)253     void execlnormBwd(
254             bool useScaleShift, bool useScale, bool useShift, prop_kind pk) {
255         std::unordered_map<int, memory> args = {
256                 {DNNL_ARG_SRC, src->get()},
257                 {DNNL_ARG_DIFF_DST, diff_dst->get()},
258                 {DNNL_ARG_MEAN, mean},
259                 {DNNL_ARG_VARIANCE, variance},
260                 {DNNL_ARG_DIFF_SRC, diff_src->get()},
261         };
262 
263         if (useScaleShift) {
264             args.insert({DNNL_ARG_SCALE_SHIFT, weights});
265             if (pk == prop_kind::backward)
266                 args.insert({DNNL_ARG_DIFF_SCALE_SHIFT, diff_weights});
267         }
268 
269         if (useScale) {
270             args.insert({DNNL_ARG_SCALE, weights});
271             if (pk == prop_kind::backward)
272                 args.insert({DNNL_ARG_DIFF_SCALE, diff_weights});
273         }
274 
275         if (useShift) {
276             args.insert({DNNL_ARG_SHIFT, bias});
277             if (pk == prop_kind::backward)
278                 args.insert({DNNL_ARG_DIFF_SHIFT, diff_bias});
279         }
280 
281         layer_normalization_backward(lnorm_bwd_pd).execute(strm, args);
282         strm.wait();
283     }
284 
check_lnorm_fwd(const test_lnorm_params_t & p,const memory & src,const memory & mean,const memory & variance,const memory & weights,const memory & bias,const memory & dst,normalization_flags flags,prop_kind pk)285     void check_lnorm_fwd(const test_lnorm_params_t &p, const memory &src,
286             const memory &mean, const memory &variance, const memory &weights,
287             const memory &bias, const memory &dst, normalization_flags flags,
288             prop_kind pk) {
289         const size_t nelems = std::accumulate(p.dims.begin(), p.dims.end(),
290                 size_t(1), std::multiplies<size_t>());
291         if (!nelems) return;
292 
293         const bool use_weights_bias
294                 = (bool)(flags & normalization_flags::use_scale_shift);
295         const bool use_weights = (bool)(flags & normalization_flags::use_scale);
296         const bool use_bias = (bool)(flags & normalization_flags::use_shift);
297         const bool calculate_stats
298                 = !(bool)(flags & normalization_flags::use_global_stats);
299         const bool is_training = pk == prop_kind::forward_training;
300 
301         const memory::desc src_d = src.get_desc();
302         const memory::desc dst_d = dst.get_desc();
303         const memory::desc weights_d
304                 = use_weights_bias || use_weights || use_bias
305                 ? weights.get_desc()
306                 : memory::desc();
307         const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
308         const dnnl::impl::memory_desc_wrapper stat_mdw((*stat_d).data);
309         const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.data);
310         const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.data);
311 
312         const auto ndims = src_mdw.ndims();
313         const auto C = src_mdw.dims()[ndims - 1];
314 
315         auto src_data = map_memory<const float>(src);
316         GTEST_EXPECT_NE(src_data, nullptr);
317         auto dst_data = map_memory<const float>(dst);
318         GTEST_EXPECT_NE(dst_data, nullptr);
319         auto weights_data = use_weights_bias || use_weights
320                 ? map_memory<const float>(weights)
321                 : nullptr;
322         if (use_weights_bias || use_weights)
323             GTEST_EXPECT_NE(weights_data, nullptr);
324         const size_t bias_off = use_weights_bias && !weights_mdw.has_zero_dim()
325                 ? weights_mdw.off_l(C, true)
326                 : 0;
327         auto bias_data = use_bias
328                 ? map_memory<const float>(bias)
329                 : use_weights_bias ? &weights_data[bias_off] : nullptr;
330         if (use_weights_bias || use_bias) GTEST_EXPECT_NE(bias_data, nullptr);
331         auto mean_data = (!calculate_stats || is_training)
332                 ? map_memory<const float>(mean)
333                 : nullptr;
334         if (!calculate_stats || is_training)
335             GTEST_EXPECT_NE(mean_data, nullptr);
336         auto variance_data = (!calculate_stats || is_training)
337                 ? map_memory<const float>(variance)
338                 : nullptr;
339         if (!calculate_stats || is_training)
340             GTEST_EXPECT_NE(variance_data, nullptr);
341 
342         if (!calculate_stats || is_training) {
343             const memory::desc stat_d = mean.get_desc();
344             const dnnl::impl::memory_desc_wrapper stat_mdw(stat_d.data);
345         }
346 
347         float eps = static_cast<float>(1.e-4 * nelems / C);
348         dnnl::impl::parallel_nd(nelems / C, [&](memory::dim n) {
349             if (is_current_test_failed()) return;
350             float ref_mean = float(0);
351             float ref_variance = float(0);
352             const auto stat_off = stat_mdw.off_l(n);
353 
354             if (calculate_stats) {
355                 for (memory::dim c = 0; c < C; c++)
356                     ref_mean += src_data[src_mdw.off_l(n * C + c)];
357                 ref_mean /= C;
358 
359                 if (is_training) {
360                     float mean_norm_max = std::max(
361                             std::abs(mean_data[stat_off]), std::abs(ref_mean));
362                     if (mean_norm_max < eps) mean_norm_max = float(1);
363                     ASSERT_NEAR(
364                             (mean_data[stat_off] - ref_mean) / mean_norm_max,
365                             0., eps);
366                 }
367 
368                 for (memory::dim c = 0; c < C; c++) {
369                     float tmp = src_data[src_mdw.off_l(n * C + c)] - ref_mean;
370                     ref_variance += tmp * tmp;
371                 }
372                 ref_variance /= C;
373 
374                 if (is_training) {
375                     float variance_norm_max
376                             = std::max(std::abs(variance_data[stat_off]),
377                                     std::abs(ref_variance));
378                     if (variance_norm_max < eps) variance_norm_max = float(1);
379                     ASSERT_NEAR((variance_data[stat_off] - ref_variance)
380                                     / variance_norm_max,
381                             0., eps);
382                 }
383             } else {
384                 ref_mean = mean_data[stat_off];
385                 ref_variance = variance_data[stat_off];
386             }
387 
388             float ref_sqrt_variance
389                     = static_cast<float>(sqrt(ref_variance + p.epsilon));
390             float ref_rsqrt_variance = float(1) / (ref_sqrt_variance);
391 
392             for (memory::dim c = 0; c < C; c++) {
393                 float ref_dst = float(0);
394                 float wei = (use_weights_bias || use_weights)
395                         ? weights_data[weights_mdw.off_l(c, true)]
396                         : 1.0f;
397                 float bia = (use_weights_bias || use_bias)
398                         ? bias_data[weights_mdw.off_l(c, true)]
399                         : 0.0f;
400                 ref_dst = wei
401                                 * ((float)src_data[src_mdw.off_l(n * C + c)]
402                                         - ref_mean)
403                                 * ref_rsqrt_variance
404                         + bia;
405 
406                 float out = dst_data[dst_mdw.off_l(n * C + c)];
407                 float norm_max = std::max(std::abs(out), std::abs(ref_dst));
408                 if (norm_max < 1e-2) norm_max = 1.;
409                 ASSERT_NEAR((out - ref_dst) / norm_max, 0., eps);
410             }
411         });
412     }
413 
check_lnorm_bwd(const test_lnorm_params_t & p,const memory & src,const memory & diff_dst,const memory & mean,const memory & variance,const memory & weights,const memory & diff_src,const memory & diff_weights,const memory & diff_bias,normalization_flags flags,prop_kind pk)414     void check_lnorm_bwd(const test_lnorm_params_t &p, const memory &src,
415             const memory &diff_dst, const memory &mean, const memory &variance,
416             const memory &weights, const memory &diff_src,
417             const memory &diff_weights, const memory &diff_bias,
418             normalization_flags flags, prop_kind pk) {
419         const ptrdiff_t nelems = std::accumulate(p.dims.begin(), p.dims.end(),
420                 size_t(1), std::multiplies<size_t>());
421         if (!nelems) return;
422 
423         const bool use_weights_bias
424                 = (bool)(flags & normalization_flags::use_scale_shift);
425         const bool use_weights = (bool)(flags & normalization_flags::use_scale);
426         const bool use_bias = (bool)(flags & normalization_flags::use_shift);
427         const bool calculate_diff_stats
428                 = !(bool)(flags & normalization_flags::use_global_stats);
429 
430         const memory::desc src_d = src.get_desc();
431         const memory::desc diff_dst_d = diff_dst.get_desc();
432         const memory::desc weights_d = use_weights || use_weights_bias
433                 ? weights.get_desc()
434                 : memory::desc();
435         const memory::desc diff_src_d = diff_src.get_desc();
436         const memory::desc diff_weights_d = use_weights || use_weights_bias
437                 ? diff_weights.get_desc()
438                 : memory::desc();
439         const memory::desc diff_bias_d = use_bias
440                 ? diff_bias.get_desc()
441                 : use_weights_bias ? diff_weights.get_desc() : memory::desc();
442 
443         const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
444         const dnnl::impl::memory_desc_wrapper stat_mdw((*stat_d).data);
445         const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.data);
446         const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.data);
447         const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.data);
448         const dnnl::impl::memory_desc_wrapper diff_weights_mdw(
449                 diff_weights_d.data);
450         const dnnl::impl::memory_desc_wrapper diff_bias_mdw(diff_bias_d.data);
451 
452         const auto ndims = src_mdw.ndims();
453         const auto C = src_mdw.dims()[ndims - 1];
454 
455         auto src_data = map_memory<const float>(src);
456         GTEST_EXPECT_NE(src_data, nullptr);
457         auto weights_data = use_weights_bias || use_weights
458                 ? map_memory<const float>(weights)
459                 : nullptr;
460         if (use_weights_bias || use_weights)
461             GTEST_EXPECT_NE(weights_data, nullptr);
462         auto diff_dst_data = map_memory<const float>(diff_dst);
463         GTEST_EXPECT_NE(diff_dst_data, nullptr);
464         auto mean_data = map_memory<const float>(mean);
465         GTEST_EXPECT_NE(mean_data, nullptr);
466         auto variance_data = map_memory<const float>(variance);
467         GTEST_EXPECT_NE(variance_data, nullptr);
468         const auto diff_src_data = map_memory<float>(diff_src);
469         GTEST_EXPECT_NE(diff_src_data, nullptr);
470         const auto diff_weights_data
471                 = (pk == prop_kind::backward
472                           && (use_weights_bias || use_weights))
473                 ? map_memory<float>(diff_weights)
474                 : nullptr;
475         if (pk == prop_kind::backward && (use_weights_bias || use_weights))
476             GTEST_EXPECT_NE(diff_weights_data, nullptr);
477         const size_t diff_bias_off
478                 = use_weights_bias && !diff_weights_mdw.has_zero_dim()
479                 ? diff_weights_mdw.off_l(C, true)
480                 : 0;
481         const auto diff_bias_data = (pk == prop_kind::backward) ? use_bias
482                         ? map_memory<float>(diff_bias)
483                         : &diff_weights_data[diff_bias_off]
484                                                                 : nullptr;
485         if (pk == prop_kind::backward && (use_weights_bias || use_bias))
486             GTEST_EXPECT_NE(diff_bias_data, nullptr);
487 
488         if (nelems == 0) {
489             if (pk == prop_kind::backward) {
490                 if (use_weights_bias || use_weights) {
491                     for (memory::dim c = 0; c < C; ++c) {
492                         auto dg = diff_weights_data[diff_weights_mdw.off_l(
493                                 c, true)];
494                         ASSERT_NEAR(dg, 0., 1e-7);
495                     }
496                 }
497                 if (use_weights_bias || use_bias) {
498                     for (memory::dim c = 0; c < C; ++c) {
499                         auto db = diff_bias_data[diff_bias_mdw.off_l(c, true)];
500                         ASSERT_NEAR(db, 0., 1e-7);
501                     }
502                 }
503             }
504             return;
505         }
506 
507         const float eps = static_cast<float>(1.e-4 * nelems / C);
508 
509         dnnl::impl::parallel_nd(C, [&](memory::dim c) {
510             if (is_current_test_failed()) return;
511 
512             float ref_diff_gamma = float(0);
513             float ref_diff_beta = float(0);
514             for (memory::dim n = 0; n < nelems / C; n++) {
515                 size_t stat_off = stat_mdw.off_l(n);
516                 const float sqrt_variance
517                         = 1.0f / sqrt(variance_data[stat_off] + p.epsilon);
518 
519                 ref_diff_gamma += (src_data[src_mdw.off_l(n * C + c)]
520                                           - mean_data[stat_off])
521                         * diff_dst_data[diff_dst_mdw.off_l(n * C + c)]
522                         * sqrt_variance;
523                 ref_diff_beta += diff_dst_data[diff_dst_mdw.off_l(n * C + c)];
524             }
525 
526             if (pk == prop_kind::backward) {
527                 if (use_weights_bias || use_weights) {
528                     auto diff_gamma
529                             = diff_weights_data[diff_weights_mdw.off_l(c)];
530                     float norm_max = std::max(
531                             std::abs(diff_gamma), std::abs(ref_diff_gamma));
532                     if (norm_max < 1e-2) norm_max = float(1);
533                     ASSERT_NEAR(
534                             (diff_gamma - ref_diff_gamma) / norm_max, 0., eps);
535                 }
536 
537                 if (use_weights_bias || use_bias) {
538                     auto diff_beta = diff_bias_data[diff_bias_mdw.off_l(c)];
539                     float norm_max = std::max(
540                             std::abs(diff_beta), std::abs(ref_diff_beta));
541                     if (norm_max < 1e-2) norm_max = float(1);
542                     ASSERT_NEAR(
543                             (diff_beta - ref_diff_beta) / norm_max, 0., eps);
544                 }
545             }
546         });
547 
548         dnnl::impl::parallel_nd(nelems / C, [&](memory::dim n) {
549             if (is_current_test_failed()) return;
550 
551             size_t stat_off = stat_mdw.off_l(n);
552             const float sqrt_variance
553                     = 1.0f / sqrt(variance_data[stat_off] + p.epsilon);
554 
555             float ref_dd_gamma = float(0);
556             float ref_dd_gamma_x = float(0);
557             if (calculate_diff_stats) {
558                 for (memory::dim c = 0; c < C; c++) {
559                     auto gamma = use_weights_bias || use_weights
560                             ? weights_data[weights_mdw.off_l(c)]
561                             : 1;
562                     ref_dd_gamma += diff_dst_data[diff_dst_mdw.off_l(n * C + c)]
563                             * gamma;
564                     ref_dd_gamma_x
565                             += diff_dst_data[diff_dst_mdw.off_l(n * C + c)]
566                             * gamma
567                             * (src_data[src_mdw.off_l(n * C + c)]
568                                     - mean_data[stat_off]);
569                 }
570                 ref_dd_gamma_x *= sqrt_variance;
571             }
572             for (memory::dim c = 0; c < C; c++) {
573                 auto gamma = use_weights_bias || use_weights
574                         ? weights_data[weights_mdw.off_l(c)]
575                         : 1;
576                 float ref_diff_src
577                         = diff_dst_data[diff_dst_mdw.off_l(n * C + c)] * gamma;
578                 if (calculate_diff_stats) {
579                     ref_diff_src -= ref_dd_gamma / C
580                             + (src_data[src_mdw.off_l(n * C + c)]
581                                       - mean_data[stat_off])
582                                     * ref_dd_gamma_x * sqrt_variance / C;
583                 }
584                 ref_diff_src *= sqrt_variance;
585                 float out_diff_src
586                         = diff_src_data[diff_src_mdw.off_l(n * C + c)];
587                 float norm_max = std::max(
588                         std::abs(out_diff_src), std::abs(ref_diff_src));
589                 if (norm_max < eps) norm_max = float(1);
590                 ASSERT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps);
591             }
592         });
593     }
594 
fwd_iface_test_stat_any(prop_kind pk,normalization_flags flags)595     void fwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
596         // non stats if inference w/o use global stats
597         if (pk == prop_kind::forward_inference
598                 && !(bool)(flags & normalization_flags::use_global_stats))
599             return;
600 
601         using tag = memory::format_tag;
602 
603         tag expect_stat_tag = derive_stat_tag();
604         if (expect_stat_tag == tag::undef) return; // optimism
605 
606         memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
607         memory::desc expect_stat_md(
608                 stat_dims, memory::data_type::f32, expect_stat_tag);
609 
610         // no stat_md provided at all
611         {
612             layer_normalization_forward::primitive_desc fwd_pd(
613                     {pk, *data_d, p.epsilon, flags}, eng);
614 
615             EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
616             EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
617         }
618 
619         // stat_md with format_tag::any
620         {
621             memory::desc any_stat_md(
622                     stat_dims, memory::data_type::f32, tag::any);
623             layer_normalization_forward::primitive_desc fwd_pd(
624                     {pk, *data_d, any_stat_md, p.epsilon, flags}, eng);
625 
626             EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
627             EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
628         }
629     }
630 
bwd_iface_test_stat_any(prop_kind pk,normalization_flags flags)631     void bwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
632         using tag = memory::format_tag;
633 
634         tag expect_stat_tag = derive_stat_tag();
635         if (expect_stat_tag == tag::undef) return; // optimism
636 
637         memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
638         memory::desc expect_stat_md(
639                 stat_dims, memory::data_type::f32, expect_stat_tag);
640 
641         layer_normalization_forward::primitive_desc fwd_pd(
642                 {prop_kind::forward_training, *data_d, p.epsilon, flags}, eng);
643 
644         // no stat_md provided at all
645         {
646             layer_normalization_backward::primitive_desc bwd_pd(
647                     {pk, *diff_d, *data_d, p.epsilon, flags}, eng, fwd_pd);
648 
649             EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
650             EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
651         }
652 
653         // stat_md with format_tag::any
654         {
655             memory::desc any_stat_md(
656                     stat_dims, memory::data_type::f32, tag::any);
657             layer_normalization_backward::primitive_desc bwd_pd(
658                     {pk, *diff_d, *data_d, any_stat_md, p.epsilon, flags}, eng,
659                     fwd_pd);
660 
661             EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
662             EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
663         }
664     }
665 
666 private:
derive_stat_tag() const667     memory::format_tag derive_stat_tag() const {
668         using tag = memory::format_tag;
669         tag expect_stat_tag = tag::undef;
670 
671         // TODO: add more cases and test cases
672         // XXX: currently test only simple cases like `abc`, `acb`. Extend,
673         //      if possible, to blocked formats too.
674         switch (p.data_tag) {
675             case tag::abc: expect_stat_tag = tag::ab; break;
676             case tag::bac: expect_stat_tag = tag::ba; break;
677             default: break;
678         }
679 
680         return expect_stat_tag;
681     }
682 };
683 
TEST_P(lnorm_test_t,TestsLnormF32)684 TEST_P(lnorm_test_t, TestsLnormF32) {}
685 
686 #include "layer_normalization.h"
687 } // namespace dnnl
688