1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file storage_rewrite.cc
22 * \brief Memory access pattern analysis and optimization.
23 * Re-write data access to enable memory sharing when possible.
24 */
25 #include <tvm/ir.h>
26 #include <tvm/ir_pass.h>
27 #include <tvm/ir_mutator.h>
28 #include <tvm/ir_visitor.h>
29 #include <tvm/target_info.h>
30 #include <map>
31 #include <unordered_set>
32 #include <unordered_map>
33 #include "ir_util.h"
34 #include "../arithmetic/compute_expr.h"
35 #include "../runtime/thread_storage_scope.h"
36
37 namespace tvm {
38 namespace ir {
39
40 using runtime::StorageRank;
41 using runtime::StorageScope;
42
43 // Find a linear pattern of storage access
44 // Used for liveness analysis.
45 // Composite scopes(loop/thread_launch/IfThen) is represented by two points:
46 // before_scope -> scope_body -> after_scope
47 //
48 // The linear_seq_ stores before_scope and after_scope.
49 // The access to the arrays are stored at the after_scope point.
50 //
51 // Define "scope" as the body of For/thread_launch/IfThenElse
52 // This pass tries to detect last point that we need to keep memory
53 // alive under the same scope as allocate.
54 // The storage need to be kept alive between allocate and last access.
55 // The free point is only inserted at the same scope of allocate.
56 //
57 class LinearAccessPatternFinder final : public IRVisitor {
58 public:
59 /*! \brief record the touch hist of statment. */
60 struct StmtEntry {
61 // The statment
62 const Node* stmt;
63 // The index in the linear_seq_ to point to end of the nested scope.
64 // This is only set to non-zero if stmt is a nested scope.
65 // if offset > 0, means this is the begin, the end entry is current_index + offset
66 // if offset < 0, means this is the end, the begin entry is current_index + offset
67 int64_t scope_pair_offset{0};
68 // The buffer variables this statment touched.
69 std::vector<const Variable*> touched;
70 };
71 // The scope of each allocation
72 struct AllocEntry {
73 // Scope used for allocation.
74 StorageScope storage_scope;
75 // scope level
76 size_t level{0};
77 // allocation stmt
78 const Allocate* alloc{nullptr};
79 };
80
Visit_(const Allocate * op)81 void Visit_(const Allocate* op) final {
82 size_t level = scope_.size();
83 const Variable* buf = op->buffer_var.get();
84 auto it = alloc_info_.find(buf);
85 CHECK(it != alloc_info_.end());
86 CHECK(it->second.alloc == nullptr);
87 it->second.alloc = op;
88 it->second.level = level;
89 IRVisitor::Visit_(op);
90 }
Visit_(const Store * op)91 void Visit_(const Store* op) final {
92 scope_.push_back(StmtEntry());
93 // visit subexpr
94 IRVisitor::Visit_(op);
95 // Add write access.
96 const Variable* buf = op->buffer_var.get();
97 auto it = alloc_info_.find(buf);
98 if (it != alloc_info_.end() && it->second.alloc) {
99 CHECK_LT(it->second.level, scope_.size());
100 scope_[it->second.level].touched.push_back(buf);
101 }
102 StmtEntry e = scope_.back();
103 scope_.pop_back();
104 if (e.touched.size() != 0) {
105 e.stmt = op;
106 linear_seq_.push_back(e);
107 }
108 }
Visit_(const Evaluate * op)109 void Visit_(const Evaluate* op) final {
110 scope_.push_back(StmtEntry());
111 // visit subexpr
112 IRVisitor::Visit_(op);
113 StmtEntry e = scope_.back();
114 scope_.pop_back();
115 if (e.touched.size() != 0) {
116 e.stmt = op;
117 linear_seq_.push_back(e);
118 }
119 }
Visit_(const Load * op)120 void Visit_(const Load* op) final {
121 // Add write access.
122 IRVisitor::Visit_(op);
123 const Variable* buf = op->buffer_var.get();
124 auto it = alloc_info_.find(buf);
125 if (it != alloc_info_.end() && it->second.alloc) {
126 CHECK_LT(it->second.level, scope_.size())
127 << "Load memory in places other than store.";
128 scope_[it->second.level].touched.push_back(buf);
129 }
130 }
Visit_(const Call * op)131 void Visit_(const Call* op) final {
132 if (op->is_intrinsic(intrinsic::tvm_address_of)) {
133 const Load* l = op->args[0].as<Load>();
134 this->Visit(l->index);
135 } else {
136 IRVisitor::Visit_(op);
137 }
138 }
Visit_(const Variable * buf)139 void Visit_(const Variable* buf) final {
140 // Directly reference to the variable count as a read.
141 auto it = alloc_info_.find(buf);
142 if (it != alloc_info_.end() && it->second.alloc) {
143 CHECK_LT(it->second.level, scope_.size())
144 << " buf=" << buf->name_hint;
145 scope_[it->second.level].touched.push_back(buf);
146 }
147 }
148 template<typename T>
VisitNewScope(const T * op)149 void VisitNewScope(const T* op) {
150 scope_.push_back(StmtEntry());
151 StmtEntry e;
152 e.stmt = op;
153 int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
154 // before scope.
155 linear_seq_.push_back(e);
156 IRVisitor::Visit_(op);
157 // after scope.
158 e.touched = std::move(scope_.back().touched);
159 scope_.pop_back();
160 int64_t end_index = static_cast<int64_t>(linear_seq_.size());
161 CHECK_GT(end_index, begin_index);
162 e.scope_pair_offset = begin_index - end_index;
163 linear_seq_.push_back(e);
164 // record the pointer to end index.
165 CHECK_NE(end_index, 0U);
166 linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
167 }
Visit_(const AttrStmt * op)168 void Visit_(const AttrStmt* op) final {
169 // Only record the outer most thread extent.
170 if (op->attr_key == attr::thread_extent && !in_thread_env_) {
171 in_thread_env_ = true;
172 VisitNewScope(op);
173 in_thread_env_ = false;
174 } else if (op->attr_key == attr::extern_scope) {
175 VisitNewScope(op);
176 } else if (op->attr_key == attr::virtual_thread) {
177 VisitNewScope(op);
178 } else if (op->attr_key == attr::storage_scope) {
179 const Variable* buf = op->node.as<Variable>();
180 alloc_info_[buf].storage_scope =
181 StorageScope::make(op->value.as<StringImm>()->value);
182 IRVisitor::Visit_(op);
183 } else {
184 IRVisitor::Visit_(op);
185 }
186 }
Visit_(const IfThenElse * op)187 void Visit_(const IfThenElse* op) final {
188 VisitNewScope(op);
189 }
190
Visit_(const For * op)191 void Visit_(const For* op) final {
192 VisitNewScope(op);
193 }
194
Visit_(const AssertStmt * op)195 void Visit_(const AssertStmt* op) final {
196 VisitNewScope(op);
197 }
198
199 // linearized access sequence.
200 std::vector<StmtEntry> linear_seq_;
201 // The storage scope of each buffer
202 std::unordered_map<const Variable*, AllocEntry> alloc_info_;
203
204 private:
205 // Whether already in thread env.
206 bool in_thread_env_{false};
207 // The scope stack.
208 std::vector<StmtEntry> scope_;
209 };
210
211 // Verify if the statement can be run safely via inplace fashion
212 //
213 // Detect pattern: dst[index] = f(src[index])
214 //
215 // WARNING: the current detection algorithm cannot handle the case
216 // when a location in an array is written multiple times
217 //
218 // For example, the following program will pass the check,
219 // but we cannot make A and B to be the same array.
220 //
221 // A[0] = B[0] + 1
222 // A[0] = B[0] + 1
223 //
224 // The high level code generator needs to ensure that the generated
225 // code only write each location of the target array once.
226 //
227 // This is the case with IR generated by the current compute schedule.
228 // We explicitly return false if we find there is an extern block
229 // which can be arbitrary IR.
230 //
231 // Neve-the-less, inplace detector should be used with care in mind.
232 // We may also consider introduce a condition checker that checks
233 // if every index only visited once for an absolute sufficient condition.
234 //
235 // The code after inplace transformation is no longer idempotent.
236 //
237 class InplaceOpVerifier : public IRVisitor {
238 public:
Check(const Node * stmt,const Variable * dst,const Variable * src)239 bool Check(const Node* stmt,
240 const Variable* dst,
241 const Variable* src) {
242 dst_ = dst;
243 src_ = src;
244 result_ = true;
245 if (stmt->IsInstance<AttrStmt>()) {
246 Visit_(static_cast<const AttrStmt*>(stmt));
247 } else if (stmt->IsInstance<For>()) {
248 Visit_(static_cast<const For*>(stmt));
249 } else if (stmt->IsInstance<IfThenElse>()) {
250 Visit_(static_cast<const IfThenElse*>(stmt));
251 } else if (stmt->IsInstance<Store>()) {
252 Visit_(static_cast<const Store*>(stmt));
253 } else {
254 return false;
255 }
256 return result_;
257 }
258
259 using IRVisitor::Visit_;
260
Visit(const NodeRef & e)261 void Visit(const NodeRef& e) final {
262 if (!result_) return;
263 IRVisitor::Visit(e);
264 }
265
Visit_(const Variable * op)266 void Visit_(const Variable* op) final {
267 // assume all opaque access is unsafe
268 if (op == dst_ || op == src_) {
269 result_ = false; return;
270 }
271 }
272
Visit_(const Store * op)273 void Visit_(const Store* op) final {
274 ++mem_nest_;
275 this->Visit(op->index);
276 --mem_nest_;
277 if (op->buffer_var.get() == dst_) {
278 store_ = op;
279 this->Visit(op->value);
280 this->Visit(op->predicate);
281 store_ = nullptr;
282 } else {
283 this->Visit(op->value);
284 this->Visit(op->predicate);
285 }
286 }
287
Visit_(const AttrStmt * op)288 void Visit_(const AttrStmt* op) final {
289 // always reject extern code
290 if (op->attr_key == attr::extern_scope ||
291 op->attr_key == attr::volatile_scope) {
292 result_ = false; return;
293 }
294 IRVisitor::Visit_(op);
295 }
296
Visit_(const Load * op)297 void Visit_(const Load* op) final {
298 const Variable* buf = op->buffer_var.get();
299 // cannot read from dst_ (no reduction)
300 if (buf == dst_) {
301 result_ = false; return;
302 }
303 // do not allow indirect memory load
304 if (mem_nest_ != 0) {
305 result_ = false; return;
306 }
307 if (src_ == buf) {
308 if (store_ == nullptr ||
309 store_->value.type() != op->type ||
310 !ir::Equal(store_->index, op->index)) {
311 result_ = false; return;
312 }
313 }
314 ++mem_nest_;
315 IRVisitor::Visit_(op);
316 --mem_nest_;
317 }
318
319
320 private:
321 // result of the check
322 bool result_{true};
323 // destination memory
324 const Variable* dst_;
325 // source variable
326 const Variable* src_;
327 // counter of load,
328 // it is not safe to inplace when there is nested load like A[B[i]]
329 int mem_nest_{0};
330 // The current store to be inspected
331 const Store* store_{nullptr};
332 };
333
334 // Planner to plan and rewrite memory allocation.
335 class StoragePlanRewriter : public IRMutator {
336 public:
337 using StmtEntry = LinearAccessPatternFinder::StmtEntry;
338 using AllocEntry = LinearAccessPatternFinder::AllocEntry;
339
Rewrite(Stmt stmt,bool detect_inplace)340 Stmt Rewrite(Stmt stmt, bool detect_inplace) {
341 detect_inplace_ = detect_inplace;
342 // plan the rewrite
343 LinearAccessPatternFinder finder;
344 finder.Visit(stmt);
345 this->LivenessAnalysis(finder.linear_seq_);
346 this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
347 this->PrepareNewAlloc();
348 // start rewrite
349 stmt = this->Mutate(stmt);
350 if (attach_map_.count(nullptr)) {
351 std::vector<Stmt> nest;
352 for (StorageEntry* e : attach_map_.at(nullptr)) {
353 // CHECK_EQ(e->scope.rank, 0);
354 if (e->new_alloc.defined()) {
355 nest.emplace_back(AttrStmt::make(
356 e->alloc_var, attr::storage_scope,
357 StringImm::make(e->scope.to_string()),
358 Evaluate::make(0)));
359 nest.push_back(e->new_alloc);
360 }
361 }
362 stmt = MergeNest(nest, stmt);
363 }
364 return stmt;
365 }
Mutate_(const Store * op,const Stmt & s)366 Stmt Mutate_(const Store* op, const Stmt& s) final {
367 Stmt stmt = IRMutator::Mutate_(op, s);
368 op = stmt.as<Store>();
369 auto it = alloc_map_.find(op->buffer_var.get());
370 if (it == alloc_map_.end()) return stmt;
371 return Store::make(it->second->alloc_var,
372 op->value,
373 RemapIndex(op->value.type(), op->index, it->second),
374 op->predicate);
375 }
Mutate_(const Load * op,const Expr & e)376 Expr Mutate_(const Load* op, const Expr& e) final {
377 Expr expr = IRMutator::Mutate_(op, e);
378 op = expr.as<Load>();
379 auto it = alloc_map_.find(op->buffer_var.get());
380 if (it == alloc_map_.end()) return expr;
381 return Load::make(op->type,
382 it->second->alloc_var,
383 RemapIndex(op->type, op->index, it->second),
384 op->predicate);
385 }
Mutate_(const Variable * op,const Expr & e)386 Expr Mutate_(const Variable* op, const Expr& e) final {
387 auto it = alloc_map_.find(op);
388 if (it != alloc_map_.end()) {
389 if (it->second->bits_offset != 0) {
390 LOG(WARNING) << "Use a merged buffer variable address, could cause error";
391 }
392 return it->second->alloc_var;
393 } else {
394 return e;
395 }
396 }
Mutate_(const Call * op,const Expr & e)397 Expr Mutate_(const Call* op, const Expr& e) final {
398 if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
399 CHECK_EQ(op->args.size(), 5U);
400 Type dtype = op->args[0].type();
401 const Variable* buffer = op->args[1].as<Variable>();
402 auto it = alloc_map_.find(buffer);
403 if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
404 const StorageEntry* se = it->second;
405 Expr offset = Mutate(op->args[2]);
406 Expr extent = Mutate(op->args[3]);
407 uint64_t elem_bits = dtype.bits() * dtype.lanes();
408 CHECK_EQ(se->bits_offset % elem_bits, 0U);
409 if (se->bits_offset != 0) {
410 offset = make_const(offset.type(), se->bits_offset / elem_bits) + offset;
411 }
412 return Call::make(
413 op->type, op->name,
414 {op->args[0], se->alloc_var, offset, extent, op->args[4]},
415 op->call_type);
416 } else {
417 return IRMutator::Mutate_(op, e);
418 }
419 }
420
Mutate_(const AttrStmt * op,const Stmt & s)421 Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
422 if (op->attr_key == attr::storage_scope) {
423 return this->Mutate(op->body);
424 } else if (op->attr_key == attr::thread_extent ||
425 op->attr_key == attr::virtual_thread ||
426 attr::IsPragmaKey(op->attr_key)) {
427 // remake all the allocation at the attach scope.
428 if (attach_map_.count(op)) {
429 auto& svec = attach_map_[op];
430 Stmt stmt = IRMutator::Mutate_(op, s);
431 op = stmt.as<AttrStmt>();
432 return AttrStmt::make(
433 op->node, op->attr_key, op->value,
434 MakeAttach(svec, op->body));
435 } else {
436 return IRMutator::Mutate_(op, s);
437 }
438 } else if (op->attr_key == attr::volatile_scope) {
439 Stmt stmt = IRMutator::Mutate_(op, s);
440 op = stmt.as<AttrStmt>();
441 auto it = alloc_map_.find(op->node.as<Variable>());
442 if (it == alloc_map_.end()) return stmt;
443 return AttrStmt::make(
444 it->second->alloc_var, op->attr_key, op->value, op->body);
445 } else {
446 return IRMutator::Mutate_(op, s);
447 }
448 }
Mutate_(const For * op,const Stmt & s)449 Stmt Mutate_(const For* op, const Stmt& s) final {
450 CHECK(op->for_type != ForType::Vectorized)
451 << "VectorizeLoop before LiftStorageAlloc";
452 // remake all the allocation at the attach scope.
453 if (attach_map_.count(op)) {
454 auto& svec = attach_map_[op];
455 Stmt stmt = IRMutator::Mutate_(op, s);
456 op = stmt.as<For>();
457 return For::make(
458 op->loop_var, op->min, op->extent, op->for_type, op->device_api,
459 MakeAttach(svec, op->body));
460 } else {
461 return IRMutator::Mutate_(op, s);
462 }
463 }
464
Mutate_(const Allocate * op,const Stmt & s)465 Stmt Mutate_(const Allocate* op, const Stmt& s) final {
466 return this->Mutate(op->body);
467 }
468
469 private:
470 struct StorageEntry {
471 // The scope that this alloc attaches after
472 // For shared/local memory it is beginning of the thread extent.
473 // for global memory it is nullptr, means beginning of everything.
474 const Node* attach_scope_{nullptr};
475 // The constant size of the buffer in bits, only used if it is constant
476 uint64_t const_nbits{0};
477 // The storage scope.
478 StorageScope scope;
479 // Allocs that shares this entry.
480 std::vector<const Allocate*> allocs;
481 // The children of this entry, not including itself.
482 std::vector<StorageEntry*> merged_children;
483 // The replacement allocation, if any.
484 Stmt new_alloc;
485 // The var expr of new allocation.
486 VarExpr alloc_var;
487 // The allocation element type.
488 Type elem_type;
489 // This is non-zero if this allocate is folded into another one
490 // the address(in bits) becomes alloc_var + bits_offset;
491 // can be effectively converted to the element type.
492 // We need to convert bit_offset to offset of specific element type later.
493 //
494 // We use bits(instead of bytes) to support non-conventional indexing in hardware.
495 // When we are merging buffer together, the bits_offset are set to be aligned
496 // to certain value given by the max_simd_bits property of the special memory.
497 //
498 // This allows effective sharing among different types as long as their alignment
499 // requirement fits into the max_simd_bits.
500 uint64_t bits_offset{0};
501 };
502
503 // Alllocate entry of node.
504 // Event entry in liveness analysis
505 struct EventEntry {
506 // variables we generate
507 std::vector<const Variable*> gen;
508 // variables we kill
509 std::vector<const Variable*> kill;
510 };
511
MakeAttach(const std::vector<StorageEntry * > & svec,Stmt body)512 Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
513 Stmt body) {
514 std::vector<Stmt> nest;
515 for (StorageEntry* e : svec) {
516 if (e->new_alloc.defined()) {
517 nest.emplace_back(AttrStmt::make(
518 e->alloc_var, attr::storage_scope,
519 StringImm::make(e->scope.to_string()),
520 Evaluate::make(0)));
521 nest.push_back(e->new_alloc);
522 }
523 }
524 return MergeNest(nest, body);
525 }
526 // Remap the index
RemapIndex(Type dtype,Expr index,StorageEntry * e)527 Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
528 if (e->bits_offset == 0) return index;
529 uint64_t elem_bits = dtype.bits() * dtype.lanes();
530 CHECK_EQ(e->bits_offset % elem_bits, 0U);
531 return make_const(index.type(), e->bits_offset / elem_bits) + index;
532 }
533 // Prepare the new allocations
PrepareNewAlloc()534 void PrepareNewAlloc() {
535 for (size_t i = 0; i < alloc_vec_.size(); ++i) {
536 StorageEntry* e = alloc_vec_[i].get();
537 attach_map_[e->attach_scope_].push_back(e);
538 }
539 // find allocation via attach map.
540 for (auto &kv : attach_map_) {
541 // find the element with the most amount of bytes.
542 std::vector<StorageEntry*>& vec = kv.second;
543 // try to find merge, for tagged memory
544 for (size_t i = 0; i < vec.size(); ++i) {
545 StorageEntry* e = vec[i];
546 if (e->scope.tag.length() != 0) {
547 CHECK_NE(e->const_nbits, 0U)
548 << "Special tagged memory must be const size";
549 for (size_t j = 0; j < i; ++j) {
550 if (e->scope == vec[j]->scope) {
551 vec[j]->merged_children.push_back(e);
552 break;
553 }
554 }
555 }
556 }
557 // Start allocation
558 for (size_t i = 0; i < vec.size(); ++i) {
559 StorageEntry* e = vec[i];
560 // already merged
561 if (e->bits_offset != 0) continue;
562 if (e->merged_children.size() != 0) {
563 NewAllocTagMerged(e); continue;
564 }
565 // Get the allocation size;
566 e->alloc_var = e->allocs[0]->buffer_var;
567 Type alloc_type = e->allocs[0]->type;
568 for (const Allocate* op : e->allocs) {
569 if (op->type.lanes() > alloc_type.lanes()) {
570 alloc_type = op->type;
571 }
572 }
573 if (e->allocs.size() == 1) {
574 // simply use the original allocation.
575 Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
576 make_const(Int(32), 1));
577 e->new_alloc = Allocate::make(
578 e->alloc_var, alloc_type, {sz},
579 e->allocs[0]->condition, Evaluate::make(0));
580 if (e->scope.tag.length() != 0) {
581 MemoryInfo info = GetMemoryInfo(e->scope.to_string());
582 uint64_t total_elem = e->const_nbits / e->elem_type.bits();
583 CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
584 << "Allocation exceed bound of memory tag " << e->scope.to_string();
585 }
586 } else {
587 // Build a merged allocation
588 Expr combo_size;
589 for (const Allocate* op : e->allocs) {
590 Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
591 auto nbits = op->type.bits() * op->type.lanes();
592 if (const auto* imm = sz.as<IntImm>()) {
593 if (imm->value > std::numeric_limits<int>::max() / nbits) {
594 LOG(WARNING) << "The allocation requires : " << imm->value
595 << " * " << nbits
596 << " bits, which is greater than the maximum of"
597 " int32. The size is cast to int64."
598 << "\n";
599 sz = make_const(Int(64), imm->value);
600 }
601 }
602 // transform to bits
603 auto sz_nbits = sz * nbits;
604 if (combo_size.defined()) {
605 combo_size = max(combo_size, sz_nbits);
606 } else {
607 combo_size = sz_nbits;
608 }
609 }
610 // transform to alloc bytes
611 auto type_bits = alloc_type.bits() * alloc_type.lanes();
612 bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
613 combo_size = indexdiv(combo_size, type_bits);
614 // round up for can not divided
615 if (!divided) {
616 combo_size = combo_size + make_const(Int(32), 1);
617 }
618 combo_size = ir::Simplify(combo_size);
619 e->new_alloc = Allocate::make(
620 e->alloc_var, alloc_type, {combo_size}, const_true(),
621 Evaluate::make(0));
622 if (e->scope.tag.length() != 0) {
623 MemoryInfo info = GetMemoryInfo(e->scope.to_string());
624 uint64_t total_elem = e->const_nbits / e->elem_type.bits();
625 CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
626 << "Allocation exceed bound of memory tag " << e->scope.to_string();
627 }
628 }
629 }
630 }
631 }
632 // New allocation for merged data
NewAllocTagMerged(StorageEntry * e)633 void NewAllocTagMerged(StorageEntry* e) {
634 CHECK_NE(e->scope.tag.length(), 0U);
635 // allocate with element type.
636 CHECK_NE(e->const_nbits, 0U);
637 MemoryInfo info = GetMemoryInfo(e->scope.to_string());
638 uint64_t total_bits = e->const_nbits;
639 // By default, align to 32 bits.
640 size_t align = 32;
641 if (info.defined()) {
642 align = info->max_simd_bits;
643 }
644 // Always align to max_simd_bits
645 // so we can remap types by keeping this property
646 if (total_bits % align != 0) {
647 total_bits += align - (total_bits % align);
648 }
649 e->alloc_var = e->allocs[0]->buffer_var;
650 for (StorageEntry* child : e->merged_children) {
651 CHECK_NE(child->const_nbits, 0U);
652 CHECK_NE(total_bits, 0U);
653 child->bits_offset = total_bits;
654 child->alloc_var = e->alloc_var;
655 total_bits += child->const_nbits;
656 if (total_bits % align != 0) {
657 total_bits += align - (total_bits % align);
658 }
659 }
660 uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
661 Expr alloc_size = make_const(e->allocs[0]->extents[0].type(),
662 (total_bits + type_bits - 1) / type_bits);
663 e->new_alloc = Allocate::make(
664 e->alloc_var, e->elem_type, {alloc_size}, const_true(),
665 Evaluate::make(0));
666 if (info.defined()) {
667 CHECK_LE(total_bits, info->max_num_bits)
668 << "Allocation exceed bound of memory tag " << e->scope.to_string();
669 }
670 }
671 // Liveness analysis to find gen and kill point of each variable.
LivenessAnalysis(const std::vector<StmtEntry> & seq)672 void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
673 // find kill point, do a reverse linear scan.
674 std::unordered_set<const Variable*> touched;
675 for (size_t i = seq.size(); i != 0; --i) {
676 const StmtEntry& s = seq[i - 1];
677 for (const Variable* buffer : s.touched) {
678 if (!touched.count(buffer)) {
679 touched.insert(buffer);
680 event_map_[s.stmt].kill.push_back(buffer);
681 }
682 }
683 }
684 // find gen point, do forward scan
685 touched.clear();
686 for (size_t i = 0; i < seq.size(); ++i) {
687 int64_t offset = seq[i].scope_pair_offset;
688 if (offset < 0) continue;
689 const StmtEntry& s = seq[i + offset];
690 for (const Variable* buffer : s.touched) {
691 if (!touched.count(buffer)) {
692 touched.insert(buffer);
693 event_map_[s.stmt].gen.push_back(buffer);
694 }
695 }
696 }
697 }
PlanNewScope(const Node * op)698 void PlanNewScope(const Node* op) {
699 if (thread_scope_ != nullptr) {
700 CHECK(thread_scope_ == op);
701 // erase all memory atatched to this scope.
702 for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
703 if (it->second->attach_scope_ == op) {
704 it = const_free_map_.erase(it);
705 } else {
706 ++it;
707 }
708 }
709 for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
710 if ((*it)->attach_scope_ == op) {
711 it = sym_free_list_.erase(it);
712 } else {
713 ++it;
714 }
715 }
716 thread_scope_ = nullptr;
717 } else {
718 thread_scope_ = op;
719 }
720 }
721
722 // Memory plan algorithm
PlanMemory(const std::vector<StmtEntry> & seq,const std::unordered_map<const Variable *,AllocEntry> & alloc_info)723 void PlanMemory(const std::vector<StmtEntry>& seq,
724 const std::unordered_map<const Variable*, AllocEntry>& alloc_info) {
725 std::unordered_set<const Variable*> inplace_flag;
726
727 for (size_t i = 0; i < seq.size(); ++i) {
728 const StmtEntry& s = seq[i];
729 auto it = event_map_.find(seq[i].stmt);
730
731 // scope_pair_offset >= 0 means it is either
732 // - leaf stmt(offset = 0)
733 // - beginning of scope(offset < 0)
734 // In both cases, we need to handle the gen event correctly
735 if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
736 // Inplace operation detection
737 // specially handle this
738 bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
739
740 for (const Variable* var : it->second.gen) {
741 CHECK(alloc_info.count(var));
742 const AllocEntry& ae = alloc_info.at(var);
743 StorageEntry* dst_entry = nullptr;
744 // inplace detection
745 if (detect_inplace) {
746 // only one inplace var for s.stmt
747 bool inplace_found = false;
748 for (const Variable* src : it->second.kill) {
749 if (!inplace_flag.count(src) && alloc_map_.count(src)) {
750 InplaceOpVerifier visitor;
751 StorageEntry* src_entry = alloc_map_.at(src);
752 if (src_entry->scope == ae.storage_scope &&
753 src_entry->attach_scope_ == thread_scope_ &&
754 src_entry->elem_type == ae.alloc->type.element_of() &&
755 visitor.Check(s.stmt, var, src)) {
756 uint64_t const_nbits =
757 static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
758 ae.alloc->type.bits() *
759 ae.alloc->type.lanes();
760 if (src_entry->const_nbits == const_nbits && !inplace_found) {
761 // successfully inplace
762 dst_entry = src_entry;
763 inplace_flag.insert(src);
764 inplace_found = true;
765 }
766 }
767 }
768 }
769 }
770 if (dst_entry == nullptr) {
771 dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope);
772 }
773 dst_entry->allocs.emplace_back(ae.alloc);
774 alloc_map_[var] = dst_entry;
775 }
776 }
777 // enter/exit new scope
778 if (s.stmt->IsInstance<AttrStmt>()) {
779 const auto* op = static_cast<const AttrStmt*>(s.stmt);
780 if (op->attr_key == attr::thread_extent ||
781 op->attr_key == attr::virtual_thread ||
782 attr::IsPragmaKey(op->attr_key)) {
783 PlanNewScope(op);
784 } else {
785 CHECK(op->attr_key == attr::extern_scope);
786 }
787 } else if (s.stmt->IsInstance<For>()) {
788 const auto* op = static_cast<const For*>(s.stmt);
789 if (op->for_type == ForType::Parallel) {
790 if (thread_scope_ == nullptr || thread_scope_ == op) {
791 PlanNewScope(op);
792 }
793 }
794 }
795 // scope_pair_offset <= 0 means it is either
796 // - leaf stmt(offset = 0)
797 // - end of scope(offset < 0)
798 // In both cases, we need to handle the kill event correctly
799 if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
800 for (const Variable* var : it->second.kill) {
801 // skip space which are already replaced by inplace
802 if (!inplace_flag.count(var)) {
803 this->Free(var);
804 }
805 }
806 }
807 }
808 }
809 // Allocate new storage entry.
NewAlloc(const Allocate * op,const Node * attach_scope,const StorageScope & scope,size_t const_nbits)810 StorageEntry* NewAlloc(const Allocate* op,
811 const Node* attach_scope,
812 const StorageScope& scope,
813 size_t const_nbits) {
814 CHECK(op != nullptr);
815 // Re-use not successful, allocate a new buffer.
816 std::unique_ptr<StorageEntry> entry(new StorageEntry());
817 entry->attach_scope_ = attach_scope;
818 entry->scope = scope;
819 entry->elem_type = op->type.element_of();
820 entry->const_nbits = const_nbits;
821 StorageEntry* e = entry.get();
822 alloc_vec_.emplace_back(std::move(entry));
823 return e;
824 }
825
FindAlloc(const Allocate * op,const Node * attach_scope,const StorageScope & scope)826 StorageEntry* FindAlloc(const Allocate* op,
827 const Node* attach_scope,
828 const StorageScope& scope) {
829 CHECK(op != nullptr);
830 // skip plan for local variable,
831 // compiler can do a better job with register allocation.
832 const uint64_t match_range = 16;
833 uint64_t op_elem_bits = op->type.bits() * op->type.lanes();
834 uint64_t const_nbits = static_cast<uint64_t>(
835 op->constant_allocation_size() * op_elem_bits);
836 // disable reuse of small arrays, they will be lowered to registers in LLVM
837 // This rules only apply if we are using non special memory
838 if (scope.tag.length() == 0) {
839 if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) {
840 return NewAlloc(op, attach_scope, scope, const_nbits);
841 }
842 if (const_nbits > 0 && const_nbits <= 32) {
843 return NewAlloc(op, attach_scope, scope, const_nbits);
844 }
845 }
846 if (const_nbits != 0) {
847 // constant allocation.
848 auto begin = const_free_map_.lower_bound(const_nbits / match_range);
849 auto mid = const_free_map_.lower_bound(const_nbits);
850 auto end = const_free_map_.upper_bound(const_nbits * match_range);
851 // start looking at the buffer that is bigger than the required size first
852 for (auto it = mid; it != end; ++it) {
853 StorageEntry *e = it->second;
854 if (e->attach_scope_ != attach_scope) continue;
855 if (e->scope != scope) continue;
856 // when not divided, no reuse, eg, float4 vs float3
857 if (e->bits_offset % op_elem_bits != 0) continue;
858 e->const_nbits = std::max(const_nbits, e->const_nbits);
859 const_free_map_.erase(it);
860 return e;
861 }
862 // then start looking at smaller buffers.
863 for (auto it = mid; it != begin;) {
864 --it;
865 StorageEntry *e = it->second;
866 if (e->attach_scope_ != attach_scope) continue;
867 if (e->scope != scope) continue;
868 if (e->elem_type != op->type.element_of()) continue;
869 e->const_nbits = std::max(const_nbits, e->const_nbits);
870 const_free_map_.erase(it);
871 return e;
872 }
873 } else {
874 // Simple strategy: round roubin.
875 for (auto it = sym_free_list_.begin();
876 it != sym_free_list_.end(); ++it) {
877 StorageEntry* e = *it;
878 if (e->attach_scope_ != attach_scope) continue;
879 if (e->scope != scope) continue;
880 if (e->elem_type != op->type.element_of()) continue;
881 sym_free_list_.erase(it);
882 return e;
883 }
884 }
885 return NewAlloc(op, attach_scope, scope, const_nbits);
886 }
887 // simulated free.
Free(const Variable * var)888 void Free(const Variable* var) {
889 auto it = alloc_map_.find(var);
890 CHECK(it != alloc_map_.end());
891 StorageEntry* e = it->second;
892 CHECK_NE(e->allocs.size(), 0U);
893
894 // disable reuse of small arrays, they will be lowered to registers in LLVM
895 // This rules only apply if we are using non special memory
896 if (e->scope.tag.length() == 0) {
897 // Disable sharing of local memory.
898 if (e->scope.rank >= StorageRank::kWarp ||
899 e->allocs[0]->type.is_handle()) return;
900 // disable reuse of small arrays
901 if (e->const_nbits > 0 && e->const_nbits <= 32) return;
902 }
903 // normal free.
904 if (e->const_nbits != 0) {
905 const_free_map_.insert({e->const_nbits, e});
906 } else {
907 sym_free_list_.push_back(e);
908 }
909 }
910 // thread scope.
911 const Node* thread_scope_{nullptr};
912 // whether enable inplace detection.
913 bool detect_inplace_{false};
914 // Locations of free ops.
915 std::unordered_map<const Node*, EventEntry> event_map_;
916 // constant size free map.
917 std::multimap<uint64_t, StorageEntry*> const_free_map_;
918 // symbolic free list, for non constant items.
919 std::list<StorageEntry*> sym_free_list_;
920 // The allocation attach map
921 std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
922 // The allocation assign map
923 std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
924 // The allocations
925 std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
926 // analyzer
927 arith::Analyzer analyzer_;
928 };
929
930 // Turn alloc into vector alloc
931 // if all its access is the same vector type.
932 class VectorAllocRewriter : public IRMutator {
933 public:
Mutate_(const Load * op,const Expr & e)934 Expr Mutate_(const Load* op, const Expr& e) final {
935 UpdateTypeMap(op->buffer_var.get(), op->type);
936 return IRMutator::Mutate_(op, e);
937 }
938
Mutate_(const Store * op,const Stmt & s)939 Stmt Mutate_(const Store* op, const Stmt& s) final {
940 UpdateTypeMap(op->buffer_var.get(), op->value.type());
941 return IRMutator::Mutate_(op, s);
942 }
Mutate_(const Call * op,const Expr & e)943 Expr Mutate_(const Call* op, const Expr& e) final {
944 if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
945 Type dtype = op->args[0].type();
946 const Variable* buffer = op->args[1].as<Variable>();
947 UpdateTypeMap(buffer, dtype);
948 }
949 return IRMutator::Mutate_(op, e);
950 }
951
Mutate_(const Allocate * op,const Stmt & s)952 Stmt Mutate_(const Allocate* op, const Stmt& s) final {
953 Stmt stmt = IRMutator::Mutate_(op, s);
954 op = stmt.as<Allocate>();
955 const auto& tvec = acc_map_[op->buffer_var.get()];
956
957 if (tvec.size() == 1 &&
958 tvec[0].element_of() == op->type.element_of() &&
959 tvec[0].lanes() % op->type.lanes() == 0 &&
960 tvec[0].lanes() != op->type.lanes()) {
961 int factor = tvec[0].lanes() / op->type.lanes();
962 Array<Expr> extents = op->extents;
963 arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
964 if (me->base % factor == 0 && me->coeff % factor == 0) {
965 extents.Set(extents.size() - 1,
966 extents[extents.size() - 1] / make_const(extents[0].type(), factor));
967 return Allocate::make(
968 op->buffer_var, tvec[0], extents,
969 op->condition, op->body);
970 }
971 }
972 return stmt;
973 }
974
UpdateTypeMap(const Variable * buffer,Type t)975 void UpdateTypeMap(const Variable* buffer, Type t) {
976 auto& tvec = acc_map_[buffer];
977 if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
978 tvec.push_back(t);
979 }
980 }
981
982 // Internal access map
983 std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
984 // internal analyzer
985 arith::Analyzer analyzer_;
986 };
987
988
PointerValueTypeRewrite(LoweredFunc f)989 LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
990 auto n = make_node<LoweredFuncNode>(*f.operator->());
991 VectorAllocRewriter rewriter;
992 n->body = rewriter.Mutate(n->body);
993 for (Var arg : f->args) {
994 if (arg.type().is_handle()) {
995 const auto& tvec = rewriter.acc_map_[arg.get()];
996 if (tvec.size() == 1) {
997 Expr dtype = make_const(tvec[0], 0);
998 n->handle_data_type.Set(arg, dtype);
999 } else {
1000 // always set data type to be non vectorized so
1001 // load/store can still work via scalarization
1002 if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
1003 Expr dtype = make_const(tvec[0].with_lanes(1), 0);
1004 n->handle_data_type.Set(arg, dtype);
1005 }
1006 }
1007 }
1008 }
1009 return LoweredFunc(n);
1010 }
1011
StorageRewrite(Stmt stmt)1012 Stmt StorageRewrite(Stmt stmt) {
1013 stmt = StoragePlanRewriter().Rewrite(stmt, true);
1014 return VectorAllocRewriter().Mutate(stmt);
1015 }
1016 } // namespace ir
1017 } // namespace tvm
1018