1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #include "common/c_types_map.hpp"
19 #include "common/memory.hpp"
20 #include "common/memory_tracking.hpp"
21 #include "common/nstl.hpp"
22 #include "common/type_helpers.hpp"
23 #include "common/utils.hpp"
24 
25 #include "cpu/aarch64/jit_sve_512_x8s8s32x_conv_kernel.hpp"
26 
27 #define GET_OFF(field) static_cast<int32_t>(offsetof(jit_conv_call_s, field))
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace aarch64 {
33 
34 using namespace dnnl::impl::memory_tracking::names;
35 using namespace dnnl::impl::utils;
36 using namespace dnnl::impl::data_type;
37 
38 namespace {
pick_loop_order(jit_conv_conf_t & jcp,int nthr)39 void pick_loop_order(jit_conv_conf_t &jcp, int nthr) {
40     jcp.loop_order = loop_cwgn;
41     if (jcp.ngroups > 1) {
42         jcp.loop_order = loop_ngcw;
43         if (jcp.mb < nthr)
44             jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
45     }
46 }
47 } // namespace
48 
prepare_output(int ur_w)49 void jit_sve_512_x8s8s32x_fwd_kernel::prepare_output(int ur_w) {
50     int nb_oc_block
51             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
52     for (int k = 0; k < nb_oc_block; k++)
53         for (int j = 0; j < ur_w; j++) {
54             auto vmm = vmm_out(j, k);
55             eor(vmm.d, vmm.d, vmm.d);
56         }
57     if (!jcp.signed_input) {
58         eor(reg_scratch, reg_scratch, reg_scratch);
59         if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
60             mov_imm(WReg(reg_tmp0_imm.getIdx()), 128);
61             dup(vmm_shift.s, WReg(reg_tmp0_imm.getIdx()));
62         } else {
63             dup(vmm_shift.b, -128);
64         }
65     }
66 }
67 
cvt2ps(data_type_t type_in,const ZReg vmm_in,const XReg reg_base,const int offset,bool mask_flag)68 void jit_sve_512_x8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in,
69         const ZReg vmm_in, const XReg reg_base, const int offset,
70         bool mask_flag) {
71 
72     auto vmm = vmm_in;
73     auto reg_addr = get_comp_addr_reg(reg_base, offset);
74     switch (type_in) {
75         case data_type::f32:
76         case data_type::s32:
77             if (mask_flag)
78                 ld1w(vmm.s, ktail_mask / T_z, ptr(reg_addr));
79             else
80                 ld1w(vmm.s, mask_all_one, ptr(reg_addr));
81             break;
82         case data_type::s8:
83             sub(reg_stack, reg_stack, 64);
84             str(vmm_tmp, ptr(reg_stack));
85             vmm_load_src(vmm_tmp, reg_addr, mask_flag);
86             zip1(vmm_tmp.b, vmm_tmp.b, vmm_tmp.b);
87             zip1(vmm_tmp.h, vmm_tmp.h, vmm_tmp.h);
88             sxtb(vmm.s, mask_all_one / T_m, vmm_tmp.s);
89             if (mask_flag) {
90                 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
91                 mov(vmm.s, mask_tmp / T_m, 0);
92             }
93             ldr(vmm_tmp, ptr(reg_stack));
94             add(reg_stack, reg_stack, 64);
95             break;
96         case data_type::u8:
97             sub(reg_stack, reg_stack, 64);
98             str(vmm_tmp, ptr(reg_stack));
99             vmm_load_src(vmm_tmp, reg_addr, mask_flag);
100             zip1(vmm_tmp.b, vmm_tmp.b, vmm_tmp.b);
101             zip1(vmm_tmp.h, vmm_tmp.h, vmm_tmp.h);
102             uxtb(vmm.s, mask_all_one / T_m, vmm_tmp.s);
103             if (mask_flag) {
104                 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
105                 mov(vmm.s, mask_tmp / T_m, 0);
106             }
107             ldr(vmm_tmp, ptr(reg_stack));
108             add(reg_stack, reg_stack, 64);
109             break;
110         default: assert(!"unsupported data type");
111     }
112     if (type_in != data_type::f32) scvtf(vmm_in.s, mask_all_one, vmm_in.s);
113 }
114 
store_output(int ur_w,bool last_oc_block_flag)115 void jit_sve_512_x8s8s32x_fwd_kernel::store_output(
116         int ur_w, bool last_oc_block_flag) {
117     int nb_oc_block
118             = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
119     int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
120 
121     ldr(reg_bias, ptr(reg_param1, GET_OFF(bias)));
122     ldr(reg_ptr_scales, ptr(reg_param1, GET_OFF(scales)));
123     if (!jcp.signed_input)
124         ldr(reg_compensation, ptr(reg_param1, GET_OFF(compensation)));
125 
126     const auto &p = attr_.post_ops_;
127     const int sum_idx = p.find(primitive_kind::sum);
128     const float *p_sum_scale = nullptr;
129     if (sum_idx != -1) {
130         const auto &p_entry = p.entry_[sum_idx];
131         p_sum_scale = &p_entry.sum.scale;
132     }
133 
134     if (p_sum_scale && *p_sum_scale != 1.f)
135         mov_imm(reg_ptr_sum_scale, (size_t)p_sum_scale);
136 
137     for (int k = 0; k < nb_oc_block; k++) {
138         const bool mask_flag
139                 = last_oc_block_flag && k == nb_oc_block - 1 && mask_gflag;
140         int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
141         if (jcp.with_bias) {
142             int bias_offset = jcp.typesize_bia * k * oc_block;
143 
144             cvt2ps(jcp.bia_dt, vmm_bias, reg_bias, bias_offset, mask_flag);
145         }
146         if (!jcp.signed_input) {
147             int comp_offset = sizeof(int32_t) * k * oc_block;
148 
149             cvt2ps(data_type::s32, vmm_comp, reg_compensation, comp_offset,
150                     mask_flag);
151         }
152         /* optimization under specific conditions: preload scale_offset data */
153         if (!jcp.is_fast_depthwise && jcp.signed_input) {
154             auto reg_addr = get_comp_addr_reg(reg_ptr_scales, scale_offset);
155             ld1w(vmm_pre_load.s, mask_all_one, ptr(reg_addr));
156         }
157         /* add to accum: compensation, bias and permute */
158         for (int j = 0; j < ur_w; j++) {
159             auto vmm = vmm_out(j, k);
160             if (jcp.is_fast_depthwise) {
161                 auto zmm = zmm_out(j, k);
162                 auto zmm_tmp1 = ZReg(31);
163                 auto zmm_tmp2 = ZReg(30);
164                 auto zmm_tmp3 = ZReg(29);
165                 sub(reg_stack, reg_stack, 64);
166                 str(zmm_tmp1, ptr(reg_stack));
167                 sub(reg_stack, reg_stack, 64);
168                 str(zmm_tmp2, ptr(reg_stack));
169                 sub(reg_stack, reg_stack, 64);
170                 str(zmm_tmp3, ptr(reg_stack));
171                 mov(zmm_tmp1.s, 15);
172                 and_(zmm_tmp1.b, mask_all_one, zmm_permute.b);
173                 for (int i = 0; i < 16; i++) {
174                     cmpeq(mask_tmp.s, mask_all_one, zmm_tmp1.s, i);
175                     dup(zmm_tmp2.s, zmm.s[i]);
176                     mov(zmm_tmp3.s, mask_tmp / T_m, zmm_tmp2.s);
177                 }
178                 mov(zmm.d, zmm_tmp3.d);
179                 ldr(zmm_tmp3, ptr(reg_stack));
180                 add(reg_stack, reg_stack, 64);
181                 ldr(zmm_tmp2, ptr(reg_stack));
182                 add(reg_stack, reg_stack, 64);
183                 ldr(zmm_tmp1, ptr(reg_stack));
184                 add(reg_stack, reg_stack, 64);
185             }
186             scvtf(vmm.s, mask_all_one, vmm.s);
187             if (!jcp.signed_input) fsub(vmm.s, vmm.s, vmm_comp.s);
188             if (jcp.with_bias) fadd(vmm.s, vmm.s, vmm_bias.s);
189 
190             if (!jcp.is_fast_depthwise && jcp.signed_input) {
191                 /* optimization under specific conditions: optimize using preloaded scale_offset data */
192                 fmul(vmm.s, vmm.s, vmm_pre_load.s);
193                 if (mask_flag) {
194                     not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
195                     mov(vmm.s, mask_tmp / T_m, 0);
196                 }
197             } else {
198                 auto reg_addr = get_comp_addr_reg(reg_ptr_scales, scale_offset);
199                 sub(reg_stack, reg_stack, 64);
200                 str(vmm_tmp, ptr(reg_stack));
201                 ld1w(vmm_tmp.s, mask_all_one, ptr(reg_addr));
202                 fmul(vmm.s, vmm.s, vmm_tmp.s);
203                 ldr(vmm_tmp, ptr(reg_stack));
204                 add(reg_stack, reg_stack, 64);
205                 if (mask_flag) {
206                     not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
207                     mov(vmm.s, mask_tmp / T_m, 0);
208                 }
209             }
210         }
211     }
212 
213     /* Do post-ops */
214     if (p_sum_scale) { // post_op: sum
215         for (int k = 0; k < nb_oc_block; k++) {
216             const bool mask_flag
217                     = last_oc_block_flag && k == nb_oc_block - 1 && mask_gflag;
218             for (int j = 0; j < ur_w; j++) {
219                 int aux_output_offset = jcp.typesize_out
220                         * (k * oc_block
221                                 + j * jcp.oc_without_padding * jcp.ngroups);
222                 auto vmm = vmm_out(j, k);
223                 cvt2ps(jcp.dst_dt, vmm_prev_dst, reg_out, aux_output_offset,
224                         mask_flag);
225                 if (*p_sum_scale == 1.f) {
226                     fadd(vmm.s, vmm.s, vmm_prev_dst.s);
227                 } else {
228                     sub(reg_stack, reg_stack, 64);
229                     str(vmm_tmp, ptr(reg_stack));
230                     ld1rw(vmm_tmp.s, mask_all_one / T_z,
231                             ptr(reg_ptr_sum_scale));
232                     fmla(vmm.s, mask_all_one / T_m, vmm_prev_dst.s, vmm_tmp.s);
233                     ldr(vmm_tmp, ptr(reg_stack));
234                     add(reg_stack, reg_stack, 64);
235                 }
236             }
237         }
238     }
239 
240     // Properly saturate the accumulators for integer datatypes
241     if (one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
242         if (jcp.dst_dt == data_type::u8) {
243             eor(vmm_zero.d, vmm_zero.d, vmm_zero.d);
244         }
245         float saturation_ubound = types::max_value<float>(jcp.dst_dt);
246         mov_imm(aux_reg_saturation, float2int(saturation_ubound));
247         dup(vmm_saturation.s, WReg(aux_reg_saturation.getIdx()));
248 
249         for (int k = 0; k < nb_oc_block; k++) {
250             for (int j = 0; j < ur_w; j++) {
251                 auto vmm = vmm_out(j, k);
252                 if (jcp.dst_dt == data_type::u8) {
253                     fmaxnm(vmm.s, mask_all_one, vmm_zero.s);
254                     fmax(vmm.s, mask_all_one, vmm_zero.s);
255                 }
256                 fminnm(vmm.s, mask_all_one, vmm_saturation.s);
257                 fmin(vmm.s, mask_all_one, vmm_saturation.s);
258 
259                 frintn(vmm.s, mask_all_one, vmm.s);
260                 fcvtzs(vmm.s, mask_all_one, vmm.s);
261             }
262         }
263     }
264 
265     /* write out register to output_addr */
266     for (int k = 0; k < nb_oc_block; k++) {
267         const bool mask_flag
268                 = last_oc_block_flag && k == nb_oc_block - 1 && mask_gflag;
269         for (int j = 0; j < ur_w; j++) {
270             int aux_output_offset = jcp.typesize_out
271                     * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
272 
273             auto base = reg_out;
274             auto re = get_offset(aux_output_offset);
275 
276             auto reg_tmp_adr = ((j % 4) == 0) ? reg_tmp0_adr
277                                               : ((j % 4) == 1)
278                             ? reg_tmp1_adr
279                             : ((j % 4) == 2) ? reg_tmp2_adr : reg_tmp3_adr;
280             auto reg_tmp_imm = ((j % 4) == 0) ? reg_tmp0_imm
281                                               : ((j % 4) == 1)
282                             ? reg_tmp1_imm
283                             : ((j % 4) == 2) ? reg_tmp2_imm : reg_tmp3_imm;
284             add_imm(reg_tmp_adr, base, re, reg_tmp_imm);
285 
286             auto vmm = vmm_out(j, k);
287 
288             auto _mask = mask_flag ? ktail_mask : mask_all_one;
289             switch (jcp.dst_dt) {
290                 case data_type::f32:
291                 case data_type::s32:
292                     st1w(vmm.s, _mask, ptr(reg_tmp_adr));
293                     break;
294                 case data_type::s8:
295                     smin(vmm.s, 127);
296                     smax(vmm.s, -128);
297                     st1b(vmm.s, _mask, ptr(reg_tmp_adr));
298                     break;
299                 case data_type::u8:
300                     umin(vmm.s, 255);
301                     st1b(vmm.s, _mask, ptr(reg_tmp_adr));
302                     break;
303                 default: assert(!"unknown dst_dt");
304             }
305         }
306     }
307 }
308 
compute_ker_dw(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool h_padded)309 void jit_sve_512_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l,
310         int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
311 
312     if (sve_len_ != 64)
313         assert(!"invalid group blocking for depthwise convolution");
314 
315     auto input_spatial_index = [=](int oi, int ki) {
316         return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
317     };
318 
319     auto input_offset2 = [=](int ii, int ci) {
320         if (jcp.is_fused_conv)
321             return jcp.typesize_in
322                     * (ii * jcp.dw_conv_buffer_oc + ci * jcp.ch_block);
323         else
324             return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
325     };
326 
327     auto input_offset3 = [=](int oi, int ci, int ki) {
328         return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
329     };
330 
331     auto kernel_offset = [=](int ci, int ki) {
332         return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
333     };
334 
335     auto compute = [=](ZReg vreg_acc, ZReg vreg_wei, ZReg vreg_src) {
336         sdot(vreg_acc.s, vreg_src.b, vreg_wei.b);
337     };
338 
339     int ii_start = 0;
340     int ii_end = -1;
341     if (jcp.is_resrc_depthwise && !h_padded) {
342         // find bounds of input spatial indices
343         bool first = true;
344         for (int ki = 0; ki < jcp.kw; ki++) {
345             int oi_start = get_ow_start(ki, pad_l);
346             int oi_end = get_ow_end(ur_w, ki, pad_r);
347             for (int oi = oi_start; oi < oi_end; oi++) {
348                 int ii = input_spatial_index(oi, ki);
349                 if (first || ii < ii_start) ii_start = ii;
350                 if (first || ii > ii_end) ii_end = ii;
351                 first = false;
352             }
353         }
354     }
355 
356     if (!jcp.signed_input) {
357         eor(zmm_shifted_zero.d, zmm_shifted_zero.d, zmm_shifted_zero.d);
358         sub(zmm_shifted_zero.b, zmm_shifted_zero.b, vmm_shift.b);
359     }
360 
361     for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
362         const bool mask_flag = last_ic_block_flag != no_last_block
363                 && ci == jcp.nb_ch_blocking - 1;
364         if (jcp.is_resrc_depthwise && !h_padded) {
365             // now we can load input once and reuse up to jcp.kw times
366             for (int ii = ii_start; ii <= ii_end; ii++) {
367                 int aux_input_offset = input_offset2(ii, ci);
368                 auto zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
369                 auto zmm_inp_msk = zmm_inp_tmp;
370                 if (jcp.is_fast_depthwise) {
371                     assert(!mask_flag);
372                     auto reg_addr
373                             = get_comp_addr_reg(aux_reg_inp, aux_input_offset);
374                     ldr(QReg(zmm_inp_msk.getIdx()), ptr(reg_addr));
375                     ptrue(mask_tmp.d, VL2);
376                     splice(zmm_inp_msk.d, mask_tmp.d, zmm_inp_msk.d);
377                     ptrue(mask_tmp.d, VL4);
378                     splice(zmm_inp_msk.d, mask_tmp.d, zmm_inp_msk.d);
379                 } else {
380                     auto reg_addr
381                             = get_comp_addr_reg(aux_reg_inp, aux_input_offset);
382                     auto zmm_tmp = ZReg(31);
383                     sub(reg_stack, reg_stack, 64);
384                     str(zmm_tmp, ptr(reg_stack));
385                     if (mask_flag) {
386                         eor(mask_tmp.b, mask_all_one, mask_tmp.b, mask_tmp.b);
387                         eor(mask_tmp2.b, mask_all_one, mask_tmp2.b,
388                                 mask_tmp2.b);
389                         uzp1(mask_tmp.h, ktail_mask.h, mask_tmp.h);
390                         uzp1(mask_tmp.b, mask_tmp.b, mask_tmp2.b);
391                     } else {
392                         ptrue(mask_tmp.b, VL16);
393                     }
394                     ld1b(zmm_tmp.b, mask_tmp, ptr(reg_addr));
395                     zip1(zmm_tmp.b, zmm_tmp.b, zmm_tmp.b);
396                     zip1(zmm_tmp.h, zmm_tmp.h, zmm_tmp.h);
397                     uxtb(zmm_inp_msk.s, mask_all_one / T_m, zmm_tmp.s);
398                     if (mask_flag) {
399                         not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
400                         mov(zmm_inp_msk.s, mask_tmp / T_m, 0);
401                     }
402                     ldr(zmm_tmp, ptr(reg_stack));
403                     add(reg_stack, reg_stack, 64);
404                 }
405                 if (!jcp.signed_input)
406                     sub(zmm_inp_tmp.b, zmm_inp_tmp.b, vmm_shift.b);
407             }
408         }
409         for (int ki = 0; ki < jcp.kw; ki++) {
410             int aux_kernel_offset = kernel_offset(ci, ki);
411             if (jcp.is_fast_depthwise) {
412                 auto reg_addr
413                         = get_comp_addr_reg(aux_reg_ker, aux_kernel_offset);
414                 ldr(QReg(zmm_wei.getIdx()), ptr(reg_addr));
415                 ptrue(mask_tmp.d, VL2);
416                 splice(zmm_wei.d, mask_tmp.d, zmm_wei.d);
417                 ptrue(mask_tmp.d, VL4);
418                 splice(zmm_wei.d, mask_tmp.d, zmm_wei.d);
419                 not_(mask_tmp.b, mask_all_one, kblend_mask.b);
420                 mov(zmm_wei.b, kblend_mask / T_m, zmm_wei.b);
421                 mov(zmm_wei.b, mask_tmp / T_m, 0);
422             } else {
423                 auto reg_addr
424                         = get_comp_addr_reg(aux_reg_ker, aux_kernel_offset);
425                 auto zmm_tmp = ZReg(30);
426                 sub(reg_stack, reg_stack, 64);
427                 str(zmm_tmp, ptr(reg_stack));
428                 ldr(QReg(zmm_tmp.getIdx()), ptr(reg_addr));
429                 zip1(zmm_tmp.b, zmm_tmp.b, zmm_tmp.b);
430                 zip1(zmm_tmp.h, zmm_tmp.h, zmm_tmp.h);
431                 sxtb(zmm_wei.s, mask_all_one / T_m, zmm_tmp.s);
432                 ldr(zmm_tmp, ptr(reg_stack));
433                 add(reg_stack, reg_stack, 64);
434             }
435             if (h_padded) {
436                 assert(!jcp.signed_input);
437                 for (int oi = 0; oi < ur_w; oi++)
438                     compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
439             } else {
440                 auto r_zmm_src = zmm_src;
441                 int oi_start = get_ow_start(ki, pad_l);
442                 int oi_end = get_ow_end(ur_w, ki, pad_r);
443                 int start_ = !jcp.signed_input ? 0 : oi_start;
444                 int end_ = !jcp.signed_input ? ur_w : oi_end;
445                 for (int oi = start_; oi < end_; oi++) {
446                     if (oi >= oi_start && oi < oi_end) {
447                         if (jcp.is_resrc_depthwise) {
448                             int ii = input_spatial_index(oi, ki);
449                             zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
450                         } else {
451                             int aux_input_offset = input_offset3(oi, ci, ki);
452                             if (jcp.is_fast_depthwise) {
453                                 assert(!mask_flag);
454                                 auto reg_addr = get_comp_addr_reg(
455                                         aux_reg_inp, aux_input_offset);
456                                 ldr(QReg(r_zmm_src.getIdx()), ptr(reg_addr));
457                                 ptrue(mask_tmp.d, VL2);
458                                 splice(r_zmm_src.d, mask_tmp.d, r_zmm_src.d);
459                                 ptrue(mask_tmp.d, VL4);
460                                 splice(r_zmm_src.d, mask_tmp.d, r_zmm_src.d);
461                             } else {
462                                 auto reg_addr = get_comp_addr_reg(
463                                         aux_reg_inp, aux_input_offset);
464                                 auto zmm_tmp = ZReg(31);
465                                 sub(reg_stack, reg_stack, 64);
466                                 str(zmm_tmp, ptr(reg_stack));
467                                 if (mask_flag) {
468                                     eor(mask_tmp.b, mask_all_one, mask_tmp.b,
469                                             mask_tmp.b);
470                                     eor(mask_tmp2.b, mask_all_one, mask_tmp2.b,
471                                             mask_tmp2.b);
472                                     uzp1(mask_tmp.h, ktail_mask.h, mask_tmp.h);
473                                     uzp1(mask_tmp.b, mask_tmp.b, mask_tmp2.b);
474                                 } else {
475                                     ptrue(mask_tmp.b, VL16);
476                                 }
477                                 ld1b(zmm_tmp.b, mask_tmp, ptr(reg_addr));
478                                 zip1(zmm_tmp.b, zmm_tmp.b, zmm_tmp.b);
479                                 zip1(zmm_tmp.h, zmm_tmp.h, zmm_tmp.h);
480                                 uxtb(r_zmm_src.s, mask_all_one / T_m,
481                                         zmm_tmp.s);
482                                 if (mask_flag) {
483                                     not_(mask_tmp.b, mask_all_one.b,
484                                             ktail_mask.b);
485                                     mov(r_zmm_src.s, mask_tmp / T_m, 0);
486                                 }
487                                 ldr(zmm_tmp, ptr(reg_stack));
488                                 add(reg_stack, reg_stack, 64);
489                             }
490                             if (!jcp.signed_input)
491                                 sub(zmm_src.b, zmm_src.b, vmm_shift.b);
492                         }
493                         compute(zmm_out(oi, ci), zmm_wei, zmm_src);
494                     } else {
495                         assert(!jcp.signed_input);
496                         compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
497                     }
498                 }
499             }
500         }
501     }
502 }
503 
compute_ker(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool h_padded)504 void jit_sve_512_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l,
505         int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
506     if (jcp.is_depthwise)
507         return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
508 
509     int kw = jcp.kw;
510     int stride_w = jcp.stride_w;
511     int ic_block = jcp.ic_block;
512     int oc_block = jcp.oc_block;
513     int ch_block_all = jcp.ch_block * ic_block * oc_block;
514 
515     int nb_oc_block = jcp.nb_oc_blocking;
516 
517     auto input_offset = [=](int oi, int ic, int ki) {
518         return jcp.typesize_in
519                 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
520                                 * jcp.ic_without_padding * jcp.ngroups
521                         + 4 * ic);
522     };
523     auto kernel_offset = [=](int ii, int ic, int ki) {
524         return jcp.typesize_in
525                 * ((ii * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki)
526                                 * ch_block_all
527                         + 4 * ic * oc_block);
528     };
529     auto compute = [=](ZReg vreg_acc, ZReg vreg_wei, ZReg vreg_src) {
530         sdot(ZRegS(vreg_acc.getIdx()), ZRegB(vreg_src.getIdx()),
531                 ZRegB(vreg_wei.getIdx()));
532     };
533 
534     for (int ki = 0; ki < kw; ki++) {
535         int jj_start = get_ow_start(ki, pad_l);
536         int jj_end = get_ow_end(ur_w, ki, pad_r);
537         int ic_tail_size = jcp.ic_without_padding % 4;
538         int _start = (!jcp.signed_input) ? 0 : jj_start;
539         int _end = (!jcp.signed_input) ? ur_w : jj_end;
540         /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
541         int icb = (last_ic_block_flag != no_last_block)
542                 ? div_up((jcp.ic_without_padding % ic_block), 4)
543                 : ic_block / 4;
544         for (int ic = 0; ic < icb; ic++) {
545             if (h_padded) {
546                 /* fill padded area with shifted values */
547                 auto inp = vmm_inp(0, nb_oc_block);
548                 eor(inp.d, inp.d, inp.d);
549                 sub(inp.b, inp.b, vmm_shift.b);
550             } else {
551                 for (int jj = _start; jj < _end; jj++) {
552                     int aux_input_offset = input_offset(jj, ic, ki);
553                     if (jj >= jj_start && jj < jj_end) {
554                         if (last_ic_block_flag == last_sp_block
555                                 && ic_tail_size != 0 && ic == icb - 1) {
556                             auto xmm_tmp = VReg16B(
557                                     vmm_inp(jj, nb_oc_block).getIdx());
558                             for (int r = 0; r < ic_tail_size; ++r) {
559                                 add_imm(reg_tmp0_adr, aux_reg_inp,
560                                         (aux_input_offset + r), reg_tmp0_imm);
561                                 ldrb(WReg(reg_tmp1_imm.getIdx()),
562                                         ptr(reg_tmp0_adr));
563                                 ins(VReg16B(xmm_tmp.getIdx())[r],
564                                         WReg(reg_tmp1_imm.getIdx()));
565                             }
566                             dup(vmm_inp(jj, nb_oc_block).s,
567                                     ZRegS(xmm_tmp.getIdx())[0]);
568                         } else {
569                             auto base = aux_reg_inp;
570                             auto re = get_offset(aux_input_offset);
571 
572                             if ((-0x40 <= re) && (re < 0x40) && ((re % 4) == 0))
573                                 ld1rw(vmm_inp(jj, nb_oc_block).s, mask_all_one,
574                                         ptr(base, static_cast<int32_t>(re)));
575                             else {
576                                 auto reg_tmp_adr = ((jj % 4) == 0)
577                                         ? reg_tmp0_adr
578                                         : ((jj % 4) == 1) ? reg_tmp1_adr
579                                                           : ((jj % 4) == 2)
580                                                         ? reg_tmp2_adr
581                                                         : reg_tmp3_adr;
582                                 auto reg_tmp_imm = ((jj % 4) == 0)
583                                         ? reg_tmp0_imm
584                                         : ((jj % 4) == 1) ? reg_tmp1_imm
585                                                           : ((jj % 4) == 2)
586                                                         ? reg_tmp2_imm
587                                                         : reg_tmp3_imm;
588                                 add_imm(reg_tmp_adr, base, re, reg_tmp_imm);
589                                 ld1rw(vmm_inp(jj, nb_oc_block).s, mask_all_one,
590                                         ptr(reg_tmp_adr));
591                             }
592                         }
593                         if (!jcp.signed_input)
594                             sub(vmm_inp(jj, nb_oc_block).b,
595                                     vmm_inp(jj, nb_oc_block).b, vmm_shift.b);
596                     } else {
597                         /* fill padded area with shifted values */
598                         if (!jcp.signed_input) {
599                             auto inp = vmm_inp(jj, nb_oc_block);
600                             eor(inp.d, inp.d, inp.d);
601                             sub(inp.b, inp.b, vmm_shift.b);
602                         }
603                     }
604                 }
605             }
606             for (int ii = 0; ii < nb_oc_block; ii++) {
607                 if (!jcp.signed_input) {
608                     int aux_kernel_offset = kernel_offset(ii, ic, ki);
609                     auto reg_addr
610                             = get_comp_addr_reg(aux_reg_ker, aux_kernel_offset);
611                     ld1w(vmm_wei.s, mask_all_one, ptr(reg_addr));
612                     for (int jj = _start; jj < _end; jj++) {
613                         auto inp = (h_padded == true)
614                                 ? vmm_inp(0, nb_oc_block)
615                                 : vmm_inp(jj, nb_oc_block);
616                         compute(vmm_out(jj, ii), vmm_wei, inp);
617                     }
618                 } else {
619                     if (ii == 0) {
620                         int aux_kernel_offset = kernel_offset(ii, ic, ki);
621                         auto reg_addr = get_comp_addr_reg(
622                                 aux_reg_ker, aux_kernel_offset);
623                         ld1w(vmm_wei.s, mask_all_one, ptr(reg_addr));
624                     }
625                     if ((ii + 1) < nb_oc_block) {
626                         int aux_kernel_offset = kernel_offset((ii + 1), ic, ki);
627                         auto _vmm_wei = ((ii % 2) == 0) ? vmm_comp : vmm_wei;
628                         auto reg_addr = get_comp_addr_reg(
629                                 aux_reg_ker, aux_kernel_offset);
630                         ld1w(_vmm_wei.s, mask_all_one, ptr(reg_addr));
631                     }
632                     for (int jj = _start; jj < _end; jj++) {
633                         auto _vmm_wei = ((ii % 2) == 0) ? vmm_wei : vmm_comp;
634                         auto inp = (h_padded == true)
635                                 ? vmm_inp(0, nb_oc_block)
636                                 : vmm_inp(jj, nb_oc_block);
637                         compute(vmm_out(jj, ii), _vmm_wei, inp);
638                     }
639                 }
640             }
641         }
642     }
643 }
644 
kh_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag)645 void jit_sve_512_x8s8s32x_fwd_kernel::kh_loop(
646         int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
647     Label kd_label, kh_label, skip_kd_loop, skip_kh_loop;
648     Label f_overflow_label, no_f_overflow_label, d_h_f_overflow_label,
649             t_overflow_label, no_t_overflow_label, b_overflow_label,
650             no_b_overflow_label, back_overflow_label, no_back_overflow_label,
651             d_h_back_overflow_label;
652 
653     int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
654     int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
655     int shift_input_ptr
656             = jcp.typesize_in * jcp.iw * jcp.ic_without_padding * jcp.ngroups;
657 
658     if (jcp.ndims == 5) {
659         mov(aux_reg_ker_d, reg_ker);
660         mov(aux_reg_inp_d, reg_inp);
661         if (!jcp.signed_input) {
662             //TODO: May be avoided when f_pad=0 and dd0
663             //TODO: Potential optimization by precomputing, when kd <<< od?
664             ldr(reg_ki, ptr(reg_param1, GET_OFF(f_overflow)));
665             cmp(reg_ki, 0);
666             b(EQ, no_f_overflow_label);
667             L(f_overflow_label);
668             {
669                 mov(aux_reg_ker, aux_reg_ker_d);
670                 mov_imm(reg_kj, jcp.kh);
671                 L(d_h_f_overflow_label);
672                 {
673                     compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
674                     adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr,
675                             reg_tmp0_imm);
676                     subs(reg_kj, reg_kj, 1);
677                     b(NE, d_h_f_overflow_label);
678                 }
679                 add_imm(aux_reg_ker_d, aux_reg_ker_d, shift_kernel_ptr * jcp.kh,
680                         reg_tmp0_imm);
681                 subs(reg_ki, reg_ki, 1);
682                 b(NE, f_overflow_label);
683             }
684             L(no_f_overflow_label);
685         }
686 
687         ldr(reg_ki, ptr(reg_param1, GET_OFF(kd_padding)));
688         if ((!jcp.signed_input) || (jcp.dilate_d >= jcp.id)
689                 || (jcp.signed_input
690                         && (jcp.kd - 1) * (jcp.dilate_d + 1)
691                                 < nstl::max(jcp.f_pad, jcp.back_pad))) {
692             cmp(reg_ki, 0);
693             b(EQ, skip_kd_loop);
694         }
695         L(kd_label);
696         mov(aux_reg_inp, aux_reg_inp_d);
697         mov(aux_reg_ker, aux_reg_ker_d);
698     } else {
699         if (jcp.is_fused_conv) {
700             mov(aux_reg_inp_buffer_ptr, reg_inp_buffer_ptr);
701         } else {
702             mov(aux_reg_inp, reg_inp);
703         }
704         mov(aux_reg_ker, reg_ker);
705     }
706 
707     if (!jcp.signed_input && jcp.ndims > 3) {
708         ldr(reg_overflow, ptr(reg_param1, GET_OFF(t_overflow)));
709         cmp(reg_overflow, 0);
710         b(EQ, no_t_overflow_label);
711         L(t_overflow_label);
712         {
713             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
714 
715             adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr, reg_tmp0_imm);
716             subs(reg_overflow, reg_overflow, 1);
717             cmp(reg_overflow, 0);
718             b(GT, t_overflow_label);
719         }
720         L(no_t_overflow_label);
721     }
722     ldr(reg_kj, ptr(reg_param1, GET_OFF(kh_padding)));
723     if ((!jcp.signed_input) || (jcp.dilate_h >= jcp.ih)
724             || (jcp.signed_input
725                     && (jcp.kh - 1) * (jcp.dilate_h + 1)
726                             < nstl::max(jcp.t_pad, jcp.b_pad))) {
727         cmp(reg_kj, 0);
728         b(EQ, skip_kh_loop);
729     }
730     L(kh_label);
731     {
732         if (jcp.is_fused_conv) {
733             ldr(aux_reg_inp, ptr(aux_reg_inp_buffer_ptr));
734             add(aux_reg_inp, aux_reg_inp, reg_inp);
735         }
736         compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
737 
738         adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr, reg_tmp0_imm);
739         if (jcp.is_fused_conv) {
740             adds_imm(aux_reg_inp_buffer_ptr, aux_reg_inp_buffer_ptr,
741                     sizeof(void *), reg_tmp0_imm);
742         } else {
743             adds_imm(aux_reg_inp, aux_reg_inp,
744                     shift_input_ptr * (jcp.dilate_h + 1), reg_tmp0_imm);
745         }
746         subs(reg_kj, reg_kj, 1);
747         cmp(reg_kj, 0);
748         b(GT, kh_label);
749     }
750     L(skip_kh_loop);
751     if (!jcp.signed_input && jcp.ndims > 3) {
752         ldr(reg_overflow, ptr(reg_param1, GET_OFF(b_overflow)));
753         cmp(reg_overflow, 0);
754         b(EQ, no_b_overflow_label);
755         L(b_overflow_label);
756         {
757             compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
758 
759             adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr, reg_tmp0_imm);
760             subs(reg_overflow, reg_overflow, 1);
761             cmp(reg_overflow, 0);
762             b(GT, b_overflow_label);
763         }
764         L(no_b_overflow_label);
765     }
766 
767     if (jcp.ndims == 5) {
768         adds_imm(aux_reg_inp_d, aux_reg_inp_d,
769                 shift_input_ptr * jcp.ih * (jcp.dilate_d + 1), reg_tmp0_imm);
770         adds_imm(aux_reg_ker_d, aux_reg_ker_d, shift_kernel_ptr * jcp.kh,
771                 reg_tmp0_imm);
772         subs(reg_ki, reg_ki, 1);
773         b(NE, kd_label);
774 
775         L(skip_kd_loop);
776         if (!jcp.signed_input) {
777             ldr(reg_ki, ptr(reg_param1, GET_OFF(back_overflow)));
778             cmp(reg_ki, 0);
779             b(EQ, no_back_overflow_label);
780             L(back_overflow_label);
781             {
782                 mov(aux_reg_ker, aux_reg_ker_d);
783                 mov(reg_kj, jcp.kh);
784                 L(d_h_back_overflow_label);
785                 {
786                     compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
787                     adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr,
788                             reg_tmp0_imm);
789                     subs(reg_kj, reg_kj, 1);
790                     b(NE, d_h_back_overflow_label);
791                 }
792                 adds_imm(aux_reg_ker_d, aux_reg_ker_d,
793                         shift_kernel_ptr * jcp.kh, reg_tmp0_imm);
794                 subs(reg_ki, reg_ki, 1);
795                 b(NE, back_overflow_label);
796             }
797             L(no_back_overflow_label);
798         }
799     }
800 }
801 
icb_loop(int ur_w,int pad_l,int pad_r,bool is_last_sp_block)802 void jit_sve_512_x8s8s32x_fwd_kernel::icb_loop(
803         int ur_w, int pad_l, int pad_r, bool is_last_sp_block) {
804     prepare_output(ur_w);
805 
806     // IC loop
807     Label icb_label;
808     mov_imm(reg_icb, jcp.nb_ic);
809     L(icb_label);
810     if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
811         Label common_ker, end_ker;
812 
813         if (jcp.is_depthwise)
814             cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
815         else
816             cmp(reg_icb, 1); // The last IC block
817         b(NE, common_ker);
818 
819         kh_loop(ur_w, pad_l, pad_r,
820                 is_last_sp_block ? last_sp_block : last_ic_block);
821         b(end_ker);
822 
823         L(common_ker);
824         kh_loop(ur_w, pad_l, pad_r, no_last_block);
825 
826         L(end_ker);
827     } else {
828         kh_loop(ur_w, pad_l, pad_r, no_last_block);
829     }
830     // End of IC Loop
831     int inp_step = jcp.ic_block;
832     int ker_step = jcp.kd * jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
833     adds_imm(reg_inp, reg_inp, jcp.typesize_in * inp_step, reg_tmp0_imm);
834     adds_imm(reg_ker, reg_ker, jcp.typesize_in * ker_step, reg_tmp0_imm);
835 
836     subs(reg_icb, reg_icb, 1);
837     cmp(reg_icb, 0);
838     b(GT, icb_label);
839 
840     subs_imm(reg_inp, reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic,
841             reg_tmp0_imm);
842     subs_imm(reg_ker, reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic,
843             reg_tmp0_imm);
844 
845     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
846         Label common_store, end_store;
847 
848         if (jcp.is_depthwise)
849             cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
850         else
851             cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
852 
853         b(NE, common_store);
854 
855         store_output(ur_w, true); // last oc block
856         b(end_store);
857 
858         L(common_store);
859         store_output(ur_w, false);
860 
861         L(end_store);
862     } else {
863         store_output(ur_w, false);
864     }
865 }
866 
vmm_mask_all_one()867 void jit_sve_512_x8s8s32x_fwd_kernel::vmm_mask_all_one() {
868     mask_gflag = false;
869     if (sve_len_ == 64) {
870         mask_gflag = true;
871         ptrue(mask_all_one.b);
872     } else if (sve_len_ == 32) {
873         ptrue(mask_all_one.b, VL32);
874     } else if (sve_len_ == 16) {
875         ptrue(mask_all_one.b, VL16);
876     } else {
877         assert(!"unreachable");
878     }
879 }
880 
vmm_load_src(ZReg src,XReg reg_addr,bool mask_flag)881 void jit_sve_512_x8s8s32x_fwd_kernel::vmm_load_src(
882         ZReg src, XReg reg_addr, bool mask_flag) {
883     if (mask_flag) {
884         eor(mask_tmp.b, mask_all_one, mask_tmp.b, mask_tmp.b);
885         eor(mask_tmp2.b, mask_all_one, mask_tmp2.b, mask_tmp2.b);
886         uzp1(mask_tmp.h, ktail_mask.h, mask_tmp.h);
887         uzp1(mask_tmp.b, mask_tmp.b, mask_tmp2.b);
888     } else {
889         if (sve_len_ == 64)
890             ptrue(mask_tmp.b, VL16);
891         else if (sve_len_ == 32)
892             ptrue(mask_tmp.b, VL8);
893         else if (sve_len_ == 16)
894             ptrue(mask_tmp.b, VL4);
895         else
896             assert(!"unreabhable");
897     }
898 
899     ld1b(src.b, mask_tmp, ptr(reg_addr));
900 }
901 
generate()902 void jit_sve_512_x8s8s32x_fwd_kernel::generate() {
903     Label permute_index_table;
904     int in_ic_shift = jcp.is_fused_conv ? jcp.dw_conv_buffer_oc
905                                         : jcp.ic_without_padding * jcp.ngroups;
906     int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
907             * in_ic_shift;
908     int inp_shift_pad_second_block
909             = -1 * jcp.typesize_in * jcp.l_pad * in_ic_shift;
910     int inp_shift = jcp.typesize_in * (jcp.ur_w * jcp.stride_w * in_ic_shift);
911     int out_shift = jcp.typesize_out
912             * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
913     preamble();
914 
915     vmm_mask_all_one();
916 
917     if (jcp.is_depthwise) {
918         int idx = jcp.max_regs_ur - 1;
919         if (!jcp.is_resrc_depthwise) zmm_src = ZReg(++idx);
920         if (jcp.is_fast_depthwise) zmm_permute = ZReg(++idx);
921         if (!jcp.signed_input) zmm_shifted_zero = ZReg(++idx);
922         // due to extra register used for shifts and compensations
923         // and/or saturation, we increment by one more
924         if (!jcp.signed_input || jcp.need_saturation) ++idx;
925         assert(idx == ker_dw_reg_base_idx);
926     }
927 
928     if (jcp.is_fused_conv) {
929         ldr(reg_inp_buffer_ptr, ptr(reg_param1, GET_OFF(src)));
930         /* In case of fused depthwise convolution, `param.src` is not a pointer
931         to input, instead it points to a buffer containing pointers to
932         consecutive rows of input in format wc with c=jcp.dw_conv_buffer_oc.
933         */
934         mov_imm(reg_inp, 0);
935     } else {
936         ldr(reg_inp, ptr(reg_param1, GET_OFF(src)));
937     }
938     ldr(reg_out, ptr(reg_param1, GET_OFF(dst)));
939     ldr(reg_ker, ptr(reg_param1, GET_OFF(filt)));
940 
941     if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
942         int tail_size = jcp.is_depthwise
943                 ? jcp.ngroups % jcp.ch_block
944                 : jcp.oc_without_padding % jcp.oc_block;
945         int mask = (1 << tail_size) - 1;
946         ldr(reg_oc_blocks, ptr(reg_param1, GET_OFF(oc_blocks)));
947         auto regw_tmp = reg_oi;
948         mov(regw_tmp, mask);
949         auto vmm_tmp1 = ZReg(31);
950         auto vmm_tmp2 = ZReg(30);
951         index(vmm_tmp1.s, 0, 1);
952         mov(vmm_tmp2.s, 1);
953         lsl(vmm_tmp2.s, mask_all_one / T_m, vmm_tmp1.s);
954         dup(vmm_tmp1.s, WReg(regw_tmp.getIdx()));
955         and_(vmm_tmp1.d, vmm_tmp1.d, vmm_tmp2.d);
956         cmpne(ktail_mask.s, mask_all_one, vmm_tmp1.s, 0);
957     }
958     if (jcp.is_fast_depthwise) {
959         // prepare mask register for blending weights
960         movk(reg_scratch, uint16_t(0x1111), 0);
961         movk(reg_scratch, uint16_t(0x2222), 16);
962         movk(reg_scratch, uint16_t(0x4444), 32);
963         movk(reg_scratch, uint16_t(0x8888), 48);
964         sub(reg_stack, reg_stack, 8);
965         str(reg_scratch, ptr(reg_stack));
966         ldr(kblend_mask, ptr(reg_stack));
967         add(reg_stack, reg_stack, 8);
968         // load permute indices from data section
969         adr(reg_scratch, permute_index_table);
970         ld1w(zmm_permute.s, mask_all_one, ptr(reg_scratch));
971     }
972 
973     int r_pad = nstl::max(0, jcp.r_pad);
974     int n_oi = jcp.ow / jcp.ur_w;
975     int r_pad1 = calculate_end_padding(jcp.l_pad, jcp.ur_w * n_oi, jcp.iw,
976             jcp.stride_w, calculate_extended_filter_size(jcp.kw, jcp.dilate_w));
977 
978     if (jcp.nb_ow == 1) {
979         if (r_pad1 > 0 || jcp.ur_w_tail == 0) n_oi--;
980 
981         eor(reg_oi, reg_oi, reg_oi);
982         if (jcp.ow == jcp.ur_w) {
983             icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
984         } else {
985             if (n_oi == 0) {
986                 icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
987                 adds_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp0_imm);
988                 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
989                 if (jcp.ur_w_tail != 0) {
990                     icb_loop(jcp.ur_w_tail, 0, r_pad, true);
991                 }
992             } else {
993                 if (jcp.l_pad > 0) {
994                     icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
995                     adds_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp0_imm);
996                     adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
997 
998                     adds(reg_oi, reg_oi, 1);
999                 }
1000                 if ((jcp.l_pad <= 0 && n_oi > 0)
1001                         || (jcp.l_pad > 0 && n_oi > 1)) {
1002                     Label ow_loop_label;
1003                     L(ow_loop_label);
1004                     {
1005                         icb_loop(jcp.ur_w, 0, 0, false);
1006                         adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1007                         adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1008 
1009                         adds(reg_oi, reg_oi, 1);
1010                         mov_imm(reg_tmp0_imm, n_oi);
1011                         cmp(reg_oi, reg_tmp0_imm);
1012                         b(LT, ow_loop_label);
1013                     }
1014                 }
1015                 if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
1016                     icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
1017                     adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1018                     adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1019                 }
1020                 if (jcp.ur_w_tail != 0) {
1021                     icb_loop(jcp.ur_w_tail, 0, r_pad, true);
1022                 }
1023             }
1024         }
1025     } else {
1026         // ow block is only processed.
1027         // Number of block is passed as parameter owb,
1028         // and padding processing depends on this number.
1029         Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
1030                 oi_loop_label, oi_loop_end_label;
1031 
1032         assert(jcp.ow_block % jcp.ur_w == 0);
1033         int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
1034         // to simplify code (and general regs usage),
1035         // size of ow block must be >= 2 * ur_w
1036         assert(n_oi_not_last_ow_block > 1);
1037         int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
1038         int n_oi_first_ow_block = n_oi_not_last_ow_block;
1039         int n_oi_last_ow_block
1040                 = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
1041         // prepare right padding
1042         bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
1043         bool first_ow_block_padded
1044                 = next_last_ow_block_padded && jcp.nb_ow == 2;
1045         bool last_ow_block_padded
1046                 = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
1047 
1048         if (last_ow_block_padded)
1049             n_oi_last_ow_block--;
1050         else if (first_ow_block_padded)
1051             n_oi_first_ow_block--;
1052         else if (next_last_ow_block_padded)
1053             n_oi_next_last_ow_block--;
1054 
1055         ldr(reg_owb, ptr(reg_param1, GET_OFF(owb)));
1056         cmp(reg_owb, 0); // is that the first ow-block ?
1057         b(GT, middle_ow_blocks_label);
1058 
1059         // the first ow block, compute left padding
1060         mov_imm(reg_oi, n_oi_first_ow_block);
1061         if (jcp.l_pad > 0) {
1062             icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
1063             adds_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp0_imm);
1064             adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1065 
1066             subs(reg_oi, reg_oi, 1);
1067         }
1068         b(oi_loop_label);
1069 
1070         // middle or last ow block entry
1071         L(middle_ow_blocks_label);
1072 
1073         if (jcp.l_pad > 0) {
1074             // just to consider left padding, not compute
1075             adds_imm(
1076                     reg_inp, reg_inp, inp_shift_pad_second_block, reg_tmp0_imm);
1077         }
1078 
1079         // set number of iteration for oi-loop
1080         if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
1081             cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
1082             mov_imm(reg_oi, n_oi_last_ow_block);
1083             b(EQ, oi_loop_label);
1084         }
1085 
1086         if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
1087             cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1088 
1089             mov_imm(reg_oi, n_oi_next_last_ow_block);
1090             b(EQ, oi_loop_label);
1091         }
1092         mov_imm(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
1093 
1094         // oi loop w/o padding
1095         L(oi_loop_label);
1096         {
1097             cmp(reg_oi, 0);
1098             b(LE, oi_loop_end_label);
1099 
1100             icb_loop(jcp.ur_w, 0, 0, false);
1101 
1102             adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1103             adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1104             subs(reg_oi, reg_oi, 1);
1105 
1106             b(oi_loop_label);
1107         }
1108         L(oi_loop_end_label);
1109 
1110         ldr(reg_owb, ptr(reg_param1, GET_OFF(owb)));
1111         cmp(reg_owb, 0); // first ow-block ?
1112         if (first_ow_block_padded)
1113             b(EQ, last_oi_label);
1114         else
1115             b(EQ, end_label);
1116 
1117         cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1118         b(LT, end_label);
1119         if (next_last_ow_block_padded)
1120             b(EQ, last_oi_label);
1121         else
1122             b(EQ, end_label);
1123 
1124         // that is last block
1125         if (!last_ow_block_padded) b(tail_label);
1126 
1127         // last oi block with right padding
1128         L(last_oi_label);
1129         icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
1130         adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1131         adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1132 
1133         ldr(reg_owb, ptr(reg_param1, GET_OFF(owb)));
1134         cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1135         b(LT, end_label);
1136 
1137         // ur_w tail
1138         L(tail_label);
1139         if (jcp.ur_w_tail != 0) { icb_loop(jcp.ur_w_tail, 0, r_pad, true); }
1140         L(end_label);
1141     }
1142     postamble();
1143 
1144     if (jcp.is_fast_depthwise) {
1145         align(64);
1146         L(permute_index_table);
1147         const uint32_t _idx[]
1148                 = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
1149         for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
1150             dd(_idx[i]);
1151     }
1152 }
1153 
post_ops_ok(jit_conv_conf_t & jcp,const primitive_attr_t & attr)1154 bool jit_sve_512_x8s8s32x_fwd_kernel::post_ops_ok(
1155         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1156     using namespace primitive_kind;
1157     const auto &p = attr.post_ops_;
1158 
1159     /* At this time, post_op is not supported. */
1160     return 0 == p.len();
1161 }
1162 
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr,int nthreads)1163 status_t jit_sve_512_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
1164         const convolution_desc_t &cd, memory_desc_t &src_md,
1165         memory_desc_t &weights_md, memory_desc_t &dst_md,
1166         memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) {
1167     using namespace prop_kind;
1168 
1169     const memory_desc_wrapper src_d(&src_md);
1170     const memory_desc_wrapper weights_d(&weights_md);
1171     const memory_desc_wrapper dst_d(&dst_md);
1172     const memory_desc_wrapper bias_d(&bias_md);
1173 
1174     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1175     const int ndims = src_d.ndims();
1176     const bool is_1d = ndims == 3;
1177     const bool is_2d = ndims == 4;
1178     const bool is_3d = ndims == 5;
1179     assert(is_1d || is_2d || is_3d);
1180 
1181     if (!(mayiuse(sve_512)
1182                 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
1183                 && weights_d.data_type() == data_type::s8
1184                 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
1185                         data_type::s8, data_type::u8)))
1186         return status::unimplemented;
1187 
1188     jcp = zero<decltype(jcp)>();
1189     jcp.nthr = nthreads;
1190     jcp.ndims = ndims;
1191     jcp.prop_kind = cd.prop_kind;
1192     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1193     jcp.mb = src_d.dims()[0];
1194     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1195     jcp.oc_without_padding = jcp.oc;
1196     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1197     jcp.ic_without_padding = jcp.ic;
1198     jcp.id = is_3d ? src_d.dims()[2] : 1;
1199     jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1200     jcp.iw = src_d.dims()[ndims - 1];
1201     jcp.od = is_3d ? dst_d.dims()[2] : 1;
1202     jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1203     jcp.ow = dst_d.dims()[ndims - 1];
1204     jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
1205     jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
1206     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1207     jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
1208     jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
1209     jcp.l_pad = cd.padding[0][ndims - 3];
1210     jcp.stride_d = is_3d ? cd.strides[0] : 1;
1211     jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
1212     jcp.stride_w = cd.strides[ndims - 3];
1213     jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1214 
1215     jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
1216     jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
1217     jcp.dilate_w = cd.dilates[ndims - 3];
1218 
1219     int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1220     int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1221     int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1222     jcp.r_pad = calculate_end_padding(
1223             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1224     jcp.b_pad = calculate_end_padding(
1225             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1226     jcp.back_pad = calculate_end_padding(
1227             jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1228     bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1229             || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1230             || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1231     if (kernel_outside_src) return status::unimplemented;
1232 
1233     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
1234     jcp.need_saturation = utils::one_of(
1235             dst_d.data_type(), data_type::u8, data_type::s8, data_type::s32);
1236     jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
1237 
1238     if (jcp.is_depthwise && is_3d)
1239         // NOTE: 3D depthwise is not currently supported here.
1240         return status::unimplemented;
1241 
1242     if (jcp.is_depthwise) {
1243         jcp.ch_block = 16;
1244         jcp.ic_block = 1;
1245         jcp.oc_block = 1;
1246     } else {
1247         jcp.ch_block = 1;
1248         jcp.ic_block = 16;
1249         jcp.oc_block = 16;
1250 
1251         if (jcp.ngroups == 1) {
1252             /* For non grouped convolutions, pad channels by 16 if needed */
1253             jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1254             jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1255         } else if (jcp.ngroups != 1
1256                 && ((jcp.ic % jcp.ic_block != 0)
1257                         || (jcp.oc % jcp.oc_block != 0))) {
1258             /* For grouped convolutions, oneDNN doesn't support padding.
1259                When channels per group is not multiple of 4, 8, 16, return unimplemented. */
1260             jcp.ic_block = (jcp.ic % 8 == 0) && (jcp.oc % 8 == 0) ? 8 : 4;
1261             jcp.oc_block = jcp.ic_block;
1262         }
1263         if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
1264             return status::unimplemented;
1265     }
1266 
1267     if (!post_ops_ok(jcp, attr)) return status::unimplemented;
1268 
1269     jcp.is_fast_depthwise = true && jcp.is_depthwise
1270             && jcp.ngroups % jcp.ch_block == 0; /* groups not multiple of
1271     ch_block (= 16) would require byte masking for load from src */
1272 
1273     jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
1274             && jcp.kw < 4 && jcp.dilate_w == 0;
1275     if (jcp.is_depthwise) {
1276         jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
1277                 - !jcp.signed_input
1278                 - (!jcp.signed_input || jcp.need_saturation); // both alias
1279     } else {
1280         jcp.max_regs_ur = 31;
1281     }
1282 
1283     auto set_or_check_wei_format = [&]() {
1284         using namespace format_tag;
1285         format_tag_t wei_tag;
1286         if (jcp.ic_block == 16 || jcp.ch_block == 16) {
1287             if (is_3d) {
1288                 wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i;
1289             } else if (is_1d) {
1290                 wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
1291                                       : OIw4i16o4i;
1292             } else {
1293                 assert(is_2d);
1294                 wei_tag = with_groups
1295                         ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i
1296                         : OIhw4i16o4i;
1297             }
1298         } else if (jcp.ic_block == 8) {
1299             assert(with_groups);
1300             wei_tag = is_3d ? gOIdhw2i8o4i : is_2d ? gOIhw2i8o4i : gOIw2i8o4i;
1301         } else {
1302             assert(with_groups && jcp.ic_block == 4);
1303             wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
1304         }
1305 
1306         memory_desc_t want_wei_md = weights_md;
1307         memory_desc_init_by_tag(want_wei_md, wei_tag);
1308         if (!jcp.signed_input) {
1309             want_wei_md.extra.flags = 0
1310                     | memory_extra_flags::compensation_conv_s8s8
1311                     | memory_extra_flags::scale_adjust;
1312             want_wei_md.extra.compensation_mask = (1 << 0)
1313                     + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
1314             want_wei_md.extra.scale_adjust = 1.f;
1315         }
1316 
1317         if (weights_md.format_kind == format_kind::any) {
1318             weights_md = want_wei_md;
1319             return true;
1320         }
1321 
1322         return weights_md == want_wei_md;
1323     };
1324 
1325     if (!set_or_check_wei_format()) return status::unimplemented;
1326 
1327     format_tag_t dat_tag = utils::pick(
1328             ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1329 
1330     if (src_d.format_kind() == format_kind::any) {
1331         CHECK(memory_desc_init_by_tag(src_md, dat_tag));
1332         jcp.src_tag = dat_tag;
1333     } else {
1334         jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1335     }
1336     if (jcp.src_tag != dat_tag) return status::unimplemented;
1337 
1338     if (dst_d.format_kind() == format_kind::any) {
1339         CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
1340         jcp.dst_tag = dat_tag;
1341     } else {
1342         jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
1343     }
1344     if (jcp.dst_tag != dat_tag) return status::unimplemented;
1345 
1346     if (jcp.with_bias) {
1347         if (bias_d.format_kind() == format_kind::any)
1348             CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
1349     }
1350 
1351     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1352     jcp.dst_dt = cd.dst_desc.data_type;
1353 
1354     jcp.typesize_in = types::data_type_size(src_d.data_type());
1355     jcp.typesize_out = types::data_type_size(dst_d.data_type());
1356     jcp.typesize_bia
1357             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
1358 
1359     jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
1360     jcp.nb_ic = jcp.ic / jcp.ic_block;
1361     jcp.nb_oc = jcp.oc / jcp.oc_block;
1362 
1363     // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
1364     int nb_ch_blocking = 4;
1365     for (/* init above */; nb_ch_blocking > 1; nb_ch_blocking--)
1366         if (jcp.nb_ch % nb_ch_blocking == 0) break;
1367     jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
1368 
1369     // If OC blocking is incommensurate with the number of OC blocks (general
1370     // requirement for all convolutions), or if it results in an unrolling
1371     // factor smaller than the left padding (special requirement for SSD:fc6),
1372     // then search for a smaller OC blocking that satisfies both constraints.
1373     auto is_oc_blocking_ok = [&](int block) {
1374         int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
1375         return jcp.nb_oc % block == 0 && jcp.l_pad <= ur_w
1376                 && jcp.ow % ur_w != 1;
1377     };
1378 
1379     // choose nb_oc work chunk size for distribution within threads
1380     int max_threading_nb_oc_chunk = 4;
1381     jcp.nb_oc_blocking_thr_chunk
1382             = nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
1383     for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
1384         if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) break;
1385     }
1386 
1387     // choose oc blocking for computational kernel
1388     jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
1389 
1390     if (jcp.is_resrc_depthwise)
1391         jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
1392                 / (jcp.nb_ch_blocking + jcp.stride_w);
1393     else
1394         jcp.ur_w = jcp.max_regs_ur
1395                 / (jcp.is_depthwise ? jcp.nb_ch_blocking
1396                                     : jcp.nb_oc_blocking + 1);
1397     if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1398     if (!jcp.is_depthwise && jcp.ur_w < jcp.ow) {
1399         // tune ur_w such that penultimate ur_w block (including ur_w_tail)
1400         // does not read past the end of src
1401         const int broadcast_size = 4;
1402         if (jcp.ic_without_padding % broadcast_size != 0) {
1403             while (jcp.ur_w > 0) {
1404                 int last_block_size = (jcp.ow % jcp.ur_w == 0)
1405                         ? jcp.ur_w
1406                         : jcp.ow % jcp.ur_w;
1407                 int penultimate_iw_index
1408                         = (jcp.ow - 1 - last_block_size) * jcp.stride_w
1409                         + (jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad;
1410                 int penultimate_iw_leeway = (jcp.iw - 1 - penultimate_iw_index)
1411                                 * jcp.ic_without_padding
1412                         + jcp.ic_without_padding % broadcast_size;
1413                 if (penultimate_iw_leeway >= broadcast_size) break;
1414                 --jcp.ur_w;
1415             }
1416             if (jcp.ur_w == 0) // no satisfactory ur_w could be found
1417                 return status::unimplemented;
1418         }
1419     }
1420     jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1421 
1422     jcp.ow_block = jcp.ow;
1423     int base_work_amount = jcp.mb * jcp.nb_ch * jcp.od * jcp.oh
1424             * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
1425     float best_thr_eff
1426             = (float)base_work_amount / rnd_up(base_work_amount, jcp.nthr);
1427     int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
1428     for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1429         int ow_block
1430                 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
1431         if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
1432                 && best_thr_eff > 0.8f)
1433             break;
1434         if (div_up(jcp.ow, ow_block) != nb_ow) continue;
1435         auto work_amount = base_work_amount * nb_ow;
1436         float thr_eff = (float)work_amount / rnd_up(work_amount, jcp.nthr);
1437         if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
1438             jcp.ow_block = ow_block;
1439             best_thr_eff = thr_eff;
1440         }
1441         if (best_thr_eff > 0.9f) break;
1442     }
1443     jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1444 
1445     bool args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.l_pad <= jcp.ur_w;
1446     if (!args_ok) return status::unimplemented;
1447 
1448     int r_pad_no_tail = nstl::max(0,
1449             calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
1450                     jcp.stride_w, ext_kw));
1451     if (r_pad_no_tail > jcp.ur_w) return status::unimplemented;
1452 
1453     pick_loop_order(jcp, jcp.nthr);
1454 
1455     const auto &oscales = attr.output_scales_;
1456     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1457 
1458     // only common and per-oc-channel scales are supported
1459     const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
1460     if (!oscales_ok) return status::unimplemented;
1461 
1462     jcp.wei_adj_scale
1463             = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
1464             ? weights_d.extra().scale_adjust
1465             : 1.f;
1466 
1467     return status::success;
1468 }
1469 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)1470 void jit_sve_512_x8s8s32x_fwd_kernel::init_scratchpad(
1471         memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1472         const primitive_attr_t &attr) {}
1473 
1474 } // namespace aarch64
1475 } // namespace cpu
1476 } // namespace impl
1477 } // namespace dnnl
1478