1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file storage_rewrite.cc
22  * \brief Memory access pattern analysis and optimization.
23  *  Re-write data access to enable memory sharing when possible.
24  */
25 #include <tvm/ir.h>
26 #include <tvm/ir_pass.h>
27 #include <tvm/ir_mutator.h>
28 #include <tvm/ir_visitor.h>
29 #include <tvm/target_info.h>
30 #include <map>
31 #include <unordered_set>
32 #include <unordered_map>
33 #include "ir_util.h"
34 #include "../arithmetic/compute_expr.h"
35 #include "../runtime/thread_storage_scope.h"
36 
37 namespace tvm {
38 namespace ir {
39 
40 using runtime::StorageRank;
41 using runtime::StorageScope;
42 
43 // Find a linear pattern of storage access
44 // Used for liveness analysis.
45 // Composite scopes(loop/thread_launch/IfThen) is represented by two points:
46 // before_scope -> scope_body -> after_scope
47 //
48 // The linear_seq_ stores before_scope and after_scope.
49 // The access to the arrays are stored at the after_scope point.
50 //
51 // Define "scope" as the body of For/thread_launch/IfThenElse
52 // This pass tries to detect last point that we need to keep memory
53 // alive under the same scope as allocate.
54 // The storage need to be kept alive between allocate and last access.
55 // The free point is only inserted at the same scope of allocate.
56 //
57 class LinearAccessPatternFinder final : public IRVisitor {
58  public:
59   /*! \brief record the touch hist of statment. */
60   struct StmtEntry {
61     // The statment
62     const Node* stmt;
63     // The index in the linear_seq_ to point to end of the nested scope.
64     // This is only set to non-zero if stmt is a nested scope.
65     // if offset > 0, means this is the begin, the end entry is current_index + offset
66     // if offset < 0, means this is the end, the begin entry is current_index + offset
67     int64_t scope_pair_offset{0};
68     // The buffer variables this statment touched.
69     std::vector<const Variable*> touched;
70   };
71   // The scope of each allocation
72   struct AllocEntry {
73     // Scope used for allocation.
74     StorageScope storage_scope;
75     // scope level
76     size_t level{0};
77     // allocation stmt
78     const Allocate* alloc{nullptr};
79   };
80 
Visit_(const Allocate * op)81   void Visit_(const Allocate* op) final {
82     size_t level = scope_.size();
83     const Variable* buf = op->buffer_var.get();
84     auto it = alloc_info_.find(buf);
85     CHECK(it != alloc_info_.end());
86     CHECK(it->second.alloc == nullptr);
87     it->second.alloc = op;
88     it->second.level = level;
89     IRVisitor::Visit_(op);
90   }
Visit_(const Store * op)91   void Visit_(const Store* op) final {
92     scope_.push_back(StmtEntry());
93     // visit subexpr
94     IRVisitor::Visit_(op);
95     // Add write access.
96     const Variable* buf = op->buffer_var.get();
97     auto it = alloc_info_.find(buf);
98     if (it != alloc_info_.end() && it->second.alloc) {
99       CHECK_LT(it->second.level, scope_.size());
100       scope_[it->second.level].touched.push_back(buf);
101     }
102     StmtEntry e = scope_.back();
103     scope_.pop_back();
104     if (e.touched.size() != 0) {
105       e.stmt = op;
106       linear_seq_.push_back(e);
107     }
108   }
Visit_(const Evaluate * op)109   void Visit_(const Evaluate* op) final {
110     scope_.push_back(StmtEntry());
111     // visit subexpr
112     IRVisitor::Visit_(op);
113     StmtEntry e = scope_.back();
114     scope_.pop_back();
115     if (e.touched.size() != 0) {
116       e.stmt = op;
117       linear_seq_.push_back(e);
118     }
119   }
Visit_(const Load * op)120   void Visit_(const Load* op) final {
121     // Add write access.
122     IRVisitor::Visit_(op);
123     const Variable* buf = op->buffer_var.get();
124     auto it = alloc_info_.find(buf);
125     if (it != alloc_info_.end() && it->second.alloc) {
126       CHECK_LT(it->second.level, scope_.size())
127           << "Load memory in places other than store.";
128       scope_[it->second.level].touched.push_back(buf);
129     }
130   }
Visit_(const Call * op)131   void Visit_(const Call* op) final {
132     if (op->is_intrinsic(intrinsic::tvm_address_of)) {
133       const Load* l = op->args[0].as<Load>();
134       this->Visit(l->index);
135     } else {
136       IRVisitor::Visit_(op);
137     }
138   }
Visit_(const Variable * buf)139   void Visit_(const Variable* buf) final {
140     // Directly reference to the variable count as a read.
141     auto it = alloc_info_.find(buf);
142     if (it != alloc_info_.end() && it->second.alloc) {
143       CHECK_LT(it->second.level, scope_.size())
144           << " buf=" << buf->name_hint;
145       scope_[it->second.level].touched.push_back(buf);
146     }
147   }
148   template<typename T>
VisitNewScope(const T * op)149   void VisitNewScope(const T* op) {
150     scope_.push_back(StmtEntry());
151     StmtEntry e;
152     e.stmt = op;
153     int64_t begin_index =  static_cast<int64_t>(linear_seq_.size());
154     // before scope.
155     linear_seq_.push_back(e);
156     IRVisitor::Visit_(op);
157     // after scope.
158     e.touched = std::move(scope_.back().touched);
159     scope_.pop_back();
160     int64_t end_index =  static_cast<int64_t>(linear_seq_.size());
161     CHECK_GT(end_index, begin_index);
162     e.scope_pair_offset = begin_index - end_index;
163     linear_seq_.push_back(e);
164     // record the pointer to end index.
165     CHECK_NE(end_index, 0U);
166     linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
167   }
Visit_(const AttrStmt * op)168   void Visit_(const AttrStmt* op) final {
169     // Only record the outer most thread extent.
170     if (op->attr_key == attr::thread_extent && !in_thread_env_) {
171       in_thread_env_ = true;
172       VisitNewScope(op);
173       in_thread_env_ = false;
174     } else if (op->attr_key == attr::extern_scope) {
175       VisitNewScope(op);
176     } else if (op->attr_key == attr::virtual_thread) {
177       VisitNewScope(op);
178     } else if (op->attr_key == attr::storage_scope) {
179       const Variable* buf = op->node.as<Variable>();
180       alloc_info_[buf].storage_scope =
181           StorageScope::make(op->value.as<StringImm>()->value);
182       IRVisitor::Visit_(op);
183     } else {
184       IRVisitor::Visit_(op);
185     }
186   }
Visit_(const IfThenElse * op)187   void Visit_(const IfThenElse* op) final {
188     VisitNewScope(op);
189   }
190 
Visit_(const For * op)191   void Visit_(const For* op) final {
192     VisitNewScope(op);
193   }
194 
Visit_(const AssertStmt * op)195   void Visit_(const AssertStmt* op) final {
196     VisitNewScope(op);
197   }
198 
199   // linearized access sequence.
200   std::vector<StmtEntry> linear_seq_;
201   // The storage scope of each buffer
202   std::unordered_map<const Variable*, AllocEntry> alloc_info_;
203 
204  private:
205   // Whether already in thread env.
206   bool in_thread_env_{false};
207   // The scope stack.
208   std::vector<StmtEntry> scope_;
209 };
210 
211 // Verify if the statement can be run safely via inplace fashion
212 //
213 // Detect pattern: dst[index] = f(src[index])
214 //
215 // WARNING: the current detection algorithm cannot handle the case
216 // when a location in an array is written multiple times
217 //
218 // For example, the following program will pass the check,
219 // but we cannot make A and B to be the same array.
220 //
221 //  A[0] = B[0] + 1
222 //  A[0] = B[0] + 1
223 //
224 // The high level code generator needs to ensure that the generated
225 // code only write each location of the target array once.
226 //
227 // This is the case with IR generated by the current compute schedule.
228 // We explicitly return false if we find there is an extern block
229 // which can be arbitrary IR.
230 //
231 // Neve-the-less, inplace detector should be used with care in mind.
232 // We may also consider introduce a condition checker that checks
233 // if every index only visited once for an absolute sufficient condition.
234 //
235 // The code after inplace transformation is no longer idempotent.
236 //
237 class InplaceOpVerifier : public IRVisitor {
238  public:
Check(const Node * stmt,const Variable * dst,const Variable * src)239   bool Check(const Node* stmt,
240              const Variable* dst,
241              const Variable* src) {
242     dst_ = dst;
243     src_ = src;
244     result_ = true;
245     if (stmt->IsInstance<AttrStmt>()) {
246       Visit_(static_cast<const AttrStmt*>(stmt));
247     } else if (stmt->IsInstance<For>()) {
248       Visit_(static_cast<const For*>(stmt));
249     } else if (stmt->IsInstance<IfThenElse>()) {
250       Visit_(static_cast<const IfThenElse*>(stmt));
251     } else if (stmt->IsInstance<Store>()) {
252       Visit_(static_cast<const Store*>(stmt));
253     } else {
254       return false;
255     }
256     return result_;
257   }
258 
259   using IRVisitor::Visit_;
260 
Visit(const NodeRef & e)261   void Visit(const NodeRef& e) final {
262     if (!result_) return;
263     IRVisitor::Visit(e);
264   }
265 
Visit_(const Variable * op)266   void Visit_(const Variable* op) final {
267     // assume all opaque access is unsafe
268     if (op == dst_ || op == src_) {
269       result_ = false; return;
270     }
271   }
272 
Visit_(const Store * op)273   void Visit_(const Store* op) final {
274     ++mem_nest_;
275     this->Visit(op->index);
276     --mem_nest_;
277     if (op->buffer_var.get() == dst_) {
278       store_ = op;
279       this->Visit(op->value);
280       this->Visit(op->predicate);
281       store_ = nullptr;
282     } else {
283       this->Visit(op->value);
284       this->Visit(op->predicate);
285     }
286   }
287 
Visit_(const AttrStmt * op)288   void Visit_(const AttrStmt* op) final {
289     // always reject extern code
290     if (op->attr_key == attr::extern_scope ||
291         op->attr_key == attr::volatile_scope) {
292       result_ = false; return;
293     }
294     IRVisitor::Visit_(op);
295   }
296 
Visit_(const Load * op)297   void Visit_(const Load* op) final {
298     const Variable* buf = op->buffer_var.get();
299     // cannot read from dst_ (no reduction)
300     if (buf == dst_) {
301       result_ = false; return;
302     }
303     // do not allow indirect memory load
304     if (mem_nest_ != 0) {
305       result_ = false; return;
306     }
307     if (src_ == buf) {
308       if (store_ == nullptr ||
309           store_->value.type() != op->type ||
310           !ir::Equal(store_->index, op->index)) {
311         result_ = false; return;
312       }
313     }
314     ++mem_nest_;
315     IRVisitor::Visit_(op);
316     --mem_nest_;
317   }
318 
319 
320  private:
321   // result of the check
322   bool result_{true};
323   // destination memory
324   const Variable* dst_;
325   // source variable
326   const Variable* src_;
327   // counter of load,
328   // it is not safe to inplace when there is nested load like A[B[i]]
329   int mem_nest_{0};
330   // The current store to be inspected
331   const Store* store_{nullptr};
332 };
333 
334 // Planner to plan and rewrite memory allocation.
335 class StoragePlanRewriter : public IRMutator {
336  public:
337   using StmtEntry = LinearAccessPatternFinder::StmtEntry;
338   using AllocEntry = LinearAccessPatternFinder::AllocEntry;
339 
Rewrite(Stmt stmt,bool detect_inplace)340   Stmt Rewrite(Stmt stmt, bool detect_inplace) {
341     detect_inplace_ = detect_inplace;
342     // plan the rewrite
343     LinearAccessPatternFinder finder;
344     finder.Visit(stmt);
345     this->LivenessAnalysis(finder.linear_seq_);
346     this->PlanMemory(finder.linear_seq_, finder.alloc_info_);
347     this->PrepareNewAlloc();
348     // start rewrite
349     stmt = this->Mutate(stmt);
350     if (attach_map_.count(nullptr)) {
351       std::vector<Stmt> nest;
352       for (StorageEntry* e : attach_map_.at(nullptr)) {
353         // CHECK_EQ(e->scope.rank, 0);
354         if (e->new_alloc.defined()) {
355           nest.emplace_back(AttrStmt::make(
356               e->alloc_var, attr::storage_scope,
357               StringImm::make(e->scope.to_string()),
358               Evaluate::make(0)));
359           nest.push_back(e->new_alloc);
360         }
361       }
362       stmt = MergeNest(nest, stmt);
363     }
364     return stmt;
365   }
Mutate_(const Store * op,const Stmt & s)366   Stmt Mutate_(const Store* op, const Stmt& s) final {
367     Stmt stmt = IRMutator::Mutate_(op, s);
368     op = stmt.as<Store>();
369     auto it = alloc_map_.find(op->buffer_var.get());
370     if (it == alloc_map_.end()) return stmt;
371     return Store::make(it->second->alloc_var,
372                        op->value,
373                        RemapIndex(op->value.type(), op->index, it->second),
374                        op->predicate);
375   }
Mutate_(const Load * op,const Expr & e)376   Expr Mutate_(const Load* op, const Expr& e) final {
377     Expr expr = IRMutator::Mutate_(op, e);
378     op = expr.as<Load>();
379     auto it = alloc_map_.find(op->buffer_var.get());
380     if (it == alloc_map_.end()) return expr;
381     return Load::make(op->type,
382                       it->second->alloc_var,
383                       RemapIndex(op->type, op->index, it->second),
384                       op->predicate);
385   }
Mutate_(const Variable * op,const Expr & e)386   Expr Mutate_(const Variable* op, const Expr& e) final {
387     auto it = alloc_map_.find(op);
388     if (it != alloc_map_.end()) {
389       if (it->second->bits_offset != 0) {
390         LOG(WARNING) << "Use a merged buffer variable address, could cause error";
391       }
392       return it->second->alloc_var;
393     } else {
394       return e;
395     }
396   }
Mutate_(const Call * op,const Expr & e)397   Expr Mutate_(const Call* op, const Expr& e) final {
398     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
399       CHECK_EQ(op->args.size(), 5U);
400       Type dtype = op->args[0].type();
401       const Variable* buffer = op->args[1].as<Variable>();
402       auto it = alloc_map_.find(buffer);
403        if (it == alloc_map_.end()) return IRMutator::Mutate_(op, e);
404        const StorageEntry* se = it->second;
405        Expr offset = Mutate(op->args[2]);
406        Expr extent = Mutate(op->args[3]);
407        uint64_t elem_bits = dtype.bits() * dtype.lanes();
408        CHECK_EQ(se->bits_offset % elem_bits, 0U);
409        if (se->bits_offset != 0) {
410          offset = make_const(offset.type(), se->bits_offset / elem_bits) + offset;
411        }
412        return Call::make(
413            op->type, op->name,
414            {op->args[0], se->alloc_var, offset, extent, op->args[4]},
415            op->call_type);
416     } else {
417       return IRMutator::Mutate_(op, e);
418     }
419   }
420 
Mutate_(const AttrStmt * op,const Stmt & s)421   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
422     if (op->attr_key == attr::storage_scope) {
423       return this->Mutate(op->body);
424     } else if (op->attr_key == attr::thread_extent ||
425                op->attr_key == attr::virtual_thread ||
426                attr::IsPragmaKey(op->attr_key)) {
427       // remake all the allocation at the attach scope.
428       if (attach_map_.count(op)) {
429         auto& svec = attach_map_[op];
430         Stmt stmt = IRMutator::Mutate_(op, s);
431         op = stmt.as<AttrStmt>();
432         return AttrStmt::make(
433             op->node, op->attr_key, op->value,
434             MakeAttach(svec, op->body));
435       } else {
436         return IRMutator::Mutate_(op, s);
437       }
438     } else if (op->attr_key == attr::volatile_scope) {
439       Stmt stmt = IRMutator::Mutate_(op, s);
440       op = stmt.as<AttrStmt>();
441       auto it = alloc_map_.find(op->node.as<Variable>());
442       if (it == alloc_map_.end()) return stmt;
443       return AttrStmt::make(
444           it->second->alloc_var, op->attr_key, op->value, op->body);
445     } else {
446       return IRMutator::Mutate_(op, s);
447     }
448   }
Mutate_(const For * op,const Stmt & s)449   Stmt Mutate_(const For* op, const Stmt& s) final {
450     CHECK(op->for_type != ForType::Vectorized)
451         << "VectorizeLoop before LiftStorageAlloc";
452     // remake all the allocation at the attach scope.
453     if (attach_map_.count(op)) {
454       auto& svec = attach_map_[op];
455       Stmt stmt = IRMutator::Mutate_(op, s);
456       op = stmt.as<For>();
457       return For::make(
458           op->loop_var, op->min, op->extent, op->for_type, op->device_api,
459           MakeAttach(svec, op->body));
460     } else {
461       return IRMutator::Mutate_(op, s);
462     }
463   }
464 
Mutate_(const Allocate * op,const Stmt & s)465   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
466     return this->Mutate(op->body);
467   }
468 
469  private:
470   struct StorageEntry {
471     // The scope that this alloc attaches after
472     // For shared/local memory it is beginning of the thread extent.
473     // for global memory it is nullptr, means beginning of everything.
474     const Node* attach_scope_{nullptr};
475     // The constant size of the buffer in bits, only used if it is constant
476     uint64_t const_nbits{0};
477     // The storage scope.
478     StorageScope scope;
479     // Allocs that shares this entry.
480     std::vector<const Allocate*> allocs;
481     // The children of this entry, not including itself.
482     std::vector<StorageEntry*> merged_children;
483     // The replacement allocation, if any.
484     Stmt new_alloc;
485     // The var expr of new allocation.
486     VarExpr alloc_var;
487     // The allocation element type.
488     Type elem_type;
489     // This is non-zero if this allocate is folded into another one
490     // the address(in bits) becomes alloc_var + bits_offset;
491     // can be effectively converted to the element type.
492     // We need to convert bit_offset to offset of specific element type later.
493     //
494     // We use bits(instead of bytes) to support non-conventional indexing in hardware.
495     // When we are merging buffer together, the bits_offset are set to be aligned
496     // to certain value given by the max_simd_bits property of the special memory.
497     //
498     // This allows effective sharing among different types as long as their alignment
499     // requirement fits into the max_simd_bits.
500     uint64_t bits_offset{0};
501   };
502 
503   // Alllocate entry of node.
504   // Event entry in liveness analysis
505   struct EventEntry {
506     // variables we generate
507     std::vector<const Variable*> gen;
508     // variables we kill
509     std::vector<const Variable*> kill;
510   };
511 
MakeAttach(const std::vector<StorageEntry * > & svec,Stmt body)512   Stmt MakeAttach(const std::vector<StorageEntry*>& svec,
513                   Stmt body) {
514     std::vector<Stmt> nest;
515     for (StorageEntry* e : svec) {
516       if (e->new_alloc.defined()) {
517         nest.emplace_back(AttrStmt::make(
518             e->alloc_var, attr::storage_scope,
519             StringImm::make(e->scope.to_string()),
520             Evaluate::make(0)));
521         nest.push_back(e->new_alloc);
522       }
523     }
524     return MergeNest(nest, body);
525   }
526   // Remap the index
RemapIndex(Type dtype,Expr index,StorageEntry * e)527   Expr RemapIndex(Type dtype, Expr index, StorageEntry* e) {
528     if (e->bits_offset == 0) return index;
529     uint64_t elem_bits = dtype.bits() * dtype.lanes();
530     CHECK_EQ(e->bits_offset % elem_bits, 0U);
531     return make_const(index.type(), e->bits_offset / elem_bits) + index;
532   }
533   // Prepare the new allocations
PrepareNewAlloc()534   void PrepareNewAlloc() {
535     for (size_t i = 0; i < alloc_vec_.size(); ++i) {
536       StorageEntry* e = alloc_vec_[i].get();
537       attach_map_[e->attach_scope_].push_back(e);
538     }
539     // find allocation via attach map.
540     for (auto &kv : attach_map_) {
541       // find the element with the most amount of bytes.
542       std::vector<StorageEntry*>& vec = kv.second;
543       // try to find merge, for tagged memory
544       for (size_t i = 0; i < vec.size(); ++i) {
545         StorageEntry* e = vec[i];
546         if (e->scope.tag.length() != 0) {
547           CHECK_NE(e->const_nbits, 0U)
548               << "Special tagged memory must be const size";
549           for (size_t j = 0; j < i; ++j) {
550             if (e->scope == vec[j]->scope) {
551               vec[j]->merged_children.push_back(e);
552               break;
553             }
554           }
555         }
556       }
557       // Start allocation
558       for (size_t i = 0; i < vec.size(); ++i) {
559         StorageEntry* e = vec[i];
560         // already merged
561         if (e->bits_offset != 0) continue;
562         if (e->merged_children.size() != 0) {
563           NewAllocTagMerged(e); continue;
564         }
565         // Get the allocation size;
566         e->alloc_var = e->allocs[0]->buffer_var;
567         Type alloc_type = e->allocs[0]->type;
568         for (const Allocate* op : e->allocs) {
569           if (op->type.lanes() > alloc_type.lanes()) {
570             alloc_type = op->type;
571           }
572         }
573         if (e->allocs.size() == 1) {
574           // simply use the original allocation.
575           Expr sz = arith::ComputeReduce<Mul>(e->allocs[0]->extents,
576                                               make_const(Int(32), 1));
577           e->new_alloc = Allocate::make(
578               e->alloc_var, alloc_type, {sz},
579               e->allocs[0]->condition, Evaluate::make(0));
580           if (e->scope.tag.length() != 0) {
581             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
582             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
583             CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
584                 << "Allocation exceed bound of memory tag " << e->scope.to_string();
585           }
586         } else {
587           // Build a merged allocation
588           Expr combo_size;
589           for (const Allocate* op : e->allocs) {
590             Expr sz = arith::ComputeReduce<Mul>(op->extents, make_const(Int(32), 1));
591             auto nbits = op->type.bits() * op->type.lanes();
592             if (const auto* imm = sz.as<IntImm>()) {
593               if (imm->value > std::numeric_limits<int>::max() / nbits) {
594                 LOG(WARNING) << "The allocation requires : " << imm->value
595                              << " * " << nbits
596                              << " bits, which is greater than the maximum of"
597                                 " int32. The size is cast to int64."
598                              << "\n";
599                 sz = make_const(Int(64), imm->value);
600               }
601             }
602             // transform to bits
603             auto sz_nbits = sz * nbits;
604             if (combo_size.defined()) {
605               combo_size = max(combo_size, sz_nbits);
606             } else {
607               combo_size = sz_nbits;
608             }
609           }
610           // transform to alloc bytes
611           auto type_bits = alloc_type.bits() * alloc_type.lanes();
612           bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0);
613           combo_size = indexdiv(combo_size, type_bits);
614           // round up for can not divided
615           if (!divided) {
616             combo_size = combo_size + make_const(Int(32), 1);
617           }
618           combo_size = ir::Simplify(combo_size);
619           e->new_alloc = Allocate::make(
620               e->alloc_var, alloc_type, {combo_size}, const_true(),
621               Evaluate::make(0));
622           if (e->scope.tag.length() != 0) {
623             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
624             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
625             CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
626                 << "Allocation exceed bound of memory tag " << e->scope.to_string();
627           }
628         }
629       }
630     }
631   }
632   // New allocation for merged data
NewAllocTagMerged(StorageEntry * e)633   void NewAllocTagMerged(StorageEntry* e) {
634     CHECK_NE(e->scope.tag.length(), 0U);
635     // allocate with element type.
636     CHECK_NE(e->const_nbits, 0U);
637     MemoryInfo info = GetMemoryInfo(e->scope.to_string());
638     uint64_t total_bits = e->const_nbits;
639     // By default, align to 32 bits.
640     size_t align = 32;
641     if (info.defined()) {
642       align = info->max_simd_bits;
643     }
644     // Always align to max_simd_bits
645     // so we can remap types by keeping this property
646     if (total_bits % align != 0) {
647       total_bits += align  - (total_bits % align);
648     }
649     e->alloc_var = e->allocs[0]->buffer_var;
650     for (StorageEntry* child : e->merged_children) {
651       CHECK_NE(child->const_nbits, 0U);
652       CHECK_NE(total_bits, 0U);
653       child->bits_offset = total_bits;
654       child->alloc_var = e->alloc_var;
655       total_bits += child->const_nbits;
656       if (total_bits % align != 0) {
657         total_bits += align  - (total_bits % align);
658       }
659     }
660     uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes();
661     Expr alloc_size = make_const(e->allocs[0]->extents[0].type(),
662                                  (total_bits + type_bits - 1) / type_bits);
663     e->new_alloc = Allocate::make(
664         e->alloc_var, e->elem_type, {alloc_size}, const_true(),
665         Evaluate::make(0));
666     if (info.defined()) {
667       CHECK_LE(total_bits, info->max_num_bits)
668           << "Allocation exceed bound of memory tag " << e->scope.to_string();
669     }
670   }
671   // Liveness analysis to find gen and kill point of each variable.
LivenessAnalysis(const std::vector<StmtEntry> & seq)672   void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
673     // find kill point, do a reverse linear scan.
674     std::unordered_set<const Variable*> touched;
675     for (size_t i = seq.size(); i != 0; --i) {
676       const StmtEntry& s = seq[i - 1];
677       for (const Variable* buffer : s.touched) {
678         if (!touched.count(buffer)) {
679           touched.insert(buffer);
680           event_map_[s.stmt].kill.push_back(buffer);
681         }
682       }
683     }
684     // find gen point, do forward scan
685     touched.clear();
686     for (size_t i = 0; i < seq.size(); ++i) {
687       int64_t offset = seq[i].scope_pair_offset;
688       if (offset < 0) continue;
689       const StmtEntry& s = seq[i + offset];
690       for (const Variable* buffer : s.touched) {
691         if (!touched.count(buffer)) {
692           touched.insert(buffer);
693           event_map_[s.stmt].gen.push_back(buffer);
694         }
695       }
696     }
697   }
PlanNewScope(const Node * op)698   void PlanNewScope(const Node* op) {
699     if (thread_scope_ != nullptr) {
700       CHECK(thread_scope_ == op);
701       // erase all memory atatched to this scope.
702       for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
703         if (it->second->attach_scope_ == op) {
704           it = const_free_map_.erase(it);
705         } else {
706           ++it;
707         }
708       }
709       for (auto it = sym_free_list_.begin(); it != sym_free_list_.end();) {
710         if ((*it)->attach_scope_ == op) {
711           it = sym_free_list_.erase(it);
712         } else {
713           ++it;
714         }
715       }
716       thread_scope_ = nullptr;
717     } else {
718       thread_scope_ = op;
719     }
720   }
721 
722   // Memory plan algorithm
PlanMemory(const std::vector<StmtEntry> & seq,const std::unordered_map<const Variable *,AllocEntry> & alloc_info)723   void PlanMemory(const std::vector<StmtEntry>& seq,
724                   const std::unordered_map<const Variable*, AllocEntry>& alloc_info) {
725     std::unordered_set<const Variable*> inplace_flag;
726 
727     for (size_t i = 0; i < seq.size(); ++i) {
728       const StmtEntry& s = seq[i];
729       auto it = event_map_.find(seq[i].stmt);
730 
731       // scope_pair_offset >= 0 means it is either
732       // - leaf stmt(offset = 0)
733       // - beginning of scope(offset < 0)
734       // In both cases, we need to handle the gen event correctly
735       if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
736         // Inplace operation detection
737         // specially handle this
738         bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2);
739 
740         for (const Variable* var : it->second.gen) {
741           CHECK(alloc_info.count(var));
742           const AllocEntry& ae = alloc_info.at(var);
743           StorageEntry* dst_entry = nullptr;
744           // inplace detection
745           if (detect_inplace) {
746             // only one inplace var for s.stmt
747             bool inplace_found = false;
748             for (const Variable* src : it->second.kill) {
749               if (!inplace_flag.count(src) && alloc_map_.count(src)) {
750                 InplaceOpVerifier visitor;
751                 StorageEntry* src_entry = alloc_map_.at(src);
752                 if (src_entry->scope == ae.storage_scope &&
753                     src_entry->attach_scope_ == thread_scope_ &&
754                     src_entry->elem_type == ae.alloc->type.element_of() &&
755                     visitor.Check(s.stmt, var, src)) {
756                   uint64_t const_nbits =
757                       static_cast<uint64_t>(ae.alloc->constant_allocation_size()) *
758                       ae.alloc->type.bits() *
759                       ae.alloc->type.lanes();
760                   if (src_entry->const_nbits == const_nbits && !inplace_found) {
761                     // successfully inplace
762                     dst_entry = src_entry;
763                     inplace_flag.insert(src);
764                     inplace_found = true;
765                   }
766                 }
767               }
768             }
769           }
770           if (dst_entry == nullptr) {
771             dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope);
772           }
773           dst_entry->allocs.emplace_back(ae.alloc);
774           alloc_map_[var] = dst_entry;
775         }
776       }
777       // enter/exit new scope
778       if (s.stmt->IsInstance<AttrStmt>()) {
779         const auto* op = static_cast<const AttrStmt*>(s.stmt);
780         if (op->attr_key == attr::thread_extent ||
781             op->attr_key == attr::virtual_thread ||
782             attr::IsPragmaKey(op->attr_key)) {
783           PlanNewScope(op);
784         } else {
785           CHECK(op->attr_key == attr::extern_scope);
786         }
787       } else if (s.stmt->IsInstance<For>()) {
788         const auto* op = static_cast<const For*>(s.stmt);
789         if (op->for_type == ForType::Parallel) {
790           if (thread_scope_ == nullptr || thread_scope_ == op) {
791             PlanNewScope(op);
792           }
793         }
794       }
795       // scope_pair_offset <= 0 means it is either
796       // - leaf stmt(offset = 0)
797       // - end of scope(offset < 0)
798       // In both cases, we need to handle the kill event correctly
799       if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
800         for (const Variable* var : it->second.kill) {
801           // skip space which are already replaced by inplace
802           if (!inplace_flag.count(var)) {
803             this->Free(var);
804           }
805         }
806       }
807     }
808   }
809   // Allocate new storage entry.
NewAlloc(const Allocate * op,const Node * attach_scope,const StorageScope & scope,size_t const_nbits)810   StorageEntry* NewAlloc(const Allocate* op,
811                          const Node* attach_scope,
812                          const StorageScope& scope,
813                          size_t const_nbits) {
814     CHECK(op != nullptr);
815     // Re-use not successful, allocate a new buffer.
816     std::unique_ptr<StorageEntry> entry(new StorageEntry());
817     entry->attach_scope_ = attach_scope;
818     entry->scope = scope;
819     entry->elem_type = op->type.element_of();
820     entry->const_nbits = const_nbits;
821     StorageEntry* e = entry.get();
822     alloc_vec_.emplace_back(std::move(entry));
823     return e;
824   }
825 
FindAlloc(const Allocate * op,const Node * attach_scope,const StorageScope & scope)826   StorageEntry* FindAlloc(const Allocate* op,
827                           const Node* attach_scope,
828                           const StorageScope& scope) {
829     CHECK(op != nullptr);
830     // skip plan for local variable,
831     // compiler can do a better job with register allocation.
832     const uint64_t match_range = 16;
833     uint64_t op_elem_bits = op->type.bits() * op->type.lanes();
834     uint64_t const_nbits = static_cast<uint64_t>(
835         op->constant_allocation_size() * op_elem_bits);
836     // disable reuse of small arrays, they will be lowered to registers in LLVM
837     // This rules only apply if we are using non special memory
838     if (scope.tag.length() == 0) {
839       if (scope.rank >= StorageRank::kWarp || op->type.is_handle()) {
840         return NewAlloc(op, attach_scope, scope, const_nbits);
841       }
842       if (const_nbits > 0  &&  const_nbits <= 32) {
843         return NewAlloc(op, attach_scope, scope, const_nbits);
844       }
845     }
846     if (const_nbits != 0) {
847       // constant allocation.
848       auto begin = const_free_map_.lower_bound(const_nbits / match_range);
849       auto mid = const_free_map_.lower_bound(const_nbits);
850       auto end = const_free_map_.upper_bound(const_nbits * match_range);
851       // start looking at the buffer that is bigger than the required size first
852       for (auto it = mid; it != end; ++it) {
853         StorageEntry *e = it->second;
854         if (e->attach_scope_ != attach_scope) continue;
855         if (e->scope != scope) continue;
856         // when not divided, no reuse, eg, float4 vs float3
857         if (e->bits_offset % op_elem_bits != 0) continue;
858         e->const_nbits = std::max(const_nbits, e->const_nbits);
859         const_free_map_.erase(it);
860         return e;
861       }
862       // then start looking at smaller buffers.
863       for (auto it = mid; it != begin;) {
864         --it;
865         StorageEntry *e = it->second;
866         if (e->attach_scope_ != attach_scope) continue;
867         if (e->scope != scope) continue;
868         if (e->elem_type != op->type.element_of()) continue;
869         e->const_nbits = std::max(const_nbits, e->const_nbits);
870         const_free_map_.erase(it);
871         return e;
872       }
873     } else {
874       // Simple strategy: round roubin.
875       for (auto it = sym_free_list_.begin();
876            it != sym_free_list_.end(); ++it) {
877         StorageEntry* e = *it;
878         if (e->attach_scope_ != attach_scope) continue;
879         if (e->scope != scope) continue;
880         if (e->elem_type != op->type.element_of()) continue;
881         sym_free_list_.erase(it);
882         return e;
883       }
884     }
885     return NewAlloc(op, attach_scope, scope, const_nbits);
886   }
887   // simulated free.
Free(const Variable * var)888   void Free(const Variable* var) {
889     auto it = alloc_map_.find(var);
890     CHECK(it != alloc_map_.end());
891     StorageEntry* e = it->second;
892     CHECK_NE(e->allocs.size(), 0U);
893 
894     // disable reuse of small arrays, they will be lowered to registers in LLVM
895     // This rules only apply if we are using non special memory
896     if (e->scope.tag.length() == 0) {
897       // Disable sharing of local memory.
898       if (e->scope.rank >= StorageRank::kWarp ||
899           e->allocs[0]->type.is_handle()) return;
900       // disable reuse of small arrays
901       if (e->const_nbits > 0 && e->const_nbits <= 32) return;
902     }
903     // normal free.
904     if (e->const_nbits != 0) {
905       const_free_map_.insert({e->const_nbits, e});
906     } else {
907       sym_free_list_.push_back(e);
908     }
909   }
910   // thread scope.
911   const Node* thread_scope_{nullptr};
912   // whether enable inplace detection.
913   bool detect_inplace_{false};
914   // Locations of free ops.
915   std::unordered_map<const Node*, EventEntry> event_map_;
916   // constant size free map.
917   std::multimap<uint64_t, StorageEntry*> const_free_map_;
918   // symbolic free list, for non constant items.
919   std::list<StorageEntry*> sym_free_list_;
920   // The allocation attach map
921   std::unordered_map<const Node*, std::vector<StorageEntry*> > attach_map_;
922   // The allocation assign map
923   std::unordered_map<const Variable*, StorageEntry*> alloc_map_;
924   // The allocations
925   std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
926   // analyzer
927   arith::Analyzer analyzer_;
928 };
929 
930 // Turn alloc into vector alloc
931 // if all its access is the same vector type.
932 class VectorAllocRewriter : public IRMutator {
933  public:
Mutate_(const Load * op,const Expr & e)934   Expr Mutate_(const Load* op, const Expr& e) final {
935     UpdateTypeMap(op->buffer_var.get(), op->type);
936     return IRMutator::Mutate_(op, e);
937   }
938 
Mutate_(const Store * op,const Stmt & s)939   Stmt Mutate_(const Store* op, const Stmt& s) final {
940     UpdateTypeMap(op->buffer_var.get(), op->value.type());
941     return IRMutator::Mutate_(op, s);
942   }
Mutate_(const Call * op,const Expr & e)943   Expr Mutate_(const Call* op, const Expr& e) final {
944     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
945       Type dtype = op->args[0].type();
946       const Variable* buffer = op->args[1].as<Variable>();
947       UpdateTypeMap(buffer, dtype);
948     }
949     return IRMutator::Mutate_(op, e);
950   }
951 
Mutate_(const Allocate * op,const Stmt & s)952   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
953     Stmt stmt = IRMutator::Mutate_(op, s);
954     op = stmt.as<Allocate>();
955     const auto& tvec = acc_map_[op->buffer_var.get()];
956 
957     if (tvec.size() == 1 &&
958         tvec[0].element_of() == op->type.element_of() &&
959         tvec[0].lanes() % op->type.lanes() == 0 &&
960         tvec[0].lanes() != op->type.lanes()) {
961       int factor = tvec[0].lanes() / op->type.lanes();
962       Array<Expr> extents = op->extents;
963       arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]);
964       if (me->base % factor == 0 && me->coeff % factor == 0) {
965         extents.Set(extents.size() - 1,
966                     extents[extents.size() - 1] / make_const(extents[0].type(), factor));
967         return Allocate::make(
968             op->buffer_var, tvec[0], extents,
969             op->condition, op->body);
970       }
971     }
972     return stmt;
973   }
974 
UpdateTypeMap(const Variable * buffer,Type t)975   void UpdateTypeMap(const Variable* buffer, Type t) {
976     auto& tvec = acc_map_[buffer];
977     if (std::find(tvec.begin(), tvec.end(), t) == tvec.end()) {
978       tvec.push_back(t);
979     }
980   }
981 
982   // Internal access map
983   std::unordered_map<const Variable*, std::vector<Type> > acc_map_;
984   // internal analyzer
985   arith::Analyzer analyzer_;
986 };
987 
988 
PointerValueTypeRewrite(LoweredFunc f)989 LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
990   auto n = make_node<LoweredFuncNode>(*f.operator->());
991   VectorAllocRewriter rewriter;
992   n->body = rewriter.Mutate(n->body);
993   for (Var arg : f->args) {
994     if (arg.type().is_handle()) {
995       const auto& tvec = rewriter.acc_map_[arg.get()];
996       if (tvec.size() == 1) {
997         Expr dtype = make_const(tvec[0], 0);
998         n->handle_data_type.Set(arg, dtype);
999       } else {
1000         // always set data type to be non vectorized so
1001         // load/store can still work via scalarization
1002         if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
1003           Expr dtype = make_const(tvec[0].with_lanes(1), 0);
1004           n->handle_data_type.Set(arg, dtype);
1005         }
1006       }
1007     }
1008   }
1009   return LoweredFunc(n);
1010 }
1011 
StorageRewrite(Stmt stmt)1012 Stmt StorageRewrite(Stmt stmt) {
1013   stmt = StoragePlanRewriter().Rewrite(stmt, true);
1014   return VectorAllocRewriter().Mutate(stmt);
1015 }
1016 }  // namespace ir
1017 }  // namespace tvm
1018