1 #include "RemoveUndef.h"
2 #include "IREquality.h"
3 #include "IRMutator.h"
4 #include "IROperator.h"
5 #include "Scope.h"
6 #include "Substitute.h"
7
8 namespace Halide {
9 namespace Internal {
10
11 using std::vector;
12
13 class RemoveUndef : public IRMutator {
14 public:
15 Expr predicate;
16
17 private:
18 using IRMutator::visit;
19
20 Scope<> dead_vars;
21
visit(const Variable * op)22 Expr visit(const Variable *op) override {
23 if (dead_vars.contains(op->name)) {
24 return Expr();
25 } else {
26 return op;
27 }
28 }
29
30 template<typename T>
mutate_binary_operator(const T * op)31 Expr mutate_binary_operator(const T *op) {
32 Expr a = mutate(op->a);
33 if (!a.defined()) return Expr();
34 Expr b = mutate(op->b);
35 if (!b.defined()) return Expr();
36 if (a.same_as(op->a) &&
37 b.same_as(op->b)) {
38 return op;
39 } else {
40 return T::make(std::move(a), std::move(b));
41 }
42 }
43
visit(const Cast * op)44 Expr visit(const Cast *op) override {
45 Expr value = mutate(op->value);
46 if (!value.defined()) return Expr();
47 if (value.same_as(op->value)) {
48 return op;
49 } else {
50 return Cast::make(op->type, std::move(value));
51 }
52 }
53
visit(const Add * op)54 Expr visit(const Add *op) override {
55 return mutate_binary_operator(op);
56 }
visit(const Sub * op)57 Expr visit(const Sub *op) override {
58 return mutate_binary_operator(op);
59 }
visit(const Mul * op)60 Expr visit(const Mul *op) override {
61 return mutate_binary_operator(op);
62 }
visit(const Div * op)63 Expr visit(const Div *op) override {
64 return mutate_binary_operator(op);
65 }
visit(const Mod * op)66 Expr visit(const Mod *op) override {
67 return mutate_binary_operator(op);
68 }
visit(const Min * op)69 Expr visit(const Min *op) override {
70 return mutate_binary_operator(op);
71 }
visit(const Max * op)72 Expr visit(const Max *op) override {
73 return mutate_binary_operator(op);
74 }
visit(const EQ * op)75 Expr visit(const EQ *op) override {
76 return mutate_binary_operator(op);
77 }
visit(const NE * op)78 Expr visit(const NE *op) override {
79 return mutate_binary_operator(op);
80 }
visit(const LT * op)81 Expr visit(const LT *op) override {
82 return mutate_binary_operator(op);
83 }
visit(const LE * op)84 Expr visit(const LE *op) override {
85 return mutate_binary_operator(op);
86 }
visit(const GT * op)87 Expr visit(const GT *op) override {
88 return mutate_binary_operator(op);
89 }
visit(const GE * op)90 Expr visit(const GE *op) override {
91 return mutate_binary_operator(op);
92 }
visit(const And * op)93 Expr visit(const And *op) override {
94 return mutate_binary_operator(op);
95 }
visit(const Or * op)96 Expr visit(const Or *op) override {
97 return mutate_binary_operator(op);
98 }
99
visit(const Not * op)100 Expr visit(const Not *op) override {
101 Expr a = mutate(op->a);
102 if (!a.defined()) return Expr();
103 if (a.same_as(op->a)) {
104 return op;
105 } else {
106 return Not::make(a);
107 }
108 }
109
visit(const Select * op)110 Expr visit(const Select *op) override {
111 Expr cond = mutate(op->condition);
112 Expr t = mutate(op->true_value);
113 Expr f = mutate(op->false_value);
114
115 if (!cond.defined()) {
116 return Expr();
117 }
118
119 if (!t.defined() && !f.defined()) {
120 return Expr();
121 }
122
123 if (!t.defined()) {
124 // Swap the cases so that we only need to deal with the
125 // case when false is not defined below.
126 cond = Not::make(cond);
127 t = f;
128 f = Expr();
129 }
130
131 if (!f.defined()) {
132 // We need to convert this to an if-then-else
133 if (predicate.defined()) {
134 predicate = predicate && cond;
135 } else {
136 predicate = cond;
137 }
138 return t;
139 } else if (cond.same_as(op->condition) &&
140 t.same_as(op->true_value) &&
141 f.same_as(op->false_value)) {
142 return op;
143 } else {
144 return Select::make(cond, t, f);
145 }
146 }
147
visit(const Load * op)148 Expr visit(const Load *op) override {
149 Expr pred = mutate(op->predicate);
150 if (!pred.defined()) return Expr();
151 Expr index = mutate(op->index);
152 if (!index.defined()) return Expr();
153 if (pred.same_as(op->predicate) && index.same_as(op->index)) {
154 return op;
155 } else {
156 return Load::make(op->type, op->name, index, op->image, op->param, pred, op->alignment);
157 }
158 }
159
visit(const Ramp * op)160 Expr visit(const Ramp *op) override {
161 Expr base = mutate(op->base);
162 if (!base.defined()) return Expr();
163 Expr stride = mutate(op->stride);
164 if (!stride.defined()) return Expr();
165 if (base.same_as(op->base) &&
166 stride.same_as(op->stride)) {
167 return op;
168 } else {
169 return Ramp::make(base, stride, op->lanes);
170 }
171 }
172
visit(const Broadcast * op)173 Expr visit(const Broadcast *op) override {
174 Expr value = mutate(op->value);
175 if (!value.defined()) return Expr();
176 if (value.same_as(op->value)) {
177 return op;
178 } else {
179 return Broadcast::make(value, op->lanes);
180 }
181 }
182
visit(const Call * op)183 Expr visit(const Call *op) override {
184 if (op->is_intrinsic(Call::undef)) {
185 return Expr();
186 }
187
188 vector<Expr> new_args(op->args.size());
189 bool changed = false;
190
191 // Mutate the args
192 for (size_t i = 0; i < op->args.size(); i++) {
193 Expr old_arg = op->args[i];
194 Expr new_arg = mutate(old_arg);
195 if (!new_arg.defined()) return Expr();
196 if (!new_arg.same_as(old_arg)) changed = true;
197 new_args[i] = new_arg;
198 }
199
200 if (!changed) {
201 return op;
202 } else {
203 return Call::make(op->type, op->name, new_args, op->call_type,
204 op->func, op->value_index, op->image, op->param);
205 }
206 }
207
208 template<typename T, typename Body>
visit_let(const T * op)209 Body visit_let(const T *op) {
210 // Visit an entire chain of lets in a single method to conserve stack space.
211 struct Frame {
212 const T *op;
213 Expr new_value;
214 ScopedBinding<> binding;
215 Frame(const T *op, Expr v, Scope<> &scope)
216 : op(op), new_value(std::move(v)),
217 binding(!new_value.defined(), scope, op->name) {
218 }
219 };
220 vector<Frame> frames;
221
222 Body result;
223 do {
224 frames.emplace_back(op, mutate(op->value), dead_vars);
225 result = op->body;
226 } while ((op = result.template as<T>()));
227
228 result = mutate(result);
229
230 if (result.defined()) {
231 for (auto it = frames.rbegin(); it != frames.rend(); it++) {
232 if (!it->new_value.defined()) continue;
233 predicate = substitute(it->op->name, it->new_value, predicate);
234 if (it->new_value.same_as(it->op->value) && result.same_as(it->op->body)) {
235 result = it->op;
236 } else {
237 result = T::make(it->op->name, std::move(it->new_value), result);
238 }
239 }
240 }
241
242 return result;
243 }
244
visit(const Let * op)245 Expr visit(const Let *op) override {
246 return visit_let<Let, Expr>(op);
247 }
248
visit(const LetStmt * op)249 Stmt visit(const LetStmt *op) override {
250 return visit_let<LetStmt, Stmt>(op);
251 }
252
visit(const AssertStmt * op)253 Stmt visit(const AssertStmt *op) override {
254 Expr condition = mutate(op->condition);
255 if (!condition.defined()) {
256 return Stmt();
257 }
258
259 Expr message = mutate(op->message);
260 if (!message.defined()) {
261 return Stmt();
262 }
263
264 if (condition.same_as(op->condition) && message.same_as(op->message)) {
265 return op;
266 } else {
267 return AssertStmt::make(condition, message);
268 }
269 }
270
visit(const ProducerConsumer * op)271 Stmt visit(const ProducerConsumer *op) override {
272 Stmt body = mutate(op->body);
273 if (!body.defined()) return Stmt();
274 if (body.same_as(op->body)) {
275 return op;
276 } else {
277 return ProducerConsumer::make(op->name, op->is_producer, body);
278 }
279 }
280
visit(const For * op)281 Stmt visit(const For *op) override {
282 Expr min = mutate(op->min);
283 if (!min.defined()) {
284 return Stmt();
285 }
286 Expr extent = mutate(op->extent);
287 if (!extent.defined()) {
288 return Stmt();
289 }
290 Stmt body = mutate(op->body);
291 if (!body.defined()) return Stmt();
292 if (min.same_as(op->min) &&
293 extent.same_as(op->extent) &&
294 body.same_as(op->body)) {
295 return op;
296 } else {
297 return For::make(op->name, min, extent, op->for_type, op->device_api, body);
298 }
299 }
300
visit(const Store * op)301 Stmt visit(const Store *op) override {
302 predicate = Expr();
303
304 Expr pred = mutate(op->predicate);
305 Expr value = mutate(op->value);
306 if (!value.defined()) {
307 return Stmt();
308 }
309
310 Expr index = mutate(op->index);
311 if (!index.defined()) {
312 return Stmt();
313 }
314
315 if (predicate.defined()) {
316 // This becomes a conditional store
317 Stmt stmt = IfThenElse::make(predicate, Store::make(op->name, value, index, op->param, pred, op->alignment));
318 predicate = Expr();
319 return stmt;
320 } else if (pred.same_as(op->predicate) &&
321 value.same_as(op->value) &&
322 index.same_as(op->index)) {
323 return op;
324 } else {
325 return Store::make(op->name, value, index, op->param, pred, op->alignment);
326 }
327 }
328
visit(const Provide * op)329 Stmt visit(const Provide *op) override {
330 predicate = Expr();
331
332 vector<Expr> new_args(op->args.size());
333 vector<Expr> new_values(op->values.size());
334 vector<Expr> args_predicates;
335 vector<Expr> values_predicates;
336 bool changed = false;
337
338 // Mutate the args
339 for (size_t i = 0; i < op->args.size(); i++) {
340 Expr old_arg = op->args[i];
341 predicate = Expr();
342 Expr new_arg = mutate(old_arg);
343 if (!new_arg.defined()) {
344 return Stmt();
345 }
346 args_predicates.push_back(predicate);
347 if (!new_arg.same_as(old_arg)) changed = true;
348 new_args[i] = new_arg;
349 }
350
351 for (size_t i = 1; i < args_predicates.size(); i++) {
352 user_assert(equal(args_predicates[i - 1], args_predicates[i]))
353 << "Conditionally-undef args in a Tuple should have the same conditions\n"
354 << " Condition " << i - 1 << ": " << args_predicates[i - 1] << "\n"
355 << " Condition " << i << ": " << args_predicates[i] << "\n";
356 }
357
358 bool all_values_undefined = true;
359 for (size_t i = 0; i < op->values.size(); i++) {
360 Expr old_value = op->values[i];
361 predicate = Expr();
362 Expr new_value = mutate(old_value);
363 if (!new_value.defined()) {
364 new_value = undef(old_value.type());
365 } else {
366 all_values_undefined = false;
367 values_predicates.push_back(predicate);
368 }
369 if (!new_value.same_as(old_value)) changed = true;
370 new_values[i] = new_value;
371 }
372
373 if (all_values_undefined) {
374 return Stmt();
375 }
376
377 for (size_t i = 1; i < values_predicates.size(); i++) {
378 user_assert(equal(values_predicates[i - 1], values_predicates[i]))
379 << "Conditionally-undef values in a Tuple should have the same conditions\n"
380 << " Condition " << i - 1 << ": " << values_predicates[i - 1] << "\n"
381 << " Condition " << i << ": " << values_predicates[i] << "\n";
382 }
383
384 if (predicate.defined()) {
385 Stmt stmt = IfThenElse::make(predicate, Provide::make(op->name, new_values, new_args));
386 predicate = Expr();
387 return stmt;
388 } else if (!changed) {
389 return op;
390 } else {
391 return Provide::make(op->name, new_values, new_args);
392 }
393 }
394
visit(const Allocate * op)395 Stmt visit(const Allocate *op) override {
396 std::vector<Expr> new_extents;
397 bool all_extents_unmodified = true;
398 for (size_t i = 0; i < op->extents.size(); i++) {
399 new_extents.push_back(mutate(op->extents[i]));
400 if (!new_extents.back().defined()) {
401 return Stmt();
402 }
403 all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
404 }
405 Stmt body = mutate(op->body);
406 if (!body.defined()) return Stmt();
407
408 Expr condition = mutate(op->condition);
409 if (!condition.defined()) return Stmt();
410
411 Expr new_expr;
412 if (op->new_expr.defined()) {
413 new_expr = mutate(op->new_expr);
414 }
415
416 if (all_extents_unmodified &&
417 body.same_as(op->body) &&
418 condition.same_as(op->condition) &&
419 new_expr.same_as(op->new_expr)) {
420 return op;
421 } else {
422 return Allocate::make(op->name, op->type, op->memory_type,
423 new_extents, condition, body, new_expr, op->free_function);
424 }
425 }
426
visit(const Free * op)427 Stmt visit(const Free *op) override {
428 return op;
429 }
430
visit(const Realize * op)431 Stmt visit(const Realize *op) override {
432 Region new_bounds(op->bounds.size());
433 bool bounds_changed = false;
434
435 // Mutate the bounds
436 for (size_t i = 0; i < op->bounds.size(); i++) {
437 Expr old_min = op->bounds[i].min;
438 Expr old_extent = op->bounds[i].extent;
439 Expr new_min = mutate(old_min);
440 if (!new_min.defined()) {
441 return Stmt();
442 }
443 Expr new_extent = mutate(old_extent);
444 if (!new_extent.defined()) {
445 return Stmt();
446 }
447 if (!new_min.same_as(old_min)) {
448 bounds_changed = true;
449 }
450 if (!new_extent.same_as(old_extent)) {
451 bounds_changed = true;
452 }
453 new_bounds[i] = Range(new_min, new_extent);
454 }
455
456 Stmt body = mutate(op->body);
457 if (!body.defined()) return Stmt();
458
459 Expr condition = mutate(op->condition);
460 if (!condition.defined()) return Stmt();
461
462 if (!bounds_changed &&
463 body.same_as(op->body) &&
464 condition.same_as(op->condition)) {
465 return op;
466 } else {
467 return Realize::make(op->name, op->types, op->memory_type, new_bounds, condition, body);
468 }
469 }
470
visit(const Block * op)471 Stmt visit(const Block *op) override {
472 // Visit a sequence of blocks in a single method to conserve stack space.
473 Stmt result;
474 vector<std::pair<const Block *, Stmt>> frames;
475
476 do {
477 Stmt next = mutate(op->first);
478 if (next.defined()) {
479 frames.emplace_back(op, std::move(next));
480 }
481 result = op->rest;
482 } while ((op = result.as<Block>()));
483
484 result = mutate(result);
485
486 for (auto it = frames.rbegin(); it != frames.rend(); it++) {
487 op = it->first;
488 Stmt new_first = std::move(it->second);
489 if (!result.defined()) {
490 result = new_first;
491 } else if (new_first.same_as(op->first) && result.same_as(op->rest)) {
492 result = op;
493 } else {
494 result = Block::make(new_first, result);
495 }
496 }
497 return result;
498 }
499
visit(const IfThenElse * op)500 Stmt visit(const IfThenElse *op) override {
501 Expr condition = mutate(op->condition);
502 if (!condition.defined()) {
503 return Stmt();
504 }
505 Stmt then_case = mutate(op->then_case);
506 Stmt else_case = mutate(op->else_case);
507
508 if (!then_case.defined() && !else_case.defined()) {
509 return Stmt();
510 }
511
512 if (!then_case.defined()) {
513 condition = Not::make(condition);
514 then_case = else_case;
515 else_case = Stmt();
516 }
517
518 if (condition.same_as(op->condition) &&
519 then_case.same_as(op->then_case) &&
520 else_case.same_as(op->else_case)) {
521 return op;
522 } else {
523 return IfThenElse::make(condition, then_case, else_case);
524 }
525 }
526
visit(const Evaluate * op)527 Stmt visit(const Evaluate *op) override {
528 Expr v = mutate(op->value);
529 if (!v.defined()) {
530 return Stmt();
531 } else if (v.same_as(op->value)) {
532 return op;
533 } else {
534 return Evaluate::make(v);
535 }
536 }
537 };
538
remove_undef(Stmt s)539 Stmt remove_undef(Stmt s) {
540 RemoveUndef r;
541 s = r.mutate(s);
542 internal_assert(!r.predicate.defined())
543 << "Undefined expression leaked outside of a Store node: "
544 << r.predicate << "\n";
545 return s;
546 }
547
548 } // namespace Internal
549 } // namespace Halide
550