1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file util.cc
23  *
24  * \brief Utility functions for Relay.
25  */
26 #include <tvm/ir/type_functor.h>
27 #include <tvm/relay/analysis.h>
28 #include <tvm/relay/attrs/algorithm.h>
29 #include <tvm/relay/expr_functor.h>
30 #include <tvm/relay/op.h>
31 #include <tvm/relay/op_attr_types.h>
32 #include <tvm/relay/pattern_functor.h>
33 
34 #include "../transforms/pass_util.h"
35 
36 namespace tvm {
37 namespace relay {
38 
39 template <typename T>
40 struct InsertionSet {
41   std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
42   std::vector<T> data;
Inserttvm::relay::InsertionSet43   void Insert(const T& t) {
44     if (set.count(t) == 0) {
45       set.insert(t);
46       data.push_back(t);
47     }
48   }
49 };
50 
51 class TypeVarTVisitor : public TypeVisitor {
52  public:
TypeVarTVisitor(InsertionSet<TypeVar> * type_vars,InsertionSet<TypeVar> * bound_type_vars)53   TypeVarTVisitor(InsertionSet<TypeVar>* type_vars, InsertionSet<TypeVar>* bound_type_vars)
54       : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {}
55 
VisitType_(const TypeVarNode * tp)56   void VisitType_(const TypeVarNode* tp) final {
57     TypeVar var = GetRef<TypeVar>(tp);
58     type_vars_->Insert(var);
59   }
60 
VisitType_(const FuncTypeNode * f)61   void VisitType_(const FuncTypeNode* f) final {
62     for (auto type_param : f->type_params) {
63       type_vars_->Insert(type_param);
64       bound_type_vars_->Insert(type_param);
65     }
66     TypeVisitor::VisitType_(f);
67   }
68 
69  private:
70   InsertionSet<TypeVar>* type_vars_;
71   InsertionSet<TypeVar>* bound_type_vars_;
72 };
73 
74 class TypeVarEVisitor : private ExprVisitor {
75  public:
TypeVarEVisitor(const IRModule & mod)76   explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
77 
CollectFree()78   Array<TypeVar> CollectFree() {
79     Array<TypeVar> ret;
80     for (const auto& v : type_vars_.data) {
81       if (bound_type_vars_.set.count(v) == 0) {
82         ret.push_back(v);
83       }
84     }
85     return ret;
86   }
87 
CollectBound()88   Array<TypeVar> CollectBound() {
89     Array<TypeVar> ret;
90     for (const auto& v : bound_type_vars_.data) {
91       ret.push_back(v);
92     }
93     return ret;
94   }
95 
CollectAll()96   Array<TypeVar> CollectAll() {
97     Array<TypeVar> ret;
98     for (const auto& v : type_vars_.data) {
99       ret.push_back(v);
100     }
101     return ret;
102   }
103 
Free(const Expr & expr)104   Array<TypeVar> Free(const Expr& expr) {
105     VisitExpr(expr);
106     return CollectFree();
107   }
108 
Free(const Type & type)109   Array<TypeVar> Free(const Type& type) {
110     VisitType(type);
111     return CollectFree();
112   }
113 
Bound(const Expr & expr)114   Array<TypeVar> Bound(const Expr& expr) {
115     VisitExpr(expr);
116     return CollectBound();
117   }
118 
Bound(const Type & type)119   Array<TypeVar> Bound(const Type& type) {
120     VisitType(type);
121     return CollectBound();
122   }
123 
All(const Expr & expr)124   Array<TypeVar> All(const Expr& expr) {
125     VisitExpr(expr);
126     return CollectAll();
127   }
128 
All(const Type & type)129   Array<TypeVar> All(const Type& type) {
130     VisitType(type);
131     return CollectAll();
132   }
133 
VisitExpr_(const FunctionNode * f)134   void VisitExpr_(const FunctionNode* f) final {
135     for (const auto& tp : f->type_params) {
136       type_vars_.Insert(tp);
137       bound_type_vars_.Insert(tp);
138     }
139     ExprVisitor::VisitExpr_(f);
140   }
141 
VisitExpr_(const ConstructorNode * cn)142   void VisitExpr_(const ConstructorNode* cn) final {
143     // for constructors, type vars will be bound in the module
144     auto data = mod_->LookupTypeDef(cn->belong_to);
145     for (const auto& tv : data->type_vars) {
146       type_vars_.Insert(tv);
147       bound_type_vars_.Insert(tv);
148     }
149     ExprVisitor::VisitExpr_(cn);
150   }
151 
VisitType(const Type & t)152   void VisitType(const Type& t) final {
153     TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t);
154   }
155 
156  private:
157   InsertionSet<TypeVar> type_vars_;
158   InsertionSet<TypeVar> bound_type_vars_;
159   const IRModule& mod_;
160 };
161 
162 class VarVisitor : protected ExprVisitor, protected PatternVisitor {
163  public:
Free(const Expr & expr)164   Array<Var> Free(const Expr& expr) {
165     this->VisitExpr(expr);
166     Array<Var> ret;
167     for (const auto& v : vars_.data) {
168       if (bound_vars_.set.count(v) == 0) {
169         ret.push_back(v);
170       }
171     }
172     return ret;
173   }
174 
Collect()175   Array<Var> Collect() {
176     Array<Var> ret;
177     for (const auto& v : bound_vars_.data) {
178       ret.push_back(v);
179     }
180     return ret;
181   }
182 
Bound(const Expr & expr)183   Array<Var> Bound(const Expr& expr) {
184     this->VisitExpr(expr);
185     return Collect();
186   }
187 
Bound(const Pattern & pat)188   Array<Var> Bound(const Pattern& pat) {
189     this->VisitPattern(pat);
190     return Collect();
191   }
192 
All(const Expr & expr)193   Array<Var> All(const Expr& expr) {
194     this->VisitExpr(expr);
195     Array<Var> ret;
196     for (const auto& v : vars_.data) {
197       ret.push_back(v);
198     }
199     return ret;
200   }
201 
MarkBounded(const Var & v)202   void MarkBounded(const Var& v) {
203     bound_vars_.Insert(v);
204     vars_.Insert(v);
205   }
206 
VisitExpr_(const VarNode * var)207   void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
208 
VisitExpr_(const FunctionNode * op)209   void VisitExpr_(const FunctionNode* op) final {
210     for (const auto& param : op->params) {
211       MarkBounded(param);
212     }
213     VisitExpr(op->body);
214   }
215 
VisitExpr_(const LetNode * op)216   void VisitExpr_(const LetNode* op) final {
217     Expr let = GetRef<Let>(op);
218     while (auto let_node = let.as<LetNode>()) {
219       MarkBounded(let_node->var);
220       VisitExpr(let_node->value);
221       let = let_node->body;
222     }
223     VisitExpr(let);
224   }
225 
VisitPattern(const Pattern & p)226   void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
227 
VisitPattern_(const PatternVarNode * op)228   void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); }
229 
230  private:
231   InsertionSet<Var> vars_;
232   InsertionSet<Var> bound_vars_;
233 };
234 
FreeTypeVars(const Expr & expr,const IRModule & mod)235 tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) {
236   return TypeVarEVisitor(mod).Free(expr);
237 }
238 
FreeTypeVars(const Type & type,const IRModule & mod)239 tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) {
240   return TypeVarEVisitor(mod).Free(type);
241 }
242 
BoundTypeVars(const Expr & expr,const IRModule & mod)243 tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) {
244   return TypeVarEVisitor(mod).Bound(expr);
245 }
246 
BoundTypeVars(const Type & type,const IRModule & mod)247 tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) {
248   return TypeVarEVisitor(mod).Bound(type);
249 }
250 
AllTypeVars(const Expr & expr,const IRModule & mod)251 tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) {
252   return TypeVarEVisitor(mod).All(expr);
253 }
254 
AllTypeVars(const Type & type,const IRModule & mod)255 tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) {
256   return TypeVarEVisitor(mod).All(type);
257 }
258 
FreeVars(const Expr & expr)259 tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
260 
BoundVars(const Expr & expr)261 tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }
262 
BoundVars(const Pattern & pat)263 tvm::Array<Var> BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); }
264 
AllVars(const Expr & expr)265 tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
266 
267 TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars);
268 
__anon383ebf080102(TVMArgs args, TVMRetValue* ret) 269 TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
270   ObjectRef x = args[0];
271   if (x.as<ExprNode>()) {
272     *ret = BoundVars(Downcast<Expr>(x));
273   } else {
274     *ret = BoundVars(Downcast<Pattern>(x));
275   }
276 });
277 
278 TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars);
279 
__anon383ebf080202(TVMArgs args, TVMRetValue* ret) 280 TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
281   ObjectRef x = args[0];
282   IRModule mod = args[1];
283   if (x.as<TypeNode>()) {
284     *ret = FreeTypeVars(Downcast<Type>(x), mod);
285   } else {
286     *ret = FreeTypeVars(Downcast<Expr>(x), mod);
287   }
288 });
289 
__anon383ebf080302(TVMArgs args, TVMRetValue* ret) 290 TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
291   ObjectRef x = args[0];
292   IRModule mod = args[1];
293   if (x.as<TypeNode>()) {
294     *ret = BoundTypeVars(Downcast<Type>(x), mod);
295   } else {
296     *ret = BoundTypeVars(Downcast<Expr>(x), mod);
297   }
298 });
299 
__anon383ebf080402(TVMArgs args, TVMRetValue* ret) 300 TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
301   ObjectRef x = args[0];
302   IRModule mod = args[1];
303   if (x.as<TypeNode>()) {
304     *ret = AllTypeVars(Downcast<Type>(x), mod);
305   } else {
306     *ret = AllTypeVars(Downcast<Expr>(x), mod);
307   }
308 });
309 
310 class DtypeCollector : protected ExprVisitor, protected TypeVisitor {
311  public:
VisitExpr(const Expr & expr)312   void VisitExpr(const Expr& expr) final {
313     if (expr->checked_type_.defined()) {
314       TypeVisitor::VisitType(expr->checked_type());
315     }
316     ExprVisitor::VisitExpr(expr);
317   }
318 
VisitType_(const TensorTypeNode * op)319   void VisitType_(const TensorTypeNode* op) final { dtypes_.insert(DLDataType2String(op->dtype)); }
320 
All(const Expr & expr)321   Array<String> All(const Expr& expr) {
322     VisitExpr(expr);
323 
324     Array<String> res;
325     for (const auto& dtype : dtypes_) {
326       res.push_back(String(dtype));
327     }
328     return res;
329   }
330 
331  private:
332   std::unordered_set<std::string> dtypes_;
333 };
334 
AllDtypes(const Expr & expr)335 tvm::Array<String> AllDtypes(const Expr& expr) { return DtypeCollector().All(expr); }
336 
337 TVM_REGISTER_GLOBAL("relay.analysis.all_dtypes").set_body_typed(AllDtypes);
338 
339 /*!
340  * \brief Get reference counter of each internal ExprNode in body.
341  * \param body The body expression.
342  * \return The reference count mapping.
343  */
GetExprRefCount(const Expr & body)344 std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body) {
345   class ExprRefCounter : private MixedModeVisitor {
346    public:
347     std::unordered_map<const Object*, size_t> Get(const Expr& body) {
348       this->VisitExpr(body);
349       return std::move(this->visit_counter_);
350     }
351   };
352   return ExprRefCounter().Get(body);
353 }
354 
355 template <typename T>
IsNDArrayAllGreaterEqual(const runtime::NDArray & tensor,T value)356 bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
357   CHECK_EQ(tensor->ctx.device_type, kDLCPU);
358   CHECK(tensor->strides == nullptr);
359   CHECK_EQ(tensor->byte_offset, 0);
360   const T* data = static_cast<const T*>(tensor->data);
361   int64_t num_elems = 1;
362   for (int i = 0; i < tensor->ndim; ++i) {
363     num_elems *= tensor->shape[i];
364   }
365 
366   for (int64_t i = 0; i < num_elems; i++) {
367     if (*data < value) {
368       return false;
369     }
370     data++;
371   }
372   return true;
373 }
374 
IsAllPositiveConstant(const Expr & expr)375 bool IsAllPositiveConstant(const Expr& expr) {
376   // Cache the operators that are checked recursively to reduce lookup overhead.
377   static const auto& expand_dims_op = Op::Get("expand_dims");
378   static const auto& reshape_op = Op::Get("reshape");
379   static const auto& transpose_op = Op::Get("transpose");
380   static const auto& squeeze_op = Op::Get("squeeze");
381 
382   // peel through a few common transform ops.
383   if (const auto* constant = expr.as<ConstantNode>()) {
384     const auto& tensor = constant->data;
385     const auto& dtype = tensor->dtype;
386     if (dtype.lanes != 1) {
387       return false;
388     } else if (dtype.code == kDLFloat && dtype.bits == 32) {
389       return IsNDArrayAllGreaterEqual<float>(tensor, 0);
390     } else if (dtype.code == kDLFloat && dtype.bits == 64) {
391       return IsNDArrayAllGreaterEqual<double>(tensor, 0);
392     } else if (dtype.code == kDLInt && dtype.bits == 8) {
393       return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
394     } else if (dtype.code == kDLInt && dtype.bits == 32) {
395       return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
396     } else if (dtype.code == kDLUInt && dtype.bits == 8) {
397       return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
398     } else if (dtype.code == kDLUInt && dtype.bits == 32) {
399       return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
400     } else {
401       return false;
402     }
403   } else if (const auto* op = expr.as<CallNode>()) {
404     // tail recursion.
405     if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op ||
406         op->op == squeeze_op) {
407       return IsAllPositiveConstant(op->args[0]);
408     } else {
409       return false;
410     }
411   } else {
412     return false;
413   }
414 }
415 
TypeSubst(const Type & type,const TypeVar & tvar,const Type & subst)416 Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
417   return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
418 }
419 
TypeSubst(const Expr & expr,const TypeVar & tvar,const Type & subst)420 Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
421   return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
422 }
423 
TypeSubst(const Type & type,const tvm::Map<TypeVar,Type> & subst_map)424 Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
425   return Bind(type, subst_map);
426 }
427 
TypeSubst(const Expr & expr,const tvm::Map<TypeVar,Type> & subst_map)428 Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
429   class TypeSubstMutator : public ExprMutator, public PatternMutator {
430    public:
431     explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) {}
432     Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); }
433     Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); }
434 
435     Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
436 
437     Clause VisitClause(const Clause& c) final {
438       Pattern pat = VisitPattern(c->lhs);
439       return Clause(pat, VisitExpr(c->rhs));
440     }
441 
442    private:
443     const tvm::Map<TypeVar, Type>& subst_map_;
444   };
445   CHECK(WellFormed(expr));
446   auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
447   CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
448   CHECK(WellFormed(ret));
449   return ret;
450 }
451 
452 struct IsDynamicVisitor : public TypeVisitor {
453   bool is_dyn{false};
VisitType_tvm::relay::IsDynamicVisitor454   void VisitType_(const TensorTypeNode* tt) {
455     for (auto dim : tt->shape) {
456       if (dim.as<tir::IntImmNode>() == nullptr) {
457         is_dyn = true;
458         break;
459       }
460     }
461   }
462 };
463 
IsDynamic(const Type & ty)464 bool IsDynamic(const Type& ty) {
465   IsDynamicVisitor v;
466   v.VisitType(ty);
467   return v.is_dyn;
468 }
469 
470 TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
471 
IsDataDependant(const CallNode * call)472 bool IsDataDependant(const CallNode* call) {
473   static auto tshape_data_dependant = Op::GetAttrMap<TShapeDataDependant>("TShapeDataDependant");
474   Op op = Downcast<Op>(call->op);
475 
476   if (!tshape_data_dependant.count(op)) {
477     return false;
478   }
479 
480   if (op->name == "strided_slice") {
481     if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
482       if (attrs->begin && attrs->end && attrs->strides) {
483         // not data dependant if begin, end and strides exist
484         return false;
485       }
486     }
487   }
488 
489   return tshape_data_dependant[op];
490 }
491 }  // namespace relay
492 }  // namespace tvm
493