1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/tvm/ir/expr.cc
22  * \brief The expression AST nodes of Relay.
23  */
24 #include <tvm/relay/expr.h>
25 
26 namespace tvm {
27 namespace relay {
28 
29 using tvm::IRPrinter;
30 using namespace tvm::runtime;
31 
make(runtime::NDArray data)32 Constant ConstantNode::make(runtime::NDArray data) {
33   NodePtr<ConstantNode> n = make_node<ConstantNode>();
34   n->data = std::move(data);
35   return Constant(n);
36 }
37 
38 TVM_REGISTER_NODE_TYPE(ConstantNode);
39 
40 TVM_REGISTER_API("relay._make.Constant")
41 .set_body_typed(ConstantNode::make);
42 
43 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250102(const ObjectRef& ref, IRPrinter* p) 44 .set_dispatch<ConstantNode>([](const ObjectRef& ref, IRPrinter* p) {
45     auto* node = static_cast<const ConstantNode*>(ref.get());
46     const PackedFunc* fprint = Registry::Get("relay._constant_repr");
47     CHECK(fprint) << "unable to find printing function for constants";
48     std::string data = (*fprint)(GetRef<Constant>(node));
49     p->stream << "Constant(" << data << ")";
50   });
51 
tensor_type() const52 TensorType ConstantNode::tensor_type() const {
53   auto dtype = TVMType2Type(data->dtype);
54   Array<tvm::Expr> shape;
55   for (int i = 0; i < data->ndim; i++) {
56     CHECK_LE(data->shape[i], std::numeric_limits<int32_t>::max());
57     CHECK_GE(data->shape[i], std::numeric_limits<int32_t>::min());
58     shape.push_back(
59         tvm::ir::IntImm::make(Int(32), data->shape[i]));
60   }
61 
62   return TensorTypeNode::make(shape, dtype);
63 }
64 
make(tvm::Array<relay::Expr> fields)65 Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
66   NodePtr<TupleNode> n = make_node<TupleNode>();
67   n->fields = std::move(fields);
68   return Tuple(n);
69 }
70 
71 TVM_REGISTER_NODE_TYPE(TupleNode);
72 
73 TVM_REGISTER_API("relay._make.Tuple")
74 .set_body_typed(TupleNode::make);
75 
76 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250202(const ObjectRef& ref, IRPrinter* p) 77 .set_dispatch<TupleNode>([](const ObjectRef& ref, IRPrinter* p) {
78     auto* node = static_cast<const TupleNode*>(ref.get());
79     p->stream << "Tuple(" << node->fields << ")";
80   });
81 
82 
make(Id vid,Type type_annotation)83 Var VarNode::make(Id vid, Type type_annotation) {
84   NodePtr<VarNode> n = make_node<VarNode>();
85   n->vid = std::move(vid);
86   n->type_annotation = std::move(type_annotation);
87   return Var(n);
88 }
89 
make(std::string name_hint,Type type_annotation)90 Var VarNode::make(std::string name_hint, Type type_annotation) {
91   NodePtr<IdNode> n = make_node<IdNode>();
92   n->name_hint = std::move(name_hint);
93   return VarNode::make(Id(n), type_annotation);
94 }
95 
96 TVM_REGISTER_NODE_TYPE(VarNode);
97 
98 TVM_REGISTER_API("relay._make.Var")
99 .set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
100 
101 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250302(const ObjectRef& ref, IRPrinter* p) 102 .set_dispatch<VarNode>([](const ObjectRef& ref, IRPrinter* p) {
103     auto* node = static_cast<const VarNode*>(ref.get());
104     p->stream << "Var(" << node->name_hint();
105     if (node->type_annotation.defined()) {
106       p->stream << ", ty=";
107       p->Print(node->type_annotation);
108     }
109     p->stream << ")";
110   });
111 
make(std::string name_hint)112 GlobalVar GlobalVarNode::make(std::string name_hint) {
113   NodePtr<GlobalVarNode> n = make_node<GlobalVarNode>();
114   n->name_hint = std::move(name_hint);
115   return GlobalVar(n);
116 }
117 
118 TVM_REGISTER_NODE_TYPE(GlobalVarNode);
119 
120 TVM_REGISTER_API("relay._make.GlobalVar")
121 .set_body_typed(GlobalVarNode::make);
122 
123 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250402(const ObjectRef& ref, IRPrinter* p) 124 .set_dispatch<GlobalVarNode>([](const ObjectRef& ref, IRPrinter* p) {
125     auto* node = static_cast<const GlobalVarNode*>(ref.get());
126     p->stream << "GlobalVar(" << node->name_hint << ")";
127   });
128 
129 
make(tvm::Array<Var> params,Expr body,Type ret_type,tvm::Array<TypeVar> type_params,tvm::Attrs attrs)130 Function FunctionNode::make(tvm::Array<Var> params,
131                             Expr body,
132                             Type ret_type,
133                             tvm::Array<TypeVar> type_params,
134                             tvm::Attrs attrs) {
135   NodePtr<FunctionNode> n = make_node<FunctionNode>();
136   CHECK(params.defined());
137   CHECK(type_params.defined());
138   n->params = std::move(params);
139   n->body = std::move(body);
140   n->ret_type = std::move(ret_type);
141   n->type_params = std::move(type_params);
142   n->attrs = std::move(attrs);
143   return Function(n);
144 }
145 
func_type_annotation() const146 FuncType FunctionNode::func_type_annotation() const {
147   Array<Type> param_types;
148   for (auto param : this->params) {
149     Type param_type = (param->type_annotation.defined()) ? param->type_annotation
150       : IncompleteTypeNode::make(Kind::kType);
151     param_types.push_back(param_type);
152   }
153 
154   Type ret_type = (this->ret_type.defined()) ? this->ret_type
155     : IncompleteTypeNode::make(Kind::kType);
156   return FuncTypeNode::make(param_types, ret_type, this->type_params, {});
157 }
158 
IsPrimitive() const159 bool FunctionNode::IsPrimitive() const {
160   NodeRef res = FunctionGetAttr(GetRef<Function>(this), "Primitive");
161   const ir::IntImm* pval = res.as<ir::IntImm>();
162   return pval && pval->value != 0;
163 }
164 
SetParams(const tvm::Map<Var,Constant> & parameters) const165 Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
166   return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
167 }
168 
169 TVM_REGISTER_API("relay._expr.FunctionSetParams")
170 .set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
__anon1ba907250502(const Function& func, const tvm::Map<Var, Constant>& parameters) 171   [](const Function& func, const tvm::Map<Var, Constant>& parameters) {
172     return func->SetParams(parameters);
173 });
174 
GetParams() const175 tvm::Map<Var, Constant> FunctionNode::GetParams() const {
176   auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
177   return Downcast<tvm::Map<Var, Constant>>(node_ref);
178 }
179 
180 TVM_REGISTER_API("relay._expr.FunctionGetParams")
__anon1ba907250602(const Function& func) 181 .set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
182   return func->GetParams();
183 });
184 
FunctionGetAttr(const Function & func,const std::string & key)185 NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
186   if (!func->attrs.defined()) { return NodeRef(); }
187 
188   const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
189   CHECK(dict_attrs);
190   auto it = dict_attrs->dict.find(key);
191   if (it != dict_attrs->dict.end()) {
192     return (*it).second;
193   } else {
194     return NodeRef();
195   }
196 }
197 
FunctionSetAttr(const Function & func,const std::string & key,const NodeRef & data)198 Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data) {
199   const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
200   Attrs func_attrs;
201   if (dattrs) {
202     Map<std::string, NodeRef> dict = dattrs->dict;
203     dict.Set(key, data);
204     func_attrs = DictAttrsNode::make(dict);
205   } else {
206     Map<std::string, NodeRef> dict = {{key, data}};
207     func_attrs = DictAttrsNode::make(dict);
208   }
209 
210   return FunctionNode::make(
211     func->params,
212     func->body,
213     func->ret_type,
214     func->type_params,
215     func_attrs);
216 }
217 
218 TVM_REGISTER_NODE_TYPE(FunctionNode);
219 
220 TVM_REGISTER_API("relay._make.Function")
221 .set_body_typed(FunctionNode::make);
222 
223 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250702(const ObjectRef& ref, IRPrinter* p) 224 .set_dispatch<FunctionNode>([](const ObjectRef& ref, IRPrinter* p) {
225   auto* node = static_cast<const FunctionNode*>(ref.get());
226   p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
227             << ", " << node->body << ", " << node->type_params << ", "
228             << node->attrs << ")";
229 });
230 
make(Expr op,Array<Expr> args,Attrs attrs,Array<Type> type_args)231 Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
232                     Array<Type> type_args) {
233   NodePtr<CallNode> n = make_node<CallNode>();
234   n->op = std::move(op);
235   n->args = std::move(args);
236   n->attrs = std::move(attrs);
237   n->type_args = std::move(type_args);
238   return Call(n);
239 }
240 
241 TVM_REGISTER_NODE_TYPE(CallNode);
242 
243 TVM_REGISTER_API("relay._make.Call")
244 .set_body_typed(CallNode::make);
245 
246 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250802(const ObjectRef& ref, IRPrinter* p) 247 .set_dispatch<CallNode>([](const ObjectRef& ref, IRPrinter* p) {
248   auto* node = static_cast<const CallNode*>(ref.get());
249   p->stream << "CallNode(" << node->op << ", " << node->args << ", "
250             << node->attrs << ", " << node->type_args << ")";
251   });
252 
make(Var var,Expr value,Expr body)253 Let LetNode::make(Var var, Expr value, Expr body) {
254   NodePtr<LetNode> n = make_node<LetNode>();
255   n->var = std::move(var);
256   n->value = std::move(value);
257   n->body = std::move(body);
258   return Let(n);
259 }
260 
261 TVM_REGISTER_NODE_TYPE(LetNode);
262 
263 TVM_REGISTER_API("relay._make.Let")
264 .set_body_typed(LetNode::make);
265 
266 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250902(const ObjectRef& ref, IRPrinter* p) 267 .set_dispatch<LetNode>([](const ObjectRef& ref, IRPrinter* p) {
268   auto* node = static_cast<const LetNode*>(ref.get());
269   p->stream << "LetNode(" << node->var << ", " << node->value
270             << ", " << node->body << ")";
271 });
272 
make(Expr cond,Expr true_branch,Expr false_branch)273 If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
274   NodePtr<IfNode> n = make_node<IfNode>();
275   n->cond = std::move(cond);
276   n->true_branch = std::move(true_branch);
277   n->false_branch = std::move(false_branch);
278   return If(n);
279 }
280 
281 TVM_REGISTER_NODE_TYPE(IfNode);
282 
283 TVM_REGISTER_API("relay._make.If")
284 .set_body_typed(IfNode::make);
285 
286 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250a02(const ObjectRef& ref, IRPrinter* p) 287 .set_dispatch<IfNode>([](const ObjectRef& ref, IRPrinter* p) {
288   auto* node = static_cast<const IfNode*>(ref.get());
289   p->stream << "IfNode(" << node->cond << ", " << node->true_branch
290             << ", " << node->false_branch << ")";
291 });
292 
make(Expr tuple,int index)293 TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
294   NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
295   n->tuple = std::move(tuple);
296   n->index = index;
297   return TupleGetItem(n);
298 }
299 
300 TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
301 
302 TVM_REGISTER_API("relay._make.TupleGetItem")
303 .set_body_typed(TupleGetItemNode::make);
304 
305 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250b02(const ObjectRef& ref, IRPrinter* p) 306 .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, IRPrinter* p) {
307   auto* node = static_cast<const TupleGetItemNode*>(ref.get());
308   p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
309 });
310 
make(Expr value)311 RefCreate RefCreateNode::make(Expr value) {
312   NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
313   n->value = std::move(value);
314   return RefCreate(n);
315 }
316 
317 TVM_REGISTER_NODE_TYPE(RefCreateNode);
318 
319 TVM_REGISTER_API("relay._make.RefCreate")
320 .set_body_typed(RefCreateNode::make);
321 
322 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250c02(const ObjectRef& ref, IRPrinter* p) 323 .set_dispatch<RefCreateNode>([](const ObjectRef& ref, IRPrinter* p) {
324   auto* node = static_cast<const RefCreateNode*>(ref.get());
325   p->stream << "RefCreateNode(" << node->value << ")";
326 });
327 
make(Expr ref)328 RefRead RefReadNode::make(Expr ref) {
329   NodePtr<RefReadNode> n = make_node<RefReadNode>();
330   n->ref = std::move(ref);
331   return RefRead(n);
332 }
333 
334 TVM_REGISTER_NODE_TYPE(RefReadNode);
335 
336 TVM_REGISTER_API("relay._make.RefRead")
337 .set_body_typed(RefReadNode::make);
338 
339 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250d02(const ObjectRef& ref, IRPrinter* p) 340 .set_dispatch<RefReadNode>([](const ObjectRef& ref, IRPrinter* p) {
341   auto* node = static_cast<const RefReadNode*>(ref.get());
342   p->stream << "RefReadNode(" << node->ref << ")";
343 });
344 
make(Expr ref,Expr value)345 RefWrite RefWriteNode::make(Expr ref, Expr value) {
346   NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
347   n->ref = std::move(ref);
348   n->value = std::move(value);
349   return RefWrite(n);
350 }
351 
352 TVM_REGISTER_NODE_TYPE(RefWriteNode);
353 
354 TVM_REGISTER_API("relay._make.RefWrite")
355 .set_body_typed(RefWriteNode::make);
356 
357 TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
__anon1ba907250e02(const ObjectRef& ref, IRPrinter* p) 358 .set_dispatch<RefWriteNode>([](const ObjectRef& ref, IRPrinter* p) {
359   auto* node = static_cast<const RefWriteNode*>(ref.get());
360   p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
361 });
362 
363 TVM_REGISTER_API("relay._expr.TempExprRealize")
__anon1ba907250f02(TempExpr temp) 364 .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
365   return temp->Realize();
366 });
367 
368 TVM_REGISTER_API("relay._expr.FunctionSetAttr")
369 .set_body_typed<Function(Function, std::string, NodeRef)>(
__anon1ba907251002(Function func, std::string name, NodeRef ref) 370   [](Function func, std::string name, NodeRef ref) {
371     return FunctionSetAttr(func, name, ref);
372 });
373 
374 }  // namespace relay
375 }  // namespace tvm
376