1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file src/relay/ir/pattern_functor.cc 22 * \brief Implementations of visitors and mutators for ADT patterns. 23 */ 24 25 #include <tvm/relay/pattern_functor.h> 26 27 namespace tvm { 28 namespace relay { 29 Mutate(const Pattern & pat)30Pattern PatternMutator::Mutate(const Pattern& pat) { 31 return (*this)(pat); 32 } 33 VisitPattern_(const PatternWildcardNode * op)34Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { 35 return GetRef<Pattern>(op); 36 } 37 VisitPattern_(const PatternVarNode * op)38Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { 39 return PatternVarNode::make(VisitVar(op->var)); 40 } 41 VisitPattern_(const PatternConstructorNode * op)42Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) { 43 std::vector<Pattern> pat; 44 for (const auto& p : op->patterns) { 45 pat.push_back(VisitPattern(p)); 46 } 47 return PatternConstructorNode::make(VisitConstructor(op->constructor), pat); 48 } 49 VisitPattern_(const PatternTupleNode * op)50Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { 51 std::vector<Pattern> pat; 52 for (const auto& p : op->patterns) { 53 pat.push_back(VisitPattern(p)); 54 } 55 return PatternTupleNode::make(pat); 56 } 57 VisitType(const Type & t)58Type PatternMutator::VisitType(const Type& t) { 59 return t; 60 } 61 VisitVar(const Var & v)62Var PatternMutator::VisitVar(const Var& v) { 63 if (var_map_.count(v) == 0) { 64 var_map_.insert(std::pair<Var, Var>(v, 65 VarNode::make(v->name_hint(), 66 VisitType(v->type_annotation)))); 67 } 68 return var_map_.at(v); 69 } 70 VisitConstructor(const Constructor & v)71Constructor PatternMutator::VisitConstructor(const Constructor& v) { 72 return v; 73 } 74 VisitPattern_(const PatternWildcardNode * op)75void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } 76 VisitPattern_(const PatternVarNode * op)77void PatternVisitor::VisitPattern_(const PatternVarNode* op) { 78 VisitVar(op->var); 79 } 80 VisitPattern_(const PatternConstructorNode * op)81void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { 82 VisitConstructor(op->constructor); 83 for (const auto& p : op->patterns) { 84 VisitPattern(p); 85 } 86 } 87 VisitPattern_(const PatternTupleNode * op)88void PatternVisitor::VisitPattern_(const PatternTupleNode* op) { 89 for (const auto& p : op->patterns) { 90 VisitPattern(p); 91 } 92 } 93 VisitType(const Type & t)94void PatternVisitor::VisitType(const Type& t) { } 95 VisitVar(const Var & v)96void PatternVisitor::VisitVar(const Var& v) { 97 VisitType(v->type_annotation); 98 } 99 VisitConstructor(const Constructor & c)100void PatternVisitor::VisitConstructor(const Constructor& c) { 101 for (const auto& inp : c->inputs) { 102 VisitType(inp); 103 } 104 } 105 106 } // namespace relay 107 } // namespace tvm 108