1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/ir/expr.h
22  * \brief Base expr nodes in TVM.
23  */
24 #ifndef TVM_IR_EXPR_H_
25 #define TVM_IR_EXPR_H_
26 
27 #include <tvm/ir/span.h>
28 #include <tvm/ir/type.h>
29 #include <tvm/node/container.h>
30 #include <tvm/node/node.h>
31 #include <tvm/runtime/object.h>
32 
33 #include <algorithm>
34 #include <limits>
35 #include <string>
36 #include <type_traits>
37 
38 namespace tvm {
39 
40 using tvm::runtime::String;
41 
42 /*!
43  * \brief Base type of all the expressions.
44  * \sa Expr
45  */
46 class BaseExprNode : public Object {
47  public:
48   static constexpr const char* _type_key = "BaseExpr";
49   static constexpr const bool _type_has_method_sequal_reduce = true;
50   static constexpr const bool _type_has_method_shash_reduce = true;
51   static constexpr const uint32_t _type_child_slots = 58;
52   TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
53 };
54 
55 /*!
56  * \brief Managed reference to BaseExprNode.
57  * \sa BaseExprNode
58  */
59 class BaseExpr : public ObjectRef {
60  public:
61   TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode);
62 };
63 
64 /*!
65  * \brief Base node of all primitive expressions.
66  *
67  *  A primitive expression deals with low-level
68  *  POD data types and handles without
69  *  doing life-cycle management for objects.
70  *
71  *  PrimExpr is used in the low-level code
72  *  optimizations and integer analysis.
73  *
74  * \sa PrimExpr
75  */
76 class PrimExprNode : public BaseExprNode {
77  public:
78   /*!
79    * \brief The runtime data type of the primitive expression.
80    *
81    * runtime::DataType(dtype) provides coarse grained type information
82    * during compile time and runtime. It is eagerly built in
83    * PrimExpr expression construction and can be used for
84    * quick type checking.
85    *
86    * dtype is sufficient to decide the Type of the PrimExpr
87    * when it corresponds to POD value types such as i32.
88    *
89    * When dtype is DataType::Handle(), the expression could corresponds to
90    * a more fine-grained Type, and we can get the type by running lazy type inference.
91    */
92   DataType dtype;
93 
94   static constexpr const char* _type_key = "PrimExpr";
95   static constexpr const uint32_t _type_child_slots = 34;
96   TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
97 };
98 
99 /*!
100  * \brief Reference to PrimExprNode.
101  * \sa PrimExprNode
102  */
103 class PrimExpr : public BaseExpr {
104  public:
105   /*!
106    * \brief construct from integer.
107    * \param value The value to be constructed.
108    */
109   TVM_DLL PrimExpr(int32_t value);  // NOLINT(*)
110   /*!
111    * \brief construct from float.
112    * \param value The value to be constructed.
113    */
114   TVM_DLL PrimExpr(float value);  // NOLINT(*)
115 
116   /*! \return the data type of this expression. */
dtype()117   DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; }
118 
119   TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode);
120 
121  private:
122   // Internal function for conversion.
123   friend struct runtime::PackedFuncValueConverter<PrimExpr>;
124   TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
125 };
126 
127 /*!
128  * \brief Base node of all non-primitive expressions.
129  *
130  * RelayExpr supports tensor types, functions and ADT as
131  * first class citizens. The life-cycle of the corresponding
132  * objects are implicitly managed by the language.
133  *
134  * \sa RelayExpr
135  */
136 class RelayExprNode : public BaseExprNode {
137  public:
138   /*!
139    * \brief Span that points to the original source code.
140    *        Reserved debug information.
141    */
142   mutable Span span;
143   /*!
144    * \brief Stores the result of type inference(type checking).
145    *
146    * \note This can be undefined before type inference.
147    *       This value is discarded during serialization.
148    */
149   mutable Type checked_type_ = Type(nullptr);
150   /*!
151    * \return The checked_type
152    */
153   inline const Type& checked_type() const;
154   /*!
155    * \brief Check if the inferred(checked) type of the Expr
156    *  is backed by a TTypeNode and return it.
157    *
158    * \note This function will thrown an error if the node type
159    *       of this Expr is not TTypeNode.
160    *
161    * \return The corresponding TTypeNode pointer.
162    * \tparam The specific TypeNode we look for.
163    */
164   template <typename TTypeNode>
165   inline const TTypeNode* type_as() const;
166 
167   static constexpr const char* _type_key = "RelayExpr";
168   static constexpr const uint32_t _type_child_slots = 22;
169   TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode);
170 };
171 
172 /*!
173  * \brief Managed reference to RelayExprNode.
174  * \sa RelayExprNode
175  */
176 class RelayExpr : public BaseExpr {
177  public:
178   TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode);
179 };
180 
181 class GlobalVar;
182 /*!
183  * \brief Global variable that lives in the top-level module.
184  *
185  * A GlobalVar only refers to function definitions.
186  * This is used to enable recursive calls between function.
187  *
188  * \sa GlobalVarNode
189  */
190 class GlobalVarNode : public RelayExprNode {
191  public:
192   /*! \brief The name of the variable, this only acts as a hint. */
193   String name_hint;
194 
195   void VisitAttrs(AttrVisitor* v) {
196     v->Visit("name_hint", &name_hint);
197     v->Visit("span", &span);
198     v->Visit("_checked_type_", &checked_type_);
199   }
200 
201   bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
202     // name matters for global var.
203     return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other);
204   }
205 
206   void SHashReduce(SHashReducer hash_reduce) const {
207     hash_reduce(name_hint);
208     hash_reduce.FreeVarHashImpl(this);
209   }
210 
211   static constexpr const char* _type_key = "GlobalVar";
212   TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
213 };
214 
215 /*!
216  * \brief Managed reference to GlobalVarNode.
217  * \sa GlobalVarNode
218  */
219 class GlobalVar : public RelayExpr {
220  public:
221   TVM_DLL explicit GlobalVar(String name_hint);
222 
223   TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
224 };
225 
226 // PrimExprs that are useful as runtime containers.
227 //
228 /*!
229  * \brief Constant integer literals in the program.
230  * \sa IntImm
231  */
232 class IntImmNode : public PrimExprNode {
233  public:
234   /*! \brief the Internal value. */
235   int64_t value;
236 
237   void VisitAttrs(AttrVisitor* v) {
238     v->Visit("dtype", &dtype);
239     v->Visit("value", &value);
240   }
241 
242   bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
243     return equal(dtype, other->dtype) && equal(value, other->value);
244   }
245 
246   void SHashReduce(SHashReducer hash_reduce) const {
247     hash_reduce(dtype);
248     hash_reduce(value);
249   }
250 
251   static constexpr const char* _type_key = "IntImm";
252   TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
253 };
254 
255 /*!
256  * \brief Managed reference class to IntImmNode.
257  *
258  * \sa IntImmNode
259  */
260 class IntImm : public PrimExpr {
261  public:
262   /*!
263    * \brief Constructor.
264    * \param dtype The data type of the value.
265    * \param value The internal value.
266    */
267   TVM_DLL IntImm(DataType dtype, int64_t value);
268 
269   TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
270 };
271 
272 /*!
273  * \brief Constant floating point literals in the program.
274  * \sa FloatImm
275  */
276 class FloatImmNode : public PrimExprNode {
277  public:
278   /*! \brief The constant value content. */
279   double value;
280 
281   void VisitAttrs(AttrVisitor* v) {
282     v->Visit("dtype", &dtype);
283     v->Visit("value", &value);
284   }
285 
286   bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
287     return equal(dtype, other->dtype) && equal(value, other->value);
288   }
289 
290   void SHashReduce(SHashReducer hash_reduce) const {
291     hash_reduce(dtype);
292     hash_reduce(value);
293   }
294 
295   static constexpr const char* _type_key = "FloatImm";
296   TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
297 };
298 
299 /*!
300  * \brief Managed reference class to FloatImmNode.
301  *
302  * \sa FloatImmNode
303  */
304 class FloatImm : public PrimExpr {
305  public:
306   /*!
307    * \brief Constructor.
308    * \param dtype The data type of the value.
309    * \param value The internal value.
310    */
311   TVM_DLL FloatImm(DataType dtype, double value);
312 
313   TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
314 };
315 
316 /*!
317  * \brief Boolean constant.
318  *
319  *  This reference type is useful to add additional compile-time
320  *  type checks and helper functions for Integer equal comparisons.
321  */
322 class Bool : public IntImm {
323  public:
324   explicit Bool(bool value) : IntImm(DataType::Bool(), value) {}
325   Bool operator!() const { return Bool((*this)->value == 0); }
326   operator bool() const { return (*this)->value != 0; }
327 
328   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode);
329 };
330 
331 // Overload operators to make sure we have the most fine grained types.
332 inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); }
333 inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); }
334 inline Bool operator||(const Bool& a, const Bool& b) {
335   return Bool(a.operator bool() || b.operator bool());
336 }
337 inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); }
338 inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); }
339 inline Bool operator&&(const Bool& a, const Bool& b) {
340   return Bool(a.operator bool() && b.operator bool());
341 }
342 
343 /*!
344  * \brief Container of constant int that adds more constructors.
345  *
346  * This is used to store and automate type check
347  * attributes that must be constant integer.
348  *
349  * \sa IntImm
350  */
351 class Integer : public IntImm {
352  public:
353   Integer() {}
354   /*!
355    * \brief constructor from node.
356    */
357   explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
358   /*!
359    * \brief Construct integer from int value.
360    */
361   Integer(int value) : IntImm(DataType::Int(32), value) {}  // NOLINT(*)
362   /*!
363    * \brief Construct integer from int imm.
364    * \param other The other value.
365    */
366   Integer(IntImm other) : IntImm(std::move(other)) {}  // NOLINT(*)
367   /*!
368    * \brief Constructor from enum
369    * \tparam Enum The enum type.
370    * \param value The enum value.
371    */
372   template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
373   explicit Integer(Enum value) : Integer(static_cast<int>(value)) {
374     static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value,
375                   "declare enum to be enum int to use visitor");
376   }
377   /*!
378    * \brief Assign an expression to integer.
379    * \param other another expression.
380    */
381   Integer& operator=(const IntImm& other) {
382     data_ = ObjectRef::GetDataPtr<Object>(other);
383     return *this;
384   }
385   /*!
386    * \brief convert to int64_t
387    */
388   operator int64_t() const {
389     CHECK(data_ != nullptr) << " Trying to reference a null Integer";
390     return (*this)->value;
391   }
392   // comparators
393   Bool operator==(int other) const {
394     if (data_ == nullptr) return Bool(false);
395     return Bool((*this)->value == other);
396   }
397   Bool operator!=(int other) const { return !(*this == other); }
398   template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
399   Bool operator==(Enum other) const {
400     return *this == static_cast<int>(other);
401   }
402   template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type>
403   Bool operator!=(Enum other) const {
404     return *this != static_cast<int>(other);
405   }
406 };
407 
408 /*! \brief range over one dimension */
409 class RangeNode : public Object {
410  public:
411   /*! \brief beginning of the node */
412   PrimExpr min;
413   /*! \brief the extend of range */
414   PrimExpr extent;
415   /*! \brief constructor */
416   RangeNode() {}
417   RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {}
418 
419   void VisitAttrs(AttrVisitor* v) {
420     v->Visit("min", &min);
421     v->Visit("extent", &extent);
422   }
423 
424   bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
425     return equal(min, other->min) && equal(extent, other->extent);
426   }
427 
428   void SHashReduce(SHashReducer hash_reduce) const {
429     hash_reduce(min);
430     hash_reduce(extent);
431   }
432 
433   static constexpr const char* _type_key = "Range";
434   static constexpr const bool _type_has_method_sequal_reduce = true;
435   static constexpr const bool _type_has_method_shash_reduce = true;
436   TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
437 };
438 
439 /*! \brief Range constainer  */
440 class Range : public ObjectRef {
441  public:
442   /*!
443    * \brief constructor by begin and end
444    * \param begin The begin of the range.
445    * \param end The end of the range.
446    */
447   TVM_DLL Range(PrimExpr begin, PrimExpr end);
448   /*!
449    * \brief construct a new range with min and extent
450    *  The corresponding constructor is removed,
451    *  because that is counter convention of tradition meaning
452    *  of range(begin, end)
453    *
454    * \param min The minimum range.
455    * \param extent The extent of the range.
456    */
457   static Range FromMinExtent(PrimExpr min, PrimExpr extent);
458   // declare range.
459   TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
460 };
461 
462 // implementataions
463 inline const Type& RelayExprNode::checked_type() const {
464   CHECK(checked_type_.defined()) << "internal error: the type checker has "
465                                  << "not populated the checked_type "
466                                  << "field for " << GetRef<RelayExpr>(this);
467   return this->checked_type_;
468 }
469 
470 template <typename TTypeNode>
471 inline const TTypeNode* RelayExprNode::type_as() const {
472   static_assert(std::is_base_of<TypeNode, TTypeNode>::value,
473                 "TType must be a special case of type");
474   CHECK(checked_type_.defined())
475       << "Type inference for this Expr has not completed. Try to call infer_type pass.";
476   const TTypeNode* node = checked_type_.as<TTypeNode>();
477   CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get "
478                          << checked_type_->GetTypeKey();
479   return node;
480 }
481 
482 }  // namespace tvm
483 
484 namespace tvm {
485 namespace runtime {
486 // common rule for RetValue and ArgValue
487 template <>
488 struct PackedFuncValueConverter<PrimExpr> {
489   static PrimExpr From(const TVMPODValue_& val) {
490     if (val.type_code() == kTVMNullptr) {
491       return PrimExpr(ObjectPtr<Object>(nullptr));
492     }
493     if (val.type_code() == kDLInt) {
494       return PrimExpr(val.operator int());
495     }
496     if (val.type_code() == kDLFloat) {
497       return PrimExpr(static_cast<float>(val.operator double()));
498     }
499 
500     return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
501   }
502 };
503 
504 template <>
505 struct PackedFuncValueConverter<tvm::Integer> {
506   static tvm::Integer From(const TVMPODValue_& val) {
507     if (val.type_code() == kTVMNullptr) {
508       return Integer(ObjectPtr<Object>(nullptr));
509     }
510     if (val.type_code() == kTVMArgInt) {
511       return Integer(val.operator int());
512     }
513     return val.AsObjectRef<tvm::Integer>();
514   }
515 };
516 
517 template <>
518 struct PackedFuncValueConverter<tvm::Bool> {
519   static tvm::Bool From(const TVMPODValue_& val) {
520     if (val.type_code() == kTVMNullptr) {
521       return Bool(ObjectPtr<Object>(nullptr));
522     }
523     if (val.type_code() == kTVMArgInt) {
524       int v = val.operator int();
525       CHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
526       return Bool(static_cast<bool>(v));
527     }
528     return val.AsObjectRef<tvm::Bool>();
529   }
530 };
531 
532 }  // namespace runtime
533 }  // namespace tvm
534 #endif  // TVM_IR_EXPR_H_
535