1 #include "Simplify_Internal.h"
2 
3 #include "IRMutator.h"
4 #include "Substitute.h"
5 
6 namespace Halide {
7 namespace Internal {
8 
9 using std::pair;
10 using std::string;
11 using std::vector;
12 
visit(const IfThenElse * op)13 Stmt Simplify::visit(const IfThenElse *op) {
14     Expr condition = mutate(op->condition, nullptr);
15 
16     // If (likely(true)) ...
17     const Call *call = condition.as<Call>();
18     Expr unwrapped_condition = condition;
19     if (call &&
20         (call->is_intrinsic(Call::likely) ||
21          call->is_intrinsic(Call::likely_if_innermost))) {
22         unwrapped_condition = call->args[0];
23     }
24 
25     // If (true) ...
26     if (is_one(unwrapped_condition)) {
27         return mutate(op->then_case);
28     }
29 
30     // If (false) ...
31     if (is_zero(unwrapped_condition)) {
32         if (op->else_case.defined()) {
33             return mutate(op->else_case);
34         } else {
35             return Evaluate::make(0);
36         }
37     }
38 
39     Stmt then_case, else_case;
40     {
41         auto f = scoped_truth(unwrapped_condition);
42         // Also substitute the entire condition
43         then_case = substitute(op->condition, const_true(condition.type().lanes()), op->then_case);
44         then_case = mutate(then_case);
45     }
46     {
47         auto f = scoped_falsehood(unwrapped_condition);
48         else_case = substitute(op->condition, const_false(condition.type().lanes()), op->else_case);
49         else_case = mutate(else_case);
50     }
51 
52     // If both sides are no-ops, bail out.
53     if (is_no_op(then_case) && is_no_op(else_case)) {
54         return then_case;
55     }
56 
57     // Pull out common nodes
58     if (equal(then_case, else_case)) {
59         return then_case;
60     }
61     const Acquire *then_acquire = then_case.as<Acquire>();
62     const Acquire *else_acquire = else_case.as<Acquire>();
63     const ProducerConsumer *then_pc = then_case.as<ProducerConsumer>();
64     const ProducerConsumer *else_pc = else_case.as<ProducerConsumer>();
65     const Block *then_block = then_case.as<Block>();
66     const Block *else_block = else_case.as<Block>();
67     const For *then_for = then_case.as<For>();
68     if (then_acquire &&
69         else_acquire &&
70         equal(then_acquire->semaphore, else_acquire->semaphore) &&
71         equal(then_acquire->count, else_acquire->count)) {
72         return Acquire::make(then_acquire->semaphore, then_acquire->count,
73                              mutate(IfThenElse::make(condition, then_acquire->body, else_acquire->body)));
74     } else if (then_pc &&
75                else_pc &&
76                then_pc->name == else_pc->name &&
77                then_pc->is_producer == else_pc->is_producer) {
78         return ProducerConsumer::make(then_pc->name, then_pc->is_producer,
79                                       mutate(IfThenElse::make(condition, then_pc->body, else_pc->body)));
80     } else if (then_block &&
81                else_block &&
82                equal(then_block->first, else_block->first)) {
83         return Block::make(then_block->first,
84                            mutate(IfThenElse::make(condition, then_block->rest, else_block->rest)));
85     } else if (then_block &&
86                else_block &&
87                equal(then_block->rest, else_block->rest)) {
88         return Block::make(mutate(IfThenElse::make(condition, then_block->first, else_block->first)),
89                            then_block->rest);
90     } else if (then_block && equal(then_block->first, else_case)) {
91         return Block::make(else_case,
92                            mutate(IfThenElse::make(condition, then_block->rest)));
93     } else if (then_block && equal(then_block->rest, else_case)) {
94         return Block::make(mutate(IfThenElse::make(condition, then_block->first)),
95                            else_case);
96     } else if (else_block && equal(then_case, else_block->first)) {
97         return Block::make(then_case,
98                            mutate(IfThenElse::make(condition, Evaluate::make(0), else_block->rest)));
99     } else if (else_block && equal(then_case, else_block->rest)) {
100         return Block::make(mutate(IfThenElse::make(condition, Evaluate::make(0), else_block->first)),
101                            then_case);
102     } else if (then_for &&
103                !else_case.defined() &&
104                equal(unwrapped_condition, 0 < then_for->extent)) {
105         // This guard is redundant
106         return then_case;
107     } else if (condition.same_as(op->condition) &&
108                then_case.same_as(op->then_case) &&
109                else_case.same_as(op->else_case)) {
110         return op;
111     } else {
112         return IfThenElse::make(condition, then_case, else_case);
113     }
114 }
115 
visit(const AssertStmt * op)116 Stmt Simplify::visit(const AssertStmt *op) {
117     Expr cond = mutate(op->condition, nullptr);
118 
119     // The message is only evaluated when the condition is false
120     Expr message;
121     {
122         auto f = scoped_falsehood(cond);
123         message = mutate(op->message, nullptr);
124     }
125 
126     if (is_zero(cond)) {
127         // Usually, assert(const-false) should generate a warning;
128         // in at least one case (specialize_fail()), we want to suppress
129         // the warning, because the assertion is generated internally
130         // by Halide and is expected to always fail.
131         const Call *call = message.as<Call>();
132         const bool const_false_conditions_expected =
133             call && call->name == "halide_error_specialize_fail";
134         if (!const_false_conditions_expected) {
135             user_warning << "This pipeline is guaranteed to fail an assertion at runtime: \n"
136                          << message << "\n";
137         }
138     } else if (is_one(cond)) {
139         return Evaluate::make(0);
140     }
141 
142     if (cond.same_as(op->condition) && message.same_as(op->message)) {
143         return op;
144     } else {
145         return AssertStmt::make(cond, message);
146     }
147 }
148 
visit(const For * op)149 Stmt Simplify::visit(const For *op) {
150     ExprInfo min_bounds, extent_bounds;
151     Expr new_min = mutate(op->min, &min_bounds);
152     Expr new_extent = mutate(op->extent, &extent_bounds);
153 
154     ScopedValue<bool> old_in_vector_loop(in_vector_loop,
155                                          (in_vector_loop ||
156                                           op->for_type == ForType::Vectorized));
157 
158     bool bounds_tracked = false;
159     if (min_bounds.min_defined || (min_bounds.max_defined && extent_bounds.max_defined)) {
160         min_bounds.max += extent_bounds.max - 1;
161         min_bounds.max_defined &= extent_bounds.max_defined;
162         min_bounds.alignment = ModulusRemainder{};
163         bounds_tracked = true;
164         bounds_and_alignment_info.push(op->name, min_bounds);
165     }
166 
167     Stmt new_body = mutate(op->body);
168 
169     if (bounds_tracked) {
170         bounds_and_alignment_info.pop(op->name);
171     }
172 
173     if (is_no_op(new_body)) {
174         return new_body;
175     } else if (extent_bounds.max_defined &&
176                extent_bounds.max <= 0) {
177         return Evaluate::make(0);
178     } else if (is_one(new_extent) &&
179                op->device_api == DeviceAPI::None) {
180         Stmt s = LetStmt::make(op->name, new_min, new_body);
181         return mutate(s);
182     } else if (extent_bounds.max_defined &&
183                extent_bounds.max == 1 &&
184                !in_vector_loop &&
185                op->device_api == DeviceAPI::None) {
186         // If we're inside a vector loop we don't want to rewrite a
187         // for loop of extent at most one into an if, because the
188         // vectorization pass deals with those differently to an
189         // if. If the extent depends on the vectorized variable, the
190         // for loop gets an all-true vectorized case, but an if
191         // statement just gets scalarized.
192         Stmt s = LetStmt::make(op->name, new_min, new_body);
193         return mutate(IfThenElse::make(0 < new_extent, s));
194     } else if (op->min.same_as(new_min) &&
195                op->extent.same_as(new_extent) &&
196                op->body.same_as(new_body)) {
197         return op;
198     } else {
199         return For::make(op->name, new_min, new_extent, op->for_type, op->device_api, new_body);
200     }
201 }
202 
visit(const Provide * op)203 Stmt Simplify::visit(const Provide *op) {
204     found_buffer_reference(op->name, op->args.size());
205 
206     vector<Expr> new_args(op->args.size());
207     vector<Expr> new_values(op->values.size());
208     bool changed = false;
209 
210     // Mutate the args
211     for (size_t i = 0; i < op->args.size(); i++) {
212         const Expr &old_arg = op->args[i];
213         Expr new_arg = mutate(old_arg, nullptr);
214         if (!new_arg.same_as(old_arg)) changed = true;
215         new_args[i] = new_arg;
216     }
217 
218     for (size_t i = 0; i < op->values.size(); i++) {
219         const Expr &old_value = op->values[i];
220         Expr new_value = mutate(old_value, nullptr);
221         if (!new_value.same_as(old_value)) changed = true;
222         new_values[i] = new_value;
223     }
224 
225     if (!changed) {
226         return op;
227     } else {
228         return Provide::make(op->name, new_values, new_args);
229     }
230 }
231 
visit(const Store * op)232 Stmt Simplify::visit(const Store *op) {
233     found_buffer_reference(op->name);
234 
235     Expr predicate = mutate(op->predicate, nullptr);
236     Expr value = mutate(op->value, nullptr);
237 
238     ExprInfo index_info;
239     Expr index = mutate(op->index, &index_info);
240 
241     ExprInfo base_info;
242     if (const Ramp *r = index.as<Ramp>()) {
243         mutate(r->base, &base_info);
244     }
245     base_info.alignment = ModulusRemainder::intersect(base_info.alignment, index_info.alignment);
246 
247     const Load *load = value.as<Load>();
248     const Broadcast *scalar_pred = predicate.as<Broadcast>();
249 
250     ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment);
251 
252     if (is_zero(predicate)) {
253         // Predicate is always false
254         return Evaluate::make(0);
255     } else if (scalar_pred && !is_one(scalar_pred->value)) {
256         return IfThenElse::make(scalar_pred->value,
257                                 Store::make(op->name, value, index, op->param, const_true(value.type().lanes()), align));
258     } else if (is_undef(value) || (load && load->name == op->name && equal(load->index, index))) {
259         // foo[x] = foo[x] or foo[x] = undef is a no-op
260         return Evaluate::make(0);
261     } else if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index) && align == op->alignment) {
262         return op;
263     } else {
264         return Store::make(op->name, value, index, op->param, predicate, align);
265     }
266 }
267 
visit(const Allocate * op)268 Stmt Simplify::visit(const Allocate *op) {
269     std::vector<Expr> new_extents;
270     bool all_extents_unmodified = true;
271     for (size_t i = 0; i < op->extents.size(); i++) {
272         new_extents.push_back(mutate(op->extents[i], nullptr));
273         all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
274     }
275     Stmt body = mutate(op->body);
276     Expr condition = mutate(op->condition, nullptr);
277     Expr new_expr;
278     if (op->new_expr.defined()) {
279         new_expr = mutate(op->new_expr, nullptr);
280     }
281     const IfThenElse *body_if = body.as<IfThenElse>();
282     if (body_if &&
283         op->condition.defined() &&
284         equal(op->condition, body_if->condition)) {
285         // We can move the allocation into the if body case. The
286         // else case must not use it.
287         Stmt stmt = Allocate::make(op->name, op->type, op->memory_type,
288                                    new_extents, condition, body_if->then_case,
289                                    new_expr, op->free_function);
290         return IfThenElse::make(body_if->condition, stmt, body_if->else_case);
291     } else if (all_extents_unmodified &&
292                body.same_as(op->body) &&
293                condition.same_as(op->condition) &&
294                new_expr.same_as(op->new_expr)) {
295         return op;
296     } else {
297         return Allocate::make(op->name, op->type, op->memory_type,
298                               new_extents, condition, body,
299                               new_expr, op->free_function);
300     }
301 }
302 
visit(const Evaluate * op)303 Stmt Simplify::visit(const Evaluate *op) {
304     Expr value = mutate(op->value, nullptr);
305 
306     // Rewrite Lets inside an evaluate as LetStmts outside the Evaluate.
307     vector<pair<string, Expr>> lets;
308     while (const Let *let = value.as<Let>()) {
309         lets.emplace_back(let->name, let->value);
310         value = let->body;
311     }
312 
313     if (value.same_as(op->value)) {
314         internal_assert(lets.empty());
315         return op;
316     } else {
317         // Rewrap the lets outside the evaluate node
318         Stmt stmt = Evaluate::make(value);
319         for (size_t i = lets.size(); i > 0; i--) {
320             stmt = LetStmt::make(lets[i - 1].first, lets[i - 1].second, stmt);
321         }
322         return stmt;
323     }
324 }
325 
visit(const ProducerConsumer * op)326 Stmt Simplify::visit(const ProducerConsumer *op) {
327     Stmt body = mutate(op->body);
328 
329     if (is_no_op(body)) {
330         return Evaluate::make(0);
331     } else if (body.same_as(op->body)) {
332         return op;
333     } else {
334         return ProducerConsumer::make(op->name, op->is_producer, body);
335     }
336 }
337 
visit(const Block * op)338 Stmt Simplify::visit(const Block *op) {
339     Stmt first = mutate(op->first);
340     Stmt rest = op->rest;
341 
342     if (const AssertStmt *first_assert = first.as<AssertStmt>()) {
343         // Handle an entire sequence of asserts here to avoid a deeply
344         // nested stack.  We won't be popping any knowledge until
345         // after the end of this chain of asserts, so we can use a
346         // single ScopedFact and progressively add knowledge to it.
347         ScopedFact knowledge(this);
348         vector<Stmt> result;
349         result.push_back(first);
350         knowledge.learn_true(first_assert->condition);
351 
352         // Loop invariants: 'first' has already been mutated and is in
353         // the result list. 'first' was an AssertStmt before it was
354         // mutated, and its condition has been captured in
355         // 'knowledge'. 'rest' has not been mutated and is not in the
356         // result list.
357         const Block *rest_block;
358         while ((rest_block = rest.as<Block>()) &&
359                (first_assert = rest_block->first.as<AssertStmt>())) {
360             first = mutate(first_assert);
361             rest = rest_block->rest;
362             result.push_back(first);
363             if ((first_assert = first.as<AssertStmt>())) {
364                 // If it didn't fold away to trivially true or false,
365                 // learn the condition.
366                 knowledge.learn_true(first_assert->condition);
367             }
368         }
369 
370         result.push_back(mutate(rest));
371 
372         return Block::make(result);
373 
374     } else {
375         rest = mutate(op->rest);
376     }
377 
378     // Check if both halves start with a let statement.
379     const LetStmt *let_first = first.as<LetStmt>();
380     const LetStmt *let_rest = rest.as<LetStmt>();
381     const Block *block_rest = rest.as<Block>();
382     const IfThenElse *if_first = first.as<IfThenElse>();
383     const IfThenElse *if_next =
384         rest.as<IfThenElse>() ? rest.as<IfThenElse>() : (block_rest ? block_rest->first.as<IfThenElse>() : nullptr);
385     Stmt if_rest = block_rest ? block_rest->rest : Stmt();
386 
387     if (is_no_op(first) &&
388         is_no_op(rest)) {
389         return Evaluate::make(0);
390     } else if (is_no_op(first)) {
391         return rest;
392     } else if (is_no_op(rest)) {
393         return first;
394     } else if (let_first &&
395                let_rest &&
396                equal(let_first->value, let_rest->value) &&
397                is_pure(let_first->value)) {
398 
399         // Do both first and rest start with the same let statement (occurs when unrolling).
400         Stmt new_block = mutate(Block::make(let_first->body, let_rest->body));
401 
402         // We need to make a new name since we're pulling it out to a
403         // different scope.
404         string var_name = unique_name('t');
405         Expr new_var = Variable::make(let_first->value.type(), var_name);
406         new_block = substitute(let_first->name, new_var, new_block);
407         new_block = substitute(let_rest->name, new_var, new_block);
408 
409         return LetStmt::make(var_name, let_first->value, new_block);
410     } else if (if_first &&
411                if_next &&
412                equal(if_first->condition, if_next->condition) &&
413                is_pure(if_first->condition)) {
414         // Two ifs with matching conditions
415         Stmt then_case = mutate(Block::make(if_first->then_case, if_next->then_case));
416         Stmt else_case;
417         if (if_first->else_case.defined() && if_next->else_case.defined()) {
418             else_case = mutate(Block::make(if_first->else_case, if_next->else_case));
419         } else if (if_first->else_case.defined()) {
420             // We already simplified the body of the ifs.
421             else_case = if_first->else_case;
422         } else {
423             else_case = if_next->else_case;
424         }
425         Stmt result = IfThenElse::make(if_first->condition, then_case, else_case);
426         if (if_rest.defined()) {
427             result = Block::make(result, if_rest);
428         }
429         return result;
430     } else if (if_first &&
431                if_next &&
432                !if_next->else_case.defined() &&
433                is_pure(if_first->condition) &&
434                is_pure(if_next->condition) &&
435                is_one(mutate((if_first->condition && if_next->condition) == if_next->condition, nullptr))) {
436         // Two ifs where the second condition is tighter than
437         // the first condition.  The second if can be nested
438         // inside the first one, because if it's true the
439         // first one must also be true.
440         Stmt then_case = mutate(Block::make(if_first->then_case, if_next));
441         Stmt else_case = mutate(if_first->else_case);
442         Stmt result = IfThenElse::make(if_first->condition, then_case, else_case);
443         if (if_rest.defined()) {
444             result = Block::make(result, if_rest);
445         }
446         return result;
447     } else if (op->first.same_as(first) &&
448                op->rest.same_as(rest)) {
449         return op;
450     } else {
451         return Block::make(first, rest);
452     }
453 }
454 
visit(const Realize * op)455 Stmt Simplify::visit(const Realize *op) {
456     Region new_bounds;
457     bool bounds_changed;
458 
459     // Mutate the bounds
460     std::tie(new_bounds, bounds_changed) = mutate_region(this, op->bounds, nullptr);
461 
462     Stmt body = mutate(op->body);
463     Expr condition = mutate(op->condition, nullptr);
464     if (!bounds_changed &&
465         body.same_as(op->body) &&
466         condition.same_as(op->condition)) {
467         return op;
468     }
469     return Realize::make(op->name, op->types, op->memory_type, new_bounds,
470                          std::move(condition), std::move(body));
471 }
472 
visit(const Prefetch * op)473 Stmt Simplify::visit(const Prefetch *op) {
474     Stmt body = mutate(op->body);
475     Expr condition = mutate(op->condition, nullptr);
476 
477     if (is_zero(op->condition)) {
478         // Predicate is always false
479         return body;
480     }
481 
482     Region new_bounds;
483     bool bounds_changed;
484 
485     // Mutate the bounds
486     std::tie(new_bounds, bounds_changed) = mutate_region(this, op->bounds, nullptr);
487 
488     if (!bounds_changed &&
489         body.same_as(op->body) &&
490         condition.same_as(op->condition)) {
491         return op;
492     } else {
493         return Prefetch::make(op->name, op->types, new_bounds, op->prefetch, std::move(condition), std::move(body));
494     }
495 }
496 
visit(const Free * op)497 Stmt Simplify::visit(const Free *op) {
498     return op;
499 }
500 
visit(const Acquire * op)501 Stmt Simplify::visit(const Acquire *op) {
502     Expr sema = mutate(op->semaphore, nullptr);
503     Expr count = mutate(op->count, nullptr);
504     Stmt body = mutate(op->body);
505     if (sema.same_as(op->semaphore) &&
506         body.same_as(op->body) &&
507         count.same_as(op->count)) {
508         return op;
509     } else {
510         return Acquire::make(std::move(sema), std::move(count), std::move(body));
511     }
512 }
513 
visit(const Fork * op)514 Stmt Simplify::visit(const Fork *op) {
515     Stmt first = mutate(op->first);
516     Stmt rest = mutate(op->rest);
517     if (is_no_op(first)) {
518         return rest;
519     } else if (is_no_op(rest)) {
520         return first;
521     } else if (op->first.same_as(first) &&
522                op->rest.same_as(rest)) {
523         return op;
524     } else {
525         return Fork::make(first, rest);
526     }
527 }
528 
visit(const Atomic * op)529 Stmt Simplify::visit(const Atomic *op) {
530     Stmt body = mutate(op->body);
531     if (is_no_op(body)) {
532         return Evaluate::make(0);
533     } else if (body.same_as(op->body)) {
534         return op;
535     } else {
536         return Atomic::make(op->producer_name,
537                             op->mutex_name,
538                             std::move(body));
539     }
540 }
541 
542 }  // namespace Internal
543 }  // namespace Halide
544