1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #ifndef CPU_AARCH64_JIT_UNI_1X1_CONV_UTILS_HPP
19 #define CPU_AARCH64_JIT_UNI_1X1_CONV_UTILS_HPP
20 
21 #include "common/convolution_pd.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/memory_tracking.hpp"
24 #include "common/nstl.hpp"
25 #include "common/primitive_iterator.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28 
29 #include "cpu/aarch64/jit_generator.hpp"
30 #include "cpu/aarch64/jit_primitive_conf.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace aarch64 {
36 
37 struct reduce_to_unit_stride_t {
38     convolution_desc_t conv_d_;
39     bool reduce_src_;
40     size_t space_per_thread_;
41 };
42 
43 /* 1x1-kernel does not support non-unit strides so far, so the idea is:
44  *  - for fwd or bwd_weights: to copy src to a scratch memory (with strides
45  *    equal to 1) and then call the kernel
46  *  - for bwd_data: reduce the problem to the one with unit stride by
47  *    performing computations in a scratch memory (with strides equal to 1)
48  *    and then copy the result to diff_src */
49 template <typename conv_pd_t>
rtus_prepare(conv_pd_t * self,const convolution_desc_t * & conv_d,const memory_desc_t * & src_d,const memory_desc_t * dst_d)50 inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
51         const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
52     const int ndims = src_d->ndims;
53 
54     bool rtus_applicable = utils::one_of(ndims, 3, 4);
55     if (ndims == 3)
56         rtus_applicable = rtus_applicable && conv_d->strides[0] != 1
57                 && conv_d->src_desc.data_type != data_type::s32;
58     else
59         rtus_applicable = rtus_applicable
60                 && (conv_d->strides[0] != 1 || conv_d->strides[1] != 1);
61     for (int d = 2; d < ndims; ++d) {
62         /* TODO: relax these conditions (by improving reducer) */
63         rtus_applicable = rtus_applicable && conv_d->padding[0][d - 2] == 0
64                 && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
65     }
66     if (!rtus_applicable) return;
67 
68     const auto dat_tag = ndims == 3
69             ? memory_desc_wrapper(src_d).matches_one_of_tag(
70                     format_tag::nCw8c, format_tag::nCw16c, format_tag::nwc)
71             : memory_desc_wrapper(src_d).matches_one_of_tag(
72                     format_tag::nChw8c, format_tag::nChw16c, format_tag::nhwc);
73     if (dat_tag == format_tag::undef) return;
74 
75     const bool is_nspc
76             = utils::one_of(dat_tag, format_tag::nwc, format_tag::nhwc);
77     if (is_nspc && !mayiuse(sve_256)) return;
78 
79     // rtus is applicable, configure it.
80     self->rtus_.reduce_src_ = true;
81     conv_d = &(self->rtus_.conv_d_ = *conv_d);
82     self->rtus_.conv_d_.strides[0] = 1;
83     if (ndims == 4) self->rtus_.conv_d_.strides[1] = 1;
84     utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
85     if (ndims == 4) utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
86     const int ic = src_d->dims[1];
87     if (self->desc()->prop_kind == prop_kind::backward_data) {
88         data_type_t data_type = self->rtus_.conv_d_.diff_src_desc.data_type;
89         src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
90         self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
91         self->rtus_.conv_d_.diff_src_desc.data_type = data_type;
92         memory_desc_wrapper::compute_blocking(
93                 self->rtus_.conv_d_.diff_src_desc, dat_tag);
94     } else {
95         data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
96         src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
97         self->rtus_.conv_d_.src_desc.dims[1] = ic;
98         self->rtus_.conv_d_.src_desc.data_type = data_type;
99         memory_desc_wrapper::compute_blocking(
100                 self->rtus_.conv_d_.src_desc, dat_tag);
101     }
102 }
103 
104 template <typename conv_pd_t>
rtus_prepare_space_info(conv_pd_t * self,memory_tracking::registrar_t & scratchpad,int max_threads)105 inline void rtus_prepare_space_info(conv_pd_t *self,
106         memory_tracking::registrar_t &scratchpad, int max_threads) {
107     if (!self->rtus_.reduce_src_) return;
108     const auto &jcp = self->jcp_;
109     const bool is_nspc
110             = utils::one_of(jcp.src_tag, format_tag::nhwc, format_tag::nwc);
111 
112     const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
113             jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
114     size_t typesize
115             = types::data_type_size(self->invariant_src_md()->data_type);
116 
117     self->rtus_.space_per_thread_
118             = is_nspc ? jcp.is * jcp.ic : factor * jcp.is * jcp.ic_block;
119     scratchpad.book(memory_tracking::names::key_conv_rtus_space,
120             max_threads * self->rtus_.space_per_thread_, typesize);
121 }
122 
123 template <cpu_isa_t isa>
124 struct rtus_driver_t : public jit_generator {
125 
126     struct call_params_t {
127         const void *ws; /* reduced image (w/ strides = 1) */
128         const void *src; /* source image (w/ non-unit strides) */
129         size_t icb;
130         size_t os;
131         size_t iw_start;
132     };
133 
134     DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
135 
136     Xbyak_aarch64::XReg reg_ws = x5;
137     Xbyak_aarch64::XReg reg_src = x6;
138     Xbyak_aarch64::XReg reg_icb = x7;
139     Xbyak_aarch64::XReg reg_os = x8;
140     Xbyak_aarch64::XReg reg_iw_start = x9;
141 
142     Xbyak_aarch64::XReg reg_cur_os = x10;
143     Xbyak_aarch64::XReg reg_cur_iw = x11;
144     Xbyak_aarch64::XReg reg_cur_src = x12;
145     Xbyak_aarch64::XReg reg_cur_src_fin = reg_cur_iw; /* just reuse */
146 
147     Xbyak_aarch64::PReg tail_mask = p1;
148 
149     // nspc section
150     Xbyak_aarch64::XReg reg_cur_icb = x13;
151     Xbyak_aarch64::XReg reg_tail_mask = x14;
152     Xbyak_aarch64::XReg reg_icb_remainder = x15;
153     Xbyak_aarch64::XReg reg_ws_copy = x16;
154     Xbyak_aarch64::XReg reg_tmp_imm = x17;
155     Xbyak_aarch64::XReg reg_tmp = x18;
156 
157     Xbyak_aarch64::ZReg reg_zero = Xbyak_aarch64::ZReg(0);
158     Xbyak_aarch64::ZReg reg_v = Xbyak_aarch64::ZReg(1);
159 
160     int iw_, stride_w_;
161     int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
162     bool src_to_ws_;
163     size_t typesize_;
164     int ic_, ic_tail_;
165     bool is_nspc_;
166 
rtus_driver_tdnnl::impl::cpu::aarch64::rtus_driver_t167     rtus_driver_t(int iw, int stride_w, int src_step_h, int src_step_icb,
168             int ws_step_icb, bool src_to_ws, size_t typesize, int ic,
169             bool is_nspc = false)
170         : iw_(iw)
171         , stride_w_(stride_w)
172         , src_step_h_(src_step_h)
173         , src_step_icb_(src_step_icb)
174         , ws_step_icb_(ws_step_icb)
175         , src_to_ws_(src_to_ws)
176         , typesize_(typesize)
177         , ic_(ic)
178         , is_nspc_(is_nspc) {
179         using namespace Xbyak_aarch64;
180 
181         assert(ic_ > 0);
182 
183         auto Vmm = [=](int idx, size_t typesize) {
184             ZReg res = ZReg(idx);
185             if (is_nspc_) {
186                 switch (isa) {
187                     case sve_512: res = ZReg(idx); break;
188                     default: assert(!"Not supported isa"); res = ZReg(idx);
189                 }
190                 return res;
191             }
192             switch (isa) {
193                 case sve_512:
194                     switch (typesize) {
195                         case 4: res = ZReg(idx); break;
196                         default:
197                             assert(!"Not supported typesize");
198                             res = ZReg(idx);
199                     }
200             }
201             return res;
202         };
203 
204         reg_zero = Vmm(0, typesize);
205         reg_v = Vmm(1, typesize);
206 
207         vlen_ = reg_v.getBit() / 8;
208         vlen_shift_ = 0;
209 
210         int tvlen = is_nspc_ ? typesize_ : vlen_;
211         while (tvlen > 1) {
212             tvlen /= 2;
213             vlen_shift_++;
214         }
215 
216         const int simd_w = vlen_ / sizeof(float);
217         ic_tail_ = ic_ % simd_w;
218     }
219 
loop_isdnnl::impl::cpu::aarch64::rtus_driver_t220     void loop_is() {
221         using namespace Xbyak_aarch64;
222 
223         mov(reg_cur_src, reg_src);
224         mov(reg_cur_iw, reg_iw_start);
225         mov(reg_cur_os, reg_os);
226 
227         Label is_loop;
228         L(is_loop);
229 
230         if (src_to_ws_) {
231             ldr(reg_v, ptr(reg_cur_src));
232             str(reg_v, ptr(reg_ws));
233         } else {
234             ldr(reg_v, ptr(reg_ws));
235             str(reg_v, ptr(reg_cur_src));
236             for (int w = 1; w < stride_w_; ++w) {
237                 add_imm(reg_tmp, reg_cur_src, w * vlen_, reg_tmp_imm);
238                 str(reg_zero, ptr(reg_tmp));
239             }
240         }
241 
242         add_imm(reg_ws, reg_ws, vlen_, reg_tmp_imm);
243         add_imm(reg_cur_src, reg_cur_src, stride_w_ * vlen_, reg_tmp_imm);
244 
245         // for 1d or stride_h=1 convolutions the loop over h should be skipped
246         if (!(src_step_icb_ == iw_ || src_step_h_ == iw_)) {
247             Label skip_h_step;
248             add_imm(reg_cur_iw, reg_cur_iw, stride_w_, reg_tmp_imm);
249             cmp(reg_cur_iw, iw_);
250             b(LT, skip_h_step);
251 
252             if (src_to_ws_) {
253                 add_imm(reg_cur_src, reg_cur_src, (src_step_h_ - iw_) * vlen_,
254                         reg_tmp_imm);
255             } else {
256                 mov(reg_cur_src_fin, reg_cur_src);
257                 add_imm(reg_cur_src_fin, reg_cur_src_fin,
258                         (src_step_h_ - iw_) * vlen_, reg_tmp_imm);
259                 Label ih_loop;
260                 L(ih_loop);
261 
262                 for (int w = 0; w < stride_w_; ++w) {
263                     add_imm(reg_tmp, reg_cur_src, w * vlen_, reg_tmp_imm);
264                     str(reg_zero, ptr(reg_tmp));
265                 }
266 
267                 add_imm(reg_cur_src, reg_cur_src, stride_w_ * vlen_,
268                         reg_tmp_imm);
269                 cmp(reg_cur_src, reg_cur_src_fin);
270                 b(LT, ih_loop);
271             }
272             mov(reg_cur_iw, 0);
273             L(skip_h_step);
274         }
275 
276         subs_imm(reg_cur_os, reg_cur_os, vlen_, reg_tmp_imm);
277         b(NE, is_loop);
278 
279         /* restore dst */
280         sub(reg_ws, reg_ws, reg_os);
281     }
282 
loop_is_nspcdnnl::impl::cpu::aarch64::rtus_driver_t283     void loop_is_nspc() {}
284 
generatednnl::impl::cpu::aarch64::rtus_driver_t285     void generate() override {
286         using namespace Xbyak_aarch64;
287         assert(isa == sve_512);
288 
289         preamble();
290 #define READ_PARAM(what) \
291     ldr(reg_##what, \
292             ptr(abi_param1, \
293                     static_cast<int32_t>(offsetof(call_params_t, what))))
294         READ_PARAM(src);
295         READ_PARAM(icb);
296         READ_PARAM(os);
297         READ_PARAM(iw_start);
298         READ_PARAM(ws);
299 #undef READ_PARAM
300 
301         if (!src_to_ws_) {
302             switch (reg_zero.getBit() / 8) {
303                 case 64 /*ZReg*/: {
304                     Xbyak_aarch64::ZRegS zreg_s(reg_zero.getIdx());
305                     fmov(zreg_s); // zero clear
306                     break;
307                 }
308                 default: assert(!"rtus kernel failure");
309             }
310         }
311         if (is_nspc_) {
312             assert(!"loop_is_nspc error");
313         } else {
314             lsl(reg_os, reg_os, vlen_shift_);
315 
316             Label icb_loop;
317             L(icb_loop);
318 
319             loop_is();
320 
321             add_imm(reg_ws, reg_ws, ws_step_icb_ * vlen_, reg_tmp_imm);
322             add_imm(reg_src, reg_src, src_step_icb_ * vlen_, reg_tmp_imm);
323 
324             subs_imm(reg_icb, reg_icb, vlen_ / typesize_, reg_tmp_imm);
325             b(NE, icb_loop);
326         }
327 
328         postamble();
329     }
330 };
331 
332 template <cpu_isa_t isa, typename conv_t>
init_rtus_driver(conv_t * self)333 inline status_t init_rtus_driver(conv_t *self) {
334     const auto &conf = *self->pd();
335     if (!conf.rtus_.reduce_src_) return status::success;
336 
337     const auto &cd = *conf.desc();
338     const int ndims = conf.ndims();
339     const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
340     const int stride_w = cd.strides[ndims - 3];
341 
342     const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
343     const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md();
344 
345     const int ih = ndims == 3 ? 1 : src_d.dims[2];
346     const int iw = src_d.dims[ndims - 1];
347     const int ic = src_d.dims[1];
348 
349     const auto src_tag = memory_desc_wrapper(src_d).matches_one_of_tag(
350             format_tag::nhwc, format_tag::nwc);
351     const bool is_nspc = src_tag != format_tag::undef;
352     const int src_step_h = stride_h * iw;
353     const int src_step_icb = !is_nspc ? ih * iw : 1;
354     const int ws_step_icb = !is_nspc ? conf.jcp_.is : 1;
355     const bool src_to_ws = !is_bwd_data;
356     const size_t typesize
357             = types::data_type_size(self->pd()->invariant_src_md()->data_type);
358 
359     CHECK(safe_ptr_assign(self->rtus_driver_,
360             new rtus_driver_t<isa>(iw, stride_w, src_step_h, src_step_icb,
361                     ws_step_icb, src_to_ws, typesize, ic, is_nspc)));
362     return self->rtus_driver_->create_kernel();
363 }
364 
best_divider(int value,int min_divider,int max_divider,bool find_max,int step=1)365 inline int best_divider(int value, int min_divider, int max_divider,
366         bool find_max, int step = 1) {
367     using namespace dnnl::impl::utils;
368     max_divider = nstl::max(1, nstl::min(max_divider, value));
369     min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
370 
371     auto loss_ratio = [](int total, int chunk) {
372         return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk);
373     };
374 
375     float min_loss = FLT_MAX;
376     int x_divider = max_divider;
377     for (int divider = max_divider; divider >= min_divider; divider -= step) {
378         const float loss = loss_ratio(value, divider);
379         if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {
380             min_loss = loss;
381             x_divider = divider;
382         }
383     }
384     return x_divider;
385 }
386 
387 typedef jit_1x1_conv_conf_t jcp_t;
388 
is_bcast_layout_nxc(const jcp_t & jcp)389 inline bool is_bcast_layout_nxc(const jcp_t &jcp) {
390     switch (jcp.prop_kind) {
391         case prop_kind::forward_training:
392         case prop_kind::forward_inference:
393         case prop_kind::backward_weights:
394             return utils::one_of(jcp.src_tag, format_tag::ndhwc,
395                     format_tag::nhwc, format_tag::nwc);
396         case prop_kind::backward_data:
397             return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
398                     format_tag::nhwc, format_tag::nwc);
399         default: assert(!"invalid prop_kind"); return false;
400     }
401 }
402 
is_load_layout_nxc(const jcp_t & jcp)403 inline bool is_load_layout_nxc(const jcp_t &jcp) {
404     return jcp.prop_kind == prop_kind::backward_weights
405             && utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
406                     format_tag::nwc);
407 }
408 
is_out_layout_nxc(const jcp_t & jcp)409 inline bool is_out_layout_nxc(const jcp_t &jcp) {
410     switch (jcp.prop_kind) {
411         case prop_kind::forward_training:
412         case prop_kind::forward_inference:
413             return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
414                     format_tag::nhwc, format_tag::nwc);
415         case prop_kind::backward_data:
416             return utils::one_of(jcp.src_tag, format_tag::ndhwc,
417                     format_tag::nhwc, format_tag::nwc);
418         case prop_kind::backward_weights: return false;
419         default: assert(!"invalid prop_kind"); return false;
420     }
421 }
422 
get_bcast_u_offset(const jcp_t & jcp)423 inline size_t get_bcast_u_offset(const jcp_t &jcp) {
424     return is_bcast_layout_nxc(jcp) ? jcp.ic : jcp.ic_block;
425 }
426 
get_bcast_j_offset(const jcp_t & jcp)427 inline size_t get_bcast_j_offset(const jcp_t &jcp) {
428     return is_bcast_layout_nxc(jcp) ? jcp.reduce_dim : jcp.reduce_loop_unroll;
429 }
430 
get_bcast_offset(const jcp_t & jcp,int u,int j)431 inline size_t get_bcast_offset(const jcp_t &jcp, int u, int j) {
432     size_t offset;
433     if (utils::one_of(jcp.prop_kind, prop_kind::forward_training,
434                 prop_kind::forward_inference, prop_kind::backward_data)) {
435         assert(jcp.reduce_loop_unroll == jcp.reduce_block);
436         if (is_bcast_layout_nxc(jcp) || u != jcp.reduce_loop_unroll) {
437             offset = j * get_bcast_j_offset(jcp) + u;
438         } else {
439             offset = (jcp.bcast_dim + j) * get_bcast_j_offset(jcp);
440         }
441     } else {
442         offset = u * get_bcast_u_offset(jcp) + j;
443     }
444     return sizeof(float) * offset;
445 }
446 
get_load_u_offset(const jcp_t & jcp)447 inline size_t get_load_u_offset(const jcp_t &jcp) {
448     return is_load_layout_nxc(jcp) ? jcp.oc : jcp.oc_block;
449 }
450 
get_load_i_offset(const jcp_t & jcp)451 inline size_t get_load_i_offset(const jcp_t &jcp) {
452     return is_load_layout_nxc(jcp) ? jcp.oc_block : jcp.os;
453 }
454 
get_load_bwd_w_offset(const jcp_t & jcp,int i,int u0)455 inline size_t get_load_bwd_w_offset(const jcp_t &jcp, int i, int u0) {
456     if (is_load_layout_nxc(jcp)) {
457         return i * get_load_i_offset(jcp) + u0 * get_load_u_offset(jcp);
458     } else {
459         return (i * get_load_i_offset(jcp) + u0) * get_load_u_offset(jcp);
460     }
461 }
462 
get_output_i_offset(const jcp_t & jcp)463 inline size_t get_output_i_offset(const jcp_t &jcp) {
464     if (is_out_layout_nxc(jcp)) {
465         return jcp.load_block;
466     } else {
467         return (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
468     }
469 }
470 
get_output_j_offset(const jcp_t & jcp)471 inline size_t get_output_j_offset(const jcp_t &jcp) {
472     return is_out_layout_nxc(jcp) ? jcp.load_dim : jcp.load_block;
473 }
474 
get_load_loop_output_fwd_offset(const jcp_t & jcp,int load_loop_blk)475 inline size_t get_load_loop_output_fwd_offset(
476         const jcp_t &jcp, int load_loop_blk) {
477     size_t offset = load_loop_blk * jcp.oc_block * sizeof(float);
478     if (!is_out_layout_nxc(jcp)) {
479         offset *= jcp.with_dw_conv ? jcp.ow : jcp.os;
480     }
481     return offset;
482 }
483 
get_load_loop_output_bwd_d_offset(const jcp_t & jcp,int load_loop_blk)484 inline size_t get_load_loop_output_bwd_d_offset(
485         const jcp_t &jcp, int load_loop_blk) {
486     size_t offset = load_loop_blk * jcp.ic_block * sizeof(float);
487     if (!is_out_layout_nxc(jcp)) { offset *= jcp.os; }
488     return offset;
489 }
490 
491 } // namespace aarch64
492 } // namespace cpu
493 } // namespace impl
494 } // namespace dnnl
495 
496 #endif
497