1 #ifndef HALIDE_SIMPLIFY_VISITORS_H
2 #define HALIDE_SIMPLIFY_VISITORS_H
3 
4 /** \file
5  * The simplifier is separated into multiple compilation units with
6  * this single shared header to speed up the build. This file is not
7  * exported in Halide.h. */
8 
9 #include "Bounds.h"
10 #include "IRMatch.h"
11 #include "IRVisitor.h"
12 #include "Scope.h"
13 
14 // Because this file is only included by the simplify methods and
15 // doesn't go into Halide.h, we're free to use any old names for our
16 // macros.
17 
18 #define LOG_EXPR_MUTATIONS 0
19 #define LOG_STMT_MUTATIONS 0
20 
21 // On old compilers, some visitors would use large stack frames,
22 // because they use expression templates that generate large numbers
23 // of temporary objects when they are built and matched against. If we
24 // wrap the expressions that imply lots of temporaries in a lambda, we
25 // can get these large frames out of the recursive path.
26 #define EVAL_IN_LAMBDA(x) (([&]() HALIDE_NEVER_INLINE { return (x); })())
27 
28 namespace Halide {
29 namespace Internal {
30 
31 class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
32     using Super = VariadicVisitor<Simplify, Expr, Stmt>;
33 
34 public:
35     Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai);
36 
37     struct ExprInfo {
38         // We track constant integer bounds when they exist
39         int64_t min = 0, max = 0;
40         bool min_defined = false, max_defined = false;
41         // And the alignment of integer variables
42         ModulusRemainder alignment;
43 
trim_bounds_using_alignmentExprInfo44         void trim_bounds_using_alignment() {
45             if (alignment.modulus == 0) {
46                 min_defined = max_defined = true;
47                 min = max = alignment.remainder;
48             } else if (alignment.modulus > 1) {
49                 if (min_defined) {
50                     int64_t new_min = min - mod_imp(min, alignment.modulus) + alignment.remainder;
51                     if (new_min < min) {
52                         new_min += alignment.modulus;
53                     }
54                     min = new_min;
55                 }
56                 if (max_defined) {
57                     int64_t new_max = max - mod_imp(max, alignment.modulus) + alignment.remainder;
58                     if (new_max > max) {
59                         new_max -= alignment.modulus;
60                     }
61                     max = new_max;
62                 }
63             }
64 
65             if (min_defined && max_defined && min == max) {
66                 alignment.modulus = 0;
67                 alignment.remainder = min;
68             }
69         }
70 
71         // Mix in existing knowledge about this Expr
intersectExprInfo72         void intersect(const ExprInfo &other) {
73             if (min_defined && other.min_defined) {
74                 min = std::max(min, other.min);
75             } else if (other.min_defined) {
76                 min_defined = true;
77                 min = other.min;
78             }
79 
80             if (max_defined && other.max_defined) {
81                 max = std::min(max, other.max);
82             } else if (other.max_defined) {
83                 max_defined = true;
84                 max = other.max;
85             }
86 
87             alignment = ModulusRemainder::intersect(alignment, other.alignment);
88 
89             trim_bounds_using_alignment();
90         }
91     };
92 
93 #if (LOG_EXPR_MUTATORIONS || LOG_STMT_MUTATIONS)
94     static int debug_indent;
95 #endif
96 
97 #if LOG_EXPR_MUTATIONS
mutate(const Expr & e,ExprInfo * b)98     Expr mutate(const Expr &e, ExprInfo *b) {
99         const std::string spaces(debug_indent, ' ');
100         debug(1) << spaces << "Simplifying Expr: " << e << "\n";
101         debug_indent++;
102         Expr new_e = Super::dispatch(e, b);
103         debug_indent--;
104         if (!new_e.same_as(e)) {
105             debug(1)
106                 << spaces << "Before: " << e << "\n"
107                 << spaces << "After:  " << new_e << "\n";
108         }
109         internal_assert(e.type() == new_e.type());
110         return new_e;
111     }
112 
113 #else
114     HALIDE_ALWAYS_INLINE
mutate(const Expr & e,ExprInfo * b)115     Expr mutate(const Expr &e, ExprInfo *b) {
116         Expr new_e = Super::dispatch(e, b);
117         internal_assert(new_e.type() == e.type()) << e << " -> " << new_e << "\n";
118         return new_e;
119     }
120 #endif
121 
122 #if LOG_STMT_MUTATIONS
mutate(const Stmt & s)123     Stmt mutate(const Stmt &s) {
124         const std::string spaces(debug_indent, ' ');
125         debug(1) << spaces << "Simplifying Stmt: " << s << "\n";
126         debug_indent++;
127         Stmt new_s = Super::dispatch(s);
128         debug_indent--;
129         if (!new_s.same_as(s)) {
130             debug(1)
131                 << spaces << "Before: " << s << "\n"
132                 << spaces << "After:  " << new_s << "\n";
133         }
134         return new_s;
135     }
136 #else
mutate(const Stmt & s)137     Stmt mutate(const Stmt &s) {
138         return Super::dispatch(s);
139     }
140 #endif
141 
142     bool remove_dead_lets;
143     bool no_float_simplify;
144 
145     HALIDE_ALWAYS_INLINE
may_simplify(const Type & t)146     bool may_simplify(const Type &t) const {
147         return !no_float_simplify || !t.is_float();
148     }
149 
150     // Returns true iff t is an integral type where overflow is undefined
151     HALIDE_ALWAYS_INLINE
no_overflow_int(Type t)152     bool no_overflow_int(Type t) {
153         return t.is_int() && t.bits() >= 32;
154     }
155 
156     HALIDE_ALWAYS_INLINE
no_overflow_scalar_int(Type t)157     bool no_overflow_scalar_int(Type t) {
158         return t.is_scalar() && no_overflow_int(t);
159     }
160 
161     // Returns true iff t does not have a well defined overflow behavior.
162     HALIDE_ALWAYS_INLINE
no_overflow(Type t)163     bool no_overflow(Type t) {
164         return t.is_float() || no_overflow_int(t);
165     }
166 
167     struct VarInfo {
168         Expr replacement;
169         int old_uses, new_uses;
170     };
171 
172     // Tracked for all let vars
173     Scope<VarInfo> var_info;
174 
175     // Only tracked for integer let vars
176     Scope<ExprInfo> bounds_and_alignment_info;
177 
178     // Symbols used by rewrite rules
179     IRMatcher::Wild<0> x;
180     IRMatcher::Wild<1> y;
181     IRMatcher::Wild<2> z;
182     IRMatcher::Wild<3> w;
183     IRMatcher::Wild<4> u;
184     IRMatcher::Wild<5> v;
185     IRMatcher::WildConst<0> c0;
186     IRMatcher::WildConst<1> c1;
187     IRMatcher::WildConst<2> c2;
188     IRMatcher::WildConst<3> c3;
189 
190     // Tracks whether or not we're inside a vector loop. Certain
191     // transformations are not a good idea if the code is to be
192     // vectorized.
193     bool in_vector_loop = false;
194 
195     // If we encounter a reference to a buffer (a Load, Store, Call,
196     // or Provide), there's an implicit dependence on some associated
197     // symbols.
198     void found_buffer_reference(const std::string &name, size_t dimensions = 0);
199 
200     // Wrappers for as_const_foo that are more convenient to use in
201     // the large chains of conditions in the visit methods
202     // below. Unlike the versions in IROperator, these only match
203     // scalars.
204     bool const_float(const Expr &e, double *f);
205     bool const_int(const Expr &e, int64_t *i);
206     bool const_uint(const Expr &e, uint64_t *u);
207 
208     // Put the args to a commutative op in a canonical order
209     HALIDE_ALWAYS_INLINE
should_commute(const Expr & a,const Expr & b)210     bool should_commute(const Expr &a, const Expr &b) {
211         if (a.node_type() < b.node_type()) return true;
212         if (a.node_type() > b.node_type()) return false;
213 
214         if (a.node_type() == IRNodeType::Variable) {
215             const Variable *va = a.as<Variable>();
216             const Variable *vb = b.as<Variable>();
217             return va->name.compare(vb->name) > 0;
218         }
219 
220         return false;
221     }
222 
223     std::set<Expr, IRDeepCompare> truths, falsehoods;
224 
225     struct ScopedFact {
226         Simplify *simplify;
227 
228         std::vector<const Variable *> pop_list;
229         std::vector<const Variable *> bounds_pop_list;
230         std::vector<Expr> truths, falsehoods;
231 
232         void learn_false(const Expr &fact);
233         void learn_true(const Expr &fact);
234         void learn_upper_bound(const Variable *v, int64_t val);
235         void learn_lower_bound(const Variable *v, int64_t val);
236 
ScopedFactScopedFact237         ScopedFact(Simplify *s)
238             : simplify(s) {
239         }
240         ~ScopedFact();
241 
242         // allow move but not copy
243         ScopedFact(const ScopedFact &that) = delete;
244         ScopedFact(ScopedFact &&that) = default;
245     };
246 
247     // Tell the simplifier to learn from and exploit a boolean
248     // condition, over the lifetime of the returned object.
scoped_truth(const Expr & fact)249     ScopedFact scoped_truth(const Expr &fact) {
250         ScopedFact f(this);
251         f.learn_true(fact);
252         return f;
253     }
254 
255     // Tell the simplifier to assume a boolean condition is false over
256     // the lifetime of the returned object.
scoped_falsehood(const Expr & fact)257     ScopedFact scoped_falsehood(const Expr &fact) {
258         ScopedFact f(this);
259         f.learn_false(fact);
260         return f;
261     }
262 
263     template<typename T>
264     Expr hoist_slice_vector(Expr e);
265 
mutate_let_body(const Stmt & s,ExprInfo *)266     Stmt mutate_let_body(const Stmt &s, ExprInfo *) {
267         return mutate(s);
268     }
mutate_let_body(const Expr & e,ExprInfo * bounds)269     Expr mutate_let_body(const Expr &e, ExprInfo *bounds) {
270         return mutate(e, bounds);
271     }
272 
273     template<typename T, typename Body>
274     Body simplify_let(const T *op, ExprInfo *bounds);
275 
276     Expr visit(const IntImm *op, ExprInfo *bounds);
277     Expr visit(const UIntImm *op, ExprInfo *bounds);
278     Expr visit(const FloatImm *op, ExprInfo *bounds);
279     Expr visit(const StringImm *op, ExprInfo *bounds);
280     Expr visit(const Broadcast *op, ExprInfo *bounds);
281     Expr visit(const Cast *op, ExprInfo *bounds);
282     Expr visit(const Variable *op, ExprInfo *bounds);
283     Expr visit(const Add *op, ExprInfo *bounds);
284     Expr visit(const Sub *op, ExprInfo *bounds);
285     Expr visit(const Mul *op, ExprInfo *bounds);
286     Expr visit(const Div *op, ExprInfo *bounds);
287     Expr visit(const Mod *op, ExprInfo *bounds);
288     Expr visit(const Min *op, ExprInfo *bounds);
289     Expr visit(const Max *op, ExprInfo *bounds);
290     Expr visit(const EQ *op, ExprInfo *bounds);
291     Expr visit(const NE *op, ExprInfo *bounds);
292     Expr visit(const LT *op, ExprInfo *bounds);
293     Expr visit(const LE *op, ExprInfo *bounds);
294     Expr visit(const GT *op, ExprInfo *bounds);
295     Expr visit(const GE *op, ExprInfo *bounds);
296     Expr visit(const And *op, ExprInfo *bounds);
297     Expr visit(const Or *op, ExprInfo *bounds);
298     Expr visit(const Not *op, ExprInfo *bounds);
299     Expr visit(const Select *op, ExprInfo *bounds);
300     Expr visit(const Ramp *op, ExprInfo *bounds);
301     Stmt visit(const IfThenElse *op);
302     Expr visit(const Load *op, ExprInfo *bounds);
303     Expr visit(const Call *op, ExprInfo *bounds);
304     Expr visit(const Shuffle *op, ExprInfo *bounds);
305     Expr visit(const VectorReduce *op, ExprInfo *bounds);
306     Expr visit(const Let *op, ExprInfo *bounds);
307     Stmt visit(const LetStmt *op);
308     Stmt visit(const AssertStmt *op);
309     Stmt visit(const For *op);
310     Stmt visit(const Provide *op);
311     Stmt visit(const Store *op);
312     Stmt visit(const Allocate *op);
313     Stmt visit(const Evaluate *op);
314     Stmt visit(const ProducerConsumer *op);
315     Stmt visit(const Block *op);
316     Stmt visit(const Realize *op);
317     Stmt visit(const Prefetch *op);
318     Stmt visit(const Free *op);
319     Stmt visit(const Acquire *op);
320     Stmt visit(const Fork *op);
321     Stmt visit(const Atomic *op);
322 };
323 
324 }  // namespace Internal
325 }  // namespace Halide
326 
327 #endif
328