1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cassert>
18 #include <set>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_desc_wrapper.hpp"
23 #include "common/nstl.hpp"
24 #include "common/type_helpers.hpp"
25 #include "common/utils.hpp"
26 #include "oneapi/dnnl/dnnl_debug.h"
27
28 #include "cpu/x64/jit_uni_reorder.hpp"
29
30 // #define TR_DEBUG
31 #if defined(TR_DEBUG)
32 #define DEBUg(...) \
33 do { \
34 __VA_ARGS__ \
35 } while (0)
36 #else
37 #define DEBUg(...)
38 #endif
39 #define DEBUG(...) DEBUg(__VA_ARGS__)
40
41 using namespace dnnl::impl::types;
42 using namespace dnnl::impl::status;
43
44 namespace dnnl {
45 namespace impl {
46 namespace cpu {
47 namespace x64 {
48
49 namespace tr {
50
51 /** ad-hoc structure to describe blocked memory layout */
52 struct layout_desc_t {
layout_desc_tdnnl::impl::cpu::x64::tr::layout_desc_t53 layout_desc_t()
54 : dt(dnnl_data_type_undef)
55 , ndims(0)
56 , id {-1}
57 , dims {0}
58 , tails {0}
59 , is_blk {false}
60 , strides {0} {}
61 data_type_t dt;
62 int ndims;
63 dims_t id;
64 dims_t dims;
65 dims_t tails;
66 bool is_blk[DNNL_MAX_NDIMS];
67 strides_t strides;
68 };
69
cvt_mem_desc_to_layout_desc(const memory_desc_t & md_,layout_desc_t & ld,const dims_t & blocks,const dims_t & external_padding,const dims_t & tails)70 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
71 layout_desc_t &ld, const dims_t &blocks, const dims_t &external_padding,
72 const dims_t &tails) {
73 static constexpr bool it_is_blk = true;
74
75 const auto md = memory_desc_wrapper(md_);
76
77 if (!md.is_blocking_desc()) return invalid_arguments;
78
79 const auto &bd = md.blocking_desc();
80
81 ld.ndims = 0;
82 ld.dt = md.data_type();
83
84 auto add_dim = [&ld](int id, int dim, int tail, bool is_blk,
85 ptrdiff_t stride) {
86 assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
87 ld.id[ld.ndims] = id;
88 ld.dims[ld.ndims] = dim;
89 ld.strides[ld.ndims] = stride;
90 ld.tails[ld.ndims] = tail;
91 ld.is_blk[ld.ndims] = is_blk;
92 ++ld.ndims;
93 };
94
95 for (int d = 0; d < md.ndims(); ++d) {
96 const int ld_ndims_start = ld.ndims;
97 if (blocks[d] != 1) {
98 stride_t stride = 1;
99 int tail = tails[d];
100 for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
101 if (bd.inner_idxs[iblk] == d) {
102 const int inner_tail = tail % bd.inner_blks[iblk];
103 add_dim(d, bd.inner_blks[iblk], inner_tail, it_is_blk,
104 stride);
105 tail = utils::div_up(tail, bd.inner_blks[iblk]);
106 }
107 stride *= bd.inner_blks[iblk];
108 }
109 }
110
111 const int dim_with_external_padding
112 = (md.padded_dims()[d] + external_padding[d]) / blocks[d];
113 const int padded_dim = md.padded_dims()[d] / blocks[d];
114 const int tail = dim_with_external_padding != padded_dim
115 ? dim_with_external_padding
116 - (dim_with_external_padding - padded_dim)
117 : 0;
118
119 add_dim(d, dim_with_external_padding, tail, !it_is_blk, bd.strides[d]);
120
121 // TODO: NOW: revisit, do we need a reverse?
122 // TODO: NOW: consider using strides instead of block sizes in md
123 // reverse the order of dims
124 for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
125 const int idx0 = ld_ndims_start + ld_d;
126 const int idx1 = ld.ndims - 1 - ld_d;
127 nstl::swap(ld.dims[idx0], ld.dims[idx1]);
128 nstl::swap(ld.strides[idx0], ld.strides[idx1]);
129 nstl::swap(ld.tails[idx0], ld.tails[idx1]);
130 nstl::swap(ld.is_blk[idx0], ld.is_blk[idx1]);
131 }
132 }
133
134 return success;
135 }
136
is_with_groups(const memory_desc_t & dst_md)137 static bool is_with_groups(const memory_desc_t &dst_md) {
138 using namespace memory_extra_flags;
139 auto dst_d = memory_desc_wrapper(dst_md);
140 const int grp_bit = 1 << 1;
141 auto check_flag_and_mask = [&](int flag, int mask) {
142 return (dst_d.extra().flags & flag) && (mask & grp_bit);
143 };
144
145 return check_flag_and_mask(
146 compensation_conv_s8s8, dst_d.extra().compensation_mask)
147 || check_flag_and_mask(compensation_conv_asymmetric_src,
148 dst_d.extra().asymm_compensation_mask);
149 }
150
prb_set_compensation_strides(prb_t & p)151 static void prb_set_compensation_strides(prb_t &p) {
152 const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp;
153 if (!compensation_needed) return;
154 int mask = p.compensation_mask;
155 ptrdiff_t cs = 1;
156 for (int d = 0; d < p.ndims; ++d) {
157 if (mask & (1 << p.nodes[d].dim_id)) {
158 p.nodes[d].cs = cs;
159 cs = cs * p.nodes[d].n;
160 }
161 }
162 }
163
prb_init(prb_t & p,const memory_desc_t & imd,const memory_desc_t & omd,const primitive_attr_t * attr)164 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
165 const primitive_attr_t *attr) {
166 auto im_d = memory_desc_wrapper(imd);
167 auto om_d = memory_desc_wrapper(omd);
168
169 auto check_post_ops = [](const primitive_attr_t *attr) {
170 const auto &po = attr->post_ops_;
171 return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false));
172 };
173
174 bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc()
175 && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim()
176 && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim()
177 && attr->has_default_values(
178 primitive_attr_t::skip_mask_t::oscale_runtime
179 | primitive_attr_t::skip_mask_t::zero_points_runtime
180 | primitive_attr_t::skip_mask_t::post_ops)
181 && check_post_ops(attr);
182 if (!ok) return unimplemented;
183
184 bool is_tail_present = false;
185 dims_t iblocks, oblocks, i_tails, o_tails, i_paddings, o_paddings;
186 im_d.compute_blocks(iblocks);
187 om_d.compute_blocks(oblocks);
188
189 for (int d = 0; d < om_d.ndims(); ++d) {
190 const auto dim = om_d.dims()[d];
191 const auto pdim = om_d.padded_dims()[d];
192 const auto cblock = oblocks[d];
193 // do not allow excess pdim other than required for rounding-up of dim.
194 if (utils::rnd_up(dim, cblock) != pdim) return unimplemented;
195 }
196
197 utils::array_set(i_tails, 0, im_d.ndims());
198 utils::array_set(o_tails, 0, om_d.ndims());
199 utils::array_set(i_paddings, 0, im_d.ndims());
200 utils::array_set(o_paddings, 0, om_d.ndims());
201
202 for (int d = 0; d < im_d.ndims(); ++d) {
203 const int i_dim = im_d.dims()[d];
204 const int o_dim = om_d.dims()[d];
205 const int i_tail = i_dim % iblocks[d];
206 const int o_tail = o_dim % oblocks[d];
207
208 if (o_tail > 0) {
209 is_tail_present = true;
210 o_tails[d] = o_tail;
211 o_paddings[d] = oblocks[d] - o_tail;
212 }
213
214 if (i_tail > 0) {
215 is_tail_present = true;
216 i_tails[d] = i_tail;
217 i_paddings[d] = iblocks[d] - i_tail;
218 }
219 }
220
221 // To compute input layout description we need to pass output paddings
222 // which will be used to compute input dims rounded up to multiple of
223 // output dims. Analogous applies to output layout description.
224 // This is demanded by the algorithm of nodes creation.
225 // Example:
226 // input:
227 // format: abc
228 // size: 77, 15, 3
229 // o_padding: 3, 17, 0
230 // returns ild: 80, 32, 3
231 // output:
232 // format: ABc16b16a2b
233 // size: 77, 15, 3
234 // i_padding: 0, 0, 0
235 // returns old: 5, 16, 1, 16, 2, 3
236 layout_desc_t ild, old;
237 CHECK(cvt_mem_desc_to_layout_desc(imd, ild, iblocks, o_paddings, i_tails));
238 CHECK(cvt_mem_desc_to_layout_desc(omd, old, oblocks, i_paddings, o_tails));
239
240 p.itype = ild.dt;
241 p.otype = old.dt;
242 p.is_tail_present = is_tail_present;
243 p.req_src_zp = *attr->zero_points_.get(DNNL_ARG_SRC);
244 p.req_dst_zp = *attr->zero_points_.get(DNNL_ARG_DST);
245 p.scale_type = attr->output_scales_.has_default_values()
246 ? scale_type_t::NONE
247 : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON
248 : scale_type_t::MANY);
249 p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust)
250 ? om_d.extra().scale_adjust
251 : 1.f;
252 p.req_s8s8_comp
253 = om_d.extra().flags & memory_extra_flags::compensation_conv_s8s8;
254 p.req_asymmetric_comp = om_d.extra().flags
255 & memory_extra_flags::compensation_conv_asymmetric_src;
256
257 const bool with_groups = is_with_groups(omd);
258
259 auto mask_ok = [&](bool check, int mask) {
260 return IMPLICATION(check, mask == (with_groups ? 0x3 : 0x1));
261 };
262
263 if (!mask_ok(p.req_s8s8_comp, om_d.extra().compensation_mask)
264 || !mask_ok(p.req_asymmetric_comp,
265 om_d.extra().asymm_compensation_mask))
266 return status::unimplemented;
267
268 ptrdiff_t ss[max_ndims] = {0}; // scales strides
269 if (p.scale_type == scale_type_t::MANY) {
270 const int mask = attr->output_scales_.mask_;
271 ptrdiff_t dense_stride = 1;
272 ptrdiff_t last_stride = 1;
273 for (int d = old.ndims - 1; d >= 0; --d) {
274 assert((d == 0 || old.id[d - 1] <= old.id[d])
275 && "logical dimensions should be in ascending order");
276 if (mask & (1 << old.id[d])) {
277 if ((d + 1) < old.ndims && old.id[d + 1] != old.id[d]
278 && (mask & (1 << old.id[d + 1]))) {
279 dense_stride = dense_stride * imd.dims[old.id[d + 1]];
280 last_stride = dense_stride;
281 }
282 ss[d] = last_stride;
283 last_stride *= old.dims[d];
284 }
285 }
286 }
287
288 const auto compensation_needed = p.req_s8s8_comp || p.req_asymmetric_comp;
289 if (compensation_needed) {
290 p.compensation_mask = p.req_s8s8_comp
291 ? om_d.extra().compensation_mask
292 : (p.req_asymmetric_comp ? om_d.extra().asymm_compensation_mask
293 : tr::prb_t::invalid_comp_mask);
294
295 if (p.compensation_mask == tr::prb_t::asymmetric_comp_mask)
296 return unimplemented;
297
298 assert(p.compensation_mask == tr::prb_t::standard_comp_mask
299 || p.compensation_mask == tr::prb_t::comp_mask_with_groups);
300 }
301
302 int ndims = 0;
303
304 int i_pos = 0; /* state for input -- current dimension */
305 int o_pos = 0; /* state for output -- current dimension */
306
307 while (i_pos < ild.ndims && o_pos < old.ndims) {
308 assert(ild.id[i_pos] == old.id[o_pos]);
309
310 assert(ndims < max_ndims);
311 if (ndims == max_ndims) return runtime_error;
312
313 if (ild.dims[i_pos] == old.dims[o_pos]) {
314 p.nodes[ndims].n = ild.dims[i_pos];
315 p.nodes[ndims].dim_id = old.id[o_pos];
316 p.nodes[ndims].tail_size = old.tails[o_pos];
317 p.nodes[ndims].is_zero_pad_needed
318 = old.is_blk[o_pos] && old.tails[o_pos] > 0;
319 p.nodes[ndims].is = ild.strides[i_pos];
320 p.nodes[ndims].os = old.strides[o_pos];
321 p.nodes[ndims].ss = ss[o_pos];
322 ++ndims;
323 ++i_pos;
324 ++o_pos;
325 } else if (ild.dims[i_pos] < old.dims[o_pos]) {
326 // old must be divisible by ild or we will not be
327 // able to create valid nodes. The problem appears
328 // when stag=Acdb48a and dtag=Acdb32a for example.
329 if (ild.dims[i_pos] == 0 || old.dims[o_pos] % ild.dims[i_pos] != 0)
330 return status::unimplemented;
331
332 int factor = old.dims[o_pos] / ild.dims[i_pos];
333
334 const size_t tail_of_upper_dim
335 = utils::div_up(old.tails[o_pos], factor) == ild.dims[i_pos]
336 ? 0
337 : utils::div_up(old.tails[o_pos], factor);
338 const size_t tail_of_lower_dim = old.tails[o_pos] % factor;
339
340 p.nodes[ndims].n = ild.dims[i_pos];
341 p.nodes[ndims].dim_id = old.id[o_pos];
342 p.nodes[ndims].tail_size = tail_of_upper_dim;
343 p.nodes[ndims].is_zero_pad_needed
344 = old.is_blk[o_pos] && tail_of_upper_dim > 0;
345 p.nodes[ndims].is = ild.strides[i_pos];
346 p.nodes[ndims].os = old.strides[o_pos] * factor;
347 p.nodes[ndims].ss = ss[o_pos] * factor;
348 ++ndims;
349 ++i_pos;
350 old.dims[o_pos] = factor;
351 old.tails[o_pos] = tail_of_lower_dim;
352 } else if (ild.dims[i_pos] > old.dims[o_pos]) {
353 // ild must be divisible by old or we will not be
354 // able to create valid nodes. The problem appears
355 // when stag=Acdb32a and dtag=Acdb48a for example.
356 if (old.dims[o_pos] == 0 || ild.dims[i_pos] % old.dims[o_pos] != 0)
357 return status::unimplemented;
358
359 int factor = ild.dims[i_pos] / old.dims[o_pos];
360 p.nodes[ndims].n = old.dims[o_pos];
361 p.nodes[ndims].dim_id = old.id[o_pos];
362 p.nodes[ndims].tail_size = old.tails[o_pos];
363 p.nodes[ndims].is_zero_pad_needed
364 = old.is_blk[o_pos] && old.tails[o_pos] > 0;
365 p.nodes[ndims].is = ild.strides[i_pos] * factor;
366 p.nodes[ndims].os = old.strides[o_pos];
367 p.nodes[ndims].ss = ss[o_pos];
368 ++ndims;
369 ++o_pos;
370 ild.dims[i_pos] = factor;
371 }
372 }
373
374 p.ndims = ndims;
375 p.full_ndims = ndims;
376
377 p.ioff = memory_desc_wrapper(imd).offset0();
378 p.ooff = memory_desc_wrapper(omd).offset0();
379
380 const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
381 p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
382
383 DEBUG({
384 printf("init : ");
385 prb_dump(prb);
386 });
387 // Sort the prb array in increasing sizes of the output stride
388 prb_normalize(p);
389 DEBUG({
390 printf("norm : ");
391 prb_dump(prb);
392 });
393
394 // compensation strides require prb_normalized
395 prb_set_compensation_strides(p);
396
397 /* Combine the variables, which appear together on both
398 * sides of the reorder */
399 prb_simplify(p);
400 DEBUG({
401 printf("smpl : ");
402 prb_dump(prb);
403 });
404
405 return success;
406 }
407
prb_normalize(prb_t & p)408 void prb_normalize(prb_t &p) {
409 for (int d = 0; d < p.ndims; ++d) {
410 int min_pos = d;
411 for (int j = d + 1; j < p.ndims; ++j) {
412 bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os
413 || (true && p.nodes[j].os == p.nodes[min_pos].os
414 && p.nodes[j].n < p.nodes[min_pos].n);
415 if (new_min) min_pos = j;
416 }
417 if (min_pos != d) { nstl::swap(p.nodes[d], p.nodes[min_pos]); }
418 }
419 }
420
prb_node_dependency(prb_t & prb)421 void prb_node_dependency(prb_t &prb) {
422 for (int i = 0; i < prb.ndims; i++) {
423 tr::node_t &node = prb.nodes[i];
424 node.parent_node_id = node_t::empty_field;
425 for (int j = i + 1; j < prb.ndims; j++) {
426 const tr::node_t &potential_parent_node = prb.nodes[j];
427 if (!potential_parent_node.is_dim_id_empty()
428 && potential_parent_node.dim_id == node.dim_id) {
429 node.parent_node_id = j;
430 break;
431 }
432 }
433 }
434 }
435
prb_simplify(prb_t & p)436 void prb_simplify(prb_t &p) {
437 #if defined(__GNUC__) && __GNUC__ >= 4
438 /* GCC produces bogus array subscript is above array bounds warning for
439 * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
440 #pragma GCC diagnostic push
441 #pragma GCC diagnostic ignored "-Warray-bounds"
442 #endif
443
444 const auto skip_dim_combining = [&p](const int node_id) -> bool {
445 return (p.is_tail_in_one_of_child_nodes(node_id)
446 && p.nodes[node_id].n > 1)
447 || p.nodes[node_id].tail_size > 0;
448 };
449
450 if (p.is_tail_present) prb_node_dependency(p);
451
452 for (int d = 0; d < p.ndims - 1; ++d) {
453 auto &this_node = p.nodes[d + 0];
454 auto &next_node = p.nodes[d + 1];
455 const bool skip_dims_combining
456 = skip_dim_combining(d) || skip_dim_combining(d + 1);
457 const bool fold = false
458 || (next_node.n == static_cast<size_t>(1)
459 && !skip_dims_combining) // trivial case, just drop next node
460 || (true // or real folding if possible
461 && !skip_dims_combining
462 && next_node.is
463 == static_cast<ptrdiff_t>(
464 this_node.n * this_node.is)
465 && next_node.os
466 == static_cast<ptrdiff_t>(
467 this_node.n * this_node.os)
468 && next_node.ss
469 == static_cast<ptrdiff_t>(
470 this_node.n * this_node.ss)
471 && next_node.cs
472 == static_cast<ptrdiff_t>(
473 this_node.n * this_node.cs));
474 if (fold) {
475 this_node.n *= next_node.n;
476 this_node.dim_id = node_t::empty_field;
477 this_node.is_zero_pad_needed = false;
478 for (int j = d + 2; j < p.ndims; ++j)
479 p.nodes[j - 1] = p.nodes[j];
480 --p.ndims;
481 --p.full_ndims;
482 --d; // make another try
483 if (p.is_tail_present) prb_node_dependency(p);
484 }
485 }
486 #if defined(__GNUC__) && __GNUC__ >= 4
487 #pragma GCC diagnostic pop
488 #endif
489 }
490
prb_node_split(prb_t & p,int dim,size_t new_node_size)491 void prb_node_split(prb_t &p, int dim, size_t new_node_size) {
492 assert(dim < p.ndims);
493 assert(p.ndims < max_ndims);
494 assert(p.nodes[dim].n % new_node_size == 0);
495
496 p.ndims += 1;
497 p.full_ndims += 1;
498
499 for (int d = p.ndims; d > dim + 1; --d)
500 p.nodes[d] = p.nodes[d - 1];
501
502 const size_t upper_node_size = p.nodes[dim].n / new_node_size;
503 const size_t lower_node_size = new_node_size;
504 p.nodes[dim + 1].n = upper_node_size;
505 p.nodes[dim].n = lower_node_size;
506
507 const bool is_tail = p.nodes[dim].tail_size > 0;
508 const size_t upper_node_tail
509 = utils::div_up(p.nodes[dim].tail_size, lower_node_size)
510 == upper_node_size
511 ? 0
512 : utils::div_up(p.nodes[dim].tail_size, lower_node_size);
513 const size_t lower_node_tail = p.nodes[dim].tail_size % lower_node_size;
514 p.nodes[dim].tail_size = is_tail ? lower_node_tail : 0;
515 p.nodes[dim + 1].tail_size = is_tail ? upper_node_tail : 0;
516
517 p.nodes[dim + 1].is_zero_pad_needed
518 = p.nodes[dim].is_zero_pad_needed && p.nodes[dim + 1].tail_size > 0;
519 p.nodes[dim].is_zero_pad_needed
520 = p.nodes[dim].is_zero_pad_needed && p.nodes[dim].tail_size > 0;
521
522 p.nodes[dim + 1].dim_id = p.nodes[dim].dim_id;
523 p.nodes[dim + 1].is = p.nodes[dim].is * lower_node_size;
524 p.nodes[dim + 1].os = p.nodes[dim].os * lower_node_size;
525 p.nodes[dim + 1].ss = p.nodes[dim].ss * lower_node_size;
526 p.nodes[dim + 1].cs = p.nodes[dim].cs * lower_node_size;
527 }
528
prb_node_swap(prb_t & p,int d0,int d1)529 void prb_node_swap(prb_t &p, int d0, int d1) {
530 assert(d0 < p.ndims);
531 assert(d1 < p.ndims);
532 assert(p.ndims < max_ndims);
533
534 if (d0 == d1) return;
535
536 nstl::swap(p.nodes[d0], p.nodes[d1]);
537 }
538
prb_node_move(prb_t & p,int d0,int d1)539 void prb_node_move(prb_t &p, int d0, int d1) {
540 assert(d0 < p.ndims);
541 assert(d1 < p.ndims);
542 assert(p.ndims < max_ndims);
543
544 if (d0 == d1) return;
545
546 node_t node = p.nodes[d0];
547
548 if (d0 < d1)
549 for (int d = d0; d < d1; ++d)
550 p.nodes[d] = p.nodes[d + 1];
551 else
552 for (int d = d0; d > d1; --d)
553 p.nodes[d] = p.nodes[d - 1];
554
555 p.nodes[d1] = node;
556 }
557
prb_dump(const prb_t & p)558 void prb_dump(const prb_t &p) {
559 printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype),
560 dnnl_dt2str(p.otype), p.ndims);
561 for (int d = 0; d < p.ndims; ++d)
562 printf("[%zu:%zu:%d:%d:%s:%td:%td:%td:%td]", p.nodes[d].n,
563 p.nodes[d].tail_size, p.nodes[d].dim_id,
564 p.nodes[d].parent_node_id,
565 p.nodes[d].is_zero_pad_needed ? "true" : "false", p.nodes[d].is,
566 p.nodes[d].os, p.nodes[d].ss, p.nodes[d].cs);
567 printf(" off:%zu:%zu\n", p.ioff, p.ooff);
568 }
569
570 } // namespace tr
571
572 } // namespace x64
573 } // namespace cpu
574 } // namespace impl
575 } // namespace dnnl
576