1 #include "RemoveUndef.h"
2 #include "IREquality.h"
3 #include "IRMutator.h"
4 #include "IROperator.h"
5 #include "Scope.h"
6 #include "Substitute.h"
7 
8 namespace Halide {
9 namespace Internal {
10 
11 using std::vector;
12 
13 class RemoveUndef : public IRMutator {
14 public:
15     Expr predicate;
16 
17 private:
18     using IRMutator::visit;
19 
20     Scope<> dead_vars;
21 
visit(const Variable * op)22     Expr visit(const Variable *op) override {
23         if (dead_vars.contains(op->name)) {
24             return Expr();
25         } else {
26             return op;
27         }
28     }
29 
30     template<typename T>
mutate_binary_operator(const T * op)31     Expr mutate_binary_operator(const T *op) {
32         Expr a = mutate(op->a);
33         if (!a.defined()) return Expr();
34         Expr b = mutate(op->b);
35         if (!b.defined()) return Expr();
36         if (a.same_as(op->a) &&
37             b.same_as(op->b)) {
38             return op;
39         } else {
40             return T::make(std::move(a), std::move(b));
41         }
42     }
43 
visit(const Cast * op)44     Expr visit(const Cast *op) override {
45         Expr value = mutate(op->value);
46         if (!value.defined()) return Expr();
47         if (value.same_as(op->value)) {
48             return op;
49         } else {
50             return Cast::make(op->type, std::move(value));
51         }
52     }
53 
visit(const Add * op)54     Expr visit(const Add *op) override {
55         return mutate_binary_operator(op);
56     }
visit(const Sub * op)57     Expr visit(const Sub *op) override {
58         return mutate_binary_operator(op);
59     }
visit(const Mul * op)60     Expr visit(const Mul *op) override {
61         return mutate_binary_operator(op);
62     }
visit(const Div * op)63     Expr visit(const Div *op) override {
64         return mutate_binary_operator(op);
65     }
visit(const Mod * op)66     Expr visit(const Mod *op) override {
67         return mutate_binary_operator(op);
68     }
visit(const Min * op)69     Expr visit(const Min *op) override {
70         return mutate_binary_operator(op);
71     }
visit(const Max * op)72     Expr visit(const Max *op) override {
73         return mutate_binary_operator(op);
74     }
visit(const EQ * op)75     Expr visit(const EQ *op) override {
76         return mutate_binary_operator(op);
77     }
visit(const NE * op)78     Expr visit(const NE *op) override {
79         return mutate_binary_operator(op);
80     }
visit(const LT * op)81     Expr visit(const LT *op) override {
82         return mutate_binary_operator(op);
83     }
visit(const LE * op)84     Expr visit(const LE *op) override {
85         return mutate_binary_operator(op);
86     }
visit(const GT * op)87     Expr visit(const GT *op) override {
88         return mutate_binary_operator(op);
89     }
visit(const GE * op)90     Expr visit(const GE *op) override {
91         return mutate_binary_operator(op);
92     }
visit(const And * op)93     Expr visit(const And *op) override {
94         return mutate_binary_operator(op);
95     }
visit(const Or * op)96     Expr visit(const Or *op) override {
97         return mutate_binary_operator(op);
98     }
99 
visit(const Not * op)100     Expr visit(const Not *op) override {
101         Expr a = mutate(op->a);
102         if (!a.defined()) return Expr();
103         if (a.same_as(op->a)) {
104             return op;
105         } else {
106             return Not::make(a);
107         }
108     }
109 
visit(const Select * op)110     Expr visit(const Select *op) override {
111         Expr cond = mutate(op->condition);
112         Expr t = mutate(op->true_value);
113         Expr f = mutate(op->false_value);
114 
115         if (!cond.defined()) {
116             return Expr();
117         }
118 
119         if (!t.defined() && !f.defined()) {
120             return Expr();
121         }
122 
123         if (!t.defined()) {
124             // Swap the cases so that we only need to deal with the
125             // case when false is not defined below.
126             cond = Not::make(cond);
127             t = f;
128             f = Expr();
129         }
130 
131         if (!f.defined()) {
132             // We need to convert this to an if-then-else
133             if (predicate.defined()) {
134                 predicate = predicate && cond;
135             } else {
136                 predicate = cond;
137             }
138             return t;
139         } else if (cond.same_as(op->condition) &&
140                    t.same_as(op->true_value) &&
141                    f.same_as(op->false_value)) {
142             return op;
143         } else {
144             return Select::make(cond, t, f);
145         }
146     }
147 
visit(const Load * op)148     Expr visit(const Load *op) override {
149         Expr pred = mutate(op->predicate);
150         if (!pred.defined()) return Expr();
151         Expr index = mutate(op->index);
152         if (!index.defined()) return Expr();
153         if (pred.same_as(op->predicate) && index.same_as(op->index)) {
154             return op;
155         } else {
156             return Load::make(op->type, op->name, index, op->image, op->param, pred, op->alignment);
157         }
158     }
159 
visit(const Ramp * op)160     Expr visit(const Ramp *op) override {
161         Expr base = mutate(op->base);
162         if (!base.defined()) return Expr();
163         Expr stride = mutate(op->stride);
164         if (!stride.defined()) return Expr();
165         if (base.same_as(op->base) &&
166             stride.same_as(op->stride)) {
167             return op;
168         } else {
169             return Ramp::make(base, stride, op->lanes);
170         }
171     }
172 
visit(const Broadcast * op)173     Expr visit(const Broadcast *op) override {
174         Expr value = mutate(op->value);
175         if (!value.defined()) return Expr();
176         if (value.same_as(op->value)) {
177             return op;
178         } else {
179             return Broadcast::make(value, op->lanes);
180         }
181     }
182 
visit(const Call * op)183     Expr visit(const Call *op) override {
184         if (op->is_intrinsic(Call::undef)) {
185             return Expr();
186         }
187 
188         vector<Expr> new_args(op->args.size());
189         bool changed = false;
190 
191         // Mutate the args
192         for (size_t i = 0; i < op->args.size(); i++) {
193             Expr old_arg = op->args[i];
194             Expr new_arg = mutate(old_arg);
195             if (!new_arg.defined()) return Expr();
196             if (!new_arg.same_as(old_arg)) changed = true;
197             new_args[i] = new_arg;
198         }
199 
200         if (!changed) {
201             return op;
202         } else {
203             return Call::make(op->type, op->name, new_args, op->call_type,
204                               op->func, op->value_index, op->image, op->param);
205         }
206     }
207 
208     template<typename T, typename Body>
visit_let(const T * op)209     Body visit_let(const T *op) {
210         // Visit an entire chain of lets in a single method to conserve stack space.
211         struct Frame {
212             const T *op;
213             Expr new_value;
214             ScopedBinding<> binding;
215             Frame(const T *op, Expr v, Scope<> &scope)
216                 : op(op), new_value(std::move(v)),
217                   binding(!new_value.defined(), scope, op->name) {
218             }
219         };
220         vector<Frame> frames;
221 
222         Body result;
223         do {
224             frames.emplace_back(op, mutate(op->value), dead_vars);
225             result = op->body;
226         } while ((op = result.template as<T>()));
227 
228         result = mutate(result);
229 
230         if (result.defined()) {
231             for (auto it = frames.rbegin(); it != frames.rend(); it++) {
232                 if (!it->new_value.defined()) continue;
233                 predicate = substitute(it->op->name, it->new_value, predicate);
234                 if (it->new_value.same_as(it->op->value) && result.same_as(it->op->body)) {
235                     result = it->op;
236                 } else {
237                     result = T::make(it->op->name, std::move(it->new_value), result);
238                 }
239             }
240         }
241 
242         return result;
243     }
244 
visit(const Let * op)245     Expr visit(const Let *op) override {
246         return visit_let<Let, Expr>(op);
247     }
248 
visit(const LetStmt * op)249     Stmt visit(const LetStmt *op) override {
250         return visit_let<LetStmt, Stmt>(op);
251     }
252 
visit(const AssertStmt * op)253     Stmt visit(const AssertStmt *op) override {
254         Expr condition = mutate(op->condition);
255         if (!condition.defined()) {
256             return Stmt();
257         }
258 
259         Expr message = mutate(op->message);
260         if (!message.defined()) {
261             return Stmt();
262         }
263 
264         if (condition.same_as(op->condition) && message.same_as(op->message)) {
265             return op;
266         } else {
267             return AssertStmt::make(condition, message);
268         }
269     }
270 
visit(const ProducerConsumer * op)271     Stmt visit(const ProducerConsumer *op) override {
272         Stmt body = mutate(op->body);
273         if (!body.defined()) return Stmt();
274         if (body.same_as(op->body)) {
275             return op;
276         } else {
277             return ProducerConsumer::make(op->name, op->is_producer, body);
278         }
279     }
280 
visit(const For * op)281     Stmt visit(const For *op) override {
282         Expr min = mutate(op->min);
283         if (!min.defined()) {
284             return Stmt();
285         }
286         Expr extent = mutate(op->extent);
287         if (!extent.defined()) {
288             return Stmt();
289         }
290         Stmt body = mutate(op->body);
291         if (!body.defined()) return Stmt();
292         if (min.same_as(op->min) &&
293             extent.same_as(op->extent) &&
294             body.same_as(op->body)) {
295             return op;
296         } else {
297             return For::make(op->name, min, extent, op->for_type, op->device_api, body);
298         }
299     }
300 
visit(const Store * op)301     Stmt visit(const Store *op) override {
302         predicate = Expr();
303 
304         Expr pred = mutate(op->predicate);
305         Expr value = mutate(op->value);
306         if (!value.defined()) {
307             return Stmt();
308         }
309 
310         Expr index = mutate(op->index);
311         if (!index.defined()) {
312             return Stmt();
313         }
314 
315         if (predicate.defined()) {
316             // This becomes a conditional store
317             Stmt stmt = IfThenElse::make(predicate, Store::make(op->name, value, index, op->param, pred, op->alignment));
318             predicate = Expr();
319             return stmt;
320         } else if (pred.same_as(op->predicate) &&
321                    value.same_as(op->value) &&
322                    index.same_as(op->index)) {
323             return op;
324         } else {
325             return Store::make(op->name, value, index, op->param, pred, op->alignment);
326         }
327     }
328 
visit(const Provide * op)329     Stmt visit(const Provide *op) override {
330         predicate = Expr();
331 
332         vector<Expr> new_args(op->args.size());
333         vector<Expr> new_values(op->values.size());
334         vector<Expr> args_predicates;
335         vector<Expr> values_predicates;
336         bool changed = false;
337 
338         // Mutate the args
339         for (size_t i = 0; i < op->args.size(); i++) {
340             Expr old_arg = op->args[i];
341             predicate = Expr();
342             Expr new_arg = mutate(old_arg);
343             if (!new_arg.defined()) {
344                 return Stmt();
345             }
346             args_predicates.push_back(predicate);
347             if (!new_arg.same_as(old_arg)) changed = true;
348             new_args[i] = new_arg;
349         }
350 
351         for (size_t i = 1; i < args_predicates.size(); i++) {
352             user_assert(equal(args_predicates[i - 1], args_predicates[i]))
353                 << "Conditionally-undef args in a Tuple should have the same conditions\n"
354                 << "  Condition " << i - 1 << ": " << args_predicates[i - 1] << "\n"
355                 << "  Condition " << i << ": " << args_predicates[i] << "\n";
356         }
357 
358         bool all_values_undefined = true;
359         for (size_t i = 0; i < op->values.size(); i++) {
360             Expr old_value = op->values[i];
361             predicate = Expr();
362             Expr new_value = mutate(old_value);
363             if (!new_value.defined()) {
364                 new_value = undef(old_value.type());
365             } else {
366                 all_values_undefined = false;
367                 values_predicates.push_back(predicate);
368             }
369             if (!new_value.same_as(old_value)) changed = true;
370             new_values[i] = new_value;
371         }
372 
373         if (all_values_undefined) {
374             return Stmt();
375         }
376 
377         for (size_t i = 1; i < values_predicates.size(); i++) {
378             user_assert(equal(values_predicates[i - 1], values_predicates[i]))
379                 << "Conditionally-undef values in a Tuple should have the same conditions\n"
380                 << "  Condition " << i - 1 << ": " << values_predicates[i - 1] << "\n"
381                 << "  Condition " << i << ": " << values_predicates[i] << "\n";
382         }
383 
384         if (predicate.defined()) {
385             Stmt stmt = IfThenElse::make(predicate, Provide::make(op->name, new_values, new_args));
386             predicate = Expr();
387             return stmt;
388         } else if (!changed) {
389             return op;
390         } else {
391             return Provide::make(op->name, new_values, new_args);
392         }
393     }
394 
visit(const Allocate * op)395     Stmt visit(const Allocate *op) override {
396         std::vector<Expr> new_extents;
397         bool all_extents_unmodified = true;
398         for (size_t i = 0; i < op->extents.size(); i++) {
399             new_extents.push_back(mutate(op->extents[i]));
400             if (!new_extents.back().defined()) {
401                 return Stmt();
402             }
403             all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
404         }
405         Stmt body = mutate(op->body);
406         if (!body.defined()) return Stmt();
407 
408         Expr condition = mutate(op->condition);
409         if (!condition.defined()) return Stmt();
410 
411         Expr new_expr;
412         if (op->new_expr.defined()) {
413             new_expr = mutate(op->new_expr);
414         }
415 
416         if (all_extents_unmodified &&
417             body.same_as(op->body) &&
418             condition.same_as(op->condition) &&
419             new_expr.same_as(op->new_expr)) {
420             return op;
421         } else {
422             return Allocate::make(op->name, op->type, op->memory_type,
423                                   new_extents, condition, body, new_expr, op->free_function);
424         }
425     }
426 
visit(const Free * op)427     Stmt visit(const Free *op) override {
428         return op;
429     }
430 
visit(const Realize * op)431     Stmt visit(const Realize *op) override {
432         Region new_bounds(op->bounds.size());
433         bool bounds_changed = false;
434 
435         // Mutate the bounds
436         for (size_t i = 0; i < op->bounds.size(); i++) {
437             Expr old_min = op->bounds[i].min;
438             Expr old_extent = op->bounds[i].extent;
439             Expr new_min = mutate(old_min);
440             if (!new_min.defined()) {
441                 return Stmt();
442             }
443             Expr new_extent = mutate(old_extent);
444             if (!new_extent.defined()) {
445                 return Stmt();
446             }
447             if (!new_min.same_as(old_min)) {
448                 bounds_changed = true;
449             }
450             if (!new_extent.same_as(old_extent)) {
451                 bounds_changed = true;
452             }
453             new_bounds[i] = Range(new_min, new_extent);
454         }
455 
456         Stmt body = mutate(op->body);
457         if (!body.defined()) return Stmt();
458 
459         Expr condition = mutate(op->condition);
460         if (!condition.defined()) return Stmt();
461 
462         if (!bounds_changed &&
463             body.same_as(op->body) &&
464             condition.same_as(op->condition)) {
465             return op;
466         } else {
467             return Realize::make(op->name, op->types, op->memory_type, new_bounds, condition, body);
468         }
469     }
470 
visit(const Block * op)471     Stmt visit(const Block *op) override {
472         // Visit a sequence of blocks in a single method to conserve stack space.
473         Stmt result;
474         vector<std::pair<const Block *, Stmt>> frames;
475 
476         do {
477             Stmt next = mutate(op->first);
478             if (next.defined()) {
479                 frames.emplace_back(op, std::move(next));
480             }
481             result = op->rest;
482         } while ((op = result.as<Block>()));
483 
484         result = mutate(result);
485 
486         for (auto it = frames.rbegin(); it != frames.rend(); it++) {
487             op = it->first;
488             Stmt new_first = std::move(it->second);
489             if (!result.defined()) {
490                 result = new_first;
491             } else if (new_first.same_as(op->first) && result.same_as(op->rest)) {
492                 result = op;
493             } else {
494                 result = Block::make(new_first, result);
495             }
496         }
497         return result;
498     }
499 
visit(const IfThenElse * op)500     Stmt visit(const IfThenElse *op) override {
501         Expr condition = mutate(op->condition);
502         if (!condition.defined()) {
503             return Stmt();
504         }
505         Stmt then_case = mutate(op->then_case);
506         Stmt else_case = mutate(op->else_case);
507 
508         if (!then_case.defined() && !else_case.defined()) {
509             return Stmt();
510         }
511 
512         if (!then_case.defined()) {
513             condition = Not::make(condition);
514             then_case = else_case;
515             else_case = Stmt();
516         }
517 
518         if (condition.same_as(op->condition) &&
519             then_case.same_as(op->then_case) &&
520             else_case.same_as(op->else_case)) {
521             return op;
522         } else {
523             return IfThenElse::make(condition, then_case, else_case);
524         }
525     }
526 
visit(const Evaluate * op)527     Stmt visit(const Evaluate *op) override {
528         Expr v = mutate(op->value);
529         if (!v.defined()) {
530             return Stmt();
531         } else if (v.same_as(op->value)) {
532             return op;
533         } else {
534             return Evaluate::make(v);
535         }
536     }
537 };
538 
remove_undef(Stmt s)539 Stmt remove_undef(Stmt s) {
540     RemoveUndef r;
541     s = r.mutate(s);
542     internal_assert(!r.predicate.defined())
543         << "Undefined expression leaked outside of a Store node: "
544         << r.predicate << "\n";
545     return s;
546 }
547 
548 }  // namespace Internal
549 }  // namespace Halide
550