1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP 18 #define CPU_X64_JIT_UNI_X8S8S32X_DECONVOLUTION_HPP 19 20 #include <functional> 21 #include <memory> 22 23 #include "common/c_types_map.hpp" 24 #include "common/primitive.hpp" 25 26 #include "cpu/cpu_deconvolution_pd.hpp" 27 28 #include "cpu/x64/jit_generator.hpp" 29 #include "cpu/x64/jit_primitive_conf.hpp" 30 31 namespace dnnl { 32 namespace impl { 33 namespace cpu { 34 namespace x64 { 35 36 namespace zp { 37 class jit_uni_deconv_zp_pad_str_kernel_base_t; 38 } // namespace zp 39 40 namespace injector { 41 template <cpu_isa_t isa, typename Vmm> 42 class jit_uni_postops_injector_t; 43 } // namespace injector 44 45 using namespace Xbyak; 46 47 template <cpu_isa_t isa, typename Vmm> 48 struct _jit_uni_x8s8s32x_deconv_fwd_kernel : public jit_generator { 49 DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_uni_x8s8s32x_deconv_fwd_kernel); 50 51 _jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, 52 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d); 53 ~_jit_uni_x8s8s32x_deconv_fwd_kernel(); 54 55 const jit_conv_conf_t jcp_; 56 57 private: 58 std::unique_ptr<injector::jit_uni_postops_injector_t<isa, Vmm>> 59 postops_injector_; 60 using reg64_t = const Xbyak::Reg64; 61 62 static constexpr dim_t IC_SUB_STEP = 4; 63 static constexpr dim_t KER_MAX_REG_IDX = 13; 64 65 enum ker_block_t { 66 no_last_block = 0x1U, 67 last_ic_block = 0x2U, 68 last_sp_block = 0x4U, 69 }; 70 71 /* data regs */ 72 const reg64_t ®_src_ = r8; 73 const reg64_t ®_filt_ = r9; 74 const reg64_t ®_dst_ = r10; 75 const reg64_t ¶m1_ = abi_param1; 76 const reg64_t ®_kh_ = abi_not_param1; 77 const reg64_t ®_ki_ = r14; 78 79 const reg64_t ®_nur_w_ = rbx; 80 const reg64_t ®_bias_ = rdx; 81 const reg64_t ®_icb_ = reg_bias_; 82 const reg64_t ®_ptr_scales_ = rax; 83 const reg64_t ®_ptr_saturation_ubound_ = rax; 84 const reg64_t ®_oc_blocks_ = rsi; 85 86 const reg64_t &aux_reg_src_ = r11; 87 const reg64_t &aux_reg_filt_ = r12; 88 89 const reg64_t &aux_reg_src_d_ = r13; 90 const reg64_t &aux_reg_filt_d_ = r15; 91 92 const reg64_t ®_compensation_ = r14; 93 const reg64_t ®_scratch_ = r14; 94 const reg64_t ®_ptr_sum_scale_ = r11; 95 const reg64_t ®_ptr_sum_zp_ = r15; 96 const reg64_t ®_bias_alpha_ = abi_not_param1; 97 const reg64_t ®_overflow_ = rax; 98 const reg64_t ®_comp_strides_ = reg_overflow_; 99 const reg64_t ®_ker_long_offt_ = r15; 100 const reg64_t ®_zp_dst_ = r15; 101 const reg64_t ®_zp_src_ = r15; 102 const reg64_t ®_zp_compensation_ = r11; 103 const Xbyak::Address zp_src_pad_comp_addr_ = ptr[rsp]; 104 const Xbyak::Address reg_scratch_preserved_ = ptr[rsp + 8]; 105 static constexpr int reserved_stack_size_ = 16; 106 107 const Vmm vmm_tmp_ = Vmm(3); 108 const Vmm vmm_one_ = Vmm(2); 109 /* used during write-out section of store_output */ 110 const Vmm vmm_zero_ = Vmm(0); 111 const Vmm &vmm_saturation_ = vmm_zero_; 112 const Vmm &vmm_wei_ = vmm_zero_; 113 const Vmm &vmm_scale_ = vmm_zero_; 114 /* signed input */ 115 const Vmm vmm_shift_ = Vmm(1); 116 const Vmm vmm_comp_ = Vmm(1); 117 const Vmm &vmm_bias_ = vmm_zero_; 118 const Vmm &vmm_prev_dst_ = vmm_zero_; 119 const Vmm &vmm_sum_zp_ = vmm_tmp_; 120 121 Vmm vmm_out(int i_ur, int i_oc) const; 122 Vmm vmm_inp(int i_ic, int nb_x_blocking) const; 123 Vmm vmm_bias_alpha() const; 124 Xmm xmm_bias_alpha() const; 125 126 int get_ow_start(int ki, int l_overflow) const noexcept; 127 int get_ow_end(int ur_w, int ki, int r_overflow) const noexcept; 128 int get_blocking_size() const noexcept; 129 int get_tail_size() const noexcept; 130 131 void prepare_output(int ur_w); 132 void apply_postops(int ur_w, bool last_oc_block, const float *p_sum_scale, 133 const int32_t *p_sum_zp); 134 void store_output(int ur_w, bool last_oc_block); 135 void compute_ker(int ur_w, int l_overflow, int r_overflow, 136 ker_block_t last_ic_block_flag, bool h_padded = false); 137 void compute(const Vmm &vreg_acc, const Vmm &vreg_wei, const Vmm &vreg_src); 138 std::function<Vmm()> prepare_round_robin_vmm_inp_generator(int ur_w) const 139 noexcept; 140 void apply_zp_src_pad_str_comp( 141 int ur_w, int l_overflow, int r_overflow, bool h_padded); 142 void append_zp_src_pad_str_comp(int ur_w, int l_overflow, int r_overflow, 143 bool h_padded, bool last_oc_block); 144 void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); 145 void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); 146 void generate() override; 147 void cvt2ps(data_type_t type_in, const Vmm &vmm_in, const Reg64 ®, 148 int offset, int load_size); 149 }; 150 151 template <cpu_isa_t isa> 152 struct jit_uni_x8s8s32x_deconv_fwd_kernel { 153 154 jit_uni_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, 155 const primitive_attr_t &attr, const memory_desc_wrapper &dst_d); 156 create_kerneldnnl::impl::cpu::x64::jit_uni_x8s8s32x_deconv_fwd_kernel157 status_t create_kernel() { return kernel_->create_kernel(); } 158 159 ~jit_uni_x8s8s32x_deconv_fwd_kernel(); 160 operator ()dnnl::impl::cpu::x64::jit_uni_x8s8s32x_deconv_fwd_kernel161 void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); } 162 163 static bool post_ops_ok(jit_conv_conf_t &jcp, 164 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr); 165 166 static status_t init_conf(jit_conv_conf_t &jcp, 167 const deconvolution_desc_t &cd, memory_desc_t &src_md, 168 memory_desc_t &weights_md, memory_desc_t &dst_md, 169 const bool with_bias, memory_desc_t &bias_md, 170 primitive_attr_t &attr, int nthreads); 171 172 static void init_scratchpad(memory_tracking::registrar_t &scratchpad, 173 const jit_conv_conf_t &jcp, const primitive_attr_t &attr); 174 175 using _jit_avx2_x8s8s32x_deconv_fwd_kernel 176 = _jit_uni_x8s8s32x_deconv_fwd_kernel<avx2, Xbyak::Ymm>; 177 178 private: 179 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_x8s8s32x_deconv_fwd_kernel); 180 std::unique_ptr<jit_generator> kernel_; 181 }; 182 183 template <cpu_isa_t isa> 184 struct jit_uni_x8s8s32x_deconvolution_fwd_t : public primitive_t { 185 struct pd_t : public cpu_deconvolution_fwd_pd_t { 186 using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; 187 188 DECLARE_COMMON_PD_T( 189 JIT_IMPL_NAME_HELPER("jit_uni_deconv:", 190 isa == avx2 && jcp_.ver == ver_vnni ? avx2_vnni : isa, 191 ""), 192 jit_uni_x8s8s32x_deconvolution_fwd_t); 193 194 status_t init(engine_t *engine); 195 jit_conv_conf_t jcp_; 196 }; 197 198 jit_uni_x8s8s32x_deconvolution_fwd_t(const pd_t *apd); 199 ~jit_uni_x8s8s32x_deconvolution_fwd_t(); 200 201 status_t init(engine_t *engine) override; 202 status_t execute(const exec_ctx_t &ctx) const override; 203 204 private: 205 status_t execute_forward_1d(const exec_ctx_t &ctx) const; 206 status_t execute_forward_2d(const exec_ctx_t &ctx) const; 207 status_t execute_forward_3d(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::jit_uni_x8s8s32x_deconvolution_fwd_t208 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 209 std::unique_ptr<jit_uni_x8s8s32x_deconv_fwd_kernel<isa>> kernel_; 210 std::unique_ptr<zp::jit_uni_deconv_zp_pad_str_kernel_base_t> 211 zp_src_pad_comp_kernel_; 212 }; 213 214 } // namespace x64 215 } // namespace cpu 216 } // namespace impl 217 } // namespace dnnl 218 219 #endif 220 221 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 222