1 #ifndef HALIDE_EXPR_H 2 #define HALIDE_EXPR_H 3 4 /** \file 5 * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt) 6 */ 7 8 #include <string> 9 #include <vector> 10 11 #include "IntrusivePtr.h" 12 #include "Type.h" 13 14 namespace Halide { 15 16 struct bfloat16_t; 17 struct float16_t; 18 19 namespace Internal { 20 21 class IRMutator; 22 class IRVisitor; 23 24 /** All our IR node types get unique IDs for the purposes of RTTI */ 25 enum class IRNodeType { 26 // Exprs, in order of strength. Code in IRMatch.h and the 27 // simplifier relies on this order for canonicalization of 28 // expressions, so you may need to update those modules if you 29 // change this list. 30 IntImm, 31 UIntImm, 32 FloatImm, 33 StringImm, 34 Broadcast, 35 Cast, 36 Variable, 37 Add, 38 Sub, 39 Mod, 40 Mul, 41 Div, 42 Min, 43 Max, 44 EQ, 45 NE, 46 LT, 47 LE, 48 GT, 49 GE, 50 And, 51 Or, 52 Not, 53 Select, 54 Load, 55 Ramp, 56 Call, 57 Let, 58 Shuffle, 59 VectorReduce, 60 // Stmts 61 LetStmt, 62 AssertStmt, 63 ProducerConsumer, 64 For, 65 Acquire, 66 Store, 67 Provide, 68 Allocate, 69 Free, 70 Realize, 71 Block, 72 Fork, 73 IfThenElse, 74 Evaluate, 75 Prefetch, 76 Atomic 77 }; 78 79 constexpr IRNodeType StrongestExprNodeType = IRNodeType::Shuffle; 80 81 /** The abstract base classes for a node in the Halide IR. */ 82 struct IRNode { 83 84 /** We use the visitor pattern to traverse IR nodes throughout the 85 * compiler, so we have a virtual accept method which accepts 86 * visitors. 87 */ 88 virtual void accept(IRVisitor *v) const = 0; IRNodeIRNode89 IRNode(IRNodeType t) 90 : node_type(t) { 91 } 92 virtual ~IRNode() = default; 93 94 /** These classes are all managed with intrusive reference 95 * counting, so we also track a reference count. It's mutable 96 * so that we can do reference counting even through const 97 * references to IR nodes. 98 */ 99 mutable RefCount ref_count; 100 101 /** Each IR node subclass has a unique identifier. We can compare 102 * these values to do runtime type identification. We don't 103 * compile with rtti because that injects run-time type 104 * identification stuff everywhere (and often breaks when linking 105 * external libraries compiled without it), and we only want it 106 * for IR nodes. One might want to put this value in the vtable, 107 * but that adds another level of indirection, and for Exprs we 108 * have 32 free bits in between the ref count and the Type 109 * anyway, so this doesn't increase the memory footprint of an IR node. 110 */ 111 IRNodeType node_type; 112 }; 113 114 template<> 115 inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept { 116 return t->ref_count; 117 } 118 119 template<> 120 inline void destroy<IRNode>(const IRNode *t) { 121 delete t; 122 } 123 124 /** IR nodes are split into expressions and statements. These are 125 similar to expressions and statements in C - expressions 126 represent some value and have some type (e.g. x + 3), and 127 statements are side-effecting pieces of code that do not 128 represent a value (e.g. assert(x > 3)) */ 129 130 /** A base class for statement nodes. They have no properties or 131 methods beyond base IR nodes for now. */ 132 struct BaseStmtNode : public IRNode { BaseStmtNodeBaseStmtNode133 BaseStmtNode(IRNodeType t) 134 : IRNode(t) { 135 } 136 virtual Stmt mutate_stmt(IRMutator *v) const = 0; 137 }; 138 139 /** A base class for expression nodes. They all contain their types 140 * (e.g. Int(32), Float(32)) */ 141 struct BaseExprNode : public IRNode { BaseExprNodeBaseExprNode142 BaseExprNode(IRNodeType t) 143 : IRNode(t) { 144 } 145 virtual Expr mutate_expr(IRMutator *v) const = 0; 146 Type type; 147 }; 148 149 /** We use the "curiously recurring template pattern" to avoid 150 duplicated code in the IR Nodes. These classes live between the 151 abstract base classes and the actual IR Nodes in the 152 inheritance hierarchy. It provides an implementation of the 153 accept function necessary for the visitor pattern to work, and 154 a concrete instantiation of a unique IRNodeType per class. */ 155 template<typename T> 156 struct ExprNode : public BaseExprNode { 157 void accept(IRVisitor *v) const override; 158 Expr mutate_expr(IRMutator *v) const override; ExprNodeExprNode159 ExprNode() 160 : BaseExprNode(T::_node_type) { 161 } 162 ~ExprNode() override = default; 163 }; 164 165 template<typename T> 166 struct StmtNode : public BaseStmtNode { 167 void accept(IRVisitor *v) const override; 168 Stmt mutate_stmt(IRMutator *v) const override; StmtNodeStmtNode169 StmtNode() 170 : BaseStmtNode(T::_node_type) { 171 } 172 ~StmtNode() override = default; 173 }; 174 175 /** IR nodes are passed around opaque handles to them. This is a 176 base class for those handles. It manages the reference count, 177 and dispatches visitors. */ 178 struct IRHandle : public IntrusivePtr<const IRNode> { 179 HALIDE_ALWAYS_INLINE 180 IRHandle() = default; 181 182 HALIDE_ALWAYS_INLINE IRHandleIRHandle183 IRHandle(const IRNode *p) 184 : IntrusivePtr<const IRNode>(p) { 185 } 186 187 /** Dispatch to the correct visitor method for this node. E.g. if 188 * this node is actually an Add node, then this will call 189 * IRVisitor::visit(const Add *) */ acceptIRHandle190 void accept(IRVisitor *v) const { 191 ptr->accept(v); 192 } 193 194 /** Downcast this ir node to its actual type (e.g. Add, or 195 * Select). This returns nullptr if the node is not of the requested 196 * type. Example usage: 197 * 198 * if (const Add *add = node->as<Add>()) { 199 * // This is an add node 200 * } 201 */ 202 template<typename T> asIRHandle203 const T *as() const { 204 if (ptr && ptr->node_type == T::_node_type) { 205 return (const T *)ptr; 206 } 207 return nullptr; 208 } 209 node_typeIRHandle210 IRNodeType node_type() const { 211 return ptr->node_type; 212 } 213 }; 214 215 /** Integer constants */ 216 struct IntImm : public ExprNode<IntImm> { 217 int64_t value; 218 219 static const IntImm *make(Type t, int64_t value); 220 221 static const IRNodeType _node_type = IRNodeType::IntImm; 222 }; 223 224 /** Unsigned integer constants */ 225 struct UIntImm : public ExprNode<UIntImm> { 226 uint64_t value; 227 228 static const UIntImm *make(Type t, uint64_t value); 229 230 static const IRNodeType _node_type = IRNodeType::UIntImm; 231 }; 232 233 /** Floating point constants */ 234 struct FloatImm : public ExprNode<FloatImm> { 235 double value; 236 237 static const FloatImm *make(Type t, double value); 238 239 static const IRNodeType _node_type = IRNodeType::FloatImm; 240 }; 241 242 /** String constants */ 243 struct StringImm : public ExprNode<StringImm> { 244 std::string value; 245 246 static const StringImm *make(const std::string &val); 247 248 static const IRNodeType _node_type = IRNodeType::StringImm; 249 }; 250 251 } // namespace Internal 252 253 /** A fragment of Halide syntax. It's implemented as reference-counted 254 * handle to a concrete expression node, but it's immutable, so you 255 * can treat it as a value type. */ 256 struct Expr : public Internal::IRHandle { 257 /** Make an undefined expression */ 258 HALIDE_ALWAYS_INLINE 259 Expr() = default; 260 261 /** Make an expression from a concrete expression node pointer (e.g. Add) */ 262 HALIDE_ALWAYS_INLINE ExprExpr263 Expr(const Internal::BaseExprNode *n) 264 : IRHandle(n) { 265 } 266 267 /** Make an expression representing numeric constants of various types. */ 268 // @{ ExprExpr269 explicit Expr(int8_t x) 270 : IRHandle(Internal::IntImm::make(Int(8), x)) { 271 } ExprExpr272 explicit Expr(int16_t x) 273 : IRHandle(Internal::IntImm::make(Int(16), x)) { 274 } ExprExpr275 Expr(int32_t x) 276 : IRHandle(Internal::IntImm::make(Int(32), x)) { 277 } ExprExpr278 explicit Expr(int64_t x) 279 : IRHandle(Internal::IntImm::make(Int(64), x)) { 280 } ExprExpr281 explicit Expr(uint8_t x) 282 : IRHandle(Internal::UIntImm::make(UInt(8), x)) { 283 } ExprExpr284 explicit Expr(uint16_t x) 285 : IRHandle(Internal::UIntImm::make(UInt(16), x)) { 286 } ExprExpr287 explicit Expr(uint32_t x) 288 : IRHandle(Internal::UIntImm::make(UInt(32), x)) { 289 } ExprExpr290 explicit Expr(uint64_t x) 291 : IRHandle(Internal::UIntImm::make(UInt(64), x)) { 292 } ExprExpr293 Expr(float16_t x) 294 : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) { 295 } ExprExpr296 Expr(bfloat16_t x) 297 : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) { 298 } ExprExpr299 Expr(float x) 300 : IRHandle(Internal::FloatImm::make(Float(32), x)) { 301 } ExprExpr302 explicit Expr(double x) 303 : IRHandle(Internal::FloatImm::make(Float(64), x)) { 304 } 305 // @} 306 307 /** Make an expression representing a const string (i.e. a StringImm) */ ExprExpr308 Expr(const std::string &s) 309 : IRHandle(Internal::StringImm::make(s)) { 310 } 311 312 /** Override get() to return a BaseExprNode * instead of an IRNode * */ 313 HALIDE_ALWAYS_INLINE getExpr314 const Internal::BaseExprNode *get() const { 315 return (const Internal::BaseExprNode *)ptr; 316 } 317 318 /** Get the type of this expression node */ 319 HALIDE_ALWAYS_INLINE typeExpr320 Type type() const { 321 return get()->type; 322 } 323 }; 324 325 /** This lets you use an Expr as a key in a map of the form 326 * map<Expr, Foo, ExprCompare> */ 327 struct ExprCompare { operatorExprCompare328 bool operator()(const Expr &a, const Expr &b) const { 329 return a.get() < b.get(); 330 } 331 }; 332 333 /** A single-dimensional span. Includes all numbers between min and 334 * (min + extent - 1). */ 335 struct Range { 336 Expr min, extent; 337 338 Range() = default; 339 Range(const Expr &min_in, const Expr &extent_in); 340 }; 341 342 /** A multi-dimensional box. The outer product of the elements */ 343 typedef std::vector<Range> Region; 344 345 /** An enum describing different address spaces to be used with Func::store_in. */ 346 enum class MemoryType { 347 /** Let Halide select a storage type automatically */ 348 Auto, 349 350 /** Heap/global memory. Allocated using halide_malloc, or 351 * halide_device_malloc */ 352 Heap, 353 354 /** Stack memory. Allocated using alloca. Requires a constant 355 * size. Corresponds to per-thread local memory on the GPU. If all 356 * accesses are at constant coordinates, may be promoted into the 357 * register file at the discretion of the register allocator. */ 358 Stack, 359 360 /** Register memory. The allocation should be promoted into the 361 * register file. All stores must be at constant coordinates. May 362 * be spilled to the stack at the discretion of the register 363 * allocator. */ 364 Register, 365 366 /** Allocation is stored in GPU shared memory. Also known as 367 * "local" in OpenCL, and "threadgroup" in metal. Can be shared 368 * across GPU threads within the same block. */ 369 GPUShared, 370 371 /** Allocate Locked Cache Memory to act as local memory */ 372 LockedCache, 373 /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on 374 * v65+. This memory has higher performance and lower power. Ideal for 375 * intermediate buffers. Necessary for vgather-vscatter instructions 376 * on Hexagon */ 377 VTCM, 378 }; 379 380 namespace Internal { 381 382 /** An enum describing a type of loop traversal. Used in schedules, 383 * and in the For loop IR node. Serial is a conventional ordered for 384 * loop. Iterations occur in increasing order, and each iteration must 385 * appear to have finished before the next begins. Parallel, GPUBlock, 386 * and GPUThread are parallel and unordered: iterations may occur in 387 * any order, and multiple iterations may occur 388 * simultaneously. Vectorized and GPULane are parallel and 389 * synchronous: they act as if all iterations occur at the same time 390 * in lockstep. */ 391 enum class ForType { 392 Serial, 393 Parallel, 394 Vectorized, 395 Unrolled, 396 Extern, 397 GPUBlock, 398 GPUThread, 399 GPULane, 400 }; 401 402 /** Check if for_type executes for loop iterations in parallel and unordered. */ 403 bool is_unordered_parallel(ForType for_type); 404 405 /** Returns true if for_type executes for loop iterations in parallel. */ 406 bool is_parallel(ForType for_type); 407 408 /** A reference-counted handle to a statement node. */ 409 struct Stmt : public IRHandle { 410 Stmt() = default; StmtStmt411 Stmt(const BaseStmtNode *n) 412 : IRHandle(n) { 413 } 414 415 /** Override get() to return a BaseStmtNode * instead of an IRNode * */ 416 HALIDE_ALWAYS_INLINE getStmt417 const BaseStmtNode *get() const { 418 return (const Internal::BaseStmtNode *)ptr; 419 } 420 421 /** This lets you use a Stmt as a key in a map of the form 422 * map<Stmt, Foo, Stmt::Compare> */ 423 struct Compare { operatorStmt::Compare424 bool operator()(const Stmt &a, const Stmt &b) const { 425 return a.ptr < b.ptr; 426 } 427 }; 428 }; 429 430 } // namespace Internal 431 } // namespace Halide 432 433 #endif 434