1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \file util.cc
23 *
24 * \brief Utility functions for Relay.
25 */
26 #include <tvm/ir/type_functor.h>
27 #include <tvm/relay/analysis.h>
28 #include <tvm/relay/attrs/algorithm.h>
29 #include <tvm/relay/expr_functor.h>
30 #include <tvm/relay/op.h>
31 #include <tvm/relay/op_attr_types.h>
32 #include <tvm/relay/pattern_functor.h>
33
34 #include "../transforms/pass_util.h"
35
36 namespace tvm {
37 namespace relay {
38
39 template <typename T>
40 struct InsertionSet {
41 std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
42 std::vector<T> data;
Inserttvm::relay::InsertionSet43 void Insert(const T& t) {
44 if (set.count(t) == 0) {
45 set.insert(t);
46 data.push_back(t);
47 }
48 }
49 };
50
51 class TypeVarTVisitor : public TypeVisitor {
52 public:
TypeVarTVisitor(InsertionSet<TypeVar> * type_vars,InsertionSet<TypeVar> * bound_type_vars)53 TypeVarTVisitor(InsertionSet<TypeVar>* type_vars, InsertionSet<TypeVar>* bound_type_vars)
54 : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {}
55
VisitType_(const TypeVarNode * tp)56 void VisitType_(const TypeVarNode* tp) final {
57 TypeVar var = GetRef<TypeVar>(tp);
58 type_vars_->Insert(var);
59 }
60
VisitType_(const FuncTypeNode * f)61 void VisitType_(const FuncTypeNode* f) final {
62 for (auto type_param : f->type_params) {
63 type_vars_->Insert(type_param);
64 bound_type_vars_->Insert(type_param);
65 }
66 TypeVisitor::VisitType_(f);
67 }
68
69 private:
70 InsertionSet<TypeVar>* type_vars_;
71 InsertionSet<TypeVar>* bound_type_vars_;
72 };
73
74 class TypeVarEVisitor : private ExprVisitor {
75 public:
TypeVarEVisitor(const IRModule & mod)76 explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
77
CollectFree()78 Array<TypeVar> CollectFree() {
79 Array<TypeVar> ret;
80 for (const auto& v : type_vars_.data) {
81 if (bound_type_vars_.set.count(v) == 0) {
82 ret.push_back(v);
83 }
84 }
85 return ret;
86 }
87
CollectBound()88 Array<TypeVar> CollectBound() {
89 Array<TypeVar> ret;
90 for (const auto& v : bound_type_vars_.data) {
91 ret.push_back(v);
92 }
93 return ret;
94 }
95
CollectAll()96 Array<TypeVar> CollectAll() {
97 Array<TypeVar> ret;
98 for (const auto& v : type_vars_.data) {
99 ret.push_back(v);
100 }
101 return ret;
102 }
103
Free(const Expr & expr)104 Array<TypeVar> Free(const Expr& expr) {
105 VisitExpr(expr);
106 return CollectFree();
107 }
108
Free(const Type & type)109 Array<TypeVar> Free(const Type& type) {
110 VisitType(type);
111 return CollectFree();
112 }
113
Bound(const Expr & expr)114 Array<TypeVar> Bound(const Expr& expr) {
115 VisitExpr(expr);
116 return CollectBound();
117 }
118
Bound(const Type & type)119 Array<TypeVar> Bound(const Type& type) {
120 VisitType(type);
121 return CollectBound();
122 }
123
All(const Expr & expr)124 Array<TypeVar> All(const Expr& expr) {
125 VisitExpr(expr);
126 return CollectAll();
127 }
128
All(const Type & type)129 Array<TypeVar> All(const Type& type) {
130 VisitType(type);
131 return CollectAll();
132 }
133
VisitExpr_(const FunctionNode * f)134 void VisitExpr_(const FunctionNode* f) final {
135 for (const auto& tp : f->type_params) {
136 type_vars_.Insert(tp);
137 bound_type_vars_.Insert(tp);
138 }
139 ExprVisitor::VisitExpr_(f);
140 }
141
VisitExpr_(const ConstructorNode * cn)142 void VisitExpr_(const ConstructorNode* cn) final {
143 // for constructors, type vars will be bound in the module
144 auto data = mod_->LookupTypeDef(cn->belong_to);
145 for (const auto& tv : data->type_vars) {
146 type_vars_.Insert(tv);
147 bound_type_vars_.Insert(tv);
148 }
149 ExprVisitor::VisitExpr_(cn);
150 }
151
VisitType(const Type & t)152 void VisitType(const Type& t) final {
153 TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t);
154 }
155
156 private:
157 InsertionSet<TypeVar> type_vars_;
158 InsertionSet<TypeVar> bound_type_vars_;
159 const IRModule& mod_;
160 };
161
162 class VarVisitor : protected ExprVisitor, protected PatternVisitor {
163 public:
Free(const Expr & expr)164 Array<Var> Free(const Expr& expr) {
165 this->VisitExpr(expr);
166 Array<Var> ret;
167 for (const auto& v : vars_.data) {
168 if (bound_vars_.set.count(v) == 0) {
169 ret.push_back(v);
170 }
171 }
172 return ret;
173 }
174
Collect()175 Array<Var> Collect() {
176 Array<Var> ret;
177 for (const auto& v : bound_vars_.data) {
178 ret.push_back(v);
179 }
180 return ret;
181 }
182
Bound(const Expr & expr)183 Array<Var> Bound(const Expr& expr) {
184 this->VisitExpr(expr);
185 return Collect();
186 }
187
Bound(const Pattern & pat)188 Array<Var> Bound(const Pattern& pat) {
189 this->VisitPattern(pat);
190 return Collect();
191 }
192
All(const Expr & expr)193 Array<Var> All(const Expr& expr) {
194 this->VisitExpr(expr);
195 Array<Var> ret;
196 for (const auto& v : vars_.data) {
197 ret.push_back(v);
198 }
199 return ret;
200 }
201
MarkBounded(const Var & v)202 void MarkBounded(const Var& v) {
203 bound_vars_.Insert(v);
204 vars_.Insert(v);
205 }
206
VisitExpr_(const VarNode * var)207 void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
208
VisitExpr_(const FunctionNode * op)209 void VisitExpr_(const FunctionNode* op) final {
210 for (const auto& param : op->params) {
211 MarkBounded(param);
212 }
213 VisitExpr(op->body);
214 }
215
VisitExpr_(const LetNode * op)216 void VisitExpr_(const LetNode* op) final {
217 Expr let = GetRef<Let>(op);
218 while (auto let_node = let.as<LetNode>()) {
219 MarkBounded(let_node->var);
220 VisitExpr(let_node->value);
221 let = let_node->body;
222 }
223 VisitExpr(let);
224 }
225
VisitPattern(const Pattern & p)226 void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
227
VisitPattern_(const PatternVarNode * op)228 void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); }
229
230 private:
231 InsertionSet<Var> vars_;
232 InsertionSet<Var> bound_vars_;
233 };
234
FreeTypeVars(const Expr & expr,const IRModule & mod)235 tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) {
236 return TypeVarEVisitor(mod).Free(expr);
237 }
238
FreeTypeVars(const Type & type,const IRModule & mod)239 tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) {
240 return TypeVarEVisitor(mod).Free(type);
241 }
242
BoundTypeVars(const Expr & expr,const IRModule & mod)243 tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) {
244 return TypeVarEVisitor(mod).Bound(expr);
245 }
246
BoundTypeVars(const Type & type,const IRModule & mod)247 tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) {
248 return TypeVarEVisitor(mod).Bound(type);
249 }
250
AllTypeVars(const Expr & expr,const IRModule & mod)251 tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) {
252 return TypeVarEVisitor(mod).All(expr);
253 }
254
AllTypeVars(const Type & type,const IRModule & mod)255 tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) {
256 return TypeVarEVisitor(mod).All(type);
257 }
258
FreeVars(const Expr & expr)259 tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
260
BoundVars(const Expr & expr)261 tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }
262
BoundVars(const Pattern & pat)263 tvm::Array<Var> BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); }
264
AllVars(const Expr & expr)265 tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
266
267 TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars);
268
__anon383ebf080102(TVMArgs args, TVMRetValue* ret) 269 TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
270 ObjectRef x = args[0];
271 if (x.as<ExprNode>()) {
272 *ret = BoundVars(Downcast<Expr>(x));
273 } else {
274 *ret = BoundVars(Downcast<Pattern>(x));
275 }
276 });
277
278 TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars);
279
__anon383ebf080202(TVMArgs args, TVMRetValue* ret) 280 TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
281 ObjectRef x = args[0];
282 IRModule mod = args[1];
283 if (x.as<TypeNode>()) {
284 *ret = FreeTypeVars(Downcast<Type>(x), mod);
285 } else {
286 *ret = FreeTypeVars(Downcast<Expr>(x), mod);
287 }
288 });
289
__anon383ebf080302(TVMArgs args, TVMRetValue* ret) 290 TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
291 ObjectRef x = args[0];
292 IRModule mod = args[1];
293 if (x.as<TypeNode>()) {
294 *ret = BoundTypeVars(Downcast<Type>(x), mod);
295 } else {
296 *ret = BoundTypeVars(Downcast<Expr>(x), mod);
297 }
298 });
299
__anon383ebf080402(TVMArgs args, TVMRetValue* ret) 300 TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
301 ObjectRef x = args[0];
302 IRModule mod = args[1];
303 if (x.as<TypeNode>()) {
304 *ret = AllTypeVars(Downcast<Type>(x), mod);
305 } else {
306 *ret = AllTypeVars(Downcast<Expr>(x), mod);
307 }
308 });
309
310 class DtypeCollector : protected ExprVisitor, protected TypeVisitor {
311 public:
VisitExpr(const Expr & expr)312 void VisitExpr(const Expr& expr) final {
313 if (expr->checked_type_.defined()) {
314 TypeVisitor::VisitType(expr->checked_type());
315 }
316 ExprVisitor::VisitExpr(expr);
317 }
318
VisitType_(const TensorTypeNode * op)319 void VisitType_(const TensorTypeNode* op) final { dtypes_.insert(DLDataType2String(op->dtype)); }
320
All(const Expr & expr)321 Array<String> All(const Expr& expr) {
322 VisitExpr(expr);
323
324 Array<String> res;
325 for (const auto& dtype : dtypes_) {
326 res.push_back(String(dtype));
327 }
328 return res;
329 }
330
331 private:
332 std::unordered_set<std::string> dtypes_;
333 };
334
AllDtypes(const Expr & expr)335 tvm::Array<String> AllDtypes(const Expr& expr) { return DtypeCollector().All(expr); }
336
337 TVM_REGISTER_GLOBAL("relay.analysis.all_dtypes").set_body_typed(AllDtypes);
338
339 /*!
340 * \brief Get reference counter of each internal ExprNode in body.
341 * \param body The body expression.
342 * \return The reference count mapping.
343 */
GetExprRefCount(const Expr & body)344 std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body) {
345 class ExprRefCounter : private MixedModeVisitor {
346 public:
347 std::unordered_map<const Object*, size_t> Get(const Expr& body) {
348 this->VisitExpr(body);
349 return std::move(this->visit_counter_);
350 }
351 };
352 return ExprRefCounter().Get(body);
353 }
354
355 template <typename T>
IsNDArrayAllGreaterEqual(const runtime::NDArray & tensor,T value)356 bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
357 CHECK_EQ(tensor->ctx.device_type, kDLCPU);
358 CHECK(tensor->strides == nullptr);
359 CHECK_EQ(tensor->byte_offset, 0);
360 const T* data = static_cast<const T*>(tensor->data);
361 int64_t num_elems = 1;
362 for (int i = 0; i < tensor->ndim; ++i) {
363 num_elems *= tensor->shape[i];
364 }
365
366 for (int64_t i = 0; i < num_elems; i++) {
367 if (*data < value) {
368 return false;
369 }
370 data++;
371 }
372 return true;
373 }
374
IsAllPositiveConstant(const Expr & expr)375 bool IsAllPositiveConstant(const Expr& expr) {
376 // Cache the operators that are checked recursively to reduce lookup overhead.
377 static const auto& expand_dims_op = Op::Get("expand_dims");
378 static const auto& reshape_op = Op::Get("reshape");
379 static const auto& transpose_op = Op::Get("transpose");
380 static const auto& squeeze_op = Op::Get("squeeze");
381
382 // peel through a few common transform ops.
383 if (const auto* constant = expr.as<ConstantNode>()) {
384 const auto& tensor = constant->data;
385 const auto& dtype = tensor->dtype;
386 if (dtype.lanes != 1) {
387 return false;
388 } else if (dtype.code == kDLFloat && dtype.bits == 32) {
389 return IsNDArrayAllGreaterEqual<float>(tensor, 0);
390 } else if (dtype.code == kDLFloat && dtype.bits == 64) {
391 return IsNDArrayAllGreaterEqual<double>(tensor, 0);
392 } else if (dtype.code == kDLInt && dtype.bits == 8) {
393 return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
394 } else if (dtype.code == kDLInt && dtype.bits == 32) {
395 return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
396 } else if (dtype.code == kDLUInt && dtype.bits == 8) {
397 return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
398 } else if (dtype.code == kDLUInt && dtype.bits == 32) {
399 return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
400 } else {
401 return false;
402 }
403 } else if (const auto* op = expr.as<CallNode>()) {
404 // tail recursion.
405 if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op ||
406 op->op == squeeze_op) {
407 return IsAllPositiveConstant(op->args[0]);
408 } else {
409 return false;
410 }
411 } else {
412 return false;
413 }
414 }
415
TypeSubst(const Type & type,const TypeVar & tvar,const Type & subst)416 Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
417 return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
418 }
419
TypeSubst(const Expr & expr,const TypeVar & tvar,const Type & subst)420 Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
421 return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
422 }
423
TypeSubst(const Type & type,const tvm::Map<TypeVar,Type> & subst_map)424 Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
425 return Bind(type, subst_map);
426 }
427
TypeSubst(const Expr & expr,const tvm::Map<TypeVar,Type> & subst_map)428 Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
429 class TypeSubstMutator : public ExprMutator, public PatternMutator {
430 public:
431 explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) {}
432 Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); }
433 Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); }
434
435 Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
436
437 Clause VisitClause(const Clause& c) final {
438 Pattern pat = VisitPattern(c->lhs);
439 return Clause(pat, VisitExpr(c->rhs));
440 }
441
442 private:
443 const tvm::Map<TypeVar, Type>& subst_map_;
444 };
445 CHECK(WellFormed(expr));
446 auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
447 CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
448 CHECK(WellFormed(ret));
449 return ret;
450 }
451
452 struct IsDynamicVisitor : public TypeVisitor {
453 bool is_dyn{false};
VisitType_tvm::relay::IsDynamicVisitor454 void VisitType_(const TensorTypeNode* tt) {
455 for (auto dim : tt->shape) {
456 if (dim.as<tir::IntImmNode>() == nullptr) {
457 is_dyn = true;
458 break;
459 }
460 }
461 }
462 };
463
IsDynamic(const Type & ty)464 bool IsDynamic(const Type& ty) {
465 IsDynamicVisitor v;
466 v.VisitType(ty);
467 return v.is_dyn;
468 }
469
470 TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
471
IsDataDependant(const CallNode * call)472 bool IsDataDependant(const CallNode* call) {
473 static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
474 Op op = Downcast<Op>(call->op);
475
476 if (!tshape_data_dependant.count(op)) {
477 return false;
478 }
479
480 if (op->name == "strided_slice") {
481 if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
482 if (attrs->begin && attrs->end && attrs->strides) {
483 // not data dependant if begin, end and strides exist
484 return false;
485 }
486 }
487 }
488
489 return tshape_data_dependant[op];
490 }
491 } // namespace relay
492 } // namespace tvm
493