1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*!
20  * \file stmt_functor.cc
21  */
22 #include <tvm/runtime/registry.h>
23 #include <tvm/tir/stmt_functor.h>
24 
25 #include <functional>
26 
27 #include "functor_common.h"
28 
29 namespace tvm {
30 namespace tir {
31 
VisitStmt_(const LetStmtNode * op)32 void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
33   this->VisitExpr(op->value);
34   this->VisitStmt(op->body);
35 }
36 
VisitStmt_(const AttrStmtNode * op)37 void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
38   this->VisitExpr(op->value);
39   this->VisitStmt(op->body);
40 }
41 
VisitStmt_(const ForNode * op)42 void StmtVisitor::VisitStmt_(const ForNode* op) {
43   this->VisitExpr(op->min);
44   this->VisitExpr(op->extent);
45   this->VisitStmt(op->body);
46 }
47 
VisitStmt_(const AllocateNode * op)48 void StmtVisitor::VisitStmt_(const AllocateNode* op) {
49   VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
50   this->VisitStmt(op->body);
51   this->VisitExpr(op->condition);
52 }
53 
VisitStmt_(const StoreNode * op)54 void StmtVisitor::VisitStmt_(const StoreNode* op) {
55   this->VisitExpr(op->value);
56   this->VisitExpr(op->index);
57   this->VisitExpr(op->predicate);
58 }
59 
VisitStmt_(const BufferStoreNode * op)60 void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
61   this->VisitExpr(op->value);
62   VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
63 }
64 
VisitStmt_(const BufferRealizeNode * op)65 void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
66   VisitArray(op->bounds, [this](const Range& r) {
67     this->VisitExpr(r->min);
68     this->VisitExpr(r->extent);
69   });
70   this->VisitExpr(op->condition);
71   this->VisitStmt(op->body);
72 }
73 
VisitStmt_(const IfThenElseNode * op)74 void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
75   this->VisitExpr(op->condition);
76   this->VisitStmt(op->then_case);
77   if (op->else_case.defined()) {
78     this->VisitStmt(op->else_case);
79   }
80 }
81 
VisitStmt_(const AssertStmtNode * op)82 void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
83   this->VisitExpr(op->condition);
84   this->VisitExpr(op->message);
85   this->VisitStmt(op->body);
86 }
87 
VisitStmt_(const ProducerStoreNode * op)88 void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) {
89   VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
90   this->VisitExpr(op->value);
91 }
92 
VisitStmt_(const ProducerRealizeNode * op)93 void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) {
94   VisitArray(op->bounds, [this](const Range& r) {
95     this->VisitExpr(r->min);
96     this->VisitExpr(r->extent);
97   });
98   this->VisitStmt(op->body);
99   this->VisitExpr(op->condition);
100 }
101 
VisitStmt_(const PrefetchNode * op)102 void StmtVisitor::VisitStmt_(const PrefetchNode* op) {
103   VisitArray(op->bounds, [this](const Range& r) {
104     this->VisitExpr(r->min);
105     this->VisitExpr(r->extent);
106   });
107 }
108 
VisitStmt_(const SeqStmtNode * op)109 void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
110   VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
111 }
112 
VisitStmt_(const EvaluateNode * op)113 void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); }
114 
115 class StmtMutator::Internal {
116  public:
Mutate(StmtMutator * self,const Array<PrimExpr> & arr)117   static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
118     auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
119     return MutateArray(arr, fmutate, self->allow_copy_on_write_);
120   }
121 
Mutate(StmtMutator * self,const Array<Stmt> & arr)122   static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) {
123     auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); };
124     return MutateArray(arr, fmutate, self->allow_copy_on_write_);
125   }
126 
Mutate(StmtMutator * self,const Array<Range> & arr)127   static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
128     auto fmutate = [self](const Range& r) {
129       PrimExpr min = self->VisitExpr(r->min);
130       PrimExpr extent = self->VisitExpr(r->extent);
131       if (min.same_as(r->min) && extent.same_as(r->extent)) {
132         return r;
133       } else {
134         return Range::FromMinExtent(min, extent);
135       }
136     };
137     return MutateArray(arr, fmutate, self->allow_copy_on_write_);
138   }
139 };
140 
VisitStmt_(const AttrStmtNode * op)141 Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
142   PrimExpr value = this->VisitExpr(op->value);
143   Stmt body = this->VisitStmt(op->body);
144   if (value.same_as(op->value) && body.same_as(op->body)) {
145     return GetRef<Stmt>(op);
146   } else {
147     auto n = CopyOnWrite(op);
148     n->value = std::move(value);
149     n->body = std::move(body);
150     return Stmt(n);
151   }
152 }
153 
VisitStmt_(const LetStmtNode * op)154 Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
155   PrimExpr value = this->VisitExpr(op->value);
156   Stmt body = this->VisitStmt(op->body);
157   if (value.same_as(op->value) && body.same_as(op->body)) {
158     return GetRef<Stmt>(op);
159   } else {
160     auto n = CopyOnWrite(op);
161     n->value = std::move(value);
162     n->body = std::move(body);
163     return Stmt(n);
164   }
165 }
166 
VisitStmt_(const ForNode * op)167 Stmt StmtMutator::VisitStmt_(const ForNode* op) {
168   PrimExpr min = this->VisitExpr(op->min);
169   PrimExpr extent = this->VisitExpr(op->extent);
170   Stmt body = this->VisitStmt(op->body);
171   if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
172     return GetRef<Stmt>(op);
173   } else {
174     auto n = CopyOnWrite(op);
175     n->min = std::move(min);
176     n->extent = std::move(extent);
177     n->body = std::move(body);
178     return Stmt(n);
179   }
180 }
181 
VisitStmt_(const AllocateNode * op)182 Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
183   Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
184   Stmt body = this->VisitStmt(op->body);
185   PrimExpr condition = this->VisitExpr(op->condition);
186 
187   if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) {
188     return GetRef<Stmt>(op);
189   } else {
190     auto n = CopyOnWrite(op);
191     n->extents = std::move(extents);
192     n->body = std::move(body);
193     n->condition = std::move(condition);
194     return Stmt(n);
195   }
196 }
197 
VisitStmt_(const IfThenElseNode * op)198 Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
199   PrimExpr condition = this->VisitExpr(op->condition);
200   Stmt then_case = this->VisitStmt(op->then_case);
201   Stmt else_case;
202   if (op->else_case.defined()) {
203     else_case = this->VisitStmt(op->else_case);
204   }
205   if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
206       else_case.same_as(op->else_case)) {
207     return GetRef<Stmt>(op);
208   } else {
209     auto n = CopyOnWrite(op);
210     n->condition = std::move(condition);
211     n->then_case = std::move(then_case);
212     n->else_case = std::move(else_case);
213     return Stmt(n);
214   }
215 }
216 
VisitStmt_(const StoreNode * op)217 Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
218   PrimExpr value = this->VisitExpr(op->value);
219   PrimExpr index = this->VisitExpr(op->index);
220   PrimExpr predicate = this->VisitExpr(op->predicate);
221   if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) {
222     return GetRef<Stmt>(op);
223   } else {
224     auto n = CopyOnWrite(op);
225     n->value = std::move(value);
226     n->index = std::move(index);
227     n->predicate = std::move(predicate);
228     return Stmt(n);
229   }
230 }
231 
VisitStmt_(const BufferStoreNode * op)232 Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
233   PrimExpr value = this->VisitExpr(op->value);
234   Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
235 
236   if (value.same_as(op->value) && indices.same_as(op->indices)) {
237     return GetRef<Stmt>(op);
238   } else {
239     auto n = CopyOnWrite(op);
240     n->value = std::move(value);
241     n->indices = std::move(indices);
242     return Stmt(n);
243   }
244 }
245 
VisitStmt_(const BufferRealizeNode * op)246 Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
247   Region bounds = Internal::Mutate(this, op->bounds);
248   PrimExpr condition = this->VisitExpr(op->condition);
249   Stmt body = this->VisitStmt(op->body);
250 
251   if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) {
252     return GetRef<Stmt>(op);
253   } else {
254     auto n = CopyOnWrite(op);
255     n->bounds = std::move(bounds);
256     n->condition = std::move(condition);
257     n->body = std::move(body);
258     return Stmt(n);
259   }
260 }
261 
VisitStmt_(const ProducerStoreNode * op)262 Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) {
263   Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
264   PrimExpr value = this->VisitExpr(op->value);
265   if (indices.same_as(op->indices) && value.same_as(op->value)) {
266     return GetRef<Stmt>(op);
267   } else {
268     auto n = CopyOnWrite(op);
269     n->indices = std::move(indices);
270     n->value = std::move(value);
271     return Stmt(n);
272   }
273 }
274 
VisitStmt_(const ProducerRealizeNode * op)275 Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) {
276   Region bounds = Internal::Mutate(this, op->bounds);
277   Stmt body = this->VisitStmt(op->body);
278   PrimExpr condition = this->VisitExpr(op->condition);
279   if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) {
280     return GetRef<Stmt>(op);
281   } else {
282     auto n = CopyOnWrite(op);
283     n->bounds = std::move(bounds);
284     n->body = std::move(body);
285     n->condition = std::move(condition);
286     return Stmt(n);
287   }
288 }
289 
VisitStmt_(const PrefetchNode * op)290 Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) {
291   Region bounds = Internal::Mutate(this, op->bounds);
292   if (bounds.same_as(op->bounds)) {
293     return GetRef<Stmt>(op);
294   } else {
295     auto n = CopyOnWrite(op);
296     n->bounds = std::move(bounds);
297     return Stmt(n);
298   }
299 }
300 
VisitStmt_(const SeqStmtNode * op)301 Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
302   Array<Stmt> seq = Internal::Mutate(this, op->seq);
303   if (seq.same_as(op->seq)) {
304     return GetRef<Stmt>(op);
305   } else {
306     auto n = CopyOnWrite(op);
307     n->seq = std::move(seq);
308     return Stmt(n);
309   }
310 }
311 
312 // advanced visit function for seqstmt.
VisitSeqStmt_(const SeqStmtNode * op,bool flatten_before_visit,std::function<Stmt (const Stmt &)> fmutate)313 Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
314                                 std::function<Stmt(const Stmt&)> fmutate) {
315   if (flatten_before_visit) {
316     // Pass 1, check if we need to flatten.
317     bool need_flatten = false;
318     for (size_t i = 0; i < op->seq.size(); ++i) {
319       Stmt tmp = (*op)[i];
320       if (tmp.as<SeqStmtNode>()) need_flatten = true;
321     }
322     flatten_before_visit = need_flatten;
323   }
324   // function to run the visit.
325   auto frunvisit = [&](const SeqStmtNode* op) {
326     Array<Stmt> seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_)
327                                          : Internal::Mutate(this, op->seq);
328     if (seq.same_as(op->seq)) {
329       return GetRef<Stmt>(op);
330     } else {
331       auto n = CopyOnWrite(op);
332       n->seq = std::move(seq);
333       return Stmt(n);
334     }
335   };
336   if (flatten_before_visit) {
337     Array<Stmt> seq;
338     SeqStmt::Flattener flattener(&seq);
339     flattener(0, op->seq);
340     // NOTE: If copy on write is allowed
341     // the assignment to seq below will
342     // destruct the original seq.
343     //
344     // Such destruction removes duplicated reference
345     // count to children and still enables COW for
346     // child Stmt.
347     ObjectPtr<SeqStmtNode> n = CopyOnWrite(op);
348     n->seq = std::move(seq);
349     return frunvisit(n.operator->());
350   } else {
351     return frunvisit(op);
352   }
353 }
354 
VisitStmt_(const AssertStmtNode * op)355 Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
356   PrimExpr condition = this->VisitExpr(op->condition);
357   PrimExpr message = this->VisitExpr(op->message);
358   Stmt body = this->VisitStmt(op->body);
359 
360   if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) {
361     return GetRef<Stmt>(op);
362   } else {
363     auto n = CopyOnWrite(op);
364     n->condition = std::move(condition);
365     n->message = std::move(message);
366     n->body = std::move(body);
367     return Stmt(n);
368   }
369 }
370 
VisitStmt_(const EvaluateNode * op)371 Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
372   PrimExpr value = this->VisitExpr(op->value);
373   if (value.same_as(op->value)) {
374     return GetRef<Stmt>(op);
375   } else {
376     auto n = CopyOnWrite(op);
377     n->value = std::move(value);
378     return Stmt(n);
379   }
380 }
381 
382 // Implementations of IRTransform, PostOrderVisit and Substitute
383 class IRApplyVisit : public StmtExprVisitor {
384  public:
IRApplyVisit(std::function<void (const ObjectRef &)> f)385   explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
386 
VisitExpr(const PrimExpr & node)387   void VisitExpr(const PrimExpr& node) final {
388     if (visited_.count(node.get()) != 0) return;
389     visited_.insert(node.get());
390     ExprVisitor::VisitExpr(node);
391     f_(node);
392   }
393 
VisitStmt(const Stmt & node)394   void VisitStmt(const Stmt& node) final {
395     if (visited_.count(node.get()) != 0) return;
396     visited_.insert(node.get());
397     StmtVisitor::VisitStmt(node);
398     f_(node);
399   }
400 
401  private:
402   std::function<void(const ObjectRef&)> f_;
403   std::unordered_set<const Object*> visited_;
404 };
405 
PostOrderVisit(const ObjectRef & node,std::function<void (const ObjectRef &)> fvisit)406 void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) {
407   if (node.as<StmtNode>()) {
408     IRApplyVisit visitor(fvisit);
409     visitor(Downcast<Stmt>(node));
410   } else {
411     IRApplyVisit visitor(fvisit);
412     visitor(Downcast<PrimExpr>(node));
413   }
414 }
415 
416 class IRTransformer final : public StmtExprMutator {
417  public:
IRTransformer(const runtime::PackedFunc & f_preorder,const runtime::PackedFunc & f_postorder,const std::unordered_set<uint32_t> & only_enable)418   IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder,
419                 const std::unordered_set<uint32_t>& only_enable)
420       : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {}
421 
VisitStmt(const Stmt & stmt)422   Stmt VisitStmt(const Stmt& stmt) final {
423     return MutateInternal<Stmt>(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); });
424   }
VisitExpr(const PrimExpr & expr)425   PrimExpr VisitExpr(const PrimExpr& expr) final {
426     return MutateInternal<PrimExpr>(expr,
427                                     [this](const PrimExpr& e) { return this->BaseVisitExpr(e); });
428   }
429 
430  private:
431   // NOTE: redirect to parent's call
432   // This is used to get around limitation of gcc-4.8
BaseVisitStmt(const Stmt & s)433   Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); }
BaseVisitExpr(const PrimExpr & e)434   PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); }
435 
436   template <typename T, typename F>
437   T MutateInternal(const T& node, F fmutate) {
438     if (only_enable_.size() && !only_enable_.count(node->type_index())) {
439       return fmutate(node);
440     }
441     if (f_preorder_ != nullptr) {
442       T pre = f_preorder_(node);
443       if (pre.defined()) return pre;
444     }
445     T new_node = fmutate(node);
446     if (f_postorder_ != nullptr) {
447       T post = f_postorder_(new_node);
448       if (post.defined()) return post;
449     }
450     return new_node;
451   }
452   // The functions
453   const runtime::PackedFunc& f_preorder_;
454   const runtime::PackedFunc& f_postorder_;
455   // type indices enabled.
456   const std::unordered_set<uint32_t>& only_enable_;
457 };
458 
IRTransform(Stmt ir_node,const runtime::PackedFunc & f_preorder,const runtime::PackedFunc & f_postorder,Optional<Array<String>> only_enable)459 Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder,
460                  const runtime::PackedFunc& f_postorder, Optional<Array<String>> only_enable) {
461   std::unordered_set<uint32_t> only_type_index;
462   if (only_enable.defined()) {
463     for (auto s : only_enable.value()) {
464       only_type_index.insert(Object::TypeKey2Index(s.c_str()));
465     }
466   }
467   IRTransformer transform(f_preorder, f_postorder, only_type_index);
468   return transform(std::move(ir_node));
469 }
470 
471 class IRSubstitue : public StmtExprMutator {
472  public:
IRSubstitue(std::function<Optional<PrimExpr> (const Var &)> vmap)473   explicit IRSubstitue(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {}
474 
VisitExpr_(const VarNode * op)475   PrimExpr VisitExpr_(const VarNode* op) final {
476     Var var = GetRef<Var>(op);
477     auto ret = vmap_(var);
478     if (ret.defined()) return ret.value();
479     return std::move(var);
480   }
481 
VisitExpr_(const LoadNode * op)482   PrimExpr VisitExpr_(const LoadNode* op) final {
483     // NOTE: we do not explicit recursivly mutate op->buffer_var
484     PrimExpr ret = StmtExprMutator::VisitExpr_(op);
485     op = ret.as<LoadNode>();
486     if (auto mapped_var = vmap_(op->buffer_var)) {
487       return Load(op->dtype, Downcast<Var>(mapped_var.value()), op->index, op->predicate);
488     } else {
489       return ret;
490     }
491   }
492 
VisitStmt_(const StoreNode * op)493   Stmt VisitStmt_(const StoreNode* op) final {
494     // NOTE: we do not explicit recursivly mutate op->buffer_var
495     Stmt ret = StmtExprMutator::VisitStmt_(op);
496     op = ret.as<StoreNode>();
497     if (auto mapped_var = vmap_(op->buffer_var)) {
498       return Store(Downcast<Var>(mapped_var.value()), op->value, op->index, op->predicate);
499     } else {
500       return ret;
501     }
502   }
503 
504  private:
505   std::function<Optional<PrimExpr>(const Var&)> vmap_;
506 };
507 
Substitute(Stmt stmt,std::function<Optional<PrimExpr> (const Var &)> vmap)508 Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) {
509   return IRSubstitue(vmap)(std::move(stmt));
510 }
511 
Substitute(PrimExpr expr,std::function<Optional<PrimExpr> (const Var &)> vmap)512 PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap) {
513   return IRSubstitue(vmap)(std::move(expr));
514 }
515 
516 TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);
517 
__anone8be9d1b0e02(ObjectRef node, PackedFunc f) 518 TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
519   tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
520 });
521 
522 TVM_REGISTER_GLOBAL("tir.Substitute")
__anone8be9d1b1002(ObjectRef node, Map<Var, PrimExpr> vmap) 523     .set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
524       if (node->IsInstance<StmtNode>()) {
525         return Substitute(Downcast<Stmt>(node), vmap);
526       } else {
527         return Substitute(Downcast<PrimExpr>(node), vmap);
528       }
529     });
530 
531 }  // namespace tir
532 }  // namespace tvm
533