1 /******************************************************************************* 2 * Copyright 2017-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP 18 #define CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/memory.hpp" 22 23 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" 24 #include "cpu/x64/jit_generator.hpp" 25 #include "cpu/x64/jit_primitive_conf.hpp" 26 27 namespace dnnl { 28 namespace impl { 29 namespace cpu { 30 namespace x64 { 31 32 struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator { 33 jit_sse41_conv_fwd_kernel_f32(const jit_conv_conf_t &ajcp, 34 const primitive_attr_t &attr, const memory_desc_t &dst_md); 35 36 static status_t init_conf(jit_conv_conf_t &jcp, 37 const convolution_desc_t &cd, const memory_desc_wrapper &src_d, 38 const memory_desc_wrapper &weights_d, 39 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr, 40 int nthreads); 41 42 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_conv_fwd_kernel_f32) 43 jit_conv_conf_t jcp; 44 const primitive_attr_t &attr_; 45 46 private: 47 static constexpr auto simd_w_ = cpu_isa_traits<sse41>::vlen / sizeof(float); 48 using reg64_t = const Xbyak::Reg64; 49 reg64_t reg_input = rax; 50 reg64_t aux_reg_input = r8; 51 reg64_t reg_kernel = rdx; 52 reg64_t aux_reg_kernel = r9; 53 reg64_t reg_output = rsi; 54 reg64_t reg_bias = rbx; 55 56 reg64_t kj = r10; 57 reg64_t oi_iter = r11; 58 reg64_t ki_iter = r12; 59 reg64_t reg_kh = abi_not_param1; 60 reg64_t simd_iter = r15; 61 reg64_t reg_oc_blocks = r14; 62 reg64_t imm_addr64 = reg_oc_blocks; 63 Xbyak::Reg32 reg_ci_flag = r13d; 64 65 std::unique_ptr<injector::jit_uni_postops_injector_t<sse41>> 66 postops_injector_; 67 68 inline void oh_step_unroll_kw( 69 int ur_w, int pad_l, int pad_r, int oc_blocks); 70 inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); 71 inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); 72 inline void solve_common(int oc_blocks); 73 filter_w_to_inputdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3274 inline dim_t filter_w_to_input(int ki, int oi = 0, int pad_l = 0) { 75 return ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l; 76 } 77 filter_h_to_inputdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3278 inline dim_t filter_h_to_input(int ki) { 79 return ki * (jcp.dilate_h + 1) * jcp.iw; 80 } 81 get_input_offsetdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3282 inline dim_t get_input_offset(int i_ic, int i_iw) { 83 dim_t offset; 84 if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw, 85 format_tag::ncdhw)) { 86 offset = i_ic * jcp.ih * jcp.iw + i_iw; 87 } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc, 88 format_tag::ndhwc)) { 89 offset = i_iw * jcp.ic * jcp.ngroups + i_ic; 90 } else { 91 offset = i_iw * jcp.ic_block + i_ic; 92 } 93 return sizeof(float) * offset; 94 } 95 get_output_offsetdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3296 inline dim_t get_output_offset(int i_oc_block, int i_ow) { 97 dim_t offset; 98 if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc, 99 format_tag::ndhwc)) { 100 offset = i_ow * jcp.oc * jcp.ngroups + i_oc_block * jcp.oc_block; 101 } else { 102 offset = (i_oc_block * jcp.oh * jcp.ow + i_ow) * jcp.oc_block; 103 } 104 return sizeof(float) * offset; 105 } 106 get_kernel_offsetdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f32107 inline dim_t get_kernel_offset(int i_oc_block, int ki, int i_ic) { 108 dim_t block_step_size = jcp.ic_block * jcp.oc_block; 109 dim_t ic_block_step_size = jcp.kh * jcp.kw * block_step_size; 110 dim_t oc_block_step_size = jcp.nb_ic * ic_block_step_size; 111 dim_t offset = i_oc_block * oc_block_step_size + ki * block_step_size 112 + i_ic * jcp.oc_block; 113 return sizeof(float) * offset; 114 } 115 116 void apply_postops(const int oc_blocks, const int ur_w); 117 118 void generate() override; 119 }; 120 121 } // namespace x64 122 } // namespace cpu 123 } // namespace impl 124 } // namespace dnnl 125 126 #endif 127