1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file inject_virtual_thread.cc
22 */
23 #include <tvm/ir.h>
24 #include <tvm/ir_visitor.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_set>
28 #include "../arithmetic/compute_expr.h"
29
30 namespace tvm {
31 namespace ir {
32
33 // If expression is touched by var.
34 class ExprTouched final : public IRVisitor {
35 public:
ExprTouched(const std::unordered_set<const Variable * > & touched,bool check_write)36 explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
37 bool check_write)
38 : touched_var_(touched), check_write_(check_write) {}
39
Visit(const NodeRef & n)40 void Visit(const NodeRef& n) final {
41 // early stopping
42 if (expr_touched_ && !check_write_) return;
43 IRVisitor::Visit(n);
44 }
Visit_(const Load * op)45 void Visit_(const Load *op) final {
46 HandleUseVar(op->buffer_var.get());
47 IRVisitor::Visit_(op);
48 }
Visit_(const Variable * op)49 void Visit_(const Variable *op) final {
50 HandleUseVar(op);
51 }
Visit_(const Call * op)52 void Visit_(const Call *op) final {
53 if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
54 int rw_mask = 0;
55 CHECK(arith::GetConstInt(op->args[4], &rw_mask));
56 const Variable* buffer_var = op->args[1].as<Variable>();
57 CHECK(buffer_var);
58 // read
59 if (rw_mask & 1) {
60 HandleUseVar(buffer_var);
61 }
62 if (rw_mask & 2) {
63 HandleWriteVar(buffer_var);
64 }
65 this->Visit(op->args[2]);
66 } else {
67 IRVisitor::Visit_(op);
68 }
69 }
HandleUseVar(const Variable * var)70 void HandleUseVar(const Variable* var) {
71 auto it = touched_var_.find(var);
72 if (it != touched_var_.end()) {
73 expr_touched_ = true;
74 }
75 // rember the used vars
76 // in case the var get touched later in a loop.
77 if (!expr_touched_) {
78 used_vars_.push_back(var);
79 }
80 }
HandleWriteVar(const Variable * var)81 void HandleWriteVar(const Variable* var) {
82 write_vars_.push_back(var);
83 }
84 // the fields.
85 bool expr_touched_{false};
86 std::vector<const Variable*> used_vars_;
87 std::vector<const Variable*> write_vars_;
88 const std::unordered_set<const Variable*>& touched_var_;
89 bool check_write_;
90 };
91
92 // Analyze if the buffers are invariant to value of var
93 class VarTouchedAnalysis : public IRVisitor {
94 public:
Visit_(const LetStmt * op)95 void Visit_(const LetStmt *op) {
96 ExprTouched tc(touched_var_, false);
97 tc.Visit(op->value);
98 Record(op->var.get(), tc);
99 this->Visit(op->body);
100 }
Visit_(const Store * op)101 void Visit_(const Store *op) {
102 ExprTouched tc(touched_var_, false);
103 tc.Visit(op->value);
104 tc.Visit(op->index);
105 Record(op->buffer_var.get(), tc);
106 }
Visit_(const For * op)107 void Visit_(const For *op) {
108 ExprTouched tc(touched_var_, false);
109 tc.Visit(op->min);
110 tc.Visit(op->extent);
111 Record(op->loop_var.get(), tc);
112 this->Visit(op->body);
113 }
114 // external function call
Visit_(const Evaluate * op)115 void Visit_(const Evaluate *op) {
116 ExprTouched tc(touched_var_, true);
117 tc.Visit(op->value);
118 for (const Variable* var : tc.write_vars_) {
119 Record(var, tc);
120 }
121 }
Visit_(const Allocate * op)122 void Visit_(const Allocate *op) {
123 ExprTouched tc(touched_var_, false);
124 for (size_t i = 0; i < op->extents.size(); ++i) {
125 tc.Visit(op->extents[i]);
126 }
127 tc.Visit(op->condition);
128 if (op->new_expr.defined()) {
129 tc.Visit(op->new_expr);
130 }
131 Record(op->buffer_var.get(), tc);
132 this->Visit(op->body);
133 }
Record(const Variable * var,const ExprTouched & tc)134 void Record(const Variable* var,
135 const ExprTouched& tc) {
136 if (touched_var_.count(var)) return;
137 if (tc.expr_touched_) {
138 touched_var_.insert(var);
139 } else {
140 for (const Variable* r : tc.used_vars_) {
141 if (r != var) {
142 affect_[r].push_back(var);
143 }
144 }
145 }
146 }
147
148 std::unordered_set<const Variable*>
TouchedVar(const Stmt & stmt,const Variable * var)149 TouchedVar(const Stmt& stmt,
150 const Variable* var) {
151 touched_var_.insert(var);
152 this->Visit(stmt);
153 // do a DFS to push affect around dependency.
154 std::vector<const Variable*> pending(
155 touched_var_.begin(), touched_var_.end());
156 while (!pending.empty()) {
157 const Variable* v = pending.back();
158 pending.pop_back();
159 for (const Variable* r : affect_[v]) {
160 if (!touched_var_.count(r)) {
161 touched_var_.insert(r);
162 pending.push_back(r);
163 }
164 }
165 }
166 return std::move(touched_var_);
167 }
168
169 private:
170 // Whether variable is touched by the thread variable.
171 std::unordered_set<const Variable*> touched_var_;
172 // x -> all the buffers x read from
173 std::unordered_map<const Variable*,
174 std::vector<const Variable*> > affect_;
175 };
176
177
178 // Inject virtual thread loop
179 // rewrite the buffer access pattern when necessary.
180 class VTInjector : public IRMutator {
181 public:
182 using IRMutator::Mutate;
183 // constructor
VTInjector(Var var,int num_threads,const std::unordered_set<const Variable * > & touched_var,bool allow_share)184 VTInjector(Var var,
185 int num_threads,
186 const std::unordered_set<const Variable*>& touched_var,
187 bool allow_share)
188 : var_(var), num_threads_(num_threads),
189 touched_var_(touched_var), allow_share_(allow_share) {
190 }
191 // Inject VTLoop when needed.
Mutate(Stmt stmt)192 Stmt Mutate(Stmt stmt) final {
193 CHECK(!visit_touched_var_);
194 stmt = IRMutator::Mutate(stmt);
195 if (visit_touched_var_ || trigger_base_inject_) {
196 if (!vt_loop_injected_) {
197 return InjectVTLoop(stmt, false);
198 }
199 visit_touched_var_ = false;
200 trigger_base_inject_ = false;
201 }
202 return stmt;
203 }
204 // Variable
Mutate_(const Variable * op,const Expr & e)205 Expr Mutate_(const Variable *op, const Expr& e) final {
206 CHECK(!alloc_remap_.count(op))
207 << "Buffer address may get rewritten in virtual thread";
208 if (touched_var_.count(op)) {
209 visit_touched_var_ = true;
210 }
211 return e;
212 }
RewriteIndex(Expr index,Expr alloc_extent) const213 Expr RewriteIndex(Expr index, Expr alloc_extent) const {
214 return index + var_ * alloc_extent;
215 }
216 // Load
Mutate_(const Load * op,const Expr & e)217 Expr Mutate_(const Load* op, const Expr& e) final {
218 Expr expr = IRMutator::Mutate_(op, e);
219 op = expr.as<Load>();
220 if (touched_var_.count(op->buffer_var.get())) {
221 visit_touched_var_ = true;
222 }
223 auto it = alloc_remap_.find(op->buffer_var.get());
224 if (it != alloc_remap_.end()) {
225 return Load::make(op->type, op->buffer_var,
226 RewriteIndex(op->index, it->second),
227 op->predicate);
228 } else {
229 return expr;
230 }
231 }
232 // Expression.
Mutate_(const Call * op,const Expr & e)233 Expr Mutate_(const Call* op, const Expr& e) final {
234 if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
235 CHECK_EQ(op->args.size(), 5U);
236 Type dtype = op->args[0].type();
237 const Variable* buffer = op->args[1].as<Variable>();
238 auto it = alloc_remap_.find(buffer);
239 if (it == alloc_remap_.end()) return IRMutator::Mutate_(op, e);
240 visit_touched_var_ = true;
241 Expr offset = Mutate(op->args[2]);
242 Expr extent = Mutate(op->args[3]);
243 Expr stride =
244 it->second / make_const(offset.type(), dtype.lanes());
245 offset = stride * var_ + offset;
246 return Call::make(
247 op->type, op->name,
248 {op->args[0], op->args[1], offset, extent, op->args[4]},
249 op->call_type);
250 } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
251 return allow_share_ ? e : var_;
252 } else {
253 return IRMutator::Mutate_(op, e);
254 }
255 }
Mutate_(const Evaluate * op,const Stmt & s)256 Stmt Mutate_(const Evaluate* op, const Stmt& s) final {
257 trigger_base_inject_ = !allow_share_;
258 return IRMutator::Mutate_(op, s);
259 }
260 // Store
Mutate_(const Store * op,const Stmt & s)261 Stmt Mutate_(const Store* op, const Stmt& s) final {
262 Stmt stmt = IRMutator::Mutate_(op, s);
263 op = stmt.as<Store>();
264 if (touched_var_.count(op->buffer_var.get())) {
265 visit_touched_var_ = true;
266 }
267 trigger_base_inject_ = !allow_share_;
268 auto it = alloc_remap_.find(op->buffer_var.get());
269 if (it != alloc_remap_.end()) {
270 return Store::make(op->buffer_var,
271 op->value,
272 RewriteIndex(op->index, it->second),
273 op->predicate);
274 } else {
275 return stmt;
276 }
277 }
278 // Attribute
Mutate_(const AttrStmt * op,const Stmt & s)279 Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
280 Expr value = Mutate(op->value);
281 if (visit_touched_var_ && !vt_loop_injected_) {
282 return InjectVTLoop(s, true);
283 } else if (!allow_share_ && !vt_loop_injected_ &&
284 (op->attr_key == attr::coproc_uop_scope ||
285 op->attr_key == attr::coproc_scope)) {
286 return InjectVTLoop(s, true);
287 } else {
288 Stmt body = Mutate(op->body);
289 if (value.same_as(op->value) &&
290 body.same_as(op->body)) {
291 return s;
292 } else {
293 return AttrStmt::make(op->node, op->attr_key, value, body);
294 }
295 }
296 }
297 // LetStmt
Mutate_(const LetStmt * op,const Stmt & s)298 Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
299 Expr value = this->Mutate(op->value);
300 if (visit_touched_var_ && !vt_loop_injected_) {
301 return InjectVTLoop(s, true);
302 }
303 visit_touched_var_ = false;
304 Stmt body = Mutate(op->body);
305 if (value.same_as(op->value) &&
306 body.same_as(op->body)) {
307 return s;
308 } else {
309 return LetStmt::make(op->var, value, body);
310 }
311 }
312 // For
Mutate_(const For * op,const Stmt & s)313 Stmt Mutate_(const For* op, const Stmt& s) final {
314 CHECK(is_zero(op->min));
315 Expr extent = Mutate(op->extent);
316 if (visit_touched_var_ && !vt_loop_injected_) {
317 Stmt stmt = InjectVTLoop(s, true);
318 ++max_loop_depth_;
319 return stmt;
320 }
321 visit_touched_var_ = false;
322 Stmt body = Mutate(op->body);
323 ++max_loop_depth_;
324 if (extent.same_as(op->extent) &&
325 body.same_as(op->body)) {
326 return s;
327 } else {
328 return For::make(
329 op->loop_var, op->min, extent, op->for_type, op->device_api, body);
330 }
331 }
332 // IfThenElse
Mutate_(const IfThenElse * op,const Stmt & s)333 Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
334 Expr condition = this->Mutate(op->condition);
335 if (visit_touched_var_ && !vt_loop_injected_) {
336 return InjectVTLoop(s, true);
337 }
338 visit_touched_var_ = false;
339 CHECK_EQ(max_loop_depth_, 0);
340 Stmt then_case = this->Mutate(op->then_case);
341 Stmt else_case;
342 if (op->else_case.defined()) {
343 int temp = max_loop_depth_;
344 max_loop_depth_ = 0;
345 else_case = this->Mutate(op->else_case);
346 max_loop_depth_ = std::max(temp, max_loop_depth_);
347 }
348 if (condition.same_as(op->condition) &&
349 then_case.same_as(op->then_case) &&
350 else_case.same_as(op->else_case)) {
351 return s;
352 } else {
353 return IfThenElse::make(condition, then_case, else_case);
354 }
355 }
356 // Block
Mutate_(const Block * op,const Stmt & s)357 Stmt Mutate_(const Block* op, const Stmt& s) final {
358 CHECK_EQ(max_loop_depth_, 0);
359 Stmt first = this->Mutate(op->first);
360 int temp = max_loop_depth_;
361 max_loop_depth_ = 0;
362 Stmt rest = this->Mutate(op->rest);
363 max_loop_depth_ = std::max(max_loop_depth_, temp);
364 if (first.same_as(op->first) &&
365 rest.same_as(op->rest)) {
366 return s;
367 } else {
368 return Block::make(first, rest);
369 }
370 }
371 // Allocate
Mutate_(const Allocate * op,const Stmt & s)372 Stmt Mutate_(const Allocate* op, const Stmt& s) final {
373 if (op->new_expr.defined() && !vt_loop_injected_) {
374 return InjectVTLoop(s, true);
375 }
376 Expr condition = Mutate(op->condition);
377 if (visit_touched_var_ && !vt_loop_injected_) {
378 return InjectVTLoop(s, true);
379 }
380
381 bool changed = false;
382 Array<Expr> extents;
383 for (size_t i = 0; i < op->extents.size(); i++) {
384 Expr new_ext = Mutate(op->extents[i]);
385 if (visit_touched_var_ && !vt_loop_injected_) {
386 return InjectVTLoop(s, true);
387 }
388 if (!new_ext.same_as(op->extents[i])) changed = true;
389 extents.push_back(new_ext);
390 }
391 visit_touched_var_ = false;
392
393 Stmt body;
394 // always rewrite if not allow sharing.
395 if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
396 // place v on highest dimension.
397 Expr stride = arith::ComputeReduce<Mul>(
398 op->extents, Expr()) * op->type.lanes();
399 Array<Expr> other;
400 other.push_back(make_const(op->extents[0].type(), num_threads_));
401 for (Expr e : extents) {
402 other.push_back(e);
403 }
404 extents = other;
405 changed = true;
406 // mark this buffer get touched.
407 alloc_remap_[op->buffer_var.get()] = stride;
408 // Mutate the body.
409 body = Mutate(op->body);
410 } else {
411 // Mutate the body.
412 body = Mutate(op->body);
413 }
414 if (!changed &&
415 body.same_as(op->body) &&
416 condition.same_as(op->condition)) {
417 return s;
418 } else {
419 return Allocate::make(
420 op->buffer_var, op->type,
421 extents, condition, body,
422 op->new_expr, op->free_function);
423 }
424 }
425
426 // inject vthread loop
InjectVTLoop(Stmt stmt,bool before_mutation)427 Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
428 CHECK(!vt_loop_injected_);
429 // reset the flags
430 visit_touched_var_ = false;
431 trigger_base_inject_ = false;
432 vt_loop_injected_ = true;
433 if (before_mutation) {
434 stmt = this->Mutate(stmt);
435 }
436 // reset the flags after processing.
437 vt_loop_injected_ = false;
438 visit_touched_var_ = false;
439 // only unroll if number of vthreads are small
440 if (max_loop_depth_ == 0 && num_threads_ < 16) {
441 // do unrolling if it is inside innermost content.
442 Stmt blk = Substitute(stmt, {{var_, make_zero(var_.type())}});
443 for (int i = 1; i < num_threads_; ++i) {
444 blk = Block::make(
445 blk, Substitute(stmt, {{var_, make_const(var_.type(), i)}}));
446 }
447 return blk;
448 } else {
449 // insert a for loop
450 Var idx(var_->name_hint + ".s", var_->type);
451 Map<Var, Expr> values{{var_, idx}};
452 stmt = Substitute(stmt, values);
453 return For::make(idx, make_zero(idx.type()),
454 make_const(idx.type(), num_threads_),
455 ForType::Serial, DeviceAPI::None, stmt);
456 }
457 }
458
459 private:
460 // vthread variable
461 Var var_;
462 // the threads/lanes
463 int num_threads_;
464 // whethe the loop is already injected.
465 bool vt_loop_injected_{false};
466 // whether current expression get touched.
467 bool visit_touched_var_{false};
468 // Trigger base stmt
469 bool trigger_base_inject_{false};
470 // the counter of loops in after mutation.
471 int max_loop_depth_{0};
472 // The variables that get touched.
473 const std::unordered_set<const Variable*>& touched_var_;
474 // Whether allow shareding.
475 bool allow_share_;
476 // The allocations that get touched -> extent
477 std::unordered_map<const Variable*, Expr> alloc_remap_;
478 };
479
480
481 class VirtualThreadInjector : public IRMutator {
482 public:
Mutate_(const AttrStmt * op,const Stmt & s)483 Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
484 Stmt stmt = IRMutator::Mutate_(op, s);
485 op = stmt.as<AttrStmt>();
486 if (op->attr_key == attr::virtual_thread) {
487 IterVar iv = Downcast<IterVar>(op->node);
488 bool allow_share = iv->thread_tag == "vthread";
489 int nthread = static_cast<int>(op->value.as<IntImm>()->value);
490 VarTouchedAnalysis vs;
491 auto touched = vs.TouchedVar(op->body, iv->var.get());
492 VTInjector injecter(iv->var, nthread, touched, allow_share);
493 return injecter.Mutate(op->body);
494 } else {
495 return stmt;
496 }
497 }
498
Mutate_(const Provide * op,const Stmt & s)499 Stmt Mutate_(const Provide* op, const Stmt& s) final {
500 LOG(FATAL) << "Need to call StorageFlatten first";
501 return s;
502 }
503 };
504
InjectVirtualThread(Stmt stmt)505 Stmt InjectVirtualThread(Stmt stmt) {
506 stmt = VirtualThreadInjector().Mutate(stmt);
507 return ConvertSSA(stmt);
508 }
509
510 } // namespace ir
511 } // namespace tvm
512