1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef GPU_JIT_CONV_TENSOR_HPP
18 #define GPU_JIT_CONV_TENSOR_HPP
19
20 #include <algorithm>
21 #include <array>
22 #include <iostream>
23 #include <sstream>
24 #include <string>
25 #include <thread>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 #include <unordered_map>
30
31 #include "common/memory_desc_wrapper.hpp"
32 #include "gpu/jit/conv/ir.hpp"
33 #include "gpu/jit/conv/utils.hpp"
34
35 namespace dnnl {
36 namespace impl {
37 namespace gpu {
38 namespace jit {
39
40 class tensor_t {
41 public:
42 tensor_t() = default;
43
tensor_t(const std::vector<dim_t> & dims)44 tensor_t(const std::vector<dim_t> &dims)
45 : tensor_t(dims, std::vector<expr_t>()) {}
46
tensor_t(const std::vector<dim_t> & dims,const std::vector<expr_t> & start)47 tensor_t(const std::vector<dim_t> &dims, const std::vector<expr_t> &start)
48 : dims_(dims), start_(start) {
49 if (start_.empty()) start_.resize(dims.size(), 0);
50 }
51
tensor_t(const std::vector<dim_t> & dims,const std::vector<dim_t> & start)52 tensor_t(const std::vector<dim_t> &dims, const std::vector<dim_t> &start)
53 : tensor_t(dims) {
54 start_.resize(start.size());
55 for (size_t i = 0; i < start.size(); i++)
56 start_[i] = start[i];
57 }
58
operator ()(int idx) const59 dim_t operator()(int idx) const { return dims_[idx]; }
60
start(int idx) const61 const expr_t &start(int idx) const { return start_[idx]; }
62
ndims() const63 int ndims() const { return int(dims_.size()); }
64
elems() const65 dim_t elems() const {
66 dim_t ret = 1;
67 for (int i = 0; i < ndims(); i++)
68 ret *= dims_[i];
69 return ret;
70 }
71
dims() const72 const std::vector<dim_t> &dims() const { return dims_; }
73
start() const74 const std::vector<expr_t> &start() const { return start_; }
75
is_empty() const76 bool is_empty() const { return dims_.empty(); }
77
is_equal(const tensor_t & other) const78 bool is_equal(const tensor_t &other) const {
79 if (ndims() != other.ndims()) return false;
80 for (int i = 0; i < ndims(); i++) {
81 if (dims_[i] != other.dims_[i]) return false;
82 if (!start_[i].is_equal(other.start_[i])) return false;
83 }
84 return true;
85 }
86
str() const87 std::string str() const {
88 using ir_utils::operator<<;
89
90 if (is_empty()) return "(nil)";
91 std::ostringstream oss;
92 oss << ir_utils::make_seq_print_helper(dims_, "x");
93 if (!has_zero_start()) oss << " start: [" << start_ << "]";
94 return oss.str();
95 }
96
IR_DEFINE_DUMP()97 IR_DEFINE_DUMP()
98
99 bool has_zero_start() const {
100 for (int i = 0; i < ndims(); i++)
101 if (!is_zero(start_[i])) return false;
102 return true;
103 }
104
to_1d_offset(const std::vector<dim_t> & args) const105 dim_t to_1d_offset(const std::vector<dim_t> &args) const {
106 ir_assert(has_zero_start());
107
108 dim_t off = 0;
109 for (int i = 0; i < ndims(); i++) {
110 off *= dims_[i];
111 off += args[i];
112 }
113 return off;
114 }
115
create_sub_tensor(const tensor_t & tile) const116 tensor_t create_sub_tensor(const tensor_t &tile) const {
117 ir_assert(ndims() == tile.ndims()) << "Incompatible sizes.";
118 std::vector<expr_t> new_start = start_;
119 for (int i = 0; i < ndims(); i++)
120 new_start[i] += tile.start(i);
121 return tensor_t(tile.dims(), new_start);
122 }
123
substitute(const expr_t & from,const expr_t & to) const124 tensor_t substitute(const expr_t &from, const expr_t &to) const {
125 tensor_t ret = *this;
126 for (int i = 0; i < ndims(); i++) {
127 ret.start_[i] = jit::substitute(ret.start_[i], from, to);
128 ret.start_[i] = simplify(ret.start_[i]);
129 }
130 return ret;
131 }
132
133 private:
134 std::vector<dim_t> dims_;
135 std::vector<expr_t> start_;
136 };
137
operator <<(std::ostream & out,const tensor_t & tensor)138 inline std::ostream &operator<<(std::ostream &out, const tensor_t &tensor) {
139 out << tensor.str();
140 return out;
141 }
142
143 class grid_info_t {
144 public:
145 grid_info_t() = default;
grid_info_t(int ndims)146 grid_info_t(int ndims) : dims_(ndims), offs_(ndims), idxs_(ndims) {}
grid_info_t(const std::vector<int> & dims,const std::vector<expr_t> & idxs)147 grid_info_t(const std::vector<int> &dims, const std::vector<expr_t> &idxs)
148 : grid_info_t(dims, {}, idxs) {}
grid_info_t(const std::vector<int> & dims,const std::vector<int> & offs,const std::vector<expr_t> & idxs)149 grid_info_t(const std::vector<int> &dims, const std::vector<int> &offs,
150 const std::vector<expr_t> &idxs)
151 : dims_(dims), offs_(offs), idxs_(idxs) {
152 if (offs_.empty()) offs_.resize(dims.size());
153 ir_assert(dims_.size() == offs_.size());
154 ir_assert(dims_.size() == idxs_.size());
155 }
156
operator ==(const grid_info_t & other) const157 bool operator==(const grid_info_t &other) const {
158 if (ndims() != other.ndims()) return false;
159 for (int i = 0; i < ndims(); i++) {
160 if (dim(i) != other.dim(i)) return false;
161 if (off(i) != other.off(i)) return false;
162 if (!idx(i).is_equal(other.idx(i))) return false;
163 }
164 return true;
165 }
166
is_empty() const167 bool is_empty() const { return dims_.empty(); }
168
dim(int dim_idx)169 int &dim(int dim_idx) { return dims_[dim_idx]; }
off(int dim_idx)170 int &off(int dim_idx) { return offs_[dim_idx]; }
idx(int dim_idx)171 expr_t &idx(int dim_idx) { return idxs_[dim_idx]; }
dim_idx(const expr_t & idx_var) const172 int dim_idx(const expr_t &idx_var) const {
173 for (int i = 0; i < ndims(); i++) {
174 if (idx(i).is_same(idx_var)) return i;
175 }
176 ir_error_not_expected() << "Index not found: " << idx_var;
177 return -1;
178 }
179
dim(int dim_idx) const180 const int &dim(int dim_idx) const { return dims_[dim_idx]; }
dim(const expr_t & idx_var) const181 const int &dim(const expr_t &idx_var) const {
182 return dims_[dim_idx(idx_var)];
183 }
off(int dim_idx) const184 const int &off(int dim_idx) const { return offs_[dim_idx]; }
idx(int dim_idx) const185 const expr_t &idx(int dim_idx) const { return idxs_[dim_idx]; }
186
ndims() const187 int ndims() const { return int(dims_.size()); }
elems() const188 int elems() const {
189 return utils::array_product(dims_.data(), dims_.size());
190 }
191
sub_grid(std::initializer_list<int> old_dim_idxs) const192 grid_info_t sub_grid(std::initializer_list<int> old_dim_idxs) const {
193 grid_info_t ret(int(old_dim_idxs.size()));
194 int new_dim_idx = 0;
195 for (auto old_dim_idx : old_dim_idxs) {
196 ret.dim(new_dim_idx) = dim(old_dim_idx);
197 ret.off(new_dim_idx) = off(old_dim_idx);
198 ret.idx(new_dim_idx) = idx(old_dim_idx);
199 new_dim_idx++;
200 }
201 return ret;
202 }
203
slice(int dim_idx,int new_off,int new_dim,const expr_t & new_idx,expr_t & new_idx_value) const204 grid_info_t slice(int dim_idx, int new_off, int new_dim,
205 const expr_t &new_idx, expr_t &new_idx_value) const {
206 ir_assert(dim_idx >= 0 && dim_idx < ndims());
207 ir_assert(new_dim > 0 && new_off >= 0);
208 ir_assert(new_off + new_dim <= dims_[dim_idx]);
209
210 grid_info_t ret = *this;
211 ret.offs_[dim_idx] += new_off;
212 ret.dims_[dim_idx] = new_dim;
213 if (new_off > 0) {
214 new_idx_value = ret.idxs_[dim_idx] - new_off;
215 ret.idxs_[dim_idx] = new_idx;
216 } else {
217 new_idx_value = expr_t();
218 }
219 ret.parent_dims_ = (parent_dims_.empty() ? dims_ : parent_dims_);
220 return ret;
221 }
222
halven(const expr_t & new_idx,int & dim_idx,expr_t & new_idx_value,bool first=true) const223 grid_info_t halven(const expr_t &new_idx, int &dim_idx,
224 expr_t &new_idx_value, bool first = true) const {
225 for (int i = ndims() - 1; i >= 0; i--) {
226 if (dim(i) == 1 || dim(i) % 2 != 0) continue;
227 dim_idx = i;
228 if (first) return slice(i, 0, dim(i) / 2, new_idx, new_idx_value);
229 return slice(i, dim(i) / 2, dim(i) / 2, new_idx, new_idx_value);
230 }
231 return grid_info_t();
232 }
233
slice_condition() const234 expr_t slice_condition() const {
235 if (parent_dims_.empty()) return expr_t();
236 expr_t ret(true);
237 for (int i = 0; i < ndims(); i++) {
238 auto &idx = idxs_[i];
239 if (offs_[i] > 0) ret &= (idx >= 0);
240 if (offs_[i] + dims_[i] < parent_dims_[i]) ret &= (idx < dims_[i]);
241 }
242 if (ret.is_equal(expr_t(true))) return expr_t();
243 return ret;
244 }
245
str() const246 std::string str() const {
247 std::ostringstream oss;
248 oss << ir_utils::make_seq_print_helper(dims_, "x");
249 return oss.str();
250 }
251
252 IR_DEFINE_DUMP()
253
254 private:
255 std::vector<int> dims_;
256 std::vector<int> offs_;
257 std::vector<expr_t> idxs_;
258
259 std::vector<int> parent_dims_;
260 };
261
operator <<(std::ostream & out,const grid_info_t & grid_info)262 inline std::ostream &operator<<(
263 std::ostream &out, const grid_info_t &grid_info) {
264 out << grid_info.str();
265 return out;
266 }
267
268 class grid_splitter_t {
269 public:
grid_splitter_t(const grid_info_t & grid)270 grid_splitter_t(const grid_info_t &grid)
271 : grid_(grid), cur_idx_(grid.ndims() - 1), cur_stride_(1) {
272 skip_size_1_dims();
273 ir_assert(cur_idx_ >= 0);
274 }
275
cur_block() const276 int cur_block() const {
277 if (is_empty()) return 1;
278
279 return grid_.dim(cur_idx_) / cur_stride_;
280 }
281
is_empty() const282 bool is_empty() const { return cur_idx_ == -1; }
283
can_pop_block(int size) const284 bool can_pop_block(int size) const {
285 if (is_empty()) return false;
286 return cur_block() % size == 0;
287 }
288
289 expr_t pop_block(int size);
290
291 private:
skip_size_1_dims()292 void skip_size_1_dims() {
293 while (cur_idx_ >= 0 && grid_.dim(cur_idx_) == 1)
294 cur_idx_--;
295 }
296
297 grid_info_t grid_;
298
299 int cur_idx_;
300 int cur_stride_;
301 };
302
303 enum class stride_kind_t {
304 undef,
305 fixed,
306 unknown,
307 };
308
309 class stride_t {
310 public:
311 stride_t() = default;
312
stride_t(dim_t stride)313 stride_t(dim_t stride) : stride_t(stride_kind_t::fixed, stride) {}
314
operator ==(const stride_t & other) const315 bool operator==(const stride_t &other) const {
316 return (kind_ == other.kind_) && (stride_ == other.stride_);
317 }
318
operator !=(const stride_t & other) const319 bool operator!=(const stride_t &other) const { return !operator==(other); }
320
get_hash() const321 size_t get_hash() const { return ir_utils::get_hash(kind_, stride_); }
322
operator dim_t() const323 operator dim_t() const {
324 ir_assert(kind_ == stride_kind_t::fixed);
325 return stride_;
326 }
327
is_fixed() const328 bool is_fixed() const { return kind_ == stride_kind_t::fixed; }
329
is_unknown() const330 bool is_unknown() const { return kind_ == stride_kind_t::unknown; }
331
operator *=(const stride_t & other)332 stride_t &operator*=(const stride_t &other) {
333 if (is_fixed() && other.is_fixed()) {
334 stride_ *= other.stride_;
335 } else {
336 set_unknown();
337 }
338 return *this;
339 }
340
operator /=(const stride_t & other)341 stride_t &operator/=(const stride_t &other) {
342 if (is_fixed() && other.is_fixed()) {
343 stride_ /= other.stride_;
344 } else {
345 set_unknown();
346 }
347 return *this;
348 }
349
str() const350 std::string str() const {
351 std::ostringstream oss;
352 if (is_fixed()) {
353 oss << stride_;
354 } else {
355 oss << "(unknown)";
356 }
357 return oss.str();
358 }
359
IR_DEFINE_DUMP()360 IR_DEFINE_DUMP()
361
362 static stride_t unknown() { return stride_t(stride_kind_t::unknown); }
363
364 private:
stride_t(stride_kind_t kind,dim_t stride=0)365 stride_t(stride_kind_t kind, dim_t stride = 0)
366 : kind_(kind), stride_(stride) {}
367
set_unknown()368 void set_unknown() {
369 kind_ = stride_kind_t::unknown;
370 stride_ = 0;
371 }
372
373 stride_kind_t kind_ = stride_kind_t::undef;
374 dim_t stride_ = 0;
375 };
376
operator <<(std::ostream & out,const stride_t & stride)377 inline std::ostream &operator<<(std::ostream &out, const stride_t &stride) {
378 out << stride.str();
379 return out;
380 }
381
operator *(const stride_t & a,const stride_t & b)382 inline stride_t operator*(const stride_t &a, const stride_t &b) {
383 stride_t tmp = a;
384 return tmp *= b;
385 }
386
operator *(const stride_t & a,dim_t b)387 inline stride_t operator*(const stride_t &a, dim_t b) {
388 return a * stride_t(b);
389 }
390
operator *(dim_t a,const stride_t & b)391 inline stride_t operator*(dim_t a, const stride_t &b) {
392 return stride_t(a) * b;
393 }
394
395 struct block_t {
396 block_t() = default;
397
block_tdnnl::impl::gpu::jit::block_t398 block_t(int dim_idx, dim_t block, const stride_t &stride)
399 : dim_idx(dim_idx), block(block), stride(stride) {}
400
is_equaldnnl::impl::gpu::jit::block_t401 bool is_equal(const block_t &other) const {
402 return (dim_idx == other.dim_idx) && (block == other.block)
403 && (stride == other.stride);
404 }
405
get_hashdnnl::impl::gpu::jit::block_t406 size_t get_hash() const {
407 return ir_utils::get_hash(dim_idx, block, stride);
408 }
409
strdnnl::impl::gpu::jit::block_t410 std::string str() const {
411 std::ostringstream oss;
412 oss << "block_t(dim_idx = " << dim_idx;
413 oss << ", block = " << block;
414 oss << ", stride = " << stride;
415 oss << ")";
416 return oss.str();
417 }
418
419 IR_DEFINE_DUMP()
420
421 int dim_idx; // Dimension index.
422 dim_t block; // Block size.
423 stride_t stride; // Stride between elements of the block.
424 };
425
operator <<(std::ostream & out,const block_t & b)426 inline std::ostream &operator<<(std::ostream &out, const block_t &b) {
427 out << b.str();
428 return out;
429 }
430
431 class layout_t {
432 public:
433 static const int max_ndims = 6;
434
layout_t()435 layout_t() : type_(type_t::undef()), ndims_(0), offset_(0) {
436 sanity_check();
437 }
438
439 layout_t(const type_t &type, const expr_t &offset,
440 const std::string &format, const std::vector<dim_t> &dims = {},
441 bool do_normalize = true);
442
layout_t(const memory_desc_wrapper & mdw,const std::string & format,bool do_normalize=true)443 layout_t(const memory_desc_wrapper &mdw, const std::string &format,
444 bool do_normalize = true)
445 : layout_t(mdw.data_type(), mdw.offset0(), format,
446 std::vector<dim_t>(
447 mdw.padded_dims(), mdw.padded_dims() + mdw.ndims()),
448 do_normalize) {}
449
layout_t(const memory_desc_wrapper & mdw,const char * format,bool do_normalize=true)450 layout_t(const memory_desc_wrapper &mdw, const char *format,
451 bool do_normalize = true)
452 : layout_t(mdw, std::string(format), do_normalize) {}
453
454 layout_t(const memory_desc_wrapper &mdw, bool do_normalize = true);
455
layout_t(const type_t & type,const expr_t & offset,const std::vector<dim_t> & dims,bool do_normalize=true)456 layout_t(const type_t &type, const expr_t &offset,
457 const std::vector<dim_t> &dims, bool do_normalize = true)
458 : type_(type), ndims_(int(dims.size())), offset_(offset) {
459 dim_t stride = 1;
460 for (int i = ndims_ - 1; i >= 0; i--) {
461 blocks_.emplace_back(i, dims[i], stride);
462 stride *= dims[i];
463 }
464 if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
465 sanity_check();
466 }
467
layout_t(const type_t & type,int ndims,const expr_t & offset,const std::vector<block_t> & blocks,bool do_normalize=true)468 layout_t(const type_t &type, int ndims, const expr_t &offset,
469 const std::vector<block_t> &blocks, bool do_normalize = true)
470 : type_(type), ndims_(ndims), offset_(offset), blocks_(blocks) {
471 if (do_normalize) blocks_ = normalize_blocks(ndims_, blocks_);
472 sanity_check();
473 }
474
layout_t(const type_t & type,const expr_t & offset,const layout_t & other,bool do_normalize)475 layout_t(const type_t &type, const expr_t &offset, const layout_t &other,
476 bool do_normalize)
477 : layout_t(type, other.ndims(), offset, other.blocks(), do_normalize) {}
478
is_empty() const479 bool is_empty() const { return ndims_ == 0; }
480
ndims() const481 int ndims() const { return ndims_; }
482
elems() const483 dim_t elems() const {
484 dim_t ret = 1;
485 for (auto &b : blocks_)
486 ret *= b.block;
487 return ret;
488 }
489
490 // Storage size in bytes.
size() const491 dim_t size() const {
492 if (is_empty()) return 0;
493 dim_t max_stride = 1;
494 for (auto &b : blocks_) {
495 max_stride = std::max(max_stride, dim_t(b.block * b.stride));
496 }
497 return max_stride * type().size();
498 }
499
500 template <typename T = expr_t>
offset(const std::vector<T> & args={},bool ignore_offset=false) const501 T offset(
502 const std::vector<T> &args = {}, bool ignore_offset = false) const {
503 if (args.empty()) return expr_cast<T>(offset_);
504
505 ir_assert(int(args.size()) == ndims()) << "Dimensions do not match.";
506
507 T off = 0;
508 auto _args = args;
509 for (auto &eb : enumerated_blocks()) {
510 auto &b = eb.second;
511 auto &idx = _args[b.dim_idx];
512 if (ir_utils::is_equal(idx, T(0))) continue;
513
514 // Do not use modulus for outermost blocks.
515 auto i = is_outermost(eb) ? idx : (idx % b.block);
516 off = i * dim_t(b.stride) + off;
517 idx /= b.block;
518 }
519 if (ignore_offset) return off;
520
521 T off0 = expr_cast<T>(offset_);
522 return off0 + off;
523 }
524
type() const525 const type_t &type() const { return type_; }
526
dims() const527 std::vector<dim_t> dims() const {
528 std::vector<dim_t> dims(ndims(), 1);
529 for (auto &b : blocks_) {
530 dims[b.dim_idx] *= b.block;
531 }
532 return dims;
533 }
534
dim(int dim_idx) const535 dim_t dim(int dim_idx) const {
536 dim_t ret = 1;
537 for (auto &b : blocks_) {
538 if (b.dim_idx == dim_idx) ret *= b.block;
539 }
540 return ret;
541 }
542
blocks() const543 const std::vector<block_t> &blocks() const { return blocks_; }
544
set_offset(const expr_t & offset)545 void set_offset(const expr_t &offset) { offset_ = offset; }
546
is_strictly_equal(const layout_t & other,bool compare_offset=true) const547 bool is_strictly_equal(
548 const layout_t &other, bool compare_offset = true) const {
549 if (!type_.is_equal(other.type_)) return false;
550 if (compare_offset && !offset_.is_equal(other.offset_)) return false;
551 if (!ir_utils::is_equal(blocks_, other.blocks_)) return false;
552 return true;
553 }
554
operator ==(const layout_t & other) const555 bool operator==(const layout_t &other) const { return is_equal(other); }
556
operator !=(const layout_t & other) const557 bool operator!=(const layout_t &other) const { return !operator==(other); }
558
is_equal(const layout_t & other,bool compare_offset=true) const559 bool is_equal(const layout_t &other, bool compare_offset = true) const {
560 return normalize().is_strictly_equal(other.normalize(), compare_offset);
561 }
562
get_hash() const563 size_t get_hash() const {
564 return ir_utils::get_hash(type_, ndims_, offset_, blocks_);
565 }
566
567 template <typename T>
operator ()(const std::vector<T> & args) const568 T operator()(const std::vector<T> &args) const {
569 return offset(args);
570 }
571
572 template <typename T = expr_t>
offset_in_bytes(const std::vector<T> & args={},bool ignore_offset=false) const573 T offset_in_bytes(
574 const std::vector<T> &args = {}, bool ignore_offset = false) const {
575 return offset(args, ignore_offset) * type().size();
576 }
577
desc_str(bool dnnl_style=false) const578 std::string desc_str(bool dnnl_style = false) const {
579 if (is_empty()) return "(nil)";
580 if (!dnnl_style && blocks_.empty()) return "(scalar)";
581 std::string ret;
582 stride_t dense_stride(1);
583 std::vector<bool> seen(ndims());
584 for (auto &eb : enumerated_blocks()) {
585 auto &b = eb.second;
586 std::string b_str;
587 if (dnnl_style && is_outermost(eb)) {
588 b_str.append(1, (seen[b.dim_idx] ? 'A' : 'a') + b.dim_idx);
589 } else {
590 b_str = std::to_string(b.block);
591 b_str.append(1, 'a' + b.dim_idx);
592 }
593 if (!dnnl_style) {
594 if (b.stride.is_unknown()) {
595 b_str.append(1, '?');
596 } else if (b.stride != dense_stride) {
597 b_str.append(1, '*');
598 }
599 }
600 ret = b_str + ret;
601 dense_stride = b.stride * b.block;
602 seen[b.dim_idx] = true;
603 }
604 return ret;
605 }
606
str() const607 std::string str() const {
608 if (is_empty()) return "(nil)";
609 std::ostringstream oss;
610 oss << desc_str();
611 if (!has_zero_offset()) oss << " offset: " << offset_;
612 return oss.str();
613 }
614
615 IR_DEFINE_DUMP()
616
617 memory_desc_t to_dnnl(const dim_t *dims_hint) const;
618
619 // Returns a vector of <block index, block> pairs.
620 // The innermost block (first) has index 0.
enumerated_blocks() const621 std::vector<std::pair<int, block_t>> enumerated_blocks() const {
622 std::vector<std::pair<int, block_t>> ret;
623 for (int i = 0; i < int(blocks_.size()); i++) {
624 ret.emplace_back(i, blocks_[i]);
625 }
626 return ret;
627 }
628
strides(int dim_idx) const629 std::vector<dim_t> strides(int dim_idx) const {
630 std::vector<dim_t> ret;
631 for (auto &b : blocks_)
632 if (b.dim_idx == dim_idx) ret.push_back(b.stride);
633 return ret;
634 }
635
636 // eb is <block index, block> pair, see enumerated_blocks().
is_outermost(const std::pair<int,block_t> & eb) const637 bool is_outermost(const std::pair<int, block_t> &eb) const {
638 return is_outermost(eb, blocks_);
639 }
640
is_plain() const641 bool is_plain() const {
642 std::vector<bool> seen(ndims());
643 for (auto &b : blocks_) {
644 if (seen[b.dim_idx]) return false;
645 seen[b.dim_idx] = true;
646 }
647 return true;
648 }
649
has_zero_offset() const650 bool has_zero_offset() const { return offset_.is_equal(expr_t(0)); }
651
has_unknown_strides() const652 bool has_unknown_strides() const {
653 for (auto &b : blocks_)
654 if (b.stride.is_unknown()) return true;
655 return false;
656 }
657
658 // Returns a canonical representation of the layout:
659 // - Size one blocks are removed
660 // - Consecutive dense blocks are merged
normalize() const661 layout_t normalize() const {
662 auto blocks = normalize_blocks(ndims(), blocks_);
663 return layout_t(type(), ndims(), offset(), blocks);
664 }
665
transpose() const666 layout_t transpose() const {
667 if (ndims() != 2) ir_error_not_expected();
668
669 // Flip: 0 -> 1, 1 -> 0.
670 auto blocks = blocks_;
671 for (auto &b : blocks)
672 b.dim_idx ^= 1;
673
674 return layout_t(type(), ndims(), offset(), blocks);
675 }
676
677 // Returns a new (sub-)layout that fully contains the passed sub-tensor.
678 // Strides are kept unchanged.
679 // Assumption: the original layout can be tiled by the passed sub-tensor.
680 // For example: XaYb4a2b can be tiled into 2x2 sub-tensors but it's not
681 // possible to tile it into 3x2 sub-tensors.
682 layout_t map(const tensor_t &tensor) const;
683
684 layout_t reinterpret(
685 const type_t &new_type, bool do_normalize = true) const;
686
retype(const type_t & new_type) const687 layout_t retype(const type_t &new_type) const {
688 auto ret = *this;
689 ret.type_ = new_type;
690 return ret;
691 }
692
is_dense() const693 bool is_dense() const {
694 stride_t stride = 1;
695 for (auto &b : blocks_) {
696 if (b.stride != stride) return false;
697 stride *= b.block;
698 }
699 return true;
700 }
701
702 // Returns true if the layout has at least n inner blocks. For example:
703 // NChw32n16c - 2 inner blocks.
is_n_blocked(int n) const704 bool is_n_blocked(int n) const {
705 int block_count[layout_t::max_ndims] = {0};
706 for (auto &b : blocks_)
707 block_count[b.dim_idx]++;
708
709 int ninner_blocks = 0;
710 stride_t stride = 1;
711 for (auto &b : blocks_) {
712 if (b.stride != stride) break; // Not dense anymore.
713 if (block_count[b.dim_idx] == 1) break; // Outer block.
714 stride *= b.block;
715 ir_assert(block_count[b.dim_idx] > 1);
716 block_count[b.dim_idx]--;
717 ninner_blocks++;
718 }
719
720 return ninner_blocks >= n;
721 }
722
723 // Returns a packed layout where all blocks are contiguous, without gaps.
make_dense() const724 layout_t make_dense() const {
725 dim_t stride = 1;
726 auto new_blocks = blocks_;
727 for (auto &b : new_blocks) {
728 b.stride = stride;
729 stride *= b.block;
730 }
731 return layout_t(type(), ndims(), 0, new_blocks);
732 }
733
make_strided(int _stride) const734 layout_t make_strided(int _stride) const {
735 stride_t stride = _stride;
736 auto new_blocks = blocks_;
737 for (auto &b : new_blocks) {
738 b.stride = stride;
739 stride *= b.block;
740 }
741 return layout_t(type(), ndims(), 0, new_blocks);
742 }
743
744 // Returns an equivalent layout where the specified block is split into two.
745 // block0 - inner block size.
746 // block1 - outer block size.
747 layout_t split_block(const std::pair<int, block_t> &eb, dim_t block0,
748 dim_t block1) const;
749
750 // Splits blocks so that they can be used to form `multi_blocks` without
751 // crossing the block boundaries. `multi_blocks` are ordered from innermost
752 // to outermost. Returns an empty layout if such a split is not possible.
753 // Example (all blocks are ordered from innermost to outermost):
754 // Input blocks: [4, 4, 2]
755 // Multi-blocks: [8, 2]
756 // Output blocks: [4, 2, 2, 2]
757 layout_t split_into_multi_blocks(
758 const std::vector<dim_t> &multi_blocks) const;
759
760 layout_t split_into_multi_blocks_with_hint(
761 std::vector<dim_t> &multi_blocks) const;
762
add_outer_block(int dim_idx,dim_t block,dim_t stride=-1) const763 layout_t add_outer_block(
764 int dim_idx, dim_t block, dim_t stride = -1) const {
765 if (stride == -1) stride = elems();
766 ir_assert(stride >= elems());
767 ir_assert(dim_idx < ndims());
768 auto new_blocks = blocks();
769 new_blocks.emplace_back(dim_idx, block, stride);
770 return layout_t(type(), ndims(), offset(), new_blocks);
771 }
772
773 tensor_t split_into_dense_tile(dim_t tile_elems, dim_t outer_block) const;
774
775 // Returns a tensor corresponding to the biggest innermost sub-layout so that
776 // 1) It consists of consecutive blocks only.
777 // 2) It contains less or equal than max_tile_elems elements.
778 // 3) It is dense if is_dense_tile is true.
779 tensor_t split_into_max_tile(
780 dim_t max_tile_elems, bool is_dense_tile) const;
781
split(const grid_info_t & grid) const782 tensor_t split(const grid_info_t &grid) const {
783 std::vector<dim_t> tile_dims(ndims(), 1);
784 ir_assert(elems() % grid.elems() == 0) << "Can't split across grid.";
785
786 dim_t cur_elems_per_tile = 1;
787 dim_t elems_per_tile = elems() / grid.elems();
788 for (auto &b : blocks()) {
789 dim_t block
790 = std::min(b.block, elems_per_tile / cur_elems_per_tile);
791 tile_dims[b.dim_idx] *= block;
792 cur_elems_per_tile *= block;
793 }
794 ir_assert(cur_elems_per_tile == elems_per_tile)
795 << "Can't split across grid.";
796
797 return split(tensor_t(tile_dims), grid);
798 }
799
split(const tensor_t & tile,const grid_info_t & grid,std::vector<block_t> * outer_blocks=nullptr) const800 tensor_t split(const tensor_t &tile, const grid_info_t &grid,
801 std::vector<block_t> *outer_blocks = nullptr) const {
802 ir_assert(ndims() == tile.ndims())
803 << "Number of dimensions doesn't match.";
804 ir_assert(tile.has_zero_start());
805
806 if (outer_blocks) outer_blocks->resize(0);
807
808 if (grid.elems() == 1) return tile;
809
810 dim_t total_elems = elems();
811 dim_t tile_elems = tile.elems();
812
813 grid_splitter_t grid_splitter(grid);
814 ir_assert(tile_elems * grid.elems() == total_elems)
815 << "Tile/grid dimensions do not match.";
816 MAYBE_UNUSED(total_elems);
817 MAYBE_UNUSED(tile_elems);
818
819 std::vector<dim_t> dims(tile.ndims(), 1);
820 std::vector<expr_t> start(tile.ndims(), 0);
821 std::vector<dim_t> rem_dims = tile.dims();
822 for (auto &eb : enumerated_blocks()) {
823 auto &b = eb.second;
824 if (b.block == 1) continue;
825
826 dim_t &e = rem_dims[b.dim_idx];
827 if (e > 1) {
828 if (e % b.block == 0) {
829 e /= b.block;
830 } else if (b.block % e == 0) {
831 auto tmp_layout = split_block(eb, e, b.block / e);
832 return tmp_layout.split(tile, grid, outer_blocks);
833 } else {
834 ir_error_not_expected() << "Can't split across grid.";
835 }
836 } else {
837 dim_t next_chunk
838 = math::gcd(b.block, grid_splitter.cur_block());
839 if (b.block == next_chunk) {
840 auto idx = grid_splitter.pop_block(next_chunk);
841 start[b.dim_idx] += idx * dims[b.dim_idx];
842 if (outer_blocks) outer_blocks->push_back(b);
843 } else if (b.block % next_chunk == 0) {
844 auto tmp_layout
845 = split_block(eb, next_chunk, b.block / next_chunk);
846 return tmp_layout.split(tile, grid, outer_blocks);
847 } else {
848 ir_error_not_expected() << "Can't split across grid.";
849 }
850 }
851 dims[b.dim_idx] *= b.block;
852 }
853 return tensor_t(tile.dims(), start);
854 }
855
856 // Iterates through tiles of the layout, calling `f` with relative offsets
857 // for each tile. The iteration order is defined by the layout blocks -
858 // absolute 1D offsets are increasing between callback calls.
859 template <typename F>
for_each_tile(const tensor_t & tile,const F & f) const860 void for_each_tile(const tensor_t &tile, const F &f) const {
861 ir_assert(tile.ndims() == ndims());
862 ir_assert(tile.has_zero_start());
863 for (int i = 0; i < ndims(); i++) {
864 ir_assert(dim(i) % tile.dims()[i] == 0);
865 }
866
867 int nblocks = int(blocks().size());
868 std::vector<dim_t> sub_blocks(nblocks);
869 for (int i = 0; i < nblocks; i++)
870 sub_blocks[i] = blocks()[i].block;
871
872 for (int i = 0; i < ndims(); i++) {
873 dim_t dim = tile.dims()[i];
874 for (auto &eb : enumerated_blocks()) {
875 auto &b = eb.second;
876 if (b.dim_idx != i) continue;
877 int block_idx = eb.first;
878 if (b.block >= dim) {
879 ir_assert(b.block % dim == 0);
880 sub_blocks[block_idx] = b.block / dim;
881 break;
882 }
883 sub_blocks[block_idx] = 1;
884 ir_assert(dim % b.block == 0);
885 dim /= b.block;
886 }
887 }
888
889 int ntiles = int(elems() / tile.elems());
890
891 std::vector<dim_t> sub_block_idxs(nblocks);
892 for (int i = 0; i < ntiles; i++) {
893 // Convert sub-block indices to dimension indices.
894 std::vector<dim_t> dims(ndims(), 1);
895 std::vector<dim_t> start(ndims());
896 for (int j = 0; j < nblocks; j++) {
897 auto &b = blocks()[j];
898 dim_t k = sub_block_idxs[j]
899 * (blocks()[j].block / sub_blocks[j]);
900 start[b.dim_idx] += dims[b.dim_idx] * k;
901 dims[b.dim_idx] *= b.block;
902 }
903
904 // Pass dimension offsets to the callback.
905 f(start);
906
907 // Move to the next vector of indices.
908 for (int j = 0; j < nblocks; j++) {
909 auto &idx = sub_block_idxs[j];
910 if (idx + 1 < sub_blocks[j]) {
911 idx++;
912 break;
913 }
914 idx = 0;
915 }
916 }
917 }
918
919 // eb is <block index, block> pair, see enumerated_blocks().
is_outermost(const std::pair<int,block_t> & eb,const std::vector<block_t> & blocks)920 static bool is_outermost(const std::pair<int, block_t> &eb,
921 const std::vector<block_t> &blocks) {
922 int dim_idx = eb.second.dim_idx;
923 for (int i = 0; i < int(blocks.size()); i++) {
924 if (blocks[i].dim_idx == dim_idx && i > eb.first) return false;
925 }
926 return true;
927 }
928
929 // Assume that layouts are normalized.
930 static void align_layouts(layout_t &a, layout_t &b);
931
normalize_blocks(int ndims,const std::vector<block_t> & blocks,bool remove_size_1_blocks=true)932 static std::vector<block_t> normalize_blocks(int ndims,
933 const std::vector<block_t> &blocks,
934 bool remove_size_1_blocks = true) {
935 auto new_blocks = blocks;
936
937 // Remove blocks of size 1.
938 if (remove_size_1_blocks) {
939 for (auto it = new_blocks.begin(); it != new_blocks.end();) {
940 if (it->block == 1) {
941 it = new_blocks.erase(it);
942 } else {
943 ++it;
944 }
945 }
946 }
947
948 // Merge same dimension blocks.
949 block_t prev_b;
950 prev_b.dim_idx = -1;
951 for (auto it = new_blocks.begin(); it != new_blocks.end();) {
952 if (it->dim_idx == prev_b.dim_idx
953 && it->stride == (prev_b.stride * prev_b.block)) {
954 auto &b = *(it - 1);
955 b.block *= it->block;
956 prev_b = b;
957 it = new_blocks.erase(it);
958 } else {
959 prev_b = *it;
960 ++it;
961 }
962 }
963
964 return new_blocks;
965 }
966
967 private:
968 // Returns vector of <dimension index, block size> pairs.
969 static std::vector<std::pair<int, dim_t>> parse_format(
970 const std::string &format, int ndims_hint);
971
972 // Returns vector of <dimension letter, block size> pairs.
973 static std::vector<std::pair<char, dim_t>> parse_letter_blocks(
974 const std::string &format);
975
976 void sanity_check() const;
977
978 layout_t split_into_multi_blocks_impl(
979 const std::vector<dim_t> &multi_blocks,
980 std::vector<dim_t> *out_multi_blocks) const;
981
982 // Data type of the layout.
983 type_t type_;
984
985 // Number of dimensions.
986 int ndims_;
987
988 // Offset to the start of the layout (in elements of type).
989 expr_t offset_;
990
991 // Blocks ordered from innermost to outermost.
992 std::vector<block_t> blocks_;
993 };
994
operator <<(std::ostream & out,const layout_t & layout)995 inline std::ostream &operator<<(std::ostream &out, const layout_t &layout) {
996 out << layout.str();
997 return out;
998 }
999
1000 class mask_tensor_t {
1001 public:
1002 mask_tensor_t() = default;
1003
mask_tensor_t(const layout_t & layout)1004 mask_tensor_t(const layout_t &layout)
1005 : layout_(layout), masks_(layout.elems(), -1) {
1006 ir_assert(layout.is_dense());
1007 }
1008
mask_tensor_t(const layout_t & layout,const std::vector<int> & masks,const object_eq_map_t<expr_t,int> & mask2ids,const std::vector<expr_t> & id2masks)1009 mask_tensor_t(const layout_t &layout, const std::vector<int> &masks,
1010 const object_eq_map_t<expr_t, int> &mask2ids,
1011 const std::vector<expr_t> &id2masks)
1012 : layout_(layout)
1013 , masks_(masks)
1014 , mask2ids_(mask2ids)
1015 , id2masks_(id2masks) {
1016 ir_assert(int(masks.size()) == elems()) << "Incompatible size.";
1017 }
1018
type() const1019 const type_t &type() const { return layout_.type(); }
1020
layout() const1021 const layout_t &layout() const { return layout_; }
1022
elems() const1023 dim_t elems() const { return layout_.elems(); }
1024
set_mask(dim_t off,const expr_t & mask)1025 void set_mask(dim_t off, const expr_t &mask) {
1026 ir_assert(0 <= off && off < elems()) << "Incorrect offset.";
1027 if (mask.is_empty()) return;
1028
1029 auto ret = mask2ids_.insert({mask, int(mask2ids_.size())});
1030 int id = ret.first->second;
1031 masks_[off] = id;
1032
1033 if (ret.second) id2masks_.push_back(mask);
1034 }
1035
mask(dim_t off) const1036 const expr_t &mask(dim_t off) const {
1037 ir_assert(0 <= off && off < elems());
1038 return id2masks_[masks_[off]];
1039 }
1040
simplify(const constraint_set_t & cset)1041 void simplify(const constraint_set_t &cset) {
1042 for (auto &mask : id2masks_) {
1043 auto new_mask = jit::simplify(mask, cset);
1044 // Some complex expressions need more than one simplify() call.
1045 int max_tries = 5;
1046 for (int i = 0; i < max_tries; i++) {
1047 mask = new_mask;
1048 new_mask = jit::simplify(new_mask, cset);
1049 if (new_mask.is_equal(mask)) break;
1050 }
1051 }
1052 mask2ids_.clear();
1053 for (int i = 0; i < int(id2masks_.size()); i++) {
1054 auto ret = mask2ids_.insert({id2masks_[i], i});
1055 if (!ret.second) {
1056 for (auto &m : masks_)
1057 if (m == i) m = ret.first->second;
1058 }
1059 }
1060 }
1061
map(const tensor_t & tile) const1062 mask_tensor_t map(const tensor_t &tile) const {
1063 auto tile_start = expr_cast<dim_t>(tile.start());
1064 auto sub_layout = layout_.map(tensor_t(tile.dims()));
1065 mask_tensor_t sub_mask(sub_layout);
1066 ir_utils::for_each(
1067 tile.dims(), [&](const std::vector<dim_t> &sub_start) {
1068 dim_t sub_off = sub_layout(sub_start);
1069 dim_t off = layout_(tile_start) + layout_(sub_start);
1070 sub_mask.set_mask(sub_off, mask(off));
1071 });
1072 return sub_mask;
1073 }
1074
reinterpret(const type_t & new_type) const1075 mask_tensor_t reinterpret(const type_t &new_type) const {
1076 ir_assert(!is_empty()) << "Can't reinterpret.";
1077 dim_t bytes = elems() * type().size();
1078 if (bytes % new_type.size() != 0 && bytes > new_type.size())
1079 return mask_tensor_t();
1080 int new_mask_size = std::max((int)(bytes / new_type.size()), 1);
1081 std::vector<int> new_masks(new_mask_size);
1082 for (dim_t i = 0; i < bytes; i += new_type.size()) {
1083 int mask_id = std::numeric_limits<int>::max();
1084 for (int j = 0; j < new_type.size() && j < bytes; j++) {
1085 int cur_mask_id = masks_[(i + j) / type().size()];
1086 if (mask_id >= int(masks_.size())) {
1087 mask_id = cur_mask_id;
1088 } else if (mask_id != cur_mask_id) {
1089 // Mask is not consistent, can't reinterpret.
1090 return mask_tensor_t();
1091 }
1092 }
1093 ir_assert(0 <= mask_id && mask_id < int(masks_.size()));
1094 new_masks[i / new_type.size()] = mask_id;
1095 }
1096 dim_t new_elmes = utils::div_up(bytes, new_type.size());
1097 layout_t _1d_layout(new_type, 0, std::vector<dim_t> {new_elmes});
1098 return mask_tensor_t(_1d_layout, new_masks, mask2ids_, id2masks_);
1099 }
1100
to_expr(int nmasks) const1101 expr_t to_expr(int nmasks) const {
1102 if (elems() % nmasks != 0) return expr_t();
1103
1104 std::vector<expr_t> vec(nmasks);
1105 for (int i = 0; i < elems(); i++) {
1106 auto &channel_mask = vec[i % nmasks];
1107 auto &cur_mask = id2masks_[masks_[i]];
1108 if (channel_mask.is_empty()) {
1109 channel_mask = cur_mask;
1110 continue;
1111 }
1112 if (!channel_mask.is_equal(cur_mask)) return expr_t();
1113 }
1114 auto e = shuffle_t::make(vec);
1115 e = jit::simplify(e);
1116 e = jit::simplify_propagate_shuffle(e);
1117 return e;
1118 }
1119
is_empty() const1120 bool is_empty() const { return layout_.is_empty(); }
1121
str() const1122 std::string str() const {
1123 std::ostringstream oss;
1124 for (int i = 0; i < int(elems()); i++) {
1125 if (i != 0) oss << std::endl;
1126 oss << "mask #" << i << ": ";
1127 if (masks_[i] == -1) {
1128 oss << "(nil)";
1129 } else {
1130 oss << id2masks_[masks_[i]];
1131 }
1132 }
1133 return oss.str();
1134 }
1135
1136 IR_DEFINE_DUMP()
1137
1138 private:
1139 layout_t layout_;
1140 std::vector<int> masks_;
1141
1142 object_eq_map_t<expr_t, int> mask2ids_;
1143 std::vector<expr_t> id2masks_;
1144 };
1145
operator <<(std::ostream & out,const mask_tensor_t & mask_tensor)1146 inline std::ostream &operator<<(
1147 std::ostream &out, const mask_tensor_t &mask_tensor) {
1148 out << mask_tensor.str();
1149 return out;
1150 }
1151
1152 class tdim_info_t {
1153 public:
1154 tdim_info_t() = default;
1155
tdim_info_t(const expr_t & expr,const expr_t & mask)1156 tdim_info_t(const expr_t &expr, const expr_t &mask)
1157 : expr_(expr), mask_(mask) {}
1158
nvargs() const1159 int nvargs() const { return nvargs_; }
1160
expr() const1161 const expr_t &expr() const { return expr_; }
1162
mask() const1163 const expr_t &mask() const { return mask_; }
1164
mask(const expr_t & tvalue,const std::vector<expr_t> & vvars,const std::vector<expr_t> & vvalues) const1165 expr_t mask(const expr_t &tvalue, const std::vector<expr_t> &vvars,
1166 const std::vector<expr_t> &vvalues) const {
1167 auto ret = substitute(mask_, placeholder_var(), tvalue);
1168 for (int i = 0; i < int(vvars.size()); i++) {
1169 if (contains_object(ret, vvars[i])) {
1170 ret = substitute(ret, vvars[i], vvalues[i]);
1171 }
1172 }
1173 return ret;
1174 }
1175
vidx(int arg_idx) const1176 int vidx(int arg_idx) const {
1177 ir_assert(arg_idx < nvargs());
1178 return vidxs_[arg_idx];
1179 }
1180
vstride(int arg_idx) const1181 stride_t vstride(int arg_idx) const {
1182 ir_assert(arg_idx < nvargs());
1183 return vstrides_[arg_idx];
1184 }
1185
is_empty() const1186 bool is_empty() const { return expr_.is_empty(); }
1187
is_identity() const1188 bool is_identity() const { return is_var(expr_); }
1189
is_fixed_stride(int arg_idx) const1190 bool is_fixed_stride(int arg_idx) const {
1191 ir_assert(arg_idx < nvargs());
1192 return vstrides_[arg_idx].is_fixed();
1193 }
1194
add_vvar(int vidx,const expr_t & varg)1195 void add_vvar(int vidx, const expr_t &varg) {
1196 ir_assert(nvargs_ + 1 <= max_nvargs);
1197 vidxs_[nvargs_] = vidx;
1198 vstrides_[nvargs_] = compute_stride(expr_, nvargs_, varg);
1199 nvargs_++;
1200 }
1201
placeholder_var()1202 static const expr_t &placeholder_var() {
1203 static expr_t ph_var = var_t::make(type_t::s32(), "_ph");
1204 return ph_var;
1205 }
1206
1207 private:
1208 static const int max_nvargs = 2;
1209
1210 static stride_t compute_stride(const expr_t &e, int idx, const expr_t &var);
1211
1212 expr_t expr_;
1213
1214 int nvargs_ = 0;
1215 std::array<stride_t, max_nvargs> vstrides_;
1216 std::array<int, max_nvargs> vidxs_;
1217 expr_t mask_;
1218 };
1219
1220 class view_t {
1221 public:
1222 view_t() = default;
1223
view_t(const std::vector<expr_t> & vvars,int ntdims)1224 view_t(const std::vector<expr_t> &vvars, int ntdims)
1225 : vvars_(vvars)
1226 , vdims_(vvars.size())
1227 , vstart_(vvars.size())
1228 , tdims_(ntdims) {}
1229
1230 // Constructs view from a layout.
view_t(const layout_t & layout,const std::vector<expr_t> & _vvars={},uint32_t bound_check_mask=0)1231 explicit view_t(const layout_t &layout,
1232 const std::vector<expr_t> &_vvars = {},
1233 uint32_t bound_check_mask = 0)
1234 : vvars_(_vvars)
1235 , vdims_(layout.dims())
1236 , vstart_(layout.ndims(), 0)
1237 , tdims_(layout.ndims())
1238 , tlayout_(layout) {
1239 if (vvars_.empty()) vvars_ = create_vvars(layout.ndims());
1240 for (int i = 0; i < nvdims(); i++) {
1241 expr_t i_mask;
1242 if ((bound_check_mask & (1 << i)) != 0)
1243 i_mask = (placeholder_var() < layout.dim(i));
1244 set_tdim(i, vvars_[i], i_mask);
1245 }
1246 }
1247
vvars() const1248 const std::vector<expr_t> &vvars() const { return vvars_; }
1249
vdims() const1250 const std::vector<dim_t> &vdims() const { return vdims_; }
1251
vstart(int vidx) const1252 expr_t vstart(int vidx) const { return vstart_[vidx]; }
1253
tlayout() const1254 const layout_t tlayout() const { return tlayout_; }
1255
nvdims() const1256 int nvdims() const { return int(vdims_.size()); }
1257
ntdims() const1258 int ntdims() const { return int(tdims_.size()); }
1259
velems() const1260 dim_t velems() const {
1261 dim_t ret = 1;
1262 for (int i = 0; i < nvdims(); i++)
1263 ret *= vdims_[i];
1264 return ret;
1265 }
1266
vvar(int idx) const1267 const expr_t &vvar(int idx) const {
1268 ir_assert(idx < nvdims());
1269 return vvars_[idx];
1270 }
1271
tdim(int idx) const1272 const tdim_info_t &tdim(int idx) const {
1273 ir_assert(idx < ntdims());
1274 return tdims_[idx];
1275 }
1276
set_tdim(int tidx,const expr_t & _texpr,expr_t mask={})1277 void set_tdim(int tidx, const expr_t &_texpr, expr_t mask = {}) {
1278 ir_assert(tdims_[tidx].is_empty());
1279
1280 auto texpr = simplify(_texpr);
1281 ir_assert(!is_const(texpr)) << "Tensor dimension can't be a constant.";
1282
1283 tdim_info_t tdim(texpr, mask);
1284 for (int i = 0; i < nvdims(); i++) {
1285 if (contains_object(texpr, vvars_[i])) tdim.add_vvar(i, vvars_[i]);
1286 }
1287 ir_assert(tdim.nvargs() > 0)
1288 << "Tensor dimension must have at least one "
1289 "view dimension that maps to it.";
1290 tdims_[tidx] = tdim;
1291 }
1292
set_vdim(const expr_t & varg,dim_t vdim,const expr_t & vstart=expr_t (0))1293 void set_vdim(
1294 const expr_t &varg, dim_t vdim, const expr_t &vstart = expr_t(0)) {
1295 int vidx = vvar_index(varg);
1296 ir_assert(vstart_[vidx].is_empty());
1297 vstart_[vidx] = vstart;
1298 vdims_[vidx] = vdim;
1299 }
1300
set_tlayout(const layout_t & tlayout)1301 void set_tlayout(const layout_t &tlayout) { tlayout_ = tlayout; }
1302
str() const1303 std::string str() const {
1304 using ir_utils::operator<<;
1305
1306 if (is_empty()) return "(nil)";
1307 std::ostringstream oss;
1308 oss << ir_utils::make_seq_print_helper(vdims_, "x");
1309 if (!has_zero_vstart()) oss << " vstart: [" << vstart_ << "]";
1310 oss << " tlayout: " << tlayout_;
1311 return oss.str();
1312 }
1313
IR_DEFINE_DUMP()1314 IR_DEFINE_DUMP()
1315
1316 bool is_empty() const { return vdims_.empty(); }
1317
has_zero_vstart() const1318 bool has_zero_vstart() const {
1319 for (int i = 0; i < nvdims(); i++)
1320 if (!is_zero(vstart_[i])) return false;
1321 return true;
1322 }
1323
has_tmask(int tidx) const1324 bool has_tmask(int tidx) const {
1325 ir_assert(tidx >= 0 && tidx < ntdims());
1326 return !tdims_[tidx].mask().is_empty();
1327 }
1328
type() const1329 const type_t &type() const { return tlayout_.type(); }
1330
offset(const std::vector<expr_t> & vargs={},bool ignore_offset=false) const1331 expr_t offset(const std::vector<expr_t> &vargs = {},
1332 bool ignore_offset = false) const {
1333 auto targs = cvt_vargs_to_targs(vargs);
1334 return tlayout_.offset(targs, ignore_offset);
1335 }
1336
offset_in_bytes(const std::vector<expr_t> & vargs={},bool ignore_offset=false) const1337 expr_t offset_in_bytes(const std::vector<expr_t> &vargs = {},
1338 bool ignore_offset = false) const {
1339 return offset(vargs, ignore_offset) * type().size();
1340 }
1341
vvar_index(const expr_t & vvar) const1342 int vvar_index(const expr_t &vvar) const {
1343 for (size_t i = 0; i < vvars_.size(); i++)
1344 if (vvar.is_same(vvars_[i])) return int(i);
1345 ir_error_not_expected() << "Can't find view dimension.";
1346 return -1;
1347 }
1348
1349 template <typename T>
operator ()(const std::vector<T> & vargs) const1350 T operator()(const std::vector<T> &vargs) const {
1351 auto targs = cvt_vargs_to_targs(vargs);
1352 return tlayout_(targs);
1353 }
1354
1355 view_t create_sub_view(const tensor_t &sub_tensor) const;
1356
retype(const type_t & new_type) const1357 view_t retype(const type_t &new_type) const {
1358 auto ret = *this;
1359 ret.tlayout_ = tlayout_.retype(new_type);
1360 return ret;
1361 }
1362
make_dense() const1363 view_t make_dense() const {
1364 auto ret = *this;
1365 ret.tlayout_ = tlayout_.make_dense();
1366 return ret;
1367 }
1368
can_convert_to_vlayout() const1369 bool can_convert_to_vlayout() const {
1370 if (nvdims() != ntdims()) return false;
1371 for (int i = 0; i < nvdims(); i++) {
1372 if (!tdims_[i].expr().is_same(vvars_[i])) return false;
1373 if (!tdims_[i].is_fixed_stride(0)) return false;
1374 }
1375 return true;
1376 }
1377
1378 // FIXME: Offset of the returned layout is always 0.
create_pseudo_vlayout() const1379 layout_t create_pseudo_vlayout() const {
1380 return create_pseudo_vlayout(tlayout_);
1381 }
1382
create_dense_vlayout() const1383 layout_t create_dense_vlayout() const {
1384 return create_pseudo_vlayout().make_dense();
1385 }
1386
create_vlayout(bool force_zero_offset=false) const1387 layout_t create_vlayout(bool force_zero_offset = false) const {
1388 ir_assert(can_convert_to_vlayout()) << "Can't convert view to layout.";
1389 if (force_zero_offset) return tlayout_.map(tensor_t(vdims_));
1390 return tlayout_.map(tensor_t(vdims_, vstart_));
1391 }
1392
vlayout_size() const1393 dim_t vlayout_size() const { return create_vlayout().size(); }
1394
has_same_vlayout(const view_t & other,bool compare_offset=true) const1395 bool has_same_vlayout(
1396 const view_t &other, bool compare_offset = true) const {
1397 return create_vlayout().is_equal(
1398 other.create_vlayout(), compare_offset);
1399 }
1400
split(const grid_info_t & grid,tensor_t & vtile) const1401 view_t split(const grid_info_t &grid, tensor_t &vtile) const {
1402 auto vlayout = create_pseudo_vlayout();
1403 vtile = vlayout.split(grid);
1404 return create_sub_view(vtile);
1405 }
1406
split(const grid_info_t & grid) const1407 view_t split(const grid_info_t &grid) const {
1408 tensor_t vtile;
1409 return split(grid, vtile);
1410 }
1411
1412 // Tile is assumed to be dense.
split_into_dense_tile(dim_t & tile_elems,dim_t & outer_block) const1413 tensor_t split_into_dense_tile(
1414 dim_t &tile_elems, dim_t &outer_block) const {
1415 auto vlayout = create_pseudo_vlayout();
1416 std::vector<dim_t> blocks = {tile_elems, outer_block};
1417 vlayout = vlayout.split_into_multi_blocks_with_hint(blocks);
1418 if (vlayout.is_empty()) return tensor_t();
1419 tile_elems = blocks[0];
1420 outer_block = blocks[1];
1421 return vlayout.split_into_dense_tile(tile_elems, outer_block);
1422 }
1423
1424 // Returns a tensor corresponding to the biggest innermost sub-layout so that
1425 // 1) It consists of consecutive blocks only.
1426 // 2) It contains less or equal than max_tile_elems elements.
1427 // 3) It is dense if is_dense_tile is true.
split_into_max_tile(dim_t max_tile_elems,bool is_dense_tile) const1428 tensor_t split_into_max_tile(
1429 dim_t max_tile_elems, bool is_dense_tile) const {
1430 auto vlayout = create_pseudo_vlayout();
1431 return vlayout.split_into_max_tile(max_tile_elems, is_dense_tile);
1432 }
1433
1434 template <typename F>
for_each_tile(const tensor_t & tile,const F & f) const1435 void for_each_tile(const tensor_t &tile, const F &f) const {
1436 auto vlayout = create_dense_vlayout();
1437 vlayout.for_each_tile(tile, f);
1438 }
1439
1440 view_t substitute(const expr_t &from, const expr_t &to) const;
1441
create_mask_tensor(const constraint_set_t & cset) const1442 mask_tensor_t create_mask_tensor(const constraint_set_t &cset) const {
1443 auto _vlayout = create_dense_vlayout();
1444 mask_tensor_t mask_tensor(_vlayout);
1445 std::vector<dim_t> vargs(nvdims());
1446 create_mask_tensor(mask_tensor, _vlayout, 0, vargs);
1447 mask_tensor.simplify(cset);
1448 return mask_tensor;
1449 }
1450
try_create_buffer_view(view_t & buf_view,view_t & inv_view) const1451 bool try_create_buffer_view(view_t &buf_view, view_t &inv_view) const {
1452 buf_view = view_t(create_vvars(ntdims()), ntdims());
1453 inv_view = view_t(vvars(), ntdims());
1454 for (int i = 0; i < nvdims(); i++) {
1455 inv_view.set_vdim(vvars()[i], vdims()[i]);
1456 }
1457 for (int i = 0; i < ntdims(); i++) {
1458 auto &tdim = tdims_[i];
1459 auto &buf_vvar = buf_view.vvars()[i];
1460 if (tdim.is_identity()) {
1461 int vidx = tdim.vidx(0);
1462 buf_view.set_vdim(buf_vvar, vdims()[vidx], vstart(vidx));
1463 buf_view.set_tdim(i, buf_vvar, tdim.mask());
1464 inv_view.set_tdim(i, tdim.expr());
1465 continue;
1466 }
1467 int buf_vdim = 0;
1468 bool ok = true;
1469 for (int j = 0; j < tdim.nvargs(); j++) {
1470 int vidx = tdim.vidx(j);
1471 auto &vvar = vvars()[vidx];
1472 int vdim = vdims()[vidx];
1473 if (vdim == 1) continue;
1474 auto A = tdim.expr();
1475 auto B = jit::substitute(A, vvar, vvar + 1);
1476 auto C = simplify(B - A);
1477 if (!is_const(C)) {
1478 ok = false;
1479 break;
1480 }
1481 buf_vdim += to_cpp<int>(C) * (vdim - 1);
1482 }
1483 buf_vdim++;
1484
1485 if (!ok) return false;
1486
1487 auto buf_vstart = tdim.expr();
1488 auto inv_vstart = tdim.expr();
1489 for (int j = 0; j < tdim.nvargs(); j++) {
1490 int vidx = tdim.vidx(j);
1491 buf_vstart = jit::substitute(
1492 buf_vstart, vvars()[vidx], vstart(vidx));
1493 inv_vstart
1494 = jit::substitute(inv_vstart, vvars()[vidx], expr_t(0));
1495 }
1496 buf_vstart = simplify(buf_vstart);
1497 inv_vstart = simplify(inv_vstart);
1498
1499 if (!is_const(inv_vstart)) return false;
1500
1501 buf_view.set_vdim(buf_vvar, buf_vdim, buf_vstart);
1502
1503 // Check that mask doesn't contain vvars - they can't be accessed
1504 // in the buffered view.
1505 auto &tmask = tdim.mask();
1506 for (auto &vvar : vvars()) {
1507 if (contains_object(tmask, vvar)) { return false; }
1508 }
1509
1510 buf_view.set_tdim(i, buf_vvar, tmask);
1511 inv_view.set_tdim(i, tdim.expr() - inv_vstart);
1512 }
1513 buf_view.set_tlayout(tlayout_);
1514 return true;
1515 }
1516
placeholder_var()1517 static const expr_t &placeholder_var() {
1518 return tdim_info_t::placeholder_var();
1519 }
1520
1521 static std::vector<expr_t> create_vvars(int nvdims);
1522
1523 private:
1524 template <typename SrcT = expr_t, typename DstT = SrcT>
cvt_vargs_to_targs(const std::vector<SrcT> & _vargs={}) const1525 std::vector<DstT> cvt_vargs_to_targs(
1526 const std::vector<SrcT> &_vargs = {}) const {
1527 std::vector<expr_t> vargs = expr_cast<expr_t>(_vargs);
1528 if (vargs.empty()) vargs.resize(nvdims(), 0);
1529
1530 for (int i = 0; i < nvdims(); i++) {
1531 if (!is_zero(vstart_[i])) vargs[i] += vstart_[i];
1532 }
1533
1534 std::vector<expr_t> targs(ntdims());
1535 for (int i = 0; i < ntdims(); i++) {
1536 targs[i] = tdims_[i].expr();
1537 for (int j = 0; j < nvdims(); j++) {
1538 targs[i] = jit::substitute(targs[i], vvars_[j], vargs[j]);
1539 }
1540 }
1541 for (int i = 0; i < ntdims(); i++) {
1542 targs[i] = const_fold(targs[i]);
1543 }
1544 return expr_cast<DstT>(targs);
1545 }
1546
1547 layout_t create_pseudo_vlayout(const layout_t &tlayout) const;
1548
create_mask_tensor(mask_tensor_t & mask_tensor,const layout_t & _vlayout,int vidx,std::vector<dim_t> & vargs) const1549 void create_mask_tensor(mask_tensor_t &mask_tensor,
1550 const layout_t &_vlayout, int vidx,
1551 std::vector<dim_t> &vargs) const {
1552 if (vidx == _vlayout.ndims()) {
1553 bool is_init = false;
1554 std::vector<expr_t> vvalues;
1555 std::vector<expr_t> targs;
1556 expr_t mask = bool_imm_t::make(true);
1557 for (int i = 0; i < ntdims(); i++) {
1558 auto &tdim = tdims_[i];
1559 if (tdim.mask().is_empty()) continue;
1560 if (!is_init) {
1561 // Lazily initialize values
1562 vvalues = vstart_;
1563 for (int i = 0; i < nvdims(); i++)
1564 vvalues[i] += vargs[i];
1565 targs = cvt_vargs_to_targs<dim_t, expr_t>(vargs);
1566 is_init = true;
1567 }
1568 mask &= tdim.mask(targs[i], vvars_, vvalues);
1569 }
1570 mask_tensor.set_mask(_vlayout(vargs), mask);
1571 return;
1572 }
1573
1574 for (int i = 0; i < vdims()[vidx]; i++) {
1575 vargs[vidx] = i;
1576 create_mask_tensor(mask_tensor, _vlayout, vidx + 1, vargs);
1577 }
1578 }
1579
1580 std::vector<expr_t> vvars_;
1581 std::vector<dim_t> vdims_;
1582 std::vector<expr_t> vstart_;
1583
1584 std::vector<tdim_info_t> tdims_;
1585 layout_t tlayout_;
1586 };
1587
operator <<(std::ostream & out,const view_t & view)1588 inline std::ostream &operator<<(std::ostream &out, const view_t &view) {
1589 out << view.str();
1590 return out;
1591 }
1592
1593 class dim_assignment_t {
1594 public:
1595 dim_assignment_t() = default;
1596
dim_assignment_t(int old_ndims,int new_ndims)1597 dim_assignment_t(int old_ndims, int new_ndims)
1598 : old_ndims_(old_ndims)
1599 , new_ndims_(new_ndims)
1600 , assignments_(old_ndims, -1) {}
1601
assign(int old_idx,int new_idx)1602 void assign(int old_idx, int new_idx) {
1603 ir_assert(0 <= old_idx && old_idx < old_ndims_);
1604 ir_assert(0 <= new_idx && new_idx < new_ndims_);
1605 assignments_[old_idx] = new_idx;
1606 }
1607
assign(const std::vector<int> & old_idxes,int new_idx)1608 void assign(const std::vector<int> &old_idxes, int new_idx) {
1609 for (auto old_idx : old_idxes) {
1610 assign(old_idx, new_idx);
1611 }
1612 }
1613
operator [](int old_idx) const1614 int operator[](int old_idx) const {
1615 ir_assert(old_idx >= 0 && old_idx < old_ndims());
1616 return assignments_[old_idx];
1617 }
1618
old_ndims() const1619 int old_ndims() const { return old_ndims_; }
1620
new_ndims() const1621 int new_ndims() const { return new_ndims_; }
1622
is_empty() const1623 bool is_empty() const { return old_ndims_ == 0 && new_ndims_ == 0; }
1624
1625 layout_t map(const layout_t &layout) const;
1626
1627 private:
1628 int old_ndims_ = 0;
1629 int new_ndims_ = 0;
1630
1631 // assignments_[old_idx] = new_idx.
1632 std::vector<int> assignments_;
1633 };
1634
1635 std::vector<dim_t> normalize_conv_dims(std::vector<dim_t> &dims,
1636 bool with_groups, int groups, bool is_dw, bool reduced_to_1d,
1637 bool add_groups, bool is_wei);
1638
1639 layout_t normalize_conv_layout(const layout_t &_layout, bool with_groups,
1640 int groups, bool is_dw, bool reduced_to_1d, bool add_groups,
1641 bool is_wei);
1642
1643 void normalize_conv_layouts(layout_t &src_layout, layout_t &wei_layout,
1644 layout_t &dst_layout, bool with_groups, int groups, bool is_dw,
1645 bool reduced_to_1d, bool add_groups);
1646
1647 } // namespace jit
1648 } // namespace gpu
1649 } // namespace impl
1650 } // namespace dnnl
1651
1652 #endif
1653