1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/memory_desc_wrapper.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 #include "oneapi/dnnl/dnnl_debug.h"
26 
27 #include "cpu/x64/jit_uni_reorder.hpp"
28 
29 using namespace dnnl::impl::types;
30 using namespace dnnl::impl::status;
31 
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36 
37 namespace tr {
38 
39 /** ad-hoc structure to describe blocked memory layout */
40 struct layout_desc_t {
41     data_type_t dt;
42     int ndims;
43     dims_t id;
44     dims_t dims;
45     strides_t strides;
46 };
47 
compute_blk_and_tail(const memory_desc_t & md_,const int idx,int & blk,int & tail)48 static status_t compute_blk_and_tail(
49         const memory_desc_t &md_, const int idx, int &blk, int &tail) {
50     const auto md = memory_desc_wrapper(md_);
51     const auto &bd = md.blocking_desc();
52     if (tail == 0) return status::success;
53 
54     // Only supports inconsistent padding in single and double blocks
55     // and the total block size <= 256
56     for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) {
57         if (bd.inner_idxs[iblk] == idx) break;
58         blk *= bd.inner_blks[iblk];
59         tail *= bd.inner_blks[iblk];
60     }
61     if (bd.inner_nblks > 2 || blk > 256) return status::unimplemented;
62 
63     return status::success;
64 }
65 
compute_chunk_idx(const prb_t & p,const memory_desc_t & imd_,const memory_desc_t & omd_,const int blk_idx,int & chunk_idx)66 static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_,
67         const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) {
68     const auto imd = memory_desc_wrapper(imd_);
69     const auto omd = memory_desc_wrapper(omd_);
70     const auto &ibd = imd.blocking_desc();
71     const auto &obd = omd.blocking_desc();
72     if (p.ip_tail == 0 && p.op_tail == 0) return status::success;
73 
74     const ptrdiff_t is
75             = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]];
76     const ptrdiff_t os = obd.strides[blk_idx];
77 
78     for (int i = blk_idx; i < omd.ndims(); ++i) {
79         if (p.nodes[i].os == os && p.nodes[i].is == is) {
80             chunk_idx = i;
81             return status::success;
82         }
83     }
84 
85     return status::invalid_arguments;
86 }
87 
cvt_mem_desc_to_layout_desc(const memory_desc_t & md_,layout_desc_t & ld,const dims_t & blocks,const dims_t & ext_padding)88 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
89         layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) {
90     const auto md = memory_desc_wrapper(md_);
91 
92     bool ok = true && md.is_blocking_desc() && md.extra().flags == 0;
93     if (!ok) return invalid_arguments;
94 
95     const auto &bd = md.blocking_desc();
96 
97     ld.ndims = 0;
98     ld.dt = md.data_type();
99 
100     auto P = [&ld](int id, int dim, ptrdiff_t stride) {
101         assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
102         ld.id[ld.ndims] = id;
103         ld.dims[ld.ndims] = dim;
104         ld.strides[ld.ndims] = stride;
105         ++ld.ndims;
106     };
107 
108     for (int d = 0; d < md.ndims(); ++d) {
109         const int ld_ndims_start = ld.ndims;
110         if (blocks[d] != 1) {
111             stride_t stride = 1;
112             for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
113                 if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride);
114                 stride *= bd.inner_blks[iblk];
115             }
116         }
117         P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]);
118 
119         // TODO: NOW: revisit, do we need a reverse?
120         // TODO: NOW: consider using strides instead of block sizes in md
121         // reverse the order of dims
122         for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
123             const int idx0 = ld_ndims_start + ld_d;
124             const int idx1 = ld.ndims - 1 - ld_d;
125             nstl::swap(ld.dims[idx0], ld.dims[idx1]);
126             nstl::swap(ld.strides[idx0], ld.strides[idx1]);
127         }
128     }
129 
130     return success;
131 }
132 
prb_init(prb_t & p,const memory_desc_t & imd,const memory_desc_t & omd,const primitive_attr_t * attr)133 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
134         const primitive_attr_t *attr) {
135     auto im_d = memory_desc_wrapper(imd);
136     auto om_d = memory_desc_wrapper(omd);
137 
138     auto check_post_ops = [](const primitive_attr_t *attr) {
139         const auto &po = attr->post_ops_;
140         return po.len() == 0
141                 || (po.len() == 1 && po.contain(primitive_kind::sum, 0));
142     };
143 
144     bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc()
145             && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim()
146             && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim()
147             && attr->has_default_values(
148                     primitive_attr_t::skip_mask_t::oscale_runtime
149                     | primitive_attr_t::skip_mask_t::post_ops)
150             && check_post_ops(attr);
151     if (!ok) return unimplemented;
152 
153     dims_t iblocks, oblocks, ip_padding, op_padding;
154     im_d.compute_blocks(iblocks);
155     om_d.compute_blocks(oblocks);
156     utils::array_set(ip_padding, 0, im_d.ndims());
157     utils::array_set(op_padding, 0, om_d.ndims());
158 
159     /* padding_dim consistency check
160      * only supports inconsitent padding for src
161      * TODO: Add inconsistent padding support for dst */
162     int ip_tail = 0;
163     int op_tail = 0;
164     int iblk_w_tail = 1;
165     int oblk_w_tail = 1;
166     int blk_idx = 0;
167 
168     for (int d = 0; d < im_d.ndims(); ++d) {
169         const int ip_tmp_dim = im_d.padded_dims()[d];
170         const int op_tmp_dim = om_d.padded_dims()[d];
171         const int ip_tmp_tail = ip_tmp_dim % oblocks[d];
172         const int op_tmp_tail = op_tmp_dim % iblocks[d];
173 
174         const bool pdim_consistent = ip_tmp_dim == op_tmp_dim
175                 && ip_tmp_tail == 0 && op_tmp_tail == 0;
176         const bool pdim_tail = ip_tmp_tail > 0
177                 && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim
178                 && op_tmp_tail == 0 && ip_tail == 0;
179         if (!pdim_consistent && !pdim_tail) return status::unimplemented;
180         if (pdim_tail) {
181             blk_idx = d;
182             ip_tail = ip_tmp_tail;
183             op_tail = op_tmp_tail;
184             iblk_w_tail = iblocks[d];
185             oblk_w_tail = oblocks[d];
186             ip_padding[d] = oblocks[d] - ip_tmp_tail;
187             op_padding[d] = iblocks[d] - op_tmp_tail;
188         }
189     }
190     CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail));
191 
192     layout_desc_t ild, old;
193     status_t status
194             = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding);
195     if (status != success) return status;
196     status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding);
197     if (status != success) return status;
198 
199     p.itype = ild.dt;
200     p.otype = old.dt;
201     p.ip_tail = ip_tail;
202     p.op_tail = op_tail;
203     p.iblock = iblk_w_tail;
204     p.oblock = oblk_w_tail;
205 
206     p.scale_type = attr->output_scales_.has_default_values()
207             ? scale_type_t::NONE
208             : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON
209                                                : scale_type_t::MANY);
210 
211     ptrdiff_t ss[max_ndims] = {0};
212     if (p.scale_type == scale_type_t::MANY) {
213         ptrdiff_t last_ss = 1;
214         for (int d = old.ndims - 1; d >= 0; --d) {
215             assert((d == 0 || old.id[d - 1] <= old.id[d])
216                     && "logical dimensions should be in ascending order");
217             if (attr->output_scales_.mask_ & (1 << old.id[d])) {
218                 ss[d] = last_ss;
219                 last_ss *= old.dims[d];
220             }
221         }
222     }
223 
224     int ndims = 0;
225 
226     int i_pos = 0; /* state for input  -- current dimension */
227     int o_pos = 0; /* state for output -- current dimension */
228 
229     while (i_pos < ild.ndims && o_pos < old.ndims) {
230         assert(ild.id[i_pos] == old.id[o_pos]);
231 
232         assert(ndims < max_ndims);
233         if (ndims == max_ndims) return runtime_error;
234 
235         if (ild.dims[i_pos] == old.dims[o_pos]) {
236             p.nodes[ndims].n = ild.dims[i_pos];
237             p.nodes[ndims].is = ild.strides[i_pos];
238             p.nodes[ndims].os = old.strides[o_pos];
239             p.nodes[ndims].ss = ss[o_pos];
240             ++ndims;
241             ++i_pos;
242             ++o_pos;
243         } else if (ild.dims[i_pos] < old.dims[o_pos]) {
244             assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
245             int factor = old.dims[o_pos] / ild.dims[i_pos];
246             p.nodes[ndims].n = ild.dims[i_pos];
247             p.nodes[ndims].is = ild.strides[i_pos];
248             p.nodes[ndims].os = old.strides[o_pos] * factor;
249             p.nodes[ndims].ss = ss[o_pos] * factor;
250             ++ndims;
251             ++i_pos;
252             old.dims[o_pos] = factor;
253         } else if (ild.dims[i_pos] > old.dims[o_pos]) {
254             assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
255             int factor = ild.dims[i_pos] / old.dims[o_pos];
256             p.nodes[ndims].n = old.dims[o_pos];
257             p.nodes[ndims].is = ild.strides[i_pos] * factor;
258             p.nodes[ndims].os = old.strides[o_pos];
259             p.nodes[ndims].ss = ss[o_pos];
260             ++ndims;
261             ++o_pos;
262             ild.dims[i_pos] = factor;
263         }
264     }
265     int blk_chunk_idx = ndims - 1;
266     CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx));
267 
268     p.ndims = ndims;
269     p.full_ndims = ndims;
270     p.blk_chunk_idx = blk_chunk_idx;
271 
272     p.ioff = memory_desc_wrapper(imd).offset0();
273     p.ooff = memory_desc_wrapper(omd).offset0();
274 
275     const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
276     p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
277 
278     return success;
279 }
280 
prb_check_blk(prb_t & p,const memory_desc_t & md_)281 status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) {
282     const auto md = memory_desc_wrapper(md_);
283     const auto &bd = md.blocking_desc();
284     if (p.ip_tail == 0) return status::success;
285 
286     // Check if the inner blocks and p.nodes[blk].n in the firsti nblks
287     // is equivalent in reverse order when has tail in block layout.
288     const int nblk = bd.inner_nblks;
289     for (int iblk = 0; iblk < nblk; ++iblk) {
290         if (bd.inner_blks[nblk - iblk - 1]
291                 != static_cast<ptrdiff_t>(p.nodes[iblk].n))
292             return status::unimplemented;
293     }
294     return status::success;
295 }
296 
prb_normalize(prb_t & p)297 void prb_normalize(prb_t &p) {
298     for (int d = 0; d < p.ndims; ++d) {
299         int min_pos = d;
300         for (int j = d + 1; j < p.ndims; ++j) {
301             bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os
302                     || (true && p.nodes[j].os == p.nodes[min_pos].os
303                             && p.nodes[j].n < p.nodes[min_pos].n);
304             if (new_min) min_pos = j;
305         }
306         if (min_pos != d) {
307             nstl::swap(p.nodes[d], p.nodes[min_pos]);
308             if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d)
309                 p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos;
310         }
311     }
312 }
313 
prb_simplify(prb_t & p)314 void prb_simplify(prb_t &p) {
315 #if defined(__GNUC__) && __GNUC__ >= 4
316 /* GCC produces bogus array subscript is above array bounds warning for
317  * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
318 #pragma GCC diagnostic push
319 #pragma GCC diagnostic ignored "-Warray-bounds"
320 #endif
321     for (int d = 0; d < p.ndims - 1; ++d) {
322         auto &this_node = p.nodes[d + 0];
323         auto &next_node = p.nodes[d + 1];
324         const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0)
325                 && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1);
326         const bool fold = false
327                 || (next_node.n == static_cast<size_t>(1)
328                         && !skip_blk_idx) // trivial case, just drop next node
329                 || (true // or real folding if possible
330                         && !skip_blk_idx
331                         && next_node.is
332                                 == static_cast<ptrdiff_t>(
333                                         this_node.n * this_node.is)
334                         && next_node.os
335                                 == static_cast<ptrdiff_t>(
336                                         this_node.n * this_node.os)
337                         && next_node.ss
338                                 == static_cast<ptrdiff_t>(
339                                         this_node.n * this_node.ss));
340         if (fold) {
341             this_node.n *= next_node.n;
342             for (int j = d + 2; j < p.ndims; ++j)
343                 p.nodes[j - 1] = p.nodes[j];
344             if (d < p.blk_chunk_idx) --p.blk_chunk_idx;
345             --p.ndims;
346             --p.full_ndims;
347             --d; // make another try
348         }
349     }
350 #if defined(__GNUC__) && __GNUC__ >= 4
351 #pragma GCC diagnostic pop
352 #endif
353 }
354 
prb_node_split(prb_t & p,int dim,size_t n1)355 void prb_node_split(prb_t &p, int dim, size_t n1) {
356     assert(dim < p.ndims);
357     assert(p.ndims < max_ndims);
358     assert(p.nodes[dim].n % n1 == 0);
359 
360     p.ndims += 1;
361     p.full_ndims += 1;
362     if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1;
363 
364     for (int d = p.ndims; d > dim + 1; --d)
365         p.nodes[d] = p.nodes[d - 1];
366 
367     p.nodes[dim + 1].n = p.nodes[dim].n / n1;
368     p.nodes[dim + 1].is = p.nodes[dim].is * n1;
369     p.nodes[dim + 1].os = p.nodes[dim].os * n1;
370     p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
371 
372     p.nodes[dim].n = n1;
373 }
374 
prb_node_swap(prb_t & p,int d0,int d1)375 void prb_node_swap(prb_t &p, int d0, int d1) {
376     assert(d0 < p.ndims);
377     assert(d1 < p.ndims);
378     assert(p.ndims < max_ndims);
379 
380     if (d0 == d1) return;
381 
382     nstl::swap(p.nodes[d0], p.nodes[d1]);
383 }
384 
prb_node_move(prb_t & p,int d0,int d1)385 void prb_node_move(prb_t &p, int d0, int d1) {
386     assert(d0 < p.ndims);
387     assert(d1 < p.ndims);
388     assert(p.ndims < max_ndims);
389 
390     if (d0 == d1) return;
391 
392     node_t node = p.nodes[d0];
393 
394     if (d0 < d1)
395         for (int d = d0; d < d1; ++d)
396             p.nodes[d] = p.nodes[d + 1];
397     else
398         for (int d = d0; d > d1; --d)
399             p.nodes[d] = p.nodes[d - 1];
400 
401     p.nodes[d1] = node;
402 }
403 
prb_dump(const prb_t & p)404 void prb_dump(const prb_t &p) {
405     printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype),
406             dnnl_dt2str(p.otype), p.ndims);
407     for (int d = 0; d < p.ndims; ++d)
408         printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os,
409                 p.nodes[d].ss);
410     printf(" off:%zu:%zu\n", p.ioff, p.ooff);
411 }
412 
413 } // namespace tr
414 
415 } // namespace x64
416 } // namespace cpu
417 } // namespace impl
418 } // namespace dnnl
419