1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <float.h>
18 #include <math.h>
19 #include <stddef.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #include <sstream>
24
25 #include "oneapi/dnnl/dnnl.h"
26
27 #include "tests/test_thread.hpp"
28
29 #include "dnnl_common.hpp"
30 #include "dnnl_memory.hpp"
31 #include "utils/compare.hpp"
32
33 #include "bnorm/bnorm.hpp"
34 #include "lnorm/lnorm.hpp"
35
36 using namespace bnorm;
37
38 namespace lnorm {
39
prepare_fwd(const prb_t * prb,dnn_mem_t & src,dnn_mem_t & mean,dnn_mem_t & var,dnn_mem_t & ss,dnn_mem_t & sh)40 static int prepare_fwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &mean,
41 dnn_mem_t &var, dnn_mem_t &ss, dnn_mem_t &sh) {
42 /** Idea: choose src[] values so that both mean and variance are computed
43 * exactly (independently of the order of the computations).
44 *
45 * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean.
46 *
47 * The variation in src is allowed in the last flex_bits bits.
48 * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean
49 * value is set to 0 and src is partially filled with zeros (according to
50 * density so that at least want_flex_bits is reserved for src variation.
51 * Once src is set, variance is computed.
52 *
53 * ALG_0: mean is set to 0
54 * ALG_1: mean is set to 2^prb, where prb \in {-2, -1, ..., 4}
55 * ALG_AUTO: choose between ALG_0 and ALG_1 automatically */
56 const int64_t exact_bits = digits_dt(prb->dt);
57 const int64_t L = prb->c;
58 const int64_t logL = (int64_t)ceilf(log2f(L));
59
60 assert(logL <= 0 || (1LL << (logL - 1)) < L);
61 assert(L <= (1LL << logL));
62
63 const int64_t min_flex_bits = 3;
64 const int64_t want_flex_bits = MIN2(6, exact_bits / 2);
65
66 check_alg_t alg = prb->check_alg;
67 if (alg == ALG_AUTO) /* choose appropriate checking algorithm */
68 alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0;
69
70 const int64_t flex_bits = alg == ALG_0
71 ? want_flex_bits /* BFloat16 has only 7 bits of mantissa */
72 : MIN2(prb->dt == dnnl_bf16 ? 7 : exact_bits,
73 (exact_bits - logL) / 2 - 1);
74
75 if (flex_bits < min_flex_bits) return FAIL;
76
77 const int64_t flex_mask = (1 << flex_bits) - 1;
78
79 /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */
80 const float density = alg == ALG_0
81 ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L
82 : 1.f;
83 assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits);
84
85 BENCHDNN_PRINT(6, "check_alg: %s, density = %g, flex_bits = " IFMT "\n",
86 check_alg2str(alg), density, flex_bits);
87
88 dnnl::impl::parallel_nd(prb->n, [&](int64_t n) {
89 const float m = alg == ALG_0 ? 0.f : 0.25f * (1 << (n % 7));
90 float v = 0; /* current variance */
91
92 float *s = (float *)src + n * prb->c;
93 for (int64_t c = 0; c < prb->c; ++c) {
94 const int64_t l = c + n * 239 * 2; // l[0] must be even
95
96 if (alg == ALG_0 && !flip_coin(l / 2 * 257ULL, density)) {
97 s[c] = 0;
98 continue;
99 }
100
101 const int64_t gen = (l / 2 * 1637) & flex_mask;
102 const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */
103 const float f = 1.f * sgn * gen / (1 << flex_bits);
104
105 src.set_elem(n * prb->c + c, alg == ALG_0 ? f : m * (1.f + f));
106 if (L % 2 && (c == L - 1)) { s[c] = m; }
107 v += (s[c] - m) * (s[c] - m);
108 }
109 mean.set_elem(n, m);
110 var.set_elem(n, v / prb->c);
111 });
112
113 const bool use_ss = prb->use_ss();
114 const bool use_sc = prb->use_sc();
115 const bool use_sh = prb->use_sh();
116
117 dnnl::impl::parallel_nd(prb->c, [&](int64_t c) {
118 float sc_value = 1.f / 8 * (1 << (c % 7));
119 float sh_value = (c % 3 + 1) * sc_value / 64;
120 if (use_sc || use_sh) {
121 ((float *)ss)[c] = use_sc ? sc_value : 1.0f;
122 ((float *)sh)[c] = use_sh ? sh_value : 0.0f;
123 } else {
124 ((float *)ss)[c] = use_ss ? sc_value : 1.0f;
125 ((float *)ss)[prb->c + c] = use_ss ? sh_value : 0.0f;
126 }
127 });
128 return OK;
129 }
130 /** @brief L = 2^k * P, P % 2 != 0 */
decompose2(int64_t L,int64_t & k,int64_t & P)131 static void decompose2(int64_t L, int64_t &k, int64_t &P) {
132 P = L;
133 for (k = 0; P % 2 == 0; ++k)
134 P /= 2;
135 }
prepare_bwd(const prb_t * prb,dnn_mem_t & src,dnn_mem_t & d_dst,dnn_mem_t & mean,dnn_mem_t & var,dnn_mem_t & ss,dnn_mem_t & sh)136 static int prepare_bwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &d_dst,
137 dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &ss, dnn_mem_t &sh) {
138 const int64_t exact_bits = 24;
139
140 if (prb->c < 2) return FAIL;
141
142 const int64_t L = prb->c;
143 /** Stabilization idea...
144 * Layer Normalization (unlike batch normalization) features two types of
145 * accumulations in bwd step:
146 * First, accumulation over n:
147 * d_gamma[c] = sum_over_n ddst[n, c] * (src[n, c] - mean[n]) * inv_sigma
148 * d_beta[c] = ...
149 * Second, accumulation over c:
150 * dd_gamma[n] = sum_over_c ddst[n, c] * (src[n, c] - mean[n])
151 * * inv_sigma * gamma
152 * dd_gamma_x[n] = ...
153 * that is used when computing d_src:
154 * d_src = func(dd_gamma / C, dd_gamma_x / C, ...)
155 * To avoid accumulation error in the first case we will force sparsity
156 * of ddst over n if d_gamma and d_beta need to be computed.
157 * To get exact result of division in the second case we use the same
158 * approach as in batch normalization:
159 * Try to make dd_gamma = L / 2^t_dd_gamma and dd_gamma_x = L / 2^t_dd_gamma_x,
160 * where both t_dd_gamma and t_dd_gamma_x are in {1, .., max_k}.
161 * Currently, with no obvious reason, max_k is set to 4 for
162 * reasonably small problems and to 8 for big problems.
163 *
164 * We might hope that division by L would be exact in that case,
165 * but that might happen iff L is less than 2^exact_bits, hence
166 * restriction [r1].
167 * */
168
169 int64_t k, P;
170 decompose2(L, k, P);
171
172 int64_t log2P = (int64_t)ceilf(log2f(P));
173 if (log2P >= exact_bits) return FAIL; /* [r1] */
174
175 const int64_t max_k = 4;
176 if (k > max_k && exact_bits - log2P > max_k + 4) {
177 log2P += (k - max_k);
178 P <<= k - max_k;
179 k = max_k;
180 }
181
182 const int64_t param_dd_p2 = 7; // factor_dd <- 2^{0, .., -param_dd_p2+1}
183 const int64_t param_dd_gen = 32; // gen_dd <- {1, .., param_dd_gen}
184
185 const int64_t param_f_p2 = 1; // factor_f <- 2^{-1, ..., -param_f_p2}
186 const int64_t param_f_gen = 16; // gen_f <- {2, ..., param_s_gen}
187
188 const bool use_ss = prb->use_ss();
189 const bool use_sc = prb->use_sc();
190 const bool use_sh = prb->use_sh();
191
192 const float density
193 = (use_ss || use_sc || use_sh) ? MIN2(1.f, 10.f / prb->n) : 1.f;
194
195 BENCHDNN_PRINT(5,
196 "prep_bwd: k:" IFMT ", P:" IFMT " log2P:" IFMT ", density = %g\n",
197 k, P, log2P, density);
198
199 // fill gamma and beta
200 for (int64_t c = 0; c < prb->c; ++c) {
201 const float sc_value = 1.f / 8 * (1 << (c % 7));
202 const float sh_value = sc_value / 64;
203 if (use_sc || use_sh) {
204 ((float *)ss)[c] = use_sc ? sc_value : 1.0f;
205 ((float *)sh)[c] = use_sh ? sh_value : 0.0f;
206 } else {
207 ((float *)ss)[c] = use_ss ? sc_value : 1.0f;
208 ((float *)ss)[prb->c + c] = use_ss ? sh_value : 0.0f;
209 }
210 }
211
212 for (int64_t n = 0; n < prb->n; ++n) {
213 const float m = ((float *)mean)[n] = n % 2;
214
215 /* var + eps \in {1/4, 1, 4} */
216 const float ve_denom = 4.f / (1 << 2 * (n % 3));
217 ((float *)var)[n] = ve_denom - prb->eps;
218
219 const int64_t dd_p2 = (n * 127 % param_dd_p2);
220 const float factor_dd = 1.f / (1 << dd_p2);
221 const int64_t f_p2 = 1 + (n % param_f_p2);
222 const float factor_f = 1.f / (1 << f_p2);
223
224 const float target_dd_g = factor_dd * P;
225 const float target_dd_g_x = 2 * target_dd_g;
226
227 if (!flip_coin(n, density) && n != 0 && n != prb->n - 1) {
228 for (int64_t c = 0; c < prb->c; ++c) {
229 ((float *)d_dst)[n * prb->c + c] = 0;
230 ((float *)src)[n * prb->c + c] = m;
231 }
232 continue;
233 }
234 float dd_g = 0, dd_g_x = 0; /* current dd_gamma and dd_gamma_x */
235 for (int64_t c = 0; c < prb->c - 2; ++c) {
236 const float g = ((float *)ss)[c];
237 float &s = ((float *)src)[n * prb->c + c];
238 float &dd = ((float *)d_dst)[n * prb->c + c];
239
240 const int sgn_dd = dd_g < target_dd_g ? 1 : -1;
241 dd = sgn_dd * factor_dd * (1 + ((c + n) * 3 % param_dd_gen));
242 dd_g += dd * g;
243
244 const int sgn_f = dd_g_x < target_dd_g_x ? 1 : -1;
245 const float f = sgn_f * factor_f
246 * (2 + ((c + n) * 7 % (param_f_gen - 1)));
247
248 dd_g_x += f * dd * g;
249 s = f + m;
250 }
251
252 /* the last 2 elements in src and d_dst are set, so that:
253 * dd_gamma == target_dd_gamma
254 * dd_gamma_x == target_dd_gamma_x
255 * For this we need to solve the system:
256 * d_dst[l1] * g[c1] + d_dst[l0] * g[c0]
257 * = target_dd_gamma - dd_gamma
258 * d_dst[l1] * src[l1] * g[c1] + d_dst[l0] * src[l0] * g[c0]
259 * = target_dd_gamam_x - dd_gamma_x
260 *
261 * Here l0 -- last index, l1 -- last but one.
262 * More over, let's assume src[l1] = 1 and src[l0] = -1. */
263 int64_t l0 = n * prb->c + prb->c - 1;
264 int64_t l1 = n * prb->c + prb->c - 2;
265
266 ((float *)src)[l1] = 1.f + m;
267 ((float *)src)[l0] = -1.f + m;
268 const float g1 = ((float *)ss)[prb->c - 2];
269 const float g0 = ((float *)ss)[prb->c - 1];
270
271 float f1 = ((target_dd_g - dd_g) + (target_dd_g_x - dd_g_x)) / 2;
272 float f0 = ((target_dd_g - dd_g) - (target_dd_g_x - dd_g_x)) / 2;
273
274 ((float *)d_dst)[l1] = f1 / g1;
275 ((float *)d_dst)[l0] = f0 / g0;
276
277 if (prb->dt == dnnl_bf16) { // truncate to bf16
278 ((uint16_t *)(&((float *)d_dst)[l1]))[0] = 0;
279 ((uint16_t *)(&((float *)d_dst)[l0]))[0] = 0;
280 }
281 }
282
283 return OK;
284 }
285
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & lpd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)286 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
287 dnnl_primitive_desc_t &lpd, res_t *res, dir_t dir,
288 const_dnnl_primitive_desc_t hint) {
289 dnnl_layer_normalization_desc_t ld;
290 dnnl_memory_desc_t data_d, stat_d;
291
292 const int64_t *data_dims = &prb->dims[0];
293
294 SAFE(init_md(&data_d, prb->ndims, data_dims, prb->dt, prb->tag), CRIT);
295
296 const dnnl_memory_desc_t *stat_d_ptr = nullptr;
297 if (prb->stat_tag != tag::undef) {
298 SAFE(init_md(&stat_d, prb->ndims - 1, data_dims, dnnl_f32,
299 prb->stat_tag),
300 CRIT);
301 stat_d_ptr = &stat_d;
302 }
303
304 auto flags = (dnnl_normalization_flags_t)prb->flags;
305 if (prb->dir & FLAG_FWD) {
306 auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
307 : dnnl_forward_training;
308 DNN_SAFE(dnnl_layer_normalization_forward_desc_init(
309 &ld, prop, &data_d, stat_d_ptr, prb->eps, flags),
310 WARN);
311 } else {
312 dnnl_memory_desc_t diff_data_d;
313 DNN_SAFE(dnnl_memory_desc_init_by_tag(&diff_data_d, prb->ndims,
314 data_dims, prb->dt, dnnl_format_tag_any),
315 WARN);
316 auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data;
317 DNN_SAFE(dnnl_layer_normalization_backward_desc_init(&ld, prop,
318 &diff_data_d, &data_d, stat_d_ptr, prb->eps, flags),
319 WARN);
320 }
321
322 dnnl_primitive_desc_t hint_fwd_pd_ {};
323 dnnl_status_t status = dnnl_success;
324 if (prb->dir & FLAG_BWD) {
325 dnnl_layer_normalization_desc_t ld_fwd;
326 DNN_SAFE(dnnl_layer_normalization_forward_desc_init(&ld_fwd,
327 dnnl_forward_training, &data_d, stat_d_ptr, prb->eps,
328 flags),
329 WARN);
330 status = dnnl_primitive_desc_create(
331 &hint_fwd_pd_, &ld_fwd, nullptr, engine, nullptr);
332 if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
333 }
334 auto hint_fwd_pd = make_benchdnn_dnnl_wrapper(hint_fwd_pd_);
335 SAFE(status, WARN);
336
337 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
338 create_dnnl_attr(prb->attr, attr_args_t()));
339
340 status = dnnl_primitive_desc_create(
341 &lpd, &ld, dnnl_attr, engine, hint_fwd_pd);
342
343 if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
344 SAFE(status, WARN);
345
346 res->impl_name = query_impl_info(lpd);
347 if (maybe_skip(res->impl_name)) {
348 BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
349 res->impl_name.c_str());
350 return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
351 } else {
352 BENCHDNN_PRINT(
353 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
354 if (!strstr(res->impl_name.c_str(), "jit")) {
355 BENCHDNN_PRINT(2, "WARNING: %s",
356 "accuracy of the implementation being tested "
357 "depends on the compiler and might give "
358 "false-positives.\n");
359 BENCHDNN_PRINT(2, " %s",
360 "please consider recompiling the sources with"
361 " `-prec-div -fp-model precise` for a reliable testing.\n");
362 }
363 }
364
365 SAFE(check_pd_w_and_wo_attr(res, prb->attr, ld), WARN);
366
367 return OK;
368 }
369
check_known_skipped_case(const prb_t * prb,res_t * res)370 void check_known_skipped_case(const prb_t *prb, res_t *res) {
371 check_known_skipped_case_common({prb->dt}, prb->dir, res);
372 if (res->state == SKIPPED) return;
373
374 if (is_nvidia_gpu()) {
375 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
376 return;
377 }
378 }
379
380 /* When the error is larger than eps, It could be
381 * due to catastrophic cancellation in final result
382 * which is computed as `Y = a * X + b`.
383 * When `a * X` is close to `b` and `sign(a * X) = - sign(b)`.
384 * Then large error in `a * X` could result in a final
385 * result (which has a cancellation i.e. `|Y| = |a*X - (-b)|`)
386 * which has no meaningful digits left in mantissa.*/
add_additional_fwd_lnorm_check(const prb_t * & prb,const dnn_mem_t & ss_fp,const dnn_mem_t & sh_fp,const dnn_mem_t & dst_fp,const float & eps,compare::compare_t & cmp)387 void add_additional_fwd_lnorm_check(const prb_t *&prb, const dnn_mem_t &ss_fp,
388 const dnn_mem_t &sh_fp, const dnn_mem_t &dst_fp, const float &eps,
389 compare::compare_t &cmp) {
390 using cmp_args_t = compare::compare_t::driver_check_func_args_t;
391 const auto lnorm_add_check = [&](const cmp_args_t &args) {
392 bool scale_or_shift = prb->use_ss() || prb->use_sc() || prb->use_sh();
393 if (!scale_or_shift) return false;
394
395 dims_t l_dims = md2dims(dst_fp.md_);
396 dims_t dims_idx = off2dims_idx(l_dims, args.idx);
397 int64_t c = dims_idx[prb->ndims - 1];
398 const float beta = prb->use_sh() ? ((const float *)sh_fp)[c]
399 : ((const float *)ss_fp)[prb->c + c];
400 /* Using an empirically derived threshold,
401 * check if cancellation error
402 * in `|Y| = |a*X - (-b)|` is huge.*/
403 bool maybe_cancellation_error
404 = (fabsf(args.got - beta)
405 / (fabsf(args.got) > FLT_MIN ? fabsf(args.got) : 1))
406 > 1.0f;
407 if (maybe_cancellation_error) {
408 /* Check for error in `a * X` */
409 float diff_aX
410 = fabsf((args.got - beta) - (args.got + args.diff - beta));
411 return diff_aX <= eps;
412 }
413 return false;
414 };
415 cmp.set_driver_check_function(lnorm_add_check);
416 }
417
doit(const prb_t * prb,res_t * res)418 int doit(const prb_t *prb, res_t *res) {
419 if (bench_mode == LIST) return res->state = LISTED, OK;
420
421 check_known_skipped_case(prb, res);
422 check_sum_post_ops(prb->attr, res);
423 if (res->state == SKIPPED) return OK;
424
425 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
426 SAFE(init_prim(prim, init_pd, prb, res), WARN);
427 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
428
429 const_dnnl_primitive_desc_t const_pd;
430 DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
431
432 if (check_mem_size(const_pd) != OK) {
433 return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
434 }
435
436 const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
437 return *dnnl_primitive_desc_query_md(
438 const_pd, dnnl_query_exec_arg_md, index);
439 };
440
441 const bool use_ss = prb->use_ss();
442 const bool use_sc = prb->use_sc();
443 const bool use_sh = prb->use_sh();
444
445 const auto &data_md = q(DNNL_ARG_SRC);
446 const auto &mean_md = q(DNNL_ARG_MEAN);
447 const auto &var_md = q(DNNL_ARG_VARIANCE);
448 const auto &ss_md = q(DNNL_ARG_SCALE_SHIFT);
449 const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
450
451 const auto fp = dnnl_f32;
452 const auto tag = tag::abx;
453
454 const auto &test_engine = get_test_engine();
455
456 dnn_mem_t src_fp(data_md, fp, tag, test_engine);
457 dnn_mem_t src_dt(data_md, test_engine);
458
459 dnn_mem_t &dst_fp = src_fp; // in-place reference
460 dnn_mem_t placeholder_dst_dt;
461 if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(data_md, test_engine); }
462 dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt;
463
464 // On inference w/o global stats the layer norm doesn't require stat
465 // memories. Hence, we need to prepare the mean_fp and var_fp ourselves.
466 const auto stat_ndims = prb->ndims - 1;
467 const auto stat_tag = tag::abx;
468 dnn_mem_t mean_fp(stat_ndims, data_md.dims, fp, stat_tag, test_engine);
469 dnn_mem_t mean_dt(mean_md, test_engine);
470
471 dnn_mem_t var_fp(stat_ndims, data_md.dims, fp, stat_tag, test_engine);
472 dnn_mem_t var_dt(var_md, test_engine);
473
474 dnn_mem_t ss_fp(ss_md, fp, tag::abx, test_engine);
475 dnn_mem_t ss_dt(ss_md, test_engine);
476 dnn_mem_t d_ss_fp(ss_md, fp, tag::abx, test_engine);
477 dnn_mem_t d_ss_dt(ss_md, test_engine);
478
479 dnn_mem_t sh_fp(ss_md, fp, use_sh ? tag::x : tag::abx, test_engine);
480 dnn_mem_t sh_dt(ss_md, test_engine);
481 dnn_mem_t d_sh_fp(ss_md, fp, use_sh ? tag::x : tag::abx, test_engine);
482 dnn_mem_t d_sh_dt(ss_md, test_engine);
483
484 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
485
486 dnn_mem_t d_dst_dt, placeholder_d_src_dt;
487
488 args_t args;
489
490 if (prb->dir & FLAG_FWD) {
491 if (prepare_fwd(prb, src_fp, mean_fp, var_fp, ss_fp, sh_fp) != OK) {
492 return res->state = MISTRUSTED, OK;
493 }
494
495 SAFE(src_dt.reorder(src_fp), WARN);
496 if (prb->flags & GLOB_STATS) {
497 /* prepare mean & var if they are inputs */
498 SAFE(mean_dt.reorder(mean_fp), WARN);
499 SAFE(var_dt.reorder(var_fp), WARN);
500 }
501 if (use_ss || use_sc) { SAFE(ss_dt.reorder(ss_fp), WARN); }
502 if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
503
504 args.set(DNNL_ARG_SRC, src_dt);
505 args.set(DNNL_ARG_DST, dst_dt);
506 args.set(DNNL_ARG_MEAN, mean_dt);
507 args.set(DNNL_ARG_VARIANCE, var_dt);
508 args.set(use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT, ss_dt);
509 args.set(DNNL_ARG_SHIFT, sh_dt);
510 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
511
512 SAFE(execute_and_wait(prim, args), WARN);
513
514 if (is_bench_mode(CORR)) {
515 TIME_REF(compute_ref_fwd(
516 prb, src_fp, mean_fp, var_fp, ss_fp, sh_fp, dst_fp));
517
518 compare::compare_t cmp_data;
519 const int digits_f32 = 24;
520 const float eps = (1 << (digits_f32 - digits_dt(prb->dt))) * 5e-7;
521 cmp_data.set_threshold(eps);
522 cmp_data.set_data_kind(DATA);
523 // TODO: improve bf16 filling
524 if (prb->dt == dnnl_bf16) cmp_data.set_zero_trust_percent(100.f);
525
526 add_additional_fwd_lnorm_check(
527 prb, ss_fp, sh_fp, dst_fp, eps, cmp_data);
528 SAFE(cmp_data.compare(dst_fp, dst_dt, prb->attr, res), WARN);
529
530 if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) {
531 compare::compare_t cmp_mean;
532 cmp_mean.set_data_kind(MEAN);
533 if (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
534 cmp_mean.set_zero_trust_percent(100.f);
535 SAFE(cmp_mean.compare(mean_fp, mean_dt, prb->attr, res), WARN);
536
537 compare::compare_t cmp_var;
538 cmp_var.set_data_kind(VAR);
539 if (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
540 cmp_var.set_zero_trust_percent(100.f);
541 SAFE(cmp_var.compare(var_fp, var_dt, prb->attr, res), WARN);
542 }
543 }
544 } else {
545 const auto &d_data_md = q(DNNL_ARG_DIFF_DST);
546
547 dnn_mem_t d_dst_fp(d_data_md, fp, tag, test_engine);
548 d_dst_dt = dnn_mem_t(d_data_md, test_engine);
549
550 dnn_mem_t &d_src_fp = d_dst_fp; // in-place in ref code
551 if (!prb->inplace) {
552 placeholder_d_src_dt = dnn_mem_t(d_data_md, test_engine);
553 }
554 dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
555
556 if (prepare_bwd(prb, src_fp, d_dst_fp, mean_fp, var_fp, ss_fp, sh_fp)
557 != OK) {
558 return res->state = MISTRUSTED, OK;
559 }
560
561 SAFE(src_dt.reorder(src_fp), WARN);
562 SAFE(d_dst_dt.reorder(d_dst_fp), WARN);
563 SAFE(mean_dt.reorder(mean_fp), WARN);
564 SAFE(var_dt.reorder(var_fp), WARN);
565 if (use_ss || use_sc) { SAFE(ss_dt.reorder(ss_fp), WARN); }
566 if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
567
568 args.set(DNNL_ARG_SRC, src_dt);
569 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
570 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
571 args.set(DNNL_ARG_MEAN, mean_dt);
572 args.set(DNNL_ARG_VARIANCE, var_dt);
573 args.set(use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT, ss_dt);
574 args.set(use_sc ? DNNL_ARG_DIFF_SCALE : DNNL_ARG_DIFF_SCALE_SHIFT,
575 d_ss_dt);
576 args.set(DNNL_ARG_SHIFT, sh_dt);
577 args.set(DNNL_ARG_DIFF_SHIFT, d_sh_dt);
578 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
579
580 SAFE(execute_and_wait(prim, args), WARN);
581
582 if (is_bench_mode(CORR)) {
583 TIME_REF(compute_ref_bwd(prb, src_fp, mean_fp, var_fp, d_dst_fp,
584 ss_fp, d_src_fp, d_ss_fp, d_sh_fp));
585
586 compare::compare_t cmp_data;
587 const int digits_f32 = 24;
588 const float eps = (1 << (digits_f32 - digits_dt(prb->dt))) * 2e-7;
589 cmp_data.set_threshold(eps);
590 cmp_data.set_data_kind(DATA);
591 cmp_data.set_zero_trust_percent(70.f);
592 SAFE(cmp_data.compare(d_src_fp, d_src_dt, prb->attr, res), WARN);
593
594 if ((use_ss || use_sc) && (prb->dir & FLAG_WEI)) {
595 compare::compare_t cmp_ss;
596 cmp_ss.set_threshold(eps);
597 cmp_ss.set_data_kind(use_ss ? SS : SC);
598 SAFE(cmp_ss.compare(d_ss_fp, d_ss_dt, prb->attr, res), WARN);
599 }
600 if (use_sh && (prb->dir & FLAG_WEI)) {
601 compare::compare_t cmp_sh;
602 cmp_sh.set_threshold(eps);
603 cmp_sh.set_data_kind(SH);
604 SAFE(cmp_sh.compare(d_sh_fp, d_sh_dt, prb->attr, res), WARN);
605 }
606 }
607 }
608
609 return measure_perf(res, prim, args);
610 }
611
612 } // namespace lnorm
613