1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 
26 #include "cpu/x64/jit_generator.hpp"
27 
28 #include "cpu/x64/jit_uni_batch_normalization_s8.hpp"
29 
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34 
35 namespace {
36 
37 using namespace Xbyak;
38 
39 using data_t = int8_t;
40 
41 struct call_params_t {
42     // keep int sizes at 8 bytes -- jit code expects this
43     size_t channel_offt_count, spat_offt_count;
44     float eps;
45     const float *scale, *shift, *mean, *var;
46     const data_t *src, *dst;
47 };
48 
49 template <cpu_isa_t isa>
50 struct jit_bnorm_base_t : public jit_generator {
51 
52     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
53 
54     using Vmm = typename cpu_isa_traits<isa>::Vmm;
55     const AddressFrame &vmmword
56             = (isa == sse41) ? xword : ((isa == avx2) ? yword : zword);
57     const int vlen = cpu_isa_traits<isa>::vlen;
58 
59     const batch_normalization_pd_t *pd_;
60 
61     Reg64 reg_param = abi_param1;
62 
63     Reg64 reg_scale = rbx;
64     Reg64 reg_shift = rdx;
65     Reg64 reg_mean = rbp;
66 
67     Reg64 reg_channel_offt_count = r8;
68     Reg64 reg_spat_offt = r9;
69     Reg64 reg_spat_offt_count = r10;
70     Reg64 reg_tmp = r11;
71     Reg64 reg_src = r12;
72     Reg64 reg_dst = r13;
73     Reg64 reg_var = r14;
74     Reg64 reg_channel_offt_1byte = r15;
75     Reg64 reg_channel_offt_4byte = rax;
76 
77     Vmm vzero = Vmm(isa == avx512_core ? 29 : 13);
78     Xmm xone = Xmm(14);
79     Vmm vone = Vmm(isa == avx512_core ? 30 : 14);
80     Vmm veps = Vmm(isa == avx512_core ? 31 : 15);
81 
82     size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
83     size_t c_in_xmm_ = (isa == sse41) ? 8 : 16;
84     size_t chan_data_offt_;
85     size_t num_c_blocks_;
86     size_t c_tail_;
87     bool with_relu_;
88 
compute_predefined_variablesdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t89     void compute_predefined_variables() {
90         chan_data_offt_ = pd_->C() * sizeof(float);
91         num_c_blocks_ = pd_->C() / c_in_xmm_;
92         c_tail_ = pd_->C() % c_in_xmm_;
93         with_relu_ = (pd_->with_relu_post_op() || pd_->fuse_norm_relu())
94                 && pd_->is_fwd();
95     }
96 
load_common_paramsdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t97     void load_common_params() {
98         mov(reg_tmp, float2int(1.0f));
99         uni_vmovq(xone, reg_tmp);
100         uni_vbroadcastss(vone, xone);
101 
102 #define PARAM_OFF(x) offsetof(call_params_t, x)
103         uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
104         uni_vpxor(vzero, vzero, vzero);
105 
106         mov(reg_channel_offt_count,
107                 ptr[reg_param + PARAM_OFF(channel_offt_count)]);
108         mov(reg_spat_offt_count, ptr[reg_param + PARAM_OFF(spat_offt_count)]);
109         mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
110         mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
111         mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
112         mov(reg_scale, ptr[reg_param + PARAM_OFF(scale)]);
113         mov(reg_shift, ptr[reg_param + PARAM_OFF(shift)]);
114         mov(reg_var, ptr[reg_param + PARAM_OFF(var)]);
115 #undef PARAM_OFF
116     }
117 
mean_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t118     Address mean_ptr(size_t offt = 0) {
119         return vmmword[reg_mean + reg_channel_offt_4byte + offt];
120     }
121 
var_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t122     Address var_ptr(size_t offt = 0) {
123         return vmmword[reg_var + reg_channel_offt_4byte + offt];
124     }
125 
scale_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t126     Address scale_ptr(size_t offt = 0) {
127         return vmmword[reg_scale + reg_channel_offt_4byte + offt
128                 + 0 * chan_data_offt_];
129     }
130 
shift_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t131     Address shift_ptr(size_t offt = 0) {
132         return vmmword[reg_shift + reg_channel_offt_4byte + offt
133                 + 0 * chan_data_offt_];
134     }
135 
src_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t136     Address src_ptr(size_t offt = 0) {
137         return vmmword[reg_src + reg_spat_offt + offt];
138     }
139 
dst_ptrdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t140     Address dst_ptr(size_t offt = 0) {
141         return vmmword[reg_dst + reg_spat_offt + offt];
142     }
143 
prepare_tail_maskdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t144     virtual void prepare_tail_mask() {}
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t145     virtual void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar,
146             size_t offt, bool need_tail) {}
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t147     virtual void load_scale(const Vmm &vscale, size_t offt, bool need_tail) {}
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t148     virtual void load_shift(const Vmm &vshift, size_t offt, bool need_tail) {}
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t149     virtual void compute_dst(bool need_tail) {}
150 
151     // Precomputes vscale and vshift for following
152     // `vdst = vscale * vsrc + vshift`
compute_vscaleshiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t153     void compute_vscaleshift(const Vmm &vscale, const Vmm &vshift,
154             const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
155             bool need_tail) {
156         load_mean_and_var(vmean, vsqrtvar, offt, need_tail);
157         uni_vaddps(vsqrtvar, vsqrtvar, veps);
158         uni_vsqrtps(vsqrtvar, vsqrtvar);
159 
160         if (pd_->use_scaleshift() || (pd_->use_scale() && pd_->use_shift())) {
161             load_scale(vscale, offt, need_tail);
162             uni_vdivps(vscale, vscale, vsqrtvar);
163             load_shift(vshift, offt, need_tail);
164             uni_vfnmadd231ps(vshift, vmean, vscale);
165         } else if (pd_->use_scale()) {
166             load_scale(vscale, offt, need_tail);
167             uni_vdivps(vscale, vscale, vsqrtvar);
168             uni_vmulps(vmean, vmean, vscale);
169             uni_vsubps(vshift, vzero, vmean, vshift);
170         } else if (pd_->use_shift()) {
171             uni_vdivps(vscale, vone, vsqrtvar, vscale);
172             load_shift(vshift, offt, need_tail);
173             uni_vfnmadd231ps(vshift, vmean, vscale);
174         } else {
175             uni_vdivps(vscale, vone, vsqrtvar, vscale);
176             uni_vmulps(vmean, vmean, vscale);
177             uni_vsubps(vshift, vzero, vmean, vshift);
178         }
179     }
180 
forwarddnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t181     void forward() {
182         xor_(reg_channel_offt_1byte, reg_channel_offt_1byte);
183         xor_(reg_channel_offt_4byte, reg_channel_offt_4byte);
184         mov(reg_tmp, sizeof(data_t) * c_in_xmm_);
185 
186         if (num_c_blocks_) compute_dst(false);
187         if (c_tail_) compute_dst(true);
188     }
189 
190     // either this stub or duplication at each jit_binary_t ctor due to methods
191     // that are participated are not defined at the moment of base ctor
192     // initialization.
generatednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t193     void generate() override {
194         preamble();
195         compute_predefined_variables();
196         load_common_params();
197         prepare_tail_mask();
198         forward();
199         postamble();
200     }
201 
jit_bnorm_base_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_base_t202     jit_bnorm_base_t(const batch_normalization_pd_t *pd) : pd_(pd) {}
203 };
204 
205 template <cpu_isa_t isa>
206 struct jit_bnorm_t;
207 
208 template <>
209 struct jit_bnorm_t<avx512_core> : public jit_bnorm_base_t<avx512_core> {
210     Opmask tail_opmask = Opmask(1); // f32 mask for channel math
211 
prepare_tail_maskdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t212     void prepare_tail_mask() override {
213         if (!c_tail_) return;
214 
215         const int mask_f32 = (1 << c_tail_) - 1;
216 
217         Reg32 regw_tmp = reg_tmp.cvt32();
218         mov(regw_tmp, mask_f32);
219         kmovw(tail_opmask, regw_tmp);
220     }
221 
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t222     void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
223             bool need_tail) override {
224         if (need_tail) {
225             uni_vmovups_tail(vmean, tail_opmask, mean_ptr(offt));
226             uni_vmovups_tail(vsqrtvar, tail_opmask, var_ptr(offt));
227         } else {
228             uni_vmovups(vmean, mean_ptr(offt));
229             uni_vmovups(vsqrtvar, var_ptr(offt));
230         }
231     }
232 
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t233     void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
234         if (need_tail) {
235             uni_vmovups_tail(vscale, tail_opmask, scale_ptr(offt));
236         } else {
237             uni_vmovups(vscale, scale_ptr(offt));
238         }
239     }
240 
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t241     void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
242         if (need_tail) {
243             uni_vmovups_tail(vshift, tail_opmask, shift_ptr(offt));
244         } else {
245             uni_vmovups(vshift, shift_ptr(offt));
246         }
247     }
248 
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t249     void compute_dst(bool need_tail) override {
250         Label c_loop;
251         L(c_loop);
252         {
253             Xmm x = Xmm(0);
254             Vmm v = Vmm(0);
255             Vmm vscale = Vmm(1);
256             Vmm vshift = Vmm(2);
257             Vmm vmean = Vmm(3);
258             Vmm vsqrtvar = Vmm(4);
259 
260             // compute single vscale and vshift vectors...
261             compute_vscaleshift(vscale, vshift, vmean, vsqrtvar, 0, need_tail);
262 
263             // ... then process all spatial loop with it and move to the
264             // next channel chunk
265             mov(reg_spat_offt, reg_channel_offt_1byte);
266             Label mb_sp_loop;
267             L(mb_sp_loop);
268             {
269                 if (need_tail) {
270                     for (size_t tl = 0; tl < c_tail_; tl++)
271                         vpinsrb(x, x, src_ptr(tl), tl);
272                     vpmovsxbd(v, x);
273                 } else
274                     vpmovsxbd(v, src_ptr());
275 
276                 vcvtdq2ps(v, v);
277 
278                 uni_vfmadd213ps(v, vscale, vshift);
279                 if (with_relu_) uni_vmaxps(v, v, vzero);
280 
281                 vcvtps2dq(v, v);
282                 if (need_tail) {
283                     vpmovsdb(x, v);
284                     for (size_t tl = 0; tl < c_tail_; tl++)
285                         vpextrb(dst_ptr(tl), x, tl);
286                 } else
287                     vpmovsdb(dst_ptr(), v);
288 
289                 add(reg_spat_offt, reg_channel_offt_count);
290                 cmp(reg_spat_offt, reg_spat_offt_count);
291                 jl(mb_sp_loop);
292             }
293 
294             // reg_tmp checks c_in_xmm_ channels ahead for further tail process
295             add(reg_tmp, sizeof(data_t) * c_in_xmm_);
296             add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
297             add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
298             cmp(reg_tmp, reg_channel_offt_count);
299             jle(c_loop);
300         }
301     }
302 
jit_bnorm_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t303     jit_bnorm_t(const batch_normalization_pd_t *pd)
304         : jit_bnorm_base_t<avx512_core>(pd) {}
305 };
306 
307 template <>
308 struct jit_bnorm_t<avx2> : public jit_bnorm_base_t<avx2> {
309     Vmm tail_vmask = Vmm(11);
310     Vmm body_vmask = Vmm(12);
311 
prepare_tail_maskdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t312     void prepare_tail_mask() override {
313         // tail is always < 16, process it with two parts
314         static const uint32_t mask_half_ymm[8]
315                 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0};
316         mov(reg_tmp, reinterpret_cast<size_t>(&mask_half_ymm[0]));
317         vmovups(body_vmask, ptr[reg_tmp]);
318 
319         if (!c_tail_) return;
320 
321         static const uint32_t mask_f32[14]
322                 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
323                         0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
324 
325         mov(reg_tmp,
326                 reinterpret_cast<size_t>(&mask_f32[7 - c_tail_ % simd_w_]));
327         vmovups(tail_vmask, ptr[reg_tmp]);
328     }
329 
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t330     void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
331             bool need_tail) override {
332         if (need_tail) {
333             uni_vmovups_tail(vmean, tail_vmask, mean_ptr(offt));
334             uni_vmovups_tail(vsqrtvar, tail_vmask, var_ptr(offt));
335         } else {
336             uni_vmovups(vmean, mean_ptr(offt));
337             uni_vmovups(vsqrtvar, var_ptr(offt));
338         }
339     }
340 
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t341     void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
342         if (need_tail) {
343             uni_vmovups_tail(vscale, tail_vmask, scale_ptr(offt));
344         } else {
345             uni_vmovups(vscale, scale_ptr(offt));
346         }
347     }
348 
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t349     void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
350         if (need_tail) {
351             uni_vmovups_tail(vshift, tail_vmask, shift_ptr(offt));
352         } else {
353             uni_vmovups(vshift, shift_ptr(offt));
354         }
355     }
356 
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t357     void compute_dst(bool need_tail) override {
358         Label c_loop;
359         L(c_loop);
360         {
361 
362             Xmm x0 = Xmm(0);
363             Vmm v0 = Vmm(0);
364             Xmm x1 = Xmm(1);
365             Vmm v1 = Vmm(1);
366             Vmm vscale0 = Vmm(2);
367             Vmm vshift0 = Vmm(3);
368             Vmm vmean0 = Vmm(4);
369             Vmm vsqrtvar0 = Vmm(5);
370             Vmm vscale1 = Vmm(6);
371             Vmm vshift1 = Vmm(7);
372             Vmm vmean1 = Vmm(8);
373             Vmm vsqrtvar1 = Vmm(9);
374 
375             // compute couple vscale and vshift vectors each of 8 channels...
376             compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
377                     (c_tail_ < simd_w_ && need_tail) ? true : false);
378             if (!need_tail || c_tail_ > simd_w_) {
379                 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
380                         simd_w_ * sizeof(float), need_tail);
381             }
382 
383             // ... then process all spatial loop with it and move to the
384             // next channel chunk
385             mov(reg_spat_offt, reg_channel_offt_1byte);
386             Label mb_sp_loop;
387             L(mb_sp_loop);
388             {
389 
390                 if (need_tail) {
391                     for (size_t tl = 0; tl < c_tail_; tl++) {
392                         if (tl < simd_w_) {
393                             vpinsrb(x0, x0, src_ptr(tl), tl);
394                         } else {
395                             vpinsrb(x1, x1, src_ptr(tl), tl - simd_w_);
396                         }
397                     }
398                     vpmovsxbd(v0, x0);
399                     vpmovsxbd(v1, x1);
400                 } else {
401                     vpmovsxbd(v0, src_ptr());
402                     vpmovsxbd(v1, src_ptr(simd_w_));
403                 }
404 
405                 vcvtdq2ps(v0, v0);
406                 vcvtdq2ps(v1, v1);
407 
408                 uni_vfmadd213ps(v0, vscale0, vshift0);
409                 uni_vfmadd213ps(v1, vscale1, vshift1);
410                 if (with_relu_) {
411                     uni_vmaxps(v0, v0, vzero);
412                     uni_vmaxps(v1, v1, vzero);
413                 }
414 
415                 vcvtps2dq(v0, v0); // BA
416                 vcvtps2dq(v1, v1); // DC
417                 vpackssdw(v0, v0, v1); // BA + DC -> DBCA
418                 vpermq(v0, v0, 0xD8); // DBCA -> DCBA
419                 vperm2i128(v1, v0, v0, 0x1); // DCBA -> BADC
420                 vpacksswb(v0, v0, v1); // DCBA + BADC -> badcDCBA
421 
422                 if (need_tail) {
423                     for (size_t tl = 0; tl < c_tail_; tl++) {
424                         vpextrb(dst_ptr(tl), x0, tl);
425                     }
426                 } else {
427                     // due to vpacksswb produces 32 integers in ymm, and top
428                     // half of them are garbage, do 128-b masked store
429                     vmaskmovps(dst_ptr(), body_vmask, v0);
430                 }
431 
432                 add(reg_spat_offt, reg_channel_offt_count);
433                 cmp(reg_spat_offt, reg_spat_offt_count);
434                 jl(mb_sp_loop);
435             }
436 
437             // reg_tmp checks c_in_xmm_ channels ahead for further tail process
438             add(reg_tmp, sizeof(data_t) * c_in_xmm_);
439             add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
440             add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
441             cmp(reg_tmp, reg_channel_offt_count);
442             jle(c_loop);
443         }
444     }
445 
jit_bnorm_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t446     jit_bnorm_t(const batch_normalization_pd_t *pd)
447         : jit_bnorm_base_t<avx2>(pd) {}
448 };
449 
450 template <>
451 struct jit_bnorm_t<sse41> : public jit_bnorm_base_t<sse41> {
load_mean_and_vardnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t452     void load_mean_and_var(const Vmm &vmean, const Vmm &vsqrtvar, size_t offt,
453             bool need_tail) override {
454         if (need_tail) {
455             for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
456                 pinsrd(vmean, mean_ptr(offt + tl * sizeof(float)), tl);
457                 pinsrd(vsqrtvar, var_ptr(offt + tl * sizeof(float)), tl);
458             }
459         } else {
460             movups(vmean, mean_ptr(offt));
461             movups(vsqrtvar, var_ptr(offt));
462         }
463     }
464 
load_scalednnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t465     void load_scale(const Vmm &vscale, size_t offt, bool need_tail) override {
466         if (need_tail) {
467             for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
468                 pinsrd(vscale, scale_ptr(offt + tl * sizeof(float)), tl);
469             }
470         } else {
471             movups(vscale, scale_ptr(offt));
472         }
473     }
474 
load_shiftdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t475     void load_shift(const Vmm &vshift, size_t offt, bool need_tail) override {
476         if (need_tail) {
477             for (size_t tl = 0; tl < c_tail_ % simd_w_; tl++) {
478                 pinsrd(vshift, shift_ptr(offt + tl * sizeof(float)), tl);
479             }
480         } else {
481             movups(vshift, shift_ptr(offt));
482         }
483     }
484 
compute_dstdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t485     void compute_dst(bool need_tail) override {
486         const size_t copy_range = need_tail ? c_tail_ : c_in_xmm_;
487         Label c_loop;
488         L(c_loop);
489         {
490 
491             Vmm v0 = Vmm(0);
492             Vmm v1 = Vmm(1);
493             Vmm vscale0 = Vmm(2);
494             Vmm vshift0 = Vmm(3);
495             Vmm vmean0 = Vmm(4);
496             Vmm vsqrtvar0 = Vmm(5);
497             Vmm vscale1 = Vmm(6);
498             Vmm vshift1 = Vmm(7);
499             Vmm vmean1 = Vmm(8);
500             Vmm vsqrtvar1 = Vmm(9);
501 
502             // compute couple vscale and vshift vectors each of 8 channels...
503             compute_vscaleshift(vscale0, vshift0, vmean0, vsqrtvar0, 0,
504                     (c_tail_ < simd_w_ && need_tail) ? true : false);
505             if (!need_tail || c_tail_ > simd_w_) {
506                 compute_vscaleshift(vscale1, vshift1, vmean1, vsqrtvar1,
507                         simd_w_ * sizeof(float), need_tail);
508             }
509 
510             // ... then process all spatial loop with it and move to the
511             // next channel chunk
512             mov(reg_spat_offt, reg_channel_offt_1byte);
513             Label mb_sp_loop;
514             L(mb_sp_loop);
515             {
516                 if (need_tail) {
517                     for (size_t tl = 0; tl < copy_range; tl++) {
518                         if (tl < simd_w_) {
519                             pinsrb(v0, src_ptr(tl), tl);
520                         } else {
521                             pinsrb(v1, src_ptr(tl), (tl - simd_w_));
522                         }
523                     }
524                     pmovsxbd(v0, v0);
525                     pmovsxbd(v1, v1);
526                 } else {
527                     pmovsxbd(v0, src_ptr());
528                     pmovsxbd(v1, src_ptr(simd_w_));
529                 }
530 
531                 cvtdq2ps(v0, v0);
532                 cvtdq2ps(v1, v1);
533 
534                 uni_vfmadd213ps(v0, vscale0, vshift0);
535                 uni_vfmadd213ps(v1, vscale1, vshift1);
536                 if (with_relu_) {
537                     maxps(v0, vzero);
538                     maxps(v1, vzero);
539                 }
540 
541                 cvtps2dq(v0, v0);
542                 cvtps2dq(v1, v1);
543                 packssdw(v0, v1);
544                 movups(v1, v0);
545                 packsswb(v0, v1);
546 
547                 // Potential perf gain is possible if combining two halves
548                 // into a single vector register and use movups instead
549                 // of byte stores.
550                 for (size_t tl = 0; tl < copy_range; tl++) {
551                     pextrb(dst_ptr(tl), v0, tl);
552                 }
553 
554                 add(reg_spat_offt, reg_channel_offt_count);
555                 cmp(reg_spat_offt, reg_spat_offt_count);
556                 jl(mb_sp_loop);
557             }
558 
559             // reg_tmp checks c_in_xmm_ channels ahead for further tail process
560             add(reg_tmp, sizeof(data_t) * c_in_xmm_);
561             add(reg_channel_offt_1byte, sizeof(data_t) * c_in_xmm_);
562             add(reg_channel_offt_4byte, sizeof(float) * c_in_xmm_);
563             cmp(reg_tmp, reg_channel_offt_count);
564             jle(c_loop);
565         }
566     }
567 
jit_bnorm_tdnnl::impl::cpu::x64::__anon29b5e5200111::jit_bnorm_t568     jit_bnorm_t(const batch_normalization_pd_t *pd)
569         : jit_bnorm_base_t<sse41>(pd) {}
570 };
571 
572 } // namespace
573 
574 namespace bnorm_s8_impl {
575 
576 template <cpu_isa_t isa>
577 struct driver_t : public c_compatible {
driver_tdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t578     driver_t(const batch_normalization_pd_t *pd) : pd_(pd), ker_(pd_) {}
579     ~driver_t() = default;
580 
581     // TODO: for problems where thread pieces don't fit L2 cache, add spatial
582     // re-balance using less pieces.
execdnnl::impl::cpu::x64::bnorm_s8_impl::driver_t583     void exec(int ithr, int nthr, const data_t *src, data_t *dst,
584             const float *scale, const float *shift, const float *mean,
585             const float *var) {
586         dim_t N = pd_->MB();
587         dim_t C = pd_->C();
588         dim_t D = pd_->D();
589         dim_t H = pd_->H();
590         dim_t W = pd_->W();
591         dim_t SP = D * H * W;
592 
593         call_params_t p;
594 
595         p.eps = pd_->desc()->batch_norm_epsilon;
596 
597         p.scale = scale;
598         p.shift = shift;
599         p.mean = mean;
600         p.var = var;
601 
602         dim_t work_amount {N * SP}, start {0}, end {0};
603         balance211(work_amount, nthr, ithr, start, end);
604 
605         p.channel_offt_count = C;
606         p.spat_offt_count = (end - start) * p.channel_offt_count;
607         p.src = src + start * p.channel_offt_count;
608         p.dst = dst + start * p.channel_offt_count;
609 
610         if (p.spat_offt_count != 0) ker_(&p);
611     }
612 
create_kerneldnnl::impl::cpu::x64::bnorm_s8_impl::driver_t613     status_t create_kernel() { return ker_.create_kernel(); }
614 
615 private:
616     const batch_normalization_pd_t *pd_;
617 
618     jit_bnorm_t<isa> ker_;
619 };
620 
621 } // namespace bnorm_s8_impl
622 
623 using namespace data_type;
624 using namespace format_tag;
625 using namespace utils;
626 
627 /* fwd */
628 
629 template <cpu_isa_t isa>
init(engine_t * engine)630 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::pd_t::init(
631         engine_t *engine) {
632     auto desired_fmt_tag = (ndims() == 4) ? nhwc : ndhwc;
633 
634     bool ok = true && mayiuse(isa) && is_fwd() && !has_zero_dim_memory()
635             && one_of(ndims(), 4, 5) && stats_is_src()
636             && src_md()->data_type == s8 && check_scale_shift_data_type()
637             && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
638             && (attr()->has_default_values() || this->with_relu_post_op());
639     if (!ok) return status::unimplemented;
640 
641     return status::success;
642 }
643 
644 template <cpu_isa_t isa>
jit_uni_batch_normalization_s8_fwd_t(const pd_t * apd)645 jit_uni_batch_normalization_s8_fwd_t<isa>::jit_uni_batch_normalization_s8_fwd_t(
646         const pd_t *apd)
647     : primitive_t(apd) {}
648 
649 template <cpu_isa_t isa>
init(engine_t * engine)650 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::init(engine_t *engine) {
651     CHECK(safe_ptr_assign(
652             bnorm_driver_, new bnorm_s8_impl::driver_t<isa>(pd())));
653     return bnorm_driver_->create_kernel();
654 }
655 
656 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const657 status_t jit_uni_batch_normalization_s8_fwd_t<isa>::execute(
658         const exec_ctx_t &ctx) const {
659 
660     const memory_desc_wrapper ss_d(pd()->weights_md());
661 
662     const auto use_ss = pd()->use_scaleshift();
663     const auto use_sc = pd()->use_scale();
664     const auto use_sh = pd()->use_shift();
665 
666     const size_t shift_off
667             = use_ss && !ss_d.has_zero_dim() ? ss_d.off(1, 0) : 0;
668 
669     auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
670     auto scale = CTX_IN_MEM(
671             const float *, use_sc ? DNNL_ARG_SCALE : DNNL_ARG_SCALE_SHIFT);
672     auto shift = use_sh ? CTX_IN_MEM(const float *, DNNL_ARG_SHIFT)
673                         : use_ss ? &CTX_IN_MEM(const float *,
674                                   DNNL_ARG_SCALE_SHIFT)[shift_off]
675                                  : nullptr;
676     auto mean = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_MEAN));
677     auto var
678             = const_cast<float *>(CTX_IN_MEM(const float *, DNNL_ARG_VARIANCE));
679     auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
680 
681     // do sequential if the problem is less than one 4K memory page
682     const bool force_sequential
683             = pd()->MB() * pd()->C() * pd()->D() * pd()->H() * pd()->W()
684             <= 4096;
685 
686     parallel(force_sequential ? 1 : 0, [&](const int ithr, const int nthr) {
687         bnorm_driver_->exec(ithr, nthr, src, dst, scale, shift, mean, var);
688     });
689 
690     return status::success;
691 }
692 
693 template <cpu_isa_t isa>
694 jit_uni_batch_normalization_s8_fwd_t<
~jit_uni_batch_normalization_s8_fwd_t()695         isa>::~jit_uni_batch_normalization_s8_fwd_t() {
696     delete bnorm_driver_;
697 }
698 
699 /* struct instantiation */
700 template struct jit_uni_batch_normalization_s8_fwd_t<avx512_core>;
701 template struct jit_uni_batch_normalization_s8_fwd_t<avx2>;
702 template struct jit_uni_batch_normalization_s8_fwd_t<sse41>;
703 
704 } // namespace x64
705 } // namespace cpu
706 } // namespace impl
707 } // namespace dnnl
708