1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 * Copyright 2020 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #include <assert.h>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_desc_wrapper.hpp"
23 #include "common/nstl.hpp"
24 #include "common/type_helpers.hpp"
25 #include "common/utils.hpp"
26 #include "dnnl_debug.h"
27 
28 #include "cpu/aarch64/jit_uni_reorder.hpp"
29 
30 using namespace dnnl::impl::types;
31 using namespace dnnl::impl::status;
32 
33 namespace dnnl {
34 namespace impl {
35 namespace cpu {
36 namespace aarch64 {
37 
38 namespace tr {
39 
40 /** ad-hoc structure to describe blocked memory layout */
41 struct layout_desc_t {
42     data_type_t dt;
43     int ndims;
44     dims_t id;
45     dims_t dims;
46     strides_t strides;
47 };
48 
cvt_mem_desc_to_layout_desc(const memory_desc_t & md_,layout_desc_t & ld,const dims_t & blocks)49 status_t cvt_mem_desc_to_layout_desc(
50         const memory_desc_t &md_, layout_desc_t &ld, const dims_t &blocks) {
51     const auto md = memory_desc_wrapper(md_);
52 
53     bool ok = true && md.is_blocking_desc() && md.extra().flags == 0;
54     if (!ok) return invalid_arguments;
55 
56     const auto &bd = md.blocking_desc();
57 
58     ld.ndims = 0;
59     ld.dt = md.data_type();
60 
61     auto P = [&ld](int id, int dim, ptrdiff_t stride) {
62         assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
63         ld.id[ld.ndims] = id;
64         ld.dims[ld.ndims] = dim;
65         ld.strides[ld.ndims] = stride;
66         ++ld.ndims;
67     };
68 
69     for (int d = 0; d < md.ndims(); ++d) {
70         const int ld_ndims_start = ld.ndims;
71         if (blocks[d] != 1) {
72             stride_t stride = 1;
73             for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
74                 if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride);
75                 stride *= bd.inner_blks[iblk];
76             }
77         }
78         P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]);
79 
80         // TODO: NOW: revisit, do we need a reverse?
81         // TODO: NOW: consider using strides instead of block sizes in md
82         // reverse the order of dims
83         for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
84             const int idx0 = ld_ndims_start + ld_d;
85             const int idx1 = ld.ndims - 1 - ld_d;
86             nstl::swap(ld.dims[idx0], ld.dims[idx1]);
87             nstl::swap(ld.strides[idx0], ld.strides[idx1]);
88         }
89     }
90 
91     return success;
92 }
93 
prb_init(prb_t & p,const memory_desc_t & imd,const memory_desc_t & omd,const primitive_attr_t * attr)94 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
95         const primitive_attr_t *attr) {
96     auto im_d = memory_desc_wrapper(imd);
97     auto om_d = memory_desc_wrapper(omd);
98 
99     auto check_post_ops = [](const primitive_attr_t *attr) {
100         const auto &po = attr->post_ops_;
101         return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false));
102     };
103 
104     bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc()
105             && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim()
106             && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim()
107             && attr->has_default_values(
108                     primitive_attr_t::skip_mask_t::oscale_runtime
109                     | primitive_attr_t::skip_mask_t::post_ops)
110             && check_post_ops(attr);
111     if (!ok) return unimplemented;
112 
113     dims_t iblocks, oblocks;
114     im_d.compute_blocks(iblocks);
115     om_d.compute_blocks(oblocks);
116 
117     /* padding_dim consistency check */
118     for (int d = 0; d < im_d.ndims(); ++d) {
119         const auto pdim = im_d.padded_dims()[d];
120         bool ok = true && pdim == om_d.padded_dims()[d]
121                 && pdim % iblocks[d] == 0 && pdim % oblocks[d] == 0;
122         if (!ok) return unimplemented;
123     }
124 
125     layout_desc_t ild, old;
126     status_t status = cvt_mem_desc_to_layout_desc(imd, ild, iblocks);
127     if (status != success) return status;
128     status = cvt_mem_desc_to_layout_desc(omd, old, oblocks);
129     if (status != success) return status;
130 
131     p.itype = ild.dt;
132     p.otype = old.dt;
133 
134     p.scale_type = attr->output_scales_.has_default_values()
135             ? scale_type_t::NONE
136             : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON
137                                                : scale_type_t::MANY);
138 
139     ptrdiff_t ss[max_ndims] = {0};
140     if (p.scale_type == scale_type_t::MANY) {
141         ptrdiff_t last_ss = 1;
142         for (int d = old.ndims - 1; d >= 0; --d) {
143             assert((d == 0 || old.id[d - 1] <= old.id[d])
144                     && "logical dimensions should be in ascending order");
145             if (attr->output_scales_.mask_ & (1 << old.id[d])) {
146                 ss[d] = last_ss;
147                 last_ss *= old.dims[d];
148             }
149         }
150     }
151 
152     int ndims = 0;
153 
154     int i_pos = 0; /* state for input  -- current dimension */
155     int o_pos = 0; /* state for output -- current dimension */
156 
157     while (i_pos < ild.ndims && o_pos < old.ndims) {
158         assert(ild.id[i_pos] == old.id[o_pos]);
159         if (ild.id[i_pos] != old.id[o_pos]) return runtime_error;
160 
161         assert(ndims < max_ndims);
162         if (ndims == max_ndims) return runtime_error;
163 
164         if (ild.dims[i_pos] == old.dims[o_pos]) {
165             p.nodes[ndims].n = ild.dims[i_pos];
166             p.nodes[ndims].is = ild.strides[i_pos];
167             p.nodes[ndims].os = old.strides[o_pos];
168             p.nodes[ndims].ss = ss[o_pos];
169             ++ndims;
170             ++i_pos;
171             ++o_pos;
172         } else if (ild.dims[i_pos] < old.dims[o_pos]) {
173             assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
174             int factor = old.dims[o_pos] / ild.dims[i_pos];
175             p.nodes[ndims].n = ild.dims[i_pos];
176             p.nodes[ndims].is = ild.strides[i_pos];
177             p.nodes[ndims].os = old.strides[o_pos] * factor;
178             p.nodes[ndims].ss = ss[o_pos] * factor;
179             ++ndims;
180             ++i_pos;
181             old.dims[o_pos] = factor;
182         } else if (ild.dims[i_pos] > old.dims[o_pos]) {
183             assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
184             int factor = ild.dims[i_pos] / old.dims[o_pos];
185             p.nodes[ndims].n = old.dims[o_pos];
186             p.nodes[ndims].is = ild.strides[i_pos] * factor;
187             p.nodes[ndims].os = old.strides[o_pos];
188             p.nodes[ndims].ss = ss[o_pos];
189             ++ndims;
190             ++o_pos;
191             ild.dims[i_pos] = factor;
192         }
193     }
194     p.ndims = ndims;
195 
196     p.ioff = memory_desc_wrapper(imd).offset0();
197     p.ooff = memory_desc_wrapper(omd).offset0();
198 
199     const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
200     p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
201 
202     return success;
203 }
204 
prb_normalize(prb_t & p)205 void prb_normalize(prb_t &p) {
206     for (int d = 0; d < p.ndims; ++d) {
207         int min_pos = d;
208         for (int j = d + 1; j < p.ndims; ++j) {
209             bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os
210                     || (true && p.nodes[j].os == p.nodes[min_pos].os
211                             && p.nodes[j].n < p.nodes[min_pos].n);
212             if (new_min) min_pos = j;
213         }
214         if (min_pos != d) nstl::swap(p.nodes[d], p.nodes[min_pos]);
215     }
216 }
217 
prb_simplify(prb_t & p)218 void prb_simplify(prb_t &p) {
219 #if defined(__GNUC__) && __GNUC__ >= 4
220 /* GCC produces bogus array subscript is above array bounds warning for
221  * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
222 #pragma GCC diagnostic push
223 #pragma GCC diagnostic ignored "-Warray-bounds"
224 #endif
225     for (int d = 0; d < p.ndims - 1; ++d) {
226         auto &this_node = p.nodes[d + 0];
227         auto &next_node = p.nodes[d + 1];
228         const bool fold = false
229                 || next_node.n == (size_t)1 // trivial case, just drop next node
230                 || (true // or real folding if possible
231                         && next_node.is == (ptrdiff_t)this_node.n * this_node.is
232                         && next_node.os == (ptrdiff_t)this_node.n * this_node.os
233                         && next_node.ss
234                                 == (ptrdiff_t)this_node.n * this_node.ss);
235         if (fold) {
236             this_node.n *= next_node.n;
237             for (int j = d + 2; j < p.ndims; ++j)
238                 p.nodes[j - 1] = p.nodes[j];
239             --p.ndims;
240             --d; // make another try
241         }
242     }
243 #if defined(__GNUC__) && __GNUC__ >= 4
244 #pragma GCC diagnostic pop
245 #endif
246 }
247 
prb_node_split(prb_t & p,int dim,size_t n1)248 void prb_node_split(prb_t &p, int dim, size_t n1) {
249     assert(dim < p.ndims);
250     assert(p.ndims < max_ndims);
251     assert(p.nodes[dim].n % n1 == 0);
252 
253     p.ndims += 1;
254 
255     for (int d = p.ndims; d > dim + 1; --d)
256         p.nodes[d] = p.nodes[d - 1];
257 
258     p.nodes[dim + 1].n = p.nodes[dim].n / n1;
259     p.nodes[dim + 1].is = p.nodes[dim].is * n1;
260     p.nodes[dim + 1].os = p.nodes[dim].os * n1;
261     p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
262 
263     p.nodes[dim].n = n1;
264 }
265 
prb_node_swap(prb_t & p,int d0,int d1)266 void prb_node_swap(prb_t &p, int d0, int d1) {
267     assert(d0 < p.ndims);
268     assert(d1 < p.ndims);
269     assert(p.ndims < max_ndims);
270 
271     if (d0 == d1) return;
272 
273     nstl::swap(p.nodes[d0], p.nodes[d1]);
274 }
275 
prb_node_move(prb_t & p,int d0,int d1)276 void prb_node_move(prb_t &p, int d0, int d1) {
277     assert(d0 < p.ndims);
278     assert(d1 < p.ndims);
279     assert(p.ndims < max_ndims);
280 
281     if (d0 == d1) return;
282 
283     node_t node = p.nodes[d0];
284 
285     if (d0 < d1)
286         for (int d = d0; d < d1; ++d)
287             p.nodes[d] = p.nodes[d + 1];
288     else
289         for (int d = d0; d > d1; --d)
290             p.nodes[d] = p.nodes[d - 1];
291 
292     p.nodes[d1] = node;
293 }
294 
prb_dump(const prb_t & p)295 void prb_dump(const prb_t &p) {
296     printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype),
297             dnnl_dt2str(p.otype), p.ndims);
298     for (int d = 0; d < p.ndims; ++d)
299         printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os,
300                 p.nodes[d].ss);
301     printf(" off:%zu:%zu\n", p.ioff, p.ooff);
302 }
303 
304 } // namespace tr
305 
306 } // namespace aarch64
307 } // namespace cpu
308 } // namespace impl
309 } // namespace dnnl
310