1 #ifndef HALIDE_EXPR_H
2 #define HALIDE_EXPR_H
3 
4 /** \file
5  * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
6  */
7 
8 #include <string>
9 #include <vector>
10 
11 #include "IntrusivePtr.h"
12 #include "Type.h"
13 
14 namespace Halide {
15 
16 struct bfloat16_t;
17 struct float16_t;
18 
19 namespace Internal {
20 
21 class IRMutator;
22 class IRVisitor;
23 
24 /** All our IR node types get unique IDs for the purposes of RTTI */
25 enum class IRNodeType {
26     // Exprs, in order of strength. Code in IRMatch.h and the
27     // simplifier relies on this order for canonicalization of
28     // expressions, so you may need to update those modules if you
29     // change this list.
30     IntImm,
31     UIntImm,
32     FloatImm,
33     StringImm,
34     Broadcast,
35     Cast,
36     Variable,
37     Add,
38     Sub,
39     Mod,
40     Mul,
41     Div,
42     Min,
43     Max,
44     EQ,
45     NE,
46     LT,
47     LE,
48     GT,
49     GE,
50     And,
51     Or,
52     Not,
53     Select,
54     Load,
55     Ramp,
56     Call,
57     Let,
58     Shuffle,
59     VectorReduce,
60     // Stmts
61     LetStmt,
62     AssertStmt,
63     ProducerConsumer,
64     For,
65     Acquire,
66     Store,
67     Provide,
68     Allocate,
69     Free,
70     Realize,
71     Block,
72     Fork,
73     IfThenElse,
74     Evaluate,
75     Prefetch,
76     Atomic
77 };
78 
79 constexpr IRNodeType StrongestExprNodeType = IRNodeType::Shuffle;
80 
81 /** The abstract base classes for a node in the Halide IR. */
82 struct IRNode {
83 
84     /** We use the visitor pattern to traverse IR nodes throughout the
85      * compiler, so we have a virtual accept method which accepts
86      * visitors.
87      */
88     virtual void accept(IRVisitor *v) const = 0;
IRNodeIRNode89     IRNode(IRNodeType t)
90         : node_type(t) {
91     }
92     virtual ~IRNode() = default;
93 
94     /** These classes are all managed with intrusive reference
95      * counting, so we also track a reference count. It's mutable
96      * so that we can do reference counting even through const
97      * references to IR nodes.
98      */
99     mutable RefCount ref_count;
100 
101     /** Each IR node subclass has a unique identifier. We can compare
102      * these values to do runtime type identification. We don't
103      * compile with rtti because that injects run-time type
104      * identification stuff everywhere (and often breaks when linking
105      * external libraries compiled without it), and we only want it
106      * for IR nodes. One might want to put this value in the vtable,
107      * but that adds another level of indirection, and for Exprs we
108      * have 32 free bits in between the ref count and the Type
109      * anyway, so this doesn't increase the memory footprint of an IR node.
110      */
111     IRNodeType node_type;
112 };
113 
114 template<>
115 inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept {
116     return t->ref_count;
117 }
118 
119 template<>
120 inline void destroy<IRNode>(const IRNode *t) {
121     delete t;
122 }
123 
124 /** IR nodes are split into expressions and statements. These are
125    similar to expressions and statements in C - expressions
126    represent some value and have some type (e.g. x + 3), and
127    statements are side-effecting pieces of code that do not
128    represent a value (e.g. assert(x > 3)) */
129 
130 /** A base class for statement nodes. They have no properties or
131    methods beyond base IR nodes for now. */
132 struct BaseStmtNode : public IRNode {
BaseStmtNodeBaseStmtNode133     BaseStmtNode(IRNodeType t)
134         : IRNode(t) {
135     }
136     virtual Stmt mutate_stmt(IRMutator *v) const = 0;
137 };
138 
139 /** A base class for expression nodes. They all contain their types
140  * (e.g. Int(32), Float(32)) */
141 struct BaseExprNode : public IRNode {
BaseExprNodeBaseExprNode142     BaseExprNode(IRNodeType t)
143         : IRNode(t) {
144     }
145     virtual Expr mutate_expr(IRMutator *v) const = 0;
146     Type type;
147 };
148 
149 /** We use the "curiously recurring template pattern" to avoid
150    duplicated code in the IR Nodes. These classes live between the
151    abstract base classes and the actual IR Nodes in the
152    inheritance hierarchy. It provides an implementation of the
153    accept function necessary for the visitor pattern to work, and
154    a concrete instantiation of a unique IRNodeType per class. */
155 template<typename T>
156 struct ExprNode : public BaseExprNode {
157     void accept(IRVisitor *v) const override;
158     Expr mutate_expr(IRMutator *v) const override;
ExprNodeExprNode159     ExprNode()
160         : BaseExprNode(T::_node_type) {
161     }
162     ~ExprNode() override = default;
163 };
164 
165 template<typename T>
166 struct StmtNode : public BaseStmtNode {
167     void accept(IRVisitor *v) const override;
168     Stmt mutate_stmt(IRMutator *v) const override;
StmtNodeStmtNode169     StmtNode()
170         : BaseStmtNode(T::_node_type) {
171     }
172     ~StmtNode() override = default;
173 };
174 
175 /** IR nodes are passed around opaque handles to them. This is a
176    base class for those handles. It manages the reference count,
177    and dispatches visitors. */
178 struct IRHandle : public IntrusivePtr<const IRNode> {
179     HALIDE_ALWAYS_INLINE
180     IRHandle() = default;
181 
182     HALIDE_ALWAYS_INLINE
IRHandleIRHandle183     IRHandle(const IRNode *p)
184         : IntrusivePtr<const IRNode>(p) {
185     }
186 
187     /** Dispatch to the correct visitor method for this node. E.g. if
188      * this node is actually an Add node, then this will call
189      * IRVisitor::visit(const Add *) */
acceptIRHandle190     void accept(IRVisitor *v) const {
191         ptr->accept(v);
192     }
193 
194     /** Downcast this ir node to its actual type (e.g. Add, or
195      * Select). This returns nullptr if the node is not of the requested
196      * type. Example usage:
197      *
198      * if (const Add *add = node->as<Add>()) {
199      *   // This is an add node
200      * }
201      */
202     template<typename T>
asIRHandle203     const T *as() const {
204         if (ptr && ptr->node_type == T::_node_type) {
205             return (const T *)ptr;
206         }
207         return nullptr;
208     }
209 
node_typeIRHandle210     IRNodeType node_type() const {
211         return ptr->node_type;
212     }
213 };
214 
215 /** Integer constants */
216 struct IntImm : public ExprNode<IntImm> {
217     int64_t value;
218 
219     static const IntImm *make(Type t, int64_t value);
220 
221     static const IRNodeType _node_type = IRNodeType::IntImm;
222 };
223 
224 /** Unsigned integer constants */
225 struct UIntImm : public ExprNode<UIntImm> {
226     uint64_t value;
227 
228     static const UIntImm *make(Type t, uint64_t value);
229 
230     static const IRNodeType _node_type = IRNodeType::UIntImm;
231 };
232 
233 /** Floating point constants */
234 struct FloatImm : public ExprNode<FloatImm> {
235     double value;
236 
237     static const FloatImm *make(Type t, double value);
238 
239     static const IRNodeType _node_type = IRNodeType::FloatImm;
240 };
241 
242 /** String constants */
243 struct StringImm : public ExprNode<StringImm> {
244     std::string value;
245 
246     static const StringImm *make(const std::string &val);
247 
248     static const IRNodeType _node_type = IRNodeType::StringImm;
249 };
250 
251 }  // namespace Internal
252 
253 /** A fragment of Halide syntax. It's implemented as reference-counted
254  * handle to a concrete expression node, but it's immutable, so you
255  * can treat it as a value type. */
256 struct Expr : public Internal::IRHandle {
257     /** Make an undefined expression */
258     HALIDE_ALWAYS_INLINE
259     Expr() = default;
260 
261     /** Make an expression from a concrete expression node pointer (e.g. Add) */
262     HALIDE_ALWAYS_INLINE
ExprExpr263     Expr(const Internal::BaseExprNode *n)
264         : IRHandle(n) {
265     }
266 
267     /** Make an expression representing numeric constants of various types. */
268     // @{
ExprExpr269     explicit Expr(int8_t x)
270         : IRHandle(Internal::IntImm::make(Int(8), x)) {
271     }
ExprExpr272     explicit Expr(int16_t x)
273         : IRHandle(Internal::IntImm::make(Int(16), x)) {
274     }
ExprExpr275     Expr(int32_t x)
276         : IRHandle(Internal::IntImm::make(Int(32), x)) {
277     }
ExprExpr278     explicit Expr(int64_t x)
279         : IRHandle(Internal::IntImm::make(Int(64), x)) {
280     }
ExprExpr281     explicit Expr(uint8_t x)
282         : IRHandle(Internal::UIntImm::make(UInt(8), x)) {
283     }
ExprExpr284     explicit Expr(uint16_t x)
285         : IRHandle(Internal::UIntImm::make(UInt(16), x)) {
286     }
ExprExpr287     explicit Expr(uint32_t x)
288         : IRHandle(Internal::UIntImm::make(UInt(32), x)) {
289     }
ExprExpr290     explicit Expr(uint64_t x)
291         : IRHandle(Internal::UIntImm::make(UInt(64), x)) {
292     }
ExprExpr293     Expr(float16_t x)
294         : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
295     }
ExprExpr296     Expr(bfloat16_t x)
297         : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
298     }
ExprExpr299     Expr(float x)
300         : IRHandle(Internal::FloatImm::make(Float(32), x)) {
301     }
ExprExpr302     explicit Expr(double x)
303         : IRHandle(Internal::FloatImm::make(Float(64), x)) {
304     }
305     // @}
306 
307     /** Make an expression representing a const string (i.e. a StringImm) */
ExprExpr308     Expr(const std::string &s)
309         : IRHandle(Internal::StringImm::make(s)) {
310     }
311 
312     /** Override get() to return a BaseExprNode * instead of an IRNode * */
313     HALIDE_ALWAYS_INLINE
getExpr314     const Internal::BaseExprNode *get() const {
315         return (const Internal::BaseExprNode *)ptr;
316     }
317 
318     /** Get the type of this expression node */
319     HALIDE_ALWAYS_INLINE
typeExpr320     Type type() const {
321         return get()->type;
322     }
323 };
324 
325 /** This lets you use an Expr as a key in a map of the form
326  * map<Expr, Foo, ExprCompare> */
327 struct ExprCompare {
operatorExprCompare328     bool operator()(const Expr &a, const Expr &b) const {
329         return a.get() < b.get();
330     }
331 };
332 
333 /** A single-dimensional span. Includes all numbers between min and
334  * (min + extent - 1). */
335 struct Range {
336     Expr min, extent;
337 
338     Range() = default;
339     Range(const Expr &min_in, const Expr &extent_in);
340 };
341 
342 /** A multi-dimensional box. The outer product of the elements */
343 typedef std::vector<Range> Region;
344 
345 /** An enum describing different address spaces to be used with Func::store_in. */
346 enum class MemoryType {
347     /** Let Halide select a storage type automatically */
348     Auto,
349 
350     /** Heap/global memory. Allocated using halide_malloc, or
351      * halide_device_malloc */
352     Heap,
353 
354     /** Stack memory. Allocated using alloca. Requires a constant
355      * size. Corresponds to per-thread local memory on the GPU. If all
356      * accesses are at constant coordinates, may be promoted into the
357      * register file at the discretion of the register allocator. */
358     Stack,
359 
360     /** Register memory. The allocation should be promoted into the
361      * register file. All stores must be at constant coordinates. May
362      * be spilled to the stack at the discretion of the register
363      * allocator. */
364     Register,
365 
366     /** Allocation is stored in GPU shared memory. Also known as
367      * "local" in OpenCL, and "threadgroup" in metal. Can be shared
368      * across GPU threads within the same block. */
369     GPUShared,
370 
371     /** Allocate Locked Cache Memory to act as local memory */
372     LockedCache,
373     /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on
374      * v65+. This memory has higher performance and lower power. Ideal for
375      * intermediate buffers. Necessary for vgather-vscatter instructions
376      * on Hexagon */
377     VTCM,
378 };
379 
380 namespace Internal {
381 
382 /** An enum describing a type of loop traversal. Used in schedules,
383  * and in the For loop IR node. Serial is a conventional ordered for
384  * loop. Iterations occur in increasing order, and each iteration must
385  * appear to have finished before the next begins. Parallel, GPUBlock,
386  * and GPUThread are parallel and unordered: iterations may occur in
387  * any order, and multiple iterations may occur
388  * simultaneously. Vectorized and GPULane are parallel and
389  * synchronous: they act as if all iterations occur at the same time
390  * in lockstep. */
391 enum class ForType {
392     Serial,
393     Parallel,
394     Vectorized,
395     Unrolled,
396     Extern,
397     GPUBlock,
398     GPUThread,
399     GPULane,
400 };
401 
402 /** Check if for_type executes for loop iterations in parallel and unordered. */
403 bool is_unordered_parallel(ForType for_type);
404 
405 /** Returns true if for_type executes for loop iterations in parallel. */
406 bool is_parallel(ForType for_type);
407 
408 /** A reference-counted handle to a statement node. */
409 struct Stmt : public IRHandle {
410     Stmt() = default;
StmtStmt411     Stmt(const BaseStmtNode *n)
412         : IRHandle(n) {
413     }
414 
415     /** Override get() to return a BaseStmtNode * instead of an IRNode * */
416     HALIDE_ALWAYS_INLINE
getStmt417     const BaseStmtNode *get() const {
418         return (const Internal::BaseStmtNode *)ptr;
419     }
420 
421     /** This lets you use a Stmt as a key in a map of the form
422      * map<Stmt, Foo, Stmt::Compare> */
423     struct Compare {
operatorStmt::Compare424         bool operator()(const Stmt &a, const Stmt &b) const {
425             return a.ptr < b.ptr;
426         }
427     };
428 };
429 
430 }  // namespace Internal
431 }  // namespace Halide
432 
433 #endif
434