1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file inject_virtual_thread.cc
22  */
23 #include <tvm/ir.h>
24 #include <tvm/ir_visitor.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_set>
28 #include "../arithmetic/compute_expr.h"
29 
30 namespace tvm {
31 namespace ir {
32 
33 // If expression is touched by var.
34 class ExprTouched final : public IRVisitor {
35  public:
ExprTouched(const std::unordered_set<const Variable * > & touched,bool check_write)36   explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
37                        bool check_write)
38       : touched_var_(touched), check_write_(check_write) {}
39 
Visit(const NodeRef & n)40   void Visit(const NodeRef& n) final {
41     // early stopping
42     if (expr_touched_ && !check_write_) return;
43     IRVisitor::Visit(n);
44   }
Visit_(const Load * op)45   void Visit_(const Load *op) final {
46     HandleUseVar(op->buffer_var.get());
47     IRVisitor::Visit_(op);
48   }
Visit_(const Variable * op)49   void Visit_(const Variable *op) final {
50     HandleUseVar(op);
51   }
Visit_(const Call * op)52   void Visit_(const Call *op) final {
53     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
54       int rw_mask = 0;
55       CHECK(arith::GetConstInt(op->args[4], &rw_mask));
56       const Variable* buffer_var = op->args[1].as<Variable>();
57       CHECK(buffer_var);
58       // read
59       if (rw_mask & 1) {
60         HandleUseVar(buffer_var);
61       }
62       if (rw_mask & 2) {
63         HandleWriteVar(buffer_var);
64       }
65       this->Visit(op->args[2]);
66     } else {
67       IRVisitor::Visit_(op);
68     }
69   }
HandleUseVar(const Variable * var)70   void HandleUseVar(const Variable* var) {
71     auto it = touched_var_.find(var);
72     if (it != touched_var_.end()) {
73       expr_touched_ = true;
74     }
75     // rember the used vars
76     // in case the var get touched later in a loop.
77     if (!expr_touched_) {
78       used_vars_.push_back(var);
79     }
80   }
HandleWriteVar(const Variable * var)81   void HandleWriteVar(const Variable* var) {
82     write_vars_.push_back(var);
83   }
84   // the fields.
85   bool expr_touched_{false};
86   std::vector<const Variable*> used_vars_;
87   std::vector<const Variable*> write_vars_;
88   const std::unordered_set<const Variable*>& touched_var_;
89   bool check_write_;
90 };
91 
92 // Analyze if the buffers are invariant to value of var
93 class VarTouchedAnalysis : public IRVisitor {
94  public:
Visit_(const LetStmt * op)95   void Visit_(const LetStmt *op) {
96     ExprTouched tc(touched_var_, false);
97     tc.Visit(op->value);
98     Record(op->var.get(), tc);
99     this->Visit(op->body);
100   }
Visit_(const Store * op)101   void Visit_(const Store *op) {
102     ExprTouched tc(touched_var_, false);
103     tc.Visit(op->value);
104     tc.Visit(op->index);
105     Record(op->buffer_var.get(), tc);
106   }
Visit_(const For * op)107   void Visit_(const For *op) {
108     ExprTouched tc(touched_var_, false);
109     tc.Visit(op->min);
110     tc.Visit(op->extent);
111     Record(op->loop_var.get(), tc);
112     this->Visit(op->body);
113   }
114   // external function call
Visit_(const Evaluate * op)115   void Visit_(const Evaluate *op) {
116     ExprTouched tc(touched_var_, true);
117     tc.Visit(op->value);
118     for (const Variable* var : tc.write_vars_) {
119       Record(var, tc);
120     }
121   }
Visit_(const Allocate * op)122   void Visit_(const Allocate *op) {
123     ExprTouched tc(touched_var_, false);
124     for (size_t i = 0; i < op->extents.size(); ++i) {
125       tc.Visit(op->extents[i]);
126     }
127     tc.Visit(op->condition);
128     if (op->new_expr.defined()) {
129       tc.Visit(op->new_expr);
130     }
131     Record(op->buffer_var.get(), tc);
132     this->Visit(op->body);
133   }
Record(const Variable * var,const ExprTouched & tc)134   void Record(const Variable* var,
135               const ExprTouched& tc) {
136     if (touched_var_.count(var)) return;
137     if (tc.expr_touched_) {
138       touched_var_.insert(var);
139     } else {
140       for (const Variable* r : tc.used_vars_) {
141         if (r != var) {
142           affect_[r].push_back(var);
143         }
144       }
145     }
146   }
147 
148   std::unordered_set<const Variable*>
TouchedVar(const Stmt & stmt,const Variable * var)149   TouchedVar(const Stmt& stmt,
150              const Variable* var) {
151     touched_var_.insert(var);
152     this->Visit(stmt);
153     // do a DFS to push affect around dependency.
154     std::vector<const Variable*> pending(
155         touched_var_.begin(), touched_var_.end());
156     while (!pending.empty()) {
157       const Variable* v = pending.back();
158       pending.pop_back();
159       for (const Variable* r : affect_[v]) {
160         if (!touched_var_.count(r)) {
161           touched_var_.insert(r);
162           pending.push_back(r);
163         }
164       }
165     }
166     return std::move(touched_var_);
167   }
168 
169  private:
170   // Whether variable is touched by the thread variable.
171   std::unordered_set<const Variable*> touched_var_;
172   // x -> all the buffers x read from
173   std::unordered_map<const Variable*,
174                      std::vector<const Variable*> > affect_;
175 };
176 
177 
178 // Inject virtual thread loop
179 // rewrite the buffer access pattern when necessary.
180 class VTInjector : public IRMutator {
181  public:
182   using IRMutator::Mutate;
183   // constructor
VTInjector(Var var,int num_threads,const std::unordered_set<const Variable * > & touched_var,bool allow_share)184   VTInjector(Var var,
185              int num_threads,
186              const std::unordered_set<const Variable*>& touched_var,
187              bool allow_share)
188       : var_(var), num_threads_(num_threads),
189         touched_var_(touched_var), allow_share_(allow_share) {
190   }
191   // Inject VTLoop when needed.
Mutate(Stmt stmt)192   Stmt Mutate(Stmt stmt) final {
193     CHECK(!visit_touched_var_);
194     stmt = IRMutator::Mutate(stmt);
195     if (visit_touched_var_ || trigger_base_inject_) {
196       if (!vt_loop_injected_)  {
197         return InjectVTLoop(stmt, false);
198       }
199       visit_touched_var_ = false;
200       trigger_base_inject_ = false;
201     }
202     return stmt;
203   }
204   // Variable
Mutate_(const Variable * op,const Expr & e)205   Expr Mutate_(const Variable *op, const Expr& e) final {
206     CHECK(!alloc_remap_.count(op))
207         << "Buffer address may get rewritten in virtual thread";
208     if (touched_var_.count(op)) {
209       visit_touched_var_ = true;
210     }
211     return e;
212   }
RewriteIndex(Expr index,Expr alloc_extent) const213   Expr RewriteIndex(Expr index, Expr alloc_extent) const {
214     return index + var_ * alloc_extent;
215   }
216   // Load
Mutate_(const Load * op,const Expr & e)217   Expr Mutate_(const Load* op, const Expr& e) final {
218     Expr expr = IRMutator::Mutate_(op, e);
219     op = expr.as<Load>();
220     if (touched_var_.count(op->buffer_var.get())) {
221       visit_touched_var_ = true;
222     }
223     auto it = alloc_remap_.find(op->buffer_var.get());
224     if (it != alloc_remap_.end()) {
225       return Load::make(op->type, op->buffer_var,
226                         RewriteIndex(op->index, it->second),
227                         op->predicate);
228     } else {
229       return expr;
230     }
231   }
232   // Expression.
Mutate_(const Call * op,const Expr & e)233   Expr Mutate_(const Call* op, const Expr& e) final {
234     if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
235       CHECK_EQ(op->args.size(), 5U);
236       Type dtype = op->args[0].type();
237       const Variable* buffer = op->args[1].as<Variable>();
238       auto it = alloc_remap_.find(buffer);
239       if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
240       visit_touched_var_ = true;
241       Expr offset = Mutate(op->args[2]);
242       Expr extent = Mutate(op->args[3]);
243       Expr stride =
244           it->second / make_const(offset.type(), dtype.lanes());
245       offset = stride * var_ + offset;
246       return Call::make(
247           op->type, op->name,
248           {op->args[0], op->args[1], offset, extent, op->args[4]},
249           op->call_type);
250     } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
251       return allow_share_ ? e : var_;
252     } else {
253       return IRMutator::Mutate_(op, e);
254     }
255   }
Mutate_(const Evaluate * op,const Stmt & s)256   Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
257     trigger_base_inject_ = !allow_share_;
258     return IRMutator::Mutate_(op, s);
259   }
260   // Store
Mutate_(const Store * op,const Stmt & s)261   Stmt Mutate_(const Store* op, const Stmt& s) final {
262     Stmt stmt = IRMutator::Mutate_(op, s);
263     op = stmt.as<Store>();
264     if (touched_var_.count(op->buffer_var.get())) {
265       visit_touched_var_ = true;
266     }
267     trigger_base_inject_ = !allow_share_;
268     auto it = alloc_remap_.find(op->buffer_var.get());
269     if (it != alloc_remap_.end()) {
270       return Store::make(op->buffer_var,
271                          op->value,
272                          RewriteIndex(op->index, it->second),
273                          op->predicate);
274     } else {
275       return stmt;
276     }
277   }
278   // Attribute
Mutate_(const AttrStmt * op,const Stmt & s)279   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
280     Expr value = Mutate(op->value);
281     if (visit_touched_var_ && !vt_loop_injected_) {
282       return InjectVTLoop(s, true);
283     } else if (!allow_share_ && !vt_loop_injected_ &&
284                (op->attr_key == attr::coproc_uop_scope ||
285                 op->attr_key == attr::coproc_scope)) {
286       return InjectVTLoop(s, true);
287     } else {
288       Stmt body = Mutate(op->body);
289       if (value.same_as(op->value) &&
290           body.same_as(op->body)) {
291         return s;
292       } else {
293         return AttrStmt::make(op->node, op->attr_key, value, body);
294       }
295     }
296   }
297   // LetStmt
Mutate_(const LetStmt * op,const Stmt & s)298   Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
299     Expr value = this->Mutate(op->value);
300     if (visit_touched_var_ && !vt_loop_injected_) {
301       return InjectVTLoop(s, true);
302     }
303     visit_touched_var_ = false;
304     Stmt body = Mutate(op->body);
305     if (value.same_as(op->value) &&
306         body.same_as(op->body)) {
307       return s;
308     } else {
309       return LetStmt::make(op->var, value, body);
310     }
311   }
312   // For
Mutate_(const For * op,const Stmt & s)313   Stmt Mutate_(const For* op, const Stmt& s) final {
314     CHECK(is_zero(op->min));
315     Expr extent = Mutate(op->extent);
316     if (visit_touched_var_ && !vt_loop_injected_) {
317       Stmt stmt = InjectVTLoop(s, true);
318       ++max_loop_depth_;
319       return stmt;
320     }
321     visit_touched_var_ = false;
322     Stmt body = Mutate(op->body);
323     ++max_loop_depth_;
324     if (extent.same_as(op->extent) &&
325         body.same_as(op->body)) {
326       return s;
327     } else {
328       return For::make(
329           op->loop_var, op->min, extent, op->for_type, op->device_api, body);
330     }
331   }
332   // IfThenElse
Mutate_(const IfThenElse * op,const Stmt & s)333   Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
334     Expr condition = this->Mutate(op->condition);
335     if (visit_touched_var_ && !vt_loop_injected_) {
336       return InjectVTLoop(s, true);
337     }
338     visit_touched_var_ = false;
339     CHECK_EQ(max_loop_depth_, 0);
340     Stmt then_case = this->Mutate(op->then_case);
341     Stmt else_case;
342     if (op->else_case.defined()) {
343       int temp = max_loop_depth_;
344       max_loop_depth_ = 0;
345       else_case = this->Mutate(op->else_case);
346       max_loop_depth_ = std::max(temp, max_loop_depth_);
347     }
348     if (condition.same_as(op->condition) &&
349         then_case.same_as(op->then_case) &&
350         else_case.same_as(op->else_case)) {
351       return s;
352     } else {
353       return IfThenElse::make(condition, then_case, else_case);
354     }
355   }
356   // Block
Mutate_(const Block * op,const Stmt & s)357   Stmt Mutate_(const Block* op, const Stmt& s) final {
358     CHECK_EQ(max_loop_depth_, 0);
359     Stmt first = this->Mutate(op->first);
360     int temp = max_loop_depth_;
361     max_loop_depth_ = 0;
362     Stmt rest = this->Mutate(op->rest);
363     max_loop_depth_ = std::max(max_loop_depth_, temp);
364     if (first.same_as(op->first) &&
365         rest.same_as(op->rest)) {
366       return s;
367     } else {
368       return Block::make(first, rest);
369     }
370   }
371   // Allocate
Mutate_(const Allocate * op,const Stmt & s)372   Stmt Mutate_(const Allocate* op, const Stmt& s) final {
373     if (op->new_expr.defined() && !vt_loop_injected_) {
374       return InjectVTLoop(s, true);
375     }
376     Expr condition = Mutate(op->condition);
377     if (visit_touched_var_ && !vt_loop_injected_) {
378       return InjectVTLoop(s, true);
379     }
380 
381     bool changed = false;
382     Array<Expr> extents;
383     for (size_t i = 0; i < op->extents.size(); i++) {
384       Expr new_ext = Mutate(op->extents[i]);
385       if (visit_touched_var_ && !vt_loop_injected_) {
386         return InjectVTLoop(s, true);
387       }
388       if (!new_ext.same_as(op->extents[i])) changed = true;
389       extents.push_back(new_ext);
390     }
391     visit_touched_var_ = false;
392 
393     Stmt body;
394     // always rewrite if not allow sharing.
395     if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
396       // place v on highest dimension.
397       Expr stride = arith::ComputeReduce<Mul>(
398           op->extents, Expr()) * op->type.lanes();
399       Array<Expr> other;
400       other.push_back(make_const(op->extents[0].type(), num_threads_));
401       for (Expr e : extents) {
402         other.push_back(e);
403       }
404       extents = other;
405       changed = true;
406       // mark this buffer get touched.
407       alloc_remap_[op->buffer_var.get()] = stride;
408       // Mutate the body.
409       body = Mutate(op->body);
410     } else {
411       // Mutate the body.
412       body = Mutate(op->body);
413     }
414     if (!changed &&
415         body.same_as(op->body) &&
416         condition.same_as(op->condition)) {
417       return s;
418     } else {
419       return Allocate::make(
420           op->buffer_var, op->type,
421           extents, condition, body,
422           op->new_expr, op->free_function);
423     }
424   }
425 
426   // inject vthread loop
InjectVTLoop(Stmt stmt,bool before_mutation)427   Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
428     CHECK(!vt_loop_injected_);
429     // reset the flags
430     visit_touched_var_ = false;
431     trigger_base_inject_ = false;
432     vt_loop_injected_ = true;
433     if (before_mutation) {
434       stmt = this->Mutate(stmt);
435     }
436     // reset the flags after processing.
437     vt_loop_injected_ = false;
438     visit_touched_var_ = false;
439     // only unroll if number of vthreads are small
440     if (max_loop_depth_ == 0 && num_threads_ < 16) {
441       // do unrolling if it is inside innermost content.
442       Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}});
443       for (int i = 1; i < num_threads_; ++i) {
444         blk = Block::make(
445             blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}}));
446       }
447       return blk;
448     } else {
449       // insert a for loop
450       Var idx(var_->name_hint + ".s", var_->type);
451       Map<Var, Expr> values{{var_, idx}};
452       stmt = Substitute(stmt, values);
453       return For::make(idx, make_zero(idx.type()),
454                        make_const(idx.type(), num_threads_),
455                        ForType::Serial, DeviceAPI::None, stmt);
456     }
457   }
458 
459  private:
460   // vthread variable
461   Var var_;
462   // the threads/lanes
463   int num_threads_;
464   // whethe the loop is already injected.
465   bool vt_loop_injected_{false};
466   // whether current expression get touched.
467   bool visit_touched_var_{false};
468   // Trigger base stmt
469   bool trigger_base_inject_{false};
470   // the counter of loops in after mutation.
471   int max_loop_depth_{0};
472   // The variables that get touched.
473   const std::unordered_set<const Variable*>& touched_var_;
474   // Whether allow shareding.
475   bool allow_share_;
476   // The allocations that get touched -> extent
477   std::unordered_map<const Variable*, Expr> alloc_remap_;
478 };
479 
480 
481 class VirtualThreadInjector : public IRMutator {
482  public:
Mutate_(const AttrStmt * op,const Stmt & s)483   Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
484     Stmt stmt = IRMutator::Mutate_(op, s);
485     op = stmt.as<AttrStmt>();
486     if (op->attr_key == attr::virtual_thread) {
487       IterVar iv = Downcast<IterVar>(op->node);
488       bool allow_share = iv->thread_tag == "vthread";
489       int nthread = static_cast<int>(op->value.as<IntImm>()->value);
490       VarTouchedAnalysis vs;
491       auto touched = vs.TouchedVar(op->body, iv->var.get());
492       VTInjector injecter(iv->var, nthread, touched, allow_share);
493       return injecter.Mutate(op->body);
494     } else {
495       return stmt;
496     }
497   }
498 
Mutate_(const Provide * op,const Stmt & s)499   Stmt Mutate_(const Provide* op, const Stmt& s) final {
500     LOG(FATAL) << "Need to call StorageFlatten first";
501     return s;
502   }
503 };
504 
InjectVirtualThread(Stmt stmt)505 Stmt InjectVirtualThread(Stmt stmt) {
506   stmt = VirtualThreadInjector().Mutate(stmt);
507   return ConvertSSA(stmt);
508 }
509 
510 }  // namespace ir
511 }  // namespace tvm
512