1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25
26 #include "cpu/x64/jit_generator.hpp"
27
28 #include "cpu/x64/jit_uni_batch_normalization_s8.hpp"
29
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34
35 using namespace Xbyak;
36
37 using data_t = int8_t;
38
39 struct call_params_t {
40 // keep int sizes at 8 bytes -- jit code expects this
41 size_t channel_offt_count, spat_offt_count;
42 float eps;
43 const float *scale, *shift, *mean, *var;
44 const data_t *src, *dst;
45 };
46
47 template <cpu_isa_t isa>
48 struct jit_bnorm_base_t : public jit_generator {
49
50 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_s8_t)
51
52 using Vmm = typename cpu_isa_traits<isa>::Vmm;
53 const AddressFrame &vmmword
54 = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword);
55 const int vlen = cpu_isa_traits<isa>::vlen;
56
57 const batch_normalization_pd_t *pd_;
58
59 Reg64 reg_param = abi_param1;
60
61 Reg64 reg_scale = rbx;
62 Reg64 reg_shift = rdx;
63 Reg64 reg_mean = rbp;
64
65 Reg64 reg_channel_offt_count = r8;
66 Reg64 reg_spat_offt = r9;
67 Reg64 reg_spat_offt_count = r10;
68 Reg64 reg_tmp = r11;
69 Reg64 reg_src = r12;
70 Reg64 reg_dst = r13;
71 Reg64 reg_var = r14;
72 Reg64 reg_channel_offt_1byte = r15;
73 Reg64 reg_channel_offt_4byte = rax;
74
75 Vmm vzero = Vmm(isa == avx512_core ? 29 : 13);
76 Xmm xone = Xmm(14);
77 Vmm vone = Vmm(isa == avx512_core ? 30 : 14);
78 Vmm veps = Vmm(isa == avx512_core ? 31 : 15);
79
80 size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
81 size_t c_in_xmm_ = (isa == sse41) ? 8 : 16;
82 size_t chan_data_offt_;
83 size_t num_c_blocks_;
84 size_t c_tail_;
85 bool with_relu_;
86
compute_predefined_variablesdnnl::impl::cpu::x64::jit_bnorm_base_t87 void compute_predefined_variables() {
88 chan_data_offt_ = pd_->C() * sizeof(float);
89 num_c_blocks_ = pd_->C() / c_in_xmm_;
90 c_tail_ = pd_->C() % c_in_xmm_;
91 with_relu_ = (pd_->with_relu_post_op() || pd_->fuse_norm_relu())
92 && pd_->is_fwd();
93 }
94
load_common_paramsdnnl::impl::cpu::x64::jit_bnorm_base_t95 void load_common_params() {
96 mov(reg_tmp, float2int(1.0f));
97 uni_vmovq(xone, reg_tmp);
98 uni_vbroadcastss(vone, xone);
99
100 #define PARAM_OFF(x) offsetof(call_params_t, x)
101 uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
102 uni_vpxor(vzero, vzero, vzero);
103
104 mov(reg_channel_offt_count,
105 ptr[reg_param + PARAM_OFF(channel_offt_count)]);
106 mov(reg_spat_offt_count, ptr[reg_param + PARAM_OFF(spat_offt_count)]);
107 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
108 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
109 mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
110 mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
111 mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]);
112 mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
113 #undef PARAM_OFF
114 }
115
mean_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t116 Address mean_ptr(size_t offt = 0) {
117 return vmmword[reg_mean + reg_channel_offt_4byte + offt];
118 }
119
var_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t120 Address var_ptr(size_t offt = 0) {
121 return vmmword[reg_var + reg_channel_offt_4byte + offt];
122 }
123
scale_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t124 Address scale_ptr(size_t offt = 0) {
125 return vmmword[reg_scale + reg_channel_offt_4byte + offt
126 + 0 * chan_data_offt_];
127 }
128
shift_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t129 Address shift_ptr(size_t offt = 0) {
130 return vmmword[reg_shift + reg_channel_offt_4byte + offt
131 + 0 * chan_data_offt_];
132 }
133
src_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t134 Address src_ptr(size_t offt = 0) {
135 return vmmword[reg_src + reg_spat_offt + offt];
136 }
137
dst_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t138 Address dst_ptr(size_t offt = 0) {
139 return vmmword[reg_dst + reg_spat_offt + offt];
140 }
141
prepare_tail_maskdnnl::impl::cpu::x64::jit_bnorm_base_t142 virtual void prepare_tail_mask() {}
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_base_t143 virtual void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar,
144 size_t offt, bool need_tail) {}
load_scalednnl::impl::cpu::x64::jit_bnorm_base_t145 virtual void load_scale(const Vmm &vscale, size_t offt, bool need_tail) {}
load_shiftdnnl::impl::cpu::x64::jit_bnorm_base_t146 virtual void load_shift(const Vmm &vshift, size_t offt, bool need_tail) {}
compute_dstdnnl::impl::cpu::x64::jit_bnorm_base_t147 virtual void compute_dst(bool need_tail) {}
148
149 // Precomputes vscale and vshift for following
150 // `vdst = vscale * vsrc + vshift`
compute_vscaleshiftdnnl::impl::cpu::x64::jit_bnorm_base_t151 void compute_vscaleshift(const Vmm &vscale, const Vmm &vshift,
152 const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
153 bool need_tail) {
154 load_mean_and_var(vmean, vsqrtvar, offt, need_tail);
155 uni_vaddps(vsqrtvar, vsqrtvar, veps);
156 uni_vsqrtps(vsqrtvar, vsqrtvar);
157
158 if (pd_->use_scaleshift() || (pd_->use_scale() && pd_->use_shift())) {
159 load_scale(vscale, offt, need_tail);
160 uni_vdivps(vscale, vscale, vsqrtvar);
161 load_shift(vshift, offt, need_tail);
162 uni_vfnmadd231ps(vshift, vmean, vscale);
163 } else if (pd_->use_scale()) {
164 load_scale(vscale, offt, need_tail);
165 uni_vdivps(vscale, vscale, vsqrtvar);
166 uni_vmulps(vmean, vmean, vscale);
167 uni_vsubps(vshift, vzero, vmean, vshift);
168 } else if (pd_->use_shift()) {
169 uni_vdivps(vscale, vone, vsqrtvar, vscale);
170 load_shift(vshift, offt, need_tail);
171 uni_vfnmadd231ps(vshift, vmean, vscale);
172 } else {
173 uni_vdivps(vscale, vone, vsqrtvar, vscale);
174 uni_vmulps(vmean, vmean, vscale);
175 uni_vsubps(vshift, vzero, vmean, vshift);
176 }
177 }
178
forwarddnnl::impl::cpu::x64::jit_bnorm_base_t179 void forward() {
180 xor_(reg_channel_offt_1byte, reg_channel_offt_1byte);
181 xor_(reg_channel_offt_4byte, reg_channel_offt_4byte);
182 mov(reg_tmp, sizeof(data_t) * c_in_xmm_);
183
184 if (num_c_blocks_) compute_dst(false);
185 if (c_tail_) compute_dst(true);
186 }
187
188 // either this stub or duplication at each jit_binary_t ctor due to methods
189 // that are participated are not defined at the moment of base ctor
190 // initialization.
generatednnl::impl::cpu::x64::jit_bnorm_base_t191 void generate() override {
192 preamble();
193 compute_predefined_variables();
194 load_common_params();
195 prepare_tail_mask();
196 forward();
197 postamble();
198 }
199
jit_bnorm_base_tdnnl::impl::cpu::x64::jit_bnorm_base_t200 jit_bnorm_base_t(const batch_normalization_pd_t *pd) : pd_(pd) {}
201 };
202
203 template <cpu_isa_t isa>
204 struct jit_bnorm_s8_t;
205
206 template <>
207 struct jit_bnorm_s8_t<avx512_core> : public jit_bnorm_base_t<avx512_core> {
208 Opmask tail_opmask = Opmask(1); // f32 mask for channel math
209
prepare_tail_maskdnnl::impl::cpu::x64::jit_bnorm_s8_t210 void prepare_tail_mask() override {
211 if (!c_tail_) return;
212
213 const int mask_f32 = (1 << c_tail_) - 1;
214
215 Reg32 regw_tmp = reg_tmp.cvt32();
216 mov(regw_tmp, mask_f32);
217 kmovw(tail_opmask, regw_tmp);
218 }
219
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_s8_t220 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
221 bool need_tail) override {
222 if (need_tail) {
223 uni_vmovups_tail(vmean, tail_opmask, mean_ptr(offt));
224 uni_vmovups_tail(vsqrtvar, tail_opmask, var_ptr(offt));
225 } else {
226 uni_vmovups(vmean, mean_ptr(offt));
227 uni_vmovups(vsqrtvar, var_ptr(offt));
228 }
229 }
230
load_scalednnl::impl::cpu::x64::jit_bnorm_s8_t231 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
232 if (need_tail) {
233 uni_vmovups_tail(vscale, tail_opmask, scale_ptr(offt));
234 } else {
235 uni_vmovups(vscale, scale_ptr(offt));
236 }
237 }
238
load_shiftdnnl::impl::cpu::x64::jit_bnorm_s8_t239 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
240 if (need_tail) {
241 uni_vmovups_tail(vshift, tail_opmask, shift_ptr(offt));
242 } else {
243 uni_vmovups(vshift, shift_ptr(offt));
244 }
245 }
246
compute_dstdnnl::impl::cpu::x64::jit_bnorm_s8_t247 void compute_dst(bool need_tail) override {
248 Label c_loop;
249 L(c_loop);
250 {
251 Xmm x = Xmm(0);
252 Vmm v = Vmm(0);
253 Vmm vscale = Vmm(1);
254 Vmm vshift = Vmm(2);
255 Vmm vmean = Vmm(3);
256 Vmm vsqrtvar = Vmm(4);
257
258 // compute single vscale and vshift vectors...
259 compute_vscaleshift(vscale, vshift, vmean, vsqrtvar, 0, need_tail);
260
261 // ... then process all spatial loop with it and move to the
262 // next channel chunk
263 mov(reg_spat_offt, reg_channel_offt_1byte);
264 Label mb_sp_loop;
265 L(mb_sp_loop);
266 {
267 if (need_tail) {
268 for (size_t tl = 0; tl < c_tail_; tl++)
269 vpinsrb(x, x, src_ptr(tl), tl);
270 vpmovsxbd(v, x);
271 } else
272 vpmovsxbd(v, src_ptr());
273
274 vcvtdq2ps(v, v);
275
276 uni_vfmadd213ps(v, vscale, vshift);
277 if (with_relu_) uni_vmaxps(v, v, vzero);
278
279 vcvtps2dq(v, v);
280 if (need_tail) {
281 vpmovsdb(x, v);
282 for (size_t tl = 0; tl < c_tail_; tl++)
283 vpextrb(dst_ptr(tl), x, tl);
284 } else
285 vpmovsdb(dst_ptr(), v);
286
287 add(reg_spat_offt, reg_channel_offt_count);
288 cmp(reg_spat_offt, reg_spat_offt_count);
289 jl(mb_sp_loop);
290 }
291
292 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
293 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
294 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
295 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
296 cmp(reg_tmp, reg_channel_offt_count);
297 jle(c_loop);
298 }
299 }
300
jit_bnorm_s8_tdnnl::impl::cpu::x64::jit_bnorm_s8_t301 jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
302 : jit_bnorm_base_t<avx512_core>(pd) {}
303 };
304
305 template <>
306 struct jit_bnorm_s8_t<avx2> : public jit_bnorm_base_t<avx2> {
307 Vmm tail_vmask = Vmm(11);
308 Vmm body_vmask = Vmm(12);
309
prepare_tail_maskdnnl::impl::cpu::x64::jit_bnorm_s8_t310 void prepare_tail_mask() override {
311 // tail is always < 16, process it with two parts
312 static const uint32_t mask_half_ymm[8]
313 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0};
314 mov(reg_tmp, reinterpret_cast<size_t>(&mask_half_ymm[0]));
315 vmovups(body_vmask, ptr[reg_tmp]);
316
317 if (!c_tail_) return;
318
319 static const uint32_t mask_f32[14]
320 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
321 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
322
323 mov(reg_tmp,
324 reinterpret_cast<size_t>(&mask_f32[7 - c_tail_ % simd_w_]));
325 vmovups(tail_vmask, ptr[reg_tmp]);
326 }
327
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_s8_t328 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
329 bool need_tail) override {
330 if (need_tail) {
331 uni_vmovups_tail(vmean, tail_vmask, mean_ptr(offt));
332 uni_vmovups_tail(vsqrtvar, tail_vmask, var_ptr(offt));
333 } else {
334 uni_vmovups(vmean, mean_ptr(offt));
335 uni_vmovups(vsqrtvar, var_ptr(offt));
336 }
337 }
338
load_scalednnl::impl::cpu::x64::jit_bnorm_s8_t339 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
340 if (need_tail) {
341 uni_vmovups_tail(vscale, tail_vmask, scale_ptr(offt));
342 } else {
343 uni_vmovups(vscale, scale_ptr(offt));
344 }
345 }
346
load_shiftdnnl::impl::cpu::x64::jit_bnorm_s8_t347 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
348 if (need_tail) {
349 uni_vmovups_tail(vshift, tail_vmask, shift_ptr(offt));
350 } else {
351 uni_vmovups(vshift, shift_ptr(offt));
352 }
353 }
354
compute_dstdnnl::impl::cpu::x64::jit_bnorm_s8_t355 void compute_dst(bool need_tail) override {
356 Label c_loop;
357 L(c_loop);
358 {
359
360 Xmm x0 = Xmm(0);
361 Vmm v0 = Vmm(0);
362 Xmm x1 = Xmm(1);
363 Vmm v1 = Vmm(1);
364 Vmm vscale0 = Vmm(2);
365 Vmm vshift0 = Vmm(3);
366 Vmm vmean0 = Vmm(4);
367 Vmm vsqrtvar0 = Vmm(5);
368 Vmm vscale1 = Vmm(6);
369 Vmm vshift1 = Vmm(7);
370 Vmm vmean1 = Vmm(8);
371 Vmm vsqrtvar1 = Vmm(9);
372
373 // compute couple vscale and vshift vectors each of 8 channels...
374 compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
375 (c_tail_ < simd_w_ && need_tail) ? true : false);
376 if (!need_tail || c_tail_ > simd_w_) {
377 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
378 simd_w_ * sizeof(float), need_tail);
379 }
380
381 // ... then process all spatial loop with it and move to the
382 // next channel chunk
383 mov(reg_spat_offt, reg_channel_offt_1byte);
384 Label mb_sp_loop;
385 L(mb_sp_loop);
386 {
387
388 if (need_tail) {
389 for (size_t tl = 0; tl < c_tail_; tl++) {
390 if (tl < simd_w_) {
391 vpinsrb(x0, x0, src_ptr(tl), tl);
392 } else {
393 vpinsrb(x1, x1, src_ptr(tl), tl - simd_w_);
394 }
395 }
396 vpmovsxbd(v0, x0);
397 vpmovsxbd(v1, x1);
398 } else {
399 vpmovsxbd(v0, src_ptr());
400 vpmovsxbd(v1, src_ptr(simd_w_));
401 }
402
403 vcvtdq2ps(v0, v0);
404 vcvtdq2ps(v1, v1);
405
406 uni_vfmadd213ps(v0, vscale0, vshift0);
407 uni_vfmadd213ps(v1, vscale1, vshift1);
408 if (with_relu_) {
409 uni_vmaxps(v0, v0, vzero);
410 uni_vmaxps(v1, v1, vzero);
411 }
412
413 vcvtps2dq(v0, v0); // BA
414 vcvtps2dq(v1, v1); // DC
415 vpackssdw(v0, v0, v1); // BA + DC -> DBCA
416 vpermq(v0, v0, 0xD8); // DBCA -> DCBA
417 vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
418 vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
419
420 if (need_tail) {
421 for (size_t tl = 0; tl < c_tail_; tl++) {
422 vpextrb(dst_ptr(tl), x0, tl);
423 }
424 } else {
425 // due to vpacksswb produces 32 integers in ymm, and top
426 // half of them are garbage, do 128-b masked store
427 vmaskmovps(dst_ptr(), body_vmask, v0);
428 }
429
430 add(reg_spat_offt, reg_channel_offt_count);
431 cmp(reg_spat_offt, reg_spat_offt_count);
432 jl(mb_sp_loop);
433 }
434
435 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
436 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
437 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
438 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
439 cmp(reg_tmp, reg_channel_offt_count);
440 jle(c_loop);
441 }
442 }
443
jit_bnorm_s8_tdnnl::impl::cpu::x64::jit_bnorm_s8_t444 jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
445 : jit_bnorm_base_t<avx2>(pd) {}
446 };
447
448 template <>
449 struct jit_bnorm_s8_t<sse41> : public jit_bnorm_base_t<sse41> {
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_s8_t450 void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
451 bool need_tail) override {
452 if (need_tail) {
453 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
454 pinsrd(vmean, mean_ptr(offt + tl * sizeof(float)), tl);
455 pinsrd(vsqrtvar, var_ptr(offt + tl * sizeof(float)), tl);
456 }
457 } else {
458 movups(vmean, mean_ptr(offt));
459 movups(vsqrtvar, var_ptr(offt));
460 }
461 }
462
load_scalednnl::impl::cpu::x64::jit_bnorm_s8_t463 void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
464 if (need_tail) {
465 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
466 pinsrd(vscale, scale_ptr(offt + tl * sizeof(float)), tl);
467 }
468 } else {
469 movups(vscale, scale_ptr(offt));
470 }
471 }
472
load_shiftdnnl::impl::cpu::x64::jit_bnorm_s8_t473 void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
474 if (need_tail) {
475 for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
476 pinsrd(vshift, shift_ptr(offt + tl * sizeof(float)), tl);
477 }
478 } else {
479 movups(vshift, shift_ptr(offt));
480 }
481 }
482
compute_dstdnnl::impl::cpu::x64::jit_bnorm_s8_t483 void compute_dst(bool need_tail) override {
484 const size_t copy_range = need_tail ? c_tail_ : c_in_xmm_;
485 Label c_loop;
486 L(c_loop);
487 {
488
489 Vmm v0 = Vmm(0);
490 Vmm v1 = Vmm(1);
491 Vmm vscale0 = Vmm(2);
492 Vmm vshift0 = Vmm(3);
493 Vmm vmean0 = Vmm(4);
494 Vmm vsqrtvar0 = Vmm(5);
495 Vmm vscale1 = Vmm(6);
496 Vmm vshift1 = Vmm(7);
497 Vmm vmean1 = Vmm(8);
498 Vmm vsqrtvar1 = Vmm(9);
499
500 // compute couple vscale and vshift vectors each of 8 channels...
501 compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
502 (c_tail_ < simd_w_ && need_tail) ? true : false);
503 if (!need_tail || c_tail_ > simd_w_) {
504 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
505 simd_w_ * sizeof(float), need_tail);
506 }
507
508 // ... then process all spatial loop with it and move to the
509 // next channel chunk
510 mov(reg_spat_offt, reg_channel_offt_1byte);
511 Label mb_sp_loop;
512 L(mb_sp_loop);
513 {
514 if (need_tail) {
515 for (size_t tl = 0; tl < copy_range; tl++) {
516 if (tl < simd_w_) {
517 pinsrb(v0, src_ptr(tl), tl);
518 } else {
519 pinsrb(v1, src_ptr(tl), (tl - simd_w_));
520 }
521 }
522 pmovsxbd(v0, v0);
523 pmovsxbd(v1, v1);
524 } else {
525 pmovsxbd(v0, src_ptr());
526 pmovsxbd(v1, src_ptr(simd_w_));
527 }
528
529 cvtdq2ps(v0, v0);
530 cvtdq2ps(v1, v1);
531
532 uni_vfmadd213ps(v0, vscale0, vshift0);
533 uni_vfmadd213ps(v1, vscale1, vshift1);
534 if (with_relu_) {
535 maxps(v0, vzero);
536 maxps(v1, vzero);
537 }
538
539 cvtps2dq(v0, v0);
540 cvtps2dq(v1, v1);
541 packssdw(v0, v1);
542 movups(v1, v0);
543 packsswb(v0, v1);
544
545 // Potential perf gain is possible if combining two halves
546 // into a single vector register and use movups instead
547 // of byte stores.
548 for (size_t tl = 0; tl < copy_range; tl++) {
549 pextrb(dst_ptr(tl), v0, tl);
550 }
551
552 add(reg_spat_offt, reg_channel_offt_count);
553 cmp(reg_spat_offt, reg_spat_offt_count);
554 jl(mb_sp_loop);
555 }
556
557 // reg_tmp checks c_in_xmm_ channels ahead for further tail process
558 add(reg_tmp, sizeof(data_t) * c_in_xmm_);
559 add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
560 add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
561 cmp(reg_tmp, reg_channel_offt_count);
562 jle(c_loop);
563 }
564 }
565
jit_bnorm_s8_tdnnl::impl::cpu::x64::jit_bnorm_s8_t566 jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
567 : jit_bnorm_base_t<sse41>(pd) {}
568 };
569
570 namespace bnorm_s8_impl {
571
572 template <cpu_isa_t isa>
573 struct driver_t : public c_compatible {
driver_tdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t574 driver_t(const batch_normalization_pd_t *pd) : pd_(pd), ker_(pd_) {}
575 ~driver_t() = default;
576
577 // TODO: for problems where thread pieces don't fit L2 cache, add spatial
578 // re-balance using less pieces.
execdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t579 void exec(int ithr, int nthr, const data_t *src, data_t *dst,
580 const float *scale, const float *shift, const float *mean,
581 const float *var) {
582 dim_t N = pd_->MB();
583 dim_t C = pd_->C();
584 dim_t D = pd_->D();
585 dim_t H = pd_->H();
586 dim_t W = pd_->W();
587 dim_t SP = D * H * W;
588
589 call_params_t p;
590
591 p.eps = pd_->desc()->batch_norm_epsilon;
592
593 p.scale = scale;
594 p.shift = shift;
595 p.mean = mean;
596 p.var = var;
597
598 dim_t work_amount {N * SP}, start {0}, end {0};
599 balance211(work_amount, nthr, ithr, start, end);
600
601 p.channel_offt_count = C;
602 p.spat_offt_count = (end - start) * p.channel_offt_count;
603 p.src = src + start * p.channel_offt_count;
604 p.dst = dst + start * p.channel_offt_count;
605
606 if (p.spat_offt_count != 0) ker_(&p);
607 }
608
create_kerneldnnl::impl::cpu::x64::bnorm_s8_impl::driver_t609 status_t create_kernel() { return ker_.create_kernel(); }
610
611 private:
612 const batch_normalization_pd_t *pd_;
613
614 jit_bnorm_s8_t<isa> ker_;
615 };
616
617 } // namespace bnorm_s8_impl
618
619 using namespace data_type;
620 using namespace format_tag;
621 using namespace utils;
622
623 /* fwd */
624
625 template <cpu_isa_t isa>
init(engine_t * engine)626 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::pd_t::init(
627 engine_t *engine) {
628 auto desired_fmt_tag = (ndims() == 4) ? nhwc : ndhwc;
629
630 bool ok = true && mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
631 && one_of(ndims(), 4, 5) && stats_is_src()
632 && src_md()->data_type == s8 && check_scale_shift_data_type()
633 && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
634 && (attr()->has_default_values() || this->with_relu_post_op());
635 if (!ok) return status::unimplemented;
636
637 return status::success;
638 }
639
640 template <cpu_isa_t isa>
jit_uni_batch_normalization_s8_fwd_t(const pd_t * apd)641 jit_uni_batch_normalization_s8_fwd_t<isa>::jit_uni_batch_normalization_s8_fwd_t(
642 const pd_t *apd)
643 : primitive_t(apd) {}
644
645 template <cpu_isa_t isa>
init(engine_t * engine)646 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::init(engine_t *engine) {
647 CHECK(safe_ptr_assign(
648 bnorm_driver_, new bnorm_s8_impl::driver_t<isa>(pd())));
649 return bnorm_driver_->create_kernel();
650 }
651
652 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const653 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::execute(
654 const exec_ctx_t &ctx) const {
655
656 const memory_desc_wrapper ss_d(pd()->weights_md());
657
658 const auto use_ss = pd()->use_scaleshift();
659 const auto use_sc = pd()->use_scale();
660 const auto use_sh = pd()->use_shift();
661
662 const size_t shift_off
663 = use_ss && !ss_d.has_zero_dim() ? ss_d.off(1, 0) : 0;
664
665 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
666 auto scale = CTX_IN_MEM(
667 const float *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
668 auto shift = use_sh ? CTX_IN_MEM(const float *, DNNL_ARG_SHIFT)
669 : use_ss ? &CTX_IN_MEM(const float *,
670 DNNL_ARG_SCALE_SHIFT)[shift_off]
671 : nullptr;
672 auto mean = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN));
673 auto var
674 = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE));
675 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
676
677 // do sequential if the problem is less than one 4K memory page
678 const bool force_sequential
679 = pd()->MB() * pd()->C() * pd()->D() * pd()->H() * pd()->W()
680 <= 4096;
681
682 parallel(force_sequential ? 1 : 0, [&](const int ithr, const int nthr) {
683 bnorm_driver_->exec(ithr, nthr, src, dst, scale, shift, mean, var);
684 });
685
686 return status::success;
687 }
688
689 template <cpu_isa_t isa>
690 jit_uni_batch_normalization_s8_fwd_t<
~jit_uni_batch_normalization_s8_fwd_t()691 isa>::~jit_uni_batch_normalization_s8_fwd_t() {
692 delete bnorm_driver_;
693 }
694
695 /* struct instantiation */
696 template struct jit_uni_batch_normalization_s8_fwd_t<avx512_core>;
697 template struct jit_uni_batch_normalization_s8_fwd_t<avx2>;
698 template struct jit_uni_batch_normalization_s8_fwd_t<sse41>;
699
700 } // namespace x64
701 } // namespace cpu
702 } // namespace impl
703 } // namespace dnnl
704