1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 #include "cpu/x64/jit_uni_i8i8_pooling.hpp"
17 #include <math.h>
18
19 #include "common/dnnl_thread.hpp"
20 #include "common/utils.hpp"
21
22 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
23 #include "cpu/x64/jit_generator.hpp"
24
25 namespace dnnl {
26 namespace impl {
27 namespace cpu {
28 namespace x64 {
29
get_supported_bcast_strategies()30 static bcast_set_t get_supported_bcast_strategies() {
31 return {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc};
32 }
33
get_offset(const memory_desc_wrapper & mdw,int n,int c,int d,int h,int w)34 static inline dim_t get_offset(
35 const memory_desc_wrapper &mdw, int n, int c, int d, int h, int w) {
36 switch (mdw.ndims()) {
37 case 3: return mdw.blk_off(n, c, w);
38 case 4: return mdw.blk_off(n, c, h, w);
39 case 5: return mdw.blk_off(n, c, d, h, w);
40 default: assert(!"Invalid tensor dimension in pooling");
41 }
42 return 0;
43 }
44
45 using namespace Xbyak;
46
47 using namespace dnnl::impl::utils;
48 using namespace dnnl::impl::utils;
49 using namespace dnnl::impl::types;
50 using namespace alg_kind;
51
52 #define GET_OFF(field) offsetof(call_params_t, field)
53
54 struct call_params_t {
55 const char *src_i8;
56 const char *dst_i8;
57 const void *post_ops_binary_rhs_arg_vec;
58 size_t kd_range;
59 size_t kh_range;
60 size_t kw_range;
61 float idivider;
62 const char *src_safe_access;
63 const char *dst_safe_access;
64 };
65
66 template <cpu_isa_t isa>
67 struct jit_uni_i8i8_pooling_fwd_ker_t : public jit_generator {
68 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
69
70 using Vmm = typename cpu_isa_traits<isa>::Vmm;
xregdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t71 Xmm xreg(int idx) const { return Xmm(idx); }
yregdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t72 Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
vregdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t73 Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
74
75 // In case of avx2 with data type i8 we need to use
76 // maskmovdqu and maskmovq instructions which has its destination hardcoded in rdi.
77 // Windows ABI: abi_param1 is rcx - nothing to do else
78 // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
79 Reg64 reg_param = rcx; // Our "unified abi_param1"
80 Reg64 reg_ptr_src_i8 = r8;
81 Reg64 reg_ptr_dst_i8 = r9;
82 Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
83
84 Reg64 reg_kd_index
85 = rdi; // shared with reg_ptr_maskmovdqu_dst; only used before store
86 Reg64 reg_kh_index = r11;
87 Reg64 reg_kw_index = r10;
88 Reg64 reg_kd = r14;
89 Reg64 reg_kh = r13;
90 Reg64 reg_kw = r12;
91 Reg64 c_iter = r15; // shared with reg_mask; only used after mask init
92
93 Reg64 aux_reg_src_d
94 = rdx; // shared with reg_tmp; loaded before each accum loop, unused during store
95 Reg64 aux_reg_src_h = rax;
96 Reg64 aux_reg_src_w = rbx;
97
98 Reg64 reg_tmp = rdx; // only used during mask init and store
99 Reg64 reg_src_safe_access = rbp;
100 Reg64 reg_dst_safe_access = rsi;
101
102 Reg64 reg_mask = r15; // only used during mask init
103
104 Opmask k_cmp_mask = Opmask(7);
105
maskdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t106 Opmask mask(int idx) { return Opmask(6 - idx); }
107
108 // ref to any of XYZ-regs via xreg/yreg/vreg functions
109 Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
110 Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
111 Vmm vreg_zeros = vreg(1);
112 Vmm vreg_tail = vreg(4);
113
114 // only in case of <isa> == avx2
115 Vmm vreg_mask = vreg(2); // full byte-mask
116 Xmm xreg_mask_lo = xreg(
117 2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
118 Xmm xreg_mask_hi = xreg(
119 3); // "max" - high 128-bits part of byte-mask (stored separately)
120
121 // vreg_mask shifted left (aligned left) to be used in tail processing.
122 // Example: idx [31..0]
123 // vreg_mask = [0,0,0,0,0,.....,0,x,x,x,x,x] ; x => byte mask (msb set)
124 // vreg_mask_2 = [x,x,x,x,x,0,0,0,0,0,.....,0]
125 Vmm vreg_mask_2 = vreg(5);
126 Xmm xreg_mask_2_lo = xreg(5); // similar to xreg_mask_lo
127 Xmm xreg_mask_2_hi = xreg(6); // similar to xreg_mask_hi
128
129 Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
130 Mmx mmx_dst_i8 = Mmx(
131 0); // "avg" - Mmx reg for masked store results of s8/u8 operations
132 Mmx mmx_full_msk = Mmx(
133 1); // "avg" - Mmx reg for full mask (all 8 bytes) - used until not in tail
134 Mmx mmx_tmp = Mmx(2);
135 int post_op_tail_opmask_idx_ = -1;
136 jit_pool_conf_t jpp;
137 std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
138 postops_injector_;
139
140 enum : int { max_vidx_base = utils::one_of(isa, sse41, avx2) ? 7 : 2 };
141 //"avg" pool uses more registers for unrolling.
142 enum : int { avg_vidx_base = utils::one_of(isa, sse41, avx2) ? 4 : 2 };
143
max_base_vrdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t144 Vmm max_base_vr(int idx) const { return vreg(max_vidx_base + idx); }
avg_base_vrdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t145 Vmm avg_base_vr(int idx) const { return vreg(avg_vidx_base + idx); }
146
sizeof_src_dtdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t147 size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
sizeof_dst_dtdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t148 size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
149
150 /* max pooling */
vreg_srcdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t151 Vmm vreg_src(int idx) const { return max_base_vr(idx); } // [0 .. ur_c-1]
vreg_dstdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t152 Vmm vreg_dst(int idx) const {
153 return max_base_vr(jpp.ur_c + idx);
154 } // [ur_c .. 2*ur_c-1]
155
156 /* avg pooling */
157 // s32 used for processing of s8/u8 data
158 // thus we need to take into account ratio of sizes s32/i8 = 4
159 static constexpr data_type_t avg_proc_dt = data_type::s32;
160 enum : int {
161 s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
162 / sizeof(typename prec_traits<data_type::u8>::type),
163 max_num_ll = s32_to_i8_ratio,
164 mmx_msk_base_reg = 3
165 };
166
vreg_src_s32dnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t167 Vmm vreg_src_s32(int jj, int ll) {
168 return avg_base_vr(3 * max_num_ll * jj + ll + 0 * max_num_ll);
169 } // ll: 0..4 [0..3]
170
vreg_dst_s32dnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t171 Vmm vreg_dst_s32(int jj, int ll) {
172 return avg_base_vr(3 * max_num_ll * jj + ll + 1 * max_num_ll);
173 } // ll: 0..4 [4..7]
174
vreg_dst_f32dnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t175 Vmm vreg_dst_f32(int jj, int ll) {
176 return avg_base_vr(3 * max_num_ll * jj + ll + 2 * max_num_ll);
177 } // ll: 0..4 [8..11]
178
mmx_maskdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t179 Mmx mmx_mask(int ll) {
180 return Mmx(mmx_msk_base_reg + ll);
181 }; // ll: 0..4 [Mmx(2)...Mmx(5)]
182
183 static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr,
184 const memory_desc_wrapper &dst_d);
185
186 void init_tmp_reg();
187 void init_mask();
188
load_vreg_mask_qdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t189 void load_vreg_mask_q(int ll) {};
190
191 void load_src_max_op(
192 int jj, int ll, size_t offset, bool masked, uint64_t msk);
193 void load_src_avg_op(
194 int jj, int ll, size_t offset, bool masked, uint64_t msk);
195 void load_src(int jj, int ll, int c_tail);
196
197 void store_dst_max_op(
198 int jj, int ll, size_t offset, bool masked, uint64_t msk);
199 void store_dst_avg_op(
200 int jj, int ll, size_t offset, bool masked, uint64_t msk);
201 void store_dst(int jj, int ll, int c_tail);
202
203 void compute_avg_step(int ur_c, int c_tail);
204 void compute_max_op(const int jj);
205 void compute_max_step(int ur_c, int c_tail);
206 void compute_step(int ur_c, int c_tail);
207
208 void compute_c_block();
209 void generate() override;
210
211 static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd);
212
jit_uni_i8i8_pooling_fwd_ker_tdnnl::impl::cpu::x64::jit_uni_i8i8_pooling_fwd_ker_t213 jit_uni_i8i8_pooling_fwd_ker_t(
214 const jit_pool_conf_t &jpp_, const memory_desc_t *dst_md)
215 : jit_generator(nullptr, MAX_CODE_SIZE, true, isa)
216 , jpp(jpp_)
217 , postops_injector_(nullptr) {
218
219 if (jpp.with_postops) {
220
221 const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
222 const std::size_t c_tail_elems = jpp.c % simd_w;
223 post_op_tail_opmask_idx_ = 0;
224 if (c_tail_elems) {
225 for (int ll = max_num_ll - 1; ll >= 0; ll--) {
226 if (jpp.tail[ll] != 0) {
227 post_op_tail_opmask_idx_ = ll;
228 break;
229 }
230 }
231 };
232
233 static constexpr bool preserve_gpr = true;
234 static constexpr bool preserve_vmm = true;
235 static constexpr bool use_exact_tail_scalar_bcast = false;
236 static constexpr std::size_t tmp_vmm_injector = 0u;
237
238 const binary_injector::rhs_arg_static_params_t rhs_sp {
239 tmp_vmm_injector, rax, r14, preserve_gpr, preserve_vmm,
240 GET_OFF(post_ops_binary_rhs_arg_vec),
241 memory_desc_wrapper(*dst_md), c_tail_elems,
242 mask(post_op_tail_opmask_idx_),
243 use_exact_tail_scalar_bcast};
244 const binary_injector::static_params_t bsp {
245 reg_param, get_supported_bcast_strategies(), rhs_sp};
246
247 postops_injector_ = utils::make_unique<
248 injector::jit_uni_postops_injector_t<isa>>(
249 this, jpp.post_ops, bsp);
250 }
251 }
252 };
253
254 template <>
load_vreg_mask_q(int ll)255 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_vreg_mask_q(int ll) {};
256
257 template <>
load_vreg_mask_q(int ll)258 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
259
260 // extract ll-th part of mask (ll-th QWORD)
261 vpblendd(vreg_mask_q, vreg_zeros, vreg_mask,
262 0x3 << 2 * ll); // 0x3 - mask for 2 x DWORD
263
264 // Move mask from ll-th pos to 0-th pos
265 if (ll > 0) vpermq(vreg_mask_q, vreg_mask_q, ll);
266 };
267
268 template <>
load_src_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)269 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_max_op(
270 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
271 using namespace data_type;
272
273 if (masked) {
274 if (jpp.src_dt == s32)
275 for (int64_t i = 0; i < jpp.c_tail; i++)
276 pinsrd(vreg_src(jj),
277 ptr[aux_reg_src_w + offset + i * data_type_size(s32)],
278 i);
279 else
280 for (int i = 0; i < jpp.c_tail; i++)
281 pinsrb(vreg_src(jj), ptr[aux_reg_src_w + offset + i], i);
282 } else
283 movups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
284 }
285
286 template <>
load_src_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)287 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(
288 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
289 using namespace data_type;
290
291 if (masked) {
292 if (jpp.src_dt == s32) {
293 vpmaskmovd(vreg_src(jj), vreg_mask, ptr[aux_reg_src_w + offset]);
294 } else {
295 // Steps to access 'tail' section:
296 // 1) First load all data from the shifted src ptr
297 // 2) Now bring the required data from the end of reg to beginning.
298 // Example: idx=[31..0]
299 // vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data
300 // shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,]
301 const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail;
302
303 if (jpp.safe_c_tail) {
304
305 /* load src_tail at 'src_address - shift' so that it does not
306 * spill over the memory boundary */
307 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset - shift]);
308
309 vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81);
310 vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift);
311
312 } else {
313 Label load_data_safely, done;
314 add(aux_reg_src_w, offset);
315
316 // Check if mask crosses page boundary
317 cmp(aux_reg_src_w, reg_src_safe_access);
318 ja(load_data_safely, T_NEAR);
319
320 vpblendvb(
321 vreg_src(jj), vreg_tmp, byte[aux_reg_src_w], vreg_mask);
322 jmp(done, T_NEAR);
323
324 L(load_data_safely);
325
326 /* load src_tail at 'src_address - shift' so that it does not
327 * spill over the memory boundary */
328 vmovups(vreg_src(jj), ptr[aux_reg_src_w - shift]);
329
330 vperm2i128(vreg_tmp, vreg_src(jj), vreg_src(jj), 0x81);
331 vpalignr(vreg_src(jj), vreg_tmp, vreg_src(jj), shift);
332
333 L(done);
334 sub(aux_reg_src_w, offset);
335 }
336 }
337
338 } else
339 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
340 };
341
342 template <>
load_src_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)343 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(
344 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
345 using namespace data_type;
346
347 if (masked) {
348 if (jpp.src_dt == s32)
349 vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
350 else
351 vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
352 } else
353 vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
354 };
355
356 template <>
load_src_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)357 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::load_src_avg_op(
358 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
359 using namespace data_type;
360
361 const Vmm &vr_src = vreg_src_s32(jj, ll);
362
363 if (jpp.src_dt == s32) {
364 if (masked)
365 for (int64_t i = 0; i < jpp.c_tail; i++)
366 pinsrd(vr_src,
367 ptr[aux_reg_src_w + offset + i * data_type_size(s32)],
368 i);
369 else
370 movups(vr_src, ptr[aux_reg_src_w + offset]);
371 } else if (utils::one_of(jpp.src_dt, s8, u8)) {
372 if (masked) {
373 const int copy_range = math::ilog2q(jpp.tail[ll] + 1);
374 for (int i = 0; i < copy_range; i++)
375 pinsrb(vr_src, ptr[aux_reg_src_w + offset + i], i);
376
377 if (jpp.src_dt == s8)
378 pmovsxbd(vr_src, vr_src);
379 else
380 pmovzxbd(vr_src, vr_src);
381 } else {
382 if (jpp.src_dt == s8)
383 pmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
384 else
385 pmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
386 }
387 } else
388 assert(!"unsupported src data type");
389 }
390
391 template <>
load_src_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)392 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(
393 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
394 using namespace data_type;
395
396 auto load_i8 = [&](bool is_signed, const Vmm &vr_src) {
397 // Need to use mask of tail?
398 if (masked) {
399
400 // load ll-th part of mask into vreg_mask_q
401 load_vreg_mask_q(ll);
402
403 // Steps to access 'tail' section:
404 // 1) First load all data from the shifted src ptr
405 // 2) Now bring the required data from the end of reg to begining.
406 // Example: idx=[31..0]
407 // vreg_src = [x,x,x,x,.....,x,-,-,-,-,-] ; x => byte data
408 // shift to transform vreg_src = [-,-,-,-,-,x,..,x,x,x,x,]
409 // Re-purposing vreg_zeros here. Set it back to zero immmediately.
410 const int msk_gran
411 = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt);
412
413 const uint8_t shift = cpu_isa_traits<avx2>::vlen
414 - (jpp.c_tail > (ll + 1) * msk_gran
415 ? msk_gran
416 : jpp.c_tail - (ll * msk_gran));
417 if (jpp.safe_c_tail) {
418 /* load src_tail at 'src_address - shift' so that it does not
419 * spill over the memory boundary */
420 vmovups(vr_src, ptr[aux_reg_src_w + offset - shift]);
421
422 vperm2i128(vreg_zeros, vr_src, vr_src, 0x81);
423 vpalignr(vr_src, vreg_zeros, vr_src, shift);
424 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
425 } else {
426 Label load_data_safely, done;
427 // assume that it is not safe to load the src_tail
428
429 add(aux_reg_src_w, offset);
430
431 // Check if load crosses the memory boundary
432 cmp(aux_reg_src_w, reg_src_safe_access);
433 ja(load_data_safely, T_NEAR);
434
435 vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w], vreg_mask_q);
436 jmp(done, T_NEAR);
437
438 L(load_data_safely);
439
440 /* load src_tail at 'src_address - shift' so that it does not
441 * spill over the memory boundary */
442 vmovups(vr_src, ptr[aux_reg_src_w - shift]);
443
444 vperm2i128(vreg_zeros, vr_src, vr_src, 0x81);
445 vpalignr(vr_src, vreg_zeros, vr_src, shift);
446 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
447
448 L(done);
449 sub(aux_reg_src_w, offset);
450 }
451
452 // Conversion s8/u8 -> s32
453 if (is_signed)
454 vpmovsxbd(vr_src, vr_src);
455 else
456 vpmovzxbd(vr_src, vr_src);
457 } else {
458
459 // Load from mem into vr_src with conversion
460 if (is_signed)
461 vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
462 else
463 vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
464 }
465 };
466
467 switch (jpp.src_dt) {
468 case s32:
469 if (masked)
470 vpmaskmovd(vreg_src_s32(jj, ll), vreg_mask,
471 ptr[aux_reg_src_w + offset]);
472 else
473 vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
474 break;
475 case s8: load_i8(true, vreg_src_s32(jj, ll)); break;
476 case u8: load_i8(false, vreg_src_s32(jj, ll)); break;
477 default: assert(!"unsupported src data type");
478 }
479 };
480
481 template <>
load_src_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)482 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(
483 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
484 using namespace data_type;
485
486 const Vmm &vr_src
487 = masked ? vreg_src_s32(jj, ll) | mask(ll) : vreg_src_s32(jj, ll);
488
489 switch (jpp.src_dt) {
490 case s32: vmovups(vr_src, ptr[aux_reg_src_w + offset]); break;
491 case s8: vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]); break;
492 case u8: vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]); break;
493 default: assert(!"unsupported src data type");
494 }
495 };
496
497 template <cpu_isa_t isa>
load_src(int jj,int ll,int c_tail)498 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
499 using namespace data_type;
500
501 int c_block = jpp.c_block;
502 int ur_c = jpp.ur_c;
503
504 switch (jpp.alg) {
505 case pooling_max: {
506 auto offset = jj * c_block * sizeof_src_dt();
507 bool masked = jj == ur_c - 1 && c_tail;
508 load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
509 break;
510 }
511 case pooling_avg_include_padding:
512 case pooling_avg_exclude_padding: {
513 auto offset = (ll * (c_block / max_num_ll) + jj * c_block)
514 * sizeof_src_dt();
515 bool masked = jj == ur_c - 1 && c_tail;
516 load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
517 break;
518 }
519 default: assert(!"unsupported algorithm");
520 }
521 }
522
523 template <>
store_dst_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)524 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_max_op(
525 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
526 using namespace data_type;
527
528 if (masked) {
529 if (jpp.src_dt == s32)
530 for (int i = 0; i < jpp.c_tail; i++)
531 pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)],
532 vreg_dst(jj), i);
533 else if (utils::one_of(jpp.src_dt, u8, s8))
534 for (int i = 0; i < jpp.c_tail; i++)
535 pextrb(ptr[reg_ptr_dst_i8 + offset + i], vreg_dst(jj), i);
536 else
537 assert(!"unsupported src data type");
538 } else
539 movups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
540 }
541
542 template <>
store_dst_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)543 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(
544 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
545 using namespace data_type;
546
547 Label store_data_safely, done;
548
549 int c_block = jpp.c_block;
550
551 const uint64_t low_mask = (1ULL << (c_block / 2)) - 1;
552 const uint8_t shift = cpu_isa_traits<avx2>::vlen - jpp.c_tail;
553
554 if (masked) {
555 switch (jpp.src_dt) {
556 case s32:
557 vpmaskmovd(
558 ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
559 break;
560 case s8:
561 case u8: {
562
563 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
564
565 if (!jpp.safe_c_tail) {
566 Xmm xreg_dst = Xmm(vreg_dst(jj).getIdx());
567
568 cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access);
569 ja(store_data_safely, T_NEAR);
570
571 // Store low half by mask (bytes 0...15)
572 vmaskmovdqu(xreg_dst, xreg_mask_lo);
573
574 // Do we need to store high half (bytes 16...31) ?
575 if (msk & ~low_mask) {
576 vextracti128(xreg_dst, vreg_dst(jj), 1);
577 add(reg_ptr_maskmovdqu_dst, c_block / 2);
578 vmaskmovdqu(xreg_dst, xreg_mask_hi);
579 }
580 jmp(done, T_NEAR);
581 }
582
583 L(store_data_safely);
584
585 vperm2i128(vreg_tail, vreg_dst(jj), vreg_dst(jj), 0x08);
586 if (shift <= 16) {
587 vpalignr(vreg_tail, vreg_dst(jj), vreg_tail, 16 - shift);
588 } else {
589 vpalignr(vreg_tail, vreg_tail, vreg_zeros, 32 - shift);
590 }
591
592 Xmm xreg_tail = Xmm(vreg_tail.getIdx());
593 // Do we need to store low half (bytes 0...15) ?
594 if (msk & ~low_mask) {
595 sub(reg_ptr_maskmovdqu_dst, shift);
596 vmaskmovdqu(xreg_tail, xreg_mask_2_lo);
597 add(reg_ptr_maskmovdqu_dst, c_block / 2);
598 } else {
599 add(reg_ptr_maskmovdqu_dst, (c_block / 2) - shift);
600 }
601
602 // Store high half by mask (bytes 16..31)
603 vextracti128(xreg_tail, vreg_tail, 1);
604 vmaskmovdqu(xreg_tail, xreg_mask_2_hi);
605
606 L(done);
607 } break;
608 default: assert(!"unsupported src data type");
609 }
610 } else
611 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
612 }
613
614 template <>
store_dst_max_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)615 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(
616 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
617 using namespace data_type;
618
619 if (masked) {
620 switch (jpp.src_dt) {
621 case s32:
622 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
623 break;
624 case s8:
625 case u8:
626 vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
627 break;
628 default: assert(!"unsupported src data type");
629 }
630 } else
631 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
632 }
633
634 template <>
store_dst_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)635 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::store_dst_avg_op(
636 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
637 using namespace data_type;
638
639 // Don't generate useless code
640 if (masked && !msk) return;
641
642 const Vmm &vr_dst = vreg_dst_s32(jj, ll);
643
644 if (jpp.src_dt == s32) {
645 if (masked)
646 for (int i = 0; i < jpp.c_tail; i++)
647 pextrd(ptr[reg_ptr_dst_i8 + offset + i * data_type_size(s32)],
648 vr_dst, i);
649 else
650 movups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
651 } else if (utils::one_of(jpp.src_dt, s8, u8)) {
652 packssdw(vr_dst, vr_dst);
653 if (jpp.src_dt == s8)
654 packsswb(vr_dst, vr_dst);
655 else
656 packuswb(vr_dst, vr_dst);
657
658 const int copy_range = masked
659 ? math::ilog2q(jpp.tail[ll] + 1)
660 : cpu_isa_traits<sse41>::vlen / data_type_size(avg_proc_dt);
661 for (int i = 0; i < copy_range; i++)
662 pextrb(ptr[reg_ptr_dst_i8 + offset + i], vr_dst, i);
663 } else
664 assert(!"unsupported src data type");
665 }
666
667 template <>
store_dst_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)668 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(
669 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
670 using namespace data_type;
671
672 // Don't generate useless code
673 if (masked && !msk) return;
674
675 auto s32_to_i8 = [&](bool is_signed, const Vmm &vr_dst) {
676 // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
677 // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
678 if (is_signed)
679 vpackssdw(vr_dst, vr_dst, vreg_zeros);
680 else
681 vpackusdw(vr_dst, vr_dst, vreg_zeros);
682
683 // Permute qwords to restore original order
684 // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
685 vpermq(vr_dst, vr_dst, 0x58);
686
687 // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
688 // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
689 if (is_signed)
690 vpacksswb(vr_dst, vr_dst, vreg_zeros);
691 else
692 vpackuswb(vr_dst, vr_dst, vreg_zeros);
693 };
694
695 auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm &vr_dst) {
696 // Conversion s32 -> s8/u8
697 s32_to_i8(is_signed, vr_dst);
698
699 // early-out for non-masked cases
700 if (!is_masked) {
701 vmovlps(ptr[reg_ptr_dst_i8 + offset], Xmm(vr_dst.getIdx()));
702 return;
703 }
704 // store 8 bytes
705 lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
706
707 // Need to use mmx 8-bytes operation to avoid memory violations.
708 // NOTICE: it was discovered that Intel SSE and Intel AVX instructions
709 // maskmovdqu/vmaskmovdqu
710 // with low 8-bytes mask throws exception if high 8-bytes belongs write-protected page.
711 // NOTE: use indirect move via gpr to avoid transition penalty
712 vmovq(reg_tmp, Xmm(vr_dst.getIdx()));
713 movq(mmx_dst_i8, reg_tmp);
714
715 // mmx_full_msk - mask for all 8 bytes in zero-tail case
716 // mmx_mask(ll) - ll-th mask of tail in non-zero-tail case
717
718 const int msk_gran
719 = cpu_isa_traits<avx2>::vlen / data_type_size(avg_proc_dt);
720
721 const int ll_end = (ll + 1) * msk_gran; // ((ll + 1) * 8)
722
723 if (is_masked && (ll_end > jpp.c_tail)) { //implies this tail not full.
724 Label store_data_safely, done;
725 const uint8_t shift = msk_gran - jpp.c_tail % msk_gran;
726
727 if (!jpp.safe_c_tail) {
728 cmp(reg_ptr_maskmovdqu_dst, reg_dst_safe_access);
729 ja(store_data_safely, T_NEAR);
730
731 /* store dst_tail with overlap outside the channel dimension,
732 * but assume it's within the memory boundary. */
733 maskmovq(mmx_dst_i8, mmx_mask(ll));
734 jmp(done, T_NEAR);
735 }
736
737 L(store_data_safely);
738
739 /* store dst_tail at 'dst_address - shift' so that it does not
740 * spill over the memory boundary */
741 movq(mmx_tmp, mmx_mask(ll));
742 psllq(mmx_tmp, shift * 8); // multiply with 8 (bits/byte)
743 psllq(mmx_dst_i8, shift * 8);
744 sub(reg_ptr_maskmovdqu_dst, shift);
745 maskmovq(mmx_dst_i8, mmx_tmp);
746
747 L(done);
748 } else {
749 maskmovq(mmx_dst_i8, mmx_full_msk);
750 }
751 };
752
753 switch (jpp.dst_dt) {
754 case s32:
755 if (masked) {
756 vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask,
757 vreg_dst_s32(jj, ll));
758 } else
759 vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
760 break;
761 case s8: store_i8(true, masked, vreg_dst_s32(jj, ll)); break;
762 case u8: store_i8(false, masked, vreg_dst_s32(jj, ll)); break;
763 default: assert(!"unsuppotred dst data_type");
764 }
765 }
766
767 template <>
store_dst_avg_op(int jj,int ll,size_t offset,bool masked,uint64_t msk)768 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(
769 int jj, int ll, size_t offset, bool masked, uint64_t msk) {
770 using namespace data_type;
771
772 // Don't generate useless code
773 if (masked && !msk) return;
774
775 const Vmm &vr_dst
776 = masked ? vreg_dst_s32(jj, ll) | mask(ll) : vreg_dst_s32(jj, ll);
777
778 switch (jpp.dst_dt) {
779 case s32: vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
780 case s8: vpmovsdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
781 case u8: vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst); break;
782 default: assert(!"unsupported dst data_type");
783 }
784 }
785
786 template <cpu_isa_t isa>
store_dst(int jj,int ll,int c_tail)787 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(
788 int jj, int ll, int c_tail) {
789 using namespace data_type;
790
791 int c_block = jpp.c_block;
792 int ur_c = jpp.ur_c;
793
794 switch (jpp.alg) {
795 case pooling_max: {
796 auto offset = jj * c_block * sizeof_dst_dt();
797 bool masked = jj == ur_c - 1 && c_tail;
798 store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
799 break;
800 }
801 case pooling_avg_include_padding:
802 case pooling_avg_exclude_padding: {
803 auto offset = (ll * (c_block / max_num_ll) + jj * c_block)
804 * sizeof_dst_dt();
805 bool masked = jj == ur_c - 1 && c_tail;
806 store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
807 break;
808 }
809 default: assert(!"unsupported pooling algorithm");
810 }
811 }
812
813 template <>
compute_max_op(const int jj)814 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::compute_max_op(const int jj) {
815 using namespace data_type;
816 switch (jpp.src_dt) {
817 case s32: pmaxsd(vreg_dst(jj), vreg_src(jj)); break;
818 case s8: pmaxsb(vreg_dst(jj), vreg_src(jj)); break;
819 case u8: pmaxub(vreg_dst(jj), vreg_src(jj)); break;
820 default: assert(!"unsupported src data type");
821 }
822 }
823
824 template <>
compute_max_op(const int jj)825 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj) {
826 using namespace data_type;
827 switch (jpp.src_dt) {
828 case s32: vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
829 case s8: vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
830 case u8: vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj)); break;
831 default: assert(!"unsupported src data type");
832 }
833 }
834
835 template <>
compute_max_op(const int jj)836 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj) {
837 using namespace data_type;
838
839 // Compare
840 switch (jpp.src_dt) {
841 case s32:
842 vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
843 break;
844 case s8:
845 vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
846 break;
847 case u8:
848 vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
849 break;
850 default: assert(!"unsupported src data type");
851 }
852
853 // move max values into vreg_dst
854 if (jpp.src_dt == s32)
855 vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
856 else
857 vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
858 }
859
860 template <cpu_isa_t isa>
compute_max_step(int ur_c,int c_tail)861 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(
862 int ur_c, int c_tail) {
863 Label l_kd, l_kh, l_kw;
864
865 int ih = jpp.ih;
866 int iw = jpp.iw;
867 int c = jpp.c;
868
869 for (int jj = 0; jj < ur_c; jj++)
870 uni_vmovups(vreg_dst(jj), vreg_tmp);
871
872 mov(aux_reg_src_d, reg_ptr_src_i8);
873 xor_(reg_kd_index, reg_kd_index);
874 L(l_kd);
875 {
876 mov(aux_reg_src_h, aux_reg_src_d);
877 xor_(reg_kh_index, reg_kh_index);
878 L(l_kh);
879 {
880 mov(aux_reg_src_w, aux_reg_src_h);
881 xor_(reg_kw_index, reg_kw_index);
882 L(l_kw);
883 {
884 for (int jj = 0; jj < ur_c; jj++) {
885 load_src(jj, 0, c_tail);
886 compute_max_op(jj);
887 }
888 add(aux_reg_src_w, c * sizeof_src_dt());
889 inc(reg_kw_index);
890 cmp(reg_kw_index, reg_kw);
891 jl(l_kw, T_NEAR);
892 }
893 add(aux_reg_src_h, iw * c * sizeof_src_dt());
894 inc(reg_kh_index);
895 cmp(reg_kh_index, reg_kh);
896 jl(l_kh, T_NEAR);
897 }
898 add(aux_reg_src_d, ih * iw * c * sizeof_src_dt());
899 inc(reg_kd_index);
900 cmp(reg_kd_index, reg_kd);
901 jl(l_kd, T_NEAR);
902 }
903
904 for (int jj = 0; jj < ur_c; jj++)
905 store_dst(jj, 0, c_tail);
906 }
907
908 template <cpu_isa_t isa>
compute_avg_step(int ur_c,int c_tail)909 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(
910 int ur_c, int c_tail) {
911 using namespace data_type;
912
913 Label l_kd, l_kh, l_kw;
914
915 int ih = jpp.ih;
916 int iw = jpp.iw;
917 int c = jpp.c;
918
919 const int num_ll = data_type_size(avg_proc_dt) / data_type_size(jpp.src_dt);
920
921 for (int jj = 0; jj < ur_c; jj++) {
922 for (int ll = 0; ll < num_ll; ll++) {
923 bool masked = jj == ur_c - 1 && c_tail;
924 size_t msk = jpp.tail[ll];
925 if (!(masked && !msk)) {
926 // Clearing of src reg is not needed as they are written before read
927 uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
928 vreg_dst_s32(jj, ll));
929 }
930 }
931 }
932
933 mov(aux_reg_src_d, reg_ptr_src_i8);
934 xor_(reg_kd_index, reg_kd_index);
935 L(l_kd);
936 {
937 mov(aux_reg_src_h, aux_reg_src_d);
938 xor_(reg_kh_index, reg_kh_index);
939 L(l_kh);
940 {
941 mov(aux_reg_src_w, aux_reg_src_h);
942 xor_(reg_kw_index, reg_kw_index);
943 L(l_kw);
944 {
945 for (int jj = 0; jj < ur_c; jj++) {
946 for (int ll = 0; ll < num_ll; ll++) {
947 bool masked = jj == ur_c - 1 && c_tail;
948 size_t msk = jpp.tail[ll];
949 if (!(masked && !msk)) {
950 load_src(jj, ll, c_tail);
951 uni_vpaddd(vreg_dst_s32(jj, ll),
952 vreg_dst_s32(jj, ll), vreg_src_s32(jj, ll));
953 }
954 }
955 }
956 add(aux_reg_src_w, c * sizeof_src_dt());
957 inc(reg_kw_index);
958 cmp(reg_kw_index, reg_kw);
959 jl(l_kw, T_NEAR);
960 }
961 add(aux_reg_src_h, iw * c * sizeof_src_dt());
962 inc(reg_kh_index);
963 cmp(reg_kh_index, reg_kh);
964 jl(l_kh, T_NEAR);
965 }
966 add(aux_reg_src_d, ih * iw * c * sizeof_src_dt());
967 inc(reg_kd_index);
968 cmp(reg_kd_index, reg_kd);
969 jl(l_kd, T_NEAR);
970 }
971
972 static constexpr int vlen_size_elem
973 = cpu_isa_traits<isa>::vlen / sizeof(float);
974 const auto reg_tmp_postops = r15;
975 const injector_utils::register_preserve_guard_t reg_guard(this,
976 jpp.with_binary
977 ? std::initializer_list<Xbyak::Reg64> {reg_tmp_postops}
978 : std::initializer_list<Xbyak::Reg64> {},
979 {});
980 if (jpp.with_binary) {
981 imul(reg_tmp_postops, c_iter, ur_c * num_ll * vlen_size_elem);
982 }
983
984 for (int jj = 0; jj < ur_c; jj++) {
985 for (int ll = 0; ll < num_ll; ll++) {
986 const bool masked = jj == ur_c - 1 && c_tail;
987 const size_t msk = jpp.tail[ll];
988 if (!(masked && !msk)) {
989 const auto ®_dst_f32 = vreg_dst_f32(jj, ll);
990 const auto ®_dst_s32 = vreg_dst_s32(jj, ll);
991 uni_vcvtdq2ps(reg_dst_f32, reg_dst_s32);
992 uni_vfmadd132ps(reg_dst_f32, vreg_zeros, vreg_tmp);
993
994 if (jpp.with_postops) {
995 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
996 if (jpp.with_binary) {
997 rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
998 reg_dst_f32.getIdx(), reg_tmp_postops);
999 rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(
1000 reg_dst_f32.getIdx(),
1001 ll * vlen_size_elem + jj * vlen_size_elem);
1002 rhs_arg_params.vmm_idx_to_oc_off_oprnd.emplace(
1003 reg_dst_f32.getIdx(), reg_tmp_postops);
1004 rhs_arg_params.vmm_idx_to_oc_elem_off_val.emplace(
1005 reg_dst_f32.getIdx(),
1006 ll * vlen_size_elem + jj * vlen_size_elem);
1007 const bool tail = ll == post_op_tail_opmask_idx_;
1008 if (tail && masked)
1009 rhs_arg_params.vmm_tail_idx_.emplace(
1010 reg_dst_f32.getIdx());
1011 }
1012 postops_injector_->compute_vector(
1013 reg_dst_f32.getIdx(), rhs_arg_params);
1014 }
1015
1016 uni_vcvtps2dq(reg_dst_s32, reg_dst_f32);
1017
1018 if (jpp.with_postops)
1019 if (jpp.dst_dt == u8) {
1020 uni_vpmaxsd(reg_dst_s32, reg_dst_s32, vreg_zeros);
1021 }
1022 store_dst(jj, ll, c_tail);
1023 }
1024 }
1025 }
1026 }
1027
1028 template <cpu_isa_t isa>
compute_step(int ur_c,int c_tail)1029 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
1030 switch (jpp.alg) {
1031 case pooling_max: compute_max_step(ur_c, c_tail); break;
1032 case pooling_avg_include_padding:
1033 case pooling_avg_exclude_padding: compute_avg_step(ur_c, c_tail); break;
1034 default: assert(!"unsupported pooling algorithm");
1035 }
1036 }
1037
1038 template <cpu_isa_t isa>
compute_c_block()1039 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block() {
1040 Label l_main_loop;
1041
1042 int nb_c = jpp.nb_c;
1043 int c_block = jpp.c_block;
1044 int ur_c = jpp.ur_c;
1045 int ur_c_tail = jpp.ur_c_tail;
1046 int c_steps = nb_c / ur_c;
1047 int c_tail = jpp.c_tail;
1048
1049 xor_(c_iter, c_iter);
1050 if (c_steps > 0) {
1051 L(l_main_loop);
1052 {
1053 compute_step(ur_c, 0);
1054 add(reg_ptr_src_i8, ur_c * c_block * sizeof_src_dt());
1055 add(reg_ptr_dst_i8, ur_c * c_block * sizeof_dst_dt());
1056 inc(c_iter);
1057 cmp(c_iter, c_steps);
1058 jl(l_main_loop, T_NEAR);
1059 }
1060 }
1061
1062 if (ur_c_tail != 0) { compute_step(ur_c_tail, c_tail); }
1063 }
1064
1065 template <>
init_mask()1066 void jit_uni_i8i8_pooling_fwd_ker_t<sse41>::init_mask() {}
1067
1068 template <>
init_mask()1069 void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
1070 using namespace data_type;
1071 using cpu_isa = cpu_isa_traits<avx2>;
1072
1073 // AVX2 mask initialization: mask stored in Ymm-regs
1074 auto init = [&](uint64_t bit_mask, bool need_ymm_mask = true,
1075 bool need_mmx_mask = false) {
1076 const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
1077
1078 const size_t DBITS = 8 * sizeof_src_dt();
1079 const uint64_t VMSK = 1ULL << (DBITS - 1);
1080 const size_t D_PER_QW = (8 * sizeof(uint64_t)) / DBITS;
1081 uint64_t vmask[QW_PER_VREG];
1082 for (size_t i = 0; i < QW_PER_VREG; i++) {
1083 uint64_t qw_vmask = 0ULL;
1084 for (size_t j = 0; j < D_PER_QW; j++) {
1085 if (bit_mask & 1) qw_vmask |= VMSK << DBITS * j;
1086 bit_mask >>= 1;
1087 }
1088 vmask[i] = qw_vmask;
1089 }
1090
1091 // Need mask in Ymm regs ?
1092 if (need_ymm_mask) {
1093
1094 // Put QWORDS with target mask into xmm regs
1095 const int xdst_i[QW_PER_VREG]
1096 = {xreg_mask_lo.getIdx(), xreg_mask_lo.getIdx(),
1097 xreg_mask_hi.getIdx(), xreg_mask_hi.getIdx()};
1098 const int xsrc_i[QW_PER_VREG] = {
1099 vreg_zeros
1100 .getIdx(), // 0-th qword insert in zeros -> {qw0, 0}
1101 xreg_mask_lo
1102 .getIdx(), // 1-st and 0-th merge -> {qw0,qw1}
1103 vreg_zeros.getIdx(), xreg_mask_hi.getIdx()};
1104 const uint8 qw_dst_idx[QW_PER_VREG]
1105 = {0, 1, 0, 1}; // qword index in 128-bit xreg
1106
1107 for (size_t i = 0; i < QW_PER_VREG; i++) {
1108 mov(reg_mask, vmask[i]);
1109 vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask,
1110 qw_dst_idx[i]);
1111
1112 // Need mask in MMX regs also?
1113 if (need_mmx_mask)
1114 movq(mmx_mask(i), reg_mask); // reuse value in reg_mask
1115 }
1116
1117 // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
1118 // and High (xreg_mask_hi) into full vreg_mask
1119 // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
1120 vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
1121
1122 // Compute mask algned to left from vreg_mask and store it in vreg_mask_2 to be use for tail processing.
1123 const uint8_t shift = 32 - jpp.c_tail;
1124 vperm2i128(vreg_mask_2, vreg_mask, vreg_mask, 0x08);
1125 if (shift <= 16) {
1126 vpalignr(vreg_mask_2, vreg_mask, vreg_mask_2, 16 - shift);
1127 } else {
1128 vpalignr(vreg_mask_2, vreg_mask_2, vreg_zeros, 32 - shift);
1129 }
1130 vextracti128(xreg_mask_2_hi, vreg_mask_2, 0x1);
1131 }
1132
1133 // Need mask in MMX regs ?
1134 if (need_mmx_mask) {
1135
1136 // Only in MMX regs ?
1137 if (!need_ymm_mask)
1138 for (size_t i = 0; i < QW_PER_VREG; i++) {
1139 mov(reg_mask, vmask[i]);
1140 movq(mmx_mask(i), reg_mask);
1141 }
1142
1143 // Form full mask for one QWORD
1144 uint64_t qw_full_vmask = 0ULL;
1145 for (size_t i = 0; i < D_PER_QW; i++)
1146 qw_full_vmask |= VMSK << DBITS * i;
1147
1148 mov(reg_mask, qw_full_vmask);
1149 movq(mmx_full_msk, reg_mask);
1150 }
1151 };
1152
1153 uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
1154 switch (jpp.alg) {
1155 case pooling_max:
1156 // For "max" we need mask only in case of non-zero tail
1157 if (tail_mask) init(tail_mask);
1158 break;
1159 case pooling_avg_include_padding:
1160 case pooling_avg_exclude_padding:
1161 // For "avg" we need mask:
1162 // - s32 - in case of the non-zero tail
1163 // - s8/u8 - irrespective of the tail in MMX regs (always store by mask)
1164 // - for non-zero tail in Ymm regs (for load)
1165 switch (jpp.src_dt) {
1166 case s32:
1167 if (tail_mask) init(tail_mask);
1168 break;
1169 case s8:
1170 case u8:
1171 init(tail_mask ? tail_mask : ~0ULL, tail_mask != 0, true);
1172 break;
1173 default: assert(!"unsupported src data type");
1174 }
1175 break;
1176 default: assert(!"unsupported pooling algorithm");
1177 }
1178 }
1179
1180 template <>
init_mask()1181 void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
1182
1183 for (int ll = 0; ll < max_num_ll; ll++) {
1184 mov(reg_mask, jpp.tail[ll]);
1185 kmovq(mask(ll), reg_mask);
1186 }
1187 }
1188
1189 template <cpu_isa_t isa>
init_tmp_reg()1190 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
1191 using namespace data_type;
1192
1193 switch (jpp.alg) {
1194 case pooling_avg_include_padding:
1195 case pooling_avg_exclude_padding:
1196 mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
1197 uni_vmovq(xmm_tmp, reg_tmp);
1198 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
1199 break;
1200 case pooling_max:
1201 switch (jpp.src_dt) {
1202 case s32:
1203 mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
1204 break;
1205 case s8:
1206 mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
1207 break;
1208 case u8:
1209 mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
1210 break;
1211 default: assert(!"unsupported src data_type");
1212 }
1213
1214 uni_vmovq(xmm_tmp, reg_tmp);
1215 if (jpp.src_dt == s32)
1216 uni_vpbroadcastd(vreg_tmp, xmm_tmp);
1217 else if (mayiuse(avx2))
1218 vpbroadcastb(vreg_tmp, xmm_tmp);
1219 else
1220 pshufb(xmm_tmp, vreg_zeros);
1221 break;
1222 default: assert(!"unsupported pooling algorithm");
1223 }
1224 }
1225
1226 template <cpu_isa_t isa>
generate()1227 void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
1228 preamble();
1229
1230 #if !defined(_WIN32)
1231 // Always use rcx as abi_param1 -
1232 // see the note about maskmovdqu/maskmovq near reg_param.
1233 mov(rcx, rdi);
1234 #endif
1235
1236 #define READ_PARAM(reg, field) \
1237 mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
1238 READ_PARAM(reg_ptr_src_i8, src_i8);
1239 READ_PARAM(reg_ptr_dst_i8, dst_i8);
1240 READ_PARAM(reg_kd, kd_range);
1241 READ_PARAM(reg_kh, kh_range);
1242 READ_PARAM(reg_kw, kw_range);
1243 READ_PARAM(reg_src_safe_access, src_safe_access);
1244 READ_PARAM(reg_dst_safe_access, dst_safe_access);
1245
1246 #undef READ_PARAM
1247
1248 uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
1249
1250 init_mask();
1251
1252 init_tmp_reg();
1253
1254 compute_c_block();
1255
1256 emms();
1257 postamble();
1258
1259 if (jpp.with_eltwise && postops_injector_)
1260 postops_injector_->prepare_table();
1261 }
1262
1263 template <cpu_isa_t isa>
init_conf(jit_pool_conf_t & jpp,const pooling_pd_t * ppd)1264 status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(
1265 jit_pool_conf_t &jpp, const pooling_pd_t *ppd) {
1266 if (!mayiuse(isa)) return status::unimplemented;
1267
1268 const auto &pd = *ppd->desc();
1269 const memory_desc_wrapper src_d(ppd->src_md());
1270 const memory_desc_wrapper dst_d(ppd->dst_md());
1271 const int ndims = src_d.ndims();
1272 const bool is_1d = ndims == 3;
1273 const bool is_3d = ndims == 5;
1274
1275 jpp.mb = src_d.dims()[0];
1276 jpp.c = src_d.dims()[1];
1277
1278 jpp.id = is_3d ? src_d.dims()[ndims - 3] : 1;
1279 jpp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1280 jpp.iw = src_d.dims()[ndims - 1];
1281
1282 jpp.od = is_3d ? dst_d.dims()[ndims - 3] : 1;
1283 jpp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1284 jpp.ow = dst_d.dims()[ndims - 1];
1285
1286 jpp.stride_d = is_3d ? pd.strides[ndims - 5] : 1;
1287 jpp.stride_h = is_1d ? 1 : pd.strides[ndims - 4];
1288 jpp.stride_w = pd.strides[ndims - 3];
1289
1290 jpp.kd = is_3d ? pd.kernel[ndims - 5] : 1;
1291 jpp.kh = is_1d ? 1 : pd.kernel[ndims - 4];
1292 jpp.kw = pd.kernel[ndims - 3];
1293
1294 jpp.f_pad = is_3d ? pd.padding[0][ndims - 5] : 0;
1295 jpp.t_pad = is_1d ? 0 : pd.padding[0][ndims - 4];
1296 jpp.l_pad = pd.padding[0][ndims - 3];
1297
1298 int back_pad = calculate_end_padding(
1299 jpp.f_pad, jpp.od, jpp.id, jpp.stride_d, jpp.kd);
1300 int bottom_pad = calculate_end_padding(
1301 jpp.t_pad, jpp.oh, jpp.ih, jpp.stride_h, jpp.kh);
1302 int right_pad = calculate_end_padding(
1303 jpp.l_pad, jpp.ow, jpp.iw, jpp.stride_w, jpp.kw);
1304
1305 if (jpp.f_pad >= jpp.kd || jpp.t_pad >= jpp.kh || jpp.l_pad >= jpp.kw
1306 || back_pad >= jpp.kd || bottom_pad >= jpp.kh
1307 || right_pad >= jpp.kw)
1308 return status::unimplemented;
1309
1310 jpp.alg = pd.alg_kind;
1311
1312 jpp.src_dt = pd.src_desc.data_type;
1313 jpp.dst_dt = pd.dst_desc.data_type;
1314
1315 // data_type items per one vreg on the <isa>
1316 // isa == sse41 : 16 bytes -> 16 for s8/u8, 4 for s32
1317 // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32
1318 // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
1319 int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
1320
1321 /* Verify that vlen-sized memory access happens within the tensor's
1322 * size, otherwise load/store will always spill outside the memory
1323 * boundary.*/
1324 bool safe_load_n_store = IMPLICATION(utils::one_of(isa, avx2, sse41),
1325 jpp.mb * jpp.c * nstl::min(jpp.id, jpp.od)
1326 * nstl::min(jpp.ih, jpp.oh)
1327 * nstl::min(jpp.iw, jpp.ow)
1328 >= simd_w);
1329 if (!safe_load_n_store) return status::unimplemented;
1330
1331 jpp.c_block = simd_w;
1332 jpp.c_tail = jpp.c % jpp.c_block;
1333 jpp.nb_c = jpp.c / jpp.c_block;
1334 jpp.ur_c = 1;
1335 jpp.ur_c_tail = jpp.c_tail != 0;
1336
1337 size_t tail_mask = (1ULL << jpp.c_tail) - 1;
1338
1339 /* If channel_size is bigger than vlen, we can safely assume there is no
1340 * underflow of memory boundary, so always perform c_tail and save
1341 * a couple of compute cycles*/
1342 jpp.safe_c_tail = jpp.c_tail > 0 && jpp.c >= simd_w;
1343
1344 switch (jpp.alg) {
1345 case pooling_max:
1346 jpp.tail[0] = tail_mask;
1347 jpp.tail[1] = 0;
1348 jpp.tail[2] = 0;
1349 jpp.tail[3] = 0;
1350 break;
1351 case pooling_avg_include_padding:
1352 case pooling_avg_exclude_padding: {
1353 // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
1354 // sse : 4, avx2 : 8, avx512 : 16
1355 const size_t msk_gran
1356 = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
1357 const size_t msk_msk = (1ULL << msk_gran) - 1;
1358 size_t m = tail_mask;
1359 for (size_t ll = 0; ll < max_num_ll; ll++) {
1360 jpp.tail[ll] = m & msk_msk;
1361 m = m >> msk_gran;
1362 }
1363 break;
1364 }
1365 default: return status::unimplemented;
1366 }
1367
1368 if (!post_ops_ok(jpp, *ppd->attr(), dst_d)) return status::unimplemented;
1369
1370 return status::success;
1371 }
1372
1373 template <cpu_isa_t isa>
post_ops_ok(jit_pool_conf_t & jpp,const primitive_attr_t & attr,const memory_desc_wrapper & dst_d)1374 bool jit_uni_i8i8_pooling_fwd_ker_t<isa>::post_ops_ok(jit_pool_conf_t &jpp,
1375 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d) {
1376 const auto &post_ops = attr.post_ops_;
1377 const auto &entries = post_ops.entry_;
1378 jpp.with_postops = false;
1379 jpp.with_eltwise = false;
1380 jpp.with_binary = false;
1381
1382 if (entries.empty()) return true;
1383
1384 for (const auto &entry : entries) {
1385 if (entry.is_eltwise()) {
1386 const auto alg = entry.eltwise.alg;
1387 jpp.with_eltwise = eltwise_injector::is_supported(isa, alg);
1388 } else if (entry.is_binary()) {
1389 if (isa != avx512_core
1390 && entry.binary.src1_desc.data_type == data_type::bf16)
1391 return false;
1392 jpp.with_binary = true;
1393 } else
1394 return false;
1395 }
1396
1397 jpp.with_postops = jpp.with_eltwise || jpp.with_binary;
1398 jpp.post_ops = post_ops;
1399
1400 /*
1401 * TODO Currently eltwise/binary injectors assumes that data in vmm has f32 dt.
1402 * In max pooling data remains in i8 data type.
1403 */
1404 return IMPLICATION(jpp.with_postops, jpp.alg != pooling_max)
1405 && binary_injector::binary_args_broadcast_supported(
1406 post_ops, dst_d, get_supported_bcast_strategies());
1407 }
1408
1409 template <cpu_isa_t isa>
jit_conf()1410 status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
1411 return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this);
1412 }
1413
1414 template <cpu_isa_t isa>
jit_uni_i8i8_pooling_fwd_t(const pd_t * apd)1415 jit_uni_i8i8_pooling_fwd_t<isa>::jit_uni_i8i8_pooling_fwd_t(const pd_t *apd)
1416 : primitive_t(apd), ker_(nullptr) {}
1417
1418 template <cpu_isa_t isa>
1419 jit_uni_i8i8_pooling_fwd_t<isa>::~jit_uni_i8i8_pooling_fwd_t() = default;
1420
1421 template <cpu_isa_t isa>
init(engine_t * engine)1422 status_t jit_uni_i8i8_pooling_fwd_t<isa>::init(engine_t *engine) {
1423 CHECK(safe_ptr_assign(ker_,
1424 new jit_uni_i8i8_pooling_fwd_ker_t<isa>(
1425 pd()->jpp_, pd()->invariant_dst_md())));
1426 return ker_->create_kernel();
1427 }
1428
1429 template <cpu_isa_t isa>
execute_forward(const exec_ctx_t & ctx) const1430 status_t jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward(
1431 const exec_ctx_t &ctx) const {
1432 auto src_i8 = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1433 auto dst_i8 = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1434
1435 const memory_desc_wrapper src_d(pd()->src_md());
1436 const memory_desc_wrapper dst_d(pd()->dst_md());
1437
1438 const auto &jpp = pd()->jpp_;
1439 const auto post_ops_binary_rhs_arg_vec
1440 = binary_injector::prepare_binary_args(jpp.post_ops, ctx);
1441 /* Calculate when the memory-access will happen outisde of the memory
1442 * boundary, if so, compute a safe memory access. */
1443 const auto src_safe_access = reinterpret_cast<char *>(
1444 reinterpret_cast<ptrdiff_t>(src_i8 + src_d.size() - 1)
1445 - (cpu_isa_traits<isa>::vlen - 1));
1446
1447 const auto dst_safe_access = reinterpret_cast<char *>(
1448 reinterpret_cast<ptrdiff_t>(dst_i8 + dst_d.size() - 1)
1449 - (cpu_isa_traits<isa>::vlen - 1));
1450
1451 parallel_nd(
1452 jpp.mb, jpp.od, jpp.oh, jpp.ow, [&](int n, int od, int oh, int ow) {
1453 const int id = nstl::max(od * jpp.stride_d - jpp.f_pad, 0);
1454 const int ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, 0);
1455 const int iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, 0);
1456
1457 const int kd_start
1458 = nstl::max(0, jpp.f_pad - od * jpp.stride_d);
1459 const int kd_end = nstl::min(
1460 jpp.kd, jpp.id + jpp.f_pad - od * jpp.stride_d);
1461 const int kh_start
1462 = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
1463 const int kh_end = nstl::min(
1464 jpp.kh, jpp.ih + jpp.t_pad - oh * jpp.stride_h);
1465 const int kw_start
1466 = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
1467 const int kw_end = nstl::min(
1468 jpp.kw, jpp.iw + jpp.l_pad - ow * jpp.stride_w);
1469
1470 auto p = call_params_t();
1471 p.src_i8 = &src_i8[get_offset(src_d, n, 0, id, ih, iw)
1472 * src_d.data_type_size()];
1473 p.dst_i8 = &dst_i8[get_offset(dst_d, n, 0, od, oh, ow)
1474 * dst_d.data_type_size()];
1475 p.kd_range = (size_t)(kd_end - kd_start);
1476 p.kh_range = (size_t)(kh_end - kh_start);
1477 p.kw_range = (size_t)(kw_end - kw_start);
1478 p.idivider = 1.0f
1479 / ((jpp.alg == pooling_avg_exclude_padding)
1480 ? p.kd_range * p.kh_range * p.kw_range
1481 : jpp.kd * jpp.kh * jpp.kw);
1482 p.src_safe_access = src_safe_access;
1483 p.dst_safe_access = dst_safe_access;
1484 p.post_ops_binary_rhs_arg_vec
1485 = post_ops_binary_rhs_arg_vec.data();
1486 (*ker_)(&p);
1487 });
1488 return status::success;
1489 }
1490
1491 // Explicit instantiation only for supported <isa> values.
1492 //
1493 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
1494 template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
1495
1496 template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
1497 template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
1498
1499 template struct jit_uni_i8i8_pooling_fwd_ker_t<sse41>;
1500 template struct jit_uni_i8i8_pooling_fwd_t<sse41>;
1501
1502 } // namespace x64
1503 } // namespace cpu
1504 } // namespace impl
1505 } // namespace dnnl
1506