1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "common/c_types_map.hpp"
20 #include "common/dnnl_thread.hpp"
21 #include "common/memory_desc_wrapper.hpp"
22 #include "common/nstl.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 #include "oneapi/dnnl/dnnl_debug.h"
26
27 #include "cpu/x64/jit_uni_reorder.hpp"
28
29 using namespace dnnl::impl::types;
30 using namespace dnnl::impl::status;
31
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36
37 namespace tr {
38
39 /** ad-hoc structure to describe blocked memory layout */
40 struct layout_desc_t {
41 data_type_t dt;
42 int ndims;
43 dims_t id;
44 dims_t dims;
45 strides_t strides;
46 };
47
compute_blk_and_tail(const memory_desc_t & md_,const int idx,int & blk,int & tail)48 static status_t compute_blk_and_tail(
49 const memory_desc_t &md_, const int idx, int &blk, int &tail) {
50 const auto md = memory_desc_wrapper(md_);
51 const auto &bd = md.blocking_desc();
52 if (tail == 0) return status::success;
53
54 // Only supports inconsistent padding in single and double blocks
55 // and the total block size <= 256
56 for (int iblk = bd.inner_nblks - 1; iblk > 0; --iblk) {
57 if (bd.inner_idxs[iblk] == idx) break;
58 blk *= bd.inner_blks[iblk];
59 tail *= bd.inner_blks[iblk];
60 }
61 if (bd.inner_nblks > 2 || blk > 256) return status::unimplemented;
62
63 return status::success;
64 }
65
compute_chunk_idx(const prb_t & p,const memory_desc_t & imd_,const memory_desc_t & omd_,const int blk_idx,int & chunk_idx)66 static status_t compute_chunk_idx(const prb_t &p, const memory_desc_t &imd_,
67 const memory_desc_t &omd_, const int blk_idx, int &chunk_idx) {
68 const auto imd = memory_desc_wrapper(imd_);
69 const auto omd = memory_desc_wrapper(omd_);
70 const auto &ibd = imd.blocking_desc();
71 const auto &obd = omd.blocking_desc();
72 if (p.ip_tail == 0 && p.op_tail == 0) return status::success;
73
74 const ptrdiff_t is
75 = ibd.strides[blk_idx] * obd.inner_blks[obd.inner_idxs[blk_idx]];
76 const ptrdiff_t os = obd.strides[blk_idx];
77
78 for (int i = blk_idx; i < omd.ndims(); ++i) {
79 if (p.nodes[i].os == os && p.nodes[i].is == is) {
80 chunk_idx = i;
81 return status::success;
82 }
83 }
84
85 return status::invalid_arguments;
86 }
87
cvt_mem_desc_to_layout_desc(const memory_desc_t & md_,layout_desc_t & ld,const dims_t & blocks,const dims_t & ext_padding)88 status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
89 layout_desc_t &ld, const dims_t &blocks, const dims_t &ext_padding) {
90 const auto md = memory_desc_wrapper(md_);
91
92 bool ok = true && md.is_blocking_desc() && md.extra().flags == 0;
93 if (!ok) return invalid_arguments;
94
95 const auto &bd = md.blocking_desc();
96
97 ld.ndims = 0;
98 ld.dt = md.data_type();
99
100 auto P = [&ld](int id, int dim, ptrdiff_t stride) {
101 assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
102 ld.id[ld.ndims] = id;
103 ld.dims[ld.ndims] = dim;
104 ld.strides[ld.ndims] = stride;
105 ++ld.ndims;
106 };
107
108 for (int d = 0; d < md.ndims(); ++d) {
109 const int ld_ndims_start = ld.ndims;
110 if (blocks[d] != 1) {
111 stride_t stride = 1;
112 for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
113 if (bd.inner_idxs[iblk] == d) P(d, bd.inner_blks[iblk], stride);
114 stride *= bd.inner_blks[iblk];
115 }
116 }
117 P(d, (md.padded_dims()[d] + ext_padding[d]) / blocks[d], bd.strides[d]);
118
119 // TODO: NOW: revisit, do we need a reverse?
120 // TODO: NOW: consider using strides instead of block sizes in md
121 // reverse the order of dims
122 for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
123 const int idx0 = ld_ndims_start + ld_d;
124 const int idx1 = ld.ndims - 1 - ld_d;
125 nstl::swap(ld.dims[idx0], ld.dims[idx1]);
126 nstl::swap(ld.strides[idx0], ld.strides[idx1]);
127 }
128 }
129
130 return success;
131 }
132
prb_init(prb_t & p,const memory_desc_t & imd,const memory_desc_t & omd,const primitive_attr_t * attr)133 status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
134 const primitive_attr_t *attr) {
135 auto im_d = memory_desc_wrapper(imd);
136 auto om_d = memory_desc_wrapper(omd);
137
138 auto check_post_ops = [](const primitive_attr_t *attr) {
139 const auto &po = attr->post_ops_;
140 return po.len() == 0
141 || (po.len() == 1 && po.contain(primitive_kind::sum, 0));
142 };
143
144 bool ok = im_d.is_blocking_desc() && om_d.is_blocking_desc()
145 && !im_d.has_runtime_dims_or_strides() && !im_d.has_zero_dim()
146 && !om_d.has_runtime_dims_or_strides() && !om_d.has_zero_dim()
147 && attr->has_default_values(
148 primitive_attr_t::skip_mask_t::oscale_runtime
149 | primitive_attr_t::skip_mask_t::post_ops)
150 && check_post_ops(attr);
151 if (!ok) return unimplemented;
152
153 dims_t iblocks, oblocks, ip_padding, op_padding;
154 im_d.compute_blocks(iblocks);
155 om_d.compute_blocks(oblocks);
156 utils::array_set(ip_padding, 0, im_d.ndims());
157 utils::array_set(op_padding, 0, om_d.ndims());
158
159 /* padding_dim consistency check
160 * only supports inconsitent padding for src
161 * TODO: Add inconsistent padding support for dst */
162 int ip_tail = 0;
163 int op_tail = 0;
164 int iblk_w_tail = 1;
165 int oblk_w_tail = 1;
166 int blk_idx = 0;
167
168 for (int d = 0; d < im_d.ndims(); ++d) {
169 const int ip_tmp_dim = im_d.padded_dims()[d];
170 const int op_tmp_dim = om_d.padded_dims()[d];
171 const int ip_tmp_tail = ip_tmp_dim % oblocks[d];
172 const int op_tmp_tail = op_tmp_dim % iblocks[d];
173
174 const bool pdim_consistent = ip_tmp_dim == op_tmp_dim
175 && ip_tmp_tail == 0 && op_tmp_tail == 0;
176 const bool pdim_tail = ip_tmp_tail > 0
177 && (ip_tmp_dim + oblocks[d] - ip_tmp_tail) == op_tmp_dim
178 && op_tmp_tail == 0 && ip_tail == 0;
179 if (!pdim_consistent && !pdim_tail) return status::unimplemented;
180 if (pdim_tail) {
181 blk_idx = d;
182 ip_tail = ip_tmp_tail;
183 op_tail = op_tmp_tail;
184 iblk_w_tail = iblocks[d];
185 oblk_w_tail = oblocks[d];
186 ip_padding[d] = oblocks[d] - ip_tmp_tail;
187 op_padding[d] = iblocks[d] - op_tmp_tail;
188 }
189 }
190 CHECK(compute_blk_and_tail(omd, blk_idx, oblk_w_tail, ip_tail));
191
192 layout_desc_t ild, old;
193 status_t status
194 = cvt_mem_desc_to_layout_desc(imd, ild, iblocks, ip_padding);
195 if (status != success) return status;
196 status = cvt_mem_desc_to_layout_desc(omd, old, oblocks, op_padding);
197 if (status != success) return status;
198
199 p.itype = ild.dt;
200 p.otype = old.dt;
201 p.ip_tail = ip_tail;
202 p.op_tail = op_tail;
203 p.iblock = iblk_w_tail;
204 p.oblock = oblk_w_tail;
205
206 p.scale_type = attr->output_scales_.has_default_values()
207 ? scale_type_t::NONE
208 : (attr->output_scales_.mask_ == 0 ? scale_type_t::COMMON
209 : scale_type_t::MANY);
210
211 ptrdiff_t ss[max_ndims] = {0};
212 if (p.scale_type == scale_type_t::MANY) {
213 ptrdiff_t last_ss = 1;
214 for (int d = old.ndims - 1; d >= 0; --d) {
215 assert((d == 0 || old.id[d - 1] <= old.id[d])
216 && "logical dimensions should be in ascending order");
217 if (attr->output_scales_.mask_ & (1 << old.id[d])) {
218 ss[d] = last_ss;
219 last_ss *= old.dims[d];
220 }
221 }
222 }
223
224 int ndims = 0;
225
226 int i_pos = 0; /* state for input -- current dimension */
227 int o_pos = 0; /* state for output -- current dimension */
228
229 while (i_pos < ild.ndims && o_pos < old.ndims) {
230 assert(ild.id[i_pos] == old.id[o_pos]);
231
232 assert(ndims < max_ndims);
233 if (ndims == max_ndims) return runtime_error;
234
235 if (ild.dims[i_pos] == old.dims[o_pos]) {
236 p.nodes[ndims].n = ild.dims[i_pos];
237 p.nodes[ndims].is = ild.strides[i_pos];
238 p.nodes[ndims].os = old.strides[o_pos];
239 p.nodes[ndims].ss = ss[o_pos];
240 ++ndims;
241 ++i_pos;
242 ++o_pos;
243 } else if (ild.dims[i_pos] < old.dims[o_pos]) {
244 assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
245 int factor = old.dims[o_pos] / ild.dims[i_pos];
246 p.nodes[ndims].n = ild.dims[i_pos];
247 p.nodes[ndims].is = ild.strides[i_pos];
248 p.nodes[ndims].os = old.strides[o_pos] * factor;
249 p.nodes[ndims].ss = ss[o_pos] * factor;
250 ++ndims;
251 ++i_pos;
252 old.dims[o_pos] = factor;
253 } else if (ild.dims[i_pos] > old.dims[o_pos]) {
254 assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
255 int factor = ild.dims[i_pos] / old.dims[o_pos];
256 p.nodes[ndims].n = old.dims[o_pos];
257 p.nodes[ndims].is = ild.strides[i_pos] * factor;
258 p.nodes[ndims].os = old.strides[o_pos];
259 p.nodes[ndims].ss = ss[o_pos];
260 ++ndims;
261 ++o_pos;
262 ild.dims[i_pos] = factor;
263 }
264 }
265 int blk_chunk_idx = ndims - 1;
266 CHECK(compute_chunk_idx(p, imd, omd, blk_idx, blk_chunk_idx));
267
268 p.ndims = ndims;
269 p.full_ndims = ndims;
270 p.blk_chunk_idx = blk_chunk_idx;
271
272 p.ioff = memory_desc_wrapper(imd).offset0();
273 p.ooff = memory_desc_wrapper(omd).offset0();
274
275 const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
276 p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
277
278 return success;
279 }
280
prb_check_blk(prb_t & p,const memory_desc_t & md_)281 status_t prb_check_blk(prb_t &p, const memory_desc_t &md_) {
282 const auto md = memory_desc_wrapper(md_);
283 const auto &bd = md.blocking_desc();
284 if (p.ip_tail == 0) return status::success;
285
286 // Check if the inner blocks and p.nodes[blk].n in the firsti nblks
287 // is equivalent in reverse order when has tail in block layout.
288 const int nblk = bd.inner_nblks;
289 for (int iblk = 0; iblk < nblk; ++iblk) {
290 if (bd.inner_blks[nblk - iblk - 1]
291 != static_cast<ptrdiff_t>(p.nodes[iblk].n))
292 return status::unimplemented;
293 }
294 return status::success;
295 }
296
prb_normalize(prb_t & p)297 void prb_normalize(prb_t &p) {
298 for (int d = 0; d < p.ndims; ++d) {
299 int min_pos = d;
300 for (int j = d + 1; j < p.ndims; ++j) {
301 bool new_min = false || p.nodes[j].os < p.nodes[min_pos].os
302 || (true && p.nodes[j].os == p.nodes[min_pos].os
303 && p.nodes[j].n < p.nodes[min_pos].n);
304 if (new_min) min_pos = j;
305 }
306 if (min_pos != d) {
307 nstl::swap(p.nodes[d], p.nodes[min_pos]);
308 if (p.blk_chunk_idx == min_pos || p.blk_chunk_idx == d)
309 p.blk_chunk_idx = p.blk_chunk_idx == min_pos ? d : min_pos;
310 }
311 }
312 }
313
prb_simplify(prb_t & p)314 void prb_simplify(prb_t &p) {
315 #if defined(__GNUC__) && __GNUC__ >= 4
316 /* GCC produces bogus array subscript is above array bounds warning for
317 * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
318 #pragma GCC diagnostic push
319 #pragma GCC diagnostic ignored "-Warray-bounds"
320 #endif
321 for (int d = 0; d < p.ndims - 1; ++d) {
322 auto &this_node = p.nodes[d + 0];
323 auto &next_node = p.nodes[d + 1];
324 const bool skip_blk_idx = (p.ip_tail > 0 || p.op_tail > 0)
325 && (p.blk_chunk_idx == d || p.blk_chunk_idx == d + 1);
326 const bool fold = false
327 || (next_node.n == static_cast<size_t>(1)
328 && !skip_blk_idx) // trivial case, just drop next node
329 || (true // or real folding if possible
330 && !skip_blk_idx
331 && next_node.is
332 == static_cast<ptrdiff_t>(
333 this_node.n * this_node.is)
334 && next_node.os
335 == static_cast<ptrdiff_t>(
336 this_node.n * this_node.os)
337 && next_node.ss
338 == static_cast<ptrdiff_t>(
339 this_node.n * this_node.ss));
340 if (fold) {
341 this_node.n *= next_node.n;
342 for (int j = d + 2; j < p.ndims; ++j)
343 p.nodes[j - 1] = p.nodes[j];
344 if (d < p.blk_chunk_idx) --p.blk_chunk_idx;
345 --p.ndims;
346 --p.full_ndims;
347 --d; // make another try
348 }
349 }
350 #if defined(__GNUC__) && __GNUC__ >= 4
351 #pragma GCC diagnostic pop
352 #endif
353 }
354
prb_node_split(prb_t & p,int dim,size_t n1)355 void prb_node_split(prb_t &p, int dim, size_t n1) {
356 assert(dim < p.ndims);
357 assert(p.ndims < max_ndims);
358 assert(p.nodes[dim].n % n1 == 0);
359
360 p.ndims += 1;
361 p.full_ndims += 1;
362 if (dim < p.blk_chunk_idx) p.blk_chunk_idx += 1;
363
364 for (int d = p.ndims; d > dim + 1; --d)
365 p.nodes[d] = p.nodes[d - 1];
366
367 p.nodes[dim + 1].n = p.nodes[dim].n / n1;
368 p.nodes[dim + 1].is = p.nodes[dim].is * n1;
369 p.nodes[dim + 1].os = p.nodes[dim].os * n1;
370 p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
371
372 p.nodes[dim].n = n1;
373 }
374
prb_node_swap(prb_t & p,int d0,int d1)375 void prb_node_swap(prb_t &p, int d0, int d1) {
376 assert(d0 < p.ndims);
377 assert(d1 < p.ndims);
378 assert(p.ndims < max_ndims);
379
380 if (d0 == d1) return;
381
382 nstl::swap(p.nodes[d0], p.nodes[d1]);
383 }
384
prb_node_move(prb_t & p,int d0,int d1)385 void prb_node_move(prb_t &p, int d0, int d1) {
386 assert(d0 < p.ndims);
387 assert(d1 < p.ndims);
388 assert(p.ndims < max_ndims);
389
390 if (d0 == d1) return;
391
392 node_t node = p.nodes[d0];
393
394 if (d0 < d1)
395 for (int d = d0; d < d1; ++d)
396 p.nodes[d] = p.nodes[d + 1];
397 else
398 for (int d = d0; d > d1; --d)
399 p.nodes[d] = p.nodes[d - 1];
400
401 p.nodes[d1] = node;
402 }
403
prb_dump(const prb_t & p)404 void prb_dump(const prb_t &p) {
405 printf("@@@ type:%s:%s ndims:%d ", dnnl_dt2str(p.itype),
406 dnnl_dt2str(p.otype), p.ndims);
407 for (int d = 0; d < p.ndims; ++d)
408 printf("[%zu:%td:%td:%td]", p.nodes[d].n, p.nodes[d].is, p.nodes[d].os,
409 p.nodes[d].ss);
410 printf(" off:%zu:%zu\n", p.ioff, p.ooff);
411 }
412
413 } // namespace tr
414
415 } // namespace x64
416 } // namespace cpu
417 } // namespace impl
418 } // namespace dnnl
419