1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_CONV_TENSOR_HPP
18 #define GPU_JIT_CONV_TENSOR_HPP
19 
20 #include <algorithm>
21 #include <array>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <thread>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 #include <unordered_map>
30 
31 #include "common/memory_desc_wrapper.hpp"
32 #include "gpu/jit/conv/ir.hpp"
33 #include "gpu/jit/conv/utils.hpp"
34 
35 namespace dnnl {
36 namespace impl {
37 namespace gpu {
38 namespace jit {
39 
40 class tensor_t {
41 public:
42     tensor_t() = default;
43 
tensor_t(const std::vector<dim_t> & dims)44     tensor_t(const std::vector<dim_t> &dims)
45         : tensor_t(dims, std::vector<expr_t>()) {}
46 
tensor_t(const std::vector<dim_t> & dims,const std::vector<expr_t> & start)47     tensor_t(const std::vector<dim_t> &dims, const std::vector<expr_t> &start)
48         : dims_(dims), start_(start) {
49         if (start_.empty()) start_.resize(dims.size(), 0);
50     }
51 
tensor_t(const std::vector<dim_t> & dims,const std::vector<dim_t> & start)52     tensor_t(const std::vector<dim_t> &dims, const std::vector<dim_t> &start)
53         : tensor_t(dims) {
54         start_.resize(start.size());
55         for (size_t i = 0; i < start.size(); i++)
56             start_[i] = start[i];
57     }
58 
operator ()(int idx) const59     dim_t operator()(int idx) const { return dims_[idx]; }
60 
start(int idx) const61     const expr_t &start(int idx) const { return start_[idx]; }
62 
ndims() const63     int ndims() const { return int(dims_.size()); }
64 
elems() const65     dim_t elems() const {
66         dim_t ret = 1;
67         for (int i = 0; i < ndims(); i++)
68             ret *= dims_[i];
69         return ret;
70     }
71 
dims() const72     const std::vector<dim_t> &dims() const { return dims_; }
73 
start() const74     const std::vector<expr_t> &start() const { return start_; }
75 
is_empty() const76     bool is_empty() const { return dims_.empty(); }
77 
is_equal(const tensor_t & other) const78     bool is_equal(const tensor_t &other) const {
79         if (ndims() != other.ndims()) return false;
80         for (int i = 0; i < ndims(); i++) {
81             if (dims_[i] != other.dims_[i]) return false;
82             if (!start_[i].is_equal(other.start_[i])) return false;
83         }
84         return true;
85     }
86 
str() const87     std::string str() const {
88         using ir_utils::operator<<;
89 
90         if (is_empty()) return "(nil)";
91         std::ostringstream oss;
92         oss << ir_utils::make_seq_print_helper(dims_, "x");
93         if (!has_zero_start()) oss << " start: [" << start_ << "]";
94         return oss.str();
95     }
96 
IR_DEFINE_DUMP()97     IR_DEFINE_DUMP()
98 
99     bool has_zero_start() const {
100         for (int i = 0; i < ndims(); i++)
101             if (!is_zero(start_[i])) return false;
102         return true;
103     }
104 
to_1d_offset(const std::vector<dim_t> & args) const105     dim_t to_1d_offset(const std::vector<dim_t> &args) const {
106         ir_assert(has_zero_start());
107 
108         dim_t off = 0;
109         for (int i = 0; i < ndims(); i++) {
110             off *= dims_[i];
111             off += args[i];
112         }
113         return off;
114     }
115 
create_sub_tensor(const tensor_t & tile) const116     tensor_t create_sub_tensor(const tensor_t &tile) const {
117         ir_assert(ndims() == tile.ndims()) << "Incompatible sizes.";
118         std::vector<expr_t> new_start = start_;
119         for (int i = 0; i < ndims(); i++)
120             new_start[i] += tile.start(i);
121         return tensor_t(tile.dims(), new_start);
122     }
123 
substitute(const expr_t & from,const expr_t & to) const124     tensor_t substitute(const expr_t &from, const expr_t &to) const {
125         tensor_t ret = *this;
126         for (int i = 0; i < ndims(); i++) {
127             ret.start_[i] = jit::substitute(ret.start_[i], from, to);
128             ret.start_[i] = simplify(ret.start_[i]);
129         }
130         return ret;
131     }
132 
133 private:
134     std::vector<dim_t> dims_;
135     std::vector<expr_t> start_;
136 };
137 
operator <<(std::ostream & out,const tensor_t & tensor)138 inline std::ostream &operator<<(std::ostream &out, const tensor_t &tensor) {
139     out << tensor.str();
140     return out;
141 }
142 
143 class grid_info_t {
144 public:
145     grid_info_t() = default;
grid_info_t(int ndims)146     grid_info_t(int ndims) : dims_(ndims), offs_(ndims), idxs_(ndims) {}
grid_info_t(const std::vector<int> & dims,const std::vector<expr_t> & idxs)147     grid_info_t(const std::vector<int> &dims, const std::vector<expr_t> &idxs)
148         : grid_info_t(dims, {}, idxs) {}
grid_info_t(const std::vector<int> & dims,const std::vector<int> & offs,const std::vector<expr_t> & idxs)149     grid_info_t(const std::vector<int> &dims, const std::vector<int> &offs,
150             const std::vector<expr_t> &idxs)
151         : dims_(dims), offs_(offs), idxs_(idxs) {
152         if (offs_.empty()) offs_.resize(dims.size());
153         ir_assert(dims_.size() == offs_.size());
154         ir_assert(dims_.size() == idxs_.size());
155     }
156 
operator ==(const grid_info_t & other) const157     bool operator==(const grid_info_t &other) const {
158         if (ndims() != other.ndims()) return false;
159         for (int i = 0; i < ndims(); i++) {
160             if (dim(i) != other.dim(i)) return false;
161             if (off(i) != other.off(i)) return false;
162             if (!idx(i).is_equal(other.idx(i))) return false;
163         }
164         return true;
165     }
166 
is_empty() const167     bool is_empty() const { return dims_.empty(); }
168 
dim(int dim_idx)169     int &dim(int dim_idx) { return dims_[dim_idx]; }
off(int dim_idx)170     int &off(int dim_idx) { return offs_[dim_idx]; }
idx(int dim_idx)171     expr_t &idx(int dim_idx) { return idxs_[dim_idx]; }
dim_idx(const expr_t & idx_var) const172     int dim_idx(const expr_t &idx_var) const {
173         for (int i = 0; i < ndims(); i++) {
174             if (idx(i).is_same(idx_var)) return i;
175         }
176         ir_error_not_expected() << "Index not found: " << idx_var;
177         return -1;
178     }
179 
dim(int dim_idx) const180     const int &dim(int dim_idx) const { return dims_[dim_idx]; }
dim(const expr_t & idx_var) const181     const int &dim(const expr_t &idx_var) const {
182         return dims_[dim_idx(idx_var)];
183     }
off(int dim_idx) const184     const int &off(int dim_idx) const { return offs_[dim_idx]; }
idx(int dim_idx) const185     const expr_t &idx(int dim_idx) const { return idxs_[dim_idx]; }
186 
ndims() const187     int ndims() const { return int(dims_.size()); }
elems() const188     int elems() const {
189         return utils::array_product(dims_.data(), dims_.size());
190     }
191 
sub_grid(std::initializer_list<int> old_dim_idxs) const192     grid_info_t sub_grid(std::initializer_list<int> old_dim_idxs) const {
193         grid_info_t ret(int(old_dim_idxs.size()));
194         int new_dim_idx = 0;
195         for (auto old_dim_idx : old_dim_idxs) {
196             ret.dim(new_dim_idx) = dim(old_dim_idx);
197             ret.off(new_dim_idx) = off(old_dim_idx);
198             ret.idx(new_dim_idx) = idx(old_dim_idx);
199             new_dim_idx++;
200         }
201         return ret;
202     }
203 
slice(int dim_idx,int new_off,int new_dim,const expr_t & new_idx,expr_t & new_idx_value) const204     grid_info_t slice(int dim_idx, int new_off, int new_dim,
205             const expr_t &new_idx, expr_t &new_idx_value) const {
206         ir_assert(dim_idx >= 0 && dim_idx < ndims());
207         ir_assert(new_dim > 0 && new_off >= 0);
208         ir_assert(new_off + new_dim <= dims_[dim_idx]);
209 
210         grid_info_t ret = *this;
211         ret.offs_[dim_idx] += new_off;
212         ret.dims_[dim_idx] = new_dim;
213         if (new_off > 0) {
214             new_idx_value = ret.idxs_[dim_idx] - new_off;
215             ret.idxs_[dim_idx] = new_idx;
216         } else {
217             new_idx_value = expr_t();
218         }
219         ret.parent_dims_ = (parent_dims_.empty() ? dims_ : parent_dims_);
220         return ret;
221     }
222 
halven(const expr_t & new_idx,int & dim_idx,expr_t & new_idx_value,bool first=true) const223     grid_info_t halven(const expr_t &new_idx, int &dim_idx,
224             expr_t &new_idx_value, bool first = true) const {
225         for (int i = ndims() - 1; i >= 0; i--) {
226             if (dim(i) == 1 || dim(i) % 2 != 0) continue;
227             dim_idx = i;
228             if (first) return slice(i, 0, dim(i) / 2, new_idx, new_idx_value);
229             return slice(i, dim(i) / 2, dim(i) / 2, new_idx, new_idx_value);
230         }
231         return grid_info_t();
232     }
233 
slice_condition() const234     expr_t slice_condition() const {
235         if (parent_dims_.empty()) return expr_t();
236         expr_t ret(true);
237         for (int i = 0; i < ndims(); i++) {
238             auto &idx = idxs_[i];
239             if (offs_[i] > 0) ret &= (idx >= 0);
240             if (offs_[i] + dims_[i] < parent_dims_[i]) ret &= (idx < dims_[i]);
241         }
242         if (ret.is_equal(expr_t(true))) return expr_t();
243         return ret;
244     }
245 
str() const246     std::string str() const {
247         std::ostringstream oss;
248         oss << ir_utils::make_seq_print_helper(dims_, "x");
249         return oss.str();
250     }
251 
252     IR_DEFINE_DUMP()
253 
254 private:
255     std::vector<int> dims_;
256     std::vector<int> offs_;
257     std::vector<expr_t> idxs_;
258 
259     std::vector<int> parent_dims_;
260 };
261 
operator <<(std::ostream & out,const grid_info_t & grid_info)262 inline std::ostream &operator<<(
263         std::ostream &out, const grid_info_t &grid_info) {
264     out << grid_info.str();
265     return out;
266 }
267 
268 class grid_splitter_t {
269 public:
grid_splitter_t(const grid_info_t & grid)270     grid_splitter_t(const grid_info_t &grid)
271         : grid_(grid), cur_idx_(grid.ndims() - 1), cur_stride_(1) {
272         skip_size_1_dims();
273         ir_assert(cur_idx_ >= 0);
274     }
275 
cur_block() const276     int cur_block() const {
277         if (is_empty()) return 1;
278 
279         return grid_.dim(cur_idx_) / cur_stride_;
280     }
281 
is_empty() const282     bool is_empty() const { return cur_idx_ == -1; }
283 
can_pop_block(int size) const284     bool can_pop_block(int size) const {
285         if (is_empty()) return false;
286         return cur_block() % size == 0;
287     }
288 
289     expr_t pop_block(int size);
290 
291 private:
skip_size_1_dims()292     void skip_size_1_dims() {
293         while (cur_idx_ >= 0 && grid_.dim(cur_idx_) == 1)
294             cur_idx_--;
295     }
296 
297     grid_info_t grid_;
298 
299     int cur_idx_;
300     int cur_stride_;
301 };
302 
303 enum class stride_kind_t {
304     undef,
305     fixed,
306     unknown,
307 };
308 
309 class stride_t {
310 public:
311     stride_t() = default;
312 
stride_t(dim_t stride)313     stride_t(dim_t stride) : stride_t(stride_kind_t::fixed, stride) {}
314 
operator ==(const stride_t & other) const315     bool operator==(const stride_t &other) const {
316         return (kind_ == other.kind_) && (stride_ == other.stride_);
317     }
318 
operator !=(const stride_t & other) const319     bool operator!=(const stride_t &other) const { return !operator==(other); }
320 
get_hash() const321     size_t get_hash() const { return ir_utils::get_hash(kind_, stride_); }
322 
operator dim_t() const323     operator dim_t() const {
324         ir_assert(kind_ == stride_kind_t::fixed);
325         return stride_;
326     }
327 
is_fixed() const328     bool is_fixed() const { return kind_ == stride_kind_t::fixed; }
329 
is_unknown() const330     bool is_unknown() const { return kind_ == stride_kind_t::unknown; }
331 
operator *=(const stride_t & other)332     stride_t &operator*=(const stride_t &other) {
333         if (is_fixed() && other.is_fixed()) {
334             stride_ *= other.stride_;
335         } else {
336             set_unknown();
337         }
338         return *this;
339     }
340 
operator /=(const stride_t & other)341     stride_t &operator/=(const stride_t &other) {
342         if (is_fixed() && other.is_fixed()) {
343             stride_ /= other.stride_;
344         } else {
345             set_unknown();
346         }
347         return *this;
348     }
349 
str() const350     std::string str() const {
351         std::ostringstream oss;
352         if (is_fixed()) {
353             oss << stride_;
354         } else {
355             oss << "(unknown)";
356         }
357         return oss.str();
358     }
359 
IR_DEFINE_DUMP()360     IR_DEFINE_DUMP()
361 
362     static stride_t unknown() { return stride_t(stride_kind_t::unknown); }
363 
364 private:
stride_t(stride_kind_t kind,dim_t stride=0)365     stride_t(stride_kind_t kind, dim_t stride = 0)
366         : kind_(kind), stride_(stride) {}
367 
set_unknown()368     void set_unknown() {
369         kind_ = stride_kind_t::unknown;
370         stride_ = 0;
371     }
372 
373     stride_kind_t kind_ = stride_kind_t::undef;
374     dim_t stride_ = 0;
375 };
376 
operator <<(std::ostream & out,const stride_t & stride)377 inline std::ostream &operator<<(std::ostream &out, const stride_t &stride) {
378     out << stride.str();
379     return out;
380 }
381 
operator *(const stride_t & a,const stride_t & b)382 inline stride_t operator*(const stride_t &a, const stride_t &b) {
383     stride_t tmp = a;
384     return tmp *= b;
385 }
386 
operator *(const stride_t & a,dim_t b)387 inline stride_t operator*(const stride_t &a, dim_t b) {
388     return a * stride_t(b);
389 }
390 
operator *(dim_t a,const stride_t & b)391 inline stride_t operator*(dim_t a, const stride_t &b) {
392     return stride_t(a) * b;
393 }
394 
395 struct block_t {
396     block_t() = default;
397 
block_tdnnl::impl::gpu::jit::block_t398     block_t(int dim_idx, dim_t block, const stride_t &stride)
399         : dim_idx(dim_idx), block(block), stride(stride) {}
400 
is_equaldnnl::impl::gpu::jit::block_t401     bool is_equal(const block_t &other) const {
402         return (dim_idx == other.dim_idx) && (block == other.block)
403                 && (stride == other.stride);
404     }
405 
get_hashdnnl::impl::gpu::jit::block_t406     size_t get_hash() const {
407         return ir_utils::get_hash(dim_idx, block, stride);
408     }
409 
strdnnl::impl::gpu::jit::block_t410     std::string str() const {
411         std::ostringstream oss;
412         oss << "block_t(dim_idx = " << dim_idx;
413         oss << ", block = " << block;
414         oss << ", stride = " << stride;
415         oss << ")";
416         return oss.str();
417     }
418 
419     IR_DEFINE_DUMP()
420 
421     int dim_idx; // Dimension index.
422     dim_t block; // Block size.
423     stride_t stride; // Stride between elements of the block.
424 };
425 
operator <<(std::ostream & out,const block_t & b)426 inline std::ostream &operator<<(std::ostream &out, const block_t &b) {
427     out << b.str();
428     return out;
429 }
430 
431 class layout_t {
432 public:
433     static const int max_ndims = 6;
434 
layout_t()435     layout_t() : type_(type_t::undef()), ndims_(0), offset_(0) {
436         sanity_check();
437     }
438 
439     layout_t(const type_t &type, const expr_t &offset,
440             const std::string &format, const std::vector<dim_t> &dims = {},
441             bool do_normalize = true);
442 
layout_t(const memory_desc_wrapper & mdw,const std::string & format,bool do_normalize=true)443     layout_t(const memory_desc_wrapper &mdw, const std::string &format,
444             bool do_normalize = true)
445         : layout_t(mdw.data_type(), mdw.offset0(), format,
446                 std::vector<dim_t>(
447                         mdw.padded_dims(), mdw.padded_dims() + mdw.ndims()),
448                 do_normalize) {}
449 
layout_t(const memory_desc_wrapper & mdw,const char * format,bool do_normalize=true)450     layout_t(const memory_desc_wrapper &mdw, const char *format,
451             bool do_normalize = true)
452         : layout_t(mdw, std::string(format), do_normalize) {}
453 
454     layout_t(const memory_desc_wrapper &mdw, bool do_normalize = true);
455 
layout_t(const type_t & type,const expr_t & offset,const std::vector<dim_t> & dims,bool do_normalize=true)456     layout_t(const type_t &type, const expr_t &offset,
457             const std::vector<dim_t> &dims, bool do_normalize = true)
458         : type_(type), ndims_(int(dims.size())), offset_(offset) {
459         dim_t stride = 1;
460         for (int i = ndims_ - 1; i >= 0; i--) {
461             blocks_.emplace_back(i, dims[i], stride);
462             stride *= dims[i];
463         }
464         if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
465         sanity_check();
466     }
467 
layout_t(const type_t & type,int ndims,const expr_t & offset,const std::vector<block_t> & blocks,bool do_normalize=true)468     layout_t(const type_t &type, int ndims, const expr_t &offset,
469             const std::vector<block_t> &blocks, bool do_normalize = true)
470         : type_(type), ndims_(ndims), offset_(offset), blocks_(blocks) {
471         if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
472         sanity_check();
473     }
474 
layout_t(const type_t & type,const expr_t & offset,const layout_t & other,bool do_normalize)475     layout_t(const type_t &type, const expr_t &offset, const layout_t &other,
476             bool do_normalize)
477         : layout_t(type, other.ndims(), offset, other.blocks(), do_normalize) {}
478 
is_empty() const479     bool is_empty() const { return ndims_ == 0; }
480 
ndims() const481     int ndims() const { return ndims_; }
482 
elems() const483     dim_t elems() const {
484         dim_t ret = 1;
485         for (auto &b : blocks_)
486             ret *= b.block;
487         return ret;
488     }
489 
490     // Storage size in bytes.
size() const491     dim_t size() const {
492         if (is_empty()) return 0;
493         dim_t max_stride = 1;
494         for (auto &b : blocks_) {
495             max_stride = std::max(max_stride, dim_t(b.block * b.stride));
496         }
497         return max_stride * type().size();
498     }
499 
500     template <typename T = expr_t>
offset(const std::vector<T> & args={},bool ignore_offset=false) const501     T offset(
502             const std::vector<T> &args = {}, bool ignore_offset = false) const {
503         if (args.empty()) return expr_cast<T>(offset_);
504 
505         ir_assert(int(args.size()) == ndims()) << "Dimensions do not match.";
506 
507         T off = 0;
508         auto _args = args;
509         for (auto &eb : enumerated_blocks()) {
510             auto &b = eb.second;
511             auto &idx = _args[b.dim_idx];
512             if (ir_utils::is_equal(idx, T(0))) continue;
513 
514             // Do not use modulus for outermost blocks.
515             auto i = is_outermost(eb) ? idx : (idx % b.block);
516             off = i * dim_t(b.stride) + off;
517             idx /= b.block;
518         }
519         if (ignore_offset) return off;
520 
521         T off0 = expr_cast<T>(offset_);
522         return off0 + off;
523     }
524 
type() const525     const type_t &type() const { return type_; }
526 
dims() const527     std::vector<dim_t> dims() const {
528         std::vector<dim_t> dims(ndims(), 1);
529         for (auto &b : blocks_) {
530             dims[b.dim_idx] *= b.block;
531         }
532         return dims;
533     }
534 
dim(int dim_idx) const535     dim_t dim(int dim_idx) const {
536         dim_t ret = 1;
537         for (auto &b : blocks_) {
538             if (b.dim_idx == dim_idx) ret *= b.block;
539         }
540         return ret;
541     }
542 
blocks() const543     const std::vector<block_t> &blocks() const { return blocks_; }
544 
set_offset(const expr_t & offset)545     void set_offset(const expr_t &offset) { offset_ = offset; }
546 
is_strictly_equal(const layout_t & other,bool compare_offset=true) const547     bool is_strictly_equal(
548             const layout_t &other, bool compare_offset = true) const {
549         if (!type_.is_equal(other.type_)) return false;
550         if (compare_offset && !offset_.is_equal(other.offset_)) return false;
551         if (!ir_utils::is_equal(blocks_, other.blocks_)) return false;
552         return true;
553     }
554 
operator ==(const layout_t & other) const555     bool operator==(const layout_t &other) const { return is_equal(other); }
556 
operator !=(const layout_t & other) const557     bool operator!=(const layout_t &other) const { return !operator==(other); }
558 
is_equal(const layout_t & other,bool compare_offset=true) const559     bool is_equal(const layout_t &other, bool compare_offset = true) const {
560         return normalize().is_strictly_equal(other.normalize(), compare_offset);
561     }
562 
get_hash() const563     size_t get_hash() const {
564         return ir_utils::get_hash(type_, ndims_, offset_, blocks_);
565     }
566 
567     template <typename T>
operator ()(const std::vector<T> & args) const568     T operator()(const std::vector<T> &args) const {
569         return offset(args);
570     }
571 
572     template <typename T = expr_t>
offset_in_bytes(const std::vector<T> & args={},bool ignore_offset=false) const573     T offset_in_bytes(
574             const std::vector<T> &args = {}, bool ignore_offset = false) const {
575         return offset(args, ignore_offset) * type().size();
576     }
577 
desc_str(bool dnnl_style=false) const578     std::string desc_str(bool dnnl_style = false) const {
579         if (is_empty()) return "(nil)";
580         if (!dnnl_style && blocks_.empty()) return "(scalar)";
581         std::string ret;
582         stride_t dense_stride(1);
583         std::vector<bool> seen(ndims());
584         for (auto &eb : enumerated_blocks()) {
585             auto &b = eb.second;
586             std::string b_str;
587             if (dnnl_style && is_outermost(eb)) {
588                 b_str.append(1, (seen[b.dim_idx] ? 'A' : 'a') + b.dim_idx);
589             } else {
590                 b_str = std::to_string(b.block);
591                 b_str.append(1, 'a' + b.dim_idx);
592             }
593             if (!dnnl_style) {
594                 if (b.stride.is_unknown()) {
595                     b_str.append(1, '?');
596                 } else if (b.stride != dense_stride) {
597                     b_str.append(1, '*');
598                 }
599             }
600             ret = b_str + ret;
601             dense_stride = b.stride * b.block;
602             seen[b.dim_idx] = true;
603         }
604         return ret;
605     }
606 
str() const607     std::string str() const {
608         if (is_empty()) return "(nil)";
609         std::ostringstream oss;
610         oss << desc_str();
611         if (!has_zero_offset()) oss << " offset: " << offset_;
612         return oss.str();
613     }
614 
615     IR_DEFINE_DUMP()
616 
617     memory_desc_t to_dnnl(const dim_t *dims_hint) const;
618 
619     // Returns a vector of <block index, block> pairs.
620     // The innermost block (first) has index 0.
enumerated_blocks() const621     std::vector<std::pair<int, block_t>> enumerated_blocks() const {
622         std::vector<std::pair<int, block_t>> ret;
623         for (int i = 0; i < int(blocks_.size()); i++) {
624             ret.emplace_back(i, blocks_[i]);
625         }
626         return ret;
627     }
628 
strides(int dim_idx) const629     std::vector<dim_t> strides(int dim_idx) const {
630         std::vector<dim_t> ret;
631         for (auto &b : blocks_)
632             if (b.dim_idx == dim_idx) ret.push_back(b.stride);
633         return ret;
634     }
635 
636     // eb is <block index, block> pair, see enumerated_blocks().
is_outermost(const std::pair<int,block_t> & eb) const637     bool is_outermost(const std::pair<int, block_t> &eb) const {
638         return is_outermost(eb, blocks_);
639     }
640 
is_plain() const641     bool is_plain() const {
642         std::vector<bool> seen(ndims());
643         for (auto &b : blocks_) {
644             if (seen[b.dim_idx]) return false;
645             seen[b.dim_idx] = true;
646         }
647         return true;
648     }
649 
has_zero_offset() const650     bool has_zero_offset() const { return offset_.is_equal(expr_t(0)); }
651 
has_unknown_strides() const652     bool has_unknown_strides() const {
653         for (auto &b : blocks_)
654             if (b.stride.is_unknown()) return true;
655         return false;
656     }
657 
658     // Returns a canonical representation of the layout:
659     // - Size one blocks are removed
660     // - Consecutive dense blocks are merged
normalize() const661     layout_t normalize() const {
662         auto blocks = normalize_blocks(ndims(), blocks_);
663         return layout_t(type(), ndims(), offset(), blocks);
664     }
665 
transpose() const666     layout_t transpose() const {
667         if (ndims() != 2) ir_error_not_expected();
668 
669         // Flip: 0 -> 1, 1 -> 0.
670         auto blocks = blocks_;
671         for (auto &b : blocks)
672             b.dim_idx ^= 1;
673 
674         return layout_t(type(), ndims(), offset(), blocks);
675     }
676 
677     // Returns a new (sub-)layout that fully contains the passed sub-tensor.
678     // Strides are kept unchanged.
679     // Assumption: the original layout can be tiled by the passed sub-tensor.
680     // For example: XaYb4a2b can be tiled into 2x2 sub-tensors but it's not
681     // possible to tile it into 3x2 sub-tensors.
682     layout_t map(const tensor_t &tensor) const;
683 
684     layout_t reinterpret(
685             const type_t &new_type, bool do_normalize = true) const;
686 
retype(const type_t & new_type) const687     layout_t retype(const type_t &new_type) const {
688         auto ret = *this;
689         ret.type_ = new_type;
690         return ret;
691     }
692 
is_dense() const693     bool is_dense() const {
694         stride_t stride = 1;
695         for (auto &b : blocks_) {
696             if (b.stride != stride) return false;
697             stride *= b.block;
698         }
699         return true;
700     }
701 
702     // Returns true if the layout has at least n inner blocks. For example:
703     // NChw32n16c - 2 inner blocks.
is_n_blocked(int n) const704     bool is_n_blocked(int n) const {
705         int block_count[layout_t::max_ndims] = {0};
706         for (auto &b : blocks_)
707             block_count[b.dim_idx]++;
708 
709         int ninner_blocks = 0;
710         stride_t stride = 1;
711         for (auto &b : blocks_) {
712             if (b.stride != stride) break; // Not dense anymore.
713             if (block_count[b.dim_idx] == 1) break; // Outer block.
714             stride *= b.block;
715             ir_assert(block_count[b.dim_idx] > 1);
716             block_count[b.dim_idx]--;
717             ninner_blocks++;
718         }
719 
720         return ninner_blocks >= n;
721     }
722 
723     // Returns a packed layout where all blocks are contiguous, without gaps.
make_dense() const724     layout_t make_dense() const {
725         dim_t stride = 1;
726         auto new_blocks = blocks_;
727         for (auto &b : new_blocks) {
728             b.stride = stride;
729             stride *= b.block;
730         }
731         return layout_t(type(), ndims(), 0, new_blocks);
732     }
733 
make_strided(int _stride) const734     layout_t make_strided(int _stride) const {
735         stride_t stride = _stride;
736         auto new_blocks = blocks_;
737         for (auto &b : new_blocks) {
738             b.stride = stride;
739             stride *= b.block;
740         }
741         return layout_t(type(), ndims(), 0, new_blocks);
742     }
743 
744     // Returns an equivalent layout where the specified block is split into two.
745     // block0 - inner block size.
746     // block1 - outer block size.
747     layout_t split_block(const std::pair<int, block_t> &eb, dim_t block0,
748             dim_t block1) const;
749 
750     // Splits blocks so that they can be used to form `multi_blocks` without
751     // crossing the block boundaries. `multi_blocks` are ordered from innermost
752     // to outermost. Returns an empty layout if such a split is not possible.
753     // Example (all blocks are ordered from innermost to outermost):
754     //     Input blocks:  [4, 4, 2]
755     //     Multi-blocks:  [8, 2]
756     //     Output blocks: [4, 2, 2, 2]
757     layout_t split_into_multi_blocks(
758             const std::vector<dim_t> &multi_blocks) const;
759 
760     layout_t split_into_multi_blocks_with_hint(
761             std::vector<dim_t> &multi_blocks) const;
762 
add_outer_block(int dim_idx,dim_t block,dim_t stride=-1) const763     layout_t add_outer_block(
764             int dim_idx, dim_t block, dim_t stride = -1) const {
765         if (stride == -1) stride = elems();
766         ir_assert(stride >= elems());
767         ir_assert(dim_idx < ndims());
768         auto new_blocks = blocks();
769         new_blocks.emplace_back(dim_idx, block, stride);
770         return layout_t(type(), ndims(), offset(), new_blocks);
771     }
772 
773     tensor_t split_into_dense_tile(dim_t tile_elems, dim_t outer_block) const;
774 
775     // Returns a tensor corresponding to the biggest innermost sub-layout so that
776     // 1) It consists of consecutive blocks only.
777     // 2) It contains less or equal than max_tile_elems elements.
778     // 3) It is dense if is_dense_tile is true.
779     tensor_t split_into_max_tile(
780             dim_t max_tile_elems, bool is_dense_tile) const;
781 
split(const grid_info_t & grid) const782     tensor_t split(const grid_info_t &grid) const {
783         std::vector<dim_t> tile_dims(ndims(), 1);
784         ir_assert(elems() % grid.elems() == 0) << "Can't split across grid.";
785 
786         dim_t cur_elems_per_tile = 1;
787         dim_t elems_per_tile = elems() / grid.elems();
788         for (auto &b : blocks()) {
789             dim_t block
790                     = std::min(b.block, elems_per_tile / cur_elems_per_tile);
791             tile_dims[b.dim_idx] *= block;
792             cur_elems_per_tile *= block;
793         }
794         ir_assert(cur_elems_per_tile == elems_per_tile)
795                 << "Can't split across grid.";
796 
797         return split(tensor_t(tile_dims), grid);
798     }
799 
split(const tensor_t & tile,const grid_info_t & grid,std::vector<block_t> * outer_blocks=nullptr) const800     tensor_t split(const tensor_t &tile, const grid_info_t &grid,
801             std::vector<block_t> *outer_blocks = nullptr) const {
802         ir_assert(ndims() == tile.ndims())
803                 << "Number of dimensions doesn't match.";
804         ir_assert(tile.has_zero_start());
805 
806         if (outer_blocks) outer_blocks->resize(0);
807 
808         if (grid.elems() == 1) return tile;
809 
810         dim_t total_elems = elems();
811         dim_t tile_elems = tile.elems();
812 
813         grid_splitter_t grid_splitter(grid);
814         ir_assert(tile_elems * grid.elems() == total_elems)
815                 << "Tile/grid dimensions do not match.";
816         MAYBE_UNUSED(total_elems);
817         MAYBE_UNUSED(tile_elems);
818 
819         std::vector<dim_t> dims(tile.ndims(), 1);
820         std::vector<expr_t> start(tile.ndims(), 0);
821         std::vector<dim_t> rem_dims = tile.dims();
822         for (auto &eb : enumerated_blocks()) {
823             auto &b = eb.second;
824             if (b.block == 1) continue;
825 
826             dim_t &e = rem_dims[b.dim_idx];
827             if (e > 1) {
828                 if (e % b.block == 0) {
829                     e /= b.block;
830                 } else if (b.block % e == 0) {
831                     auto tmp_layout = split_block(eb, e, b.block / e);
832                     return tmp_layout.split(tile, grid, outer_blocks);
833                 } else {
834                     ir_error_not_expected() << "Can't split across grid.";
835                 }
836             } else {
837                 dim_t next_chunk
838                         = math::gcd(b.block, grid_splitter.cur_block());
839                 if (b.block == next_chunk) {
840                     auto idx = grid_splitter.pop_block(next_chunk);
841                     start[b.dim_idx] += idx * dims[b.dim_idx];
842                     if (outer_blocks) outer_blocks->push_back(b);
843                 } else if (b.block % next_chunk == 0) {
844                     auto tmp_layout
845                             = split_block(eb, next_chunk, b.block / next_chunk);
846                     return tmp_layout.split(tile, grid, outer_blocks);
847                 } else {
848                     ir_error_not_expected() << "Can't split across grid.";
849                 }
850             }
851             dims[b.dim_idx] *= b.block;
852         }
853         return tensor_t(tile.dims(), start);
854     }
855 
856     // Iterates through tiles of the layout, calling `f` with relative offsets
857     // for each tile. The iteration order is defined by the layout blocks -
858     // absolute 1D offsets are increasing between callback calls.
859     template <typename F>
for_each_tile(const tensor_t & tile,const F & f) const860     void for_each_tile(const tensor_t &tile, const F &f) const {
861         ir_assert(tile.ndims() == ndims());
862         ir_assert(tile.has_zero_start());
863         for (int i = 0; i < ndims(); i++) {
864             ir_assert(dim(i) % tile.dims()[i] == 0);
865         }
866 
867         int nblocks = int(blocks().size());
868         std::vector<dim_t> sub_blocks(nblocks);
869         for (int i = 0; i < nblocks; i++)
870             sub_blocks[i] = blocks()[i].block;
871 
872         for (int i = 0; i < ndims(); i++) {
873             dim_t dim = tile.dims()[i];
874             for (auto &eb : enumerated_blocks()) {
875                 auto &b = eb.second;
876                 if (b.dim_idx != i) continue;
877                 int block_idx = eb.first;
878                 if (b.block >= dim) {
879                     ir_assert(b.block % dim == 0);
880                     sub_blocks[block_idx] = b.block / dim;
881                     break;
882                 }
883                 sub_blocks[block_idx] = 1;
884                 ir_assert(dim % b.block == 0);
885                 dim /= b.block;
886             }
887         }
888 
889         int ntiles = int(elems() / tile.elems());
890 
891         std::vector<dim_t> sub_block_idxs(nblocks);
892         for (int i = 0; i < ntiles; i++) {
893             // Convert sub-block indices to dimension indices.
894             std::vector<dim_t> dims(ndims(), 1);
895             std::vector<dim_t> start(ndims());
896             for (int j = 0; j < nblocks; j++) {
897                 auto &b = blocks()[j];
898                 dim_t k = sub_block_idxs[j]
899                         * (blocks()[j].block / sub_blocks[j]);
900                 start[b.dim_idx] += dims[b.dim_idx] * k;
901                 dims[b.dim_idx] *= b.block;
902             }
903 
904             // Pass dimension offsets to the callback.
905             f(start);
906 
907             // Move to the next vector of indices.
908             for (int j = 0; j < nblocks; j++) {
909                 auto &idx = sub_block_idxs[j];
910                 if (idx + 1 < sub_blocks[j]) {
911                     idx++;
912                     break;
913                 }
914                 idx = 0;
915             }
916         }
917     }
918 
919     // eb is <block index, block> pair, see enumerated_blocks().
is_outermost(const std::pair<int,block_t> & eb,const std::vector<block_t> & blocks)920     static bool is_outermost(const std::pair<int, block_t> &eb,
921             const std::vector<block_t> &blocks) {
922         int dim_idx = eb.second.dim_idx;
923         for (int i = 0; i < int(blocks.size()); i++) {
924             if (blocks[i].dim_idx == dim_idx && i > eb.first) return false;
925         }
926         return true;
927     }
928 
929     // Assume that layouts are normalized.
930     static void align_layouts(layout_t &a, layout_t &b);
931 
normalize_blocks(int ndims,const std::vector<block_t> & blocks,bool remove_size_1_blocks=true)932     static std::vector<block_t> normalize_blocks(int ndims,
933             const std::vector<block_t> &blocks,
934             bool remove_size_1_blocks = true) {
935         auto new_blocks = blocks;
936 
937         // Remove blocks of size 1.
938         if (remove_size_1_blocks) {
939             for (auto it = new_blocks.begin(); it != new_blocks.end();) {
940                 if (it->block == 1) {
941                     it = new_blocks.erase(it);
942                 } else {
943                     ++it;
944                 }
945             }
946         }
947 
948         // Merge same dimension blocks.
949         block_t prev_b;
950         prev_b.dim_idx = -1;
951         for (auto it = new_blocks.begin(); it != new_blocks.end();) {
952             if (it->dim_idx == prev_b.dim_idx
953                     && it->stride == (prev_b.stride * prev_b.block)) {
954                 auto &b = *(it - 1);
955                 b.block *= it->block;
956                 prev_b = b;
957                 it = new_blocks.erase(it);
958             } else {
959                 prev_b = *it;
960                 ++it;
961             }
962         }
963 
964         return new_blocks;
965     }
966 
967 private:
968     // Returns vector of <dimension index, block size> pairs.
969     static std::vector<std::pair<int, dim_t>> parse_format(
970             const std::string &format, int ndims_hint);
971 
972     // Returns vector of <dimension letter, block size> pairs.
973     static std::vector<std::pair<char, dim_t>> parse_letter_blocks(
974             const std::string &format);
975 
976     void sanity_check() const;
977 
978     layout_t split_into_multi_blocks_impl(
979             const std::vector<dim_t> &multi_blocks,
980             std::vector<dim_t> *out_multi_blocks) const;
981 
982     // Data type of the layout.
983     type_t type_;
984 
985     // Number of dimensions.
986     int ndims_;
987 
988     // Offset to the start of the layout (in elements of type).
989     expr_t offset_;
990 
991     // Blocks ordered from innermost to outermost.
992     std::vector<block_t> blocks_;
993 };
994 
operator <<(std::ostream & out,const layout_t & layout)995 inline std::ostream &operator<<(std::ostream &out, const layout_t &layout) {
996     out << layout.str();
997     return out;
998 }
999 
1000 class mask_tensor_t {
1001 public:
1002     mask_tensor_t() = default;
1003 
mask_tensor_t(const layout_t & layout)1004     mask_tensor_t(const layout_t &layout)
1005         : layout_(layout), masks_(layout.elems(), -1) {
1006         ir_assert(layout.is_dense());
1007     }
1008 
mask_tensor_t(const layout_t & layout,const std::vector<int> & masks,const object_eq_map_t<expr_t,int> & mask2ids,const std::vector<expr_t> & id2masks)1009     mask_tensor_t(const layout_t &layout, const std::vector<int> &masks,
1010             const object_eq_map_t<expr_t, int> &mask2ids,
1011             const std::vector<expr_t> &id2masks)
1012         : layout_(layout)
1013         , masks_(masks)
1014         , mask2ids_(mask2ids)
1015         , id2masks_(id2masks) {
1016         ir_assert(int(masks.size()) == elems()) << "Incompatible size.";
1017     }
1018 
type() const1019     const type_t &type() const { return layout_.type(); }
1020 
layout() const1021     const layout_t &layout() const { return layout_; }
1022 
elems() const1023     dim_t elems() const { return layout_.elems(); }
1024 
set_mask(dim_t off,const expr_t & mask)1025     void set_mask(dim_t off, const expr_t &mask) {
1026         ir_assert(0 <= off && off < elems()) << "Incorrect offset.";
1027         if (mask.is_empty()) return;
1028 
1029         auto ret = mask2ids_.insert({mask, int(mask2ids_.size())});
1030         int id = ret.first->second;
1031         masks_[off] = id;
1032 
1033         if (ret.second) id2masks_.push_back(mask);
1034     }
1035 
mask(dim_t off) const1036     const expr_t &mask(dim_t off) const {
1037         ir_assert(0 <= off && off < elems());
1038         return id2masks_[masks_[off]];
1039     }
1040 
simplify(const constraint_set_t & cset)1041     void simplify(const constraint_set_t &cset) {
1042         for (auto &mask : id2masks_) {
1043             auto new_mask = jit::simplify(mask, cset);
1044             // Some complex expressions need more than one simplify() call.
1045             int max_tries = 5;
1046             for (int i = 0; i < max_tries; i++) {
1047                 mask = new_mask;
1048                 new_mask = jit::simplify(new_mask, cset);
1049                 if (new_mask.is_equal(mask)) break;
1050             }
1051         }
1052         mask2ids_.clear();
1053         for (int i = 0; i < int(id2masks_.size()); i++) {
1054             auto ret = mask2ids_.insert({id2masks_[i], i});
1055             if (!ret.second) {
1056                 for (auto &m : masks_)
1057                     if (m == i) m = ret.first->second;
1058             }
1059         }
1060     }
1061 
map(const tensor_t & tile) const1062     mask_tensor_t map(const tensor_t &tile) const {
1063         auto tile_start = expr_cast<dim_t>(tile.start());
1064         auto sub_layout = layout_.map(tensor_t(tile.dims()));
1065         mask_tensor_t sub_mask(sub_layout);
1066         ir_utils::for_each(
1067                 tile.dims(), [&](const std::vector<dim_t> &sub_start) {
1068                     dim_t sub_off = sub_layout(sub_start);
1069                     dim_t off = layout_(tile_start) + layout_(sub_start);
1070                     sub_mask.set_mask(sub_off, mask(off));
1071                 });
1072         return sub_mask;
1073     }
1074 
reinterpret(const type_t & new_type) const1075     mask_tensor_t reinterpret(const type_t &new_type) const {
1076         ir_assert(!is_empty()) << "Can't reinterpret.";
1077         dim_t bytes = elems() * type().size();
1078         if (bytes % new_type.size() != 0 && bytes > new_type.size())
1079             return mask_tensor_t();
1080         int new_mask_size = std::max((int)(bytes / new_type.size()), 1);
1081         std::vector<int> new_masks(new_mask_size);
1082         for (dim_t i = 0; i < bytes; i += new_type.size()) {
1083             int mask_id = std::numeric_limits<int>::max();
1084             for (int j = 0; j < new_type.size() && j < bytes; j++) {
1085                 int cur_mask_id = masks_[(i + j) / type().size()];
1086                 if (mask_id >= int(masks_.size())) {
1087                     mask_id = cur_mask_id;
1088                 } else if (mask_id != cur_mask_id) {
1089                     // Mask is not consistent, can't reinterpret.
1090                     return mask_tensor_t();
1091                 }
1092             }
1093             ir_assert(0 <= mask_id && mask_id < int(masks_.size()));
1094             new_masks[i / new_type.size()] = mask_id;
1095         }
1096         dim_t new_elmes = utils::div_up(bytes, new_type.size());
1097         layout_t _1d_layout(new_type, 0, std::vector<dim_t> {new_elmes});
1098         return mask_tensor_t(_1d_layout, new_masks, mask2ids_, id2masks_);
1099     }
1100 
to_expr(int nmasks) const1101     expr_t to_expr(int nmasks) const {
1102         if (elems() % nmasks != 0) return expr_t();
1103 
1104         std::vector<expr_t> vec(nmasks);
1105         for (int i = 0; i < elems(); i++) {
1106             auto &channel_mask = vec[i % nmasks];
1107             auto &cur_mask = id2masks_[masks_[i]];
1108             if (channel_mask.is_empty()) {
1109                 channel_mask = cur_mask;
1110                 continue;
1111             }
1112             if (!channel_mask.is_equal(cur_mask)) return expr_t();
1113         }
1114         auto e = shuffle_t::make(vec);
1115         e = jit::simplify(e);
1116         e = jit::simplify_propagate_shuffle(e);
1117         return e;
1118     }
1119 
is_empty() const1120     bool is_empty() const { return layout_.is_empty(); }
1121 
str() const1122     std::string str() const {
1123         std::ostringstream oss;
1124         for (int i = 0; i < int(elems()); i++) {
1125             if (i != 0) oss << std::endl;
1126             oss << "mask #" << i << ": ";
1127             if (masks_[i] == -1) {
1128                 oss << "(nil)";
1129             } else {
1130                 oss << id2masks_[masks_[i]];
1131             }
1132         }
1133         return oss.str();
1134     }
1135 
1136     IR_DEFINE_DUMP()
1137 
1138 private:
1139     layout_t layout_;
1140     std::vector<int> masks_;
1141 
1142     object_eq_map_t<expr_t, int> mask2ids_;
1143     std::vector<expr_t> id2masks_;
1144 };
1145 
operator <<(std::ostream & out,const mask_tensor_t & mask_tensor)1146 inline std::ostream &operator<<(
1147         std::ostream &out, const mask_tensor_t &mask_tensor) {
1148     out << mask_tensor.str();
1149     return out;
1150 }
1151 
1152 class tdim_info_t {
1153 public:
1154     tdim_info_t() = default;
1155 
tdim_info_t(const expr_t & expr,const expr_t & mask)1156     tdim_info_t(const expr_t &expr, const expr_t &mask)
1157         : expr_(expr), mask_(mask) {}
1158 
nvargs() const1159     int nvargs() const { return nvargs_; }
1160 
expr() const1161     const expr_t &expr() const { return expr_; }
1162 
mask() const1163     const expr_t &mask() const { return mask_; }
1164 
mask(const expr_t & tvalue,const std::vector<expr_t> & vvars,const std::vector<expr_t> & vvalues) const1165     expr_t mask(const expr_t &tvalue, const std::vector<expr_t> &vvars,
1166             const std::vector<expr_t> &vvalues) const {
1167         auto ret = substitute(mask_, placeholder_var(), tvalue);
1168         for (int i = 0; i < int(vvars.size()); i++) {
1169             if (contains_object(ret, vvars[i])) {
1170                 ret = substitute(ret, vvars[i], vvalues[i]);
1171             }
1172         }
1173         return ret;
1174     }
1175 
vidx(int arg_idx) const1176     int vidx(int arg_idx) const {
1177         ir_assert(arg_idx < nvargs());
1178         return vidxs_[arg_idx];
1179     }
1180 
vstride(int arg_idx) const1181     stride_t vstride(int arg_idx) const {
1182         ir_assert(arg_idx < nvargs());
1183         return vstrides_[arg_idx];
1184     }
1185 
is_empty() const1186     bool is_empty() const { return expr_.is_empty(); }
1187 
is_identity() const1188     bool is_identity() const { return is_var(expr_); }
1189 
is_fixed_stride(int arg_idx) const1190     bool is_fixed_stride(int arg_idx) const {
1191         ir_assert(arg_idx < nvargs());
1192         return vstrides_[arg_idx].is_fixed();
1193     }
1194 
add_vvar(int vidx,const expr_t & varg)1195     void add_vvar(int vidx, const expr_t &varg) {
1196         ir_assert(nvargs_ + 1 <= max_nvargs);
1197         vidxs_[nvargs_] = vidx;
1198         vstrides_[nvargs_] = compute_stride(expr_, nvargs_, varg);
1199         nvargs_++;
1200     }
1201 
placeholder_var()1202     static const expr_t &placeholder_var() {
1203         static expr_t ph_var = var_t::make(type_t::s32(), "_ph");
1204         return ph_var;
1205     }
1206 
1207 private:
1208     static const int max_nvargs = 2;
1209 
1210     static stride_t compute_stride(const expr_t &e, int idx, const expr_t &var);
1211 
1212     expr_t expr_;
1213 
1214     int nvargs_ = 0;
1215     std::array<stride_t, max_nvargs> vstrides_;
1216     std::array<int, max_nvargs> vidxs_;
1217     expr_t mask_;
1218 };
1219 
1220 class view_t {
1221 public:
1222     view_t() = default;
1223 
view_t(const std::vector<expr_t> & vvars,int ntdims)1224     view_t(const std::vector<expr_t> &vvars, int ntdims)
1225         : vvars_(vvars)
1226         , vdims_(vvars.size())
1227         , vstart_(vvars.size())
1228         , tdims_(ntdims) {}
1229 
1230     // Constructs view from a layout.
view_t(const layout_t & layout,const std::vector<expr_t> & _vvars={},uint32_t bound_check_mask=0)1231     explicit view_t(const layout_t &layout,
1232             const std::vector<expr_t> &_vvars = {},
1233             uint32_t bound_check_mask = 0)
1234         : vvars_(_vvars)
1235         , vdims_(layout.dims())
1236         , vstart_(layout.ndims(), 0)
1237         , tdims_(layout.ndims())
1238         , tlayout_(layout) {
1239         if (vvars_.empty()) vvars_ = create_vvars(layout.ndims());
1240         for (int i = 0; i < nvdims(); i++) {
1241             expr_t i_mask;
1242             if ((bound_check_mask & (1 << i)) != 0)
1243                 i_mask = (placeholder_var() < layout.dim(i));
1244             set_tdim(i, vvars_[i], i_mask);
1245         }
1246     }
1247 
vvars() const1248     const std::vector<expr_t> &vvars() const { return vvars_; }
1249 
vdims() const1250     const std::vector<dim_t> &vdims() const { return vdims_; }
1251 
vstart(int vidx) const1252     expr_t vstart(int vidx) const { return vstart_[vidx]; }
1253 
tlayout() const1254     const layout_t tlayout() const { return tlayout_; }
1255 
nvdims() const1256     int nvdims() const { return int(vdims_.size()); }
1257 
ntdims() const1258     int ntdims() const { return int(tdims_.size()); }
1259 
velems() const1260     dim_t velems() const {
1261         dim_t ret = 1;
1262         for (int i = 0; i < nvdims(); i++)
1263             ret *= vdims_[i];
1264         return ret;
1265     }
1266 
vvar(int idx) const1267     const expr_t &vvar(int idx) const {
1268         ir_assert(idx < nvdims());
1269         return vvars_[idx];
1270     }
1271 
tdim(int idx) const1272     const tdim_info_t &tdim(int idx) const {
1273         ir_assert(idx < ntdims());
1274         return tdims_[idx];
1275     }
1276 
set_tdim(int tidx,const expr_t & _texpr,expr_t mask={})1277     void set_tdim(int tidx, const expr_t &_texpr, expr_t mask = {}) {
1278         ir_assert(tdims_[tidx].is_empty());
1279 
1280         auto texpr = simplify(_texpr);
1281         ir_assert(!is_const(texpr)) << "Tensor dimension can't be a constant.";
1282 
1283         tdim_info_t tdim(texpr, mask);
1284         for (int i = 0; i < nvdims(); i++) {
1285             if (contains_object(texpr, vvars_[i])) tdim.add_vvar(i, vvars_[i]);
1286         }
1287         ir_assert(tdim.nvargs() > 0)
1288                 << "Tensor dimension must have at least one "
1289                    "view dimension that maps to it.";
1290         tdims_[tidx] = tdim;
1291     }
1292 
set_vdim(const expr_t & varg,dim_t vdim,const expr_t & vstart=expr_t (0))1293     void set_vdim(
1294             const expr_t &varg, dim_t vdim, const expr_t &vstart = expr_t(0)) {
1295         int vidx = vvar_index(varg);
1296         ir_assert(vstart_[vidx].is_empty());
1297         vstart_[vidx] = vstart;
1298         vdims_[vidx] = vdim;
1299     }
1300 
set_tlayout(const layout_t & tlayout)1301     void set_tlayout(const layout_t &tlayout) { tlayout_ = tlayout; }
1302 
str() const1303     std::string str() const {
1304         using ir_utils::operator<<;
1305 
1306         if (is_empty()) return "(nil)";
1307         std::ostringstream oss;
1308         oss << ir_utils::make_seq_print_helper(vdims_, "x");
1309         if (!has_zero_vstart()) oss << " vstart: [" << vstart_ << "]";
1310         oss << " tlayout: " << tlayout_;
1311         return oss.str();
1312     }
1313 
IR_DEFINE_DUMP()1314     IR_DEFINE_DUMP()
1315 
1316     bool is_empty() const { return vdims_.empty(); }
1317 
has_zero_vstart() const1318     bool has_zero_vstart() const {
1319         for (int i = 0; i < nvdims(); i++)
1320             if (!is_zero(vstart_[i])) return false;
1321         return true;
1322     }
1323 
has_tmask(int tidx) const1324     bool has_tmask(int tidx) const {
1325         ir_assert(tidx >= 0 && tidx < ntdims());
1326         return !tdims_[tidx].mask().is_empty();
1327     }
1328 
type() const1329     const type_t &type() const { return tlayout_.type(); }
1330 
offset(const std::vector<expr_t> & vargs={},bool ignore_offset=false) const1331     expr_t offset(const std::vector<expr_t> &vargs = {},
1332             bool ignore_offset = false) const {
1333         auto targs = cvt_vargs_to_targs(vargs);
1334         return tlayout_.offset(targs, ignore_offset);
1335     }
1336 
offset_in_bytes(const std::vector<expr_t> & vargs={},bool ignore_offset=false) const1337     expr_t offset_in_bytes(const std::vector<expr_t> &vargs = {},
1338             bool ignore_offset = false) const {
1339         return offset(vargs, ignore_offset) * type().size();
1340     }
1341 
vvar_index(const expr_t & vvar) const1342     int vvar_index(const expr_t &vvar) const {
1343         for (size_t i = 0; i < vvars_.size(); i++)
1344             if (vvar.is_same(vvars_[i])) return int(i);
1345         ir_error_not_expected() << "Can't find view dimension.";
1346         return -1;
1347     }
1348 
1349     template <typename T>
operator ()(const std::vector<T> & vargs) const1350     T operator()(const std::vector<T> &vargs) const {
1351         auto targs = cvt_vargs_to_targs(vargs);
1352         return tlayout_(targs);
1353     }
1354 
1355     view_t create_sub_view(const tensor_t &sub_tensor) const;
1356 
retype(const type_t & new_type) const1357     view_t retype(const type_t &new_type) const {
1358         auto ret = *this;
1359         ret.tlayout_ = tlayout_.retype(new_type);
1360         return ret;
1361     }
1362 
make_dense() const1363     view_t make_dense() const {
1364         auto ret = *this;
1365         ret.tlayout_ = tlayout_.make_dense();
1366         return ret;
1367     }
1368 
can_convert_to_vlayout() const1369     bool can_convert_to_vlayout() const {
1370         if (nvdims() != ntdims()) return false;
1371         for (int i = 0; i < nvdims(); i++) {
1372             if (!tdims_[i].expr().is_same(vvars_[i])) return false;
1373             if (!tdims_[i].is_fixed_stride(0)) return false;
1374         }
1375         return true;
1376     }
1377 
1378     // FIXME: Offset of the returned layout is always 0.
create_pseudo_vlayout() const1379     layout_t create_pseudo_vlayout() const {
1380         return create_pseudo_vlayout(tlayout_);
1381     }
1382 
create_dense_vlayout() const1383     layout_t create_dense_vlayout() const {
1384         return create_pseudo_vlayout().make_dense();
1385     }
1386 
create_vlayout(bool force_zero_offset=false) const1387     layout_t create_vlayout(bool force_zero_offset = false) const {
1388         ir_assert(can_convert_to_vlayout()) << "Can't convert view to layout.";
1389         if (force_zero_offset) return tlayout_.map(tensor_t(vdims_));
1390         return tlayout_.map(tensor_t(vdims_, vstart_));
1391     }
1392 
vlayout_size() const1393     dim_t vlayout_size() const { return create_vlayout().size(); }
1394 
has_same_vlayout(const view_t & other,bool compare_offset=true) const1395     bool has_same_vlayout(
1396             const view_t &other, bool compare_offset = true) const {
1397         return create_vlayout().is_equal(
1398                 other.create_vlayout(), compare_offset);
1399     }
1400 
split(const grid_info_t & grid,tensor_t & vtile) const1401     view_t split(const grid_info_t &grid, tensor_t &vtile) const {
1402         auto vlayout = create_pseudo_vlayout();
1403         vtile = vlayout.split(grid);
1404         return create_sub_view(vtile);
1405     }
1406 
split(const grid_info_t & grid) const1407     view_t split(const grid_info_t &grid) const {
1408         tensor_t vtile;
1409         return split(grid, vtile);
1410     }
1411 
1412     // Tile is assumed to be dense.
split_into_dense_tile(dim_t & tile_elems,dim_t & outer_block) const1413     tensor_t split_into_dense_tile(
1414             dim_t &tile_elems, dim_t &outer_block) const {
1415         auto vlayout = create_pseudo_vlayout();
1416         std::vector<dim_t> blocks = {tile_elems, outer_block};
1417         vlayout = vlayout.split_into_multi_blocks_with_hint(blocks);
1418         if (vlayout.is_empty()) return tensor_t();
1419         tile_elems = blocks[0];
1420         outer_block = blocks[1];
1421         return vlayout.split_into_dense_tile(tile_elems, outer_block);
1422     }
1423 
1424     // Returns a tensor corresponding to the biggest innermost sub-layout so that
1425     // 1) It consists of consecutive blocks only.
1426     // 2) It contains less or equal than max_tile_elems elements.
1427     // 3) It is dense if is_dense_tile is true.
split_into_max_tile(dim_t max_tile_elems,bool is_dense_tile) const1428     tensor_t split_into_max_tile(
1429             dim_t max_tile_elems, bool is_dense_tile) const {
1430         auto vlayout = create_pseudo_vlayout();
1431         return vlayout.split_into_max_tile(max_tile_elems, is_dense_tile);
1432     }
1433 
1434     template <typename F>
for_each_tile(const tensor_t & tile,const F & f) const1435     void for_each_tile(const tensor_t &tile, const F &f) const {
1436         auto vlayout = create_dense_vlayout();
1437         vlayout.for_each_tile(tile, f);
1438     }
1439 
1440     view_t substitute(const expr_t &from, const expr_t &to) const;
1441 
create_mask_tensor(const constraint_set_t & cset) const1442     mask_tensor_t create_mask_tensor(const constraint_set_t &cset) const {
1443         auto _vlayout = create_dense_vlayout();
1444         mask_tensor_t mask_tensor(_vlayout);
1445         std::vector<dim_t> vargs(nvdims());
1446         create_mask_tensor(mask_tensor, _vlayout, 0, vargs);
1447         mask_tensor.simplify(cset);
1448         return mask_tensor;
1449     }
1450 
try_create_buffer_view(view_t & buf_view,view_t & inv_view) const1451     bool try_create_buffer_view(view_t &buf_view, view_t &inv_view) const {
1452         buf_view = view_t(create_vvars(ntdims()), ntdims());
1453         inv_view = view_t(vvars(), ntdims());
1454         for (int i = 0; i < nvdims(); i++) {
1455             inv_view.set_vdim(vvars()[i], vdims()[i]);
1456         }
1457         for (int i = 0; i < ntdims(); i++) {
1458             auto &tdim = tdims_[i];
1459             auto &buf_vvar = buf_view.vvars()[i];
1460             if (tdim.is_identity()) {
1461                 int vidx = tdim.vidx(0);
1462                 buf_view.set_vdim(buf_vvar, vdims()[vidx], vstart(vidx));
1463                 buf_view.set_tdim(i, buf_vvar, tdim.mask());
1464                 inv_view.set_tdim(i, tdim.expr());
1465                 continue;
1466             }
1467             int buf_vdim = 0;
1468             bool ok = true;
1469             for (int j = 0; j < tdim.nvargs(); j++) {
1470                 int vidx = tdim.vidx(j);
1471                 auto &vvar = vvars()[vidx];
1472                 int vdim = vdims()[vidx];
1473                 if (vdim == 1) continue;
1474                 auto A = tdim.expr();
1475                 auto B = jit::substitute(A, vvar, vvar + 1);
1476                 auto C = simplify(B - A);
1477                 if (!is_const(C)) {
1478                     ok = false;
1479                     break;
1480                 }
1481                 buf_vdim += to_cpp<int>(C) * (vdim - 1);
1482             }
1483             buf_vdim++;
1484 
1485             if (!ok) return false;
1486 
1487             auto buf_vstart = tdim.expr();
1488             auto inv_vstart = tdim.expr();
1489             for (int j = 0; j < tdim.nvargs(); j++) {
1490                 int vidx = tdim.vidx(j);
1491                 buf_vstart = jit::substitute(
1492                         buf_vstart, vvars()[vidx], vstart(vidx));
1493                 inv_vstart
1494                         = jit::substitute(inv_vstart, vvars()[vidx], expr_t(0));
1495             }
1496             buf_vstart = simplify(buf_vstart);
1497             inv_vstart = simplify(inv_vstart);
1498 
1499             if (!is_const(inv_vstart)) return false;
1500 
1501             buf_view.set_vdim(buf_vvar, buf_vdim, buf_vstart);
1502 
1503             // Check that mask doesn't contain vvars - they can't be accessed
1504             // in the buffered view.
1505             auto &tmask = tdim.mask();
1506             for (auto &vvar : vvars()) {
1507                 if (contains_object(tmask, vvar)) { return false; }
1508             }
1509 
1510             buf_view.set_tdim(i, buf_vvar, tmask);
1511             inv_view.set_tdim(i, tdim.expr() - inv_vstart);
1512         }
1513         buf_view.set_tlayout(tlayout_);
1514         return true;
1515     }
1516 
placeholder_var()1517     static const expr_t &placeholder_var() {
1518         return tdim_info_t::placeholder_var();
1519     }
1520 
1521     static std::vector<expr_t> create_vvars(int nvdims);
1522 
1523 private:
1524     template <typename SrcT = expr_t, typename DstT = SrcT>
cvt_vargs_to_targs(const std::vector<SrcT> & _vargs={}) const1525     std::vector<DstT> cvt_vargs_to_targs(
1526             const std::vector<SrcT> &_vargs = {}) const {
1527         std::vector<expr_t> vargs = expr_cast<expr_t>(_vargs);
1528         if (vargs.empty()) vargs.resize(nvdims(), 0);
1529 
1530         for (int i = 0; i < nvdims(); i++) {
1531             if (!is_zero(vstart_[i])) vargs[i] += vstart_[i];
1532         }
1533 
1534         std::vector<expr_t> targs(ntdims());
1535         for (int i = 0; i < ntdims(); i++) {
1536             targs[i] = tdims_[i].expr();
1537             for (int j = 0; j < nvdims(); j++) {
1538                 targs[i] = jit::substitute(targs[i], vvars_[j], vargs[j]);
1539             }
1540         }
1541         for (int i = 0; i < ntdims(); i++) {
1542             targs[i] = const_fold(targs[i]);
1543         }
1544         return expr_cast<DstT>(targs);
1545     }
1546 
1547     layout_t create_pseudo_vlayout(const layout_t &tlayout) const;
1548 
create_mask_tensor(mask_tensor_t & mask_tensor,const layout_t & _vlayout,int vidx,std::vector<dim_t> & vargs) const1549     void create_mask_tensor(mask_tensor_t &mask_tensor,
1550             const layout_t &_vlayout, int vidx,
1551             std::vector<dim_t> &vargs) const {
1552         if (vidx == _vlayout.ndims()) {
1553             bool is_init = false;
1554             std::vector<expr_t> vvalues;
1555             std::vector<expr_t> targs;
1556             expr_t mask = bool_imm_t::make(true);
1557             for (int i = 0; i < ntdims(); i++) {
1558                 auto &tdim = tdims_[i];
1559                 if (tdim.mask().is_empty()) continue;
1560                 if (!is_init) {
1561                     // Lazily initialize values
1562                     vvalues = vstart_;
1563                     for (int i = 0; i < nvdims(); i++)
1564                         vvalues[i] += vargs[i];
1565                     targs = cvt_vargs_to_targs<dim_t, expr_t>(vargs);
1566                     is_init = true;
1567                 }
1568                 mask &= tdim.mask(targs[i], vvars_, vvalues);
1569             }
1570             mask_tensor.set_mask(_vlayout(vargs), mask);
1571             return;
1572         }
1573 
1574         for (int i = 0; i < vdims()[vidx]; i++) {
1575             vargs[vidx] = i;
1576             create_mask_tensor(mask_tensor, _vlayout, vidx + 1, vargs);
1577         }
1578     }
1579 
1580     std::vector<expr_t> vvars_;
1581     std::vector<dim_t> vdims_;
1582     std::vector<expr_t> vstart_;
1583 
1584     std::vector<tdim_info_t> tdims_;
1585     layout_t tlayout_;
1586 };
1587 
operator <<(std::ostream & out,const view_t & view)1588 inline std::ostream &operator<<(std::ostream &out, const view_t &view) {
1589     out << view.str();
1590     return out;
1591 }
1592 
1593 class dim_assignment_t {
1594 public:
1595     dim_assignment_t() = default;
1596 
dim_assignment_t(int old_ndims,int new_ndims)1597     dim_assignment_t(int old_ndims, int new_ndims)
1598         : old_ndims_(old_ndims)
1599         , new_ndims_(new_ndims)
1600         , assignments_(old_ndims, -1) {}
1601 
assign(int old_idx,int new_idx)1602     void assign(int old_idx, int new_idx) {
1603         ir_assert(0 <= old_idx && old_idx < old_ndims_);
1604         ir_assert(0 <= new_idx && new_idx < new_ndims_);
1605         assignments_[old_idx] = new_idx;
1606     }
1607 
assign(const std::vector<int> & old_idxes,int new_idx)1608     void assign(const std::vector<int> &old_idxes, int new_idx) {
1609         for (auto old_idx : old_idxes) {
1610             assign(old_idx, new_idx);
1611         }
1612     }
1613 
operator [](int old_idx) const1614     int operator[](int old_idx) const {
1615         ir_assert(old_idx >= 0 && old_idx < old_ndims());
1616         return assignments_[old_idx];
1617     }
1618 
old_ndims() const1619     int old_ndims() const { return old_ndims_; }
1620 
new_ndims() const1621     int new_ndims() const { return new_ndims_; }
1622 
is_empty() const1623     bool is_empty() const { return old_ndims_ == 0 && new_ndims_ == 0; }
1624 
1625     layout_t map(const layout_t &layout) const;
1626 
1627 private:
1628     int old_ndims_ = 0;
1629     int new_ndims_ = 0;
1630 
1631     // assignments_[old_idx] = new_idx.
1632     std::vector<int> assignments_;
1633 };
1634 
1635 std::vector<dim_t> normalize_conv_dims(std::vector<dim_t> &dims,
1636         bool with_groups, int groups, bool is_dw, bool reduced_to_1d,
1637         bool add_groups, bool is_wei);
1638 
1639 layout_t normalize_conv_layout(const layout_t &_layout, bool with_groups,
1640         int groups, bool is_dw, bool reduced_to_1d, bool add_groups,
1641         bool is_wei);
1642 
1643 void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
1644         layout_t &dst_layout, bool with_groups, int groups, bool is_dw,
1645         bool reduced_to_1d, bool add_groups);
1646 
1647 } // namespace jit
1648 } // namespace gpu
1649 } // namespace impl
1650 } // namespace dnnl
1651 
1652 #endif
1653