1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/ir/expr.h 22 * \brief Base expr nodes in TVM. 23 */ 24 #ifndef TVM_IR_EXPR_H_ 25 #define TVM_IR_EXPR_H_ 26 27 #include <tvm/ir/span.h> 28 #include <tvm/ir/type.h> 29 #include <tvm/node/container.h> 30 #include <tvm/node/node.h> 31 #include <tvm/runtime/object.h> 32 33 #include <algorithm> 34 #include <limits> 35 #include <string> 36 #include <type_traits> 37 38 namespace tvm { 39 40 using tvm::runtime::String; 41 42 /*! 43 * \brief Base type of all the expressions. 44 * \sa Expr 45 */ 46 class BaseExprNode : public Object { 47 public: 48 static constexpr const char* _type_key = "BaseExpr"; 49 static constexpr const bool _type_has_method_sequal_reduce = true; 50 static constexpr const bool _type_has_method_shash_reduce = true; 51 static constexpr const uint32_t _type_child_slots = 58; 52 TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); 53 }; 54 55 /*! 56 * \brief Managed reference to BaseExprNode. 57 * \sa BaseExprNode 58 */ 59 class BaseExpr : public ObjectRef { 60 public: 61 TVM_DEFINE_OBJECT_REF_METHODS(BaseExpr, ObjectRef, BaseExprNode); 62 }; 63 64 /*! 65 * \brief Base node of all primitive expressions. 66 * 67 * A primitive expression deals with low-level 68 * POD data types and handles without 69 * doing life-cycle management for objects. 70 * 71 * PrimExpr is used in the low-level code 72 * optimizations and integer analysis. 73 * 74 * \sa PrimExpr 75 */ 76 class PrimExprNode : public BaseExprNode { 77 public: 78 /*! 79 * \brief The runtime data type of the primitive expression. 80 * 81 * runtime::DataType(dtype) provides coarse grained type information 82 * during compile time and runtime. It is eagerly built in 83 * PrimExpr expression construction and can be used for 84 * quick type checking. 85 * 86 * dtype is sufficient to decide the Type of the PrimExpr 87 * when it corresponds to POD value types such as i32. 88 * 89 * When dtype is DataType::Handle(), the expression could corresponds to 90 * a more fine-grained Type, and we can get the type by running lazy type inference. 91 */ 92 DataType dtype; 93 94 static constexpr const char* _type_key = "PrimExpr"; 95 static constexpr const uint32_t _type_child_slots = 34; 96 TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); 97 }; 98 99 /*! 100 * \brief Reference to PrimExprNode. 101 * \sa PrimExprNode 102 */ 103 class PrimExpr : public BaseExpr { 104 public: 105 /*! 106 * \brief construct from integer. 107 * \param value The value to be constructed. 108 */ 109 TVM_DLL PrimExpr(int32_t value); // NOLINT(*) 110 /*! 111 * \brief construct from float. 112 * \param value The value to be constructed. 113 */ 114 TVM_DLL PrimExpr(float value); // NOLINT(*) 115 116 /*! \return the data type of this expression. */ dtype()117 DataType dtype() const { return static_cast<const PrimExprNode*>(get())->dtype; } 118 119 TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); 120 121 private: 122 // Internal function for conversion. 123 friend struct runtime::PackedFuncValueConverter<PrimExpr>; 124 TVM_DLL static PrimExpr FromObject_(ObjectRef ref); 125 }; 126 127 /*! 128 * \brief Base node of all non-primitive expressions. 129 * 130 * RelayExpr supports tensor types, functions and ADT as 131 * first class citizens. The life-cycle of the corresponding 132 * objects are implicitly managed by the language. 133 * 134 * \sa RelayExpr 135 */ 136 class RelayExprNode : public BaseExprNode { 137 public: 138 /*! 139 * \brief Span that points to the original source code. 140 * Reserved debug information. 141 */ 142 mutable Span span; 143 /*! 144 * \brief Stores the result of type inference(type checking). 145 * 146 * \note This can be undefined before type inference. 147 * This value is discarded during serialization. 148 */ 149 mutable Type checked_type_ = Type(nullptr); 150 /*! 151 * \return The checked_type 152 */ 153 inline const Type& checked_type() const; 154 /*! 155 * \brief Check if the inferred(checked) type of the Expr 156 * is backed by a TTypeNode and return it. 157 * 158 * \note This function will thrown an error if the node type 159 * of this Expr is not TTypeNode. 160 * 161 * \return The corresponding TTypeNode pointer. 162 * \tparam The specific TypeNode we look for. 163 */ 164 template <typename TTypeNode> 165 inline const TTypeNode* type_as() const; 166 167 static constexpr const char* _type_key = "RelayExpr"; 168 static constexpr const uint32_t _type_child_slots = 22; 169 TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); 170 }; 171 172 /*! 173 * \brief Managed reference to RelayExprNode. 174 * \sa RelayExprNode 175 */ 176 class RelayExpr : public BaseExpr { 177 public: 178 TVM_DEFINE_OBJECT_REF_METHODS(RelayExpr, BaseExpr, RelayExprNode); 179 }; 180 181 class GlobalVar; 182 /*! 183 * \brief Global variable that lives in the top-level module. 184 * 185 * A GlobalVar only refers to function definitions. 186 * This is used to enable recursive calls between function. 187 * 188 * \sa GlobalVarNode 189 */ 190 class GlobalVarNode : public RelayExprNode { 191 public: 192 /*! \brief The name of the variable, this only acts as a hint. */ 193 String name_hint; 194 195 void VisitAttrs(AttrVisitor* v) { 196 v->Visit("name_hint", &name_hint); 197 v->Visit("span", &span); 198 v->Visit("_checked_type_", &checked_type_); 199 } 200 201 bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { 202 // name matters for global var. 203 return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); 204 } 205 206 void SHashReduce(SHashReducer hash_reduce) const { 207 hash_reduce(name_hint); 208 hash_reduce.FreeVarHashImpl(this); 209 } 210 211 static constexpr const char* _type_key = "GlobalVar"; 212 TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); 213 }; 214 215 /*! 216 * \brief Managed reference to GlobalVarNode. 217 * \sa GlobalVarNode 218 */ 219 class GlobalVar : public RelayExpr { 220 public: 221 TVM_DLL explicit GlobalVar(String name_hint); 222 223 TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); 224 }; 225 226 // PrimExprs that are useful as runtime containers. 227 // 228 /*! 229 * \brief Constant integer literals in the program. 230 * \sa IntImm 231 */ 232 class IntImmNode : public PrimExprNode { 233 public: 234 /*! \brief the Internal value. */ 235 int64_t value; 236 237 void VisitAttrs(AttrVisitor* v) { 238 v->Visit("dtype", &dtype); 239 v->Visit("value", &value); 240 } 241 242 bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { 243 return equal(dtype, other->dtype) && equal(value, other->value); 244 } 245 246 void SHashReduce(SHashReducer hash_reduce) const { 247 hash_reduce(dtype); 248 hash_reduce(value); 249 } 250 251 static constexpr const char* _type_key = "IntImm"; 252 TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); 253 }; 254 255 /*! 256 * \brief Managed reference class to IntImmNode. 257 * 258 * \sa IntImmNode 259 */ 260 class IntImm : public PrimExpr { 261 public: 262 /*! 263 * \brief Constructor. 264 * \param dtype The data type of the value. 265 * \param value The internal value. 266 */ 267 TVM_DLL IntImm(DataType dtype, int64_t value); 268 269 TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); 270 }; 271 272 /*! 273 * \brief Constant floating point literals in the program. 274 * \sa FloatImm 275 */ 276 class FloatImmNode : public PrimExprNode { 277 public: 278 /*! \brief The constant value content. */ 279 double value; 280 281 void VisitAttrs(AttrVisitor* v) { 282 v->Visit("dtype", &dtype); 283 v->Visit("value", &value); 284 } 285 286 bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const { 287 return equal(dtype, other->dtype) && equal(value, other->value); 288 } 289 290 void SHashReduce(SHashReducer hash_reduce) const { 291 hash_reduce(dtype); 292 hash_reduce(value); 293 } 294 295 static constexpr const char* _type_key = "FloatImm"; 296 TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); 297 }; 298 299 /*! 300 * \brief Managed reference class to FloatImmNode. 301 * 302 * \sa FloatImmNode 303 */ 304 class FloatImm : public PrimExpr { 305 public: 306 /*! 307 * \brief Constructor. 308 * \param dtype The data type of the value. 309 * \param value The internal value. 310 */ 311 TVM_DLL FloatImm(DataType dtype, double value); 312 313 TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); 314 }; 315 316 /*! 317 * \brief Boolean constant. 318 * 319 * This reference type is useful to add additional compile-time 320 * type checks and helper functions for Integer equal comparisons. 321 */ 322 class Bool : public IntImm { 323 public: 324 explicit Bool(bool value) : IntImm(DataType::Bool(), value) {} 325 Bool operator!() const { return Bool((*this)->value == 0); } 326 operator bool() const { return (*this)->value != 0; } 327 328 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); 329 }; 330 331 // Overload operators to make sure we have the most fine grained types. 332 inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); } 333 inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); } 334 inline Bool operator||(const Bool& a, const Bool& b) { 335 return Bool(a.operator bool() || b.operator bool()); 336 } 337 inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); } 338 inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); } 339 inline Bool operator&&(const Bool& a, const Bool& b) { 340 return Bool(a.operator bool() && b.operator bool()); 341 } 342 343 /*! 344 * \brief Container of constant int that adds more constructors. 345 * 346 * This is used to store and automate type check 347 * attributes that must be constant integer. 348 * 349 * \sa IntImm 350 */ 351 class Integer : public IntImm { 352 public: 353 Integer() {} 354 /*! 355 * \brief constructor from node. 356 */ 357 explicit Integer(ObjectPtr<Object> node) : IntImm(node) {} 358 /*! 359 * \brief Construct integer from int value. 360 */ 361 Integer(int value) : IntImm(DataType::Int(32), value) {} // NOLINT(*) 362 /*! 363 * \brief Construct integer from int imm. 364 * \param other The other value. 365 */ 366 Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) 367 /*! 368 * \brief Constructor from enum 369 * \tparam Enum The enum type. 370 * \param value The enum value. 371 */ 372 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type> 373 explicit Integer(Enum value) : Integer(static_cast<int>(value)) { 374 static_assert(std::is_same<int, typename std::underlying_type<Enum>::type>::value, 375 "declare enum to be enum int to use visitor"); 376 } 377 /*! 378 * \brief Assign an expression to integer. 379 * \param other another expression. 380 */ 381 Integer& operator=(const IntImm& other) { 382 data_ = ObjectRef::GetDataPtr<Object>(other); 383 return *this; 384 } 385 /*! 386 * \brief convert to int64_t 387 */ 388 operator int64_t() const { 389 CHECK(data_ != nullptr) << " Trying to reference a null Integer"; 390 return (*this)->value; 391 } 392 // comparators 393 Bool operator==(int other) const { 394 if (data_ == nullptr) return Bool(false); 395 return Bool((*this)->value == other); 396 } 397 Bool operator!=(int other) const { return !(*this == other); } 398 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type> 399 Bool operator==(Enum other) const { 400 return *this == static_cast<int>(other); 401 } 402 template <typename Enum, typename = typename std::enable_if<std::is_enum<Enum>::value>::type> 403 Bool operator!=(Enum other) const { 404 return *this != static_cast<int>(other); 405 } 406 }; 407 408 /*! \brief range over one dimension */ 409 class RangeNode : public Object { 410 public: 411 /*! \brief beginning of the node */ 412 PrimExpr min; 413 /*! \brief the extend of range */ 414 PrimExpr extent; 415 /*! \brief constructor */ 416 RangeNode() {} 417 RangeNode(PrimExpr min, PrimExpr extent) : min(min), extent(extent) {} 418 419 void VisitAttrs(AttrVisitor* v) { 420 v->Visit("min", &min); 421 v->Visit("extent", &extent); 422 } 423 424 bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const { 425 return equal(min, other->min) && equal(extent, other->extent); 426 } 427 428 void SHashReduce(SHashReducer hash_reduce) const { 429 hash_reduce(min); 430 hash_reduce(extent); 431 } 432 433 static constexpr const char* _type_key = "Range"; 434 static constexpr const bool _type_has_method_sequal_reduce = true; 435 static constexpr const bool _type_has_method_shash_reduce = true; 436 TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); 437 }; 438 439 /*! \brief Range constainer */ 440 class Range : public ObjectRef { 441 public: 442 /*! 443 * \brief constructor by begin and end 444 * \param begin The begin of the range. 445 * \param end The end of the range. 446 */ 447 TVM_DLL Range(PrimExpr begin, PrimExpr end); 448 /*! 449 * \brief construct a new range with min and extent 450 * The corresponding constructor is removed, 451 * because that is counter convention of tradition meaning 452 * of range(begin, end) 453 * 454 * \param min The minimum range. 455 * \param extent The extent of the range. 456 */ 457 static Range FromMinExtent(PrimExpr min, PrimExpr extent); 458 // declare range. 459 TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode); 460 }; 461 462 // implementataions 463 inline const Type& RelayExprNode::checked_type() const { 464 CHECK(checked_type_.defined()) << "internal error: the type checker has " 465 << "not populated the checked_type " 466 << "field for " << GetRef<RelayExpr>(this); 467 return this->checked_type_; 468 } 469 470 template <typename TTypeNode> 471 inline const TTypeNode* RelayExprNode::type_as() const { 472 static_assert(std::is_base_of<TypeNode, TTypeNode>::value, 473 "TType must be a special case of type"); 474 CHECK(checked_type_.defined()) 475 << "Type inference for this Expr has not completed. Try to call infer_type pass."; 476 const TTypeNode* node = checked_type_.as<TTypeNode>(); 477 CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get " 478 << checked_type_->GetTypeKey(); 479 return node; 480 } 481 482 } // namespace tvm 483 484 namespace tvm { 485 namespace runtime { 486 // common rule for RetValue and ArgValue 487 template <> 488 struct PackedFuncValueConverter<PrimExpr> { 489 static PrimExpr From(const TVMPODValue_& val) { 490 if (val.type_code() == kTVMNullptr) { 491 return PrimExpr(ObjectPtr<Object>(nullptr)); 492 } 493 if (val.type_code() == kDLInt) { 494 return PrimExpr(val.operator int()); 495 } 496 if (val.type_code() == kDLFloat) { 497 return PrimExpr(static_cast<float>(val.operator double())); 498 } 499 500 return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>()); 501 } 502 }; 503 504 template <> 505 struct PackedFuncValueConverter<tvm::Integer> { 506 static tvm::Integer From(const TVMPODValue_& val) { 507 if (val.type_code() == kTVMNullptr) { 508 return Integer(ObjectPtr<Object>(nullptr)); 509 } 510 if (val.type_code() == kTVMArgInt) { 511 return Integer(val.operator int()); 512 } 513 return val.AsObjectRef<tvm::Integer>(); 514 } 515 }; 516 517 template <> 518 struct PackedFuncValueConverter<tvm::Bool> { 519 static tvm::Bool From(const TVMPODValue_& val) { 520 if (val.type_code() == kTVMNullptr) { 521 return Bool(ObjectPtr<Object>(nullptr)); 522 } 523 if (val.type_code() == kTVMArgInt) { 524 int v = val.operator int(); 525 CHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; 526 return Bool(static_cast<bool>(v)); 527 } 528 return val.AsObjectRef<tvm::Bool>(); 529 } 530 }; 531 532 } // namespace runtime 533 } // namespace tvm 534 #endif // TVM_IR_EXPR_H_ 535