1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cassert>
18 #include <cmath>
19 #include <memory>
20
21 #include "common/c_types_map.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/math_utils.hpp"
24 #include "common/memory_tracking.hpp"
25 #include "common/nstl.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28
29 #include "cpu/cpu_batch_normalization_utils.hpp"
30 #include "cpu/platform.hpp"
31 #include "cpu/x64/jit_generator.hpp"
32
33 #include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
34 #include "cpu/x64/jit_uni_tbb_batch_normalization.hpp"
35
36 namespace dnnl {
37 namespace impl {
38 namespace cpu {
39 namespace x64 {
40
41 namespace {
42
43 using namespace memory_tracking::names;
44 using namespace Xbyak;
45 using acc_data_t = float;
46
47 constexpr int bits_per_byte = 8;
48
get_c_padded(const batch_normalization_pd_t * bdesc)49 dim_t get_c_padded(const batch_normalization_pd_t *bdesc) {
50 return bdesc->src_md()->padded_dims[1];
51 }
52
53 template <cpu_isa_t isa>
get_vlen(jit_memory_tag_kind_t tag_kind)54 int get_vlen(jit_memory_tag_kind_t tag_kind) {
55 return isa == sse41 && tag_kind == jit_memory_tag_kind_t::blocked
56 ? 32
57 : cpu_isa_traits<isa>::vlen;
58 }
59
60 template <cpu_isa_t isa>
get_simd_w(jit_memory_tag_kind_t tag_kind)61 int get_simd_w(jit_memory_tag_kind_t tag_kind) {
62 return get_vlen<isa>(tag_kind) / sizeof(acc_data_t);
63 }
64
65 template <cpu_isa_t isa>
get_data_strides(const batch_normalization_pd_t * bdesc,jit_memory_tag_kind_t tag_kind)66 std::tuple<dim_t, dim_t, dim_t> get_data_strides(
67 const batch_normalization_pd_t *bdesc, jit_memory_tag_kind_t tag_kind) {
68 const int simd_w = get_simd_w<isa>(tag_kind);
69 size_t stride_N, stride_S, stride_C;
70
71 if (tag_kind == jit_memory_tag_kind_t::nspc) {
72 stride_C = static_cast<size_t>(simd_w);
73 stride_S = static_cast<size_t>(bdesc->C());
74 stride_N = static_cast<size_t>(bdesc->D() * bdesc->H() * bdesc->W())
75 * stride_S;
76 } else {
77 const size_t C_blks = static_cast<size_t>(get_c_padded(bdesc) / simd_w);
78
79 stride_C = static_cast<size_t>(
80 bdesc->D() * bdesc->H() * bdesc->W() * simd_w);
81 stride_S = static_cast<size_t>(simd_w);
82 stride_N = C_blks * stride_C;
83 }
84
85 return std::make_tuple(stride_N, stride_S, stride_C);
86 }
87
88 #define PARAM_ADDR(x) (reg_param_ + offsetof(call_params_t, x))
89 template <cpu_isa_t isa>
90 struct jit_bnorm_process_tail_t {
91 using Vmm = typename cpu_isa_traits<isa>::Vmm;
92
jit_bnorm_process_tail_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t93 jit_bnorm_process_tail_t(const batch_normalization_pd_t *bdesc,
94 jit_generator *host, Reg64 reg_tmp, Reg64 reg_blk_has_tail,
95 Reg64 reg_C, Vmm vtail_mask, Opmask ktail_mask)
96 : h_(host)
97 , reg_tmp_(reg_tmp)
98 , reg_blk_has_tail_(reg_blk_has_tail)
99 , reg_C_(reg_C)
100 , vtail_mask_(vtail_mask)
101 , ktail_mask_(ktail_mask) {
102 const memory_desc_wrapper data_d(bdesc->src_md());
103 c_is_padded_ = bdesc->C() != data_d.padded_dims()[1];
104
105 const int vlen = isa == sse41 ? 32 : cpu_isa_traits<isa>::vlen;
106 tail_ = bdesc->C() % (int)(vlen / sizeof(float));
107 }
108
109 jit_generator *const h_;
110 const Reg64 reg_tmp_;
111 const Reg64 reg_blk_has_tail_;
112 const Reg64 reg_C_;
113 const Vmm vtail_mask_;
114 const Opmask ktail_mask_;
115 bool c_is_padded_;
116 int tail_;
117
prepare_tail_mask_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t118 void prepare_tail_mask_avx512_common() {
119 if (!c_is_padded_) return;
120
121 const int mask = (1 << tail_) - 1;
122
123 Reg32 regw_tmp = reg_tmp_.cvt32();
124 h_->mov(regw_tmp, mask);
125 h_->kmovw(ktail_mask_, regw_tmp);
126 }
127
prepare_tail_mask_avx2_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t128 void prepare_tail_mask_avx2_common() {
129 if (!c_is_padded_) return;
130
131 static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
132 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0,
133 0, 0, 0, 0, 0, 0, 0};
134
135 h_->mov(reg_tmp_, reinterpret_cast<size_t>(&mask[8 - tail_]));
136 h_->vmovups(vtail_mask_, h_->ptr[reg_tmp_]);
137 }
138
prepare_taildnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t139 void prepare_tail() {
140 if (isa == avx512_common)
141 prepare_tail_mask_avx512_common();
142 else if (isa == avx2)
143 prepare_tail_mask_avx2_common();
144 }
145
uni_vmovups_tail_avx2_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t146 void uni_vmovups_tail_avx2_common(
147 const Operand &dst, const Operand &src, Label &l_ret) {
148 if (dst.isMEM()) {
149 h_->vmaskmovps(dst.getAddress(), vtail_mask_, Vmm(src.getIdx()));
150 } else {
151 h_->vmaskmovps(Vmm(dst.getIdx()), vtail_mask_, src.getAddress());
152 }
153 h_->jmp(l_ret);
154 }
155
uni_vmovups_tail_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t156 void uni_vmovups_tail_avx512_common(
157 const Operand &dst, const Operand &src, Label &l_ret) {
158 if (dst.isMEM())
159 h_->uni_vmovups(dst.getAddress() | ktail_mask_ | h_->T_z,
160 Vmm(src.getIdx()));
161 else
162 h_->uni_vmovups(Vmm(dst.getIdx()) | ktail_mask_ | h_->T_z,
163 src.getAddress());
164
165 h_->jmp(l_ret);
166 }
167
uni_vmovups_maybe_taildnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_tail_t168 void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
169 Label l_no_mask, l_ret;
170 if (c_is_padded_) {
171 h_->cmp(reg_blk_has_tail_, 0);
172 h_->jz(l_no_mask);
173
174 h_->cmp(reg_C_, 1);
175 h_->jne(l_no_mask);
176 assert(isa == avx512_common || isa == avx2);
177 if (isa == avx512_common)
178 uni_vmovups_tail_avx512_common(dst, src, l_ret);
179 else if (isa == avx2)
180 uni_vmovups_tail_avx2_common(dst, src, l_ret);
181 }
182 h_->L(l_no_mask);
183 if (dst.isMEM())
184 h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
185 else
186 h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
187
188 h_->L(l_ret);
189 }
190 };
191
192 template <cpu_isa_t isa>
193 struct jit_bnorm_process_relu_t {
194 using Vmm = typename cpu_isa_traits<isa>::Vmm;
195
jit_bnorm_process_relu_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t196 jit_bnorm_process_relu_t(const batch_normalization_pd_t *bdesc,
197 jit_generator *host, Reg64 reg_off_dat, Reg64 reg_tmp,
198 Reg64 reg_ptr_ws, Vmm vzero, Vmm vstore_mask, Opmask kstore_mask)
199 : h_(host)
200 , reg_off_dat_(reg_off_dat)
201 , reg_tmp_(reg_tmp)
202 , reg_ptr_ws_(reg_ptr_ws)
203 , vzero_(vzero)
204 , vstore_mask_(vstore_mask)
205 , kstore_mask_(kstore_mask) {
206 with_relu_ = bdesc->with_relu_post_op() || bdesc->fuse_norm_relu();
207 with_relu_inf_only_ = with_relu_
208 && !(bdesc->fuse_norm_relu() && bdesc->is_training());
209
210 bit_shift_ = static_cast<int>(log2(bits_per_byte
211 * types::data_type_size(bdesc->desc()->data_desc.data_type)));
212 }
213
214 jit_generator *const h_;
215 const Reg64 reg_off_dat_;
216 const Reg64 reg_tmp_;
217 const Reg64 reg_ptr_ws_;
218 const Vmm vzero_, vstore_mask_;
219 const Opmask kstore_mask_;
220 Label l_relu_mask_avx2_;
221 bool with_relu_, with_relu_inf_only_;
222 int bit_shift_;
223
fwd_prepare_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t224 void fwd_prepare_relu() {
225 if (with_relu_) { h_->uni_vpxor(vzero_, vzero_, vzero_); }
226 }
227
bwd_prepare_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t228 void bwd_prepare_relu() {
229 if (with_relu_) {
230 h_->uni_vpxor(vzero_, vzero_, vzero_);
231 if (isa == avx2) prepare_l_relu_mask_avx2();
232 }
233 }
234
prepare_l_relu_mask_avx2dnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t235 void prepare_l_relu_mask_avx2() {
236 Label l_mask_after;
237 h_->jmp(l_mask_after);
238 h_->align(32);
239 h_->L(l_relu_mask_avx2_); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
240 for (int i = 0; i < 8; ++i)
241 h_->dd(1 << i);
242 h_->L(l_mask_after);
243 }
244
fwd_process_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t245 void fwd_process_relu(Vmm v, const int off = 0) {
246 if (with_relu_inf_only_) {
247 h_->uni_vmaxps(v, v, vzero_);
248 } else if (with_relu_) {
249 if (isa == avx512_common)
250 fwd_process_relu_avx512_common(v, off);
251 else if (isa == avx2)
252 fwd_process_relu_avx2(v, off);
253 else
254 assert(false);
255 }
256 }
257
bwd_process_reludnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t258 void bwd_process_relu(Vmm v, const int off = 0) {
259 if (with_relu_) {
260 if (isa == avx512_common)
261 bwd_process_relu_avx512_common(v, off);
262 else if (isa == avx2)
263 bwd_process_relu_avx2(v, off);
264 else
265 assert(false);
266 }
267 }
268
fwd_process_relu_avx2dnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t269 void fwd_process_relu_avx2(Vmm vdst, const int off = 0) {
270 Reg64 reg_store_mask = reg_tmp_;
271 h_->shr(reg_off_dat_, bit_shift_);
272 h_->vcmpps(vstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os);
273 h_->vmovmskps(reg_store_mask, vstore_mask_);
274 h_->mov(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off],
275 reg_store_mask.cvt8());
276 h_->vblendvps(vdst, vzero_, vdst, vstore_mask_);
277 h_->shl(reg_off_dat_, bit_shift_);
278 }
279
fwd_process_relu_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t280 void fwd_process_relu_avx512_common(Vmm vdst, const int off = 0) {
281 h_->shr(reg_off_dat_, bit_shift_);
282 h_->vcmpps(kstore_mask_, vzero_, vdst, jit_generator::_cmp_lt_os);
283 h_->kmovw(h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off], kstore_mask_);
284 h_->vblendmps(vdst | kstore_mask_, vzero_, vdst);
285 h_->shl(reg_off_dat_, bit_shift_);
286 }
287
bwd_process_relu_avx2dnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t288 void bwd_process_relu_avx2(Vmm vdiff_dst, const int off = 0) {
289 h_->shr(reg_off_dat_, bit_shift_);
290 h_->vpbroadcastb(
291 vstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]);
292 h_->vpand(vstore_mask_, vstore_mask_,
293 h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]);
294 h_->vpcmpeqd(vstore_mask_, vstore_mask_,
295 h_->ptr[Xbyak::util::rip + l_relu_mask_avx2_]);
296 h_->vblendvps(vdiff_dst, vzero_, vdiff_dst, vstore_mask_);
297 h_->shl(reg_off_dat_, bit_shift_);
298 }
299
bwd_process_relu_avx512_commondnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_process_relu_t300 void bwd_process_relu_avx512_common(Vmm vdiff_dst, const int off = 0) {
301 h_->shr(reg_off_dat_, bit_shift_);
302 h_->kmovw(kstore_mask_, h_->ptr[reg_ptr_ws_ + reg_off_dat_ + off]);
303 h_->vmovups(vdiff_dst | kstore_mask_ | h_->T_z, vdiff_dst);
304 h_->shl(reg_off_dat_, bit_shift_);
305 }
306 };
307
308 template <cpu_isa_t isa>
309 struct jit_bnorm_bf16_emulation_t {
310 using Vmm = typename cpu_isa_traits<isa>::Vmm;
311
jit_bnorm_bf16_emulation_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bf16_emulation_t312 jit_bnorm_bf16_emulation_t(const batch_normalization_pd_t *bdesc,
313 jit_generator *host, Zmm zmm_reserved_1, Zmm zmm_reserved_2,
314 Zmm zmm_reserved_3, Zmm zmm_reserved_4, Reg64 reg_tmp)
315 : h_(host), bf16_emu_(nullptr) {
316 is_bf16_ = bdesc->desc()->data_desc.data_type == data_type::bf16;
317 if (is_bf16_ && !mayiuse(avx512_core_bf16)) {
318 bf16_emu_ = utils::make_unique<bf16_emulation_t>(h_, zmm_reserved_1,
319 zmm_reserved_2, zmm_reserved_3, reg_tmp, zmm_reserved_4,
320 zmm_reserved_4);
321 bf16_emu_->init_vcvtneps2bf16();
322 }
323 }
324
325 jit_generator *const h_;
326 std::unique_ptr<bf16_emulation_t> bf16_emu_;
327 bool is_bf16_;
328
uni_vmovups_datadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bf16_emulation_t329 void uni_vmovups_data(const Operand &dst, const Operand &src) {
330 if (dst.isMEM()) {
331 if (is_bf16_) {
332 constexpr bool isAvx2 = isa == avx2;
333 const typename std::conditional<isAvx2, Xmm, Ymm>::type
334 dst_reg {src.getIdx()};
335 const typename std::conditional<isAvx2, Ymm, Zmm>::type
336 src_reg {src.getIdx()};
337
338 // convert f32 output to bf16
339 if (!bf16_emu_)
340 h_->vcvtneps2bf16(dst_reg, src_reg);
341 else
342 bf16_emu_->vcvtneps2bf16(dst_reg, src_reg);
343
344 h_->vmovdqu16(dst.getAddress(), dst_reg);
345 } else {
346 h_->uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
347 }
348 } else {
349 if (is_bf16_) {
350 // convert bf16 input to f32
351 h_->vpmovzxwd(Vmm(dst.getIdx()), src.getAddress());
352 h_->vpslld(Vmm(dst.getIdx()), Vmm(dst.getIdx()), 0x10);
353 } else {
354 h_->uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
355 }
356 }
357 }
358
359 private:
360 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_bnorm_bf16_emulation_t);
361 };
362
363 template <cpu_isa_t isa>
364 struct jit_bnorm_fwd_statistics_t : public jit_generator {
365 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_statistics_t)
366 using Vmm = typename cpu_isa_traits<isa>::Vmm;
367
368 const AddressFrame &vmmword
369 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
370
371 struct call_params_t {
372 size_t N, C, S;
373 const void *src;
374 const acc_data_t *mean;
375 const acc_data_t *var;
376 size_t blk_has_tail;
377 size_t do_normalise;
378 };
379
380 const Reg64 ®_param_ = abi_param1;
381 const Reg64 ®_tmp_ = abi_not_param1;
382 const Reg64 ®_N_ = rsi;
383 const Reg64 ®_S_ = rax;
384 const Reg64 ®_C_ = rdx;
385 const Reg64 ®_off_c_ = rbx;
386 const Reg64 ®_blk_has_tail_ = rbp;
387
388 const Reg64 ®_off_dat_ = r8;
389 const Reg64 ®_off_dat_save_ = r9;
390 const Reg64 ®_ptr_mean_ = r10;
391 const Reg64 ®_ptr_var_ = r11;
392 const Reg64 ®_ptr_src_ = r12;
393 const Reg64 ®_do_normalise_ = r13;
394 const Reg64 ®_ptr_stat_ = r14;
395
396 const Vmm v_ = Vmm(0);
397 const Vmm vtmp_ = Vmm(1);
398 const Vmm vtail_mask_ = Vmm(2);
399 const Vmm vNS_ = Vmm(3);
400 const Vmm vzero_ = Vmm(4);
401 // When variance is computed then two vmms(one for variance and
402 // one for mean) are needed to unroll one c block at any moment,
403 // therefore the number of registers which are used to unrolling
404 // must to be divisible by two.
405 static constexpr int min_idx_to_unroll_ = 4;
406 static constexpr int max_idx_to_unroll_ = isa == avx512_common ? 28 : 16;
407 static constexpr int number_of_vmms_to_unrolling_variables_
408 = max_idx_to_unroll_ - min_idx_to_unroll_;
409 static_assert(number_of_vmms_to_unrolling_variables_ % 2 == 0
410 && number_of_vmms_to_unrolling_variables_ != 0,
411 "Number of register to unrolling must to be divisible by 2.");
412
413 const Opmask &ktail_mask_ = k2;
414
415 const batch_normalization_pd_t *bdesc_;
416 const jit_memory_tag_kind_t tag_kind_;
417 const int vlen;
418 const int simd_w;
419 jit_bnorm_process_tail_t<isa> jit_tail_;
420 jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
421 int stride_N_, stride_S_, stride_C_;
422 size_t data_type_size_, acc_type_size_;
423
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t424 void load_common_params() {
425 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
426 mov(reg_ptr_src_, PARAM_PTR(src));
427 mov(reg_ptr_mean_, PARAM_PTR(mean));
428 mov(reg_ptr_var_, PARAM_PTR(var));
429 #undef PARAM_PTR
430 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
431 mov(reg_do_normalise_, dword[PARAM_ADDR(do_normalise)]);
432 }
433
zeroisednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t434 void zeroise() {
435 Label label_zeroise;
436 xor_(reg_off_c_, reg_off_c_);
437 uni_vpxor(vzero_, vzero_, vzero_);
438 mov(reg_C_, dword[PARAM_ADDR(C)]);
439 L(label_zeroise);
440 {
441 jit_tail_.uni_vmovups_maybe_tail(
442 vmmword[reg_ptr_stat_ + reg_off_c_], vzero_);
443 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
444 jit_tail_.uni_vmovups_maybe_tail(
445 vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], vzero_);
446 }
447 add(reg_off_c_, simd_w * acc_type_size_);
448 dec(reg_C_);
449 jnz(label_zeroise);
450 }
451 }
452
load_statdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t453 void load_stat(bool compute_mean, const int c_blks_to_unroll = 1) {
454 int start_idx = min_idx_to_unroll_;
455 int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
456 const int step = simd_w * acc_type_size_;
457
458 // load mean or variance
459 for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
460 const Vmm vstat = Vmm(idx);
461 jit_tail_.uni_vmovups_maybe_tail(
462 vstat, vmmword[reg_ptr_stat_ + reg_off_c_ + off]);
463 }
464
465 // if variance is counted then mean also is needed
466 if (!compute_mean) {
467 start_idx = min_idx_to_unroll_ + c_blks_to_unroll;
468 end_idx = min_idx_to_unroll_ + 2 * c_blks_to_unroll;
469
470 for (int idx = start_idx, off = 0; idx < end_idx;
471 idx++, off += step) {
472 const Vmm vmean = Vmm(idx);
473 jit_tail_.uni_vmovups_maybe_tail(
474 vmean, vmmword[reg_ptr_mean_ + reg_off_c_ + off]);
475 }
476 }
477 }
478
compute_statdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t479 void compute_stat(bool compute_mean, const int c_blks_to_unroll = 1) {
480 const int start_idx = min_idx_to_unroll_;
481 const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
482 const int step = simd_w * data_type_size_;
483
484 for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
485 const Vmm vstat = Vmm(idx);
486
487 jit_bf16_emu_.uni_vmovups_data(
488 v_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
489
490 if (compute_mean) {
491 uni_vaddps(vstat, vstat, v_);
492 } else {
493 const Vmm vmean = Vmm(idx + c_blks_to_unroll);
494
495 // var += (src - mean)^2
496 uni_vsubps(vtmp_, v_, vmean, vtmp_);
497 uni_vfmadd231ps(vstat, vtmp_, vtmp_);
498 }
499 }
500 }
501
store_statdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t502 void store_stat(const int c_blks_to_unroll = 1) {
503 const int start_idx = min_idx_to_unroll_;
504 const int end_idx = c_blks_to_unroll + min_idx_to_unroll_;
505 const int step = simd_w * acc_type_size_;
506
507 for (int idx = start_idx, off = 0; idx < end_idx; idx++, off += step) {
508 const Vmm vstat = Vmm(idx);
509
510 jit_tail_.uni_vmovups_maybe_tail(
511 vmmword[reg_ptr_stat_ + reg_off_c_ + off], vstat);
512 }
513 }
514
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t515 void compute_blocked(bool compute_mean) {
516 Label label_C, label_S;
517 mov(reg_C_, dword[PARAM_ADDR(C)]);
518 L(label_C);
519 {
520 mov(reg_off_dat_, reg_off_dat_save_);
521
522 load_stat(compute_mean);
523
524 mov(reg_S_, dword[PARAM_ADDR(S)]);
525 L(label_S);
526 {
527 compute_stat(compute_mean);
528
529 add(reg_off_dat_, stride_S_ * data_type_size_);
530
531 dec(reg_S_);
532 jnz(label_S);
533 }
534
535 store_stat();
536
537 add(reg_off_dat_save_, stride_C_ * data_type_size_);
538 add(reg_off_c_, simd_w * acc_type_size_);
539
540 dec(reg_C_);
541 jnz(label_C);
542 }
543 }
544
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t545 void compute_nspc(bool compute_mean) {
546 mov(reg_C_, dword[PARAM_ADDR(C)]);
547
548 // When a variance is computed, two values are unrolled: mean and variance,
549 // so number_of_vmms_to_unrolling_variables_ is divided by 2.
550 const int max_of_unrolled_c_blks = compute_mean
551 ? number_of_vmms_to_unrolling_variables_
552 : number_of_vmms_to_unrolling_variables_ / 2;
553 std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1);
554
555 for (int c_blks_to_unroll = max_of_unrolled_c_blks;
556 c_blks_to_unroll > 0; --c_blks_to_unroll) {
557 L(c_unroll_label[c_blks_to_unroll]);
558 {
559 cmp(reg_C_, c_blks_to_unroll);
560 jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR);
561
562 mov(reg_off_dat_, reg_off_dat_save_);
563
564 load_stat(compute_mean, c_blks_to_unroll);
565
566 Label label_S;
567 mov(reg_S_, dword[PARAM_ADDR(S)]);
568 L(label_S);
569 {
570 compute_stat(compute_mean, c_blks_to_unroll);
571
572 add(reg_off_dat_, stride_S_ * data_type_size_);
573
574 dec(reg_S_);
575 jnz(label_S);
576 }
577
578 store_stat(c_blks_to_unroll);
579
580 add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_);
581 add(reg_off_dat_save_,
582 c_blks_to_unroll * stride_C_ * data_type_size_);
583
584 sub(reg_C_, c_blks_to_unroll);
585 jmp(c_unroll_label[c_blks_to_unroll], T_NEAR);
586 }
587 }
588 L(c_unroll_label[0]);
589 }
590
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t591 void compute(bool compute_mean) {
592 Label label_N;
593 mov(reg_N_, dword[PARAM_ADDR(N)]);
594 L(label_N);
595 {
596 xor_(reg_off_dat_save_, reg_off_dat_save_);
597 xor_(reg_off_c_, reg_off_c_);
598
599 tag_kind_ == jit_memory_tag_kind_t::nspc
600 ? compute_nspc(compute_mean)
601 : compute_blocked(compute_mean);
602
603 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
604 xor_(reg_off_dat_save_, reg_off_dat_save_);
605 xor_(reg_off_c_, reg_off_c_);
606 add(reg_off_dat_save_, vlen / 2);
607 add(reg_off_c_, vlen / 2);
608
609 compute_blocked(compute_mean);
610 }
611
612 add(reg_ptr_src_, stride_N_ * data_type_size_);
613 dec(reg_N_);
614 jnz(label_N);
615 }
616 }
617
normalizednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t618 void normalize() {
619 Label label_ret, label_normalise;
620 cmp(reg_do_normalise_, 0);
621 jz(label_ret);
622
623 const int S = bdesc_->D() * bdesc_->H() * bdesc_->W();
624 mov(reg_tmp_, float2int(bdesc_->MB() * S));
625 Xmm xtmp = Xmm(vtmp_.getIdx());
626 uni_vmovq(xtmp, reg_tmp_);
627 uni_vbroadcastss(vNS_, xtmp);
628
629 xor_(reg_off_c_, reg_off_c_);
630 mov(reg_C_, dword[PARAM_ADDR(C)]);
631 L(label_normalise);
632 {
633 jit_tail_.uni_vmovups_maybe_tail(
634 v_, vmmword[reg_ptr_stat_ + reg_off_c_]);
635 uni_vdivps(v_, v_, vNS_);
636 jit_tail_.uni_vmovups_maybe_tail(
637 vmmword[reg_ptr_stat_ + reg_off_c_], v_);
638
639 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
640 jit_tail_.uni_vmovups_maybe_tail(
641 v_, vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2]);
642 uni_vdivps(v_, v_, vNS_);
643 jit_tail_.uni_vmovups_maybe_tail(
644 vmmword[reg_ptr_stat_ + reg_off_c_ + vlen / 2], v_);
645 }
646
647 add(reg_off_c_, simd_w * acc_type_size_);
648 dec(reg_C_);
649 jnz(label_normalise);
650 }
651
652 L(label_ret);
653 }
654
jit_bnorm_fwd_statistics_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_statistics_t655 jit_bnorm_fwd_statistics_t(const batch_normalization_pd_t *bdesc,
656 const jit_memory_tag_kind_t tag_kind)
657 : bdesc_(bdesc)
658 , tag_kind_(tag_kind)
659 , vlen(get_vlen<isa>(tag_kind))
660 , simd_w(get_simd_w<isa>(tag_kind))
661 , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
662 vtail_mask_, ktail_mask_)
663 , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
664 static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
665 "unsupported isa");
666
667 std::tie(stride_N_, stride_S_, stride_C_)
668 = get_data_strides<isa>(bdesc_, tag_kind);
669
670 data_type_size_
671 = types::data_type_size(bdesc->desc()->data_desc.data_type);
672 acc_type_size_ = sizeof(acc_data_t);
673 }
674 };
675
676 template <cpu_isa_t isa>
677 struct jit_bnorm_fwd_mean_t : jit_bnorm_fwd_statistics_t<isa> {
678 using call_params_t =
679 typename jit_bnorm_fwd_statistics_t<isa>::call_params_t;
680
jit_bnorm_fwd_mean_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_mean_t681 jit_bnorm_fwd_mean_t(const batch_normalization_pd_t *bdesc,
682 const jit_memory_tag_kind_t tag_kind)
683 : jit_bnorm_fwd_statistics_t<isa>(bdesc, tag_kind) {}
684
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_mean_t685 void generate() override {
686 this->preamble();
687 this->load_common_params();
688 this->mov(this->reg_ptr_stat_, this->reg_ptr_mean_);
689 this->jit_tail_.prepare_tail();
690 this->zeroise();
691 this->compute(true);
692 this->normalize();
693 this->postamble();
694 }
695 };
696
697 template <cpu_isa_t isa>
698 struct jit_bnorm_fwd_var_t : jit_bnorm_fwd_statistics_t<isa> {
699 using call_params_t =
700 typename jit_bnorm_fwd_statistics_t<isa>::call_params_t;
701
jit_bnorm_fwd_var_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_var_t702 jit_bnorm_fwd_var_t(const batch_normalization_pd_t *bdesc,
703 const jit_memory_tag_kind_t tag_kind)
704 : jit_bnorm_fwd_statistics_t<isa>(bdesc, tag_kind) {}
705
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_var_t706 void generate() override {
707 this->preamble();
708 this->load_common_params();
709 this->mov(this->reg_ptr_stat_, this->reg_ptr_var_);
710 this->jit_tail_.prepare_tail();
711 this->zeroise();
712 this->compute(false);
713 this->normalize();
714 this->postamble();
715 }
716 };
717
718 template <cpu_isa_t isa>
719 struct jit_bnorm_fwd_t : public jit_generator {
720 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_fwd_t)
721 using Vmm = typename cpu_isa_traits<isa>::Vmm;
722
723 const AddressFrame &vmmword
724 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
725
726 struct call_params_t {
727 size_t N, C, S;
728 const void *src, *dst;
729 const uint8_t *ws;
730 const acc_data_t *mean, *var;
731 const acc_data_t *scale, *shift;
732 size_t blk_has_tail;
733 };
734
735 const Reg64 ®_param_ = abi_param1;
736 const Reg64 ®_tmp_ = abi_not_param1;
737 const Reg64 ®_N_ = rsi;
738 const Reg64 ®_S_ = rax;
739 const Reg64 ®_C_ = rdx;
740 const Reg64 ®_off_c_ = rbx;
741 const Reg64 ®_blk_has_tail_ = rbp;
742
743 const Reg64 ®_off_dat_ = r8;
744 const Reg64 ®_off_dat_save_ = r9;
745 const Reg64 ®_ptr_ws_ = r10;
746 const Reg64 ®_ptr_scale_ = r11;
747 const Reg64 ®_ptr_shift_ = reg_N_;
748 const Reg64 ®_ptr_var_ = r12;
749 const Reg64 ®_ptr_mean_ = r13;
750 const Reg64 ®_ptr_dst_ = r14;
751 const Reg64 ®_ptr_src_ = r15;
752
753 const Vmm vzero_ = Vmm(0);
754 const Vmm vone_ = Vmm(1);
755 const Vmm vmean_ = Vmm(2);
756 const Vmm vvar_ = Vmm(3);
757 const Vmm vsqrtvar_ = Vmm(4);
758 const Vmm vgamma_ = Vmm(5);
759 const Vmm vbeta_ = Vmm(6);
760 const Vmm veps_ = Vmm(7);
761 const Vmm vtmp_ = Vmm(8);
762 const Vmm v_ = Vmm(9);
763 const Vmm vtail_mask_ = Vmm(10);
764 const Vmm vstore_mask_ = vtmp_;
765
766 const Opmask &kstore_mask_ = k1;
767 const Opmask &ktail_mask_ = k2;
768
769 const batch_normalization_pd_t *bdesc_;
770 const jit_memory_tag_kind_t tag_kind_;
771 const int vlen;
772 const int simd_w;
773 jit_bnorm_process_tail_t<isa> jit_tail_;
774 jit_bnorm_process_relu_t<isa> jit_relu_;
775 jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
776 int stride_N_, stride_S_, stride_C_;
777 size_t data_type_size_, acc_type_size_;
778
779 enum {
780 stack_off_N = 0,
781 stack_off_shift = 8,
782 stack_size_required = 16,
783 };
784
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t785 void load_common_params() {
786 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
787 mov(reg_ptr_src_, PARAM_PTR(src));
788 mov(reg_ptr_dst_, PARAM_PTR(dst));
789 mov(reg_ptr_mean_, PARAM_PTR(mean));
790 mov(reg_ptr_var_, PARAM_PTR(var));
791 mov(reg_ptr_scale_, PARAM_PTR(scale));
792 mov(reg_ptr_ws_, PARAM_PTR(ws));
793
794 Xmm x = Xmm(v_.getIdx());
795
796 mov(reg_tmp_, float2int(bdesc_->desc()->batch_norm_epsilon));
797 uni_vmovq(x, reg_tmp_);
798 uni_vbroadcastss(veps_, x);
799
800 mov(reg_tmp_, float2int(1.f));
801 uni_vmovq(x, reg_tmp_);
802 uni_vbroadcastss(vone_, x);
803
804 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
805
806 mov(reg_tmp_, PARAM_PTR(shift));
807 mov(ptr[rsp + stack_off_shift], reg_tmp_);
808 mov(reg_tmp_, PARAM_PTR(N));
809 mov(ptr[rsp + stack_off_N], reg_tmp_);
810 #undef PARAM_PTR
811 }
812
load_c_specificsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t813 void load_c_specifics() {
814 jit_tail_.uni_vmovups_maybe_tail(
815 vmean_, vmmword[reg_ptr_mean_ + reg_off_c_]);
816 jit_tail_.uni_vmovups_maybe_tail(
817 vvar_, vmmword[reg_ptr_var_ + reg_off_c_]);
818
819 uni_vmovups(vsqrtvar_, vvar_);
820 uni_vaddps(vsqrtvar_, vsqrtvar_, veps_);
821 uni_vsqrtps(vsqrtvar_, vsqrtvar_);
822
823 if (isa == sse41) {
824 movups(vtmp_, vone_);
825 divps(vtmp_, vsqrtvar_);
826 movups(vsqrtvar_, vtmp_);
827 } else
828 vdivps(vsqrtvar_, vone_, vsqrtvar_);
829
830 if (bdesc_->use_scaleshift() || bdesc_->use_scale())
831 jit_tail_.uni_vmovups_maybe_tail(
832 vgamma_, vmmword[reg_ptr_scale_ + reg_off_c_]);
833 if (bdesc_->use_scaleshift() || bdesc_->use_shift())
834 jit_tail_.uni_vmovups_maybe_tail(
835 vbeta_, vmmword[reg_ptr_shift_ + reg_off_c_]);
836 }
837
compute_bnormdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t838 void compute_bnorm(bool stream_store_allowed) {
839 jit_bf16_emu_.uni_vmovups_data(
840 v_, vmmword[reg_ptr_src_ + reg_off_dat_]);
841 uni_vsubps(v_, v_, vmean_);
842 uni_vmulps(v_, v_, vsqrtvar_);
843
844 if (bdesc_->use_scaleshift()
845 || (bdesc_->use_scale() && bdesc_->use_shift()))
846 uni_vfmadd213ps(v_, vgamma_, vbeta_);
847 else if (bdesc_->use_scale())
848 uni_vmulps(v_, v_, vgamma_);
849 else if (bdesc_->use_shift())
850 uni_vaddps(v_, v_, vbeta_);
851
852 jit_relu_.fwd_process_relu(v_);
853
854 if (stream_store_allowed) {
855 uni_vmovntps(vmmword[reg_ptr_dst_ + reg_off_dat_], v_);
856 } else {
857 jit_bf16_emu_.uni_vmovups_data(
858 vmmword[reg_ptr_dst_ + reg_off_dat_], v_);
859 }
860 }
861
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t862 void compute_blocked(bool stream_store_allowed) {
863 Label label_C, label_S;
864 mov(reg_C_, dword[PARAM_ADDR(C)]);
865 L(label_C);
866 {
867 mov(reg_off_dat_, reg_off_dat_save_);
868
869 load_c_specifics();
870
871 mov(reg_S_, dword[PARAM_ADDR(S)]);
872 L(label_S);
873 {
874 compute_bnorm(stream_store_allowed);
875
876 add(reg_off_dat_, stride_S_ * data_type_size_);
877
878 dec(reg_S_);
879 jnz(label_S);
880 }
881
882 add(reg_off_dat_save_, stride_C_ * data_type_size_);
883 add(reg_off_c_, simd_w * acc_type_size_);
884
885 dec(reg_C_);
886 jnz(label_C);
887 }
888 }
889
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t890 void compute_nspc(bool stream_store_allowed) {
891 Label label_C, label_S;
892 mov(reg_S_, dword[PARAM_ADDR(S)]);
893 L(label_S);
894 {
895 mov(reg_off_dat_, reg_off_dat_save_);
896 xor_(reg_off_c_, reg_off_c_);
897
898 mov(reg_C_, dword[PARAM_ADDR(C)]);
899 L(label_C);
900 {
901 load_c_specifics();
902
903 compute_bnorm(stream_store_allowed);
904
905 add(reg_off_c_, simd_w * acc_type_size_);
906 add(reg_off_dat_, stride_C_ * data_type_size_);
907
908 dec(reg_C_);
909 jnz(label_C);
910 }
911
912 add(reg_off_dat_save_, stride_S_ * data_type_size_);
913
914 dec(reg_S_);
915 jnz(label_S);
916 }
917 }
918
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t919 void compute(bool stream_store_allowed) {
920 Label label_N;
921 mov(reg_N_, ptr[rsp + stack_off_N]);
922 L(label_N);
923 {
924 // save reg_N_, because register is shared with reg_ptr_shift_
925 mov(ptr[rsp + stack_off_N], reg_N_);
926 mov(reg_ptr_shift_, ptr[rsp + stack_off_shift]);
927
928 xor_(reg_off_dat_save_, reg_off_dat_save_);
929 xor_(reg_off_c_, reg_off_c_);
930
931 tag_kind_ == jit_memory_tag_kind_t::nspc
932 ? compute_nspc(stream_store_allowed)
933 : compute_blocked(stream_store_allowed);
934
935 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
936 xor_(reg_off_dat_save_, reg_off_dat_save_);
937 xor_(reg_off_c_, reg_off_c_);
938 add(reg_off_dat_save_, vlen / 2);
939 add(reg_off_c_, vlen / 2);
940
941 compute_blocked(stream_store_allowed);
942 }
943
944 add(reg_ptr_src_, stride_N_ * data_type_size_);
945 add(reg_ptr_dst_, stride_N_ * data_type_size_);
946 add(reg_ptr_ws_, stride_N_ / bits_per_byte);
947
948 // restore reg_N_, because register is shared with reg_ptr_shift_
949 mov(reg_N_, ptr[rsp + stack_off_N]);
950 dec(reg_N_);
951 jnz(label_N);
952 }
953 }
954
jit_bnorm_fwd_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t955 jit_bnorm_fwd_t(const batch_normalization_pd_t *bdesc,
956 const jit_memory_tag_kind_t tag_kind)
957 : bdesc_(bdesc)
958 , tag_kind_(tag_kind)
959 , vlen(get_vlen<isa>(tag_kind))
960 , simd_w(get_simd_w<isa>(tag_kind))
961 , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
962 vtail_mask_, ktail_mask_)
963 , jit_relu_(bdesc, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
964 vstore_mask_, kstore_mask_)
965 , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
966 static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
967 "unsupported isa");
968
969 std::tie(stride_N_, stride_S_, stride_C_)
970 = get_data_strides<isa>(bdesc_, tag_kind);
971
972 data_type_size_
973 = types::data_type_size(bdesc->desc()->data_desc.data_type);
974 acc_type_size_ = sizeof(acc_data_t);
975 }
976
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_fwd_t977 void generate() override {
978 bool is_bf16 = bdesc_->desc()->data_desc.data_type == data_type::bf16;
979 const bool is_tail_in_nspc_format
980 = tag_kind_ == jit_memory_tag_kind_t::nspc
981 && jit_tail_.tail_ != 0;
982 const bool stream_store_allowed = !is_bf16 && !is_tail_in_nspc_format;
983
984 preamble();
985 sub(rsp, stack_size_required);
986 load_common_params();
987 jit_relu_.fwd_prepare_relu();
988 jit_tail_.prepare_tail();
989
990 Label normal_store, end_store;
991 test(reg_ptr_dst_, vlen - 1);
992 jnz(normal_store, T_NEAR);
993 compute(stream_store_allowed);
994 jmp(end_store, T_NEAR);
995 L(normal_store);
996 { compute(false); }
997 L(end_store);
998
999 add(rsp, stack_size_required);
1000 postamble();
1001 }
1002 };
1003
1004 template <cpu_isa_t isa>
1005 struct jit_bnorm_bwd_t : public jit_generator {
1006 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_t)
1007 using Vmm = typename cpu_isa_traits<isa>::Vmm;
1008
1009 const AddressFrame &vmmword
1010 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
1011
1012 struct call_params_t {
1013 size_t N, C, S;
1014 const void *src, *diff_src, *diff_dst;
1015 const uint8_t *ws;
1016 const acc_data_t *mean, *var;
1017 const acc_data_t *scale, *diff_scale, *diff_shift;
1018 size_t blk_has_tail;
1019 };
1020
1021 const Reg64 ®_param_ = abi_param1;
1022 const Reg64 ®_tmp_ = abi_not_param1;
1023 const Reg64 ®_N_ = rsi;
1024 const Reg64 ®_S_ = rax;
1025 const Reg64 ®_C_ = rdx;
1026 const Reg64 ®_off_c_ = rbx;
1027 const Reg64 ®_blk_has_tail_ = rbp;
1028
1029 const Reg64 ®_off_dat_ = r8;
1030 const Reg64 ®_off_dat_save_ = r9;
1031 const Reg64 ®_ptr_c_ = r10;
1032 const Reg64 ®_ptr_ws_ = r11;
1033 const Reg64 ®_ptr_diff_dst_ = r12;
1034 const Reg64 ®_ptr_diff_src_ = r13;
1035 const Reg64 ®_ptr_src_ = r14;
1036
1037 const Vmm vzero_ = Vmm(0);
1038 const Vmm vone_ = Vmm(1);
1039 const Vmm vmean_ = Vmm(2);
1040 const Vmm vsqrtvar_ = Vmm(3);
1041 const Vmm vgamma_ = Vmm(4);
1042 const Vmm vdiff_gamma_ = Vmm(5);
1043 const Vmm vdiff_beta_ = Vmm(6);
1044 const Vmm veps_ = Vmm(7);
1045 const Vmm vNS_ = Vmm(8);
1046 const Vmm vtmp_ = Vmm(9);
1047 const Vmm v_ = Vmm(10);
1048 const Vmm vtail_mask_ = Vmm(11);
1049 const Vmm vstore_mask_ = vtmp_;
1050
1051 const Opmask &kstore_mask_ = k1;
1052 const Opmask &ktail_mask_ = k2;
1053
1054 const batch_normalization_pd_t *bdesc_;
1055 const jit_memory_tag_kind_t tag_kind_;
1056 const int vlen;
1057 const int simd_w;
1058 jit_bnorm_process_tail_t<isa> jit_tail_;
1059 jit_bnorm_process_relu_t<isa> jit_relu_;
1060 jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
1061 int stride_N_, stride_S_, stride_C_;
1062 size_t data_type_size_, acc_type_size_;
1063
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1064 void load_common_params() {
1065 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
1066 mov(reg_ptr_src_, PARAM_PTR(src));
1067 mov(reg_ptr_diff_src_, PARAM_PTR(diff_src));
1068 mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst));
1069 mov(reg_ptr_ws_, PARAM_PTR(ws));
1070 #undef PARAM_PTR
1071
1072 Xmm x = Xmm(v_.getIdx());
1073
1074 mov(reg_tmp_, float2int(bdesc_->desc()->batch_norm_epsilon));
1075 uni_vmovq(x, reg_tmp_);
1076 uni_vbroadcastss(veps_, x);
1077
1078 mov(reg_tmp_, float2int(1.f));
1079 uni_vmovq(x, reg_tmp_);
1080 uni_vbroadcastss(vone_, x);
1081
1082 const int S = bdesc_->D() * bdesc_->H() * bdesc_->W();
1083 mov(reg_tmp_, float2int(bdesc_->MB() * S));
1084 uni_vmovq(x, reg_tmp_);
1085 uni_vbroadcastss(vNS_, x);
1086
1087 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
1088 }
1089
load_c_specificsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1090 void load_c_specifics() {
1091 mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]);
1092 jit_tail_.uni_vmovups_maybe_tail(
1093 vmean_, vmmword[reg_ptr_c_ + reg_off_c_]);
1094
1095 mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]);
1096 jit_tail_.uni_vmovups_maybe_tail(
1097 vsqrtvar_, vmmword[reg_ptr_c_ + reg_off_c_]);
1098 uni_vaddps(vsqrtvar_, vsqrtvar_, veps_);
1099 uni_vsqrtps(vsqrtvar_, vsqrtvar_);
1100
1101 if (isa == sse41) {
1102 movups(vtmp_, vone_);
1103 divps(vtmp_, vsqrtvar_);
1104 movups(vsqrtvar_, vtmp_);
1105 } else
1106 vdivps(vsqrtvar_, vone_, vsqrtvar_);
1107
1108 if (bdesc_->use_scaleshift() || bdesc_->use_scale()) {
1109 mov(reg_ptr_c_, ptr[PARAM_ADDR(scale)]);
1110 jit_tail_.uni_vmovups_maybe_tail(
1111 vgamma_, vmmword[reg_ptr_c_ + reg_off_c_]);
1112 }
1113
1114 if (calculate_diff_stats()) {
1115 mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_scale)]);
1116 jit_tail_.uni_vmovups_maybe_tail(
1117 vdiff_gamma_, vmmword[reg_ptr_c_ + reg_off_c_]);
1118 uni_vmulps(vdiff_gamma_, vdiff_gamma_, vsqrtvar_);
1119 uni_vdivps(vdiff_gamma_, vdiff_gamma_, vNS_);
1120 mov(reg_ptr_c_, ptr[PARAM_ADDR(diff_shift)]);
1121 jit_tail_.uni_vmovups_maybe_tail(
1122 vdiff_beta_, vmmword[reg_ptr_c_ + reg_off_c_]);
1123 uni_vdivps(vdiff_beta_, vdiff_beta_, vNS_);
1124 }
1125 }
1126
compute_bnormdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1127 void compute_bnorm(bool stream_store_allowed) {
1128 jit_bf16_emu_.uni_vmovups_data(
1129 v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_]);
1130 jit_relu_.bwd_process_relu(v_);
1131
1132 if (calculate_diff_stats()) {
1133 uni_vsubps(v_, v_, vdiff_beta_);
1134 jit_bf16_emu_.uni_vmovups_data(
1135 vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_]);
1136 uni_vsubps(vtmp_, vtmp_, vmean_);
1137 uni_vmulps(vtmp_, vtmp_, vdiff_gamma_);
1138 uni_vsubps(v_, v_, vtmp_);
1139 }
1140
1141 if (bdesc_->use_scaleshift() || bdesc_->use_scale())
1142 uni_vmulps(v_, v_, vgamma_);
1143 uni_vmulps(v_, v_, vsqrtvar_);
1144
1145 if (stream_store_allowed) {
1146 uni_vmovntps(vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_);
1147 } else {
1148 jit_bf16_emu_.uni_vmovups_data(
1149 vmmword[reg_ptr_diff_src_ + reg_off_dat_], v_);
1150 }
1151 }
1152
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1153 void compute_blocked(bool stream_store_allowed) {
1154 Label label_C, label_S;
1155 mov(reg_C_, dword[PARAM_ADDR(C)]);
1156 L(label_C);
1157 {
1158 mov(reg_off_dat_, reg_off_dat_save_);
1159
1160 load_c_specifics();
1161
1162 mov(reg_S_, dword[PARAM_ADDR(S)]);
1163 L(label_S);
1164 {
1165 compute_bnorm(stream_store_allowed);
1166
1167 add(reg_off_dat_, stride_S_ * data_type_size_);
1168
1169 dec(reg_S_);
1170 jnz(label_S);
1171 }
1172
1173 add(reg_off_dat_save_, stride_C_ * data_type_size_);
1174 add(reg_off_c_, simd_w * acc_type_size_);
1175
1176 dec(reg_C_);
1177 jnz(label_C);
1178 }
1179 }
1180
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1181 void compute_nspc(bool stream_store_allowed) {
1182 Label label_C, label_S;
1183 mov(reg_S_, dword[PARAM_ADDR(S)]);
1184 L(label_S);
1185 {
1186 mov(reg_off_dat_, reg_off_dat_save_);
1187 xor_(reg_off_c_, reg_off_c_);
1188
1189 mov(reg_C_, dword[PARAM_ADDR(C)]);
1190 L(label_C);
1191 {
1192 load_c_specifics();
1193
1194 compute_bnorm(stream_store_allowed);
1195
1196 add(reg_off_c_, simd_w * acc_type_size_);
1197 add(reg_off_dat_, stride_C_ * data_type_size_);
1198
1199 dec(reg_C_);
1200 jnz(label_C);
1201 }
1202
1203 add(reg_off_dat_save_, stride_S_ * data_type_size_);
1204
1205 dec(reg_S_);
1206 jnz(label_S);
1207 }
1208 }
1209
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1210 void compute(bool stream_store_allowed) {
1211 Label label_N;
1212 mov(reg_N_, dword[PARAM_ADDR(N)]);
1213 L(label_N);
1214 {
1215 xor_(reg_off_dat_save_, reg_off_dat_save_);
1216 xor_(reg_off_c_, reg_off_c_);
1217
1218 tag_kind_ == jit_memory_tag_kind_t::nspc
1219 ? compute_nspc(stream_store_allowed)
1220 : compute_blocked(stream_store_allowed);
1221
1222 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1223 xor_(reg_off_dat_save_, reg_off_dat_save_);
1224 xor_(reg_off_c_, reg_off_c_);
1225 add(reg_off_dat_save_, vlen / 2);
1226 add(reg_off_c_, vlen / 2);
1227
1228 compute_blocked(stream_store_allowed);
1229 }
1230
1231 add(reg_ptr_src_, stride_N_ * data_type_size_);
1232 add(reg_ptr_diff_src_, stride_N_ * data_type_size_);
1233 add(reg_ptr_diff_dst_, stride_N_ * data_type_size_);
1234 add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1235
1236 dec(reg_N_);
1237 jnz(label_N);
1238 }
1239 }
1240
calculate_diff_statsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1241 bool calculate_diff_stats() const { return !bdesc_->use_global_stats(); }
1242
jit_bnorm_bwd_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1243 jit_bnorm_bwd_t(const batch_normalization_pd_t *bdesc,
1244 const jit_memory_tag_kind_t tag_kind)
1245 : bdesc_(bdesc)
1246 , tag_kind_(tag_kind)
1247 , vlen(get_vlen<isa>(tag_kind))
1248 , simd_w(get_simd_w<isa>(tag_kind))
1249 , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
1250 vtail_mask_, ktail_mask_)
1251 , jit_relu_(bdesc, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1252 vstore_mask_, kstore_mask_)
1253 , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1254 static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
1255 "unsupported isa");
1256
1257 std::tie(stride_N_, stride_S_, stride_C_)
1258 = get_data_strides<isa>(bdesc_, tag_kind);
1259
1260 data_type_size_
1261 = types::data_type_size(bdesc->desc()->data_desc.data_type);
1262 acc_type_size_ = sizeof(acc_data_t);
1263 }
1264
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_t1265 void generate() override {
1266 bool is_bf16 = bdesc_->desc()->data_desc.data_type == data_type::bf16;
1267 const bool is_tail_in_nspc_format
1268 = tag_kind_ == jit_memory_tag_kind_t::nspc
1269 && jit_tail_.tail_ != 0;
1270 const bool stream_store_allowed = !is_bf16 && !is_tail_in_nspc_format;
1271
1272 preamble();
1273 load_common_params();
1274 jit_relu_.bwd_prepare_relu();
1275 jit_tail_.prepare_tail();
1276
1277 Label normal_store, end_store;
1278 test(reg_ptr_diff_src_, vlen - 1);
1279 jnz(normal_store, T_NEAR);
1280 compute(stream_store_allowed);
1281 jmp(end_store, T_NEAR);
1282 L(normal_store);
1283 { compute(false); }
1284 L(end_store);
1285
1286 postamble();
1287 }
1288 };
1289
1290 template <cpu_isa_t isa>
1291 struct jit_bnorm_bwd_diff_ss_t : public jit_generator {
1292 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_bwd_diff_ss_t)
1293 using Vmm = typename cpu_isa_traits<isa>::Vmm;
1294
1295 const AddressFrame &vmmword
1296 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
1297
1298 struct call_params_t {
1299 size_t N, C, S;
1300 const void *src, *diff_dst;
1301 const uint8_t *ws;
1302 const acc_data_t *mean, *var;
1303 const acc_data_t *diff_gamma, *diff_beta;
1304 size_t blk_has_tail;
1305 };
1306
1307 const Reg64 ®_param_ = abi_param1;
1308 const Reg64 ®_tmp_ = abi_not_param1;
1309 const Reg64 ®_N_ = rsi;
1310 const Reg64 ®_S_ = rax;
1311 const Reg64 ®_C_ = rdx;
1312 const Reg64 ®_off_c_ = rbx;
1313 const Reg64 ®_blk_has_tail_ = rbp;
1314
1315 const Reg64 ®_off_dat_ = r8;
1316 const Reg64 ®_off_dat_save_ = r9;
1317 const Reg64 ®_ptr_c_ = r10;
1318 const Reg64 ®_ptr_diff_gamma_ = r11;
1319 const Reg64 ®_ptr_diff_beta_ = r12;
1320 const Reg64 ®_ptr_ws_ = r13;
1321 const Reg64 ®_ptr_diff_dst_ = r14;
1322 const Reg64 ®_ptr_src_ = r15;
1323
1324 const Vmm vtail_mask_ = Vmm(0);
1325 const Vmm v_ = Vmm(1);
1326 const Vmm vtmp_ = Vmm(2);
1327 const Vmm vstore_mask_ = vtmp_;
1328 const Vmm vzero_ = Vmm(3);
1329 const Vmm veps_ = Vmm(4);
1330 const Vmm vone_ = Vmm(5);
1331 // Diff_beta, diff_gamma and one of the statistic values(mean or sqrtvar)
1332 // are unrolled i.e.three vmms are needed to unroll one c block at any moment,
1333 // therefore the number of registers which are used to unrolling must to be
1334 // divisible by three.
1335 static constexpr int min_idx_to_unroll_ = 6;
1336 static constexpr int max_idx_to_unroll_ = isa == avx512_common ? 27 : 15;
1337 static constexpr int number_of_unrolled_variables_ = 3;
1338 static constexpr int number_of_vmms_to_unrolling_variables_
1339 = max_idx_to_unroll_ - min_idx_to_unroll_;
1340 static_assert(number_of_vmms_to_unrolling_variables_
1341 % number_of_unrolled_variables_
1342 == 0
1343 && number_of_vmms_to_unrolling_variables_ != 0,
1344 "Number of register to unrolling must to be divisible by 3.");
1345
1346 const Opmask &kstore_mask_ = k1;
1347 const Opmask &ktail_mask_ = k2;
1348
1349 const batch_normalization_pd_t *bdesc_;
1350 const jit_memory_tag_kind_t tag_kind_;
1351 const int vlen;
1352 const int simd_w;
1353 jit_bnorm_process_tail_t<isa> jit_tail_;
1354 jit_bnorm_process_relu_t<isa> jit_relu_;
1355 jit_bnorm_bf16_emulation_t<isa> jit_bf16_emu_;
1356 int stride_N_, stride_S_, stride_C_;
1357 size_t data_type_size_, acc_type_size_;
1358
load_common_paramsdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1359 void load_common_params() {
1360 #define PARAM_PTR(x) ptr[PARAM_ADDR(x)]
1361 mov(reg_ptr_src_, PARAM_PTR(src));
1362 mov(reg_ptr_diff_dst_, PARAM_PTR(diff_dst));
1363 mov(reg_ptr_ws_, PARAM_PTR(ws));
1364 mov(reg_ptr_diff_gamma_, PARAM_PTR(diff_gamma));
1365 mov(reg_ptr_diff_beta_, PARAM_PTR(diff_beta));
1366 #undef PARAM_PTR
1367
1368 Xmm x = Xmm(v_.getIdx());
1369
1370 mov(reg_tmp_, float2int(bdesc_->desc()->batch_norm_epsilon));
1371 uni_vmovq(x, reg_tmp_);
1372 uni_vbroadcastss(veps_, x);
1373
1374 mov(reg_tmp_, float2int(1.f));
1375 uni_vmovq(x, reg_tmp_);
1376 uni_vbroadcastss(vone_, x);
1377
1378 mov(reg_blk_has_tail_, dword[PARAM_ADDR(blk_has_tail)]);
1379 }
1380
zeroisednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1381 void zeroise() {
1382 Label label_zeroise;
1383 xor_(reg_off_c_, reg_off_c_);
1384 uni_vpxor(vzero_, vzero_, vzero_);
1385 mov(reg_C_, dword[PARAM_ADDR(C)]);
1386 L(label_zeroise);
1387 {
1388 jit_tail_.uni_vmovups_maybe_tail(
1389 vmmword[reg_ptr_diff_gamma_ + reg_off_c_], vzero_);
1390 jit_tail_.uni_vmovups_maybe_tail(
1391 vmmword[reg_ptr_diff_beta_ + reg_off_c_], vzero_);
1392 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1393 jit_tail_.uni_vmovups_maybe_tail(
1394 vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + vlen / 2],
1395 vzero_);
1396 jit_tail_.uni_vmovups_maybe_tail(
1397 vmmword[reg_ptr_diff_beta_ + reg_off_c_ + vlen / 2],
1398 vzero_);
1399 }
1400 add(reg_off_c_, simd_w * acc_type_size_);
1401 dec(reg_C_);
1402 jnz(label_zeroise);
1403 }
1404 }
1405
load_meandnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1406 void load_mean(const int c_blks_to_unroll = 1) {
1407 mov(reg_ptr_c_, ptr[PARAM_ADDR(mean)]);
1408
1409 const int start_idx = min_idx_to_unroll_;
1410 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1411 + min_idx_to_unroll_;
1412 const int step = simd_w * acc_type_size_;
1413
1414 for (int idx = start_idx, off = 0; idx < end_idx;
1415 idx += number_of_unrolled_variables_, off += step) {
1416 const Vmm vmean = Vmm(idx);
1417
1418 jit_tail_.uni_vmovups_maybe_tail(
1419 vmean, vmmword[reg_ptr_c_ + reg_off_c_ + off]);
1420 }
1421 }
1422
zeroise_diff_beta_and_diff_gammadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1423 void zeroise_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1424 const int start_idx = min_idx_to_unroll_;
1425 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1426 + min_idx_to_unroll_;
1427
1428 for (int idx = start_idx; idx < end_idx;
1429 idx += number_of_unrolled_variables_) {
1430 const Vmm vdiff_beta = Vmm(idx + 1);
1431 const Vmm vdiff_gamma = Vmm(idx + 2);
1432
1433 uni_vpxor(vdiff_beta, vdiff_beta, vdiff_beta);
1434 uni_vpxor(vdiff_gamma, vdiff_gamma, vdiff_gamma);
1435 }
1436 }
1437
load_and_prepare_sqrtvardnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1438 void load_and_prepare_sqrtvar(const int c_blks_to_unroll = 1) {
1439 mov(reg_ptr_c_, ptr[PARAM_ADDR(var)]);
1440
1441 const int start_idx = min_idx_to_unroll_;
1442 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1443 + min_idx_to_unroll_;
1444 const int step = simd_w * acc_type_size_;
1445
1446 for (int idx = start_idx, off = 0; idx < end_idx;
1447 idx += number_of_unrolled_variables_, off += step) {
1448 const Vmm vsqrtvar = Vmm(idx);
1449
1450 jit_tail_.uni_vmovups_maybe_tail(
1451 vsqrtvar, vmmword[reg_ptr_c_ + reg_off_c_ + off]);
1452
1453 // 1.0 / sqrt(var + eps)
1454 uni_vaddps(vsqrtvar, vsqrtvar, veps_);
1455 uni_vsqrtps(vsqrtvar, vsqrtvar);
1456
1457 if (isa == sse41) {
1458 movups(vtmp_, vone_);
1459 divps(vtmp_, vsqrtvar);
1460 movups(vsqrtvar, vtmp_);
1461 } else
1462 vdivps(vsqrtvar, vone_, vsqrtvar);
1463 }
1464 }
1465
compute_diff_beta_and_diff_gammadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1466 void compute_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1467 const int start_idx = min_idx_to_unroll_;
1468 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1469 + min_idx_to_unroll_;
1470 const int step = simd_w * data_type_size_;
1471
1472 for (int idx = start_idx, off = 0; idx < end_idx;
1473 idx += number_of_unrolled_variables_, off += step) {
1474 const Vmm vmean = Vmm(idx);
1475 const Vmm vdiff_beta = Vmm(idx + 1);
1476 const Vmm vdiff_gamma = Vmm(idx + 2);
1477
1478 jit_bf16_emu_.uni_vmovups_data(
1479 v_, vmmword[reg_ptr_diff_dst_ + reg_off_dat_ + off]);
1480
1481 jit_relu_.bwd_process_relu(
1482 v_, off / (bits_per_byte * data_type_size_));
1483
1484 // diff_beta
1485 uni_vaddps(vdiff_beta, vdiff_beta, v_);
1486
1487 jit_bf16_emu_.uni_vmovups_data(
1488 vtmp_, vmmword[reg_ptr_src_ + reg_off_dat_ + off]);
1489
1490 // diff_gamma, note that diff_gamma will be multiplied
1491 // by sqrtvar before store
1492 uni_vsubps(vtmp_, vtmp_, vmean);
1493 uni_vfmadd231ps(vdiff_gamma, vtmp_, v_);
1494 }
1495 }
1496
store_diff_beta_and_diff_gammadnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1497 void store_diff_beta_and_diff_gamma(const int c_blks_to_unroll = 1) {
1498 const int start_idx = min_idx_to_unroll_;
1499 const int end_idx = number_of_unrolled_variables_ * c_blks_to_unroll
1500 + min_idx_to_unroll_;
1501 const int step = simd_w * acc_type_size_;
1502
1503 for (int idx = start_idx, off = 0; idx < end_idx;
1504 idx += number_of_unrolled_variables_, off += step) {
1505 const Vmm vdiff_beta = Vmm(idx + 1);
1506
1507 jit_tail_.uni_vmovups_maybe_tail(
1508 vtmp_, vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off]);
1509 uni_vaddps(vdiff_beta, vdiff_beta, vtmp_);
1510 jit_tail_.uni_vmovups_maybe_tail(
1511 vmmword[reg_ptr_diff_beta_ + reg_off_c_ + off], vdiff_beta);
1512 }
1513
1514 for (int idx = start_idx, off = 0; idx < end_idx;
1515 idx += number_of_unrolled_variables_, off += step) {
1516 const Vmm vsqrtvar = Vmm(idx);
1517 const Vmm vdiff_gamma = Vmm(idx + 2);
1518
1519 // multiply diff_gamma by 1.0/sqrt(var + eps)
1520 uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
1521
1522 jit_tail_.uni_vmovups_maybe_tail(
1523 vtmp_, vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off]);
1524 uni_vaddps(vdiff_gamma, vdiff_gamma, vtmp_);
1525 jit_tail_.uni_vmovups_maybe_tail(
1526 vmmword[reg_ptr_diff_gamma_ + reg_off_c_ + off],
1527 vdiff_gamma);
1528 }
1529 }
1530
compute_blockeddnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1531 void compute_blocked() {
1532 Label label_C, label_S;
1533 mov(reg_C_, dword[PARAM_ADDR(C)]);
1534 L(label_C);
1535 {
1536 mov(reg_off_dat_, reg_off_dat_save_);
1537
1538 load_mean();
1539 zeroise_diff_beta_and_diff_gamma();
1540
1541 mov(reg_S_, dword[PARAM_ADDR(S)]);
1542 L(label_S);
1543 {
1544 compute_diff_beta_and_diff_gamma();
1545
1546 add(reg_off_dat_, stride_S_ * data_type_size_);
1547
1548 dec(reg_S_);
1549 jnz(label_S);
1550 }
1551
1552 load_and_prepare_sqrtvar();
1553 store_diff_beta_and_diff_gamma();
1554
1555 add(reg_off_dat_save_, stride_C_ * data_type_size_);
1556 add(reg_off_c_, simd_w * acc_type_size_);
1557
1558 dec(reg_C_);
1559 jnz(label_C);
1560 }
1561 }
1562
compute_nspcdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1563 void compute_nspc() {
1564 mov(reg_C_, dword[PARAM_ADDR(C)]);
1565
1566 constexpr int max_of_unrolled_c_blks
1567 = number_of_vmms_to_unrolling_variables_
1568 / number_of_unrolled_variables_;
1569 std::vector<Label> c_unroll_label(max_of_unrolled_c_blks + 1);
1570
1571 for (int c_blks_to_unroll = max_of_unrolled_c_blks;
1572 c_blks_to_unroll > 0; --c_blks_to_unroll) {
1573 L(c_unroll_label[c_blks_to_unroll]);
1574 {
1575 cmp(reg_C_, c_blks_to_unroll);
1576 jl(c_unroll_label[c_blks_to_unroll - 1], T_NEAR);
1577
1578 mov(reg_off_dat_, reg_off_dat_save_);
1579
1580 load_mean(c_blks_to_unroll);
1581 zeroise_diff_beta_and_diff_gamma(c_blks_to_unroll);
1582
1583 Label label_S;
1584 mov(reg_S_, dword[PARAM_ADDR(S)]);
1585 L(label_S);
1586 {
1587 compute_diff_beta_and_diff_gamma(c_blks_to_unroll);
1588
1589 add(reg_off_dat_, stride_S_ * data_type_size_);
1590
1591 dec(reg_S_);
1592 jnz(label_S);
1593 }
1594
1595 load_and_prepare_sqrtvar(c_blks_to_unroll);
1596 store_diff_beta_and_diff_gamma(c_blks_to_unroll);
1597
1598 add(reg_off_c_, c_blks_to_unroll * simd_w * acc_type_size_);
1599 add(reg_off_dat_save_,
1600 c_blks_to_unroll * stride_C_ * data_type_size_);
1601
1602 sub(reg_C_, c_blks_to_unroll);
1603 jmp(c_unroll_label[c_blks_to_unroll], T_NEAR);
1604 }
1605 }
1606 L(c_unroll_label[0]);
1607 }
1608
computednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1609 void compute() {
1610 Label label_N;
1611 mov(reg_N_, dword[PARAM_ADDR(N)]);
1612 L(label_N);
1613 {
1614 xor_(reg_off_dat_save_, reg_off_dat_save_);
1615 xor_(reg_off_c_, reg_off_c_);
1616
1617 tag_kind_ == jit_memory_tag_kind_t::nspc ? compute_nspc()
1618 : compute_blocked();
1619
1620 if (isa == sse41 && tag_kind_ == jit_memory_tag_kind_t::blocked) {
1621 xor_(reg_off_dat_save_, reg_off_dat_save_);
1622 xor_(reg_off_c_, reg_off_c_);
1623 add(reg_off_dat_save_, vlen / 2);
1624 add(reg_off_c_, vlen / 2);
1625
1626 compute_blocked();
1627 }
1628
1629 add(reg_ptr_src_, stride_N_ * data_type_size_);
1630 add(reg_ptr_diff_dst_, stride_N_ * data_type_size_);
1631 add(reg_ptr_ws_, stride_N_ / bits_per_byte);
1632
1633 dec(reg_N_);
1634 jnz(label_N);
1635 }
1636 }
1637
jit_bnorm_bwd_diff_ss_tdnnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1638 jit_bnorm_bwd_diff_ss_t(const batch_normalization_pd_t *bdesc,
1639 const jit_memory_tag_kind_t tag_kind)
1640 : bdesc_(bdesc)
1641 , tag_kind_(tag_kind)
1642 , vlen(get_vlen<isa>(tag_kind))
1643 , simd_w(get_simd_w<isa>(tag_kind))
1644 , jit_tail_(bdesc, this, reg_tmp_, reg_blk_has_tail_, reg_C_,
1645 vtail_mask_, ktail_mask_)
1646 , jit_relu_(bdesc, this, reg_off_dat_, reg_tmp_, reg_ptr_ws_, vzero_,
1647 vstore_mask_, kstore_mask_)
1648 , jit_bf16_emu_(bdesc, this, zmm28, zmm29, zmm30, zmm31, reg_tmp_) {
1649 static_assert(isa == sse41 || isa == avx2 || isa == avx512_common,
1650 "unsupported isa");
1651
1652 std::tie(stride_N_, stride_S_, stride_C_)
1653 = get_data_strides<isa>(bdesc_, tag_kind);
1654
1655 data_type_size_
1656 = types::data_type_size(bdesc->desc()->data_desc.data_type);
1657 acc_type_size_ = sizeof(acc_data_t);
1658 }
1659
generatednnl::impl::cpu::x64::__anon2759f98d0111::jit_bnorm_bwd_diff_ss_t1660 void generate() override {
1661 preamble();
1662 load_common_params();
1663 jit_relu_.bwd_prepare_relu();
1664 jit_tail_.prepare_tail();
1665 zeroise();
1666 compute();
1667 postamble();
1668 }
1669 };
1670 } // namespace
1671 namespace bnorm_tbb_impl {
1672
1673 template <cpu_isa_t isa>
1674 struct driver_t : public c_compatible {
1675 private:
1676 struct bnorm_dims_t {
1677 dim_t N, C, S;
1678 dim_t glob;
1679 };
1680
1681 DNNL_DISALLOW_COPY_AND_ASSIGN(driver_t);
1682
1683 public:
driver_tdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1684 driver_t(const batch_normalization_pd_t *bdesc,
1685 const jit_memory_tag_kind_t tag_kind)
1686 : bdesc_(bdesc)
1687 , tag_kind_(tag_kind)
1688 , simd_w(get_simd_w<isa>(tag_kind)) {
1689 nthr_ = dnnl_get_max_threads();
1690 N_ = bdesc_->MB();
1691 S_ = bdesc_->D() * bdesc_->H() * bdesc_->W();
1692 C_ = bdesc_->C();
1693 C_blks_ = get_c_padded(bdesc_) / simd_w;
1694
1695 const size_t l3_size = platform::get_per_core_cache_size(3) * nthr_ / 2;
1696 int num_tensors = bdesc_->is_fwd() ? 1 : 2;
1697 dt_size_ = types::data_type_size(bdesc_->desc()->data_desc.data_type);
1698 const size_t working_set_size
1699 = dt_size_ * N_ * S_ * simd_w * num_tensors;
1700
1701 do_blocking_ = tag_kind_ == jit_memory_tag_kind_t::nspc
1702 ? false
1703 : working_set_size * C_blks_ >= l3_size / 2 && l3_size > 0;
1704
1705 if (tag_kind_ == jit_memory_tag_kind_t::nspc) {
1706 C_blk_step_ = C_blks_;
1707 } else {
1708 C_blk_step_ = l3_size / working_set_size;
1709 C_blk_step_ = nstl::max<dim_t>(C_blk_step_, 1);
1710 C_blk_step_ = nstl::min<dim_t>(C_blk_step_, C_blks_);
1711 }
1712 }
1713
create_kerneldnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1714 status_t create_kernel() {
1715 if (bdesc_->is_fwd()) {
1716 CHECK(safe_ptr_assign(
1717 ker_fwd_, new jit_bnorm_fwd_t<isa>(bdesc_, tag_kind_)));
1718 CHECK(ker_fwd_->create_kernel());
1719 if (!bdesc_->stats_is_src()) {
1720 CHECK(safe_ptr_assign(ker_fwd_mean_,
1721 new jit_bnorm_fwd_mean_t<isa>(bdesc_, tag_kind_)));
1722 CHECK(safe_ptr_assign(ker_fwd_var_,
1723 new jit_bnorm_fwd_var_t<isa>(bdesc_, tag_kind_)));
1724 CHECK(ker_fwd_mean_->create_kernel());
1725 CHECK(ker_fwd_var_->create_kernel());
1726 }
1727 } else {
1728 CHECK(safe_ptr_assign(
1729 ker_bwd_, new jit_bnorm_bwd_t<isa>(bdesc_, tag_kind_)));
1730 CHECK(safe_ptr_assign(ker_bwd_diff_ss_,
1731 new jit_bnorm_bwd_diff_ss_t<isa>(bdesc_, tag_kind_)));
1732 CHECK(ker_bwd_->create_kernel());
1733 CHECK(ker_bwd_diff_ss_->create_kernel());
1734 }
1735 return status::success;
1736 }
1737
init_scratchpaddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1738 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
1739 const batch_normalization_pd_t *bdesc) {
1740
1741 int nthrs = dnnl_get_max_threads();
1742 int C_PADDED = get_c_padded(bdesc);
1743
1744 int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
1745 int pbuf_sz = (use_tmp_diff_scale(bdesc) + use_tmp_diff_shift(bdesc))
1746 * C_PADDED;
1747 int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
1748
1749 scratchpad.book<acc_data_t>(key_bnorm_tmp_stats, sbuf_sz);
1750 scratchpad.book<acc_data_t>(key_bnorm_tmp_diff_ss, pbuf_sz);
1751 scratchpad.book<acc_data_t>(key_bnorm_reduction, rbuf_sz);
1752 }
1753
exec_fwd_step_statsdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1754 void exec_fwd_step_stats(const dim_t C_blks, const bnorm_dims_t &nthr,
1755 const void *src, acc_data_t *mean, acc_data_t *var,
1756 acc_data_t *rbuf, bool blk_has_tail) {
1757 size_t stride_C, stride_N, stride_S;
1758 std::tie(stride_N, stride_S, stride_C)
1759 = get_data_strides<isa>(bdesc_, tag_kind_);
1760
1761 const int nthr_NS = nthr.N * nthr.S;
1762 const bool need_reduction = nthr_NS > 1;
1763 const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w;
1764
1765 const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size;
1766
1767 auto reduce = [&](acc_data_t *stat, acc_data_t *r_stat) {
1768 if (!need_reduction) return;
1769 acc_data_t *loc_stat = r_stat;
1770
1771 for (dim_t c = 0; c < size_C_stat; ++c)
1772 stat[c] = loc_stat[c];
1773
1774 for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
1775 loc_stat += size_C_stat;
1776 for (dim_t c = 0; c < size_C_stat; ++c)
1777 stat[c] += loc_stat[c];
1778 }
1779
1780 for (dim_t c = 0; c < size_C_stat; ++c)
1781 stat[c] /= N_ * S_;
1782 };
1783
1784 // find local mean
1785 acc_data_t *r_mean = need_reduction ? rbuf : mean;
1786 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1787 assert(nthr_glob == nthr.glob);
1788 const auto ithr = map_thread(ithr_glob, nthr);
1789 bnorm_dims_t start, stop;
1790 work_distribution(C_blks, ithr, nthr, start, stop);
1791
1792 auto c = typename jit_bnorm_fwd_mean_t<isa>::call_params_t();
1793 c.N = stop.N - start.N;
1794 c.C = stop.C - start.C;
1795 c.S = stop.S - start.S;
1796
1797 const size_t d_off = start.N * stride_N + start.C * stride_C
1798 + start.S * stride_S;
1799 c.src = (void *)((char *)src + d_off * dt_size_);
1800 const int ithr_NS = ithr.N * nthr.S + ithr.S;
1801 c.mean = &r_mean[ithr_NS * size_C_stat + start.C * simd_w];
1802 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1803 c.do_normalise = !need_reduction;
1804 (*ker_fwd_mean_)(&c);
1805 });
1806
1807 // mean reduction
1808 reduce(mean, r_mean);
1809
1810 // find local var
1811 acc_data_t *r_var = need_reduction ? rbuf : var;
1812 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1813 assert(nthr_glob == nthr.glob);
1814 const auto ithr = map_thread(ithr_glob, nthr);
1815 bnorm_dims_t start, stop;
1816 work_distribution(C_blks, ithr, nthr, start, stop);
1817
1818 auto c = typename jit_bnorm_fwd_var_t<isa>::call_params_t();
1819 c.N = stop.N - start.N;
1820 c.C = stop.C - start.C;
1821 c.S = stop.S - start.S;
1822
1823 const size_t d_off = start.N * stride_N + start.C * stride_C
1824 + start.S * stride_S;
1825 c.src = (void *)((char *)src + d_off * dt_size_);
1826 const int ithr_NS = ithr.N * nthr.S + ithr.S;
1827 c.mean = &mean[start.C * simd_w];
1828 c.var = &r_var[ithr_NS * size_C_stat + start.C * simd_w];
1829 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1830 c.do_normalise = !need_reduction;
1831 (*ker_fwd_var_)(&c);
1832 });
1833
1834 // var reduction
1835 reduce(var, r_var);
1836 }
1837
exec_fwd_step_normalizationdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1838 void exec_fwd_step_normalization(const dim_t C_blks,
1839 const bnorm_dims_t &nthr, const void *src, void *dst,
1840 const acc_data_t *scale, const acc_data_t *shift,
1841 const acc_data_t *mean, const acc_data_t *var, uint8_t *ws,
1842 bool blk_has_tail) {
1843 size_t stride_C, stride_N, stride_S;
1844 std::tie(stride_N, stride_S, stride_C)
1845 = get_data_strides<isa>(bdesc_, tag_kind_);
1846
1847 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1848 assert(nthr_glob == nthr.glob);
1849 const auto ithr = map_thread(ithr_glob, nthr);
1850 bnorm_dims_t start, stop;
1851 work_distribution(C_blks, ithr, nthr, start, stop);
1852
1853 auto c = typename jit_bnorm_fwd_t<isa>::call_params_t();
1854 c.N = stop.N - start.N;
1855 c.C = stop.C - start.C;
1856 c.S = stop.S - start.S;
1857
1858 const size_t d_off = start.N * stride_N + start.C * stride_C
1859 + start.S * stride_S;
1860 c.src = (void *)((char *)src + d_off * dt_size_);
1861 c.dst = (void *)((char *)dst + d_off * dt_size_);
1862 c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
1863 c.mean = &mean[start.C * simd_w];
1864 c.var = &var[start.C * simd_w];
1865 c.scale = scale ? &scale[start.C * simd_w] : nullptr;
1866 c.shift = shift ? &shift[start.C * simd_w] : nullptr;
1867 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1868 (*ker_fwd_)(&c);
1869 });
1870 }
1871
exec_fwddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1872 void exec_fwd(const void *src, void *dst, const acc_data_t *scale,
1873 const acc_data_t *shift, acc_data_t *mean, acc_data_t *var,
1874 uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
1875 auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
1876 if (use_tmp_stats(bdesc_)) {
1877 auto sbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_stats);
1878 mean = sbuf;
1879 var = sbuf + C_blks_ * simd_w;
1880 }
1881
1882 size_t stride_C;
1883 std::tie(std::ignore, std::ignore, stride_C)
1884 = get_data_strides<isa>(bdesc_, tag_kind_);
1885
1886 dim_t C_blk_step = C_blk_step_;
1887 auto nthr = bnorm_dims_t();
1888
1889 thread_distribution(C_blk_step, nthr);
1890
1891 for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) {
1892 if (C_blk_st + C_blk_step > C_blks_) {
1893 C_blk_step = C_blks_ - C_blk_st;
1894 thread_distribution(C_blk_step, nthr);
1895 }
1896
1897 if (!bdesc_->stats_is_src()) {
1898 exec_fwd_step_stats(C_blk_step, nthr,
1899 (void *)((char *)src
1900 + (C_blk_st * stride_C) * dt_size_),
1901 mean + C_blk_st * simd_w, var + C_blk_st * simd_w, rbuf,
1902 (C_blk_st + C_blk_step) * simd_w > C_);
1903 }
1904
1905 exec_fwd_step_normalization(C_blk_step, nthr,
1906 (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
1907 (void *)((char *)dst + (C_blk_st * stride_C) * dt_size_),
1908 scale + C_blk_st * simd_w, shift + C_blk_st * simd_w,
1909 mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
1910 ws + C_blk_st * stride_C / bits_per_byte,
1911 (C_blk_st + C_blk_step) * simd_w > C_);
1912 }
1913 }
1914
exec_bwd_step_diff_ssdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1915 void exec_bwd_step_diff_ss(const dim_t C_blks, const bnorm_dims_t &nthr,
1916 const void *src, const void *diff_dst, const acc_data_t *mean,
1917 const acc_data_t *var, const uint8_t *ws, acc_data_t *diff_scale,
1918 acc_data_t *diff_shift, acc_data_t *rbuf, bool blk_has_tail) {
1919 size_t stride_C, stride_N, stride_S;
1920 std::tie(stride_N, stride_S, stride_C)
1921 = get_data_strides<isa>(bdesc_, tag_kind_);
1922
1923 const dim_t tail_size = blk_has_tail ? C_ % simd_w : simd_w;
1924 const dim_t size_C_stat = (C_blks - 1) * simd_w + tail_size;
1925
1926 const int nthr_NS = nthr.N * nthr.S;
1927 const bool need_reduction = nthr_NS > 1;
1928
1929 acc_data_t *diff_gamma = diff_scale;
1930 acc_data_t *diff_beta = diff_shift;
1931
1932 acc_data_t *const r_diff_gamma = need_reduction ? rbuf : diff_gamma;
1933 acc_data_t *const r_diff_beta
1934 = need_reduction ? rbuf + nthr_NS * size_C_stat : diff_beta;
1935
1936 auto reduce = [&]() {
1937 if (!need_reduction) return;
1938
1939 // diff_gamma
1940 const acc_data_t *loc_diff_gamma = r_diff_gamma;
1941 for (dim_t c = 0; c < size_C_stat; ++c)
1942 diff_gamma[c] = loc_diff_gamma[c];
1943 for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
1944 loc_diff_gamma += size_C_stat;
1945 for (dim_t c = 0; c < size_C_stat; ++c)
1946 diff_gamma[c] += loc_diff_gamma[c];
1947 }
1948
1949 // diff_beta
1950 const acc_data_t *loc_diff_beta = r_diff_beta;
1951 for (dim_t c = 0; c < size_C_stat; ++c)
1952 diff_beta[c] = loc_diff_beta[c];
1953 for (int thr_ns = 1; thr_ns < nthr_NS; ++thr_ns) {
1954 loc_diff_beta += size_C_stat;
1955 for (dim_t c = 0; c < size_C_stat; ++c)
1956 diff_beta[c] += loc_diff_beta[c];
1957 }
1958 };
1959
1960 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
1961 assert(nthr_glob == nthr.glob);
1962 const auto ithr = map_thread(ithr_glob, nthr);
1963 bnorm_dims_t start, stop;
1964 work_distribution(C_blks, ithr, nthr, start, stop);
1965
1966 const int ithr_NS = ithr.N * nthr.S + ithr.S;
1967 acc_data_t *loc_diff_gamma = &r_diff_gamma[ithr_NS * size_C_stat];
1968 acc_data_t *loc_diff_beta = &r_diff_beta[ithr_NS * size_C_stat];
1969
1970 auto c = typename jit_bnorm_bwd_diff_ss_t<isa>::call_params_t();
1971 c.N = stop.N - start.N;
1972 c.C = stop.C - start.C;
1973 c.S = stop.S - start.S;
1974
1975 const size_t d_off = start.N * stride_N + start.C * stride_C
1976 + start.S * stride_S;
1977 c.src = (void *)((char *)src + d_off * dt_size_);
1978 c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_);
1979 c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
1980 c.mean = &mean[start.C * simd_w];
1981 c.var = &var[start.C * simd_w];
1982 c.diff_gamma = &loc_diff_gamma[start.C * simd_w];
1983 c.diff_beta = &loc_diff_beta[start.C * simd_w];
1984 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
1985
1986 (*ker_bwd_diff_ss_)(&c);
1987 });
1988
1989 reduce();
1990 }
1991
exec_bwd_step_normalizationdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t1992 void exec_bwd_step_normalization(const dim_t C_blks,
1993 const bnorm_dims_t &nthr, const void *src, void *diff_src,
1994 const void *diff_dst, const acc_data_t *mean, const acc_data_t *var,
1995 const uint8_t *ws, const acc_data_t *scale,
1996 const acc_data_t *diff_scale, const acc_data_t *diff_shift,
1997 bool blk_has_tail) {
1998 size_t stride_C, stride_N, stride_S;
1999 std::tie(stride_N, stride_S, stride_C)
2000 = get_data_strides<isa>(bdesc_, tag_kind_);
2001
2002 parallel(nthr.glob, [&](int ithr_glob, int nthr_glob) {
2003 assert(nthr_glob == nthr.glob);
2004 const auto ithr = map_thread(ithr_glob, nthr);
2005 bnorm_dims_t start, stop;
2006 work_distribution(C_blks, ithr, nthr, start, stop);
2007
2008 auto c = typename jit_bnorm_bwd_t<isa>::call_params_t();
2009 c.N = stop.N - start.N;
2010 c.C = stop.C - start.C;
2011 c.S = stop.S - start.S;
2012
2013 const size_t d_off = start.N * stride_N + start.C * stride_C
2014 + start.S * stride_S;
2015 c.src = (void *)((char *)src + d_off * dt_size_);
2016 c.diff_src = (void *)((char *)diff_src + d_off * dt_size_);
2017 c.diff_dst = (void *)((char *)diff_dst + d_off * dt_size_);
2018 c.ws = ws ? &ws[d_off / bits_per_byte] : nullptr;
2019 c.mean = &mean[start.C * simd_w];
2020 c.var = &var[start.C * simd_w];
2021 c.scale = scale ? &scale[start.C * simd_w] : nullptr;
2022 c.diff_scale = &diff_scale[start.C * simd_w];
2023 c.diff_shift = &diff_shift[start.C * simd_w];
2024 c.blk_has_tail = blk_has_tail && stop.C == C_blks;
2025
2026 (*ker_bwd_)(&c);
2027 });
2028 }
2029
exec_bwddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2030 void exec_bwd(const void *src, void *diff_src, const void *diff_dst,
2031 const acc_data_t *scale, acc_data_t *diff_scale,
2032 acc_data_t *diff_shift, const acc_data_t *mean,
2033 const acc_data_t *var, const uint8_t *ws,
2034 const memory_tracking::grantor_t &scratchpad) {
2035 auto rbuf = scratchpad.get<acc_data_t>(key_bnorm_reduction);
2036 if (use_tmp_diff_scale(bdesc_)) {
2037 auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2038 diff_scale = pbuf;
2039 }
2040 if (use_tmp_diff_shift(bdesc_)) {
2041 auto pbuf = scratchpad.get<acc_data_t>(key_bnorm_tmp_diff_ss);
2042 size_t shift_off = use_tmp_diff_scale(bdesc_) ? bdesc_->C() : 0;
2043 diff_shift = &pbuf[shift_off];
2044 }
2045
2046 size_t stride_C;
2047 std::tie(std::ignore, std::ignore, stride_C)
2048 = get_data_strides<isa>(bdesc_, tag_kind_);
2049
2050 dim_t C_blk_step = C_blk_step_;
2051 auto nthr = bnorm_dims_t();
2052
2053 thread_distribution(C_blk_step, nthr);
2054
2055 for (dim_t C_blk_st = 0; C_blk_st < C_blks_; C_blk_st += C_blk_step) {
2056 if (C_blk_st + C_blk_step > C_blks_) {
2057 C_blk_step = C_blks_ - C_blk_st;
2058 thread_distribution(C_blk_step, nthr);
2059 }
2060
2061 exec_bwd_step_diff_ss(C_blk_step, nthr,
2062 (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2063 (void *)((char *)diff_dst
2064 + (C_blk_st * stride_C) * dt_size_),
2065 mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2066 ws + C_blk_st * stride_C / bits_per_byte,
2067 diff_scale + C_blk_st * simd_w,
2068 diff_shift + C_blk_st * simd_w, rbuf,
2069 (C_blk_st + C_blk_step) * simd_w > C_);
2070
2071 exec_bwd_step_normalization(C_blk_step, nthr,
2072 (void *)((char *)src + (C_blk_st * stride_C) * dt_size_),
2073 (void *)((char *)diff_src
2074 + (C_blk_st * stride_C) * dt_size_),
2075 (void *)((char *)diff_dst
2076 + (C_blk_st * stride_C) * dt_size_),
2077 mean + C_blk_st * simd_w, var + C_blk_st * simd_w,
2078 ws + C_blk_st * stride_C / bits_per_byte,
2079 scale + C_blk_st * simd_w, diff_scale + C_blk_st * simd_w,
2080 diff_shift + C_blk_st * simd_w,
2081 (C_blk_st + C_blk_step) * simd_w > C_);
2082 }
2083 }
2084
2085 private:
use_tmp_statsdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2086 static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
2087 return true && !bdesc->stats_is_src()
2088 && bdesc->desc()->prop_kind == prop_kind::forward_inference;
2089 }
2090
use_tmp_diff_scalednnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2091 static bool use_tmp_diff_scale(const batch_normalization_pd_t *bdesc) {
2092 return false
2093 || (bdesc->is_bwd() && !bdesc->use_scaleshift()
2094 && !bdesc->use_scale())
2095 || bdesc->desc()->prop_kind == prop_kind::backward_data;
2096 }
2097
use_tmp_diff_shiftdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2098 static bool use_tmp_diff_shift(const batch_normalization_pd_t *bdesc) {
2099 return false
2100 || (bdesc->is_bwd() && !bdesc->use_scaleshift()
2101 && !bdesc->use_shift())
2102 || bdesc->desc()->prop_kind == prop_kind::backward_data;
2103 }
2104
thread_distributiondnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2105 void thread_distribution(dim_t C_blks, bnorm_dims_t &nthr) {
2106 if (do_blocking_) {
2107 nthr.N = nstl::min<dim_t>(N_, nthr_);
2108 nthr.C = nstl::min<dim_t>(C_blks, nthr_ / nthr.N);
2109 } else {
2110 if (tag_kind_ == jit_memory_tag_kind_t::nspc) {
2111 if ((nthr_ <= C_blks && nthr_ == 1) || C_blks <= 8)
2112 nthr.C = 1;
2113 else if (nthr_ >= 8 && C_blks <= 32)
2114 nthr.C = 8;
2115 else {
2116 nthr.C = math::gcd((dim_t)nthr_, C_blks);
2117 // Unroll by channels in JIT kernel
2118 if ((nthr.C == C_blks) || (nthr.C == nthr_)) nthr.C = 1;
2119 }
2120 } else {
2121 nthr.C = math::gcd((dim_t)nthr_, C_blks);
2122 }
2123 nthr.N = utils::saturate((dim_t)1, N_, nthr_ / nthr.C);
2124 }
2125 nthr.S = utils::saturate((dim_t)1, S_, nthr_ / (nthr.C * nthr.N));
2126 nthr.glob = nthr.N * nthr.C * nthr.S;
2127 }
2128
map_thread_cdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2129 int map_thread_c(int ithr_glob, const bnorm_dims_t &nthr) {
2130 return ithr_glob / nthr.N / nthr.S;
2131 }
2132
map_threaddnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2133 bnorm_dims_t map_thread(int ithr_glob, const bnorm_dims_t &nthr) {
2134 auto ithr = bnorm_dims_t();
2135 ithr.glob = ithr_glob;
2136 ithr.C = map_thread_c(ithr.glob, nthr);
2137 ithr.N = ithr.glob / nthr.S % nthr.N;
2138 ithr.S = ithr.glob % nthr.S;
2139 return ithr;
2140 }
2141
work_distribution_cdnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2142 void work_distribution_c(dim_t C_blks, int ithr_c, int nthr_c,
2143 dim_t &start_c, dim_t &stop_c) {
2144 balance211(C_blks, nthr_c, ithr_c, start_c, stop_c);
2145 }
2146
work_distributiondnnl::impl::cpu::x64::bnorm_tbb_impl::driver_t2147 void work_distribution(dim_t C_blks, const bnorm_dims_t &ithr,
2148 const bnorm_dims_t &nthr, bnorm_dims_t &start, bnorm_dims_t &stop) {
2149 work_distribution_c(C_blks, ithr.C, nthr.C, start.C, stop.C);
2150 balance211(N_, nthr.N, ithr.N, start.N, stop.N);
2151 balance211(S_, nthr.S, ithr.S, start.S, stop.S);
2152 }
2153
2154 const batch_normalization_pd_t *bdesc_;
2155 const jit_memory_tag_kind_t tag_kind_;
2156 const int simd_w;
2157
2158 bool do_blocking_;
2159
2160 int nthr_;
2161
2162 dim_t N_, S_; // MB, D * H *W
2163 dim_t C_, C_blks_; // C / simd_w
2164 dim_t C_blk_step_; // for C_blks = 0 .. C_blks_, += C_blk_step_
2165
2166 std::unique_ptr<jit_bnorm_fwd_t<isa>> ker_fwd_;
2167 std::unique_ptr<jit_bnorm_fwd_mean_t<isa>> ker_fwd_mean_;
2168 std::unique_ptr<jit_bnorm_fwd_var_t<isa>> ker_fwd_var_;
2169 std::unique_ptr<jit_bnorm_bwd_t<isa>> ker_bwd_;
2170 std::unique_ptr<jit_bnorm_bwd_diff_ss_t<isa>> ker_bwd_diff_ss_;
2171
2172 size_t dt_size_;
2173 };
2174 } // namespace bnorm_tbb_impl
2175
2176 using namespace data_type;
2177 using namespace format_tag;
2178 using namespace utils;
2179
2180 /* fwd */
2181 template <cpu_isa_t isa>
init(engine_t * engine)2182 status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::pd_t::init(
2183 engine_t *engine) {
2184
2185 const bool ok = mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
2186 && one_of(ndims(), 4, 5) && one_of(src_md()->data_type, f32, bf16)
2187 && IMPLICATION(src_md()->data_type == bf16,
2188 is_superset(isa, avx512_common) && mayiuse(avx512_core))
2189 && check_scale_shift_data_type()
2190 && (attr()->has_default_values() || this->with_relu_post_op());
2191 if (!ok) return status::unimplemented;
2192
2193 const format_tag_t blocked_tag = is_superset(isa, avx512_common)
2194 ? utils::pick(ndims() - 4, nChw16c, nCdhw16c)
2195 : utils::pick(ndims() - 4, nChw8c, nCdhw8c);
2196
2197 const format_tag_t blocked_format
2198 = memory_desc_matches_tag(*src_md(), blocked_tag)
2199 ? blocked_tag
2200 : format_tag::undef;
2201 const format_tag_t nspc_format
2202 = memory_desc_matches_one_of_tag(*src_md(), nhwc, ndhwc);
2203
2204 if (memory_desc_matches_tag(*dst_md(), blocked_format))
2205 tag_kind_ = jit_memory_tag_kind_t::blocked;
2206 else if (memory_desc_matches_tag(*dst_md(), nspc_format)) {
2207 tag_kind_ = jit_memory_tag_kind_t::nspc;
2208 const int simd_w = get_simd_w<isa>(tag_kind_);
2209 if (C() % simd_w != 0) return status::unimplemented;
2210 } else
2211 return status::unimplemented;
2212
2213 const bool isa_supports_avx2 = is_superset(isa, avx2);
2214 if (is_training() && fuse_norm_relu()) {
2215 if (!isa_supports_avx2) return status::unimplemented;
2216 init_default_ws(1);
2217 }
2218
2219 if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2220 && !isa_supports_avx2)
2221 return status::unimplemented;
2222
2223 auto scratchpad = scratchpad_registry().registrar();
2224 bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this);
2225
2226 return status::success;
2227 }
2228
2229 template <cpu_isa_t isa>
2230 jit_uni_tbb_batch_normalization_fwd_t<
jit_uni_tbb_batch_normalization_fwd_t(const pd_t * apd)2231 isa>::jit_uni_tbb_batch_normalization_fwd_t(const pd_t *apd)
2232 : primitive_t(apd) {}
2233
2234 template <cpu_isa_t isa>
init(engine_t * engine)2235 status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::init(engine_t *engine) {
2236 CHECK(safe_ptr_assign(bnorm_driver_,
2237 new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_)));
2238 return bnorm_driver_->create_kernel();
2239 }
2240
2241 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const2242 status_t jit_uni_tbb_batch_normalization_fwd_t<isa>::execute(
2243 const exec_ctx_t &ctx) const {
2244
2245 const memory_desc_wrapper ss_d(pd()->weights_md());
2246
2247 const auto use_ss = pd()->use_scaleshift();
2248 const auto use_sc = pd()->use_scale();
2249 const auto use_sh = pd()->use_shift();
2250
2251 const size_t shift_off
2252 = use_ss && !ss_d.has_zero_dim() ? ss_d.off(1, 0) : 0;
2253
2254 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2255 auto scale = CTX_IN_MEM(
2256 const acc_data_t *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
2257 auto shift = use_sh ? CTX_IN_MEM(const acc_data_t *, DNNL_ARG_SHIFT)
2258 : use_ss ? &CTX_IN_MEM(const acc_data_t *,
2259 DNNL_ARG_SCALE_SHIFT)[shift_off]
2260 : nullptr;
2261
2262 auto mean = pd()->stats_is_src() ? const_cast<acc_data_t *>(
2263 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN))
2264 : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_MEAN);
2265 auto var = pd()->stats_is_src()
2266 ? const_cast<acc_data_t *>(
2267 CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE))
2268 : CTX_OUT_MEM(acc_data_t *, DNNL_ARG_VARIANCE);
2269
2270 auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
2271 auto ws = CTX_OUT_MEM(uint8_t *, DNNL_ARG_WORKSPACE);
2272
2273 auto scratchpad = ctx.get_scratchpad_grantor();
2274
2275 bnorm_driver_->exec_fwd(src, dst, scale, shift, mean, var, ws, scratchpad);
2276
2277 return status::success;
2278 }
2279
2280 template <cpu_isa_t isa>
2281 jit_uni_tbb_batch_normalization_fwd_t<
2282 isa>::~jit_uni_tbb_batch_normalization_fwd_t()
2283 = default;
2284
2285 template struct jit_uni_tbb_batch_normalization_fwd_t<sse41>;
2286 template struct jit_uni_tbb_batch_normalization_fwd_t<avx2>;
2287 template struct jit_uni_tbb_batch_normalization_fwd_t<avx512_common>;
2288
2289 /* bwd */
2290 template <cpu_isa_t isa>
init(engine_t * engine)2291 status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::pd_t::init(
2292 engine_t *engine) {
2293
2294 const bool ok = mayiuse(isa) && is_bwd() && !has_zero_dim_memory()
2295 && one_of(ndims(), 4, 5) && set_default_formats_common()
2296 && one_of(true,
2297 everyone_is(
2298 f32, src_md()->data_type, diff_src_md()->data_type),
2299 everyone_is(bf16, src_md()->data_type,
2300 diff_src_md()->data_type))
2301 && IMPLICATION(src_md()->data_type == bf16,
2302 is_superset(isa, avx512_common) && mayiuse(avx512_core))
2303 && check_scale_shift_data_type() && attr()->has_default_values();
2304 if (!ok) return status::unimplemented;
2305
2306 const format_tag_t blocked_tag = is_superset(isa, avx512_common)
2307 ? utils::pick(ndims() - 4, nChw16c, nCdhw16c)
2308 : utils::pick(ndims() - 4, nChw8c, nCdhw8c);
2309
2310 const format_tag_t blocked_format
2311 = memory_desc_matches_tag(*src_md(), blocked_tag)
2312 ? blocked_tag
2313 : format_tag::undef;
2314 const format_tag_t nspc_format
2315 = memory_desc_matches_one_of_tag(*src_md(), nhwc, ndhwc);
2316
2317 if (memory_desc_matches_tag(*diff_src_md(), blocked_format))
2318 tag_kind_ = jit_memory_tag_kind_t::blocked;
2319 else if (memory_desc_matches_tag(*diff_src_md(), nspc_format)) {
2320 tag_kind_ = jit_memory_tag_kind_t::nspc;
2321 const int simd_w = get_simd_w<isa>(tag_kind_);
2322 if (C() % simd_w != 0) return status::unimplemented;
2323 } else
2324 return status::unimplemented;
2325
2326 const bool isa_supports_avx2 = is_superset(isa, avx2);
2327 if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
2328 && !isa_supports_avx2)
2329 return status::unimplemented;
2330
2331 if (fuse_norm_relu()) {
2332 if (!isa_supports_avx2) return status::unimplemented;
2333 init_default_ws(1);
2334 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
2335 }
2336
2337 auto scratchpad = scratchpad_registry().registrar();
2338 bnorm_tbb_impl::driver_t<isa>::init_scratchpad(scratchpad, this);
2339
2340 return status::success;
2341 }
2342
2343 template <cpu_isa_t isa>
2344 jit_uni_tbb_batch_normalization_bwd_t<
jit_uni_tbb_batch_normalization_bwd_t(const pd_t * apd)2345 isa>::jit_uni_tbb_batch_normalization_bwd_t(const pd_t *apd)
2346 : primitive_t(apd) {}
2347
2348 template <cpu_isa_t isa>
init(engine_t * engine)2349 status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::init(engine_t *engine) {
2350 CHECK(safe_ptr_assign(bnorm_driver_,
2351 new bnorm_tbb_impl::driver_t<isa>(pd(), pd()->tag_kind_)));
2352 return bnorm_driver_->create_kernel();
2353 }
2354
2355 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const2356 status_t jit_uni_tbb_batch_normalization_bwd_t<isa>::execute(
2357 const exec_ctx_t &ctx) const {
2358
2359 const memory_desc_wrapper diff_ss_d(pd()->diff_weights_md());
2360
2361 const auto use_ss = pd()->use_scaleshift();
2362 const auto use_sc = pd()->use_scale();
2363 const auto use_sh = pd()->use_shift();
2364
2365 const size_t diff_shift_off
2366 = use_ss && !diff_ss_d.has_zero_dim() ? diff_ss_d.off(1, 0) : 0;
2367
2368 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
2369 auto mean = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_MEAN);
2370 auto var = CTX_IN_MEM(const acc_data_t *, DNNL_ARG_VARIANCE);
2371 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
2372 auto scale = CTX_IN_MEM(
2373 const acc_data_t *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
2374 auto ws = CTX_IN_MEM(const uint8_t *, DNNL_ARG_WORKSPACE);
2375
2376 auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
2377 auto diff_scale = CTX_OUT_MEM(acc_data_t *,
2378 use_sc ? DNNL_ARG_DIFF_SCALE : DNNL_ARG_DIFF_SCALE_SHIFT);
2379 auto diff_shift = use_sh ? CTX_OUT_MEM(acc_data_t *, DNNL_ARG_DIFF_SHIFT)
2380 : use_ss ? &diff_scale[diff_shift_off] : nullptr;
2381
2382 auto scratchpad = ctx.get_scratchpad_grantor();
2383
2384 bnorm_driver_->exec_bwd(src, diff_src, diff_dst, scale, diff_scale,
2385 diff_shift, mean, var, ws, scratchpad);
2386
2387 return status::success;
2388 }
2389
2390 template <cpu_isa_t isa>
2391 jit_uni_tbb_batch_normalization_bwd_t<
2392 isa>::~jit_uni_tbb_batch_normalization_bwd_t()
2393 = default;
2394
2395 template struct jit_uni_tbb_batch_normalization_bwd_t<sse41>;
2396 template struct jit_uni_tbb_batch_normalization_bwd_t<avx2>;
2397 template struct jit_uni_tbb_batch_normalization_bwd_t<avx512_common>;
2398 } // namespace x64
2399 } // namespace cpu
2400 } // namespace impl
2401 } // namespace dnnl
2402