1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include <assert.h>
19 #include <float.h>
20
21 #include "common/c_types_map.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/memory.hpp"
24 #include "common/memory_tracking.hpp"
25 #include "common/nstl.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28
29 #include "cpu/aarch64/cpu_barrier.hpp"
30 #include "cpu/platform.hpp"
31
32 #include "cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp"
33
34 #include "cpu/aarch64/jit_uni_1x1_conv_utils.hpp"
35
36 #define GET_OFF(field) \
37 static_cast<int32_t>(offsetof(jit_1x1_conv_call_s, field))
38
39 namespace dnnl {
40 namespace impl {
41 namespace cpu {
42 namespace aarch64 {
43
44 using namespace dnnl::impl::format_tag;
45 using namespace dnnl::impl::prop_kind;
46 using namespace dnnl::impl::utils;
47
bcast_loop(int load_loop_blk)48 void jit_sve_512_1x1_conv_kernel::bcast_loop(int load_loop_blk) {
49
50 mov(aux1_reg_bcast_data, reg_bcast_data);
51 mov(aux_reg_bcast_data, reg_bcast_data);
52 mov(aux_reg_output_data, reg_output_data);
53 mov(reg_bcast_loop_iter, reg_bcast_loop_work);
54
55 Label bcast_loop;
56 Label bcast_loop_tail;
57 Label large_tail;
58
59 cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm);
60 b(LT, bcast_loop_tail);
61
62 L(bcast_loop);
63 {
64 assert(jcp.bcast_block % jcp.ur == 0);
65 int num_substeps = jcp.bcast_block / jcp.ur;
66 assert(num_substeps > 0 && num_substeps < 10);
67 for (int i = 0; i < num_substeps; i++) {
68 if (i + 1 == num_substeps) L(large_tail);
69 reduce_loop(load_loop_blk, jcp.ur, i, false);
70 if (i < num_substeps - 1) {
71 add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data,
72 jcp.bcast_loop_bcast_substep, reg_tmp_imm);
73 add_imm(aux_reg_output_data, aux_reg_output_data,
74 jcp.bcast_loop_output_substep, reg_tmp_imm);
75 } else {
76 add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data,
77 jcp.bcast_loop_bcast_step
78 - (num_substeps - 1)
79 * jcp.bcast_loop_bcast_substep,
80 reg_tmp_imm);
81 add_imm(aux_reg_output_data, aux_reg_output_data,
82 jcp.bcast_loop_output_step
83 - (num_substeps - 1)
84 * jcp.bcast_loop_output_substep,
85 reg_tmp_imm);
86 }
87 subs_imm(reg_bcast_loop_iter, reg_bcast_loop_iter, jcp.ur,
88 reg_tmp_imm);
89 }
90 cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm);
91 b(GE, bcast_loop);
92 }
93
94 L(bcast_loop_tail);
95 if (jcp.ur_tail) {
96 Label bcast_loop_tail_out;
97 if (jcp.ur_tail >= jcp.ur) {
98 cmp_imm(reg_bcast_loop_iter, jcp.ur, reg_tmp_imm);
99 b(GE, large_tail);
100 }
101 if (jcp.ur_tail % jcp.ur) {
102 cmp(reg_bcast_loop_iter, 0);
103 b(LE, bcast_loop_tail_out);
104 reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true);
105 L(bcast_loop_tail_out);
106 }
107 }
108 }
109
reduce_loop(int load_loop_blk,int ur,int substep,bool wraparound)110 void jit_sve_512_1x1_conv_kernel::reduce_loop(
111 int load_loop_blk, int ur, int substep, bool wraparound) {
112
113 const bool out_layout_nxc = is_out_layout_nxc(jcp);
114 const bool load_layout_nxc = is_load_layout_nxc(jcp);
115 const bool bcast_layout_nxc = is_bcast_layout_nxc(jcp);
116 const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block;
117
118 auto vreg_sum = [=]() { return ZReg(31); };
119 auto vreg_sum_s = [=]() { return ZRegS(31); };
120
121 auto vreg_load = [=](int i_load, int i_fma) {
122 return ZReg(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
123 + jcp.fma_step * i_load + i_fma);
124 };
125 auto vreg_load_s = [=](int i_load, int i_fma) {
126 return ZRegS(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
127 + jcp.fma_step * i_load + i_fma);
128 };
129
130 auto vreg_accum = [=](int i_load, int i_ur) {
131 return ZReg(i_ur * load_loop_blk + i_load);
132 };
133 auto vreg_accum_s = [=](int i_load, int i_ur) {
134 return ZRegS(i_ur * load_loop_blk + i_load);
135 };
136
137 auto bias_load = [=](int i_load, int i_ur) {
138 int ofs = jcp.typesize_out * jcp.oc_block * i_load;
139 if (ldr_imm_check(ofs)) {
140 ldr(vreg_accum(i_load, i_ur),
141 ptr(reg_bias_data, static_cast<int32_t>(VL64_OFS(ofs))));
142 } else {
143 add_imm(reg_tmp_ofs, reg_bias_data, ofs, reg_tmp_imm);
144 ldr(vreg_accum(i_load, i_ur), ptr(reg_tmp_ofs));
145 }
146 };
147
148 auto bcast_load = [=](int i_reduce, int i_ur, int prev_ofs, int bcast_idx) {
149 assert(i_ur < jcp.ur);
150 assert(i_reduce <= jcp.reduce_loop_unroll);
151 int ofs;
152 if (one_of(jcp.prop_kind, forward_training, forward_inference,
153 backward_data)) {
154 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
155 const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim
156 : jcp.reduce_loop_unroll;
157 ofs = (i_reduce == jcp.reduce_loop_unroll)
158 ? (jcp.bcast_dim + i_ur) * reduce_mul
159 : i_ur * reduce_mul + i_reduce;
160 } else {
161 int rmul = bcast_layout_nxc ? jcp.ic : jcp.ic_block;
162 ofs = i_reduce * rmul + i_ur;
163 }
164
165 ofs = jcp.typesize_in * ofs;
166 int tmp_ofs = ofs;
167 if (ld1rw_imm_check(ofs)) {
168 ld1rw(ZRegS(bcast_idx), reg_p_all_ones,
169 ptr(aux_reg_bcast_data, static_cast<int32_t>(ofs)));
170 } else {
171 if ((prev_ofs != -1) && ld1rw_imm_check(ofs - prev_ofs)) {
172 ld1rw(ZRegS(bcast_idx), reg_p_all_ones,
173 ptr(reg_prev_bcast_addr,
174 static_cast<int32_t>((ofs - prev_ofs))));
175 } else {
176 if ((prev_ofs != -1) && ((ofs - prev_ofs) >= 0)) {
177 ofs = ofs - prev_ofs;
178 add_imm(reg_prev_bcast_addr, reg_prev_bcast_addr, ofs,
179 reg_tmp_imm);
180 } else {
181 add_imm(reg_prev_bcast_addr, aux_reg_bcast_data, ofs,
182 reg_tmp_imm);
183 }
184 prev_ofs = tmp_ofs;
185
186 ld1rw(ZRegS(bcast_idx), reg_p_all_ones,
187 ptr(reg_prev_bcast_addr));
188 }
189 }
190 return prev_ofs;
191 };
192
193 auto load_load = [=](int i_reduce, int i_load, int i_fma) {
194 int ofs;
195 int u0 = i_reduce % jcp.reduce_loop_unroll;
196 int u1 = i_reduce / jcp.reduce_loop_unroll;
197 int lmul = jcp.load_block
198 * (load_layout_nxc ? 1
199 : utils::rnd_up(
200 jcp.reduce_dim, jcp.reduce_block));
201 int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block;
202 ofs = i_load * lmul + u0 * rmul;
203 ofs = u1 * jcp.reduce_loop_load_step + jcp.typesize_in * ofs;
204
205 if (ldr_imm_check(ofs)) {
206 ofs = VL64_OFS(ofs);
207 ldr(vreg_load(i_load, i_fma),
208 ptr(aux_reg_load_data, static_cast<int32_t>(ofs)));
209 } else {
210 add_imm(reg_tmp_ofs, aux_reg_load_data, ofs, reg_tmp_imm);
211 ldr(vreg_load(i_load, i_fma), ptr(reg_tmp_ofs));
212 }
213 };
214
215 auto out_load = [=](int i_load, int i_ur, int prev_ofs) {
216 int ofs, ofs_tmp;
217 int bwd_iload
218 = (i_load != 0) && one_of(jcp.prop_kind, backward_weights);
219 auto r = (bwd_iload) ? reg_tmp_ofs : aux_reg_output_data;
220
221 if (one_of(jcp.prop_kind, forward_training, forward_inference,
222 backward_data)) {
223 int i_load_shift = out_layout_nxc
224 ? jcp.load_block
225 : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim)
226 * jcp.load_block;
227 int i_ur_shift = out_layout_nxc ? jcp.load_dim : jcp.load_block;
228 ofs = (i_load * i_load_shift + i_ur * i_ur_shift)
229 * jcp.typesize_out;
230 } else {
231 ofs = jcp.typesize_out * jcp.load_block * i_ur;
232 }
233
234 ofs_tmp = ofs;
235
236 if (bwd_iload) mov(r, i_load);
237 if (ldr_imm_check(ofs)) {
238 if (bwd_iload) madd(r, r, reg_output_stride, aux_reg_output_data);
239 ldr(vreg_sum(), ptr(r, static_cast<int32_t>(VL64_OFS(ofs))));
240 } else {
241 if ((prev_ofs != -1) && ((ofs - prev_ofs) > 0)
242 && (VL64_OFS(ofs - prev_ofs) <= LDRMAX)) {
243 if (bwd_iload)
244 madd(r, r, reg_output_stride, reg_prev_out_addr);
245 else
246 r = reg_prev_out_addr;
247 ldr(vreg_sum(),
248 ptr(r, static_cast<int32_t>(VL64_OFS(ofs - prev_ofs))));
249 } else {
250 if ((prev_ofs != -1) && ((ofs - prev_ofs) > 0)) {
251 ofs = ofs - prev_ofs;
252 add_imm(reg_prev_out_addr, reg_prev_out_addr, ofs,
253 reg_tmp_imm);
254 } else {
255 add_imm(reg_prev_out_addr, aux_reg_output_data, ofs,
256 reg_tmp_imm);
257 }
258 if (bwd_iload)
259 madd(r, r, reg_output_stride, reg_prev_out_addr);
260 else
261 r = reg_prev_out_addr;
262 ldr(vreg_sum(), ptr(r));
263
264 prev_ofs = ofs_tmp;
265 }
266 }
267 return prev_ofs;
268 };
269
270 auto out_str = [=](int i_load, int i_ur, int prev_ofs) {
271 int ofs, ofs_tmp;
272 int bwd_iload
273 = (i_load != 0) && one_of(jcp.prop_kind, backward_weights);
274 auto r = (bwd_iload) ? reg_tmp_ofs : aux_reg_output_data;
275 if (one_of(jcp.prop_kind, forward_training, forward_inference,
276 backward_data)) {
277 ofs = (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
278 * jcp.typesize_out;
279 } else {
280 ofs = jcp.typesize_out * jcp.load_block * i_ur;
281 }
282 ofs_tmp = ofs;
283
284 if (bwd_iload) mov(r, i_load);
285 if (str_imm_check(ofs)) {
286 if (bwd_iload) madd(r, r, reg_output_stride, aux_reg_output_data);
287 str(vreg_accum(i_load, i_ur),
288 ptr(r, static_cast<int32_t>(VL64_OFS(ofs))));
289 } else {
290 if ((prev_ofs != -1) && str_imm_check(ofs - prev_ofs)) {
291 if (bwd_iload)
292 madd(r, r, reg_output_stride, reg_prev_out_addr);
293 else
294 r = reg_prev_out_addr;
295 str(vreg_accum(i_load, i_ur),
296 ptr(r, static_cast<int32_t>(VL64_OFS(ofs - prev_ofs))));
297 } else {
298 if ((prev_ofs != -1) && ((ofs - prev_ofs) > 0)) {
299 ofs = ofs - prev_ofs;
300 add_imm(reg_prev_out_addr, reg_prev_out_addr, ofs,
301 reg_tmp_imm);
302 } else {
303 add_imm(reg_prev_out_addr, aux_reg_output_data, ofs,
304 reg_tmp_imm);
305 }
306 if (bwd_iload)
307 madd(r, r, reg_output_stride, reg_prev_out_addr);
308 else
309 r = reg_prev_out_addr;
310 str(vreg_accum(i_load, i_ur), ptr(r));
311
312 prev_ofs = ofs_tmp;
313 }
314 }
315 return prev_ofs;
316 };
317
318 auto prefetch_output = [=](int i_load, int i_ur) {
319 int ofs;
320 int bwd_iload
321 = (i_load != 0) && one_of(jcp.prop_kind, backward_weights);
322 auto r = (bwd_iload) ? reg_tmp_ofs : aux_reg_output_data;
323 if (one_of(jcp.prop_kind, forward_training, forward_inference,
324 backward_data)) {
325 ofs = (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
326 * jcp.typesize_out;
327 } else {
328 ofs = jcp.typesize_out * jcp.load_block * i_ur;
329 }
330 std::string op = "LD";
331 prefetch(op, 2, r, ofs);
332 };
333
334 auto init = [=]() {
335 Label init_done;
336 Label init_zero;
337
338 if (jcp.with_sum) {
339 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
340 for (int i_ur = 0; i_ur < ur; ++i_ur) {
341 prefetch_output(i_load, i_ur);
342 }
343 }
344 }
345
346 if (jcp.with_bias
347 && one_of(jcp.prop_kind, forward_training, forward_inference)) {
348
349 tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
350 b(EQ, init_zero);
351
352 for (int i_load = 0; i_load < load_loop_blk; i_load++)
353 for (int i_ur = 0; i_ur < ur; ++i_ur) {
354 bias_load(i_load, i_ur);
355 }
356 b(init_done);
357 }
358
359 L(init_zero);
360 /* Zero clear */
361 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
362 for (int i_ur = 0; i_ur < ur; ++i_ur) {
363 fmov(vreg_accum_s(i_load, i_ur));
364 }
365 L(init_done);
366 };
367
368 auto store = [=]() {
369 Label store_noadd;
370 if (!jcp.with_sum) {
371 tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
372 b(NE, store_noadd);
373 }
374
375 int prev_ofs = -1;
376 for (int i_ur = 0; i_ur < ur; ++i_ur)
377 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
378 auto r = vreg_accum_s(i_load, i_ur);
379 prev_ofs = out_load(i_load, i_ur, prev_ofs);
380 fadd(r, r, vreg_sum_s());
381 }
382
383 L(store_noadd);
384 if (jcp.with_eltwise) {
385 #ifndef DISABLE_ELTWISE
386 Label store_noeltwise;
387 tst(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
388 b(EQ, store_noeltwise);
389 eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
390 L(store_noeltwise);
391 #else
392 assert(!"fused eltwise error!");
393 #endif
394 }
395
396 prev_ofs = -1;
397 for (int i_ur = 0; i_ur < ur; ++i_ur) {
398 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
399 prev_ofs = out_str(i_load, i_ur, prev_ofs);
400 }
401 }
402 };
403
404 auto fma_block = [=](bool last_block) {
405 assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
406
407 int reduce_step = jcp.fma_step;
408 int prev_bcast_ofs = -1;
409 assert(reduce_dim_tail % reduce_step == 0);
410
411 const int i_reduce_end = reduce_dim_tail && last_block
412 ? reduce_dim_tail
413 : jcp.reduce_loop_unroll;
414
415 int bcast_reg_ofs = utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
416 + jcp.fma_step * load_loop_blk;
417 int num_bcast_regs = 32 - bcast_reg_ofs;
418 int bcast_reg_idx = 0;
419
420 for (int i_reduce = 0; i_reduce < i_reduce_end;
421 i_reduce += reduce_step) { // IC
422 for (int i_load = 0; i_load < load_loop_blk; ++i_load) { // OC
423 for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
424 load_load(i_reduce + i_fma, i_load, i_fma);
425 }
426 }
427
428 int bcast_reg_startidx = bcast_reg_idx % num_bcast_regs;
429 for (int i_ur = 0; i_ur < ur; ++i_ur) {
430 if (i_ur >= num_bcast_regs) break;
431 prev_bcast_ofs = bcast_load(i_reduce, i_ur, prev_bcast_ofs,
432 bcast_reg_ofs + (bcast_reg_idx % num_bcast_regs));
433 bcast_reg_idx++;
434 }
435
436 for (int i_ur = 0; i_ur < ur; ++i_ur) {
437
438 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
439 fmla(vreg_accum_s(i_load, i_ur), reg_p_all_ones,
440 vreg_load_s(i_load, 0),
441 ZRegS(bcast_reg_ofs
442 + ((bcast_reg_startidx + i_ur)
443 % num_bcast_regs)));
444 }
445 if ((num_bcast_regs + i_ur) < ur) {
446 prev_bcast_ofs = bcast_load(i_reduce, num_bcast_regs + i_ur,
447 prev_bcast_ofs,
448 bcast_reg_ofs + (bcast_reg_idx % num_bcast_regs));
449 bcast_reg_idx++;
450 }
451 }
452 }
453 };
454 Label reduce_loop;
455 Label reduce_loop_tail;
456
457 mov(aux_reg_load_data, reg_load_data);
458
459 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
460 init();
461
462 mov(reduce_loop_iter, reg_reduce_loop_work);
463 subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll,
464 reg_tmp_imm);
465 b(LE, reduce_loop_tail);
466
467 align(32);
468 L(reduce_loop);
469 {
470 fma_block(false);
471 add_imm(aux_reg_bcast_data, aux_reg_bcast_data,
472 jcp.reduce_loop_bcast_step, reg_tmp_imm);
473 add_imm(aux_reg_load_data, aux_reg_load_data, jcp.reduce_loop_load_step,
474 reg_tmp_imm);
475 subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll,
476 reg_tmp_imm);
477 b(GT, reduce_loop);
478 }
479
480 L(reduce_loop_tail);
481 fma_block(true);
482
483 store();
484 }
485
generate()486 void jit_sve_512_1x1_conv_kernel::generate() {
487 preamble();
488
489 /* All 1 predicate register */
490 ptrue(reg_p_all_ones.b);
491
492 /* Pointers indicate weight, input, and output data */
493 ldr(reg_bcast_data, ptr(abi_param1, GET_OFF(bcast_data))); // Input
494 ldr(reg_load_data, ptr(abi_param1, GET_OFF(load_data))); // Weight
495 ldr(reg_output_data, ptr(abi_param1, GET_OFF(output_data))); // Output
496
497 /* Pointer indicates bias data if the layer has bias option */
498 if (jcp.with_bias) ldr(reg_bias_data, ptr(abi_param1, GET_OFF(bias_data)));
499
500 /* Get workloads of each loop */
501 ldr(reg_load_loop_work, ptr(abi_param1, GET_OFF(load_dim)));
502 ldr(reg_bcast_loop_work, ptr(abi_param1, GET_OFF(bcast_dim)));
503 ldr(reg_reduce_loop_work, ptr(abi_param1, GET_OFF(reduce_dim)));
504
505 /* A flag for controlling reduce loop */
506 ldr(reg_reduce_pos_flag, ptr(abi_param1, GET_OFF(first_last_flag)));
507
508 if (one_of(jcp.prop_kind, forward_training, forward_inference))
509 mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
510
511 if (jcp.prop_kind == backward_weights)
512 ldr(reg_output_stride, ptr(abi_param1, GET_OFF(output_stride)));
513
514 auto load_loop_body = [=](int load_loop_blk) {
515 subs_imm(reg_load_loop_work, reg_load_loop_work,
516 load_loop_blk * jcp.load_loop_iter_step, reg_tmp_imm);
517
518 bcast_loop(load_loop_blk);
519 add_imm(reg_load_data, reg_load_data,
520 load_loop_blk * jcp.load_loop_load_step, reg_tmp_imm);
521 switch (jcp.prop_kind) {
522 case forward_training:
523 case forward_inference:
524 add_imm(reg_bias_data, reg_bias_data,
525 load_loop_blk * jcp.load_block * jcp.typesize_out,
526 reg_tmp_imm);
527 add_imm(reg_output_data, reg_output_data,
528 load_loop_blk * jcp.load_block * jcp.typesize_out
529 * (is_out_layout_nxc(jcp)
530 ? 1
531 : (jcp.with_dw_conv
532 ? jcp.ow
533 : jcp.bcast_dim)),
534 reg_tmp_imm);
535 break;
536 case backward_data:
537 add_imm(reg_output_data, reg_output_data,
538 load_loop_blk * jcp.load_block * jcp.typesize_out
539 * (is_out_layout_nxc(jcp) ? 1 : jcp.bcast_dim),
540 reg_tmp_imm);
541 break;
542 case backward_weights:
543 for (int i_load = 0; i_load < load_loop_blk; i_load++)
544 add(reg_output_data, reg_output_data, reg_output_stride);
545 break;
546 default: assert(!"invalid prop_kind");
547 }
548 };
549
550 const int simd_w = cpu_isa_traits<sve_512>::vlen / sizeof(float);
551
552 Label load_loop_blk[7];
553
554 // with an implicit load_loop_block {6, 5, 4, 3, 2, 1}
555 static const int ur_cases_bcast[] = {2, 5, 6, 9, 14, 32};
556
557 const int size_ur_cases = sizeof(ur_cases_bcast);
558
559 const int *ur_cases = ur_cases_bcast;
560 const int num_ur_cases = size_ur_cases / sizeof(*ur_cases);
561
562 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
563 int label_idx = num_ur_cases - ur_idx - 1;
564 if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
565 cmp_imm(reg_load_loop_work, simd_w * (label_idx + 1), reg_tmp_imm);
566 b(LE, load_loop_blk[label_idx]);
567 }
568 }
569
570 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
571 int label_idx = num_ur_cases - ur_idx - 1;
572 if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
573 L(load_loop_blk[label_idx]);
574 {
575 if (label_idx == 0) {
576 cmp(reg_load_loop_work, 0);
577 b(LE, load_loop_blk[num_ur_cases]);
578 }
579 load_loop_body(label_idx + 1);
580 if (label_idx - 1 > 0) {
581 cmp_imm(reg_load_loop_work, 2 * label_idx * simd_w,
582 reg_tmp_imm);
583 b(EQ, load_loop_blk[label_idx - 1]);
584 }
585 cmp_imm(reg_load_loop_work, label_idx * simd_w, reg_tmp_imm);
586 b(GT, load_loop_blk[label_idx]);
587 }
588 for (int idx = label_idx - 1; idx > 0; --idx) {
589 cmp_imm(reg_load_loop_work, simd_w * (idx + 1), reg_tmp_imm);
590 b(EQ, load_loop_blk[idx]);
591 }
592 if (ur_idx < num_ur_cases - 2) {
593 cmp_imm(reg_load_loop_work, simd_w, reg_tmp_imm);
594 b(LE, load_loop_blk[0]);
595 }
596 }
597 }
598 L(load_loop_blk[num_ur_cases]);
599
600 postamble();
601 if (jcp.with_eltwise) {
602 #ifndef DISABLE_ELTWISE
603 eltwise_injector_->prepare_table();
604 binCommit();
605 #else
606 assert(!"fused eltwise error");
607 #endif
608 }
609 }
610
post_ops_ok(jit_1x1_conv_conf_t & jcp,const primitive_attr_t & attr)611 bool jit_sve_512_1x1_conv_kernel::post_ops_ok(
612 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
613
614 const auto &p = attr.post_ops_;
615
616 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
617 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
618 auto is_convolution
619 = [&](int idx) { return p.entry_[idx].is_convolution(); };
620
621 int dw_idx = p.find(primitive_kind::convolution);
622 int len = dw_idx != -1 ? dw_idx + 1 : p.len();
623
624 switch (len) {
625 case 0: return true; // no post_ops
626 case 1: // eltwise OR sum OR Convolution
627 return is_eltwise(0) || is_sum(0) || is_convolution(0);
628 case 2: // sum -> eltwise OR eltwise -> convolution
629 return (is_sum(0) && is_eltwise(1))
630 || (is_eltwise(0) && is_convolution(1));
631 default: return false;
632 }
633
634 return false;
635 }
636
init_conf(jit_1x1_conv_conf_t & jcp,const convolution_desc_t & cd,const memory_desc_wrapper & src_d,const memory_desc_wrapper & weights_d,const memory_desc_wrapper & dst_d,const primitive_attr_t & attr,int nthreads,bool reduce_src)637 status_t jit_sve_512_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
638 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
639 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
640 const primitive_attr_t &attr, int nthreads, bool reduce_src) {
641
642 /* arch check */
643 if (!mayiuse(sve_512)) return status::unimplemented;
644
645 jcp.nthr = nthreads;
646
647 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
648 const int simd_w = cpu_isa_traits<sve_512>::vlen / sizeof(float);
649 const int ndims = src_d.ndims();
650 /* Forward_[training, inference], backward_[data, weight] */
651 jcp.prop_kind = cd.prop_kind;
652
653 /* Check group option */
654 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
655 /* Batchsize */
656 jcp.mb = src_d.dims()[0];
657 /* Channel */
658 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
659 jcp.oc = jcp.oc_without_padding;
660 jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
661 jcp.ic = jcp.ic_without_padding;
662 /* D, H, W */
663 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
664 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
665 jcp.iw = src_d.dims()[ndims - 1];
666 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
667 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
668 jcp.ow = dst_d.dims()[ndims - 1];
669 /* Kernel size */
670 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
671 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
672 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
673 /* padding params */
674 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
675 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
676 jcp.l_pad = cd.padding[0][ndims - 3];
677 /* stride params */
678 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
679 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
680 jcp.stride_w = cd.strides[ndims - 3];
681 /* bias info */
682 jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
683 format_kind::undef, cd.diff_bias_desc.format_kind)
684 != format_kind::undef;
685
686 /* Spatials */
687 jcp.os = jcp.od * jcp.oh * jcp.ow;
688 jcp.is = jcp.id * jcp.ih * jcp.iw;
689 jcp.tr_is = rnd_up(jcp.is, 4);
690
691 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
692
693 /* Depthwise conv check */
694 const auto &p = attr.post_ops_;
695 const int dw_conv_ind = p.find(primitive_kind::convolution);
696 jcp.with_dw_conv = dw_conv_ind != -1;
697
698 /* Post operation check */
699 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
700 // handled in depthwise convolution.
701 jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
702 const int eltwise_ind = p.find(primitive_kind::eltwise, 0, dw_conv_ind);
703 jcp.with_eltwise = eltwise_ind != -1;
704 if (jcp.with_eltwise) {
705 #ifndef DISABLE_ELTWISE
706 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
707 if (jcp.eltwise.alg == alg_kind::eltwise_pow)
708 return status::unimplemented;
709 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
710 #else
711 return status::unimplemented;
712 #endif
713 }
714
715 /* Data format check */
716 const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
717 const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
718 jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
719 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
720 bool is_data_layout_nxc
721 = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
722 auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
723
724 if (is_data_layout_nxc) return status::unimplemented;
725
726 /* Channel padding check */
727 bool ok_to_pad_channels
728 = true && jcp.ngroups == 1 && src_d.data_type() == data_type::f32;
729
730 /* Input and output must be multiple of simd_w */
731 if (ok_to_pad_channels) {
732 jcp.oc = rnd_up(jcp.oc, simd_w);
733 jcp.ic = rnd_up(jcp.ic, simd_w);
734 }
735
736 bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag
737 && jcp.dst_tag == required_dat_tag
738 && (jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) && jcp.f_pad == 0
739 && jcp.t_pad == 0 && jcp.l_pad == 0 && jcp.stride_w == 1
740 && jcp.stride_h == 1 && jcp.stride_d == 1 && jcp.kd == 1
741 && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw
742 && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0
743 if (!args_ok) return status::unimplemented;
744
745 /* Channel blocking size is simd_w */
746 jcp.ic_block = jcp.oc_block = simd_w;
747
748 jcp.ver = ver_sve_512;
749 if (everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(),
750 dst_d.data_type())) {
751 const int is_bwd_d = jcp.prop_kind == backward_data;
752 /* Set weight data layout tag */
753 format_tag_t wei_tag = with_groups
754 ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
755 gOIhw16i16o, gIOhw16o16i, gOIdhw16i16o, gIOdhw16o16i)
756 : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
757 OIhw16i16o, IOhw16o16i, OIdhw16i16o, IOdhw16o16i);
758
759 jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
760 if (jcp.wei_tag != wei_tag) return status::unimplemented;
761
762 jcp.fma_step = 1;
763 jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
764 jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
765 } else {
766 // TODO: currently, only support fp32
767 return status::unimplemented;
768 }
769
770 /* once all the formats are set, check the padding consistency */
771 args_ok = true && jcp.ic <= src_d.padded_dims()[1]
772 && jcp.oc <= dst_d.padded_dims()[1]
773 && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
774 && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
775 if (!args_ok) return status::unimplemented;
776
777 // TODO: Optimize bellow params
778 const int SMALL_SPATIAL = 10;
779 const int BIG_SPATIAL = 65;
780 const int BIG_LOAD_DIM = (jcp.reduce_dim >= 512) ? 256 : 512;
781
782 int load_blocking {0};
783 int load_blocking_max {0};
784 int bcast_blocking {0};
785 int bcast_blocking_max {0};
786 int reduce_blocking {0};
787 int reduce_blocking_max {0};
788
789 jcp.load_grp_count = 1;
790
791 // TODO: mov check funcs into platform files
792 const int L1_capacity
793 = platform::get_per_core_cache_size(1) / sizeof(float);
794 const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float);
795 const int L2_capacity = (L2_size * 3) / 4;
796
797 /* FWD, BWD data */
798 if (one_of(jcp.prop_kind, forward_training, forward_inference,
799 backward_data)) {
800
801 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
802 /* Forward */
803 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
804 jcp.reduce_dim = jcp.ic; // src channel
805 jcp.reduce_block = jcp.ic_block; // src simd_w
806
807 jcp.load_dim = jcp.oc; // dst channel
808 jcp.load_block = jcp.oc_block; // dst simd_W
809
810 jcp.bcast_dim = jcp.is; // src H*W
811 } else {
812 /* Backward data */
813 jcp.reduce_dim = jcp.oc; // src channel
814 jcp.reduce_block = jcp.oc_block; // src simd_w
815
816 jcp.load_dim = jcp.ic; // dst channel
817 jcp.load_block = jcp.ic_block; // dst simd_w
818
819 jcp.bcast_dim = jcp.os; // src H*W
820 }
821
822 /* # of consecutive channel elements */
823 jcp.reduce_loop_unroll = jcp.reduce_block;
824
825 /* Offset to move to the next 16 input channel elements with the same H*W position */
826 jcp.reduce_loop_bcast_step
827 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
828
829 /* Offset: 16o*16i (filter) */
830 jcp.reduce_loop_load_step
831 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
832
833 /* Offset: I/16 * 16o */
834 jcp.load_loop_load_step
835 = (utils::rnd_up(jcp.reduce_dim, jcp.reduce_block))
836 * jcp.load_block * jcp.typesize_in;
837
838 /* adjusting registry blocking */
839 int max_regs, min_regs, size_threshold, ur_step;
840
841 /* spatial : H*D of dst */
842 const int spatial
843 = (one_of(jcp.prop_kind, forward_training, forward_inference))
844 ? jcp.od * jcp.oh // forward
845 : jcp.id * jcp.ih; // backward
846
847 max_regs = 9; // max # of ur_w
848 min_regs = 6; // min # of ur_w
849 size_threshold = 14;
850 ur_step = 1; // step size of ur_w param checking
851 jcp.ur = 1;
852
853 /*
854 * H*D of dst > SMALL_SPATIAL
855 */
856 if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
857 && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL
858 && jcp.reduce_dim < 256) {
859 max_regs = 6;
860 min_regs = 5;
861 }
862
863 for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
864 /*
865 * H*D of dst >= size_threshold, (H*D of dst) % ur_w == 0
866 * or
867 * H*D of dst < size_threshold, (H*W of dst) % ur_w == 0
868 */
869 if ((spatial >= size_threshold && spatial % ur_w == 0)
870 || (spatial < size_threshold && jcp.os % ur_w == 0)) {
871 jcp.ur = ur_w;
872 break;
873 }
874 }
875
876 if (jcp.ur == 1) {
877 // If ur = 1, then min(max_regs, H*W of dst)
878 jcp.ur = nstl::min(max_regs, jcp.os);
879 }
880 jcp.bcast_block = jcp.ur; // block size of bcast (input data)
881 /* Number of steps for the dst address to output, used in bcast_loop() */
882 jcp.bcast_loop_output_step = jcp.ur * jcp.typesize_out * jcp.load_block;
883 jcp.bcast_loop_output_substep = -1; // unused
884
885 /* Number of steps for the src address to be broadcasted in bcast_loop() */
886 jcp.bcast_loop_bcast_step = jcp.ur * jcp.typesize_in * jcp.reduce_block;
887 jcp.bcast_loop_bcast_substep = -1; // unused
888
889 jcp.load_loop_iter_step = jcp.load_block;
890
891 if (jcp.prop_kind == backward_data)
892 jcp.loop_order = loop_lbr;
893 else
894 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
895
896 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
897 int nb_load = div_up(jcp.load_dim, jcp.load_block);
898
899 reduce_blocking = jcp.reduce_dim;
900 if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
901 && spatial < BIG_SPATIAL) {
902 reduce_blocking = nstl::min(jcp.reduce_dim, 80);
903 } else if (spatial > SMALL_SPATIAL)
904 reduce_blocking = nstl::min(jcp.reduce_dim, 512);
905 else
906 reduce_blocking = nstl::min(jcp.reduce_dim, 256);
907
908 // Check input data cache aliasing.
909 // For other ISA constants may be updated.
910 // 64 * 1024 is chosen due to 1MB L2 16-way cache.
911 // 7 is empirical value. It is about half of 16.
912 // So we leave about half of the set for other data - weights, dst
913 int way_size = (16 * 1024) / jcp.typesize_in;
914 int max_hits = 7;
915 if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
916 int nrb = reduce_blocking / simd_w;
917 int sp = jcp.bcast_dim;
918 int wl = way_size / simd_w;
919 for (int start_off = 0; start_off < jcp.ur; start_off++) {
920 for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
921 if (off % sp >= jcp.ur || ++hits < max_hits) continue;
922 int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
923 reduce_blocking
924 = nstl::min(reduce_blocking, max_r_blocking);
925 break;
926 }
927 }
928 }
929
930 if (reduce_blocking < jcp.reduce_dim) {
931 if (jcp.prop_kind == backward_data)
932 jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
933 else
934 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
935 }
936 load_blocking = jcp.load_dim;
937
938 /* Number of weight elements to be loaded for dest */
939 int load_size = jcp.load_dim * jcp.reduce_dim;
940 /* Number of elements to be broadcasted from src */
941 auto bcast_size
942 = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
943
944 /* 12 cores per CMG */
945 if (jcp.nthr <= 12 && jcp.mb < jcp.nthr
946 && nb_load * nb_bcast > jcp.nthr) {
947 // Some heuristic here
948 float calc_koef = 0.01, best_cost = FLT_MAX;
949 int n_lgc = jcp.nthr;
950 float ratio = (float)load_size / (float)bcast_size;
951 int best_lgc = ratio > 1 ? n_lgc : 1;
952 auto calc_job_cost = [&](int lb, int tg, float mem_k) {
953 int bb_size = jcp.mb * div_up(nb_bcast, tg);
954 float calc_size = (float)(bb_size * jcp.ur)
955 * (lb * jcp.load_block) * jcp.reduce_dim;
956 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
957 * jcp.reduce_dim;
958 return calc_koef * calc_size + mem_k * mem_size;
959 };
960 for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
961 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
962 int min_lb = nb_load / lgc;
963 int max_lb = div_up(nb_load, lgc);
964 int min_tg = jcp.nthr / lgc;
965 int max_tg = div_up(jcp.nthr, lgc);
966 // Some heuristic here
967 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
968 float job_cost = 0.;
969 if (jcp.nthr % lgc < nb_load % lgc) {
970 job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
971 } else {
972 auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
973 auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
974 job_cost = nstl::max(job_cost1, job_cost2);
975 }
976
977 if (job_cost < best_cost) {
978 best_lgc = lgc;
979 best_cost = job_cost;
980 }
981 }
982 jcp.load_grp_count = best_lgc;
983 load_blocking
984 = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
985 } else {
986 jcp.load_grp_count
987 = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast);
988 jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count,
989 2 * jcp.load_grp_count, false);
990 }
991 if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr && jcp.load_dim > 512
992 && jcp.load_dim / jcp.reduce_dim >= 4) {
993 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
994 load_blocking = jcp.load_block;
995 }
996
997 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
998 div_up(jcp.nthr, jcp.load_grp_count))
999 * jcp.bcast_block;
1000 bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
1001 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1002
1003 int space_for_bcast = (L2_capacity - /* kernel_size - */
1004 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking
1005 - 3 * 1024);
1006 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2;
1007
1008 int bcast_in_cache
1009 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1010 bcast_blocking = nstl::min(
1011 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1012
1013 load_blocking_max = load_blocking;
1014 bcast_blocking_max = bcast_blocking * 3 / 2;
1015 reduce_blocking_max = reduce_blocking;
1016
1017 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
1018
1019 } else if (jcp.prop_kind == backward_weights) { /* BWD weight */
1020
1021 jcp.reduce_dim = jcp.is;
1022
1023 jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
1024 if (jcp.reduce_dim % jcp.reduce_block != 0)
1025 jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
1026 if (jcp.reduce_block > 256) { jcp.reduce_block = 1; }
1027
1028 jcp.load_dim = jcp.oc;
1029 jcp.load_block = jcp.oc_block;
1030
1031 jcp.bcast_dim = jcp.ic;
1032 jcp.bcast_block = jcp.ic_block;
1033
1034 if (jcp.reduce_block <= 19) {
1035 // if reduce_block is big then generated JIT code may be big
1036 // for small values of ur because reduce_loop_unroll = reduce_block
1037 jcp.ur = jcp.bcast_block / 2;
1038 } else {
1039 jcp.ur = jcp.bcast_block;
1040 }
1041
1042 jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block;
1043 jcp.reduce_loop_unroll = jcp.reduce_block;
1044 jcp.reduce_loop_bcast_step
1045 = jcp.typesize_in * jcp.reduce_loop_unroll * jcp.ic_block;
1046 jcp.reduce_loop_load_step
1047 = jcp.typesize_in * jcp.reduce_loop_unroll * jcp.oc_block;
1048
1049 jcp.bcast_loop_output_step
1050 = jcp.oc_block * jcp.ic_block * jcp.typesize_out;
1051 jcp.bcast_loop_output_substep
1052 = jcp.oc_block * jcp.ur * jcp.typesize_out;
1053 jcp.bcast_loop_bcast_step = jcp.ic_block
1054 * utils::rnd_up(jcp.reduce_dim, jcp.reduce_block)
1055 * jcp.typesize_in;
1056 jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
1057
1058 jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block * jcp.os;
1059 jcp.load_loop_iter_step = jcp.oc_block;
1060
1061 /* --- */
1062 balance(jcp);
1063
1064 load_blocking = div_up(jcp.load_dim, jcp.load_block);
1065 load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1066 load_blocking *= jcp.load_block;
1067
1068 load_blocking_max = load_blocking;
1069 assert(jcp.load_dim % load_blocking == 0);
1070
1071 int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1072 int min_bcast_blocking = 5;
1073
1074 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1075 bcast_blocking = best_divider(
1076 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1077 bcast_blocking *= jcp.bcast_block;
1078 bcast_blocking_max = bcast_blocking;
1079 assert(jcp.bcast_dim % bcast_blocking == 0);
1080
1081 // for reduction balance
1082 int max_reduce_blocking
1083 = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
1084 int min_reduce_blocking
1085 = nstl::min(L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1086 reduce_blocking = best_divider(
1087 jcp.reduce_dim, min_reduce_blocking, max_reduce_blocking, true);
1088 reduce_blocking = nstl::max(
1089 rnd_dn(reduce_blocking, jcp.reduce_block), jcp.reduce_block);
1090
1091 reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1092 } else
1093 return status::unimplemented;
1094
1095 assert(load_blocking);
1096 assert(load_blocking_max);
1097 assert(bcast_blocking);
1098 assert(bcast_blocking_max);
1099 assert(reduce_blocking);
1100 assert(reduce_blocking_max);
1101
1102 assert(load_blocking % jcp.load_block == 0);
1103 assert(reduce_blocking % jcp.reduce_block == 0);
1104 assert(load_blocking_max % jcp.load_block == 0);
1105 assert(reduce_blocking_max % jcp.reduce_block == 0);
1106 assert(jcp.reduce_dim % jcp.reduce_block == 0);
1107
1108 assert(jcp.bcast_block % jcp.ur == 0);
1109
1110 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1111 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1112 jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block);
1113 jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block);
1114 jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block);
1115 jcp.nb_reduce_blocking_max
1116 = utils::div_up(reduce_blocking_max, jcp.reduce_block);
1117
1118 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1119 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1120 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1121
1122 return status::success;
1123 }
1124
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_1x1_conv_conf_t & jcp)1125 void jit_sve_512_1x1_conv_kernel::init_scratchpad(
1126 memory_tracking::registrar_t &scratchpad,
1127 const jit_1x1_conv_conf_t &jcp) {
1128
1129 using namespace dnnl::impl::memory_tracking::names;
1130
1131 // Fox nxc layout bias is padded only for bwd_wb direction, as bias
1132 // reduction kernels can't handle tails yet.
1133 if (jcp.with_bias && jcp.prop_kind != backward_data
1134 && (jcp.oc != jcp.oc_without_padding // blocked layout
1135 || (jcp.prop_kind == backward_weights // nxc layout
1136 && jcp.oc % jcp.oc_block != 0))) {
1137
1138 const size_t nelems_padded_bias
1139 = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block);
1140 scratchpad.book(
1141 key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out);
1142 }
1143
1144 if (jcp.prop_kind == backward_weights) {
1145 const size_t wei_size = (size_t)jcp.ngroups
1146 * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block);
1147 scratchpad.book(key_conv_wei_reduction, wei_size * (jcp.nthr_mb - 1),
1148 jcp.typesize_out);
1149 }
1150 }
1151
1152 /* BWD W*/
balance(jit_1x1_conv_conf_t & jcp)1153 void jit_sve_512_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp) {
1154 int nthreads = jcp.nthr;
1155 // initialize jcp reduction threading properties
1156 jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1157 if (nthreads < jcp.ngroups) {
1158 /* simplification... fortunately it doesn't hurt much */
1159 return;
1160 }
1161 // bcast_dim: src H*W, bcast_block: ur (fwd, bwd_d)
1162 const int nb_bcast
1163 = div_up(jcp.bcast_dim, jcp.bcast_block); // # of H*W loop
1164 // load_dim: dst channel, load_block: simd_w
1165 const int nb_load
1166 = div_up(jcp.load_dim, jcp.load_block); // # of dst channel loop
1167 // reduce_dim: src channel, reduce_block: simd_w
1168 const int nb_reduce
1169 = div_up(jcp.reduce_dim, jcp.reduce_block); // # of src channel loop
1170
1171 jcp.nthr_g = jcp.ngroups;
1172 const int nthr = nthreads / jcp.nthr_g;
1173
1174 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1175 /* calculate per thread memory cost (read/write). high level
1176 * optimizer tries to minimize memory consumption. few notes: (n1)
1177 * unclear why, but that essentially helps first convolution...
1178 * (n2) assuming the reduction over minibatch is always there:
1179 * - instead of 8 it should be 5 here (write ~= 2 read):
1180 * kernel: temporal workspace 1 write
1181 * reduction: 1 read from workspace and 1 write to the diff_wei
1182 * - but experiments showed 8 works better than 5 or 6... */
1183 int bcast_koeff = 1;
1184 int load_koeff = 1;
1185 int output_koeff = 12;
1186 return 0
1187 + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1188 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b)
1189 * jcp.ic_block * jcp.reduce_block / jcp.stride_h
1190 / jcp.stride_w /* (n1) */
1191 + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1192 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1193 * jcp.oc_block * jcp.reduce_block
1194 + (size_t)output_koeff /* (n2) */
1195 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1196 * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block;
1197 };
1198
1199 int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1200 auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1201
1202 /* step 1: find the best thread distribution with lowest memory cost */
1203 const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1204 for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1205 const int nthr_par = nthr / nthr_mb;
1206 const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1207 for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1208 nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1209 auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1210 if (mem_cost <= best_mem_cost) {
1211 best_mem_cost = mem_cost;
1212 jcp.nthr_mb = nthr_mb;
1213 jcp.nthr_oc_b = nthr_oc_b;
1214 jcp.nthr_ic_b = nthr_ic_b;
1215 }
1216 }
1217
1218 const bool ready_for_async = utils::one_of(jcp.ver, ver_fma);
1219 if (!ready_for_async && !dnnl_thr_syncable()) {
1220 assert(nthr_mb == 1);
1221 break;
1222 }
1223 }
1224 if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1225 jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1226
1227 jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1228 assert(jcp.nthr <= nthreads);
1229 }
1230
1231 } // namespace aarch64
1232 } // namespace cpu
1233 } // namespace impl
1234 } // namespace dnnl
1235