1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25
26 #include "cpu/x64/jit_generator.hpp"
27
28 #include "cpu/x64/jit_uni_batch_normalization_s8.hpp"
29
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34
35 namespace {
36
37 using namespace Xbyak;
38
39 using data_t = int8_t;
40
41 struct call_params_t {
42 // keep int sizes at 8 bytes -- jit code expects this
43 size_t channel_offt_count, spat_offt_count;
44 float eps;
45 const float *scale, *shift, *mean, *var;
46 const data_t *src, *dst;
47 };
48
49 template <cpu_isa_t isa>
50 struct jit_bnorm_base_t : public jit_generator {
51
52 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
53
54 using Vmm = typename cpu_isa_traits<isa>::Vmm;
55 const AddressFrame &vmmword
56 = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword);
57 const int vlen = cpu_isa_traits<isa>::vlen;
58
59 const batch_normalization_pd_t *pd_;
60
61 Reg64 reg_param = abi_param1;
62
63 Reg64 reg_scale = rbx;
64 Reg64 reg_shift = rdx;
65 Reg64 reg_mean = rbp;
66
67 Reg64 reg_channel_offt_count = r8;
68 Reg64 reg_spat_offt = r9;
69 Reg64 reg_spat_offt_count = r10;
70 Reg64 reg_tmp = r11;
71 Reg64 reg_src = r12;
72 Reg64 reg_dst = r13;
73 Reg64 reg_var = r14;
74 Reg64 reg_channel_offt_1byte = r15;
75 Reg64 reg_channel_offt_4byte = rax;
76
77 Vmm vzero = Vmm(isa == avx512_core ? 29 : 13);
78 Xmm xone = Xmm(14);
79 Vmm vone = Vmm(isa == avx512_core ? 30 : 14);
80 Vmm veps = Vmm(isa == avx512_core ? 31 : 15);
81
82 size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
83 size_t c_in_xmm_ = (isa == sse41) ? 8 : 16;
84 size_t chan_data_offt_;
85 size_t num_c_blocks_;
86 size_t c_tail_;
87 bool with_relu_;
88
compute_predefined_variablesdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t89 void compute_predefined_variables() {
90 chan_data_offt_ = pd_->C() * sizeof(float);
91 num_c_blocks_ = pd_->C() / c_in_xmm_;
92 c_tail_ = pd_->C() % c_in_xmm_;
93 with_relu_ = (pd_->with_relu_post_op() || pd_->fuse_norm_relu())
94 && pd_->is_fwd();
95 }
96
load_common_paramsdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t97 void load_common_params() {
98 mov(reg_tmp, float2int(1.0f));
99 uni_vmovq(xone, reg_tmp);
100 uni_vbroadcastss(vone, xone);
101
102 #define PARAM_OFF(x) offsetof(call_params_t, x)
103 uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
104 uni_vpxor(vzero, vzero, vzero);
105
106 mov(reg_channel_offt_count,
107 ptr[reg_param + PARAM_OFF(channel_offt_count)]);
108 mov(reg_spat_offt_count, ptr[reg_param + PARAM_OFF(spat_offt_count)]);
109 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
110 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
111 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
112 mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
113 mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]);
114 mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
115 #undef PARAM_OFF
116 }
117
mean_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t118 Address mean_ptr(size_t offt = 0) {
119 return vmmword[reg_mean + reg_channel_offt_4byte + offt];
120 }
121
var_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t122 Address var_ptr(size_t offt = 0) {
123 return vmmword[reg_var + reg_channel_offt_4byte + offt];
124 }
125
scale_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t126 Address scale_ptr(size_t offt = 0) {
127 return vmmword[reg_scale + reg_channel_offt_4byte + offt
128 + 0 * chan_data_offt_];
129 }
130
shift_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t131 Address shift_ptr(size_t offt = 0) {
132 return vmmword[reg_shift + reg_channel_offt_4byte + offt
133 + 0 * chan_data_offt_];
134 }
135
src_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t136 Address src_ptr(size_t offt = 0) {
137 return vmmword[reg_src + reg_spat_offt + offt];
138 }
139
dst_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t140 Address dst_ptr(size_t offt = 0) {
141 return vmmword[reg_dst + reg_spat_offt + offt];
142 }
143
prepare_tail_maskdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t144 virtual void prepare_tail_mask() {}
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t145 virtual void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar,
146 size_t offt, bool need_tail) {}
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t147 virtual void load_scale(const Vmm &vscale, size_t offt, bool need_tail) {}
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t148 virtual void load_shift(const Vmm &vshift, size_t offt, bool need_tail) {}
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t149 virtual void compute_dst(bool need_tail) {}
150
151 // Precomputes vscale and vshift for following
152 // `vdst = vscale * vsrc + vshift`
compute_vscaleshiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t153 void compute_vscaleshift(const Vmm &vscale, const Vmm &vshift,
154 const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
155 bool need_tail) {
156 load_mean_and_var(vmean, vsqrtvar, offt, need_tail);
157 uni_vaddps(vsqrtvar, vsqrtvar, veps);
158 uni_vsqrtps(vsqrtvar, vsqrtvar);
159
160 if (pd_->use_scaleshift() || (pd_->use_scale() && pd_->use_shift())) {
161 load_scale(vscale, offt, need_tail);
162 uni_vdivps(vscale, vscale, vsqrtvar);
163 load_shift(vshift, offt, need_tail);
164 uni_vfnmadd231ps(vshift, vmean, vscale);
165 } else if (pd_->use_scale()) {
166 load_scale(vscale, offt, need_tail);
167 uni_vdivps(vscale, vscale, vsqrtvar);
168 uni_vmulps(vmean, vmean, vscale);
169 uni_vsubps(vshift, vzero, vmean, vshift);
170 } else if (pd_->use_shift()) {
171 uni_vdivps(vscale, vone, vsqrtvar, vscale);
172 load_shift(vshift, offt, need_tail);
173 uni_vfnmadd231ps(vshift, vmean, vscale);
174 } else {
175 uni_vdivps(vscale, vone, vsqrtvar, vscale);
176 uni_vmulps(vmean, vmean, vscale);
177 uni_vsubps(vshift, vzero, vmean, vshift);
178 }
179 }
180
forwarddnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t181 void forward() {
182 xor_(reg_channel_offt_1byte, reg_channel_offt_1byte);
183 xor_(reg_channel_offt_4byte, reg_channel_offt_4byte);
184 mov(reg_tmp, sizeof(data_t) * c_in_xmm_);
185
186 if (num_c_blocks_) compute_dst(false);
187 if (c_tail_) compute_dst(true);
188 }
189
190 // either this stub or duplication at each jit_binary_t ctor due to methods
191 // that are participated are not defined at the moment of base ctor
192 // initialization.
generatednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t193 void generate() override {
194 preamble();
195 compute_predefined_variables();
196 load_common_params();
197 prepare_tail_mask();
198 forward();
199 postamble();
200 }
201
jit_bnorm_base_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t202 jit_bnorm_base_t(const batch_normalization_pd_t *pd) : pd_(pd) {}
203 };
204
205 template <cpu_isa_t isa>
206 struct jit_bnorm_t;
207
208 template <>
209 struct jit_bnorm_t<avx512_core> : public jit_bnorm_base_t<avx512_core> {
210 Opmask tail_opmask = Opmask(1); // f32 mask for channel math
211
prepare_tail_maskdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t212 void prepare_tail_mask() override {
213 if (!c_tail_) return;
214
215 const int mask_f32 = (1 << c_tail_) - 1;
216
217 Reg32 regw_tmp = reg_tmp.cvt32();
218 mov(regw_tmp, mask_f32);
219 kmovw(tail_opmask, regw_tmp);
220 }
221
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t222 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
223 bool need_tail) override {
224 if (need_tail) {
225 uni_vmovups_tail(vmean, tail_opmask, mean_ptr(offt));
226 uni_vmovups_tail(vsqrtvar, tail_opmask, var_ptr(offt));
227 } else {
228 uni_vmovups(vmean, mean_ptr(offt));
229 uni_vmovups(vsqrtvar, var_ptr(offt));
230 }
231 }
232
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t233 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
234 if (need_tail) {
235 uni_vmovups_tail(vscale, tail_opmask, scale_ptr(offt));
236 } else {
237 uni_vmovups(vscale, scale_ptr(offt));
238 }
239 }
240
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t241 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
242 if (need_tail) {
243 uni_vmovups_tail(vshift, tail_opmask, shift_ptr(offt));
244 } else {
245 uni_vmovups(vshift, shift_ptr(offt));
246 }
247 }
248
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t249 void compute_dst(bool need_tail) override {
250 Label c_loop;
251 L(c_loop);
252 {
253 Xmm x = Xmm(0);
254 Vmm v = Vmm(0);
255 Vmm vscale = Vmm(1);
256 Vmm vshift = Vmm(2);
257 Vmm vmean = Vmm(3);
258 Vmm vsqrtvar = Vmm(4);
259
260 // compute single vscale and vshift vectors...
261 compute_vscaleshift(vscale, vshift, vmean, vsqrtvar, 0, need_tail);
262
263 // ... then process all spatial loop with it and move to the
264 // next channel chunk
265 mov(reg_spat_offt, reg_channel_offt_1byte);
266 Label mb_sp_loop;
267 L(mb_sp_loop);
268 {
269 if (need_tail) {
270 for (size_t tl = 0; tl < c_tail_; tl++)
271 vpinsrb(x, x, src_ptr(tl), tl);
272 vpmovsxbd(v, x);
273 } else
274 vpmovsxbd(v, src_ptr());
275
276 vcvtdq2ps(v, v);
277
278 uni_vfmadd213ps(v, vscale, vshift);
279 if (with_relu_) uni_vmaxps(v, v, vzero);
280
281 vcvtps2dq(v, v);
282 if (need_tail) {
283 vpmovsdb(x, v);
284 for (size_t tl = 0; tl < c_tail_; tl++)
285 vpextrb(dst_ptr(tl), x, tl);
286 } else
287 vpmovsdb(dst_ptr(), v);
288
289 add(reg_spat_offt, reg_channel_offt_count);
290 cmp(reg_spat_offt, reg_spat_offt_count);
291 jl(mb_sp_loop);
292 }
293
294 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
295 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
296 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
297 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
298 cmp(reg_tmp, reg_channel_offt_count);
299 jle(c_loop);
300 }
301 }
302
jit_bnorm_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t303 jit_bnorm_t(const batch_normalization_pd_t *pd)
304 : jit_bnorm_base_t<avx512_core>(pd) {}
305 };
306
307 template <>
308 struct jit_bnorm_t<avx2> : public jit_bnorm_base_t<avx2> {
309 Vmm tail_vmask = Vmm(11);
310 Vmm body_vmask = Vmm(12);
311
prepare_tail_maskdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t312 void prepare_tail_mask() override {
313 // tail is always < 16, process it with two parts
314 static const uint32_t mask_half_ymm[8]
315 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0};
316 mov(reg_tmp, reinterpret_cast<size_t>(&mask_half_ymm[0]));
317 vmovups(body_vmask, ptr[reg_tmp]);
318
319 if (!c_tail_) return;
320
321 static const uint32_t mask_f32[14]
322 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
323 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
324
325 mov(reg_tmp,
326 reinterpret_cast<size_t>(&mask_f32[7 - c_tail_ % simd_w_]));
327 vmovups(tail_vmask, ptr[reg_tmp]);
328 }
329
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t330 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
331 bool need_tail) override {
332 if (need_tail) {
333 uni_vmovups_tail(vmean, tail_vmask, mean_ptr(offt));
334 uni_vmovups_tail(vsqrtvar, tail_vmask, var_ptr(offt));
335 } else {
336 uni_vmovups(vmean, mean_ptr(offt));
337 uni_vmovups(vsqrtvar, var_ptr(offt));
338 }
339 }
340
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t341 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
342 if (need_tail) {
343 uni_vmovups_tail(vscale, tail_vmask, scale_ptr(offt));
344 } else {
345 uni_vmovups(vscale, scale_ptr(offt));
346 }
347 }
348
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t349 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
350 if (need_tail) {
351 uni_vmovups_tail(vshift, tail_vmask, shift_ptr(offt));
352 } else {
353 uni_vmovups(vshift, shift_ptr(offt));
354 }
355 }
356
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t357 void compute_dst(bool need_tail) override {
358 Label c_loop;
359 L(c_loop);
360 {
361
362 Xmm x0 = Xmm(0);
363 Vmm v0 = Vmm(0);
364 Xmm x1 = Xmm(1);
365 Vmm v1 = Vmm(1);
366 Vmm vscale0 = Vmm(2);
367 Vmm vshift0 = Vmm(3);
368 Vmm vmean0 = Vmm(4);
369 Vmm vsqrtvar0 = Vmm(5);
370 Vmm vscale1 = Vmm(6);
371 Vmm vshift1 = Vmm(7);
372 Vmm vmean1 = Vmm(8);
373 Vmm vsqrtvar1 = Vmm(9);
374
375 // compute couple vscale and vshift vectors each of 8 channels...
376 compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
377 (c_tail_ < simd_w_ && need_tail) ? true : false);
378 if (!need_tail || c_tail_ > simd_w_) {
379 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
380 simd_w_ * sizeof(float), need_tail);
381 }
382
383 // ... then process all spatial loop with it and move to the
384 // next channel chunk
385 mov(reg_spat_offt, reg_channel_offt_1byte);
386 Label mb_sp_loop;
387 L(mb_sp_loop);
388 {
389
390 if (need_tail) {
391 for (size_t tl = 0; tl < c_tail_; tl++) {
392 if (tl < simd_w_) {
393 vpinsrb(x0, x0, src_ptr(tl), tl);
394 } else {
395 vpinsrb(x1, x1, src_ptr(tl), tl - simd_w_);
396 }
397 }
398 vpmovsxbd(v0, x0);
399 vpmovsxbd(v1, x1);
400 } else {
401 vpmovsxbd(v0, src_ptr());
402 vpmovsxbd(v1, src_ptr(simd_w_));
403 }
404
405 vcvtdq2ps(v0, v0);
406 vcvtdq2ps(v1, v1);
407
408 uni_vfmadd213ps(v0, vscale0, vshift0);
409 uni_vfmadd213ps(v1, vscale1, vshift1);
410 if (with_relu_) {
411 uni_vmaxps(v0, v0, vzero);
412 uni_vmaxps(v1, v1, vzero);
413 }
414
415 vcvtps2dq(v0, v0); // BA
416 vcvtps2dq(v1, v1); // DC
417 vpackssdw(v0, v0, v1); // BA + DC -> DBCA
418 vpermq(v0, v0, 0xD8); // DBCA -> DCBA
419 vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
420 vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
421
422 if (need_tail) {
423 for (size_t tl = 0; tl < c_tail_; tl++) {
424 vpextrb(dst_ptr(tl), x0, tl);
425 }
426 } else {
427 // due to vpacksswb produces 32 integers in ymm, and top
428 // half of them are garbage, do 128-b masked store
429 vmaskmovps(dst_ptr(), body_vmask, v0);
430 }
431
432 add(reg_spat_offt, reg_channel_offt_count);
433 cmp(reg_spat_offt, reg_spat_offt_count);
434 jl(mb_sp_loop);
435 }
436
437 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
438 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
439 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
440 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
441 cmp(reg_tmp, reg_channel_offt_count);
442 jle(c_loop);
443 }
444 }
445
jit_bnorm_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t446 jit_bnorm_t(const batch_normalization_pd_t *pd)
447 : jit_bnorm_base_t<avx2>(pd) {}
448 };
449
450 template <>
451 struct jit_bnorm_t<sse41> : public jit_bnorm_base_t<sse41> {
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t452 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
453 bool need_tail) override {
454 if (need_tail) {
455 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
456 pinsrd(vmean, mean_ptr(offt + tl * sizeof(float)), tl);
457 pinsrd(vsqrtvar, var_ptr(offt + tl * sizeof(float)), tl);
458 }
459 } else {
460 movups(vmean, mean_ptr(offt));
461 movups(vsqrtvar, var_ptr(offt));
462 }
463 }
464
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t465 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
466 if (need_tail) {
467 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
468 pinsrd(vscale, scale_ptr(offt + tl * sizeof(float)), tl);
469 }
470 } else {
471 movups(vscale, scale_ptr(offt));
472 }
473 }
474
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t475 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
476 if (need_tail) {
477 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
478 pinsrd(vshift, shift_ptr(offt + tl * sizeof(float)), tl);
479 }
480 } else {
481 movups(vshift, shift_ptr(offt));
482 }
483 }
484
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t485 void compute_dst(bool need_tail) override {
486 const size_t copy_range = need_tail ? c_tail_ : c_in_xmm_;
487 Label c_loop;
488 L(c_loop);
489 {
490
491 Vmm v0 = Vmm(0);
492 Vmm v1 = Vmm(1);
493 Vmm vscale0 = Vmm(2);
494 Vmm vshift0 = Vmm(3);
495 Vmm vmean0 = Vmm(4);
496 Vmm vsqrtvar0 = Vmm(5);
497 Vmm vscale1 = Vmm(6);
498 Vmm vshift1 = Vmm(7);
499 Vmm vmean1 = Vmm(8);
500 Vmm vsqrtvar1 = Vmm(9);
501
502 // compute couple vscale and vshift vectors each of 8 channels...
503 compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
504 (c_tail_ < simd_w_ && need_tail) ? true : false);
505 if (!need_tail || c_tail_ > simd_w_) {
506 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
507 simd_w_ * sizeof(float), need_tail);
508 }
509
510 // ... then process all spatial loop with it and move to the
511 // next channel chunk
512 mov(reg_spat_offt, reg_channel_offt_1byte);
513 Label mb_sp_loop;
514 L(mb_sp_loop);
515 {
516 if (need_tail) {
517 for (size_t tl = 0; tl < copy_range; tl++) {
518 if (tl < simd_w_) {
519 pinsrb(v0, src_ptr(tl), tl);
520 } else {
521 pinsrb(v1, src_ptr(tl), (tl - simd_w_));
522 }
523 }
524 pmovsxbd(v0, v0);
525 pmovsxbd(v1, v1);
526 } else {
527 pmovsxbd(v0, src_ptr());
528 pmovsxbd(v1, src_ptr(simd_w_));
529 }
530
531 cvtdq2ps(v0, v0);
532 cvtdq2ps(v1, v1);
533
534 uni_vfmadd213ps(v0, vscale0, vshift0);
535 uni_vfmadd213ps(v1, vscale1, vshift1);
536 if (with_relu_) {
537 maxps(v0, vzero);
538 maxps(v1, vzero);
539 }
540
541 cvtps2dq(v0, v0);
542 cvtps2dq(v1, v1);
543 packssdw(v0, v1);
544 movups(v1, v0);
545 packsswb(v0, v1);
546
547 // Potential perf gain is possible if combining two halves
548 // into a single vector register and use movups instead
549 // of byte stores.
550 for (size_t tl = 0; tl < copy_range; tl++) {
551 pextrb(dst_ptr(tl), v0, tl);
552 }
553
554 add(reg_spat_offt, reg_channel_offt_count);
555 cmp(reg_spat_offt, reg_spat_offt_count);
556 jl(mb_sp_loop);
557 }
558
559 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
560 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
561 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
562 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
563 cmp(reg_tmp, reg_channel_offt_count);
564 jle(c_loop);
565 }
566 }
567
jit_bnorm_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t568 jit_bnorm_t(const batch_normalization_pd_t *pd)
569 : jit_bnorm_base_t<sse41>(pd) {}
570 };
571
572 } // namespace
573
574 namespace bnorm_s8_impl {
575
576 template <cpu_isa_t isa>
577 struct driver_t : public c_compatible {
driver_tdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t578 driver_t(const batch_normalization_pd_t *pd) : pd_(pd), ker_(pd_) {}
579 ~driver_t() = default;
580
581 // TODO: for problems where thread pieces don't fit L2 cache, add spatial
582 // re-balance using less pieces.
execdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t583 void exec(int ithr, int nthr, const data_t *src, data_t *dst,
584 const float *scale, const float *shift, const float *mean,
585 const float *var) {
586 dim_t N = pd_->MB();
587 dim_t C = pd_->C();
588 dim_t D = pd_->D();
589 dim_t H = pd_->H();
590 dim_t W = pd_->W();
591 dim_t SP = D * H * W;
592
593 call_params_t p;
594
595 p.eps = pd_->desc()->batch_norm_epsilon;
596
597 p.scale = scale;
598 p.shift = shift;
599 p.mean = mean;
600 p.var = var;
601
602 dim_t work_amount {N * SP}, start {0}, end {0};
603 balance211(work_amount, nthr, ithr, start, end);
604
605 p.channel_offt_count = C;
606 p.spat_offt_count = (end - start) * p.channel_offt_count;
607 p.src = src + start * p.channel_offt_count;
608 p.dst = dst + start * p.channel_offt_count;
609
610 if (p.spat_offt_count != 0) ker_(&p);
611 }
612
create_kerneldnnl::impl::cpu::x64::bnorm_s8_impl::driver_t613 status_t create_kernel() { return ker_.create_kernel(); }
614
615 private:
616 const batch_normalization_pd_t *pd_;
617
618 jit_bnorm_t<isa> ker_;
619 };
620
621 } // namespace bnorm_s8_impl
622
623 using namespace data_type;
624 using namespace format_tag;
625 using namespace utils;
626
627 /* fwd */
628
629 template <cpu_isa_t isa>
init(engine_t * engine)630 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::pd_t::init(
631 engine_t *engine) {
632 auto desired_fmt_tag = (ndims() == 4) ? nhwc : ndhwc;
633
634 bool ok = true && mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
635 && one_of(ndims(), 4, 5) && stats_is_src()
636 && src_md()->data_type == s8 && check_scale_shift_data_type()
637 && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
638 && (attr()->has_default_values() || this->with_relu_post_op());
639 if (!ok) return status::unimplemented;
640
641 return status::success;
642 }
643
644 template <cpu_isa_t isa>
jit_uni_batch_normalization_s8_fwd_t(const pd_t * apd)645 jit_uni_batch_normalization_s8_fwd_t<isa>::jit_uni_batch_normalization_s8_fwd_t(
646 const pd_t *apd)
647 : primitive_t(apd) {}
648
649 template <cpu_isa_t isa>
init(engine_t * engine)650 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::init(engine_t *engine) {
651 CHECK(safe_ptr_assign(
652 bnorm_driver_, new bnorm_s8_impl::driver_t<isa>(pd())));
653 return bnorm_driver_->create_kernel();
654 }
655
656 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const657 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::execute(
658 const exec_ctx_t &ctx) const {
659
660 const memory_desc_wrapper ss_d(pd()->weights_md());
661
662 const auto use_ss = pd()->use_scaleshift();
663 const auto use_sc = pd()->use_scale();
664 const auto use_sh = pd()->use_shift();
665
666 const size_t shift_off
667 = use_ss && !ss_d.has_zero_dim() ? ss_d.off(1, 0) : 0;
668
669 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
670 auto scale = CTX_IN_MEM(
671 const float *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
672 auto shift = use_sh ? CTX_IN_MEM(const float *, DNNL_ARG_SHIFT)
673 : use_ss ? &CTX_IN_MEM(const float *,
674 DNNL_ARG_SCALE_SHIFT)[shift_off]
675 : nullptr;
676 auto mean = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN));
677 auto var
678 = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE));
679 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
680
681 // do sequential if the problem is less than one 4K memory page
682 const bool force_sequential
683 = pd()->MB() * pd()->C() * pd()->D() * pd()->H() * pd()->W()
684 <= 4096;
685
686 parallel(force_sequential ? 1 : 0, [&](const int ithr, const int nthr) {
687 bnorm_driver_->exec(ithr, nthr, src, dst, scale, shift, mean, var);
688 });
689
690 return status::success;
691 }
692
693 template <cpu_isa_t isa>
694 jit_uni_batch_normalization_s8_fwd_t<
~jit_uni_batch_normalization_s8_fwd_t()695 isa>::~jit_uni_batch_normalization_s8_fwd_t() {
696 delete bnorm_driver_;
697 }
698
699 /* struct instantiation */
700 template struct jit_uni_batch_normalization_s8_fwd_t<avx512_core>;
701 template struct jit_uni_batch_normalization_s8_fwd_t<avx2>;
702 template struct jit_uni_batch_normalization_s8_fwd_t<sse41>;
703
704 } // namespace x64
705 } // namespace cpu
706 } // namespace impl
707 } // namespace dnnl
708