1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP
18 #define CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/memory.hpp"
22 
23 #include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24 #include "cpu/x64/jit_generator.hpp"
25 #include "cpu/x64/jit_primitive_conf.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31 
32 struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator {
33     jit_sse41_conv_fwd_kernel_f32(const jit_conv_conf_t &ajcp,
34             const primitive_attr_t &attr, const memory_desc_t &dst_md);
35 
36     static status_t init_conf(jit_conv_conf_t &jcp,
37             const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
38             const memory_desc_wrapper &weights_d,
39             const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
40             int nthreads);
41 
42     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_conv_fwd_kernel_f32)
43     jit_conv_conf_t jcp;
44     const primitive_attr_t &attr_;
45 
46 private:
47     static constexpr auto simd_w_ = cpu_isa_traits<sse41>::vlen / sizeof(float);
48     using reg64_t = const Xbyak::Reg64;
49     reg64_t reg_input = rax;
50     reg64_t aux_reg_input = r8;
51     reg64_t reg_kernel = rdx;
52     reg64_t aux_reg_kernel = r9;
53     reg64_t reg_output = rsi;
54     reg64_t reg_bias = rbx;
55 
56     reg64_t kj = r10;
57     reg64_t oi_iter = r11;
58     reg64_t ki_iter = r12;
59     reg64_t reg_kh = abi_not_param1;
60     reg64_t simd_iter = r15;
61     reg64_t reg_oc_blocks = r14;
62     reg64_t imm_addr64 = reg_oc_blocks;
63 
64     Xbyak::Reg32 reg_ci_flag = r13d;
65 
66     std::unique_ptr<injector::jit_uni_postops_injector_t<sse41>>
67             postops_injector_;
68 
69     inline void oh_step_unroll_kw(
70             int ur_w, int pad_l, int pad_r, int oc_blocks);
71     inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks);
72     inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks);
73     inline void solve_common(int oc_blocks);
74 
filter_w_to_inputdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3275     inline dim_t filter_w_to_input(int ki, int oi = 0, int pad_l = 0) {
76         return ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l;
77     }
78 
filter_h_to_inputdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3279     inline dim_t filter_h_to_input(int ki) {
80         return ki * (jcp.dilate_h + 1) * jcp.iw;
81     }
82 
get_input_offsetdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3283     inline dim_t get_input_offset(int i_ic, int i_iw) {
84         dim_t offset;
85         if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw,
86                     format_tag::ncdhw)) {
87             offset = i_ic * jcp.ih * jcp.iw + i_iw;
88         } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc,
89                            format_tag::ndhwc)) {
90             offset = i_iw * jcp.ic * jcp.ngroups + i_ic;
91         } else {
92             offset = i_iw * jcp.ic_block + i_ic;
93         }
94         return sizeof(float) * offset;
95     }
96 
get_output_offsetdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f3297     inline dim_t get_output_offset(int i_oc_block, int i_ow) {
98         dim_t offset;
99         if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc,
100                     format_tag::ndhwc)) {
101             offset = i_ow * jcp.oc * jcp.ngroups + i_oc_block * jcp.oc_block;
102         } else {
103             offset = (i_oc_block * jcp.oh * jcp.ow + i_ow) * jcp.oc_block;
104         }
105         return sizeof(float) * offset;
106     }
107 
get_kernel_offsetdnnl::impl::cpu::x64::jit_sse41_conv_fwd_kernel_f32108     inline dim_t get_kernel_offset(int i_oc_block, int ki, int i_ic) {
109         dim_t block_step_size = jcp.ic_block * jcp.oc_block;
110         dim_t ic_block_step_size = jcp.kh * jcp.kw * block_step_size;
111         dim_t oc_block_step_size = jcp.nb_ic * ic_block_step_size;
112         dim_t offset = i_oc_block * oc_block_step_size + ki * block_step_size
113                 + i_ic * jcp.oc_block;
114         return sizeof(float) * offset;
115     }
116 
117     void apply_postops(const int oc_blocks, const int ur_w);
118 
119     void generate() override;
120 };
121 
122 } // namespace x64
123 } // namespace cpu
124 } // namespace impl
125 } // namespace dnnl
126 
127 #endif
128