1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <cctype>
18 #include <sstream>
19 #include <thread>
20 
21 #include "gpu/jit/conv/tensor.hpp"
22 
23 namespace dnnl {
24 namespace impl {
25 namespace gpu {
26 namespace jit {
27 
layout_t(const type_t & type,const expr_t & offset,const std::string & format,const std::vector<dim_t> & dims,bool do_normalize)28 layout_t::layout_t(const type_t &type, const expr_t &offset,
29         const std::string &format, const std::vector<dim_t> &dims,
30         bool do_normalize)
31     : type_(type), offset_(offset) {
32     auto parts = parse_format(format, int(dims.size()));
33     ndims_ = 0;
34     for (auto &p : parts) {
35         int dim_idx = p.first;
36         dim_t block = p.second;
37         ndims_ = std::max(ndims_, dim_idx + 1);
38         if (block == 0 && dims.empty())
39             ir_error_not_expected()
40                     << "Dimensions are missing. Can't deduce them from "
41                        "the format.";
42     }
43     if (!dims.empty() && ndims_ != int(dims.size())) {
44         ir_error_not_expected() << "Format and dimensions do not match.";
45     }
46 
47     dim_t stride = 1;
48     // Iterate from right to left (innermost to outermost).
49     for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
50         int dim_idx = it->first;
51         dim_t block = it->second;
52         if (block == 0) {
53             dim_t full_block = 1;
54             for (auto &b : blocks_)
55                 if (b.dim_idx == dim_idx) full_block *= b.block;
56 
57             block = utils::div_up(dims[dim_idx], full_block);
58         }
59 
60         blocks_.emplace_back(dim_idx, block, stride);
61         stride = block * stride;
62     }
63 
64     if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
65     sanity_check();
66 }
67 
layout_t(const memory_desc_wrapper & mdw,bool do_normalize)68 layout_t::layout_t(const memory_desc_wrapper &mdw, bool do_normalize)
69     : type_(mdw.data_type()), offset_(mdw.offset0()) {
70     ir_assert(mdw.is_blocking_desc()) << "Expected blocking memory descriptor.";
71 
72     ndims_ = mdw.ndims();
73     auto &blocking = mdw.blocking_desc();
74     auto *padded_dims = mdw.padded_dims();
75 
76     dim_t stride = 1;
77     std::vector<dim_t> full_blocks(ndims_, 1);
78     for (int i = blocking.inner_nblks - 1; i >= 0; i--) {
79         int dim_idx = blocking.inner_idxs[i];
80         dim_t block = blocking.inner_blks[i];
81         blocks_.emplace_back(dim_idx, block, stride);
82         stride *= block;
83         full_blocks[dim_idx] *= block;
84     }
85 
86     for (int i = 0; i < ndims_; i++) {
87         dim_t block = padded_dims[i] / full_blocks[i];
88         blocks_.emplace_back(i, block, blocking.strides[i]);
89     }
90 
91     // Sort outer blocks by their stride.
92     std::sort(blocks_.begin() + blocking.inner_nblks, blocks_.end(),
93             [](const block_t &a, const block_t &b) {
94                 if (a.stride == b.stride) return a.dim_idx > b.dim_idx;
95                 return a.stride < b.stride;
96             });
97 
98     if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
99     sanity_check();
100 }
101 
to_dnnl(const dim_t * dims_hint) const102 memory_desc_t layout_t::to_dnnl(const dim_t *dims_hint) const {
103     memory_desc_t md = {};
104     md.ndims = ndims();
105     std::copy(dims_hint, dims_hint + ndims(), md.dims);
106     md.data_type = jit::to_dnnl(type_);
107     md.offset0 = to_cpp<dim_t>(offset_);
108     md.format_kind = format_kind::blocked;
109 
110     auto &blk = md.format_desc.blocking;
111     bool seen[DNNL_MAX_NDIMS] = {};
112 
113     bool in_inner_block = false;
114     dim_t prev_stride = 0;
115 
116     for (auto it = blocks_.rbegin(); it != blocks_.rend(); ++it) {
117         auto &b = *it;
118         if (!seen[b.dim_idx]) {
119             // Outer block.
120             ir_assert(!in_inner_block);
121             MAYBE_UNUSED(in_inner_block);
122             blk.strides[b.dim_idx] = b.stride;
123             md.padded_dims[b.dim_idx] = b.block;
124         } else {
125             // Inner block.
126             md.padded_dims[b.dim_idx] *= b.block;
127             blk.inner_idxs[blk.inner_nblks] = b.dim_idx;
128             blk.inner_blks[blk.inner_nblks] = b.block;
129             blk.inner_nblks++;
130             if (prev_stride > 0) {
131                 // Inner block must be dense.
132                 ir_assert(prev_stride == b.block * b.stride);
133             }
134             prev_stride = b.stride;
135             in_inner_block = true;
136         }
137         seen[b.dim_idx] = true;
138     }
139 
140     return md;
141 }
142 
map(const tensor_t & tensor) const143 layout_t layout_t::map(const tensor_t &tensor) const {
144     if (ndims() != tensor.ndims())
145         ir_error_not_expected() << "Dimensions do not match.";
146 
147     std::vector<dim_t> remaining_dims = tensor.dims();
148     std::vector<block_t> mapped_blocks;
149 
150     for (auto &eb : enumerated_blocks()) {
151         block_t &b = eb.second;
152         bool b_is_outermost = is_outermost(eb);
153 
154         dim_t block = b.block;
155         dim_t &rem_dim = remaining_dims[b.dim_idx];
156         if (rem_dim == 1) {
157             if (b_is_outermost) {
158                 // This is to have similarity between the current and
159                 // mapped layouts.
160                 mapped_blocks.emplace_back(b.dim_idx, 1, b.stride);
161             }
162             continue;
163         }
164         if (b_is_outermost) {
165             block = rem_dim;
166         } else if (rem_dim % block != 0) {
167             // Try to split the current block and start mapping from
168             // scratch.
169             if (block % rem_dim == 0)
170                 return split_block(eb, rem_dim, block / rem_dim).map(tensor);
171 
172             ir_error_not_expected() << "Can't map tensor layout.";
173         }
174         rem_dim /= block;
175         mapped_blocks.emplace_back(b.dim_idx, block, b.stride);
176     }
177 
178     for (auto &d : remaining_dims) {
179         ir_assert(d == 1) << "Can't map tensor layout.";
180         MAYBE_UNUSED(d);
181     }
182 
183     return layout_t(type(), ndims(), operator()(tensor.start()), mapped_blocks);
184 }
185 
reinterpret(const type_t & new_type,bool do_normalize) const186 layout_t layout_t::reinterpret(
187         const type_t &new_type, bool do_normalize) const {
188     int old_size = type().size();
189     int new_size = new_type.size();
190     if (new_size == old_size) return *this;
191 
192     expr_t new_offset = 0;
193     if (!has_zero_offset()) {
194         ir_assert(is_const(offset_)) << "Expected constant offset.";
195         int64_t off = to_cpp<int64_t>(offset_) * old_size;
196         ir_assert(off % new_size == 0);
197         new_offset = off / new_size;
198     }
199 
200     if (old_size % new_size != 0 && new_size % old_size != 0) {
201         ir_error_not_expected();
202         return layout_t();
203     }
204 
205     auto new_blocks = blocks_;
206     if (new_blocks.empty()) {
207         ir_error_not_expected() << "Can't reinterpret.";
208         return layout_t();
209     }
210 
211     if (new_size < old_size) {
212         int factor = (old_size / new_size);
213         auto &b0 = new_blocks.front();
214         b0.block *= factor;
215         // Recompute strides.
216         for (auto &b : new_blocks) {
217             if (&b == &b0) continue;
218             b.stride *= factor;
219         }
220     } else {
221         int factor = (new_size / old_size);
222         auto &b0 = new_blocks.front();
223         if (b0.block % factor != 0) {
224             ir_error_not_expected();
225             return layout_t();
226         }
227         b0.block /= factor;
228         // Recompute strides.
229         for (auto &b : new_blocks) {
230             if (&b == &b0) continue;
231             if (b.stride % factor != 0) {
232                 ir_error_not_expected();
233                 return layout_t();
234             }
235             b.stride /= factor;
236         }
237     }
238 
239     return layout_t(new_type, ndims(), new_offset, new_blocks, do_normalize);
240 }
241 
split_block(const std::pair<int,block_t> & eb,dim_t block0,dim_t block1) const242 layout_t layout_t::split_block(
243         const std::pair<int, block_t> &eb, dim_t block0, dim_t block1) const {
244     int block_idx = eb.first;
245     auto &b = eb.second;
246     ir_assert(b.block == block0 * block1) << "Incompatible block sizes.";
247     MAYBE_UNUSED(b);
248 
249     auto new_blocks = blocks_;
250 
251     block_t &b0 = new_blocks[block_idx];
252     block_t b1 = b0;
253 
254     b0.block = block0;
255     b1.block = block1;
256     b1.stride = b0.stride * block0;
257 
258     new_blocks.insert(new_blocks.begin() + block_idx + 1, b1);
259 
260     return layout_t(
261             type(), ndims(), offset(), new_blocks, /*do_normalize=*/false);
262 }
263 
split_into_multi_blocks(const std::vector<dim_t> & multi_blocks) const264 layout_t layout_t::split_into_multi_blocks(
265         const std::vector<dim_t> &multi_blocks) const {
266     return split_into_multi_blocks_impl(multi_blocks, nullptr);
267 }
268 
split_into_multi_blocks_with_hint(std::vector<dim_t> & multi_blocks) const269 layout_t layout_t::split_into_multi_blocks_with_hint(
270         std::vector<dim_t> &multi_blocks) const {
271     return split_into_multi_blocks_impl(multi_blocks, &multi_blocks);
272 }
273 
split_into_dense_tile(dim_t tile_elems,dim_t outer_block) const274 tensor_t layout_t::split_into_dense_tile(
275         dim_t tile_elems, dim_t outer_block) const {
276     stride_t dense_stride = 1;
277     dim_t cur_tile_elems = 1;
278     dim_t cur_outer_block = 1;
279     bool in_tile = (tile_elems != 1);
280     std::vector<dim_t> tile_dims(ndims(), 1);
281     for (auto &b : blocks()) {
282         if (b.block == 1) continue;
283         if (in_tile) {
284             if (b.stride.is_unknown()) return tensor_t();
285             if (dense_stride != b.stride) return tensor_t();
286             dense_stride = b.block * b.stride;
287             cur_tile_elems *= b.block;
288             tile_dims[b.dim_idx] *= b.block;
289             ir_assert(cur_tile_elems <= tile_elems);
290             if (cur_tile_elems == tile_elems) in_tile = false;
291         } else {
292             if (outer_block == 1) break;
293             cur_outer_block *= b.block;
294             tile_dims[b.dim_idx] *= b.block;
295             ir_assert(cur_outer_block <= outer_block);
296             if (cur_outer_block == outer_block) break;
297         }
298     }
299     if (cur_tile_elems != tile_elems) return tensor_t();
300     if (cur_outer_block != outer_block) return tensor_t();
301     return tensor_t(tile_dims);
302 }
303 
split_into_max_tile(dim_t max_tile_elems,bool is_dense_tile) const304 tensor_t layout_t::split_into_max_tile(
305         dim_t max_tile_elems, bool is_dense_tile) const {
306     stride_t dense_stride = 1;
307     std::vector<dim_t> tile_dims(ndims(), 1);
308     dim_t cur_elems = 1;
309     for (auto &eb : enumerated_blocks()) {
310         auto &b = eb.second;
311         if (b.block == 1) continue;
312         if (b.block * cur_elems <= max_tile_elems) {
313             if (is_dense_tile) {
314                 if (b.stride.is_unknown()) break;
315                 if (dense_stride != b.stride) break;
316                 dense_stride = b.block * b.stride;
317             }
318             cur_elems *= b.block;
319             tile_dims[b.dim_idx] *= b.block;
320             continue;
321         }
322         dim_t max_block = utils::max_div(b.block, max_tile_elems / cur_elems);
323         if (max_block == 1) break;
324         auto tmp_layout = split_block(eb, max_block, b.block / max_block);
325         return tmp_layout.split_into_max_tile(max_tile_elems, is_dense_tile);
326     }
327     return tensor_t(tile_dims);
328 }
329 
align_layouts(layout_t & a,layout_t & b)330 void layout_t::align_layouts(layout_t &a, layout_t &b) {
331     for (int i = 0; i < a.ndims(); i++) {
332         auto a_blocks = a.blocks();
333         auto b_blocks = b.blocks();
334 
335         int a_max = int(a_blocks.size());
336         int b_max = int(b_blocks.size());
337         int a_idx = 0;
338         int b_idx = 0;
339 
340         for (;;) {
341             while (a_idx < a_max && a_blocks[a_idx].dim_idx != i)
342                 a_idx++;
343             while (b_idx < b_max && b_blocks[b_idx].dim_idx != i)
344                 b_idx++;
345 
346             if (a_idx >= a_max || b_idx >= b_max) break;
347 
348             auto &ab = a_blocks[a_idx];
349             auto &bb = b_blocks[b_idx];
350             dim_t common_block = math::gcd(ab.block, bb.block);
351             if (ab.block == common_block && bb.block == common_block) {
352                 a_idx++;
353                 b_idx++;
354                 continue;
355             }
356 
357             if (ab.block != common_block) {
358                 a = a.split_block(
359                         {a_idx, ab}, common_block, ab.block / common_block);
360             }
361             if (bb.block != common_block) {
362                 b = b.split_block(
363                         {b_idx, bb}, common_block, bb.block / common_block);
364             }
365             break;
366         }
367     }
368 }
369 
parse_letter_blocks(const std::string & format)370 std::vector<std::pair<char, dim_t>> layout_t::parse_letter_blocks(
371         const std::string &format) {
372     std::vector<std::pair<char, dim_t>> ret;
373 
374     std::stringstream ss(format);
375     while (!ss.eof()) {
376         int next = ss.peek();
377         if (ss.eof()) break;
378         dim_t block = 0;
379         while (std::isdigit(next)) {
380             block = 10 * block + (next - '0');
381             ss.ignore(1);
382             next = ss.peek();
383         }
384         char letter = char(ss.peek());
385         ir_assert(!ss.eof()) << "EOF is unexpected.";
386         ss.ignore(1);
387         ret.emplace_back(letter, block);
388     }
389     return ret;
390 }
391 
parse_format(const std::string & format,int ndims_hint)392 std::vector<std::pair<int, dim_t>> layout_t::parse_format(
393         const std::string &format, int ndims_hint) {
394     bool seen_letters[DNNL_MAX_NDIMS] = {};
395     int letter_ndims = 0;
396     for (char c = 'a'; c < 'a' + DNNL_MAX_NDIMS; c++) {
397         if (format.find(c) != std::string::npos) {
398             seen_letters[c - 'a'] = true;
399             MAYBE_UNUSED(seen_letters);
400             letter_ndims++;
401         }
402     }
403 
404     for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
405         ir_assert(seen_letters[i] == (i < letter_ndims));
406     }
407 
408     auto letter_blocks = parse_letter_blocks(format);
409 
410     std::vector<std::pair<int, dim_t>> parts;
411     for (auto &p : letter_blocks) {
412         char letter = p.first;
413         dim_t block = p.second;
414         if (letter != 'x') {
415             int dim_idx = std::tolower(letter) - 'a';
416             parts.emplace_back(dim_idx, block);
417         } else {
418             ir_assert(ndims_hint >= letter_ndims);
419             for (int i = letter_ndims; i < ndims_hint; i++) {
420                 parts.emplace_back(i, 0);
421             }
422         }
423     }
424 
425     return parts;
426 }
427 
sanity_check() const428 void layout_t::sanity_check() const {
429 #ifdef NDEBUG
430     return;
431 #endif
432     if (is_empty()) return;
433 
434     for (auto &b : blocks_) {
435         ir_assert(b.block > 0) << "Incorrect block size.";
436         MAYBE_UNUSED(b);
437     }
438 }
439 
split_into_multi_blocks_impl(const std::vector<dim_t> & multi_blocks,std::vector<dim_t> * out_multi_blocks) const440 layout_t layout_t::split_into_multi_blocks_impl(
441         const std::vector<dim_t> &multi_blocks,
442         std::vector<dim_t> *out_multi_blocks) const {
443     if (is_empty()) return *this;
444 
445     bool allow_smaller_blocks = bool(out_multi_blocks);
446     layout_t tmp(*this);
447     std::vector<dim_t> rem_elems = multi_blocks;
448     std::vector<dim_t> cur_elems(rem_elems.size(), 1);
449     for (auto &eb : tmp.enumerated_blocks()) {
450         auto &b = eb.second;
451         for (int i = 0; i < int(rem_elems.size()); i++) {
452             auto &e = rem_elems[i];
453             if (e == 1) continue;
454             if (b.block > e) {
455                 // Try to split this block.
456                 int next_block = utils::max_div(b.block, e);
457                 if (next_block == 1) return layout_t();
458                 return tmp.split_block(eb, next_block, b.block / next_block)
459                         .split_into_multi_blocks_impl(
460                                 multi_blocks, out_multi_blocks);
461             }
462             if (e % b.block != 0) {
463                 if (!allow_smaller_blocks) return layout_t();
464             }
465             e /= b.block;
466             cur_elems[i] *= b.block;
467             break;
468         }
469     }
470     for (int i = 0; i < int(cur_elems.size()); i++) {
471         if (cur_elems[i] != multi_blocks[i]) {
472             if (!allow_smaller_blocks) return layout_t();
473         }
474         if (out_multi_blocks) (*out_multi_blocks)[i] = cur_elems[i];
475     }
476     return tmp;
477 }
478 
pop_block(int size)479 expr_t grid_splitter_t::pop_block(int size) {
480     ir_assert(size > 1);
481     ir_assert(can_pop_block(size));
482 
483     int new_stride = cur_stride_ * size;
484 
485     auto idx_expr = grid_.idx(cur_idx_);
486     if (cur_stride_ != 1) idx_expr /= cur_stride_;
487     if (new_stride != grid_.dim(cur_idx_)) idx_expr %= size;
488 
489     cur_stride_ = new_stride;
490     if (cur_stride_ == grid_.dim(cur_idx_)) {
491         // Move to the next dimension.
492         cur_idx_--;
493         skip_size_1_dims();
494         cur_stride_ = 1;
495     }
496     return idx_expr;
497 }
498 
compute_stride(const expr_t & e,int idx,const expr_t & var)499 stride_t tdim_info_t::compute_stride(
500         const expr_t &e, int idx, const expr_t &var) {
501     // e == var -> fixed stride.
502     if (e.is_same(var)) return stride_t(1);
503 
504     auto vars = find_objects<var_t>(e);
505 
506     auto e0 = e;
507     auto e1 = substitute(e, var, var + 1);
508     auto e_stride = simplify(e1 - e0);
509 
510     if (is_const(e_stride)) return stride_t(to_cpp<dim_t>(e_stride));
511 
512     // Stride is not a constant.
513     return stride_t::unknown();
514 }
515 
create_sub_view(const tensor_t & sub_tensor) const516 view_t view_t::create_sub_view(const tensor_t &sub_tensor) const {
517     ir_assert(sub_tensor.ndims() == nvdims()) << "Dimensions don't match.";
518 
519     auto ret = *this;
520     ret.vdims_ = sub_tensor.dims();
521     for (int i = 0; i < nvdims(); i++) {
522         auto &i_start = sub_tensor.start()[i];
523         if (is_zero(i_start)) continue;
524         auto &s = ret.vstart_[i];
525         s += i_start;
526         s = simplify(s);
527     }
528     return ret;
529 }
530 
substitute(const expr_t & from,const expr_t & to) const531 view_t view_t::substitute(const expr_t &from, const expr_t &to) const {
532     view_t ret = *this;
533     for (int i = 0; i < nvdims(); i++) {
534         ret.vstart_[i] = jit::substitute(ret.vstart_[i], from, to);
535         ret.vstart_[i] = simplify(ret.vstart_[i]);
536     }
537     return ret;
538 }
539 
create_vvars(int nvdims)540 std::vector<expr_t> view_t::create_vvars(int nvdims) {
541     static const int max_nvdims = 128;
542     static std::vector<expr_t> _vvars;
543     static std::once_flag initialized;
544     std::call_once(initialized, [&]() {
545         for (int i = 0; i < max_nvdims; i++)
546             _vvars.push_back(
547                     var_t::make(type_t::s32(), "_" + std::to_string(i)));
548     });
549 
550     ir_assert(nvdims <= max_nvdims) << "Too many dimensions: " << nvdims;
551     return std::vector<expr_t>(_vvars.begin(), _vvars.begin() + nvdims);
552 }
553 
create_pseudo_vlayout(const layout_t & tlayout) const554 layout_t view_t::create_pseudo_vlayout(const layout_t &tlayout) const {
555     ir_assert(!tlayout.is_empty());
556 
557     std::vector<dim_t> rem_vdims = vdims_;
558     std::vector<block_t> blocks;
559 
560     for (auto &teb : tlayout.enumerated_blocks()) {
561         block_t &tb = teb.second;
562         bool tb_is_outermost = tlayout.is_outermost(teb);
563         dim_t tblock = tb.block;
564 
565         auto &tinfo = tdims_[tb.dim_idx];
566         if (tb_is_outermost) {
567             bool is_first = true;
568             for (int i = tinfo.nvargs() - 1; i >= 0; i--) {
569                 int vidx = tinfo.vidx(i);
570                 if (rem_vdims[vidx] == 1) continue;
571 
572                 // When expression contains 2+ variables, use unknown
573                 // stride unless the view variable is the innermost.
574                 stride_t stride
575                         = (is_first ? tinfo.vstride(i) : stride_t::unknown());
576                 blocks.emplace_back(
577                         vidx, rem_vdims[vidx], stride * stride_t(tb.stride));
578                 rem_vdims[vidx] = 1;
579                 is_first = false;
580             }
581             continue;
582         }
583 
584         ir_assert(tinfo.is_identity()) << "Can't create pseudo-layout.";
585 
586         int vidx = tinfo.vidx(0);
587         dim_t &rem_vdim = rem_vdims[vidx];
588         if (rem_vdim == 1) continue;
589 
590         if (tb_is_outermost) {
591             tblock = rem_vdim;
592             rem_vdim = 1;
593         } else if (rem_vdim % tblock == 0) {
594             rem_vdim /= tblock;
595         } else if (rem_vdim % tblock != 0) {
596             // Try to split the current block and start from scratch.
597             if (tblock % rem_vdim == 0) {
598                 auto tmp_layout
599                         = tlayout.split_block(teb, rem_vdim, tblock / rem_vdim);
600                 return create_pseudo_vlayout(tmp_layout);
601             }
602 
603             ir_error_not_expected() << "Can't create pseudo-layout.";
604         }
605         blocks.emplace_back(tb.dim_idx, tblock, tb.stride);
606     }
607 
608     for (auto &d : rem_vdims) {
609         ir_assert(d == 1) << "Can't create pseudo-layout.";
610         MAYBE_UNUSED(d);
611     }
612 
613     return layout_t(tlayout.type(), nvdims(), 0, blocks);
614 }
615 
map(const layout_t & layout) const616 layout_t dim_assignment_t::map(const layout_t &layout) const {
617     std::vector<block_t> new_blocks;
618     for (auto &b : layout.blocks()) {
619         int new_idx = assignments_[b.dim_idx];
620         if (new_idx == -1) continue; // Drop this block.
621         auto new_b = b;
622         new_b.dim_idx = new_idx;
623         new_blocks.push_back(new_b);
624     }
625     new_blocks = layout_t::normalize_blocks(new_ndims(), new_blocks,
626             /*remove_size_1_blocks=*/false);
627     auto ret = layout_t(layout.type(), new_ndims(), layout.offset(), new_blocks,
628             /*do_normalize=*/false);
629     ir_assert(layout.elems() == ret.elems())
630             << "Assignment doesn't preserve number of elements.";
631     return ret;
632 }
633 
634 // Adds size one spatial dimensions according to input parameters. Spatial
635 // dimensions are assumed to be the last dimensions.
normalize_conv_spatial(const layout_t & layout,int old_sp_ndims,bool reduced_to_1d)636 layout_t normalize_conv_spatial(
637         const layout_t &layout, int old_sp_ndims, bool reduced_to_1d) {
638     int old_ndims = layout.ndims();
639     int new_ndims = old_ndims - old_sp_ndims + 3;
640 
641     dim_assignment_t to_3d(old_ndims, new_ndims);
642     for (int i = 0; i < old_ndims; i++) {
643         if (i < old_ndims - old_sp_ndims) {
644             // Non-spatial dimensions.
645             to_3d.assign(i, i);
646         } else {
647             // Spatial dimensions.
648             int sp_idx = 3 - (old_ndims - i);
649             if (reduced_to_1d) sp_idx = 2;
650             to_3d.assign(i, new_ndims - (3 - sp_idx));
651         }
652     }
653     return to_3d.map(layout);
654 }
655 
insert_dimension(const layout_t & layout,int dim_idx)656 layout_t insert_dimension(const layout_t &layout, int dim_idx) {
657     auto new_blocks = layout.blocks();
658     for (auto &b : new_blocks) {
659         if (b.dim_idx >= dim_idx) b.dim_idx++;
660     }
661     return layout_t(layout.type(), layout.ndims() + 1, layout.offset(),
662             new_blocks,
663             /*do_normalize=*/false);
664 }
665 
remove_size_1_dimension(const layout_t & layout,int dim_idx)666 layout_t remove_size_1_dimension(const layout_t &layout, int dim_idx) {
667     ir_assert(0 <= dim_idx && dim_idx < layout.ndims());
668     ir_assert(layout.dim(dim_idx) == 1);
669     dim_assignment_t a(layout.ndims(), layout.ndims() - 1);
670     for (int i = 0; i < layout.ndims(); i++) {
671         if (i == dim_idx) continue;
672         a.assign(i, i < dim_idx ? i : i - 1);
673     }
674     return a.map(layout);
675 }
676 
split_dimension(const layout_t & _layout,int dim_idx,int outer_block)677 layout_t split_dimension(
678         const layout_t &_layout, int dim_idx, int outer_block) {
679     int rem_inner_block
680             = ir_utils::safe_divide(_layout.dim(dim_idx), outer_block);
681     auto layout = insert_dimension(_layout, dim_idx);
682     std::vector<block_t> new_blocks;
683     for (auto &eb : layout.enumerated_blocks()) {
684         auto &b = eb.second;
685         if (b.dim_idx != dim_idx + 1) {
686             new_blocks.push_back(b);
687             continue;
688         }
689         if (b.block % rem_inner_block == 0) {
690             new_blocks.emplace_back(dim_idx + 1, rem_inner_block, b.stride);
691             new_blocks.emplace_back(dim_idx, b.block / rem_inner_block,
692                     dim_t(b.stride) * rem_inner_block);
693             rem_inner_block = 1;
694         } else {
695             new_blocks.push_back(b);
696             rem_inner_block = ir_utils::safe_divide(rem_inner_block, b.block);
697         }
698     }
699 
700     // Remove inner blocks with size one.
701     std::vector<block_t> _new_blocks;
702     std::vector<bool> seen(layout.ndims());
703     for (auto it = new_blocks.rbegin(); it != new_blocks.rend(); ++it) {
704         if (it->block == 1 && seen[it->dim_idx]) continue;
705         _new_blocks.push_back(*it);
706         seen[it->dim_idx] = true;
707     }
708     std::reverse(_new_blocks.begin(), _new_blocks.end());
709     return layout_t(layout.type(), layout.ndims(), layout.offset(), _new_blocks,
710             /*do_normalize=*/false);
711 }
712 
normalize_conv_groups(const layout_t & layout,bool with_groups,int groups,bool is_dw,bool add_groups,bool is_wei)713 layout_t normalize_conv_groups(const layout_t &layout, bool with_groups,
714         int groups, bool is_dw, bool add_groups, bool is_wei) {
715     if (with_groups == add_groups) return layout;
716     if (is_wei) {
717         ir_assert(groups == 1)
718                 << "Adding/removing groups can be done only for single group.";
719         if (add_groups) return insert_dimension(layout, 0);
720         return remove_size_1_dimension(layout, 0);
721     }
722 
723     ir_assert(!with_groups) << "Unexpected groups in source/destination.";
724     if (is_dw) groups = layout.dim(1);
725     return split_dimension(layout, /*dim_idx=*/1, groups);
726 }
727 
normalize_conv_layout(const layout_t & _layout,bool with_groups,int groups,bool is_dw,bool reduced_to_1d,bool add_groups,bool is_wei)728 layout_t normalize_conv_layout(const layout_t &_layout, bool with_groups,
729         int groups, bool is_dw, bool reduced_to_1d, bool add_groups,
730         bool is_wei) {
731     int old_sp_ndims = _layout.ndims() - (with_groups ? 3 : 2);
732 
733     layout_t layout = _layout;
734     layout = normalize_conv_spatial(layout, old_sp_ndims, reduced_to_1d);
735     layout = normalize_conv_groups(
736             layout, with_groups, groups, is_dw, add_groups, is_wei);
737 
738     return layout;
739 }
740 
normalize_conv_dims(std::vector<dim_t> & dims,bool with_groups,int groups,bool is_dw,bool reduced_to_1d,bool add_groups,bool is_wei)741 std::vector<dim_t> normalize_conv_dims(std::vector<dim_t> &dims,
742         bool with_groups, int groups, bool is_dw, bool reduced_to_1d,
743         bool add_groups, bool is_wei) {
744     layout_t dummy_layout(type_t::u8(), 0, dims);
745     return normalize_conv_layout(dummy_layout, with_groups, groups, is_dw,
746             reduced_to_1d, add_groups, is_wei)
747             .dims();
748 }
749 
normalize_conv_layouts(layout_t & src_layout,layout_t & wei_layout,layout_t & dst_layout,bool with_groups,int groups,bool is_dw,bool reduced_to_1d,bool add_groups)750 void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
751         layout_t &dst_layout, bool with_groups, int groups, bool is_dw,
752         bool reduced_to_1d, bool add_groups) {
753     src_layout = normalize_conv_layout(src_layout, /*with_groups=*/false,
754             groups, is_dw, reduced_to_1d, add_groups, /*is_wei=*/false);
755     wei_layout = normalize_conv_layout(wei_layout, with_groups, groups, is_dw,
756             reduced_to_1d, add_groups, /*is_wei=*/true);
757     dst_layout = normalize_conv_layout(dst_layout, /*with_groups=*/false,
758             groups, is_dw, reduced_to_1d, add_groups, /*is_wei=*/false);
759 }
760 
761 } // namespace jit
762 } // namespace gpu
763 } // namespace impl
764 } // namespace dnnl
765