1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <cctype>
18 #include <sstream>
19 #include <thread>
20
21 #include "gpu/jit/conv/tensor.hpp"
22
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace jit {
27
layout_t(const type_t & type,const expr_t & offset,const std::string & format,const std::vector<dim_t> & dims,bool do_normalize)28 layout_t::layout_t(const type_t &type, const expr_t &offset,
29 const std::string &format, const std::vector<dim_t> &dims,
30 bool do_normalize)
31 : type_(type), offset_(offset) {
32 auto parts = parse_format(format, int(dims.size()));
33 ndims_ = 0;
34 for (auto &p : parts) {
35 int dim_idx = p.first;
36 dim_t block = p.second;
37 ndims_ = std::max(ndims_, dim_idx + 1);
38 if (block == 0 && dims.empty())
39 ir_error_not_expected()
40 << "Dimensions are missing. Can't deduce them from "
41 "the format.";
42 }
43 if (!dims.empty() && ndims_ != int(dims.size())) {
44 ir_error_not_expected() << "Format and dimensions do not match.";
45 }
46
47 dim_t stride = 1;
48 // Iterate from right to left (innermost to outermost).
49 for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
50 int dim_idx = it->first;
51 dim_t block = it->second;
52 if (block == 0) {
53 dim_t full_block = 1;
54 for (auto &b : blocks_)
55 if (b.dim_idx == dim_idx) full_block *= b.block;
56
57 block = utils::div_up(dims[dim_idx], full_block);
58 }
59
60 blocks_.emplace_back(dim_idx, block, stride);
61 stride = block * stride;
62 }
63
64 if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
65 sanity_check();
66 }
67
layout_t(const memory_desc_wrapper & mdw,bool do_normalize)68 layout_t::layout_t(const memory_desc_wrapper &mdw, bool do_normalize)
69 : type_(mdw.data_type()), offset_(mdw.offset0()) {
70 ir_assert(mdw.is_blocking_desc()) << "Expected blocking memory descriptor.";
71
72 ndims_ = mdw.ndims();
73 auto &blocking = mdw.blocking_desc();
74 auto *padded_dims = mdw.padded_dims();
75
76 dim_t stride = 1;
77 std::vector<dim_t> full_blocks(ndims_, 1);
78 for (int i = blocking.inner_nblks - 1; i >= 0; i--) {
79 int dim_idx = blocking.inner_idxs[i];
80 dim_t block = blocking.inner_blks[i];
81 blocks_.emplace_back(dim_idx, block, stride);
82 stride *= block;
83 full_blocks[dim_idx] *= block;
84 }
85
86 for (int i = 0; i < ndims_; i++) {
87 dim_t block = padded_dims[i] / full_blocks[i];
88 blocks_.emplace_back(i, block, blocking.strides[i]);
89 }
90
91 // Sort outer blocks by their stride.
92 std::sort(blocks_.begin() + blocking.inner_nblks, blocks_.end(),
93 [](const block_t &a, const block_t &b) {
94 if (a.stride == b.stride) return a.dim_idx > b.dim_idx;
95 return a.stride < b.stride;
96 });
97
98 if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
99 sanity_check();
100 }
101
to_dnnl(const dim_t * dims_hint) const102 memory_desc_t layout_t::to_dnnl(const dim_t *dims_hint) const {
103 memory_desc_t md = {};
104 md.ndims = ndims();
105 std::copy(dims_hint, dims_hint + ndims(), md.dims);
106 md.data_type = jit::to_dnnl(type_);
107 md.offset0 = to_cpp<dim_t>(offset_);
108 md.format_kind = format_kind::blocked;
109
110 auto &blk = md.format_desc.blocking;
111 bool seen[DNNL_MAX_NDIMS] = {};
112
113 bool in_inner_block = false;
114 dim_t prev_stride = 0;
115
116 for (auto it = blocks_.rbegin(); it != blocks_.rend(); ++it) {
117 auto &b = *it;
118 if (!seen[b.dim_idx]) {
119 // Outer block.
120 ir_assert(!in_inner_block);
121 MAYBE_UNUSED(in_inner_block);
122 blk.strides[b.dim_idx] = b.stride;
123 md.padded_dims[b.dim_idx] = b.block;
124 } else {
125 // Inner block.
126 md.padded_dims[b.dim_idx] *= b.block;
127 blk.inner_idxs[blk.inner_nblks] = b.dim_idx;
128 blk.inner_blks[blk.inner_nblks] = b.block;
129 blk.inner_nblks++;
130 if (prev_stride > 0) {
131 // Inner block must be dense.
132 ir_assert(prev_stride == b.block * b.stride);
133 }
134 prev_stride = b.stride;
135 in_inner_block = true;
136 }
137 seen[b.dim_idx] = true;
138 }
139
140 return md;
141 }
142
map(const tensor_t & tensor) const143 layout_t layout_t::map(const tensor_t &tensor) const {
144 if (ndims() != tensor.ndims())
145 ir_error_not_expected() << "Dimensions do not match.";
146
147 std::vector<dim_t> remaining_dims = tensor.dims();
148 std::vector<block_t> mapped_blocks;
149
150 for (auto &eb : enumerated_blocks()) {
151 block_t &b = eb.second;
152 bool b_is_outermost = is_outermost(eb);
153
154 dim_t block = b.block;
155 dim_t &rem_dim = remaining_dims[b.dim_idx];
156 if (rem_dim == 1) {
157 if (b_is_outermost) {
158 // This is to have similarity between the current and
159 // mapped layouts.
160 mapped_blocks.emplace_back(b.dim_idx, 1, b.stride);
161 }
162 continue;
163 }
164 if (b_is_outermost) {
165 block = rem_dim;
166 } else if (rem_dim % block != 0) {
167 // Try to split the current block and start mapping from
168 // scratch.
169 if (block % rem_dim == 0)
170 return split_block(eb, rem_dim, block / rem_dim).map(tensor);
171
172 ir_error_not_expected() << "Can't map tensor layout.";
173 }
174 rem_dim /= block;
175 mapped_blocks.emplace_back(b.dim_idx, block, b.stride);
176 }
177
178 for (auto &d : remaining_dims) {
179 ir_assert(d == 1) << "Can't map tensor layout.";
180 MAYBE_UNUSED(d);
181 }
182
183 return layout_t(type(), ndims(), operator()(tensor.start()), mapped_blocks);
184 }
185
reinterpret(const type_t & new_type,bool do_normalize) const186 layout_t layout_t::reinterpret(
187 const type_t &new_type, bool do_normalize) const {
188 int old_size = type().size();
189 int new_size = new_type.size();
190 if (new_size == old_size) return *this;
191
192 expr_t new_offset = 0;
193 if (!has_zero_offset()) {
194 ir_assert(is_const(offset_)) << "Expected constant offset.";
195 int64_t off = to_cpp<int64_t>(offset_) * old_size;
196 ir_assert(off % new_size == 0);
197 new_offset = off / new_size;
198 }
199
200 if (old_size % new_size != 0 && new_size % old_size != 0) {
201 ir_error_not_expected();
202 return layout_t();
203 }
204
205 auto new_blocks = blocks_;
206 if (new_blocks.empty()) {
207 ir_error_not_expected() << "Can't reinterpret.";
208 return layout_t();
209 }
210
211 if (new_size < old_size) {
212 int factor = (old_size / new_size);
213 auto &b0 = new_blocks.front();
214 b0.block *= factor;
215 // Recompute strides.
216 for (auto &b : new_blocks) {
217 if (&b == &b0) continue;
218 b.stride *= factor;
219 }
220 } else {
221 int factor = (new_size / old_size);
222 auto &b0 = new_blocks.front();
223 if (b0.block % factor != 0) {
224 ir_error_not_expected();
225 return layout_t();
226 }
227 b0.block /= factor;
228 // Recompute strides.
229 for (auto &b : new_blocks) {
230 if (&b == &b0) continue;
231 if (b.stride % factor != 0) {
232 ir_error_not_expected();
233 return layout_t();
234 }
235 b.stride /= factor;
236 }
237 }
238
239 return layout_t(new_type, ndims(), new_offset, new_blocks, do_normalize);
240 }
241
split_block(const std::pair<int,block_t> & eb,dim_t block0,dim_t block1) const242 layout_t layout_t::split_block(
243 const std::pair<int, block_t> &eb, dim_t block0, dim_t block1) const {
244 int block_idx = eb.first;
245 auto &b = eb.second;
246 ir_assert(b.block == block0 * block1) << "Incompatible block sizes.";
247 MAYBE_UNUSED(b);
248
249 auto new_blocks = blocks_;
250
251 block_t &b0 = new_blocks[block_idx];
252 block_t b1 = b0;
253
254 b0.block = block0;
255 b1.block = block1;
256 b1.stride = b0.stride * block0;
257
258 new_blocks.insert(new_blocks.begin() + block_idx + 1, b1);
259
260 return layout_t(
261 type(), ndims(), offset(), new_blocks, /*do_normalize=*/false);
262 }
263
split_into_multi_blocks(const std::vector<dim_t> & multi_blocks) const264 layout_t layout_t::split_into_multi_blocks(
265 const std::vector<dim_t> &multi_blocks) const {
266 return split_into_multi_blocks_impl(multi_blocks, nullptr);
267 }
268
split_into_multi_blocks_with_hint(std::vector<dim_t> & multi_blocks) const269 layout_t layout_t::split_into_multi_blocks_with_hint(
270 std::vector<dim_t> &multi_blocks) const {
271 return split_into_multi_blocks_impl(multi_blocks, &multi_blocks);
272 }
273
split_into_dense_tile(dim_t tile_elems,dim_t outer_block) const274 tensor_t layout_t::split_into_dense_tile(
275 dim_t tile_elems, dim_t outer_block) const {
276 stride_t dense_stride = 1;
277 dim_t cur_tile_elems = 1;
278 dim_t cur_outer_block = 1;
279 bool in_tile = (tile_elems != 1);
280 std::vector<dim_t> tile_dims(ndims(), 1);
281 for (auto &b : blocks()) {
282 if (b.block == 1) continue;
283 if (in_tile) {
284 if (b.stride.is_unknown()) return tensor_t();
285 if (dense_stride != b.stride) return tensor_t();
286 dense_stride = b.block * b.stride;
287 cur_tile_elems *= b.block;
288 tile_dims[b.dim_idx] *= b.block;
289 ir_assert(cur_tile_elems <= tile_elems);
290 if (cur_tile_elems == tile_elems) in_tile = false;
291 } else {
292 if (outer_block == 1) break;
293 cur_outer_block *= b.block;
294 tile_dims[b.dim_idx] *= b.block;
295 ir_assert(cur_outer_block <= outer_block);
296 if (cur_outer_block == outer_block) break;
297 }
298 }
299 if (cur_tile_elems != tile_elems) return tensor_t();
300 if (cur_outer_block != outer_block) return tensor_t();
301 return tensor_t(tile_dims);
302 }
303
split_into_max_tile(dim_t max_tile_elems,bool is_dense_tile) const304 tensor_t layout_t::split_into_max_tile(
305 dim_t max_tile_elems, bool is_dense_tile) const {
306 stride_t dense_stride = 1;
307 std::vector<dim_t> tile_dims(ndims(), 1);
308 dim_t cur_elems = 1;
309 for (auto &eb : enumerated_blocks()) {
310 auto &b = eb.second;
311 if (b.block == 1) continue;
312 if (b.block * cur_elems <= max_tile_elems) {
313 if (is_dense_tile) {
314 if (b.stride.is_unknown()) break;
315 if (dense_stride != b.stride) break;
316 dense_stride = b.block * b.stride;
317 }
318 cur_elems *= b.block;
319 tile_dims[b.dim_idx] *= b.block;
320 continue;
321 }
322 dim_t max_block = utils::max_div(b.block, max_tile_elems / cur_elems);
323 if (max_block == 1) break;
324 auto tmp_layout = split_block(eb, max_block, b.block / max_block);
325 return tmp_layout.split_into_max_tile(max_tile_elems, is_dense_tile);
326 }
327 return tensor_t(tile_dims);
328 }
329
align_layouts(layout_t & a,layout_t & b)330 void layout_t::align_layouts(layout_t &a, layout_t &b) {
331 for (int i = 0; i < a.ndims(); i++) {
332 auto a_blocks = a.blocks();
333 auto b_blocks = b.blocks();
334
335 int a_max = int(a_blocks.size());
336 int b_max = int(b_blocks.size());
337 int a_idx = 0;
338 int b_idx = 0;
339
340 for (;;) {
341 while (a_idx < a_max && a_blocks[a_idx].dim_idx != i)
342 a_idx++;
343 while (b_idx < b_max && b_blocks[b_idx].dim_idx != i)
344 b_idx++;
345
346 if (a_idx >= a_max || b_idx >= b_max) break;
347
348 auto &ab = a_blocks[a_idx];
349 auto &bb = b_blocks[b_idx];
350 dim_t common_block = math::gcd(ab.block, bb.block);
351 if (ab.block == common_block && bb.block == common_block) {
352 a_idx++;
353 b_idx++;
354 continue;
355 }
356
357 if (ab.block != common_block) {
358 a = a.split_block(
359 {a_idx, ab}, common_block, ab.block / common_block);
360 }
361 if (bb.block != common_block) {
362 b = b.split_block(
363 {b_idx, bb}, common_block, bb.block / common_block);
364 }
365 break;
366 }
367 }
368 }
369
parse_letter_blocks(const std::string & format)370 std::vector<std::pair<char, dim_t>> layout_t::parse_letter_blocks(
371 const std::string &format) {
372 std::vector<std::pair<char, dim_t>> ret;
373
374 std::stringstream ss(format);
375 while (!ss.eof()) {
376 int next = ss.peek();
377 if (ss.eof()) break;
378 dim_t block = 0;
379 while (std::isdigit(next)) {
380 block = 10 * block + (next - '0');
381 ss.ignore(1);
382 next = ss.peek();
383 }
384 char letter = char(ss.peek());
385 ir_assert(!ss.eof()) << "EOF is unexpected.";
386 ss.ignore(1);
387 ret.emplace_back(letter, block);
388 }
389 return ret;
390 }
391
parse_format(const std::string & format,int ndims_hint)392 std::vector<std::pair<int, dim_t>> layout_t::parse_format(
393 const std::string &format, int ndims_hint) {
394 bool seen_letters[DNNL_MAX_NDIMS] = {};
395 int letter_ndims = 0;
396 for (char c = 'a'; c < 'a' + DNNL_MAX_NDIMS; c++) {
397 if (format.find(c) != std::string::npos) {
398 seen_letters[c - 'a'] = true;
399 MAYBE_UNUSED(seen_letters);
400 letter_ndims++;
401 }
402 }
403
404 for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
405 ir_assert(seen_letters[i] == (i < letter_ndims));
406 }
407
408 auto letter_blocks = parse_letter_blocks(format);
409
410 std::vector<std::pair<int, dim_t>> parts;
411 for (auto &p : letter_blocks) {
412 char letter = p.first;
413 dim_t block = p.second;
414 if (letter != 'x') {
415 int dim_idx = std::tolower(letter) - 'a';
416 parts.emplace_back(dim_idx, block);
417 } else {
418 ir_assert(ndims_hint >= letter_ndims);
419 for (int i = letter_ndims; i < ndims_hint; i++) {
420 parts.emplace_back(i, 0);
421 }
422 }
423 }
424
425 return parts;
426 }
427
sanity_check() const428 void layout_t::sanity_check() const {
429 #ifdef NDEBUG
430 return;
431 #endif
432 if (is_empty()) return;
433
434 for (auto &b : blocks_) {
435 ir_assert(b.block > 0) << "Incorrect block size.";
436 MAYBE_UNUSED(b);
437 }
438 }
439
split_into_multi_blocks_impl(const std::vector<dim_t> & multi_blocks,std::vector<dim_t> * out_multi_blocks) const440 layout_t layout_t::split_into_multi_blocks_impl(
441 const std::vector<dim_t> &multi_blocks,
442 std::vector<dim_t> *out_multi_blocks) const {
443 if (is_empty()) return *this;
444
445 bool allow_smaller_blocks = bool(out_multi_blocks);
446 layout_t tmp(*this);
447 std::vector<dim_t> rem_elems = multi_blocks;
448 std::vector<dim_t> cur_elems(rem_elems.size(), 1);
449 for (auto &eb : tmp.enumerated_blocks()) {
450 auto &b = eb.second;
451 for (int i = 0; i < int(rem_elems.size()); i++) {
452 auto &e = rem_elems[i];
453 if (e == 1) continue;
454 if (b.block > e) {
455 // Try to split this block.
456 int next_block = utils::max_div(b.block, e);
457 if (next_block == 1) return layout_t();
458 return tmp.split_block(eb, next_block, b.block / next_block)
459 .split_into_multi_blocks_impl(
460 multi_blocks, out_multi_blocks);
461 }
462 if (e % b.block != 0) {
463 if (!allow_smaller_blocks) return layout_t();
464 }
465 e /= b.block;
466 cur_elems[i] *= b.block;
467 break;
468 }
469 }
470 for (int i = 0; i < int(cur_elems.size()); i++) {
471 if (cur_elems[i] != multi_blocks[i]) {
472 if (!allow_smaller_blocks) return layout_t();
473 }
474 if (out_multi_blocks) (*out_multi_blocks)[i] = cur_elems[i];
475 }
476 return tmp;
477 }
478
pop_block(int size)479 expr_t grid_splitter_t::pop_block(int size) {
480 ir_assert(size > 1);
481 ir_assert(can_pop_block(size));
482
483 int new_stride = cur_stride_ * size;
484
485 auto idx_expr = grid_.idx(cur_idx_);
486 if (cur_stride_ != 1) idx_expr /= cur_stride_;
487 if (new_stride != grid_.dim(cur_idx_)) idx_expr %= size;
488
489 cur_stride_ = new_stride;
490 if (cur_stride_ == grid_.dim(cur_idx_)) {
491 // Move to the next dimension.
492 cur_idx_--;
493 skip_size_1_dims();
494 cur_stride_ = 1;
495 }
496 return idx_expr;
497 }
498
compute_stride(const expr_t & e,int idx,const expr_t & var)499 stride_t tdim_info_t::compute_stride(
500 const expr_t &e, int idx, const expr_t &var) {
501 // e == var -> fixed stride.
502 if (e.is_same(var)) return stride_t(1);
503
504 auto vars = find_objects<var_t>(e);
505
506 auto e0 = e;
507 auto e1 = substitute(e, var, var + 1);
508 auto e_stride = simplify(e1 - e0);
509
510 if (is_const(e_stride)) return stride_t(to_cpp<dim_t>(e_stride));
511
512 // Stride is not a constant.
513 return stride_t::unknown();
514 }
515
create_sub_view(const tensor_t & sub_tensor) const516 view_t view_t::create_sub_view(const tensor_t &sub_tensor) const {
517 ir_assert(sub_tensor.ndims() == nvdims()) << "Dimensions don't match.";
518
519 auto ret = *this;
520 ret.vdims_ = sub_tensor.dims();
521 for (int i = 0; i < nvdims(); i++) {
522 auto &i_start = sub_tensor.start()[i];
523 if (is_zero(i_start)) continue;
524 auto &s = ret.vstart_[i];
525 s += i_start;
526 s = simplify(s);
527 }
528 return ret;
529 }
530
substitute(const expr_t & from,const expr_t & to) const531 view_t view_t::substitute(const expr_t &from, const expr_t &to) const {
532 view_t ret = *this;
533 for (int i = 0; i < nvdims(); i++) {
534 ret.vstart_[i] = jit::substitute(ret.vstart_[i], from, to);
535 ret.vstart_[i] = simplify(ret.vstart_[i]);
536 }
537 return ret;
538 }
539
create_vvars(int nvdims)540 std::vector<expr_t> view_t::create_vvars(int nvdims) {
541 static const int max_nvdims = 128;
542 static std::vector<expr_t> _vvars;
543 static std::once_flag initialized;
544 std::call_once(initialized, [&]() {
545 for (int i = 0; i < max_nvdims; i++)
546 _vvars.push_back(
547 var_t::make(type_t::s32(), "_" + std::to_string(i)));
548 });
549
550 ir_assert(nvdims <= max_nvdims) << "Too many dimensions: " << nvdims;
551 return std::vector<expr_t>(_vvars.begin(), _vvars.begin() + nvdims);
552 }
553
create_pseudo_vlayout(const layout_t & tlayout) const554 layout_t view_t::create_pseudo_vlayout(const layout_t &tlayout) const {
555 ir_assert(!tlayout.is_empty());
556
557 std::vector<dim_t> rem_vdims = vdims_;
558 std::vector<block_t> blocks;
559
560 for (auto &teb : tlayout.enumerated_blocks()) {
561 block_t &tb = teb.second;
562 bool tb_is_outermost = tlayout.is_outermost(teb);
563 dim_t tblock = tb.block;
564
565 auto &tinfo = tdims_[tb.dim_idx];
566 if (tb_is_outermost) {
567 bool is_first = true;
568 for (int i = tinfo.nvargs() - 1; i >= 0; i--) {
569 int vidx = tinfo.vidx(i);
570 if (rem_vdims[vidx] == 1) continue;
571
572 // When expression contains 2+ variables, use unknown
573 // stride unless the view variable is the innermost.
574 stride_t stride
575 = (is_first ? tinfo.vstride(i) : stride_t::unknown());
576 blocks.emplace_back(
577 vidx, rem_vdims[vidx], stride * stride_t(tb.stride));
578 rem_vdims[vidx] = 1;
579 is_first = false;
580 }
581 continue;
582 }
583
584 ir_assert(tinfo.is_identity()) << "Can't create pseudo-layout.";
585
586 int vidx = tinfo.vidx(0);
587 dim_t &rem_vdim = rem_vdims[vidx];
588 if (rem_vdim == 1) continue;
589
590 if (tb_is_outermost) {
591 tblock = rem_vdim;
592 rem_vdim = 1;
593 } else if (rem_vdim % tblock == 0) {
594 rem_vdim /= tblock;
595 } else if (rem_vdim % tblock != 0) {
596 // Try to split the current block and start from scratch.
597 if (tblock % rem_vdim == 0) {
598 auto tmp_layout
599 = tlayout.split_block(teb, rem_vdim, tblock / rem_vdim);
600 return create_pseudo_vlayout(tmp_layout);
601 }
602
603 ir_error_not_expected() << "Can't create pseudo-layout.";
604 }
605 blocks.emplace_back(tb.dim_idx, tblock, tb.stride);
606 }
607
608 for (auto &d : rem_vdims) {
609 ir_assert(d == 1) << "Can't create pseudo-layout.";
610 MAYBE_UNUSED(d);
611 }
612
613 return layout_t(tlayout.type(), nvdims(), 0, blocks);
614 }
615
map(const layout_t & layout) const616 layout_t dim_assignment_t::map(const layout_t &layout) const {
617 std::vector<block_t> new_blocks;
618 for (auto &b : layout.blocks()) {
619 int new_idx = assignments_[b.dim_idx];
620 if (new_idx == -1) continue; // Drop this block.
621 auto new_b = b;
622 new_b.dim_idx = new_idx;
623 new_blocks.push_back(new_b);
624 }
625 new_blocks = layout_t::normalize_blocks(new_ndims(), new_blocks,
626 /*remove_size_1_blocks=*/false);
627 auto ret = layout_t(layout.type(), new_ndims(), layout.offset(), new_blocks,
628 /*do_normalize=*/false);
629 ir_assert(layout.elems() == ret.elems())
630 << "Assignment doesn't preserve number of elements.";
631 return ret;
632 }
633
634 // Adds size one spatial dimensions according to input parameters. Spatial
635 // dimensions are assumed to be the last dimensions.
normalize_conv_spatial(const layout_t & layout,int old_sp_ndims,bool reduced_to_1d)636 layout_t normalize_conv_spatial(
637 const layout_t &layout, int old_sp_ndims, bool reduced_to_1d) {
638 int old_ndims = layout.ndims();
639 int new_ndims = old_ndims - old_sp_ndims + 3;
640
641 dim_assignment_t to_3d(old_ndims, new_ndims);
642 for (int i = 0; i < old_ndims; i++) {
643 if (i < old_ndims - old_sp_ndims) {
644 // Non-spatial dimensions.
645 to_3d.assign(i, i);
646 } else {
647 // Spatial dimensions.
648 int sp_idx = 3 - (old_ndims - i);
649 if (reduced_to_1d) sp_idx = 2;
650 to_3d.assign(i, new_ndims - (3 - sp_idx));
651 }
652 }
653 return to_3d.map(layout);
654 }
655
insert_dimension(const layout_t & layout,int dim_idx)656 layout_t insert_dimension(const layout_t &layout, int dim_idx) {
657 auto new_blocks = layout.blocks();
658 for (auto &b : new_blocks) {
659 if (b.dim_idx >= dim_idx) b.dim_idx++;
660 }
661 return layout_t(layout.type(), layout.ndims() + 1, layout.offset(),
662 new_blocks,
663 /*do_normalize=*/false);
664 }
665
remove_size_1_dimension(const layout_t & layout,int dim_idx)666 layout_t remove_size_1_dimension(const layout_t &layout, int dim_idx) {
667 ir_assert(0 <= dim_idx && dim_idx < layout.ndims());
668 ir_assert(layout.dim(dim_idx) == 1);
669 dim_assignment_t a(layout.ndims(), layout.ndims() - 1);
670 for (int i = 0; i < layout.ndims(); i++) {
671 if (i == dim_idx) continue;
672 a.assign(i, i < dim_idx ? i : i - 1);
673 }
674 return a.map(layout);
675 }
676
split_dimension(const layout_t & _layout,int dim_idx,int outer_block)677 layout_t split_dimension(
678 const layout_t &_layout, int dim_idx, int outer_block) {
679 int rem_inner_block
680 = ir_utils::safe_divide(_layout.dim(dim_idx), outer_block);
681 auto layout = insert_dimension(_layout, dim_idx);
682 std::vector<block_t> new_blocks;
683 for (auto &eb : layout.enumerated_blocks()) {
684 auto &b = eb.second;
685 if (b.dim_idx != dim_idx + 1) {
686 new_blocks.push_back(b);
687 continue;
688 }
689 if (b.block % rem_inner_block == 0) {
690 new_blocks.emplace_back(dim_idx + 1, rem_inner_block, b.stride);
691 new_blocks.emplace_back(dim_idx, b.block / rem_inner_block,
692 dim_t(b.stride) * rem_inner_block);
693 rem_inner_block = 1;
694 } else {
695 new_blocks.push_back(b);
696 rem_inner_block = ir_utils::safe_divide(rem_inner_block, b.block);
697 }
698 }
699
700 // Remove inner blocks with size one.
701 std::vector<block_t> _new_blocks;
702 std::vector<bool> seen(layout.ndims());
703 for (auto it = new_blocks.rbegin(); it != new_blocks.rend(); ++it) {
704 if (it->block == 1 && seen[it->dim_idx]) continue;
705 _new_blocks.push_back(*it);
706 seen[it->dim_idx] = true;
707 }
708 std::reverse(_new_blocks.begin(), _new_blocks.end());
709 return layout_t(layout.type(), layout.ndims(), layout.offset(), _new_blocks,
710 /*do_normalize=*/false);
711 }
712
normalize_conv_groups(const layout_t & layout,bool with_groups,int groups,bool is_dw,bool add_groups,bool is_wei)713 layout_t normalize_conv_groups(const layout_t &layout, bool with_groups,
714 int groups, bool is_dw, bool add_groups, bool is_wei) {
715 if (with_groups == add_groups) return layout;
716 if (is_wei) {
717 ir_assert(groups == 1)
718 << "Adding/removing groups can be done only for single group.";
719 if (add_groups) return insert_dimension(layout, 0);
720 return remove_size_1_dimension(layout, 0);
721 }
722
723 ir_assert(!with_groups) << "Unexpected groups in source/destination.";
724 if (is_dw) groups = layout.dim(1);
725 return split_dimension(layout, /*dim_idx=*/1, groups);
726 }
727
normalize_conv_layout(const layout_t & _layout,bool with_groups,int groups,bool is_dw,bool reduced_to_1d,bool add_groups,bool is_wei)728 layout_t normalize_conv_layout(const layout_t &_layout, bool with_groups,
729 int groups, bool is_dw, bool reduced_to_1d, bool add_groups,
730 bool is_wei) {
731 int old_sp_ndims = _layout.ndims() - (with_groups ? 3 : 2);
732
733 layout_t layout = _layout;
734 layout = normalize_conv_spatial(layout, old_sp_ndims, reduced_to_1d);
735 layout = normalize_conv_groups(
736 layout, with_groups, groups, is_dw, add_groups, is_wei);
737
738 return layout;
739 }
740
normalize_conv_dims(std::vector<dim_t> & dims,bool with_groups,int groups,bool is_dw,bool reduced_to_1d,bool add_groups,bool is_wei)741 std::vector<dim_t> normalize_conv_dims(std::vector<dim_t> &dims,
742 bool with_groups, int groups, bool is_dw, bool reduced_to_1d,
743 bool add_groups, bool is_wei) {
744 layout_t dummy_layout(type_t::u8(), 0, dims);
745 return normalize_conv_layout(dummy_layout, with_groups, groups, is_dw,
746 reduced_to_1d, add_groups, is_wei)
747 .dims();
748 }
749
normalize_conv_layouts(layout_t & src_layout,layout_t & wei_layout,layout_t & dst_layout,bool with_groups,int groups,bool is_dw,bool reduced_to_1d,bool add_groups)750 void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
751 layout_t &dst_layout, bool with_groups, int groups, bool is_dw,
752 bool reduced_to_1d, bool add_groups) {
753 src_layout = normalize_conv_layout(src_layout, /*with_groups=*/false,
754 groups, is_dw, reduced_to_1d, add_groups, /*is_wei=*/false);
755 wei_layout = normalize_conv_layout(wei_layout, with_groups, groups, is_dw,
756 reduced_to_1d, add_groups, /*is_wei=*/true);
757 dst_layout = normalize_conv_layout(dst_layout, /*with_groups=*/false,
758 groups, is_dw, reduced_to_1d, add_groups, /*is_wei=*/false);
759 }
760
761 } // namespace jit
762 } // namespace gpu
763 } // namespace impl
764 } // namespace dnnl
765