1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*!
20 * \file stmt_functor.cc
21 */
22 #include <tvm/runtime/registry.h>
23 #include <tvm/tir/stmt_functor.h>
24
25 #include <functional>
26
27 #include "functor_common.h"
28
29 namespace tvm {
30 namespace tir {
31
VisitStmt_(const LetStmtNode * op)32 void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
33 this->VisitExpr(op->value);
34 this->VisitStmt(op->body);
35 }
36
VisitStmt_(const AttrStmtNode * op)37 void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
38 this->VisitExpr(op->value);
39 this->VisitStmt(op->body);
40 }
41
VisitStmt_(const ForNode * op)42 void StmtVisitor::VisitStmt_(const ForNode* op) {
43 this->VisitExpr(op->min);
44 this->VisitExpr(op->extent);
45 this->VisitStmt(op->body);
46 }
47
VisitStmt_(const AllocateNode * op)48 void StmtVisitor::VisitStmt_(const AllocateNode* op) {
49 VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
50 this->VisitStmt(op->body);
51 this->VisitExpr(op->condition);
52 }
53
VisitStmt_(const StoreNode * op)54 void StmtVisitor::VisitStmt_(const StoreNode* op) {
55 this->VisitExpr(op->value);
56 this->VisitExpr(op->index);
57 this->VisitExpr(op->predicate);
58 }
59
VisitStmt_(const BufferStoreNode * op)60 void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
61 this->VisitExpr(op->value);
62 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
63 }
64
VisitStmt_(const BufferRealizeNode * op)65 void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
66 VisitArray(op->bounds, [this](const Range& r) {
67 this->VisitExpr(r->min);
68 this->VisitExpr(r->extent);
69 });
70 this->VisitExpr(op->condition);
71 this->VisitStmt(op->body);
72 }
73
VisitStmt_(const IfThenElseNode * op)74 void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
75 this->VisitExpr(op->condition);
76 this->VisitStmt(op->then_case);
77 if (op->else_case.defined()) {
78 this->VisitStmt(op->else_case);
79 }
80 }
81
VisitStmt_(const AssertStmtNode * op)82 void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
83 this->VisitExpr(op->condition);
84 this->VisitExpr(op->message);
85 this->VisitStmt(op->body);
86 }
87
VisitStmt_(const ProducerStoreNode * op)88 void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) {
89 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
90 this->VisitExpr(op->value);
91 }
92
VisitStmt_(const ProducerRealizeNode * op)93 void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) {
94 VisitArray(op->bounds, [this](const Range& r) {
95 this->VisitExpr(r->min);
96 this->VisitExpr(r->extent);
97 });
98 this->VisitStmt(op->body);
99 this->VisitExpr(op->condition);
100 }
101
VisitStmt_(const PrefetchNode * op)102 void StmtVisitor::VisitStmt_(const PrefetchNode* op) {
103 VisitArray(op->bounds, [this](const Range& r) {
104 this->VisitExpr(r->min);
105 this->VisitExpr(r->extent);
106 });
107 }
108
VisitStmt_(const SeqStmtNode * op)109 void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
110 VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
111 }
112
VisitStmt_(const EvaluateNode * op)113 void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); }
114
115 class StmtMutator::Internal {
116 public:
Mutate(StmtMutator * self,const Array<PrimExpr> & arr)117 static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
118 auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
119 return MutateArray(arr, fmutate, self->allow_copy_on_write_);
120 }
121
Mutate(StmtMutator * self,const Array<Stmt> & arr)122 static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) {
123 auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); };
124 return MutateArray(arr, fmutate, self->allow_copy_on_write_);
125 }
126
Mutate(StmtMutator * self,const Array<Range> & arr)127 static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
128 auto fmutate = [self](const Range& r) {
129 PrimExpr min = self->VisitExpr(r->min);
130 PrimExpr extent = self->VisitExpr(r->extent);
131 if (min.same_as(r->min) && extent.same_as(r->extent)) {
132 return r;
133 } else {
134 return Range::FromMinExtent(min, extent);
135 }
136 };
137 return MutateArray(arr, fmutate, self->allow_copy_on_write_);
138 }
139 };
140
VisitStmt_(const AttrStmtNode * op)141 Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
142 PrimExpr value = this->VisitExpr(op->value);
143 Stmt body = this->VisitStmt(op->body);
144 if (value.same_as(op->value) && body.same_as(op->body)) {
145 return GetRef<Stmt>(op);
146 } else {
147 auto n = CopyOnWrite(op);
148 n->value = std::move(value);
149 n->body = std::move(body);
150 return Stmt(n);
151 }
152 }
153
VisitStmt_(const LetStmtNode * op)154 Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
155 PrimExpr value = this->VisitExpr(op->value);
156 Stmt body = this->VisitStmt(op->body);
157 if (value.same_as(op->value) && body.same_as(op->body)) {
158 return GetRef<Stmt>(op);
159 } else {
160 auto n = CopyOnWrite(op);
161 n->value = std::move(value);
162 n->body = std::move(body);
163 return Stmt(n);
164 }
165 }
166
VisitStmt_(const ForNode * op)167 Stmt StmtMutator::VisitStmt_(const ForNode* op) {
168 PrimExpr min = this->VisitExpr(op->min);
169 PrimExpr extent = this->VisitExpr(op->extent);
170 Stmt body = this->VisitStmt(op->body);
171 if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
172 return GetRef<Stmt>(op);
173 } else {
174 auto n = CopyOnWrite(op);
175 n->min = std::move(min);
176 n->extent = std::move(extent);
177 n->body = std::move(body);
178 return Stmt(n);
179 }
180 }
181
VisitStmt_(const AllocateNode * op)182 Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
183 Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
184 Stmt body = this->VisitStmt(op->body);
185 PrimExpr condition = this->VisitExpr(op->condition);
186
187 if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) {
188 return GetRef<Stmt>(op);
189 } else {
190 auto n = CopyOnWrite(op);
191 n->extents = std::move(extents);
192 n->body = std::move(body);
193 n->condition = std::move(condition);
194 return Stmt(n);
195 }
196 }
197
VisitStmt_(const IfThenElseNode * op)198 Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
199 PrimExpr condition = this->VisitExpr(op->condition);
200 Stmt then_case = this->VisitStmt(op->then_case);
201 Stmt else_case;
202 if (op->else_case.defined()) {
203 else_case = this->VisitStmt(op->else_case);
204 }
205 if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
206 else_case.same_as(op->else_case)) {
207 return GetRef<Stmt>(op);
208 } else {
209 auto n = CopyOnWrite(op);
210 n->condition = std::move(condition);
211 n->then_case = std::move(then_case);
212 n->else_case = std::move(else_case);
213 return Stmt(n);
214 }
215 }
216
VisitStmt_(const StoreNode * op)217 Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
218 PrimExpr value = this->VisitExpr(op->value);
219 PrimExpr index = this->VisitExpr(op->index);
220 PrimExpr predicate = this->VisitExpr(op->predicate);
221 if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) {
222 return GetRef<Stmt>(op);
223 } else {
224 auto n = CopyOnWrite(op);
225 n->value = std::move(value);
226 n->index = std::move(index);
227 n->predicate = std::move(predicate);
228 return Stmt(n);
229 }
230 }
231
VisitStmt_(const BufferStoreNode * op)232 Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
233 PrimExpr value = this->VisitExpr(op->value);
234 Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
235
236 if (value.same_as(op->value) && indices.same_as(op->indices)) {
237 return GetRef<Stmt>(op);
238 } else {
239 auto n = CopyOnWrite(op);
240 n->value = std::move(value);
241 n->indices = std::move(indices);
242 return Stmt(n);
243 }
244 }
245
VisitStmt_(const BufferRealizeNode * op)246 Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
247 Region bounds = Internal::Mutate(this, op->bounds);
248 PrimExpr condition = this->VisitExpr(op->condition);
249 Stmt body = this->VisitStmt(op->body);
250
251 if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) {
252 return GetRef<Stmt>(op);
253 } else {
254 auto n = CopyOnWrite(op);
255 n->bounds = std::move(bounds);
256 n->condition = std::move(condition);
257 n->body = std::move(body);
258 return Stmt(n);
259 }
260 }
261
VisitStmt_(const ProducerStoreNode * op)262 Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) {
263 Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
264 PrimExpr value = this->VisitExpr(op->value);
265 if (indices.same_as(op->indices) && value.same_as(op->value)) {
266 return GetRef<Stmt>(op);
267 } else {
268 auto n = CopyOnWrite(op);
269 n->indices = std::move(indices);
270 n->value = std::move(value);
271 return Stmt(n);
272 }
273 }
274
VisitStmt_(const ProducerRealizeNode * op)275 Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) {
276 Region bounds = Internal::Mutate(this, op->bounds);
277 Stmt body = this->VisitStmt(op->body);
278 PrimExpr condition = this->VisitExpr(op->condition);
279 if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) {
280 return GetRef<Stmt>(op);
281 } else {
282 auto n = CopyOnWrite(op);
283 n->bounds = std::move(bounds);
284 n->body = std::move(body);
285 n->condition = std::move(condition);
286 return Stmt(n);
287 }
288 }
289
VisitStmt_(const PrefetchNode * op)290 Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) {
291 Region bounds = Internal::Mutate(this, op->bounds);
292 if (bounds.same_as(op->bounds)) {
293 return GetRef<Stmt>(op);
294 } else {
295 auto n = CopyOnWrite(op);
296 n->bounds = std::move(bounds);
297 return Stmt(n);
298 }
299 }
300
VisitStmt_(const SeqStmtNode * op)301 Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
302 Array<Stmt> seq = Internal::Mutate(this, op->seq);
303 if (seq.same_as(op->seq)) {
304 return GetRef<Stmt>(op);
305 } else {
306 auto n = CopyOnWrite(op);
307 n->seq = std::move(seq);
308 return Stmt(n);
309 }
310 }
311
312 // advanced visit function for seqstmt.
VisitSeqStmt_(const SeqStmtNode * op,bool flatten_before_visit,std::function<Stmt (const Stmt &)> fmutate)313 Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
314 std::function<Stmt(const Stmt&)> fmutate) {
315 if (flatten_before_visit) {
316 // Pass 1, check if we need to flatten.
317 bool need_flatten = false;
318 for (size_t i = 0; i < op->seq.size(); ++i) {
319 Stmt tmp = (*op)[i];
320 if (tmp.as<SeqStmtNode>()) need_flatten = true;
321 }
322 flatten_before_visit = need_flatten;
323 }
324 // function to run the visit.
325 auto frunvisit = [&](const SeqStmtNode* op) {
326 Array<Stmt> seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_)
327 : Internal::Mutate(this, op->seq);
328 if (seq.same_as(op->seq)) {
329 return GetRef<Stmt>(op);
330 } else {
331 auto n = CopyOnWrite(op);
332 n->seq = std::move(seq);
333 return Stmt(n);
334 }
335 };
336 if (flatten_before_visit) {
337 Array<Stmt> seq;
338 SeqStmt::Flattener flattener(&seq);
339 flattener(0, op->seq);
340 // NOTE: If copy on write is allowed
341 // the assignment to seq below will
342 // destruct the original seq.
343 //
344 // Such destruction removes duplicated reference
345 // count to children and still enables COW for
346 // child Stmt.
347 ObjectPtr<SeqStmtNode> n = CopyOnWrite(op);
348 n->seq = std::move(seq);
349 return frunvisit(n.operator->());
350 } else {
351 return frunvisit(op);
352 }
353 }
354
VisitStmt_(const AssertStmtNode * op)355 Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
356 PrimExpr condition = this->VisitExpr(op->condition);
357 PrimExpr message = this->VisitExpr(op->message);
358 Stmt body = this->VisitStmt(op->body);
359
360 if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) {
361 return GetRef<Stmt>(op);
362 } else {
363 auto n = CopyOnWrite(op);
364 n->condition = std::move(condition);
365 n->message = std::move(message);
366 n->body = std::move(body);
367 return Stmt(n);
368 }
369 }
370
VisitStmt_(const EvaluateNode * op)371 Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
372 PrimExpr value = this->VisitExpr(op->value);
373 if (value.same_as(op->value)) {
374 return GetRef<Stmt>(op);
375 } else {
376 auto n = CopyOnWrite(op);
377 n->value = std::move(value);
378 return Stmt(n);
379 }
380 }
381
382 // Implementations of IRTransform, PostOrderVisit and Substitute
383 class IRApplyVisit : public StmtExprVisitor {
384 public:
IRApplyVisit(std::function<void (const ObjectRef &)> f)385 explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
386
VisitExpr(const PrimExpr & node)387 void VisitExpr(const PrimExpr& node) final {
388 if (visited_.count(node.get()) != 0) return;
389 visited_.insert(node.get());
390 ExprVisitor::VisitExpr(node);
391 f_(node);
392 }
393
VisitStmt(const Stmt & node)394 void VisitStmt(const Stmt& node) final {
395 if (visited_.count(node.get()) != 0) return;
396 visited_.insert(node.get());
397 StmtVisitor::VisitStmt(node);
398 f_(node);
399 }
400
401 private:
402 std::function<void(const ObjectRef&)> f_;
403 std::unordered_set<const Object*> visited_;
404 };
405
PostOrderVisit(const ObjectRef & node,std::function<void (const ObjectRef &)> fvisit)406 void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) {
407 if (node.as<StmtNode>()) {
408 IRApplyVisit visitor(fvisit);
409 visitor(Downcast<Stmt>(node));
410 } else {
411 IRApplyVisit visitor(fvisit);
412 visitor(Downcast<PrimExpr>(node));
413 }
414 }
415
416 class IRTransformer final : public StmtExprMutator {
417 public:
IRTransformer(const runtime::PackedFunc & f_preorder,const runtime::PackedFunc & f_postorder,const std::unordered_set<uint32_t> & only_enable)418 IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder,
419 const std::unordered_set<uint32_t>& only_enable)
420 : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {}
421
VisitStmt(const Stmt & stmt)422 Stmt VisitStmt(const Stmt& stmt) final {
423 return MutateInternal<Stmt>(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); });
424 }
VisitExpr(const PrimExpr & expr)425 PrimExpr VisitExpr(const PrimExpr& expr) final {
426 return MutateInternal<PrimExpr>(expr,
427 [this](const PrimExpr& e) { return this->BaseVisitExpr(e); });
428 }
429
430 private:
431 // NOTE: redirect to parent's call
432 // This is used to get around limitation of gcc-4.8
BaseVisitStmt(const Stmt & s)433 Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); }
BaseVisitExpr(const PrimExpr & e)434 PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); }
435
436 template <typename T, typename F>
437 T MutateInternal(const T& node, F fmutate) {
438 if (only_enable_.size() && !only_enable_.count(node->type_index())) {
439 return fmutate(node);
440 }
441 if (f_preorder_ != nullptr) {
442 T pre = f_preorder_(node);
443 if (pre.defined()) return pre;
444 }
445 T new_node = fmutate(node);
446 if (f_postorder_ != nullptr) {
447 T post = f_postorder_(new_node);
448 if (post.defined()) return post;
449 }
450 return new_node;
451 }
452 // The functions
453 const runtime::PackedFunc& f_preorder_;
454 const runtime::PackedFunc& f_postorder_;
455 // type indices enabled.
456 const std::unordered_set<uint32_t>& only_enable_;
457 };
458
IRTransform(Stmt ir_node,const runtime::PackedFunc & f_preorder,const runtime::PackedFunc & f_postorder,Optional<Array<String>> only_enable)459 Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder,
460 const runtime::PackedFunc& f_postorder, Optional<Array<String>> only_enable) {
461 std::unordered_set<uint32_t> only_type_index;
462 if (only_enable.defined()) {
463 for (auto s : only_enable.value()) {
464 only_type_index.insert(Object::TypeKey2Index(s.c_str()));
465 }
466 }
467 IRTransformer transform(f_preorder, f_postorder, only_type_index);
468 return transform(std::move(ir_node));
469 }
470
471 class IRSubstitue : public StmtExprMutator {
472 public:
IRSubstitue(std::function<Optional<PrimExpr> (const Var &)> vmap)473 explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {}
474
VisitExpr_(const VarNode * op)475 PrimExpr VisitExpr_(const VarNode* op) final {
476 Var var = GetRef<Var>(op);
477 auto ret = vmap_(var);
478 if (ret.defined()) return ret.value();
479 return std::move(var);
480 }
481
VisitExpr_(const LoadNode * op)482 PrimExpr VisitExpr_(const LoadNode* op) final {
483 // NOTE: we do not explicit recursivly mutate op->buffer_var
484 PrimExpr ret = StmtExprMutator::VisitExpr_(op);
485 op = ret.as<LoadNode>();
486 if (auto mapped_var = vmap_(op->buffer_var)) {
487 return Load(op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
488 } else {
489 return ret;
490 }
491 }
492
VisitStmt_(const StoreNode * op)493 Stmt VisitStmt_(const StoreNode* op) final {
494 // NOTE: we do not explicit recursivly mutate op->buffer_var
495 Stmt ret = StmtExprMutator::VisitStmt_(op);
496 op = ret.as<StoreNode>();
497 if (auto mapped_var = vmap_(op->buffer_var)) {
498 return Store(Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
499 } else {
500 return ret;
501 }
502 }
503
504 private:
505 std::function<Optional<PrimExpr>(const Var&)> vmap_;
506 };
507
Substitute(Stmt stmt,std::function<Optional<PrimExpr> (const Var &)> vmap)508 Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) {
509 return IRSubstitue(vmap)(std::move(stmt));
510 }
511
Substitute(PrimExpr expr,std::function<Optional<PrimExpr> (const Var &)> vmap)512 PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap) {
513 return IRSubstitue(vmap)(std::move(expr));
514 }
515
516 TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);
517
__anone8be9d1b0e02(ObjectRef node, PackedFunc f) 518 TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
519 tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
520 });
521
522 TVM_REGISTER_GLOBAL("tir.Substitute")
__anone8be9d1b1002(ObjectRef node, Map<Var, PrimExpr> vmap) 523 .set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
524 if (node->IsInstance<StmtNode>()) {
525 return Substitute(Downcast<Stmt>(node), vmap);
526 } else {
527 return Substitute(Downcast<PrimExpr>(node), vmap);
528 }
529 });
530
531 } // namespace tir
532 } // namespace tvm
533