1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP
18 #define CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP
19 
20 #include <vector>
21 #include "common/c_types_map.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/memory.hpp"
24 #include "common/nstl.hpp"
25 #include "common/primitive.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28 
29 #include "cpu/cpu_deconvolution_pd.hpp"
30 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
31 #include "cpu/x64/jit_generator.hpp"
32 #include "cpu/x64/jit_primitive_conf.hpp"
33 
34 namespace dnnl {
35 namespace impl {
36 namespace cpu {
37 namespace x64 {
38 
39 typedef enum {
40     no_last_block = 0x1U,
41     last_ic_block = 0x2U,
42     last_sp_block = 0x4U,
43 } ker_block_t;
44 
45 struct ur_w_blks_params_t {
46     struct single_ur_w_blk_params_t {
single_ur_w_blk_params_tdnnl::impl::cpu::x64::ur_w_blks_params_t::single_ur_w_blk_params_t47         single_ur_w_blk_params_t(
48                 int l_overflow, int r_overflow, bool process_sp_carefully)
49             : l_overflow(l_overflow)
50             , r_overflow(r_overflow)
51             , process_sp_carefully(process_sp_carefully) {}
52 
53         // l_overflow - no. of spatial elements of weights standing out of
54         // src spatial when computing the 1st output pixel in the current blk
55         int l_overflow;
56         // r_overflow - no. of spatial elements of weights standing out of
57         // src spatial when computing the lst output pixel in the current blk
58         int r_overflow;
59         // process_sp_carefully - indicates if loading the last src sp
60         // for computation of the last dst sp of the block can't be done
61         // by fetching 4 src sp at once
62         bool process_sp_carefully;
63     };
64     std::vector<single_ur_w_blk_params_t> blks_params;
65     int num_pre_blks; // num of blocks with l_overflow>0
66     int num_post_blks; // num of blocks with r_overflow>0 or that need to be
67             // processed carefully
68 };
69 
70 template <typename Vmm>
71 struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator {
72     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t);
73 
74     jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
75             const primitive_attr_t &attr, const memory_desc_t &dst_md);
76     ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel();
77 
78     const jit_conv_conf_t &jcp;
79     const primitive_attr_t &attr_;
80 
81 private:
82     std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
83             postops_injector_;
84 
85     const int ic_sub_step = 4;
86 
87     /* data regs */
88     const Xbyak::Reg64 reg_src = r8;
89     const Xbyak::Reg64 reg_filt = r9;
90     const Xbyak::Reg64 reg_dst = r10;
91     const Xbyak::Reg64 param1 = abi_param1;
92     const Xbyak::Reg64 reg_kh = abi_not_param1;
93     const Xbyak::Reg64 reg_ki = r14;
94 
95     const Xbyak::Reg64 reg_nur_w = rbx;
96     const Xbyak::Reg64 reg_bias = rdx;
97     const Xbyak::Reg64 reg_icb = reg_bias;
98     const Xbyak::Reg64 reg_ptr_scales = rax;
99     const Xbyak::Reg64 reg_ptr_saturation_ubound = rax;
100     const Xbyak::Reg64 reg_oc_blocks = rsi;
101 
102     const Xbyak::Reg64 aux_reg_src = r11;
103     const Xbyak::Reg64 aux_reg_filt = r12;
104 
105     const Xbyak::Reg64 aux_reg_src_d = r13;
106     const Xbyak::Reg64 aux_reg_filt_d = r15;
107 
108     const Xbyak::Reg64 reg_compensation = r14;
109     const Xbyak::Reg64 reg_scratch = r14;
110     const Xbyak::Reg64 reg_ptr_sum_scale = r11;
111     const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
112     const Xbyak::Reg64 reg_overflow = rax;
113     const Xbyak::Reg64 reg_comp_strides = reg_overflow;
114     const Xbyak::Reg64 reg_ker_long_offt = r15;
115 
116     Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
117     const Vmm vmm_tmp = Vmm(28);
118     const Vmm vmm_one = Vmm(29);
119     /* used during write-out section of store_output */
120     const Vmm vmm_zero = Vmm(31);
121     const Vmm vmm_saturation = Vmm(31);
122     const Vmm vmm_wei = Vmm(31);
123 
124     /* signed input */
125     const Vmm vmm_shift = Vmm(30);
126     const Vmm vmm_comp = Vmm(30);
127     const Vmm vmm_bias = Vmm(31);
128     const Vmm vmm_prev_dst = Vmm(31);
129 
vmm_outdnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel130     Vmm vmm_out(int i_ur, int i_oc) {
131         int idx = i_ur * jcp.nb_oc_blocking + i_oc;
132         assert(idx < 31);
133         return Vmm(idx);
134     }
vmm_inpdnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel135     Vmm vmm_inp(int i_ic, int nb_x_blocking) {
136         int idx = i_ic + nb_x_blocking * jcp.ur_w;
137         assert(idx < 31);
138         return Vmm(idx);
139     }
vmm_bias_alphadnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel140     Vmm vmm_bias_alpha() { return Vmm(jcp.nb_oc_blocking * jcp.ur_w); }
xmm_bias_alphadnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel141     Xbyak::Xmm xmm_bias_alpha() {
142         return Xbyak::Xmm(jcp.nb_oc_blocking * jcp.ur_w);
143     }
144 
get_ow_startdnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel145     int get_ow_start(int ki, int l_overflow) {
146         int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w
147                 + l_overflow * jcp.stride_w
148                 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
149         while (res < 0)
150             res += jcp.stride_w;
151         return res;
152     }
153 
get_ow_enddnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel154     int get_ow_end(int ur_w, int ki, int r_overflow) {
155         if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail))
156             ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
157         int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
158                 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
159         while (res < 0)
160             res += jcp.stride_w;
161         return ur_w - res;
162     }
163     void prepare_output(int ur_w);
164     void store_output(int ur_w, bool last_oc_block);
165     void compute_ker(int ur_w, int l_overflow, int r_overflow,
166             ker_block_t last_ic_block_flag, bool h_padded = false);
167     void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block);
168     void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block);
169 
170     ur_w_blks_params_t get_ur_w_blks_params();
171 
172     void generate() override;
173     void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op,
174             bool mask_flag);
175 };
176 
177 struct _jit_avx512_core_x8s8s32x_deconv_fwd_kernel {
178 
_jit_avx512_core_x8s8s32x_deconv_fwd_kerneldnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel179     _jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
180             const primitive_attr_t &attr, const memory_desc_t &dst_md)
181         : kernel_(nullptr) {
182 
183         int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
184         switch (ch_block) {
185             case 16:
186                 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
187                         Xbyak::Zmm>(ajcp, attr, dst_md);
188                 return;
189             case 8:
190                 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
191                         Xbyak::Ymm>(ajcp, attr, dst_md);
192                 return;
193             case 4:
194                 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel<
195                         Xbyak::Xmm>(ajcp, attr, dst_md);
196                 return;
197             default: assert(!"invalid channel blocking");
198         }
199     }
200 
create_kerneldnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel201     status_t create_kernel() { return kernel_->create_kernel(); }
202 
~_jit_avx512_core_x8s8s32x_deconv_fwd_kerneldnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel203     ~_jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { delete kernel_; }
204 
operator ()dnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel205     void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); }
206 
207     static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr,
208             const memory_desc_wrapper &dst_d);
209 
210     static status_t init_conf(jit_conv_conf_t &jcp,
211             const deconvolution_desc_t &cd, memory_desc_t &src_md,
212             memory_desc_t &weights_md, memory_desc_t &dst_md,
213             const bool with_bias, memory_desc_t &bias_md,
214             const primitive_attr_t &attr, int nthreads);
215 
216     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
217             const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
218 
219 private:
220     DNNL_DISALLOW_COPY_AND_ASSIGN(_jit_avx512_core_x8s8s32x_deconv_fwd_kernel);
221     jit_generator *kernel_;
222 };
223 
224 template <impl::data_type_t src_type, impl::data_type_t dst_type>
225 struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public primitive_t {
226     struct pd_t : public cpu_deconvolution_fwd_pd_t {
227         using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t;
228 
229         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_deconvolution:",
230                                     ((jcp_.ver == ver_vnni) ? avx512_core_vnni
231                                                             : avx512_core),
232                                     ""),
233                 _jit_avx512_core_x8s8s32x_deconvolution_fwd_t);
234 
initdnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t::pd_t235         status_t init(engine_t *engine) {
236             const bool ok = is_fwd()
237                     && (desc()->alg_kind & alg_kind::deconvolution_direct)
238                     && desc()->src_desc.data_type == src_type
239                     && desc()->dst_desc.data_type == dst_type
240                     && IMPLICATION(with_bias(),
241                             utils::one_of(desc()->bias_desc.data_type,
242                                     data_type::f32, data_type::s32,
243                                     data_type::s8, data_type::u8))
244                     && desc()->accum_data_type == data_type::s32
245                     && attr()->has_default_values(
246                             primitive_attr_t::skip_mask_t::oscale
247                             | primitive_attr_t::skip_mask_t::post_ops);
248             if (!ok) return status::unimplemented;
249 
250             status_t status
251                     = _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
252                             jcp_, *desc(), src_md_, weights_md_, dst_md_,
253                             with_bias(), bias_md_, *attr(),
254                             dnnl_get_max_threads());
255 
256             if (status != status::success) return status;
257 
258             auto scratchpad = scratchpad_registry().registrar();
259             _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
260                     scratchpad, jcp_, *attr());
261 
262             return status::success;
263         }
264 
265         jit_conv_conf_t jcp_;
266     };
267 
_jit_avx512_core_x8s8s32x_deconvolution_fwd_tdnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t268     _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd)
269         : primitive_t(apd) {}
270 
271     typedef typename prec_traits<src_type>::type src_data_t;
272     typedef typename prec_traits<data_type::s8>::type wei_data_t;
273     typedef typename prec_traits<dst_type>::type dst_data_t;
274 
initdnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t275     status_t init(engine_t *engine) override {
276         CHECK(safe_ptr_assign(kernel_,
277                 new _jit_avx512_core_x8s8s32x_deconv_fwd_kernel(
278                         pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
279         return kernel_->create_kernel();
280     }
281 
executednnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t282     status_t execute(const exec_ctx_t &ctx) const override {
283         auto ndims = pd()->ndims();
284         if (ndims == 3)
285             execute_forward_1d(ctx);
286         else if (ndims == 4)
287             execute_forward_2d(ctx);
288         else if (ndims == 5)
289             execute_forward_3d(ctx);
290         else
291             return status::unimplemented;
292         return status::success;
293     }
294 
295 private:
296     void execute_forward_1d(const exec_ctx_t &ctx) const;
297     void execute_forward_2d(const exec_ctx_t &ctx) const;
298     void execute_forward_3d(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t299     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
300     std::unique_ptr<_jit_avx512_core_x8s8s32x_deconv_fwd_kernel> kernel_;
301 };
302 
303 } // namespace x64
304 } // namespace cpu
305 } // namespace impl
306 } // namespace dnnl
307 
308 #endif
309 
310 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
311