1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 * Copyright 2020-2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #ifndef CPU_AARCH64_JIT_SVE_CONV_KERNEL_HPP
19 #define CPU_AARCH64_JIT_SVE_CONV_KERNEL_HPP
20 
21 #include "common/c_types_map.hpp"
22 #include "common/memory_tracking.hpp"
23 
24 #include "cpu/aarch64/jit_generator.hpp"
25 #include "cpu/aarch64/jit_primitive_conf.hpp"
26 
27 #include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp"
28 
29 #include "cpu/aarch64/jit_op_imm_check.hpp"
30 
31 #define LDRWMAX 252
32 #define ADDMAX 4095
33 /* Get vector offsets, ofs / VL(VL: 512bits = 64Bytes) */
34 #define VL_OFS(ofs) ((ofs) >> 6)
35 
36 using namespace Xbyak_aarch64;
37 
38 namespace dnnl {
39 namespace impl {
40 namespace cpu {
41 namespace aarch64 {
42 
43 struct jit_sve_512_conv_fwd_kernel : public jit_generator {
44 
jit_sve_512_conv_fwd_kerneldnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel45     jit_sve_512_conv_fwd_kernel(
46             const jit_conv_conf_t &ajcp, const primitive_attr_t &attr)
47         : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) {
48 
49         if (jcp.with_eltwise)
50             eltwise_injector_ = new jit_uni_eltwise_injector_f32<sve_512>(
51                     this, jcp.eltwise);
52     }
53 
~jit_sve_512_conv_fwd_kerneldnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel54     ~jit_sve_512_conv_fwd_kernel() { delete eltwise_injector_; }
55 
56     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_conv_fwd_kernel)
57 
58     jit_conv_conf_t jcp;
59     const primitive_attr_t &attr_;
60 
61     static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
62     static status_t init_conf(jit_conv_conf_t &jcp,
63             const convolution_desc_t &cd, memory_desc_t &src_pd,
64             memory_desc_t &weights_pd, memory_desc_t &dst_pd,
65             memory_desc_t &bias_pd, const primitive_attr_t &attr, int nthreads);
66     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
67             const jit_conv_conf_t &jcp);
68 
69 private:
70     using reg64_t = const XReg;
71     enum {
72         typesize = sizeof(float),
73         ker_reg_base_idx = 28,
74     };
75 
76     const PReg reg_p_all_ones = p3;
77 
78     reg64_t param = abi_param1;
79     reg64_t reg_inp = x1; // src base addr (2d)
80     reg64_t reg_ker = x2; // ker base addr (2d)
81     reg64_t aux_reg_ker_d = x2; // ker addr (3d)
82     reg64_t reg_out = x3; // dst base addr (2d)
83     reg64_t reg_ki = x3; // d-dim loop var? (3d)
84     reg64_t reg_owb = x5; // num of ow-block
85     reg64_t reg_out_prf = x6; // addr for prefetch
86 
87     reg64_t aux_reg_inp = x7; // src addr (main loop)
88     reg64_t aux_reg_inp2 = x24; // src addr (main loop)
89     reg64_t aux_reg_inp3 = x25; // src addr (main loop)
90     reg64_t reg_out_ofs = x7; // dst addr (store_output)
91     reg64_t aux_reg_ker = x8; // ker addr (main loop)
92     reg64_t reg_channel = x9; // reduce workload
93     reg64_t reg_bias = x10; // bias addr (prepare_out)
94 
95     reg64_t aux_reg_inp_d = x11; // src addr (3d)
96     reg64_t reg_oi = x11;
97 
98     reg64_t reg_kh = x12; // ker h size
99     reg64_t reg_kj = x13; // ker h workload
100 
101     /* Temporary registers for ARM insts */
102     reg64_t reg_tmp_addr = x14;
103     reg64_t reg_prev_bcast_addr = x15;
104     reg64_t reg_prev_wei_addr = x16;
105     reg64_t reg_tmp_imm = x17;
106 
107     reg64_t reg_out_org = x18; // dst base addr (3d)
108     reg64_t reg_oi_org = x19; // base oi (3d)
109     reg64_t aux_reg_ker_d_org = x20;
110     reg64_t reg_ker_org = x21; // ker base addr (3d)
111     reg64_t reg_inp_org = x29; // src base addr (3d)
112 
prefetchdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel113     void prefetch(
114             const std::string prfop, int level, reg64_t in, long long int ofs) {
115         bool for_load = false;
116         if (prfop == "LD") {
117             for_load = true;
118         } else if (prfop == "ST") {
119             for_load = false;
120         } else {
121             assert(!"invalid prfop");
122         }
123 
124         bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false;
125         if (cacheline_aligned == true) {
126             Prfop op = PLDL1KEEP;
127             switch (level) {
128                 case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break;
129                 case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break;
130                 case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break;
131                 default: assert(!"invalid prfop"); break;
132             }
133 
134             if ((ofs <= PRFMMAX) && (ofs >= 0)) {
135                 prfm(op, ptr(in, static_cast<int32_t>(ofs)));
136             } else {
137                 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm);
138                 prfm(op, ptr(reg_tmp_addr));
139             }
140         } else {
141             PrfopSve op_sve = PLDL1KEEP_SVE;
142             switch (level) {
143                 case 1:
144                     op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE;
145                     break;
146                 case 2:
147                     op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE;
148                     break;
149                 case 3:
150                     op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE;
151                     break;
152                 default: assert(!"invalid level"); break;
153             }
154 
155             if ((VL_OFS(ofs) <= PRFWMAX)
156                     && (VL_OFS(ofs) >= (-1 * PRFWMAX - 1))) {
157                 prfw(op_sve, reg_p_all_ones,
158                         ptr(in, static_cast<int32_t>(VL_OFS(ofs))));
159             } else {
160                 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm);
161                 prfw(op_sve, reg_p_all_ones, ptr(reg_tmp_addr));
162             }
163         }
164     }
165 
166     jit_uni_eltwise_injector_f32<sve_512> *eltwise_injector_;
167 
168     inline void prepare_output(int ur_w);
169     inline void store_output(int ur_w);
170     inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r);
171     inline void compute_loop(int ur_w, int pad_l, int pad_r);
172 
173     void generate() override;
174 
get_output_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel175     inline size_t get_output_offset(int oi, int n_oc_block) {
176         const bool is_nxc_layout = is_dst_layout_nxc();
177         size_t ow_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block;
178         size_t ocb_str = is_nxc_layout
179                 ? jcp.oc_block
180                 : (size_t)jcp.od * jcp.oh * jcp.ow * jcp.oc_block;
181 
182         return jcp.typesize_out * (n_oc_block * ocb_str + oi * ow_str);
183     }
184 
get_input_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel185     inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) {
186         const bool is_nxc_layout = is_src_layout_nxc();
187         size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic
188                                       : (!jcp.is_1stconv ? jcp.ic_block : 1);
189         size_t ic_str = !jcp.is_1stconv || is_nxc_layout
190                 ? 1
191                 : (size_t)jcp.iw * jcp.ih * jcp.id;
192         size_t iw_idx = ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l;
193 
194         return jcp.typesize_in * (iw_idx * iw_str + ic * ic_str);
195     }
196 
get_kernel_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel197     inline int get_kernel_offset(
198             int ki, int ic, int n_oc_block, int ker_number) {
199         return jcp.typesize_in * jcp.oc_block
200                 * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw
201                                 * jcp.kd
202                         + (ic + ker_number) + ki * jcp.ic_block);
203     }
204 
get_ow_startdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel205     inline int get_ow_start(int ki, int pad_l) {
206         return nstl::max(0,
207                 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
208     }
209 
get_ow_enddnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel210     inline int get_ow_end(int ur_w, int ki, int pad_r) {
211         return ur_w
212                 - nstl::max(0,
213                         utils::div_up(
214                                 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1),
215                                 jcp.stride_w));
216     }
is_src_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel217     inline bool is_src_layout_nxc() {
218         return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
219                 format_tag::nwc);
220     }
is_dst_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel221     inline bool is_dst_layout_nxc() {
222         return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
223                 format_tag::nwc);
224     }
225 };
226 
227 struct jit_sve_512_conv_bwd_data_kernel_f32 : public jit_generator {
228 
jit_sve_512_conv_bwd_data_kernel_f32dnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32229     jit_sve_512_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp)
230         : jcp(ajcp) {}
231 
232     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_conv_bwd_data_kernel_f32)
233     jit_conv_conf_t jcp;
234     void (*jit_ker_)(jit_conv_call_s *);
235 
236     static status_t init_conf(jit_conv_conf_t &jcp,
237             const convolution_desc_t &cd, memory_desc_t &diff_src_d,
238             memory_desc_t &weights_d, memory_desc_t &diff_dst_d, int nthreads);
239     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
240             const jit_conv_conf_t &jcp);
241 
242 private:
243     using reg64_t = const XReg;
244     enum {
245         typesize = sizeof(float),
246     };
247     int ker_reg_base_idx = (jcp.nb_ic_blocking == 1) ? 16 : 24;
248 
249     reg64_t param = abi_param1;
250     reg64_t reg_dst = x1;
251     reg64_t reg_ker = x2;
252     reg64_t reg_src = x3;
253 
254     reg64_t reg_dst_prf = x23;
255     reg64_t reg_ker_prf = x5;
256     reg64_t reg_src_prf = x6;
257     reg64_t reg_iwb = x24;
258 
259     reg64_t aux_reg_dst = x7;
260     reg64_t aux_reg_ker = x8;
261 
262     reg64_t aux_reg_dst_prf = x9;
263     reg64_t aux_reg_ker_prf = x10;
264 
265     reg64_t aux_reg_dst_d_prf = x6;
266     reg64_t aux_reg_dst_d = x11;
267     reg64_t aux_reg_ker_d_prf = x12;
268     reg64_t aux_reg_ker_d = x2;
269     reg64_t reg_ki = x3;
270 
271     reg64_t reg_kj = x13;
272     reg64_t reg_oi = x11;
273     reg64_t reg_kh = x12;
274 
275     reg64_t reg_channel = x9;
276 
277     reg64_t reg_tmp = x14;
278     reg64_t reg_long_offt = x7;
279 
280     /* Temporary registers for ARM insts */
281     reg64_t reg_prev_bcast_addr = x15;
282     reg64_t reg_prev_bcast_addr2 = x17;
283     reg64_t reg_prev_bcast_addr3 = x21;
284     reg64_t reg_tmp_imm = x16;
285     reg64_t reg_tmp_addr = x18;
286 
287     reg64_t reg_src_prf_org = x19;
288     reg64_t reg_src_org = x20;
289     reg64_t reg_oi_org = x25;
290     reg64_t reg_dst_org = x22;
291     reg64_t reg_ker_org = x26;
292     reg64_t reg_input_org = x22;
293     reg64_t reg_kernel_org = x26;
294 
295     const PReg reg_p_all_ones = p3;
296 
prefetchdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32297     long long int prefetch(const std::string prfop, int level, reg64_t in,
298             long long int ofs, long long int prev_ofs) {
299         bool for_load = false;
300         if (prfop == "LD") {
301             for_load = true;
302         } else if (prfop == "ST") {
303             for_load = false;
304         } else {
305             assert(!"invalid prfop");
306         }
307 
308         bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false;
309         if (cacheline_aligned == true) {
310             Prfop op = PLDL1KEEP;
311             switch (level) {
312                 case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break;
313                 case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break;
314                 case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break;
315                 default: assert(!"invalid prfop"); break;
316             }
317 
318             long long int tmp_ofs = ofs - prev_ofs;
319             if ((ofs <= PRFMMAX) && (ofs >= 0)) {
320                 prfm(op, ptr(in, static_cast<int32_t>(ofs)));
321             } else if ((tmp_ofs <= PRFMMAX) && (tmp_ofs >= 0)) {
322                 prfm(op, ptr(reg_tmp_addr, static_cast<int32_t>(tmp_ofs)));
323             } else {
324                 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm);
325                 prfm(op, ptr(reg_tmp_addr));
326                 prev_ofs = ofs;
327             }
328         } else {
329             PrfopSve op_sve = PLDL1KEEP_SVE;
330             switch (level) {
331                 case 1:
332                     op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE;
333                     break;
334                 case 2:
335                     op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE;
336                     break;
337                 case 3:
338                     op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE;
339                     break;
340                 default: assert(!"invalid prfop"); break;
341             }
342 
343             long long int tmp_ofs = ofs - prev_ofs;
344             if ((VL_OFS(ofs) <= PRFWMAX)
345                     && (VL_OFS(ofs) >= (-1 * PRFWMAX - 1))) {
346                 prfw(op_sve, reg_p_all_ones,
347                         ptr(in, static_cast<int32_t>(VL_OFS(ofs))));
348             } else if ((VL_OFS(tmp_ofs) <= PRFWMAX)
349                     && (VL_OFS(tmp_ofs) >= (-1 * PRFWMAX - 1))) {
350                 prfw(op_sve, reg_p_all_ones,
351                         ptr(reg_tmp_addr,
352                                 static_cast<int32_t>(VL_OFS(tmp_ofs))));
353             } else {
354                 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm);
355                 prfw(op_sve, reg_p_all_ones, ptr(reg_tmp_addr));
356                 prev_ofs = ofs;
357             }
358         }
359         return prev_ofs;
360     }
361 
362     ZReg reg_wei = ZReg(31);
363 
364     inline void prepare_output(int ur_w);
365     inline void store_output(int ur_w);
366     inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow);
367     inline void compute_loop_fma_core(
368             int ur_w, int l_overflow, int r_overflow, int k_offset);
369     inline void compute_loop(
370             int ur_w, int l_overflow, int r_overflow, int k_offset = 0);
371     void generate() override;
372 
get_iw_startdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32373     inline int get_iw_start(int ki, int l_overflow) {
374         int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
375                 + l_overflow * jcp.stride_w
376                 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
377         while (res < 0)
378             res += jcp.stride_w;
379 
380         return res;
381     }
382 
get_iw_enddnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32383     inline int get_iw_end(int ur_w, int ki, int r_overflow) {
384         if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
385             ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
386         int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
387                 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
388         while (res < 0)
389             res += jcp.stride_w;
390 
391         return ur_w - res;
392     }
393 
get_diff_src_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32394     inline size_t get_diff_src_offset(int iw, int icb) {
395         const bool is_nxc_layout = is_dsrc_layout_nxc();
396         size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic : jcp.ic_block;
397         size_t icb_str = is_nxc_layout
398                 ? jcp.ic_block
399                 : (size_t)jcp.id * jcp.ih * jcp.iw * jcp.ic_block;
400 
401         return typesize * (icb * icb_str + iw * iw_str);
402     }
403 
get_dst_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32404     inline ptrdiff_t get_dst_offset(int iw, int oc, int kw) {
405         ptrdiff_t ow
406                 = (iw + jcp.l_pad - kw * (jcp.dilate_w + 1)) / jcp.stride_w;
407         ptrdiff_t ow_str
408                 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block;
409 
410         return typesize * (ow * ow_str + oc);
411     };
412 
is_dsrc_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32413     inline bool is_dsrc_layout_nxc() {
414         return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
415                 format_tag::nwc);
416     }
is_ddst_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32417     inline bool is_ddst_layout_nxc() {
418         return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
419                 format_tag::nwc);
420     }
421 };
422 
423 struct jit_sve_512_conv_bwd_weights_kernel_f32 : public jit_generator {
424 
jit_sve_512_conv_bwd_weights_kernel_f32dnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32425     jit_sve_512_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp)
426         : jcp(ajcp) {}
427 
generatednnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32428     void generate() override {
429         if (jcp.harness != harness_nxc) {
430             generate_kernel();
431         } else {
432             assert(!"none microkernel");
433         }
434     }
435 
436     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_conv_bwd_weights_kernel_f32)
437 
438     static status_t init_conf(jit_conv_conf_t &jcp,
439             const convolution_desc_t &cd, memory_desc_t &src_md,
440             memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
441             memory_desc_t &diff_dst_md, int nthreads);
442     static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
443             const jit_conv_conf_t &jcp);
444 
445     jit_conv_conf_t jcp;
446 
447 private:
448     using reg64_t = const XReg;
449     enum { typesize = sizeof(float) };
450     static const int max_ur_w;
451     static const int min_oh_reduce;
452 
453     reg64_t param = abi_param1;
454     reg64_t reg_input = x1;
455     reg64_t reg_kernel = x2;
456     reg64_t reg_output = x3;
457     reg64_t b_ic = x20;
458     reg64_t kj = x5;
459     reg64_t reg_kh = x6;
460     reg64_t reg_ur_w_trips = x7;
461     reg64_t reg_oj = x8;
462     reg64_t reg_tmp = x10;
463     reg64_t reg_icb = x9;
464 
465     reg64_t ki = x11;
466     reg64_t reg_kd_count = x12;
467     reg64_t reg_oi = x12;
468     reg64_t reg_d_index = x13;
469     reg64_t reg_input_d = x8;
470     reg64_t reg_output_d = x9;
471     reg64_t aux_reg_input = x12;
472     reg64_t aux_reg_kernel = x13;
473     reg64_t reg_bias = x9;
474     reg64_t reg_oc_tail = x10;
475 
476     /* Temporary registers */
477     reg64_t reg_add_tmp = x14;
478     reg64_t reg_tmp_imm = x15;
479 
480     reg64_t reg_kd_count_org = x16;
481     reg64_t reg_input_d_org = x17;
482     reg64_t reg_output_d_org = x18;
483     reg64_t reg_d_index_org = x19;
484 
485     reg64_t reg_input_org = x24;
486     reg64_t reg_kernel_org = x22;
487     reg64_t reg_output_org = x23;
488 
489     reg64_t reg_pre_addr_input = x25;
490     reg64_t reg_pre_addr_out = x26;
491     reg64_t reg_pre_addr_ker = x26;
492     reg64_t reg_ker_start_addr = x27;
493     reg64_t reg_addr_diff_input = x28;
494 
495     const PReg reg_p_all_ones = p3;
496 
prefetchdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32497     void prefetch(
498             const std::string prfop, int level, reg64_t in, long long int ofs) {
499         bool for_load = false;
500         if (prfop == "LD") {
501             for_load = true;
502         } else if (prfop == "ST") {
503             for_load = false;
504         } else {
505             assert(!"invalid prfop");
506         }
507 
508         bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false;
509         if (cacheline_aligned == true) {
510             Prfop op = PLDL1KEEP;
511             switch (level) {
512                 case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break;
513                 case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break;
514                 case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break;
515                 default: assert(!"invalid prfop"); break;
516             }
517 
518             if ((ofs <= PRFMMAX) && (ofs >= 0)) {
519                 prfm(op, ptr(in, static_cast<int32_t>(ofs)));
520             } else {
521                 add_imm(reg_add_tmp, in, ofs, reg_tmp_imm);
522                 prfm(op, ptr(reg_add_tmp));
523             }
524         } else {
525             PrfopSve op_sve;
526             switch (level) {
527                 case 1:
528                     op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE;
529                     break;
530                 case 2:
531                     op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE;
532                     break;
533                 case 3:
534                     op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE;
535                     break;
536                 default: assert(!"invalid prfop"); break;
537             }
538 
539             if ((VL_OFS(ofs) <= PRFWMAX)
540                     && (VL_OFS(ofs) >= (-1 * PRFWMAX - 1))) {
541                 prfw(op_sve, reg_p_all_ones,
542                         ptr(in, static_cast<int32_t>(VL_OFS(ofs))));
543             } else {
544                 add_imm(reg_add_tmp, in, ofs, reg_tmp_imm);
545                 prfw(op_sve, reg_p_all_ones, ptr(reg_add_tmp));
546             }
547         }
548     }
549 
550     inline void bias_kernel_2d();
551     inline void bias_kernel_3d();
552     inline void maybe_zero_kernel();
553     inline void compute_oh_step_unroll_ow_icblock(
554             int ic_block_step, int max_ur_w);
555     inline void od_step_comeback_pointers();
556     inline void oh_step_comeback_pointers();
557     inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
558     inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
559             int ic_block_step, int input_offset, int kernel_offset,
560             int output_offset, bool input_wraparound = false);
561     inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
562     inline void compute_oh_step_disp();
563     inline void compute_oh_loop_common();
564     inline void compute_oh_loop_partial();
565     inline void compute_od_loop_partial();
566 
567     inline bool compute_full_spat_loop();
568     inline bool flat_4ops_compute();
569 
570     inline void compute_loop();
is_src_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32571     inline bool is_src_layout_nxc() {
572         return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc,
573                 format_tag::nwc);
574     }
is_ddst_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32575     inline bool is_ddst_layout_nxc() {
576         return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
577                 format_tag::nwc);
578     }
579 
get_full_src_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32580     inline ptrdiff_t get_full_src_offset(
581             int i_iw, int i_ic, ptrdiff_t input_offset) {
582         const bool is_nxc_layout = is_src_layout_nxc();
583         const size_t w_shift_st = (jcp.is_hw_transp ? jcp.iw : 1)
584                 * (jcp.is_1stconv ? 1 : jcp.ic_block);
585         ptrdiff_t w_shift = is_nxc_layout ? jcp.ngroups * jcp.ic : w_shift_st;
586         ptrdiff_t ic_shift = (jcp.is_1stconv && !is_nxc_layout
587                         ? (ptrdiff_t)jcp.ih * jcp.iw * jcp.id
588                         : 1);
589 
590         ptrdiff_t local_input_offset = i_iw * w_shift + i_ic * ic_shift;
591         return input_offset + typesize * local_input_offset;
592     };
593 
get_iw_idxdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32594     inline int get_iw_idx(int ow, int kw, int l_pad) {
595         return ow * jcp.stride_w + kw * (jcp.dilate_w + 1) - l_pad;
596     }
597 
598     void generate_kernel();
599 
600     static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
601             int &nthr_g, int &nthr_oc_b, int &nthr_ic_b, int nthreads);
602 };
603 
604 } // namespace aarch64
605 } // namespace cpu
606 } // namespace impl
607 } // namespace dnnl
608 
609 #endif
610