1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/tvm/ir/expr.cc
22 * \brief The expression AST nodes of Relay.
23 */
24 #include <tvm/relay/expr.h>
25
26 namespace tvm {
27 namespace relay {
28
29 using tvm::IRPrinter;
30 using namespace tvm::runtime;
31
make(runtime::NDArray data)32 Constant ConstantNode::make(runtime::NDArray data) {
33 NodePtr<ConstantNode> n = make_node<ConstantNode>();
34 n->data = std::move(data);
35 return Constant(n);
36 }
37
38 TVM_REGISTER_NODE_TYPE(ConstantNode);
39
40 TVM_REGISTER_API("relay._make.Constant")
41 .set_body_typed(ConstantNode::make);
42
43 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0102(const ObjectRef& ref, IRPrinter* p) 44 .set_dispatch<ConstantNode>([](const ObjectRef& ref, IRPrinter* p) {
45 auto* node = static_cast<const ConstantNode*>(ref.get());
46 const PackedFunc* fprint = Registry::Get("relay._constant_repr");
47 CHECK(fprint) << "unable to find printing function for constants";
48 std::string data = (*fprint)(GetRef<Constant>(node));
49 p->stream << "Constant(" << data << ")";
50 });
51
tensor_type() const52 TensorType ConstantNode::tensor_type() const {
53 auto dtype = TVMType2Type(data->dtype);
54 Array<tvm::Expr> shape;
55 for (int i = 0; i < data->ndim; i++) {
56 CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
57 CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
58 shape.push_back(
59 tvm::ir::IntImm::make(Int(32), data->shape[i]));
60 }
61
62 return TensorTypeNode::make(shape, dtype);
63 }
64
make(tvm::Array<relay::Expr> fields)65 Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
66 NodePtr<TupleNode> n = make_node<TupleNode>();
67 n->fields = std::move(fields);
68 return Tuple(n);
69 }
70
71 TVM_REGISTER_NODE_TYPE(TupleNode);
72
73 TVM_REGISTER_API("relay._make.Tuple")
74 .set_body_typed(TupleNode::make);
75
76 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0202(const ObjectRef& ref, IRPrinter* p) 77 .set_dispatch<TupleNode>([](const ObjectRef& ref, IRPrinter* p) {
78 auto* node = static_cast<const TupleNode*>(ref.get());
79 p->stream << "Tuple(" << node->fields << ")";
80 });
81
82
make(Id vid,Type type_annotation)83 Var VarNode::make(Id vid, Type type_annotation) {
84 NodePtr<VarNode> n = make_node<VarNode>();
85 n->vid = std::move(vid);
86 n->type_annotation = std::move(type_annotation);
87 return Var(n);
88 }
89
make(std::string name_hint,Type type_annotation)90 Var VarNode::make(std::string name_hint, Type type_annotation) {
91 NodePtr<IdNode> n = make_node<IdNode>();
92 n->name_hint = std::move(name_hint);
93 return VarNode::make(Id(n), type_annotation);
94 }
95
96 TVM_REGISTER_NODE_TYPE(VarNode);
97
98 TVM_REGISTER_API("relay._make.Var")
99 .set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
100
101 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0302(const ObjectRef& ref, IRPrinter* p) 102 .set_dispatch<VarNode>([](const ObjectRef& ref, IRPrinter* p) {
103 auto* node = static_cast<const VarNode*>(ref.get());
104 p->stream << "Var(" << node->name_hint();
105 if (node->type_annotation.defined()) {
106 p->stream << ", ty=";
107 p->Print(node->type_annotation);
108 }
109 p->stream << ")";
110 });
111
make(std::string name_hint)112 GlobalVar GlobalVarNode::make(std::string name_hint) {
113 NodePtr<GlobalVarNode> n = make_node<GlobalVarNode>();
114 n->name_hint = std::move(name_hint);
115 return GlobalVar(n);
116 }
117
118 TVM_REGISTER_NODE_TYPE(GlobalVarNode);
119
120 TVM_REGISTER_API("relay._make.GlobalVar")
121 .set_body_typed(GlobalVarNode::make);
122
123 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0402(const ObjectRef& ref, IRPrinter* p) 124 .set_dispatch<GlobalVarNode>([](const ObjectRef& ref, IRPrinter* p) {
125 auto* node = static_cast<const GlobalVarNode*>(ref.get());
126 p->stream << "GlobalVar(" << node->name_hint << ")";
127 });
128
129
make(tvm::Array<Var> params,Expr body,Type ret_type,tvm::Array<TypeVar> type_params,tvm::Attrs attrs)130 Function FunctionNode::make(tvm::Array<Var> params,
131 Expr body,
132 Type ret_type,
133 tvm::Array<TypeVar> type_params,
134 tvm::Attrs attrs) {
135 NodePtr<FunctionNode> n = make_node<FunctionNode>();
136 CHECK(params.defined());
137 CHECK(type_params.defined());
138 n->params = std::move(params);
139 n->body = std::move(body);
140 n->ret_type = std::move(ret_type);
141 n->type_params = std::move(type_params);
142 n->attrs = std::move(attrs);
143 return Function(n);
144 }
145
func_type_annotation() const146 FuncType FunctionNode::func_type_annotation() const {
147 Array<Type> param_types;
148 for (auto param : this->params) {
149 Type param_type = (param->type_annotation.defined()) ? param->type_annotation
150 : IncompleteTypeNode::make(Kind::kType);
151 param_types.push_back(param_type);
152 }
153
154 Type ret_type = (this->ret_type.defined()) ? this->ret_type
155 : IncompleteTypeNode::make(Kind::kType);
156 return FuncTypeNode::make(param_types, ret_type, this->type_params, {});
157 }
158
IsPrimitive() const159 bool FunctionNode::IsPrimitive() const {
160 NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive");
161 const ir::IntImm* pval = res.as<ir::IntImm>();
162 return pval && pval->value != 0;
163 }
164
SetParams(const tvm::Map<Var,Constant> & parameters) const165 Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
166 return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
167 }
168
169 TVM_REGISTER_API("relay._expr.FunctionSetParams")
170 .set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
__anon70068e3b0502(const Function& func, const tvm::Map<Var, Constant>& parameters) 171 [](const Function& func, const tvm::Map<Var, Constant>& parameters) {
172 return func->SetParams(parameters);
173 });
174
GetParams() const175 tvm::Map<Var, Constant> FunctionNode::GetParams() const {
176 auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
177 return Downcast<tvm::Map<Var, Constant>>(node_ref);
178 }
179
180 TVM_REGISTER_API("relay._expr.FunctionGetParams")
__anon70068e3b0602(const Function& func) 181 .set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
182 return func->GetParams();
183 });
184
FunctionGetAttr(const Function & func,const std::string & key)185 NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
186 if (!func->attrs.defined()) { return NodeRef(); }
187
188 const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
189 CHECK(dict_attrs);
190 auto it = dict_attrs->dict.find(key);
191 if (it != dict_attrs->dict.end()) {
192 return (*it).second;
193 } else {
194 return NodeRef();
195 }
196 }
197
FunctionSetAttr(const Function & func,const std::string & key,const NodeRef & data)198 Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) {
199 const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
200 Attrs func_attrs;
201 if (dattrs) {
202 Map<std::string, NodeRef> dict = dattrs->dict;
203 dict.Set(key, data);
204 func_attrs = DictAttrsNode::make(dict);
205 } else {
206 Map<std::string, NodeRef> dict = {{key, data}};
207 func_attrs = DictAttrsNode::make(dict);
208 }
209
210 return FunctionNode::make(
211 func->params,
212 func->body,
213 func->ret_type,
214 func->type_params,
215 func_attrs);
216 }
217
218 TVM_REGISTER_NODE_TYPE(FunctionNode);
219
220 TVM_REGISTER_API("relay._make.Function")
221 .set_body_typed(FunctionNode::make);
222
223 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0702(const ObjectRef& ref, IRPrinter* p) 224 .set_dispatch<FunctionNode>([](const ObjectRef& ref, IRPrinter* p) {
225 auto* node = static_cast<const FunctionNode*>(ref.get());
226 p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
227 << ", " << node->body << ", " << node->type_params << ", "
228 << node->attrs << ")";
229 });
230
make(Expr op,Array<Expr> args,Attrs attrs,Array<Type> type_args)231 Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
232 Array<Type> type_args) {
233 NodePtr<CallNode> n = make_node<CallNode>();
234 n->op = std::move(op);
235 n->args = std::move(args);
236 n->attrs = std::move(attrs);
237 n->type_args = std::move(type_args);
238 return Call(n);
239 }
240
241 TVM_REGISTER_NODE_TYPE(CallNode);
242
243 TVM_REGISTER_API("relay._make.Call")
244 .set_body_typed(CallNode::make);
245
246 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0802(const ObjectRef& ref, IRPrinter* p) 247 .set_dispatch<CallNode>([](const ObjectRef& ref, IRPrinter* p) {
248 auto* node = static_cast<const CallNode*>(ref.get());
249 p->stream << "CallNode(" << node->op << ", " << node->args << ", "
250 << node->attrs << ", " << node->type_args << ")";
251 });
252
make(Var var,Expr value,Expr body)253 Let LetNode::make(Var var, Expr value, Expr body) {
254 NodePtr<LetNode> n = make_node<LetNode>();
255 n->var = std::move(var);
256 n->value = std::move(value);
257 n->body = std::move(body);
258 return Let(n);
259 }
260
261 TVM_REGISTER_NODE_TYPE(LetNode);
262
263 TVM_REGISTER_API("relay._make.Let")
264 .set_body_typed(LetNode::make);
265
266 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0902(const ObjectRef& ref, IRPrinter* p) 267 .set_dispatch<LetNode>([](const ObjectRef& ref, IRPrinter* p) {
268 auto* node = static_cast<const LetNode*>(ref.get());
269 p->stream << "LetNode(" << node->var << ", " << node->value
270 << ", " << node->body << ")";
271 });
272
make(Expr cond,Expr true_branch,Expr false_branch)273 If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
274 NodePtr<IfNode> n = make_node<IfNode>();
275 n->cond = std::move(cond);
276 n->true_branch = std::move(true_branch);
277 n->false_branch = std::move(false_branch);
278 return If(n);
279 }
280
281 TVM_REGISTER_NODE_TYPE(IfNode);
282
283 TVM_REGISTER_API("relay._make.If")
284 .set_body_typed(IfNode::make);
285
286 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0a02(const ObjectRef& ref, IRPrinter* p) 287 .set_dispatch<IfNode>([](const ObjectRef& ref, IRPrinter* p) {
288 auto* node = static_cast<const IfNode*>(ref.get());
289 p->stream << "IfNode(" << node->cond << ", " << node->true_branch
290 << ", " << node->false_branch << ")";
291 });
292
make(Expr tuple,int index)293 TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
294 NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
295 n->tuple = std::move(tuple);
296 n->index = index;
297 return TupleGetItem(n);
298 }
299
300 TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
301
302 TVM_REGISTER_API("relay._make.TupleGetItem")
303 .set_body_typed(TupleGetItemNode::make);
304
305 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0b02(const ObjectRef& ref, IRPrinter* p) 306 .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, IRPrinter* p) {
307 auto* node = static_cast<const TupleGetItemNode*>(ref.get());
308 p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
309 });
310
make(Expr value)311 RefCreate RefCreateNode::make(Expr value) {
312 NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
313 n->value = std::move(value);
314 return RefCreate(n);
315 }
316
317 TVM_REGISTER_NODE_TYPE(RefCreateNode);
318
319 TVM_REGISTER_API("relay._make.RefCreate")
320 .set_body_typed(RefCreateNode::make);
321
322 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0c02(const ObjectRef& ref, IRPrinter* p) 323 .set_dispatch<RefCreateNode>([](const ObjectRef& ref, IRPrinter* p) {
324 auto* node = static_cast<const RefCreateNode*>(ref.get());
325 p->stream << "RefCreateNode(" << node->value << ")";
326 });
327
make(Expr ref)328 RefRead RefReadNode::make(Expr ref) {
329 NodePtr<RefReadNode> n = make_node<RefReadNode>();
330 n->ref = std::move(ref);
331 return RefRead(n);
332 }
333
334 TVM_REGISTER_NODE_TYPE(RefReadNode);
335
336 TVM_REGISTER_API("relay._make.RefRead")
337 .set_body_typed(RefReadNode::make);
338
339 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0d02(const ObjectRef& ref, IRPrinter* p) 340 .set_dispatch<RefReadNode>([](const ObjectRef& ref, IRPrinter* p) {
341 auto* node = static_cast<const RefReadNode*>(ref.get());
342 p->stream << "RefReadNode(" << node->ref << ")";
343 });
344
make(Expr ref,Expr value)345 RefWrite RefWriteNode::make(Expr ref, Expr value) {
346 NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
347 n->ref = std::move(ref);
348 n->value = std::move(value);
349 return RefWrite(n);
350 }
351
352 TVM_REGISTER_NODE_TYPE(RefWriteNode);
353
354 TVM_REGISTER_API("relay._make.RefWrite")
355 .set_body_typed(RefWriteNode::make);
356
357 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon70068e3b0e02(const ObjectRef& ref, IRPrinter* p) 358 .set_dispatch<RefWriteNode>([](const ObjectRef& ref, IRPrinter* p) {
359 auto* node = static_cast<const RefWriteNode*>(ref.get());
360 p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
361 });
362
363 TVM_REGISTER_API("relay._expr.TempExprRealize")
__anon70068e3b0f02(TempExpr temp) 364 .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
365 return temp->Realize();
366 });
367
368 TVM_REGISTER_API("relay._expr.FunctionSetAttr")
369 .set_body_typed<Function(Function, std::string, NodeRef)>(
__anon70068e3b1002(Function func, std::string name, NodeRef ref) 370 [](Function func, std::string name, NodeRef ref) {
371 return FunctionSetAttr(func, name, ref);
372 });
373
374 } // namespace relay
375 } // namespace tvm
376