1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 
26 #include "cpu/x64/jit_generator.hpp"
27 
28 #include "cpu/x64/jit_uni_batch_normalization_s8.hpp"
29 
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34 
35 using namespace Xbyak;
36 
37 using data_t = int8_t;
38 
39 struct call_params_t {
40     // keep int sizes at 8 bytes -- jit code expects this
41     size_t channel_offt_count, spat_offt_count;
42     float eps;
43     const float *scale, *shift, *mean, *var;
44     const data_t *src, *dst;
45 };
46 
47 template <cpu_isa_t isa>
48 struct jit_bnorm_base_t : public jit_generator {
49 
50     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_s8_t)
51 
52     using Vmm = typename cpu_isa_traits<isa>::Vmm;
53     const AddressFrame &vmmword
54             = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword);
55     const int vlen = cpu_isa_traits<isa>::vlen;
56 
57     const batch_normalization_pd_t *pd_;
58 
59     Reg64 reg_param = abi_param1;
60 
61     Reg64 reg_scale = rbx;
62     Reg64 reg_shift = rdx;
63     Reg64 reg_mean = rbp;
64 
65     Reg64 reg_channel_offt_count = r8;
66     Reg64 reg_spat_offt = r9;
67     Reg64 reg_spat_offt_count = r10;
68     Reg64 reg_tmp = r11;
69     Reg64 reg_src = r12;
70     Reg64 reg_dst = r13;
71     Reg64 reg_var = r14;
72     Reg64 reg_channel_offt_1byte = r15;
73     Reg64 reg_channel_offt_4byte = rax;
74 
75     Vmm vzero = Vmm(isa == avx512_core ? 29 : 13);
76     Xmm xone = Xmm(14);
77     Vmm vone = Vmm(isa == avx512_core ? 30 : 14);
78     Vmm veps = Vmm(isa == avx512_core ? 31 : 15);
79 
80     size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
81     size_t c_in_xmm_ = (isa == sse41) ? 8 : 16;
82     size_t chan_data_offt_;
83     size_t num_c_blocks_;
84     size_t c_tail_;
85     bool with_relu_;
86 
compute_predefined_variablesdnnl::impl::cpu::x64::jit_bnorm_base_t87     void compute_predefined_variables() {
88         chan_data_offt_ = pd_->C() * sizeof(float);
89         num_c_blocks_ = pd_->C() / c_in_xmm_;
90         c_tail_ = pd_->C() % c_in_xmm_;
91         with_relu_ = (pd_->with_relu_post_op() || pd_->fuse_norm_relu())
92                 && pd_->is_fwd();
93     }
94 
load_common_paramsdnnl::impl::cpu::x64::jit_bnorm_base_t95     void load_common_params() {
96         mov(reg_tmp, float2int(1.0f));
97         uni_vmovq(xone, reg_tmp);
98         uni_vbroadcastss(vone, xone);
99 
100 #define PARAM_OFF(x) offsetof(call_params_t, x)
101         uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
102         uni_vpxor(vzero, vzero, vzero);
103 
104         mov(reg_channel_offt_count,
105                 ptr[reg_param + PARAM_OFF(channel_offt_count)]);
106         mov(reg_spat_offt_count, ptr[reg_param + PARAM_OFF(spat_offt_count)]);
107         mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
108         mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
109         mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
110         mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
111         mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]);
112         mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
113 #undef PARAM_OFF
114     }
115 
mean_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t116     Address mean_ptr(size_t offt = 0) {
117         return vmmword[reg_mean + reg_channel_offt_4byte + offt];
118     }
119 
var_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t120     Address var_ptr(size_t offt = 0) {
121         return vmmword[reg_var + reg_channel_offt_4byte + offt];
122     }
123 
scale_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t124     Address scale_ptr(size_t offt = 0) {
125         return vmmword[reg_scale + reg_channel_offt_4byte + offt
126                 + 0 * chan_data_offt_];
127     }
128 
shift_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t129     Address shift_ptr(size_t offt = 0) {
130         return vmmword[reg_shift + reg_channel_offt_4byte + offt
131                 + 0 * chan_data_offt_];
132     }
133 
src_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t134     Address src_ptr(size_t offt = 0) {
135         return vmmword[reg_src + reg_spat_offt + offt];
136     }
137 
dst_ptrdnnl::impl::cpu::x64::jit_bnorm_base_t138     Address dst_ptr(size_t offt = 0) {
139         return vmmword[reg_dst + reg_spat_offt + offt];
140     }
141 
prepare_tail_maskdnnl::impl::cpu::x64::jit_bnorm_base_t142     virtual void prepare_tail_mask() {}
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_base_t143     virtual void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar,
144             size_t offt, bool need_tail) {}
load_scalednnl::impl::cpu::x64::jit_bnorm_base_t145     virtual void load_scale(const Vmm &vscale, size_t offt, bool need_tail) {}
load_shiftdnnl::impl::cpu::x64::jit_bnorm_base_t146     virtual void load_shift(const Vmm &vshift, size_t offt, bool need_tail) {}
compute_dstdnnl::impl::cpu::x64::jit_bnorm_base_t147     virtual void compute_dst(bool need_tail) {}
148 
149     // Precomputes vscale and vshift for following
150     // `vdst = vscale * vsrc + vshift`
compute_vscaleshiftdnnl::impl::cpu::x64::jit_bnorm_base_t151     void compute_vscaleshift(const Vmm &vscale, const Vmm &vshift,
152             const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
153             bool need_tail) {
154         load_mean_and_var(vmean, vsqrtvar, offt, need_tail);
155         uni_vaddps(vsqrtvar, vsqrtvar, veps);
156         uni_vsqrtps(vsqrtvar, vsqrtvar);
157 
158         if (pd_->use_scaleshift() || (pd_->use_scale() && pd_->use_shift())) {
159             load_scale(vscale, offt, need_tail);
160             uni_vdivps(vscale, vscale, vsqrtvar);
161             load_shift(vshift, offt, need_tail);
162             uni_vfnmadd231ps(vshift, vmean, vscale);
163         } else if (pd_->use_scale()) {
164             load_scale(vscale, offt, need_tail);
165             uni_vdivps(vscale, vscale, vsqrtvar);
166             uni_vmulps(vmean, vmean, vscale);
167             uni_vsubps(vshift, vzero, vmean, vshift);
168         } else if (pd_->use_shift()) {
169             uni_vdivps(vscale, vone, vsqrtvar, vscale);
170             load_shift(vshift, offt, need_tail);
171             uni_vfnmadd231ps(vshift, vmean, vscale);
172         } else {
173             uni_vdivps(vscale, vone, vsqrtvar, vscale);
174             uni_vmulps(vmean, vmean, vscale);
175             uni_vsubps(vshift, vzero, vmean, vshift);
176         }
177     }
178 
forwarddnnl::impl::cpu::x64::jit_bnorm_base_t179     void forward() {
180         xor_(reg_channel_offt_1byte, reg_channel_offt_1byte);
181         xor_(reg_channel_offt_4byte, reg_channel_offt_4byte);
182         mov(reg_tmp, sizeof(data_t) * c_in_xmm_);
183 
184         if (num_c_blocks_) compute_dst(false);
185         if (c_tail_) compute_dst(true);
186     }
187 
188     // either this stub or duplication at each jit_binary_t ctor due to methods
189     // that are participated are not defined at the moment of base ctor
190     // initialization.
generatednnl::impl::cpu::x64::jit_bnorm_base_t191     void generate() override {
192         preamble();
193         compute_predefined_variables();
194         load_common_params();
195         prepare_tail_mask();
196         forward();
197         postamble();
198     }
199 
jit_bnorm_base_tdnnl::impl::cpu::x64::jit_bnorm_base_t200     jit_bnorm_base_t(const batch_normalization_pd_t *pd) : pd_(pd) {}
201 };
202 
203 template <cpu_isa_t isa>
204 struct jit_bnorm_s8_t;
205 
206 template <>
207 struct jit_bnorm_s8_t<avx512_core> : public jit_bnorm_base_t<avx512_core> {
208     Opmask tail_opmask = Opmask(1); // f32 mask for channel math
209 
prepare_tail_maskdnnl::impl::cpu::x64::jit_bnorm_s8_t210     void prepare_tail_mask() override {
211         if (!c_tail_) return;
212 
213         const int mask_f32 = (1 << c_tail_) - 1;
214 
215         Reg32 regw_tmp = reg_tmp.cvt32();
216         mov(regw_tmp, mask_f32);
217         kmovw(tail_opmask, regw_tmp);
218     }
219 
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_s8_t220     void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
221             bool need_tail) override {
222         if (need_tail) {
223             uni_vmovups_tail(vmean, tail_opmask, mean_ptr(offt));
224             uni_vmovups_tail(vsqrtvar, tail_opmask, var_ptr(offt));
225         } else {
226             uni_vmovups(vmean, mean_ptr(offt));
227             uni_vmovups(vsqrtvar, var_ptr(offt));
228         }
229     }
230 
load_scalednnl::impl::cpu::x64::jit_bnorm_s8_t231     void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
232         if (need_tail) {
233             uni_vmovups_tail(vscale, tail_opmask, scale_ptr(offt));
234         } else {
235             uni_vmovups(vscale, scale_ptr(offt));
236         }
237     }
238 
load_shiftdnnl::impl::cpu::x64::jit_bnorm_s8_t239     void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
240         if (need_tail) {
241             uni_vmovups_tail(vshift, tail_opmask, shift_ptr(offt));
242         } else {
243             uni_vmovups(vshift, shift_ptr(offt));
244         }
245     }
246 
compute_dstdnnl::impl::cpu::x64::jit_bnorm_s8_t247     void compute_dst(bool need_tail) override {
248         Label c_loop;
249         L(c_loop);
250         {
251             Xmm x = Xmm(0);
252             Vmm v = Vmm(0);
253             Vmm vscale = Vmm(1);
254             Vmm vshift = Vmm(2);
255             Vmm vmean = Vmm(3);
256             Vmm vsqrtvar = Vmm(4);
257 
258             // compute single vscale and vshift vectors...
259             compute_vscaleshift(vscale, vshift, vmean, vsqrtvar, 0, need_tail);
260 
261             // ... then process all spatial loop with it and move to the
262             // next channel chunk
263             mov(reg_spat_offt, reg_channel_offt_1byte);
264             Label mb_sp_loop;
265             L(mb_sp_loop);
266             {
267                 if (need_tail) {
268                     for (size_t tl = 0; tl < c_tail_; tl++)
269                         vpinsrb(x, x, src_ptr(tl), tl);
270                     vpmovsxbd(v, x);
271                 } else
272                     vpmovsxbd(v, src_ptr());
273 
274                 vcvtdq2ps(v, v);
275 
276                 uni_vfmadd213ps(v, vscale, vshift);
277                 if (with_relu_) uni_vmaxps(v, v, vzero);
278 
279                 vcvtps2dq(v, v);
280                 if (need_tail) {
281                     vpmovsdb(x, v);
282                     for (size_t tl = 0; tl < c_tail_; tl++)
283                         vpextrb(dst_ptr(tl), x, tl);
284                 } else
285                     vpmovsdb(dst_ptr(), v);
286 
287                 add(reg_spat_offt, reg_channel_offt_count);
288                 cmp(reg_spat_offt, reg_spat_offt_count);
289                 jl(mb_sp_loop);
290             }
291 
292             // reg_tmp checks c_in_xmm_ channels ahead for further tail process
293             add(reg_tmp, sizeof(data_t) * c_in_xmm_);
294             add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
295             add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
296             cmp(reg_tmp, reg_channel_offt_count);
297             jle(c_loop);
298         }
299     }
300 
jit_bnorm_s8_tdnnl::impl::cpu::x64::jit_bnorm_s8_t301     jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
302         : jit_bnorm_base_t<avx512_core>(pd) {}
303 };
304 
305 template <>
306 struct jit_bnorm_s8_t<avx2> : public jit_bnorm_base_t<avx2> {
307     Vmm tail_vmask = Vmm(11);
308     Vmm body_vmask = Vmm(12);
309 
prepare_tail_maskdnnl::impl::cpu::x64::jit_bnorm_s8_t310     void prepare_tail_mask() override {
311         // tail is always < 16, process it with two parts
312         static const uint32_t mask_half_ymm[8]
313                 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0};
314         mov(reg_tmp, reinterpret_cast<size_t>(&mask_half_ymm[0]));
315         vmovups(body_vmask, ptr[reg_tmp]);
316 
317         if (!c_tail_) return;
318 
319         static const uint32_t mask_f32[14]
320                 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
321                         0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
322 
323         mov(reg_tmp,
324                 reinterpret_cast<size_t>(&mask_f32[7 - c_tail_ % simd_w_]));
325         vmovups(tail_vmask, ptr[reg_tmp]);
326     }
327 
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_s8_t328     void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
329             bool need_tail) override {
330         if (need_tail) {
331             uni_vmovups_tail(vmean, tail_vmask, mean_ptr(offt));
332             uni_vmovups_tail(vsqrtvar, tail_vmask, var_ptr(offt));
333         } else {
334             uni_vmovups(vmean, mean_ptr(offt));
335             uni_vmovups(vsqrtvar, var_ptr(offt));
336         }
337     }
338 
load_scalednnl::impl::cpu::x64::jit_bnorm_s8_t339     void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
340         if (need_tail) {
341             uni_vmovups_tail(vscale, tail_vmask, scale_ptr(offt));
342         } else {
343             uni_vmovups(vscale, scale_ptr(offt));
344         }
345     }
346 
load_shiftdnnl::impl::cpu::x64::jit_bnorm_s8_t347     void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
348         if (need_tail) {
349             uni_vmovups_tail(vshift, tail_vmask, shift_ptr(offt));
350         } else {
351             uni_vmovups(vshift, shift_ptr(offt));
352         }
353     }
354 
compute_dstdnnl::impl::cpu::x64::jit_bnorm_s8_t355     void compute_dst(bool need_tail) override {
356         Label c_loop;
357         L(c_loop);
358         {
359 
360             Xmm x0 = Xmm(0);
361             Vmm v0 = Vmm(0);
362             Xmm x1 = Xmm(1);
363             Vmm v1 = Vmm(1);
364             Vmm vscale0 = Vmm(2);
365             Vmm vshift0 = Vmm(3);
366             Vmm vmean0 = Vmm(4);
367             Vmm vsqrtvar0 = Vmm(5);
368             Vmm vscale1 = Vmm(6);
369             Vmm vshift1 = Vmm(7);
370             Vmm vmean1 = Vmm(8);
371             Vmm vsqrtvar1 = Vmm(9);
372 
373             // compute couple vscale and vshift vectors each of 8 channels...
374             compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
375                     (c_tail_ < simd_w_ && need_tail) ? true : false);
376             if (!need_tail || c_tail_ > simd_w_) {
377                 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
378                         simd_w_ * sizeof(float), need_tail);
379             }
380 
381             // ... then process all spatial loop with it and move to the
382             // next channel chunk
383             mov(reg_spat_offt, reg_channel_offt_1byte);
384             Label mb_sp_loop;
385             L(mb_sp_loop);
386             {
387 
388                 if (need_tail) {
389                     for (size_t tl = 0; tl < c_tail_; tl++) {
390                         if (tl < simd_w_) {
391                             vpinsrb(x0, x0, src_ptr(tl), tl);
392                         } else {
393                             vpinsrb(x1, x1, src_ptr(tl), tl - simd_w_);
394                         }
395                     }
396                     vpmovsxbd(v0, x0);
397                     vpmovsxbd(v1, x1);
398                 } else {
399                     vpmovsxbd(v0, src_ptr());
400                     vpmovsxbd(v1, src_ptr(simd_w_));
401                 }
402 
403                 vcvtdq2ps(v0, v0);
404                 vcvtdq2ps(v1, v1);
405 
406                 uni_vfmadd213ps(v0, vscale0, vshift0);
407                 uni_vfmadd213ps(v1, vscale1, vshift1);
408                 if (with_relu_) {
409                     uni_vmaxps(v0, v0, vzero);
410                     uni_vmaxps(v1, v1, vzero);
411                 }
412 
413                 vcvtps2dq(v0, v0); // BA
414                 vcvtps2dq(v1, v1); // DC
415                 vpackssdw(v0, v0, v1); // BA + DC -> DBCA
416                 vpermq(v0, v0, 0xD8); // DBCA -> DCBA
417                 vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
418                 vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
419 
420                 if (need_tail) {
421                     for (size_t tl = 0; tl < c_tail_; tl++) {
422                         vpextrb(dst_ptr(tl), x0, tl);
423                     }
424                 } else {
425                     // due to vpacksswb produces 32 integers in ymm, and top
426                     // half of them are garbage, do 128-b masked store
427                     vmaskmovps(dst_ptr(), body_vmask, v0);
428                 }
429 
430                 add(reg_spat_offt, reg_channel_offt_count);
431                 cmp(reg_spat_offt, reg_spat_offt_count);
432                 jl(mb_sp_loop);
433             }
434 
435             // reg_tmp checks c_in_xmm_ channels ahead for further tail process
436             add(reg_tmp, sizeof(data_t) * c_in_xmm_);
437             add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
438             add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
439             cmp(reg_tmp, reg_channel_offt_count);
440             jle(c_loop);
441         }
442     }
443 
jit_bnorm_s8_tdnnl::impl::cpu::x64::jit_bnorm_s8_t444     jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
445         : jit_bnorm_base_t<avx2>(pd) {}
446 };
447 
448 template <>
449 struct jit_bnorm_s8_t<sse41> : public jit_bnorm_base_t<sse41> {
load_mean_and_vardnnl::impl::cpu::x64::jit_bnorm_s8_t450     void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
451             bool need_tail) override {
452         if (need_tail) {
453             for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
454                 pinsrd(vmean, mean_ptr(offt + tl * sizeof(float)), tl);
455                 pinsrd(vsqrtvar, var_ptr(offt + tl * sizeof(float)), tl);
456             }
457         } else {
458             movups(vmean, mean_ptr(offt));
459             movups(vsqrtvar, var_ptr(offt));
460         }
461     }
462 
load_scalednnl::impl::cpu::x64::jit_bnorm_s8_t463     void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
464         if (need_tail) {
465             for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
466                 pinsrd(vscale, scale_ptr(offt + tl * sizeof(float)), tl);
467             }
468         } else {
469             movups(vscale, scale_ptr(offt));
470         }
471     }
472 
load_shiftdnnl::impl::cpu::x64::jit_bnorm_s8_t473     void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
474         if (need_tail) {
475             for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
476                 pinsrd(vshift, shift_ptr(offt + tl * sizeof(float)), tl);
477             }
478         } else {
479             movups(vshift, shift_ptr(offt));
480         }
481     }
482 
compute_dstdnnl::impl::cpu::x64::jit_bnorm_s8_t483     void compute_dst(bool need_tail) override {
484         const size_t copy_range = need_tail ? c_tail_ : c_in_xmm_;
485         Label c_loop;
486         L(c_loop);
487         {
488 
489             Vmm v0 = Vmm(0);
490             Vmm v1 = Vmm(1);
491             Vmm vscale0 = Vmm(2);
492             Vmm vshift0 = Vmm(3);
493             Vmm vmean0 = Vmm(4);
494             Vmm vsqrtvar0 = Vmm(5);
495             Vmm vscale1 = Vmm(6);
496             Vmm vshift1 = Vmm(7);
497             Vmm vmean1 = Vmm(8);
498             Vmm vsqrtvar1 = Vmm(9);
499 
500             // compute couple vscale and vshift vectors each of 8 channels...
501             compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
502                     (c_tail_ < simd_w_ && need_tail) ? true : false);
503             if (!need_tail || c_tail_ > simd_w_) {
504                 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
505                         simd_w_ * sizeof(float), need_tail);
506             }
507 
508             // ... then process all spatial loop with it and move to the
509             // next channel chunk
510             mov(reg_spat_offt, reg_channel_offt_1byte);
511             Label mb_sp_loop;
512             L(mb_sp_loop);
513             {
514                 if (need_tail) {
515                     for (size_t tl = 0; tl < copy_range; tl++) {
516                         if (tl < simd_w_) {
517                             pinsrb(v0, src_ptr(tl), tl);
518                         } else {
519                             pinsrb(v1, src_ptr(tl), (tl - simd_w_));
520                         }
521                     }
522                     pmovsxbd(v0, v0);
523                     pmovsxbd(v1, v1);
524                 } else {
525                     pmovsxbd(v0, src_ptr());
526                     pmovsxbd(v1, src_ptr(simd_w_));
527                 }
528 
529                 cvtdq2ps(v0, v0);
530                 cvtdq2ps(v1, v1);
531 
532                 uni_vfmadd213ps(v0, vscale0, vshift0);
533                 uni_vfmadd213ps(v1, vscale1, vshift1);
534                 if (with_relu_) {
535                     maxps(v0, vzero);
536                     maxps(v1, vzero);
537                 }
538 
539                 cvtps2dq(v0, v0);
540                 cvtps2dq(v1, v1);
541                 packssdw(v0, v1);
542                 movups(v1, v0);
543                 packsswb(v0, v1);
544 
545                 // Potential perf gain is possible if combining two halves
546                 // into a single vector register and use movups instead
547                 // of byte stores.
548                 for (size_t tl = 0; tl < copy_range; tl++) {
549                     pextrb(dst_ptr(tl), v0, tl);
550                 }
551 
552                 add(reg_spat_offt, reg_channel_offt_count);
553                 cmp(reg_spat_offt, reg_spat_offt_count);
554                 jl(mb_sp_loop);
555             }
556 
557             // reg_tmp checks c_in_xmm_ channels ahead for further tail process
558             add(reg_tmp, sizeof(data_t) * c_in_xmm_);
559             add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
560             add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
561             cmp(reg_tmp, reg_channel_offt_count);
562             jle(c_loop);
563         }
564     }
565 
jit_bnorm_s8_tdnnl::impl::cpu::x64::jit_bnorm_s8_t566     jit_bnorm_s8_t(const batch_normalization_pd_t *pd)
567         : jit_bnorm_base_t<sse41>(pd) {}
568 };
569 
570 namespace bnorm_s8_impl {
571 
572 template <cpu_isa_t isa>
573 struct driver_t : public c_compatible {
driver_tdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t574     driver_t(const batch_normalization_pd_t *pd) : pd_(pd), ker_(pd_) {}
575     ~driver_t() = default;
576 
577     // TODO: for problems where thread pieces don't fit L2 cache, add spatial
578     // re-balance using less pieces.
execdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t579     void exec(int ithr, int nthr, const data_t *src, data_t *dst,
580             const float *scale, const float *shift, const float *mean,
581             const float *var) {
582         dim_t N = pd_->MB();
583         dim_t C = pd_->C();
584         dim_t D = pd_->D();
585         dim_t H = pd_->H();
586         dim_t W = pd_->W();
587         dim_t SP = D * H * W;
588 
589         call_params_t p;
590 
591         p.eps = pd_->desc()->batch_norm_epsilon;
592 
593         p.scale = scale;
594         p.shift = shift;
595         p.mean = mean;
596         p.var = var;
597 
598         dim_t work_amount {N * SP}, start {0}, end {0};
599         balance211(work_amount, nthr, ithr, start, end);
600 
601         p.channel_offt_count = C;
602         p.spat_offt_count = (end - start) * p.channel_offt_count;
603         p.src = src + start * p.channel_offt_count;
604         p.dst = dst + start * p.channel_offt_count;
605 
606         if (p.spat_offt_count != 0) ker_(&p);
607     }
608 
create_kerneldnnl::impl::cpu::x64::bnorm_s8_impl::driver_t609     status_t create_kernel() { return ker_.create_kernel(); }
610 
611 private:
612     const batch_normalization_pd_t *pd_;
613 
614     jit_bnorm_s8_t<isa> ker_;
615 };
616 
617 } // namespace bnorm_s8_impl
618 
619 using namespace data_type;
620 using namespace format_tag;
621 using namespace utils;
622 
623 /* fwd */
624 
625 template <cpu_isa_t isa>
init(engine_t * engine)626 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::pd_t::init(
627         engine_t *engine) {
628     auto desired_fmt_tag = (ndims() == 4) ? nhwc : ndhwc;
629 
630     bool ok = true && mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
631             && one_of(ndims(), 4, 5) && stats_is_src()
632             && src_md()->data_type == s8 && check_scale_shift_data_type()
633             && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
634             && (attr()->has_default_values() || this->with_relu_post_op());
635     if (!ok) return status::unimplemented;
636 
637     return status::success;
638 }
639 
640 template <cpu_isa_t isa>
jit_uni_batch_normalization_s8_fwd_t(const pd_t * apd)641 jit_uni_batch_normalization_s8_fwd_t<isa>::jit_uni_batch_normalization_s8_fwd_t(
642         const pd_t *apd)
643     : primitive_t(apd) {}
644 
645 template <cpu_isa_t isa>
init(engine_t * engine)646 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::init(engine_t *engine) {
647     CHECK(safe_ptr_assign(
648             bnorm_driver_, new bnorm_s8_impl::driver_t<isa>(pd())));
649     return bnorm_driver_->create_kernel();
650 }
651 
652 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const653 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::execute(
654         const exec_ctx_t &ctx) const {
655 
656     const memory_desc_wrapper ss_d(pd()->weights_md());
657 
658     const auto use_ss = pd()->use_scaleshift();
659     const auto use_sc = pd()->use_scale();
660     const auto use_sh = pd()->use_shift();
661 
662     const size_t shift_off
663             = use_ss && !ss_d.has_zero_dim() ? ss_d.off(1, 0) : 0;
664 
665     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
666     auto scale = CTX_IN_MEM(
667             const float *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
668     auto shift = use_sh ? CTX_IN_MEM(const float *, DNNL_ARG_SHIFT)
669                         : use_ss ? &CTX_IN_MEM(const float *,
670                                   DNNL_ARG_SCALE_SHIFT)[shift_off]
671                                  : nullptr;
672     auto mean = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN));
673     auto var
674             = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE));
675     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
676 
677     // do sequential if the problem is less than one 4K memory page
678     const bool force_sequential
679             = pd()->MB() * pd()->C() * pd()->D() * pd()->H() * pd()->W()
680             <= 4096;
681 
682     parallel(force_sequential ? 1 : 0, [&](const int ithr, const int nthr) {
683         bnorm_driver_->exec(ithr, nthr, src, dst, scale, shift, mean, var);
684     });
685 
686     return status::success;
687 }
688 
689 template <cpu_isa_t isa>
690 jit_uni_batch_normalization_s8_fwd_t<
~jit_uni_batch_normalization_s8_fwd_t()691         isa>::~jit_uni_batch_normalization_s8_fwd_t() {
692     delete bnorm_driver_;
693 }
694 
695 /* struct instantiation */
696 template struct jit_uni_batch_normalization_s8_fwd_t<avx512_core>;
697 template struct jit_uni_batch_normalization_s8_fwd_t<avx2>;
698 template struct jit_uni_batch_normalization_s8_fwd_t<sse41>;
699 
700 } // namespace x64
701 } // namespace cpu
702 } // namespace impl
703 } // namespace dnnl
704