1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * Copyright 2020-2021 FUJITSU LIMITED 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 *******************************************************************************/ 17 18 #ifndef CPU_AARCH64_JIT_SVE_CONV_KERNEL_HPP 19 #define CPU_AARCH64_JIT_SVE_CONV_KERNEL_HPP 20 21 #include "common/c_types_map.hpp" 22 #include "common/memory_tracking.hpp" 23 24 #include "cpu/aarch64/jit_generator.hpp" 25 #include "cpu/aarch64/jit_primitive_conf.hpp" 26 27 #include "cpu/aarch64/injectors/jit_uni_eltwise_injector.hpp" 28 29 #include "cpu/aarch64/jit_op_imm_check.hpp" 30 31 #define LDRWMAX 252 32 #define ADDMAX 4095 33 /* Get vector offsets, ofs / VL(VL: 512bits = 64Bytes) */ 34 #define VL_OFS(ofs) ((ofs) >> 6) 35 36 using namespace Xbyak_aarch64; 37 38 namespace dnnl { 39 namespace impl { 40 namespace cpu { 41 namespace aarch64 { 42 43 struct jit_sve_512_conv_fwd_kernel : public jit_generator { 44 jit_sve_512_conv_fwd_kerneldnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel45 jit_sve_512_conv_fwd_kernel( 46 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr) 47 : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) { 48 49 if (jcp.with_eltwise) 50 eltwise_injector_ = new jit_uni_eltwise_injector_f32<sve_512>( 51 this, jcp.eltwise); 52 } 53 ~jit_sve_512_conv_fwd_kerneldnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel54 ~jit_sve_512_conv_fwd_kernel() { delete eltwise_injector_; } 55 56 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_conv_fwd_kernel) 57 58 jit_conv_conf_t jcp; 59 const primitive_attr_t &attr_; 60 61 static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr); 62 static status_t init_conf(jit_conv_conf_t &jcp, 63 const convolution_desc_t &cd, memory_desc_t &src_pd, 64 memory_desc_t &weights_pd, memory_desc_t &dst_pd, 65 memory_desc_t &bias_pd, const primitive_attr_t &attr, int nthreads); 66 static void init_scratchpad(memory_tracking::registrar_t &scratchpad, 67 const jit_conv_conf_t &jcp); 68 69 private: 70 using reg64_t = const XReg; 71 enum { 72 typesize = sizeof(float), 73 ker_reg_base_idx = 28, 74 }; 75 76 const PReg reg_p_all_ones = p3; 77 78 reg64_t param = abi_param1; 79 reg64_t reg_inp = x1; // src base addr (2d) 80 reg64_t reg_ker = x2; // ker base addr (2d) 81 reg64_t aux_reg_ker_d = x2; // ker addr (3d) 82 reg64_t reg_out = x3; // dst base addr (2d) 83 reg64_t reg_ki = x3; // d-dim loop var? (3d) 84 reg64_t reg_owb = x5; // num of ow-block 85 reg64_t reg_out_prf = x6; // addr for prefetch 86 87 reg64_t aux_reg_inp = x7; // src addr (main loop) 88 reg64_t aux_reg_inp2 = x24; // src addr (main loop) 89 reg64_t aux_reg_inp3 = x25; // src addr (main loop) 90 reg64_t reg_out_ofs = x7; // dst addr (store_output) 91 reg64_t aux_reg_ker = x8; // ker addr (main loop) 92 reg64_t reg_channel = x9; // reduce workload 93 reg64_t reg_bias = x10; // bias addr (prepare_out) 94 95 reg64_t aux_reg_inp_d = x11; // src addr (3d) 96 reg64_t reg_oi = x11; 97 98 reg64_t reg_kh = x12; // ker h size 99 reg64_t reg_kj = x13; // ker h workload 100 101 /* Temporary registers for ARM insts */ 102 reg64_t reg_tmp_addr = x14; 103 reg64_t reg_prev_bcast_addr = x15; 104 reg64_t reg_prev_wei_addr = x16; 105 reg64_t reg_tmp_imm = x17; 106 107 reg64_t reg_out_org = x18; // dst base addr (3d) 108 reg64_t reg_oi_org = x19; // base oi (3d) 109 reg64_t aux_reg_ker_d_org = x20; 110 reg64_t reg_ker_org = x21; // ker base addr (3d) 111 reg64_t reg_inp_org = x29; // src base addr (3d) 112 prefetchdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel113 void prefetch( 114 const std::string prfop, int level, reg64_t in, long long int ofs) { 115 bool for_load = false; 116 if (prfop == "LD") { 117 for_load = true; 118 } else if (prfop == "ST") { 119 for_load = false; 120 } else { 121 assert(!"invalid prfop"); 122 } 123 124 bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false; 125 if (cacheline_aligned == true) { 126 Prfop op = PLDL1KEEP; 127 switch (level) { 128 case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break; 129 case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break; 130 case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break; 131 default: assert(!"invalid prfop"); break; 132 } 133 134 if ((ofs <= PRFMMAX) && (ofs >= 0)) { 135 prfm(op, ptr(in, static_cast<int32_t>(ofs))); 136 } else { 137 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm); 138 prfm(op, ptr(reg_tmp_addr)); 139 } 140 } else { 141 PrfopSve op_sve = PLDL1KEEP_SVE; 142 switch (level) { 143 case 1: 144 op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE; 145 break; 146 case 2: 147 op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE; 148 break; 149 case 3: 150 op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE; 151 break; 152 default: assert(!"invalid level"); break; 153 } 154 155 if ((VL_OFS(ofs) <= PRFWMAX) 156 && (VL_OFS(ofs) >= (-1 * PRFWMAX - 1))) { 157 prfw(op_sve, reg_p_all_ones, 158 ptr(in, static_cast<int32_t>(VL_OFS(ofs)))); 159 } else { 160 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm); 161 prfw(op_sve, reg_p_all_ones, ptr(reg_tmp_addr)); 162 } 163 } 164 } 165 166 jit_uni_eltwise_injector_f32<sve_512> *eltwise_injector_; 167 168 inline void prepare_output(int ur_w); 169 inline void store_output(int ur_w); 170 inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); 171 inline void compute_loop(int ur_w, int pad_l, int pad_r); 172 173 void generate() override; 174 get_output_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel175 inline size_t get_output_offset(int oi, int n_oc_block) { 176 const bool is_nxc_layout = is_dst_layout_nxc(); 177 size_t ow_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block; 178 size_t ocb_str = is_nxc_layout 179 ? jcp.oc_block 180 : (size_t)jcp.od * jcp.oh * jcp.ow * jcp.oc_block; 181 182 return jcp.typesize_out * (n_oc_block * ocb_str + oi * ow_str); 183 } 184 get_input_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel185 inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { 186 const bool is_nxc_layout = is_src_layout_nxc(); 187 size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic 188 : (!jcp.is_1stconv ? jcp.ic_block : 1); 189 size_t ic_str = !jcp.is_1stconv || is_nxc_layout 190 ? 1 191 : (size_t)jcp.iw * jcp.ih * jcp.id; 192 size_t iw_idx = ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l; 193 194 return jcp.typesize_in * (iw_idx * iw_str + ic * ic_str); 195 } 196 get_kernel_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel197 inline int get_kernel_offset( 198 int ki, int ic, int n_oc_block, int ker_number) { 199 return jcp.typesize_in * jcp.oc_block 200 * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw 201 * jcp.kd 202 + (ic + ker_number) + ki * jcp.ic_block); 203 } 204 get_ow_startdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel205 inline int get_ow_start(int ki, int pad_l) { 206 return nstl::max(0, 207 utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); 208 } 209 get_ow_enddnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel210 inline int get_ow_end(int ur_w, int ki, int pad_r) { 211 return ur_w 212 - nstl::max(0, 213 utils::div_up( 214 pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), 215 jcp.stride_w)); 216 } is_src_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel217 inline bool is_src_layout_nxc() { 218 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, 219 format_tag::nwc); 220 } is_dst_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_fwd_kernel221 inline bool is_dst_layout_nxc() { 222 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, 223 format_tag::nwc); 224 } 225 }; 226 227 struct jit_sve_512_conv_bwd_data_kernel_f32 : public jit_generator { 228 jit_sve_512_conv_bwd_data_kernel_f32dnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32229 jit_sve_512_conv_bwd_data_kernel_f32(const jit_conv_conf_t &ajcp) 230 : jcp(ajcp) {} 231 232 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_conv_bwd_data_kernel_f32) 233 jit_conv_conf_t jcp; 234 void (*jit_ker_)(jit_conv_call_s *); 235 236 static status_t init_conf(jit_conv_conf_t &jcp, 237 const convolution_desc_t &cd, memory_desc_t &diff_src_d, 238 memory_desc_t &weights_d, memory_desc_t &diff_dst_d, int nthreads); 239 static void init_scratchpad(memory_tracking::registrar_t &scratchpad, 240 const jit_conv_conf_t &jcp); 241 242 private: 243 using reg64_t = const XReg; 244 enum { 245 typesize = sizeof(float), 246 }; 247 int ker_reg_base_idx = (jcp.nb_ic_blocking == 1) ? 16 : 24; 248 249 reg64_t param = abi_param1; 250 reg64_t reg_dst = x1; 251 reg64_t reg_ker = x2; 252 reg64_t reg_src = x3; 253 254 reg64_t reg_dst_prf = x23; 255 reg64_t reg_ker_prf = x5; 256 reg64_t reg_src_prf = x6; 257 reg64_t reg_iwb = x24; 258 259 reg64_t aux_reg_dst = x7; 260 reg64_t aux_reg_ker = x8; 261 262 reg64_t aux_reg_dst_prf = x9; 263 reg64_t aux_reg_ker_prf = x10; 264 265 reg64_t aux_reg_dst_d_prf = x6; 266 reg64_t aux_reg_dst_d = x11; 267 reg64_t aux_reg_ker_d_prf = x12; 268 reg64_t aux_reg_ker_d = x2; 269 reg64_t reg_ki = x3; 270 271 reg64_t reg_kj = x13; 272 reg64_t reg_oi = x11; 273 reg64_t reg_kh = x12; 274 275 reg64_t reg_channel = x9; 276 277 reg64_t reg_tmp = x14; 278 reg64_t reg_long_offt = x7; 279 280 /* Temporary registers for ARM insts */ 281 reg64_t reg_prev_bcast_addr = x15; 282 reg64_t reg_prev_bcast_addr2 = x17; 283 reg64_t reg_prev_bcast_addr3 = x21; 284 reg64_t reg_tmp_imm = x16; 285 reg64_t reg_tmp_addr = x18; 286 287 reg64_t reg_src_prf_org = x19; 288 reg64_t reg_src_org = x20; 289 reg64_t reg_oi_org = x25; 290 reg64_t reg_dst_org = x22; 291 reg64_t reg_ker_org = x26; 292 reg64_t reg_input_org = x22; 293 reg64_t reg_kernel_org = x26; 294 295 const PReg reg_p_all_ones = p3; 296 prefetchdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32297 long long int prefetch(const std::string prfop, int level, reg64_t in, 298 long long int ofs, long long int prev_ofs) { 299 bool for_load = false; 300 if (prfop == "LD") { 301 for_load = true; 302 } else if (prfop == "ST") { 303 for_load = false; 304 } else { 305 assert(!"invalid prfop"); 306 } 307 308 bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false; 309 if (cacheline_aligned == true) { 310 Prfop op = PLDL1KEEP; 311 switch (level) { 312 case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break; 313 case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break; 314 case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break; 315 default: assert(!"invalid prfop"); break; 316 } 317 318 long long int tmp_ofs = ofs - prev_ofs; 319 if ((ofs <= PRFMMAX) && (ofs >= 0)) { 320 prfm(op, ptr(in, static_cast<int32_t>(ofs))); 321 } else if ((tmp_ofs <= PRFMMAX) && (tmp_ofs >= 0)) { 322 prfm(op, ptr(reg_tmp_addr, static_cast<int32_t>(tmp_ofs))); 323 } else { 324 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm); 325 prfm(op, ptr(reg_tmp_addr)); 326 prev_ofs = ofs; 327 } 328 } else { 329 PrfopSve op_sve = PLDL1KEEP_SVE; 330 switch (level) { 331 case 1: 332 op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE; 333 break; 334 case 2: 335 op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE; 336 break; 337 case 3: 338 op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE; 339 break; 340 default: assert(!"invalid prfop"); break; 341 } 342 343 long long int tmp_ofs = ofs - prev_ofs; 344 if ((VL_OFS(ofs) <= PRFWMAX) 345 && (VL_OFS(ofs) >= (-1 * PRFWMAX - 1))) { 346 prfw(op_sve, reg_p_all_ones, 347 ptr(in, static_cast<int32_t>(VL_OFS(ofs)))); 348 } else if ((VL_OFS(tmp_ofs) <= PRFWMAX) 349 && (VL_OFS(tmp_ofs) >= (-1 * PRFWMAX - 1))) { 350 prfw(op_sve, reg_p_all_ones, 351 ptr(reg_tmp_addr, 352 static_cast<int32_t>(VL_OFS(tmp_ofs)))); 353 } else { 354 add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm); 355 prfw(op_sve, reg_p_all_ones, ptr(reg_tmp_addr)); 356 prev_ofs = ofs; 357 } 358 } 359 return prev_ofs; 360 } 361 362 ZReg reg_wei = ZReg(31); 363 364 inline void prepare_output(int ur_w); 365 inline void store_output(int ur_w); 366 inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow); 367 inline void compute_loop_fma_core( 368 int ur_w, int l_overflow, int r_overflow, int k_offset); 369 inline void compute_loop( 370 int ur_w, int l_overflow, int r_overflow, int k_offset = 0); 371 void generate() override; 372 get_iw_startdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32373 inline int get_iw_start(int ki, int l_overflow) { 374 int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w 375 + l_overflow * jcp.stride_w 376 - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1); 377 while (res < 0) 378 res += jcp.stride_w; 379 380 return res; 381 } 382 get_iw_enddnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32383 inline int get_iw_end(int ur_w, int ki, int r_overflow) { 384 if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail)) 385 ur_w += nstl::min(0, jcp.r_pad); // remove negative padding 386 int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w 387 + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1); 388 while (res < 0) 389 res += jcp.stride_w; 390 391 return ur_w - res; 392 } 393 get_diff_src_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32394 inline size_t get_diff_src_offset(int iw, int icb) { 395 const bool is_nxc_layout = is_dsrc_layout_nxc(); 396 size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic : jcp.ic_block; 397 size_t icb_str = is_nxc_layout 398 ? jcp.ic_block 399 : (size_t)jcp.id * jcp.ih * jcp.iw * jcp.ic_block; 400 401 return typesize * (icb * icb_str + iw * iw_str); 402 } 403 get_dst_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32404 inline ptrdiff_t get_dst_offset(int iw, int oc, int kw) { 405 ptrdiff_t ow 406 = (iw + jcp.l_pad - kw * (jcp.dilate_w + 1)) / jcp.stride_w; 407 ptrdiff_t ow_str 408 = is_ddst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block; 409 410 return typesize * (ow * ow_str + oc); 411 }; 412 is_dsrc_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32413 inline bool is_dsrc_layout_nxc() { 414 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, 415 format_tag::nwc); 416 } is_ddst_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_data_kernel_f32417 inline bool is_ddst_layout_nxc() { 418 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, 419 format_tag::nwc); 420 } 421 }; 422 423 struct jit_sve_512_conv_bwd_weights_kernel_f32 : public jit_generator { 424 jit_sve_512_conv_bwd_weights_kernel_f32dnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32425 jit_sve_512_conv_bwd_weights_kernel_f32(const jit_conv_conf_t &ajcp) 426 : jcp(ajcp) {} 427 generatednnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32428 void generate() override { 429 if (jcp.harness != harness_nxc) { 430 generate_kernel(); 431 } else { 432 assert(!"none microkernel"); 433 } 434 } 435 436 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sve_512_conv_bwd_weights_kernel_f32) 437 438 static status_t init_conf(jit_conv_conf_t &jcp, 439 const convolution_desc_t &cd, memory_desc_t &src_md, 440 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md, 441 memory_desc_t &diff_dst_md, int nthreads); 442 static void init_scratchpad(memory_tracking::registrar_t &scratchpad, 443 const jit_conv_conf_t &jcp); 444 445 jit_conv_conf_t jcp; 446 447 private: 448 using reg64_t = const XReg; 449 enum { typesize = sizeof(float) }; 450 static const int max_ur_w; 451 static const int min_oh_reduce; 452 453 reg64_t param = abi_param1; 454 reg64_t reg_input = x1; 455 reg64_t reg_kernel = x2; 456 reg64_t reg_output = x3; 457 reg64_t b_ic = x20; 458 reg64_t kj = x5; 459 reg64_t reg_kh = x6; 460 reg64_t reg_ur_w_trips = x7; 461 reg64_t reg_oj = x8; 462 reg64_t reg_tmp = x10; 463 reg64_t reg_icb = x9; 464 465 reg64_t ki = x11; 466 reg64_t reg_kd_count = x12; 467 reg64_t reg_oi = x12; 468 reg64_t reg_d_index = x13; 469 reg64_t reg_input_d = x8; 470 reg64_t reg_output_d = x9; 471 reg64_t aux_reg_input = x12; 472 reg64_t aux_reg_kernel = x13; 473 reg64_t reg_bias = x9; 474 reg64_t reg_oc_tail = x10; 475 476 /* Temporary registers */ 477 reg64_t reg_add_tmp = x14; 478 reg64_t reg_tmp_imm = x15; 479 480 reg64_t reg_kd_count_org = x16; 481 reg64_t reg_input_d_org = x17; 482 reg64_t reg_output_d_org = x18; 483 reg64_t reg_d_index_org = x19; 484 485 reg64_t reg_input_org = x24; 486 reg64_t reg_kernel_org = x22; 487 reg64_t reg_output_org = x23; 488 489 reg64_t reg_pre_addr_input = x25; 490 reg64_t reg_pre_addr_out = x26; 491 reg64_t reg_pre_addr_ker = x26; 492 reg64_t reg_ker_start_addr = x27; 493 reg64_t reg_addr_diff_input = x28; 494 495 const PReg reg_p_all_ones = p3; 496 prefetchdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32497 void prefetch( 498 const std::string prfop, int level, reg64_t in, long long int ofs) { 499 bool for_load = false; 500 if (prfop == "LD") { 501 for_load = true; 502 } else if (prfop == "ST") { 503 for_load = false; 504 } else { 505 assert(!"invalid prfop"); 506 } 507 508 bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false; 509 if (cacheline_aligned == true) { 510 Prfop op = PLDL1KEEP; 511 switch (level) { 512 case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break; 513 case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break; 514 case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break; 515 default: assert(!"invalid prfop"); break; 516 } 517 518 if ((ofs <= PRFMMAX) && (ofs >= 0)) { 519 prfm(op, ptr(in, static_cast<int32_t>(ofs))); 520 } else { 521 add_imm(reg_add_tmp, in, ofs, reg_tmp_imm); 522 prfm(op, ptr(reg_add_tmp)); 523 } 524 } else { 525 PrfopSve op_sve; 526 switch (level) { 527 case 1: 528 op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE; 529 break; 530 case 2: 531 op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE; 532 break; 533 case 3: 534 op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE; 535 break; 536 default: assert(!"invalid prfop"); break; 537 } 538 539 if ((VL_OFS(ofs) <= PRFWMAX) 540 && (VL_OFS(ofs) >= (-1 * PRFWMAX - 1))) { 541 prfw(op_sve, reg_p_all_ones, 542 ptr(in, static_cast<int32_t>(VL_OFS(ofs)))); 543 } else { 544 add_imm(reg_add_tmp, in, ofs, reg_tmp_imm); 545 prfw(op_sve, reg_p_all_ones, ptr(reg_add_tmp)); 546 } 547 } 548 } 549 550 inline void bias_kernel_2d(); 551 inline void bias_kernel_3d(); 552 inline void maybe_zero_kernel(); 553 inline void compute_oh_step_unroll_ow_icblock( 554 int ic_block_step, int max_ur_w); 555 inline void od_step_comeback_pointers(); 556 inline void oh_step_comeback_pointers(); 557 inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w); 558 inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r, 559 int ic_block_step, int input_offset, int kernel_offset, 560 int output_offset, bool input_wraparound = false); 561 inline void compute_oh_step_common(int ic_block_step, int max_ur_w); 562 inline void compute_oh_step_disp(); 563 inline void compute_oh_loop_common(); 564 inline void compute_oh_loop_partial(); 565 inline void compute_od_loop_partial(); 566 567 inline bool compute_full_spat_loop(); 568 inline bool flat_4ops_compute(); 569 570 inline void compute_loop(); is_src_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32571 inline bool is_src_layout_nxc() { 572 return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, 573 format_tag::nwc); 574 } is_ddst_layout_nxcdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32575 inline bool is_ddst_layout_nxc() { 576 return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, 577 format_tag::nwc); 578 } 579 get_full_src_offsetdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32580 inline ptrdiff_t get_full_src_offset( 581 int i_iw, int i_ic, ptrdiff_t input_offset) { 582 const bool is_nxc_layout = is_src_layout_nxc(); 583 const size_t w_shift_st = (jcp.is_hw_transp ? jcp.iw : 1) 584 * (jcp.is_1stconv ? 1 : jcp.ic_block); 585 ptrdiff_t w_shift = is_nxc_layout ? jcp.ngroups * jcp.ic : w_shift_st; 586 ptrdiff_t ic_shift = (jcp.is_1stconv && !is_nxc_layout 587 ? (ptrdiff_t)jcp.ih * jcp.iw * jcp.id 588 : 1); 589 590 ptrdiff_t local_input_offset = i_iw * w_shift + i_ic * ic_shift; 591 return input_offset + typesize * local_input_offset; 592 }; 593 get_iw_idxdnnl::impl::cpu::aarch64::jit_sve_512_conv_bwd_weights_kernel_f32594 inline int get_iw_idx(int ow, int kw, int l_pad) { 595 return ow * jcp.stride_w + kw * (jcp.dilate_w + 1) - l_pad; 596 } 597 598 void generate_kernel(); 599 600 static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb, 601 int &nthr_g, int &nthr_oc_b, int &nthr_ic_b, int nthreads); 602 }; 603 604 } // namespace aarch64 605 } // namespace cpu 606 } // namespace impl 607 } // namespace dnnl 608 609 #endif 610