1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/tvm/relay/expr_functor.cc
22  * \brief A wrapper around ExprFunctor which functionally updates the AST.
23  *
24  * ExprMutator uses memoization and self return in order to amortize
25  * the cost of using functional updates.
26  */
27 #include <tvm/relay/analysis.h>
28 #include <tvm/relay/expr_functor.h>
29 #include <tvm/relay/pattern_functor.h>
30 #include "type_functor.h"
31 
32 namespace tvm {
33 namespace relay {
34 
VisitExpr(const Expr & expr)35 Expr ExprMutator::VisitExpr(const Expr& expr) {
36   auto it = this->memo_.find(expr);
37   if (it != this->memo_.end()) {
38     return it->second;
39   } else {
40     Expr new_expr = ExprFunctor::VisitExpr(expr);
41     memo_[expr] = new_expr;
42     return new_expr;
43   }
44 }
45 
VisitExpr_(const VarNode * op)46 Expr ExprMutator::VisitExpr_(const VarNode* op) {
47   if (op->type_annotation.defined()) {
48     auto type = this->VisitType(op->type_annotation);
49     if (!op->type_annotation.same_as(type)) {
50       return VarNode::make(op->vid, type);
51     }
52   }
53   // default case return self.
54   return GetRef<Expr>(op);
55 }
56 
VisitExpr_(const ConstantNode * op)57 Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
58   return GetRef<Expr>(op);
59 }
60 
VisitExpr_(const GlobalVarNode * op)61 Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
62   return GetRef<Expr>(op);
63 }
64 
VisitExpr_(const OpNode * op)65 Expr ExprMutator::VisitExpr_(const OpNode* op) {
66   return GetRef<Expr>(op);
67 }
68 
VisitExpr_(const TupleNode * op)69 Expr ExprMutator::VisitExpr_(const TupleNode* op) {
70   tvm::Array<Expr> fields;
71   bool all_fields_unchanged = true;
72   for (auto field : op->fields) {
73     auto new_field = this->Mutate(field);
74     fields.push_back(new_field);
75     all_fields_unchanged &= new_field.same_as(field);
76   }
77 
78   if (all_fields_unchanged) {
79     return GetRef<Expr>(op);
80   } else {
81     return TupleNode::make(fields);
82   }
83 }
84 
VisitExpr_(const FunctionNode * op)85 Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
86   tvm::Array<TypeVar> ty_params;
87   bool all_ty_params_unchanged = true;
88 
89   for (auto ty_param : op->type_params) {
90     TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
91     ty_params.push_back(new_ty_param);
92     all_ty_params_unchanged &= new_ty_param.same_as(ty_param);
93   }
94 
95   tvm::Array<Var> params;
96   bool all_params_unchanged = true;
97   for (auto param : op->params) {
98     Var new_param = Downcast<Var>(this->Mutate(param));
99     params.push_back(new_param);
100     all_params_unchanged &= param.same_as(new_param);
101   }
102 
103   auto ret_type = this->VisitType(op->ret_type);
104   auto body = this->Mutate(op->body);
105 
106   if (all_ty_params_unchanged &&
107       all_params_unchanged &&
108       ret_type.same_as(op->ret_type) &&
109       body.same_as(op->body)) {
110     return GetRef<Expr>(op);
111   } else {
112     return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
113   }
114 }
115 
VisitExpr_(const CallNode * call_node)116 Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
117   auto new_op = this->Mutate(call_node->op);
118   bool unchanged = call_node->op.same_as(new_op);
119 
120   tvm::Array<Type> ty_args;
121   for (auto ty_arg : call_node->type_args) {
122     auto new_ty_arg = this->VisitType(ty_arg);
123     ty_args.push_back(new_ty_arg);
124     unchanged &= new_ty_arg.same_as(ty_arg);
125   }
126 
127   tvm::Array<Expr> call_args;
128   for (auto arg : call_node->args) {
129     auto new_arg = this->Mutate(arg);
130     call_args.push_back(new_arg);
131     unchanged &= new_arg.same_as(arg);
132   }
133 
134   if (unchanged) {
135     return GetRef<Expr>(call_node);
136   } else {
137     return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
138   }
139 }
140 
VisitExpr_(const LetNode * op)141 Expr ExprMutator::VisitExpr_(const LetNode* op) {
142   Var var = Downcast<Var>(this->Mutate(op->var));
143   auto value = this->Mutate(op->value);
144   auto body = this->Mutate(op->body);
145 
146   if (var.same_as(op->var) &&
147       value.same_as(op->value) &&
148       body.same_as(op->body)) {
149     return GetRef<Expr>(op);
150   } else {
151     return LetNode::make(var, value, body);
152   }
153 }
154 
VisitExpr_(const IfNode * op)155 Expr ExprMutator::VisitExpr_(const IfNode* op) {
156   auto guard = this->Mutate(op->cond);
157   auto true_b = this->Mutate(op->true_branch);
158   auto false_b = this->Mutate(op->false_branch);
159   if (op->cond.same_as(guard) &&
160       op->true_branch.same_as(true_b) &&
161       op->false_branch.same_as(false_b)) {
162     return GetRef<Expr>(op);;
163   } else {
164     return IfNode::make(guard, true_b, false_b);
165   }
166 }
167 
VisitExpr_(const TupleGetItemNode * g)168 Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
169   auto t = this->Mutate(g->tuple);
170   if (g->tuple == t) {
171     return GetRef<Expr>(g);
172   } else {
173     return TupleGetItemNode::make(t, g->index);
174   }
175 }
176 
VisitExpr_(const RefCreateNode * op)177 Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
178   Expr value = this->Mutate(op->value);
179   if (value.same_as(op->value)) {
180     return GetRef<Expr>(op);
181   } else {
182     return RefCreateNode::make(value);
183   }
184 }
185 
VisitExpr_(const RefReadNode * op)186 Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
187   Expr ref = this->Mutate(op->ref);
188   if (ref.same_as(op->ref)) {
189     return GetRef<Expr>(op);
190   } else {
191     return RefReadNode::make(ref);
192   }
193 }
194 
VisitExpr_(const RefWriteNode * op)195 Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
196   Expr ref = this->Mutate(op->ref);
197   Expr value = this->Mutate(op->value);
198   if (ref.same_as(op->ref) && value.same_as(op->value)) {
199     return GetRef<Expr>(op);
200   } else {
201     return RefWriteNode::make(ref, value);
202   }
203 }
204 
VisitExpr_(const ConstructorNode * c)205 Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
206   return GetRef<Expr>(c);
207 }
208 
VisitExpr_(const MatchNode * m)209 Expr ExprMutator::VisitExpr_(const MatchNode* m) {
210   std::vector<Clause> clauses;
211   for (const Clause& p : m->clauses) {
212     clauses.push_back(VisitClause(p));
213   }
214   return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
215 }
216 
VisitClause(const Clause & c)217 Clause ExprMutator::VisitClause(const Clause& c) {
218   Pattern p = VisitPattern(c->lhs);
219   return ClauseNode::make(p, VisitExpr(c->rhs));
220 }
221 
VisitPattern(const Pattern & p)222 Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
223 
VisitType(const Type & t)224 Type ExprMutator::VisitType(const Type& t) { return t; }
225 
VisitExpr(const Expr & expr)226 void ExprVisitor::VisitExpr(const Expr& expr) {
227   auto it = visit_counter_.find(expr.get());
228   if (it != visit_counter_.end()) {
229     ++it->second;
230   } else {
231     using TParent = ExprFunctor<void(const Expr&)>;
232     TParent::VisitExpr(expr);
233     visit_counter_.insert({expr.get(), 1});
234   }
235 }
236 
VisitExpr_(const VarNode * op)237 void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
238   if (op->type_annotation.defined()) {
239     this->VisitType(op->type_annotation);
240   }
241 }
242 
VisitExpr_(const GlobalVarNode * op)243 void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
244 }
245 
VisitExpr_(const ConstantNode * op)246 void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
247 }
248 
VisitExpr_(const TupleNode * op)249 void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
250   for (auto field : op->fields) {
251     this->VisitExpr(field);
252   }
253 }
254 
VisitExpr_(const FunctionNode * op)255 void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
256   for (auto param : op->params) {
257     this->VisitExpr(param);
258   }
259 
260   this->VisitExpr(op->body);
261 }
262 
VisitExpr_(const CallNode * op)263 void ExprVisitor::VisitExpr_(const CallNode* op) {
264   this->VisitExpr(op->op);
265 
266   for (auto ty_arg : op->type_args) {
267     this->VisitType(ty_arg);
268   }
269 
270   for (auto arg : op->args) {
271     this->VisitExpr(arg);
272   }
273 }
274 
VisitExpr_(const LetNode * op)275 void ExprVisitor::VisitExpr_(const LetNode* op) {
276   this->VisitExpr(op->value);
277   this->VisitExpr(op->var);
278   this->VisitExpr(op->body);
279 }
280 
VisitExpr_(const IfNode * op)281 void ExprVisitor::VisitExpr_(const IfNode* op) {
282   this->VisitExpr(op->cond);
283   this->VisitExpr(op->true_branch);
284   this->VisitExpr(op->false_branch);
285 }
286 
VisitExpr_(const OpNode * op)287 void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
288 
VisitExpr_(const TupleGetItemNode * op)289 void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
290   this->VisitExpr(op->tuple);
291 }
292 
VisitExpr_(const RefCreateNode * op)293 void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
294   this->VisitExpr(op->value);
295 }
296 
VisitExpr_(const RefReadNode * op)297 void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
298   this->VisitExpr(op->ref);
299 }
300 
VisitExpr_(const RefWriteNode * op)301 void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
302   this->VisitExpr(op->ref);
303   this->VisitExpr(op->value);
304 }
305 
VisitExpr_(const ConstructorNode * op)306 void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
307   for (const Type& t : op->inputs) {
308     this->VisitType(t);
309   }
310   this->VisitType(op->belong_to);
311 }
312 
VisitExpr_(const MatchNode * op)313 void ExprVisitor::VisitExpr_(const MatchNode* op) {
314   this->VisitExpr(op->data);
315   for (const Clause& c : op->clauses) {
316     this->VisitClause(c);
317   }
318 }
319 
VisitClause(const Clause & op)320 void ExprVisitor::VisitClause(const Clause& op) {
321   this->VisitPattern(op->lhs);
322   this->VisitExpr(op->rhs);
323 }
324 
VisitPattern(const Pattern & p)325 void ExprVisitor::VisitPattern(const Pattern& p) { return; }
326 
VisitType(const Type & t)327 void ExprVisitor::VisitType(const Type& t) { return; }
328 
329 // visitor to implement apply
330 class ExprApplyVisit : public ExprVisitor {
331  public:
ExprApplyVisit(std::function<void (const Expr &)> f)332   explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
333 
VisitExpr(const Expr & e)334   void VisitExpr(const Expr& e) final {
335     if (visited_.count(e.get()) != 0) return;
336     visited_.insert(e.get());
337     ExprVisitor::VisitExpr(e);
338     f_(e);
339   }
340 
341  private:
342   std::function<void(const Expr&)> f_;
343   std::unordered_set<const Node*> visited_;
344 };
345 
PostOrderVisit(const Expr & e,std::function<void (const Expr &)> fvisit)346 void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
347   ExprApplyVisit(fvisit).VisitExpr(e);
348 }
349 
350 TVM_REGISTER_API("relay._analysis.post_order_visit")
__anon151807450102(Expr expr, PackedFunc f) 351 .set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
352     PostOrderVisit(expr, [f](const Expr& n) {
353         f(n);
354       });
355   });
356 
357 // Implement bind.
358 class ExprBinder : public ExprMutator, PatternMutator {
359  public:
ExprBinder(const tvm::Map<Var,Expr> & args_map)360   explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
361     : args_map_(args_map) {
362   }
363 
VisitExpr_(const LetNode * op)364   Expr VisitExpr_(const LetNode* op) final {
365     CHECK(!args_map_.count(op->var))
366         << "Cannot bind an internel variable in let";
367     return ExprMutator::VisitExpr_(op);
368   }
369 
VisitExpr_(const FunctionNode * op)370   Expr VisitExpr_(const FunctionNode* op) final {
371     for (Var param : op->params) {
372       CHECK(!args_map_.count(param))
373           << "Cannnot bind an internal function parameter";
374     }
375     return ExprMutator::VisitExpr_(op);
376   }
377 
VisitExpr_(const VarNode * op)378   Expr VisitExpr_(const VarNode* op) final {
379     auto id = GetRef<Var>(op);
380     auto it = args_map_.find(id);
381     if (it != args_map_.end()) {
382       return (*it).second;
383     } else {
384       return std::move(id);
385     }
386   }
387 
VisitPattern(const Pattern & p)388   Pattern VisitPattern(const Pattern& p) final {
389     return PatternMutator::VisitPattern(p);
390   }
391 
VisitClause(const Clause & c)392   Clause VisitClause(const Clause& c) final {
393     Pattern pat = VisitPattern(c->lhs);
394     return ClauseNode::make(pat, VisitExpr(c->rhs));
395   }
396 
VisitVar(const Var & v)397   Var VisitVar(const Var& v) final {
398     CHECK(!args_map_.count(v))
399       << "Cannnot bind an internal pattern variable";
400     return v;
401   }
402 
403  private:
404   const tvm::Map<Var, Expr>& args_map_;
405 };
406 
Bind(const Expr & expr,const tvm::Map<Var,Expr> & args_map)407 Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
408   if (const FunctionNode* func = expr.as<FunctionNode>()) {
409     Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
410     Array<Var> new_params;
411     for (Var param : func->params) {
412       if (!args_map.count(param)) {
413         new_params.push_back(param);
414       }
415     }
416     if (new_body.same_as(func->body) &&
417         new_params.size() == func->params.size()) {
418       return expr;
419     }
420     auto ret = FunctionNode::make(new_params,
421                                   new_body,
422                                   func->ret_type,
423                                   func->type_params,
424                                   func->attrs);
425     std::unordered_set<Var, NodeHash, NodeEqual> set;
426     for (const auto& v : FreeVars(expr)) {
427       set.insert(v);
428     }
429     for (const auto& v : FreeVars(ret)) {
430       if (set.count(v) == 0) {
431         new_params.push_back(v);
432       }
433     }
434     ret = FunctionNode::make(new_params,
435                              new_body,
436                              func->ret_type,
437                              func->type_params,
438                              func->attrs);
439     CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
440     return std::move(ret);
441   } else {
442     return ExprBinder(args_map).VisitExpr(expr);
443   }
444 }
445 
446 TVM_REGISTER_API("relay._expr.Bind")
__anon151807450302(TVMArgs args, TVMRetValue* ret) 447 .set_body([](TVMArgs args, TVMRetValue* ret) {
448     NodeRef input = args[0];
449     if (input->IsInstance<ExprNode>()) {
450       *ret = Bind(Downcast<Expr>(input), args[1]);
451     } else {
452       CHECK(input->IsInstance<TypeNode>());
453       *ret = Bind(Downcast<Type>(input), args[1]);
454     }
455   });
456 }  // namespace relay
457 }  // namespace tvm
458