1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/tvm/relay/ir/alpha_equal.cc
22 * \brief Alpha equality check by deep comparing two nodes.
23 */
24 #include <tvm/ir_pass.h>
25 #include <tvm/relay/expr_functor.h>
26 #include <tvm/relay/pattern_functor.h>
27 #include <tvm/runtime/ndarray.h>
28 #include <tvm/relay/analysis.h>
29 #include <tvm/relay/op_attr_types.h>
30 #include <tvm/relay/attrs/nn.h>
31 #include "type_functor.h"
32 #include "../../lang/attr_functor.h"
33 namespace tvm {
34 namespace relay {
35
36 // Alpha Equal handler for Relay.
37 class AlphaEqualHandler:
38 public AttrsEqualHandler,
39 public TypeFunctor<bool(const Type&, const Type&)>,
40 public ExprFunctor<bool(const Expr&, const Expr&)>,
41 public PatternFunctor<bool(const Pattern&, const Pattern&)> {
42 public:
AlphaEqualHandler(bool map_free_var,bool assert_mode)43 explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
44 : map_free_var_(map_free_var), assert_mode_(assert_mode) { }
45
46 /*!
47 * Check equality of two nodes.
48 * \param lhs The left hand operand.
49 * \param rhs The right hand operand.
50 * \return The comparison result.
51 */
Equal(const NodeRef & lhs,const NodeRef & rhs)52 bool Equal(const NodeRef& lhs, const NodeRef& rhs) {
53 if (lhs.same_as(rhs)) return true;
54 if (!lhs.defined() || !rhs.defined()) return false;
55 if (lhs->IsInstance<TypeNode>()) {
56 if (!rhs->IsInstance<TypeNode>()) return false;
57 return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
58 }
59 if (lhs->IsInstance<ExprNode>()) {
60 if (!rhs->IsInstance<ExprNode>()) return false;
61 return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
62 }
63 if (const auto lhsm = lhs.as<ModuleNode>()) {
64 auto rhsm = rhs.as<ModuleNode>();
65 if (!rhsm) return false;
66 if (lhsm->functions.size() != rhsm->functions.size()) return false;
67 for (const auto& p : lhsm->functions) {
68 if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
69 }
70 if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
71 for (const auto& p : lhsm->type_definitions) {
72 if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
73 !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
74 return false;
75 }
76 }
77 return true;
78 }
79 return AttrEqual(lhs, rhs);
80 }
81
DoubleEqual(double l,double r)82 bool DoubleEqual(double l, double r) {
83 return true;
84 }
85 /*!
86 * Check equality of two attributes.
87 * \param lhs The left hand operand.
88 * \param rhs The right hand operand.
89 * \return The comparison result.
90 */
AttrEqual(const NodeRef & lhs,const NodeRef & rhs)91 bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
92 auto compute = [&]() {
93 if (&lhs == &rhs) return true;
94 if (auto lhsd = lhs.as<DictAttrsNode>()) {
95 auto rhsd = rhs.as<DictAttrsNode>();
96 if (!rhsd) return false;
97 if (lhsd->dict.size() != rhsd->dict.size()) return false;
98 for (const auto& k : lhsd->dict) {
99 if (!Equal(k.second, rhsd->dict[k.first])) return false;
100 }
101 return true;
102 }
103 if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
104 auto rhsbn = rhs.as<BatchNormAttrs>();
105 if (!rhsbn) return false;
106 return (lhsbn->axis == rhsbn->axis)
107 && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
108 && (lhsbn->center == rhsbn->center)
109 && (lhsbn->scale == rhsbn->scale);
110 }
111 return AttrsEqualHandler::Equal(lhs, rhs);
112 };
113 return Compare(compute(), lhs, rhs);
114 }
115 /*!
116 * Check equality of two types.
117 * \param lhs The left hand operand.
118 * \param rhs The right hand operand.
119 * \return the comparison result.
120 */
TypeEqual(const Type & lhs,const Type & rhs)121 bool TypeEqual(const Type& lhs, const Type& rhs) {
122 auto compute = [&]() {
123 if (lhs.same_as(rhs)) return true;
124 if (!lhs.defined() || !rhs.defined()) return false;
125 return this->VisitType(lhs, rhs);
126 };
127 return Compare(compute(), lhs, rhs);
128 }
129
Compare(bool result,const NodeRef & lhs,const NodeRef & rhs)130 bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
131 if (assert_mode_) {
132 CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
133 }
134 return result;
135 }
136 /*!
137 * Check equality of two expressions.
138 *
139 * \note We run graph structural equality checking when comparing two Exprs.
140 * This means that AlphaEqualHandler can only be used once for each pair.
141 * The equality checker checks data-flow equvalence of the Expr DAG.
142 * This function also runs faster as it memomizes equal_map.
143 *
144 * \param lhs The left hand operand.
145 * \param rhs The right hand operand.
146 * \return The comparison result.
147 */
ExprEqual(const Expr & lhs,const Expr & rhs)148 bool ExprEqual(const Expr& lhs, const Expr& rhs) {
149 auto compute = [&]() {
150 if (lhs.same_as(rhs)) return true;
151 if (!lhs.defined() || !rhs.defined()) return false;
152 auto it = equal_map_.find(lhs);
153 if (it != equal_map_.end()) {
154 return it->second.same_as(rhs);
155 }
156 if (this->VisitExpr(lhs, rhs)) {
157 equal_map_[lhs] = rhs;
158 return true;
159 } else {
160 return false;
161 }
162 };
163 return Compare(compute(), lhs, rhs);
164 }
165
166 protected:
167 /*!
168 * \brief Check if data type equals each other.
169 * \param lhs The left hand operand.
170 * \param rhs The right hand operand.
171 * \return The compare result.
172 */
DataTypeEqual(const DataType & lhs,const DataType & rhs)173 bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
174 return lhs == rhs;
175 }
176 /*!
177 * \brief Check Equality of leaf node of the graph.
178 * if map_free_var_ is set to true, try to map via equal node.
179 * \param lhs The left hand operand.
180 * \param rhs The right hand operand.
181 * \return The compare result.
182 */
LeafNodeEqual(const ObjectRef & lhs,const ObjectRef & rhs)183 bool LeafNodeEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
184 if (lhs.same_as(rhs)) return true;
185 auto it = equal_map_.find(lhs);
186 if (it != equal_map_.end()) {
187 return it->second.same_as(rhs);
188 } else {
189 if (map_free_var_) {
190 if (lhs->type_index() != rhs->type_index()) return false;
191 equal_map_[lhs] = rhs;
192 return true;
193 } else {
194 return false;
195 }
196 }
197 }
198 using AttrsEqualHandler::VisitAttr_;
VisitAttr_(const Variable * lhs,const ObjectRef & other)199 bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final {
200 return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
201 }
202
203 // Type equality
VisitType_(const TensorTypeNode * lhs,const Type & other)204 bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
205 if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
206 return (lhs->dtype == rhs->dtype &&
207 AttrEqual(lhs->shape, rhs->shape));
208 } else {
209 return false;
210 }
211 }
212
VisitType_(const IncompleteTypeNode * lhs,const Type & other)213 bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
214 return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
215 }
216
VisitType_(const TypeVarNode * lhs,const Type & other)217 bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
218 if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
219 if (lhs->kind != rhs->kind) return false;
220 return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
221 } else {
222 return false;
223 }
224 }
225
VisitType_(const FuncTypeNode * lhs,const Type & other)226 bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
227 if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
228 if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
229 if (lhs->type_params.size() != rhs->type_params.size()) return false;
230 if (lhs->type_constraints.size() != rhs->type_constraints.size()) return false;
231 for (size_t i = 0; i < lhs->type_params.size(); ++i) {
232 if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
233 return false;
234 }
235 equal_map_[lhs->type_params[i]] = rhs->type_params[i];
236 // set up type parameter equal
237 if (lhs->type_params[i]->kind == Kind::kShapeVar) {
238 // map variable
239 equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
240 }
241 }
242 for (size_t i = 0; i < lhs->arg_types.size(); i++) {
243 if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
244 }
245 if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
246 for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
247 if (!TypeEqual(lhs->type_constraints[i],
248 rhs->type_constraints[i])) {
249 return false;
250 }
251 }
252 return true;
253 } else {
254 return false;
255 }
256 }
257
VisitType_(const TypeRelationNode * lhs,const Type & other)258 bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
259 if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
260 if (lhs->func->name != rhs->func->name) return false;
261 if (lhs->num_inputs != rhs->num_inputs) return false;
262 if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
263 if (lhs->args.size() != rhs->args.size()) return false;
264 for (size_t i = 0; i < lhs->args.size(); ++i) {
265 if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
266 }
267 return true;
268 } else {
269 return false;
270 }
271 }
272
VisitType_(const TupleTypeNode * lhs,const Type & other)273 bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
274 if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
275 if (lhs->fields.size() != rhs->fields.size()) return false;
276 for (size_t i = 0; i < lhs->fields.size(); ++i) {
277 if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
278 }
279 return true;
280 } else {
281 return false;
282 }
283 }
284
VisitType_(const RefTypeNode * lhs,const Type & other)285 bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
286 if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
287 return TypeEqual(lhs->value, rhs->value);
288 }
289 return false;
290 }
291
VisitType_(const GlobalTypeVarNode * lhs,const Type & other)292 bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
293 return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
294 }
295
VisitType_(const TypeCallNode * lhs,const Type & other)296 bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
297 const TypeCallNode* rhs = other.as<TypeCallNode>();
298 if (rhs == nullptr
299 || lhs->args.size() != rhs->args.size()
300 || !TypeEqual(lhs->func, rhs->func)) {
301 return false;
302 }
303
304 for (size_t i = 0; i < lhs->args.size(); ++i) {
305 if (!TypeEqual(lhs->args[i], rhs->args[i])) {
306 return false;
307 }
308 }
309 return true;
310 }
311
VisitType_(const TypeDataNode * lhs,const Type & other)312 bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
313 const TypeDataNode* rhs = other.as<TypeDataNode>();
314 if (rhs == nullptr
315 || lhs->type_vars.size() != rhs->type_vars.size()
316 || !TypeEqual(lhs->header, rhs->header)) {
317 return false;
318 }
319 for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
320 if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
321 return false;
322 }
323 }
324 for (size_t i = 0; i < lhs->constructors.size(); ++i) {
325 if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
326 return false;
327 }
328 }
329 return true;
330 }
331
332 // Expr equal checking.
NDArrayEqual(const runtime::NDArray & lhs,const runtime::NDArray & rhs)333 bool NDArrayEqual(const runtime::NDArray& lhs,
334 const runtime::NDArray& rhs) {
335 if (lhs.defined() != rhs.defined()) {
336 return false;
337 } else if (lhs.same_as(rhs)) {
338 return true;
339 } else {
340 auto ldt = lhs->dtype;
341 auto rdt = rhs->dtype;
342 CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
343 CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
344 if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
345 size_t data_size = runtime::GetDataSize(*lhs.operator->());
346 return std::memcmp(lhs->data, rhs->data, data_size) == 0;
347 } else {
348 return false;
349 }
350 }
351 }
352 // merge declaration of two variables together.
MergeVarDecl(const Var & lhs,const Var & rhs)353 bool MergeVarDecl(const Var& lhs, const Var& rhs) {
354 if (lhs.same_as(rhs)) return true;
355 if (!lhs.defined() || !rhs.defined()) return false;
356 if (!TypeEqual(lhs->type_annotation,
357 rhs->type_annotation)) return false;
358 CHECK(!equal_map_.count(lhs))
359 << "Duplicated declaration of variable " << lhs;
360 equal_map_[lhs] = rhs;
361 return true;
362 }
363
VisitExpr_(const VarNode * lhs,const Expr & other)364 bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
365 // This function will only be triggered if we are matching free variables.
366 if (const VarNode* rhs = other.as<VarNode>()) {
367 if (lhs->name_hint() != rhs->name_hint()) return false;
368 if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
369 return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
370 } else {
371 return false;
372 }
373 }
374
VisitExpr_(const GlobalVarNode * lhs,const Expr & other)375 bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
376 if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
377 // use name equality for global var for now.
378 return lhs->name_hint == rhs->name_hint;
379 }
380 return false;
381 }
382
VisitExpr_(const TupleNode * lhs,const Expr & other)383 bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
384 if (const TupleNode* rhs = other.as<TupleNode>()) {
385 if (lhs->fields.size() != rhs->fields.size()) return false;
386 for (size_t i = 0; i < lhs->fields.size(); ++i) {
387 if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
388 }
389 return true;
390 } else {
391 return false;
392 }
393 }
394
VisitExpr_(const FunctionNode * lhs,const Expr & other)395 bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
396 if (const FunctionNode* rhs = other.as<FunctionNode>()) {
397 if (lhs->params.size() != rhs->params.size()) return false;
398 if (lhs->type_params.size() != rhs->type_params.size()) return false;
399 // map type parameter to be the same
400 for (size_t i = 0; i < lhs->type_params.size(); ++i) {
401 if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return false;
402 equal_map_[lhs->type_params[i]] = rhs->type_params[i];
403 }
404 // check parameter type annotations
405 for (size_t i = 0; i < lhs->params.size(); ++i) {
406 if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
407 }
408 // check return types.
409 if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
410 if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
411 return ExprEqual(lhs->body, rhs->body);
412 } else {
413 return false;
414 }
415 }
416
VisitExpr_(const CallNode * lhs,const Expr & other)417 bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
418 if (const CallNode* rhs = other.as<CallNode>()) {
419 if (!ExprEqual(lhs->op, rhs->op)) return false;
420 if (lhs->args.size() != rhs->args.size()) return false;
421 // skip type_args check for primitive ops.
422 bool is_primitive = IsPrimitiveOp(lhs->op);
423 if (!is_primitive) {
424 if (lhs->type_args.size() != rhs->type_args.size()) {
425 return false;
426 }
427 }
428 for (size_t i = 0; i < lhs->args.size(); ++i) {
429 if (!ExprEqual(lhs->args[i], rhs->args[i])) {
430 return false;
431 }
432 }
433
434 if (!is_primitive) {
435 for (size_t i = 0; i < lhs->type_args.size(); ++i) {
436 if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
437 }
438 }
439 return AttrEqual(lhs->attrs, rhs->attrs);
440 } else {
441 return false;
442 }
443 }
444
VisitExpr_(const LetNode * lhs,const Expr & other)445 bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
446 if (const LetNode* rhs = other.as<LetNode>()) {
447 if (!MergeVarDecl(lhs->var, rhs->var)) return false;
448 if (!ExprEqual(lhs->value, rhs->value)) return false;
449 return ExprEqual(lhs->body, rhs->body);
450 } else {
451 return false;
452 }
453 }
454
VisitExpr_(const IfNode * lhs,const Expr & other)455 bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
456 if (const IfNode* rhs = other.as<IfNode>()) {
457 return ExprEqual(lhs->cond, rhs->cond) &&
458 ExprEqual(lhs->true_branch, rhs->true_branch) &&
459 ExprEqual(lhs->false_branch, rhs->false_branch);
460 } else {
461 return false;
462 }
463 }
464
VisitExpr_(const OpNode * lhs,const Expr & other)465 bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
466 return lhs == other.get();
467 }
468
VisitExpr_(const ConstantNode * lhs,const Expr & other)469 bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
470 if (const ConstantNode* rhs = other.as<ConstantNode>()) {
471 return NDArrayEqual(lhs->data, rhs->data);
472 } else {
473 return false;
474 }
475 }
476
VisitExpr_(const TupleGetItemNode * lhs,const Expr & other)477 bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
478 if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
479 return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
480 } else {
481 return false;
482 }
483 }
484
VisitExpr_(const RefCreateNode * lhs,const Expr & other)485 bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
486 if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
487 return ExprEqual(lhs->value, rhs->value);
488 } else {
489 return false;
490 }
491 }
492
VisitExpr_(const RefReadNode * lhs,const Expr & other)493 bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
494 if (const RefReadNode* rhs = other.as<RefReadNode>()) {
495 return ExprEqual(lhs->ref, rhs->ref);
496 } else {
497 return false;
498 }
499 }
500
VisitExpr_(const RefWriteNode * lhs,const Expr & other)501 bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
502 if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
503 return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, rhs->value);
504 } else {
505 return false;
506 }
507 }
508
VisitExpr_(const ConstructorNode * lhs,const Expr & other)509 bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
510 if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
511 return lhs->name_hint == rhs->name_hint;
512 }
513 return false;
514 }
515
ClauseEqual(const Clause & lhs,const Clause & rhs)516 bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
517 return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
518 }
519
PatternEqual(const Pattern & lhs,const Pattern & rhs)520 bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
521 return Compare(VisitPattern(lhs, rhs), lhs, rhs);
522 }
523
VisitPattern_(const PatternWildcardNode * lhs,const Pattern & other)524 bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
525 return other.as<PatternWildcardNode>();
526 }
527
VisitPattern_(const PatternVarNode * lhs,const Pattern & other)528 bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
529 if (const auto* rhs = other.as<PatternVarNode>()) {
530 return MergeVarDecl(lhs->var, rhs->var);
531 }
532 return false;
533 }
534
VisitPattern_(const PatternConstructorNode * lhs,const Pattern & other)535 bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) final {
536 const auto* rhs = other.as<PatternConstructorNode>();
537 if (rhs == nullptr
538 || !ExprEqual(lhs->constructor, rhs->constructor)
539 || lhs->patterns.size() != rhs->patterns.size()) {
540 return false;
541 }
542
543 for (size_t i = 0; i < lhs->patterns.size(); i++) {
544 if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
545 return false;
546 }
547 }
548 return true;
549 }
550
VisitPattern_(const PatternTupleNode * lhs,const Pattern & other)551 bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
552 const auto* rhs = other.as<PatternTupleNode>();
553 if (rhs == nullptr
554 || lhs->patterns.size() != rhs->patterns.size()) {
555 return false;
556 }
557
558 for (size_t i = 0; i < lhs->patterns.size(); i++) {
559 if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
560 return false;
561 }
562 }
563 return true;
564 }
565
VisitExpr_(const MatchNode * lhs,const Expr & other)566 bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
567 const MatchNode* rhs = other.as<MatchNode>();
568
569 if (rhs == nullptr
570 || !ExprEqual(lhs->data, rhs->data)
571 || lhs->clauses.size() != rhs->clauses.size()
572 || lhs->complete != rhs->complete) {
573 return false;
574 }
575
576 for (size_t i = 0; i < lhs->clauses.size(); ++i) {
577 if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
578 return false;
579 }
580 }
581 return true;
582 }
583
584 private:
585 // whether to map open terms.
586 bool map_free_var_;
587 // if in assert mode, must return true, and will throw error otherwise.
588 bool assert_mode_;
589 // renaming of NodeRef to indicate two nodes equals to each other
590 std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_;
591 };
592
AlphaEqual(const Type & lhs,const Type & rhs)593 bool AlphaEqual(const Type& lhs, const Type& rhs) {
594 return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
595 }
596
AlphaEqual(const Expr & lhs,const Expr & rhs)597 bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
598 return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
599 }
600
601 // TODO(@jroesch): move to correct namespace?
602 TVM_REGISTER_API("relay._make._alpha_equal")
__anon41ef03230402(NodeRef a, NodeRef b) 603 .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
604 return AlphaEqualHandler(false, false).Equal(a, b);
605 });
606
607 TVM_REGISTER_API("relay._make._assert_alpha_equal")
__anon41ef03230502(NodeRef a, NodeRef b) 608 .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
609 bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
610 CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
611 });
612
613 TVM_REGISTER_API("relay._make._graph_equal")
__anon41ef03230602(NodeRef a, NodeRef b) 614 .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
615 return AlphaEqualHandler(true, false).Equal(a, b);
616 });
617
618 TVM_REGISTER_API("relay._make._assert_graph_equal")
__anon41ef03230702(NodeRef a, NodeRef b) 619 .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
620 bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
621 CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
622 });
623
624 } // namespace relay
625 } // namespace tvm
626