1 #ifndef HALIDE_SIMPLIFY_VISITORS_H 2 #define HALIDE_SIMPLIFY_VISITORS_H 3 4 /** \file 5 * The simplifier is separated into multiple compilation units with 6 * this single shared header to speed up the build. This file is not 7 * exported in Halide.h. */ 8 9 #include "Bounds.h" 10 #include "IRMatch.h" 11 #include "IRVisitor.h" 12 #include "Scope.h" 13 14 // Because this file is only included by the simplify methods and 15 // doesn't go into Halide.h, we're free to use any old names for our 16 // macros. 17 18 #define LOG_EXPR_MUTATIONS 0 19 #define LOG_STMT_MUTATIONS 0 20 21 // On old compilers, some visitors would use large stack frames, 22 // because they use expression templates that generate large numbers 23 // of temporary objects when they are built and matched against. If we 24 // wrap the expressions that imply lots of temporaries in a lambda, we 25 // can get these large frames out of the recursive path. 26 #define EVAL_IN_LAMBDA(x) (([&]() HALIDE_NEVER_INLINE { return (x); })()) 27 28 namespace Halide { 29 namespace Internal { 30 31 class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> { 32 using Super = VariadicVisitor<Simplify, Expr, Stmt>; 33 34 public: 35 Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai); 36 37 struct ExprInfo { 38 // We track constant integer bounds when they exist 39 int64_t min = 0, max = 0; 40 bool min_defined = false, max_defined = false; 41 // And the alignment of integer variables 42 ModulusRemainder alignment; 43 trim_bounds_using_alignmentExprInfo44 void trim_bounds_using_alignment() { 45 if (alignment.modulus == 0) { 46 min_defined = max_defined = true; 47 min = max = alignment.remainder; 48 } else if (alignment.modulus > 1) { 49 if (min_defined) { 50 int64_t new_min = min - mod_imp(min, alignment.modulus) + alignment.remainder; 51 if (new_min < min) { 52 new_min += alignment.modulus; 53 } 54 min = new_min; 55 } 56 if (max_defined) { 57 int64_t new_max = max - mod_imp(max, alignment.modulus) + alignment.remainder; 58 if (new_max > max) { 59 new_max -= alignment.modulus; 60 } 61 max = new_max; 62 } 63 } 64 65 if (min_defined && max_defined && min == max) { 66 alignment.modulus = 0; 67 alignment.remainder = min; 68 } 69 } 70 71 // Mix in existing knowledge about this Expr intersectExprInfo72 void intersect(const ExprInfo &other) { 73 if (min_defined && other.min_defined) { 74 min = std::max(min, other.min); 75 } else if (other.min_defined) { 76 min_defined = true; 77 min = other.min; 78 } 79 80 if (max_defined && other.max_defined) { 81 max = std::min(max, other.max); 82 } else if (other.max_defined) { 83 max_defined = true; 84 max = other.max; 85 } 86 87 alignment = ModulusRemainder::intersect(alignment, other.alignment); 88 89 trim_bounds_using_alignment(); 90 } 91 }; 92 93 #if (LOG_EXPR_MUTATORIONS || LOG_STMT_MUTATIONS) 94 static int debug_indent; 95 #endif 96 97 #if LOG_EXPR_MUTATIONS mutate(const Expr & e,ExprInfo * b)98 Expr mutate(const Expr &e, ExprInfo *b) { 99 const std::string spaces(debug_indent, ' '); 100 debug(1) << spaces << "Simplifying Expr: " << e << "\n"; 101 debug_indent++; 102 Expr new_e = Super::dispatch(e, b); 103 debug_indent--; 104 if (!new_e.same_as(e)) { 105 debug(1) 106 << spaces << "Before: " << e << "\n" 107 << spaces << "After: " << new_e << "\n"; 108 } 109 internal_assert(e.type() == new_e.type()); 110 return new_e; 111 } 112 113 #else 114 HALIDE_ALWAYS_INLINE mutate(const Expr & e,ExprInfo * b)115 Expr mutate(const Expr &e, ExprInfo *b) { 116 Expr new_e = Super::dispatch(e, b); 117 internal_assert(new_e.type() == e.type()) << e << " -> " << new_e << "\n"; 118 return new_e; 119 } 120 #endif 121 122 #if LOG_STMT_MUTATIONS mutate(const Stmt & s)123 Stmt mutate(const Stmt &s) { 124 const std::string spaces(debug_indent, ' '); 125 debug(1) << spaces << "Simplifying Stmt: " << s << "\n"; 126 debug_indent++; 127 Stmt new_s = Super::dispatch(s); 128 debug_indent--; 129 if (!new_s.same_as(s)) { 130 debug(1) 131 << spaces << "Before: " << s << "\n" 132 << spaces << "After: " << new_s << "\n"; 133 } 134 return new_s; 135 } 136 #else mutate(const Stmt & s)137 Stmt mutate(const Stmt &s) { 138 return Super::dispatch(s); 139 } 140 #endif 141 142 bool remove_dead_lets; 143 bool no_float_simplify; 144 145 HALIDE_ALWAYS_INLINE may_simplify(const Type & t)146 bool may_simplify(const Type &t) const { 147 return !no_float_simplify || !t.is_float(); 148 } 149 150 // Returns true iff t is an integral type where overflow is undefined 151 HALIDE_ALWAYS_INLINE no_overflow_int(Type t)152 bool no_overflow_int(Type t) { 153 return t.is_int() && t.bits() >= 32; 154 } 155 156 HALIDE_ALWAYS_INLINE no_overflow_scalar_int(Type t)157 bool no_overflow_scalar_int(Type t) { 158 return t.is_scalar() && no_overflow_int(t); 159 } 160 161 // Returns true iff t does not have a well defined overflow behavior. 162 HALIDE_ALWAYS_INLINE no_overflow(Type t)163 bool no_overflow(Type t) { 164 return t.is_float() || no_overflow_int(t); 165 } 166 167 struct VarInfo { 168 Expr replacement; 169 int old_uses, new_uses; 170 }; 171 172 // Tracked for all let vars 173 Scope<VarInfo> var_info; 174 175 // Only tracked for integer let vars 176 Scope<ExprInfo> bounds_and_alignment_info; 177 178 // Symbols used by rewrite rules 179 IRMatcher::Wild<0> x; 180 IRMatcher::Wild<1> y; 181 IRMatcher::Wild<2> z; 182 IRMatcher::Wild<3> w; 183 IRMatcher::Wild<4> u; 184 IRMatcher::Wild<5> v; 185 IRMatcher::WildConst<0> c0; 186 IRMatcher::WildConst<1> c1; 187 IRMatcher::WildConst<2> c2; 188 IRMatcher::WildConst<3> c3; 189 190 // Tracks whether or not we're inside a vector loop. Certain 191 // transformations are not a good idea if the code is to be 192 // vectorized. 193 bool in_vector_loop = false; 194 195 // If we encounter a reference to a buffer (a Load, Store, Call, 196 // or Provide), there's an implicit dependence on some associated 197 // symbols. 198 void found_buffer_reference(const std::string &name, size_t dimensions = 0); 199 200 // Wrappers for as_const_foo that are more convenient to use in 201 // the large chains of conditions in the visit methods 202 // below. Unlike the versions in IROperator, these only match 203 // scalars. 204 bool const_float(const Expr &e, double *f); 205 bool const_int(const Expr &e, int64_t *i); 206 bool const_uint(const Expr &e, uint64_t *u); 207 208 // Put the args to a commutative op in a canonical order 209 HALIDE_ALWAYS_INLINE should_commute(const Expr & a,const Expr & b)210 bool should_commute(const Expr &a, const Expr &b) { 211 if (a.node_type() < b.node_type()) return true; 212 if (a.node_type() > b.node_type()) return false; 213 214 if (a.node_type() == IRNodeType::Variable) { 215 const Variable *va = a.as<Variable>(); 216 const Variable *vb = b.as<Variable>(); 217 return va->name.compare(vb->name) > 0; 218 } 219 220 return false; 221 } 222 223 std::set<Expr, IRDeepCompare> truths, falsehoods; 224 225 struct ScopedFact { 226 Simplify *simplify; 227 228 std::vector<const Variable *> pop_list; 229 std::vector<const Variable *> bounds_pop_list; 230 std::vector<Expr> truths, falsehoods; 231 232 void learn_false(const Expr &fact); 233 void learn_true(const Expr &fact); 234 void learn_upper_bound(const Variable *v, int64_t val); 235 void learn_lower_bound(const Variable *v, int64_t val); 236 ScopedFactScopedFact237 ScopedFact(Simplify *s) 238 : simplify(s) { 239 } 240 ~ScopedFact(); 241 242 // allow move but not copy 243 ScopedFact(const ScopedFact &that) = delete; 244 ScopedFact(ScopedFact &&that) = default; 245 }; 246 247 // Tell the simplifier to learn from and exploit a boolean 248 // condition, over the lifetime of the returned object. scoped_truth(const Expr & fact)249 ScopedFact scoped_truth(const Expr &fact) { 250 ScopedFact f(this); 251 f.learn_true(fact); 252 return f; 253 } 254 255 // Tell the simplifier to assume a boolean condition is false over 256 // the lifetime of the returned object. scoped_falsehood(const Expr & fact)257 ScopedFact scoped_falsehood(const Expr &fact) { 258 ScopedFact f(this); 259 f.learn_false(fact); 260 return f; 261 } 262 263 template<typename T> 264 Expr hoist_slice_vector(Expr e); 265 mutate_let_body(const Stmt & s,ExprInfo *)266 Stmt mutate_let_body(const Stmt &s, ExprInfo *) { 267 return mutate(s); 268 } mutate_let_body(const Expr & e,ExprInfo * bounds)269 Expr mutate_let_body(const Expr &e, ExprInfo *bounds) { 270 return mutate(e, bounds); 271 } 272 273 template<typename T, typename Body> 274 Body simplify_let(const T *op, ExprInfo *bounds); 275 276 Expr visit(const IntImm *op, ExprInfo *bounds); 277 Expr visit(const UIntImm *op, ExprInfo *bounds); 278 Expr visit(const FloatImm *op, ExprInfo *bounds); 279 Expr visit(const StringImm *op, ExprInfo *bounds); 280 Expr visit(const Broadcast *op, ExprInfo *bounds); 281 Expr visit(const Cast *op, ExprInfo *bounds); 282 Expr visit(const Variable *op, ExprInfo *bounds); 283 Expr visit(const Add *op, ExprInfo *bounds); 284 Expr visit(const Sub *op, ExprInfo *bounds); 285 Expr visit(const Mul *op, ExprInfo *bounds); 286 Expr visit(const Div *op, ExprInfo *bounds); 287 Expr visit(const Mod *op, ExprInfo *bounds); 288 Expr visit(const Min *op, ExprInfo *bounds); 289 Expr visit(const Max *op, ExprInfo *bounds); 290 Expr visit(const EQ *op, ExprInfo *bounds); 291 Expr visit(const NE *op, ExprInfo *bounds); 292 Expr visit(const LT *op, ExprInfo *bounds); 293 Expr visit(const LE *op, ExprInfo *bounds); 294 Expr visit(const GT *op, ExprInfo *bounds); 295 Expr visit(const GE *op, ExprInfo *bounds); 296 Expr visit(const And *op, ExprInfo *bounds); 297 Expr visit(const Or *op, ExprInfo *bounds); 298 Expr visit(const Not *op, ExprInfo *bounds); 299 Expr visit(const Select *op, ExprInfo *bounds); 300 Expr visit(const Ramp *op, ExprInfo *bounds); 301 Stmt visit(const IfThenElse *op); 302 Expr visit(const Load *op, ExprInfo *bounds); 303 Expr visit(const Call *op, ExprInfo *bounds); 304 Expr visit(const Shuffle *op, ExprInfo *bounds); 305 Expr visit(const VectorReduce *op, ExprInfo *bounds); 306 Expr visit(const Let *op, ExprInfo *bounds); 307 Stmt visit(const LetStmt *op); 308 Stmt visit(const AssertStmt *op); 309 Stmt visit(const For *op); 310 Stmt visit(const Provide *op); 311 Stmt visit(const Store *op); 312 Stmt visit(const Allocate *op); 313 Stmt visit(const Evaluate *op); 314 Stmt visit(const ProducerConsumer *op); 315 Stmt visit(const Block *op); 316 Stmt visit(const Realize *op); 317 Stmt visit(const Prefetch *op); 318 Stmt visit(const Free *op); 319 Stmt visit(const Acquire *op); 320 Stmt visit(const Fork *op); 321 Stmt visit(const Atomic *op); 322 }; 323 324 } // namespace Internal 325 } // namespace Halide 326 327 #endif 328