1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include "common/c_types_map.hpp"
19 #include "common/memory.hpp"
20 #include "common/memory_tracking.hpp"
21 #include "common/nstl.hpp"
22 #include "common/type_helpers.hpp"
23 #include "common/utils.hpp"
24
25 #include "cpu/aarch64/jit_sve_512_x8s8s32x_conv_kernel.hpp"
26
27 #define GET_OFF(field) static_cast<int32_t>(offsetof(jit_conv_call_s, field))
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 namespace aarch64 {
33
34 using namespace dnnl::impl::memory_tracking::names;
35 using namespace dnnl::impl::utils;
36 using namespace dnnl::impl::data_type;
37
38 namespace {
pick_loop_order(jit_conv_conf_t & jcp,int nthr)39 void pick_loop_order(jit_conv_conf_t &jcp, int nthr) {
40 jcp.loop_order = loop_cwgn;
41 if (jcp.ngroups > 1) {
42 jcp.loop_order = loop_ngcw;
43 if (jcp.mb < nthr)
44 jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
45 }
46 }
47 } // namespace
48
prepare_output(int ur_w)49 void jit_sve_512_x8s8s32x_fwd_kernel::prepare_output(int ur_w) {
50 int nb_oc_block
51 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
52 for (int k = 0; k < nb_oc_block; k++)
53 for (int j = 0; j < ur_w; j++) {
54 auto vmm = vmm_out(j, k);
55 eor(vmm.d, vmm.d, vmm.d);
56 }
57 if (!jcp.signed_input) {
58 eor(reg_scratch, reg_scratch, reg_scratch);
59 if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
60 mov_imm(WReg(reg_tmp0_imm.getIdx()), 128);
61 dup(vmm_shift.s, WReg(reg_tmp0_imm.getIdx()));
62 } else {
63 dup(vmm_shift.b, -128);
64 }
65 }
66 }
67
cvt2ps(data_type_t type_in,const ZReg vmm_in,const XReg reg_base,const int offset,bool mask_flag)68 void jit_sve_512_x8s8s32x_fwd_kernel::cvt2ps(data_type_t type_in,
69 const ZReg vmm_in, const XReg reg_base, const int offset,
70 bool mask_flag) {
71
72 auto vmm = vmm_in;
73 auto reg_addr = get_comp_addr_reg(reg_base, offset);
74 switch (type_in) {
75 case data_type::f32:
76 case data_type::s32:
77 if (mask_flag)
78 ld1w(vmm.s, ktail_mask / T_z, ptr(reg_addr));
79 else
80 ld1w(vmm.s, mask_all_one, ptr(reg_addr));
81 break;
82 case data_type::s8:
83 sub(reg_stack, reg_stack, 64);
84 str(vmm_tmp, ptr(reg_stack));
85 vmm_load_src(vmm_tmp, reg_addr, mask_flag);
86 zip1(vmm_tmp.b, vmm_tmp.b, vmm_tmp.b);
87 zip1(vmm_tmp.h, vmm_tmp.h, vmm_tmp.h);
88 sxtb(vmm.s, mask_all_one / T_m, vmm_tmp.s);
89 if (mask_flag) {
90 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
91 mov(vmm.s, mask_tmp / T_m, 0);
92 }
93 ldr(vmm_tmp, ptr(reg_stack));
94 add(reg_stack, reg_stack, 64);
95 break;
96 case data_type::u8:
97 sub(reg_stack, reg_stack, 64);
98 str(vmm_tmp, ptr(reg_stack));
99 vmm_load_src(vmm_tmp, reg_addr, mask_flag);
100 zip1(vmm_tmp.b, vmm_tmp.b, vmm_tmp.b);
101 zip1(vmm_tmp.h, vmm_tmp.h, vmm_tmp.h);
102 uxtb(vmm.s, mask_all_one / T_m, vmm_tmp.s);
103 if (mask_flag) {
104 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
105 mov(vmm.s, mask_tmp / T_m, 0);
106 }
107 ldr(vmm_tmp, ptr(reg_stack));
108 add(reg_stack, reg_stack, 64);
109 break;
110 default: assert(!"unsupported data type");
111 }
112 if (type_in != data_type::f32) scvtf(vmm_in.s, mask_all_one, vmm_in.s);
113 }
114
store_output(int ur_w,bool last_oc_block_flag)115 void jit_sve_512_x8s8s32x_fwd_kernel::store_output(
116 int ur_w, bool last_oc_block_flag) {
117 int nb_oc_block
118 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
119 int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
120
121 ldr(reg_bias, ptr(reg_param1, GET_OFF(bias)));
122 ldr(reg_ptr_scales, ptr(reg_param1, GET_OFF(scales)));
123 if (!jcp.signed_input)
124 ldr(reg_compensation, ptr(reg_param1, GET_OFF(compensation)));
125
126 const auto &p = attr_.post_ops_;
127 const int sum_idx = p.find(primitive_kind::sum);
128 const float *p_sum_scale = nullptr;
129 if (sum_idx != -1) {
130 const auto &p_entry = p.entry_[sum_idx];
131 p_sum_scale = &p_entry.sum.scale;
132 }
133
134 if (p_sum_scale && *p_sum_scale != 1.f)
135 mov_imm(reg_ptr_sum_scale, (size_t)p_sum_scale);
136
137 for (int k = 0; k < nb_oc_block; k++) {
138 const bool mask_flag
139 = last_oc_block_flag && k == nb_oc_block - 1 && mask_gflag;
140 int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
141 if (jcp.with_bias) {
142 int bias_offset = jcp.typesize_bia * k * oc_block;
143
144 cvt2ps(jcp.bia_dt, vmm_bias, reg_bias, bias_offset, mask_flag);
145 }
146 if (!jcp.signed_input) {
147 int comp_offset = sizeof(int32_t) * k * oc_block;
148
149 cvt2ps(data_type::s32, vmm_comp, reg_compensation, comp_offset,
150 mask_flag);
151 }
152 /* optimization under specific conditions: preload scale_offset data */
153 if (!jcp.is_fast_depthwise && jcp.signed_input) {
154 auto reg_addr = get_comp_addr_reg(reg_ptr_scales, scale_offset);
155 ld1w(vmm_pre_load.s, mask_all_one, ptr(reg_addr));
156 }
157 /* add to accum: compensation, bias and permute */
158 for (int j = 0; j < ur_w; j++) {
159 auto vmm = vmm_out(j, k);
160 if (jcp.is_fast_depthwise) {
161 auto zmm = zmm_out(j, k);
162 auto zmm_tmp1 = ZReg(31);
163 auto zmm_tmp2 = ZReg(30);
164 auto zmm_tmp3 = ZReg(29);
165 sub(reg_stack, reg_stack, 64);
166 str(zmm_tmp1, ptr(reg_stack));
167 sub(reg_stack, reg_stack, 64);
168 str(zmm_tmp2, ptr(reg_stack));
169 sub(reg_stack, reg_stack, 64);
170 str(zmm_tmp3, ptr(reg_stack));
171 mov(zmm_tmp1.s, 15);
172 and_(zmm_tmp1.b, mask_all_one, zmm_permute.b);
173 for (int i = 0; i < 16; i++) {
174 cmpeq(mask_tmp.s, mask_all_one, zmm_tmp1.s, i);
175 dup(zmm_tmp2.s, zmm.s[i]);
176 mov(zmm_tmp3.s, mask_tmp / T_m, zmm_tmp2.s);
177 }
178 mov(zmm.d, zmm_tmp3.d);
179 ldr(zmm_tmp3, ptr(reg_stack));
180 add(reg_stack, reg_stack, 64);
181 ldr(zmm_tmp2, ptr(reg_stack));
182 add(reg_stack, reg_stack, 64);
183 ldr(zmm_tmp1, ptr(reg_stack));
184 add(reg_stack, reg_stack, 64);
185 }
186 scvtf(vmm.s, mask_all_one, vmm.s);
187 if (!jcp.signed_input) fsub(vmm.s, vmm.s, vmm_comp.s);
188 if (jcp.with_bias) fadd(vmm.s, vmm.s, vmm_bias.s);
189
190 if (!jcp.is_fast_depthwise && jcp.signed_input) {
191 /* optimization under specific conditions: optimize using preloaded scale_offset data */
192 fmul(vmm.s, vmm.s, vmm_pre_load.s);
193 if (mask_flag) {
194 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
195 mov(vmm.s, mask_tmp / T_m, 0);
196 }
197 } else {
198 auto reg_addr = get_comp_addr_reg(reg_ptr_scales, scale_offset);
199 sub(reg_stack, reg_stack, 64);
200 str(vmm_tmp, ptr(reg_stack));
201 ld1w(vmm_tmp.s, mask_all_one, ptr(reg_addr));
202 fmul(vmm.s, vmm.s, vmm_tmp.s);
203 ldr(vmm_tmp, ptr(reg_stack));
204 add(reg_stack, reg_stack, 64);
205 if (mask_flag) {
206 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
207 mov(vmm.s, mask_tmp / T_m, 0);
208 }
209 }
210 }
211 }
212
213 /* Do post-ops */
214 if (p_sum_scale) { // post_op: sum
215 for (int k = 0; k < nb_oc_block; k++) {
216 const bool mask_flag
217 = last_oc_block_flag && k == nb_oc_block - 1 && mask_gflag;
218 for (int j = 0; j < ur_w; j++) {
219 int aux_output_offset = jcp.typesize_out
220 * (k * oc_block
221 + j * jcp.oc_without_padding * jcp.ngroups);
222 auto vmm = vmm_out(j, k);
223 cvt2ps(jcp.dst_dt, vmm_prev_dst, reg_out, aux_output_offset,
224 mask_flag);
225 if (*p_sum_scale == 1.f) {
226 fadd(vmm.s, vmm.s, vmm_prev_dst.s);
227 } else {
228 sub(reg_stack, reg_stack, 64);
229 str(vmm_tmp, ptr(reg_stack));
230 ld1rw(vmm_tmp.s, mask_all_one / T_z,
231 ptr(reg_ptr_sum_scale));
232 fmla(vmm.s, mask_all_one / T_m, vmm_prev_dst.s, vmm_tmp.s);
233 ldr(vmm_tmp, ptr(reg_stack));
234 add(reg_stack, reg_stack, 64);
235 }
236 }
237 }
238 }
239
240 // Properly saturate the accumulators for integer datatypes
241 if (one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
242 if (jcp.dst_dt == data_type::u8) {
243 eor(vmm_zero.d, vmm_zero.d, vmm_zero.d);
244 }
245 float saturation_ubound = types::max_value<float>(jcp.dst_dt);
246 mov_imm(aux_reg_saturation, float2int(saturation_ubound));
247 dup(vmm_saturation.s, WReg(aux_reg_saturation.getIdx()));
248
249 for (int k = 0; k < nb_oc_block; k++) {
250 for (int j = 0; j < ur_w; j++) {
251 auto vmm = vmm_out(j, k);
252 if (jcp.dst_dt == data_type::u8) {
253 fmaxnm(vmm.s, mask_all_one, vmm_zero.s);
254 fmax(vmm.s, mask_all_one, vmm_zero.s);
255 }
256 fminnm(vmm.s, mask_all_one, vmm_saturation.s);
257 fmin(vmm.s, mask_all_one, vmm_saturation.s);
258
259 frintn(vmm.s, mask_all_one, vmm.s);
260 fcvtzs(vmm.s, mask_all_one, vmm.s);
261 }
262 }
263 }
264
265 /* write out register to output_addr */
266 for (int k = 0; k < nb_oc_block; k++) {
267 const bool mask_flag
268 = last_oc_block_flag && k == nb_oc_block - 1 && mask_gflag;
269 for (int j = 0; j < ur_w; j++) {
270 int aux_output_offset = jcp.typesize_out
271 * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
272
273 auto base = reg_out;
274 auto re = get_offset(aux_output_offset);
275
276 auto reg_tmp_adr = ((j % 4) == 0) ? reg_tmp0_adr
277 : ((j % 4) == 1)
278 ? reg_tmp1_adr
279 : ((j % 4) == 2) ? reg_tmp2_adr : reg_tmp3_adr;
280 auto reg_tmp_imm = ((j % 4) == 0) ? reg_tmp0_imm
281 : ((j % 4) == 1)
282 ? reg_tmp1_imm
283 : ((j % 4) == 2) ? reg_tmp2_imm : reg_tmp3_imm;
284 add_imm(reg_tmp_adr, base, re, reg_tmp_imm);
285
286 auto vmm = vmm_out(j, k);
287
288 auto _mask = mask_flag ? ktail_mask : mask_all_one;
289 switch (jcp.dst_dt) {
290 case data_type::f32:
291 case data_type::s32:
292 st1w(vmm.s, _mask, ptr(reg_tmp_adr));
293 break;
294 case data_type::s8:
295 smin(vmm.s, 127);
296 smax(vmm.s, -128);
297 st1b(vmm.s, _mask, ptr(reg_tmp_adr));
298 break;
299 case data_type::u8:
300 umin(vmm.s, 255);
301 st1b(vmm.s, _mask, ptr(reg_tmp_adr));
302 break;
303 default: assert(!"unknown dst_dt");
304 }
305 }
306 }
307 }
308
compute_ker_dw(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool h_padded)309 void jit_sve_512_x8s8s32x_fwd_kernel::compute_ker_dw(int ur_w, int pad_l,
310 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
311
312 if (sve_len_ != 64)
313 assert(!"invalid group blocking for depthwise convolution");
314
315 auto input_spatial_index = [=](int oi, int ki) {
316 return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
317 };
318
319 auto input_offset2 = [=](int ii, int ci) {
320 if (jcp.is_fused_conv)
321 return jcp.typesize_in
322 * (ii * jcp.dw_conv_buffer_oc + ci * jcp.ch_block);
323 else
324 return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
325 };
326
327 auto input_offset3 = [=](int oi, int ci, int ki) {
328 return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
329 };
330
331 auto kernel_offset = [=](int ci, int ki) {
332 return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
333 };
334
335 auto compute = [=](ZReg vreg_acc, ZReg vreg_wei, ZReg vreg_src) {
336 sdot(vreg_acc.s, vreg_src.b, vreg_wei.b);
337 };
338
339 int ii_start = 0;
340 int ii_end = -1;
341 if (jcp.is_resrc_depthwise && !h_padded) {
342 // find bounds of input spatial indices
343 bool first = true;
344 for (int ki = 0; ki < jcp.kw; ki++) {
345 int oi_start = get_ow_start(ki, pad_l);
346 int oi_end = get_ow_end(ur_w, ki, pad_r);
347 for (int oi = oi_start; oi < oi_end; oi++) {
348 int ii = input_spatial_index(oi, ki);
349 if (first || ii < ii_start) ii_start = ii;
350 if (first || ii > ii_end) ii_end = ii;
351 first = false;
352 }
353 }
354 }
355
356 if (!jcp.signed_input) {
357 eor(zmm_shifted_zero.d, zmm_shifted_zero.d, zmm_shifted_zero.d);
358 sub(zmm_shifted_zero.b, zmm_shifted_zero.b, vmm_shift.b);
359 }
360
361 for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
362 const bool mask_flag = last_ic_block_flag != no_last_block
363 && ci == jcp.nb_ch_blocking - 1;
364 if (jcp.is_resrc_depthwise && !h_padded) {
365 // now we can load input once and reuse up to jcp.kw times
366 for (int ii = ii_start; ii <= ii_end; ii++) {
367 int aux_input_offset = input_offset2(ii, ci);
368 auto zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
369 auto zmm_inp_msk = zmm_inp_tmp;
370 if (jcp.is_fast_depthwise) {
371 assert(!mask_flag);
372 auto reg_addr
373 = get_comp_addr_reg(aux_reg_inp, aux_input_offset);
374 ldr(QReg(zmm_inp_msk.getIdx()), ptr(reg_addr));
375 ptrue(mask_tmp.d, VL2);
376 splice(zmm_inp_msk.d, mask_tmp.d, zmm_inp_msk.d);
377 ptrue(mask_tmp.d, VL4);
378 splice(zmm_inp_msk.d, mask_tmp.d, zmm_inp_msk.d);
379 } else {
380 auto reg_addr
381 = get_comp_addr_reg(aux_reg_inp, aux_input_offset);
382 auto zmm_tmp = ZReg(31);
383 sub(reg_stack, reg_stack, 64);
384 str(zmm_tmp, ptr(reg_stack));
385 if (mask_flag) {
386 eor(mask_tmp.b, mask_all_one, mask_tmp.b, mask_tmp.b);
387 eor(mask_tmp2.b, mask_all_one, mask_tmp2.b,
388 mask_tmp2.b);
389 uzp1(mask_tmp.h, ktail_mask.h, mask_tmp.h);
390 uzp1(mask_tmp.b, mask_tmp.b, mask_tmp2.b);
391 } else {
392 ptrue(mask_tmp.b, VL16);
393 }
394 ld1b(zmm_tmp.b, mask_tmp, ptr(reg_addr));
395 zip1(zmm_tmp.b, zmm_tmp.b, zmm_tmp.b);
396 zip1(zmm_tmp.h, zmm_tmp.h, zmm_tmp.h);
397 uxtb(zmm_inp_msk.s, mask_all_one / T_m, zmm_tmp.s);
398 if (mask_flag) {
399 not_(mask_tmp.b, mask_all_one.b, ktail_mask.b);
400 mov(zmm_inp_msk.s, mask_tmp / T_m, 0);
401 }
402 ldr(zmm_tmp, ptr(reg_stack));
403 add(reg_stack, reg_stack, 64);
404 }
405 if (!jcp.signed_input)
406 sub(zmm_inp_tmp.b, zmm_inp_tmp.b, vmm_shift.b);
407 }
408 }
409 for (int ki = 0; ki < jcp.kw; ki++) {
410 int aux_kernel_offset = kernel_offset(ci, ki);
411 if (jcp.is_fast_depthwise) {
412 auto reg_addr
413 = get_comp_addr_reg(aux_reg_ker, aux_kernel_offset);
414 ldr(QReg(zmm_wei.getIdx()), ptr(reg_addr));
415 ptrue(mask_tmp.d, VL2);
416 splice(zmm_wei.d, mask_tmp.d, zmm_wei.d);
417 ptrue(mask_tmp.d, VL4);
418 splice(zmm_wei.d, mask_tmp.d, zmm_wei.d);
419 not_(mask_tmp.b, mask_all_one, kblend_mask.b);
420 mov(zmm_wei.b, kblend_mask / T_m, zmm_wei.b);
421 mov(zmm_wei.b, mask_tmp / T_m, 0);
422 } else {
423 auto reg_addr
424 = get_comp_addr_reg(aux_reg_ker, aux_kernel_offset);
425 auto zmm_tmp = ZReg(30);
426 sub(reg_stack, reg_stack, 64);
427 str(zmm_tmp, ptr(reg_stack));
428 ldr(QReg(zmm_tmp.getIdx()), ptr(reg_addr));
429 zip1(zmm_tmp.b, zmm_tmp.b, zmm_tmp.b);
430 zip1(zmm_tmp.h, zmm_tmp.h, zmm_tmp.h);
431 sxtb(zmm_wei.s, mask_all_one / T_m, zmm_tmp.s);
432 ldr(zmm_tmp, ptr(reg_stack));
433 add(reg_stack, reg_stack, 64);
434 }
435 if (h_padded) {
436 assert(!jcp.signed_input);
437 for (int oi = 0; oi < ur_w; oi++)
438 compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
439 } else {
440 auto r_zmm_src = zmm_src;
441 int oi_start = get_ow_start(ki, pad_l);
442 int oi_end = get_ow_end(ur_w, ki, pad_r);
443 int start_ = !jcp.signed_input ? 0 : oi_start;
444 int end_ = !jcp.signed_input ? ur_w : oi_end;
445 for (int oi = start_; oi < end_; oi++) {
446 if (oi >= oi_start && oi < oi_end) {
447 if (jcp.is_resrc_depthwise) {
448 int ii = input_spatial_index(oi, ki);
449 zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
450 } else {
451 int aux_input_offset = input_offset3(oi, ci, ki);
452 if (jcp.is_fast_depthwise) {
453 assert(!mask_flag);
454 auto reg_addr = get_comp_addr_reg(
455 aux_reg_inp, aux_input_offset);
456 ldr(QReg(r_zmm_src.getIdx()), ptr(reg_addr));
457 ptrue(mask_tmp.d, VL2);
458 splice(r_zmm_src.d, mask_tmp.d, r_zmm_src.d);
459 ptrue(mask_tmp.d, VL4);
460 splice(r_zmm_src.d, mask_tmp.d, r_zmm_src.d);
461 } else {
462 auto reg_addr = get_comp_addr_reg(
463 aux_reg_inp, aux_input_offset);
464 auto zmm_tmp = ZReg(31);
465 sub(reg_stack, reg_stack, 64);
466 str(zmm_tmp, ptr(reg_stack));
467 if (mask_flag) {
468 eor(mask_tmp.b, mask_all_one, mask_tmp.b,
469 mask_tmp.b);
470 eor(mask_tmp2.b, mask_all_one, mask_tmp2.b,
471 mask_tmp2.b);
472 uzp1(mask_tmp.h, ktail_mask.h, mask_tmp.h);
473 uzp1(mask_tmp.b, mask_tmp.b, mask_tmp2.b);
474 } else {
475 ptrue(mask_tmp.b, VL16);
476 }
477 ld1b(zmm_tmp.b, mask_tmp, ptr(reg_addr));
478 zip1(zmm_tmp.b, zmm_tmp.b, zmm_tmp.b);
479 zip1(zmm_tmp.h, zmm_tmp.h, zmm_tmp.h);
480 uxtb(r_zmm_src.s, mask_all_one / T_m,
481 zmm_tmp.s);
482 if (mask_flag) {
483 not_(mask_tmp.b, mask_all_one.b,
484 ktail_mask.b);
485 mov(r_zmm_src.s, mask_tmp / T_m, 0);
486 }
487 ldr(zmm_tmp, ptr(reg_stack));
488 add(reg_stack, reg_stack, 64);
489 }
490 if (!jcp.signed_input)
491 sub(zmm_src.b, zmm_src.b, vmm_shift.b);
492 }
493 compute(zmm_out(oi, ci), zmm_wei, zmm_src);
494 } else {
495 assert(!jcp.signed_input);
496 compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
497 }
498 }
499 }
500 }
501 }
502 }
503
compute_ker(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag,bool h_padded)504 void jit_sve_512_x8s8s32x_fwd_kernel::compute_ker(int ur_w, int pad_l,
505 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
506 if (jcp.is_depthwise)
507 return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
508
509 int kw = jcp.kw;
510 int stride_w = jcp.stride_w;
511 int ic_block = jcp.ic_block;
512 int oc_block = jcp.oc_block;
513 int ch_block_all = jcp.ch_block * ic_block * oc_block;
514
515 int nb_oc_block = jcp.nb_oc_blocking;
516
517 auto input_offset = [=](int oi, int ic, int ki) {
518 return jcp.typesize_in
519 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
520 * jcp.ic_without_padding * jcp.ngroups
521 + 4 * ic);
522 };
523 auto kernel_offset = [=](int ii, int ic, int ki) {
524 return jcp.typesize_in
525 * ((ii * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki)
526 * ch_block_all
527 + 4 * ic * oc_block);
528 };
529 auto compute = [=](ZReg vreg_acc, ZReg vreg_wei, ZReg vreg_src) {
530 sdot(ZRegS(vreg_acc.getIdx()), ZRegB(vreg_src.getIdx()),
531 ZRegB(vreg_wei.getIdx()));
532 };
533
534 for (int ki = 0; ki < kw; ki++) {
535 int jj_start = get_ow_start(ki, pad_l);
536 int jj_end = get_ow_end(ur_w, ki, pad_r);
537 int ic_tail_size = jcp.ic_without_padding % 4;
538 int _start = (!jcp.signed_input) ? 0 : jj_start;
539 int _end = (!jcp.signed_input) ? ur_w : jj_end;
540 /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
541 int icb = (last_ic_block_flag != no_last_block)
542 ? div_up((jcp.ic_without_padding % ic_block), 4)
543 : ic_block / 4;
544 for (int ic = 0; ic < icb; ic++) {
545 if (h_padded) {
546 /* fill padded area with shifted values */
547 auto inp = vmm_inp(0, nb_oc_block);
548 eor(inp.d, inp.d, inp.d);
549 sub(inp.b, inp.b, vmm_shift.b);
550 } else {
551 for (int jj = _start; jj < _end; jj++) {
552 int aux_input_offset = input_offset(jj, ic, ki);
553 if (jj >= jj_start && jj < jj_end) {
554 if (last_ic_block_flag == last_sp_block
555 && ic_tail_size != 0 && ic == icb - 1) {
556 auto xmm_tmp = VReg16B(
557 vmm_inp(jj, nb_oc_block).getIdx());
558 for (int r = 0; r < ic_tail_size; ++r) {
559 add_imm(reg_tmp0_adr, aux_reg_inp,
560 (aux_input_offset + r), reg_tmp0_imm);
561 ldrb(WReg(reg_tmp1_imm.getIdx()),
562 ptr(reg_tmp0_adr));
563 ins(VReg16B(xmm_tmp.getIdx())[r],
564 WReg(reg_tmp1_imm.getIdx()));
565 }
566 dup(vmm_inp(jj, nb_oc_block).s,
567 ZRegS(xmm_tmp.getIdx())[0]);
568 } else {
569 auto base = aux_reg_inp;
570 auto re = get_offset(aux_input_offset);
571
572 if ((-0x40 <= re) && (re < 0x40) && ((re % 4) == 0))
573 ld1rw(vmm_inp(jj, nb_oc_block).s, mask_all_one,
574 ptr(base, static_cast<int32_t>(re)));
575 else {
576 auto reg_tmp_adr = ((jj % 4) == 0)
577 ? reg_tmp0_adr
578 : ((jj % 4) == 1) ? reg_tmp1_adr
579 : ((jj % 4) == 2)
580 ? reg_tmp2_adr
581 : reg_tmp3_adr;
582 auto reg_tmp_imm = ((jj % 4) == 0)
583 ? reg_tmp0_imm
584 : ((jj % 4) == 1) ? reg_tmp1_imm
585 : ((jj % 4) == 2)
586 ? reg_tmp2_imm
587 : reg_tmp3_imm;
588 add_imm(reg_tmp_adr, base, re, reg_tmp_imm);
589 ld1rw(vmm_inp(jj, nb_oc_block).s, mask_all_one,
590 ptr(reg_tmp_adr));
591 }
592 }
593 if (!jcp.signed_input)
594 sub(vmm_inp(jj, nb_oc_block).b,
595 vmm_inp(jj, nb_oc_block).b, vmm_shift.b);
596 } else {
597 /* fill padded area with shifted values */
598 if (!jcp.signed_input) {
599 auto inp = vmm_inp(jj, nb_oc_block);
600 eor(inp.d, inp.d, inp.d);
601 sub(inp.b, inp.b, vmm_shift.b);
602 }
603 }
604 }
605 }
606 for (int ii = 0; ii < nb_oc_block; ii++) {
607 if (!jcp.signed_input) {
608 int aux_kernel_offset = kernel_offset(ii, ic, ki);
609 auto reg_addr
610 = get_comp_addr_reg(aux_reg_ker, aux_kernel_offset);
611 ld1w(vmm_wei.s, mask_all_one, ptr(reg_addr));
612 for (int jj = _start; jj < _end; jj++) {
613 auto inp = (h_padded == true)
614 ? vmm_inp(0, nb_oc_block)
615 : vmm_inp(jj, nb_oc_block);
616 compute(vmm_out(jj, ii), vmm_wei, inp);
617 }
618 } else {
619 if (ii == 0) {
620 int aux_kernel_offset = kernel_offset(ii, ic, ki);
621 auto reg_addr = get_comp_addr_reg(
622 aux_reg_ker, aux_kernel_offset);
623 ld1w(vmm_wei.s, mask_all_one, ptr(reg_addr));
624 }
625 if ((ii + 1) < nb_oc_block) {
626 int aux_kernel_offset = kernel_offset((ii + 1), ic, ki);
627 auto _vmm_wei = ((ii % 2) == 0) ? vmm_comp : vmm_wei;
628 auto reg_addr = get_comp_addr_reg(
629 aux_reg_ker, aux_kernel_offset);
630 ld1w(_vmm_wei.s, mask_all_one, ptr(reg_addr));
631 }
632 for (int jj = _start; jj < _end; jj++) {
633 auto _vmm_wei = ((ii % 2) == 0) ? vmm_wei : vmm_comp;
634 auto inp = (h_padded == true)
635 ? vmm_inp(0, nb_oc_block)
636 : vmm_inp(jj, nb_oc_block);
637 compute(vmm_out(jj, ii), _vmm_wei, inp);
638 }
639 }
640 }
641 }
642 }
643 }
644
kh_loop(int ur_w,int pad_l,int pad_r,ic_block_t last_ic_block_flag)645 void jit_sve_512_x8s8s32x_fwd_kernel::kh_loop(
646 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
647 Label kd_label, kh_label, skip_kd_loop, skip_kh_loop;
648 Label f_overflow_label, no_f_overflow_label, d_h_f_overflow_label,
649 t_overflow_label, no_t_overflow_label, b_overflow_label,
650 no_b_overflow_label, back_overflow_label, no_back_overflow_label,
651 d_h_back_overflow_label;
652
653 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
654 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
655 int shift_input_ptr
656 = jcp.typesize_in * jcp.iw * jcp.ic_without_padding * jcp.ngroups;
657
658 if (jcp.ndims == 5) {
659 mov(aux_reg_ker_d, reg_ker);
660 mov(aux_reg_inp_d, reg_inp);
661 if (!jcp.signed_input) {
662 //TODO: May be avoided when f_pad=0 and dd0
663 //TODO: Potential optimization by precomputing, when kd <<< od?
664 ldr(reg_ki, ptr(reg_param1, GET_OFF(f_overflow)));
665 cmp(reg_ki, 0);
666 b(EQ, no_f_overflow_label);
667 L(f_overflow_label);
668 {
669 mov(aux_reg_ker, aux_reg_ker_d);
670 mov_imm(reg_kj, jcp.kh);
671 L(d_h_f_overflow_label);
672 {
673 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
674 adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr,
675 reg_tmp0_imm);
676 subs(reg_kj, reg_kj, 1);
677 b(NE, d_h_f_overflow_label);
678 }
679 add_imm(aux_reg_ker_d, aux_reg_ker_d, shift_kernel_ptr * jcp.kh,
680 reg_tmp0_imm);
681 subs(reg_ki, reg_ki, 1);
682 b(NE, f_overflow_label);
683 }
684 L(no_f_overflow_label);
685 }
686
687 ldr(reg_ki, ptr(reg_param1, GET_OFF(kd_padding)));
688 if ((!jcp.signed_input) || (jcp.dilate_d >= jcp.id)
689 || (jcp.signed_input
690 && (jcp.kd - 1) * (jcp.dilate_d + 1)
691 < nstl::max(jcp.f_pad, jcp.back_pad))) {
692 cmp(reg_ki, 0);
693 b(EQ, skip_kd_loop);
694 }
695 L(kd_label);
696 mov(aux_reg_inp, aux_reg_inp_d);
697 mov(aux_reg_ker, aux_reg_ker_d);
698 } else {
699 if (jcp.is_fused_conv) {
700 mov(aux_reg_inp_buffer_ptr, reg_inp_buffer_ptr);
701 } else {
702 mov(aux_reg_inp, reg_inp);
703 }
704 mov(aux_reg_ker, reg_ker);
705 }
706
707 if (!jcp.signed_input && jcp.ndims > 3) {
708 ldr(reg_overflow, ptr(reg_param1, GET_OFF(t_overflow)));
709 cmp(reg_overflow, 0);
710 b(EQ, no_t_overflow_label);
711 L(t_overflow_label);
712 {
713 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
714
715 adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr, reg_tmp0_imm);
716 subs(reg_overflow, reg_overflow, 1);
717 cmp(reg_overflow, 0);
718 b(GT, t_overflow_label);
719 }
720 L(no_t_overflow_label);
721 }
722 ldr(reg_kj, ptr(reg_param1, GET_OFF(kh_padding)));
723 if ((!jcp.signed_input) || (jcp.dilate_h >= jcp.ih)
724 || (jcp.signed_input
725 && (jcp.kh - 1) * (jcp.dilate_h + 1)
726 < nstl::max(jcp.t_pad, jcp.b_pad))) {
727 cmp(reg_kj, 0);
728 b(EQ, skip_kh_loop);
729 }
730 L(kh_label);
731 {
732 if (jcp.is_fused_conv) {
733 ldr(aux_reg_inp, ptr(aux_reg_inp_buffer_ptr));
734 add(aux_reg_inp, aux_reg_inp, reg_inp);
735 }
736 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
737
738 adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr, reg_tmp0_imm);
739 if (jcp.is_fused_conv) {
740 adds_imm(aux_reg_inp_buffer_ptr, aux_reg_inp_buffer_ptr,
741 sizeof(void *), reg_tmp0_imm);
742 } else {
743 adds_imm(aux_reg_inp, aux_reg_inp,
744 shift_input_ptr * (jcp.dilate_h + 1), reg_tmp0_imm);
745 }
746 subs(reg_kj, reg_kj, 1);
747 cmp(reg_kj, 0);
748 b(GT, kh_label);
749 }
750 L(skip_kh_loop);
751 if (!jcp.signed_input && jcp.ndims > 3) {
752 ldr(reg_overflow, ptr(reg_param1, GET_OFF(b_overflow)));
753 cmp(reg_overflow, 0);
754 b(EQ, no_b_overflow_label);
755 L(b_overflow_label);
756 {
757 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
758
759 adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr, reg_tmp0_imm);
760 subs(reg_overflow, reg_overflow, 1);
761 cmp(reg_overflow, 0);
762 b(GT, b_overflow_label);
763 }
764 L(no_b_overflow_label);
765 }
766
767 if (jcp.ndims == 5) {
768 adds_imm(aux_reg_inp_d, aux_reg_inp_d,
769 shift_input_ptr * jcp.ih * (jcp.dilate_d + 1), reg_tmp0_imm);
770 adds_imm(aux_reg_ker_d, aux_reg_ker_d, shift_kernel_ptr * jcp.kh,
771 reg_tmp0_imm);
772 subs(reg_ki, reg_ki, 1);
773 b(NE, kd_label);
774
775 L(skip_kd_loop);
776 if (!jcp.signed_input) {
777 ldr(reg_ki, ptr(reg_param1, GET_OFF(back_overflow)));
778 cmp(reg_ki, 0);
779 b(EQ, no_back_overflow_label);
780 L(back_overflow_label);
781 {
782 mov(aux_reg_ker, aux_reg_ker_d);
783 mov(reg_kj, jcp.kh);
784 L(d_h_back_overflow_label);
785 {
786 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
787 adds_imm(aux_reg_ker, aux_reg_ker, shift_kernel_ptr,
788 reg_tmp0_imm);
789 subs(reg_kj, reg_kj, 1);
790 b(NE, d_h_back_overflow_label);
791 }
792 adds_imm(aux_reg_ker_d, aux_reg_ker_d,
793 shift_kernel_ptr * jcp.kh, reg_tmp0_imm);
794 subs(reg_ki, reg_ki, 1);
795 b(NE, back_overflow_label);
796 }
797 L(no_back_overflow_label);
798 }
799 }
800 }
801
icb_loop(int ur_w,int pad_l,int pad_r,bool is_last_sp_block)802 void jit_sve_512_x8s8s32x_fwd_kernel::icb_loop(
803 int ur_w, int pad_l, int pad_r, bool is_last_sp_block) {
804 prepare_output(ur_w);
805
806 // IC loop
807 Label icb_label;
808 mov_imm(reg_icb, jcp.nb_ic);
809 L(icb_label);
810 if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
811 Label common_ker, end_ker;
812
813 if (jcp.is_depthwise)
814 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
815 else
816 cmp(reg_icb, 1); // The last IC block
817 b(NE, common_ker);
818
819 kh_loop(ur_w, pad_l, pad_r,
820 is_last_sp_block ? last_sp_block : last_ic_block);
821 b(end_ker);
822
823 L(common_ker);
824 kh_loop(ur_w, pad_l, pad_r, no_last_block);
825
826 L(end_ker);
827 } else {
828 kh_loop(ur_w, pad_l, pad_r, no_last_block);
829 }
830 // End of IC Loop
831 int inp_step = jcp.ic_block;
832 int ker_step = jcp.kd * jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
833 adds_imm(reg_inp, reg_inp, jcp.typesize_in * inp_step, reg_tmp0_imm);
834 adds_imm(reg_ker, reg_ker, jcp.typesize_in * ker_step, reg_tmp0_imm);
835
836 subs(reg_icb, reg_icb, 1);
837 cmp(reg_icb, 0);
838 b(GT, icb_label);
839
840 subs_imm(reg_inp, reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic,
841 reg_tmp0_imm);
842 subs_imm(reg_ker, reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic,
843 reg_tmp0_imm);
844
845 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
846 Label common_store, end_store;
847
848 if (jcp.is_depthwise)
849 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
850 else
851 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
852
853 b(NE, common_store);
854
855 store_output(ur_w, true); // last oc block
856 b(end_store);
857
858 L(common_store);
859 store_output(ur_w, false);
860
861 L(end_store);
862 } else {
863 store_output(ur_w, false);
864 }
865 }
866
vmm_mask_all_one()867 void jit_sve_512_x8s8s32x_fwd_kernel::vmm_mask_all_one() {
868 mask_gflag = false;
869 if (sve_len_ == 64) {
870 mask_gflag = true;
871 ptrue(mask_all_one.b);
872 } else if (sve_len_ == 32) {
873 ptrue(mask_all_one.b, VL32);
874 } else if (sve_len_ == 16) {
875 ptrue(mask_all_one.b, VL16);
876 } else {
877 assert(!"unreachable");
878 }
879 }
880
vmm_load_src(ZReg src,XReg reg_addr,bool mask_flag)881 void jit_sve_512_x8s8s32x_fwd_kernel::vmm_load_src(
882 ZReg src, XReg reg_addr, bool mask_flag) {
883 if (mask_flag) {
884 eor(mask_tmp.b, mask_all_one, mask_tmp.b, mask_tmp.b);
885 eor(mask_tmp2.b, mask_all_one, mask_tmp2.b, mask_tmp2.b);
886 uzp1(mask_tmp.h, ktail_mask.h, mask_tmp.h);
887 uzp1(mask_tmp.b, mask_tmp.b, mask_tmp2.b);
888 } else {
889 if (sve_len_ == 64)
890 ptrue(mask_tmp.b, VL16);
891 else if (sve_len_ == 32)
892 ptrue(mask_tmp.b, VL8);
893 else if (sve_len_ == 16)
894 ptrue(mask_tmp.b, VL4);
895 else
896 assert(!"unreabhable");
897 }
898
899 ld1b(src.b, mask_tmp, ptr(reg_addr));
900 }
901
generate()902 void jit_sve_512_x8s8s32x_fwd_kernel::generate() {
903 Label permute_index_table;
904 int in_ic_shift = jcp.is_fused_conv ? jcp.dw_conv_buffer_oc
905 : jcp.ic_without_padding * jcp.ngroups;
906 int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
907 * in_ic_shift;
908 int inp_shift_pad_second_block
909 = -1 * jcp.typesize_in * jcp.l_pad * in_ic_shift;
910 int inp_shift = jcp.typesize_in * (jcp.ur_w * jcp.stride_w * in_ic_shift);
911 int out_shift = jcp.typesize_out
912 * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
913 preamble();
914
915 vmm_mask_all_one();
916
917 if (jcp.is_depthwise) {
918 int idx = jcp.max_regs_ur - 1;
919 if (!jcp.is_resrc_depthwise) zmm_src = ZReg(++idx);
920 if (jcp.is_fast_depthwise) zmm_permute = ZReg(++idx);
921 if (!jcp.signed_input) zmm_shifted_zero = ZReg(++idx);
922 // due to extra register used for shifts and compensations
923 // and/or saturation, we increment by one more
924 if (!jcp.signed_input || jcp.need_saturation) ++idx;
925 assert(idx == ker_dw_reg_base_idx);
926 }
927
928 if (jcp.is_fused_conv) {
929 ldr(reg_inp_buffer_ptr, ptr(reg_param1, GET_OFF(src)));
930 /* In case of fused depthwise convolution, `param.src` is not a pointer
931 to input, instead it points to a buffer containing pointers to
932 consecutive rows of input in format wc with c=jcp.dw_conv_buffer_oc.
933 */
934 mov_imm(reg_inp, 0);
935 } else {
936 ldr(reg_inp, ptr(reg_param1, GET_OFF(src)));
937 }
938 ldr(reg_out, ptr(reg_param1, GET_OFF(dst)));
939 ldr(reg_ker, ptr(reg_param1, GET_OFF(filt)));
940
941 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
942 int tail_size = jcp.is_depthwise
943 ? jcp.ngroups % jcp.ch_block
944 : jcp.oc_without_padding % jcp.oc_block;
945 int mask = (1 << tail_size) - 1;
946 ldr(reg_oc_blocks, ptr(reg_param1, GET_OFF(oc_blocks)));
947 auto regw_tmp = reg_oi;
948 mov(regw_tmp, mask);
949 auto vmm_tmp1 = ZReg(31);
950 auto vmm_tmp2 = ZReg(30);
951 index(vmm_tmp1.s, 0, 1);
952 mov(vmm_tmp2.s, 1);
953 lsl(vmm_tmp2.s, mask_all_one / T_m, vmm_tmp1.s);
954 dup(vmm_tmp1.s, WReg(regw_tmp.getIdx()));
955 and_(vmm_tmp1.d, vmm_tmp1.d, vmm_tmp2.d);
956 cmpne(ktail_mask.s, mask_all_one, vmm_tmp1.s, 0);
957 }
958 if (jcp.is_fast_depthwise) {
959 // prepare mask register for blending weights
960 movk(reg_scratch, uint16_t(0x1111), 0);
961 movk(reg_scratch, uint16_t(0x2222), 16);
962 movk(reg_scratch, uint16_t(0x4444), 32);
963 movk(reg_scratch, uint16_t(0x8888), 48);
964 sub(reg_stack, reg_stack, 8);
965 str(reg_scratch, ptr(reg_stack));
966 ldr(kblend_mask, ptr(reg_stack));
967 add(reg_stack, reg_stack, 8);
968 // load permute indices from data section
969 adr(reg_scratch, permute_index_table);
970 ld1w(zmm_permute.s, mask_all_one, ptr(reg_scratch));
971 }
972
973 int r_pad = nstl::max(0, jcp.r_pad);
974 int n_oi = jcp.ow / jcp.ur_w;
975 int r_pad1 = calculate_end_padding(jcp.l_pad, jcp.ur_w * n_oi, jcp.iw,
976 jcp.stride_w, calculate_extended_filter_size(jcp.kw, jcp.dilate_w));
977
978 if (jcp.nb_ow == 1) {
979 if (r_pad1 > 0 || jcp.ur_w_tail == 0) n_oi--;
980
981 eor(reg_oi, reg_oi, reg_oi);
982 if (jcp.ow == jcp.ur_w) {
983 icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
984 } else {
985 if (n_oi == 0) {
986 icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
987 adds_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp0_imm);
988 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
989 if (jcp.ur_w_tail != 0) {
990 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
991 }
992 } else {
993 if (jcp.l_pad > 0) {
994 icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
995 adds_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp0_imm);
996 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
997
998 adds(reg_oi, reg_oi, 1);
999 }
1000 if ((jcp.l_pad <= 0 && n_oi > 0)
1001 || (jcp.l_pad > 0 && n_oi > 1)) {
1002 Label ow_loop_label;
1003 L(ow_loop_label);
1004 {
1005 icb_loop(jcp.ur_w, 0, 0, false);
1006 adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1007 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1008
1009 adds(reg_oi, reg_oi, 1);
1010 mov_imm(reg_tmp0_imm, n_oi);
1011 cmp(reg_oi, reg_tmp0_imm);
1012 b(LT, ow_loop_label);
1013 }
1014 }
1015 if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
1016 icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
1017 adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1018 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1019 }
1020 if (jcp.ur_w_tail != 0) {
1021 icb_loop(jcp.ur_w_tail, 0, r_pad, true);
1022 }
1023 }
1024 }
1025 } else {
1026 // ow block is only processed.
1027 // Number of block is passed as parameter owb,
1028 // and padding processing depends on this number.
1029 Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
1030 oi_loop_label, oi_loop_end_label;
1031
1032 assert(jcp.ow_block % jcp.ur_w == 0);
1033 int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
1034 // to simplify code (and general regs usage),
1035 // size of ow block must be >= 2 * ur_w
1036 assert(n_oi_not_last_ow_block > 1);
1037 int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
1038 int n_oi_first_ow_block = n_oi_not_last_ow_block;
1039 int n_oi_last_ow_block
1040 = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
1041 // prepare right padding
1042 bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
1043 bool first_ow_block_padded
1044 = next_last_ow_block_padded && jcp.nb_ow == 2;
1045 bool last_ow_block_padded
1046 = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
1047
1048 if (last_ow_block_padded)
1049 n_oi_last_ow_block--;
1050 else if (first_ow_block_padded)
1051 n_oi_first_ow_block--;
1052 else if (next_last_ow_block_padded)
1053 n_oi_next_last_ow_block--;
1054
1055 ldr(reg_owb, ptr(reg_param1, GET_OFF(owb)));
1056 cmp(reg_owb, 0); // is that the first ow-block ?
1057 b(GT, middle_ow_blocks_label);
1058
1059 // the first ow block, compute left padding
1060 mov_imm(reg_oi, n_oi_first_ow_block);
1061 if (jcp.l_pad > 0) {
1062 icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
1063 adds_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp0_imm);
1064 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1065
1066 subs(reg_oi, reg_oi, 1);
1067 }
1068 b(oi_loop_label);
1069
1070 // middle or last ow block entry
1071 L(middle_ow_blocks_label);
1072
1073 if (jcp.l_pad > 0) {
1074 // just to consider left padding, not compute
1075 adds_imm(
1076 reg_inp, reg_inp, inp_shift_pad_second_block, reg_tmp0_imm);
1077 }
1078
1079 // set number of iteration for oi-loop
1080 if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
1081 cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
1082 mov_imm(reg_oi, n_oi_last_ow_block);
1083 b(EQ, oi_loop_label);
1084 }
1085
1086 if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
1087 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1088
1089 mov_imm(reg_oi, n_oi_next_last_ow_block);
1090 b(EQ, oi_loop_label);
1091 }
1092 mov_imm(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
1093
1094 // oi loop w/o padding
1095 L(oi_loop_label);
1096 {
1097 cmp(reg_oi, 0);
1098 b(LE, oi_loop_end_label);
1099
1100 icb_loop(jcp.ur_w, 0, 0, false);
1101
1102 adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1103 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1104 subs(reg_oi, reg_oi, 1);
1105
1106 b(oi_loop_label);
1107 }
1108 L(oi_loop_end_label);
1109
1110 ldr(reg_owb, ptr(reg_param1, GET_OFF(owb)));
1111 cmp(reg_owb, 0); // first ow-block ?
1112 if (first_ow_block_padded)
1113 b(EQ, last_oi_label);
1114 else
1115 b(EQ, end_label);
1116
1117 cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
1118 b(LT, end_label);
1119 if (next_last_ow_block_padded)
1120 b(EQ, last_oi_label);
1121 else
1122 b(EQ, end_label);
1123
1124 // that is last block
1125 if (!last_ow_block_padded) b(tail_label);
1126
1127 // last oi block with right padding
1128 L(last_oi_label);
1129 icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
1130 adds_imm(reg_inp, reg_inp, inp_shift, reg_tmp0_imm);
1131 adds_imm(reg_out, reg_out, out_shift, reg_tmp0_imm);
1132
1133 ldr(reg_owb, ptr(reg_param1, GET_OFF(owb)));
1134 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1135 b(LT, end_label);
1136
1137 // ur_w tail
1138 L(tail_label);
1139 if (jcp.ur_w_tail != 0) { icb_loop(jcp.ur_w_tail, 0, r_pad, true); }
1140 L(end_label);
1141 }
1142 postamble();
1143
1144 if (jcp.is_fast_depthwise) {
1145 align(64);
1146 L(permute_index_table);
1147 const uint32_t _idx[]
1148 = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
1149 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
1150 dd(_idx[i]);
1151 }
1152 }
1153
post_ops_ok(jit_conv_conf_t & jcp,const primitive_attr_t & attr)1154 bool jit_sve_512_x8s8s32x_fwd_kernel::post_ops_ok(
1155 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1156 using namespace primitive_kind;
1157 const auto &p = attr.post_ops_;
1158
1159 /* At this time, post_op is not supported. */
1160 return 0 == p.len();
1161 }
1162
init_conf(jit_conv_conf_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & weights_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr,int nthreads)1163 status_t jit_sve_512_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
1164 const convolution_desc_t &cd, memory_desc_t &src_md,
1165 memory_desc_t &weights_md, memory_desc_t &dst_md,
1166 memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) {
1167 using namespace prop_kind;
1168
1169 const memory_desc_wrapper src_d(&src_md);
1170 const memory_desc_wrapper weights_d(&weights_md);
1171 const memory_desc_wrapper dst_d(&dst_md);
1172 const memory_desc_wrapper bias_d(&bias_md);
1173
1174 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1175 const int ndims = src_d.ndims();
1176 const bool is_1d = ndims == 3;
1177 const bool is_2d = ndims == 4;
1178 const bool is_3d = ndims == 5;
1179 assert(is_1d || is_2d || is_3d);
1180
1181 if (!(mayiuse(sve_512)
1182 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
1183 && weights_d.data_type() == data_type::s8
1184 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
1185 data_type::s8, data_type::u8)))
1186 return status::unimplemented;
1187
1188 jcp = zero<decltype(jcp)>();
1189 jcp.nthr = nthreads;
1190 jcp.ndims = ndims;
1191 jcp.prop_kind = cd.prop_kind;
1192 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1193 jcp.mb = src_d.dims()[0];
1194 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1195 jcp.oc_without_padding = jcp.oc;
1196 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1197 jcp.ic_without_padding = jcp.ic;
1198 jcp.id = is_3d ? src_d.dims()[2] : 1;
1199 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1200 jcp.iw = src_d.dims()[ndims - 1];
1201 jcp.od = is_3d ? dst_d.dims()[2] : 1;
1202 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1203 jcp.ow = dst_d.dims()[ndims - 1];
1204 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
1205 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
1206 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1207 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
1208 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
1209 jcp.l_pad = cd.padding[0][ndims - 3];
1210 jcp.stride_d = is_3d ? cd.strides[0] : 1;
1211 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
1212 jcp.stride_w = cd.strides[ndims - 3];
1213 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1214
1215 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
1216 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
1217 jcp.dilate_w = cd.dilates[ndims - 3];
1218
1219 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1220 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1221 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1222 jcp.r_pad = calculate_end_padding(
1223 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1224 jcp.b_pad = calculate_end_padding(
1225 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1226 jcp.back_pad = calculate_end_padding(
1227 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1228 bool kernel_outside_src = false || ext_kw <= jcp.l_pad
1229 || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad
1230 || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad;
1231 if (kernel_outside_src) return status::unimplemented;
1232
1233 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
1234 jcp.need_saturation = utils::one_of(
1235 dst_d.data_type(), data_type::u8, data_type::s8, data_type::s32);
1236 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
1237
1238 if (jcp.is_depthwise && is_3d)
1239 // NOTE: 3D depthwise is not currently supported here.
1240 return status::unimplemented;
1241
1242 if (jcp.is_depthwise) {
1243 jcp.ch_block = 16;
1244 jcp.ic_block = 1;
1245 jcp.oc_block = 1;
1246 } else {
1247 jcp.ch_block = 1;
1248 jcp.ic_block = 16;
1249 jcp.oc_block = 16;
1250
1251 if (jcp.ngroups == 1) {
1252 /* For non grouped convolutions, pad channels by 16 if needed */
1253 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1254 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1255 } else if (jcp.ngroups != 1
1256 && ((jcp.ic % jcp.ic_block != 0)
1257 || (jcp.oc % jcp.oc_block != 0))) {
1258 /* For grouped convolutions, oneDNN doesn't support padding.
1259 When channels per group is not multiple of 4, 8, 16, return unimplemented. */
1260 jcp.ic_block = (jcp.ic % 8 == 0) && (jcp.oc % 8 == 0) ? 8 : 4;
1261 jcp.oc_block = jcp.ic_block;
1262 }
1263 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
1264 return status::unimplemented;
1265 }
1266
1267 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
1268
1269 jcp.is_fast_depthwise = true && jcp.is_depthwise
1270 && jcp.ngroups % jcp.ch_block == 0; /* groups not multiple of
1271 ch_block (= 16) would require byte masking for load from src */
1272
1273 jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
1274 && jcp.kw < 4 && jcp.dilate_w == 0;
1275 if (jcp.is_depthwise) {
1276 jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
1277 - !jcp.signed_input
1278 - (!jcp.signed_input || jcp.need_saturation); // both alias
1279 } else {
1280 jcp.max_regs_ur = 31;
1281 }
1282
1283 auto set_or_check_wei_format = [&]() {
1284 using namespace format_tag;
1285 format_tag_t wei_tag;
1286 if (jcp.ic_block == 16 || jcp.ch_block == 16) {
1287 if (is_3d) {
1288 wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i;
1289 } else if (is_1d) {
1290 wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
1291 : OIw4i16o4i;
1292 } else {
1293 assert(is_2d);
1294 wei_tag = with_groups
1295 ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i
1296 : OIhw4i16o4i;
1297 }
1298 } else if (jcp.ic_block == 8) {
1299 assert(with_groups);
1300 wei_tag = is_3d ? gOIdhw2i8o4i : is_2d ? gOIhw2i8o4i : gOIw2i8o4i;
1301 } else {
1302 assert(with_groups && jcp.ic_block == 4);
1303 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
1304 }
1305
1306 memory_desc_t want_wei_md = weights_md;
1307 memory_desc_init_by_tag(want_wei_md, wei_tag);
1308 if (!jcp.signed_input) {
1309 want_wei_md.extra.flags = 0
1310 | memory_extra_flags::compensation_conv_s8s8
1311 | memory_extra_flags::scale_adjust;
1312 want_wei_md.extra.compensation_mask = (1 << 0)
1313 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
1314 want_wei_md.extra.scale_adjust = 1.f;
1315 }
1316
1317 if (weights_md.format_kind == format_kind::any) {
1318 weights_md = want_wei_md;
1319 return true;
1320 }
1321
1322 return weights_md == want_wei_md;
1323 };
1324
1325 if (!set_or_check_wei_format()) return status::unimplemented;
1326
1327 format_tag_t dat_tag = utils::pick(
1328 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1329
1330 if (src_d.format_kind() == format_kind::any) {
1331 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
1332 jcp.src_tag = dat_tag;
1333 } else {
1334 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1335 }
1336 if (jcp.src_tag != dat_tag) return status::unimplemented;
1337
1338 if (dst_d.format_kind() == format_kind::any) {
1339 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
1340 jcp.dst_tag = dat_tag;
1341 } else {
1342 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
1343 }
1344 if (jcp.dst_tag != dat_tag) return status::unimplemented;
1345
1346 if (jcp.with_bias) {
1347 if (bias_d.format_kind() == format_kind::any)
1348 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
1349 }
1350
1351 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1352 jcp.dst_dt = cd.dst_desc.data_type;
1353
1354 jcp.typesize_in = types::data_type_size(src_d.data_type());
1355 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1356 jcp.typesize_bia
1357 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
1358
1359 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
1360 jcp.nb_ic = jcp.ic / jcp.ic_block;
1361 jcp.nb_oc = jcp.oc / jcp.oc_block;
1362
1363 // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
1364 int nb_ch_blocking = 4;
1365 for (/* init above */; nb_ch_blocking > 1; nb_ch_blocking--)
1366 if (jcp.nb_ch % nb_ch_blocking == 0) break;
1367 jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
1368
1369 // If OC blocking is incommensurate with the number of OC blocks (general
1370 // requirement for all convolutions), or if it results in an unrolling
1371 // factor smaller than the left padding (special requirement for SSD:fc6),
1372 // then search for a smaller OC blocking that satisfies both constraints.
1373 auto is_oc_blocking_ok = [&](int block) {
1374 int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
1375 return jcp.nb_oc % block == 0 && jcp.l_pad <= ur_w
1376 && jcp.ow % ur_w != 1;
1377 };
1378
1379 // choose nb_oc work chunk size for distribution within threads
1380 int max_threading_nb_oc_chunk = 4;
1381 jcp.nb_oc_blocking_thr_chunk
1382 = nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
1383 for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
1384 if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) break;
1385 }
1386
1387 // choose oc blocking for computational kernel
1388 jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
1389
1390 if (jcp.is_resrc_depthwise)
1391 jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
1392 / (jcp.nb_ch_blocking + jcp.stride_w);
1393 else
1394 jcp.ur_w = jcp.max_regs_ur
1395 / (jcp.is_depthwise ? jcp.nb_ch_blocking
1396 : jcp.nb_oc_blocking + 1);
1397 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1398 if (!jcp.is_depthwise && jcp.ur_w < jcp.ow) {
1399 // tune ur_w such that penultimate ur_w block (including ur_w_tail)
1400 // does not read past the end of src
1401 const int broadcast_size = 4;
1402 if (jcp.ic_without_padding % broadcast_size != 0) {
1403 while (jcp.ur_w > 0) {
1404 int last_block_size = (jcp.ow % jcp.ur_w == 0)
1405 ? jcp.ur_w
1406 : jcp.ow % jcp.ur_w;
1407 int penultimate_iw_index
1408 = (jcp.ow - 1 - last_block_size) * jcp.stride_w
1409 + (jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad;
1410 int penultimate_iw_leeway = (jcp.iw - 1 - penultimate_iw_index)
1411 * jcp.ic_without_padding
1412 + jcp.ic_without_padding % broadcast_size;
1413 if (penultimate_iw_leeway >= broadcast_size) break;
1414 --jcp.ur_w;
1415 }
1416 if (jcp.ur_w == 0) // no satisfactory ur_w could be found
1417 return status::unimplemented;
1418 }
1419 }
1420 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1421
1422 jcp.ow_block = jcp.ow;
1423 int base_work_amount = jcp.mb * jcp.nb_ch * jcp.od * jcp.oh
1424 * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
1425 float best_thr_eff
1426 = (float)base_work_amount / rnd_up(base_work_amount, jcp.nthr);
1427 int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
1428 for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1429 int ow_block
1430 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
1431 if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
1432 && best_thr_eff > 0.8f)
1433 break;
1434 if (div_up(jcp.ow, ow_block) != nb_ow) continue;
1435 auto work_amount = base_work_amount * nb_ow;
1436 float thr_eff = (float)work_amount / rnd_up(work_amount, jcp.nthr);
1437 if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
1438 jcp.ow_block = ow_block;
1439 best_thr_eff = thr_eff;
1440 }
1441 if (best_thr_eff > 0.9f) break;
1442 }
1443 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1444
1445 bool args_ok = true && jcp.oc % jcp.oc_block == 0 && jcp.l_pad <= jcp.ur_w;
1446 if (!args_ok) return status::unimplemented;
1447
1448 int r_pad_no_tail = nstl::max(0,
1449 calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw,
1450 jcp.stride_w, ext_kw));
1451 if (r_pad_no_tail > jcp.ur_w) return status::unimplemented;
1452
1453 pick_loop_order(jcp, jcp.nthr);
1454
1455 const auto &oscales = attr.output_scales_;
1456 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
1457
1458 // only common and per-oc-channel scales are supported
1459 const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
1460 if (!oscales_ok) return status::unimplemented;
1461
1462 jcp.wei_adj_scale
1463 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
1464 ? weights_d.extra().scale_adjust
1465 : 1.f;
1466
1467 return status::success;
1468 }
1469
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_conv_conf_t & jcp,const primitive_attr_t & attr)1470 void jit_sve_512_x8s8s32x_fwd_kernel::init_scratchpad(
1471 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1472 const primitive_attr_t &attr) {}
1473
1474 } // namespace aarch64
1475 } // namespace cpu
1476 } // namespace impl
1477 } // namespace dnnl
1478