1 /******************************************************************************* 2 * Copyright 2021 Intel Corporation 3 * Copyright 2021 FUJITSU LIMITED 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 *******************************************************************************/ 17 18 #ifndef CPU_AARCH64_JIT_UNI_DW_CONV_KERNEL_F32_HPP 19 #define CPU_AARCH64_JIT_UNI_DW_CONV_KERNEL_F32_HPP 20 21 #include "common/c_types_map.hpp" 22 #include "common/memory_tracking.hpp" 23 24 #include "cpu/aarch64/jit_generator.hpp" 25 #include "cpu/aarch64/jit_primitive_conf.hpp" 26 27 #include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp" 28 29 using namespace Xbyak_aarch64; 30 31 namespace dnnl { 32 namespace impl { 33 namespace cpu { 34 namespace aarch64 { 35 36 template <cpu_isa_t isa> 37 struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator { 38 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32) 39 jit_uni_dw_conv_fwd_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3240 jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp) 41 : jcp(ajcp), eltwise_injector_(nullptr) { 42 if (jcp.with_eltwise) 43 eltwise_injector_ = new jit_uni_eltwise_injector_f32<sve_512>( 44 this, jcp.eltwise); 45 } 46 ~jit_uni_dw_conv_fwd_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3247 ~jit_uni_dw_conv_fwd_kernel_f32() { delete eltwise_injector_; } 48 49 jit_conv_conf_t jcp; 50 51 private: 52 using reg64_t = const XReg; 53 const PReg reg_p_all_ones = p2; 54 const int vlen = cpu_isa_traits<isa>::vlen; 55 56 // dw convolution 57 reg64_t reg_input = x1; 58 reg64_t aux_reg_input = x2; 59 reg64_t reg_kernel = x3; 60 reg64_t aux_reg_kernel = x5; 61 reg64_t reg_ch_blocks = x6; 62 reg64_t reg_output = x7; 63 reg64_t reg_bias = x8; 64 reg64_t reg_kh = x9; 65 reg64_t iter_kh = x10; 66 reg64_t reg_oi = x11; 67 reg64_t aux_reg_ch_blocks = x12; 68 // fused convolution 69 reg64_t reg_input_buffer_ptr = x13; 70 reg64_t aux_reg_input_buffer_ptr = x14; 71 reg64_t reg_iw_offset = reg_input; 72 73 /* Temprary regs */ 74 reg64_t reg_tmp_imm = x15; 75 reg64_t reg_kernel_stack = x16; 76 reg64_t reg_input_stack = x17; 77 reg64_t reg_output_stack = x18; 78 reg64_t reg_bias_stack = x19; 79 reg64_t reg_tmp_addr = x20; 80 81 inline void load_src(int ur_ch_blocks, int ur_w); 82 inline void compute_loop(int ur_w, int ur_ch_blocks, int pad_l, int pad_r); 83 inline void ow_loop(int ur_ch_blocks); 84 inline void apply_filter_unrolled( 85 int ur_ch_blocks, int ur_w, int pad_l, int pad_r); 86 inline void apply_activation(int ur_ch_blocks, int ur_w); 87 inline void store_dst(int ur_ch_blocks, int ur_w); 88 get_ker_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3289 inline ZReg get_ker_reg(int idx) { 90 assert(idx <= 31); 91 return ZReg(idx + 0); 92 } get_ker_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3293 inline ZRegS get_ker_reg_s(int idx) { 94 assert(idx <= 31); 95 return ZRegS(idx + 0); 96 } get_src_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3297 inline ZReg get_src_reg(int idx) { 98 assert((idx + 1) <= 31); 99 return ZReg(idx + 1); 100 } get_src_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32101 inline ZRegS get_src_reg_s(int idx) { 102 assert((idx + 1) <= 31); 103 return ZRegS(idx + 1); 104 } 105 get_acc_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32106 inline ZReg get_acc_reg(int idx) { 107 assert((idx + 4) <= 31); 108 return ZReg(idx + 4); 109 } get_acc_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32110 inline ZRegS get_acc_reg_s(int idx) { 111 assert((idx + 4) <= 31); 112 return ZRegS(idx + 4); 113 } 114 get_ow_startdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32115 int get_ow_start(int ki, int pad_l) { 116 return nstl::max(0, 117 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); 118 } 119 get_ow_enddnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32120 int get_ow_end(int ur_w, int ki, int pad_r) { 121 return ur_w 122 - nstl::max(0, 123 utils::div_up( 124 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), 125 jcp.stride_w)); 126 } 127 is_src_layout_nxcdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32128 inline bool is_src_layout_nxc() { 129 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, 130 format_tag::nwc); 131 } is_dst_layout_nxcdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32132 inline bool is_dst_layout_nxc() { 133 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, 134 format_tag::nwc); 135 } 136 137 jit_uni_eltwise_injector_f32<sve_512> *eltwise_injector_; 138 void generate() override; 139 }; 140 141 template <cpu_isa_t isa> 142 struct jit_uni_dw_conv_bwd_data_kernel_f32 : public jit_generator { 143 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32) 144 jit_uni_dw_conv_bwd_data_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32145 jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {} 146 jit_conv_conf_t jcp; 147 148 private: 149 using reg64_t = const XReg; 150 const PReg reg_p_all_ones = p2; 151 get_ker_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32152 inline ZReg get_ker_reg(int idx) { return ZReg(idx + 0); } get_src_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32153 inline ZReg get_src_reg(int idx) { return ZReg(idx + 1); } get_acc_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32154 inline ZReg get_acc_reg(int idx) { return ZReg(idx + 4); } get_ker_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32155 inline ZRegS get_ker_reg_s(int idx) { return ZRegS(idx + 0); } get_src_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32156 inline ZRegS get_src_reg_s(int idx) { return ZRegS(idx + 1); } get_acc_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32157 inline ZRegS get_acc_reg_s(int idx) { return ZRegS(idx + 4); } 158 159 reg64_t reg_ddst = x1; 160 reg64_t aux_reg_ddst = x2; 161 reg64_t aux1_reg_ddst = x3; 162 reg64_t reg_kernel = x5; 163 reg64_t aux_reg_kernel = x6; 164 reg64_t aux1_reg_kernel = x7; 165 reg64_t reg_dsrc = x8; 166 167 reg64_t reg_ur_str_w = x9; 168 reg64_t reg_ch_blocks = x10; 169 170 reg64_t iter_kh = x11; 171 reg64_t iter_kw = x12; 172 reg64_t reg_kh = x13; 173 reg64_t reg_kw = x14; 174 175 /* Temprary regs */ 176 reg64_t reg_tmp_imm = x15; 177 reg64_t reg_tmp_addr = x16; 178 179 inline void loop_body(int ur_ch_blocks); 180 inline void load_ddst(int ur_ch_blocks, int ur_str_w); 181 inline void apply_filter(int ur_ch_blocks, int ur_str_w); 182 inline void store_dsrc(int ur_ch_blocks, int ur_str_w); 183 184 void generate() override; 185 }; 186 187 template <cpu_isa_t isa> 188 struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator { 189 190 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32) 191 jit_uni_dw_conv_bwd_weights_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32192 jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {} 193 194 jit_conv_conf_t jcp; 195 196 private: 197 using reg64_t = const XReg; 198 const PReg reg_p_all_ones = p2; 199 int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float); 200 get_bias_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32201 inline ZReg get_bias_reg(int idx = 0) { return ZReg(idx); } get_output_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32202 inline ZReg get_output_reg(int idx) { return ZReg(idx + 1); } get_input_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32203 inline ZReg get_input_reg(int idx) { return ZReg(idx + 5); } get_acc_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32204 inline ZReg get_acc_reg(int idx) { return ZReg(idx + 2); } get_aux_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32205 inline ZReg get_aux_reg() { return ZReg(0); } get_bias_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32206 inline ZRegS get_bias_reg_s(int idx = 0) { return ZRegS(idx); } get_output_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32207 inline ZRegS get_output_reg_s(int idx) { return ZRegS(idx + 1); } get_input_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32208 inline ZRegS get_input_reg_s(int idx) { return ZRegS(idx + 5); } get_acc_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32209 inline ZRegS get_acc_reg_s(int idx) { return ZRegS(idx + 2); } get_aux_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32210 inline ZRegS get_aux_reg_s() { return ZRegS(0); } 211 212 reg64_t reg_tmp_input = x1; 213 reg64_t reg_tmp_output = x2; 214 reg64_t reg_tmp_filter = x3; 215 reg64_t reg_kh_offset = x5; 216 217 /* parameter passed by driver into kernel */ 218 reg64_t reg_exec_flags = x14; 219 220 reg64_t reg_oh_worksize = x6; 221 reg64_t reg_oh = x5; 222 223 reg64_t reg_iter_ow_blk = x7; 224 225 reg64_t reg_kh = x8; 226 reg64_t reg_kh_count = x9; 227 228 /* Base addresses for convolution parameters. */ 229 reg64_t reg_input_baddr = x10; 230 reg64_t reg_output_baddr = x11; 231 reg64_t reg_filter_baddr = x12; 232 reg64_t reg_bias_baddr = x13; 233 234 /* Temporary regs */ 235 reg64_t reg_tmp_imm = x15; 236 reg64_t reg_tmp_addr = x16; 237 238 /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs 239 */ 240 inline void compute_ow_step_unroll( 241 int unroll_w, int l_pad, int pad_offset, int ow_block); 242 243 /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */ 244 inline void compute_h_step( 245 int unroll_w, int l_pad, int pad_offset, int ow_block); 246 inline void compute_h_loop( 247 int unroll_w, int l_pad, int pad_offset, int ow_block); 248 249 /* Write 'width' micro-kernel JITs; depending on the padding and convolution 250 * size, write a micro-kernel for the left ow-block, middle ow-block(s), and 251 * right ow-block.*/ 252 inline void compute_ow_block_unroll(); 253 254 inline void compute_zero_filter(); 255 inline void load_filter(); 256 inline void zero_filter(); 257 inline void load_bias(); 258 inline void zero_bias(); 259 inline void compute_bias_step_unroll(const int unroll_w); 260 inline void compute_bias_loop(const int block_size); 261 inline void store_filter(); 262 inline void store_bias(); 263 264 void generate() override; 265 }; 266 } // namespace aarch64 267 } // namespace cpu 268 } // namespace impl 269 } // namespace dnnl 270 271 #endif 272