1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <float.h>
18 #include <math.h>
19 #include <stddef.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 
23 #include <sstream>
24 
25 #include "oneapi/dnnl/dnnl.h"
26 
27 #include "tests/test_thread.hpp"
28 
29 #include "dnnl_common.hpp"
30 #include "dnnl_memory.hpp"
31 #include "utils/compare.hpp"
32 
33 #include "bnorm/bnorm.hpp"
34 #include "lnorm/lnorm.hpp"
35 
36 using namespace bnorm;
37 
38 namespace lnorm {
39 
prepare_fwd(const prb_t * prb,dnn_mem_t & src,dnn_mem_t & mean,dnn_mem_t & var,dnn_mem_t & ss,dnn_mem_t & sh)40 static int prepare_fwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &mean,
41         dnn_mem_t &var, dnn_mem_t &ss, dnn_mem_t &sh) {
42     /** Idea: choose src[] values so that both mean and variance are computed
43      * exactly (independently of the order of the computations).
44      *
45      * The `exactness` is achieved via [a1]: src[i] + src[i+1] = 2 * mean.
46      *
47      * The variation in src is allowed in the last flex_bits bits.
48      * If the sequence (L) is too big (flex_bits <= min_flex_bits), the mean
49      * value is set to 0 and src is partially filled with zeros (according to
50      * density so that at least want_flex_bits is reserved for src variation.
51      * Once src is set, variance is computed.
52      *
53      * ALG_0: mean is set to 0
54      * ALG_1: mean is set to 2^prb, where prb \in {-2, -1, ..., 4}
55      * ALG_AUTO: choose between ALG_0 and ALG_1 automatically */
56     const int64_t exact_bits = digits_dt(prb->dt);
57     const int64_t L = prb->c;
58     const int64_t logL = (int64_t)ceilf(log2f(L));
59 
60     assert(logL <= 0 || (1LL << (logL - 1)) < L);
61     assert(L <= (1LL << logL));
62 
63     const int64_t min_flex_bits = 3;
64     const int64_t want_flex_bits = MIN2(6, exact_bits / 2);
65 
66     check_alg_t alg = prb->check_alg;
67     if (alg == ALG_AUTO) /* choose appropriate checking algorithm */
68         alg = (exact_bits - logL) / 2 - 1 >= min_flex_bits ? ALG_1 : ALG_0;
69 
70     const int64_t flex_bits = alg == ALG_0
71             ? want_flex_bits /* BFloat16 has only 7 bits of mantissa */
72             : MIN2(prb->dt == dnnl_bf16 ? 7 : exact_bits,
73                     (exact_bits - logL) / 2 - 1);
74 
75     if (flex_bits < min_flex_bits) return FAIL;
76 
77     const int64_t flex_mask = (1 << flex_bits) - 1;
78 
79     /* density: (exact_bits - log_2(L * density)) / 2 >= flex_bits */
80     const float density = alg == ALG_0
81             ? 1.f * (1 << (exact_bits - 2 * flex_bits)) / L
82             : 1.f;
83     assert((exact_bits - ceilf(log2f(L * density))) / 2 >= flex_bits);
84 
85     BENCHDNN_PRINT(6, "check_alg: %s, density = %g, flex_bits = " IFMT "\n",
86             check_alg2str(alg), density, flex_bits);
87 
88     dnnl::impl::parallel_nd(prb->n, [&](int64_t n) {
89         const float m = alg == ALG_0 ? 0.f : 0.25f * (1 << (n % 7));
90         float v = 0; /* current variance */
91 
92         float *s = (float *)src + n * prb->c;
93         for (int64_t c = 0; c < prb->c; ++c) {
94             const int64_t l = c + n * 239 * 2; // l[0] must be even
95 
96             if (alg == ALG_0 && !flip_coin(l / 2 * 257ULL, density)) {
97                 s[c] = 0;
98                 continue;
99             }
100 
101             const int64_t gen = (l / 2 * 1637) & flex_mask;
102             const int sgn = l % 2 == 0 ? 1 : -1; /* [a1] */
103             const float f = 1.f * sgn * gen / (1 << flex_bits);
104 
105             src.set_elem(n * prb->c + c, alg == ALG_0 ? f : m * (1.f + f));
106             if (L % 2 && (c == L - 1)) { s[c] = m; }
107             v += (s[c] - m) * (s[c] - m);
108         }
109         mean.set_elem(n, m);
110         var.set_elem(n, v / prb->c);
111     });
112 
113     const bool use_ss = prb->use_ss();
114     const bool use_sc = prb->use_sc();
115     const bool use_sh = prb->use_sh();
116 
117     dnnl::impl::parallel_nd(prb->c, [&](int64_t c) {
118         float sc_value = 1.f / 8 * (1 << (c % 7));
119         float sh_value = (c % 3 + 1) * sc_value / 64;
120         if (use_sc || use_sh) {
121             ((float *)ss)[c] = use_sc ? sc_value : 1.0f;
122             ((float *)sh)[c] = use_sh ? sh_value : 0.0f;
123         } else {
124             ((float *)ss)[c] = use_ss ? sc_value : 1.0f;
125             ((float *)ss)[prb->c + c] = use_ss ? sh_value : 0.0f;
126         }
127     });
128     return OK;
129 }
130 /** @brief L = 2^k * P, P % 2 != 0 */
decompose2(int64_t L,int64_t & k,int64_t & P)131 static void decompose2(int64_t L, int64_t &k, int64_t &P) {
132     P = L;
133     for (k = 0; P % 2 == 0; ++k)
134         P /= 2;
135 }
prepare_bwd(const prb_t * prb,dnn_mem_t & src,dnn_mem_t & d_dst,dnn_mem_t & mean,dnn_mem_t & var,dnn_mem_t & ss,dnn_mem_t & sh)136 static int prepare_bwd(const prb_t *prb, dnn_mem_t &src, dnn_mem_t &d_dst,
137         dnn_mem_t &mean, dnn_mem_t &var, dnn_mem_t &ss, dnn_mem_t &sh) {
138     const int64_t exact_bits = 24;
139 
140     if (prb->c < 2) return FAIL;
141 
142     const int64_t L = prb->c;
143     /** Stabilization idea...
144      * Layer Normalization (unlike batch normalization) features two types of
145      * accumulations in bwd step:
146      * First, accumulation over n:
147      *      d_gamma[c] = sum_over_n ddst[n, c] * (src[n, c] - mean[n]) * inv_sigma
148      *      d_beta[c] = ...
149      * Second, accumulation over c:
150      *      dd_gamma[n] = sum_over_c ddst[n, c] * (src[n, c] - mean[n])
151      *          * inv_sigma * gamma
152      *      dd_gamma_x[n] = ...
153      * that is used when computing d_src:
154      *      d_src = func(dd_gamma / C, dd_gamma_x / C, ...)
155      * To avoid accumulation error in the first case we will force sparsity
156      * of ddst over n if d_gamma and d_beta need to be computed.
157      * To get exact result of division in the second case we use the same
158      * approach as in batch normalization:
159      * Try to make dd_gamma = L / 2^t_dd_gamma and dd_gamma_x = L / 2^t_dd_gamma_x,
160      * where both t_dd_gamma and t_dd_gamma_x are in {1, .., max_k}.
161      * Currently, with no obvious reason, max_k is set to 4 for
162      * reasonably small problems and to 8 for big problems.
163      *
164      * We might hope that division by L would be exact in that case,
165      * but that might happen iff L is less than 2^exact_bits, hence
166      * restriction [r1].
167      * */
168 
169     int64_t k, P;
170     decompose2(L, k, P);
171 
172     int64_t log2P = (int64_t)ceilf(log2f(P));
173     if (log2P >= exact_bits) return FAIL; /* [r1] */
174 
175     const int64_t max_k = 4;
176     if (k > max_k && exact_bits - log2P > max_k + 4) {
177         log2P += (k - max_k);
178         P <<= k - max_k;
179         k = max_k;
180     }
181 
182     const int64_t param_dd_p2 = 7; // factor_dd <- 2^{0, .., -param_dd_p2+1}
183     const int64_t param_dd_gen = 32; // gen_dd <- {1, .., param_dd_gen}
184 
185     const int64_t param_f_p2 = 1; // factor_f <- 2^{-1, ..., -param_f_p2}
186     const int64_t param_f_gen = 16; // gen_f <- {2, ..., param_s_gen}
187 
188     const bool use_ss = prb->use_ss();
189     const bool use_sc = prb->use_sc();
190     const bool use_sh = prb->use_sh();
191 
192     const float density
193             = (use_ss || use_sc || use_sh) ? MIN2(1.f, 10.f / prb->n) : 1.f;
194 
195     BENCHDNN_PRINT(5,
196             "prep_bwd: k:" IFMT ", P:" IFMT " log2P:" IFMT ", density = %g\n",
197             k, P, log2P, density);
198 
199     // fill gamma and beta
200     for (int64_t c = 0; c < prb->c; ++c) {
201         const float sc_value = 1.f / 8 * (1 << (c % 7));
202         const float sh_value = sc_value / 64;
203         if (use_sc || use_sh) {
204             ((float *)ss)[c] = use_sc ? sc_value : 1.0f;
205             ((float *)sh)[c] = use_sh ? sh_value : 0.0f;
206         } else {
207             ((float *)ss)[c] = use_ss ? sc_value : 1.0f;
208             ((float *)ss)[prb->c + c] = use_ss ? sh_value : 0.0f;
209         }
210     }
211 
212     for (int64_t n = 0; n < prb->n; ++n) {
213         const float m = ((float *)mean)[n] = n % 2;
214 
215         /* var + eps \in {1/4, 1, 4} */
216         const float ve_denom = 4.f / (1 << 2 * (n % 3));
217         ((float *)var)[n] = ve_denom - prb->eps;
218 
219         const int64_t dd_p2 = (n * 127 % param_dd_p2);
220         const float factor_dd = 1.f / (1 << dd_p2);
221         const int64_t f_p2 = 1 + (n % param_f_p2);
222         const float factor_f = 1.f / (1 << f_p2);
223 
224         const float target_dd_g = factor_dd * P;
225         const float target_dd_g_x = 2 * target_dd_g;
226 
227         if (!flip_coin(n, density) && n != 0 && n != prb->n - 1) {
228             for (int64_t c = 0; c < prb->c; ++c) {
229                 ((float *)d_dst)[n * prb->c + c] = 0;
230                 ((float *)src)[n * prb->c + c] = m;
231             }
232             continue;
233         }
234         float dd_g = 0, dd_g_x = 0; /* current dd_gamma and dd_gamma_x */
235         for (int64_t c = 0; c < prb->c - 2; ++c) {
236             const float g = ((float *)ss)[c];
237             float &s = ((float *)src)[n * prb->c + c];
238             float &dd = ((float *)d_dst)[n * prb->c + c];
239 
240             const int sgn_dd = dd_g < target_dd_g ? 1 : -1;
241             dd = sgn_dd * factor_dd * (1 + ((c + n) * 3 % param_dd_gen));
242             dd_g += dd * g;
243 
244             const int sgn_f = dd_g_x < target_dd_g_x ? 1 : -1;
245             const float f = sgn_f * factor_f
246                     * (2 + ((c + n) * 7 % (param_f_gen - 1)));
247 
248             dd_g_x += f * dd * g;
249             s = f + m;
250         }
251 
252         /* the last 2 elements in src and d_dst are set, so that:
253          *      dd_gamma == target_dd_gamma
254          *      dd_gamma_x == target_dd_gamma_x
255          * For this we need to solve the system:
256          *      d_dst[l1] * g[c1]           + d_dst[l0] * g[c0]
257          *          = target_dd_gamma - dd_gamma
258          *      d_dst[l1] * src[l1] * g[c1] + d_dst[l0] * src[l0] * g[c0]
259          *          = target_dd_gamam_x - dd_gamma_x
260          *
261          * Here l0 -- last index, l1 -- last but one.
262          * More over, let's assume src[l1] = 1 and src[l0] = -1. */
263         int64_t l0 = n * prb->c + prb->c - 1;
264         int64_t l1 = n * prb->c + prb->c - 2;
265 
266         ((float *)src)[l1] = 1.f + m;
267         ((float *)src)[l0] = -1.f + m;
268         const float g1 = ((float *)ss)[prb->c - 2];
269         const float g0 = ((float *)ss)[prb->c - 1];
270 
271         float f1 = ((target_dd_g - dd_g) + (target_dd_g_x - dd_g_x)) / 2;
272         float f0 = ((target_dd_g - dd_g) - (target_dd_g_x - dd_g_x)) / 2;
273 
274         ((float *)d_dst)[l1] = f1 / g1;
275         ((float *)d_dst)[l0] = f0 / g0;
276 
277         if (prb->dt == dnnl_bf16) { // truncate to bf16
278             ((uint16_t *)(&((float *)d_dst)[l1]))[0] = 0;
279             ((uint16_t *)(&((float *)d_dst)[l0]))[0] = 0;
280         }
281     }
282 
283     return OK;
284 }
285 
init_pd(dnnl_engine_t engine,const prb_t * prb,dnnl_primitive_desc_t & lpd,res_t * res,dir_t dir,const_dnnl_primitive_desc_t hint)286 static int init_pd(dnnl_engine_t engine, const prb_t *prb,
287         dnnl_primitive_desc_t &lpd, res_t *res, dir_t dir,
288         const_dnnl_primitive_desc_t hint) {
289     dnnl_layer_normalization_desc_t ld;
290     dnnl_memory_desc_t data_d, stat_d;
291 
292     const int64_t *data_dims = &prb->dims[0];
293 
294     SAFE(init_md(&data_d, prb->ndims, data_dims, prb->dt, prb->tag), CRIT);
295 
296     const dnnl_memory_desc_t *stat_d_ptr = nullptr;
297     if (prb->stat_tag != tag::undef) {
298         SAFE(init_md(&stat_d, prb->ndims - 1, data_dims, dnnl_f32,
299                      prb->stat_tag),
300                 CRIT);
301         stat_d_ptr = &stat_d;
302     }
303 
304     auto flags = (dnnl_normalization_flags_t)prb->flags;
305     if (prb->dir & FLAG_FWD) {
306         auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
307                                         : dnnl_forward_training;
308         DNN_SAFE(dnnl_layer_normalization_forward_desc_init(
309                          &ld, prop, &data_d, stat_d_ptr, prb->eps, flags),
310                 WARN);
311     } else {
312         dnnl_memory_desc_t diff_data_d;
313         DNN_SAFE(dnnl_memory_desc_init_by_tag(&diff_data_d, prb->ndims,
314                          data_dims, prb->dt, dnnl_format_tag_any),
315                 WARN);
316         auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data;
317         DNN_SAFE(dnnl_layer_normalization_backward_desc_init(&ld, prop,
318                          &diff_data_d, &data_d, stat_d_ptr, prb->eps, flags),
319                 WARN);
320     }
321 
322     dnnl_primitive_desc_t hint_fwd_pd_ {};
323     dnnl_status_t status = dnnl_success;
324     if (prb->dir & FLAG_BWD) {
325         dnnl_layer_normalization_desc_t ld_fwd;
326         DNN_SAFE(dnnl_layer_normalization_forward_desc_init(&ld_fwd,
327                          dnnl_forward_training, &data_d, stat_d_ptr, prb->eps,
328                          flags),
329                 WARN);
330         status = dnnl_primitive_desc_create(
331                 &hint_fwd_pd_, &ld_fwd, nullptr, engine, nullptr);
332         if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
333     }
334     auto hint_fwd_pd = make_benchdnn_dnnl_wrapper(hint_fwd_pd_);
335     SAFE(status, WARN);
336 
337     auto dnnl_attr = make_benchdnn_dnnl_wrapper(
338             create_dnnl_attr(prb->attr, attr_args_t()));
339 
340     status = dnnl_primitive_desc_create(
341             &lpd, &ld, dnnl_attr, engine, hint_fwd_pd);
342 
343     if (status == dnnl_unimplemented) return res->state = UNIMPLEMENTED, OK;
344     SAFE(status, WARN);
345 
346     res->impl_name = query_impl_info(lpd);
347     if (maybe_skip(res->impl_name)) {
348         BENCHDNN_PRINT(2, "SKIPPED: oneDNN implementation: %s\n",
349                 res->impl_name.c_str());
350         return res->state = SKIPPED, res->reason = SKIP_IMPL_HIT, OK;
351     } else {
352         BENCHDNN_PRINT(
353                 5, "oneDNN implementation: %s\n", res->impl_name.c_str());
354         if (!strstr(res->impl_name.c_str(), "jit")) {
355             BENCHDNN_PRINT(2, "WARNING: %s",
356                     "accuracy of the implementation being tested "
357                     "depends on the compiler and might give "
358                     "false-positives.\n");
359             BENCHDNN_PRINT(2, "         %s",
360                     "please consider recompiling the sources with"
361                     " `-prec-div -fp-model precise` for a reliable testing.\n");
362         }
363     }
364 
365     SAFE(check_pd_w_and_wo_attr(res, prb->attr, ld), WARN);
366 
367     return OK;
368 }
369 
check_known_skipped_case(const prb_t * prb,res_t * res)370 void check_known_skipped_case(const prb_t *prb, res_t *res) {
371     check_known_skipped_case_common({prb->dt}, prb->dir, res);
372     if (res->state == SKIPPED) return;
373 
374     if (is_nvidia_gpu()) {
375         res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
376         return;
377     }
378 }
379 
380 /* When the error is larger than eps, It could be
381  * due to catastrophic cancellation in final result
382  * which is computed as `Y = a * X + b`.
383  * When `a * X`  is close to `b` and `sign(a * X) = - sign(b)`.
384  * Then large error in `a * X` could result in a final
385  * result (which has a cancellation i.e. `|Y| = |a*X - (-b)|`)
386  * which has no meaningful digits left in mantissa.*/
add_additional_fwd_lnorm_check(const prb_t * & prb,const dnn_mem_t & ss_fp,const dnn_mem_t & sh_fp,const dnn_mem_t & dst_fp,const float & eps,compare::compare_t & cmp)387 void add_additional_fwd_lnorm_check(const prb_t *&prb, const dnn_mem_t &ss_fp,
388         const dnn_mem_t &sh_fp, const dnn_mem_t &dst_fp, const float &eps,
389         compare::compare_t &cmp) {
390     using cmp_args_t = compare::compare_t::driver_check_func_args_t;
391     const auto lnorm_add_check = [&](const cmp_args_t &args) {
392         bool scale_or_shift = prb->use_ss() || prb->use_sc() || prb->use_sh();
393         if (!scale_or_shift) return false;
394 
395         dims_t l_dims = md2dims(dst_fp.md_);
396         dims_t dims_idx = off2dims_idx(l_dims, args.idx);
397         int64_t c = dims_idx[prb->ndims - 1];
398         const float beta = prb->use_sh() ? ((const float *)sh_fp)[c]
399                                          : ((const float *)ss_fp)[prb->c + c];
400         /* Using an empirically derived threshold,
401          * check if cancellation error
402          * in `|Y| = |a*X - (-b)|` is huge.*/
403         bool maybe_cancellation_error
404                 = (fabsf(args.got - beta)
405                           / (fabsf(args.got) > FLT_MIN ? fabsf(args.got) : 1))
406                 > 1.0f;
407         if (maybe_cancellation_error) {
408             /* Check for error in `a * X` */
409             float diff_aX
410                     = fabsf((args.got - beta) - (args.got + args.diff - beta));
411             return diff_aX <= eps;
412         }
413         return false;
414     };
415     cmp.set_driver_check_function(lnorm_add_check);
416 }
417 
doit(const prb_t * prb,res_t * res)418 int doit(const prb_t *prb, res_t *res) {
419     if (bench_mode == LIST) return res->state = LISTED, OK;
420 
421     check_known_skipped_case(prb, res);
422     check_sum_post_ops(prb->attr, res);
423     if (res->state == SKIPPED) return OK;
424 
425     benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
426     SAFE(init_prim(prim, init_pd, prb, res), WARN);
427     if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
428 
429     const_dnnl_primitive_desc_t const_pd;
430     DNN_SAFE(dnnl_primitive_get_primitive_desc(prim, &const_pd), CRIT);
431 
432     if (check_mem_size(const_pd) != OK) {
433         return res->state = SKIPPED, res->reason = NOT_ENOUGH_RAM, OK;
434     }
435 
436     const auto q = [&](int index = 0) -> const dnnl_memory_desc_t & {
437         return *dnnl_primitive_desc_query_md(
438                 const_pd, dnnl_query_exec_arg_md, index);
439     };
440 
441     const bool use_ss = prb->use_ss();
442     const bool use_sc = prb->use_sc();
443     const bool use_sh = prb->use_sh();
444 
445     const auto &data_md = q(DNNL_ARG_SRC);
446     const auto &mean_md = q(DNNL_ARG_MEAN);
447     const auto &var_md = q(DNNL_ARG_VARIANCE);
448     const auto &ss_md = q(DNNL_ARG_SCALE_SHIFT);
449     const auto &scratchpad_md = q(DNNL_ARG_SCRATCHPAD);
450 
451     const auto fp = dnnl_f32;
452     const auto tag = tag::abx;
453 
454     const auto &test_engine = get_test_engine();
455 
456     dnn_mem_t src_fp(data_md, fp, tag, test_engine);
457     dnn_mem_t src_dt(data_md, test_engine);
458 
459     dnn_mem_t &dst_fp = src_fp; // in-place reference
460     dnn_mem_t placeholder_dst_dt;
461     if (!prb->inplace) { placeholder_dst_dt = dnn_mem_t(data_md, test_engine); }
462     dnn_mem_t &dst_dt = prb->inplace ? src_dt : placeholder_dst_dt;
463 
464     // On inference w/o global stats the layer norm doesn't require stat
465     // memories. Hence, we need to prepare the mean_fp and var_fp ourselves.
466     const auto stat_ndims = prb->ndims - 1;
467     const auto stat_tag = tag::abx;
468     dnn_mem_t mean_fp(stat_ndims, data_md.dims, fp, stat_tag, test_engine);
469     dnn_mem_t mean_dt(mean_md, test_engine);
470 
471     dnn_mem_t var_fp(stat_ndims, data_md.dims, fp, stat_tag, test_engine);
472     dnn_mem_t var_dt(var_md, test_engine);
473 
474     dnn_mem_t ss_fp(ss_md, fp, tag::abx, test_engine);
475     dnn_mem_t ss_dt(ss_md, test_engine);
476     dnn_mem_t d_ss_fp(ss_md, fp, tag::abx, test_engine);
477     dnn_mem_t d_ss_dt(ss_md, test_engine);
478 
479     dnn_mem_t sh_fp(ss_md, fp, use_sh ? tag::x : tag::abx, test_engine);
480     dnn_mem_t sh_dt(ss_md, test_engine);
481     dnn_mem_t d_sh_fp(ss_md, fp, use_sh ? tag::x : tag::abx, test_engine);
482     dnn_mem_t d_sh_dt(ss_md, test_engine);
483 
484     dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
485 
486     dnn_mem_t d_dst_dt, placeholder_d_src_dt;
487 
488     args_t args;
489 
490     if (prb->dir & FLAG_FWD) {
491         if (prepare_fwd(prb, src_fp, mean_fp, var_fp, ss_fp, sh_fp) != OK) {
492             return res->state = MISTRUSTED, OK;
493         }
494 
495         SAFE(src_dt.reorder(src_fp), WARN);
496         if (prb->flags & GLOB_STATS) {
497             /* prepare mean & var if they are inputs */
498             SAFE(mean_dt.reorder(mean_fp), WARN);
499             SAFE(var_dt.reorder(var_fp), WARN);
500         }
501         if (use_ss || use_sc) { SAFE(ss_dt.reorder(ss_fp), WARN); }
502         if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
503 
504         args.set(DNNL_ARG_SRC, src_dt);
505         args.set(DNNL_ARG_DST, dst_dt);
506         args.set(DNNL_ARG_MEAN, mean_dt);
507         args.set(DNNL_ARG_VARIANCE, var_dt);
508         args.set(use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT, ss_dt);
509         args.set(DNNL_ARG_SHIFT, sh_dt);
510         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
511 
512         SAFE(execute_and_wait(prim, args), WARN);
513 
514         if (is_bench_mode(CORR)) {
515             TIME_REF(compute_ref_fwd(
516                     prb, src_fp, mean_fp, var_fp, ss_fp, sh_fp, dst_fp));
517 
518             compare::compare_t cmp_data;
519             const int digits_f32 = 24;
520             const float eps = (1 << (digits_f32 - digits_dt(prb->dt))) * 5e-7;
521             cmp_data.set_threshold(eps);
522             cmp_data.set_data_kind(DATA);
523             // TODO: improve bf16 filling
524             if (prb->dt == dnnl_bf16) cmp_data.set_zero_trust_percent(100.f);
525 
526             add_additional_fwd_lnorm_check(
527                     prb, ss_fp, sh_fp, dst_fp, eps, cmp_data);
528             SAFE(cmp_data.compare(dst_fp, dst_dt, prb->attr, res), WARN);
529 
530             if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) {
531                 compare::compare_t cmp_mean;
532                 cmp_mean.set_data_kind(MEAN);
533                 if (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
534                     cmp_mean.set_zero_trust_percent(100.f);
535                 SAFE(cmp_mean.compare(mean_fp, mean_dt, prb->attr, res), WARN);
536 
537                 compare::compare_t cmp_var;
538                 cmp_var.set_data_kind(VAR);
539                 if (prb->dt == dnnl_bf16 || prb->dt == dnnl_f16)
540                     cmp_var.set_zero_trust_percent(100.f);
541                 SAFE(cmp_var.compare(var_fp, var_dt, prb->attr, res), WARN);
542             }
543         }
544     } else {
545         const auto &d_data_md = q(DNNL_ARG_DIFF_DST);
546 
547         dnn_mem_t d_dst_fp(d_data_md, fp, tag, test_engine);
548         d_dst_dt = dnn_mem_t(d_data_md, test_engine);
549 
550         dnn_mem_t &d_src_fp = d_dst_fp; // in-place in ref code
551         if (!prb->inplace) {
552             placeholder_d_src_dt = dnn_mem_t(d_data_md, test_engine);
553         }
554         dnn_mem_t &d_src_dt = prb->inplace ? d_dst_dt : placeholder_d_src_dt;
555 
556         if (prepare_bwd(prb, src_fp, d_dst_fp, mean_fp, var_fp, ss_fp, sh_fp)
557                 != OK) {
558             return res->state = MISTRUSTED, OK;
559         }
560 
561         SAFE(src_dt.reorder(src_fp), WARN);
562         SAFE(d_dst_dt.reorder(d_dst_fp), WARN);
563         SAFE(mean_dt.reorder(mean_fp), WARN);
564         SAFE(var_dt.reorder(var_fp), WARN);
565         if (use_ss || use_sc) { SAFE(ss_dt.reorder(ss_fp), WARN); }
566         if (use_sh) { SAFE(sh_dt.reorder(sh_fp), WARN); }
567 
568         args.set(DNNL_ARG_SRC, src_dt);
569         args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
570         args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
571         args.set(DNNL_ARG_MEAN, mean_dt);
572         args.set(DNNL_ARG_VARIANCE, var_dt);
573         args.set(use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT, ss_dt);
574         args.set(use_sc ? DNNL_ARG_DIFF_SCALE : DNNL_ARG_DIFF_SCALE_SHIFT,
575                 d_ss_dt);
576         args.set(DNNL_ARG_SHIFT, sh_dt);
577         args.set(DNNL_ARG_DIFF_SHIFT, d_sh_dt);
578         args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
579 
580         SAFE(execute_and_wait(prim, args), WARN);
581 
582         if (is_bench_mode(CORR)) {
583             TIME_REF(compute_ref_bwd(prb, src_fp, mean_fp, var_fp, d_dst_fp,
584                     ss_fp, d_src_fp, d_ss_fp, d_sh_fp));
585 
586             compare::compare_t cmp_data;
587             const int digits_f32 = 24;
588             const float eps = (1 << (digits_f32 - digits_dt(prb->dt))) * 2e-7;
589             cmp_data.set_threshold(eps);
590             cmp_data.set_data_kind(DATA);
591             cmp_data.set_zero_trust_percent(70.f);
592             SAFE(cmp_data.compare(d_src_fp, d_src_dt, prb->attr, res), WARN);
593 
594             if ((use_ss || use_sc) && (prb->dir & FLAG_WEI)) {
595                 compare::compare_t cmp_ss;
596                 cmp_ss.set_threshold(eps);
597                 cmp_ss.set_data_kind(use_ss ? SS : SC);
598                 SAFE(cmp_ss.compare(d_ss_fp, d_ss_dt, prb->attr, res), WARN);
599             }
600             if (use_sh && (prb->dir & FLAG_WEI)) {
601                 compare::compare_t cmp_sh;
602                 cmp_sh.set_threshold(eps);
603                 cmp_sh.set_data_kind(SH);
604                 SAFE(cmp_sh.compare(d_sh_fp, d_sh_dt, prb->attr, res), WARN);
605             }
606         }
607     }
608 
609     return measure_perf(res, prim, args);
610 }
611 
612 } // namespace lnorm
613