1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/tvm/ir/adt.cc
22 * \brief AST nodes for Relay algebraic data types (ADTs).
23 */
24 #include <tvm/relay/type.h>
25 #include <tvm/relay/adt.h>
26
27 namespace tvm {
28 namespace relay {
29
make()30 PatternWildcard PatternWildcardNode::make() {
31 NodePtr<PatternWildcardNode> n = make_node<PatternWildcardNode>();
32 return PatternWildcard(n);
33 }
34
35 TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
36
37 TVM_REGISTER_API("relay._make.PatternWildcard")
38 .set_body_typed(PatternWildcardNode::make);
39
40 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0102(const ObjectRef& ref, IRPrinter* p) 41 .set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, IRPrinter* p) {
42 p->stream << "PatternWildcardNode()";
43 });
44
make(tvm::relay::Var var)45 PatternVar PatternVarNode::make(tvm::relay::Var var) {
46 NodePtr<PatternVarNode> n = make_node<PatternVarNode>();
47 n->var = std::move(var);
48 return PatternVar(n);
49 }
50
51 TVM_REGISTER_NODE_TYPE(PatternVarNode);
52
53 TVM_REGISTER_API("relay._make.PatternVar")
54 .set_body_typed(PatternVarNode::make);
55
56 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0202(const ObjectRef& ref, IRPrinter* p) 57 .set_dispatch<PatternVarNode>([](const ObjectRef& ref, IRPrinter* p) {
58 auto* node = static_cast<const PatternVarNode*>(ref.get());
59 p->stream << "PatternVarNode(" << node->var << ")";
60 });
61
make(Constructor constructor,tvm::Array<Pattern> patterns)62 PatternConstructor PatternConstructorNode::make(Constructor constructor,
63 tvm::Array<Pattern> patterns) {
64 NodePtr<PatternConstructorNode> n = make_node<PatternConstructorNode>();
65 n->constructor = std::move(constructor);
66 n->patterns = std::move(patterns);
67 return PatternConstructor(n);
68 }
69
70 TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
71
72 TVM_REGISTER_API("relay._make.PatternConstructor")
73 .set_body_typed(PatternConstructorNode::make);
74
75 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0302(const ObjectRef& ref, IRPrinter* p) 76 .set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
77 auto* node = static_cast<const PatternConstructorNode*>(ref.get());
78 p->stream << "PatternConstructorNode(" << node->constructor
79 << ", " << node->patterns << ")";
80 });
81
make(tvm::Array<Pattern> patterns)82 PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
83 NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
84 n->patterns = std::move(patterns);
85 return PatternTuple(n);
86 }
87
88 TVM_REGISTER_NODE_TYPE(PatternTupleNode);
89
90 TVM_REGISTER_API("relay._make.PatternTuple")
91 .set_body_typed(PatternTupleNode::make);
92
93 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0402(const ObjectRef& ref, IRPrinter* p) 94 .set_dispatch<PatternTupleNode>([](const ObjectRef& ref, IRPrinter* p) {
95 auto* node = static_cast<const PatternTupleNode*>(ref.get());
96 p->stream << "PatternTupleNode(" << node->patterns << ")";
97 });
98
make(std::string name_hint,tvm::Array<Type> inputs,GlobalTypeVar belong_to)99 Constructor ConstructorNode::make(std::string name_hint,
100 tvm::Array<Type> inputs,
101 GlobalTypeVar belong_to) {
102 NodePtr<ConstructorNode> n = make_node<ConstructorNode>();
103 n->name_hint = std::move(name_hint);
104 n->inputs = std::move(inputs);
105 n->belong_to = std::move(belong_to);
106 return Constructor(n);
107 }
108
109 TVM_REGISTER_NODE_TYPE(ConstructorNode);
110
111 TVM_REGISTER_API("relay._make.Constructor")
112 .set_body_typed(ConstructorNode::make);
113
114 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0502(const ObjectRef& ref, IRPrinter* p) 115 .set_dispatch<ConstructorNode>([](const ObjectRef& ref, IRPrinter* p) {
116 auto* node = static_cast<const ConstructorNode*>(ref.get());
117 p->stream << "ConstructorNode(" << node->name_hint << ", "
118 << node->inputs << ", " << node->belong_to << ")";
119 });
120
make(GlobalTypeVar header,tvm::Array<TypeVar> type_vars,tvm::Array<Constructor> constructors)121 TypeData TypeDataNode::make(GlobalTypeVar header,
122 tvm::Array<TypeVar> type_vars,
123 tvm::Array<Constructor> constructors) {
124 NodePtr<TypeDataNode> n = make_node<TypeDataNode>();
125 n->header = std::move(header);
126 n->type_vars = std::move(type_vars);
127 n->constructors = std::move(constructors);
128 return TypeData(n);
129 }
130
131 TVM_REGISTER_NODE_TYPE(TypeDataNode);
132
133 TVM_REGISTER_API("relay._make.TypeData")
134 .set_body_typed(TypeDataNode::make);
135
136 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0602(const ObjectRef& ref, IRPrinter* p) 137 .set_dispatch<TypeDataNode>([](const ObjectRef& ref, IRPrinter* p) {
138 auto* node = static_cast<const TypeDataNode*>(ref.get());
139 p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
140 << node->constructors << ")";
141 });
142
make(Pattern lhs,Expr rhs)143 Clause ClauseNode::make(Pattern lhs, Expr rhs) {
144 NodePtr<ClauseNode> n = make_node<ClauseNode>();
145 n->lhs = std::move(lhs);
146 n->rhs = std::move(rhs);
147 return Clause(n);
148 }
149
150 TVM_REGISTER_NODE_TYPE(ClauseNode);
151
152 TVM_REGISTER_API("relay._make.Clause")
153 .set_body_typed(ClauseNode::make);
154
155 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0702(const ObjectRef& ref, IRPrinter* p) 156 .set_dispatch<ClauseNode>([](const ObjectRef& ref, IRPrinter* p) {
157 auto* node = static_cast<const ClauseNode*>(ref.get());
158 p->stream << "ClauseNode(" << node->lhs << ", "
159 << node->rhs << ")";
160 });
161
make(Expr data,tvm::Array<Clause> clauses,bool complete)162 Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
163 NodePtr<MatchNode> n = make_node<MatchNode>();
164 n->data = std::move(data);
165 n->clauses = std::move(clauses);
166 n->complete = complete;
167 return Match(n);
168 }
169
170 TVM_REGISTER_NODE_TYPE(MatchNode);
171
172 TVM_REGISTER_API("relay._make.Match")
173 .set_body_typed(MatchNode::make);
174
175 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon24a4c0ff0802(const ObjectRef& ref, IRPrinter* p) 176 .set_dispatch<MatchNode>([](const ObjectRef& ref, IRPrinter* p) {
177 auto* node = static_cast<const MatchNode*>(ref.get());
178 p->stream << "MatchNode(" << node->data << ", "
179 << node->clauses << ", " << node->complete << ")";
180 });
181
182 } // namespace relay
183 } // namespace tvm
184