1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/tvm/relay/expr_functor.cc
22 * \brief A wrapper around ExprFunctor which functionally updates the AST.
23 *
24 * ExprMutator uses memoization and self return in order to amortize
25 * the cost of using functional updates.
26 */
27 #include <tvm/relay/analysis.h>
28 #include <tvm/relay/expr_functor.h>
29 #include <tvm/relay/pattern_functor.h>
30 #include "type_functor.h"
31
32 namespace tvm {
33 namespace relay {
34
VisitExpr(const Expr & expr)35 Expr ExprMutator::VisitExpr(const Expr& expr) {
36 auto it = this->memo_.find(expr);
37 if (it != this->memo_.end()) {
38 return it->second;
39 } else {
40 Expr new_expr = ExprFunctor::VisitExpr(expr);
41 memo_[expr] = new_expr;
42 return new_expr;
43 }
44 }
45
VisitExpr_(const VarNode * op)46 Expr ExprMutator::VisitExpr_(const VarNode* op) {
47 if (op->type_annotation.defined()) {
48 auto type = this->VisitType(op->type_annotation);
49 if (!op->type_annotation.same_as(type)) {
50 return VarNode::make(op->vid, type);
51 }
52 }
53 // default case return self.
54 return GetRef<Expr>(op);
55 }
56
VisitExpr_(const ConstantNode * op)57 Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
58 return GetRef<Expr>(op);
59 }
60
VisitExpr_(const GlobalVarNode * op)61 Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
62 return GetRef<Expr>(op);
63 }
64
VisitExpr_(const OpNode * op)65 Expr ExprMutator::VisitExpr_(const OpNode* op) {
66 return GetRef<Expr>(op);
67 }
68
VisitExpr_(const TupleNode * op)69 Expr ExprMutator::VisitExpr_(const TupleNode* op) {
70 tvm::Array<Expr> fields;
71 bool all_fields_unchanged = true;
72 for (auto field : op->fields) {
73 auto new_field = this->Mutate(field);
74 fields.push_back(new_field);
75 all_fields_unchanged &= new_field.same_as(field);
76 }
77
78 if (all_fields_unchanged) {
79 return GetRef<Expr>(op);
80 } else {
81 return TupleNode::make(fields);
82 }
83 }
84
VisitExpr_(const FunctionNode * op)85 Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
86 tvm::Array<TypeVar> ty_params;
87 bool all_ty_params_unchanged = true;
88
89 for (auto ty_param : op->type_params) {
90 TypeVar new_ty_param = Downcast<TypeVar>(VisitType(ty_param));
91 ty_params.push_back(new_ty_param);
92 all_ty_params_unchanged &= new_ty_param.same_as(ty_param);
93 }
94
95 tvm::Array<Var> params;
96 bool all_params_unchanged = true;
97 for (auto param : op->params) {
98 Var new_param = Downcast<Var>(this->Mutate(param));
99 params.push_back(new_param);
100 all_params_unchanged &= param.same_as(new_param);
101 }
102
103 auto ret_type = this->VisitType(op->ret_type);
104 auto body = this->Mutate(op->body);
105
106 if (all_ty_params_unchanged &&
107 all_params_unchanged &&
108 ret_type.same_as(op->ret_type) &&
109 body.same_as(op->body)) {
110 return GetRef<Expr>(op);
111 } else {
112 return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
113 }
114 }
115
VisitExpr_(const CallNode * call_node)116 Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
117 auto new_op = this->Mutate(call_node->op);
118 bool unchanged = call_node->op.same_as(new_op);
119
120 tvm::Array<Type> ty_args;
121 for (auto ty_arg : call_node->type_args) {
122 auto new_ty_arg = this->VisitType(ty_arg);
123 ty_args.push_back(new_ty_arg);
124 unchanged &= new_ty_arg.same_as(ty_arg);
125 }
126
127 tvm::Array<Expr> call_args;
128 for (auto arg : call_node->args) {
129 auto new_arg = this->Mutate(arg);
130 call_args.push_back(new_arg);
131 unchanged &= new_arg.same_as(arg);
132 }
133
134 if (unchanged) {
135 return GetRef<Expr>(call_node);
136 } else {
137 return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
138 }
139 }
140
VisitExpr_(const LetNode * op)141 Expr ExprMutator::VisitExpr_(const LetNode* op) {
142 Var var = Downcast<Var>(this->Mutate(op->var));
143 auto value = this->Mutate(op->value);
144 auto body = this->Mutate(op->body);
145
146 if (var.same_as(op->var) &&
147 value.same_as(op->value) &&
148 body.same_as(op->body)) {
149 return GetRef<Expr>(op);
150 } else {
151 return LetNode::make(var, value, body);
152 }
153 }
154
VisitExpr_(const IfNode * op)155 Expr ExprMutator::VisitExpr_(const IfNode* op) {
156 auto guard = this->Mutate(op->cond);
157 auto true_b = this->Mutate(op->true_branch);
158 auto false_b = this->Mutate(op->false_branch);
159 if (op->cond.same_as(guard) &&
160 op->true_branch.same_as(true_b) &&
161 op->false_branch.same_as(false_b)) {
162 return GetRef<Expr>(op);;
163 } else {
164 return IfNode::make(guard, true_b, false_b);
165 }
166 }
167
VisitExpr_(const TupleGetItemNode * g)168 Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
169 auto t = this->Mutate(g->tuple);
170 if (g->tuple == t) {
171 return GetRef<Expr>(g);
172 } else {
173 return TupleGetItemNode::make(t, g->index);
174 }
175 }
176
VisitExpr_(const RefCreateNode * op)177 Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
178 Expr value = this->Mutate(op->value);
179 if (value.same_as(op->value)) {
180 return GetRef<Expr>(op);
181 } else {
182 return RefCreateNode::make(value);
183 }
184 }
185
VisitExpr_(const RefReadNode * op)186 Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
187 Expr ref = this->Mutate(op->ref);
188 if (ref.same_as(op->ref)) {
189 return GetRef<Expr>(op);
190 } else {
191 return RefReadNode::make(ref);
192 }
193 }
194
VisitExpr_(const RefWriteNode * op)195 Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
196 Expr ref = this->Mutate(op->ref);
197 Expr value = this->Mutate(op->value);
198 if (ref.same_as(op->ref) && value.same_as(op->value)) {
199 return GetRef<Expr>(op);
200 } else {
201 return RefWriteNode::make(ref, value);
202 }
203 }
204
VisitExpr_(const ConstructorNode * c)205 Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
206 return GetRef<Expr>(c);
207 }
208
VisitExpr_(const MatchNode * m)209 Expr ExprMutator::VisitExpr_(const MatchNode* m) {
210 std::vector<Clause> clauses;
211 for (const Clause& p : m->clauses) {
212 clauses.push_back(VisitClause(p));
213 }
214 return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
215 }
216
VisitClause(const Clause & c)217 Clause ExprMutator::VisitClause(const Clause& c) {
218 Pattern p = VisitPattern(c->lhs);
219 return ClauseNode::make(p, VisitExpr(c->rhs));
220 }
221
VisitPattern(const Pattern & p)222 Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
223
VisitType(const Type & t)224 Type ExprMutator::VisitType(const Type& t) { return t; }
225
VisitExpr(const Expr & expr)226 void ExprVisitor::VisitExpr(const Expr& expr) {
227 auto it = visit_counter_.find(expr.get());
228 if (it != visit_counter_.end()) {
229 ++it->second;
230 } else {
231 using TParent = ExprFunctor<void(const Expr&)>;
232 TParent::VisitExpr(expr);
233 visit_counter_.insert({expr.get(), 1});
234 }
235 }
236
VisitExpr_(const VarNode * op)237 void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
238 if (op->type_annotation.defined()) {
239 this->VisitType(op->type_annotation);
240 }
241 }
242
VisitExpr_(const GlobalVarNode * op)243 void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
244 }
245
VisitExpr_(const ConstantNode * op)246 void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
247 }
248
VisitExpr_(const TupleNode * op)249 void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
250 for (auto field : op->fields) {
251 this->VisitExpr(field);
252 }
253 }
254
VisitExpr_(const FunctionNode * op)255 void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
256 for (auto param : op->params) {
257 this->VisitExpr(param);
258 }
259
260 this->VisitExpr(op->body);
261 }
262
VisitExpr_(const CallNode * op)263 void ExprVisitor::VisitExpr_(const CallNode* op) {
264 this->VisitExpr(op->op);
265
266 for (auto ty_arg : op->type_args) {
267 this->VisitType(ty_arg);
268 }
269
270 for (auto arg : op->args) {
271 this->VisitExpr(arg);
272 }
273 }
274
VisitExpr_(const LetNode * op)275 void ExprVisitor::VisitExpr_(const LetNode* op) {
276 this->VisitExpr(op->value);
277 this->VisitExpr(op->var);
278 this->VisitExpr(op->body);
279 }
280
VisitExpr_(const IfNode * op)281 void ExprVisitor::VisitExpr_(const IfNode* op) {
282 this->VisitExpr(op->cond);
283 this->VisitExpr(op->true_branch);
284 this->VisitExpr(op->false_branch);
285 }
286
VisitExpr_(const OpNode * op)287 void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
288
VisitExpr_(const TupleGetItemNode * op)289 void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
290 this->VisitExpr(op->tuple);
291 }
292
VisitExpr_(const RefCreateNode * op)293 void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
294 this->VisitExpr(op->value);
295 }
296
VisitExpr_(const RefReadNode * op)297 void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
298 this->VisitExpr(op->ref);
299 }
300
VisitExpr_(const RefWriteNode * op)301 void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
302 this->VisitExpr(op->ref);
303 this->VisitExpr(op->value);
304 }
305
VisitExpr_(const ConstructorNode * op)306 void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
307 for (const Type& t : op->inputs) {
308 this->VisitType(t);
309 }
310 this->VisitType(op->belong_to);
311 }
312
VisitExpr_(const MatchNode * op)313 void ExprVisitor::VisitExpr_(const MatchNode* op) {
314 this->VisitExpr(op->data);
315 for (const Clause& c : op->clauses) {
316 this->VisitClause(c);
317 }
318 }
319
VisitClause(const Clause & op)320 void ExprVisitor::VisitClause(const Clause& op) {
321 this->VisitPattern(op->lhs);
322 this->VisitExpr(op->rhs);
323 }
324
VisitPattern(const Pattern & p)325 void ExprVisitor::VisitPattern(const Pattern& p) { return; }
326
VisitType(const Type & t)327 void ExprVisitor::VisitType(const Type& t) { return; }
328
329 // visitor to implement apply
330 class ExprApplyVisit : public ExprVisitor {
331 public:
ExprApplyVisit(std::function<void (const Expr &)> f)332 explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
333
VisitExpr(const Expr & e)334 void VisitExpr(const Expr& e) final {
335 if (visited_.count(e.get()) != 0) return;
336 visited_.insert(e.get());
337 ExprVisitor::VisitExpr(e);
338 f_(e);
339 }
340
341 private:
342 std::function<void(const Expr&)> f_;
343 std::unordered_set<const Node*> visited_;
344 };
345
PostOrderVisit(const Expr & e,std::function<void (const Expr &)> fvisit)346 void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
347 ExprApplyVisit(fvisit).VisitExpr(e);
348 }
349
350 TVM_REGISTER_API("relay._analysis.post_order_visit")
__anon151807450102(Expr expr, PackedFunc f) 351 .set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
352 PostOrderVisit(expr, [f](const Expr& n) {
353 f(n);
354 });
355 });
356
357 // Implement bind.
358 class ExprBinder : public ExprMutator, PatternMutator {
359 public:
ExprBinder(const tvm::Map<Var,Expr> & args_map)360 explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
361 : args_map_(args_map) {
362 }
363
VisitExpr_(const LetNode * op)364 Expr VisitExpr_(const LetNode* op) final {
365 CHECK(!args_map_.count(op->var))
366 << "Cannot bind an internel variable in let";
367 return ExprMutator::VisitExpr_(op);
368 }
369
VisitExpr_(const FunctionNode * op)370 Expr VisitExpr_(const FunctionNode* op) final {
371 for (Var param : op->params) {
372 CHECK(!args_map_.count(param))
373 << "Cannnot bind an internal function parameter";
374 }
375 return ExprMutator::VisitExpr_(op);
376 }
377
VisitExpr_(const VarNode * op)378 Expr VisitExpr_(const VarNode* op) final {
379 auto id = GetRef<Var>(op);
380 auto it = args_map_.find(id);
381 if (it != args_map_.end()) {
382 return (*it).second;
383 } else {
384 return std::move(id);
385 }
386 }
387
VisitPattern(const Pattern & p)388 Pattern VisitPattern(const Pattern& p) final {
389 return PatternMutator::VisitPattern(p);
390 }
391
VisitClause(const Clause & c)392 Clause VisitClause(const Clause& c) final {
393 Pattern pat = VisitPattern(c->lhs);
394 return ClauseNode::make(pat, VisitExpr(c->rhs));
395 }
396
VisitVar(const Var & v)397 Var VisitVar(const Var& v) final {
398 CHECK(!args_map_.count(v))
399 << "Cannnot bind an internal pattern variable";
400 return v;
401 }
402
403 private:
404 const tvm::Map<Var, Expr>& args_map_;
405 };
406
Bind(const Expr & expr,const tvm::Map<Var,Expr> & args_map)407 Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
408 if (const FunctionNode* func = expr.as<FunctionNode>()) {
409 Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
410 Array<Var> new_params;
411 for (Var param : func->params) {
412 if (!args_map.count(param)) {
413 new_params.push_back(param);
414 }
415 }
416 if (new_body.same_as(func->body) &&
417 new_params.size() == func->params.size()) {
418 return expr;
419 }
420 auto ret = FunctionNode::make(new_params,
421 new_body,
422 func->ret_type,
423 func->type_params,
424 func->attrs);
425 std::unordered_set<Var, NodeHash, NodeEqual> set;
426 for (const auto& v : FreeVars(expr)) {
427 set.insert(v);
428 }
429 for (const auto& v : FreeVars(ret)) {
430 if (set.count(v) == 0) {
431 new_params.push_back(v);
432 }
433 }
434 ret = FunctionNode::make(new_params,
435 new_body,
436 func->ret_type,
437 func->type_params,
438 func->attrs);
439 CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
440 return std::move(ret);
441 } else {
442 return ExprBinder(args_map).VisitExpr(expr);
443 }
444 }
445
446 TVM_REGISTER_API("relay._expr.Bind")
__anon151807450302(TVMArgs args, TVMRetValue* ret) 447 .set_body([](TVMArgs args, TVMRetValue* ret) {
448 NodeRef input = args[0];
449 if (input->IsInstance<ExprNode>()) {
450 *ret = Bind(Downcast<Expr>(input), args[1]);
451 } else {
452 CHECK(input->IsInstance<TypeNode>());
453 *ret = Bind(Downcast<Type>(input), args[1]);
454 }
455 });
456 } // namespace relay
457 } // namespace tvm
458