1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #include <assert.h>
19 #include <float.h>
20 
21 #include "common/c_types_map.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/memory.hpp"
24 #include "common/memory_tracking.hpp"
25 #include "common/nstl.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28 
29 #include "cpu/aarch64/cpu_barrier.hpp"
30 #include "cpu/platform.hpp"
31 
32 #include "cpu/aarch64/jit_sve_512_1x1_conv_kernel.hpp"
33 
34 #include "cpu/aarch64/jit_uni_1x1_conv_utils.hpp"
35 
36 #define GET_OFF(field) \
37     static_cast<int32_t>(offsetof(jit_1x1_conv_call_s, field))
38 
39 namespace dnnl {
40 namespace impl {
41 namespace cpu {
42 namespace aarch64 {
43 
44 using namespace dnnl::impl::format_tag;
45 using namespace dnnl::impl::prop_kind;
46 using namespace dnnl::impl::utils;
47 
bcast_loop(int load_loop_blk)48 void jit_sve_512_1x1_conv_kernel::bcast_loop(int load_loop_blk) {
49 
50     mov(aux1_reg_bcast_data, reg_bcast_data);
51     mov(aux_reg_bcast_data, reg_bcast_data);
52     mov(aux_reg_output_data, reg_output_data);
53     mov(reg_bcast_loop_iter, reg_bcast_loop_work);
54 
55     Label bcast_loop;
56     Label bcast_loop_tail;
57     Label large_tail;
58 
59     cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm);
60     b(LT, bcast_loop_tail);
61 
62     L(bcast_loop);
63     {
64         assert(jcp.bcast_block % jcp.ur == 0);
65         int num_substeps = jcp.bcast_block / jcp.ur;
66         assert(num_substeps > 0 && num_substeps < 10);
67         for (int i = 0; i < num_substeps; i++) {
68             if (i + 1 == num_substeps) L(large_tail);
69             reduce_loop(load_loop_blk, jcp.ur, i, false);
70             if (i < num_substeps - 1) {
71                 add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data,
72                         jcp.bcast_loop_bcast_substep, reg_tmp_imm);
73                 add_imm(aux_reg_output_data, aux_reg_output_data,
74                         jcp.bcast_loop_output_substep, reg_tmp_imm);
75             } else {
76                 add_imm(aux1_reg_bcast_data, aux1_reg_bcast_data,
77                         jcp.bcast_loop_bcast_step
78                                 - (num_substeps - 1)
79                                         * jcp.bcast_loop_bcast_substep,
80                         reg_tmp_imm);
81                 add_imm(aux_reg_output_data, aux_reg_output_data,
82                         jcp.bcast_loop_output_step
83                                 - (num_substeps - 1)
84                                         * jcp.bcast_loop_output_substep,
85                         reg_tmp_imm);
86             }
87             subs_imm(reg_bcast_loop_iter, reg_bcast_loop_iter, jcp.ur,
88                     reg_tmp_imm);
89         }
90         cmp_imm(reg_bcast_loop_iter, jcp.bcast_block, reg_tmp_imm);
91         b(GE, bcast_loop);
92     }
93 
94     L(bcast_loop_tail);
95     if (jcp.ur_tail) {
96         Label bcast_loop_tail_out;
97         if (jcp.ur_tail >= jcp.ur) {
98             cmp_imm(reg_bcast_loop_iter, jcp.ur, reg_tmp_imm);
99             b(GE, large_tail);
100         }
101         if (jcp.ur_tail % jcp.ur) {
102             cmp(reg_bcast_loop_iter, 0);
103             b(LE, bcast_loop_tail_out);
104             reduce_loop(load_loop_blk, jcp.ur_tail % jcp.ur, 0, true);
105             L(bcast_loop_tail_out);
106         }
107     }
108 }
109 
reduce_loop(int load_loop_blk,int ur,int substep,bool wraparound)110 void jit_sve_512_1x1_conv_kernel::reduce_loop(
111         int load_loop_blk, int ur, int substep, bool wraparound) {
112 
113     const bool out_layout_nxc = is_out_layout_nxc(jcp);
114     const bool load_layout_nxc = is_load_layout_nxc(jcp);
115     const bool bcast_layout_nxc = is_bcast_layout_nxc(jcp);
116     const int reduce_dim_tail = jcp.reduce_dim % jcp.reduce_block;
117 
118     auto vreg_sum = [=]() { return ZReg(31); };
119     auto vreg_sum_s = [=]() { return ZRegS(31); };
120 
121     auto vreg_load = [=](int i_load, int i_fma) {
122         return ZReg(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
123                 + jcp.fma_step * i_load + i_fma);
124     };
125     auto vreg_load_s = [=](int i_load, int i_fma) {
126         return ZRegS(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
127                 + jcp.fma_step * i_load + i_fma);
128     };
129 
130     auto vreg_accum = [=](int i_load, int i_ur) {
131         return ZReg(i_ur * load_loop_blk + i_load);
132     };
133     auto vreg_accum_s = [=](int i_load, int i_ur) {
134         return ZRegS(i_ur * load_loop_blk + i_load);
135     };
136 
137     auto bias_load = [=](int i_load, int i_ur) {
138         int ofs = jcp.typesize_out * jcp.oc_block * i_load;
139         if (ldr_imm_check(ofs)) {
140             ldr(vreg_accum(i_load, i_ur),
141                     ptr(reg_bias_data, static_cast<int32_t>(VL64_OFS(ofs))));
142         } else {
143             add_imm(reg_tmp_ofs, reg_bias_data, ofs, reg_tmp_imm);
144             ldr(vreg_accum(i_load, i_ur), ptr(reg_tmp_ofs));
145         }
146     };
147 
148     auto bcast_load = [=](int i_reduce, int i_ur, int prev_ofs, int bcast_idx) {
149         assert(i_ur < jcp.ur);
150         assert(i_reduce <= jcp.reduce_loop_unroll);
151         int ofs;
152         if (one_of(jcp.prop_kind, forward_training, forward_inference,
153                     backward_data)) {
154             assert(jcp.reduce_loop_unroll == jcp.reduce_block);
155             const int reduce_mul = bcast_layout_nxc ? jcp.reduce_dim
156                                                     : jcp.reduce_loop_unroll;
157             ofs = (i_reduce == jcp.reduce_loop_unroll)
158                     ? (jcp.bcast_dim + i_ur) * reduce_mul
159                     : i_ur * reduce_mul + i_reduce;
160         } else {
161             int rmul = bcast_layout_nxc ? jcp.ic : jcp.ic_block;
162             ofs = i_reduce * rmul + i_ur;
163         }
164 
165         ofs = jcp.typesize_in * ofs;
166         int tmp_ofs = ofs;
167         if (ld1rw_imm_check(ofs)) {
168             ld1rw(ZRegS(bcast_idx), reg_p_all_ones,
169                     ptr(aux_reg_bcast_data, static_cast<int32_t>(ofs)));
170         } else {
171             if ((prev_ofs != -1) && ld1rw_imm_check(ofs - prev_ofs)) {
172                 ld1rw(ZRegS(bcast_idx), reg_p_all_ones,
173                         ptr(reg_prev_bcast_addr,
174                                 static_cast<int32_t>((ofs - prev_ofs))));
175             } else {
176                 if ((prev_ofs != -1) && ((ofs - prev_ofs) >= 0)) {
177                     ofs = ofs - prev_ofs;
178                     add_imm(reg_prev_bcast_addr, reg_prev_bcast_addr, ofs,
179                             reg_tmp_imm);
180                 } else {
181                     add_imm(reg_prev_bcast_addr, aux_reg_bcast_data, ofs,
182                             reg_tmp_imm);
183                 }
184                 prev_ofs = tmp_ofs;
185 
186                 ld1rw(ZRegS(bcast_idx), reg_p_all_ones,
187                         ptr(reg_prev_bcast_addr));
188             }
189         }
190         return prev_ofs;
191     };
192 
193     auto load_load = [=](int i_reduce, int i_load, int i_fma) {
194         int ofs;
195         int u0 = i_reduce % jcp.reduce_loop_unroll;
196         int u1 = i_reduce / jcp.reduce_loop_unroll;
197         int lmul = jcp.load_block
198                 * (load_layout_nxc ? 1
199                                    : utils::rnd_up(
200                                            jcp.reduce_dim, jcp.reduce_block));
201         int rmul = load_layout_nxc ? jcp.load_dim : jcp.load_block;
202         ofs = i_load * lmul + u0 * rmul;
203         ofs = u1 * jcp.reduce_loop_load_step + jcp.typesize_in * ofs;
204 
205         if (ldr_imm_check(ofs)) {
206             ofs = VL64_OFS(ofs);
207             ldr(vreg_load(i_load, i_fma),
208                     ptr(aux_reg_load_data, static_cast<int32_t>(ofs)));
209         } else {
210             add_imm(reg_tmp_ofs, aux_reg_load_data, ofs, reg_tmp_imm);
211             ldr(vreg_load(i_load, i_fma), ptr(reg_tmp_ofs));
212         }
213     };
214 
215     auto out_load = [=](int i_load, int i_ur, int prev_ofs) {
216         int ofs, ofs_tmp;
217         int bwd_iload
218                 = (i_load != 0) && one_of(jcp.prop_kind, backward_weights);
219         auto r = (bwd_iload) ? reg_tmp_ofs : aux_reg_output_data;
220 
221         if (one_of(jcp.prop_kind, forward_training, forward_inference,
222                     backward_data)) {
223             int i_load_shift = out_layout_nxc
224                     ? jcp.load_block
225                     : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim)
226                             * jcp.load_block;
227             int i_ur_shift = out_layout_nxc ? jcp.load_dim : jcp.load_block;
228             ofs = (i_load * i_load_shift + i_ur * i_ur_shift)
229                     * jcp.typesize_out;
230         } else {
231             ofs = jcp.typesize_out * jcp.load_block * i_ur;
232         }
233 
234         ofs_tmp = ofs;
235 
236         if (bwd_iload) mov(r, i_load);
237         if (ldr_imm_check(ofs)) {
238             if (bwd_iload) madd(r, r, reg_output_stride, aux_reg_output_data);
239             ldr(vreg_sum(), ptr(r, static_cast<int32_t>(VL64_OFS(ofs))));
240         } else {
241             if ((prev_ofs != -1) && ((ofs - prev_ofs) > 0)
242                     && (VL64_OFS(ofs - prev_ofs) <= LDRMAX)) {
243                 if (bwd_iload)
244                     madd(r, r, reg_output_stride, reg_prev_out_addr);
245                 else
246                     r = reg_prev_out_addr;
247                 ldr(vreg_sum(),
248                         ptr(r, static_cast<int32_t>(VL64_OFS(ofs - prev_ofs))));
249             } else {
250                 if ((prev_ofs != -1) && ((ofs - prev_ofs) > 0)) {
251                     ofs = ofs - prev_ofs;
252                     add_imm(reg_prev_out_addr, reg_prev_out_addr, ofs,
253                             reg_tmp_imm);
254                 } else {
255                     add_imm(reg_prev_out_addr, aux_reg_output_data, ofs,
256                             reg_tmp_imm);
257                 }
258                 if (bwd_iload)
259                     madd(r, r, reg_output_stride, reg_prev_out_addr);
260                 else
261                     r = reg_prev_out_addr;
262                 ldr(vreg_sum(), ptr(r));
263 
264                 prev_ofs = ofs_tmp;
265             }
266         }
267         return prev_ofs;
268     };
269 
270     auto out_str = [=](int i_load, int i_ur, int prev_ofs) {
271         int ofs, ofs_tmp;
272         int bwd_iload
273                 = (i_load != 0) && one_of(jcp.prop_kind, backward_weights);
274         auto r = (bwd_iload) ? reg_tmp_ofs : aux_reg_output_data;
275         if (one_of(jcp.prop_kind, forward_training, forward_inference,
276                     backward_data)) {
277             ofs = (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
278                     * jcp.typesize_out;
279         } else {
280             ofs = jcp.typesize_out * jcp.load_block * i_ur;
281         }
282         ofs_tmp = ofs;
283 
284         if (bwd_iload) mov(r, i_load);
285         if (str_imm_check(ofs)) {
286             if (bwd_iload) madd(r, r, reg_output_stride, aux_reg_output_data);
287             str(vreg_accum(i_load, i_ur),
288                     ptr(r, static_cast<int32_t>(VL64_OFS(ofs))));
289         } else {
290             if ((prev_ofs != -1) && str_imm_check(ofs - prev_ofs)) {
291                 if (bwd_iload)
292                     madd(r, r, reg_output_stride, reg_prev_out_addr);
293                 else
294                     r = reg_prev_out_addr;
295                 str(vreg_accum(i_load, i_ur),
296                         ptr(r, static_cast<int32_t>(VL64_OFS(ofs - prev_ofs))));
297             } else {
298                 if ((prev_ofs != -1) && ((ofs - prev_ofs) > 0)) {
299                     ofs = ofs - prev_ofs;
300                     add_imm(reg_prev_out_addr, reg_prev_out_addr, ofs,
301                             reg_tmp_imm);
302                 } else {
303                     add_imm(reg_prev_out_addr, aux_reg_output_data, ofs,
304                             reg_tmp_imm);
305                 }
306                 if (bwd_iload)
307                     madd(r, r, reg_output_stride, reg_prev_out_addr);
308                 else
309                     r = reg_prev_out_addr;
310                 str(vreg_accum(i_load, i_ur), ptr(r));
311 
312                 prev_ofs = ofs_tmp;
313             }
314         }
315         return prev_ofs;
316     };
317 
318     auto prefetch_output = [=](int i_load, int i_ur) {
319         int ofs;
320         int bwd_iload
321                 = (i_load != 0) && one_of(jcp.prop_kind, backward_weights);
322         auto r = (bwd_iload) ? reg_tmp_ofs : aux_reg_output_data;
323         if (one_of(jcp.prop_kind, forward_training, forward_inference,
324                     backward_data)) {
325             ofs = (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
326                     * jcp.typesize_out;
327         } else {
328             ofs = jcp.typesize_out * jcp.load_block * i_ur;
329         }
330         std::string op = "LD";
331         prefetch(op, 2, r, ofs);
332     };
333 
334     auto init = [=]() {
335         Label init_done;
336         Label init_zero;
337 
338         if (jcp.with_sum) {
339             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
340                 for (int i_ur = 0; i_ur < ur; ++i_ur) {
341                     prefetch_output(i_load, i_ur);
342                 }
343             }
344         }
345 
346         if (jcp.with_bias
347                 && one_of(jcp.prop_kind, forward_training, forward_inference)) {
348 
349             tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
350             b(EQ, init_zero);
351 
352             for (int i_load = 0; i_load < load_loop_blk; i_load++)
353                 for (int i_ur = 0; i_ur < ur; ++i_ur) {
354                     bias_load(i_load, i_ur);
355                 }
356             b(init_done);
357         }
358 
359         L(init_zero);
360         /* Zero clear */
361         for (int i_load = 0; i_load < load_loop_blk; ++i_load)
362             for (int i_ur = 0; i_ur < ur; ++i_ur) {
363                 fmov(vreg_accum_s(i_load, i_ur));
364             }
365         L(init_done);
366     };
367 
368     auto store = [=]() {
369         Label store_noadd;
370         if (!jcp.with_sum) {
371             tst(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
372             b(NE, store_noadd);
373         }
374 
375         int prev_ofs = -1;
376         for (int i_ur = 0; i_ur < ur; ++i_ur)
377             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
378                 auto r = vreg_accum_s(i_load, i_ur);
379                 prev_ofs = out_load(i_load, i_ur, prev_ofs);
380                 fadd(r, r, vreg_sum_s());
381             }
382 
383         L(store_noadd);
384         if (jcp.with_eltwise) {
385 #ifndef DISABLE_ELTWISE
386             Label store_noeltwise;
387             tst(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
388             b(EQ, store_noeltwise);
389             eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
390             L(store_noeltwise);
391 #else
392             assert(!"fused eltwise error!");
393 #endif
394         }
395 
396         prev_ofs = -1;
397         for (int i_ur = 0; i_ur < ur; ++i_ur) {
398             for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
399                 prev_ofs = out_str(i_load, i_ur, prev_ofs);
400             }
401         }
402     };
403 
404     auto fma_block = [=](bool last_block) {
405         assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
406 
407         int reduce_step = jcp.fma_step;
408         int prev_bcast_ofs = -1;
409         assert(reduce_dim_tail % reduce_step == 0);
410 
411         const int i_reduce_end = reduce_dim_tail && last_block
412                 ? reduce_dim_tail
413                 : jcp.reduce_loop_unroll;
414 
415         int bcast_reg_ofs = utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
416                 + jcp.fma_step * load_loop_blk;
417         int num_bcast_regs = 32 - bcast_reg_ofs;
418         int bcast_reg_idx = 0;
419 
420         for (int i_reduce = 0; i_reduce < i_reduce_end;
421                 i_reduce += reduce_step) { // IC
422             for (int i_load = 0; i_load < load_loop_blk; ++i_load) { // OC
423                 for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
424                     load_load(i_reduce + i_fma, i_load, i_fma);
425                 }
426             }
427 
428             int bcast_reg_startidx = bcast_reg_idx % num_bcast_regs;
429             for (int i_ur = 0; i_ur < ur; ++i_ur) {
430                 if (i_ur >= num_bcast_regs) break;
431                 prev_bcast_ofs = bcast_load(i_reduce, i_ur, prev_bcast_ofs,
432                         bcast_reg_ofs + (bcast_reg_idx % num_bcast_regs));
433                 bcast_reg_idx++;
434             }
435 
436             for (int i_ur = 0; i_ur < ur; ++i_ur) {
437 
438                 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
439                     fmla(vreg_accum_s(i_load, i_ur), reg_p_all_ones,
440                             vreg_load_s(i_load, 0),
441                             ZRegS(bcast_reg_ofs
442                                     + ((bcast_reg_startidx + i_ur)
443                                             % num_bcast_regs)));
444                 }
445                 if ((num_bcast_regs + i_ur) < ur) {
446                     prev_bcast_ofs = bcast_load(i_reduce, num_bcast_regs + i_ur,
447                             prev_bcast_ofs,
448                             bcast_reg_ofs + (bcast_reg_idx % num_bcast_regs));
449                     bcast_reg_idx++;
450                 }
451             }
452         }
453     };
454     Label reduce_loop;
455     Label reduce_loop_tail;
456 
457     mov(aux_reg_load_data, reg_load_data);
458 
459     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
460     init();
461 
462     mov(reduce_loop_iter, reg_reduce_loop_work);
463     subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll,
464             reg_tmp_imm);
465     b(LE, reduce_loop_tail);
466 
467     align(32);
468     L(reduce_loop);
469     {
470         fma_block(false);
471         add_imm(aux_reg_bcast_data, aux_reg_bcast_data,
472                 jcp.reduce_loop_bcast_step, reg_tmp_imm);
473         add_imm(aux_reg_load_data, aux_reg_load_data, jcp.reduce_loop_load_step,
474                 reg_tmp_imm);
475         subs_imm(reduce_loop_iter, reduce_loop_iter, jcp.reduce_loop_unroll,
476                 reg_tmp_imm);
477         b(GT, reduce_loop);
478     }
479 
480     L(reduce_loop_tail);
481     fma_block(true);
482 
483     store();
484 }
485 
generate()486 void jit_sve_512_1x1_conv_kernel::generate() {
487     preamble();
488 
489     /* All 1 predicate register */
490     ptrue(reg_p_all_ones.b);
491 
492     /* Pointers indicate weight, input, and output data */
493     ldr(reg_bcast_data, ptr(abi_param1, GET_OFF(bcast_data))); // Input
494     ldr(reg_load_data, ptr(abi_param1, GET_OFF(load_data))); // Weight
495     ldr(reg_output_data, ptr(abi_param1, GET_OFF(output_data))); // Output
496 
497     /* Pointer indicates bias data if the layer has bias option */
498     if (jcp.with_bias) ldr(reg_bias_data, ptr(abi_param1, GET_OFF(bias_data)));
499 
500     /* Get workloads of each loop */
501     ldr(reg_load_loop_work, ptr(abi_param1, GET_OFF(load_dim)));
502     ldr(reg_bcast_loop_work, ptr(abi_param1, GET_OFF(bcast_dim)));
503     ldr(reg_reduce_loop_work, ptr(abi_param1, GET_OFF(reduce_dim)));
504 
505     /* A flag for controlling reduce loop */
506     ldr(reg_reduce_pos_flag, ptr(abi_param1, GET_OFF(first_last_flag)));
507 
508     if (one_of(jcp.prop_kind, forward_training, forward_inference))
509         mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
510 
511     if (jcp.prop_kind == backward_weights)
512         ldr(reg_output_stride, ptr(abi_param1, GET_OFF(output_stride)));
513 
514     auto load_loop_body = [=](int load_loop_blk) {
515         subs_imm(reg_load_loop_work, reg_load_loop_work,
516                 load_loop_blk * jcp.load_loop_iter_step, reg_tmp_imm);
517 
518         bcast_loop(load_loop_blk);
519         add_imm(reg_load_data, reg_load_data,
520                 load_loop_blk * jcp.load_loop_load_step, reg_tmp_imm);
521         switch (jcp.prop_kind) {
522             case forward_training:
523             case forward_inference:
524                 add_imm(reg_bias_data, reg_bias_data,
525                         load_loop_blk * jcp.load_block * jcp.typesize_out,
526                         reg_tmp_imm);
527                 add_imm(reg_output_data, reg_output_data,
528                         load_loop_blk * jcp.load_block * jcp.typesize_out
529                                 * (is_out_layout_nxc(jcp)
530                                                 ? 1
531                                                 : (jcp.with_dw_conv
532                                                                 ? jcp.ow
533                                                                 : jcp.bcast_dim)),
534                         reg_tmp_imm);
535                 break;
536             case backward_data:
537                 add_imm(reg_output_data, reg_output_data,
538                         load_loop_blk * jcp.load_block * jcp.typesize_out
539                                 * (is_out_layout_nxc(jcp) ? 1 : jcp.bcast_dim),
540                         reg_tmp_imm);
541                 break;
542             case backward_weights:
543                 for (int i_load = 0; i_load < load_loop_blk; i_load++)
544                     add(reg_output_data, reg_output_data, reg_output_stride);
545                 break;
546             default: assert(!"invalid prop_kind");
547         }
548     };
549 
550     const int simd_w = cpu_isa_traits<sve_512>::vlen / sizeof(float);
551 
552     Label load_loop_blk[7];
553 
554     // with an implicit load_loop_block {6, 5, 4, 3, 2,  1}
555     static const int ur_cases_bcast[] = {2, 5, 6, 9, 14, 32};
556 
557     const int size_ur_cases = sizeof(ur_cases_bcast);
558 
559     const int *ur_cases = ur_cases_bcast;
560     const int num_ur_cases = size_ur_cases / sizeof(*ur_cases);
561 
562     for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
563         int label_idx = num_ur_cases - ur_idx - 1;
564         if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
565             cmp_imm(reg_load_loop_work, simd_w * (label_idx + 1), reg_tmp_imm);
566             b(LE, load_loop_blk[label_idx]);
567         }
568     }
569 
570     for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
571         int label_idx = num_ur_cases - ur_idx - 1;
572         if (jcp.nb_load > label_idx && jcp.ur <= ur_cases[ur_idx]) {
573             L(load_loop_blk[label_idx]);
574             {
575                 if (label_idx == 0) {
576                     cmp(reg_load_loop_work, 0);
577                     b(LE, load_loop_blk[num_ur_cases]);
578                 }
579                 load_loop_body(label_idx + 1);
580                 if (label_idx - 1 > 0) {
581                     cmp_imm(reg_load_loop_work, 2 * label_idx * simd_w,
582                             reg_tmp_imm);
583                     b(EQ, load_loop_blk[label_idx - 1]);
584                 }
585                 cmp_imm(reg_load_loop_work, label_idx * simd_w, reg_tmp_imm);
586                 b(GT, load_loop_blk[label_idx]);
587             }
588             for (int idx = label_idx - 1; idx > 0; --idx) {
589                 cmp_imm(reg_load_loop_work, simd_w * (idx + 1), reg_tmp_imm);
590                 b(EQ, load_loop_blk[idx]);
591             }
592             if (ur_idx < num_ur_cases - 2) {
593                 cmp_imm(reg_load_loop_work, simd_w, reg_tmp_imm);
594                 b(LE, load_loop_blk[0]);
595             }
596         }
597     }
598     L(load_loop_blk[num_ur_cases]);
599 
600     postamble();
601     if (jcp.with_eltwise) {
602 #ifndef DISABLE_ELTWISE
603         eltwise_injector_->prepare_table();
604         binCommit();
605 #else
606         assert(!"fused eltwise error");
607 #endif
608     }
609 }
610 
post_ops_ok(jit_1x1_conv_conf_t & jcp,const primitive_attr_t & attr)611 bool jit_sve_512_1x1_conv_kernel::post_ops_ok(
612         jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
613 
614     const auto &p = attr.post_ops_;
615 
616     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
617     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
618     auto is_convolution
619             = [&](int idx) { return p.entry_[idx].is_convolution(); };
620 
621     int dw_idx = p.find(primitive_kind::convolution);
622     int len = dw_idx != -1 ? dw_idx + 1 : p.len();
623 
624     switch (len) {
625         case 0: return true; // no post_ops
626         case 1: // eltwise OR sum OR Convolution
627             return is_eltwise(0) || is_sum(0) || is_convolution(0);
628         case 2: // sum -> eltwise OR eltwise -> convolution
629             return (is_sum(0) && is_eltwise(1))
630                     || (is_eltwise(0) && is_convolution(1));
631         default: return false;
632     }
633 
634     return false;
635 }
636 
init_conf(jit_1x1_conv_conf_t & jcp,const convolution_desc_t & cd,const memory_desc_wrapper & src_d,const memory_desc_wrapper & weights_d,const memory_desc_wrapper & dst_d,const primitive_attr_t & attr,int nthreads,bool reduce_src)637 status_t jit_sve_512_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
638         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
639         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
640         const primitive_attr_t &attr, int nthreads, bool reduce_src) {
641 
642     /* arch check */
643     if (!mayiuse(sve_512)) return status::unimplemented;
644 
645     jcp.nthr = nthreads;
646 
647     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
648     const int simd_w = cpu_isa_traits<sve_512>::vlen / sizeof(float);
649     const int ndims = src_d.ndims();
650     /* Forward_[training, inference], backward_[data, weight] */
651     jcp.prop_kind = cd.prop_kind;
652 
653     /* Check group option */
654     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
655     /* Batchsize */
656     jcp.mb = src_d.dims()[0];
657     /* Channel */
658     jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
659     jcp.oc = jcp.oc_without_padding;
660     jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
661     jcp.ic = jcp.ic_without_padding;
662     /* D, H, W */
663     jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
664     jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
665     jcp.iw = src_d.dims()[ndims - 1];
666     jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
667     jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
668     jcp.ow = dst_d.dims()[ndims - 1];
669     /* Kernel size */
670     jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
671     jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
672     jcp.kw = weights_d.dims()[with_groups + ndims - 1];
673     /* padding params */
674     jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
675     jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
676     jcp.l_pad = cd.padding[0][ndims - 3];
677     /* stride params */
678     jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
679     jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
680     jcp.stride_w = cd.strides[ndims - 3];
681     /* bias info */
682     jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
683                             format_kind::undef, cd.diff_bias_desc.format_kind)
684             != format_kind::undef;
685 
686     /* Spatials */
687     jcp.os = jcp.od * jcp.oh * jcp.ow;
688     jcp.is = jcp.id * jcp.ih * jcp.iw;
689     jcp.tr_is = rnd_up(jcp.is, 4);
690 
691     if (!post_ops_ok(jcp, attr)) return status::unimplemented;
692 
693     /* Depthwise conv check */
694     const auto &p = attr.post_ops_;
695     const int dw_conv_ind = p.find(primitive_kind::convolution);
696     jcp.with_dw_conv = dw_conv_ind != -1;
697 
698     /* Post operation check */
699     // Using dw_conv_ind as upper-bound below, as post-ops after it will be
700     // handled in depthwise convolution.
701     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
702     const int eltwise_ind = p.find(primitive_kind::eltwise, 0, dw_conv_ind);
703     jcp.with_eltwise = eltwise_ind != -1;
704     if (jcp.with_eltwise) {
705 #ifndef DISABLE_ELTWISE
706         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
707         if (jcp.eltwise.alg == alg_kind::eltwise_pow)
708             return status::unimplemented;
709         if (dst_d.data_type() == data_type::s32) return status::unimplemented;
710 #else
711         return status::unimplemented;
712 #endif
713     }
714 
715     /* Data format check */
716     const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc);
717     const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
718     jcp.src_tag = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
719     jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c);
720     bool is_data_layout_nxc
721             = utils::everyone_is(dat_tag_nxc, jcp.src_tag, jcp.dst_tag);
722     auto required_dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c;
723 
724     if (is_data_layout_nxc) return status::unimplemented;
725 
726     /* Channel padding check */
727     bool ok_to_pad_channels
728             = true && jcp.ngroups == 1 && src_d.data_type() == data_type::f32;
729 
730     /* Input and output must be multiple of simd_w */
731     if (ok_to_pad_channels) {
732         jcp.oc = rnd_up(jcp.oc, simd_w);
733         jcp.ic = rnd_up(jcp.ic, simd_w);
734     }
735 
736     bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == required_dat_tag
737             && jcp.dst_tag == required_dat_tag
738             && (jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0) && jcp.f_pad == 0
739             && jcp.t_pad == 0 && jcp.l_pad == 0 && jcp.stride_w == 1
740             && jcp.stride_h == 1 && jcp.stride_d == 1 && jcp.kd == 1
741             && jcp.kh == 1 && jcp.kw == 1 && jcp.ow == jcp.iw
742             && jcp.oh == jcp.ih && jcp.od == jcp.id; // enforce rpad=0
743     if (!args_ok) return status::unimplemented;
744 
745     /* Channel blocking size is simd_w */
746     jcp.ic_block = jcp.oc_block = simd_w;
747 
748     jcp.ver = ver_sve_512;
749     if (everyone_is(data_type::f32, src_d.data_type(), weights_d.data_type(),
750                 dst_d.data_type())) {
751         const int is_bwd_d = jcp.prop_kind == backward_data;
752         /* Set weight data layout tag */
753         format_tag_t wei_tag = with_groups
754                 ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
755                         gOIhw16i16o, gIOhw16o16i, gOIdhw16i16o, gIOdhw16o16i)
756                 : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
757                         OIhw16i16o, IOhw16o16i, OIdhw16i16o, IOdhw16o16i);
758 
759         jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
760         if (jcp.wei_tag != wei_tag) return status::unimplemented;
761 
762         jcp.fma_step = 1;
763         jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
764         jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
765     } else {
766         // TODO: currently, only support fp32
767         return status::unimplemented;
768     }
769 
770     /* once all the formats are set, check the padding consistency */
771     args_ok = true && jcp.ic <= src_d.padded_dims()[1]
772             && jcp.oc <= dst_d.padded_dims()[1]
773             && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
774             && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
775     if (!args_ok) return status::unimplemented;
776 
777     // TODO: Optimize bellow params
778     const int SMALL_SPATIAL = 10;
779     const int BIG_SPATIAL = 65;
780     const int BIG_LOAD_DIM = (jcp.reduce_dim >= 512) ? 256 : 512;
781 
782     int load_blocking {0};
783     int load_blocking_max {0};
784     int bcast_blocking {0};
785     int bcast_blocking_max {0};
786     int reduce_blocking {0};
787     int reduce_blocking_max {0};
788 
789     jcp.load_grp_count = 1;
790 
791     // TODO: mov check funcs into platform files
792     const int L1_capacity
793             = platform::get_per_core_cache_size(1) / sizeof(float);
794     const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float);
795     const int L2_capacity = (L2_size * 3) / 4;
796 
797     /* FWD, BWD data */
798     if (one_of(jcp.prop_kind, forward_training, forward_inference,
799                 backward_data)) {
800 
801         if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
802             /* Forward */
803             if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
804             jcp.reduce_dim = jcp.ic; // src channel
805             jcp.reduce_block = jcp.ic_block; // src simd_w
806 
807             jcp.load_dim = jcp.oc; // dst channel
808             jcp.load_block = jcp.oc_block; // dst simd_W
809 
810             jcp.bcast_dim = jcp.is; // src H*W
811         } else {
812             /* Backward data */
813             jcp.reduce_dim = jcp.oc; // src channel
814             jcp.reduce_block = jcp.oc_block; // src simd_w
815 
816             jcp.load_dim = jcp.ic; // dst channel
817             jcp.load_block = jcp.ic_block; // dst simd_w
818 
819             jcp.bcast_dim = jcp.os; // src H*W
820         }
821 
822         /* # of consecutive channel elements  */
823         jcp.reduce_loop_unroll = jcp.reduce_block;
824 
825         /* Offset to move to the next 16 input channel elements with the same H*W position */
826         jcp.reduce_loop_bcast_step
827                 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
828 
829         /* Offset: 16o*16i (filter) */
830         jcp.reduce_loop_load_step
831                 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
832 
833         /* Offset: I/16 * 16o */
834         jcp.load_loop_load_step
835                 = (utils::rnd_up(jcp.reduce_dim, jcp.reduce_block))
836                 * jcp.load_block * jcp.typesize_in;
837 
838         /* adjusting registry blocking */
839         int max_regs, min_regs, size_threshold, ur_step;
840 
841         /* spatial : H*D of dst */
842         const int spatial
843                 = (one_of(jcp.prop_kind, forward_training, forward_inference))
844                 ? jcp.od * jcp.oh // forward
845                 : jcp.id * jcp.ih; // backward
846 
847         max_regs = 9; // max # of ur_w
848         min_regs = 6; // min # of ur_w
849         size_threshold = 14;
850         ur_step = 1; // step size of ur_w param checking
851         jcp.ur = 1;
852 
853         /*
854          *  H*D of dst  > SMALL_SPATIAL
855          */
856         if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
857                 && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL
858                 && jcp.reduce_dim < 256) {
859             max_regs = 6;
860             min_regs = 5;
861         }
862 
863         for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
864             /*
865              *  H*D of dst >= size_threshold, (H*D of dst) % ur_w == 0
866              *  or
867              *  H*D of dst < size_threshold, (H*W of dst) % ur_w == 0
868              */
869             if ((spatial >= size_threshold && spatial % ur_w == 0)
870                     || (spatial < size_threshold && jcp.os % ur_w == 0)) {
871                 jcp.ur = ur_w;
872                 break;
873             }
874         }
875 
876         if (jcp.ur == 1) {
877             // If ur = 1, then min(max_regs, H*W of dst)
878             jcp.ur = nstl::min(max_regs, jcp.os);
879         }
880         jcp.bcast_block = jcp.ur; // block size of bcast (input data)
881         /* Number of steps for the dst address to output, used in bcast_loop() */
882         jcp.bcast_loop_output_step = jcp.ur * jcp.typesize_out * jcp.load_block;
883         jcp.bcast_loop_output_substep = -1; // unused
884 
885         /* Number of steps for the src address to be broadcasted in bcast_loop() */
886         jcp.bcast_loop_bcast_step = jcp.ur * jcp.typesize_in * jcp.reduce_block;
887         jcp.bcast_loop_bcast_substep = -1; // unused
888 
889         jcp.load_loop_iter_step = jcp.load_block;
890 
891         if (jcp.prop_kind == backward_data)
892             jcp.loop_order = loop_lbr;
893         else
894             jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
895 
896         int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
897         int nb_load = div_up(jcp.load_dim, jcp.load_block);
898 
899         reduce_blocking = jcp.reduce_dim;
900         if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
901                 && spatial < BIG_SPATIAL) {
902             reduce_blocking = nstl::min(jcp.reduce_dim, 80);
903         } else if (spatial > SMALL_SPATIAL)
904             reduce_blocking = nstl::min(jcp.reduce_dim, 512);
905         else
906             reduce_blocking = nstl::min(jcp.reduce_dim, 256);
907 
908         // Check input data cache aliasing.
909         // For other ISA constants may be updated.
910         // 64 * 1024 is chosen due to 1MB L2 16-way cache.
911         // 7 is empirical value. It is about half of 16.
912         // So we leave about half of the set for other data - weights, dst
913         int way_size = (16 * 1024) / jcp.typesize_in;
914         int max_hits = 7;
915         if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
916             int nrb = reduce_blocking / simd_w;
917             int sp = jcp.bcast_dim;
918             int wl = way_size / simd_w;
919             for (int start_off = 0; start_off < jcp.ur; start_off++) {
920                 for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
921                     if (off % sp >= jcp.ur || ++hits < max_hits) continue;
922                     int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
923                     reduce_blocking
924                             = nstl::min(reduce_blocking, max_r_blocking);
925                     break;
926                 }
927             }
928         }
929 
930         if (reduce_blocking < jcp.reduce_dim) {
931             if (jcp.prop_kind == backward_data)
932                 jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
933             else
934                 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
935         }
936         load_blocking = jcp.load_dim;
937 
938         /* Number of weight elements to be loaded for dest */
939         int load_size = jcp.load_dim * jcp.reduce_dim;
940         /* Number of elements to be broadcasted from src */
941         auto bcast_size
942                 = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
943 
944         /* 12 cores per CMG */
945         if (jcp.nthr <= 12 && jcp.mb < jcp.nthr
946                 && nb_load * nb_bcast > jcp.nthr) {
947             // Some heuristic here
948             float calc_koef = 0.01, best_cost = FLT_MAX;
949             int n_lgc = jcp.nthr;
950             float ratio = (float)load_size / (float)bcast_size;
951             int best_lgc = ratio > 1 ? n_lgc : 1;
952             auto calc_job_cost = [&](int lb, int tg, float mem_k) {
953                 int bb_size = jcp.mb * div_up(nb_bcast, tg);
954                 float calc_size = (float)(bb_size * jcp.ur)
955                         * (lb * jcp.load_block) * jcp.reduce_dim;
956                 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
957                         * jcp.reduce_dim;
958                 return calc_koef * calc_size + mem_k * mem_size;
959             };
960             for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
961                 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
962                 int min_lb = nb_load / lgc;
963                 int max_lb = div_up(nb_load, lgc);
964                 int min_tg = jcp.nthr / lgc;
965                 int max_tg = div_up(jcp.nthr, lgc);
966                 // Some heuristic here
967                 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
968                 float job_cost = 0.;
969                 if (jcp.nthr % lgc < nb_load % lgc) {
970                     job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
971                 } else {
972                     auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
973                     auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
974                     job_cost = nstl::max(job_cost1, job_cost2);
975                 }
976 
977                 if (job_cost < best_cost) {
978                     best_lgc = lgc;
979                     best_cost = job_cost;
980                 }
981             }
982             jcp.load_grp_count = best_lgc;
983             load_blocking
984                     = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
985         } else {
986             jcp.load_grp_count
987                     = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast);
988             jcp.load_grp_count = best_divider(jcp.nthr, jcp.load_grp_count,
989                     2 * jcp.load_grp_count, false);
990         }
991         if (jcp.bcast_dim <= 49 && jcp.mb <= jcp.nthr && jcp.load_dim > 512
992                 && jcp.load_dim / jcp.reduce_dim >= 4) {
993             jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
994             load_blocking = jcp.load_block;
995         }
996 
997         bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
998                                  div_up(jcp.nthr, jcp.load_grp_count))
999                 * jcp.bcast_block;
1000         bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
1001         bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1002 
1003         int space_for_bcast = (L2_capacity - /* kernel_size - */
1004                 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking
1005                 - 3 * 1024);
1006         if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2;
1007 
1008         int bcast_in_cache
1009                 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1010         bcast_blocking = nstl::min(
1011                 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1012 
1013         load_blocking_max = load_blocking;
1014         bcast_blocking_max = bcast_blocking * 3 / 2;
1015         reduce_blocking_max = reduce_blocking;
1016 
1017         jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
1018 
1019     } else if (jcp.prop_kind == backward_weights) { /* BWD weight */
1020 
1021         jcp.reduce_dim = jcp.is;
1022 
1023         jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
1024         if (jcp.reduce_dim % jcp.reduce_block != 0)
1025             jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
1026         if (jcp.reduce_block > 256) { jcp.reduce_block = 1; }
1027 
1028         jcp.load_dim = jcp.oc;
1029         jcp.load_block = jcp.oc_block;
1030 
1031         jcp.bcast_dim = jcp.ic;
1032         jcp.bcast_block = jcp.ic_block;
1033 
1034         if (jcp.reduce_block <= 19) {
1035             // if reduce_block is big then generated JIT code may be big
1036             // for small values of ur because reduce_loop_unroll = reduce_block
1037             jcp.ur = jcp.bcast_block / 2;
1038         } else {
1039             jcp.ur = jcp.bcast_block;
1040         }
1041 
1042         jcp.ur_tail = jcp.bcast_dim % jcp.bcast_block;
1043         jcp.reduce_loop_unroll = jcp.reduce_block;
1044         jcp.reduce_loop_bcast_step
1045                 = jcp.typesize_in * jcp.reduce_loop_unroll * jcp.ic_block;
1046         jcp.reduce_loop_load_step
1047                 = jcp.typesize_in * jcp.reduce_loop_unroll * jcp.oc_block;
1048 
1049         jcp.bcast_loop_output_step
1050                 = jcp.oc_block * jcp.ic_block * jcp.typesize_out;
1051         jcp.bcast_loop_output_substep
1052                 = jcp.oc_block * jcp.ur * jcp.typesize_out;
1053         jcp.bcast_loop_bcast_step = jcp.ic_block
1054                 * utils::rnd_up(jcp.reduce_dim, jcp.reduce_block)
1055                 * jcp.typesize_in;
1056         jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
1057 
1058         jcp.load_loop_load_step = jcp.typesize_in * jcp.oc_block * jcp.os;
1059         jcp.load_loop_iter_step = jcp.oc_block;
1060 
1061         /* --- */
1062         balance(jcp);
1063 
1064         load_blocking = div_up(jcp.load_dim, jcp.load_block);
1065         load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1066         load_blocking *= jcp.load_block;
1067 
1068         load_blocking_max = load_blocking;
1069         assert(jcp.load_dim % load_blocking == 0);
1070 
1071         int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1072         int min_bcast_blocking = 5;
1073 
1074         bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1075         bcast_blocking = best_divider(
1076                 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1077         bcast_blocking *= jcp.bcast_block;
1078         bcast_blocking_max = bcast_blocking;
1079         assert(jcp.bcast_dim % bcast_blocking == 0);
1080 
1081         // for reduction balance
1082         int max_reduce_blocking
1083                 = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
1084         int min_reduce_blocking
1085                 = nstl::min(L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1086         reduce_blocking = best_divider(
1087                 jcp.reduce_dim, min_reduce_blocking, max_reduce_blocking, true);
1088         reduce_blocking = nstl::max(
1089                 rnd_dn(reduce_blocking, jcp.reduce_block), jcp.reduce_block);
1090 
1091         reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1092     } else
1093         return status::unimplemented;
1094 
1095     assert(load_blocking);
1096     assert(load_blocking_max);
1097     assert(bcast_blocking);
1098     assert(bcast_blocking_max);
1099     assert(reduce_blocking);
1100     assert(reduce_blocking_max);
1101 
1102     assert(load_blocking % jcp.load_block == 0);
1103     assert(reduce_blocking % jcp.reduce_block == 0);
1104     assert(load_blocking_max % jcp.load_block == 0);
1105     assert(reduce_blocking_max % jcp.reduce_block == 0);
1106     assert(jcp.reduce_dim % jcp.reduce_block == 0);
1107 
1108     assert(jcp.bcast_block % jcp.ur == 0);
1109 
1110     jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1111     jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1112     jcp.nb_load_blocking = utils::div_up(load_blocking, jcp.load_block);
1113     jcp.nb_load_blocking_max = utils::div_up(load_blocking_max, jcp.load_block);
1114     jcp.nb_reduce_blocking = utils::div_up(reduce_blocking, jcp.reduce_block);
1115     jcp.nb_reduce_blocking_max
1116             = utils::div_up(reduce_blocking_max, jcp.reduce_block);
1117 
1118     jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1119     jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1120     jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1121 
1122     return status::success;
1123 }
1124 
init_scratchpad(memory_tracking::registrar_t & scratchpad,const jit_1x1_conv_conf_t & jcp)1125 void jit_sve_512_1x1_conv_kernel::init_scratchpad(
1126         memory_tracking::registrar_t &scratchpad,
1127         const jit_1x1_conv_conf_t &jcp) {
1128 
1129     using namespace dnnl::impl::memory_tracking::names;
1130 
1131     // Fox nxc layout bias is padded only for bwd_wb direction, as  bias
1132     // reduction kernels can't handle tails yet.
1133     if (jcp.with_bias && jcp.prop_kind != backward_data
1134             && (jcp.oc != jcp.oc_without_padding // blocked layout
1135                     || (jcp.prop_kind == backward_weights // nxc layout
1136                             && jcp.oc % jcp.oc_block != 0))) {
1137 
1138         const size_t nelems_padded_bias
1139                 = jcp.ngroups * utils::rnd_up(jcp.oc, jcp.oc_block);
1140         scratchpad.book(
1141                 key_conv_padded_bias, nelems_padded_bias, jcp.typesize_out);
1142     }
1143 
1144     if (jcp.prop_kind == backward_weights) {
1145         const size_t wei_size = (size_t)jcp.ngroups
1146                 * rnd_up(jcp.oc, jcp.oc_block) * rnd_up(jcp.ic, jcp.ic_block);
1147         scratchpad.book(key_conv_wei_reduction, wei_size * (jcp.nthr_mb - 1),
1148                 jcp.typesize_out);
1149     }
1150 }
1151 
1152 /* BWD W*/
balance(jit_1x1_conv_conf_t & jcp)1153 void jit_sve_512_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp) {
1154     int nthreads = jcp.nthr;
1155     // initialize jcp reduction threading properties
1156     jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1157     if (nthreads < jcp.ngroups) {
1158         /* simplification... fortunately it doesn't hurt much */
1159         return;
1160     }
1161     // bcast_dim: src H*W, bcast_block: ur (fwd, bwd_d)
1162     const int nb_bcast
1163             = div_up(jcp.bcast_dim, jcp.bcast_block); // # of H*W loop
1164     // load_dim: dst channel, load_block: simd_w
1165     const int nb_load
1166             = div_up(jcp.load_dim, jcp.load_block); // # of dst channel loop
1167     // reduce_dim: src channel, reduce_block: simd_w
1168     const int nb_reduce
1169             = div_up(jcp.reduce_dim, jcp.reduce_block); // # of src channel loop
1170 
1171     jcp.nthr_g = jcp.ngroups;
1172     const int nthr = nthreads / jcp.nthr_g;
1173 
1174     auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1175         /* calculate per thread memory cost (read/write). high level
1176         * optimizer tries to minimize memory consumption. few notes: (n1)
1177         * unclear why, but that essentially helps first convolution...
1178         *  (n2) assuming the reduction over minibatch is always there:
1179         *    - instead of 8 it should be 5 here (write ~= 2 read):
1180         *      kernel: temporal workspace 1 write
1181         *      reduction: 1 read from workspace and 1 write to the diff_wei
1182         *    - but experiments showed 8 works better than 5 or 6... */
1183         int bcast_koeff = 1;
1184         int load_koeff = 1;
1185         int output_koeff = 12;
1186         return 0
1187                 + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1188                 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_bcast, nthr_ic_b)
1189                 * jcp.ic_block * jcp.reduce_block / jcp.stride_h
1190                 / jcp.stride_w /* (n1) */
1191                 + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1192                 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1193                 * jcp.oc_block * jcp.reduce_block
1194                 + (size_t)output_koeff /* (n2) */
1195                 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1196                 * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.oc_block;
1197     };
1198 
1199     int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1200     auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1201 
1202     /* step 1: find the best thread distribution with lowest memory cost */
1203     const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1204     for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1205         const int nthr_par = nthr / nthr_mb;
1206         const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1207         for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1208             nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1209             auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1210             if (mem_cost <= best_mem_cost) {
1211                 best_mem_cost = mem_cost;
1212                 jcp.nthr_mb = nthr_mb;
1213                 jcp.nthr_oc_b = nthr_oc_b;
1214                 jcp.nthr_ic_b = nthr_ic_b;
1215             }
1216         }
1217 
1218         const bool ready_for_async = utils::one_of(jcp.ver, ver_fma);
1219         if (!ready_for_async && !dnnl_thr_syncable()) {
1220             assert(nthr_mb == 1);
1221             break;
1222         }
1223     }
1224     if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1225         jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1226 
1227     jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1228     assert(jcp.nthr <= nthreads);
1229 }
1230 
1231 } // namespace aarch64
1232 } // namespace cpu
1233 } // namespace impl
1234 } // namespace dnnl
1235