1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include "cpu/x64/jit_uni_i8i8_pooling.hpp"
17 #include <math.h>
18 
19 #include "common/dnnl_thread.hpp"
20 #include "common/utils.hpp"
21 
22 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
23 #include "cpu/x64/jit_generator.hpp"
24 
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29 
get_supported_bcast_strategies()30 static bcast_set_t get_supported_bcast_strategies() {
31     return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc};
32 }
33 
get_offset(const memory_desc_wrapper & mdw,int n,int c,int d,int h,int w)34 static inline dim_t get_offset(
35         const memory_desc_wrapper &mdw, int n, int c, int d, int h, int w) {
36     switch (mdw.ndims()) {
37         case 3: return mdw.blk_off(n, c, w);
38         case 4: return mdw.blk_off(n, c, h, w);
39         case 5: return mdw.blk_off(n, c, d, h, w);
40         default: assert(!"Invalid tensor dimension in pooling");
41     }
42     return 0;
43 }
44 
45 using namespace Xbyak;
46 
47 using namespace dnnl::impl::utils;
48 using namespace dnnl::impl::utils;
49 using namespace dnnl::impl::types;
50 using namespace alg_kind;
51 
52 #define GET_OFF(field) offsetof(call_params_t, field)
53 
54 struct call_params_t {
55     const char *src_i8;
56     const char *dst_i8;
57     const void *post_ops_binary_rhs_arg_vec;
58     size_t kd_range;
59     size_t kh_range;
60     size_t kw_range;
61     float idivider;
62     const char *src_safe_access;
63     const char *dst_safe_access;
64 };
65 
66 template <cpu_isa_t isa>
67 struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator {
68     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
69 
70     using Vmm = typename cpu_isa_traits<isa>::Vmm;
xregdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t71     Xmm xreg(int idx) const { return Xmm(idx); }
yregdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t72     Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
vregdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t73     Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
74 
75     // In case of avx2 with data type i8 we need to use
76     // maskmovdqu and maskmovq instructions which has its destination hardcoded in rdi.
77     // Windows ABI: abi_param1 is rcx - nothing to do else
78     // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
79     Reg64 reg_param = rcx; // Our "unified abi_param1"
80     Reg64 reg_ptr_src_i8 = r8;
81     Reg64 reg_ptr_dst_i8 = r9;
82     Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
83 
84     Reg64 reg_kd_index
85             = rdi; // shared with reg_ptr_maskmovdqu_dst; only used before store
86     Reg64 reg_kh_index = r11;
87     Reg64 reg_kw_index = r10;
88     Reg64 reg_kd = r14;
89     Reg64 reg_kh = r13;
90     Reg64 reg_kw = r12;
91     Reg64 c_iter = r15; // shared with reg_mask; only used after mask init
92 
93     Reg64 aux_reg_src_d
94             = rdx; // shared with reg_tmp; loaded before each accum loop, unused during store
95     Reg64 aux_reg_src_h = rax;
96     Reg64 aux_reg_src_w = rbx;
97 
98     Reg64 reg_tmp = rdx; // only used during mask init and store
99     Reg64 reg_src_safe_access = rbp;
100     Reg64 reg_dst_safe_access = rsi;
101 
102     Reg64 reg_mask = r15; // only used during mask init
103 
104     Opmask k_cmp_mask = Opmask(7);
105 
maskdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t106     Opmask mask(int idx) { return Opmask(6 - idx); }
107 
108     // ref to any of XYZ-regs via xreg/yreg/vreg functions
109     Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
110     Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
111     Vmm vreg_zeros = vreg(1);
112     Vmm vreg_tail = vreg(4);
113 
114     // only in case of <isa> == avx2
115     Vmm vreg_mask = vreg(2); // full byte-mask
116     Xmm xreg_mask_lo = xreg(
117             2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
118     Xmm xreg_mask_hi = xreg(
119             3); // "max" - high 128-bits part of byte-mask (stored separately)
120 
121     // vreg_mask shifted left (aligned left) to be used in tail processing.
122     // Example:       idx [31..0]
123     //          vreg_mask = [0,0,0,0,0,.....,0,x,x,x,x,x] ; x => byte mask (msb set)
124     //          vreg_mask_2 = [x,x,x,x,x,0,0,0,0,0,.....,0]
125     Vmm vreg_mask_2 = vreg(5);
126     Xmm xreg_mask_2_lo = xreg(5); // similar to xreg_mask_lo
127     Xmm xreg_mask_2_hi = xreg(6); // similar to xreg_mask_hi
128 
129     Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
130     Mmx mmx_dst_i8 = Mmx(
131             0); // "avg" - Mmx reg for masked store results of s8/u8 operations
132     Mmx mmx_full_msk = Mmx(
133             1); // "avg" - Mmx reg for full mask (all 8 bytes) - used until not in tail
134     Mmx mmx_tmp = Mmx(2);
135     int post_op_tail_opmask_idx_ = -1;
136     jit_pool_conf_t jpp;
137     std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
138             postops_injector_;
139 
140     enum : int { max_vidx_base = utils::one_of(isa, sse41, avx2) ? 7 : 2 };
141     //"avg" pool uses more registers for unrolling.
142     enum : int { avg_vidx_base = utils::one_of(isa, sse41, avx2) ? 4 : 2 };
143 
max_base_vrdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t144     Vmm max_base_vr(int idx) const { return vreg(max_vidx_base + idx); }
avg_base_vrdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t145     Vmm avg_base_vr(int idx) const { return vreg(avg_vidx_base + idx); }
146 
sizeof_src_dtdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t147     size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
sizeof_dst_dtdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t148     size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
149 
150     /* max pooling */
vreg_srcdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t151     Vmm vreg_src(int idx) const { return max_base_vr(idx); } // [0    .. ur_c-1]
vreg_dstdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t152     Vmm vreg_dst(int idx) const {
153         return max_base_vr(jpp.ur_c + idx);
154     } // [ur_c .. 2*ur_c-1]
155 
156     /* avg pooling */
157     // s32 used for processing of s8/u8 data
158     // thus we need to take into account ratio of sizes s32/i8 = 4
159     static constexpr data_type_t avg_proc_dt = data_type::s32;
160     enum : int {
161         s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
162                 / sizeof(typename prec_traits<data_type::u8>::type),
163         max_num_ll = s32_to_i8_ratio,
164         mmx_msk_base_reg = 3
165     };
166 
vreg_src_s32dnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t167     Vmm vreg_src_s32(int jj, int ll) {
168         return avg_base_vr(3 * max_num_ll * jj + ll + 0 * max_num_ll);
169     } // ll: 0..4 [0..3]
170 
vreg_dst_s32dnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t171     Vmm vreg_dst_s32(int jj, int ll) {
172         return avg_base_vr(3 * max_num_ll * jj + ll + 1 * max_num_ll);
173     } // ll: 0..4 [4..7]
174 
vreg_dst_f32dnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t175     Vmm vreg_dst_f32(int jj, int ll) {
176         return avg_base_vr(3 * max_num_ll * jj + ll + 2 * max_num_ll);
177     } // ll: 0..4 [8..11]
178 
mmx_maskdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t179     Mmx mmx_mask(int ll) {
180         return Mmx(mmx_msk_base_reg + ll);
181     }; // ll: 0..4 [Mmx(2)...Mmx(5)]
182 
183     static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr,
184             const memory_desc_wrapper &dst_d);
185 
186     void init_tmp_reg();
187     void init_mask();
188 
load_vreg_mask_qdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t189     void load_vreg_mask_q(int ll) {};
190 
191     void load_src_max_op(
192             int jj, int ll, size_t offset, bool masked, uint64_t msk);
193     void load_src_avg_op(
194             int jj, int ll, size_t offset, bool masked, uint64_t msk);
195     void load_src(int jj, int ll, int c_tail);
196 
197     void store_dst_max_op(
198             int jj, int ll, size_t offset, bool masked, uint64_t msk);
199     void store_dst_avg_op(
200             int jj, int ll, size_t offset, bool masked, uint64_t msk);
201     void store_dst(int jj, int ll, int c_tail);
202 
203     void compute_avg_step(int ur_c, int c_tail);
204     void compute_max_op(const int jj);
205     void compute_max_step(int ur_c, int c_tail);
206     void compute_step(int ur_c, int c_tail);
207 
208     void compute_c_block();
209     void generate() override;
210 
211     static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd);
212 
jit_uni_i8i8_pooling_fwd_ker_tdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t213     jit_uni_i8i8_pooling_fwd_ker_t(
214             const jit_pool_conf_t &jpp_, const memory_desc_t *dst_md)
215         : jit_generator(nullptr, MAX_CODE_SIZE, true, isa)
216         , jpp(jpp_)
217         , postops_injector_(nullptr) {
218 
219         if (jpp.with_postops) {
220 
221             const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
222             const std::size_t c_tail_elems = jpp.c % simd_w;
223             post_op_tail_opmask_idx_ = 0;
224             if (c_tail_elems) {
225                 for (int ll = max_num_ll - 1; ll >= 0; ll--) {
226                     if (jpp.tail[ll] != 0) {
227                         post_op_tail_opmask_idx_ = ll;
228                         break;
229                     }
230                 }
231             };
232 
233             static constexpr bool preserve_gpr = true;
234             static constexpr bool preserve_vmm = true;
235             static constexpr bool use_exact_tail_scalar_bcast = false;
236             static constexpr std::size_t tmp_vmm_injector = 0u;
237 
238             const binary_injector::rhs_arg_static_params_t rhs_sp {
239                     tmp_vmm_injector, rax, r14, preserve_gpr, preserve_vmm,
240                     GET_OFF(post_ops_binary_rhs_arg_vec),
241                     memory_desc_wrapper(*dst_md), c_tail_elems,
242                     mask(post_op_tail_opmask_idx_),
243                     use_exact_tail_scalar_bcast};
244             const binary_injector::static_params_t bsp {
245                     reg_param, get_supported_bcast_strategies(), rhs_sp};
246 
247             postops_injector_ = utils::make_unique<
248                     injector::jit_uni_postops_injector_t<isa>>(
249                     this, jpp.post_ops, bsp);
250         }
251     }
252 };
253 
254 template <>
load_vreg_mask_q(int ll)255 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_vreg_mask_q(int ll) {};
256 
257 template <>
load_vreg_mask_q(int ll)258 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
259 
260     // extract ll-th part of mask (ll-th QWORD)
261     vpblendd(vreg_mask_q, vreg_zeros, vreg_mask,
262             0x3 << 2 * ll); // 0x3 - mask for 2 x DWORD
263 
264     // Move mask from ll-th pos to 0-th pos
265     if (ll > 0) vpermq(vreg_mask_q, vreg_mask_q, ll);
266 };
267 
268 template <>
load_src_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)269 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_max_op(
270         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
271     using namespace data_type;
272 
273     if (masked) {
274         if (jpp.src_dt == s32)
275             for (int64_t i = 0; i < jpp.c_tail; i++)
276                 pinsrd(vreg_src(jj),
277                         ptr[aux_reg_src_w + offset + i * data_type_size(s32)],
278                         i);
279         else
280             for (int i = 0; i < jpp.c_tail; i++)
281                 pinsrb(vreg_src(jj), ptr[aux_reg_src_w + offset + i], i);
282     } else
283         movups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
284 }
285 
286 template <>
load_src_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)287 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(
288         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
289     using namespace data_type;
290 
291     if (masked) {
292         if (jpp.src_dt == s32) {
293             vpmaskmovd(vreg_src(jj), vreg_mask, ptr[aux_reg_src_w + offset]);
294         } else {
295             // Steps to access 'tail' section:
296             // 1) First load all data from the shifted src ptr
297             // 2) Now bring the required data from the end of reg to beginning.
298             // Example:  idx=[31..0]
299             //    vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data
300             //    shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,]
301             const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail;
302 
303             if (jpp.safe_c_tail) {
304 
305                 /* load src_tail at 'src_address - shift' so that it does not
306                  * spill over the memory boundary */
307                 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset - shift]);
308 
309                 vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81);
310                 vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift);
311 
312             } else {
313                 Label load_data_safely, done;
314                 add(aux_reg_src_w, offset);
315 
316                 // Check if mask crosses page boundary
317                 cmp(aux_reg_src_w, reg_src_safe_access);
318                 ja(load_data_safely, T_NEAR);
319 
320                 vpblendvb(
321                         vreg_src(jj), vreg_tmp, byte[aux_reg_src_w], vreg_mask);
322                 jmp(done, T_NEAR);
323 
324                 L(load_data_safely);
325 
326                 /* load src_tail at 'src_address - shift' so that it does not
327                  * spill over the memory boundary */
328                 vmovups(vreg_src(jj), ptr[aux_reg_src_w - shift]);
329 
330                 vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81);
331                 vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift);
332 
333                 L(done);
334                 sub(aux_reg_src_w, offset);
335             }
336         }
337 
338     } else
339         vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
340 };
341 
342 template <>
load_src_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)343 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(
344         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
345     using namespace data_type;
346 
347     if (masked) {
348         if (jpp.src_dt == s32)
349             vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
350         else
351             vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
352     } else
353         vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
354 };
355 
356 template <>
load_src_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)357 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_avg_op(
358         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
359     using namespace data_type;
360 
361     const Vmm &vr_src = vreg_src_s32(jj, ll);
362 
363     if (jpp.src_dt == s32) {
364         if (masked)
365             for (int64_t i = 0; i < jpp.c_tail; i++)
366                 pinsrd(vr_src,
367                         ptr[aux_reg_src_w + offset + i * data_type_size(s32)],
368                         i);
369         else
370             movups(vr_src, ptr[aux_reg_src_w + offset]);
371     } else if (utils::one_of(jpp.src_dt, s8, u8)) {
372         if (masked) {
373             const int copy_range = math::ilog2q(jpp.tail[ll] + 1);
374             for (int i = 0; i < copy_range; i++)
375                 pinsrb(vr_src, ptr[aux_reg_src_w + offset + i], i);
376 
377             if (jpp.src_dt == s8)
378                 pmovsxbd(vr_src, vr_src);
379             else
380                 pmovzxbd(vr_src, vr_src);
381         } else {
382             if (jpp.src_dt == s8)
383                 pmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
384             else
385                 pmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
386         }
387     } else
388         assert(!"unsupported src data type");
389 }
390 
391 template <>
load_src_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)392 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(
393         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
394     using namespace data_type;
395 
396     auto load_i8 = [&](bool is_signed, const Vmm &vr_src) {
397         // Need to use mask of tail?
398         if (masked) {
399 
400             // load ll-th part of mask into vreg_mask_q
401             load_vreg_mask_q(ll);
402 
403             // Steps to access 'tail' section:
404             // 1) First load all data from the shifted src ptr
405             // 2) Now bring the required data from the end of reg to begining.
406             // Example:  idx=[31..0]
407             //    vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data
408             //    shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,]
409             // Re-purposing vreg_zeros here. Set it back to zero immmediately.
410             const int msk_gran
411                     = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt);
412 
413             const uint8_t shift = cpu_isa_traits<avx2>::vlen
414                     - (jpp.c_tail > (ll + 1) * msk_gran
415                                     ? msk_gran
416                                     : jpp.c_tail - (ll * msk_gran));
417             if (jpp.safe_c_tail) {
418                 /* load src_tail at 'src_address - shift' so that it does not
419                  * spill over the memory boundary */
420                 vmovups(vr_src, ptr[aux_reg_src_w + offset - shift]);
421 
422                 vperm2i128(vreg_zeros, vr_src, vr_src, 0x81);
423                 vpalignr(vr_src, vreg_zeros, vr_src, shift);
424                 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
425             } else {
426                 Label load_data_safely, done;
427                 // assume that it is not safe to load the src_tail
428 
429                 add(aux_reg_src_w, offset);
430 
431                 // Check if load crosses the memory boundary
432                 cmp(aux_reg_src_w, reg_src_safe_access);
433                 ja(load_data_safely, T_NEAR);
434 
435                 vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w], vreg_mask_q);
436                 jmp(done, T_NEAR);
437 
438                 L(load_data_safely);
439 
440                 /* load src_tail at 'src_address - shift' so that it does not
441                  * spill over the memory boundary */
442                 vmovups(vr_src, ptr[aux_reg_src_w - shift]);
443 
444                 vperm2i128(vreg_zeros, vr_src, vr_src, 0x81);
445                 vpalignr(vr_src, vreg_zeros, vr_src, shift);
446                 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
447 
448                 L(done);
449                 sub(aux_reg_src_w, offset);
450             }
451 
452             // Conversion s8/u8 -> s32
453             if (is_signed)
454                 vpmovsxbd(vr_src, vr_src);
455             else
456                 vpmovzxbd(vr_src, vr_src);
457         } else {
458 
459             // Load from mem into vr_src with conversion
460             if (is_signed)
461                 vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
462             else
463                 vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
464         }
465     };
466 
467     switch (jpp.src_dt) {
468         case s32:
469             if (masked)
470                 vpmaskmovd(vreg_src_s32(jj, ll), vreg_mask,
471                         ptr[aux_reg_src_w + offset]);
472             else
473                 vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
474             break;
475         case s8: load_i8(true, vreg_src_s32(jj, ll)); break;
476         case u8: load_i8(false, vreg_src_s32(jj, ll)); break;
477         default: assert(!"unsupported src data type");
478     }
479 };
480 
481 template <>
load_src_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)482 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(
483         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
484     using namespace data_type;
485 
486     const Vmm &vr_src
487             = masked ? vreg_src_s32(jj, ll) | mask(ll) : vreg_src_s32(jj, ll);
488 
489     switch (jpp.src_dt) {
490         case s32: vmovups(vr_src, ptr[aux_reg_src_w + offset]); break;
491         case s8: vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); break;
492         case u8: vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); break;
493         default: assert(!"unsupported src data type");
494     }
495 };
496 
497 template <cpu_isa_t isa>
load_src(int jj,int ll,int c_tail)498 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
499     using namespace data_type;
500 
501     int c_block = jpp.c_block;
502     int ur_c = jpp.ur_c;
503 
504     switch (jpp.alg) {
505         case pooling_max: {
506             auto offset = jj * c_block * sizeof_src_dt();
507             bool masked = jj == ur_c - 1 && c_tail;
508             load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
509             break;
510         }
511         case pooling_avg_include_padding:
512         case pooling_avg_exclude_padding: {
513             auto offset = (ll * (c_block / max_num_ll) + jj * c_block)
514                     * sizeof_src_dt();
515             bool masked = jj == ur_c - 1 && c_tail;
516             load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
517             break;
518         }
519         default: assert(!"unsupported algorithm");
520     }
521 }
522 
523 template <>
store_dst_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)524 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_max_op(
525         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
526     using namespace data_type;
527 
528     if (masked) {
529         if (jpp.src_dt == s32)
530             for (int i = 0; i < jpp.c_tail; i++)
531                 pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)],
532                         vreg_dst(jj), i);
533         else if (utils::one_of(jpp.src_dt, u8, s8))
534             for (int i = 0; i < jpp.c_tail; i++)
535                 pextrb(ptr[reg_ptr_dst_i8 + offset + i], vreg_dst(jj), i);
536         else
537             assert(!"unsupported src data type");
538     } else
539         movups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
540 }
541 
542 template <>
store_dst_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)543 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(
544         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
545     using namespace data_type;
546 
547     Label store_data_safely, done;
548 
549     int c_block = jpp.c_block;
550 
551     const uint64_t low_mask = (1ULL << (c_block / 2)) - 1;
552     const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail;
553 
554     if (masked) {
555         switch (jpp.src_dt) {
556             case s32:
557                 vpmaskmovd(
558                         ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
559                 break;
560             case s8:
561             case u8: {
562 
563                 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
564 
565                 if (!jpp.safe_c_tail) {
566                     Xmm xreg_dst = Xmm(vreg_dst(jj).getIdx());
567 
568                     cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access);
569                     ja(store_data_safely, T_NEAR);
570 
571                     // Store low half by mask (bytes 0...15)
572                     vmaskmovdqu(xreg_dst, xreg_mask_lo);
573 
574                     // Do we need to store high half (bytes 16...31) ?
575                     if (msk & ~low_mask) {
576                         vextracti128(xreg_dst, vreg_dst(jj), 1);
577                         add(reg_ptr_maskmovdqu_dst, c_block / 2);
578                         vmaskmovdqu(xreg_dst, xreg_mask_hi);
579                     }
580                     jmp(done, T_NEAR);
581                 }
582 
583                 L(store_data_safely);
584 
585                 vperm2i128(vreg_tail, vreg_dst(jj), vreg_dst(jj), 0x08);
586                 if (shift <= 16) {
587                     vpalignr(vreg_tail, vreg_dst(jj), vreg_tail, 16 - shift);
588                 } else {
589                     vpalignr(vreg_tail, vreg_tail, vreg_zeros, 32 - shift);
590                 }
591 
592                 Xmm xreg_tail = Xmm(vreg_tail.getIdx());
593                 // Do we need to store low half (bytes 0...15) ?
594                 if (msk & ~low_mask) {
595                     sub(reg_ptr_maskmovdqu_dst, shift);
596                     vmaskmovdqu(xreg_tail, xreg_mask_2_lo);
597                     add(reg_ptr_maskmovdqu_dst, c_block / 2);
598                 } else {
599                     add(reg_ptr_maskmovdqu_dst, (c_block / 2) - shift);
600                 }
601 
602                 // Store high half by mask (bytes 16..31)
603                 vextracti128(xreg_tail, vreg_tail, 1);
604                 vmaskmovdqu(xreg_tail, xreg_mask_2_hi);
605 
606                 L(done);
607             } break;
608             default: assert(!"unsupported src data type");
609         }
610     } else
611         vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
612 }
613 
614 template <>
store_dst_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)615 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(
616         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
617     using namespace data_type;
618 
619     if (masked) {
620         switch (jpp.src_dt) {
621             case s32:
622                 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
623                 break;
624             case s8:
625             case u8:
626                 vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
627                 break;
628             default: assert(!"unsupported src data type");
629         }
630     } else
631         vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
632 }
633 
634 template <>
store_dst_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)635 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_avg_op(
636         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
637     using namespace data_type;
638 
639     // Don't generate useless code
640     if (masked && !msk) return;
641 
642     const Vmm &vr_dst = vreg_dst_s32(jj, ll);
643 
644     if (jpp.src_dt == s32) {
645         if (masked)
646             for (int i = 0; i < jpp.c_tail; i++)
647                 pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)],
648                         vr_dst, i);
649         else
650             movups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
651     } else if (utils::one_of(jpp.src_dt, s8, u8)) {
652         packssdw(vr_dst, vr_dst);
653         if (jpp.src_dt == s8)
654             packsswb(vr_dst, vr_dst);
655         else
656             packuswb(vr_dst, vr_dst);
657 
658         const int copy_range = masked
659                 ? math::ilog2q(jpp.tail[ll] + 1)
660                 : cpu_isa_traits<sse41>::vlen / data_type_size(avg_proc_dt);
661         for (int i = 0; i < copy_range; i++)
662             pextrb(ptr[reg_ptr_dst_i8 + offset + i], vr_dst, i);
663     } else
664         assert(!"unsupported src data type");
665 }
666 
667 template <>
store_dst_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)668 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(
669         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
670     using namespace data_type;
671 
672     // Don't generate useless code
673     if (masked && !msk) return;
674 
675     auto s32_to_i8 = [&](bool is_signed, const Vmm &vr_dst) {
676         // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
677         // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
678         if (is_signed)
679             vpackssdw(vr_dst, vr_dst, vreg_zeros);
680         else
681             vpackusdw(vr_dst, vr_dst, vreg_zeros);
682 
683         // Permute qwords to restore original order
684         // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
685         vpermq(vr_dst, vr_dst, 0x58);
686 
687         // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
688         // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
689         if (is_signed)
690             vpacksswb(vr_dst, vr_dst, vreg_zeros);
691         else
692             vpackuswb(vr_dst, vr_dst, vreg_zeros);
693     };
694 
695     auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm &vr_dst) {
696         // Conversion s32 -> s8/u8
697         s32_to_i8(is_signed, vr_dst);
698 
699         // early-out for non-masked cases
700         if (!is_masked) {
701             vmovlps(ptr[reg_ptr_dst_i8 + offset], Xmm(vr_dst.getIdx()));
702             return;
703         }
704         // store 8 bytes
705         lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
706 
707         // Need to use mmx 8-bytes operation to avoid memory violations.
708         // NOTICE: it was discovered that Intel SSE and Intel AVX instructions
709         //         maskmovdqu/vmaskmovdqu
710         //         with low 8-bytes mask throws exception if high 8-bytes belongs write-protected page.
711         // NOTE: use indirect move via gpr to avoid transition penalty
712         vmovq(reg_tmp, Xmm(vr_dst.getIdx()));
713         movq(mmx_dst_i8, reg_tmp);
714 
715         // mmx_full_msk - mask for all 8 bytes in zero-tail case
716         // mmx_mask(ll) - ll-th mask of tail in non-zero-tail case
717 
718         const int msk_gran
719                 = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt);
720 
721         const int ll_end = (ll + 1) * msk_gran; // ((ll + 1) * 8)
722 
723         if (is_masked && (ll_end > jpp.c_tail)) { //implies this tail not full.
724             Label store_data_safely, done;
725             const uint8_t shift = msk_gran - jpp.c_tail % msk_gran;
726 
727             if (!jpp.safe_c_tail) {
728                 cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access);
729                 ja(store_data_safely, T_NEAR);
730 
731                 /* store dst_tail with overlap outside the channel dimension,
732                  * but assume it's within the memory boundary. */
733                 maskmovq(mmx_dst_i8, mmx_mask(ll));
734                 jmp(done, T_NEAR);
735             }
736 
737             L(store_data_safely);
738 
739             /* store dst_tail at 'dst_address - shift' so that it does not
740              * spill over the memory boundary */
741             movq(mmx_tmp, mmx_mask(ll));
742             psllq(mmx_tmp, shift * 8); // multiply with 8 (bits/byte)
743             psllq(mmx_dst_i8, shift * 8);
744             sub(reg_ptr_maskmovdqu_dst, shift);
745             maskmovq(mmx_dst_i8, mmx_tmp);
746 
747             L(done);
748         } else {
749             maskmovq(mmx_dst_i8, mmx_full_msk);
750         }
751     };
752 
753     switch (jpp.dst_dt) {
754         case s32:
755             if (masked) {
756                 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask,
757                         vreg_dst_s32(jj, ll));
758             } else
759                 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
760             break;
761         case s8: store_i8(true, masked, vreg_dst_s32(jj, ll)); break;
762         case u8: store_i8(false, masked, vreg_dst_s32(jj, ll)); break;
763         default: assert(!"unsuppotred dst data_type");
764     }
765 }
766 
767 template <>
store_dst_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)768 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(
769         int jj, int ll, size_t offset, bool masked, uint64_t msk) {
770     using namespace data_type;
771 
772     // Don't generate useless code
773     if (masked && !msk) return;
774 
775     const Vmm &vr_dst
776             = masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll);
777 
778     switch (jpp.dst_dt) {
779         case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
780         case s8: vpmovsdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
781         case u8: vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
782         default: assert(!"unsupported dst data_type");
783     }
784 }
785 
786 template <cpu_isa_t isa>
store_dst(int jj,int ll,int c_tail)787 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(
788         int jj, int ll, int c_tail) {
789     using namespace data_type;
790 
791     int c_block = jpp.c_block;
792     int ur_c = jpp.ur_c;
793 
794     switch (jpp.alg) {
795         case pooling_max: {
796             auto offset = jj * c_block * sizeof_dst_dt();
797             bool masked = jj == ur_c - 1 && c_tail;
798             store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
799             break;
800         }
801         case pooling_avg_include_padding:
802         case pooling_avg_exclude_padding: {
803             auto offset = (ll * (c_block / max_num_ll) + jj * c_block)
804                     * sizeof_dst_dt();
805             bool masked = jj == ur_c - 1 && c_tail;
806             store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
807             break;
808         }
809         default: assert(!"unsupported pooling algorithm");
810     }
811 }
812 
813 template <>
compute_max_op(const int jj)814 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::compute_max_op(const int jj) {
815     using namespace data_type;
816     switch (jpp.src_dt) {
817         case s32: pmaxsd(vreg_dst(jj), vreg_src(jj)); break;
818         case s8: pmaxsb(vreg_dst(jj), vreg_src(jj)); break;
819         case u8: pmaxub(vreg_dst(jj), vreg_src(jj)); break;
820         default: assert(!"unsupported src data type");
821     }
822 }
823 
824 template <>
compute_max_op(const int jj)825 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj) {
826     using namespace data_type;
827     switch (jpp.src_dt) {
828         case s32: vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
829         case s8: vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
830         case u8: vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
831         default: assert(!"unsupported src data type");
832     }
833 }
834 
835 template <>
compute_max_op(const int jj)836 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj) {
837     using namespace data_type;
838 
839     // Compare
840     switch (jpp.src_dt) {
841         case s32:
842             vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
843             break;
844         case s8:
845             vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
846             break;
847         case u8:
848             vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
849             break;
850         default: assert(!"unsupported src data type");
851     }
852 
853     // move max values into vreg_dst
854     if (jpp.src_dt == s32)
855         vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
856     else
857         vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
858 }
859 
860 template <cpu_isa_t isa>
compute_max_step(int ur_c,int c_tail)861 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(
862         int ur_c, int c_tail) {
863     Label l_kd, l_kh, l_kw;
864 
865     int ih = jpp.ih;
866     int iw = jpp.iw;
867     int c = jpp.c;
868 
869     for (int jj = 0; jj < ur_c; jj++)
870         uni_vmovups(vreg_dst(jj), vreg_tmp);
871 
872     mov(aux_reg_src_d, reg_ptr_src_i8);
873     xor_(reg_kd_index, reg_kd_index);
874     L(l_kd);
875     {
876         mov(aux_reg_src_h, aux_reg_src_d);
877         xor_(reg_kh_index, reg_kh_index);
878         L(l_kh);
879         {
880             mov(aux_reg_src_w, aux_reg_src_h);
881             xor_(reg_kw_index, reg_kw_index);
882             L(l_kw);
883             {
884                 for (int jj = 0; jj < ur_c; jj++) {
885                     load_src(jj, 0, c_tail);
886                     compute_max_op(jj);
887                 }
888                 add(aux_reg_src_w, c * sizeof_src_dt());
889                 inc(reg_kw_index);
890                 cmp(reg_kw_index, reg_kw);
891                 jl(l_kw, T_NEAR);
892             }
893             add(aux_reg_src_h, iw * c * sizeof_src_dt());
894             inc(reg_kh_index);
895             cmp(reg_kh_index, reg_kh);
896             jl(l_kh, T_NEAR);
897         }
898         add(aux_reg_src_d, ih * iw * c * sizeof_src_dt());
899         inc(reg_kd_index);
900         cmp(reg_kd_index, reg_kd);
901         jl(l_kd, T_NEAR);
902     }
903 
904     for (int jj = 0; jj < ur_c; jj++)
905         store_dst(jj, 0, c_tail);
906 }
907 
908 template <cpu_isa_t isa>
compute_avg_step(int ur_c,int c_tail)909 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(
910         int ur_c, int c_tail) {
911     using namespace data_type;
912 
913     Label l_kd, l_kh, l_kw;
914 
915     int ih = jpp.ih;
916     int iw = jpp.iw;
917     int c = jpp.c;
918 
919     const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.src_dt);
920 
921     for (int jj = 0; jj < ur_c; jj++) {
922         for (int ll = 0; ll < num_ll; ll++) {
923             bool masked = jj == ur_c - 1 && c_tail;
924             size_t msk = jpp.tail[ll];
925             if (!(masked && !msk)) {
926                 // Clearing of src reg is not needed as they are written before read
927                 uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
928                         vreg_dst_s32(jj, ll));
929             }
930         }
931     }
932 
933     mov(aux_reg_src_d, reg_ptr_src_i8);
934     xor_(reg_kd_index, reg_kd_index);
935     L(l_kd);
936     {
937         mov(aux_reg_src_h, aux_reg_src_d);
938         xor_(reg_kh_index, reg_kh_index);
939         L(l_kh);
940         {
941             mov(aux_reg_src_w, aux_reg_src_h);
942             xor_(reg_kw_index, reg_kw_index);
943             L(l_kw);
944             {
945                 for (int jj = 0; jj < ur_c; jj++) {
946                     for (int ll = 0; ll < num_ll; ll++) {
947                         bool masked = jj == ur_c - 1 && c_tail;
948                         size_t msk = jpp.tail[ll];
949                         if (!(masked && !msk)) {
950                             load_src(jj, ll, c_tail);
951                             uni_vpaddd(vreg_dst_s32(jj, ll),
952                                     vreg_dst_s32(jj, ll), vreg_src_s32(jj, ll));
953                         }
954                     }
955                 }
956                 add(aux_reg_src_w, c * sizeof_src_dt());
957                 inc(reg_kw_index);
958                 cmp(reg_kw_index, reg_kw);
959                 jl(l_kw, T_NEAR);
960             }
961             add(aux_reg_src_h, iw * c * sizeof_src_dt());
962             inc(reg_kh_index);
963             cmp(reg_kh_index, reg_kh);
964             jl(l_kh, T_NEAR);
965         }
966         add(aux_reg_src_d, ih * iw * c * sizeof_src_dt());
967         inc(reg_kd_index);
968         cmp(reg_kd_index, reg_kd);
969         jl(l_kd, T_NEAR);
970     }
971 
972     static constexpr int vlen_size_elem
973             = cpu_isa_traits<isa>::vlen / sizeof(float);
974     const auto reg_tmp_postops = r15;
975     const injector_utils::register_preserve_guard_t reg_guard(this,
976             jpp.with_binary
977                     ? std::initializer_list<Xbyak::Reg64> {reg_tmp_postops}
978                     : std::initializer_list<Xbyak::Reg64> {},
979             {});
980     if (jpp.with_binary) {
981         imul(reg_tmp_postops, c_iter, ur_c * num_ll * vlen_size_elem);
982     }
983 
984     for (int jj = 0; jj < ur_c; jj++) {
985         for (int ll = 0; ll < num_ll; ll++) {
986             const bool masked = jj == ur_c - 1 && c_tail;
987             const size_t msk = jpp.tail[ll];
988             if (!(masked && !msk)) {
989                 const auto &reg_dst_f32 = vreg_dst_f32(jj, ll);
990                 const auto &reg_dst_s32 = vreg_dst_s32(jj, ll);
991                 uni_vcvtdq2ps(reg_dst_f32, reg_dst_s32);
992                 uni_vfmadd132ps(reg_dst_f32, vreg_zeros, vreg_tmp);
993 
994                 if (jpp.with_postops) {
995                     binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
996                     if (jpp.with_binary) {
997                         rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
998                                 reg_dst_f32.getIdx(), reg_tmp_postops);
999                         rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(
1000                                 reg_dst_f32.getIdx(),
1001                                 ll * vlen_size_elem + jj * vlen_size_elem);
1002                         rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
1003                                 reg_dst_f32.getIdx(), reg_tmp_postops);
1004                         rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(
1005                                 reg_dst_f32.getIdx(),
1006                                 ll * vlen_size_elem + jj * vlen_size_elem);
1007                         const bool tail = ll == post_op_tail_opmask_idx_;
1008                         if (tail && masked)
1009                             rhs_arg_params.vmm_tail_idx_.emplace(
1010                                     reg_dst_f32.getIdx());
1011                     }
1012                     postops_injector_->compute_vector(
1013                             reg_dst_f32.getIdx(), rhs_arg_params);
1014                 }
1015 
1016                 uni_vcvtps2dq(reg_dst_s32, reg_dst_f32);
1017 
1018                 if (jpp.with_postops)
1019                     if (jpp.dst_dt == u8) {
1020                         uni_vpmaxsd(reg_dst_s32, reg_dst_s32, vreg_zeros);
1021                     }
1022                 store_dst(jj, ll, c_tail);
1023             }
1024         }
1025     }
1026 }
1027 
1028 template <cpu_isa_t isa>
compute_step(int ur_c,int c_tail)1029 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
1030     switch (jpp.alg) {
1031         case pooling_max: compute_max_step(ur_c, c_tail); break;
1032         case pooling_avg_include_padding:
1033         case pooling_avg_exclude_padding: compute_avg_step(ur_c, c_tail); break;
1034         default: assert(!"unsupported pooling algorithm");
1035     }
1036 }
1037 
1038 template <cpu_isa_t isa>
compute_c_block()1039 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block() {
1040     Label l_main_loop;
1041 
1042     int nb_c = jpp.nb_c;
1043     int c_block = jpp.c_block;
1044     int ur_c = jpp.ur_c;
1045     int ur_c_tail = jpp.ur_c_tail;
1046     int c_steps = nb_c / ur_c;
1047     int c_tail = jpp.c_tail;
1048 
1049     xor_(c_iter, c_iter);
1050     if (c_steps > 0) {
1051         L(l_main_loop);
1052         {
1053             compute_step(ur_c, 0);
1054             add(reg_ptr_src_i8, ur_c * c_block * sizeof_src_dt());
1055             add(reg_ptr_dst_i8, ur_c * c_block * sizeof_dst_dt());
1056             inc(c_iter);
1057             cmp(c_iter, c_steps);
1058             jl(l_main_loop, T_NEAR);
1059         }
1060     }
1061 
1062     if (ur_c_tail != 0) { compute_step(ur_c_tail, c_tail); }
1063 }
1064 
1065 template <>
init_mask()1066 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::init_mask() {}
1067 
1068 template <>
init_mask()1069 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
1070     using namespace data_type;
1071     using cpu_isa = cpu_isa_traits<avx2>;
1072 
1073     // AVX2 mask initialization: mask stored in Ymm-regs
1074     auto init = [&](uint64_t bit_mask, bool need_ymm_mask = true,
1075                         bool need_mmx_mask = false) {
1076         const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
1077 
1078         const size_t DBITS = 8 * sizeof_src_dt();
1079         const uint64_t VMSK = 1ULL << (DBITS - 1);
1080         const size_t D_PER_QW = (8 * sizeof(uint64_t)) / DBITS;
1081         uint64_t vmask[QW_PER_VREG];
1082         for (size_t i = 0; i < QW_PER_VREG; i++) {
1083             uint64_t qw_vmask = 0ULL;
1084             for (size_t j = 0; j < D_PER_QW; j++) {
1085                 if (bit_mask & 1) qw_vmask |= VMSK << DBITS * j;
1086                 bit_mask >>= 1;
1087             }
1088             vmask[i] = qw_vmask;
1089         }
1090 
1091         // Need mask in Ymm regs ?
1092         if (need_ymm_mask) {
1093 
1094             // Put QWORDS with target mask into xmm regs
1095             const int xdst_i[QW_PER_VREG]
1096                     = {xreg_mask_lo.getIdx(), xreg_mask_lo.getIdx(),
1097                             xreg_mask_hi.getIdx(), xreg_mask_hi.getIdx()};
1098             const int xsrc_i[QW_PER_VREG] = {
1099                     vreg_zeros
1100                             .getIdx(), // 0-th qword insert in zeros -> {qw0,  0}
1101                     xreg_mask_lo
1102                             .getIdx(), // 1-st and 0-th merge        -> {qw0,qw1}
1103                     vreg_zeros.getIdx(), xreg_mask_hi.getIdx()};
1104             const uint8 qw_dst_idx[QW_PER_VREG]
1105                     = {0, 1, 0, 1}; // qword index in 128-bit xreg
1106 
1107             for (size_t i = 0; i < QW_PER_VREG; i++) {
1108                 mov(reg_mask, vmask[i]);
1109                 vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask,
1110                         qw_dst_idx[i]);
1111 
1112                 // Need mask in MMX regs also?
1113                 if (need_mmx_mask)
1114                     movq(mmx_mask(i), reg_mask); // reuse value in reg_mask
1115             }
1116 
1117             // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
1118             // and High (xreg_mask_hi) into full vreg_mask
1119             // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
1120             vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
1121 
1122             // Compute mask algned to left from vreg_mask and store it in vreg_mask_2 to be use for tail processing.
1123             const uint8_t shift = 32 - jpp.c_tail;
1124             vperm2i128(vreg_mask_2, vreg_mask, vreg_mask, 0x08);
1125             if (shift <= 16) {
1126                 vpalignr(vreg_mask_2, vreg_mask, vreg_mask_2, 16 - shift);
1127             } else {
1128                 vpalignr(vreg_mask_2, vreg_mask_2, vreg_zeros, 32 - shift);
1129             }
1130             vextracti128(xreg_mask_2_hi, vreg_mask_2, 0x1);
1131         }
1132 
1133         // Need mask in MMX regs ?
1134         if (need_mmx_mask) {
1135 
1136             // Only in MMX regs ?
1137             if (!need_ymm_mask)
1138                 for (size_t i = 0; i < QW_PER_VREG; i++) {
1139                     mov(reg_mask, vmask[i]);
1140                     movq(mmx_mask(i), reg_mask);
1141                 }
1142 
1143             // Form full mask for one QWORD
1144             uint64_t qw_full_vmask = 0ULL;
1145             for (size_t i = 0; i < D_PER_QW; i++)
1146                 qw_full_vmask |= VMSK << DBITS * i;
1147 
1148             mov(reg_mask, qw_full_vmask);
1149             movq(mmx_full_msk, reg_mask);
1150         }
1151     };
1152 
1153     uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
1154     switch (jpp.alg) {
1155         case pooling_max:
1156             // For "max" we need mask only in case of non-zero tail
1157             if (tail_mask) init(tail_mask);
1158             break;
1159         case pooling_avg_include_padding:
1160         case pooling_avg_exclude_padding:
1161             // For "avg" we need mask:
1162             // - s32   - in case of the non-zero tail
1163             // - s8/u8 - irrespective of the tail in MMX regs (always store by mask)
1164             //         - for non-zero tail in Ymm regs (for load)
1165             switch (jpp.src_dt) {
1166                 case s32:
1167                     if (tail_mask) init(tail_mask);
1168                     break;
1169                 case s8:
1170                 case u8:
1171                     init(tail_mask ? tail_mask : ~0ULL, tail_mask != 0, true);
1172                     break;
1173                 default: assert(!"unsupported src data type");
1174             }
1175             break;
1176         default: assert(!"unsupported pooling algorithm");
1177     }
1178 }
1179 
1180 template <>
init_mask()1181 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
1182 
1183     for (int ll = 0; ll < max_num_ll; ll++) {
1184         mov(reg_mask, jpp.tail[ll]);
1185         kmovq(mask(ll), reg_mask);
1186     }
1187 }
1188 
1189 template <cpu_isa_t isa>
init_tmp_reg()1190 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
1191     using namespace data_type;
1192 
1193     switch (jpp.alg) {
1194         case pooling_avg_include_padding:
1195         case pooling_avg_exclude_padding:
1196             mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
1197             uni_vmovq(xmm_tmp, reg_tmp);
1198             uni_vpbroadcastd(vreg_tmp, xmm_tmp);
1199             break;
1200         case pooling_max:
1201             switch (jpp.src_dt) {
1202                 case s32:
1203                     mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
1204                     break;
1205                 case s8:
1206                     mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
1207                     break;
1208                 case u8:
1209                     mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
1210                     break;
1211                 default: assert(!"unsupported src data_type");
1212             }
1213 
1214             uni_vmovq(xmm_tmp, reg_tmp);
1215             if (jpp.src_dt == s32)
1216                 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
1217             else if (mayiuse(avx2))
1218                 vpbroadcastb(vreg_tmp, xmm_tmp);
1219             else
1220                 pshufb(xmm_tmp, vreg_zeros);
1221             break;
1222         default: assert(!"unsupported pooling algorithm");
1223     }
1224 }
1225 
1226 template <cpu_isa_t isa>
generate()1227 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
1228     preamble();
1229 
1230 #if !defined(_WIN32)
1231     // Always use rcx as abi_param1 -
1232     // see the note about maskmovdqu/maskmovq near reg_param.
1233     mov(rcx, rdi);
1234 #endif
1235 
1236 #define READ_PARAM(reg, field) \
1237     mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
1238     READ_PARAM(reg_ptr_src_i8, src_i8);
1239     READ_PARAM(reg_ptr_dst_i8, dst_i8);
1240     READ_PARAM(reg_kd, kd_range);
1241     READ_PARAM(reg_kh, kh_range);
1242     READ_PARAM(reg_kw, kw_range);
1243     READ_PARAM(reg_src_safe_access, src_safe_access);
1244     READ_PARAM(reg_dst_safe_access, dst_safe_access);
1245 
1246 #undef READ_PARAM
1247 
1248     uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
1249 
1250     init_mask();
1251 
1252     init_tmp_reg();
1253 
1254     compute_c_block();
1255 
1256     emms();
1257     postamble();
1258 
1259     if (jpp.with_eltwise && postops_injector_)
1260         postops_injector_->prepare_table();
1261 }
1262 
1263 template <cpu_isa_t isa>
init_conf(jit_pool_conf_t & jpp,const pooling_pd_t * ppd)1264 status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(
1265         jit_pool_conf_t &jpp, const pooling_pd_t *ppd) {
1266     if (!mayiuse(isa)) return status::unimplemented;
1267 
1268     const auto &pd = *ppd->desc();
1269     const memory_desc_wrapper src_d(ppd->src_md());
1270     const memory_desc_wrapper dst_d(ppd->dst_md());
1271     const int ndims = src_d.ndims();
1272     const bool is_1d = ndims == 3;
1273     const bool is_3d = ndims == 5;
1274 
1275     jpp.mb = src_d.dims()[0];
1276     jpp.c = src_d.dims()[1];
1277 
1278     jpp.id = is_3d ? src_d.dims()[ndims - 3] : 1;
1279     jpp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1280     jpp.iw = src_d.dims()[ndims - 1];
1281 
1282     jpp.od = is_3d ? dst_d.dims()[ndims - 3] : 1;
1283     jpp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1284     jpp.ow = dst_d.dims()[ndims - 1];
1285 
1286     jpp.stride_d = is_3d ? pd.strides[ndims - 5] : 1;
1287     jpp.stride_h = is_1d ? 1 : pd.strides[ndims - 4];
1288     jpp.stride_w = pd.strides[ndims - 3];
1289 
1290     jpp.kd = is_3d ? pd.kernel[ndims - 5] : 1;
1291     jpp.kh = is_1d ? 1 : pd.kernel[ndims - 4];
1292     jpp.kw = pd.kernel[ndims - 3];
1293 
1294     jpp.f_pad = is_3d ? pd.padding[0][ndims - 5] : 0;
1295     jpp.t_pad = is_1d ? 0 : pd.padding[0][ndims - 4];
1296     jpp.l_pad = pd.padding[0][ndims - 3];
1297 
1298     int back_pad = calculate_end_padding(
1299             jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd);
1300     int bottom_pad = calculate_end_padding(
1301             jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh);
1302     int right_pad = calculate_end_padding(
1303             jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw);
1304 
1305     if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
1306             || back_pad >= jpp.kd || bottom_pad >= jpp.kh
1307             || right_pad >= jpp.kw)
1308         return status::unimplemented;
1309 
1310     jpp.alg = pd.alg_kind;
1311 
1312     jpp.src_dt = pd.src_desc.data_type;
1313     jpp.dst_dt = pd.dst_desc.data_type;
1314 
1315     // data_type items per one vreg on the <isa>
1316     //     isa == sse41   : 16 bytes -> 16 for s8/u8, 4 for s32
1317     //     isa == avx2    : 32 bytes -> 32 for s8/u8, 8 for s32
1318     //     isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
1319     int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
1320 
1321     /* Verify that vlen-sized memory access happens within the tensor's
1322      * size, otherwise load/store will always spill outside the memory
1323      * boundary.*/
1324     bool safe_load_n_store = IMPLICATION(utils::one_of(isa, avx2, sse41),
1325             jpp.mb * jpp.c * nstl::min(jpp.id, jpp.od)
1326                             * nstl::min(jpp.ih, jpp.oh)
1327                             * nstl::min(jpp.iw, jpp.ow)
1328                     >= simd_w);
1329     if (!safe_load_n_store) return status::unimplemented;
1330 
1331     jpp.c_block = simd_w;
1332     jpp.c_tail = jpp.c % jpp.c_block;
1333     jpp.nb_c = jpp.c / jpp.c_block;
1334     jpp.ur_c = 1;
1335     jpp.ur_c_tail = jpp.c_tail != 0;
1336 
1337     size_t tail_mask = (1ULL << jpp.c_tail) - 1;
1338 
1339     /* If channel_size is bigger than vlen, we can safely assume there is no
1340      * underflow of memory boundary, so always perform c_tail and save
1341      * a couple of compute cycles*/
1342     jpp.safe_c_tail = jpp.c_tail > 0 && jpp.c >= simd_w;
1343 
1344     switch (jpp.alg) {
1345         case pooling_max:
1346             jpp.tail[0] = tail_mask;
1347             jpp.tail[1] = 0;
1348             jpp.tail[2] = 0;
1349             jpp.tail[3] = 0;
1350             break;
1351         case pooling_avg_include_padding:
1352         case pooling_avg_exclude_padding: {
1353             // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
1354             // sse : 4, avx2 : 8, avx512 : 16
1355             const size_t msk_gran
1356                     = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
1357             const size_t msk_msk = (1ULL << msk_gran) - 1;
1358             size_t m = tail_mask;
1359             for (size_t ll = 0; ll < max_num_ll; ll++) {
1360                 jpp.tail[ll] = m & msk_msk;
1361                 m = m >> msk_gran;
1362             }
1363             break;
1364         }
1365         default: return status::unimplemented;
1366     }
1367 
1368     if (!post_ops_ok(jpp, *ppd->attr(), dst_d)) return status::unimplemented;
1369 
1370     return status::success;
1371 }
1372 
1373 template <cpu_isa_t isa>
post_ops_ok(jit_pool_conf_t & jpp,const primitive_attr_t & attr,const memory_desc_wrapper & dst_d)1374 bool jit_uni_i8i8_pooling_fwd_ker_t<isa>::post_ops_ok(jit_pool_conf_t &jpp,
1375         const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
1376     const auto &post_ops = attr.post_ops_;
1377     const auto &entries = post_ops.entry_;
1378     jpp.with_postops = false;
1379     jpp.with_eltwise = false;
1380     jpp.with_binary = false;
1381 
1382     if (entries.empty()) return true;
1383 
1384     for (const auto &entry : entries) {
1385         if (entry.is_eltwise()) {
1386             const auto alg = entry.eltwise.alg;
1387             jpp.with_eltwise = eltwise_injector::is_supported(isa, alg);
1388         } else if (entry.is_binary()) {
1389             if (isa != avx512_core
1390                     && entry.binary.src1_desc.data_type == data_type::bf16)
1391                 return false;
1392             jpp.with_binary = true;
1393         } else
1394             return false;
1395     }
1396 
1397     jpp.with_postops = jpp.with_eltwise || jpp.with_binary;
1398     jpp.post_ops = post_ops;
1399 
1400     /*
1401      * TODO Currently eltwise/binary injectors assumes that data in vmm has f32 dt.
1402      * In max pooling data remains in i8 data type.
1403      */
1404     return IMPLICATION(jpp.with_postops, jpp.alg != pooling_max)
1405             && binary_injector::binary_args_broadcast_supported(
1406                     post_ops, dst_d, get_supported_bcast_strategies());
1407 }
1408 
1409 template <cpu_isa_t isa>
jit_conf()1410 status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
1411     return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this);
1412 }
1413 
1414 template <cpu_isa_t isa>
jit_uni_i8i8_pooling_fwd_t(const pd_t * apd)1415 jit_uni_i8i8_pooling_fwd_t<isa>::jit_uni_i8i8_pooling_fwd_t(const pd_t *apd)
1416     : primitive_t(apd), ker_(nullptr) {}
1417 
1418 template <cpu_isa_t isa>
1419 jit_uni_i8i8_pooling_fwd_t<isa>::~jit_uni_i8i8_pooling_fwd_t() = default;
1420 
1421 template <cpu_isa_t isa>
init(engine_t * engine)1422 status_t jit_uni_i8i8_pooling_fwd_t<isa>::init(engine_t *engine) {
1423     CHECK(safe_ptr_assign(ker_,
1424             new jit_uni_i8i8_pooling_fwd_ker_t<isa>(
1425                     pd()->jpp_, pd()->invariant_dst_md())));
1426     return ker_->create_kernel();
1427 }
1428 
1429 template <cpu_isa_t isa>
execute_forward(const exec_ctx_t & ctx) const1430 status_t jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward(
1431         const exec_ctx_t &ctx) const {
1432     auto src_i8 = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1433     auto dst_i8 = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1434 
1435     const memory_desc_wrapper src_d(pd()->src_md());
1436     const memory_desc_wrapper dst_d(pd()->dst_md());
1437 
1438     const auto &jpp = pd()->jpp_;
1439     const auto post_ops_binary_rhs_arg_vec
1440             = binary_injector::prepare_binary_args(jpp.post_ops, ctx);
1441     /* Calculate when the memory-access will happen outisde of the memory
1442      * boundary, if so, compute a safe memory access. */
1443     const auto src_safe_access = reinterpret_cast<char *>(
1444             reinterpret_cast<ptrdiff_t>(src_i8 + src_d.size() - 1)
1445             - (cpu_isa_traits<isa>::vlen - 1));
1446 
1447     const auto dst_safe_access = reinterpret_cast<char *>(
1448             reinterpret_cast<ptrdiff_t>(dst_i8 + dst_d.size() - 1)
1449             - (cpu_isa_traits<isa>::vlen - 1));
1450 
1451     parallel_nd(
1452             jpp.mb, jpp.od, jpp.oh, jpp.ow, [&](int n, int od, int oh, int ow) {
1453                 const int id = nstl::max(od * jpp.stride_d - jpp.f_pad, 0);
1454                 const int ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, 0);
1455                 const int iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, 0);
1456 
1457                 const int kd_start
1458                         = nstl::max(0, jpp.f_pad - od * jpp.stride_d);
1459                 const int kd_end = nstl::min(
1460                         jpp.kd, jpp.id + jpp.f_pad - od * jpp.stride_d);
1461                 const int kh_start
1462                         = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
1463                 const int kh_end = nstl::min(
1464                         jpp.kh, jpp.ih + jpp.t_pad - oh * jpp.stride_h);
1465                 const int kw_start
1466                         = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
1467                 const int kw_end = nstl::min(
1468                         jpp.kw, jpp.iw + jpp.l_pad - ow * jpp.stride_w);
1469 
1470                 auto p = call_params_t();
1471                 p.src_i8 = &src_i8[get_offset(src_d, n, 0, id, ih, iw)
1472                         * src_d.data_type_size()];
1473                 p.dst_i8 = &dst_i8[get_offset(dst_d, n, 0, od, oh, ow)
1474                         * dst_d.data_type_size()];
1475                 p.kd_range = (size_t)(kd_end - kd_start);
1476                 p.kh_range = (size_t)(kh_end - kh_start);
1477                 p.kw_range = (size_t)(kw_end - kw_start);
1478                 p.idivider = 1.0f
1479                         / ((jpp.alg == pooling_avg_exclude_padding)
1480                                         ? p.kd_range * p.kh_range * p.kw_range
1481                                         : jpp.kd * jpp.kh * jpp.kw);
1482                 p.src_safe_access = src_safe_access;
1483                 p.dst_safe_access = dst_safe_access;
1484                 p.post_ops_binary_rhs_arg_vec
1485                         = post_ops_binary_rhs_arg_vec.data();
1486                 (*ker_)(&p);
1487             });
1488     return status::success;
1489 }
1490 
1491 // Explicit instantiation only for supported <isa> values.
1492 //
1493 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
1494 template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
1495 
1496 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
1497 template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
1498 
1499 template struct jit_uni_i8i8_pooling_fwd_ker_t<sse41>;
1500 template struct jit_uni_i8i8_pooling_fwd_t<sse41>;
1501 
1502 } // namespace x64
1503 } // namespace cpu
1504 } // namespace impl
1505 } // namespace dnnl
1506