1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/ir/pattern_functor.cc
22  * \brief Implementations of visitors and mutators for ADT patterns.
23  */
24 
25 #include <tvm/relay/pattern_functor.h>
26 
27 namespace tvm {
28 namespace relay {
29 
Mutate(const Pattern & pat)30 Pattern PatternMutator::Mutate(const Pattern& pat) {
31   return (*this)(pat);
32 }
33 
VisitPattern_(const PatternWildcardNode * op)34 Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) {
35   return GetRef<Pattern>(op);
36 }
37 
VisitPattern_(const PatternVarNode * op)38 Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) {
39   return PatternVarNode::make(VisitVar(op->var));
40 }
41 
VisitPattern_(const PatternConstructorNode * op)42 Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
43   std::vector<Pattern> pat;
44   for (const auto& p : op->patterns) {
45     pat.push_back(VisitPattern(p));
46   }
47   return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
48 }
49 
VisitPattern_(const PatternTupleNode * op)50 Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
51   std::vector<Pattern> pat;
52   for (const auto& p : op->patterns) {
53     pat.push_back(VisitPattern(p));
54   }
55   return PatternTupleNode::make(pat);
56 }
57 
VisitType(const Type & t)58 Type PatternMutator::VisitType(const Type& t) {
59   return t;
60 }
61 
VisitVar(const Var & v)62 Var PatternMutator::VisitVar(const Var& v) {
63   if (var_map_.count(v) == 0) {
64     var_map_.insert(std::pair<Var, Var>(v,
65                                         VarNode::make(v->name_hint(),
66                                                       VisitType(v->type_annotation))));
67   }
68   return var_map_.at(v);
69 }
70 
VisitConstructor(const Constructor & v)71 Constructor PatternMutator::VisitConstructor(const Constructor& v) {
72   return v;
73 }
74 
VisitPattern_(const PatternWildcardNode * op)75 void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { }
76 
VisitPattern_(const PatternVarNode * op)77 void PatternVisitor::VisitPattern_(const PatternVarNode* op) {
78   VisitVar(op->var);
79 }
80 
VisitPattern_(const PatternConstructorNode * op)81 void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
82   VisitConstructor(op->constructor);
83   for (const auto& p : op->patterns) {
84     VisitPattern(p);
85   }
86 }
87 
VisitPattern_(const PatternTupleNode * op)88 void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
89   for (const auto& p : op->patterns) {
90     VisitPattern(p);
91   }
92 }
93 
VisitType(const Type & t)94 void PatternVisitor::VisitType(const Type& t) { }
95 
VisitVar(const Var & v)96 void PatternVisitor::VisitVar(const Var& v) {
97   VisitType(v->type_annotation);
98 }
99 
VisitConstructor(const Constructor & c)100 void PatternVisitor::VisitConstructor(const Constructor& c) {
101   for (const auto& inp : c->inputs) {
102     VisitType(inp);
103   }
104 }
105 
106 }  // namespace relay
107 }  // namespace tvm
108