1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/dnnl_thread.hpp"
18 #include "common/nstl.hpp"
19 #include "common/utils.hpp"
20 
21 #include "cpu/cpu_primitive.hpp"
22 #include "cpu/zero_point_utils.hpp"
23 
24 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
25 #include "cpu/x64/jit_uni_deconv_zp_pad_str_kernel.hpp"
26 #include "cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp"
27 
28 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
29 
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34 
35 using namespace dnnl::impl::status;
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::utils;
38 using namespace Xbyak;
39 
40 using namespace nstl;
41 
42 #define wht_blk_off(d, g, ...) \
43     (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
44                          : (d).blk_off(__VA_ARGS__))
45 
46 template <cpu_isa_t isa>
init_conf(jit_conv_conf_t & jcp,const deconvolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,const bool with_bias,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)47 status_t jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(
48         jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
49         memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
50         const bool with_bias, memory_desc_t &bias_md, primitive_attr_t &attr,
51         int nthreads) {
52     const memory_desc_wrapper src_d(&src_md);
53     const memory_desc_wrapper dst_d(&dst_md);
54     const memory_desc_wrapper weights_d(&weights_md);
55     const memory_desc_wrapper bias_d(&bias_md);
56 
57     if (!(mayiuse(isa)
58                 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
59                 && weights_d.data_type() == data_type::s8
60                 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
61                         data_type::s8, data_type::u8)))
62         return status::unimplemented;
63 
64     jcp = zero<decltype(jcp)>();
65     jcp.nthr = nthreads;
66 
67     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
68     jcp.signed_input = src_d.data_type() == data_type::s8;
69     const int ndims = jcp.ndims = dst_d.ndims();
70     const bool is_1d = ndims == 3;
71     const bool is_2d = ndims == 4;
72     const bool is_3d = ndims == 5;
73     const bool is_avx2 = isa == avx2;
74 
75     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
76     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
77     jcp.ic = src_d.dims()[1] / jcp.ngroups;
78     jcp.id = is_3d ? src_d.dims()[2] : 1;
79     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
80     jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
81     jcp.is_depthwise = true && with_groups
82             && utils::everyone_is(
83                     1, jcp.ic_without_padding, jcp.oc_without_padding);
84     jcp.ver = mayiuse(avx2_vnni) ? ver_vnni : ver_unused;
85 
86     /* TODO: future work, on hold until depthwise specialized kernel is
87      * implemented. */
88     if (jcp.is_depthwise && (jcp.signed_input || is_3d))
89         return status::unimplemented;
90 
91     if (!zero_points_valid(&attr)) return status::unimplemented;
92     jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC);
93     jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST);
94     jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC);
95 
96     format_tag_t dat_tag = utils::pick(
97             ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
98 
99     if (src_d.format_kind() == format_kind::any) {
100         CHECK(memory_desc_init_by_tag(src_md, dat_tag));
101         jcp.src_tag = dat_tag;
102     } else {
103         jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
104     }
105     if (jcp.src_tag != dat_tag) return status::unimplemented;
106 
107     if (dst_d.format_kind() == format_kind::any) {
108         CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
109         jcp.dst_tag = dat_tag;
110     } else {
111         jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
112     }
113     if (jcp.dst_tag != dat_tag) return status::unimplemented;
114 
115     auto set_or_check_wei_format = [&]() {
116         using namespace format_tag;
117         format_tag_t wei_tag;
118         if (jcp.ic_block == 8 || jcp.ch_block == 8) {
119             if (is_1d) {
120                 wei_tag = with_groups ? jcp.is_depthwise ? Goiw8g : gOIw2i8o4i
121                                       : OIw2i8o4i;
122             } else if (is_2d) {
123                 wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i
124                                       : OIhw2i8o4i;
125             } else {
126                 wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i;
127             }
128         } else {
129             if (is_avx2) {
130                 assert(with_groups && jcp.ic_block == 4);
131                 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
132             } else {
133                 if (is_1d) {
134                     wei_tag = with_groups ? jcp.is_depthwise ? Goiw4g : gOIw4o4i
135                                           : OIw4o4i;
136                 } else if (is_2d) {
137                     wei_tag = with_groups
138                             ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i
139                             : OIhw4o4i;
140                 } else {
141                     wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i;
142                 }
143             }
144         }
145 
146         memory_desc_t want_wei_md = weights_md;
147         memory_desc_init_by_tag(want_wei_md, wei_tag);
148         if (jcp.signed_input && !jcp.is_depthwise) {
149             want_wei_md.extra.flags = 0
150                     | memory_extra_flags::compensation_conv_s8s8
151                     | memory_extra_flags::scale_adjust;
152             want_wei_md.extra.compensation_mask = (1 << 0)
153                     + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
154             want_wei_md.extra.scale_adjust = (jcp.ver == ver_vnni) ? 1.f : 0.5f;
155         }
156         if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups);
157 
158         if (weights_md.format_kind == format_kind::any) {
159             weights_md = want_wei_md;
160             return true;
161         }
162 
163         return weights_md == want_wei_md;
164     };
165 
166     jcp.with_bias = with_bias;
167     if (jcp.with_bias) {
168         if (bias_d.format_kind() == format_kind::any)
169             CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
170     }
171 
172     jcp.prop_kind = cd.prop_kind;
173     jcp.mb = src_d.dims()[0];
174     jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
175     jcp.iw = src_d.dims()[ndims - 1];
176     jcp.od = is_3d ? dst_d.dims()[2] : 1;
177     jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
178     jcp.ow = dst_d.dims()[ndims - 1];
179     jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
180     jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
181     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
182     jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
183     jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
184     jcp.l_pad = cd.padding[0][ndims - 3];
185     jcp.stride_d = is_3d ? cd.strides[0] : 1;
186     jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
187     jcp.stride_w = cd.strides[ndims - 3];
188 
189     if (jcp.is_depthwise) {
190         jcp.ch_block = is_avx2 ? 8 : 4;
191         jcp.oc_block = 1;
192         jcp.ic_block = 1;
193     } else {
194         jcp.ch_block = 1;
195         jcp.oc_block = is_avx2 ? 8 : 4;
196         jcp.ic_block = is_avx2 ? 8 : 4;
197 
198         if (jcp.ngroups == 1) {
199             jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
200             jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
201         } else if (jcp.ngroups != 1
202                 && ((jcp.ic % jcp.ic_block != 0)
203                         || (jcp.oc % jcp.oc_block != 0))) {
204             /* For grouped convolution, oneDNN doesn't support padding.
205              * When channel per group is not multiple of 8 in avx2:
206              * - Use Xmm when channels per groups is multiple of 4.
207              * - Otherwise return unimplemented */
208             jcp.oc_block = jcp.ic_block = 4;
209         }
210         if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
211             return status::unimplemented;
212     }
213 
214     if (!set_or_check_wei_format()) return status::unimplemented;
215 
216     jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
217     jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
218     jcp.dilate_w = cd.dilates[ndims - 3];
219 
220     if (!IMPLICATION(jcp.dilate_d, jcp.stride_d == 1)
221             || !IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
222             || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
223         return status::unimplemented;
224 
225     const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
226     const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
227     const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
228     jcp.r_pad = calculate_end_padding(
229             jcp.l_pad, jcp.iw, jcp.ow, jcp.stride_w, ext_kw);
230     jcp.b_pad = calculate_end_padding(
231             jcp.t_pad, jcp.ih, jcp.oh, jcp.stride_h, ext_kh);
232     jcp.back_pad = calculate_end_padding(
233             jcp.f_pad, jcp.id, jcp.od, jcp.stride_d, ext_kd);
234     const bool kernel_outside_src = false || ext_kw <= jcp.l_pad
235             || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
236             || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
237     if (kernel_outside_src) return status::unimplemented;
238 
239     CHECK(attr.set_default_formats(&dst_md));
240     if (!post_ops_ok(jcp, dst_d, attr)) return status::unimplemented;
241 
242     const auto &p = attr.post_ops_;
243     const int eltwise_ind = p.find(primitive_kind::eltwise);
244     jcp.with_eltwise = eltwise_ind != -1;
245     if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
246 
247     const int binary_ind = p.find(primitive_kind::binary);
248     jcp.with_binary = binary_ind != -1;
249 
250     const int sum_ind = p.find(primitive_kind::sum);
251     jcp.with_sum = sum_ind != -1;
252 
253     const auto &oscales = attr.output_scales_;
254     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
255 
256     jcp.post_ops = p;
257 
258     // only common and per-oc-channel scales are supported
259     const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
260     if (!oscales_ok) return status::unimplemented;
261 
262     jcp.dst_dt = dst_d.data_type();
263     jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
264     jcp.typesize_bia
265             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
266     jcp.typesize_in = types::data_type_size(src_d.data_type());
267     jcp.typesize_out = types::data_type_size(dst_d.data_type());
268 
269     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
270     jcp.nb_oc = jcp.oc / jcp.oc_block;
271     jcp.nb_ic = jcp.ic / jcp.ic_block;
272 
273     /* kernel blocking params */
274     const int regs = (jcp.ver == ver_vnni ? 14 : 12);
275 
276     jcp.nb_ch_blocking = 1;
277     jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
278     for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
279         if (jcp.nb_oc % jcp.nb_oc_blocking == 0
280                 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
281             break;
282 
283     jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
284     const int l_overflow = max(
285             0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
286 
287     if (jcp.ow < jcp.ur_w) {
288         jcp.ur_w = jcp.ow;
289         jcp.ur_w_tail = 0;
290     } else {
291         for (; jcp.ur_w >= 1; jcp.ur_w--) {
292             /* ur_w should be multiple of stride_w in order
293                to simplify logic for get_ow_start and get_ow_end */
294             const bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
295 
296             /* boundary conditions:
297                These conditions ensure all elements close to boundary
298                are computed in a single call of compute loop */
299             const bool left_boundary_covered
300                     = jcp.ur_w >= l_overflow * jcp.stride_w;
301             jcp.ur_w_tail = jcp.ow % jcp.ur_w;
302             const int r_overflow_no_tail = max(0,
303                     ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)
304                             - jcp.ur_w_tail)
305                             / jcp.stride_w);
306             const bool right_boundary_covered
307                     = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
308 
309             if (is_multiple_of_stride && left_boundary_covered
310                     && right_boundary_covered)
311                 break;
312             else if (jcp.ur_w == 1)
313                 /* The boundary conditions above are also important
314                    to maintain simplicity of calls to icb_loop,
315                    if those conditions are not satisfied,
316                    then special cases will need to be added
317                    to use correct l_overflow/r_overflow values
318                    when different iterations of compute loop
319                    work on the locations close to boundary.
320                    So to keep code simple, return unimplemented
321                    for extreme case when a good ur_w cannot be found.
322                  */
323                 return status::unimplemented;
324         }
325     }
326 
327     jcp.wei_adj_scale
328             = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
329             ? weights_d.extra().scale_adjust
330             : 1.f;
331 
332     jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
333     return status::success;
334 }
335 
336 template <cpu_isa_t isa>
jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_wrapper & dst_d)337 jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::jit_uni_x8s8s32x_deconv_fwd_kernel(
338         const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
339         const memory_desc_wrapper &dst_d)
340     : kernel_(nullptr) {
341 
342     const int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
343     switch (ch_block) {
344         case 8:
345             if (isa == avx2) {
346                 kernel_ = utils::make_unique<
347                         _jit_avx2_x8s8s32x_deconv_fwd_kernel>(
348                         ajcp, attr, dst_d);
349                 return;
350             } else
351                 assert(!"invalid channel blocking for current ISA");
352         case 4:
353             kernel_ = utils::make_unique<
354                     _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Xbyak::Xmm>>(
355                     ajcp, attr, dst_d);
356             return;
357         default: assert(!"invalid channel blocking");
358     }
359 }
360 
361 template <cpu_isa_t isa>
362 jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::~jit_uni_x8s8s32x_deconv_fwd_kernel()
363         = default;
364 
365 template <cpu_isa_t isa>
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)366 void jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad(
367         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
368         const primitive_attr_t &attr) {
369     if (jcp.signed_input && jcp.ver != ver_vnni) {
370         dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 8);
371         scratchpad.book<float>(key_conv_adjusted_scales, count);
372     }
373 
374     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) {
375         const dim_t zp_pad_comp_size = jcp.oc_without_padding * jcp.ngroups
376                 * jcp.kd * jcp.kh * jcp.kw;
377         scratchpad.book<int32_t>(key_deconv_zp, zp_pad_comp_size);
378     }
379 }
380 
381 template <cpu_isa_t isa>
post_ops_ok(jit_conv_conf_t & jcp,const memory_desc_wrapper & dst_d,const primitive_attr_t & attr)382 bool jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::post_ops_ok(jit_conv_conf_t &jcp,
383         const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
384     using namespace injector;
385 
386     return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary},
387             attr.post_ops_, &dst_d, false /*sum_at_pos_0_only*/,
388             false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/,
389             {broadcasting_strategy_t::per_oc,
390                     broadcasting_strategy_t::scalar}));
391 }
392 
393 template <cpu_isa_t isa, typename Vmm>
394 _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
_jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_wrapper & dst_d)395         Vmm>::_jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
396         const primitive_attr_t &attr, const memory_desc_wrapper &dst_d)
397     : jit_generator(nullptr, MAX_CODE_SIZE, true, isa)
398     , jcp_(ajcp)
399     , postops_injector_(nullptr) {
400 
401     if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) {
402         const std::size_t tail_size = get_tail_size();
403 
404         static constexpr bool preserve_gpr = true;
405         static constexpr bool preserve_vmm = true;
406         static constexpr bool use_exact_tail_scalar_bcast = false;
407         static constexpr size_t vmm_helper_idx = 15;
408 
409         const binary_injector::rhs_arg_static_params_t rhs_sp {vmm_helper_idx,
410                 this->r14, this->r15, preserve_gpr, preserve_vmm,
411                 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d,
412                 tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast};
413         const binary_injector::static_params_t bsp {this->param1_, rhs_sp};
414 
415         postops_injector_ = utils::make_unique<
416                 injector::jit_uni_postops_injector_t<isa, Vmm>>(
417                 this, jcp_.post_ops, bsp);
418     }
419 }
420 
421 template <cpu_isa_t isa, typename Vmm>
422 _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
423         Vmm>::~_jit_uni_x8s8s32x_deconv_fwd_kernel()
424         = default;
425 
426 template <cpu_isa_t isa, typename Vmm>
vmm_out(int i_ur,int i_oc) const427 Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_out(
428         int i_ur, int i_oc) const {
429     const int idx = i_ur * jcp_.nb_oc_blocking + i_oc;
430     assert(idx < KER_MAX_REG_IDX);
431     /* remap the reg indices to avoid using xmm0 in eltwise injector */
432     return Vmm(15 - idx);
433 }
434 
435 template <cpu_isa_t isa, typename Vmm>
vmm_inp(int i_ic,int nb_x_blocking) const436 Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_inp(
437         int i_ic, int nb_x_blocking) const {
438     const int idx = i_ic + nb_x_blocking * jcp_.ur_w;
439     assert(idx < KER_MAX_REG_IDX);
440     return Vmm(15 - idx);
441 }
442 
443 template <cpu_isa_t isa, typename Vmm>
vmm_bias_alpha() const444 Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_bias_alpha() const {
445     return Vmm(15 - jcp_.nb_oc_blocking * jcp_.ur_w);
446 }
447 
448 template <cpu_isa_t isa, typename Vmm>
xmm_bias_alpha() const449 Xmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::xmm_bias_alpha() const {
450     return Xmm(vmm_bias_alpha().getIdx());
451 }
452 
453 template <cpu_isa_t isa, typename Vmm>
get_ow_start(int ki,int l_overflow) const454 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_start(
455         int ki, int l_overflow) const noexcept {
456     int res = (jcp_.ow - 1 + jcp_.r_pad) % jcp_.stride_w
457             + l_overflow * jcp_.stride_w
458             - (jcp_.kw - 1 - ki) * (jcp_.dilate_w + 1);
459     while (res < 0)
460         res += jcp_.stride_w;
461     return res;
462 }
463 
464 template <cpu_isa_t isa, typename Vmm>
get_ow_end(int ur_w,int ki,int r_overflow) const465 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_end(
466         int ur_w, int ki, int r_overflow) const noexcept {
467     if (utils::one_of(ur_w, jcp_.ow, jcp_.ur_w_tail))
468         ur_w += nstl::min(0, jcp_.r_pad); // remove negative padding
469     int res = (ur_w - 1 + jcp_.l_pad) % jcp_.stride_w
470             + r_overflow * jcp_.stride_w - ki * (jcp_.dilate_w + 1);
471     while (res < 0)
472         res += jcp_.stride_w;
473     return ur_w - res;
474 }
475 
476 template <cpu_isa_t isa, typename Vmm>
get_blocking_size() const477 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_blocking_size() const
478         noexcept {
479     return jcp_.is_depthwise ? jcp_.ch_block : jcp_.oc_block;
480 }
481 
482 template <cpu_isa_t isa, typename Vmm>
get_tail_size() const483 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_tail_size() const
484         noexcept {
485     return jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block
486                              : jcp_.oc_without_padding % jcp_.oc_block;
487 }
488 
489 template <cpu_isa_t isa, typename Vmm>
compute(const Vmm & vreg_acc,const Vmm & vreg_wei,const Vmm & vreg_src)490 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute(
491         const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src) {
492 
493     if (jcp_.ver == ver_vnni) {
494         vpdpbusd(vreg_acc, vreg_src, vreg_wei, Xbyak::VexEncoding);
495     } else if (jcp_.is_depthwise) {
496         uni_vmovups(vmm_tmp_, vreg_src);
497         uni_vpmulld(vmm_tmp_, vmm_tmp_, vreg_wei);
498         uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_);
499     } else {
500         uni_vpmaddubsw(vmm_tmp_, vreg_src, vreg_wei);
501         uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_);
502         uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_);
503     }
504 }
505 
506 template <cpu_isa_t isa, typename Vmm>
507 std::function<Vmm()> _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
prepare_round_robin_vmm_inp_generator(int ur_w) const508         Vmm>::prepare_round_robin_vmm_inp_generator(int ur_w) const noexcept {
509 
510     const int start_vmm_idx = vmm_inp(ur_w - 1, jcp_.nb_oc_blocking).getIdx();
511     const int end_vmm_idx = vmm_inp(0, jcp_.nb_oc_blocking).getIdx() + 1;
512     int current_vmm_idx = start_vmm_idx;
513 
514     return [=]() mutable {
515         const Vmm vmm {static_cast<int>(current_vmm_idx++)};
516 
517         if (current_vmm_idx == end_vmm_idx) current_vmm_idx = start_vmm_idx;
518 
519         return vmm;
520     };
521 }
522 
523 template <cpu_isa_t isa, typename Vmm>
apply_zp_src_pad_str_comp(int ur_w,int l_overflow,int r_overflow,bool h_padded)524 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_zp_src_pad_str_comp(
525         int ur_w, int l_overflow, int r_overflow, bool h_padded) {
526     Xbyak::Label end_zp_pad, no_tail;
527 
528     // apply once per icb loop, zp src stride paddding compensation calculate as
529     // zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source
530     cmp(reg_icb_, jcp_.nb_ic);
531     jne(end_zp_pad, T_NEAR);
532 
533     if (jcp_.ngroups % jcp_.ch_block
534             || jcp_.oc_without_padding % jcp_.oc_block) {
535         if (jcp_.is_depthwise)
536             cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
537         else
538             cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking);
539         jne(no_tail, T_NEAR);
540 
541         append_zp_src_pad_str_comp(
542                 ur_w, l_overflow, r_overflow, h_padded, true /*last_oc_block*/);
543         jmp(end_zp_pad, T_NEAR);
544     }
545 
546     L(no_tail);
547     append_zp_src_pad_str_comp(
548             ur_w, l_overflow, r_overflow, h_padded, false /*last_oc_block*/);
549 
550     L(end_zp_pad);
551 }
552 
553 template <cpu_isa_t isa, typename Vmm>
append_zp_src_pad_str_comp(int ur_w,int l_overflow,int r_overflow,bool h_padded,bool last_oc_block)554 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::append_zp_src_pad_str_comp(
555         int ur_w, int l_overflow, int r_overflow, bool h_padded,
556         bool last_oc_block) {
557 
558     const auto &reg_zp_src_pad_comp = reg_scratch_;
559     const auto get_next_comp_vmm = prepare_round_robin_vmm_inp_generator(ur_w);
560     bool base_comp_addr_loaded = false;
561 
562     const auto load_base_zp_src_pad_comp_addr = [&]() {
563         if (!base_comp_addr_loaded) {
564             if (jcp_.ndims == 5) mov(reg_scratch_preserved_, reg_scratch_);
565 
566             if (jcp_.ndims > 3)
567                 mov(reg_zp_src_pad_comp, zp_src_pad_comp_addr_);
568             else
569                 mov(reg_zp_src_pad_comp,
570                         qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]);
571 
572             base_comp_addr_loaded = true;
573         }
574     };
575 
576     const auto load_zp_src_pad_comp = [&](const Vmm &zp_pad_comp_vmm,
577                                               const Xbyak::Address &comp_addr,
578                                               const int ocb) {
579         const bool is_tail = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
580 
581         if (is_tail)
582             load_data(data_type::s32, zp_pad_comp_vmm, comp_addr,
583                     get_tail_size());
584         else
585             uni_vmovups(zp_pad_comp_vmm, comp_addr);
586     };
587 
588     const auto get_zp_src_comp_pad_off = [&](int it_kw, int ocb) {
589         const auto kw_offset = it_kw * jcp_.oc_without_padding * jcp_.ngroups;
590         const auto oc_offset = ocb * jcp_.oc_block;
591 
592         return (kw_offset + oc_offset) * sizeof(int32_t);
593     };
594 
595     for (int it_kw = 0; it_kw < jcp_.kw; ++it_kw) {
596         const int ow_start = get_ow_start(it_kw, l_overflow);
597         const int ow_end = get_ow_end(ur_w, it_kw, r_overflow);
598 
599         for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
600             Vmm zp_src_comp_pad_vmm; // will be assigned later
601             bool ocb_zp_loaded = false;
602 
603             const auto zp_src_comp_pad_off
604                     = get_zp_src_comp_pad_off(it_kw, ocb);
605 
606             for (int it_ow = 0; it_ow < ur_w; ++it_ow) {
607 
608                 const bool inside_padded_area = h_padded
609                         || !(it_ow >= ow_start && it_ow < ow_end
610                                 && ((it_ow + jcp_.l_pad - it_kw) % jcp_.stride_w
611                                         == 0));
612 
613                 if (inside_padded_area) {
614                     load_base_zp_src_pad_comp_addr();
615 
616                     if (!ocb_zp_loaded) {
617                         zp_src_comp_pad_vmm = get_next_comp_vmm();
618                         const auto comp_addr = ptr[reg_zp_src_pad_comp
619                                 + zp_src_comp_pad_off];
620                         load_zp_src_pad_comp(
621                                 zp_src_comp_pad_vmm, comp_addr, ocb);
622                         ocb_zp_loaded = true;
623                     }
624 
625                     const auto vmm_dst = vmm_out(it_ow, ocb);
626                     uni_vpaddd(vmm_dst, vmm_dst, zp_src_comp_pad_vmm);
627                 }
628             }
629         }
630     }
631 
632     if (jcp_.ndims > 3) {
633         if (!base_comp_addr_loaded) load_base_zp_src_pad_comp_addr();
634 
635         const auto kh_offset = jcp_.kw * jcp_.oc_without_padding * jcp_.ngroups
636                 * sizeof(int32_t);
637 
638         add(reg_zp_src_pad_comp, kh_offset);
639         mov(zp_src_pad_comp_addr_, reg_zp_src_pad_comp);
640     }
641 
642     if (jcp_.ndims == 5 && base_comp_addr_loaded)
643         mov(reg_scratch_, reg_scratch_preserved_);
644 }
645 
646 template <cpu_isa_t isa, typename Vmm>
compute_ker(int ur_w,int l_overflow,int r_overflow,ker_block_t last_ic_block_flag,bool h_padded)647 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute_ker(int ur_w,
648         int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
649         bool h_padded) {
650 
651     const bool signed_input_or_src_zp
652             = (jcp_.signed_input || jcp_.src_zero_point);
653 
654     const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block;
655     const int ur_w_stride = signed_input_or_src_zp ? 1 : jcp_.stride_w;
656 
657     const auto src_offset = [=](int oj, int icb, int ki) {
658         return jcp_.typesize_in
659                 * (((oj + jcp_.l_pad - ki * (jcp_.dilate_w + 1))
660                            / jcp_.stride_w)
661                                 * jcp_.ngroups * jcp_.ic_without_padding
662                         + icb * 4);
663     };
664 
665     const auto kernel_offset = [=](int ocb, int icb, int ki) {
666         return jcp_.typesize_in
667                 * ((ocb * jcp_.nb_ic * jcp_.kd * jcp_.kh * jcp_.kw + ki)
668                                 * ch_block_all
669                         + icb * jcp_.oc_block * IC_SUB_STEP);
670     };
671 
672     for (int ki = 0; ki < jcp_.kw; ki++) {
673 
674         const int jj_start = get_ow_start(ki, l_overflow);
675         const int jj_end = get_ow_end(ur_w, ki, r_overflow);
676 
677         const int _start = (signed_input_or_src_zp) ? 0 : jj_start;
678         const int _end = (signed_input_or_src_zp) ? ur_w : jj_end;
679 
680         const int tail_size = jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block
681                                                 : jcp_.ic_without_padding % 4;
682         const int n_ic_blocks = jcp_.is_depthwise
683                 ? 1
684                 : (last_ic_block_flag != no_last_block ? div_up(
685                            jcp_.ic_without_padding % jcp_.ic_block, 4)
686                                                        : jcp_.ic_block / 4);
687 
688         for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
689             if (h_padded) {
690                 /* fill padded area with shifted values */
691                 if (jcp_.signed_input) {
692                     const Vmm inp = vmm_inp(0, jcp_.nb_oc_blocking);
693                     uni_vpxor(inp, inp, inp);
694                     uni_vpsubb(inp, inp, vmm_shift_);
695                 }
696             } else {
697 
698                 for (int jj = _start; jj < _end; jj += ur_w_stride) {
699 
700                     const int aux_src_off = src_offset(jj, icb1, ki);
701                     const auto vmm_src = vmm_inp(jj, jcp_.nb_oc_blocking);
702 
703                     if (jj >= jj_start && jj < jj_end
704                             && ((jj + jcp_.l_pad - ki) % jcp_.stride_w == 0)) {
705                         if (jcp_.is_depthwise) {
706                             if (tail_size != 0)
707                                 assert(jcp_.nb_oc_blocking == 1);
708                             uni_vpxor(vmm_src, vmm_src, vmm_src);
709                             const bool mask_flag
710                                     = last_ic_block_flag != no_last_block
711                                     && tail_size;
712                             load_data(data_type::u8, vmm_src, aux_reg_src_,
713                                     aux_src_off,
714                                     mask_flag ? tail_size : jcp_.ch_block);
715                         } else if ((last_ic_block_flag == last_sp_block)
716                                 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
717                             const auto vmm_inp_tmp = Xmm(vmm_src.getIdx());
718                             load_bytes(vmm_inp_tmp, aux_reg_src_, aux_src_off,
719                                     tail_size);
720                             uni_vpbroadcastd(vmm_src, vmm_inp_tmp);
721                         } else {
722                             uni_vpbroadcastd(
723                                     vmm_src, ptr[aux_reg_src_ + aux_src_off]);
724                         }
725                         if (jcp_.signed_input)
726                             uni_vpsubb(vmm_src, vmm_src, vmm_shift_);
727                     } else {
728                         /* fill padded area with shifted values */
729                         if (jcp_.signed_input) {
730                             uni_vpxor(vmm_src, vmm_src, vmm_src);
731                             uni_vpsubb(vmm_src, vmm_src, vmm_shift_);
732                         }
733                     }
734                 }
735             }
736             for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
737                 const int aux_filt_off = kernel_offset(ocb, icb1, ki);
738 
739                 if (_end - _start > 0) {
740                     if (jcp_.is_depthwise) {
741                         uni_vpmovsxbd(
742                                 vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]);
743                     } else
744                         uni_vmovups(
745                                 vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]);
746                 }
747 
748                 for (int jj = _start; jj < _end; jj += ur_w_stride) {
749 
750                     const bool inside_padded_area = h_padded
751                             || !(jj >= jj_start && jj < jj_end
752                                     && ((jj + jcp_.l_pad - ki) % jcp_.stride_w
753                                             == 0));
754                     const auto vmm_dst = vmm_out(jj, ocb);
755                     if (jcp_.signed_input || !inside_padded_area) {
756                         const Vmm inp = vmm_inp(
757                                 h_padded ? 0 : jj, jcp_.nb_oc_blocking);
758                         compute(vmm_dst, vmm_wei_, inp);
759                     }
760                 }
761             }
762         }
763     }
764 
765     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
766         apply_zp_src_pad_str_comp(ur_w, l_overflow, r_overflow, h_padded);
767 }
768 
769 template <cpu_isa_t isa, typename Vmm>
kh_loop(int ur_w,int l_overflow,int r_overflow,ker_block_t last_ic_block_flag)770 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::kh_loop(int ur_w,
771         int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
772 
773     const bool signed_input_or_src_zp
774             = (jcp_.signed_input || jcp_.src_zero_point);
775 
776     const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block;
777     const int shift_src_ih = jcp_.typesize_in * (jcp_.dilate_h + 1) * jcp_.iw
778             * jcp_.ngroups * jcp_.ic_without_padding;
779     const int shift_src_id = jcp_.typesize_in * (jcp_.dilate_d + 1) * jcp_.ih
780             * jcp_.iw * jcp_.ngroups * jcp_.ic_without_padding;
781     const int stride_h = signed_input_or_src_zp ? 1 : jcp_.stride_h;
782     const int shift_filt_kh
783             = jcp_.typesize_in * jcp_.kw * ch_block_all * stride_h;
784     const int stride_d = signed_input_or_src_zp ? 1 : jcp_.stride_d;
785     const int shift_filt_kd
786             = jcp_.typesize_in * jcp_.kw * ch_block_all * jcp_.kh * stride_d;
787 
788     Label kd_loop_label, kh_loop_label, skip_kh_loop, skip_kd_loop;
789     Label t_overflow_label, no_t_overflow_label, b_overflow_label,
790             no_b_overflow_label;
791     Label back_overflow_label, no_back_overflow_label, d_h_overflow_label,
792             front_overflow_label, no_front_overflow_label, d_h_overflow_label2;
793     if (jcp_.ndims == 5) {
794         mov(aux_reg_filt_d_, reg_filt_);
795         mov(aux_reg_src_d_, reg_src_);
796 
797         if (signed_input_or_src_zp) {
798             mov(reg_ki_, ptr[param1_ + GET_OFF(back_overflow)]);
799             cmp(reg_ki_, 0);
800             je(no_back_overflow_label, T_NEAR);
801 
802             L(back_overflow_label);
803             {
804                 mov(aux_reg_filt_, aux_reg_filt_d_);
805                 mov(reg_kh_, jcp_.kh);
806                 L(d_h_overflow_label);
807                 {
808                     compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
809                     add(aux_reg_filt_, shift_filt_kh);
810                     dec(reg_kh_);
811                     jnz(d_h_overflow_label);
812                 }
813 
814                 add(aux_reg_filt_d_, shift_filt_kd);
815                 dec(reg_ki_);
816                 jnz(back_overflow_label);
817             }
818             L(no_back_overflow_label);
819         }
820 
821         mov(reg_ki_, ptr[param1_ + GET_OFF(kd_padding)]);
822 
823         if ((signed_input_or_src_zp) || (jcp_.dilate_d >= jcp_.id)
824                 || ((!signed_input_or_src_zp)
825                         && ((min(jcp_.f_pad, jcp_.back_pad) < 0)
826                                 || ((jcp_.kd - 1) * (jcp_.dilate_d + 1)
827                                         < nstl::max(
828                                                 jcp_.f_pad, jcp_.back_pad))))) {
829             cmp(reg_ki_, 0);
830             je(skip_kd_loop, T_NEAR);
831         }
832 
833         L(kd_loop_label);
834         mov(aux_reg_src_, aux_reg_src_d_);
835         mov(aux_reg_filt_, aux_reg_filt_d_);
836     } else {
837         mov(aux_reg_src_, reg_src_);
838         mov(aux_reg_filt_, reg_filt_);
839     }
840 
841     if (signed_input_or_src_zp && jcp_.ndims > 3) {
842         /* Weights are transposed, so first compute 'bottom' padding. */
843         mov(reg_overflow_, ptr[param1_ + GET_OFF(b_overflow)]);
844         cmp(reg_overflow_, 0);
845         je(no_b_overflow_label, T_NEAR);
846         L(b_overflow_label);
847         {
848             compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
849 
850             add(aux_reg_filt_, shift_filt_kh);
851             dec(reg_overflow_);
852             cmp(reg_overflow_, 0);
853             jg(b_overflow_label, T_NEAR);
854         }
855         L(no_b_overflow_label);
856     }
857 
858     mov(reg_kh_, ptr[param1_ + GET_OFF(kh_padding)]);
859 
860     if ((signed_input_or_src_zp) || (jcp_.dilate_h >= jcp_.ih)
861             || ((!signed_input_or_src_zp)
862                     && ((min(jcp_.t_pad, jcp_.b_pad) < 0)
863                             || ((jcp_.kh - 1) * (jcp_.dilate_h + 1)
864                                     < nstl::max(jcp_.t_pad, jcp_.b_pad))))) {
865         cmp(reg_kh_, 0);
866         je(skip_kh_loop, T_NEAR);
867     }
868 
869     L(kh_loop_label);
870     {
871         compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
872         sub(aux_reg_src_, shift_src_ih);
873         add(aux_reg_filt_, shift_filt_kh);
874         dec(reg_kh_);
875 
876         /* Insert weight compensation in stride 'holes' */
877         if (signed_input_or_src_zp && jcp_.stride_h > 1) {
878             Label kh_comp_loop;
879 
880             cmp(reg_kh_, 0);
881             je(skip_kh_loop, T_NEAR);
882             mov(reg_comp_strides_, jcp_.stride_h - 1);
883             L(kh_comp_loop);
884             {
885                 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
886                 add(aux_reg_filt_, shift_filt_kh);
887                 dec(reg_comp_strides_);
888                 cmp(reg_comp_strides_, 0);
889                 jg(kh_comp_loop, T_NEAR);
890             }
891         }
892         cmp(reg_kh_, 0);
893         jg(kh_loop_label, T_NEAR);
894     }
895     L(skip_kh_loop);
896     if (signed_input_or_src_zp && jcp_.ndims > 3) {
897         mov(reg_overflow_, ptr[param1_ + GET_OFF(t_overflow)]);
898         cmp(reg_overflow_, 0);
899         je(no_t_overflow_label, T_NEAR);
900         L(t_overflow_label);
901         {
902             compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
903 
904             add(aux_reg_filt_, shift_filt_kh);
905             dec(reg_overflow_);
906             cmp(reg_overflow_, 0);
907             jg(t_overflow_label, T_NEAR);
908         }
909         L(no_t_overflow_label);
910     }
911 
912     if (jcp_.ndims == 5) {
913         sub(aux_reg_src_d_, shift_src_id);
914         add(aux_reg_filt_d_, shift_filt_kd);
915         dec(reg_ki_);
916 
917         /* Insert weight compensation in stride 'holes' */
918         if (signed_input_or_src_zp && jcp_.stride_d > 1) {
919             Label kd_comp_loop, kd_kh_comp_loop;
920             cmp(reg_ki_, 0);
921             jz(skip_kd_loop, T_NEAR);
922             mov(reg_comp_strides_, jcp_.stride_d - 1);
923             L(kd_comp_loop);
924             mov(aux_reg_filt_, aux_reg_filt_d_);
925             mov(reg_kh_, jcp_.kh);
926             L(kd_kh_comp_loop);
927             {
928                 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
929                 add(aux_reg_filt_, shift_filt_kh);
930                 dec(reg_kh_);
931                 jnz(kd_kh_comp_loop, T_NEAR);
932             }
933             add(aux_reg_filt_d_, shift_filt_kd);
934             dec(reg_comp_strides_);
935             jnz(kd_comp_loop);
936         }
937 
938         cmp(reg_ki_, 0);
939         jg(kd_loop_label, T_NEAR);
940         L(skip_kd_loop);
941         if (signed_input_or_src_zp) {
942             mov(reg_ki_, ptr[param1_ + GET_OFF(f_overflow)]);
943             cmp(reg_ki_, 0);
944             jz(no_front_overflow_label, T_NEAR);
945             L(front_overflow_label);
946             {
947                 mov(aux_reg_filt_, aux_reg_filt_d_);
948                 mov(reg_kh_, jcp_.kh);
949                 L(d_h_overflow_label2);
950                 {
951                     compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
952                     add(aux_reg_filt_, shift_filt_kh);
953                     dec(reg_kh_);
954                     jnz(d_h_overflow_label2);
955                 }
956                 add(aux_reg_filt_d_, shift_filt_kd);
957                 dec(reg_ki_);
958                 jnz(front_overflow_label);
959             }
960             L(no_front_overflow_label);
961         }
962     }
963 }
964 
965 template <cpu_isa_t isa, typename Vmm>
prepare_output(int ur_w)966 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::prepare_output(int ur_w) {
967     for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
968         for (int ur = 0; ur < ur_w; ur++) {
969             const Vmm vmm = vmm_out(ur, ocb);
970             uni_vpxor(vmm, vmm, vmm);
971         }
972     }
973     if (jcp_.signed_input) {
974         const auto xmm_shift = Xbyak::Xmm(vmm_shift_.getIdx());
975         mov(reg_scratch_, 0x80808080);
976         uni_vmovq(xmm_shift, reg_scratch_);
977         uni_vpbroadcastd(vmm_shift_, xmm_shift);
978     }
979 }
980 
981 template <cpu_isa_t isa, typename Vmm>
cvt2ps(data_type_t type_in,const Vmm & vmm_in,const Reg64 & reg,int offset,int load_size)982 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::cvt2ps(data_type_t type_in,
983         const Vmm &vmm_in, const Reg64 &reg, int offset, int load_size) {
984 
985     load_data(type_in, vmm_in, reg, offset, load_size);
986     if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in);
987 }
988 
989 template <cpu_isa_t isa, typename Vmm>
apply_postops(int ur_w,bool last_oc_block,const float * p_sum_scale,const int32_t * p_sum_zp)990 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_postops(int ur_w,
991         bool last_oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) {
992     const auto sum_injector = [=]() {
993         if (p_sum_scale) { // post_op: sum
994             for (int k = 0; k < jcp_.nb_oc_blocking; k++) {
995                 const bool mask_flag
996                         = last_oc_block == 1 && k == jcp_.nb_oc_blocking - 1;
997                 for (int j = 0; j < ur_w; j++) {
998                     const int aux_output_offset = jcp_.typesize_out
999                             * (k * jcp_.oc_block
1000                                     + j * jcp_.oc_without_padding
1001                                             * jcp_.ngroups);
1002                     cvt2ps(jcp_.dst_dt, vmm_prev_dst_, reg_dst_,
1003                             aux_output_offset,
1004                             mask_flag ? get_tail_size() : get_blocking_size());
1005                     if (*p_sum_zp != 0) {
1006                         uni_vbroadcastss(vmm_sum_zp_, ptr[reg_ptr_sum_zp_]);
1007                         uni_vcvtdq2ps(vmm_sum_zp_, vmm_sum_zp_);
1008                         uni_vsubps(vmm_prev_dst_, vmm_prev_dst_, vmm_sum_zp_);
1009                     }
1010                     const Vmm vmm = vmm_out(j, k);
1011                     if (*p_sum_scale == 1.f)
1012                         uni_vaddps(vmm, vmm, vmm_prev_dst_);
1013                     else {
1014                         uni_vbroadcastss(vmm_tmp_, ptr[reg_ptr_sum_scale_]);
1015                         uni_vfmadd231ps(vmm, vmm_prev_dst_, vmm_tmp_);
1016                     }
1017                 }
1018             }
1019         }
1020     };
1021 
1022     if (p_sum_scale)
1023         postops_injector_->set_lambda_injector(
1024                 primitive_kind::sum, sum_injector);
1025 
1026     binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1027     if (jcp_.with_binary) {
1028         for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1029             const bool mask_flag
1030                     = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1031             for (int ur = 0; ur < ur_w; ur++) {
1032                 const int vmm_idx = vmm_out(ur, ocb).getIdx();
1033                 const size_t aux_output_offset = jcp_.typesize_out
1034                         * (ocb * jcp_.oc_block
1035                                 + ur * jcp_.oc_without_padding * jcp_.ngroups);
1036 
1037                 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst_);
1038                 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1039                         vmm_idx, aux_output_offset);
1040                 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1041             }
1042         }
1043     }
1044     const int nb_oc_block
1045             = jcp_.is_depthwise ? jcp_.nb_ch_blocking : jcp_.nb_oc_blocking;
1046     postops_injector_->compute_vector_range(
1047             16 - nb_oc_block * ur_w, 16, rhs_arg_params);
1048 }
1049 
1050 template <cpu_isa_t isa, typename Vmm>
store_output(int ur_w,bool last_oc_block)1051 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::store_output(
1052         int ur_w, bool last_oc_block) {
1053     mov(reg_bias_, ptr[param1_ + GET_OFF(bias)]);
1054     mov(reg_ptr_scales_, ptr[param1_ + GET_OFF(scales)]);
1055 
1056     if (jcp_.signed_input)
1057         mov(reg_compensation_, ptr[param1_ + GET_OFF(compensation)]);
1058 
1059     if (jcp_.src_zero_point) {
1060         mov(reg_zp_src_, ptr[param1_ + GET_OFF(src_zero_point)]);
1061         mov(reg_zp_compensation_, ptr[param1_ + GET_OFF(zp_compensation)]);
1062     }
1063 
1064     const auto &p = jcp_.post_ops;
1065     const int sum_idx = p.find(primitive_kind::sum);
1066     const float *p_sum_scale
1067             = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
1068     const int32_t *p_sum_zp
1069             = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
1070 
1071     if (jcp_.with_bias && jcp_.signed_input && jcp_.ver != ver_vnni) {
1072         mov(reg_bias_alpha_, float2int(jcp_.wei_adj_scale));
1073         uni_vmovq(xmm_bias_alpha(), reg_bias_alpha_);
1074         uni_vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
1075     }
1076 
1077     if (jcp_.src_zero_point) {
1078         const auto &vmm_src_zp = vmm_tmp_;
1079         const auto &vmm_zp_comp = vmm_scale_;
1080         uni_vbroadcastss(vmm_src_zp, ptr[reg_zp_src_]);
1081 
1082         for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1083             const int zp_offset = sizeof(int32_t) * ocb * jcp_.oc_block;
1084             const bool mask_flag
1085                     = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1086             const int load_size
1087                     = mask_flag ? get_tail_size() : get_blocking_size();
1088             load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation_,
1089                     zp_offset, load_size);
1090             uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_src_zp);
1091 
1092             for (int ur = 0; ur < ur_w; ur++) {
1093                 const auto vmm_dst = vmm_out(ur, ocb);
1094                 uni_vpaddd(vmm_dst, vmm_dst, vmm_zp_comp);
1095             }
1096         }
1097     }
1098 
1099     for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1100         const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1101         const int scale_offset
1102                 = jcp_.is_oc_scale * (sizeof(float) * ocb * jcp_.oc_block);
1103 
1104         const auto vmm_bias_ = vmm_tmp_;
1105         if (jcp_.with_bias) {
1106             int bias_offset = jcp_.typesize_bia * ocb * jcp_.oc_block;
1107             cvt2ps(jcp_.bia_dt, vmm_bias_, reg_bias_, bias_offset,
1108                     mask_flag ? get_tail_size() : get_blocking_size());
1109             if (jcp_.signed_input && jcp_.ver != ver_vnni)
1110                 uni_vmulps(vmm_bias_, vmm_bias_, vmm_bias_alpha());
1111         }
1112         if (jcp_.signed_input) {
1113             const int comp_offset = sizeof(int32_t) * ocb * jcp_.oc_block;
1114             cvt2ps(data_type::s32, vmm_comp_, reg_compensation_, comp_offset,
1115                     mask_flag ? get_tail_size() : get_blocking_size());
1116         }
1117 
1118         /* add to ymm_accum: compensation, bias */
1119         uni_vmovups(vmm_scale_, ptr[reg_ptr_scales_ + scale_offset]);
1120         for (int ur = 0; ur < ur_w; ur++) {
1121             const Vmm vmm = vmm_out(ur, ocb);
1122             uni_vcvtdq2ps(vmm, vmm);
1123             if (jcp_.signed_input) uni_vaddps(vmm, vmm, vmm_comp_);
1124             if (jcp_.with_bias) uni_vaddps(vmm, vmm, vmm_bias_);
1125             uni_vmulps(vmm, vmm, vmm_scale_);
1126         }
1127     }
1128 
1129     if (p_sum_scale && *p_sum_scale != 1.f)
1130         mov(reg_ptr_sum_scale_, reinterpret_cast<size_t>(p_sum_scale));
1131     if (p_sum_zp && *p_sum_zp != 0) {
1132         mov(reg_ptr_sum_zp_, reinterpret_cast<size_t>(p_sum_zp));
1133     }
1134     if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum)
1135         apply_postops(ur_w, last_oc_block, p_sum_scale, p_sum_zp);
1136     if (jcp_.dst_zero_point) {
1137         mov(reg_zp_dst_, ptr[param1_ + GET_OFF(dst_zero_point)]);
1138         const auto &vmm_zp_dst = vmm_tmp_;
1139         uni_vbroadcastss(vmm_zp_dst, ptr[reg_zp_dst_]);
1140         uni_vcvtdq2ps(vmm_zp_dst, vmm_zp_dst);
1141 
1142         for_(int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++)
1143         for (int ur = 0; ur < ur_w; ur++) {
1144             const auto vmm_dst = vmm_out(ur, ocb);
1145             uni_vaddps(vmm_dst, vmm_dst, vmm_zp_dst);
1146         }
1147     }
1148 
1149     // Properly saturate the accumulators for integer datatypes
1150 
1151     // No need to saturate on lower bound for signed integer types, as
1152     // the conversion to int would return INT_MIN, and then proper
1153     // saturation will happen when storing data
1154     if (jcp_.dst_dt == data_type::u8) {
1155         uni_vpxor(vmm_zero_, vmm_zero_, vmm_zero_);
1156         for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1157             for (int ur = 0; ur < ur_w; ur++) {
1158                 const Vmm vmm = vmm_out(ur, ocb);
1159                 uni_vmaxps(vmm, vmm, vmm_zero_);
1160             }
1161         }
1162     }
1163 
1164     if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1165         float saturation_ubound = types::max_value<float>(jcp_.dst_dt);
1166         const Xmm xmm_saturation(vmm_saturation_.getIdx());
1167         mov(reg_ptr_saturation_ubound_, float2int(saturation_ubound));
1168         uni_vmovq(xmm_saturation, reg_ptr_saturation_ubound_);
1169         uni_vbroadcastss(vmm_saturation_, xmm_saturation);
1170 
1171         for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1172             for (int ur = 0; ur < ur_w; ur++) {
1173                 const Vmm vmm = vmm_out(ur, ocb);
1174                 uni_vminps(vmm, vmm, vmm_saturation_);
1175             }
1176         }
1177     }
1178 
1179     if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1180         for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1181             for (int ur = 0; ur < ur_w; ur++) {
1182                 const Vmm vmm = vmm_out(ur, ocb);
1183                 uni_vcvtps2dq(vmm, vmm);
1184             }
1185         }
1186     }
1187 
1188     /* write out register to output_addr */
1189     for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1190         const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1191         for (int ur = 0; ur < ur_w; ur++) {
1192             const int aux_dst_off = jcp_.typesize_out
1193                     * (ur * jcp_.ngroups * jcp_.oc_without_padding
1194                             + ocb * jcp_.oc_block);
1195             const Vmm r_vmm = vmm_out(ur, ocb);
1196             store_data(jcp_.dst_dt, r_vmm, reg_dst_, aux_dst_off,
1197                     mask_flag ? get_tail_size() : get_blocking_size());
1198         }
1199     }
1200 }
1201 
1202 template <cpu_isa_t isa, typename Vmm>
icb_loop(int ur_w,int l_overflow,int r_overflow,bool is_last_sp_block)1203 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::icb_loop(
1204         int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
1205 
1206     const int shift_src_icb = jcp_.typesize_in * jcp_.ic_block;
1207     const size_t shift_filt_icb = (size_t)jcp_.typesize_in * jcp_.kd * jcp_.kh
1208             * jcp_.kw * jcp_.ic_block * jcp_.oc_block;
1209 
1210     prepare_output(ur_w);
1211 
1212     Label skip_icb_loop, icb_loop_label;
1213 
1214     mov(reg_icb_, jcp_.nb_ic);
1215     mov(reg_oc_blocks_, ptr[param1_ + GET_OFF(oc_blocks)]);
1216 
1217     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)
1218             && jcp_.ndims > 3) {
1219         mov(reg_scratch_,
1220                 qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]);
1221         mov(zp_src_pad_comp_addr_, reg_scratch_);
1222     }
1223 
1224     L(icb_loop_label);
1225     {
1226         if (jcp_.ngroups % jcp_.ch_block != 0
1227                 || jcp_.ic_without_padding != jcp_.ic) {
1228             Label common_ker, end_ker;
1229             if (jcp_.is_depthwise) {
1230                 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
1231                 jne(common_ker, T_NEAR);
1232             } else {
1233                 cmp(reg_icb_, 1);
1234                 jg(common_ker, T_NEAR);
1235             }
1236 
1237             kh_loop(ur_w, l_overflow, r_overflow, last_sp_block);
1238             jmp(end_ker, T_NEAR);
1239 
1240             L(common_ker);
1241             kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1242 
1243             L(end_ker);
1244         } else {
1245             kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1246         }
1247 
1248         add(reg_src_, shift_src_icb);
1249         safe_add(reg_filt_, shift_filt_icb, reg_ker_long_offt_);
1250         dec(reg_icb_);
1251         cmp(reg_icb_, 0);
1252         jg(icb_loop_label, T_NEAR);
1253     }
1254 
1255     /* come-back pointers */
1256     sub(reg_src_, jcp_.nb_ic * shift_src_icb);
1257     safe_sub(reg_filt_, jcp_.nb_ic * shift_filt_icb, reg_ker_long_offt_);
1258     L(skip_icb_loop);
1259 
1260     if (jcp_.ngroups % jcp_.ch_block != 0
1261             || jcp_.oc_without_padding != jcp_.oc) {
1262         Label common_store, end_store;
1263         if (jcp_.is_depthwise)
1264             cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
1265         else
1266             cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking);
1267         jne(common_store, T_NEAR);
1268 
1269         store_output(ur_w, true);
1270         jmp(end_store, T_NEAR);
1271 
1272         L(common_store);
1273         store_output(ur_w, false);
1274 
1275         L(end_store);
1276 
1277     } else {
1278         store_output(ur_w, false);
1279     }
1280 }
1281 
1282 template <cpu_isa_t isa, typename Vmm>
generate()1283 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::generate() {
1284     preamble();
1285 
1286     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
1287         sub(rsp, reserved_stack_size_);
1288 
1289     const auto vmm_one_128 = Xbyak::Xmm(vmm_one_.getIdx());
1290     mov(reg_scratch_, 0x10001);
1291     uni_vmovq(vmm_one_128, reg_scratch_);
1292     uni_vpbroadcastd(vmm_one_, vmm_one_128);
1293 
1294     mov(reg_src_, ptr[param1_ + GET_OFF(src)]);
1295     mov(reg_filt_, ptr[param1_ + GET_OFF(filt)]);
1296     mov(reg_dst_, ptr[param1_ + GET_OFF(dst)]);
1297 
1298     const int dst_shift = jcp_.typesize_out * jcp_.ur_w * jcp_.ngroups
1299             * jcp_.oc_without_padding;
1300     const int src_shift = jcp_.typesize_in * (jcp_.ur_w / jcp_.stride_w)
1301             * jcp_.ngroups * jcp_.ic_without_padding;
1302 
1303     const int l_overflow = max(0,
1304             ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - jcp_.l_pad) / jcp_.stride_w);
1305     const int r_overflow = max(0,
1306             ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - max(0, jcp_.r_pad))
1307                     / jcp_.stride_w);
1308 
1309     const int r_overflow1 = nstl::max(0,
1310             ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - nstl::max(0, jcp_.r_pad)
1311                     - jcp_.ur_w_tail)
1312                     / jcp_.stride_w);
1313     int nur_w = jcp_.ow / jcp_.ur_w;
1314     if (r_overflow1 > 0) nur_w--;
1315 
1316     if (jcp_.ur_w == jcp_.ow) {
1317         icb_loop(jcp_.ur_w, l_overflow, r_overflow, true);
1318     } else if (nur_w == 0) {
1319         icb_loop(jcp_.ur_w, l_overflow, r_overflow1, jcp_.ur_w_tail == 0);
1320         add(reg_src_, src_shift);
1321         add(reg_dst_, dst_shift);
1322         if (jcp_.ur_w_tail != 0) icb_loop(jcp_.ur_w_tail, 0, r_overflow, true);
1323     } else {
1324         xor_(reg_nur_w_, reg_nur_w_);
1325         if (l_overflow > 0) {
1326             icb_loop(jcp_.ur_w, l_overflow, 0, false);
1327             add(reg_src_, src_shift);
1328             add(reg_dst_, dst_shift);
1329             inc(reg_nur_w_);
1330         }
1331         if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
1332             Label ow_loop_label;
1333             L(ow_loop_label);
1334             {
1335                 icb_loop(jcp_.ur_w, 0, 0, false);
1336                 add(reg_src_, src_shift);
1337                 add(reg_dst_, dst_shift);
1338                 inc(reg_nur_w_);
1339                 cmp(reg_nur_w_, nur_w);
1340                 jl(ow_loop_label, T_NEAR);
1341             }
1342         }
1343         if (r_overflow1 > 0) {
1344             icb_loop(jcp_.ur_w, 0, r_overflow1, jcp_.ur_w_tail == 0);
1345             add(reg_src_, src_shift);
1346             add(reg_dst_, dst_shift);
1347         }
1348         if (jcp_.ur_w_tail != 0) {
1349             icb_loop(jcp_.ur_w_tail, 0, r_overflow, true);
1350         }
1351     }
1352 
1353     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
1354         add(rsp, reserved_stack_size_);
1355 
1356     postamble();
1357 
1358     if (jcp_.with_eltwise) postops_injector_->prepare_table();
1359 }
1360 
1361 template <cpu_isa_t isa>
jit_uni_x8s8s32x_deconvolution_fwd_t(const pd_t * apd)1362 jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::jit_uni_x8s8s32x_deconvolution_fwd_t(
1363         const pd_t *apd)
1364     : primitive_t(apd) {}
1365 
1366 template <cpu_isa_t isa>
1367 jit_uni_x8s8s32x_deconvolution_fwd_t<
1368         isa>::~jit_uni_x8s8s32x_deconvolution_fwd_t()
1369         = default;
1370 
1371 template <cpu_isa_t isa>
init(engine_t * engine)1372 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::pd_t::init(
1373         engine_t *engine) {
1374     using namespace data_type;
1375     using skip_mask_t = primitive_attr_t::skip_mask_t;
1376     const bool ok = true && is_fwd()
1377             && (desc()->alg_kind & alg_kind::deconvolution_direct)
1378             && utils::one_of(src_md(0)->data_type, s8, u8)
1379             && weights_md(0)->data_type == s8
1380             && IMPLICATION(with_bias(),
1381                     utils::one_of(weights_md(1)->data_type, f32, s32, s8, u8))
1382             && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
1383             && desc()->accum_data_type == s32
1384             && attr()->has_default_values(skip_mask_t::oscale
1385                     | skip_mask_t::post_ops | skip_mask_t::zero_points_runtime);
1386     if (!ok) return status::unimplemented;
1387 
1388     CHECK(jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(jcp_, *desc(),
1389             src_md_, weights_md_, dst_md_, with_bias(), bias_md_, attr_,
1390             dnnl_get_max_threads()));
1391 
1392     auto scratchpad = scratchpad_registry().registrar();
1393     jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad(
1394             scratchpad, jcp_, *attr());
1395 
1396     return status::success;
1397 }
1398 
1399 template <cpu_isa_t isa>
init(engine_t * engine)1400 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::init(engine_t *engine) {
1401     CHECK(safe_ptr_assign(kernel_,
1402             new jit_uni_x8s8s32x_deconv_fwd_kernel<isa>(pd()->jcp_,
1403                     *pd()->attr(), memory_desc_wrapper(pd()->dst_md()))));
1404 
1405     if (zp::should_calculate_deconv_zp_src_pad_str_comp(pd()->jcp_)) {
1406         CHECK(safe_ptr_assign(zp_src_pad_comp_kernel_,
1407                 zp::create_deconv_zp_pad_str_comp_ker<isa>(pd()->jcp_)));
1408         const auto zp_kernel_status = zp_src_pad_comp_kernel_->create_kernel();
1409         if (zp_kernel_status != status::success) return zp_kernel_status;
1410     }
1411 
1412     return kernel_->create_kernel();
1413 }
1414 
1415 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const1416 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute(
1417         const exec_ctx_t &ctx) const {
1418     const auto &_pd = pd();
1419     const auto &ndims = _pd->ndims();
1420     if (ndims == 3)
1421         return execute_forward_1d(ctx);
1422     else if (ndims == 4)
1423         return execute_forward_2d(ctx);
1424     else if (ndims == 5)
1425         return execute_forward_3d(ctx);
1426     else
1427         return status::unimplemented;
1428     return status::success;
1429 }
1430 
1431 template <cpu_isa_t isa>
execute_forward_1d(const exec_ctx_t & ctx) const1432 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_1d(
1433         const exec_ctx_t &ctx) const {
1434     const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1435     const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1436     const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1437     auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1438     DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1439     DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1440 
1441     const auto &jcp = pd()->jcp_;
1442 
1443     const memory_desc_wrapper src_d(pd()->src_md());
1444     const memory_desc_wrapper dst_d(pd()->dst_md());
1445     const memory_desc_wrapper weights_d(pd()->weights_md(0));
1446     const memory_desc_wrapper bias_d(pd()->weights_md(1));
1447 
1448     const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1449 
1450     const auto post_ops_binary_rhs_arg_vec
1451             = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1452     auto scratchpad = ctx.get_scratchpad_grantor();
1453     int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1454 
1455     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1456         zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1457                 weights_d, weights, zp_src, zp_src_comp_scratch,
1458                 zp_src_pad_comp_kernel_.get());
1459 
1460     const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1461     const int nb_groups = jcp.nb_ch;
1462 
1463     const float *oscales = pd()->attr()->output_scales_.scales_;
1464     if (jcp.signed_input && jcp.ver != ver_vnni) {
1465         auto local_scales = ctx.get_scratchpad_grantor().template get<float>(
1466                 key_conv_adjusted_scales);
1467         const size_t count = pd()->attr()->output_scales_.count_;
1468         const float factor = 1.f / pd()->jcp_.wei_adj_scale;
1469         if (count == 1) {
1470             utils::array_set(local_scales, oscales[0] * factor, 8);
1471         } else {
1472             for (size_t c = 0; c < count; c++)
1473                 local_scales[c] = oscales[c] * factor;
1474         }
1475         oscales = local_scales;
1476     }
1477     const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1478     auto w = const_cast<int8_t *>(weights);
1479     int32_t *compensation = (jcp.signed_input)
1480             ? reinterpret_cast<int32_t *>(&w[offset])
1481             : nullptr;
1482     const int32_t *zp_compensation = jcp.src_zero_point
1483             ? get_src_zp_comp_from_wei(
1484                     weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1485             : nullptr;
1486 
1487     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1488         int start {0}, end {0};
1489         const int work_amount = jcp.mb * nb_groups * oc_chunks;
1490         balance211(work_amount, nthr, ithr, start, end);
1491 
1492         auto p = jit_deconv_call_s();
1493 
1494         int n {0}, g {0}, occ {0};
1495         if (jcp.loop_order == loop_ngc)
1496             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
1497         else if (jcp.loop_order == loop_cgn)
1498             nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
1499         else
1500             assert(!"unsupported loop order");
1501         while (start < end) {
1502 
1503             const int ocb = occ * jcp.nb_oc_blocking;
1504             const int g_oc
1505                     = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1506             const int g_ic = g * jcp.ch_block * jcp.ic;
1507 
1508             p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1509             p.src = src + src_d.blk_off(n, g_ic);
1510             p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
1511             p.bias = jcp.with_bias
1512                     ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1513                     : nullptr;
1514             p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
1515             p.scales = &oscales[jcp.is_oc_scale * g_oc];
1516             p.t_overflow = 0;
1517             p.b_overflow = 0;
1518             p.kh_padding = jcp.kh;
1519             p.oc_blocks = jcp.is_depthwise ? g : ocb;
1520             p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
1521             p.oc_l_off = g_oc;
1522             p.zp_compensation
1523                     = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1524             p.zp_src_pad_str_compensation = zp_src_comp_scratch
1525                     ? zp_src_comp_scratch + g_oc
1526                     : nullptr;
1527             p.src_zero_point = zp_src;
1528             p.dst_zero_point = zp_dst;
1529             p.dst_orig = dst;
1530 
1531             (*kernel_)(&p);
1532 
1533             ++start;
1534             if (jcp.loop_order == loop_ngc)
1535                 nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
1536             else if (jcp.loop_order == loop_cgn)
1537                 nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
1538             else
1539                 assert(!"unsupported loop order");
1540         }
1541     });
1542     return status::success;
1543 }
1544 
1545 template <cpu_isa_t isa>
execute_forward_2d(const exec_ctx_t & ctx) const1546 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_2d(
1547         const exec_ctx_t &ctx) const {
1548     const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1549     const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1550     const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1551     auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1552     DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1553     DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1554 
1555     const memory_desc_wrapper src_d(pd()->src_md());
1556     const memory_desc_wrapper dst_d(pd()->dst_md());
1557     const memory_desc_wrapper weights_d(pd()->weights_md(0));
1558     const memory_desc_wrapper bias_d(pd()->weights_md(1));
1559 
1560     const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1561 
1562     const auto &jcp = pd()->jcp_;
1563     const auto post_ops_binary_rhs_arg_vec
1564             = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1565 
1566     auto scratchpad = ctx.get_scratchpad_grantor();
1567     int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1568 
1569     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1570         zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1571                 weights_d, weights, zp_src, zp_src_comp_scratch,
1572                 zp_src_pad_comp_kernel_.get());
1573 
1574     const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1575     const int nb_groups = jcp.nb_ch;
1576 
1577     const size_t src_h_stride = src_d.blk_off(0, 0, 1);
1578     const size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
1579     const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1580 
1581     const float *oscales = pd()->attr()->output_scales_.scales_;
1582     if (jcp.signed_input && jcp.ver != ver_vnni) {
1583         auto local_scales = ctx.get_scratchpad_grantor().template get<float>(
1584                 key_conv_adjusted_scales);
1585         const size_t count = pd()->attr()->output_scales_.count_;
1586         const float factor = 1.f / pd()->jcp_.wei_adj_scale;
1587         if (count == 1) {
1588             utils::array_set(local_scales, oscales[0] * factor, 8);
1589         } else {
1590             for (size_t c = 0; c < count; c++)
1591                 local_scales[c] = oscales[c] * factor;
1592         }
1593         oscales = local_scales;
1594     }
1595     const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1596     auto w = const_cast<int8_t *>(weights);
1597     int32_t *compensation = (jcp.signed_input)
1598             ? reinterpret_cast<int32_t *>(&w[offset])
1599             : nullptr;
1600     const int32_t *zp_compensation = jcp.src_zero_point
1601             ? get_src_zp_comp_from_wei(
1602                     weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1603             : nullptr;
1604 
1605     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1606         int start {0}, end {0};
1607         const int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
1608         balance211(work_amount, nthr, ithr, start, end);
1609 
1610         auto p = jit_deconv_call_s();
1611 
1612         /*loop order = cgn*/
1613         int n {0}, g {0}, occ {0}, oh_s {0};
1614         if (jcp.loop_order == loop_ngc)
1615             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1616                     oh_s, jcp.oh);
1617         else if (jcp.loop_order == loop_cgn)
1618             nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1619                     oh_s, jcp.oh);
1620         else
1621             assert(!"unsupported loop order");
1622         while (start < end) {
1623 
1624             const int ocb = occ * jcp.nb_oc_blocking;
1625             const int g_oc
1626                     = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1627             const int g_ic = g * jcp.ch_block * jcp.ic;
1628             const int work_rem = end - start;
1629             const int oh_e
1630                     = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1631 
1632             const auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1633             const auto src_w = src + src_d.blk_off(n, g_ic);
1634             const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
1635             const auto bias_w = jcp.with_bias
1636                     ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1637                     : nullptr;
1638             const int32_t *compensation_w
1639                     = (jcp.signed_input) ? compensation + g_oc : nullptr;
1640 
1641             const auto scales = &oscales[jcp.is_oc_scale * g_oc];
1642             for (int oj = oh_s; oj < oh_e; oj++) {
1643                 int ih_max = 0, kh_lo = 0, kh_len = 0;
1644                 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1645                     /* dilation */
1646                     const int dilate_h = jcp.dilate_h + 1;
1647                     // Note: use div_up to account for "holes" in filter
1648                     const int o_t_overflow = div_up(
1649                             max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1650                             dilate_h);
1651                     const int o_b_overflow
1652                             = div_up(max(0,
1653                                              (jcp.kh - 1) * dilate_h + 1
1654                                                      - jcp.oh + oj - jcp.b_pad),
1655                                     dilate_h);
1656                     kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1657                     kh_lo = o_b_overflow;
1658                     ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1659                 } else {
1660                     const int o_t_overflow = max(
1661                             0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1662                     const int o_b_overflow = max(0,
1663                             ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1664                                     / jcp.stride_h);
1665                     const int overflow_kh_hi = jcp.kh - 1
1666                             - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1667                                     jcp.stride_h);
1668                     const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1669 
1670                     kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1671                             + 1 - o_t_overflow - o_b_overflow;
1672                     kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1673                     ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1674                 }
1675 
1676                 const int wei_stride
1677                         = (!jcp.signed_input && !jcp.src_zero_point)
1678                         ? kh_lo * wht_kh_stride
1679                         : 0;
1680                 p.src = src_w + ih_max * src_h_stride;
1681                 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1682                 p.filt = wht_w + wei_stride;
1683                 p.bias = bias_w;
1684                 p.compensation = compensation_w;
1685                 p.t_overflow = jcp.dilate_h > 0
1686                         ? jcp.kh - kh_len - kh_lo
1687                         : max(0,
1688                                 jcp.kh
1689                                         - (kh_lo
1690                                                 + max(0, kh_len - 1)
1691                                                         * jcp.stride_h
1692                                                 + 1));
1693                 p.b_overflow = kh_lo;
1694                 p.kh_padding = kh_len;
1695                 p.scales = scales;
1696                 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1697                 p.post_ops_binary_rhs_arg_vec
1698                         = post_ops_binary_rhs_arg_vec.data();
1699                 p.oc_l_off = g_oc;
1700                 p.zp_compensation
1701                         = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1702                 p.zp_src_pad_str_compensation = jcp.src_zero_point
1703                         ? zp_src_comp_scratch + g_oc
1704                         : nullptr;
1705                 p.src_zero_point = zp_src;
1706                 p.dst_zero_point = zp_dst;
1707                 p.dst_orig = dst;
1708 
1709                 (*kernel_)(&p);
1710             }
1711             if (jcp.loop_order == loop_ngc)
1712                 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1713                         oc_chunks, oh_s, jcp.oh);
1714             else if (jcp.loop_order == loop_cgn)
1715                 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1716                         jcp.mb, oh_s, jcp.oh);
1717             else
1718                 assert(!"unsupported loop order");
1719         }
1720     });
1721     return status::success;
1722 }
1723 
1724 template <cpu_isa_t isa>
execute_forward_3d(const exec_ctx_t & ctx) const1725 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_3d(
1726         const exec_ctx_t &ctx) const {
1727     const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1728     const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1729     const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1730     const auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1731     DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1732     DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1733 
1734     const memory_desc_wrapper src_d(pd()->src_md());
1735     const memory_desc_wrapper dst_d(pd()->dst_md());
1736     const memory_desc_wrapper weights_d(pd()->weights_md(0));
1737     const memory_desc_wrapper bias_d(pd()->weights_md(1));
1738 
1739     const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1740 
1741     const auto &jcp = pd()->jcp_;
1742     const auto post_ops_binary_rhs_arg_vec
1743             = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1744 
1745     const auto scratchpad = ctx.get_scratchpad_grantor();
1746     int32_t *const zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1747 
1748     if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1749         zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1750                 weights_d, weights, zp_src, zp_src_comp_scratch,
1751                 zp_src_pad_comp_kernel_.get());
1752 
1753     const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1754     const int &nb_groups = jcp.nb_ch;
1755 
1756     const size_t src_d_stride = src_d.blk_off(0, 0, 1);
1757     const size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
1758     const size_t dst_d_stride = dst_d.blk_off(0, 0, 1);
1759     const size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
1760     const size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1761     const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
1762 
1763     const float *oscales = pd()->attr()->output_scales_.scales_;
1764     if (jcp.signed_input && jcp.ver != ver_vnni) {
1765         const auto local_scales
1766                 = ctx.get_scratchpad_grantor().template get<float>(
1767                         key_conv_adjusted_scales);
1768         const size_t count = pd()->attr()->output_scales_.count_;
1769         const float factor = 1.f / pd()->jcp_.wei_adj_scale;
1770         if (count == 1) {
1771             utils::array_set(local_scales, oscales[0] * factor, 8);
1772         } else {
1773             for (size_t c = 0; c < count; c++)
1774                 local_scales[c] = oscales[c] * factor;
1775         }
1776         oscales = local_scales;
1777     }
1778     const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1779     auto w = const_cast<int8_t *>(weights);
1780     int32_t *compensation = (jcp.signed_input)
1781             ? reinterpret_cast<int32_t *>(&w[offset])
1782             : nullptr;
1783     const int32_t *zp_compensation = jcp.src_zero_point
1784             ? get_src_zp_comp_from_wei(
1785                     weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1786             : nullptr;
1787 
1788     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1789         int start {0}, end {0};
1790         int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh;
1791         balance211(work_amount, nthr, ithr, start, end);
1792 
1793         auto p = jit_deconv_call_s();
1794 
1795         /*loop order = cgn*/
1796         int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0};
1797         if (jcp.loop_order == loop_ngc)
1798             nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1799                     od_s, jcp.od, oh_s, jcp.oh);
1800         else if (jcp.loop_order == loop_cgn)
1801             nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1802                     od_s, jcp.od, oh_s, jcp.oh);
1803         else
1804             assert(!"unsupported loop order");
1805         while (start < end) {
1806 
1807             const int ocb = occ * jcp.nb_oc_blocking;
1808             const int g_oc
1809                     = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1810             const int g_ic = g * jcp.ch_block * jcp.ic;
1811             const int work_rem = end - start;
1812             const int oh_e
1813                     = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1814             int input_d_s = 0, kd_len = 0, kd_lo = 0;
1815             int d_t_overflow, d_back_overflow;
1816 
1817             if (jcp.dilate_d != 0 && jcp.stride_d == 1) {
1818                 /* dilation */
1819                 int dilate_d = jcp.dilate_d + 1;
1820                 // Note: use div_up to account for "holes" in filter
1821                 d_t_overflow = div_up(
1822                         max(0, (jcp.kd - 1) * dilate_d - od_s - jcp.f_pad),
1823                         dilate_d);
1824                 d_back_overflow
1825                         = div_up(max(0,
1826                                          (jcp.kd - 1) * dilate_d + 1 - jcp.od
1827                                                  + od_s - jcp.back_pad),
1828                                 dilate_d);
1829                 kd_len = jcp.kd - d_t_overflow - d_back_overflow;
1830                 kd_lo = d_back_overflow;
1831                 input_d_s = od_s + jcp.f_pad - d_back_overflow * dilate_d;
1832             } else {
1833                 const int d_t_overflow = max(
1834                         0, (jcp.kd - (od_s + 1 + jcp.f_pad)) / jcp.stride_d);
1835                 const int d_back_overflow = max(0,
1836                         ((od_s + jcp.kd) - (jcp.od + jcp.back_pad))
1837                                 / jcp.stride_d);
1838                 const int overflow_kd_hi = jcp.kd - 1
1839                         - modulo(jcp.od + jcp.back_pad - (od_s + 1),
1840                                 jcp.stride_d);
1841                 const int overflow_kd_lo = (od_s + jcp.f_pad) % jcp.stride_d;
1842 
1843                 kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1
1844                         - d_t_overflow - d_back_overflow;
1845                 kd_lo = overflow_kd_lo + d_back_overflow * jcp.stride_d;
1846                 input_d_s = (od_s + jcp.f_pad - kd_lo) / jcp.stride_d;
1847             }
1848 
1849             auto dst_w = dst
1850                     + dst_dt_size
1851                             * (dst_d.blk_off(n, g_oc) + od_s * dst_d_stride);
1852             const auto src_w
1853                     = src + src_d.blk_off(n, g_ic) + input_d_s * src_d_stride;
1854             const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
1855                     + ((jcp.signed_input || jcp.src_zero_point) ? 0 : kd_lo)
1856                             * wht_kd_stride;
1857             const auto bias_w = jcp.with_bias
1858                     ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1859                     : nullptr;
1860             const int32_t *compensation_w
1861                     = (jcp.signed_input) ? compensation + g_oc : nullptr;
1862 
1863             const auto scales = &oscales[jcp.is_oc_scale * g_oc];
1864 
1865             for (int oj = oh_s; oj < oh_e; oj++) {
1866                 int ih_max = 0, kh_lo = 0, kh_len = 0;
1867                 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1868                     /* dilation */
1869                     const int dilate_h = jcp.dilate_h + 1;
1870                     // Note: use div_up to account for "holes" in filter
1871                     const int o_t_overflow = div_up(
1872                             max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1873                             dilate_h);
1874                     const int o_b_overflow
1875                             = div_up(max(0,
1876                                              (jcp.kh - 1) * dilate_h + 1
1877                                                      - jcp.oh + oj - jcp.b_pad),
1878                                     dilate_h);
1879                     kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1880                     kh_lo = o_b_overflow;
1881                     ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1882                 } else {
1883                     const int o_t_overflow = max(
1884                             0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1885                     const int o_b_overflow = max(0,
1886                             ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1887                                     / jcp.stride_h);
1888                     const int overflow_kh_hi = jcp.kh - 1
1889                             - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1890                                     jcp.stride_h);
1891                     const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1892 
1893                     kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1894                             + 1 - o_t_overflow - o_b_overflow;
1895                     kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1896                     ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1897                 }
1898 
1899                 const int wei_stride
1900                         = (!jcp.signed_input && !jcp.src_zero_point)
1901                         ? kh_lo * wht_kh_stride
1902                         : 0;
1903                 p.src = src_w + ih_max * src_h_stride;
1904                 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1905                 p.filt = wht_w + wei_stride;
1906                 p.bias = bias_w;
1907                 p.compensation = compensation_w;
1908                 /* Note: Currently this kernel doesn't support dilations and
1909                 strides together */
1910                 p.t_overflow = jcp.dilate_h > 0
1911                         ? jcp.kh - kh_len - kh_lo
1912                         : max(0,
1913                                 jcp.kh
1914                                         - (kh_lo
1915                                                 + max(0, kh_len - 1)
1916                                                         * jcp.stride_h
1917                                                 + 1));
1918                 p.b_overflow = kh_lo;
1919                 p.f_overflow = jcp.dilate_d > 0
1920                         ? jcp.kd - kd_len - kd_lo
1921                         : max(0,
1922                                 jcp.kd
1923                                         - (kd_lo
1924                                                 + max(0, kd_len - 1)
1925                                                         * jcp.stride_d
1926                                                 + 1));
1927                 p.back_overflow = kd_lo;
1928                 p.kh_padding = kh_len;
1929                 p.kd_padding = kd_len;
1930                 p.scales = scales;
1931                 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1932                 p.post_ops_binary_rhs_arg_vec
1933                         = post_ops_binary_rhs_arg_vec.data();
1934                 p.oc_l_off = g_oc;
1935                 p.zp_compensation
1936                         = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1937                 p.zp_src_pad_str_compensation = jcp.src_zero_point
1938                         ? zp_src_comp_scratch + g_oc
1939                         : nullptr;
1940                 p.src_zero_point = zp_src;
1941                 p.dst_zero_point = zp_dst;
1942                 p.dst_orig = dst;
1943 
1944                 (*kernel_)(&p);
1945             }
1946 
1947             if (jcp.loop_order == loop_ngc)
1948                 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1949                         oc_chunks, od_s, jcp.od, oh_s, jcp.oh);
1950             else if (jcp.loop_order == loop_cgn)
1951                 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1952                         jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
1953             else
1954                 assert(!"unsupported loop order");
1955         }
1956     });
1957 
1958     return status::success;
1959 }
1960 
1961 using namespace data_type;
1962 template struct jit_uni_x8s8s32x_deconvolution_fwd_t<avx2>;
1963 template struct jit_uni_x8s8s32x_deconvolution_fwd_t<sse41>;
1964 template struct jit_uni_x8s8s32x_deconv_fwd_kernel<avx2>;
1965 template struct jit_uni_x8s8s32x_deconv_fwd_kernel<sse41>;
1966 template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>;
1967 template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Xmm>;
1968 template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<sse41, Xbyak::Xmm>;
1969 } // namespace x64
1970 } // namespace cpu
1971 } // namespace impl
1972 } // namespace dnnl
1973