1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 * Copyright 2020 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17
18 #include <assert.h>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_desc_wrapper.hpp"
23 #include "common/nstl.hpp"
24 #include "common/type_helpers.hpp"
25 #include "common/utils.hpp"
26 #include "dnnl_debug.h"
27
28 #include "cpu/aarch64/jit_uni_reorder.hpp"
29
30 using namespace dnnl::impl::types;
31 using namespace dnnl::impl::status;
32
33 namespace dnnl {
34 namespace impl {
35 namespace cpu {
36 namespace aarch64 {
37
38 namespace tr {
39
40 /** ad-hoc structure to describe blocked memory layout */
41 struct layout_desc_t {
42 data_type_t dt;
43 int ndims;
44 dims_t id;
45 dims_t dims;
46 strides_t strides;
47 };
48
cvt_mem_desc_to_layout_desc(const memory_desc_t & md_,layout_desc_t & ld,const dims_t & blocks)49 status_t cvt_mem_desc_to_layout_desc(
50 const memory_desc_t &md_, layout_desc_t &ld, const dims_t &blocks) {
51 const auto md = memory_desc_wrapper(md_);
52
53 bool ok = true && md.is_blocking_desc() && md.extra().flags == 0;
54 if (!ok) return invalid_arguments;
55
56 const auto &bd = md.blocking_desc();
57
58 ld.ndims = 0;
59 ld.dt = md.data_type();
60
61 auto P = [&ld](int id, int dim, ptrdiff_t stride) {
62 assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
63 ld.id[ld.ndims] = id;
64 ld.dims[ld.ndims] = dim;
65 ld.strides[ld.ndims] = stride;
66 ++ld.ndims;
67 };
68
69 for (int d = 0; d < md.ndims(); ++d) {
70 const int ld_ndims_start = ld.ndims;
71 if (blocks[d] != 1) {
72 stride_t stride = 1;
73 for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
74 if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride);
75 stride *= bd.inner_blks[iblk];
76 }
77 }
78 P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]);
79
80 // TODO: NOW: revisit, do we need a reverse?
81 // TODO: NOW: consider using strides instead of block sizes in md
82 // reverse the order of dims
83 for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
84 const int idx0 = ld_ndims_start + ld_d;
85 const int idx1 = ld.ndims - 1 - ld_d;
86 nstl::swap(ld.dims[idx0], ld.dims[idx1]);
87 nstl::swap(ld.strides[idx0], ld.strides[idx1]);
88 }
89 }
90
91 return success;
92 }
93
prb_init(prb_t & p,const memory_desc_t & imd,const memory_desc_t & omd,const primitive_attr_t * attr)94 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
95 const primitive_attr_t *attr) {
96 auto im_d = memory_desc_wrapper(imd);
97 auto om_d = memory_desc_wrapper(omd);
98
99 auto check_post_ops = [](const primitive_attr_t *attr) {
100 const auto &po = attr->post_ops_;
101 return po.len() == 0 || (po.len() == 1 && po.entry_[0].is_sum(false));
102 };
103
104 bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc()
105 && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim()
106 && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim()
107 && attr->has_default_values(
108 primitive_attr_t::skip_mask_t::oscale_runtime
109 | primitive_attr_t::skip_mask_t::post_ops)
110 && check_post_ops(attr);
111 if (!ok) return unimplemented;
112
113 dims_t iblocks, oblocks;
114 im_d.compute_blocks(iblocks);
115 om_d.compute_blocks(oblocks);
116
117 /* padding_dim consistency check */
118 for (int d = 0; d < im_d.ndims(); ++d) {
119 const auto pdim = im_d.padded_dims()[d];
120 bool ok = true && pdim == om_d.padded_dims()[d]
121 && pdim % iblocks[d] == 0 && pdim % oblocks[d] == 0;
122 if (!ok) return unimplemented;
123 }
124
125 layout_desc_t ild, old;
126 status_t status = cvt_mem_desc_to_layout_desc(imd, ild, iblocks);
127 if (status != success) return status;
128 status = cvt_mem_desc_to_layout_desc(omd, old, oblocks);
129 if (status != success) return status;
130
131 p.itype = ild.dt;
132 p.otype = old.dt;
133
134 p.scale_type = attr->output_scales_.has_default_values()
135 ? scale_type_t::NONE
136 : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON
137 : scale_type_t::MANY);
138
139 ptrdiff_t ss[max_ndims] = {0};
140 if (p.scale_type == scale_type_t::MANY) {
141 ptrdiff_t last_ss = 1;
142 for (int d = old.ndims - 1; d >= 0; --d) {
143 assert((d == 0 || old.id[d - 1] <= old.id[d])
144 && "logical dimensions should be in ascending order");
145 if (attr->output_scales_.mask_ & (1 << old.id[d])) {
146 ss[d] = last_ss;
147 last_ss *= old.dims[d];
148 }
149 }
150 }
151
152 int ndims = 0;
153
154 int i_pos = 0; /* state for input -- current dimension */
155 int o_pos = 0; /* state for output -- current dimension */
156
157 while (i_pos < ild.ndims && o_pos < old.ndims) {
158 assert(ild.id[i_pos] == old.id[o_pos]);
159 if (ild.id[i_pos] != old.id[o_pos]) return runtime_error;
160
161 assert(ndims < max_ndims);
162 if (ndims == max_ndims) return runtime_error;
163
164 if (ild.dims[i_pos] == old.dims[o_pos]) {
165 p.nodes[ndims].n = ild.dims[i_pos];
166 p.nodes[ndims].is = ild.strides[i_pos];
167 p.nodes[ndims].os = old.strides[o_pos];
168 p.nodes[ndims].ss = ss[o_pos];
169 ++ndims;
170 ++i_pos;
171 ++o_pos;
172 } else if (ild.dims[i_pos] < old.dims[o_pos]) {
173 assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
174 int factor = old.dims[o_pos] / ild.dims[i_pos];
175 p.nodes[ndims].n = ild.dims[i_pos];
176 p.nodes[ndims].is = ild.strides[i_pos];
177 p.nodes[ndims].os = old.strides[o_pos] * factor;
178 p.nodes[ndims].ss = ss[o_pos] * factor;
179 ++ndims;
180 ++i_pos;
181 old.dims[o_pos] = factor;
182 } else if (ild.dims[i_pos] > old.dims[o_pos]) {
183 assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
184 int factor = ild.dims[i_pos] / old.dims[o_pos];
185 p.nodes[ndims].n = old.dims[o_pos];
186 p.nodes[ndims].is = ild.strides[i_pos] * factor;
187 p.nodes[ndims].os = old.strides[o_pos];
188 p.nodes[ndims].ss = ss[o_pos];
189 ++ndims;
190 ++o_pos;
191 ild.dims[i_pos] = factor;
192 }
193 }
194 p.ndims = ndims;
195
196 p.ioff = memory_desc_wrapper(imd).offset0();
197 p.ooff = memory_desc_wrapper(omd).offset0();
198
199 const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
200 p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
201
202 return success;
203 }
204
prb_normalize(prb_t & p)205 void prb_normalize(prb_t &p) {
206 for (int d = 0; d < p.ndims; ++d) {
207 int min_pos = d;
208 for (int j = d + 1; j < p.ndims; ++j) {
209 bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os
210 || (true && p.nodes[j].os == p.nodes[min_pos].os
211 && p.nodes[j].n < p.nodes[min_pos].n);
212 if (new_min) min_pos = j;
213 }
214 if (min_pos != d) nstl::swap(p.nodes[d], p.nodes[min_pos]);
215 }
216 }
217
prb_simplify(prb_t & p)218 void prb_simplify(prb_t &p) {
219 #if defined(__GNUC__) && __GNUC__ >= 4
220 /* GCC produces bogus array subscript is above array bounds warning for
221 * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
222 #pragma GCC diagnostic push
223 #pragma GCC diagnostic ignored "-Warray-bounds"
224 #endif
225 for (int d = 0; d < p.ndims - 1; ++d) {
226 auto &this_node = p.nodes[d + 0];
227 auto &next_node = p.nodes[d + 1];
228 const bool fold = false
229 || next_node.n == (size_t)1 // trivial case, just drop next node
230 || (true // or real folding if possible
231 && next_node.is == (ptrdiff_t)this_node.n * this_node.is
232 && next_node.os == (ptrdiff_t)this_node.n * this_node.os
233 && next_node.ss
234 == (ptrdiff_t)this_node.n * this_node.ss);
235 if (fold) {
236 this_node.n *= next_node.n;
237 for (int j = d + 2; j < p.ndims; ++j)
238 p.nodes[j - 1] = p.nodes[j];
239 --p.ndims;
240 --d; // make another try
241 }
242 }
243 #if defined(__GNUC__) && __GNUC__ >= 4
244 #pragma GCC diagnostic pop
245 #endif
246 }
247
prb_node_split(prb_t & p,int dim,size_t n1)248 void prb_node_split(prb_t &p, int dim, size_t n1) {
249 assert(dim < p.ndims);
250 assert(p.ndims < max_ndims);
251 assert(p.nodes[dim].n % n1 == 0);
252
253 p.ndims += 1;
254
255 for (int d = p.ndims; d > dim + 1; --d)
256 p.nodes[d] = p.nodes[d - 1];
257
258 p.nodes[dim + 1].n = p.nodes[dim].n / n1;
259 p.nodes[dim + 1].is = p.nodes[dim].is * n1;
260 p.nodes[dim + 1].os = p.nodes[dim].os * n1;
261 p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
262
263 p.nodes[dim].n = n1;
264 }
265
prb_node_swap(prb_t & p,int d0,int d1)266 void prb_node_swap(prb_t &p, int d0, int d1) {
267 assert(d0 < p.ndims);
268 assert(d1 < p.ndims);
269 assert(p.ndims < max_ndims);
270
271 if (d0 == d1) return;
272
273 nstl::swap(p.nodes[d0], p.nodes[d1]);
274 }
275
prb_node_move(prb_t & p,int d0,int d1)276 void prb_node_move(prb_t &p, int d0, int d1) {
277 assert(d0 < p.ndims);
278 assert(d1 < p.ndims);
279 assert(p.ndims < max_ndims);
280
281 if (d0 == d1) return;
282
283 node_t node = p.nodes[d0];
284
285 if (d0 < d1)
286 for (int d = d0; d < d1; ++d)
287 p.nodes[d] = p.nodes[d + 1];
288 else
289 for (int d = d0; d > d1; --d)
290 p.nodes[d] = p.nodes[d - 1];
291
292 p.nodes[d1] = node;
293 }
294
prb_dump(const prb_t & p)295 void prb_dump(const prb_t &p) {
296 printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype),
297 dnnl_dt2str(p.otype), p.ndims);
298 for (int d = 0; d < p.ndims; ++d)
299 printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os,
300 p.nodes[d].ss);
301 printf(" off:%zu:%zu\n", p.ioff, p.ooff);
302 }
303
304 } // namespace tr
305
306 } // namespace aarch64
307 } // namespace cpu
308 } // namespace impl
309 } // namespace dnnl
310