1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cmath>
18 #include <memory>
19
20 #include "dnnl_test_common.hpp"
21 #include "gtest/gtest.h"
22
23 #include "oneapi/dnnl/dnnl.hpp"
24
25 #define CPU_INST_TEST_CASE(str, ...) \
26 CPU_INSTANTIATE_TEST_SUITE_P( \
27 str, lnorm_test_t, ::testing::Values(__VA_ARGS__));
28
29 namespace dnnl {
30
31 struct test_lnorm_params_t {
32 memory::format_tag data_tag;
33 memory::format_tag stat_tag;
34 memory::format_tag diff_tag;
35 memory::dims dims;
36 float epsilon;
37 bool expect_to_fail;
38 dnnl_status_t expected_status;
39 };
40
41 template <typename T>
fill(const memory & m)42 void fill(const memory &m) {
43 auto numElements = m.get_desc().get_size() / sizeof(T);
44 fill_data<T>(numElements, m);
45 }
46
47 class lnorm_test_t : public ::testing::TestWithParam<test_lnorm_params_t> {
48 private:
49 std::shared_ptr<test_memory> src, dst, diff_src, diff_dst;
50 memory weights, bias, diff_weights, diff_bias, mean, variance;
51
52 std::shared_ptr<memory::desc> data_d;
53 std::shared_ptr<memory::desc> stat_d;
54 std::shared_ptr<memory::desc> diff_d;
55
56 layer_normalization_forward::primitive_desc lnorm_fwd_pd;
57 layer_normalization_backward::primitive_desc lnorm_bwd_pd;
58
59 test_lnorm_params_t p;
60 engine eng;
61 stream strm;
62
63 protected:
SetUp()64 void SetUp() override {
65 SKIP_IF_CUDA(true, "Layer normalization not supported by CUDA.");
66 p = ::testing::TestWithParam<decltype(p)>::GetParam();
67 catch_expected_failures(
68 [=]() { Test(); }, p.expect_to_fail, p.expected_status);
69 }
70
Test()71 void Test() {
72 eng = get_test_engine();
73 strm = make_stream(eng);
74
75 data_d = std::make_shared<memory::desc>(
76 p.dims, memory::data_type::f32, p.data_tag);
77 memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
78 stat_d = std::make_shared<memory::desc>(
79 stat_dims, memory::data_type::f32, p.stat_tag);
80 diff_d = std::make_shared<memory::desc>(
81 p.dims, memory::data_type::f32, p.diff_tag);
82
83 src = std::make_shared<test_memory>(*data_d, eng);
84 dst = std::make_shared<test_memory>(*data_d, eng);
85 diff_src = std::make_shared<test_memory>(*diff_d, eng);
86 diff_dst = std::make_shared<test_memory>(*diff_d, eng);
87
88 auto training = prop_kind::forward_training;
89 auto inference = prop_kind::forward_inference;
90
91 using flags = normalization_flags;
92 Forward(training);
93 Forward(training, flags::use_global_stats);
94 Forward(training, flags::use_scale_shift);
95 Forward(training, flags::use_scale);
96 Forward(training, flags::use_shift);
97 Forward(training, flags::use_scale | flags::use_shift);
98 Forward(training, flags::use_scale_shift | flags::use_global_stats);
99 Forward(inference);
100 Forward(inference, flags::use_global_stats);
101 Forward(inference, flags::use_scale_shift);
102
103 Backward(prop_kind::backward_data);
104 Backward(prop_kind::backward_data, flags::use_global_stats);
105 Backward(prop_kind::backward, flags::use_scale_shift);
106 Backward(prop_kind::backward, flags::use_scale);
107 Backward(prop_kind::backward, flags::use_shift);
108 Backward(prop_kind::backward, flags::use_scale | flags::use_shift);
109 Backward(prop_kind::backward,
110 flags::use_scale_shift | flags::use_global_stats);
111 }
112
Forward(prop_kind pk,normalization_flags flags=normalization_flags::none)113 void Forward(prop_kind pk,
114 normalization_flags flags = normalization_flags::none) {
115 fwd_iface_test_stat_any(pk, flags);
116
117 bool useScaleShift
118 = (bool)(flags & normalization_flags::use_scale_shift);
119 bool useScale = (bool)(flags & normalization_flags::use_scale);
120 bool useShift = (bool)(flags & normalization_flags::use_shift);
121 bool useGlobalStats
122 = (bool)(flags & normalization_flags::use_global_stats);
123 bool isTraining = pk == prop_kind::forward_training;
124
125 auto lnorm_fwd_d = layer_normalization_forward::desc(
126 pk, *data_d, *stat_d, p.epsilon, flags);
127
128 lnorm_fwd_pd
129 = layer_normalization_forward::primitive_desc(lnorm_fwd_d, eng);
130 lnorm_fwd_pd = layer_normalization_forward::primitive_desc(
131 lnorm_fwd_pd.get()); // test construction from a C pd
132
133 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
134 == lnorm_fwd_pd.src_desc());
135 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
136 == lnorm_fwd_pd.dst_desc());
137 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN)
138 == lnorm_fwd_pd.mean_desc());
139 ASSERT_TRUE(lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
140 == lnorm_fwd_pd.variance_desc());
141 ASSERT_TRUE(
142 lnorm_fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE_SHIFT)
143 == lnorm_fwd_pd.weights_desc());
144
145 if (useScaleShift || useScale)
146 weights = test::make_memory(lnorm_fwd_pd.weights_desc(), eng);
147 if (useShift)
148 bias = test::make_memory(lnorm_fwd_pd.weights_desc(), eng);
149 if (isTraining || useGlobalStats) {
150 mean = test::make_memory(*stat_d, eng);
151 variance = test::make_memory(*stat_d, eng);
152 }
153
154 fill<float>(src->get());
155 fill<float>(dst->get());
156 if (useScaleShift || useScale) fill<float>(weights);
157 if (useShift) fill<float>(bias);
158 if (useGlobalStats) {
159 fill<float>(mean);
160 fill<float>(variance);
161 }
162
163 execlnormFwd(
164 isTraining, useGlobalStats, useScaleShift, useScale, useShift);
165 check_lnorm_fwd(p, src->get(), mean, variance, weights, bias,
166 dst->get(), flags, pk);
167 }
168
Backward(prop_kind pk,normalization_flags flags=normalization_flags::none)169 void Backward(prop_kind pk,
170 normalization_flags flags = normalization_flags::none) {
171 bwd_iface_test_stat_any(pk, flags);
172
173 bool useScaleShift
174 = (bool)(flags & normalization_flags::use_scale_shift);
175 bool useScale = (bool)(flags & normalization_flags::use_scale);
176 bool useShift = (bool)(flags & normalization_flags::use_shift);
177
178 auto lnorm_fwd_d
179 = layer_normalization_forward::desc(prop_kind::forward_training,
180 *data_d, *stat_d, p.epsilon, flags);
181 lnorm_fwd_pd
182 = layer_normalization_forward::primitive_desc(lnorm_fwd_d, eng);
183
184 auto lnorm_bwd_d = layer_normalization_backward::desc(
185 pk, *diff_d, *data_d, *stat_d, p.epsilon, flags);
186 lnorm_bwd_pd = layer_normalization_backward::primitive_desc(
187 lnorm_bwd_d, eng, lnorm_fwd_pd);
188 lnorm_bwd_pd = layer_normalization_backward::primitive_desc(
189 lnorm_bwd_pd.get()); // test construction from a C pd
190
191 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
192 == lnorm_bwd_pd.src_desc());
193 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC)
194 == lnorm_bwd_pd.diff_src_desc());
195 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST)
196 == lnorm_bwd_pd.diff_dst_desc());
197 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_MEAN)
198 == lnorm_bwd_pd.mean_desc());
199 ASSERT_TRUE(lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_VARIANCE)
200 == lnorm_bwd_pd.variance_desc());
201 ASSERT_TRUE(
202 lnorm_bwd_pd.query_md(query::exec_arg_md, DNNL_ARG_SCALE_SHIFT)
203 == lnorm_bwd_pd.weights_desc());
204 ASSERT_TRUE(lnorm_bwd_pd.query_md(
205 query::exec_arg_md, DNNL_ARG_DIFF_SCALE_SHIFT)
206 == lnorm_bwd_pd.diff_weights_desc());
207
208 if (useScaleShift || useScale)
209 weights = test::make_memory(lnorm_bwd_pd.weights_desc(), eng);
210 if (useShift)
211 bias = test::make_memory(lnorm_bwd_pd.weights_desc(), eng);
212 if (useScaleShift || useScale)
213 diff_weights
214 = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng);
215 if (useShift)
216 diff_bias
217 = test::make_memory(lnorm_bwd_pd.diff_weights_desc(), eng);
218 mean = test::make_memory(*stat_d, eng);
219 variance = test::make_memory(*stat_d, eng);
220
221 if (useScaleShift || useScale) fill<float>(weights);
222 if (useShift) fill<float>(bias);
223 fill<float>(diff_src->get());
224 fill<float>(diff_dst->get());
225 fill<float>(mean);
226 fill<float>(variance);
227
228 execlnormBwd(useScaleShift, useScale, useShift, pk);
229 check_lnorm_bwd(p, src->get(), diff_dst->get(), mean, variance, weights,
230 diff_src->get(), diff_weights, diff_bias, flags, pk);
231 }
232
execlnormFwd(bool isTraining,bool useGlobalStats,bool useScaleShift,bool useScale,bool useShift)233 void execlnormFwd(bool isTraining, bool useGlobalStats, bool useScaleShift,
234 bool useScale, bool useShift) {
235 std::unordered_map<int, memory> args = {
236 {DNNL_ARG_SRC, src->get()},
237 {DNNL_ARG_DST, dst->get()},
238 };
239
240 if (useScaleShift) args.insert({DNNL_ARG_SCALE_SHIFT, weights});
241 if (useScale) args.insert({DNNL_ARG_SCALE, weights});
242 if (useShift) args.insert({DNNL_ARG_SHIFT, bias});
243
244 if (isTraining || useGlobalStats) {
245 args.insert({DNNL_ARG_MEAN, mean});
246 args.insert({DNNL_ARG_VARIANCE, variance});
247 }
248
249 layer_normalization_forward(lnorm_fwd_pd).execute(strm, args);
250 strm.wait();
251 }
252
execlnormBwd(bool useScaleShift,bool useScale,bool useShift,prop_kind pk)253 void execlnormBwd(
254 bool useScaleShift, bool useScale, bool useShift, prop_kind pk) {
255 std::unordered_map<int, memory> args = {
256 {DNNL_ARG_SRC, src->get()},
257 {DNNL_ARG_DIFF_DST, diff_dst->get()},
258 {DNNL_ARG_MEAN, mean},
259 {DNNL_ARG_VARIANCE, variance},
260 {DNNL_ARG_DIFF_SRC, diff_src->get()},
261 };
262
263 if (useScaleShift) {
264 args.insert({DNNL_ARG_SCALE_SHIFT, weights});
265 if (pk == prop_kind::backward)
266 args.insert({DNNL_ARG_DIFF_SCALE_SHIFT, diff_weights});
267 }
268
269 if (useScale) {
270 args.insert({DNNL_ARG_SCALE, weights});
271 if (pk == prop_kind::backward)
272 args.insert({DNNL_ARG_DIFF_SCALE, diff_weights});
273 }
274
275 if (useShift) {
276 args.insert({DNNL_ARG_SHIFT, bias});
277 if (pk == prop_kind::backward)
278 args.insert({DNNL_ARG_DIFF_SHIFT, diff_bias});
279 }
280
281 layer_normalization_backward(lnorm_bwd_pd).execute(strm, args);
282 strm.wait();
283 }
284
check_lnorm_fwd(const test_lnorm_params_t & p,const memory & src,const memory & mean,const memory & variance,const memory & weights,const memory & bias,const memory & dst,normalization_flags flags,prop_kind pk)285 void check_lnorm_fwd(const test_lnorm_params_t &p, const memory &src,
286 const memory &mean, const memory &variance, const memory &weights,
287 const memory &bias, const memory &dst, normalization_flags flags,
288 prop_kind pk) {
289 const size_t nelems = std::accumulate(p.dims.begin(), p.dims.end(),
290 size_t(1), std::multiplies<size_t>());
291 if (!nelems) return;
292
293 const bool use_weights_bias
294 = (bool)(flags & normalization_flags::use_scale_shift);
295 const bool use_weights = (bool)(flags & normalization_flags::use_scale);
296 const bool use_bias = (bool)(flags & normalization_flags::use_shift);
297 const bool calculate_stats
298 = !(bool)(flags & normalization_flags::use_global_stats);
299 const bool is_training = pk == prop_kind::forward_training;
300
301 const memory::desc src_d = src.get_desc();
302 const memory::desc dst_d = dst.get_desc();
303 const memory::desc weights_d
304 = use_weights_bias || use_weights || use_bias
305 ? weights.get_desc()
306 : memory::desc();
307 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
308 const dnnl::impl::memory_desc_wrapper stat_mdw((*stat_d).data);
309 const dnnl::impl::memory_desc_wrapper dst_mdw(dst_d.data);
310 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.data);
311
312 const auto ndims = src_mdw.ndims();
313 const auto C = src_mdw.dims()[ndims - 1];
314
315 auto src_data = map_memory<const float>(src);
316 GTEST_EXPECT_NE(src_data, nullptr);
317 auto dst_data = map_memory<const float>(dst);
318 GTEST_EXPECT_NE(dst_data, nullptr);
319 auto weights_data = use_weights_bias || use_weights
320 ? map_memory<const float>(weights)
321 : nullptr;
322 if (use_weights_bias || use_weights)
323 GTEST_EXPECT_NE(weights_data, nullptr);
324 const size_t bias_off = use_weights_bias && !weights_mdw.has_zero_dim()
325 ? weights_mdw.off_l(C, true)
326 : 0;
327 auto bias_data = use_bias
328 ? map_memory<const float>(bias)
329 : use_weights_bias ? &weights_data[bias_off] : nullptr;
330 if (use_weights_bias || use_bias) GTEST_EXPECT_NE(bias_data, nullptr);
331 auto mean_data = (!calculate_stats || is_training)
332 ? map_memory<const float>(mean)
333 : nullptr;
334 if (!calculate_stats || is_training)
335 GTEST_EXPECT_NE(mean_data, nullptr);
336 auto variance_data = (!calculate_stats || is_training)
337 ? map_memory<const float>(variance)
338 : nullptr;
339 if (!calculate_stats || is_training)
340 GTEST_EXPECT_NE(variance_data, nullptr);
341
342 if (!calculate_stats || is_training) {
343 const memory::desc stat_d = mean.get_desc();
344 const dnnl::impl::memory_desc_wrapper stat_mdw(stat_d.data);
345 }
346
347 float eps = static_cast<float>(1.e-4 * nelems / C);
348 dnnl::impl::parallel_nd(nelems / C, [&](memory::dim n) {
349 if (is_current_test_failed()) return;
350 float ref_mean = float(0);
351 float ref_variance = float(0);
352 const auto stat_off = stat_mdw.off_l(n);
353
354 if (calculate_stats) {
355 for (memory::dim c = 0; c < C; c++)
356 ref_mean += src_data[src_mdw.off_l(n * C + c)];
357 ref_mean /= C;
358
359 if (is_training) {
360 float mean_norm_max = std::max(
361 std::abs(mean_data[stat_off]), std::abs(ref_mean));
362 if (mean_norm_max < eps) mean_norm_max = float(1);
363 ASSERT_NEAR(
364 (mean_data[stat_off] - ref_mean) / mean_norm_max,
365 0., eps);
366 }
367
368 for (memory::dim c = 0; c < C; c++) {
369 float tmp = src_data[src_mdw.off_l(n * C + c)] - ref_mean;
370 ref_variance += tmp * tmp;
371 }
372 ref_variance /= C;
373
374 if (is_training) {
375 float variance_norm_max
376 = std::max(std::abs(variance_data[stat_off]),
377 std::abs(ref_variance));
378 if (variance_norm_max < eps) variance_norm_max = float(1);
379 ASSERT_NEAR((variance_data[stat_off] - ref_variance)
380 / variance_norm_max,
381 0., eps);
382 }
383 } else {
384 ref_mean = mean_data[stat_off];
385 ref_variance = variance_data[stat_off];
386 }
387
388 float ref_sqrt_variance
389 = static_cast<float>(sqrt(ref_variance + p.epsilon));
390 float ref_rsqrt_variance = float(1) / (ref_sqrt_variance);
391
392 for (memory::dim c = 0; c < C; c++) {
393 float ref_dst = float(0);
394 float wei = (use_weights_bias || use_weights)
395 ? weights_data[weights_mdw.off_l(c, true)]
396 : 1.0f;
397 float bia = (use_weights_bias || use_bias)
398 ? bias_data[weights_mdw.off_l(c, true)]
399 : 0.0f;
400 ref_dst = wei
401 * ((float)src_data[src_mdw.off_l(n * C + c)]
402 - ref_mean)
403 * ref_rsqrt_variance
404 + bia;
405
406 float out = dst_data[dst_mdw.off_l(n * C + c)];
407 float norm_max = std::max(std::abs(out), std::abs(ref_dst));
408 if (norm_max < 1e-2) norm_max = 1.;
409 ASSERT_NEAR((out - ref_dst) / norm_max, 0., eps);
410 }
411 });
412 }
413
check_lnorm_bwd(const test_lnorm_params_t & p,const memory & src,const memory & diff_dst,const memory & mean,const memory & variance,const memory & weights,const memory & diff_src,const memory & diff_weights,const memory & diff_bias,normalization_flags flags,prop_kind pk)414 void check_lnorm_bwd(const test_lnorm_params_t &p, const memory &src,
415 const memory &diff_dst, const memory &mean, const memory &variance,
416 const memory &weights, const memory &diff_src,
417 const memory &diff_weights, const memory &diff_bias,
418 normalization_flags flags, prop_kind pk) {
419 const ptrdiff_t nelems = std::accumulate(p.dims.begin(), p.dims.end(),
420 size_t(1), std::multiplies<size_t>());
421 if (!nelems) return;
422
423 const bool use_weights_bias
424 = (bool)(flags & normalization_flags::use_scale_shift);
425 const bool use_weights = (bool)(flags & normalization_flags::use_scale);
426 const bool use_bias = (bool)(flags & normalization_flags::use_shift);
427 const bool calculate_diff_stats
428 = !(bool)(flags & normalization_flags::use_global_stats);
429
430 const memory::desc src_d = src.get_desc();
431 const memory::desc diff_dst_d = diff_dst.get_desc();
432 const memory::desc weights_d = use_weights || use_weights_bias
433 ? weights.get_desc()
434 : memory::desc();
435 const memory::desc diff_src_d = diff_src.get_desc();
436 const memory::desc diff_weights_d = use_weights || use_weights_bias
437 ? diff_weights.get_desc()
438 : memory::desc();
439 const memory::desc diff_bias_d = use_bias
440 ? diff_bias.get_desc()
441 : use_weights_bias ? diff_weights.get_desc() : memory::desc();
442
443 const dnnl::impl::memory_desc_wrapper src_mdw(src_d.data);
444 const dnnl::impl::memory_desc_wrapper stat_mdw((*stat_d).data);
445 const dnnl::impl::memory_desc_wrapper diff_dst_mdw(diff_dst_d.data);
446 const dnnl::impl::memory_desc_wrapper weights_mdw(weights_d.data);
447 const dnnl::impl::memory_desc_wrapper diff_src_mdw(diff_src_d.data);
448 const dnnl::impl::memory_desc_wrapper diff_weights_mdw(
449 diff_weights_d.data);
450 const dnnl::impl::memory_desc_wrapper diff_bias_mdw(diff_bias_d.data);
451
452 const auto ndims = src_mdw.ndims();
453 const auto C = src_mdw.dims()[ndims - 1];
454
455 auto src_data = map_memory<const float>(src);
456 GTEST_EXPECT_NE(src_data, nullptr);
457 auto weights_data = use_weights_bias || use_weights
458 ? map_memory<const float>(weights)
459 : nullptr;
460 if (use_weights_bias || use_weights)
461 GTEST_EXPECT_NE(weights_data, nullptr);
462 auto diff_dst_data = map_memory<const float>(diff_dst);
463 GTEST_EXPECT_NE(diff_dst_data, nullptr);
464 auto mean_data = map_memory<const float>(mean);
465 GTEST_EXPECT_NE(mean_data, nullptr);
466 auto variance_data = map_memory<const float>(variance);
467 GTEST_EXPECT_NE(variance_data, nullptr);
468 const auto diff_src_data = map_memory<float>(diff_src);
469 GTEST_EXPECT_NE(diff_src_data, nullptr);
470 const auto diff_weights_data
471 = (pk == prop_kind::backward
472 && (use_weights_bias || use_weights))
473 ? map_memory<float>(diff_weights)
474 : nullptr;
475 if (pk == prop_kind::backward && (use_weights_bias || use_weights))
476 GTEST_EXPECT_NE(diff_weights_data, nullptr);
477 const size_t diff_bias_off
478 = use_weights_bias && !diff_weights_mdw.has_zero_dim()
479 ? diff_weights_mdw.off_l(C, true)
480 : 0;
481 const auto diff_bias_data = (pk == prop_kind::backward) ? use_bias
482 ? map_memory<float>(diff_bias)
483 : &diff_weights_data[diff_bias_off]
484 : nullptr;
485 if (pk == prop_kind::backward && (use_weights_bias || use_bias))
486 GTEST_EXPECT_NE(diff_bias_data, nullptr);
487
488 if (nelems == 0) {
489 if (pk == prop_kind::backward) {
490 if (use_weights_bias || use_weights) {
491 for (memory::dim c = 0; c < C; ++c) {
492 auto dg = diff_weights_data[diff_weights_mdw.off_l(
493 c, true)];
494 ASSERT_NEAR(dg, 0., 1e-7);
495 }
496 }
497 if (use_weights_bias || use_bias) {
498 for (memory::dim c = 0; c < C; ++c) {
499 auto db = diff_bias_data[diff_bias_mdw.off_l(c, true)];
500 ASSERT_NEAR(db, 0., 1e-7);
501 }
502 }
503 }
504 return;
505 }
506
507 const float eps = static_cast<float>(1.e-4 * nelems / C);
508
509 dnnl::impl::parallel_nd(C, [&](memory::dim c) {
510 if (is_current_test_failed()) return;
511
512 float ref_diff_gamma = float(0);
513 float ref_diff_beta = float(0);
514 for (memory::dim n = 0; n < nelems / C; n++) {
515 size_t stat_off = stat_mdw.off_l(n);
516 const float sqrt_variance
517 = 1.0f / sqrt(variance_data[stat_off] + p.epsilon);
518
519 ref_diff_gamma += (src_data[src_mdw.off_l(n * C + c)]
520 - mean_data[stat_off])
521 * diff_dst_data[diff_dst_mdw.off_l(n * C + c)]
522 * sqrt_variance;
523 ref_diff_beta += diff_dst_data[diff_dst_mdw.off_l(n * C + c)];
524 }
525
526 if (pk == prop_kind::backward) {
527 if (use_weights_bias || use_weights) {
528 auto diff_gamma
529 = diff_weights_data[diff_weights_mdw.off_l(c)];
530 float norm_max = std::max(
531 std::abs(diff_gamma), std::abs(ref_diff_gamma));
532 if (norm_max < 1e-2) norm_max = float(1);
533 ASSERT_NEAR(
534 (diff_gamma - ref_diff_gamma) / norm_max, 0., eps);
535 }
536
537 if (use_weights_bias || use_bias) {
538 auto diff_beta = diff_bias_data[diff_bias_mdw.off_l(c)];
539 float norm_max = std::max(
540 std::abs(diff_beta), std::abs(ref_diff_beta));
541 if (norm_max < 1e-2) norm_max = float(1);
542 ASSERT_NEAR(
543 (diff_beta - ref_diff_beta) / norm_max, 0., eps);
544 }
545 }
546 });
547
548 dnnl::impl::parallel_nd(nelems / C, [&](memory::dim n) {
549 if (is_current_test_failed()) return;
550
551 size_t stat_off = stat_mdw.off_l(n);
552 const float sqrt_variance
553 = 1.0f / sqrt(variance_data[stat_off] + p.epsilon);
554
555 float ref_dd_gamma = float(0);
556 float ref_dd_gamma_x = float(0);
557 if (calculate_diff_stats) {
558 for (memory::dim c = 0; c < C; c++) {
559 auto gamma = use_weights_bias || use_weights
560 ? weights_data[weights_mdw.off_l(c)]
561 : 1;
562 ref_dd_gamma += diff_dst_data[diff_dst_mdw.off_l(n * C + c)]
563 * gamma;
564 ref_dd_gamma_x
565 += diff_dst_data[diff_dst_mdw.off_l(n * C + c)]
566 * gamma
567 * (src_data[src_mdw.off_l(n * C + c)]
568 - mean_data[stat_off]);
569 }
570 ref_dd_gamma_x *= sqrt_variance;
571 }
572 for (memory::dim c = 0; c < C; c++) {
573 auto gamma = use_weights_bias || use_weights
574 ? weights_data[weights_mdw.off_l(c)]
575 : 1;
576 float ref_diff_src
577 = diff_dst_data[diff_dst_mdw.off_l(n * C + c)] * gamma;
578 if (calculate_diff_stats) {
579 ref_diff_src -= ref_dd_gamma / C
580 + (src_data[src_mdw.off_l(n * C + c)]
581 - mean_data[stat_off])
582 * ref_dd_gamma_x * sqrt_variance / C;
583 }
584 ref_diff_src *= sqrt_variance;
585 float out_diff_src
586 = diff_src_data[diff_src_mdw.off_l(n * C + c)];
587 float norm_max = std::max(
588 std::abs(out_diff_src), std::abs(ref_diff_src));
589 if (norm_max < eps) norm_max = float(1);
590 ASSERT_NEAR((out_diff_src - ref_diff_src) / norm_max, 0., eps);
591 }
592 });
593 }
594
fwd_iface_test_stat_any(prop_kind pk,normalization_flags flags)595 void fwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
596 // non stats if inference w/o use global stats
597 if (pk == prop_kind::forward_inference
598 && !(bool)(flags & normalization_flags::use_global_stats))
599 return;
600
601 using tag = memory::format_tag;
602
603 tag expect_stat_tag = derive_stat_tag();
604 if (expect_stat_tag == tag::undef) return; // optimism
605
606 memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
607 memory::desc expect_stat_md(
608 stat_dims, memory::data_type::f32, expect_stat_tag);
609
610 // no stat_md provided at all
611 {
612 layer_normalization_forward::primitive_desc fwd_pd(
613 {pk, *data_d, p.epsilon, flags}, eng);
614
615 EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
616 EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
617 }
618
619 // stat_md with format_tag::any
620 {
621 memory::desc any_stat_md(
622 stat_dims, memory::data_type::f32, tag::any);
623 layer_normalization_forward::primitive_desc fwd_pd(
624 {pk, *data_d, any_stat_md, p.epsilon, flags}, eng);
625
626 EXPECT_EQ(fwd_pd.mean_desc(), expect_stat_md);
627 EXPECT_EQ(fwd_pd.variance_desc(), expect_stat_md);
628 }
629 }
630
bwd_iface_test_stat_any(prop_kind pk,normalization_flags flags)631 void bwd_iface_test_stat_any(prop_kind pk, normalization_flags flags) {
632 using tag = memory::format_tag;
633
634 tag expect_stat_tag = derive_stat_tag();
635 if (expect_stat_tag == tag::undef) return; // optimism
636
637 memory::dims stat_dims(p.dims.begin(), p.dims.end() - 1);
638 memory::desc expect_stat_md(
639 stat_dims, memory::data_type::f32, expect_stat_tag);
640
641 layer_normalization_forward::primitive_desc fwd_pd(
642 {prop_kind::forward_training, *data_d, p.epsilon, flags}, eng);
643
644 // no stat_md provided at all
645 {
646 layer_normalization_backward::primitive_desc bwd_pd(
647 {pk, *diff_d, *data_d, p.epsilon, flags}, eng, fwd_pd);
648
649 EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
650 EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
651 }
652
653 // stat_md with format_tag::any
654 {
655 memory::desc any_stat_md(
656 stat_dims, memory::data_type::f32, tag::any);
657 layer_normalization_backward::primitive_desc bwd_pd(
658 {pk, *diff_d, *data_d, any_stat_md, p.epsilon, flags}, eng,
659 fwd_pd);
660
661 EXPECT_EQ(bwd_pd.mean_desc(), expect_stat_md);
662 EXPECT_EQ(bwd_pd.variance_desc(), expect_stat_md);
663 }
664 }
665
666 private:
derive_stat_tag() const667 memory::format_tag derive_stat_tag() const {
668 using tag = memory::format_tag;
669 tag expect_stat_tag = tag::undef;
670
671 // TODO: add more cases and test cases
672 // XXX: currently test only simple cases like `abc`, `acb`. Extend,
673 // if possible, to blocked formats too.
674 switch (p.data_tag) {
675 case tag::abc: expect_stat_tag = tag::ab; break;
676 case tag::bac: expect_stat_tag = tag::ba; break;
677 default: break;
678 }
679
680 return expect_stat_tag;
681 }
682 };
683
TEST_P(lnorm_test_t,TestsLnormF32)684 TEST_P(lnorm_test_t, TestsLnormF32) {}
685
686 #include "layer_normalization.h"
687 } // namespace dnnl
688