1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 * Copyright 2021 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #ifndef CPU_AARCH64_JIT_UNI_1X1_CONV_UTILS_HPP
19 #define CPU_AARCH64_JIT_UNI_1X1_CONV_UTILS_HPP
20
21 #include "common/convolution_pd.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/memory_tracking.hpp"
24 #include "common/nstl.hpp"
25 #include "common/primitive_iterator.hpp"
26 #include "common/type_helpers.hpp"
27 #include "common/utils.hpp"
28
29 #include "cpu/aarch64/jit_generator.hpp"
30 #include "cpu/aarch64/jit_primitive_conf.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace aarch64 {
36
37 struct reduce_to_unit_stride_t {
38 convolution_desc_t conv_d_;
39 bool reduce_src_;
40 size_t space_per_thread_;
41 };
42
43 /* 1x1-kernel does not support non-unit strides so far, so the idea is:
44 * - for fwd or bwd_weights: to copy src to a scratch memory (with strides
45 * equal to 1) and then call the kernel
46 * - for bwd_data: reduce the problem to the one with unit stride by
47 * performing computations in a scratch memory (with strides equal to 1)
48 * and then copy the result to diff_src */
49 template <typename conv_pd_t>
rtus_prepare(conv_pd_t * self,const convolution_desc_t * & conv_d,const memory_desc_t * & src_d,const memory_desc_t * dst_d)50 inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
51 const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
52 const int ndims = src_d->ndims;
53
54 bool rtus_applicable = utils::one_of(ndims, 3, 4);
55 if (ndims == 3)
56 rtus_applicable = rtus_applicable && conv_d->strides[0] != 1
57 && conv_d->src_desc.data_type != data_type::s32;
58 else
59 rtus_applicable = rtus_applicable
60 && (conv_d->strides[0] != 1 || conv_d->strides[1] != 1);
61 for (int d = 2; d < ndims; ++d) {
62 /* TODO: relax these conditions (by improving reducer) */
63 rtus_applicable = rtus_applicable && conv_d->padding[0][d - 2] == 0
64 && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
65 }
66 if (!rtus_applicable) return;
67
68 const auto dat_tag = ndims == 3
69 ? memory_desc_wrapper(src_d).matches_one_of_tag(
70 format_tag::nCw8c, format_tag::nCw16c, format_tag::nwc)
71 : memory_desc_wrapper(src_d).matches_one_of_tag(
72 format_tag::nChw8c, format_tag::nChw16c, format_tag::nhwc);
73 if (dat_tag == format_tag::undef) return;
74
75 const bool is_nspc
76 = utils::one_of(dat_tag, format_tag::nwc, format_tag::nhwc);
77 if (is_nspc && !mayiuse(sve_256)) return;
78
79 // rtus is applicable, configure it.
80 self->rtus_.reduce_src_ = true;
81 conv_d = &(self->rtus_.conv_d_ = *conv_d);
82 self->rtus_.conv_d_.strides[0] = 1;
83 if (ndims == 4) self->rtus_.conv_d_.strides[1] = 1;
84 utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
85 if (ndims == 4) utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
86 const int ic = src_d->dims[1];
87 if (self->desc()->prop_kind == prop_kind::backward_data) {
88 data_type_t data_type = self->rtus_.conv_d_.diff_src_desc.data_type;
89 src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
90 self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
91 self->rtus_.conv_d_.diff_src_desc.data_type = data_type;
92 memory_desc_wrapper::compute_blocking(
93 self->rtus_.conv_d_.diff_src_desc, dat_tag);
94 } else {
95 data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
96 src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
97 self->rtus_.conv_d_.src_desc.dims[1] = ic;
98 self->rtus_.conv_d_.src_desc.data_type = data_type;
99 memory_desc_wrapper::compute_blocking(
100 self->rtus_.conv_d_.src_desc, dat_tag);
101 }
102 }
103
104 template <typename conv_pd_t>
rtus_prepare_space_info(conv_pd_t * self,memory_tracking::registrar_t & scratchpad,int max_threads)105 inline void rtus_prepare_space_info(conv_pd_t *self,
106 memory_tracking::registrar_t &scratchpad, int max_threads) {
107 if (!self->rtus_.reduce_src_) return;
108 const auto &jcp = self->jcp_;
109 const bool is_nspc
110 = utils::one_of(jcp.src_tag, format_tag::nhwc, format_tag::nwc);
111
112 const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
113 jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
114 size_t typesize
115 = types::data_type_size(self->invariant_src_md()->data_type);
116
117 self->rtus_.space_per_thread_
118 = is_nspc ? jcp.is * jcp.ic : factor * jcp.is * jcp.ic_block;
119 scratchpad.book(memory_tracking::names::key_conv_rtus_space,
120 max_threads * self->rtus_.space_per_thread_, typesize);
121 }
122
123 template <cpu_isa_t isa>
124 struct rtus_driver_t : public jit_generator {
125
126 struct call_params_t {
127 const void *ws; /* reduced image (w/ strides = 1) */
128 const void *src; /* source image (w/ non-unit strides) */
129 size_t icb;
130 size_t os;
131 size_t iw_start;
132 };
133
134 DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
135
136 Xbyak_aarch64::XReg reg_ws = x5;
137 Xbyak_aarch64::XReg reg_src = x6;
138 Xbyak_aarch64::XReg reg_icb = x7;
139 Xbyak_aarch64::XReg reg_os = x8;
140 Xbyak_aarch64::XReg reg_iw_start = x9;
141
142 Xbyak_aarch64::XReg reg_cur_os = x10;
143 Xbyak_aarch64::XReg reg_cur_iw = x11;
144 Xbyak_aarch64::XReg reg_cur_src = x12;
145 Xbyak_aarch64::XReg reg_cur_src_fin = reg_cur_iw; /* just reuse */
146
147 Xbyak_aarch64::PReg tail_mask = p1;
148
149 // nspc section
150 Xbyak_aarch64::XReg reg_cur_icb = x13;
151 Xbyak_aarch64::XReg reg_tail_mask = x14;
152 Xbyak_aarch64::XReg reg_icb_remainder = x15;
153 Xbyak_aarch64::XReg reg_ws_copy = x16;
154 Xbyak_aarch64::XReg reg_tmp_imm = x17;
155 Xbyak_aarch64::XReg reg_tmp = x18;
156
157 Xbyak_aarch64::ZReg reg_zero = Xbyak_aarch64::ZReg(0);
158 Xbyak_aarch64::ZReg reg_v = Xbyak_aarch64::ZReg(1);
159
160 int iw_, stride_w_;
161 int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
162 bool src_to_ws_;
163 size_t typesize_;
164 int ic_, ic_tail_;
165 bool is_nspc_;
166
rtus_driver_tdnnl::impl::cpu::aarch64::rtus_driver_t167 rtus_driver_t(int iw, int stride_w, int src_step_h, int src_step_icb,
168 int ws_step_icb, bool src_to_ws, size_t typesize, int ic,
169 bool is_nspc = false)
170 : iw_(iw)
171 , stride_w_(stride_w)
172 , src_step_h_(src_step_h)
173 , src_step_icb_(src_step_icb)
174 , ws_step_icb_(ws_step_icb)
175 , src_to_ws_(src_to_ws)
176 , typesize_(typesize)
177 , ic_(ic)
178 , is_nspc_(is_nspc) {
179 using namespace Xbyak_aarch64;
180
181 assert(ic_ > 0);
182
183 auto Vmm = [=](int idx, size_t typesize) {
184 ZReg res = ZReg(idx);
185 if (is_nspc_) {
186 switch (isa) {
187 case sve_512: res = ZReg(idx); break;
188 default: assert(!"Not supported isa"); res = ZReg(idx);
189 }
190 return res;
191 }
192 switch (isa) {
193 case sve_512:
194 switch (typesize) {
195 case 4: res = ZReg(idx); break;
196 default:
197 assert(!"Not supported typesize");
198 res = ZReg(idx);
199 }
200 }
201 return res;
202 };
203
204 reg_zero = Vmm(0, typesize);
205 reg_v = Vmm(1, typesize);
206
207 vlen_ = reg_v.getBit() / 8;
208 vlen_shift_ = 0;
209
210 int tvlen = is_nspc_ ? typesize_ : vlen_;
211 while (tvlen > 1) {
212 tvlen /= 2;
213 vlen_shift_++;
214 }
215
216 const int simd_w = vlen_ / sizeof(float);
217 ic_tail_ = ic_ % simd_w;
218 }
219
loop_isdnnl::impl::cpu::aarch64::rtus_driver_t220 void loop_is() {
221 using namespace Xbyak_aarch64;
222
223 mov(reg_cur_src, reg_src);
224 mov(reg_cur_iw, reg_iw_start);
225 mov(reg_cur_os, reg_os);
226
227 Label is_loop;
228 L(is_loop);
229
230 if (src_to_ws_) {
231 ldr(reg_v, ptr(reg_cur_src));
232 str(reg_v, ptr(reg_ws));
233 } else {
234 ldr(reg_v, ptr(reg_ws));
235 str(reg_v, ptr(reg_cur_src));
236 for (int w = 1; w < stride_w_; ++w) {
237 add_imm(reg_tmp, reg_cur_src, w * vlen_, reg_tmp_imm);
238 str(reg_zero, ptr(reg_tmp));
239 }
240 }
241
242 add_imm(reg_ws, reg_ws, vlen_, reg_tmp_imm);
243 add_imm(reg_cur_src, reg_cur_src, stride_w_ * vlen_, reg_tmp_imm);
244
245 // for 1d or stride_h=1 convolutions the loop over h should be skipped
246 if (!(src_step_icb_ == iw_ || src_step_h_ == iw_)) {
247 Label skip_h_step;
248 add_imm(reg_cur_iw, reg_cur_iw, stride_w_, reg_tmp_imm);
249 cmp(reg_cur_iw, iw_);
250 b(LT, skip_h_step);
251
252 if (src_to_ws_) {
253 add_imm(reg_cur_src, reg_cur_src, (src_step_h_ - iw_) * vlen_,
254 reg_tmp_imm);
255 } else {
256 mov(reg_cur_src_fin, reg_cur_src);
257 add_imm(reg_cur_src_fin, reg_cur_src_fin,
258 (src_step_h_ - iw_) * vlen_, reg_tmp_imm);
259 Label ih_loop;
260 L(ih_loop);
261
262 for (int w = 0; w < stride_w_; ++w) {
263 add_imm(reg_tmp, reg_cur_src, w * vlen_, reg_tmp_imm);
264 str(reg_zero, ptr(reg_tmp));
265 }
266
267 add_imm(reg_cur_src, reg_cur_src, stride_w_ * vlen_,
268 reg_tmp_imm);
269 cmp(reg_cur_src, reg_cur_src_fin);
270 b(LT, ih_loop);
271 }
272 mov(reg_cur_iw, 0);
273 L(skip_h_step);
274 }
275
276 subs_imm(reg_cur_os, reg_cur_os, vlen_, reg_tmp_imm);
277 b(NE, is_loop);
278
279 /* restore dst */
280 sub(reg_ws, reg_ws, reg_os);
281 }
282
loop_is_nspcdnnl::impl::cpu::aarch64::rtus_driver_t283 void loop_is_nspc() {}
284
generatednnl::impl::cpu::aarch64::rtus_driver_t285 void generate() override {
286 using namespace Xbyak_aarch64;
287 assert(isa == sve_512);
288
289 preamble();
290 #define READ_PARAM(what) \
291 ldr(reg_##what, \
292 ptr(abi_param1, \
293 static_cast<int32_t>(offsetof(call_params_t, what))))
294 READ_PARAM(src);
295 READ_PARAM(icb);
296 READ_PARAM(os);
297 READ_PARAM(iw_start);
298 READ_PARAM(ws);
299 #undef READ_PARAM
300
301 if (!src_to_ws_) {
302 switch (reg_zero.getBit() / 8) {
303 case 64 /*ZReg*/: {
304 Xbyak_aarch64::ZRegS zreg_s(reg_zero.getIdx());
305 fmov(zreg_s); // zero clear
306 break;
307 }
308 default: assert(!"rtus kernel failure");
309 }
310 }
311 if (is_nspc_) {
312 assert(!"loop_is_nspc error");
313 } else {
314 lsl(reg_os, reg_os, vlen_shift_);
315
316 Label icb_loop;
317 L(icb_loop);
318
319 loop_is();
320
321 add_imm(reg_ws, reg_ws, ws_step_icb_ * vlen_, reg_tmp_imm);
322 add_imm(reg_src, reg_src, src_step_icb_ * vlen_, reg_tmp_imm);
323
324 subs_imm(reg_icb, reg_icb, vlen_ / typesize_, reg_tmp_imm);
325 b(NE, icb_loop);
326 }
327
328 postamble();
329 }
330 };
331
332 template <cpu_isa_t isa, typename conv_t>
init_rtus_driver(conv_t * self)333 inline status_t init_rtus_driver(conv_t *self) {
334 const auto &conf = *self->pd();
335 if (!conf.rtus_.reduce_src_) return status::success;
336
337 const auto &cd = *conf.desc();
338 const int ndims = conf.ndims();
339 const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
340 const int stride_w = cd.strides[ndims - 3];
341
342 const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
343 const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md();
344
345 const int ih = ndims == 3 ? 1 : src_d.dims[2];
346 const int iw = src_d.dims[ndims - 1];
347 const int ic = src_d.dims[1];
348
349 const auto src_tag = memory_desc_wrapper(src_d).matches_one_of_tag(
350 format_tag::nhwc, format_tag::nwc);
351 const bool is_nspc = src_tag != format_tag::undef;
352 const int src_step_h = stride_h * iw;
353 const int src_step_icb = !is_nspc ? ih * iw : 1;
354 const int ws_step_icb = !is_nspc ? conf.jcp_.is : 1;
355 const bool src_to_ws = !is_bwd_data;
356 const size_t typesize
357 = types::data_type_size(self->pd()->invariant_src_md()->data_type);
358
359 CHECK(safe_ptr_assign(self->rtus_driver_,
360 new rtus_driver_t<isa>(iw, stride_w, src_step_h, src_step_icb,
361 ws_step_icb, src_to_ws, typesize, ic, is_nspc)));
362 return self->rtus_driver_->create_kernel();
363 }
364
best_divider(int value,int min_divider,int max_divider,bool find_max,int step=1)365 inline int best_divider(int value, int min_divider, int max_divider,
366 bool find_max, int step = 1) {
367 using namespace dnnl::impl::utils;
368 max_divider = nstl::max(1, nstl::min(max_divider, value));
369 min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
370
371 auto loss_ratio = [](int total, int chunk) {
372 return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk);
373 };
374
375 float min_loss = FLT_MAX;
376 int x_divider = max_divider;
377 for (int divider = max_divider; divider >= min_divider; divider -= step) {
378 const float loss = loss_ratio(value, divider);
379 if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {
380 min_loss = loss;
381 x_divider = divider;
382 }
383 }
384 return x_divider;
385 }
386
387 typedef jit_1x1_conv_conf_t jcp_t;
388
is_bcast_layout_nxc(const jcp_t & jcp)389 inline bool is_bcast_layout_nxc(const jcp_t &jcp) {
390 switch (jcp.prop_kind) {
391 case prop_kind::forward_training:
392 case prop_kind::forward_inference:
393 case prop_kind::backward_weights:
394 return utils::one_of(jcp.src_tag, format_tag::ndhwc,
395 format_tag::nhwc, format_tag::nwc);
396 case prop_kind::backward_data:
397 return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
398 format_tag::nhwc, format_tag::nwc);
399 default: assert(!"invalid prop_kind"); return false;
400 }
401 }
402
is_load_layout_nxc(const jcp_t & jcp)403 inline bool is_load_layout_nxc(const jcp_t &jcp) {
404 return jcp.prop_kind == prop_kind::backward_weights
405 && utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc,
406 format_tag::nwc);
407 }
408
is_out_layout_nxc(const jcp_t & jcp)409 inline bool is_out_layout_nxc(const jcp_t &jcp) {
410 switch (jcp.prop_kind) {
411 case prop_kind::forward_training:
412 case prop_kind::forward_inference:
413 return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
414 format_tag::nhwc, format_tag::nwc);
415 case prop_kind::backward_data:
416 return utils::one_of(jcp.src_tag, format_tag::ndhwc,
417 format_tag::nhwc, format_tag::nwc);
418 case prop_kind::backward_weights: return false;
419 default: assert(!"invalid prop_kind"); return false;
420 }
421 }
422
get_bcast_u_offset(const jcp_t & jcp)423 inline size_t get_bcast_u_offset(const jcp_t &jcp) {
424 return is_bcast_layout_nxc(jcp) ? jcp.ic : jcp.ic_block;
425 }
426
get_bcast_j_offset(const jcp_t & jcp)427 inline size_t get_bcast_j_offset(const jcp_t &jcp) {
428 return is_bcast_layout_nxc(jcp) ? jcp.reduce_dim : jcp.reduce_loop_unroll;
429 }
430
get_bcast_offset(const jcp_t & jcp,int u,int j)431 inline size_t get_bcast_offset(const jcp_t &jcp, int u, int j) {
432 size_t offset;
433 if (utils::one_of(jcp.prop_kind, prop_kind::forward_training,
434 prop_kind::forward_inference, prop_kind::backward_data)) {
435 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
436 if (is_bcast_layout_nxc(jcp) || u != jcp.reduce_loop_unroll) {
437 offset = j * get_bcast_j_offset(jcp) + u;
438 } else {
439 offset = (jcp.bcast_dim + j) * get_bcast_j_offset(jcp);
440 }
441 } else {
442 offset = u * get_bcast_u_offset(jcp) + j;
443 }
444 return sizeof(float) * offset;
445 }
446
get_load_u_offset(const jcp_t & jcp)447 inline size_t get_load_u_offset(const jcp_t &jcp) {
448 return is_load_layout_nxc(jcp) ? jcp.oc : jcp.oc_block;
449 }
450
get_load_i_offset(const jcp_t & jcp)451 inline size_t get_load_i_offset(const jcp_t &jcp) {
452 return is_load_layout_nxc(jcp) ? jcp.oc_block : jcp.os;
453 }
454
get_load_bwd_w_offset(const jcp_t & jcp,int i,int u0)455 inline size_t get_load_bwd_w_offset(const jcp_t &jcp, int i, int u0) {
456 if (is_load_layout_nxc(jcp)) {
457 return i * get_load_i_offset(jcp) + u0 * get_load_u_offset(jcp);
458 } else {
459 return (i * get_load_i_offset(jcp) + u0) * get_load_u_offset(jcp);
460 }
461 }
462
get_output_i_offset(const jcp_t & jcp)463 inline size_t get_output_i_offset(const jcp_t &jcp) {
464 if (is_out_layout_nxc(jcp)) {
465 return jcp.load_block;
466 } else {
467 return (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
468 }
469 }
470
get_output_j_offset(const jcp_t & jcp)471 inline size_t get_output_j_offset(const jcp_t &jcp) {
472 return is_out_layout_nxc(jcp) ? jcp.load_dim : jcp.load_block;
473 }
474
get_load_loop_output_fwd_offset(const jcp_t & jcp,int load_loop_blk)475 inline size_t get_load_loop_output_fwd_offset(
476 const jcp_t &jcp, int load_loop_blk) {
477 size_t offset = load_loop_blk * jcp.oc_block * sizeof(float);
478 if (!is_out_layout_nxc(jcp)) {
479 offset *= jcp.with_dw_conv ? jcp.ow : jcp.os;
480 }
481 return offset;
482 }
483
get_load_loop_output_bwd_d_offset(const jcp_t & jcp,int load_loop_blk)484 inline size_t get_load_loop_output_bwd_d_offset(
485 const jcp_t &jcp, int load_loop_blk) {
486 size_t offset = load_loop_blk * jcp.ic_block * sizeof(float);
487 if (!is_out_layout_nxc(jcp)) { offset *= jcp.os; }
488 return offset;
489 }
490
491 } // namespace aarch64
492 } // namespace cpu
493 } // namespace impl
494 } // namespace dnnl
495
496 #endif
497