1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #ifndef CPU_AARCH64_JIT_UNI_DW_CONV_KERNEL_F32_HPP
19 #define CPU_AARCH64_JIT_UNI_DW_CONV_KERNEL_F32_HPP
20 
21 #include "common/c_types_map.hpp"
22 #include "common/memory_tracking.hpp"
23 
24 #include "cpu/aarch64/jit_generator.hpp"
25 #include "cpu/aarch64/jit_primitive_conf.hpp"
26 
27 #include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp"
28 
29 using namespace Xbyak_aarch64;
30 
31 namespace dnnl {
32 namespace impl {
33 namespace cpu {
34 namespace aarch64 {
35 
36 template <cpu_isa_t isa>
37 struct jit_uni_dw_conv_fwd_kernel_f32 : public jit_generator {
38     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
39 
jit_uni_dw_conv_fwd_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3240     jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp)
41         : jcp(ajcp), eltwise_injector_(nullptr) {
42         if (jcp.with_eltwise)
43             eltwise_injector_ = new jit_uni_eltwise_injector_f32<sve_512>(
44                     this, jcp.eltwise);
45     }
46 
~jit_uni_dw_conv_fwd_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3247     ~jit_uni_dw_conv_fwd_kernel_f32() { delete eltwise_injector_; }
48 
49     jit_conv_conf_t jcp;
50 
51 private:
52     using reg64_t = const XReg;
53     const PReg reg_p_all_ones = p2;
54     const int vlen = cpu_isa_traits<isa>::vlen;
55 
56     // dw convolution
57     reg64_t reg_input = x1;
58     reg64_t aux_reg_input = x2;
59     reg64_t reg_kernel = x3;
60     reg64_t aux_reg_kernel = x5;
61     reg64_t reg_ch_blocks = x6;
62     reg64_t reg_output = x7;
63     reg64_t reg_bias = x8;
64     reg64_t reg_kh = x9;
65     reg64_t iter_kh = x10;
66     reg64_t reg_oi = x11;
67     reg64_t aux_reg_ch_blocks = x12;
68     // fused convolution
69     reg64_t reg_input_buffer_ptr = x13;
70     reg64_t aux_reg_input_buffer_ptr = x14;
71     reg64_t reg_iw_offset = reg_input;
72 
73     /* Temprary regs */
74     reg64_t reg_tmp_imm = x15;
75     reg64_t reg_kernel_stack = x16;
76     reg64_t reg_input_stack = x17;
77     reg64_t reg_output_stack = x18;
78     reg64_t reg_bias_stack = x19;
79     reg64_t reg_tmp_addr = x20;
80 
81     inline void load_src(int ur_ch_blocks, int ur_w);
82     inline void compute_loop(int ur_w, int ur_ch_blocks, int pad_l, int pad_r);
83     inline void ow_loop(int ur_ch_blocks);
84     inline void apply_filter_unrolled(
85             int ur_ch_blocks, int ur_w, int pad_l, int pad_r);
86     inline void apply_activation(int ur_ch_blocks, int ur_w);
87     inline void store_dst(int ur_ch_blocks, int ur_w);
88 
get_ker_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3289     inline ZReg get_ker_reg(int idx) {
90         assert(idx <= 31);
91         return ZReg(idx + 0);
92     }
get_ker_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3293     inline ZRegS get_ker_reg_s(int idx) {
94         assert(idx <= 31);
95         return ZRegS(idx + 0);
96     }
get_src_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f3297     inline ZReg get_src_reg(int idx) {
98         assert((idx + 1) <= 31);
99         return ZReg(idx + 1);
100     }
get_src_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32101     inline ZRegS get_src_reg_s(int idx) {
102         assert((idx + 1) <= 31);
103         return ZRegS(idx + 1);
104     }
105 
get_acc_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32106     inline ZReg get_acc_reg(int idx) {
107         assert((idx + 4) <= 31);
108         return ZReg(idx + 4);
109     }
get_acc_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32110     inline ZRegS get_acc_reg_s(int idx) {
111         assert((idx + 4) <= 31);
112         return ZRegS(idx + 4);
113     }
114 
get_ow_startdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32115     int get_ow_start(int ki, int pad_l) {
116         return nstl::max(0,
117                 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
118     }
119 
get_ow_enddnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32120     int get_ow_end(int ur_w, int ki, int pad_r) {
121         return ur_w
122                 - nstl::max(0,
123                         utils::div_up(
124                                 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
125                                 jcp.stride_w));
126     }
127 
is_src_layout_nxcdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32128     inline bool is_src_layout_nxc() {
129         return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
130                 format_tag::nwc);
131     }
is_dst_layout_nxcdnnl::impl::cpu::aarch64::jit_uni_dw_conv_fwd_kernel_f32132     inline bool is_dst_layout_nxc() {
133         return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
134                 format_tag::nwc);
135     }
136 
137     jit_uni_eltwise_injector_f32<sve_512> *eltwise_injector_;
138     void generate() override;
139 };
140 
141 template <cpu_isa_t isa>
142 struct jit_uni_dw_conv_bwd_data_kernel_f32 : public jit_generator {
143     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32)
144 
jit_uni_dw_conv_bwd_data_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32145     jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {}
146     jit_conv_conf_t jcp;
147 
148 private:
149     using reg64_t = const XReg;
150     const PReg reg_p_all_ones = p2;
151 
get_ker_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32152     inline ZReg get_ker_reg(int idx) { return ZReg(idx + 0); }
get_src_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32153     inline ZReg get_src_reg(int idx) { return ZReg(idx + 1); }
get_acc_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32154     inline ZReg get_acc_reg(int idx) { return ZReg(idx + 4); }
get_ker_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32155     inline ZRegS get_ker_reg_s(int idx) { return ZRegS(idx + 0); }
get_src_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32156     inline ZRegS get_src_reg_s(int idx) { return ZRegS(idx + 1); }
get_acc_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_data_kernel_f32157     inline ZRegS get_acc_reg_s(int idx) { return ZRegS(idx + 4); }
158 
159     reg64_t reg_ddst = x1;
160     reg64_t aux_reg_ddst = x2;
161     reg64_t aux1_reg_ddst = x3;
162     reg64_t reg_kernel = x5;
163     reg64_t aux_reg_kernel = x6;
164     reg64_t aux1_reg_kernel = x7;
165     reg64_t reg_dsrc = x8;
166 
167     reg64_t reg_ur_str_w = x9;
168     reg64_t reg_ch_blocks = x10;
169 
170     reg64_t iter_kh = x11;
171     reg64_t iter_kw = x12;
172     reg64_t reg_kh = x13;
173     reg64_t reg_kw = x14;
174 
175     /* Temprary regs */
176     reg64_t reg_tmp_imm = x15;
177     reg64_t reg_tmp_addr = x16;
178 
179     inline void loop_body(int ur_ch_blocks);
180     inline void load_ddst(int ur_ch_blocks, int ur_str_w);
181     inline void apply_filter(int ur_ch_blocks, int ur_str_w);
182     inline void store_dsrc(int ur_ch_blocks, int ur_str_w);
183 
184     void generate() override;
185 };
186 
187 template <cpu_isa_t isa>
188 struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {
189 
190     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32)
191 
jit_uni_dw_conv_bwd_weights_kernel_f32dnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32192     jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {}
193 
194     jit_conv_conf_t jcp;
195 
196 private:
197     using reg64_t = const XReg;
198     const PReg reg_p_all_ones = p2;
199     int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
200 
get_bias_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32201     inline ZReg get_bias_reg(int idx = 0) { return ZReg(idx); }
get_output_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32202     inline ZReg get_output_reg(int idx) { return ZReg(idx + 1); }
get_input_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32203     inline ZReg get_input_reg(int idx) { return ZReg(idx + 5); }
get_acc_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32204     inline ZReg get_acc_reg(int idx) { return ZReg(idx + 2); }
get_aux_regdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32205     inline ZReg get_aux_reg() { return ZReg(0); }
get_bias_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32206     inline ZRegS get_bias_reg_s(int idx = 0) { return ZRegS(idx); }
get_output_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32207     inline ZRegS get_output_reg_s(int idx) { return ZRegS(idx + 1); }
get_input_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32208     inline ZRegS get_input_reg_s(int idx) { return ZRegS(idx + 5); }
get_acc_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32209     inline ZRegS get_acc_reg_s(int idx) { return ZRegS(idx + 2); }
get_aux_reg_sdnnl::impl::cpu::aarch64::jit_uni_dw_conv_bwd_weights_kernel_f32210     inline ZRegS get_aux_reg_s() { return ZRegS(0); }
211 
212     reg64_t reg_tmp_input = x1;
213     reg64_t reg_tmp_output = x2;
214     reg64_t reg_tmp_filter = x3;
215     reg64_t reg_kh_offset = x5;
216 
217     /* parameter passed by driver into kernel */
218     reg64_t reg_exec_flags = x14;
219 
220     reg64_t reg_oh_worksize = x6;
221     reg64_t reg_oh = x5;
222 
223     reg64_t reg_iter_ow_blk = x7;
224 
225     reg64_t reg_kh = x8;
226     reg64_t reg_kh_count = x9;
227 
228     /* Base addresses for convolution parameters. */
229     reg64_t reg_input_baddr = x10;
230     reg64_t reg_output_baddr = x11;
231     reg64_t reg_filter_baddr = x12;
232     reg64_t reg_bias_baddr = x13;
233 
234     /* Temporary regs */
235     reg64_t reg_tmp_imm = x15;
236     reg64_t reg_tmp_addr = x16;
237 
238     /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
239      */
240     inline void compute_ow_step_unroll(
241             int unroll_w, int l_pad, int pad_offset, int ow_block);
242 
243     /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
244     inline void compute_h_step(
245             int unroll_w, int l_pad, int pad_offset, int ow_block);
246     inline void compute_h_loop(
247             int unroll_w, int l_pad, int pad_offset, int ow_block);
248 
249     /* Write 'width' micro-kernel JITs; depending on the padding and convolution
250      * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
251      * right ow-block.*/
252     inline void compute_ow_block_unroll();
253 
254     inline void compute_zero_filter();
255     inline void load_filter();
256     inline void zero_filter();
257     inline void load_bias();
258     inline void zero_bias();
259     inline void compute_bias_step_unroll(const int unroll_w);
260     inline void compute_bias_loop(const int block_size);
261     inline void store_filter();
262     inline void store_bias();
263 
264     void generate() override;
265 };
266 } // namespace aarch64
267 } // namespace cpu
268 } // namespace impl
269 } // namespace dnnl
270 
271 #endif
272