1 #include "Simplify_Internal.h"
2
3 #include "IRMutator.h"
4 #include "Substitute.h"
5
6 namespace Halide {
7 namespace Internal {
8
9 using std::pair;
10 using std::string;
11 using std::vector;
12
visit(const IfThenElse * op)13 Stmt Simplify::visit(const IfThenElse *op) {
14 Expr condition = mutate(op->condition, nullptr);
15
16 // If (likely(true)) ...
17 const Call *call = condition.as<Call>();
18 Expr unwrapped_condition = condition;
19 if (call &&
20 (call->is_intrinsic(Call::likely) ||
21 call->is_intrinsic(Call::likely_if_innermost))) {
22 unwrapped_condition = call->args[0];
23 }
24
25 // If (true) ...
26 if (is_one(unwrapped_condition)) {
27 return mutate(op->then_case);
28 }
29
30 // If (false) ...
31 if (is_zero(unwrapped_condition)) {
32 if (op->else_case.defined()) {
33 return mutate(op->else_case);
34 } else {
35 return Evaluate::make(0);
36 }
37 }
38
39 Stmt then_case, else_case;
40 {
41 auto f = scoped_truth(unwrapped_condition);
42 // Also substitute the entire condition
43 then_case = substitute(op->condition, const_true(condition.type().lanes()), op->then_case);
44 then_case = mutate(then_case);
45 }
46 {
47 auto f = scoped_falsehood(unwrapped_condition);
48 else_case = substitute(op->condition, const_false(condition.type().lanes()), op->else_case);
49 else_case = mutate(else_case);
50 }
51
52 // If both sides are no-ops, bail out.
53 if (is_no_op(then_case) && is_no_op(else_case)) {
54 return then_case;
55 }
56
57 // Pull out common nodes
58 if (equal(then_case, else_case)) {
59 return then_case;
60 }
61 const Acquire *then_acquire = then_case.as<Acquire>();
62 const Acquire *else_acquire = else_case.as<Acquire>();
63 const ProducerConsumer *then_pc = then_case.as<ProducerConsumer>();
64 const ProducerConsumer *else_pc = else_case.as<ProducerConsumer>();
65 const Block *then_block = then_case.as<Block>();
66 const Block *else_block = else_case.as<Block>();
67 const For *then_for = then_case.as<For>();
68 if (then_acquire &&
69 else_acquire &&
70 equal(then_acquire->semaphore, else_acquire->semaphore) &&
71 equal(then_acquire->count, else_acquire->count)) {
72 return Acquire::make(then_acquire->semaphore, then_acquire->count,
73 mutate(IfThenElse::make(condition, then_acquire->body, else_acquire->body)));
74 } else if (then_pc &&
75 else_pc &&
76 then_pc->name == else_pc->name &&
77 then_pc->is_producer == else_pc->is_producer) {
78 return ProducerConsumer::make(then_pc->name, then_pc->is_producer,
79 mutate(IfThenElse::make(condition, then_pc->body, else_pc->body)));
80 } else if (then_block &&
81 else_block &&
82 equal(then_block->first, else_block->first)) {
83 return Block::make(then_block->first,
84 mutate(IfThenElse::make(condition, then_block->rest, else_block->rest)));
85 } else if (then_block &&
86 else_block &&
87 equal(then_block->rest, else_block->rest)) {
88 return Block::make(mutate(IfThenElse::make(condition, then_block->first, else_block->first)),
89 then_block->rest);
90 } else if (then_block && equal(then_block->first, else_case)) {
91 return Block::make(else_case,
92 mutate(IfThenElse::make(condition, then_block->rest)));
93 } else if (then_block && equal(then_block->rest, else_case)) {
94 return Block::make(mutate(IfThenElse::make(condition, then_block->first)),
95 else_case);
96 } else if (else_block && equal(then_case, else_block->first)) {
97 return Block::make(then_case,
98 mutate(IfThenElse::make(condition, Evaluate::make(0), else_block->rest)));
99 } else if (else_block && equal(then_case, else_block->rest)) {
100 return Block::make(mutate(IfThenElse::make(condition, Evaluate::make(0), else_block->first)),
101 then_case);
102 } else if (then_for &&
103 !else_case.defined() &&
104 equal(unwrapped_condition, 0 < then_for->extent)) {
105 // This guard is redundant
106 return then_case;
107 } else if (condition.same_as(op->condition) &&
108 then_case.same_as(op->then_case) &&
109 else_case.same_as(op->else_case)) {
110 return op;
111 } else {
112 return IfThenElse::make(condition, then_case, else_case);
113 }
114 }
115
visit(const AssertStmt * op)116 Stmt Simplify::visit(const AssertStmt *op) {
117 Expr cond = mutate(op->condition, nullptr);
118
119 // The message is only evaluated when the condition is false
120 Expr message;
121 {
122 auto f = scoped_falsehood(cond);
123 message = mutate(op->message, nullptr);
124 }
125
126 if (is_zero(cond)) {
127 // Usually, assert(const-false) should generate a warning;
128 // in at least one case (specialize_fail()), we want to suppress
129 // the warning, because the assertion is generated internally
130 // by Halide and is expected to always fail.
131 const Call *call = message.as<Call>();
132 const bool const_false_conditions_expected =
133 call && call->name == "halide_error_specialize_fail";
134 if (!const_false_conditions_expected) {
135 user_warning << "This pipeline is guaranteed to fail an assertion at runtime: \n"
136 << message << "\n";
137 }
138 } else if (is_one(cond)) {
139 return Evaluate::make(0);
140 }
141
142 if (cond.same_as(op->condition) && message.same_as(op->message)) {
143 return op;
144 } else {
145 return AssertStmt::make(cond, message);
146 }
147 }
148
visit(const For * op)149 Stmt Simplify::visit(const For *op) {
150 ExprInfo min_bounds, extent_bounds;
151 Expr new_min = mutate(op->min, &min_bounds);
152 Expr new_extent = mutate(op->extent, &extent_bounds);
153
154 ScopedValue<bool> old_in_vector_loop(in_vector_loop,
155 (in_vector_loop ||
156 op->for_type == ForType::Vectorized));
157
158 bool bounds_tracked = false;
159 if (min_bounds.min_defined || (min_bounds.max_defined && extent_bounds.max_defined)) {
160 min_bounds.max += extent_bounds.max - 1;
161 min_bounds.max_defined &= extent_bounds.max_defined;
162 min_bounds.alignment = ModulusRemainder{};
163 bounds_tracked = true;
164 bounds_and_alignment_info.push(op->name, min_bounds);
165 }
166
167 Stmt new_body = mutate(op->body);
168
169 if (bounds_tracked) {
170 bounds_and_alignment_info.pop(op->name);
171 }
172
173 if (is_no_op(new_body)) {
174 return new_body;
175 } else if (extent_bounds.max_defined &&
176 extent_bounds.max <= 0) {
177 return Evaluate::make(0);
178 } else if (is_one(new_extent) &&
179 op->device_api == DeviceAPI::None) {
180 Stmt s = LetStmt::make(op->name, new_min, new_body);
181 return mutate(s);
182 } else if (extent_bounds.max_defined &&
183 extent_bounds.max == 1 &&
184 !in_vector_loop &&
185 op->device_api == DeviceAPI::None) {
186 // If we're inside a vector loop we don't want to rewrite a
187 // for loop of extent at most one into an if, because the
188 // vectorization pass deals with those differently to an
189 // if. If the extent depends on the vectorized variable, the
190 // for loop gets an all-true vectorized case, but an if
191 // statement just gets scalarized.
192 Stmt s = LetStmt::make(op->name, new_min, new_body);
193 return mutate(IfThenElse::make(0 < new_extent, s));
194 } else if (op->min.same_as(new_min) &&
195 op->extent.same_as(new_extent) &&
196 op->body.same_as(new_body)) {
197 return op;
198 } else {
199 return For::make(op->name, new_min, new_extent, op->for_type, op->device_api, new_body);
200 }
201 }
202
visit(const Provide * op)203 Stmt Simplify::visit(const Provide *op) {
204 found_buffer_reference(op->name, op->args.size());
205
206 vector<Expr> new_args(op->args.size());
207 vector<Expr> new_values(op->values.size());
208 bool changed = false;
209
210 // Mutate the args
211 for (size_t i = 0; i < op->args.size(); i++) {
212 const Expr &old_arg = op->args[i];
213 Expr new_arg = mutate(old_arg, nullptr);
214 if (!new_arg.same_as(old_arg)) changed = true;
215 new_args[i] = new_arg;
216 }
217
218 for (size_t i = 0; i < op->values.size(); i++) {
219 const Expr &old_value = op->values[i];
220 Expr new_value = mutate(old_value, nullptr);
221 if (!new_value.same_as(old_value)) changed = true;
222 new_values[i] = new_value;
223 }
224
225 if (!changed) {
226 return op;
227 } else {
228 return Provide::make(op->name, new_values, new_args);
229 }
230 }
231
visit(const Store * op)232 Stmt Simplify::visit(const Store *op) {
233 found_buffer_reference(op->name);
234
235 Expr predicate = mutate(op->predicate, nullptr);
236 Expr value = mutate(op->value, nullptr);
237
238 ExprInfo index_info;
239 Expr index = mutate(op->index, &index_info);
240
241 ExprInfo base_info;
242 if (const Ramp *r = index.as<Ramp>()) {
243 mutate(r->base, &base_info);
244 }
245 base_info.alignment = ModulusRemainder::intersect(base_info.alignment, index_info.alignment);
246
247 const Load *load = value.as<Load>();
248 const Broadcast *scalar_pred = predicate.as<Broadcast>();
249
250 ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment);
251
252 if (is_zero(predicate)) {
253 // Predicate is always false
254 return Evaluate::make(0);
255 } else if (scalar_pred && !is_one(scalar_pred->value)) {
256 return IfThenElse::make(scalar_pred->value,
257 Store::make(op->name, value, index, op->param, const_true(value.type().lanes()), align));
258 } else if (is_undef(value) || (load && load->name == op->name && equal(load->index, index))) {
259 // foo[x] = foo[x] or foo[x] = undef is a no-op
260 return Evaluate::make(0);
261 } else if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index) && align == op->alignment) {
262 return op;
263 } else {
264 return Store::make(op->name, value, index, op->param, predicate, align);
265 }
266 }
267
visit(const Allocate * op)268 Stmt Simplify::visit(const Allocate *op) {
269 std::vector<Expr> new_extents;
270 bool all_extents_unmodified = true;
271 for (size_t i = 0; i < op->extents.size(); i++) {
272 new_extents.push_back(mutate(op->extents[i], nullptr));
273 all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
274 }
275 Stmt body = mutate(op->body);
276 Expr condition = mutate(op->condition, nullptr);
277 Expr new_expr;
278 if (op->new_expr.defined()) {
279 new_expr = mutate(op->new_expr, nullptr);
280 }
281 const IfThenElse *body_if = body.as<IfThenElse>();
282 if (body_if &&
283 op->condition.defined() &&
284 equal(op->condition, body_if->condition)) {
285 // We can move the allocation into the if body case. The
286 // else case must not use it.
287 Stmt stmt = Allocate::make(op->name, op->type, op->memory_type,
288 new_extents, condition, body_if->then_case,
289 new_expr, op->free_function);
290 return IfThenElse::make(body_if->condition, stmt, body_if->else_case);
291 } else if (all_extents_unmodified &&
292 body.same_as(op->body) &&
293 condition.same_as(op->condition) &&
294 new_expr.same_as(op->new_expr)) {
295 return op;
296 } else {
297 return Allocate::make(op->name, op->type, op->memory_type,
298 new_extents, condition, body,
299 new_expr, op->free_function);
300 }
301 }
302
visit(const Evaluate * op)303 Stmt Simplify::visit(const Evaluate *op) {
304 Expr value = mutate(op->value, nullptr);
305
306 // Rewrite Lets inside an evaluate as LetStmts outside the Evaluate.
307 vector<pair<string, Expr>> lets;
308 while (const Let *let = value.as<Let>()) {
309 lets.emplace_back(let->name, let->value);
310 value = let->body;
311 }
312
313 if (value.same_as(op->value)) {
314 internal_assert(lets.empty());
315 return op;
316 } else {
317 // Rewrap the lets outside the evaluate node
318 Stmt stmt = Evaluate::make(value);
319 for (size_t i = lets.size(); i > 0; i--) {
320 stmt = LetStmt::make(lets[i - 1].first, lets[i - 1].second, stmt);
321 }
322 return stmt;
323 }
324 }
325
visit(const ProducerConsumer * op)326 Stmt Simplify::visit(const ProducerConsumer *op) {
327 Stmt body = mutate(op->body);
328
329 if (is_no_op(body)) {
330 return Evaluate::make(0);
331 } else if (body.same_as(op->body)) {
332 return op;
333 } else {
334 return ProducerConsumer::make(op->name, op->is_producer, body);
335 }
336 }
337
visit(const Block * op)338 Stmt Simplify::visit(const Block *op) {
339 Stmt first = mutate(op->first);
340 Stmt rest = op->rest;
341
342 if (const AssertStmt *first_assert = first.as<AssertStmt>()) {
343 // Handle an entire sequence of asserts here to avoid a deeply
344 // nested stack. We won't be popping any knowledge until
345 // after the end of this chain of asserts, so we can use a
346 // single ScopedFact and progressively add knowledge to it.
347 ScopedFact knowledge(this);
348 vector<Stmt> result;
349 result.push_back(first);
350 knowledge.learn_true(first_assert->condition);
351
352 // Loop invariants: 'first' has already been mutated and is in
353 // the result list. 'first' was an AssertStmt before it was
354 // mutated, and its condition has been captured in
355 // 'knowledge'. 'rest' has not been mutated and is not in the
356 // result list.
357 const Block *rest_block;
358 while ((rest_block = rest.as<Block>()) &&
359 (first_assert = rest_block->first.as<AssertStmt>())) {
360 first = mutate(first_assert);
361 rest = rest_block->rest;
362 result.push_back(first);
363 if ((first_assert = first.as<AssertStmt>())) {
364 // If it didn't fold away to trivially true or false,
365 // learn the condition.
366 knowledge.learn_true(first_assert->condition);
367 }
368 }
369
370 result.push_back(mutate(rest));
371
372 return Block::make(result);
373
374 } else {
375 rest = mutate(op->rest);
376 }
377
378 // Check if both halves start with a let statement.
379 const LetStmt *let_first = first.as<LetStmt>();
380 const LetStmt *let_rest = rest.as<LetStmt>();
381 const Block *block_rest = rest.as<Block>();
382 const IfThenElse *if_first = first.as<IfThenElse>();
383 const IfThenElse *if_next =
384 rest.as<IfThenElse>() ? rest.as<IfThenElse>() : (block_rest ? block_rest->first.as<IfThenElse>() : nullptr);
385 Stmt if_rest = block_rest ? block_rest->rest : Stmt();
386
387 if (is_no_op(first) &&
388 is_no_op(rest)) {
389 return Evaluate::make(0);
390 } else if (is_no_op(first)) {
391 return rest;
392 } else if (is_no_op(rest)) {
393 return first;
394 } else if (let_first &&
395 let_rest &&
396 equal(let_first->value, let_rest->value) &&
397 is_pure(let_first->value)) {
398
399 // Do both first and rest start with the same let statement (occurs when unrolling).
400 Stmt new_block = mutate(Block::make(let_first->body, let_rest->body));
401
402 // We need to make a new name since we're pulling it out to a
403 // different scope.
404 string var_name = unique_name('t');
405 Expr new_var = Variable::make(let_first->value.type(), var_name);
406 new_block = substitute(let_first->name, new_var, new_block);
407 new_block = substitute(let_rest->name, new_var, new_block);
408
409 return LetStmt::make(var_name, let_first->value, new_block);
410 } else if (if_first &&
411 if_next &&
412 equal(if_first->condition, if_next->condition) &&
413 is_pure(if_first->condition)) {
414 // Two ifs with matching conditions
415 Stmt then_case = mutate(Block::make(if_first->then_case, if_next->then_case));
416 Stmt else_case;
417 if (if_first->else_case.defined() && if_next->else_case.defined()) {
418 else_case = mutate(Block::make(if_first->else_case, if_next->else_case));
419 } else if (if_first->else_case.defined()) {
420 // We already simplified the body of the ifs.
421 else_case = if_first->else_case;
422 } else {
423 else_case = if_next->else_case;
424 }
425 Stmt result = IfThenElse::make(if_first->condition, then_case, else_case);
426 if (if_rest.defined()) {
427 result = Block::make(result, if_rest);
428 }
429 return result;
430 } else if (if_first &&
431 if_next &&
432 !if_next->else_case.defined() &&
433 is_pure(if_first->condition) &&
434 is_pure(if_next->condition) &&
435 is_one(mutate((if_first->condition && if_next->condition) == if_next->condition, nullptr))) {
436 // Two ifs where the second condition is tighter than
437 // the first condition. The second if can be nested
438 // inside the first one, because if it's true the
439 // first one must also be true.
440 Stmt then_case = mutate(Block::make(if_first->then_case, if_next));
441 Stmt else_case = mutate(if_first->else_case);
442 Stmt result = IfThenElse::make(if_first->condition, then_case, else_case);
443 if (if_rest.defined()) {
444 result = Block::make(result, if_rest);
445 }
446 return result;
447 } else if (op->first.same_as(first) &&
448 op->rest.same_as(rest)) {
449 return op;
450 } else {
451 return Block::make(first, rest);
452 }
453 }
454
visit(const Realize * op)455 Stmt Simplify::visit(const Realize *op) {
456 Region new_bounds;
457 bool bounds_changed;
458
459 // Mutate the bounds
460 std::tie(new_bounds, bounds_changed) = mutate_region(this, op->bounds, nullptr);
461
462 Stmt body = mutate(op->body);
463 Expr condition = mutate(op->condition, nullptr);
464 if (!bounds_changed &&
465 body.same_as(op->body) &&
466 condition.same_as(op->condition)) {
467 return op;
468 }
469 return Realize::make(op->name, op->types, op->memory_type, new_bounds,
470 std::move(condition), std::move(body));
471 }
472
visit(const Prefetch * op)473 Stmt Simplify::visit(const Prefetch *op) {
474 Stmt body = mutate(op->body);
475 Expr condition = mutate(op->condition, nullptr);
476
477 if (is_zero(op->condition)) {
478 // Predicate is always false
479 return body;
480 }
481
482 Region new_bounds;
483 bool bounds_changed;
484
485 // Mutate the bounds
486 std::tie(new_bounds, bounds_changed) = mutate_region(this, op->bounds, nullptr);
487
488 if (!bounds_changed &&
489 body.same_as(op->body) &&
490 condition.same_as(op->condition)) {
491 return op;
492 } else {
493 return Prefetch::make(op->name, op->types, new_bounds, op->prefetch, std::move(condition), std::move(body));
494 }
495 }
496
visit(const Free * op)497 Stmt Simplify::visit(const Free *op) {
498 return op;
499 }
500
visit(const Acquire * op)501 Stmt Simplify::visit(const Acquire *op) {
502 Expr sema = mutate(op->semaphore, nullptr);
503 Expr count = mutate(op->count, nullptr);
504 Stmt body = mutate(op->body);
505 if (sema.same_as(op->semaphore) &&
506 body.same_as(op->body) &&
507 count.same_as(op->count)) {
508 return op;
509 } else {
510 return Acquire::make(std::move(sema), std::move(count), std::move(body));
511 }
512 }
513
visit(const Fork * op)514 Stmt Simplify::visit(const Fork *op) {
515 Stmt first = mutate(op->first);
516 Stmt rest = mutate(op->rest);
517 if (is_no_op(first)) {
518 return rest;
519 } else if (is_no_op(rest)) {
520 return first;
521 } else if (op->first.same_as(first) &&
522 op->rest.same_as(rest)) {
523 return op;
524 } else {
525 return Fork::make(first, rest);
526 }
527 }
528
visit(const Atomic * op)529 Stmt Simplify::visit(const Atomic *op) {
530 Stmt body = mutate(op->body);
531 if (is_no_op(body)) {
532 return Evaluate::make(0);
533 } else if (body.same_as(op->body)) {
534 return op;
535 } else {
536 return Atomic::make(op->producer_name,
537 op->mutex_name,
538 std::move(body));
539 }
540 }
541
542 } // namespace Internal
543 } // namespace Halide
544