1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/tvm/ir/adt.cc
22  * \brief AST nodes for Relay algebraic data types (ADTs).
23  */
24 #include <tvm/relay/type.h>
25 #include <tvm/relay/adt.h>
26 
27 namespace tvm {
28 namespace relay {
29 
make()30 PatternWildcard PatternWildcardNode::make() {
31   NodePtr<PatternWildcardNode> n = make_node<PatternWildcardNode>();
32   return PatternWildcard(n);
33 }
34 
35 TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
36 
37 TVM_REGISTER_API("relay._make.PatternWildcard")
38 .set_body_typed(PatternWildcardNode::make);
39 
40 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0102(const ObjectRef& ref, IRPrinter* p) 41 .set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, IRPrinter* p) {
42   p->stream << "PatternWildcardNode()";
43 });
44 
make(tvm::relay::Var var)45 PatternVar PatternVarNode::make(tvm::relay::Var var) {
46   NodePtr<PatternVarNode> n = make_node<PatternVarNode>();
47   n->var = std::move(var);
48   return PatternVar(n);
49 }
50 
51 TVM_REGISTER_NODE_TYPE(PatternVarNode);
52 
53 TVM_REGISTER_API("relay._make.PatternVar")
54 .set_body_typed(PatternVarNode::make);
55 
56 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0202(const ObjectRef& ref, IRPrinter* p) 57 .set_dispatch<PatternVarNode>([](const ObjectRef& ref, IRPrinter* p) {
58   auto* node = static_cast<const PatternVarNode*>(ref.get());
59   p->stream << "PatternVarNode(" << node->var << ")";
60 });
61 
make(Constructor constructor,tvm::Array<Pattern> patterns)62 PatternConstructor PatternConstructorNode::make(Constructor constructor,
63                                                 tvm::Array<Pattern> patterns) {
64   NodePtr<PatternConstructorNode> n = make_node<PatternConstructorNode>();
65   n->constructor = std::move(constructor);
66   n->patterns = std::move(patterns);
67   return PatternConstructor(n);
68 }
69 
70 TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
71 
72 TVM_REGISTER_API("relay._make.PatternConstructor")
73 .set_body_typed(PatternConstructorNode::make);
74 
75 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0302(const ObjectRef& ref, IRPrinter* p) 76 .set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
77   auto* node = static_cast<const PatternConstructorNode*>(ref.get());
78   p->stream << "PatternConstructorNode(" << node->constructor
79             << ", " << node->patterns << ")";
80 });
81 
make(tvm::Array<Pattern> patterns)82 PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
83   NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
84   n->patterns = std::move(patterns);
85   return PatternTuple(n);
86 }
87 
88 TVM_REGISTER_NODE_TYPE(PatternTupleNode);
89 
90 TVM_REGISTER_API("relay._make.PatternTuple")
91 .set_body_typed(PatternTupleNode::make);
92 
93 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0402(const ObjectRef& ref, IRPrinter* p) 94 .set_dispatch<PatternTupleNode>([](const ObjectRef& ref, IRPrinter* p) {
95   auto* node = static_cast<const PatternTupleNode*>(ref.get());
96   p->stream << "PatternTupleNode(" << node->patterns << ")";
97 });
98 
make(std::string name_hint,tvm::Array<Type> inputs,GlobalTypeVar belong_to)99 Constructor ConstructorNode::make(std::string name_hint,
100                                   tvm::Array<Type> inputs,
101                                   GlobalTypeVar belong_to) {
102   NodePtr<ConstructorNode> n = make_node<ConstructorNode>();
103   n->name_hint = std::move(name_hint);
104   n->inputs = std::move(inputs);
105   n->belong_to = std::move(belong_to);
106   return Constructor(n);
107 }
108 
109 TVM_REGISTER_NODE_TYPE(ConstructorNode);
110 
111 TVM_REGISTER_API("relay._make.Constructor")
112 .set_body_typed(ConstructorNode::make);
113 
114 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0502(const ObjectRef& ref, IRPrinter* p) 115 .set_dispatch<ConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
116   auto* node = static_cast<const ConstructorNode*>(ref.get());
117   p->stream << "ConstructorNode(" << node->name_hint << ", "
118             << node->inputs << ", " << node->belong_to << ")";
119 });
120 
make(GlobalTypeVar header,tvm::Array<TypeVar> type_vars,tvm::Array<Constructor> constructors)121 TypeData TypeDataNode::make(GlobalTypeVar header,
122                             tvm::Array<TypeVar> type_vars,
123                             tvm::Array<Constructor> constructors) {
124   NodePtr<TypeDataNode> n = make_node<TypeDataNode>();
125   n->header = std::move(header);
126   n->type_vars = std::move(type_vars);
127   n->constructors = std::move(constructors);
128   return TypeData(n);
129 }
130 
131 TVM_REGISTER_NODE_TYPE(TypeDataNode);
132 
133 TVM_REGISTER_API("relay._make.TypeData")
134 .set_body_typed(TypeDataNode::make);
135 
136 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0602(const ObjectRef& ref, IRPrinter* p) 137 .set_dispatch<TypeDataNode>([](const ObjectRef& ref, IRPrinter* p) {
138   auto* node = static_cast<const TypeDataNode*>(ref.get());
139   p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
140             << node->constructors << ")";
141 });
142 
make(Pattern lhs,Expr rhs)143 Clause ClauseNode::make(Pattern lhs, Expr rhs) {
144   NodePtr<ClauseNode> n = make_node<ClauseNode>();
145   n->lhs = std::move(lhs);
146   n->rhs = std::move(rhs);
147   return Clause(n);
148 }
149 
150 TVM_REGISTER_NODE_TYPE(ClauseNode);
151 
152 TVM_REGISTER_API("relay._make.Clause")
153 .set_body_typed(ClauseNode::make);
154 
155 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0702(const ObjectRef& ref, IRPrinter* p) 156 .set_dispatch<ClauseNode>([](const ObjectRef& ref, IRPrinter* p) {
157     auto* node = static_cast<const ClauseNode*>(ref.get());
158   p->stream << "ClauseNode(" << node->lhs << ", "
159             << node->rhs << ")";
160   });
161 
make(Expr data,tvm::Array<Clause> clauses,bool complete)162 Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
163   NodePtr<MatchNode> n = make_node<MatchNode>();
164   n->data = std::move(data);
165   n->clauses = std::move(clauses);
166   n->complete = complete;
167   return Match(n);
168 }
169 
170 TVM_REGISTER_NODE_TYPE(MatchNode);
171 
172 TVM_REGISTER_API("relay._make.Match")
173 .set_body_typed(MatchNode::make);
174 
175 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0802(const ObjectRef& ref, IRPrinter* p) 176 .set_dispatch<MatchNode>([](const ObjectRef& ref, IRPrinter* p) {
177   auto* node = static_cast<const MatchNode*>(ref.get());
178   p->stream << "MatchNode(" << node->data << ", "
179             << node->clauses << ", " << node->complete << ")";
180 });
181 
182 }  // namespace relay
183 }  // namespace tvm
184