1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/tvm/relay/ir/alpha_equal.cc
22  * \brief Alpha equality check by deep comparing two nodes.
23  */
24 #include <tvm/ir_pass.h>
25 #include <tvm/relay/expr_functor.h>
26 #include <tvm/relay/pattern_functor.h>
27 #include <tvm/runtime/ndarray.h>
28 #include <tvm/relay/analysis.h>
29 #include <tvm/relay/op_attr_types.h>
30 #include <tvm/relay/attrs/nn.h>
31 #include "type_functor.h"
32 #include "../../lang/attr_functor.h"
33 namespace tvm {
34 namespace relay {
35 
36 // Alpha Equal handler for Relay.
37 class AlphaEqualHandler:
38       public AttrsEqualHandler,
39       public TypeFunctor<bool(const Type&, const Type&)>,
40       public ExprFunctor<bool(const Expr&, const Expr&)>,
41       public PatternFunctor<bool(const Pattern&, const Pattern&)> {
42  public:
AlphaEqualHandler(bool map_free_var,bool assert_mode)43   explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
44     : map_free_var_(map_free_var), assert_mode_(assert_mode) { }
45 
46   /*!
47    * Check equality of two nodes.
48    * \param lhs The left hand operand.
49    * \param rhs The right hand operand.
50    * \return The comparison result.
51    */
Equal(const NodeRef & lhs,const NodeRef & rhs)52   bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
53     if (lhs.same_as(rhs)) return true;
54     if (!lhs.defined() || !rhs.defined()) return false;
55     if (lhs->IsInstance<TypeNode>()) {
56       if (!rhs->IsInstance<TypeNode>()) return false;
57       return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
58     }
59     if (lhs->IsInstance<ExprNode>()) {
60       if (!rhs->IsInstance<ExprNode>()) return false;
61       return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
62     }
63     if (const auto lhsm = lhs.as<ModuleNode>()) {
64       auto rhsm = rhs.as<ModuleNode>();
65       if (!rhsm) return false;
66       if (lhsm->functions.size() != rhsm->functions.size()) return false;
67       for (const auto& p : lhsm->functions) {
68         if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
69       }
70       if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
71       for (const auto& p : lhsm->type_definitions) {
72         if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
73             !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
74           return false;
75         }
76       }
77       return true;
78     }
79     return AttrEqual(lhs, rhs);
80   }
81 
DoubleEqual(double l,double r)82   bool DoubleEqual(double l, double r) {
83     return true;
84   }
85   /*!
86    * Check equality of two attributes.
87    * \param lhs The left hand operand.
88    * \param rhs The right hand operand.
89    * \return The comparison result.
90    */
AttrEqual(const NodeRef & lhs,const NodeRef & rhs)91   bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
92     auto compute = [&]() {
93       if (&lhs == &rhs) return true;
94       if (auto lhsd = lhs.as<DictAttrsNode>()) {
95         auto rhsd = rhs.as<DictAttrsNode>();
96         if (!rhsd) return false;
97         if (lhsd->dict.size() != rhsd->dict.size()) return false;
98         for (const auto& k : lhsd->dict) {
99           if (!Equal(k.second, rhsd->dict[k.first])) return false;
100         }
101         return true;
102       }
103       if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
104         auto rhsbn = rhs.as<BatchNormAttrs>();
105         if (!rhsbn) return false;
106         return (lhsbn->axis == rhsbn->axis)
107           && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
108           && (lhsbn->center == rhsbn->center)
109           && (lhsbn->scale == rhsbn->scale);
110       }
111       return AttrsEqualHandler::Equal(lhs, rhs);
112     };
113     return Compare(compute(), lhs, rhs);
114   }
115   /*!
116    * Check equality of two types.
117    * \param lhs The left hand operand.
118    * \param rhs The right hand operand.
119    * \return the comparison result.
120    */
TypeEqual(const Type & lhs,const Type & rhs)121   bool TypeEqual(const Type& lhs, const Type& rhs) {
122     auto compute = [&]() {
123       if (lhs.same_as(rhs)) return true;
124       if (!lhs.defined() || !rhs.defined()) return false;
125       return this->VisitType(lhs, rhs);
126     };
127     return Compare(compute(), lhs, rhs);
128   }
129 
Compare(bool result,const NodeRef & lhs,const NodeRef & rhs)130   bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
131     if (assert_mode_) {
132       CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
133     }
134     return result;
135   }
136   /*!
137    * Check equality of two expressions.
138    *
139    * \note We run graph structural equality checking when comparing two Exprs.
140    *   This means that AlphaEqualHandler can only be used once for each pair.
141    *   The equality checker checks data-flow equvalence of the Expr DAG.
142    *   This function also runs faster as it memomizes equal_map.
143    *
144    * \param lhs The left hand operand.
145    * \param rhs The right hand operand.
146    * \return The comparison result.
147    */
ExprEqual(const Expr & lhs,const Expr & rhs)148   bool ExprEqual(const Expr& lhs, const Expr& rhs) {
149     auto compute = [&]() {
150       if (lhs.same_as(rhs)) return true;
151       if (!lhs.defined() || !rhs.defined()) return false;
152       auto it = equal_map_.find(lhs);
153       if (it != equal_map_.end()) {
154         return it->second.same_as(rhs);
155       }
156       if (this->VisitExpr(lhs, rhs)) {
157         equal_map_[lhs] = rhs;
158         return true;
159       } else {
160         return false;
161       }
162     };
163     return Compare(compute(), lhs, rhs);
164   }
165 
166  protected:
167   /*!
168    * \brief Check if data type equals each other.
169    * \param lhs The left hand operand.
170    * \param rhs The right hand operand.
171    * \return The compare result.
172    */
DataTypeEqual(const DataType & lhs,const DataType & rhs)173   bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
174     return lhs == rhs;
175   }
176   /*!
177    * \brief Check Equality of leaf node of the graph.
178    *  if map_free_var_ is set to true, try to map via equal node.
179    * \param lhs The left hand operand.
180    * \param rhs The right hand operand.
181    * \return The compare result.
182    */
LeafNodeEqual(const ObjectRef & lhs,const ObjectRef & rhs)183   bool LeafNodeEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
184     if (lhs.same_as(rhs)) return true;
185     auto it = equal_map_.find(lhs);
186     if (it != equal_map_.end()) {
187       return it->second.same_as(rhs);
188     } else {
189       if (map_free_var_) {
190         if (lhs->type_index() != rhs->type_index()) return false;
191         equal_map_[lhs] = rhs;
192         return true;
193       } else {
194         return false;
195       }
196     }
197   }
198   using AttrsEqualHandler::VisitAttr_;
VisitAttr_(const Variable * lhs,const ObjectRef & other)199   bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final {
200     return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
201   }
202 
203   // Type equality
VisitType_(const TensorTypeNode * lhs,const Type & other)204   bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
205     if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
206       return (lhs->dtype == rhs->dtype &&
207               AttrEqual(lhs->shape, rhs->shape));
208     } else {
209       return false;
210     }
211   }
212 
VisitType_(const IncompleteTypeNode * lhs,const Type & other)213   bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
214     return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
215   }
216 
VisitType_(const TypeVarNode * lhs,const Type & other)217   bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
218     if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
219       if (lhs->kind != rhs->kind) return false;
220       return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
221     } else {
222       return false;
223     }
224   }
225 
VisitType_(const FuncTypeNode * lhs,const Type & other)226   bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
227     if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
228       if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
229       if (lhs->type_params.size() != rhs->type_params.size()) return false;
230       if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
231       for (size_t i = 0; i < lhs->type_params.size(); ++i) {
232         if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
233           return false;
234         }
235         equal_map_[lhs->type_params[i]] = rhs->type_params[i];
236         // set up type parameter equal
237         if (lhs->type_params[i]->kind == Kind::kShapeVar) {
238           // map variable
239           equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
240         }
241       }
242       for (size_t i = 0; i < lhs->arg_types.size(); i++) {
243         if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
244       }
245       if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
246       for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
247         if (!TypeEqual(lhs->type_constraints[i],
248                        rhs->type_constraints[i])) {
249           return false;
250         }
251       }
252       return true;
253     } else {
254       return false;
255     }
256   }
257 
VisitType_(const TypeRelationNode * lhs,const Type & other)258   bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
259     if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
260       if (lhs->func->name != rhs->func->name) return false;
261       if (lhs->num_inputs != rhs->num_inputs) return false;
262       if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
263       if (lhs->args.size() != rhs->args.size()) return false;
264       for (size_t i = 0; i < lhs->args.size(); ++i) {
265         if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
266       }
267       return true;
268     } else {
269       return false;
270     }
271   }
272 
VisitType_(const TupleTypeNode * lhs,const Type & other)273   bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
274     if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
275       if (lhs->fields.size() != rhs->fields.size()) return false;
276       for (size_t i = 0; i < lhs->fields.size(); ++i) {
277         if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
278       }
279       return true;
280     } else {
281       return false;
282     }
283   }
284 
VisitType_(const RefTypeNode * lhs,const Type & other)285   bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
286     if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
287       return TypeEqual(lhs->value, rhs->value);
288     }
289     return false;
290   }
291 
VisitType_(const GlobalTypeVarNode * lhs,const Type & other)292   bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
293     return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
294   }
295 
VisitType_(const TypeCallNode * lhs,const Type & other)296   bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
297     const TypeCallNode* rhs = other.as<TypeCallNode>();
298     if (rhs == nullptr
299         || lhs->args.size() != rhs->args.size()
300         || !TypeEqual(lhs->func, rhs->func)) {
301       return false;
302     }
303 
304     for (size_t i = 0; i < lhs->args.size(); ++i) {
305       if (!TypeEqual(lhs->args[i], rhs->args[i])) {
306         return false;
307       }
308     }
309     return true;
310   }
311 
VisitType_(const TypeDataNode * lhs,const Type & other)312   bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
313     const TypeDataNode* rhs = other.as<TypeDataNode>();
314     if (rhs == nullptr
315         || lhs->type_vars.size() != rhs->type_vars.size()
316         || !TypeEqual(lhs->header, rhs->header)) {
317       return false;
318     }
319     for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
320       if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
321         return false;
322       }
323     }
324     for (size_t i = 0; i < lhs->constructors.size(); ++i) {
325       if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
326         return false;
327       }
328     }
329     return true;
330   }
331 
332   // Expr equal checking.
NDArrayEqual(const runtime::NDArray & lhs,const runtime::NDArray & rhs)333   bool NDArrayEqual(const runtime::NDArray& lhs,
334                     const runtime::NDArray& rhs) {
335     if (lhs.defined() != rhs.defined()) {
336       return false;
337     } else if (lhs.same_as(rhs)) {
338       return true;
339     } else {
340       auto ldt = lhs->dtype;
341       auto rdt = rhs->dtype;
342       CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
343       CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
344       if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
345         size_t data_size = runtime::GetDataSize(*lhs.operator->());
346         return std::memcmp(lhs->data, rhs->data, data_size) == 0;
347       } else {
348         return false;
349       }
350     }
351   }
352   // merge declaration of two variables together.
MergeVarDecl(const Var & lhs,const Var & rhs)353   bool MergeVarDecl(const Var& lhs, const Var& rhs) {
354     if (lhs.same_as(rhs)) return true;
355     if (!lhs.defined() || !rhs.defined()) return false;
356     if (!TypeEqual(lhs->type_annotation,
357                    rhs->type_annotation)) return false;
358     CHECK(!equal_map_.count(lhs))
359         << "Duplicated declaration of variable " <<  lhs;
360     equal_map_[lhs] = rhs;
361     return true;
362   }
363 
VisitExpr_(const VarNode * lhs,const Expr & other)364   bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
365     // This function will only be triggered if we are matching free variables.
366     if (const VarNode* rhs = other.as<VarNode>()) {
367       if (lhs->name_hint() != rhs->name_hint()) return false;
368       if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
369       return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
370     } else {
371       return false;
372     }
373   }
374 
VisitExpr_(const GlobalVarNode * lhs,const Expr & other)375   bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
376     if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
377       // use name equality for global var for now.
378       return lhs->name_hint == rhs->name_hint;
379     }
380     return false;
381   }
382 
VisitExpr_(const TupleNode * lhs,const Expr & other)383   bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
384     if (const TupleNode* rhs = other.as<TupleNode>()) {
385       if (lhs->fields.size() != rhs->fields.size()) return false;
386       for (size_t i = 0; i < lhs->fields.size(); ++i) {
387         if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
388       }
389       return true;
390     } else {
391       return false;
392     }
393   }
394 
VisitExpr_(const FunctionNode * lhs,const Expr & other)395   bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
396     if (const FunctionNode* rhs = other.as<FunctionNode>()) {
397       if (lhs->params.size() != rhs->params.size()) return false;
398       if (lhs->type_params.size() != rhs->type_params.size()) return false;
399       // map type parameter to be the same
400       for (size_t i = 0; i < lhs->type_params.size(); ++i) {
401         if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
402         equal_map_[lhs->type_params[i]] = rhs->type_params[i];
403       }
404       // check parameter type annotations
405       for (size_t i = 0; i < lhs->params.size(); ++i) {
406         if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
407       }
408       // check return types.
409       if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
410       if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
411       return ExprEqual(lhs->body, rhs->body);
412     } else {
413       return false;
414     }
415   }
416 
VisitExpr_(const CallNode * lhs,const Expr & other)417   bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
418     if (const CallNode* rhs = other.as<CallNode>()) {
419       if (!ExprEqual(lhs->op, rhs->op)) return false;
420       if (lhs->args.size() != rhs->args.size()) return false;
421       // skip type_args check for primitive ops.
422       bool is_primitive = IsPrimitiveOp(lhs->op);
423       if (!is_primitive) {
424         if (lhs->type_args.size() != rhs->type_args.size()) {
425           return false;
426         }
427       }
428       for (size_t i = 0; i < lhs->args.size(); ++i) {
429         if (!ExprEqual(lhs->args[i], rhs->args[i])) {
430           return false;
431         }
432       }
433 
434       if (!is_primitive) {
435         for (size_t i = 0; i < lhs->type_args.size(); ++i) {
436           if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
437         }
438       }
439       return AttrEqual(lhs->attrs, rhs->attrs);
440     } else {
441       return false;
442     }
443   }
444 
VisitExpr_(const LetNode * lhs,const Expr & other)445   bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
446     if (const LetNode* rhs = other.as<LetNode>()) {
447       if (!MergeVarDecl(lhs->var, rhs->var)) return false;
448       if (!ExprEqual(lhs->value, rhs->value)) return false;
449       return ExprEqual(lhs->body, rhs->body);
450     } else {
451       return false;
452     }
453   }
454 
VisitExpr_(const IfNode * lhs,const Expr & other)455   bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
456     if (const IfNode* rhs = other.as<IfNode>()) {
457       return ExprEqual(lhs->cond, rhs->cond) &&
458           ExprEqual(lhs->true_branch, rhs->true_branch) &&
459           ExprEqual(lhs->false_branch, rhs->false_branch);
460     } else {
461       return false;
462     }
463   }
464 
VisitExpr_(const OpNode * lhs,const Expr & other)465   bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
466     return lhs == other.get();
467   }
468 
VisitExpr_(const ConstantNode * lhs,const Expr & other)469   bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
470     if (const ConstantNode* rhs = other.as<ConstantNode>()) {
471       return NDArrayEqual(lhs->data, rhs->data);
472     } else {
473       return false;
474     }
475   }
476 
VisitExpr_(const TupleGetItemNode * lhs,const Expr & other)477   bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
478     if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
479       return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
480     } else {
481       return false;
482     }
483   }
484 
VisitExpr_(const RefCreateNode * lhs,const Expr & other)485   bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
486     if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
487       return ExprEqual(lhs->value, rhs->value);
488     } else {
489       return false;
490     }
491   }
492 
VisitExpr_(const RefReadNode * lhs,const Expr & other)493   bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
494     if (const RefReadNode* rhs = other.as<RefReadNode>()) {
495       return ExprEqual(lhs->ref, rhs->ref);
496     } else {
497       return false;
498     }
499   }
500 
VisitExpr_(const RefWriteNode * lhs,const Expr & other)501   bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
502     if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
503       return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, rhs->value);
504     } else {
505       return false;
506     }
507   }
508 
VisitExpr_(const ConstructorNode * lhs,const Expr & other)509   bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
510     if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
511       return lhs->name_hint == rhs->name_hint;
512     }
513     return false;
514   }
515 
ClauseEqual(const Clause & lhs,const Clause & rhs)516   bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
517     return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
518   }
519 
PatternEqual(const Pattern & lhs,const Pattern & rhs)520   bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
521     return Compare(VisitPattern(lhs, rhs), lhs, rhs);
522   }
523 
VisitPattern_(const PatternWildcardNode * lhs,const Pattern & other)524   bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
525     return other.as<PatternWildcardNode>();
526   }
527 
VisitPattern_(const PatternVarNode * lhs,const Pattern & other)528   bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
529     if (const auto* rhs = other.as<PatternVarNode>()) {
530       return MergeVarDecl(lhs->var, rhs->var);
531     }
532     return false;
533   }
534 
VisitPattern_(const PatternConstructorNode * lhs,const Pattern & other)535   bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) final {
536     const auto* rhs = other.as<PatternConstructorNode>();
537     if (rhs == nullptr
538         || !ExprEqual(lhs->constructor, rhs->constructor)
539         || lhs->patterns.size() != rhs->patterns.size()) {
540       return false;
541     }
542 
543     for (size_t i = 0; i < lhs->patterns.size(); i++) {
544       if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
545         return false;
546       }
547     }
548     return true;
549   }
550 
VisitPattern_(const PatternTupleNode * lhs,const Pattern & other)551   bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
552     const auto* rhs = other.as<PatternTupleNode>();
553     if (rhs == nullptr
554         || lhs->patterns.size() != rhs->patterns.size()) {
555       return false;
556     }
557 
558     for (size_t i = 0; i < lhs->patterns.size(); i++) {
559       if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
560         return false;
561       }
562     }
563     return true;
564   }
565 
VisitExpr_(const MatchNode * lhs,const Expr & other)566   bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
567     const MatchNode* rhs = other.as<MatchNode>();
568 
569     if (rhs == nullptr
570         || !ExprEqual(lhs->data, rhs->data)
571         || lhs->clauses.size() != rhs->clauses.size()
572         || lhs->complete != rhs->complete) {
573       return false;
574     }
575 
576     for (size_t i = 0; i < lhs->clauses.size(); ++i) {
577       if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
578         return false;
579       }
580     }
581     return true;
582   }
583 
584  private:
585   // whether to map open terms.
586   bool map_free_var_;
587   // if in assert mode, must return true, and will throw error otherwise.
588   bool assert_mode_;
589   // renaming of NodeRef to indicate two nodes equals to each other
590   std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_;
591 };
592 
AlphaEqual(const Type & lhs,const Type & rhs)593 bool AlphaEqual(const Type& lhs, const Type& rhs) {
594   return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
595 }
596 
AlphaEqual(const Expr & lhs,const Expr & rhs)597 bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
598   return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
599 }
600 
601 // TODO(@jroesch): move to correct namespace?
602 TVM_REGISTER_API("relay._make._alpha_equal")
__anon41ef03230402(NodeRef a, NodeRef b) 603 .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
604   return AlphaEqualHandler(false, false).Equal(a, b);
605 });
606 
607 TVM_REGISTER_API("relay._make._assert_alpha_equal")
__anon41ef03230502(NodeRef a, NodeRef b) 608 .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
609   bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
610   CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
611 });
612 
613 TVM_REGISTER_API("relay._make._graph_equal")
__anon41ef03230602(NodeRef a, NodeRef b) 614 .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
615   return AlphaEqualHandler(true, false).Equal(a, b);
616 });
617 
618 TVM_REGISTER_API("relay._make._assert_graph_equal")
__anon41ef03230702(NodeRef a, NodeRef b) 619 .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
620   bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
621   CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
622 });
623 
624 }  // namespace relay
625 }  // namespace tvm
626