1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP
18 #define CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP
19 
20 #include <functional>
21 #include <memory>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/primitive.hpp"
25 
26 #include "cpu/cpu_deconvolution_pd.hpp"
27 
28 #include "cpu/x64/jit_generator.hpp"
29 #include "cpu/x64/jit_primitive_conf.hpp"
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace x64 {
35 
36 namespace zp {
37 class jit_uni_deconv_zp_pad_str_kernel_base_t;
38 } // namespace zp
39 
40 namespace injector {
41 template <cpu_isa_t isa, typename Vmm>
42 class jit_uni_postops_injector_t;
43 } // namespace injector
44 
45 using namespace Xbyak;
46 
47 template <cpu_isa_t isa, typename Vmm>
48 struct _jit_uni_x8s8s32x_deconv_fwd_kernel : public jit_generator {
49     DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_deconv_fwd_kernel);
50 
51     _jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
52             const primitive_attr_t &attr, const memory_desc_wrapper &dst_d);
53     ~_jit_uni_x8s8s32x_deconv_fwd_kernel();
54 
55     const jit_conv_conf_t jcp_;
56 
57 private:
58     std::unique_ptr<injector::jit_uni_postops_injector_t<isa, Vmm>>
59             postops_injector_;
60     using reg64_t = const Xbyak::Reg64;
61 
62     static constexpr dim_t IC_SUB_STEP = 4;
63     static constexpr dim_t KER_MAX_REG_IDX = 13;
64 
65     enum ker_block_t {
66         no_last_block = 0x1U,
67         last_ic_block = 0x2U,
68         last_sp_block = 0x4U,
69     };
70 
71     /* data regs */
72     const reg64_t &reg_src_ = r8;
73     const reg64_t &reg_filt_ = r9;
74     const reg64_t &reg_dst_ = r10;
75     const reg64_t &param1_ = abi_param1;
76     const reg64_t &reg_kh_ = abi_not_param1;
77     const reg64_t &reg_ki_ = r14;
78 
79     const reg64_t &reg_nur_w_ = rbx;
80     const reg64_t &reg_bias_ = rdx;
81     const reg64_t &reg_icb_ = reg_bias_;
82     const reg64_t &reg_ptr_scales_ = rax;
83     const reg64_t &reg_ptr_saturation_ubound_ = rax;
84     const reg64_t &reg_oc_blocks_ = rsi;
85 
86     const reg64_t &aux_reg_src_ = r11;
87     const reg64_t &aux_reg_filt_ = r12;
88 
89     const reg64_t &aux_reg_src_d_ = r13;
90     const reg64_t &aux_reg_filt_d_ = r15;
91 
92     const reg64_t &reg_compensation_ = r14;
93     const reg64_t &reg_scratch_ = r14;
94     const reg64_t &reg_ptr_sum_scale_ = r11;
95     const reg64_t &reg_ptr_sum_zp_ = r15;
96     const reg64_t &reg_bias_alpha_ = abi_not_param1;
97     const reg64_t &reg_overflow_ = rax;
98     const reg64_t &reg_comp_strides_ = reg_overflow_;
99     const reg64_t &reg_ker_long_offt_ = r15;
100     const reg64_t &reg_zp_dst_ = r15;
101     const reg64_t &reg_zp_src_ = r15;
102     const reg64_t &reg_zp_compensation_ = r11;
103     const Xbyak::Address zp_src_pad_comp_addr_ = ptr[rsp];
104     const Xbyak::Address reg_scratch_preserved_ = ptr[rsp + 8];
105     static constexpr int reserved_stack_size_ = 16;
106 
107     const Vmm vmm_tmp_ = Vmm(3);
108     const Vmm vmm_one_ = Vmm(2);
109     /* used during write-out section of store_output */
110     const Vmm vmm_zero_ = Vmm(0);
111     const Vmm &vmm_saturation_ = vmm_zero_;
112     const Vmm &vmm_wei_ = vmm_zero_;
113     const Vmm &vmm_scale_ = vmm_zero_;
114     /* signed input */
115     const Vmm vmm_shift_ = Vmm(1);
116     const Vmm vmm_comp_ = Vmm(1);
117     const Vmm &vmm_bias_ = vmm_zero_;
118     const Vmm &vmm_prev_dst_ = vmm_zero_;
119     const Vmm &vmm_sum_zp_ = vmm_tmp_;
120 
121     Vmm vmm_out(int i_ur, int i_oc) const;
122     Vmm vmm_inp(int i_ic, int nb_x_blocking) const;
123     Vmm vmm_bias_alpha() const;
124     Xmm xmm_bias_alpha() const;
125 
126     int get_ow_start(int ki, int l_overflow) const noexcept;
127     int get_ow_end(int ur_w, int ki, int r_overflow) const noexcept;
128     int get_blocking_size() const noexcept;
129     int get_tail_size() const noexcept;
130 
131     void prepare_output(int ur_w);
132     void apply_postops(int ur_w, bool last_oc_block, const float *p_sum_scale,
133             const int32_t *p_sum_zp);
134     void store_output(int ur_w, bool last_oc_block);
135     void compute_ker(int ur_w, int l_overflow, int r_overflow,
136             ker_block_t last_ic_block_flag, bool h_padded = false);
137     void compute(const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src);
138     std::function<Vmm()> prepare_round_robin_vmm_inp_generator(int ur_w) const
139             noexcept;
140     void apply_zp_src_pad_str_comp(
141             int ur_w, int l_overflow, int r_overflow, bool h_padded);
142     void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow,
143             bool h_padded, bool last_oc_block);
144     void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block);
145     void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block);
146     void generate() override;
147     void cvt2ps(data_type_t type_in, const Vmm &vmm_in, const Reg64 &reg,
148             int offset, int load_size);
149 };
150 
151 template <cpu_isa_t isa>
152 struct jit_uni_x8s8s32x_deconv_fwd_kernel {
153 
154     jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp,
155             const primitive_attr_t &attr, const memory_desc_wrapper &dst_d);
156 
create_kerneldnnl::impl::cpu::x64::jit_uni_x8s8s32x_deconv_fwd_kernel157     status_t create_kernel() { return kernel_->create_kernel(); }
158 
159     ~jit_uni_x8s8s32x_deconv_fwd_kernel();
160 
operator ()dnnl::impl::cpu::x64::jit_uni_x8s8s32x_deconv_fwd_kernel161     void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); }
162 
163     static bool post_ops_ok(jit_conv_conf_t &jcp,
164             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
165 
166     static status_t init_conf(jit_conv_conf_t &jcp,
167             const deconvolution_desc_t &cd, memory_desc_t &src_md,
168             memory_desc_t &weights_md, memory_desc_t &dst_md,
169             const bool with_bias, memory_desc_t &bias_md,
170             primitive_attr_t &attr, int nthreads);
171 
172     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
173             const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
174 
175     using _jit_avx2_x8s8s32x_deconv_fwd_kernel
176             = _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>;
177 
178 private:
179     DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_deconv_fwd_kernel);
180     std::unique_ptr<jit_generator> kernel_;
181 };
182 
183 template <cpu_isa_t isa>
184 struct jit_uni_x8s8s32x_deconvolution_fwd_t : public primitive_t {
185     struct pd_t : public cpu_deconvolution_fwd_pd_t {
186         using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t;
187 
188         DECLARE_COMMON_PD_T(
189                 JIT_IMPL_NAME_HELPER("jit_uni_deconv:",
190                         isa == avx2 && jcp_.ver == ver_vnni ? avx2_vnni : isa,
191                         ""),
192                 jit_uni_x8s8s32x_deconvolution_fwd_t);
193 
194         status_t init(engine_t *engine);
195         jit_conv_conf_t jcp_;
196     };
197 
198     jit_uni_x8s8s32x_deconvolution_fwd_t(const pd_t *apd);
199     ~jit_uni_x8s8s32x_deconvolution_fwd_t();
200 
201     status_t init(engine_t *engine) override;
202     status_t execute(const exec_ctx_t &ctx) const override;
203 
204 private:
205     status_t execute_forward_1d(const exec_ctx_t &ctx) const;
206     status_t execute_forward_2d(const exec_ctx_t &ctx) const;
207     status_t execute_forward_3d(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::jit_uni_x8s8s32x_deconvolution_fwd_t208     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
209     std::unique_ptr<jit_uni_x8s8s32x_deconv_fwd_kernel<isa>> kernel_;
210     std::unique_ptr<zp::jit_uni_deconv_zp_pad_str_kernel_base_t>
211             zp_src_pad_comp_kernel_;
212 };
213 
214 } // namespace x64
215 } // namespace cpu
216 } // namespace impl
217 } // namespace dnnl
218 
219 #endif
220 
221 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
222