1 /******************************************************************************* 2 * Copyright 2018-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP 18 #define CPU_X64_JIT_AVX512_CORE_X8S8S32X_DECONVOLUTION_HPP 19 20 #include <vector> 21 #include "common/c_types_map.hpp" 22 #include "common/dnnl_thread.hpp" 23 #include "common/memory.hpp" 24 #include "common/nstl.hpp" 25 #include "common/primitive.hpp" 26 #include "common/type_helpers.hpp" 27 #include "common/utils.hpp" 28 29 #include "cpu/cpu_deconvolution_pd.hpp" 30 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" 31 #include "cpu/x64/jit_generator.hpp" 32 #include "cpu/x64/jit_primitive_conf.hpp" 33 34 namespace dnnl { 35 namespace impl { 36 namespace cpu { 37 namespace x64 { 38 39 typedef enum { 40 no_last_block = 0x1U, 41 last_ic_block = 0x2U, 42 last_sp_block = 0x4U, 43 } ker_block_t; 44 45 struct ur_w_blks_params_t { 46 struct single_ur_w_blk_params_t { single_ur_w_blk_params_tdnnl::impl::cpu::x64::ur_w_blks_params_t::single_ur_w_blk_params_t47 single_ur_w_blk_params_t( 48 int l_overflow, int r_overflow, bool process_sp_carefully) 49 : l_overflow(l_overflow) 50 , r_overflow(r_overflow) 51 , process_sp_carefully(process_sp_carefully) {} 52 53 // l_overflow - no. of spatial elements of weights standing out of 54 // src spatial when computing the 1st output pixel in the current blk 55 int l_overflow; 56 // r_overflow - no. of spatial elements of weights standing out of 57 // src spatial when computing the lst output pixel in the current blk 58 int r_overflow; 59 // process_sp_carefully - indicates if loading the last src sp 60 // for computation of the last dst sp of the block can't be done 61 // by fetching 4 src sp at once 62 bool process_sp_carefully; 63 }; 64 std::vector<single_ur_w_blk_params_t> blks_params; 65 int num_pre_blks; // num of blocks with l_overflow>0 66 int num_post_blks; // num of blocks with r_overflow>0 or that need to be 67 // processed carefully 68 }; 69 70 template <typename Vmm> 71 struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator { 72 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t); 73 74 jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, 75 const primitive_attr_t &attr, const memory_desc_t &dst_md); 76 ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel(); 77 78 const jit_conv_conf_t &jcp; 79 const primitive_attr_t &attr_; 80 81 private: 82 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> 83 postops_injector_; 84 85 const int ic_sub_step = 4; 86 87 /* data regs */ 88 const Xbyak::Reg64 reg_src = r8; 89 const Xbyak::Reg64 reg_filt = r9; 90 const Xbyak::Reg64 reg_dst = r10; 91 const Xbyak::Reg64 param1 = abi_param1; 92 const Xbyak::Reg64 reg_kh = abi_not_param1; 93 const Xbyak::Reg64 reg_ki = r14; 94 95 const Xbyak::Reg64 reg_nur_w = rbx; 96 const Xbyak::Reg64 reg_bias = rdx; 97 const Xbyak::Reg64 reg_icb = reg_bias; 98 const Xbyak::Reg64 reg_ptr_scales = rax; 99 const Xbyak::Reg64 reg_ptr_saturation_ubound = rax; 100 const Xbyak::Reg64 reg_oc_blocks = rsi; 101 102 const Xbyak::Reg64 aux_reg_src = r11; 103 const Xbyak::Reg64 aux_reg_filt = r12; 104 105 const Xbyak::Reg64 aux_reg_src_d = r13; 106 const Xbyak::Reg64 aux_reg_filt_d = r15; 107 108 const Xbyak::Reg64 reg_compensation = r14; 109 const Xbyak::Reg64 reg_scratch = r14; 110 const Xbyak::Reg64 reg_ptr_sum_scale = r11; 111 const Xbyak::Reg64 reg_bias_alpha = abi_not_param1; 112 const Xbyak::Reg64 reg_overflow = rax; 113 const Xbyak::Reg64 reg_comp_strides = reg_overflow; 114 const Xbyak::Reg64 reg_ker_long_offt = r15; 115 116 Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); 117 const Vmm vmm_tmp = Vmm(28); 118 const Vmm vmm_one = Vmm(29); 119 /* used during write-out section of store_output */ 120 const Vmm vmm_zero = Vmm(31); 121 const Vmm vmm_saturation = Vmm(31); 122 const Vmm vmm_wei = Vmm(31); 123 124 /* signed input */ 125 const Vmm vmm_shift = Vmm(30); 126 const Vmm vmm_comp = Vmm(30); 127 const Vmm vmm_bias = Vmm(31); 128 const Vmm vmm_prev_dst = Vmm(31); 129 vmm_outdnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel130 Vmm vmm_out(int i_ur, int i_oc) { 131 int idx = i_ur * jcp.nb_oc_blocking + i_oc; 132 assert(idx < 31); 133 return Vmm(idx); 134 } vmm_inpdnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel135 Vmm vmm_inp(int i_ic, int nb_x_blocking) { 136 int idx = i_ic + nb_x_blocking * jcp.ur_w; 137 assert(idx < 31); 138 return Vmm(idx); 139 } vmm_bias_alphadnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel140 Vmm vmm_bias_alpha() { return Vmm(jcp.nb_oc_blocking * jcp.ur_w); } xmm_bias_alphadnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel141 Xbyak::Xmm xmm_bias_alpha() { 142 return Xbyak::Xmm(jcp.nb_oc_blocking * jcp.ur_w); 143 } 144 get_ow_startdnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel145 int get_ow_start(int ki, int l_overflow) { 146 int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w 147 + l_overflow * jcp.stride_w 148 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); 149 while (res < 0) 150 res += jcp.stride_w; 151 return res; 152 } 153 get_ow_enddnnl::impl::cpu::x64::jit_avx512_core_x8s8s32x_deconv_fwd_kernel154 int get_ow_end(int ur_w, int ki, int r_overflow) { 155 if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail)) 156 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding 157 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w 158 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); 159 while (res < 0) 160 res += jcp.stride_w; 161 return ur_w - res; 162 } 163 void prepare_output(int ur_w); 164 void store_output(int ur_w, bool last_oc_block); 165 void compute_ker(int ur_w, int l_overflow, int r_overflow, 166 ker_block_t last_ic_block_flag, bool h_padded = false); 167 void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block); 168 void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block); 169 170 ur_w_blks_params_t get_ur_w_blks_params(); 171 172 void generate() override; 173 void cvt2ps(data_type_t type_in, Vmm vmm_in, const Xbyak::Operand &op, 174 bool mask_flag); 175 }; 176 177 struct _jit_avx512_core_x8s8s32x_deconv_fwd_kernel { 178 _jit_avx512_core_x8s8s32x_deconv_fwd_kerneldnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel179 _jit_avx512_core_x8s8s32x_deconv_fwd_kernel(const jit_conv_conf_t &ajcp, 180 const primitive_attr_t &attr, const memory_desc_t &dst_md) 181 : kernel_(nullptr) { 182 183 int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block; 184 switch (ch_block) { 185 case 16: 186 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel< 187 Xbyak::Zmm>(ajcp, attr, dst_md); 188 return; 189 case 8: 190 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel< 191 Xbyak::Ymm>(ajcp, attr, dst_md); 192 return; 193 case 4: 194 kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel< 195 Xbyak::Xmm>(ajcp, attr, dst_md); 196 return; 197 default: assert(!"invalid channel blocking"); 198 } 199 } 200 create_kerneldnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel201 status_t create_kernel() { return kernel_->create_kernel(); } 202 ~_jit_avx512_core_x8s8s32x_deconv_fwd_kerneldnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel203 ~_jit_avx512_core_x8s8s32x_deconv_fwd_kernel() { delete kernel_; } 204 operator ()dnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconv_fwd_kernel205 void operator()(const jit_deconv_call_s *p) const { (*kernel_)(p); } 206 207 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr, 208 const memory_desc_wrapper &dst_d); 209 210 static status_t init_conf(jit_conv_conf_t &jcp, 211 const deconvolution_desc_t &cd, memory_desc_t &src_md, 212 memory_desc_t &weights_md, memory_desc_t &dst_md, 213 const bool with_bias, memory_desc_t &bias_md, 214 const primitive_attr_t &attr, int nthreads); 215 216 static void init_scratchpad(memory_tracking::registrar_t &scratchpad, 217 const jit_conv_conf_t &jcp, const primitive_attr_t &attr); 218 219 private: 220 DNNL_DISALLOW_COPY_AND_ASSIGN(_jit_avx512_core_x8s8s32x_deconv_fwd_kernel); 221 jit_generator *kernel_; 222 }; 223 224 template <impl::data_type_t src_type, impl::data_type_t dst_type> 225 struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public primitive_t { 226 struct pd_t : public cpu_deconvolution_fwd_pd_t { 227 using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t; 228 229 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_deconvolution:", 230 ((jcp_.ver == ver_vnni) ? avx512_core_vnni 231 : avx512_core), 232 ""), 233 _jit_avx512_core_x8s8s32x_deconvolution_fwd_t); 234 initdnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t::pd_t235 status_t init(engine_t *engine) { 236 const bool ok = is_fwd() 237 && (desc()->alg_kind & alg_kind::deconvolution_direct) 238 && desc()->src_desc.data_type == src_type 239 && desc()->dst_desc.data_type == dst_type 240 && IMPLICATION(with_bias(), 241 utils::one_of(desc()->bias_desc.data_type, 242 data_type::f32, data_type::s32, 243 data_type::s8, data_type::u8)) 244 && desc()->accum_data_type == data_type::s32 245 && attr()->has_default_values( 246 primitive_attr_t::skip_mask_t::oscale 247 | primitive_attr_t::skip_mask_t::post_ops); 248 if (!ok) return status::unimplemented; 249 250 status_t status 251 = _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf( 252 jcp_, *desc(), src_md_, weights_md_, dst_md_, 253 with_bias(), bias_md_, *attr(), 254 dnnl_get_max_threads()); 255 256 if (status != status::success) return status; 257 258 auto scratchpad = scratchpad_registry().registrar(); 259 _jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad( 260 scratchpad, jcp_, *attr()); 261 262 return status::success; 263 } 264 265 jit_conv_conf_t jcp_; 266 }; 267 _jit_avx512_core_x8s8s32x_deconvolution_fwd_tdnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t268 _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd) 269 : primitive_t(apd) {} 270 271 typedef typename prec_traits<src_type>::type src_data_t; 272 typedef typename prec_traits<data_type::s8>::type wei_data_t; 273 typedef typename prec_traits<dst_type>::type dst_data_t; 274 initdnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t275 status_t init(engine_t *engine) override { 276 CHECK(safe_ptr_assign(kernel_, 277 new _jit_avx512_core_x8s8s32x_deconv_fwd_kernel( 278 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); 279 return kernel_->create_kernel(); 280 } 281 executednnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t282 status_t execute(const exec_ctx_t &ctx) const override { 283 auto ndims = pd()->ndims(); 284 if (ndims == 3) 285 execute_forward_1d(ctx); 286 else if (ndims == 4) 287 execute_forward_2d(ctx); 288 else if (ndims == 5) 289 execute_forward_3d(ctx); 290 else 291 return status::unimplemented; 292 return status::success; 293 } 294 295 private: 296 void execute_forward_1d(const exec_ctx_t &ctx) const; 297 void execute_forward_2d(const exec_ctx_t &ctx) const; 298 void execute_forward_3d(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::_jit_avx512_core_x8s8s32x_deconvolution_fwd_t299 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 300 std::unique_ptr<_jit_avx512_core_x8s8s32x_deconv_fwd_kernel> kernel_; 301 }; 302 303 } // namespace x64 304 } // namespace cpu 305 } // namespace impl 306 } // namespace dnnl 307 308 #endif 309 310 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 311