1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "common/dnnl_thread.hpp"
18 #include "common/nstl.hpp"
19 #include "common/utils.hpp"
20
21 #include "cpu/cpu_primitive.hpp"
22 #include "cpu/zero_point_utils.hpp"
23
24 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
25 #include "cpu/x64/jit_uni_deconv_zp_pad_str_kernel.hpp"
26 #include "cpu/x64/jit_uni_x8s8s32x_deconvolution.hpp"
27
28 #define GET_OFF(field) offsetof(jit_deconv_call_s, field)
29
30 namespace dnnl {
31 namespace impl {
32 namespace cpu {
33 namespace x64 {
34
35 using namespace dnnl::impl::status;
36 using namespace dnnl::impl::memory_tracking::names;
37 using namespace dnnl::impl::utils;
38 using namespace Xbyak;
39
40 using namespace nstl;
41
42 #define wht_blk_off(d, g, ...) \
43 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
44 : (d).blk_off(__VA_ARGS__))
45
46 template <cpu_isa_t isa>
init_conf(jit_conv_conf_t & jcp,const deconvolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,const bool with_bias,memory_desc_t & bias_md,primitive_attr_t & attr,int nthreads)47 status_t jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(
48 jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
49 memory_desc_t &src_md, memory_desc_t &weights_md, memory_desc_t &dst_md,
50 const bool with_bias, memory_desc_t &bias_md, primitive_attr_t &attr,
51 int nthreads) {
52 const memory_desc_wrapper src_d(&src_md);
53 const memory_desc_wrapper dst_d(&dst_md);
54 const memory_desc_wrapper weights_d(&weights_md);
55 const memory_desc_wrapper bias_d(&bias_md);
56
57 if (!(mayiuse(isa)
58 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
59 && weights_d.data_type() == data_type::s8
60 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
61 data_type::s8, data_type::u8)))
62 return status::unimplemented;
63
64 jcp = zero<decltype(jcp)>();
65 jcp.nthr = nthreads;
66
67 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
68 jcp.signed_input = src_d.data_type() == data_type::s8;
69 const int ndims = jcp.ndims = dst_d.ndims();
70 const bool is_1d = ndims == 3;
71 const bool is_2d = ndims == 4;
72 const bool is_3d = ndims == 5;
73 const bool is_avx2 = isa == avx2;
74
75 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
76 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
77 jcp.ic = src_d.dims()[1] / jcp.ngroups;
78 jcp.id = is_3d ? src_d.dims()[2] : 1;
79 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
80 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
81 jcp.is_depthwise = true && with_groups
82 && utils::everyone_is(
83 1, jcp.ic_without_padding, jcp.oc_without_padding);
84 jcp.ver = mayiuse(avx2_vnni) ? ver_vnni : ver_unused;
85
86 /* TODO: future work, on hold until depthwise specialized kernel is
87 * implemented. */
88 if (jcp.is_depthwise && (jcp.signed_input || is_3d))
89 return status::unimplemented;
90
91 if (!zero_points_valid(&attr)) return status::unimplemented;
92 jcp.src_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_SRC);
93 jcp.dst_zero_point = !attr.zero_points_.has_default_values(DNNL_ARG_DST);
94 jcp.zp_src_is_common = attr.zero_points_.common(DNNL_ARG_SRC);
95
96 format_tag_t dat_tag = utils::pick(
97 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
98
99 if (src_d.format_kind() == format_kind::any) {
100 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
101 jcp.src_tag = dat_tag;
102 } else {
103 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
104 }
105 if (jcp.src_tag != dat_tag) return status::unimplemented;
106
107 if (dst_d.format_kind() == format_kind::any) {
108 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
109 jcp.dst_tag = dat_tag;
110 } else {
111 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
112 }
113 if (jcp.dst_tag != dat_tag) return status::unimplemented;
114
115 auto set_or_check_wei_format = [&]() {
116 using namespace format_tag;
117 format_tag_t wei_tag;
118 if (jcp.ic_block == 8 || jcp.ch_block == 8) {
119 if (is_1d) {
120 wei_tag = with_groups ? jcp.is_depthwise ? Goiw8g : gOIw2i8o4i
121 : OIw2i8o4i;
122 } else if (is_2d) {
123 wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i
124 : OIhw2i8o4i;
125 } else {
126 wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i;
127 }
128 } else {
129 if (is_avx2) {
130 assert(with_groups && jcp.ic_block == 4);
131 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
132 } else {
133 if (is_1d) {
134 wei_tag = with_groups ? jcp.is_depthwise ? Goiw4g : gOIw4o4i
135 : OIw4o4i;
136 } else if (is_2d) {
137 wei_tag = with_groups
138 ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i
139 : OIhw4o4i;
140 } else {
141 wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i;
142 }
143 }
144 }
145
146 memory_desc_t want_wei_md = weights_md;
147 memory_desc_init_by_tag(want_wei_md, wei_tag);
148 if (jcp.signed_input && !jcp.is_depthwise) {
149 want_wei_md.extra.flags = 0
150 | memory_extra_flags::compensation_conv_s8s8
151 | memory_extra_flags::scale_adjust;
152 want_wei_md.extra.compensation_mask = (1 << 0)
153 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
154 want_wei_md.extra.scale_adjust = (jcp.ver == ver_vnni) ? 1.f : 0.5f;
155 }
156 if (jcp.src_zero_point) set_zp_src_comp_flags(want_wei_md, with_groups);
157
158 if (weights_md.format_kind == format_kind::any) {
159 weights_md = want_wei_md;
160 return true;
161 }
162
163 return weights_md == want_wei_md;
164 };
165
166 jcp.with_bias = with_bias;
167 if (jcp.with_bias) {
168 if (bias_d.format_kind() == format_kind::any)
169 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
170 }
171
172 jcp.prop_kind = cd.prop_kind;
173 jcp.mb = src_d.dims()[0];
174 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
175 jcp.iw = src_d.dims()[ndims - 1];
176 jcp.od = is_3d ? dst_d.dims()[2] : 1;
177 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
178 jcp.ow = dst_d.dims()[ndims - 1];
179 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
180 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
181 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
182 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
183 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
184 jcp.l_pad = cd.padding[0][ndims - 3];
185 jcp.stride_d = is_3d ? cd.strides[0] : 1;
186 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
187 jcp.stride_w = cd.strides[ndims - 3];
188
189 if (jcp.is_depthwise) {
190 jcp.ch_block = is_avx2 ? 8 : 4;
191 jcp.oc_block = 1;
192 jcp.ic_block = 1;
193 } else {
194 jcp.ch_block = 1;
195 jcp.oc_block = is_avx2 ? 8 : 4;
196 jcp.ic_block = is_avx2 ? 8 : 4;
197
198 if (jcp.ngroups == 1) {
199 jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
200 jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
201 } else if (jcp.ngroups != 1
202 && ((jcp.ic % jcp.ic_block != 0)
203 || (jcp.oc % jcp.oc_block != 0))) {
204 /* For grouped convolution, oneDNN doesn't support padding.
205 * When channel per group is not multiple of 8 in avx2:
206 * - Use Xmm when channels per groups is multiple of 4.
207 * - Otherwise return unimplemented */
208 jcp.oc_block = jcp.ic_block = 4;
209 }
210 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
211 return status::unimplemented;
212 }
213
214 if (!set_or_check_wei_format()) return status::unimplemented;
215
216 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
217 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
218 jcp.dilate_w = cd.dilates[ndims - 3];
219
220 if (!IMPLICATION(jcp.dilate_d, jcp.stride_d == 1)
221 || !IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
222 || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
223 return status::unimplemented;
224
225 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
226 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
227 const int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
228 jcp.r_pad = calculate_end_padding(
229 jcp.l_pad, jcp.iw, jcp.ow, jcp.stride_w, ext_kw);
230 jcp.b_pad = calculate_end_padding(
231 jcp.t_pad, jcp.ih, jcp.oh, jcp.stride_h, ext_kh);
232 jcp.back_pad = calculate_end_padding(
233 jcp.f_pad, jcp.id, jcp.od, jcp.stride_d, ext_kd);
234 const bool kernel_outside_src = false || ext_kw <= jcp.l_pad
235 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
236 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
237 if (kernel_outside_src) return status::unimplemented;
238
239 CHECK(attr.set_default_formats(&dst_md));
240 if (!post_ops_ok(jcp, dst_d, attr)) return status::unimplemented;
241
242 const auto &p = attr.post_ops_;
243 const int eltwise_ind = p.find(primitive_kind::eltwise);
244 jcp.with_eltwise = eltwise_ind != -1;
245 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
246
247 const int binary_ind = p.find(primitive_kind::binary);
248 jcp.with_binary = binary_ind != -1;
249
250 const int sum_ind = p.find(primitive_kind::sum);
251 jcp.with_sum = sum_ind != -1;
252
253 const auto &oscales = attr.output_scales_;
254 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
255
256 jcp.post_ops = p;
257
258 // only common and per-oc-channel scales are supported
259 const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
260 if (!oscales_ok) return status::unimplemented;
261
262 jcp.dst_dt = dst_d.data_type();
263 jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
264 jcp.typesize_bia
265 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
266 jcp.typesize_in = types::data_type_size(src_d.data_type());
267 jcp.typesize_out = types::data_type_size(dst_d.data_type());
268
269 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
270 jcp.nb_oc = jcp.oc / jcp.oc_block;
271 jcp.nb_ic = jcp.ic / jcp.ic_block;
272
273 /* kernel blocking params */
274 const int regs = (jcp.ver == ver_vnni ? 14 : 12);
275
276 jcp.nb_ch_blocking = 1;
277 jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
278 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
279 if (jcp.nb_oc % jcp.nb_oc_blocking == 0
280 && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
281 break;
282
283 jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
284 const int l_overflow = max(
285 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
286
287 if (jcp.ow < jcp.ur_w) {
288 jcp.ur_w = jcp.ow;
289 jcp.ur_w_tail = 0;
290 } else {
291 for (; jcp.ur_w >= 1; jcp.ur_w--) {
292 /* ur_w should be multiple of stride_w in order
293 to simplify logic for get_ow_start and get_ow_end */
294 const bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
295
296 /* boundary conditions:
297 These conditions ensure all elements close to boundary
298 are computed in a single call of compute loop */
299 const bool left_boundary_covered
300 = jcp.ur_w >= l_overflow * jcp.stride_w;
301 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
302 const int r_overflow_no_tail = max(0,
303 ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad)
304 - jcp.ur_w_tail)
305 / jcp.stride_w);
306 const bool right_boundary_covered
307 = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
308
309 if (is_multiple_of_stride && left_boundary_covered
310 && right_boundary_covered)
311 break;
312 else if (jcp.ur_w == 1)
313 /* The boundary conditions above are also important
314 to maintain simplicity of calls to icb_loop,
315 if those conditions are not satisfied,
316 then special cases will need to be added
317 to use correct l_overflow/r_overflow values
318 when different iterations of compute loop
319 work on the locations close to boundary.
320 So to keep code simple, return unimplemented
321 for extreme case when a good ur_w cannot be found.
322 */
323 return status::unimplemented;
324 }
325 }
326
327 jcp.wei_adj_scale
328 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
329 ? weights_d.extra().scale_adjust
330 : 1.f;
331
332 jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
333 return status::success;
334 }
335
336 template <cpu_isa_t isa>
jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_wrapper & dst_d)337 jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::jit_uni_x8s8s32x_deconv_fwd_kernel(
338 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
339 const memory_desc_wrapper &dst_d)
340 : kernel_(nullptr) {
341
342 const int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
343 switch (ch_block) {
344 case 8:
345 if (isa == avx2) {
346 kernel_ = utils::make_unique<
347 _jit_avx2_x8s8s32x_deconv_fwd_kernel>(
348 ajcp, attr, dst_d);
349 return;
350 } else
351 assert(!"invalid channel blocking for current ISA");
352 case 4:
353 kernel_ = utils::make_unique<
354 _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Xbyak::Xmm>>(
355 ajcp, attr, dst_d);
356 return;
357 default: assert(!"invalid channel blocking");
358 }
359 }
360
361 template <cpu_isa_t isa>
362 jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::~jit_uni_x8s8s32x_deconv_fwd_kernel()
363 = default;
364
365 template <cpu_isa_t isa>
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)366 void jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad(
367 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
368 const primitive_attr_t &attr) {
369 if (jcp.signed_input && jcp.ver != ver_vnni) {
370 dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 8);
371 scratchpad.book<float>(key_conv_adjusted_scales, count);
372 }
373
374 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp)) {
375 const dim_t zp_pad_comp_size = jcp.oc_without_padding * jcp.ngroups
376 * jcp.kd * jcp.kh * jcp.kw;
377 scratchpad.book<int32_t>(key_deconv_zp, zp_pad_comp_size);
378 }
379 }
380
381 template <cpu_isa_t isa>
post_ops_ok(jit_conv_conf_t & jcp,const memory_desc_wrapper & dst_d,const primitive_attr_t & attr)382 bool jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::post_ops_ok(jit_conv_conf_t &jcp,
383 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
384 using namespace injector;
385
386 return injector::post_ops_ok(post_ops_ok_args_t(isa, {sum, eltwise, binary},
387 attr.post_ops_, &dst_d, false /*sum_at_pos_0_only*/,
388 false /*sum_requires_scale_one*/, false /*sum_requires_zp_zero*/,
389 {broadcasting_strategy_t::per_oc,
390 broadcasting_strategy_t::scalar}));
391 }
392
393 template <cpu_isa_t isa, typename Vmm>
394 _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
_jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t & ajcp,const primitive_attr_t & attr,const memory_desc_wrapper & dst_d)395 Vmm>::_jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
396 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d)
397 : jit_generator(nullptr, MAX_CODE_SIZE, true, isa)
398 , jcp_(ajcp)
399 , postops_injector_(nullptr) {
400
401 if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum) {
402 const std::size_t tail_size = get_tail_size();
403
404 static constexpr bool preserve_gpr = true;
405 static constexpr bool preserve_vmm = true;
406 static constexpr bool use_exact_tail_scalar_bcast = false;
407 static constexpr size_t vmm_helper_idx = 15;
408
409 const binary_injector::rhs_arg_static_params_t rhs_sp {vmm_helper_idx,
410 this->r14, this->r15, preserve_gpr, preserve_vmm,
411 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d,
412 tail_size, Xbyak::Opmask(2), use_exact_tail_scalar_bcast};
413 const binary_injector::static_params_t bsp {this->param1_, rhs_sp};
414
415 postops_injector_ = utils::make_unique<
416 injector::jit_uni_postops_injector_t<isa, Vmm>>(
417 this, jcp_.post_ops, bsp);
418 }
419 }
420
421 template <cpu_isa_t isa, typename Vmm>
422 _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
423 Vmm>::~_jit_uni_x8s8s32x_deconv_fwd_kernel()
424 = default;
425
426 template <cpu_isa_t isa, typename Vmm>
vmm_out(int i_ur,int i_oc) const427 Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_out(
428 int i_ur, int i_oc) const {
429 const int idx = i_ur * jcp_.nb_oc_blocking + i_oc;
430 assert(idx < KER_MAX_REG_IDX);
431 /* remap the reg indices to avoid using xmm0 in eltwise injector */
432 return Vmm(15 - idx);
433 }
434
435 template <cpu_isa_t isa, typename Vmm>
vmm_inp(int i_ic,int nb_x_blocking) const436 Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_inp(
437 int i_ic, int nb_x_blocking) const {
438 const int idx = i_ic + nb_x_blocking * jcp_.ur_w;
439 assert(idx < KER_MAX_REG_IDX);
440 return Vmm(15 - idx);
441 }
442
443 template <cpu_isa_t isa, typename Vmm>
vmm_bias_alpha() const444 Vmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::vmm_bias_alpha() const {
445 return Vmm(15 - jcp_.nb_oc_blocking * jcp_.ur_w);
446 }
447
448 template <cpu_isa_t isa, typename Vmm>
xmm_bias_alpha() const449 Xmm _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::xmm_bias_alpha() const {
450 return Xmm(vmm_bias_alpha().getIdx());
451 }
452
453 template <cpu_isa_t isa, typename Vmm>
get_ow_start(int ki,int l_overflow) const454 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_start(
455 int ki, int l_overflow) const noexcept {
456 int res = (jcp_.ow - 1 + jcp_.r_pad) % jcp_.stride_w
457 + l_overflow * jcp_.stride_w
458 - (jcp_.kw - 1 - ki) * (jcp_.dilate_w + 1);
459 while (res < 0)
460 res += jcp_.stride_w;
461 return res;
462 }
463
464 template <cpu_isa_t isa, typename Vmm>
get_ow_end(int ur_w,int ki,int r_overflow) const465 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_ow_end(
466 int ur_w, int ki, int r_overflow) const noexcept {
467 if (utils::one_of(ur_w, jcp_.ow, jcp_.ur_w_tail))
468 ur_w += nstl::min(0, jcp_.r_pad); // remove negative padding
469 int res = (ur_w - 1 + jcp_.l_pad) % jcp_.stride_w
470 + r_overflow * jcp_.stride_w - ki * (jcp_.dilate_w + 1);
471 while (res < 0)
472 res += jcp_.stride_w;
473 return ur_w - res;
474 }
475
476 template <cpu_isa_t isa, typename Vmm>
get_blocking_size() const477 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_blocking_size() const
478 noexcept {
479 return jcp_.is_depthwise ? jcp_.ch_block : jcp_.oc_block;
480 }
481
482 template <cpu_isa_t isa, typename Vmm>
get_tail_size() const483 int _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::get_tail_size() const
484 noexcept {
485 return jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block
486 : jcp_.oc_without_padding % jcp_.oc_block;
487 }
488
489 template <cpu_isa_t isa, typename Vmm>
compute(const Vmm & vreg_acc,const Vmm & vreg_wei,const Vmm & vreg_src)490 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute(
491 const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src) {
492
493 if (jcp_.ver == ver_vnni) {
494 vpdpbusd(vreg_acc, vreg_src, vreg_wei, Xbyak::VexEncoding);
495 } else if (jcp_.is_depthwise) {
496 uni_vmovups(vmm_tmp_, vreg_src);
497 uni_vpmulld(vmm_tmp_, vmm_tmp_, vreg_wei);
498 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_);
499 } else {
500 uni_vpmaddubsw(vmm_tmp_, vreg_src, vreg_wei);
501 uni_vpmaddwd(vmm_tmp_, vmm_tmp_, vmm_one_);
502 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp_);
503 }
504 }
505
506 template <cpu_isa_t isa, typename Vmm>
507 std::function<Vmm()> _jit_uni_x8s8s32x_deconv_fwd_kernel<isa,
prepare_round_robin_vmm_inp_generator(int ur_w) const508 Vmm>::prepare_round_robin_vmm_inp_generator(int ur_w) const noexcept {
509
510 const int start_vmm_idx = vmm_inp(ur_w - 1, jcp_.nb_oc_blocking).getIdx();
511 const int end_vmm_idx = vmm_inp(0, jcp_.nb_oc_blocking).getIdx() + 1;
512 int current_vmm_idx = start_vmm_idx;
513
514 return [=]() mutable {
515 const Vmm vmm {static_cast<int>(current_vmm_idx++)};
516
517 if (current_vmm_idx == end_vmm_idx) current_vmm_idx = start_vmm_idx;
518
519 return vmm;
520 };
521 }
522
523 template <cpu_isa_t isa, typename Vmm>
apply_zp_src_pad_str_comp(int ur_w,int l_overflow,int r_overflow,bool h_padded)524 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_zp_src_pad_str_comp(
525 int ur_w, int l_overflow, int r_overflow, bool h_padded) {
526 Xbyak::Label end_zp_pad, no_tail;
527
528 // apply once per icb loop, zp src stride paddding compensation calculate as
529 // zp_pad_str_compensation = conv(1, weights_s8) * zero_point_source
530 cmp(reg_icb_, jcp_.nb_ic);
531 jne(end_zp_pad, T_NEAR);
532
533 if (jcp_.ngroups % jcp_.ch_block
534 || jcp_.oc_without_padding % jcp_.oc_block) {
535 if (jcp_.is_depthwise)
536 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
537 else
538 cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking);
539 jne(no_tail, T_NEAR);
540
541 append_zp_src_pad_str_comp(
542 ur_w, l_overflow, r_overflow, h_padded, true /*last_oc_block*/);
543 jmp(end_zp_pad, T_NEAR);
544 }
545
546 L(no_tail);
547 append_zp_src_pad_str_comp(
548 ur_w, l_overflow, r_overflow, h_padded, false /*last_oc_block*/);
549
550 L(end_zp_pad);
551 }
552
553 template <cpu_isa_t isa, typename Vmm>
append_zp_src_pad_str_comp(int ur_w,int l_overflow,int r_overflow,bool h_padded,bool last_oc_block)554 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::append_zp_src_pad_str_comp(
555 int ur_w, int l_overflow, int r_overflow, bool h_padded,
556 bool last_oc_block) {
557
558 const auto ®_zp_src_pad_comp = reg_scratch_;
559 const auto get_next_comp_vmm = prepare_round_robin_vmm_inp_generator(ur_w);
560 bool base_comp_addr_loaded = false;
561
562 const auto load_base_zp_src_pad_comp_addr = [&]() {
563 if (!base_comp_addr_loaded) {
564 if (jcp_.ndims == 5) mov(reg_scratch_preserved_, reg_scratch_);
565
566 if (jcp_.ndims > 3)
567 mov(reg_zp_src_pad_comp, zp_src_pad_comp_addr_);
568 else
569 mov(reg_zp_src_pad_comp,
570 qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]);
571
572 base_comp_addr_loaded = true;
573 }
574 };
575
576 const auto load_zp_src_pad_comp = [&](const Vmm &zp_pad_comp_vmm,
577 const Xbyak::Address &comp_addr,
578 const int ocb) {
579 const bool is_tail = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
580
581 if (is_tail)
582 load_data(data_type::s32, zp_pad_comp_vmm, comp_addr,
583 get_tail_size());
584 else
585 uni_vmovups(zp_pad_comp_vmm, comp_addr);
586 };
587
588 const auto get_zp_src_comp_pad_off = [&](int it_kw, int ocb) {
589 const auto kw_offset = it_kw * jcp_.oc_without_padding * jcp_.ngroups;
590 const auto oc_offset = ocb * jcp_.oc_block;
591
592 return (kw_offset + oc_offset) * sizeof(int32_t);
593 };
594
595 for (int it_kw = 0; it_kw < jcp_.kw; ++it_kw) {
596 const int ow_start = get_ow_start(it_kw, l_overflow);
597 const int ow_end = get_ow_end(ur_w, it_kw, r_overflow);
598
599 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
600 Vmm zp_src_comp_pad_vmm; // will be assigned later
601 bool ocb_zp_loaded = false;
602
603 const auto zp_src_comp_pad_off
604 = get_zp_src_comp_pad_off(it_kw, ocb);
605
606 for (int it_ow = 0; it_ow < ur_w; ++it_ow) {
607
608 const bool inside_padded_area = h_padded
609 || !(it_ow >= ow_start && it_ow < ow_end
610 && ((it_ow + jcp_.l_pad - it_kw) % jcp_.stride_w
611 == 0));
612
613 if (inside_padded_area) {
614 load_base_zp_src_pad_comp_addr();
615
616 if (!ocb_zp_loaded) {
617 zp_src_comp_pad_vmm = get_next_comp_vmm();
618 const auto comp_addr = ptr[reg_zp_src_pad_comp
619 + zp_src_comp_pad_off];
620 load_zp_src_pad_comp(
621 zp_src_comp_pad_vmm, comp_addr, ocb);
622 ocb_zp_loaded = true;
623 }
624
625 const auto vmm_dst = vmm_out(it_ow, ocb);
626 uni_vpaddd(vmm_dst, vmm_dst, zp_src_comp_pad_vmm);
627 }
628 }
629 }
630 }
631
632 if (jcp_.ndims > 3) {
633 if (!base_comp_addr_loaded) load_base_zp_src_pad_comp_addr();
634
635 const auto kh_offset = jcp_.kw * jcp_.oc_without_padding * jcp_.ngroups
636 * sizeof(int32_t);
637
638 add(reg_zp_src_pad_comp, kh_offset);
639 mov(zp_src_pad_comp_addr_, reg_zp_src_pad_comp);
640 }
641
642 if (jcp_.ndims == 5 && base_comp_addr_loaded)
643 mov(reg_scratch_, reg_scratch_preserved_);
644 }
645
646 template <cpu_isa_t isa, typename Vmm>
compute_ker(int ur_w,int l_overflow,int r_overflow,ker_block_t last_ic_block_flag,bool h_padded)647 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::compute_ker(int ur_w,
648 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
649 bool h_padded) {
650
651 const bool signed_input_or_src_zp
652 = (jcp_.signed_input || jcp_.src_zero_point);
653
654 const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block;
655 const int ur_w_stride = signed_input_or_src_zp ? 1 : jcp_.stride_w;
656
657 const auto src_offset = [=](int oj, int icb, int ki) {
658 return jcp_.typesize_in
659 * (((oj + jcp_.l_pad - ki * (jcp_.dilate_w + 1))
660 / jcp_.stride_w)
661 * jcp_.ngroups * jcp_.ic_without_padding
662 + icb * 4);
663 };
664
665 const auto kernel_offset = [=](int ocb, int icb, int ki) {
666 return jcp_.typesize_in
667 * ((ocb * jcp_.nb_ic * jcp_.kd * jcp_.kh * jcp_.kw + ki)
668 * ch_block_all
669 + icb * jcp_.oc_block * IC_SUB_STEP);
670 };
671
672 for (int ki = 0; ki < jcp_.kw; ki++) {
673
674 const int jj_start = get_ow_start(ki, l_overflow);
675 const int jj_end = get_ow_end(ur_w, ki, r_overflow);
676
677 const int _start = (signed_input_or_src_zp) ? 0 : jj_start;
678 const int _end = (signed_input_or_src_zp) ? ur_w : jj_end;
679
680 const int tail_size = jcp_.is_depthwise ? jcp_.ngroups % jcp_.ch_block
681 : jcp_.ic_without_padding % 4;
682 const int n_ic_blocks = jcp_.is_depthwise
683 ? 1
684 : (last_ic_block_flag != no_last_block ? div_up(
685 jcp_.ic_without_padding % jcp_.ic_block, 4)
686 : jcp_.ic_block / 4);
687
688 for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
689 if (h_padded) {
690 /* fill padded area with shifted values */
691 if (jcp_.signed_input) {
692 const Vmm inp = vmm_inp(0, jcp_.nb_oc_blocking);
693 uni_vpxor(inp, inp, inp);
694 uni_vpsubb(inp, inp, vmm_shift_);
695 }
696 } else {
697
698 for (int jj = _start; jj < _end; jj += ur_w_stride) {
699
700 const int aux_src_off = src_offset(jj, icb1, ki);
701 const auto vmm_src = vmm_inp(jj, jcp_.nb_oc_blocking);
702
703 if (jj >= jj_start && jj < jj_end
704 && ((jj + jcp_.l_pad - ki) % jcp_.stride_w == 0)) {
705 if (jcp_.is_depthwise) {
706 if (tail_size != 0)
707 assert(jcp_.nb_oc_blocking == 1);
708 uni_vpxor(vmm_src, vmm_src, vmm_src);
709 const bool mask_flag
710 = last_ic_block_flag != no_last_block
711 && tail_size;
712 load_data(data_type::u8, vmm_src, aux_reg_src_,
713 aux_src_off,
714 mask_flag ? tail_size : jcp_.ch_block);
715 } else if ((last_ic_block_flag == last_sp_block)
716 && tail_size != 0 && icb1 == n_ic_blocks - 1) {
717 const auto vmm_inp_tmp = Xmm(vmm_src.getIdx());
718 load_bytes(vmm_inp_tmp, aux_reg_src_, aux_src_off,
719 tail_size);
720 uni_vpbroadcastd(vmm_src, vmm_inp_tmp);
721 } else {
722 uni_vpbroadcastd(
723 vmm_src, ptr[aux_reg_src_ + aux_src_off]);
724 }
725 if (jcp_.signed_input)
726 uni_vpsubb(vmm_src, vmm_src, vmm_shift_);
727 } else {
728 /* fill padded area with shifted values */
729 if (jcp_.signed_input) {
730 uni_vpxor(vmm_src, vmm_src, vmm_src);
731 uni_vpsubb(vmm_src, vmm_src, vmm_shift_);
732 }
733 }
734 }
735 }
736 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
737 const int aux_filt_off = kernel_offset(ocb, icb1, ki);
738
739 if (_end - _start > 0) {
740 if (jcp_.is_depthwise) {
741 uni_vpmovsxbd(
742 vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]);
743 } else
744 uni_vmovups(
745 vmm_wei_, ptr[aux_reg_filt_ + aux_filt_off]);
746 }
747
748 for (int jj = _start; jj < _end; jj += ur_w_stride) {
749
750 const bool inside_padded_area = h_padded
751 || !(jj >= jj_start && jj < jj_end
752 && ((jj + jcp_.l_pad - ki) % jcp_.stride_w
753 == 0));
754 const auto vmm_dst = vmm_out(jj, ocb);
755 if (jcp_.signed_input || !inside_padded_area) {
756 const Vmm inp = vmm_inp(
757 h_padded ? 0 : jj, jcp_.nb_oc_blocking);
758 compute(vmm_dst, vmm_wei_, inp);
759 }
760 }
761 }
762 }
763 }
764
765 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
766 apply_zp_src_pad_str_comp(ur_w, l_overflow, r_overflow, h_padded);
767 }
768
769 template <cpu_isa_t isa, typename Vmm>
kh_loop(int ur_w,int l_overflow,int r_overflow,ker_block_t last_ic_block_flag)770 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::kh_loop(int ur_w,
771 int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
772
773 const bool signed_input_or_src_zp
774 = (jcp_.signed_input || jcp_.src_zero_point);
775
776 const int ch_block_all = jcp_.ch_block * jcp_.ic_block * jcp_.oc_block;
777 const int shift_src_ih = jcp_.typesize_in * (jcp_.dilate_h + 1) * jcp_.iw
778 * jcp_.ngroups * jcp_.ic_without_padding;
779 const int shift_src_id = jcp_.typesize_in * (jcp_.dilate_d + 1) * jcp_.ih
780 * jcp_.iw * jcp_.ngroups * jcp_.ic_without_padding;
781 const int stride_h = signed_input_or_src_zp ? 1 : jcp_.stride_h;
782 const int shift_filt_kh
783 = jcp_.typesize_in * jcp_.kw * ch_block_all * stride_h;
784 const int stride_d = signed_input_or_src_zp ? 1 : jcp_.stride_d;
785 const int shift_filt_kd
786 = jcp_.typesize_in * jcp_.kw * ch_block_all * jcp_.kh * stride_d;
787
788 Label kd_loop_label, kh_loop_label, skip_kh_loop, skip_kd_loop;
789 Label t_overflow_label, no_t_overflow_label, b_overflow_label,
790 no_b_overflow_label;
791 Label back_overflow_label, no_back_overflow_label, d_h_overflow_label,
792 front_overflow_label, no_front_overflow_label, d_h_overflow_label2;
793 if (jcp_.ndims == 5) {
794 mov(aux_reg_filt_d_, reg_filt_);
795 mov(aux_reg_src_d_, reg_src_);
796
797 if (signed_input_or_src_zp) {
798 mov(reg_ki_, ptr[param1_ + GET_OFF(back_overflow)]);
799 cmp(reg_ki_, 0);
800 je(no_back_overflow_label, T_NEAR);
801
802 L(back_overflow_label);
803 {
804 mov(aux_reg_filt_, aux_reg_filt_d_);
805 mov(reg_kh_, jcp_.kh);
806 L(d_h_overflow_label);
807 {
808 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
809 add(aux_reg_filt_, shift_filt_kh);
810 dec(reg_kh_);
811 jnz(d_h_overflow_label);
812 }
813
814 add(aux_reg_filt_d_, shift_filt_kd);
815 dec(reg_ki_);
816 jnz(back_overflow_label);
817 }
818 L(no_back_overflow_label);
819 }
820
821 mov(reg_ki_, ptr[param1_ + GET_OFF(kd_padding)]);
822
823 if ((signed_input_or_src_zp) || (jcp_.dilate_d >= jcp_.id)
824 || ((!signed_input_or_src_zp)
825 && ((min(jcp_.f_pad, jcp_.back_pad) < 0)
826 || ((jcp_.kd - 1) * (jcp_.dilate_d + 1)
827 < nstl::max(
828 jcp_.f_pad, jcp_.back_pad))))) {
829 cmp(reg_ki_, 0);
830 je(skip_kd_loop, T_NEAR);
831 }
832
833 L(kd_loop_label);
834 mov(aux_reg_src_, aux_reg_src_d_);
835 mov(aux_reg_filt_, aux_reg_filt_d_);
836 } else {
837 mov(aux_reg_src_, reg_src_);
838 mov(aux_reg_filt_, reg_filt_);
839 }
840
841 if (signed_input_or_src_zp && jcp_.ndims > 3) {
842 /* Weights are transposed, so first compute 'bottom' padding. */
843 mov(reg_overflow_, ptr[param1_ + GET_OFF(b_overflow)]);
844 cmp(reg_overflow_, 0);
845 je(no_b_overflow_label, T_NEAR);
846 L(b_overflow_label);
847 {
848 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
849
850 add(aux_reg_filt_, shift_filt_kh);
851 dec(reg_overflow_);
852 cmp(reg_overflow_, 0);
853 jg(b_overflow_label, T_NEAR);
854 }
855 L(no_b_overflow_label);
856 }
857
858 mov(reg_kh_, ptr[param1_ + GET_OFF(kh_padding)]);
859
860 if ((signed_input_or_src_zp) || (jcp_.dilate_h >= jcp_.ih)
861 || ((!signed_input_or_src_zp)
862 && ((min(jcp_.t_pad, jcp_.b_pad) < 0)
863 || ((jcp_.kh - 1) * (jcp_.dilate_h + 1)
864 < nstl::max(jcp_.t_pad, jcp_.b_pad))))) {
865 cmp(reg_kh_, 0);
866 je(skip_kh_loop, T_NEAR);
867 }
868
869 L(kh_loop_label);
870 {
871 compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
872 sub(aux_reg_src_, shift_src_ih);
873 add(aux_reg_filt_, shift_filt_kh);
874 dec(reg_kh_);
875
876 /* Insert weight compensation in stride 'holes' */
877 if (signed_input_or_src_zp && jcp_.stride_h > 1) {
878 Label kh_comp_loop;
879
880 cmp(reg_kh_, 0);
881 je(skip_kh_loop, T_NEAR);
882 mov(reg_comp_strides_, jcp_.stride_h - 1);
883 L(kh_comp_loop);
884 {
885 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
886 add(aux_reg_filt_, shift_filt_kh);
887 dec(reg_comp_strides_);
888 cmp(reg_comp_strides_, 0);
889 jg(kh_comp_loop, T_NEAR);
890 }
891 }
892 cmp(reg_kh_, 0);
893 jg(kh_loop_label, T_NEAR);
894 }
895 L(skip_kh_loop);
896 if (signed_input_or_src_zp && jcp_.ndims > 3) {
897 mov(reg_overflow_, ptr[param1_ + GET_OFF(t_overflow)]);
898 cmp(reg_overflow_, 0);
899 je(no_t_overflow_label, T_NEAR);
900 L(t_overflow_label);
901 {
902 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
903
904 add(aux_reg_filt_, shift_filt_kh);
905 dec(reg_overflow_);
906 cmp(reg_overflow_, 0);
907 jg(t_overflow_label, T_NEAR);
908 }
909 L(no_t_overflow_label);
910 }
911
912 if (jcp_.ndims == 5) {
913 sub(aux_reg_src_d_, shift_src_id);
914 add(aux_reg_filt_d_, shift_filt_kd);
915 dec(reg_ki_);
916
917 /* Insert weight compensation in stride 'holes' */
918 if (signed_input_or_src_zp && jcp_.stride_d > 1) {
919 Label kd_comp_loop, kd_kh_comp_loop;
920 cmp(reg_ki_, 0);
921 jz(skip_kd_loop, T_NEAR);
922 mov(reg_comp_strides_, jcp_.stride_d - 1);
923 L(kd_comp_loop);
924 mov(aux_reg_filt_, aux_reg_filt_d_);
925 mov(reg_kh_, jcp_.kh);
926 L(kd_kh_comp_loop);
927 {
928 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
929 add(aux_reg_filt_, shift_filt_kh);
930 dec(reg_kh_);
931 jnz(kd_kh_comp_loop, T_NEAR);
932 }
933 add(aux_reg_filt_d_, shift_filt_kd);
934 dec(reg_comp_strides_);
935 jnz(kd_comp_loop);
936 }
937
938 cmp(reg_ki_, 0);
939 jg(kd_loop_label, T_NEAR);
940 L(skip_kd_loop);
941 if (signed_input_or_src_zp) {
942 mov(reg_ki_, ptr[param1_ + GET_OFF(f_overflow)]);
943 cmp(reg_ki_, 0);
944 jz(no_front_overflow_label, T_NEAR);
945 L(front_overflow_label);
946 {
947 mov(aux_reg_filt_, aux_reg_filt_d_);
948 mov(reg_kh_, jcp_.kh);
949 L(d_h_overflow_label2);
950 {
951 compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
952 add(aux_reg_filt_, shift_filt_kh);
953 dec(reg_kh_);
954 jnz(d_h_overflow_label2);
955 }
956 add(aux_reg_filt_d_, shift_filt_kd);
957 dec(reg_ki_);
958 jnz(front_overflow_label);
959 }
960 L(no_front_overflow_label);
961 }
962 }
963 }
964
965 template <cpu_isa_t isa, typename Vmm>
prepare_output(int ur_w)966 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::prepare_output(int ur_w) {
967 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
968 for (int ur = 0; ur < ur_w; ur++) {
969 const Vmm vmm = vmm_out(ur, ocb);
970 uni_vpxor(vmm, vmm, vmm);
971 }
972 }
973 if (jcp_.signed_input) {
974 const auto xmm_shift = Xbyak::Xmm(vmm_shift_.getIdx());
975 mov(reg_scratch_, 0x80808080);
976 uni_vmovq(xmm_shift, reg_scratch_);
977 uni_vpbroadcastd(vmm_shift_, xmm_shift);
978 }
979 }
980
981 template <cpu_isa_t isa, typename Vmm>
cvt2ps(data_type_t type_in,const Vmm & vmm_in,const Reg64 & reg,int offset,int load_size)982 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::cvt2ps(data_type_t type_in,
983 const Vmm &vmm_in, const Reg64 ®, int offset, int load_size) {
984
985 load_data(type_in, vmm_in, reg, offset, load_size);
986 if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in);
987 }
988
989 template <cpu_isa_t isa, typename Vmm>
apply_postops(int ur_w,bool last_oc_block,const float * p_sum_scale,const int32_t * p_sum_zp)990 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::apply_postops(int ur_w,
991 bool last_oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) {
992 const auto sum_injector = [=]() {
993 if (p_sum_scale) { // post_op: sum
994 for (int k = 0; k < jcp_.nb_oc_blocking; k++) {
995 const bool mask_flag
996 = last_oc_block == 1 && k == jcp_.nb_oc_blocking - 1;
997 for (int j = 0; j < ur_w; j++) {
998 const int aux_output_offset = jcp_.typesize_out
999 * (k * jcp_.oc_block
1000 + j * jcp_.oc_without_padding
1001 * jcp_.ngroups);
1002 cvt2ps(jcp_.dst_dt, vmm_prev_dst_, reg_dst_,
1003 aux_output_offset,
1004 mask_flag ? get_tail_size() : get_blocking_size());
1005 if (*p_sum_zp != 0) {
1006 uni_vbroadcastss(vmm_sum_zp_, ptr[reg_ptr_sum_zp_]);
1007 uni_vcvtdq2ps(vmm_sum_zp_, vmm_sum_zp_);
1008 uni_vsubps(vmm_prev_dst_, vmm_prev_dst_, vmm_sum_zp_);
1009 }
1010 const Vmm vmm = vmm_out(j, k);
1011 if (*p_sum_scale == 1.f)
1012 uni_vaddps(vmm, vmm, vmm_prev_dst_);
1013 else {
1014 uni_vbroadcastss(vmm_tmp_, ptr[reg_ptr_sum_scale_]);
1015 uni_vfmadd231ps(vmm, vmm_prev_dst_, vmm_tmp_);
1016 }
1017 }
1018 }
1019 }
1020 };
1021
1022 if (p_sum_scale)
1023 postops_injector_->set_lambda_injector(
1024 primitive_kind::sum, sum_injector);
1025
1026 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
1027 if (jcp_.with_binary) {
1028 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1029 const bool mask_flag
1030 = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1031 for (int ur = 0; ur < ur_w; ur++) {
1032 const int vmm_idx = vmm_out(ur, ocb).getIdx();
1033 const size_t aux_output_offset = jcp_.typesize_out
1034 * (ocb * jcp_.oc_block
1035 + ur * jcp_.oc_without_padding * jcp_.ngroups);
1036
1037 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst_);
1038 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
1039 vmm_idx, aux_output_offset);
1040 if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
1041 }
1042 }
1043 }
1044 const int nb_oc_block
1045 = jcp_.is_depthwise ? jcp_.nb_ch_blocking : jcp_.nb_oc_blocking;
1046 postops_injector_->compute_vector_range(
1047 16 - nb_oc_block * ur_w, 16, rhs_arg_params);
1048 }
1049
1050 template <cpu_isa_t isa, typename Vmm>
store_output(int ur_w,bool last_oc_block)1051 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::store_output(
1052 int ur_w, bool last_oc_block) {
1053 mov(reg_bias_, ptr[param1_ + GET_OFF(bias)]);
1054 mov(reg_ptr_scales_, ptr[param1_ + GET_OFF(scales)]);
1055
1056 if (jcp_.signed_input)
1057 mov(reg_compensation_, ptr[param1_ + GET_OFF(compensation)]);
1058
1059 if (jcp_.src_zero_point) {
1060 mov(reg_zp_src_, ptr[param1_ + GET_OFF(src_zero_point)]);
1061 mov(reg_zp_compensation_, ptr[param1_ + GET_OFF(zp_compensation)]);
1062 }
1063
1064 const auto &p = jcp_.post_ops;
1065 const int sum_idx = p.find(primitive_kind::sum);
1066 const float *p_sum_scale
1067 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
1068 const int32_t *p_sum_zp
1069 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
1070
1071 if (jcp_.with_bias && jcp_.signed_input && jcp_.ver != ver_vnni) {
1072 mov(reg_bias_alpha_, float2int(jcp_.wei_adj_scale));
1073 uni_vmovq(xmm_bias_alpha(), reg_bias_alpha_);
1074 uni_vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
1075 }
1076
1077 if (jcp_.src_zero_point) {
1078 const auto &vmm_src_zp = vmm_tmp_;
1079 const auto &vmm_zp_comp = vmm_scale_;
1080 uni_vbroadcastss(vmm_src_zp, ptr[reg_zp_src_]);
1081
1082 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1083 const int zp_offset = sizeof(int32_t) * ocb * jcp_.oc_block;
1084 const bool mask_flag
1085 = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1086 const int load_size
1087 = mask_flag ? get_tail_size() : get_blocking_size();
1088 load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation_,
1089 zp_offset, load_size);
1090 uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_src_zp);
1091
1092 for (int ur = 0; ur < ur_w; ur++) {
1093 const auto vmm_dst = vmm_out(ur, ocb);
1094 uni_vpaddd(vmm_dst, vmm_dst, vmm_zp_comp);
1095 }
1096 }
1097 }
1098
1099 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1100 const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1101 const int scale_offset
1102 = jcp_.is_oc_scale * (sizeof(float) * ocb * jcp_.oc_block);
1103
1104 const auto vmm_bias_ = vmm_tmp_;
1105 if (jcp_.with_bias) {
1106 int bias_offset = jcp_.typesize_bia * ocb * jcp_.oc_block;
1107 cvt2ps(jcp_.bia_dt, vmm_bias_, reg_bias_, bias_offset,
1108 mask_flag ? get_tail_size() : get_blocking_size());
1109 if (jcp_.signed_input && jcp_.ver != ver_vnni)
1110 uni_vmulps(vmm_bias_, vmm_bias_, vmm_bias_alpha());
1111 }
1112 if (jcp_.signed_input) {
1113 const int comp_offset = sizeof(int32_t) * ocb * jcp_.oc_block;
1114 cvt2ps(data_type::s32, vmm_comp_, reg_compensation_, comp_offset,
1115 mask_flag ? get_tail_size() : get_blocking_size());
1116 }
1117
1118 /* add to ymm_accum: compensation, bias */
1119 uni_vmovups(vmm_scale_, ptr[reg_ptr_scales_ + scale_offset]);
1120 for (int ur = 0; ur < ur_w; ur++) {
1121 const Vmm vmm = vmm_out(ur, ocb);
1122 uni_vcvtdq2ps(vmm, vmm);
1123 if (jcp_.signed_input) uni_vaddps(vmm, vmm, vmm_comp_);
1124 if (jcp_.with_bias) uni_vaddps(vmm, vmm, vmm_bias_);
1125 uni_vmulps(vmm, vmm, vmm_scale_);
1126 }
1127 }
1128
1129 if (p_sum_scale && *p_sum_scale != 1.f)
1130 mov(reg_ptr_sum_scale_, reinterpret_cast<size_t>(p_sum_scale));
1131 if (p_sum_zp && *p_sum_zp != 0) {
1132 mov(reg_ptr_sum_zp_, reinterpret_cast<size_t>(p_sum_zp));
1133 }
1134 if (jcp_.with_eltwise || jcp_.with_binary || jcp_.with_sum)
1135 apply_postops(ur_w, last_oc_block, p_sum_scale, p_sum_zp);
1136 if (jcp_.dst_zero_point) {
1137 mov(reg_zp_dst_, ptr[param1_ + GET_OFF(dst_zero_point)]);
1138 const auto &vmm_zp_dst = vmm_tmp_;
1139 uni_vbroadcastss(vmm_zp_dst, ptr[reg_zp_dst_]);
1140 uni_vcvtdq2ps(vmm_zp_dst, vmm_zp_dst);
1141
1142 for_(int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++)
1143 for (int ur = 0; ur < ur_w; ur++) {
1144 const auto vmm_dst = vmm_out(ur, ocb);
1145 uni_vaddps(vmm_dst, vmm_dst, vmm_zp_dst);
1146 }
1147 }
1148
1149 // Properly saturate the accumulators for integer datatypes
1150
1151 // No need to saturate on lower bound for signed integer types, as
1152 // the conversion to int would return INT_MIN, and then proper
1153 // saturation will happen when storing data
1154 if (jcp_.dst_dt == data_type::u8) {
1155 uni_vpxor(vmm_zero_, vmm_zero_, vmm_zero_);
1156 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1157 for (int ur = 0; ur < ur_w; ur++) {
1158 const Vmm vmm = vmm_out(ur, ocb);
1159 uni_vmaxps(vmm, vmm, vmm_zero_);
1160 }
1161 }
1162 }
1163
1164 if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1165 float saturation_ubound = types::max_value<float>(jcp_.dst_dt);
1166 const Xmm xmm_saturation(vmm_saturation_.getIdx());
1167 mov(reg_ptr_saturation_ubound_, float2int(saturation_ubound));
1168 uni_vmovq(xmm_saturation, reg_ptr_saturation_ubound_);
1169 uni_vbroadcastss(vmm_saturation_, xmm_saturation);
1170
1171 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1172 for (int ur = 0; ur < ur_w; ur++) {
1173 const Vmm vmm = vmm_out(ur, ocb);
1174 uni_vminps(vmm, vmm, vmm_saturation_);
1175 }
1176 }
1177 }
1178
1179 if (one_of(jcp_.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
1180 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1181 for (int ur = 0; ur < ur_w; ur++) {
1182 const Vmm vmm = vmm_out(ur, ocb);
1183 uni_vcvtps2dq(vmm, vmm);
1184 }
1185 }
1186 }
1187
1188 /* write out register to output_addr */
1189 for (int ocb = 0; ocb < jcp_.nb_oc_blocking; ocb++) {
1190 const bool mask_flag = last_oc_block && ocb == jcp_.nb_oc_blocking - 1;
1191 for (int ur = 0; ur < ur_w; ur++) {
1192 const int aux_dst_off = jcp_.typesize_out
1193 * (ur * jcp_.ngroups * jcp_.oc_without_padding
1194 + ocb * jcp_.oc_block);
1195 const Vmm r_vmm = vmm_out(ur, ocb);
1196 store_data(jcp_.dst_dt, r_vmm, reg_dst_, aux_dst_off,
1197 mask_flag ? get_tail_size() : get_blocking_size());
1198 }
1199 }
1200 }
1201
1202 template <cpu_isa_t isa, typename Vmm>
icb_loop(int ur_w,int l_overflow,int r_overflow,bool is_last_sp_block)1203 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::icb_loop(
1204 int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
1205
1206 const int shift_src_icb = jcp_.typesize_in * jcp_.ic_block;
1207 const size_t shift_filt_icb = (size_t)jcp_.typesize_in * jcp_.kd * jcp_.kh
1208 * jcp_.kw * jcp_.ic_block * jcp_.oc_block;
1209
1210 prepare_output(ur_w);
1211
1212 Label skip_icb_loop, icb_loop_label;
1213
1214 mov(reg_icb_, jcp_.nb_ic);
1215 mov(reg_oc_blocks_, ptr[param1_ + GET_OFF(oc_blocks)]);
1216
1217 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_)
1218 && jcp_.ndims > 3) {
1219 mov(reg_scratch_,
1220 qword[param1_ + GET_OFF(zp_src_pad_str_compensation)]);
1221 mov(zp_src_pad_comp_addr_, reg_scratch_);
1222 }
1223
1224 L(icb_loop_label);
1225 {
1226 if (jcp_.ngroups % jcp_.ch_block != 0
1227 || jcp_.ic_without_padding != jcp_.ic) {
1228 Label common_ker, end_ker;
1229 if (jcp_.is_depthwise) {
1230 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
1231 jne(common_ker, T_NEAR);
1232 } else {
1233 cmp(reg_icb_, 1);
1234 jg(common_ker, T_NEAR);
1235 }
1236
1237 kh_loop(ur_w, l_overflow, r_overflow, last_sp_block);
1238 jmp(end_ker, T_NEAR);
1239
1240 L(common_ker);
1241 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1242
1243 L(end_ker);
1244 } else {
1245 kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
1246 }
1247
1248 add(reg_src_, shift_src_icb);
1249 safe_add(reg_filt_, shift_filt_icb, reg_ker_long_offt_);
1250 dec(reg_icb_);
1251 cmp(reg_icb_, 0);
1252 jg(icb_loop_label, T_NEAR);
1253 }
1254
1255 /* come-back pointers */
1256 sub(reg_src_, jcp_.nb_ic * shift_src_icb);
1257 safe_sub(reg_filt_, jcp_.nb_ic * shift_filt_icb, reg_ker_long_offt_);
1258 L(skip_icb_loop);
1259
1260 if (jcp_.ngroups % jcp_.ch_block != 0
1261 || jcp_.oc_without_padding != jcp_.oc) {
1262 Label common_store, end_store;
1263 if (jcp_.is_depthwise)
1264 cmp(reg_oc_blocks_, jcp_.nb_ch - 1);
1265 else
1266 cmp(reg_oc_blocks_, jcp_.nb_oc - jcp_.nb_oc_blocking);
1267 jne(common_store, T_NEAR);
1268
1269 store_output(ur_w, true);
1270 jmp(end_store, T_NEAR);
1271
1272 L(common_store);
1273 store_output(ur_w, false);
1274
1275 L(end_store);
1276
1277 } else {
1278 store_output(ur_w, false);
1279 }
1280 }
1281
1282 template <cpu_isa_t isa, typename Vmm>
generate()1283 void _jit_uni_x8s8s32x_deconv_fwd_kernel<isa, Vmm>::generate() {
1284 preamble();
1285
1286 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
1287 sub(rsp, reserved_stack_size_);
1288
1289 const auto vmm_one_128 = Xbyak::Xmm(vmm_one_.getIdx());
1290 mov(reg_scratch_, 0x10001);
1291 uni_vmovq(vmm_one_128, reg_scratch_);
1292 uni_vpbroadcastd(vmm_one_, vmm_one_128);
1293
1294 mov(reg_src_, ptr[param1_ + GET_OFF(src)]);
1295 mov(reg_filt_, ptr[param1_ + GET_OFF(filt)]);
1296 mov(reg_dst_, ptr[param1_ + GET_OFF(dst)]);
1297
1298 const int dst_shift = jcp_.typesize_out * jcp_.ur_w * jcp_.ngroups
1299 * jcp_.oc_without_padding;
1300 const int src_shift = jcp_.typesize_in * (jcp_.ur_w / jcp_.stride_w)
1301 * jcp_.ngroups * jcp_.ic_without_padding;
1302
1303 const int l_overflow = max(0,
1304 ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - jcp_.l_pad) / jcp_.stride_w);
1305 const int r_overflow = max(0,
1306 ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - max(0, jcp_.r_pad))
1307 / jcp_.stride_w);
1308
1309 const int r_overflow1 = nstl::max(0,
1310 ((jcp_.kw - 1) * (jcp_.dilate_w + 1) - nstl::max(0, jcp_.r_pad)
1311 - jcp_.ur_w_tail)
1312 / jcp_.stride_w);
1313 int nur_w = jcp_.ow / jcp_.ur_w;
1314 if (r_overflow1 > 0) nur_w--;
1315
1316 if (jcp_.ur_w == jcp_.ow) {
1317 icb_loop(jcp_.ur_w, l_overflow, r_overflow, true);
1318 } else if (nur_w == 0) {
1319 icb_loop(jcp_.ur_w, l_overflow, r_overflow1, jcp_.ur_w_tail == 0);
1320 add(reg_src_, src_shift);
1321 add(reg_dst_, dst_shift);
1322 if (jcp_.ur_w_tail != 0) icb_loop(jcp_.ur_w_tail, 0, r_overflow, true);
1323 } else {
1324 xor_(reg_nur_w_, reg_nur_w_);
1325 if (l_overflow > 0) {
1326 icb_loop(jcp_.ur_w, l_overflow, 0, false);
1327 add(reg_src_, src_shift);
1328 add(reg_dst_, dst_shift);
1329 inc(reg_nur_w_);
1330 }
1331 if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
1332 Label ow_loop_label;
1333 L(ow_loop_label);
1334 {
1335 icb_loop(jcp_.ur_w, 0, 0, false);
1336 add(reg_src_, src_shift);
1337 add(reg_dst_, dst_shift);
1338 inc(reg_nur_w_);
1339 cmp(reg_nur_w_, nur_w);
1340 jl(ow_loop_label, T_NEAR);
1341 }
1342 }
1343 if (r_overflow1 > 0) {
1344 icb_loop(jcp_.ur_w, 0, r_overflow1, jcp_.ur_w_tail == 0);
1345 add(reg_src_, src_shift);
1346 add(reg_dst_, dst_shift);
1347 }
1348 if (jcp_.ur_w_tail != 0) {
1349 icb_loop(jcp_.ur_w_tail, 0, r_overflow, true);
1350 }
1351 }
1352
1353 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp_))
1354 add(rsp, reserved_stack_size_);
1355
1356 postamble();
1357
1358 if (jcp_.with_eltwise) postops_injector_->prepare_table();
1359 }
1360
1361 template <cpu_isa_t isa>
jit_uni_x8s8s32x_deconvolution_fwd_t(const pd_t * apd)1362 jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::jit_uni_x8s8s32x_deconvolution_fwd_t(
1363 const pd_t *apd)
1364 : primitive_t(apd) {}
1365
1366 template <cpu_isa_t isa>
1367 jit_uni_x8s8s32x_deconvolution_fwd_t<
1368 isa>::~jit_uni_x8s8s32x_deconvolution_fwd_t()
1369 = default;
1370
1371 template <cpu_isa_t isa>
init(engine_t * engine)1372 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::pd_t::init(
1373 engine_t *engine) {
1374 using namespace data_type;
1375 using skip_mask_t = primitive_attr_t::skip_mask_t;
1376 const bool ok = true && is_fwd()
1377 && (desc()->alg_kind & alg_kind::deconvolution_direct)
1378 && utils::one_of(src_md(0)->data_type, s8, u8)
1379 && weights_md(0)->data_type == s8
1380 && IMPLICATION(with_bias(),
1381 utils::one_of(weights_md(1)->data_type, f32, s32, s8, u8))
1382 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
1383 && desc()->accum_data_type == s32
1384 && attr()->has_default_values(skip_mask_t::oscale
1385 | skip_mask_t::post_ops | skip_mask_t::zero_points_runtime);
1386 if (!ok) return status::unimplemented;
1387
1388 CHECK(jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_conf(jcp_, *desc(),
1389 src_md_, weights_md_, dst_md_, with_bias(), bias_md_, attr_,
1390 dnnl_get_max_threads()));
1391
1392 auto scratchpad = scratchpad_registry().registrar();
1393 jit_uni_x8s8s32x_deconv_fwd_kernel<isa>::init_scratchpad(
1394 scratchpad, jcp_, *attr());
1395
1396 return status::success;
1397 }
1398
1399 template <cpu_isa_t isa>
init(engine_t * engine)1400 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::init(engine_t *engine) {
1401 CHECK(safe_ptr_assign(kernel_,
1402 new jit_uni_x8s8s32x_deconv_fwd_kernel<isa>(pd()->jcp_,
1403 *pd()->attr(), memory_desc_wrapper(pd()->dst_md()))));
1404
1405 if (zp::should_calculate_deconv_zp_src_pad_str_comp(pd()->jcp_)) {
1406 CHECK(safe_ptr_assign(zp_src_pad_comp_kernel_,
1407 zp::create_deconv_zp_pad_str_comp_ker<isa>(pd()->jcp_)));
1408 const auto zp_kernel_status = zp_src_pad_comp_kernel_->create_kernel();
1409 if (zp_kernel_status != status::success) return zp_kernel_status;
1410 }
1411
1412 return kernel_->create_kernel();
1413 }
1414
1415 template <cpu_isa_t isa>
execute(const exec_ctx_t & ctx) const1416 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute(
1417 const exec_ctx_t &ctx) const {
1418 const auto &_pd = pd();
1419 const auto &ndims = _pd->ndims();
1420 if (ndims == 3)
1421 return execute_forward_1d(ctx);
1422 else if (ndims == 4)
1423 return execute_forward_2d(ctx);
1424 else if (ndims == 5)
1425 return execute_forward_3d(ctx);
1426 else
1427 return status::unimplemented;
1428 return status::success;
1429 }
1430
1431 template <cpu_isa_t isa>
execute_forward_1d(const exec_ctx_t & ctx) const1432 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_1d(
1433 const exec_ctx_t &ctx) const {
1434 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1435 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1436 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1437 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1438 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1439 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1440
1441 const auto &jcp = pd()->jcp_;
1442
1443 const memory_desc_wrapper src_d(pd()->src_md());
1444 const memory_desc_wrapper dst_d(pd()->dst_md());
1445 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1446 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1447
1448 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1449
1450 const auto post_ops_binary_rhs_arg_vec
1451 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1452 auto scratchpad = ctx.get_scratchpad_grantor();
1453 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1454
1455 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1456 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1457 weights_d, weights, zp_src, zp_src_comp_scratch,
1458 zp_src_pad_comp_kernel_.get());
1459
1460 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1461 const int nb_groups = jcp.nb_ch;
1462
1463 const float *oscales = pd()->attr()->output_scales_.scales_;
1464 if (jcp.signed_input && jcp.ver != ver_vnni) {
1465 auto local_scales = ctx.get_scratchpad_grantor().template get<float>(
1466 key_conv_adjusted_scales);
1467 const size_t count = pd()->attr()->output_scales_.count_;
1468 const float factor = 1.f / pd()->jcp_.wei_adj_scale;
1469 if (count == 1) {
1470 utils::array_set(local_scales, oscales[0] * factor, 8);
1471 } else {
1472 for (size_t c = 0; c < count; c++)
1473 local_scales[c] = oscales[c] * factor;
1474 }
1475 oscales = local_scales;
1476 }
1477 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1478 auto w = const_cast<int8_t *>(weights);
1479 int32_t *compensation = (jcp.signed_input)
1480 ? reinterpret_cast<int32_t *>(&w[offset])
1481 : nullptr;
1482 const int32_t *zp_compensation = jcp.src_zero_point
1483 ? get_src_zp_comp_from_wei(
1484 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1485 : nullptr;
1486
1487 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1488 int start {0}, end {0};
1489 const int work_amount = jcp.mb * nb_groups * oc_chunks;
1490 balance211(work_amount, nthr, ithr, start, end);
1491
1492 auto p = jit_deconv_call_s();
1493
1494 int n {0}, g {0}, occ {0};
1495 if (jcp.loop_order == loop_ngc)
1496 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
1497 else if (jcp.loop_order == loop_cgn)
1498 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
1499 else
1500 assert(!"unsupported loop order");
1501 while (start < end) {
1502
1503 const int ocb = occ * jcp.nb_oc_blocking;
1504 const int g_oc
1505 = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1506 const int g_ic = g * jcp.ch_block * jcp.ic;
1507
1508 p.dst = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1509 p.src = src + src_d.blk_off(n, g_ic);
1510 p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
1511 p.bias = jcp.with_bias
1512 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1513 : nullptr;
1514 p.compensation = (jcp.signed_input) ? compensation + g_oc : nullptr;
1515 p.scales = &oscales[jcp.is_oc_scale * g_oc];
1516 p.t_overflow = 0;
1517 p.b_overflow = 0;
1518 p.kh_padding = jcp.kh;
1519 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1520 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
1521 p.oc_l_off = g_oc;
1522 p.zp_compensation
1523 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1524 p.zp_src_pad_str_compensation = zp_src_comp_scratch
1525 ? zp_src_comp_scratch + g_oc
1526 : nullptr;
1527 p.src_zero_point = zp_src;
1528 p.dst_zero_point = zp_dst;
1529 p.dst_orig = dst;
1530
1531 (*kernel_)(&p);
1532
1533 ++start;
1534 if (jcp.loop_order == loop_ngc)
1535 nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
1536 else if (jcp.loop_order == loop_cgn)
1537 nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
1538 else
1539 assert(!"unsupported loop order");
1540 }
1541 });
1542 return status::success;
1543 }
1544
1545 template <cpu_isa_t isa>
execute_forward_2d(const exec_ctx_t & ctx) const1546 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_2d(
1547 const exec_ctx_t &ctx) const {
1548 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1549 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1550 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1551 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1552 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1553 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1554
1555 const memory_desc_wrapper src_d(pd()->src_md());
1556 const memory_desc_wrapper dst_d(pd()->dst_md());
1557 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1558 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1559
1560 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1561
1562 const auto &jcp = pd()->jcp_;
1563 const auto post_ops_binary_rhs_arg_vec
1564 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1565
1566 auto scratchpad = ctx.get_scratchpad_grantor();
1567 int32_t *zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1568
1569 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1570 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1571 weights_d, weights, zp_src, zp_src_comp_scratch,
1572 zp_src_pad_comp_kernel_.get());
1573
1574 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1575 const int nb_groups = jcp.nb_ch;
1576
1577 const size_t src_h_stride = src_d.blk_off(0, 0, 1);
1578 const size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
1579 const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1580
1581 const float *oscales = pd()->attr()->output_scales_.scales_;
1582 if (jcp.signed_input && jcp.ver != ver_vnni) {
1583 auto local_scales = ctx.get_scratchpad_grantor().template get<float>(
1584 key_conv_adjusted_scales);
1585 const size_t count = pd()->attr()->output_scales_.count_;
1586 const float factor = 1.f / pd()->jcp_.wei_adj_scale;
1587 if (count == 1) {
1588 utils::array_set(local_scales, oscales[0] * factor, 8);
1589 } else {
1590 for (size_t c = 0; c < count; c++)
1591 local_scales[c] = oscales[c] * factor;
1592 }
1593 oscales = local_scales;
1594 }
1595 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1596 auto w = const_cast<int8_t *>(weights);
1597 int32_t *compensation = (jcp.signed_input)
1598 ? reinterpret_cast<int32_t *>(&w[offset])
1599 : nullptr;
1600 const int32_t *zp_compensation = jcp.src_zero_point
1601 ? get_src_zp_comp_from_wei(
1602 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1603 : nullptr;
1604
1605 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1606 int start {0}, end {0};
1607 const int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
1608 balance211(work_amount, nthr, ithr, start, end);
1609
1610 auto p = jit_deconv_call_s();
1611
1612 /*loop order = cgn*/
1613 int n {0}, g {0}, occ {0}, oh_s {0};
1614 if (jcp.loop_order == loop_ngc)
1615 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1616 oh_s, jcp.oh);
1617 else if (jcp.loop_order == loop_cgn)
1618 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1619 oh_s, jcp.oh);
1620 else
1621 assert(!"unsupported loop order");
1622 while (start < end) {
1623
1624 const int ocb = occ * jcp.nb_oc_blocking;
1625 const int g_oc
1626 = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1627 const int g_ic = g * jcp.ch_block * jcp.ic;
1628 const int work_rem = end - start;
1629 const int oh_e
1630 = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1631
1632 const auto dst_w = dst + dst_dt_size * dst_d.blk_off(n, g_oc);
1633 const auto src_w = src + src_d.blk_off(n, g_ic);
1634 const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
1635 const auto bias_w = jcp.with_bias
1636 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1637 : nullptr;
1638 const int32_t *compensation_w
1639 = (jcp.signed_input) ? compensation + g_oc : nullptr;
1640
1641 const auto scales = &oscales[jcp.is_oc_scale * g_oc];
1642 for (int oj = oh_s; oj < oh_e; oj++) {
1643 int ih_max = 0, kh_lo = 0, kh_len = 0;
1644 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1645 /* dilation */
1646 const int dilate_h = jcp.dilate_h + 1;
1647 // Note: use div_up to account for "holes" in filter
1648 const int o_t_overflow = div_up(
1649 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1650 dilate_h);
1651 const int o_b_overflow
1652 = div_up(max(0,
1653 (jcp.kh - 1) * dilate_h + 1
1654 - jcp.oh + oj - jcp.b_pad),
1655 dilate_h);
1656 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1657 kh_lo = o_b_overflow;
1658 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1659 } else {
1660 const int o_t_overflow = max(
1661 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1662 const int o_b_overflow = max(0,
1663 ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1664 / jcp.stride_h);
1665 const int overflow_kh_hi = jcp.kh - 1
1666 - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1667 jcp.stride_h);
1668 const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1669
1670 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1671 + 1 - o_t_overflow - o_b_overflow;
1672 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1673 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1674 }
1675
1676 const int wei_stride
1677 = (!jcp.signed_input && !jcp.src_zero_point)
1678 ? kh_lo * wht_kh_stride
1679 : 0;
1680 p.src = src_w + ih_max * src_h_stride;
1681 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1682 p.filt = wht_w + wei_stride;
1683 p.bias = bias_w;
1684 p.compensation = compensation_w;
1685 p.t_overflow = jcp.dilate_h > 0
1686 ? jcp.kh - kh_len - kh_lo
1687 : max(0,
1688 jcp.kh
1689 - (kh_lo
1690 + max(0, kh_len - 1)
1691 * jcp.stride_h
1692 + 1));
1693 p.b_overflow = kh_lo;
1694 p.kh_padding = kh_len;
1695 p.scales = scales;
1696 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1697 p.post_ops_binary_rhs_arg_vec
1698 = post_ops_binary_rhs_arg_vec.data();
1699 p.oc_l_off = g_oc;
1700 p.zp_compensation
1701 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1702 p.zp_src_pad_str_compensation = jcp.src_zero_point
1703 ? zp_src_comp_scratch + g_oc
1704 : nullptr;
1705 p.src_zero_point = zp_src;
1706 p.dst_zero_point = zp_dst;
1707 p.dst_orig = dst;
1708
1709 (*kernel_)(&p);
1710 }
1711 if (jcp.loop_order == loop_ngc)
1712 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1713 oc_chunks, oh_s, jcp.oh);
1714 else if (jcp.loop_order == loop_cgn)
1715 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1716 jcp.mb, oh_s, jcp.oh);
1717 else
1718 assert(!"unsupported loop order");
1719 }
1720 });
1721 return status::success;
1722 }
1723
1724 template <cpu_isa_t isa>
execute_forward_3d(const exec_ctx_t & ctx) const1725 status_t jit_uni_x8s8s32x_deconvolution_fwd_t<isa>::execute_forward_3d(
1726 const exec_ctx_t &ctx) const {
1727 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
1728 const auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
1729 const auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1730 const auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1731 DEFINE_ZERO_POINTS_BUFFER(zp_src, DNNL_ARG_SRC);
1732 DEFINE_ZERO_POINTS_BUFFER(zp_dst, DNNL_ARG_DST);
1733
1734 const memory_desc_wrapper src_d(pd()->src_md());
1735 const memory_desc_wrapper dst_d(pd()->dst_md());
1736 const memory_desc_wrapper weights_d(pd()->weights_md(0));
1737 const memory_desc_wrapper bias_d(pd()->weights_md(1));
1738
1739 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1740
1741 const auto &jcp = pd()->jcp_;
1742 const auto post_ops_binary_rhs_arg_vec
1743 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
1744
1745 const auto scratchpad = ctx.get_scratchpad_grantor();
1746 int32_t *const zp_src_comp_scratch = scratchpad.get<int32_t>(key_deconv_zp);
1747
1748 if (zp::should_calculate_deconv_zp_src_pad_str_comp(jcp))
1749 zp::compute_deconv_zp_pad_str_comp_ker(jcp, pd()->with_groups(),
1750 weights_d, weights, zp_src, zp_src_comp_scratch,
1751 zp_src_pad_comp_kernel_.get());
1752
1753 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
1754 const int &nb_groups = jcp.nb_ch;
1755
1756 const size_t src_d_stride = src_d.blk_off(0, 0, 1);
1757 const size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
1758 const size_t dst_d_stride = dst_d.blk_off(0, 0, 1);
1759 const size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
1760 const size_t wht_kd_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
1761 const size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
1762
1763 const float *oscales = pd()->attr()->output_scales_.scales_;
1764 if (jcp.signed_input && jcp.ver != ver_vnni) {
1765 const auto local_scales
1766 = ctx.get_scratchpad_grantor().template get<float>(
1767 key_conv_adjusted_scales);
1768 const size_t count = pd()->attr()->output_scales_.count_;
1769 const float factor = 1.f / pd()->jcp_.wei_adj_scale;
1770 if (count == 1) {
1771 utils::array_set(local_scales, oscales[0] * factor, 8);
1772 } else {
1773 for (size_t c = 0; c < count; c++)
1774 local_scales[c] = oscales[c] * factor;
1775 }
1776 oscales = local_scales;
1777 }
1778 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
1779 auto w = const_cast<int8_t *>(weights);
1780 int32_t *compensation = (jcp.signed_input)
1781 ? reinterpret_cast<int32_t *>(&w[offset])
1782 : nullptr;
1783 const int32_t *zp_compensation = jcp.src_zero_point
1784 ? get_src_zp_comp_from_wei(
1785 weights, weights_d, jcp.signed_input, jcp.ngroups, jcp.oc)
1786 : nullptr;
1787
1788 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1789 int start {0}, end {0};
1790 int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh;
1791 balance211(work_amount, nthr, ithr, start, end);
1792
1793 auto p = jit_deconv_call_s();
1794
1795 /*loop order = cgn*/
1796 int n {0}, g {0}, occ {0}, od_s {0}, oh_s {0};
1797 if (jcp.loop_order == loop_ngc)
1798 nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
1799 od_s, jcp.od, oh_s, jcp.oh);
1800 else if (jcp.loop_order == loop_cgn)
1801 nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
1802 od_s, jcp.od, oh_s, jcp.oh);
1803 else
1804 assert(!"unsupported loop order");
1805 while (start < end) {
1806
1807 const int ocb = occ * jcp.nb_oc_blocking;
1808 const int g_oc
1809 = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
1810 const int g_ic = g * jcp.ch_block * jcp.ic;
1811 const int work_rem = end - start;
1812 const int oh_e
1813 = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1814 int input_d_s = 0, kd_len = 0, kd_lo = 0;
1815 int d_t_overflow, d_back_overflow;
1816
1817 if (jcp.dilate_d != 0 && jcp.stride_d == 1) {
1818 /* dilation */
1819 int dilate_d = jcp.dilate_d + 1;
1820 // Note: use div_up to account for "holes" in filter
1821 d_t_overflow = div_up(
1822 max(0, (jcp.kd - 1) * dilate_d - od_s - jcp.f_pad),
1823 dilate_d);
1824 d_back_overflow
1825 = div_up(max(0,
1826 (jcp.kd - 1) * dilate_d + 1 - jcp.od
1827 + od_s - jcp.back_pad),
1828 dilate_d);
1829 kd_len = jcp.kd - d_t_overflow - d_back_overflow;
1830 kd_lo = d_back_overflow;
1831 input_d_s = od_s + jcp.f_pad - d_back_overflow * dilate_d;
1832 } else {
1833 const int d_t_overflow = max(
1834 0, (jcp.kd - (od_s + 1 + jcp.f_pad)) / jcp.stride_d);
1835 const int d_back_overflow = max(0,
1836 ((od_s + jcp.kd) - (jcp.od + jcp.back_pad))
1837 / jcp.stride_d);
1838 const int overflow_kd_hi = jcp.kd - 1
1839 - modulo(jcp.od + jcp.back_pad - (od_s + 1),
1840 jcp.stride_d);
1841 const int overflow_kd_lo = (od_s + jcp.f_pad) % jcp.stride_d;
1842
1843 kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1
1844 - d_t_overflow - d_back_overflow;
1845 kd_lo = overflow_kd_lo + d_back_overflow * jcp.stride_d;
1846 input_d_s = (od_s + jcp.f_pad - kd_lo) / jcp.stride_d;
1847 }
1848
1849 auto dst_w = dst
1850 + dst_dt_size
1851 * (dst_d.blk_off(n, g_oc) + od_s * dst_d_stride);
1852 const auto src_w
1853 = src + src_d.blk_off(n, g_ic) + input_d_s * src_d_stride;
1854 const auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
1855 + ((jcp.signed_input || jcp.src_zero_point) ? 0 : kd_lo)
1856 * wht_kd_stride;
1857 const auto bias_w = jcp.with_bias
1858 ? bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia)
1859 : nullptr;
1860 const int32_t *compensation_w
1861 = (jcp.signed_input) ? compensation + g_oc : nullptr;
1862
1863 const auto scales = &oscales[jcp.is_oc_scale * g_oc];
1864
1865 for (int oj = oh_s; oj < oh_e; oj++) {
1866 int ih_max = 0, kh_lo = 0, kh_len = 0;
1867 if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
1868 /* dilation */
1869 const int dilate_h = jcp.dilate_h + 1;
1870 // Note: use div_up to account for "holes" in filter
1871 const int o_t_overflow = div_up(
1872 max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
1873 dilate_h);
1874 const int o_b_overflow
1875 = div_up(max(0,
1876 (jcp.kh - 1) * dilate_h + 1
1877 - jcp.oh + oj - jcp.b_pad),
1878 dilate_h);
1879 kh_len = jcp.kh - o_t_overflow - o_b_overflow;
1880 kh_lo = o_b_overflow;
1881 ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
1882 } else {
1883 const int o_t_overflow = max(
1884 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
1885 const int o_b_overflow = max(0,
1886 ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
1887 / jcp.stride_h);
1888 const int overflow_kh_hi = jcp.kh - 1
1889 - modulo(jcp.oh + jcp.b_pad - (oj + 1),
1890 jcp.stride_h);
1891 const int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
1892
1893 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
1894 + 1 - o_t_overflow - o_b_overflow;
1895 kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
1896 ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
1897 }
1898
1899 const int wei_stride
1900 = (!jcp.signed_input && !jcp.src_zero_point)
1901 ? kh_lo * wht_kh_stride
1902 : 0;
1903 p.src = src_w + ih_max * src_h_stride;
1904 p.dst = dst_w + dst_dt_size * oj * dst_h_stride;
1905 p.filt = wht_w + wei_stride;
1906 p.bias = bias_w;
1907 p.compensation = compensation_w;
1908 /* Note: Currently this kernel doesn't support dilations and
1909 strides together */
1910 p.t_overflow = jcp.dilate_h > 0
1911 ? jcp.kh - kh_len - kh_lo
1912 : max(0,
1913 jcp.kh
1914 - (kh_lo
1915 + max(0, kh_len - 1)
1916 * jcp.stride_h
1917 + 1));
1918 p.b_overflow = kh_lo;
1919 p.f_overflow = jcp.dilate_d > 0
1920 ? jcp.kd - kd_len - kd_lo
1921 : max(0,
1922 jcp.kd
1923 - (kd_lo
1924 + max(0, kd_len - 1)
1925 * jcp.stride_d
1926 + 1));
1927 p.back_overflow = kd_lo;
1928 p.kh_padding = kh_len;
1929 p.kd_padding = kd_len;
1930 p.scales = scales;
1931 p.oc_blocks = jcp.is_depthwise ? g : ocb;
1932 p.post_ops_binary_rhs_arg_vec
1933 = post_ops_binary_rhs_arg_vec.data();
1934 p.oc_l_off = g_oc;
1935 p.zp_compensation
1936 = jcp.src_zero_point ? zp_compensation + g_oc : nullptr;
1937 p.zp_src_pad_str_compensation = jcp.src_zero_point
1938 ? zp_src_comp_scratch + g_oc
1939 : nullptr;
1940 p.src_zero_point = zp_src;
1941 p.dst_zero_point = zp_dst;
1942 p.dst_orig = dst;
1943
1944 (*kernel_)(&p);
1945 }
1946
1947 if (jcp.loop_order == loop_ngc)
1948 nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
1949 oc_chunks, od_s, jcp.od, oh_s, jcp.oh);
1950 else if (jcp.loop_order == loop_cgn)
1951 nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
1952 jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
1953 else
1954 assert(!"unsupported loop order");
1955 }
1956 });
1957
1958 return status::success;
1959 }
1960
1961 using namespace data_type;
1962 template struct jit_uni_x8s8s32x_deconvolution_fwd_t<avx2>;
1963 template struct jit_uni_x8s8s32x_deconvolution_fwd_t<sse41>;
1964 template struct jit_uni_x8s8s32x_deconv_fwd_kernel<avx2>;
1965 template struct jit_uni_x8s8s32x_deconv_fwd_kernel<sse41>;
1966 template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>;
1967 template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Xmm>;
1968 template struct _jit_uni_x8s8s32x_deconv_fwd_kernel<sse41, Xbyak::Xmm>;
1969 } // namespace x64
1970 } // namespace cpu
1971 } // namespace impl
1972 } // namespace dnnl
1973