1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <cassert>
18 #include <cmath>
19 #include <memory>
20 
21 #include "common/c_types_map.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/math_utils.hpp"
24 #include "common/memory_tracking.hpp"
25 #include "common/nstl.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28 
29 #include "cpu/cpu_batch_normalization_utils.hpp"
30 #include "cpu/platform.hpp"
31 #include "cpu/x64/jit_generator.hpp"
32 
33 #include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
34 #include "cpu/x64/jit_uni_tbb_batch_normalization.hpp"
35 
36 namespace dnnl {
37 namespace impl {
38 namespace cpu {
39 namespace x64 {
40 
41 namespace {
42 
43 using namespace memory_tracking::names;
44 using namespace Xbyak;
45 using acc_data_t = float;
46 
47 constexpr int bits_per_byte = 8;
48 
get_c_padded(const batch_normalization_pd_t * bdesc)49 dim_t get_c_padded(const batch_normalization_pd_t *bdesc) {
50     return bdesc->src_md()->padded_dims[1];
51 }
52 
53 template <cpu_isa_t isa>
get_vlen(jit_memory_tag_kind_t tag_kind)54 int get_vlen(jit_memory_tag_kind_t tag_kind) {
55     return isa == sse41 && tag_kind == jit_memory_tag_kind_t::blocked
56             ? 32
57             : cpu_isa_traits<isa>::vlen;
58 }
59 
60 template <cpu_isa_t isa>
get_simd_w(jit_memory_tag_kind_t tag_kind)61 int get_simd_w(jit_memory_tag_kind_t tag_kind) {
62     return get_vlen<isa>(tag_kind) / sizeof(acc_data_t);
63 }
64 
65 template <cpu_isa_t isa>
get_data_strides(const batch_normalization_pd_t * bdesc,jit_memory_tag_kind_t tag_kind)66 std::tuple<dim_t, dim_t, dim_t> get_data_strides(
67         const batch_normalization_pd_t *bdesc, jit_memory_tag_kind_t tag_kind) {
68     const int simd_w = get_simd_w<isa>(tag_kind);
69     size_t stride_N, stride_S, stride_C;
70 
71     if (tag_kind == jit_memory_tag_kind_t::nspc) {
72         stride_C = static_cast<size_t>(simd_w);
73         stride_S = static_cast<size_t>(bdesc->C());
74         stride_N = static_cast<size_t>(bdesc->D() * bdesc->H() * bdesc->W())
75                 * stride_S;
76     } else {
77         const size_t C_blks = static_cast<size_t>(get_c_padded(bdesc) / simd_w);
78 
79         stride_C = static_cast<size_t>(
80                 bdesc->D() * bdesc->H() * bdesc->W() * simd_w);
81         stride_S = static_cast<size_t>(simd_w);
82         stride_N = C_blks * stride_C;
83     }
84 
85     return std::make_tuple(stride_N, stride_S, stride_C);
86 }
87 
88 #define PARAM_ADDR(x) (reg_param_ + offsetof(call_params_t, x))
89 template <cpu_isa_t isa>
90 struct jit_bnorm_process_tail_t {
91     using Vmm = typename cpu_isa_traits<isa>::Vmm;
92 
jit_bnorm_process_tail_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t93     jit_bnorm_process_tail_t(const batch_normalization_pd_t *bdesc,
94             jit_generator *host, Reg64 reg_tmp, Reg64 reg_blk_has_tail,
95             Reg64 reg_C, Vmm vtail_mask, Opmask ktail_mask)
96         : h_(host)
97         , reg_tmp_(reg_tmp)
98         , reg_blk_has_tail_(reg_blk_has_tail)
99         , reg_C_(reg_C)
100         , vtail_mask_(vtail_mask)
101         , ktail_mask_(ktail_mask) {
102         const memory_desc_wrapper data_d(bdesc->src_md());
103         c_is_padded_ = bdesc->C() != data_d.padded_dims()[1];
104 
105         const int vlen = isa == sse41 ? 32 : cpu_isa_traits<isa>::vlen;
106         tail_ = bdesc->C() % (int)(vlen / sizeof(float));
107     }
108 
109     jit_generator *const h_;
110     const Reg64 reg_tmp_;
111     const Reg64 reg_blk_has_tail_;
112     const Reg64 reg_C_;
113     const Vmm vtail_mask_;
114     const Opmask ktail_mask_;
115     bool c_is_padded_;
116     int tail_;
117 
prepare_tail_mask_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t118     void prepare_tail_mask_avx512_common() {
119         if (!c_is_padded_) return;
120 
121         const int mask = (1 << tail_) - 1;
122 
123         Reg32 regw_tmp = reg_tmp_.cvt32();
124         h_->mov(regw_tmp, mask);
125         h_->kmovw(ktail_mask_, regw_tmp);
126     }
127 
prepare_tail_mask_avx2_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t128     void prepare_tail_mask_avx2_common() {
129         if (!c_is_padded_) return;
130 
131         static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
132                 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
133                 0, 0, 0, 0, 0, 0, 0};
134 
135         h_->mov(reg_tmp_, reinterpret_cast<size_t>(&mask[8 - tail_]));
136         h_->vmovups(vtail_mask_, h_->ptr[reg_tmp_]);
137     }
138 
prepare_taildnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t139     void prepare_tail() {
140         if (isa == avx512_common)
141             prepare_tail_mask_avx512_common();
142         else if (isa == avx2)
143             prepare_tail_mask_avx2_common();
144     }
145 
uni_vmovups_tail_avx2_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t146     void uni_vmovups_tail_avx2_common(
147             const Operand &dst, const Operand &src, Label &l_ret) {
148         if (dst.isMEM()) {
149             h_->vmaskmovps(dst.getAddress(), vtail_mask_, Vmm(src.getIdx()));
150         } else {
151             h_->vmaskmovps(Vmm(dst.getIdx()), vtail_mask_, src.getAddress());
152         }
153         h_->jmp(l_ret);
154     }
155 
uni_vmovups_tail_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t156     void uni_vmovups_tail_avx512_common(
157             const Operand &dst, const Operand &src, Label &l_ret) {
158         if (dst.isMEM())
159             h_->uni_vmovups(dst.getAddress() | ktail_mask_ | h_->T_z,
160                     Vmm(src.getIdx()));
161         else
162             h_->uni_vmovups(Vmm(dst.getIdx()) | ktail_mask_ | h_->T_z,
163                     src.getAddress());
164 
165         h_->jmp(l_ret);
166     }
167 
uni_vmovups_maybe_taildnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t168     void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
169         Label l_no_mask, l_ret;
170         if (c_is_padded_) {
171             h_->cmp(reg_blk_has_tail_, 0);
172             h_->jz(l_no_mask);
173 
174             h_->cmp(reg_C_, 1);
175             h_->jne(l_no_mask);
176             assert(isa == avx512_common || isa == avx2);
177             if (isa == avx512_common)
178                 uni_vmovups_tail_avx512_common(dst, src, l_ret);
179             else if (isa == avx2)
180                 uni_vmovups_tail_avx2_common(dst, src, l_ret);
181         }
182         h_->L(l_no_mask);
183         if (dst.isMEM())
184             h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
185         else
186             h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
187 
188         h_->L(l_ret);
189     }
190 };
191 
192 template <cpu_isa_t isa>
193 struct jit_bnorm_process_relu_t {
194     using Vmm = typename cpu_isa_traits<isa>::Vmm;
195 
jit_bnorm_process_relu_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t196     jit_bnorm_process_relu_t(const batch_normalization_pd_t *bdesc,
197             jit_generator *host, Reg64 reg_off_dat, Reg64 reg_tmp,
198             Reg64 reg_ptr_ws, Vmm vzero, Vmm vstore_mask, Opmask kstore_mask)
199         : h_(host)
200         , reg_off_dat_(reg_off_dat)
201         , reg_tmp_(reg_tmp)
202         , reg_ptr_ws_(reg_ptr_ws)
203         , vzero_(vzero)
204         , vstore_mask_(vstore_mask)
205         , kstore_mask_(kstore_mask) {
206         with_relu_ = bdesc->with_relu_post_op() || bdesc->fuse_norm_relu();
207         with_relu_inf_only_ = with_relu_
208                 && !(bdesc->fuse_norm_relu() && bdesc->is_training());
209 
210         bit_shift_ = static_cast<int>(log2(bits_per_byte
211                 * types::data_type_size(bdesc->desc()->data_desc.data_type)));
212     }
213 
214     jit_generator *const h_;
215     const Reg64 reg_off_dat_;
216     const Reg64 reg_tmp_;
217     const Reg64 reg_ptr_ws_;
218     const Vmm vzero_, vstore_mask_;
219     const Opmask kstore_mask_;
220     Label l_relu_mask_avx2_;
221     bool with_relu_, with_relu_inf_only_;
222     int bit_shift_;
223 
fwd_prepare_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t224     void fwd_prepare_relu() {
225         if (with_relu_) { h_->uni_vpxor(vzero_, vzero_, vzero_); }
226     }
227 
bwd_prepare_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t228     void bwd_prepare_relu() {
229         if (with_relu_) {
230             h_->uni_vpxor(vzero_, vzero_, vzero_);
231             if (isa == avx2) prepare_l_relu_mask_avx2();
232         }
233     }
234 
prepare_l_relu_mask_avx2dnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t235     void prepare_l_relu_mask_avx2() {
236         Label l_mask_after;
237         h_->jmp(l_mask_after);
238         h_->align(32);
239         h_->L(l_relu_mask_avx2_); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
240         for (int i = 0; i < 8; ++i)
241             h_->dd(1 << i);
242         h_->L(l_mask_after);
243     }
244 
fwd_process_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t245     void fwd_process_relu(Vmm v, const int off = 0) {
246         if (with_relu_inf_only_) {
247             h_->uni_vmaxps(v, v, vzero_);
248         } else if (with_relu_) {
249             if (isa == avx512_common)
250                 fwd_process_relu_avx512_common(v, off);
251             else if (isa == avx2)
252                 fwd_process_relu_avx2(v, off);
253             else
254                 assert(false);
255         }
256     }
257 
bwd_process_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t258     void bwd_process_relu(Vmm v, const int off = 0) {
259         if (with_relu_) {
260             if (isa == avx512_common)
261                 bwd_process_relu_avx512_common(v, off);
262             else if (isa == avx2)
263                 bwd_process_relu_avx2(v, off);
264             else
265                 assert(false);
266         }
267     }
268 
fwd_process_relu_avx2dnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t269     void fwd_process_relu_avx2(Vmm vdst, const int off = 0) {
270         Reg64 reg_store_mask = reg_tmp_;
271         h_->shr(reg_off_dat_, bit_shift_);
272         h_->vcmpps(vstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os);
273         h_->vmovmskps(reg_store_mask, vstore_mask_);
274         h_->mov(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off],
275                 reg_store_mask.cvt8());
276         h_->vblendvps(vdst, vzero_, vdst, vstore_mask_);
277         h_->shl(reg_off_dat_, bit_shift_);
278     }
279 
fwd_process_relu_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t280     void fwd_process_relu_avx512_common(Vmm vdst, const int off = 0) {
281         h_->shr(reg_off_dat_, bit_shift_);
282         h_->vcmpps(kstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os);
283         h_->kmovw(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off], kstore_mask_);
284         h_->vblendmps(vdst | kstore_mask_, vzero_, vdst);
285         h_->shl(reg_off_dat_, bit_shift_);
286     }
287 
bwd_process_relu_avx2dnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t288     void bwd_process_relu_avx2(Vmm vdiff_dst, const int off = 0) {
289         h_->shr(reg_off_dat_, bit_shift_);
290         h_->vpbroadcastb(
291                 vstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]);
292         h_->vpand(vstore_mask_, vstore_mask_,
293                 h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]);
294         h_->vpcmpeqd(vstore_mask_, vstore_mask_,
295                 h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]);
296         h_->vblendvps(vdiff_dst, vzero_, vdiff_dst, vstore_mask_);
297         h_->shl(reg_off_dat_, bit_shift_);
298     }
299 
bwd_process_relu_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t300     void bwd_process_relu_avx512_common(Vmm vdiff_dst, const int off = 0) {
301         h_->shr(reg_off_dat_, bit_shift_);
302         h_->kmovw(kstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]);
303         h_->vmovups(vdiff_dst | kstore_mask_ | h_->T_z, vdiff_dst);
304         h_->shl(reg_off_dat_, bit_shift_);
305     }
306 };
307 
308 template <cpu_isa_t isa>
309 struct jit_bnorm_bf16_emulation_t {
310     using Vmm = typename cpu_isa_traits<isa>::Vmm;
311 
jit_bnorm_bf16_emulation_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bf16_emulation_t312     jit_bnorm_bf16_emulation_t(const batch_normalization_pd_t *bdesc,
313             jit_generator *host, Zmm zmm_reserved_1, Zmm zmm_reserved_2,
314             Zmm zmm_reserved_3, Zmm zmm_reserved_4, Reg64 reg_tmp)
315         : h_(host), bf16_emu_(nullptr) {
316         is_bf16_ = bdesc->desc()->data_desc.data_type == data_type::bf16;
317         if (is_bf16_ && !mayiuse(avx512_core_bf16)) {
318             bf16_emu_ = utils::make_unique<bf16_emulation_t>(h_, zmm_reserved_1,
319                     zmm_reserved_2, zmm_reserved_3, reg_tmp, zmm_reserved_4,
320                     zmm_reserved_4);
321             bf16_emu_->init_vcvtneps2bf16();
322         }
323     }
324 
325     jit_generator *const h_;
326     std::unique_ptr<bf16_emulation_t> bf16_emu_;
327     bool is_bf16_;
328 
uni_vmovups_datadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bf16_emulation_t329     void uni_vmovups_data(const Operand &dst, const Operand &src) {
330         if (dst.isMEM()) {
331             if (is_bf16_) {
332                 constexpr bool isAvx2 = isa == avx2;
333                 const typename std::conditional<isAvx2, Xmm, Ymm>::type
334                         dst_reg {src.getIdx()};
335                 const typename std::conditional<isAvx2, Ymm, Zmm>::type
336                         src_reg {src.getIdx()};
337 
338                 // convert f32 output to bf16
339                 if (!bf16_emu_)
340                     h_->vcvtneps2bf16(dst_reg, src_reg);
341                 else
342                     bf16_emu_->vcvtneps2bf16(dst_reg, src_reg);
343 
344                 h_->vmovdqu16(dst.getAddress(), dst_reg);
345             } else {
346                 h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
347             }
348         } else {
349             if (is_bf16_) {
350                 // convert bf16 input to f32
351                 h_->vpmovzxwd(Vmm(dst.getIdx()), src.getAddress());
352                 h_->vpslld(Vmm(dst.getIdx()), Vmm(dst.getIdx()), 0x10);
353             } else {
354                 h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
355             }
356         }
357     }
358 
359 private:
360     DNNL_DISALLOW_COPY_AND_ASSIGN(jit_bnorm_bf16_emulation_t);
361 };
362 
363 template <cpu_isa_t isa>
364 struct jit_bnorm_fwd_statistics_t : public jit_generator {
365     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_statistics_t)
366     using Vmm = typename cpu_isa_traits<isa>::Vmm;
367 
368     const AddressFrame &vmmword
369             = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
370 
371     struct call_params_t {
372         size_t N, C, S;
373         const void *src;
374         const acc_data_t *mean;
375         const acc_data_t *var;
376         size_t blk_has_tail;
377         size_t do_normalise;
378     };
379 
380     const Reg64 &reg_param_ = abi_param1;
381     const Reg64 &reg_tmp_ = abi_not_param1;
382     const Reg64 &reg_N_ = rsi;
383     const Reg64 &reg_S_ = rax;
384     const Reg64 &reg_C_ = rdx;
385     const Reg64 &reg_off_c_ = rbx;
386     const Reg64 &reg_blk_has_tail_ = rbp;
387 
388     const Reg64 &reg_off_dat_ = r8;
389     const Reg64 &reg_off_dat_save_ = r9;
390     const Reg64 &reg_ptr_mean_ = r10;
391     const Reg64 &reg_ptr_var_ = r11;
392     const Reg64 &reg_ptr_src_ = r12;
393     const Reg64 &reg_do_normalise_ = r13;
394     const Reg64 &reg_ptr_stat_ = r14;
395 
396     const Vmm v_ = Vmm(0);
397     const Vmm vtmp_ = Vmm(1);
398     const Vmm vtail_mask_ = Vmm(2);
399     const Vmm vNS_ = Vmm(3);
400     const Vmm vzero_ = Vmm(4);
401     // When variance is computed then two vmms(one for variance and
402     // one for mean) are needed to unroll one c block at any moment,
403     // therefore the number of registers which are used to unrolling
404     // must to be divisible by two.
405     static constexpr int min_idx_to_unroll_ = 4;
406     static constexpr int max_idx_to_unroll_ = isa == avx512_common ? 28 : 16;
407     static constexpr int number_of_vmms_to_unrolling_variables_
408             = max_idx_to_unroll_ - min_idx_to_unroll_;
409     static_assert(number_of_vmms_to_unrolling_variables_ % 2 == 0
410                     && number_of_vmms_to_unrolling_variables_ != 0,
411             "Number of register to unrolling must to be divisible by 2.");
412 
413     const Opmask &ktail_mask_ = k2;
414 
415     const batch_normalization_pd_t *bdesc_;
416     const jit_memory_tag_kind_t tag_kind_;
417     const int vlen;
418     const int simd_w;
419     jit_bnorm_process_tail_t<isa> jit_tail_;
420     jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
421     int stride_N_, stride_S_, stride_C_;
422     size_t data_type_size_, acc_type_size_;
423 
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t424     void load_common_params() {
425 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
426         mov(reg_ptr_src_, PARAM_PTR(src));
427         mov(reg_ptr_mean_, PARAM_PTR(mean));
428         mov(reg_ptr_var_, PARAM_PTR(var));
429 #undef PARAM_PTR
430         mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
431         mov(reg_do_normalise_, dword[PARAM_ADDR(do_normalise)]);
432     }
433 
zeroisednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t434     void zeroise() {
435         Label label_zeroise;
436         xor_(reg_off_c_, reg_off_c_);
437         uni_vpxor(vzero_, vzero_, vzero_);
438         mov(reg_C_, dword[PARAM_ADDR(C)]);
439         L(label_zeroise);
440         {
441             jit_tail_.uni_vmovups_maybe_tail(
442                     vmmword[reg_ptr_stat_ + reg_off_c_], vzero_);
443             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
444                 jit_tail_.uni_vmovups_maybe_tail(
445                         vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], vzero_);
446             }
447             add(reg_off_c_, simd_w * acc_type_size_);
448             dec(reg_C_);
449             jnz(label_zeroise);
450         }
451     }
452 
load_statdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t453     void load_stat(bool compute_mean, const int c_blks_to_unroll = 1) {
454         int start_idx = min_idx_to_unroll_;
455         int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
456         const int step = simd_w * acc_type_size_;
457 
458         // load mean or variance
459         for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
460             const Vmm vstat = Vmm(idx);
461             jit_tail_.uni_vmovups_maybe_tail(
462                     vstat, vmmword[reg_ptr_stat_ + reg_off_c_ + off]);
463         }
464 
465         // if variance is counted then mean also is needed
466         if (!compute_mean) {
467             start_idx = min_idx_to_unroll_ + c_blks_to_unroll;
468             end_idx = min_idx_to_unroll_ + 2 * c_blks_to_unroll;
469 
470             for (int idx = start_idx, off = 0; idx < end_idx;
471                     idx++, off += step) {
472                 const Vmm vmean = Vmm(idx);
473                 jit_tail_.uni_vmovups_maybe_tail(
474                         vmean, vmmword[reg_ptr_mean_ + reg_off_c_ + off]);
475             }
476         }
477     }
478 
compute_statdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t479     void compute_stat(bool compute_mean, const int c_blks_to_unroll = 1) {
480         const int start_idx = min_idx_to_unroll_;
481         const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
482         const int step = simd_w * data_type_size_;
483 
484         for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
485             const Vmm vstat = Vmm(idx);
486 
487             jit_bf16_emu_.uni_vmovups_data(
488                     v_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
489 
490             if (compute_mean) {
491                 uni_vaddps(vstat, vstat, v_);
492             } else {
493                 const Vmm vmean = Vmm(idx + c_blks_to_unroll);
494 
495                 // var += (src - mean)^2
496                 uni_vsubps(vtmp_, v_, vmean, vtmp_);
497                 uni_vfmadd231ps(vstat, vtmp_, vtmp_);
498             }
499         }
500     }
501 
store_statdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t502     void store_stat(const int c_blks_to_unroll = 1) {
503         const int start_idx = min_idx_to_unroll_;
504         const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
505         const int step = simd_w * acc_type_size_;
506 
507         for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
508             const Vmm vstat = Vmm(idx);
509 
510             jit_tail_.uni_vmovups_maybe_tail(
511                     vmmword[reg_ptr_stat_ + reg_off_c_ + off], vstat);
512         }
513     }
514 
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t515     void compute_blocked(bool compute_mean) {
516         Label label_C, label_S;
517         mov(reg_C_, dword[PARAM_ADDR(C)]);
518         L(label_C);
519         {
520             mov(reg_off_dat_, reg_off_dat_save_);
521 
522             load_stat(compute_mean);
523 
524             mov(reg_S_, dword[PARAM_ADDR(S)]);
525             L(label_S);
526             {
527                 compute_stat(compute_mean);
528 
529                 add(reg_off_dat_, stride_S_ * data_type_size_);
530 
531                 dec(reg_S_);
532                 jnz(label_S);
533             }
534 
535             store_stat();
536 
537             add(reg_off_dat_save_, stride_C_ * data_type_size_);
538             add(reg_off_c_, simd_w * acc_type_size_);
539 
540             dec(reg_C_);
541             jnz(label_C);
542         }
543     }
544 
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t545     void compute_nspc(bool compute_mean) {
546         mov(reg_C_, dword[PARAM_ADDR(C)]);
547 
548         // When a variance is computed, two values are unrolled: mean and variance,
549         // so number_of_vmms_to_unrolling_variables_ is divided by 2.
550         const int max_of_unrolled_c_blks = compute_mean
551                 ? number_of_vmms_to_unrolling_variables_
552                 : number_of_vmms_to_unrolling_variables_ / 2;
553         std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1);
554 
555         for (int c_blks_to_unroll = max_of_unrolled_c_blks;
556                 c_blks_to_unroll > 0; --c_blks_to_unroll) {
557             L(c_unroll_label[c_blks_to_unroll]);
558             {
559                 cmp(reg_C_, c_blks_to_unroll);
560                 jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR);
561 
562                 mov(reg_off_dat_, reg_off_dat_save_);
563 
564                 load_stat(compute_mean, c_blks_to_unroll);
565 
566                 Label label_S;
567                 mov(reg_S_, dword[PARAM_ADDR(S)]);
568                 L(label_S);
569                 {
570                     compute_stat(compute_mean, c_blks_to_unroll);
571 
572                     add(reg_off_dat_, stride_S_ * data_type_size_);
573 
574                     dec(reg_S_);
575                     jnz(label_S);
576                 }
577 
578                 store_stat(c_blks_to_unroll);
579 
580                 add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_);
581                 add(reg_off_dat_save_,
582                         c_blks_to_unroll * stride_C_ * data_type_size_);
583 
584                 sub(reg_C_, c_blks_to_unroll);
585                 jmp(c_unroll_label[c_blks_to_unroll], T_NEAR);
586             }
587         }
588         L(c_unroll_label[0]);
589     }
590 
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t591     void compute(bool compute_mean) {
592         Label label_N;
593         mov(reg_N_, dword[PARAM_ADDR(N)]);
594         L(label_N);
595         {
596             xor_(reg_off_dat_save_, reg_off_dat_save_);
597             xor_(reg_off_c_, reg_off_c_);
598 
599             tag_kind_ == jit_memory_tag_kind_t::nspc
600                     ? compute_nspc(compute_mean)
601                     : compute_blocked(compute_mean);
602 
603             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
604                 xor_(reg_off_dat_save_, reg_off_dat_save_);
605                 xor_(reg_off_c_, reg_off_c_);
606                 add(reg_off_dat_save_, vlen / 2);
607                 add(reg_off_c_, vlen / 2);
608 
609                 compute_blocked(compute_mean);
610             }
611 
612             add(reg_ptr_src_, stride_N_ * data_type_size_);
613             dec(reg_N_);
614             jnz(label_N);
615         }
616     }
617 
normalizednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t618     void normalize() {
619         Label label_ret, label_normalise;
620         cmp(reg_do_normalise_, 0);
621         jz(label_ret);
622 
623         const int S = bdesc_->D() * bdesc_->H() * bdesc_->W();
624         mov(reg_tmp_, float2int(bdesc_->MB() * S));
625         Xmm xtmp = Xmm(vtmp_.getIdx());
626         uni_vmovq(xtmp, reg_tmp_);
627         uni_vbroadcastss(vNS_, xtmp);
628 
629         xor_(reg_off_c_, reg_off_c_);
630         mov(reg_C_, dword[PARAM_ADDR(C)]);
631         L(label_normalise);
632         {
633             jit_tail_.uni_vmovups_maybe_tail(
634                     v_, vmmword[reg_ptr_stat_ + reg_off_c_]);
635             uni_vdivps(v_, v_, vNS_);
636             jit_tail_.uni_vmovups_maybe_tail(
637                     vmmword[reg_ptr_stat_ + reg_off_c_], v_);
638 
639             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
640                 jit_tail_.uni_vmovups_maybe_tail(
641                         v_, vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2]);
642                 uni_vdivps(v_, v_, vNS_);
643                 jit_tail_.uni_vmovups_maybe_tail(
644                         vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], v_);
645             }
646 
647             add(reg_off_c_, simd_w * acc_type_size_);
648             dec(reg_C_);
649             jnz(label_normalise);
650         }
651 
652         L(label_ret);
653     }
654 
jit_bnorm_fwd_statistics_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t655     jit_bnorm_fwd_statistics_t(const batch_normalization_pd_t *bdesc,
656             const jit_memory_tag_kind_t tag_kind)
657         : bdesc_(bdesc)
658         , tag_kind_(tag_kind)
659         , vlen(get_vlen<isa>(tag_kind))
660         , simd_w(get_simd_w<isa>(tag_kind))
661         , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
662                   vtail_mask_, ktail_mask_)
663         , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
664         static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
665                 "unsupported isa");
666 
667         std::tie(stride_N_, stride_S_, stride_C_)
668                 = get_data_strides<isa>(bdesc_, tag_kind);
669 
670         data_type_size_
671                 = types::data_type_size(bdesc->desc()->data_desc.data_type);
672         acc_type_size_ = sizeof(acc_data_t);
673     }
674 };
675 
676 template <cpu_isa_t isa>
677 struct jit_bnorm_fwd_mean_t : jit_bnorm_fwd_statistics_t<isa> {
678     using call_params_t =
679             typename jit_bnorm_fwd_statistics_t<isa>::call_params_t;
680 
jit_bnorm_fwd_mean_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_mean_t681     jit_bnorm_fwd_mean_t(const batch_normalization_pd_t *bdesc,
682             const jit_memory_tag_kind_t tag_kind)
683         : jit_bnorm_fwd_statistics_t<isa>(bdesc, tag_kind) {}
684 
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_mean_t685     void generate() override {
686         this->preamble();
687         this->load_common_params();
688         this->mov(this->reg_ptr_stat_, this->reg_ptr_mean_);
689         this->jit_tail_.prepare_tail();
690         this->zeroise();
691         this->compute(true);
692         this->normalize();
693         this->postamble();
694     }
695 };
696 
697 template <cpu_isa_t isa>
698 struct jit_bnorm_fwd_var_t : jit_bnorm_fwd_statistics_t<isa> {
699     using call_params_t =
700             typename jit_bnorm_fwd_statistics_t<isa>::call_params_t;
701 
jit_bnorm_fwd_var_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_var_t702     jit_bnorm_fwd_var_t(const batch_normalization_pd_t *bdesc,
703             const jit_memory_tag_kind_t tag_kind)
704         : jit_bnorm_fwd_statistics_t<isa>(bdesc, tag_kind) {}
705 
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_var_t706     void generate() override {
707         this->preamble();
708         this->load_common_params();
709         this->mov(this->reg_ptr_stat_, this->reg_ptr_var_);
710         this->jit_tail_.prepare_tail();
711         this->zeroise();
712         this->compute(false);
713         this->normalize();
714         this->postamble();
715     }
716 };
717 
718 template <cpu_isa_t isa>
719 struct jit_bnorm_fwd_t : public jit_generator {
720     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_t)
721     using Vmm = typename cpu_isa_traits<isa>::Vmm;
722 
723     const AddressFrame &vmmword
724             = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
725 
726     struct call_params_t {
727         size_t N, C, S;
728         const void *src, *dst;
729         const uint8_t *ws;
730         const acc_data_t *mean, *var;
731         const acc_data_t *scale, *shift;
732         size_t blk_has_tail;
733     };
734 
735     const Reg64 &reg_param_ = abi_param1;
736     const Reg64 &reg_tmp_ = abi_not_param1;
737     const Reg64 &reg_N_ = rsi;
738     const Reg64 &reg_S_ = rax;
739     const Reg64 &reg_C_ = rdx;
740     const Reg64 &reg_off_c_ = rbx;
741     const Reg64 &reg_blk_has_tail_ = rbp;
742 
743     const Reg64 &reg_off_dat_ = r8;
744     const Reg64 &reg_off_dat_save_ = r9;
745     const Reg64 &reg_ptr_ws_ = r10;
746     const Reg64 &reg_ptr_scale_ = r11;
747     const Reg64 &reg_ptr_shift_ = reg_N_;
748     const Reg64 &reg_ptr_var_ = r12;
749     const Reg64 &reg_ptr_mean_ = r13;
750     const Reg64 &reg_ptr_dst_ = r14;
751     const Reg64 &reg_ptr_src_ = r15;
752 
753     const Vmm vzero_ = Vmm(0);
754     const Vmm vone_ = Vmm(1);
755     const Vmm vmean_ = Vmm(2);
756     const Vmm vvar_ = Vmm(3);
757     const Vmm vsqrtvar_ = Vmm(4);
758     const Vmm vgamma_ = Vmm(5);
759     const Vmm vbeta_ = Vmm(6);
760     const Vmm veps_ = Vmm(7);
761     const Vmm vtmp_ = Vmm(8);
762     const Vmm v_ = Vmm(9);
763     const Vmm vtail_mask_ = Vmm(10);
764     const Vmm vstore_mask_ = vtmp_;
765 
766     const Opmask &kstore_mask_ = k1;
767     const Opmask &ktail_mask_ = k2;
768 
769     const batch_normalization_pd_t *bdesc_;
770     const jit_memory_tag_kind_t tag_kind_;
771     const int vlen;
772     const int simd_w;
773     jit_bnorm_process_tail_t<isa> jit_tail_;
774     jit_bnorm_process_relu_t<isa> jit_relu_;
775     jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
776     int stride_N_, stride_S_, stride_C_;
777     size_t data_type_size_, acc_type_size_;
778 
779     enum {
780         stack_off_N = 0,
781         stack_off_shift = 8,
782         stack_size_required = 16,
783     };
784 
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t785     void load_common_params() {
786 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
787         mov(reg_ptr_src_, PARAM_PTR(src));
788         mov(reg_ptr_dst_, PARAM_PTR(dst));
789         mov(reg_ptr_mean_, PARAM_PTR(mean));
790         mov(reg_ptr_var_, PARAM_PTR(var));
791         mov(reg_ptr_scale_, PARAM_PTR(scale));
792         mov(reg_ptr_ws_, PARAM_PTR(ws));
793 
794         Xmm x = Xmm(v_.getIdx());
795 
796         mov(reg_tmp_, float2int(bdesc_->desc()->batch_norm_epsilon));
797         uni_vmovq(x, reg_tmp_);
798         uni_vbroadcastss(veps_, x);
799 
800         mov(reg_tmp_, float2int(1.f));
801         uni_vmovq(x, reg_tmp_);
802         uni_vbroadcastss(vone_, x);
803 
804         mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
805 
806         mov(reg_tmp_, PARAM_PTR(shift));
807         mov(ptr[rsp + stack_off_shift], reg_tmp_);
808         mov(reg_tmp_, PARAM_PTR(N));
809         mov(ptr[rsp + stack_off_N], reg_tmp_);
810 #undef PARAM_PTR
811     }
812 
load_c_specificsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t813     void load_c_specifics() {
814         jit_tail_.uni_vmovups_maybe_tail(
815                 vmean_, vmmword[reg_ptr_mean_ + reg_off_c_]);
816         jit_tail_.uni_vmovups_maybe_tail(
817                 vvar_, vmmword[reg_ptr_var_ + reg_off_c_]);
818 
819         uni_vmovups(vsqrtvar_, vvar_);
820         uni_vaddps(vsqrtvar_, vsqrtvar_, veps_);
821         uni_vsqrtps(vsqrtvar_, vsqrtvar_);
822 
823         if (isa == sse41) {
824             movups(vtmp_, vone_);
825             divps(vtmp_, vsqrtvar_);
826             movups(vsqrtvar_, vtmp_);
827         } else
828             vdivps(vsqrtvar_, vone_, vsqrtvar_);
829 
830         if (bdesc_->use_scaleshift() || bdesc_->use_scale())
831             jit_tail_.uni_vmovups_maybe_tail(
832                     vgamma_, vmmword[reg_ptr_scale_ + reg_off_c_]);
833         if (bdesc_->use_scaleshift() || bdesc_->use_shift())
834             jit_tail_.uni_vmovups_maybe_tail(
835                     vbeta_, vmmword[reg_ptr_shift_ + reg_off_c_]);
836     }
837 
compute_bnormdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t838     void compute_bnorm(bool stream_store_allowed) {
839         jit_bf16_emu_.uni_vmovups_data(
840                 v_, vmmword[reg_ptr_src_ + reg_off_dat_]);
841         uni_vsubps(v_, v_, vmean_);
842         uni_vmulps(v_, v_, vsqrtvar_);
843 
844         if (bdesc_->use_scaleshift()
845                 || (bdesc_->use_scale() && bdesc_->use_shift()))
846             uni_vfmadd213ps(v_, vgamma_, vbeta_);
847         else if (bdesc_->use_scale())
848             uni_vmulps(v_, v_, vgamma_);
849         else if (bdesc_->use_shift())
850             uni_vaddps(v_, v_, vbeta_);
851 
852         jit_relu_.fwd_process_relu(v_);
853 
854         if (stream_store_allowed) {
855             uni_vmovntps(vmmword[reg_ptr_dst_ + reg_off_dat_], v_);
856         } else {
857             jit_bf16_emu_.uni_vmovups_data(
858                     vmmword[reg_ptr_dst_ + reg_off_dat_], v_);
859         }
860     }
861 
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t862     void compute_blocked(bool stream_store_allowed) {
863         Label label_C, label_S;
864         mov(reg_C_, dword[PARAM_ADDR(C)]);
865         L(label_C);
866         {
867             mov(reg_off_dat_, reg_off_dat_save_);
868 
869             load_c_specifics();
870 
871             mov(reg_S_, dword[PARAM_ADDR(S)]);
872             L(label_S);
873             {
874                 compute_bnorm(stream_store_allowed);
875 
876                 add(reg_off_dat_, stride_S_ * data_type_size_);
877 
878                 dec(reg_S_);
879                 jnz(label_S);
880             }
881 
882             add(reg_off_dat_save_, stride_C_ * data_type_size_);
883             add(reg_off_c_, simd_w * acc_type_size_);
884 
885             dec(reg_C_);
886             jnz(label_C);
887         }
888     }
889 
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t890     void compute_nspc(bool stream_store_allowed) {
891         Label label_C, label_S;
892         mov(reg_S_, dword[PARAM_ADDR(S)]);
893         L(label_S);
894         {
895             mov(reg_off_dat_, reg_off_dat_save_);
896             xor_(reg_off_c_, reg_off_c_);
897 
898             mov(reg_C_, dword[PARAM_ADDR(C)]);
899             L(label_C);
900             {
901                 load_c_specifics();
902 
903                 compute_bnorm(stream_store_allowed);
904 
905                 add(reg_off_c_, simd_w * acc_type_size_);
906                 add(reg_off_dat_, stride_C_ * data_type_size_);
907 
908                 dec(reg_C_);
909                 jnz(label_C);
910             }
911 
912             add(reg_off_dat_save_, stride_S_ * data_type_size_);
913 
914             dec(reg_S_);
915             jnz(label_S);
916         }
917     }
918 
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t919     void compute(bool stream_store_allowed) {
920         Label label_N;
921         mov(reg_N_, ptr[rsp + stack_off_N]);
922         L(label_N);
923         {
924             // save reg_N_, because register is shared with reg_ptr_shift_
925             mov(ptr[rsp + stack_off_N], reg_N_);
926             mov(reg_ptr_shift_, ptr[rsp + stack_off_shift]);
927 
928             xor_(reg_off_dat_save_, reg_off_dat_save_);
929             xor_(reg_off_c_, reg_off_c_);
930 
931             tag_kind_ == jit_memory_tag_kind_t::nspc
932                     ? compute_nspc(stream_store_allowed)
933                     : compute_blocked(stream_store_allowed);
934 
935             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
936                 xor_(reg_off_dat_save_, reg_off_dat_save_);
937                 xor_(reg_off_c_, reg_off_c_);
938                 add(reg_off_dat_save_, vlen / 2);
939                 add(reg_off_c_, vlen / 2);
940 
941                 compute_blocked(stream_store_allowed);
942             }
943 
944             add(reg_ptr_src_, stride_N_ * data_type_size_);
945             add(reg_ptr_dst_, stride_N_ * data_type_size_);
946             add(reg_ptr_ws_, stride_N_ / bits_per_byte);
947 
948             // restore reg_N_, because register is shared with reg_ptr_shift_
949             mov(reg_N_, ptr[rsp + stack_off_N]);
950             dec(reg_N_);
951             jnz(label_N);
952         }
953     }
954 
jit_bnorm_fwd_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t955     jit_bnorm_fwd_t(const batch_normalization_pd_t *bdesc,
956             const jit_memory_tag_kind_t tag_kind)
957         : bdesc_(bdesc)
958         , tag_kind_(tag_kind)
959         , vlen(get_vlen<isa>(tag_kind))
960         , simd_w(get_simd_w<isa>(tag_kind))
961         , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
962                   vtail_mask_, ktail_mask_)
963         , jit_relu_(bdesc, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
964                   vstore_mask_, kstore_mask_)
965         , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
966         static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
967                 "unsupported isa");
968 
969         std::tie(stride_N_, stride_S_, stride_C_)
970                 = get_data_strides<isa>(bdesc_, tag_kind);
971 
972         data_type_size_
973                 = types::data_type_size(bdesc->desc()->data_desc.data_type);
974         acc_type_size_ = sizeof(acc_data_t);
975     }
976 
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t977     void generate() override {
978         bool is_bf16 = bdesc_->desc()->data_desc.data_type == data_type::bf16;
979         const bool is_tail_in_nspc_format
980                 = tag_kind_ == jit_memory_tag_kind_t::nspc
981                 && jit_tail_.tail_ != 0;
982         const bool stream_store_allowed = !is_bf16 && !is_tail_in_nspc_format;
983 
984         preamble();
985         sub(rsp, stack_size_required);
986         load_common_params();
987         jit_relu_.fwd_prepare_relu();
988         jit_tail_.prepare_tail();
989 
990         Label normal_store, end_store;
991         test(reg_ptr_dst_, vlen - 1);
992         jnz(normal_store, T_NEAR);
993         compute(stream_store_allowed);
994         jmp(end_store, T_NEAR);
995         L(normal_store);
996         { compute(false); }
997         L(end_store);
998 
999         add(rsp, stack_size_required);
1000         postamble();
1001     }
1002 };
1003 
1004 template <cpu_isa_t isa>
1005 struct jit_bnorm_bwd_t : public jit_generator {
1006     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_t)
1007     using Vmm = typename cpu_isa_traits<isa>::Vmm;
1008 
1009     const AddressFrame &vmmword
1010             = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
1011 
1012     struct call_params_t {
1013         size_t N, C, S;
1014         const void *src, *diff_src, *diff_dst;
1015         const uint8_t *ws;
1016         const acc_data_t *mean, *var;
1017         const acc_data_t *scale, *diff_scale, *diff_shift;
1018         size_t blk_has_tail;
1019     };
1020 
1021     const Reg64 &reg_param_ = abi_param1;
1022     const Reg64 &reg_tmp_ = abi_not_param1;
1023     const Reg64 &reg_N_ = rsi;
1024     const Reg64 &reg_S_ = rax;
1025     const Reg64 &reg_C_ = rdx;
1026     const Reg64 &reg_off_c_ = rbx;
1027     const Reg64 &reg_blk_has_tail_ = rbp;
1028 
1029     const Reg64 &reg_off_dat_ = r8;
1030     const Reg64 &reg_off_dat_save_ = r9;
1031     const Reg64 &reg_ptr_c_ = r10;
1032     const Reg64 &reg_ptr_ws_ = r11;
1033     const Reg64 &reg_ptr_diff_dst_ = r12;
1034     const Reg64 &reg_ptr_diff_src_ = r13;
1035     const Reg64 &reg_ptr_src_ = r14;
1036 
1037     const Vmm vzero_ = Vmm(0);
1038     const Vmm vone_ = Vmm(1);
1039     const Vmm vmean_ = Vmm(2);
1040     const Vmm vsqrtvar_ = Vmm(3);
1041     const Vmm vgamma_ = Vmm(4);
1042     const Vmm vdiff_gamma_ = Vmm(5);
1043     const Vmm vdiff_beta_ = Vmm(6);
1044     const Vmm veps_ = Vmm(7);
1045     const Vmm vNS_ = Vmm(8);
1046     const Vmm vtmp_ = Vmm(9);
1047     const Vmm v_ = Vmm(10);
1048     const Vmm vtail_mask_ = Vmm(11);
1049     const Vmm vstore_mask_ = vtmp_;
1050 
1051     const Opmask &kstore_mask_ = k1;
1052     const Opmask &ktail_mask_ = k2;
1053 
1054     const batch_normalization_pd_t *bdesc_;
1055     const jit_memory_tag_kind_t tag_kind_;
1056     const int vlen;
1057     const int simd_w;
1058     jit_bnorm_process_tail_t<isa> jit_tail_;
1059     jit_bnorm_process_relu_t<isa> jit_relu_;
1060     jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
1061     int stride_N_, stride_S_, stride_C_;
1062     size_t data_type_size_, acc_type_size_;
1063 
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1064     void load_common_params() {
1065 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
1066         mov(reg_ptr_src_, PARAM_PTR(src));
1067         mov(reg_ptr_diff_src_, PARAM_PTR(diff_src));
1068         mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst));
1069         mov(reg_ptr_ws_, PARAM_PTR(ws));
1070 #undef PARAM_PTR
1071 
1072         Xmm x = Xmm(v_.getIdx());
1073 
1074         mov(reg_tmp_, float2int(bdesc_->desc()->batch_norm_epsilon));
1075         uni_vmovq(x, reg_tmp_);
1076         uni_vbroadcastss(veps_, x);
1077 
1078         mov(reg_tmp_, float2int(1.f));
1079         uni_vmovq(x, reg_tmp_);
1080         uni_vbroadcastss(vone_, x);
1081 
1082         const int S = bdesc_->D() * bdesc_->H() * bdesc_->W();
1083         mov(reg_tmp_, float2int(bdesc_->MB() * S));
1084         uni_vmovq(x, reg_tmp_);
1085         uni_vbroadcastss(vNS_, x);
1086 
1087         mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
1088     }
1089 
load_c_specificsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1090     void load_c_specifics() {
1091         mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]);
1092         jit_tail_.uni_vmovups_maybe_tail(
1093                 vmean_, vmmword[reg_ptr_c_ + reg_off_c_]);
1094 
1095         mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]);
1096         jit_tail_.uni_vmovups_maybe_tail(
1097                 vsqrtvar_, vmmword[reg_ptr_c_ + reg_off_c_]);
1098         uni_vaddps(vsqrtvar_, vsqrtvar_, veps_);
1099         uni_vsqrtps(vsqrtvar_, vsqrtvar_);
1100 
1101         if (isa == sse41) {
1102             movups(vtmp_, vone_);
1103             divps(vtmp_, vsqrtvar_);
1104             movups(vsqrtvar_, vtmp_);
1105         } else
1106             vdivps(vsqrtvar_, vone_, vsqrtvar_);
1107 
1108         if (bdesc_->use_scaleshift() || bdesc_->use_scale()) {
1109             mov(reg_ptr_c_, ptr[PARAM_ADDR(scale)]);
1110             jit_tail_.uni_vmovups_maybe_tail(
1111                     vgamma_, vmmword[reg_ptr_c_ + reg_off_c_]);
1112         }
1113 
1114         if (calculate_diff_stats()) {
1115             mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_scale)]);
1116             jit_tail_.uni_vmovups_maybe_tail(
1117                     vdiff_gamma_, vmmword[reg_ptr_c_ + reg_off_c_]);
1118             uni_vmulps(vdiff_gamma_, vdiff_gamma_, vsqrtvar_);
1119             uni_vdivps(vdiff_gamma_, vdiff_gamma_, vNS_);
1120             mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_shift)]);
1121             jit_tail_.uni_vmovups_maybe_tail(
1122                     vdiff_beta_, vmmword[reg_ptr_c_ + reg_off_c_]);
1123             uni_vdivps(vdiff_beta_, vdiff_beta_, vNS_);
1124         }
1125     }
1126 
compute_bnormdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1127     void compute_bnorm(bool stream_store_allowed) {
1128         jit_bf16_emu_.uni_vmovups_data(
1129                 v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_]);
1130         jit_relu_.bwd_process_relu(v_);
1131 
1132         if (calculate_diff_stats()) {
1133             uni_vsubps(v_, v_, vdiff_beta_);
1134             jit_bf16_emu_.uni_vmovups_data(
1135                     vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_]);
1136             uni_vsubps(vtmp_, vtmp_, vmean_);
1137             uni_vmulps(vtmp_, vtmp_, vdiff_gamma_);
1138             uni_vsubps(v_, v_, vtmp_);
1139         }
1140 
1141         if (bdesc_->use_scaleshift() || bdesc_->use_scale())
1142             uni_vmulps(v_, v_, vgamma_);
1143         uni_vmulps(v_, v_, vsqrtvar_);
1144 
1145         if (stream_store_allowed) {
1146             uni_vmovntps(vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_);
1147         } else {
1148             jit_bf16_emu_.uni_vmovups_data(
1149                     vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_);
1150         }
1151     }
1152 
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1153     void compute_blocked(bool stream_store_allowed) {
1154         Label label_C, label_S;
1155         mov(reg_C_, dword[PARAM_ADDR(C)]);
1156         L(label_C);
1157         {
1158             mov(reg_off_dat_, reg_off_dat_save_);
1159 
1160             load_c_specifics();
1161 
1162             mov(reg_S_, dword[PARAM_ADDR(S)]);
1163             L(label_S);
1164             {
1165                 compute_bnorm(stream_store_allowed);
1166 
1167                 add(reg_off_dat_, stride_S_ * data_type_size_);
1168 
1169                 dec(reg_S_);
1170                 jnz(label_S);
1171             }
1172 
1173             add(reg_off_dat_save_, stride_C_ * data_type_size_);
1174             add(reg_off_c_, simd_w * acc_type_size_);
1175 
1176             dec(reg_C_);
1177             jnz(label_C);
1178         }
1179     }
1180 
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1181     void compute_nspc(bool stream_store_allowed) {
1182         Label label_C, label_S;
1183         mov(reg_S_, dword[PARAM_ADDR(S)]);
1184         L(label_S);
1185         {
1186             mov(reg_off_dat_, reg_off_dat_save_);
1187             xor_(reg_off_c_, reg_off_c_);
1188 
1189             mov(reg_C_, dword[PARAM_ADDR(C)]);
1190             L(label_C);
1191             {
1192                 load_c_specifics();
1193 
1194                 compute_bnorm(stream_store_allowed);
1195 
1196                 add(reg_off_c_, simd_w * acc_type_size_);
1197                 add(reg_off_dat_, stride_C_ * data_type_size_);
1198 
1199                 dec(reg_C_);
1200                 jnz(label_C);
1201             }
1202 
1203             add(reg_off_dat_save_, stride_S_ * data_type_size_);
1204 
1205             dec(reg_S_);
1206             jnz(label_S);
1207         }
1208     }
1209 
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1210     void compute(bool stream_store_allowed) {
1211         Label label_N;
1212         mov(reg_N_, dword[PARAM_ADDR(N)]);
1213         L(label_N);
1214         {
1215             xor_(reg_off_dat_save_, reg_off_dat_save_);
1216             xor_(reg_off_c_, reg_off_c_);
1217 
1218             tag_kind_ == jit_memory_tag_kind_t::nspc
1219                     ? compute_nspc(stream_store_allowed)
1220                     : compute_blocked(stream_store_allowed);
1221 
1222             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1223                 xor_(reg_off_dat_save_, reg_off_dat_save_);
1224                 xor_(reg_off_c_, reg_off_c_);
1225                 add(reg_off_dat_save_, vlen / 2);
1226                 add(reg_off_c_, vlen / 2);
1227 
1228                 compute_blocked(stream_store_allowed);
1229             }
1230 
1231             add(reg_ptr_src_, stride_N_ * data_type_size_);
1232             add(reg_ptr_diff_src_, stride_N_ * data_type_size_);
1233             add(reg_ptr_diff_dst_, stride_N_ * data_type_size_);
1234             add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1235 
1236             dec(reg_N_);
1237             jnz(label_N);
1238         }
1239     }
1240 
calculate_diff_statsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1241     bool calculate_diff_stats() const { return !bdesc_->use_global_stats(); }
1242 
jit_bnorm_bwd_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1243     jit_bnorm_bwd_t(const batch_normalization_pd_t *bdesc,
1244             const jit_memory_tag_kind_t tag_kind)
1245         : bdesc_(bdesc)
1246         , tag_kind_(tag_kind)
1247         , vlen(get_vlen<isa>(tag_kind))
1248         , simd_w(get_simd_w<isa>(tag_kind))
1249         , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
1250                   vtail_mask_, ktail_mask_)
1251         , jit_relu_(bdesc, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1252                   vstore_mask_, kstore_mask_)
1253         , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1254         static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
1255                 "unsupported isa");
1256 
1257         std::tie(stride_N_, stride_S_, stride_C_)
1258                 = get_data_strides<isa>(bdesc_, tag_kind);
1259 
1260         data_type_size_
1261                 = types::data_type_size(bdesc->desc()->data_desc.data_type);
1262         acc_type_size_ = sizeof(acc_data_t);
1263     }
1264 
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1265     void generate() override {
1266         bool is_bf16 = bdesc_->desc()->data_desc.data_type == data_type::bf16;
1267         const bool is_tail_in_nspc_format
1268                 = tag_kind_ == jit_memory_tag_kind_t::nspc
1269                 && jit_tail_.tail_ != 0;
1270         const bool stream_store_allowed = !is_bf16 && !is_tail_in_nspc_format;
1271 
1272         preamble();
1273         load_common_params();
1274         jit_relu_.bwd_prepare_relu();
1275         jit_tail_.prepare_tail();
1276 
1277         Label normal_store, end_store;
1278         test(reg_ptr_diff_src_, vlen - 1);
1279         jnz(normal_store, T_NEAR);
1280         compute(stream_store_allowed);
1281         jmp(end_store, T_NEAR);
1282         L(normal_store);
1283         { compute(false); }
1284         L(end_store);
1285 
1286         postamble();
1287     }
1288 };
1289 
1290 template <cpu_isa_t isa>
1291 struct jit_bnorm_bwd_diff_ss_t : public jit_generator {
1292     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_diff_ss_t)
1293     using Vmm = typename cpu_isa_traits<isa>::Vmm;
1294 
1295     const AddressFrame &vmmword
1296             = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
1297 
1298     struct call_params_t {
1299         size_t N, C, S;
1300         const void *src, *diff_dst;
1301         const uint8_t *ws;
1302         const acc_data_t *mean, *var;
1303         const acc_data_t *diff_gamma, *diff_beta;
1304         size_t blk_has_tail;
1305     };
1306 
1307     const Reg64 &reg_param_ = abi_param1;
1308     const Reg64 &reg_tmp_ = abi_not_param1;
1309     const Reg64 &reg_N_ = rsi;
1310     const Reg64 &reg_S_ = rax;
1311     const Reg64 &reg_C_ = rdx;
1312     const Reg64 &reg_off_c_ = rbx;
1313     const Reg64 &reg_blk_has_tail_ = rbp;
1314 
1315     const Reg64 &reg_off_dat_ = r8;
1316     const Reg64 &reg_off_dat_save_ = r9;
1317     const Reg64 &reg_ptr_c_ = r10;
1318     const Reg64 &reg_ptr_diff_gamma_ = r11;
1319     const Reg64 &reg_ptr_diff_beta_ = r12;
1320     const Reg64 &reg_ptr_ws_ = r13;
1321     const Reg64 &reg_ptr_diff_dst_ = r14;
1322     const Reg64 &reg_ptr_src_ = r15;
1323 
1324     const Vmm vtail_mask_ = Vmm(0);
1325     const Vmm v_ = Vmm(1);
1326     const Vmm vtmp_ = Vmm(2);
1327     const Vmm vstore_mask_ = vtmp_;
1328     const Vmm vzero_ = Vmm(3);
1329     const Vmm veps_ = Vmm(4);
1330     const Vmm vone_ = Vmm(5);
1331     // Diff_beta, diff_gamma and one of the statistic values(mean or sqrtvar)
1332     // are unrolled i.e.three vmms are needed to unroll one c block at any moment,
1333     // therefore the number of registers which are used to unrolling must to be
1334     // divisible by three.
1335     static constexpr int min_idx_to_unroll_ = 6;
1336     static constexpr int max_idx_to_unroll_ = isa == avx512_common ? 27 : 15;
1337     static constexpr int number_of_unrolled_variables_ = 3;
1338     static constexpr int number_of_vmms_to_unrolling_variables_
1339             = max_idx_to_unroll_ - min_idx_to_unroll_;
1340     static_assert(number_of_vmms_to_unrolling_variables_
1341                                     % number_of_unrolled_variables_
1342                             == 0
1343                     && number_of_vmms_to_unrolling_variables_ != 0,
1344             "Number of register to unrolling must to be divisible by 3.");
1345 
1346     const Opmask &kstore_mask_ = k1;
1347     const Opmask &ktail_mask_ = k2;
1348 
1349     const batch_normalization_pd_t *bdesc_;
1350     const jit_memory_tag_kind_t tag_kind_;
1351     const int vlen;
1352     const int simd_w;
1353     jit_bnorm_process_tail_t<isa> jit_tail_;
1354     jit_bnorm_process_relu_t<isa> jit_relu_;
1355     jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
1356     int stride_N_, stride_S_, stride_C_;
1357     size_t data_type_size_, acc_type_size_;
1358 
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1359     void load_common_params() {
1360 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
1361         mov(reg_ptr_src_, PARAM_PTR(src));
1362         mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst));
1363         mov(reg_ptr_ws_, PARAM_PTR(ws));
1364         mov(reg_ptr_diff_gamma_, PARAM_PTR(diff_gamma));
1365         mov(reg_ptr_diff_beta_, PARAM_PTR(diff_beta));
1366 #undef PARAM_PTR
1367 
1368         Xmm x = Xmm(v_.getIdx());
1369 
1370         mov(reg_tmp_, float2int(bdesc_->desc()->batch_norm_epsilon));
1371         uni_vmovq(x, reg_tmp_);
1372         uni_vbroadcastss(veps_, x);
1373 
1374         mov(reg_tmp_, float2int(1.f));
1375         uni_vmovq(x, reg_tmp_);
1376         uni_vbroadcastss(vone_, x);
1377 
1378         mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
1379     }
1380 
zeroisednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1381     void zeroise() {
1382         Label label_zeroise;
1383         xor_(reg_off_c_, reg_off_c_);
1384         uni_vpxor(vzero_, vzero_, vzero_);
1385         mov(reg_C_, dword[PARAM_ADDR(C)]);
1386         L(label_zeroise);
1387         {
1388             jit_tail_.uni_vmovups_maybe_tail(
1389                     vmmword[reg_ptr_diff_gamma_ + reg_off_c_], vzero_);
1390             jit_tail_.uni_vmovups_maybe_tail(
1391                     vmmword[reg_ptr_diff_beta_ + reg_off_c_], vzero_);
1392             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1393                 jit_tail_.uni_vmovups_maybe_tail(
1394                         vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + vlen / 2],
1395                         vzero_);
1396                 jit_tail_.uni_vmovups_maybe_tail(
1397                         vmmword[reg_ptr_diff_beta_ + reg_off_c_ + vlen / 2],
1398                         vzero_);
1399             }
1400             add(reg_off_c_, simd_w * acc_type_size_);
1401             dec(reg_C_);
1402             jnz(label_zeroise);
1403         }
1404     }
1405 
load_meandnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1406     void load_mean(const int c_blks_to_unroll = 1) {
1407         mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]);
1408 
1409         const int start_idx = min_idx_to_unroll_;
1410         const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1411                 + min_idx_to_unroll_;
1412         const int step = simd_w * acc_type_size_;
1413 
1414         for (int idx = start_idx, off = 0; idx < end_idx;
1415                 idx += number_of_unrolled_variables_, off += step) {
1416             const Vmm vmean = Vmm(idx);
1417 
1418             jit_tail_.uni_vmovups_maybe_tail(
1419                     vmean, vmmword[reg_ptr_c_ + reg_off_c_ + off]);
1420         }
1421     }
1422 
zeroise_diff_beta_and_diff_gammadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1423     void zeroise_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1424         const int start_idx = min_idx_to_unroll_;
1425         const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1426                 + min_idx_to_unroll_;
1427 
1428         for (int idx = start_idx; idx < end_idx;
1429                 idx += number_of_unrolled_variables_) {
1430             const Vmm vdiff_beta = Vmm(idx + 1);
1431             const Vmm vdiff_gamma = Vmm(idx + 2);
1432 
1433             uni_vpxor(vdiff_beta, vdiff_beta, vdiff_beta);
1434             uni_vpxor(vdiff_gamma, vdiff_gamma, vdiff_gamma);
1435         }
1436     }
1437 
load_and_prepare_sqrtvardnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1438     void load_and_prepare_sqrtvar(const int c_blks_to_unroll = 1) {
1439         mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]);
1440 
1441         const int start_idx = min_idx_to_unroll_;
1442         const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1443                 + min_idx_to_unroll_;
1444         const int step = simd_w * acc_type_size_;
1445 
1446         for (int idx = start_idx, off = 0; idx < end_idx;
1447                 idx += number_of_unrolled_variables_, off += step) {
1448             const Vmm vsqrtvar = Vmm(idx);
1449 
1450             jit_tail_.uni_vmovups_maybe_tail(
1451                     vsqrtvar, vmmword[reg_ptr_c_ + reg_off_c_ + off]);
1452 
1453             // 1.0 / sqrt(var + eps)
1454             uni_vaddps(vsqrtvar, vsqrtvar, veps_);
1455             uni_vsqrtps(vsqrtvar, vsqrtvar);
1456 
1457             if (isa == sse41) {
1458                 movups(vtmp_, vone_);
1459                 divps(vtmp_, vsqrtvar);
1460                 movups(vsqrtvar, vtmp_);
1461             } else
1462                 vdivps(vsqrtvar, vone_, vsqrtvar);
1463         }
1464     }
1465 
compute_diff_beta_and_diff_gammadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1466     void compute_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1467         const int start_idx = min_idx_to_unroll_;
1468         const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1469                 + min_idx_to_unroll_;
1470         const int step = simd_w * data_type_size_;
1471 
1472         for (int idx = start_idx, off = 0; idx < end_idx;
1473                 idx += number_of_unrolled_variables_, off += step) {
1474             const Vmm vmean = Vmm(idx);
1475             const Vmm vdiff_beta = Vmm(idx + 1);
1476             const Vmm vdiff_gamma = Vmm(idx + 2);
1477 
1478             jit_bf16_emu_.uni_vmovups_data(
1479                     v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_ + off]);
1480 
1481             jit_relu_.bwd_process_relu(
1482                     v_, off / (bits_per_byte * data_type_size_));
1483 
1484             // diff_beta
1485             uni_vaddps(vdiff_beta, vdiff_beta, v_);
1486 
1487             jit_bf16_emu_.uni_vmovups_data(
1488                     vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
1489 
1490             // diff_gamma, note that diff_gamma will be multiplied
1491             // by sqrtvar before store
1492             uni_vsubps(vtmp_, vtmp_, vmean);
1493             uni_vfmadd231ps(vdiff_gamma, vtmp_, v_);
1494         }
1495     }
1496 
store_diff_beta_and_diff_gammadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1497     void store_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1498         const int start_idx = min_idx_to_unroll_;
1499         const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1500                 + min_idx_to_unroll_;
1501         const int step = simd_w * acc_type_size_;
1502 
1503         for (int idx = start_idx, off = 0; idx < end_idx;
1504                 idx += number_of_unrolled_variables_, off += step) {
1505             const Vmm vdiff_beta = Vmm(idx + 1);
1506 
1507             jit_tail_.uni_vmovups_maybe_tail(
1508                     vtmp_, vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off]);
1509             uni_vaddps(vdiff_beta, vdiff_beta, vtmp_);
1510             jit_tail_.uni_vmovups_maybe_tail(
1511                     vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off], vdiff_beta);
1512         }
1513 
1514         for (int idx = start_idx, off = 0; idx < end_idx;
1515                 idx += number_of_unrolled_variables_, off += step) {
1516             const Vmm vsqrtvar = Vmm(idx);
1517             const Vmm vdiff_gamma = Vmm(idx + 2);
1518 
1519             // multiply diff_gamma by 1.0/sqrt(var + eps)
1520             uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
1521 
1522             jit_tail_.uni_vmovups_maybe_tail(
1523                     vtmp_, vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off]);
1524             uni_vaddps(vdiff_gamma, vdiff_gamma, vtmp_);
1525             jit_tail_.uni_vmovups_maybe_tail(
1526                     vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off],
1527                     vdiff_gamma);
1528         }
1529     }
1530 
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1531     void compute_blocked() {
1532         Label label_C, label_S;
1533         mov(reg_C_, dword[PARAM_ADDR(C)]);
1534         L(label_C);
1535         {
1536             mov(reg_off_dat_, reg_off_dat_save_);
1537 
1538             load_mean();
1539             zeroise_diff_beta_and_diff_gamma();
1540 
1541             mov(reg_S_, dword[PARAM_ADDR(S)]);
1542             L(label_S);
1543             {
1544                 compute_diff_beta_and_diff_gamma();
1545 
1546                 add(reg_off_dat_, stride_S_ * data_type_size_);
1547 
1548                 dec(reg_S_);
1549                 jnz(label_S);
1550             }
1551 
1552             load_and_prepare_sqrtvar();
1553             store_diff_beta_and_diff_gamma();
1554 
1555             add(reg_off_dat_save_, stride_C_ * data_type_size_);
1556             add(reg_off_c_, simd_w * acc_type_size_);
1557 
1558             dec(reg_C_);
1559             jnz(label_C);
1560         }
1561     }
1562 
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1563     void compute_nspc() {
1564         mov(reg_C_, dword[PARAM_ADDR(C)]);
1565 
1566         constexpr int max_of_unrolled_c_blks
1567                 = number_of_vmms_to_unrolling_variables_
1568                 / number_of_unrolled_variables_;
1569         std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1);
1570 
1571         for (int c_blks_to_unroll = max_of_unrolled_c_blks;
1572                 c_blks_to_unroll > 0; --c_blks_to_unroll) {
1573             L(c_unroll_label[c_blks_to_unroll]);
1574             {
1575                 cmp(reg_C_, c_blks_to_unroll);
1576                 jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR);
1577 
1578                 mov(reg_off_dat_, reg_off_dat_save_);
1579 
1580                 load_mean(c_blks_to_unroll);
1581                 zeroise_diff_beta_and_diff_gamma(c_blks_to_unroll);
1582 
1583                 Label label_S;
1584                 mov(reg_S_, dword[PARAM_ADDR(S)]);
1585                 L(label_S);
1586                 {
1587                     compute_diff_beta_and_diff_gamma(c_blks_to_unroll);
1588 
1589                     add(reg_off_dat_, stride_S_ * data_type_size_);
1590 
1591                     dec(reg_S_);
1592                     jnz(label_S);
1593                 }
1594 
1595                 load_and_prepare_sqrtvar(c_blks_to_unroll);
1596                 store_diff_beta_and_diff_gamma(c_blks_to_unroll);
1597 
1598                 add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_);
1599                 add(reg_off_dat_save_,
1600                         c_blks_to_unroll * stride_C_ * data_type_size_);
1601 
1602                 sub(reg_C_, c_blks_to_unroll);
1603                 jmp(c_unroll_label[c_blks_to_unroll], T_NEAR);
1604             }
1605         }
1606         L(c_unroll_label[0]);
1607     }
1608 
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1609     void compute() {
1610         Label label_N;
1611         mov(reg_N_, dword[PARAM_ADDR(N)]);
1612         L(label_N);
1613         {
1614             xor_(reg_off_dat_save_, reg_off_dat_save_);
1615             xor_(reg_off_c_, reg_off_c_);
1616 
1617             tag_kind_ == jit_memory_tag_kind_t::nspc ? compute_nspc()
1618                                                      : compute_blocked();
1619 
1620             if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1621                 xor_(reg_off_dat_save_, reg_off_dat_save_);
1622                 xor_(reg_off_c_, reg_off_c_);
1623                 add(reg_off_dat_save_, vlen / 2);
1624                 add(reg_off_c_, vlen / 2);
1625 
1626                 compute_blocked();
1627             }
1628 
1629             add(reg_ptr_src_, stride_N_ * data_type_size_);
1630             add(reg_ptr_diff_dst_, stride_N_ * data_type_size_);
1631             add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1632 
1633             dec(reg_N_);
1634             jnz(label_N);
1635         }
1636     }
1637 
jit_bnorm_bwd_diff_ss_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1638     jit_bnorm_bwd_diff_ss_t(const batch_normalization_pd_t *bdesc,
1639             const jit_memory_tag_kind_t tag_kind)
1640         : bdesc_(bdesc)
1641         , tag_kind_(tag_kind)
1642         , vlen(get_vlen<isa>(tag_kind))
1643         , simd_w(get_simd_w<isa>(tag_kind))
1644         , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
1645                   vtail_mask_, ktail_mask_)
1646         , jit_relu_(bdesc, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1647                   vstore_mask_, kstore_mask_)
1648         , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1649         static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
1650                 "unsupported isa");
1651 
1652         std::tie(stride_N_, stride_S_, stride_C_)
1653                 = get_data_strides<isa>(bdesc_, tag_kind);
1654 
1655         data_type_size_
1656                 = types::data_type_size(bdesc->desc()->data_desc.data_type);
1657         acc_type_size_ = sizeof(acc_data_t);
1658     }
1659 
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1660     void generate() override {
1661         preamble();
1662         load_common_params();
1663         jit_relu_.bwd_prepare_relu();
1664         jit_tail_.prepare_tail();
1665         zeroise();
1666         compute();
1667         postamble();
1668     }
1669 };
1670 } // namespace
1671 namespace bnorm_tbb_impl {
1672 
1673 template <cpu_isa_t isa>
1674 struct driver_t : public c_compatible {
1675 private:
1676     struct bnorm_dims_t {
1677         dim_t N, C, S;
1678         dim_t glob;
1679     };
1680 
1681     DNNL_DISALLOW_COPY_AND_ASSIGN(driver_t);
1682 
1683 public:
driver_tdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1684     driver_t(const batch_normalization_pd_t *bdesc,
1685             const jit_memory_tag_kind_t tag_kind)
1686         : bdesc_(bdesc)
1687         , tag_kind_(tag_kind)
1688         , simd_w(get_simd_w<isa>(tag_kind)) {
1689         nthr_ = dnnl_get_max_threads();
1690         N_ = bdesc_->MB();
1691         S_ = bdesc_->D() * bdesc_->H() * bdesc_->W();
1692         C_ = bdesc_->C();
1693         C_blks_ = get_c_padded(bdesc_) / simd_w;
1694 
1695         const size_t l3_size = platform::get_per_core_cache_size(3) * nthr_ / 2;
1696         int num_tensors = bdesc_->is_fwd() ? 1 : 2;
1697         dt_size_ = types::data_type_size(bdesc_->desc()->data_desc.data_type);
1698         const size_t working_set_size
1699                 = dt_size_ * N_ * S_ * simd_w * num_tensors;
1700 
1701         do_blocking_ = tag_kind_ == jit_memory_tag_kind_t::nspc
1702                 ? false
1703                 : working_set_size * C_blks_ >= l3_size / 2 && l3_size > 0;
1704 
1705         if (tag_kind_ == jit_memory_tag_kind_t::nspc) {
1706             C_blk_step_ = C_blks_;
1707         } else {
1708             C_blk_step_ = l3_size / working_set_size;
1709             C_blk_step_ = nstl::max<dim_t>(C_blk_step_, 1);
1710             C_blk_step_ = nstl::min<dim_t>(C_blk_step_, C_blks_);
1711         }
1712     }
1713 
create_kerneldnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1714     status_t create_kernel() {
1715         if (bdesc_->is_fwd()) {
1716             CHECK(safe_ptr_assign(
1717                     ker_fwd_, new jit_bnorm_fwd_t<isa>(bdesc_, tag_kind_)));
1718             CHECK(ker_fwd_->create_kernel());
1719             if (!bdesc_->stats_is_src()) {
1720                 CHECK(safe_ptr_assign(ker_fwd_mean_,
1721                         new jit_bnorm_fwd_mean_t<isa>(bdesc_, tag_kind_)));
1722                 CHECK(safe_ptr_assign(ker_fwd_var_,
1723                         new jit_bnorm_fwd_var_t<isa>(bdesc_, tag_kind_)));
1724                 CHECK(ker_fwd_mean_->create_kernel());
1725                 CHECK(ker_fwd_var_->create_kernel());
1726             }
1727         } else {
1728             CHECK(safe_ptr_assign(
1729                     ker_bwd_, new jit_bnorm_bwd_t<isa>(bdesc_, tag_kind_)));
1730             CHECK(safe_ptr_assign(ker_bwd_diff_ss_,
1731                     new jit_bnorm_bwd_diff_ss_t<isa>(bdesc_, tag_kind_)));
1732             CHECK(ker_bwd_->create_kernel());
1733             CHECK(ker_bwd_diff_ss_->create_kernel());
1734         }
1735         return status::success;
1736     }
1737 
init_scratchpaddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1738     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1739             const batch_normalization_pd_t *bdesc) {
1740 
1741         int nthrs = dnnl_get_max_threads();
1742         int C_PADDED = get_c_padded(bdesc);
1743 
1744         int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
1745         int pbuf_sz = (use_tmp_diff_scale(bdesc) + use_tmp_diff_shift(bdesc))
1746                 * C_PADDED;
1747         int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1748 
1749         scratchpad.book<acc_data_t>(key_bnorm_tmp_stats, sbuf_sz);
1750         scratchpad.book<acc_data_t>(key_bnorm_tmp_diff_ss, pbuf_sz);
1751         scratchpad.book<acc_data_t>(key_bnorm_reduction, rbuf_sz);
1752     }
1753 
exec_fwd_step_statsdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1754     void exec_fwd_step_stats(const dim_t C_blks, const bnorm_dims_t &nthr,
1755             const void *src, acc_data_t *mean, acc_data_t *var,
1756             acc_data_t *rbuf, bool blk_has_tail) {
1757         size_t stride_C, stride_N, stride_S;
1758         std::tie(stride_N, stride_S, stride_C)
1759                 = get_data_strides<isa>(bdesc_, tag_kind_);
1760 
1761         const int nthr_NS = nthr.N * nthr.S;
1762         const bool need_reduction = nthr_NS > 1;
1763         const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w;
1764 
1765         const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size;
1766 
1767         auto reduce = [&](acc_data_t *stat, acc_data_t *r_stat) {
1768             if (!need_reduction) return;
1769             acc_data_t *loc_stat = r_stat;
1770 
1771             for (dim_t c = 0; c < size_C_stat; ++c)
1772                 stat[c] = loc_stat[c];
1773 
1774             for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
1775                 loc_stat += size_C_stat;
1776                 for (dim_t c = 0; c < size_C_stat; ++c)
1777                     stat[c] += loc_stat[c];
1778             }
1779 
1780             for (dim_t c = 0; c < size_C_stat; ++c)
1781                 stat[c] /= N_ * S_;
1782         };
1783 
1784         // find local mean
1785         acc_data_t *r_mean = need_reduction ? rbuf : mean;
1786         parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1787             assert(nthr_glob == nthr.glob);
1788             const auto ithr = map_thread(ithr_glob, nthr);
1789             bnorm_dims_t start, stop;
1790             work_distribution(C_blks, ithr, nthr, start, stop);
1791 
1792             auto c = typename jit_bnorm_fwd_mean_t<isa>::call_params_t();
1793             c.N = stop.N - start.N;
1794             c.C = stop.C - start.C;
1795             c.S = stop.S - start.S;
1796 
1797             const size_t d_off = start.N * stride_N + start.C * stride_C
1798                     + start.S * stride_S;
1799             c.src = (void *)((char *)src + d_off * dt_size_);
1800             const int ithr_NS = ithr.N * nthr.S + ithr.S;
1801             c.mean = &r_mean[ithr_NS * size_C_stat + start.C * simd_w];
1802             c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1803             c.do_normalise = !need_reduction;
1804             (*ker_fwd_mean_)(&c);
1805         });
1806 
1807         // mean reduction
1808         reduce(mean, r_mean);
1809 
1810         // find local var
1811         acc_data_t *r_var = need_reduction ? rbuf : var;
1812         parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1813             assert(nthr_glob == nthr.glob);
1814             const auto ithr = map_thread(ithr_glob, nthr);
1815             bnorm_dims_t start, stop;
1816             work_distribution(C_blks, ithr, nthr, start, stop);
1817 
1818             auto c = typename jit_bnorm_fwd_var_t<isa>::call_params_t();
1819             c.N = stop.N - start.N;
1820             c.C = stop.C - start.C;
1821             c.S = stop.S - start.S;
1822 
1823             const size_t d_off = start.N * stride_N + start.C * stride_C
1824                     + start.S * stride_S;
1825             c.src = (void *)((char *)src + d_off * dt_size_);
1826             const int ithr_NS = ithr.N * nthr.S + ithr.S;
1827             c.mean = &mean[start.C * simd_w];
1828             c.var = &r_var[ithr_NS * size_C_stat + start.C * simd_w];
1829             c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1830             c.do_normalise = !need_reduction;
1831             (*ker_fwd_var_)(&c);
1832         });
1833 
1834         // var reduction
1835         reduce(var, r_var);
1836     }
1837 
exec_fwd_step_normalizationdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1838     void exec_fwd_step_normalization(const dim_t C_blks,
1839             const bnorm_dims_t &nthr, const void *src, void *dst,
1840             const acc_data_t *scale, const acc_data_t *shift,
1841             const acc_data_t *mean, const acc_data_t *var, uint8_t *ws,
1842             bool blk_has_tail) {
1843         size_t stride_C, stride_N, stride_S;
1844         std::tie(stride_N, stride_S, stride_C)
1845                 = get_data_strides<isa>(bdesc_, tag_kind_);
1846 
1847         parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1848             assert(nthr_glob == nthr.glob);
1849             const auto ithr = map_thread(ithr_glob, nthr);
1850             bnorm_dims_t start, stop;
1851             work_distribution(C_blks, ithr, nthr, start, stop);
1852 
1853             auto c = typename jit_bnorm_fwd_t<isa>::call_params_t();
1854             c.N = stop.N - start.N;
1855             c.C = stop.C - start.C;
1856             c.S = stop.S - start.S;
1857 
1858             const size_t d_off = start.N * stride_N + start.C * stride_C
1859                     + start.S * stride_S;
1860             c.src = (void *)((char *)src + d_off * dt_size_);
1861             c.dst = (void *)((char *)dst + d_off * dt_size_);
1862             c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
1863             c.mean = &mean[start.C * simd_w];
1864             c.var = &var[start.C * simd_w];
1865             c.scale = scale ? &scale[start.C * simd_w] : nullptr;
1866             c.shift = shift ? &shift[start.C * simd_w] : nullptr;
1867             c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1868             (*ker_fwd_)(&c);
1869         });
1870     }
1871 
exec_fwddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1872     void exec_fwd(const void *src, void *dst, const acc_data_t *scale,
1873             const acc_data_t *shift, acc_data_t *mean, acc_data_t *var,
1874             uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
1875         auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
1876         if (use_tmp_stats(bdesc_)) {
1877             auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats);
1878             mean = sbuf;
1879             var = sbuf + C_blks_ * simd_w;
1880         }
1881 
1882         size_t stride_C;
1883         std::tie(std::ignore, std::ignore, stride_C)
1884                 = get_data_strides<isa>(bdesc_, tag_kind_);
1885 
1886         dim_t C_blk_step = C_blk_step_;
1887         auto nthr = bnorm_dims_t();
1888 
1889         thread_distribution(C_blk_step, nthr);
1890 
1891         for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) {
1892             if (C_blk_st + C_blk_step > C_blks_) {
1893                 C_blk_step = C_blks_ - C_blk_st;
1894                 thread_distribution(C_blk_step, nthr);
1895             }
1896 
1897             if (!bdesc_->stats_is_src()) {
1898                 exec_fwd_step_stats(C_blk_step, nthr,
1899                         (void *)((char *)src
1900                                 + (C_blk_st * stride_C) * dt_size_),
1901                         mean + C_blk_st * simd_w, var + C_blk_st * simd_w, rbuf,
1902                         (C_blk_st + C_blk_step) * simd_w > C_);
1903             }
1904 
1905             exec_fwd_step_normalization(C_blk_step, nthr,
1906                     (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
1907                     (void *)((char *)dst + (C_blk_st * stride_C) * dt_size_),
1908                     scale + C_blk_st * simd_w, shift + C_blk_st * simd_w,
1909                     mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
1910                     ws + C_blk_st * stride_C / bits_per_byte,
1911                     (C_blk_st + C_blk_step) * simd_w > C_);
1912         }
1913     }
1914 
exec_bwd_step_diff_ssdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1915     void exec_bwd_step_diff_ss(const dim_t C_blks, const bnorm_dims_t &nthr,
1916             const void *src, const void *diff_dst, const acc_data_t *mean,
1917             const acc_data_t *var, const uint8_t *ws, acc_data_t *diff_scale,
1918             acc_data_t *diff_shift, acc_data_t *rbuf, bool blk_has_tail) {
1919         size_t stride_C, stride_N, stride_S;
1920         std::tie(stride_N, stride_S, stride_C)
1921                 = get_data_strides<isa>(bdesc_, tag_kind_);
1922 
1923         const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w;
1924         const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size;
1925 
1926         const int nthr_NS = nthr.N * nthr.S;
1927         const bool need_reduction = nthr_NS > 1;
1928 
1929         acc_data_t *diff_gamma = diff_scale;
1930         acc_data_t *diff_beta = diff_shift;
1931 
1932         acc_data_t *const r_diff_gamma = need_reduction ? rbuf : diff_gamma;
1933         acc_data_t *const r_diff_beta
1934                 = need_reduction ? rbuf + nthr_NS * size_C_stat : diff_beta;
1935 
1936         auto reduce = [&]() {
1937             if (!need_reduction) return;
1938 
1939             // diff_gamma
1940             const acc_data_t *loc_diff_gamma = r_diff_gamma;
1941             for (dim_t c = 0; c < size_C_stat; ++c)
1942                 diff_gamma[c] = loc_diff_gamma[c];
1943             for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
1944                 loc_diff_gamma += size_C_stat;
1945                 for (dim_t c = 0; c < size_C_stat; ++c)
1946                     diff_gamma[c] += loc_diff_gamma[c];
1947             }
1948 
1949             // diff_beta
1950             const acc_data_t *loc_diff_beta = r_diff_beta;
1951             for (dim_t c = 0; c < size_C_stat; ++c)
1952                 diff_beta[c] = loc_diff_beta[c];
1953             for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
1954                 loc_diff_beta += size_C_stat;
1955                 for (dim_t c = 0; c < size_C_stat; ++c)
1956                     diff_beta[c] += loc_diff_beta[c];
1957             }
1958         };
1959 
1960         parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1961             assert(nthr_glob == nthr.glob);
1962             const auto ithr = map_thread(ithr_glob, nthr);
1963             bnorm_dims_t start, stop;
1964             work_distribution(C_blks, ithr, nthr, start, stop);
1965 
1966             const int ithr_NS = ithr.N * nthr.S + ithr.S;
1967             acc_data_t *loc_diff_gamma = &r_diff_gamma[ithr_NS * size_C_stat];
1968             acc_data_t *loc_diff_beta = &r_diff_beta[ithr_NS * size_C_stat];
1969 
1970             auto c = typename jit_bnorm_bwd_diff_ss_t<isa>::call_params_t();
1971             c.N = stop.N - start.N;
1972             c.C = stop.C - start.C;
1973             c.S = stop.S - start.S;
1974 
1975             const size_t d_off = start.N * stride_N + start.C * stride_C
1976                     + start.S * stride_S;
1977             c.src = (void *)((char *)src + d_off * dt_size_);
1978             c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_);
1979             c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
1980             c.mean = &mean[start.C * simd_w];
1981             c.var = &var[start.C * simd_w];
1982             c.diff_gamma = &loc_diff_gamma[start.C * simd_w];
1983             c.diff_beta = &loc_diff_beta[start.C * simd_w];
1984             c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1985 
1986             (*ker_bwd_diff_ss_)(&c);
1987         });
1988 
1989         reduce();
1990     }
1991 
exec_bwd_step_normalizationdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1992     void exec_bwd_step_normalization(const dim_t C_blks,
1993             const bnorm_dims_t &nthr, const void *src, void *diff_src,
1994             const void *diff_dst, const acc_data_t *mean, const acc_data_t *var,
1995             const uint8_t *ws, const acc_data_t *scale,
1996             const acc_data_t *diff_scale, const acc_data_t *diff_shift,
1997             bool blk_has_tail) {
1998         size_t stride_C, stride_N, stride_S;
1999         std::tie(stride_N, stride_S, stride_C)
2000                 = get_data_strides<isa>(bdesc_, tag_kind_);
2001 
2002         parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2003             assert(nthr_glob == nthr.glob);
2004             const auto ithr = map_thread(ithr_glob, nthr);
2005             bnorm_dims_t start, stop;
2006             work_distribution(C_blks, ithr, nthr, start, stop);
2007 
2008             auto c = typename jit_bnorm_bwd_t<isa>::call_params_t();
2009             c.N = stop.N - start.N;
2010             c.C = stop.C - start.C;
2011             c.S = stop.S - start.S;
2012 
2013             const size_t d_off = start.N * stride_N + start.C * stride_C
2014                     + start.S * stride_S;
2015             c.src = (void *)((char *)src + d_off * dt_size_);
2016             c.diff_src = (void *)((char *)diff_src + d_off * dt_size_);
2017             c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_);
2018             c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
2019             c.mean = &mean[start.C * simd_w];
2020             c.var = &var[start.C * simd_w];
2021             c.scale = scale ? &scale[start.C * simd_w] : nullptr;
2022             c.diff_scale = &diff_scale[start.C * simd_w];
2023             c.diff_shift = &diff_shift[start.C * simd_w];
2024             c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2025 
2026             (*ker_bwd_)(&c);
2027         });
2028     }
2029 
exec_bwddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2030     void exec_bwd(const void *src, void *diff_src, const void *diff_dst,
2031             const acc_data_t *scale, acc_data_t *diff_scale,
2032             acc_data_t *diff_shift, const acc_data_t *mean,
2033             const acc_data_t *var, const uint8_t *ws,
2034             const memory_tracking::grantor_t &scratchpad) {
2035         auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
2036         if (use_tmp_diff_scale(bdesc_)) {
2037             auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2038             diff_scale = pbuf;
2039         }
2040         if (use_tmp_diff_shift(bdesc_)) {
2041             auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2042             size_t shift_off = use_tmp_diff_scale(bdesc_) ? bdesc_->C() : 0;
2043             diff_shift = &pbuf[shift_off];
2044         }
2045 
2046         size_t stride_C;
2047         std::tie(std::ignore, std::ignore, stride_C)
2048                 = get_data_strides<isa>(bdesc_, tag_kind_);
2049 
2050         dim_t C_blk_step = C_blk_step_;
2051         auto nthr = bnorm_dims_t();
2052 
2053         thread_distribution(C_blk_step, nthr);
2054 
2055         for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) {
2056             if (C_blk_st + C_blk_step > C_blks_) {
2057                 C_blk_step = C_blks_ - C_blk_st;
2058                 thread_distribution(C_blk_step, nthr);
2059             }
2060 
2061             exec_bwd_step_diff_ss(C_blk_step, nthr,
2062                     (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2063                     (void *)((char *)diff_dst
2064                             + (C_blk_st * stride_C) * dt_size_),
2065                     mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2066                     ws + C_blk_st * stride_C / bits_per_byte,
2067                     diff_scale + C_blk_st * simd_w,
2068                     diff_shift + C_blk_st * simd_w, rbuf,
2069                     (C_blk_st + C_blk_step) * simd_w > C_);
2070 
2071             exec_bwd_step_normalization(C_blk_step, nthr,
2072                     (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2073                     (void *)((char *)diff_src
2074                             + (C_blk_st * stride_C) * dt_size_),
2075                     (void *)((char *)diff_dst
2076                             + (C_blk_st * stride_C) * dt_size_),
2077                     mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2078                     ws + C_blk_st * stride_C / bits_per_byte,
2079                     scale + C_blk_st * simd_w, diff_scale + C_blk_st * simd_w,
2080                     diff_shift + C_blk_st * simd_w,
2081                     (C_blk_st + C_blk_step) * simd_w > C_);
2082         }
2083     }
2084 
2085 private:
use_tmp_statsdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2086     static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
2087         return true && !bdesc->stats_is_src()
2088                 && bdesc->desc()->prop_kind == prop_kind::forward_inference;
2089     }
2090 
use_tmp_diff_scalednnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2091     static bool use_tmp_diff_scale(const batch_normalization_pd_t *bdesc) {
2092         return false
2093                 || (bdesc->is_bwd() && !bdesc->use_scaleshift()
2094                         && !bdesc->use_scale())
2095                 || bdesc->desc()->prop_kind == prop_kind::backward_data;
2096     }
2097 
use_tmp_diff_shiftdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2098     static bool use_tmp_diff_shift(const batch_normalization_pd_t *bdesc) {
2099         return false
2100                 || (bdesc->is_bwd() && !bdesc->use_scaleshift()
2101                         && !bdesc->use_shift())
2102                 || bdesc->desc()->prop_kind == prop_kind::backward_data;
2103     }
2104 
thread_distributiondnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2105     void thread_distribution(dim_t C_blks, bnorm_dims_t &nthr) {
2106         if (do_blocking_) {
2107             nthr.N = nstl::min<dim_t>(N_, nthr_);
2108             nthr.C = nstl::min<dim_t>(C_blks, nthr_ / nthr.N);
2109         } else {
2110             if (tag_kind_ == jit_memory_tag_kind_t::nspc) {
2111                 if ((nthr_ <= C_blks && nthr_ == 1) || C_blks <= 8)
2112                     nthr.C = 1;
2113                 else if (nthr_ >= 8 && C_blks <= 32)
2114                     nthr.C = 8;
2115                 else {
2116                     nthr.C = math::gcd((dim_t)nthr_, C_blks);
2117                     // Unroll by channels in JIT kernel
2118                     if ((nthr.C == C_blks) || (nthr.C == nthr_)) nthr.C = 1;
2119                 }
2120             } else {
2121                 nthr.C = math::gcd((dim_t)nthr_, C_blks);
2122             }
2123             nthr.N = utils::saturate((dim_t)1, N_, nthr_ / nthr.C);
2124         }
2125         nthr.S = utils::saturate((dim_t)1, S_, nthr_ / (nthr.C * nthr.N));
2126         nthr.glob = nthr.N * nthr.C * nthr.S;
2127     }
2128 
map_thread_cdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2129     int map_thread_c(int ithr_glob, const bnorm_dims_t &nthr) {
2130         return ithr_glob / nthr.N / nthr.S;
2131     }
2132 
map_threaddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2133     bnorm_dims_t map_thread(int ithr_glob, const bnorm_dims_t &nthr) {
2134         auto ithr = bnorm_dims_t();
2135         ithr.glob = ithr_glob;
2136         ithr.C = map_thread_c(ithr.glob, nthr);
2137         ithr.N = ithr.glob / nthr.S % nthr.N;
2138         ithr.S = ithr.glob % nthr.S;
2139         return ithr;
2140     }
2141 
work_distribution_cdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2142     void work_distribution_c(dim_t C_blks, int ithr_c, int nthr_c,
2143             dim_t &start_c, dim_t &stop_c) {
2144         balance211(C_blks, nthr_c, ithr_c, start_c, stop_c);
2145     }
2146 
work_distributiondnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2147     void work_distribution(dim_t C_blks, const bnorm_dims_t &ithr,
2148             const bnorm_dims_t &nthr, bnorm_dims_t &start, bnorm_dims_t &stop) {
2149         work_distribution_c(C_blks, ithr.C, nthr.C, start.C, stop.C);
2150         balance211(N_, nthr.N, ithr.N, start.N, stop.N);
2151         balance211(S_, nthr.S, ithr.S, start.S, stop.S);
2152     }
2153 
2154     const batch_normalization_pd_t *bdesc_;
2155     const jit_memory_tag_kind_t tag_kind_;
2156     const int simd_w;
2157 
2158     bool do_blocking_;
2159 
2160     int nthr_;
2161 
2162     dim_t N_, S_; // MB, D * H *W
2163     dim_t C_, C_blks_; // C / simd_w
2164     dim_t C_blk_step_; // for C_blks = 0 .. C_blks_, += C_blk_step_
2165 
2166     std::unique_ptr<jit_bnorm_fwd_t<isa>> ker_fwd_;
2167     std::unique_ptr<jit_bnorm_fwd_mean_t<isa>> ker_fwd_mean_;
2168     std::unique_ptr<jit_bnorm_fwd_var_t<isa>> ker_fwd_var_;
2169     std::unique_ptr<jit_bnorm_bwd_t<isa>> ker_bwd_;
2170     std::unique_ptr<jit_bnorm_bwd_diff_ss_t<isa>> ker_bwd_diff_ss_;
2171 
2172     size_t dt_size_;
2173 };
2174 } // namespace bnorm_tbb_impl
2175 
2176 using namespace data_type;
2177 using namespace format_tag;
2178 using namespace utils;
2179 
2180 /* fwd */
2181 template <cpu_isa_t isa>
init(engine_t * engine)2182 status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::pd_t::init(
2183         engine_t *engine) {
2184 
2185     const bool ok = mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
2186             && one_of(ndims(), 4, 5) && one_of(src_md()->data_type, f32, bf16)
2187             && IMPLICATION(src_md()->data_type == bf16,
2188                     is_superset(isa, avx512_common) && mayiuse(avx512_core))
2189             && check_scale_shift_data_type()
2190             && (attr()->has_default_values() || this->with_relu_post_op());
2191     if (!ok) return status::unimplemented;
2192 
2193     const format_tag_t blocked_tag = is_superset(isa, avx512_common)
2194             ? utils::pick(ndims() - 4, nChw16c, nCdhw16c)
2195             : utils::pick(ndims() - 4, nChw8c, nCdhw8c);
2196 
2197     const format_tag_t blocked_format
2198             = memory_desc_matches_tag(*src_md(), blocked_tag)
2199             ? blocked_tag
2200             : format_tag::undef;
2201     const format_tag_t nspc_format
2202             = memory_desc_matches_one_of_tag(*src_md(), nhwc, ndhwc);
2203 
2204     if (memory_desc_matches_tag(*dst_md(), blocked_format))
2205         tag_kind_ = jit_memory_tag_kind_t::blocked;
2206     else if (memory_desc_matches_tag(*dst_md(), nspc_format)) {
2207         tag_kind_ = jit_memory_tag_kind_t::nspc;
2208         const int simd_w = get_simd_w<isa>(tag_kind_);
2209         if (C() % simd_w != 0) return status::unimplemented;
2210     } else
2211         return status::unimplemented;
2212 
2213     const bool isa_supports_avx2 = is_superset(isa, avx2);
2214     if (is_training() && fuse_norm_relu()) {
2215         if (!isa_supports_avx2) return status::unimplemented;
2216         init_default_ws(1);
2217     }
2218 
2219     if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2220             && !isa_supports_avx2)
2221         return status::unimplemented;
2222 
2223     auto scratchpad = scratchpad_registry().registrar();
2224     bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this);
2225 
2226     return status::success;
2227 }
2228 
2229 template <cpu_isa_t isa>
2230 jit_uni_tbb_batch_normalization_fwd_t<
jit_uni_tbb_batch_normalization_fwd_t(const pd_t * apd)2231         isa>::jit_uni_tbb_batch_normalization_fwd_t(const pd_t *apd)
2232     : primitive_t(apd) {}
2233 
2234 template <cpu_isa_t isa>
init(engine_t * engine)2235 status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::init(engine_t *engine) {
2236     CHECK(safe_ptr_assign(bnorm_driver_,
2237             new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_)));
2238     return bnorm_driver_->create_kernel();
2239 }
2240 
2241 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const2242 status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::execute(
2243         const exec_ctx_t &ctx) const {
2244 
2245     const memory_desc_wrapper ss_d(pd()->weights_md());
2246 
2247     const auto use_ss = pd()->use_scaleshift();
2248     const auto use_sc = pd()->use_scale();
2249     const auto use_sh = pd()->use_shift();
2250 
2251     const size_t shift_off
2252             = use_ss && !ss_d.has_zero_dim() ? ss_d.off(1, 0) : 0;
2253 
2254     auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2255     auto scale = CTX_IN_MEM(
2256             const acc_data_t *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
2257     auto shift = use_sh ? CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT)
2258                         : use_ss ? &CTX_IN_MEM(const acc_data_t *,
2259                                   DNNL_ARG_SCALE_SHIFT)[shift_off]
2260                                  : nullptr;
2261 
2262     auto mean = pd()->stats_is_src() ? const_cast<acc_data_t *>(
2263                         CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN))
2264                                      : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN);
2265     auto var = pd()->stats_is_src()
2266             ? const_cast<acc_data_t *>(
2267                     CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE))
2268             : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE);
2269 
2270     auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
2271     auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE);
2272 
2273     auto scratchpad = ctx.get_scratchpad_grantor();
2274 
2275     bnorm_driver_->exec_fwd(src, dst, scale, shift, mean, var, ws, scratchpad);
2276 
2277     return status::success;
2278 }
2279 
2280 template <cpu_isa_t isa>
2281 jit_uni_tbb_batch_normalization_fwd_t<
2282         isa>::~jit_uni_tbb_batch_normalization_fwd_t()
2283         = default;
2284 
2285 template struct jit_uni_tbb_batch_normalization_fwd_t<sse41>;
2286 template struct jit_uni_tbb_batch_normalization_fwd_t<avx2>;
2287 template struct jit_uni_tbb_batch_normalization_fwd_t<avx512_common>;
2288 
2289 /* bwd */
2290 template <cpu_isa_t isa>
init(engine_t * engine)2291 status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::pd_t::init(
2292         engine_t *engine) {
2293 
2294     const bool ok = mayiuse(isa) && is_bwd() && !has_zero_dim_memory()
2295             && one_of(ndims(), 4, 5) && set_default_formats_common()
2296             && one_of(true,
2297                     everyone_is(
2298                             f32, src_md()->data_type, diff_src_md()->data_type),
2299                     everyone_is(bf16, src_md()->data_type,
2300                             diff_src_md()->data_type))
2301             && IMPLICATION(src_md()->data_type == bf16,
2302                     is_superset(isa, avx512_common) && mayiuse(avx512_core))
2303             && check_scale_shift_data_type() && attr()->has_default_values();
2304     if (!ok) return status::unimplemented;
2305 
2306     const format_tag_t blocked_tag = is_superset(isa, avx512_common)
2307             ? utils::pick(ndims() - 4, nChw16c, nCdhw16c)
2308             : utils::pick(ndims() - 4, nChw8c, nCdhw8c);
2309 
2310     const format_tag_t blocked_format
2311             = memory_desc_matches_tag(*src_md(), blocked_tag)
2312             ? blocked_tag
2313             : format_tag::undef;
2314     const format_tag_t nspc_format
2315             = memory_desc_matches_one_of_tag(*src_md(), nhwc, ndhwc);
2316 
2317     if (memory_desc_matches_tag(*diff_src_md(), blocked_format))
2318         tag_kind_ = jit_memory_tag_kind_t::blocked;
2319     else if (memory_desc_matches_tag(*diff_src_md(), nspc_format)) {
2320         tag_kind_ = jit_memory_tag_kind_t::nspc;
2321         const int simd_w = get_simd_w<isa>(tag_kind_);
2322         if (C() % simd_w != 0) return status::unimplemented;
2323     } else
2324         return status::unimplemented;
2325 
2326     const bool isa_supports_avx2 = is_superset(isa, avx2);
2327     if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2328             && !isa_supports_avx2)
2329         return status::unimplemented;
2330 
2331     if (fuse_norm_relu()) {
2332         if (!isa_supports_avx2) return status::unimplemented;
2333         init_default_ws(1);
2334         if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
2335     }
2336 
2337     auto scratchpad = scratchpad_registry().registrar();
2338     bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this);
2339 
2340     return status::success;
2341 }
2342 
2343 template <cpu_isa_t isa>
2344 jit_uni_tbb_batch_normalization_bwd_t<
jit_uni_tbb_batch_normalization_bwd_t(const pd_t * apd)2345         isa>::jit_uni_tbb_batch_normalization_bwd_t(const pd_t *apd)
2346     : primitive_t(apd) {}
2347 
2348 template <cpu_isa_t isa>
init(engine_t * engine)2349 status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::init(engine_t *engine) {
2350     CHECK(safe_ptr_assign(bnorm_driver_,
2351             new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_)));
2352     return bnorm_driver_->create_kernel();
2353 }
2354 
2355 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const2356 status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::execute(
2357         const exec_ctx_t &ctx) const {
2358 
2359     const memory_desc_wrapper diff_ss_d(pd()->diff_weights_md());
2360 
2361     const auto use_ss = pd()->use_scaleshift();
2362     const auto use_sc = pd()->use_scale();
2363     const auto use_sh = pd()->use_shift();
2364 
2365     const size_t diff_shift_off
2366             = use_ss && !diff_ss_d.has_zero_dim() ? diff_ss_d.off(1, 0) : 0;
2367 
2368     auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2369     auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
2370     auto var = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
2371     auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
2372     auto scale = CTX_IN_MEM(
2373             const acc_data_t *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
2374     auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
2375 
2376     auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
2377     auto diff_scale = CTX_OUT_MEM(acc_data_t *,
2378             use_sc ? DNNL_ARG_DIFF_SCALE : DNNL_ARG_DIFF_SCALE_SHIFT);
2379     auto diff_shift = use_sh ? CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT)
2380                              : use_ss ? &diff_scale[diff_shift_off] : nullptr;
2381 
2382     auto scratchpad = ctx.get_scratchpad_grantor();
2383 
2384     bnorm_driver_->exec_bwd(src, diff_src, diff_dst, scale, diff_scale,
2385             diff_shift, mean, var, ws, scratchpad);
2386 
2387     return status::success;
2388 }
2389 
2390 template <cpu_isa_t isa>
2391 jit_uni_tbb_batch_normalization_bwd_t<
2392         isa>::~jit_uni_tbb_batch_normalization_bwd_t()
2393         = default;
2394 
2395 template struct jit_uni_tbb_batch_normalization_bwd_t<sse41>;
2396 template struct jit_uni_tbb_batch_normalization_bwd_t<avx2>;
2397 template struct jit_uni_tbb_batch_normalization_bwd_t<avx512_common>;
2398 } // namespace x64
2399 } // namespace cpu
2400 } // namespace impl
2401 } // namespace dnnl
2402