1 #include "AsyncProducers.h"
2 #include "ExprUsesVar.h"
3 #include "Function.h"
4 #include "IREquality.h"
5 #include "IRMutator.h"
6 #include "IROperator.h"
7
8 namespace Halide {
9 namespace Internal {
10
11 using std::map;
12 using std::pair;
13 using std::set;
14 using std::string;
15 using std::vector;
16
17 /** A mutator which eagerly folds no-op stmts */
18 class NoOpCollapsingMutator : public IRMutator {
19 protected:
20 using IRMutator::visit;
21
visit(const LetStmt * op)22 Stmt visit(const LetStmt *op) override {
23 Stmt body = mutate(op->body);
24 if (is_no_op(body)) {
25 return body;
26 } else {
27 return LetStmt::make(op->name, op->value, body);
28 }
29 }
30
visit(const For * op)31 Stmt visit(const For *op) override {
32 Stmt body = mutate(op->body);
33 if (is_no_op(body)) {
34 return body;
35 } else {
36 return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
37 }
38 }
39
visit(const Block * op)40 Stmt visit(const Block *op) override {
41 Stmt first = mutate(op->first);
42 Stmt rest = mutate(op->rest);
43 if (is_no_op(first)) {
44 return rest;
45 } else if (is_no_op(rest)) {
46 return first;
47 } else {
48 return Block::make(first, rest);
49 }
50 }
51
visit(const Fork * op)52 Stmt visit(const Fork *op) override {
53 Stmt first = mutate(op->first);
54 Stmt rest = mutate(op->rest);
55 if (is_no_op(first)) {
56 return rest;
57 } else if (is_no_op(rest)) {
58 return first;
59 } else {
60 return Fork::make(first, rest);
61 }
62 }
63
visit(const Realize * op)64 Stmt visit(const Realize *op) override {
65 Stmt body = mutate(op->body);
66 if (is_no_op(body)) {
67 return body;
68 } else {
69 return Realize::make(op->name, op->types, op->memory_type,
70 op->bounds, op->condition, body);
71 }
72 }
73
visit(const Allocate * op)74 Stmt visit(const Allocate *op) override {
75 Stmt body = mutate(op->body);
76 if (is_no_op(body)) {
77 return body;
78 } else {
79 return Allocate::make(op->name, op->type, op->memory_type,
80 op->extents, op->condition, body,
81 op->new_expr, op->free_function);
82 }
83 }
84
visit(const IfThenElse * op)85 Stmt visit(const IfThenElse *op) override {
86 Stmt then_case = mutate(op->then_case);
87 Stmt else_case = mutate(op->else_case);
88 if (is_no_op(then_case) && is_no_op(else_case)) {
89 return then_case;
90 } else {
91 return IfThenElse::make(op->condition, then_case, else_case);
92 }
93 }
94
visit(const Atomic * op)95 Stmt visit(const Atomic *op) override {
96 Stmt body = mutate(op->body);
97 if (is_no_op(body)) {
98 return body;
99 } else {
100 return Atomic::make(op->producer_name,
101 op->mutex_name,
102 std::move(body));
103 }
104 }
105 };
106
107 class GenerateProducerBody : public NoOpCollapsingMutator {
108 const string &func;
109 vector<Expr> sema;
110
111 using NoOpCollapsingMutator::visit;
112
113 // Preserve produce nodes and add synchronization
visit(const ProducerConsumer * op)114 Stmt visit(const ProducerConsumer *op) override {
115 if (op->name == func && op->is_producer) {
116 // Add post-synchronization
117 internal_assert(!sema.empty()) << "Duplicate produce node: " << op->name << "\n";
118 Stmt body = op->body;
119 while (!sema.empty()) {
120 Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
121 body = Block::make(body, Evaluate::make(release));
122 sema.pop_back();
123 }
124 return ProducerConsumer::make_produce(op->name, body);
125 } else {
126 Stmt body = mutate(op->body);
127 if (is_no_op(body) || op->is_producer) {
128 return body;
129 } else {
130 return ProducerConsumer::make(op->name, op->is_producer, body);
131 }
132 }
133 }
134
135 // Other stmt leaves get replaced with no-ops
visit(const Evaluate *)136 Stmt visit(const Evaluate *) override {
137 return Evaluate::make(0);
138 }
139
visit(const Provide *)140 Stmt visit(const Provide *) override {
141 return Evaluate::make(0);
142 }
143
visit(const Store * op)144 Stmt visit(const Store *op) override {
145 if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
146 // This is a counter associated with the producer side of a storage-folding semaphore. Keep it.
147 return op;
148 } else {
149 return Evaluate::make(0);
150 }
151 }
152
visit(const AssertStmt *)153 Stmt visit(const AssertStmt *) override {
154 return Evaluate::make(0);
155 }
156
visit(const Prefetch *)157 Stmt visit(const Prefetch *) override {
158 return Evaluate::make(0);
159 }
160
visit(const Acquire * op)161 Stmt visit(const Acquire *op) override {
162 Stmt body = mutate(op->body);
163 const Variable *var = op->semaphore.as<Variable>();
164 internal_assert(var);
165 if (is_no_op(body)) {
166 return body;
167 } else if (starts_with(var->name, func + ".folding_semaphore.")) {
168 // This is a storage-folding semaphore for the func we're producing. Keep it.
169 return Acquire::make(op->semaphore, op->count, body);
170 } else {
171 // This semaphore will end up on both sides of the fork,
172 // so we'd better duplicate it.
173 string cloned_acquire = var->name + unique_name('_');
174 cloned_acquires[var->name] = cloned_acquire;
175 return Acquire::make(Variable::make(type_of<halide_semaphore_t *>(), cloned_acquire), op->count, body);
176 }
177 }
178
visit(const Atomic * op)179 Stmt visit(const Atomic *op) override {
180 return Evaluate::make(0);
181 }
182
visit(const Call * op)183 Expr visit(const Call *op) override {
184 if (op->name == "halide_semaphore_init") {
185 internal_assert(op->args.size() == 2);
186 const Variable *var = op->args[0].as<Variable>();
187 internal_assert(var);
188 inner_semaphores.insert(var->name);
189 }
190 return op;
191 }
192
193 map<string, string> &cloned_acquires;
194 set<string> inner_semaphores;
195
196 public:
GenerateProducerBody(const string & f,const vector<Expr> & s,map<string,string> & a)197 GenerateProducerBody(const string &f, const vector<Expr> &s, map<string, string> &a)
198 : func(f), sema(s), cloned_acquires(a) {
199 }
200 };
201
202 class GenerateConsumerBody : public NoOpCollapsingMutator {
203 const string &func;
204 vector<Expr> sema;
205
206 using NoOpCollapsingMutator::visit;
207
visit(const ProducerConsumer * op)208 Stmt visit(const ProducerConsumer *op) override {
209 if (op->name == func) {
210 if (op->is_producer) {
211 // Remove the work entirely
212 return Evaluate::make(0);
213 } else {
214 // Synchronize on the work done by the producer before beginning consumption
215 Expr acquire_sema = sema.back();
216 sema.pop_back();
217 return Acquire::make(acquire_sema, 1, op);
218 }
219 } else {
220 return NoOpCollapsingMutator::visit(op);
221 }
222 }
223
visit(const Allocate * op)224 Stmt visit(const Allocate *op) override {
225 // Don't want to keep the producer's storage-folding tracker - it's dead code on the consumer side
226 if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
227 return mutate(op->body);
228 } else {
229 return NoOpCollapsingMutator::visit(op);
230 }
231 }
232
visit(const Store * op)233 Stmt visit(const Store *op) override {
234 if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
235 return Evaluate::make(0);
236 } else {
237 return NoOpCollapsingMutator::visit(op);
238 }
239 }
240
visit(const Acquire * op)241 Stmt visit(const Acquire *op) override {
242 // Don't want to duplicate any semaphore acquires.
243 // Ones from folding should go to the producer side.
244 const Variable *var = op->semaphore.as<Variable>();
245 internal_assert(var);
246 if (starts_with(var->name, func + ".folding_semaphore.")) {
247 return mutate(op->body);
248 } else {
249 return NoOpCollapsingMutator::visit(op);
250 }
251 }
252
253 public:
GenerateConsumerBody(const string & f,const vector<Expr> & s)254 GenerateConsumerBody(const string &f, const vector<Expr> &s)
255 : func(f), sema(s) {
256 }
257 };
258
259 class CloneAcquire : public IRMutator {
260 using IRMutator::visit;
261
262 const string &old_name;
263 Expr new_var;
264
visit(const Evaluate * op)265 Stmt visit(const Evaluate *op) override {
266 const Call *call = op->value.as<Call>();
267 const Variable *var = ((call && !call->args.empty()) ? call->args[0].as<Variable>() : nullptr);
268 if (var && var->name == old_name &&
269 (call->name == "halide_semaphore_release" ||
270 call->name == "halide_semaphore_init")) {
271 vector<Expr> args = call->args;
272 args[0] = new_var;
273 Stmt new_stmt =
274 Evaluate::make(Call::make(call->type, call->name, args, call->call_type));
275 return Block::make(op, new_stmt);
276 } else {
277 return op;
278 }
279 }
280
281 public:
CloneAcquire(const string & o,const string & new_name)282 CloneAcquire(const string &o, const string &new_name)
283 : old_name(o) {
284 new_var = Variable::make(type_of<halide_semaphore_t *>(), new_name);
285 }
286 };
287
288 class CountConsumeNodes : public IRVisitor {
289 const string &func;
290
291 using IRVisitor::visit;
292
visit(const ProducerConsumer * op)293 void visit(const ProducerConsumer *op) override {
294 if (op->name == func && !op->is_producer) {
295 count++;
296 }
297 IRVisitor::visit(op);
298 }
299
300 public:
CountConsumeNodes(const string & f)301 CountConsumeNodes(const string &f)
302 : func(f) {
303 }
304 int count = 0;
305 };
306
307 class ForkAsyncProducers : public IRMutator {
308 using IRMutator::visit;
309
310 const map<string, Function> &env;
311
312 map<string, string> cloned_acquires;
313
visit(const Realize * op)314 Stmt visit(const Realize *op) override {
315 auto it = env.find(op->name);
316 internal_assert(it != env.end());
317 Function f = it->second;
318 if (f.schedule().async()) {
319 Stmt body = op->body;
320
321 // Make two copies of the body, one which only does the
322 // producer, and one which only does the consumer. Inject
323 // synchronization to preserve dependencies. Put them in a
324 // task-parallel block.
325
326 // Make a semaphore per consume node
327 CountConsumeNodes consumes(op->name);
328 body.accept(&consumes);
329
330 vector<string> sema_names;
331 vector<Expr> sema_vars;
332 for (int i = 0; i < consumes.count; i++) {
333 sema_names.push_back(op->name + ".semaphore_" + std::to_string(i));
334 sema_vars.push_back(Variable::make(Handle(), sema_names.back()));
335 }
336
337 Stmt producer = GenerateProducerBody(op->name, sema_vars, cloned_acquires).mutate(body);
338 Stmt consumer = GenerateConsumerBody(op->name, sema_vars).mutate(body);
339
340 // Recurse on both sides
341 producer = mutate(producer);
342 consumer = mutate(consumer);
343
344 // Run them concurrently
345 body = Fork::make(producer, consumer);
346
347 for (const string &sema_name : sema_names) {
348 // Make a semaphore on the stack
349 Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
350 {0}, Call::Extern);
351
352 // If there's a nested async producer, we may have
353 // recursively cloned this semaphore inside the mutation
354 // of the producer and consumer.
355 auto it = cloned_acquires.find(sema_name);
356 if (it != cloned_acquires.end()) {
357 body = CloneAcquire(sema_name, it->second).mutate(body);
358 body = LetStmt::make(it->second, sema_space, body);
359 }
360
361 body = LetStmt::make(sema_name, sema_space, body);
362 }
363
364 return Realize::make(op->name, op->types, op->memory_type,
365 op->bounds, op->condition, body);
366 } else {
367 return IRMutator::visit(op);
368 }
369 }
370
371 public:
ForkAsyncProducers(const map<string,Function> & e)372 ForkAsyncProducers(const map<string, Function> &e)
373 : env(e) {
374 }
375 };
376
377 // Lowers semaphore initialization from a call to
378 // "halide_make_semaphore" to an alloca followed by a call into the
379 // runtime to initialize. If something crashes before releasing a
380 // semaphore, the task system is responsible for propagating the
381 // failure to all branches of the fork. This depends on all semaphore
382 // acquires happening as part of the halide_do_parallel_tasks logic,
383 // not via explicit code in the closure. The current design for this
384 // does not propagate failures downward to subtasks of a failed
385 // fork. It assumes these will be able to reach completion in spite of
386 // the failure, which remains to be proven. (There is a test for the
387 // simple failure case, error_async_require_fail. One has not been
388 // written for the complex nested case yet.)
389 class InitializeSemaphores : public IRMutator {
390 using IRMutator::visit;
391
392 const Type sema_type = type_of<halide_semaphore_t *>();
393
visit(const LetStmt * op)394 Stmt visit(const LetStmt *op) override {
395 vector<const LetStmt *> frames;
396
397 // Find first op that is of sema_type
398 while (op && op->value.type() != sema_type) {
399 frames.push_back(op);
400 op = op->body.as<LetStmt>();
401 }
402
403 Stmt body;
404 if (op) {
405 body = mutate(op->body);
406 // Peel off any enclosing let expressions from the value
407 vector<pair<string, Expr>> lets;
408 Expr value = op->value;
409 while (const Let *l = value.as<Let>()) {
410 lets.emplace_back(l->name, l->value);
411 value = l->body;
412 }
413 const Call *call = value.as<Call>();
414 if (call && call->name == "halide_make_semaphore") {
415 internal_assert(call->args.size() == 1);
416
417 Expr sema_var = Variable::make(sema_type, op->name);
418 Expr sema_init = Call::make(Int(32), "halide_semaphore_init",
419 {sema_var, call->args[0]}, Call::Extern);
420 Expr sema_allocate = Call::make(sema_type, Call::alloca,
421 {(int)sizeof(halide_semaphore_t)}, Call::Intrinsic);
422 body = Block::make(Evaluate::make(sema_init), std::move(body));
423 body = LetStmt::make(op->name, std::move(sema_allocate), std::move(body));
424
425 // Re-wrap any other lets
426 for (auto it = lets.rbegin(); it != lets.rend(); it++) {
427 body = LetStmt::make(it->first, it->second, std::move(body));
428 }
429 }
430 } else {
431 body = mutate(frames.back()->body);
432 }
433
434 for (auto it = frames.rbegin(); it != frames.rend(); it++) {
435 Expr value = mutate((*it)->value);
436 if (value.same_as((*it)->value) && body.same_as((*it)->body)) {
437 body = *it;
438 } else {
439 body = LetStmt::make((*it)->name, std::move(value), std::move(body));
440 }
441 }
442 return body;
443 }
444
visit(const Call * op)445 Expr visit(const Call *op) override {
446 internal_assert(op->name != "halide_make_semaphore")
447 << "Call to halide_make_semaphore in unexpected place\n";
448 return op;
449 }
450 };
451
452 // Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
453 class TightenProducerConsumerNodes : public IRMutator {
454 using IRMutator::visit;
455
make_producer_consumer(const string & name,bool is_producer,Stmt body,const Scope<int> & scope)456 Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<int> &scope) {
457 if (const LetStmt *let = body.as<LetStmt>()) {
458 if (expr_uses_vars(let->value, scope)) {
459 return ProducerConsumer::make(name, is_producer, body);
460 } else {
461 return LetStmt::make(let->name, let->value, make_producer_consumer(name, is_producer, let->body, scope));
462 }
463 } else if (const Block *block = body.as<Block>()) {
464 // Check which sides it's used on
465 bool first = stmt_uses_vars(block->first, scope);
466 bool rest = stmt_uses_vars(block->rest, scope);
467 if (is_producer) {
468 return ProducerConsumer::make(name, is_producer, body);
469 } else if (first && rest) {
470 return Block::make(make_producer_consumer(name, is_producer, block->first, scope),
471 make_producer_consumer(name, is_producer, block->rest, scope));
472 } else if (first) {
473 return Block::make(make_producer_consumer(name, is_producer, block->first, scope), block->rest);
474 } else if (rest) {
475 return Block::make(block->first, make_producer_consumer(name, is_producer, block->rest, scope));
476 } else {
477 // Used on neither side?!
478 return body;
479 }
480 } else if (const ProducerConsumer *pc = body.as<ProducerConsumer>()) {
481 return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope));
482 } else if (const Realize *r = body.as<Realize>()) {
483 return Realize::make(r->name, r->types, r->memory_type,
484 r->bounds, r->condition,
485 make_producer_consumer(name, is_producer, r->body, scope));
486 } else {
487 return ProducerConsumer::make(name, is_producer, body);
488 }
489 }
490
visit(const ProducerConsumer * op)491 Stmt visit(const ProducerConsumer *op) override {
492 Stmt body = mutate(op->body);
493 Scope<int> scope;
494 scope.push(op->name, 0);
495 Function f = env.find(op->name)->second;
496 if (f.outputs() == 1) {
497 scope.push(op->name + ".buffer", 0);
498 } else {
499 for (int i = 0; i < f.outputs(); i++) {
500 scope.push(op->name + "." + std::to_string(i) + ".buffer", 0);
501 }
502 }
503 return make_producer_consumer(op->name, op->is_producer, body, scope);
504 }
505
506 const map<string, Function> &env;
507
508 public:
TightenProducerConsumerNodes(const map<string,Function> & e)509 TightenProducerConsumerNodes(const map<string, Function> &e)
510 : env(e) {
511 }
512 };
513
514 // Broaden the scope of acquire nodes to pack trailing work into the
515 // same task and to potentially reduce the nesting depth of tasks.
516 class ExpandAcquireNodes : public IRMutator {
517 using IRMutator::visit;
518
visit(const Block * op)519 Stmt visit(const Block *op) override {
520 // Do an entire sequence of blocks in a single visit method to conserve stack space.
521 vector<Stmt> stmts;
522 Stmt result;
523 do {
524 stmts.push_back(mutate(op->first));
525 result = op->rest;
526 } while ((op = result.as<Block>()));
527
528 result = mutate(result);
529
530 vector<pair<Expr, Expr>> semaphores;
531 for (auto it = stmts.rbegin(); it != stmts.rend(); it++) {
532 Stmt s = *it;
533 while (const Acquire *a = s.as<Acquire>()) {
534 semaphores.emplace_back(a->semaphore, a->count);
535 s = a->body;
536 }
537 result = Block::make(s, result);
538 while (!semaphores.empty()) {
539 result = Acquire::make(semaphores.back().first, semaphores.back().second, result);
540 semaphores.pop_back();
541 }
542 }
543
544 return result;
545 }
546
visit(const Realize * op)547 Stmt visit(const Realize *op) override {
548 Stmt body = mutate(op->body);
549 if (const Acquire *a = body.as<Acquire>()) {
550 // Don't do the allocation until we have the
551 // semaphore. Reduces peak memory use.
552 return Acquire::make(a->semaphore, a->count,
553 mutate(Realize::make(op->name, op->types, op->memory_type,
554 op->bounds, op->condition, a->body)));
555 } else {
556 return Realize::make(op->name, op->types, op->memory_type,
557 op->bounds, op->condition, body);
558 }
559 }
560
visit(const LetStmt * op)561 Stmt visit(const LetStmt *op) override {
562 Stmt body = mutate(op->body);
563 const Acquire *a = body.as<Acquire>();
564 if (a &&
565 !expr_uses_var(a->semaphore, op->name) &&
566 !expr_uses_var(a->count, op->name)) {
567 return Acquire::make(a->semaphore, a->count,
568 LetStmt::make(op->name, op->value, a->body));
569 } else {
570 return LetStmt::make(op->name, op->value, body);
571 }
572 }
573
visit(const ProducerConsumer * op)574 Stmt visit(const ProducerConsumer *op) override {
575 Stmt body = mutate(op->body);
576 if (const Acquire *a = body.as<Acquire>()) {
577 return Acquire::make(a->semaphore, a->count,
578 mutate(ProducerConsumer::make(op->name, op->is_producer, a->body)));
579 } else {
580 return ProducerConsumer::make(op->name, op->is_producer, body);
581 }
582 }
583 };
584
585 class TightenForkNodes : public IRMutator {
586 using IRMutator::visit;
587
make_fork(const Stmt & first,const Stmt & rest)588 Stmt make_fork(const Stmt &first, const Stmt &rest) {
589 const LetStmt *lf = first.as<LetStmt>();
590 const LetStmt *lr = rest.as<LetStmt>();
591 const Realize *rf = first.as<Realize>();
592 const Realize *rr = rest.as<Realize>();
593 if (lf && lr &&
594 lf->name == lr->name &&
595 equal(lf->value, lr->value)) {
596 return LetStmt::make(lf->name, lf->value, make_fork(lf->body, lr->body));
597 } else if (lf && !stmt_uses_var(rest, lf->name)) {
598 return LetStmt::make(lf->name, lf->value, make_fork(lf->body, rest));
599 } else if (lr && !stmt_uses_var(first, lr->name)) {
600 return LetStmt::make(lr->name, lr->value, make_fork(first, lr->body));
601 } else if (rf && !stmt_uses_var(rest, rf->name)) {
602 return Realize::make(rf->name, rf->types, rf->memory_type,
603 rf->bounds, rf->condition, make_fork(rf->body, rest));
604 } else if (rr && !stmt_uses_var(first, rr->name)) {
605 return Realize::make(rr->name, rr->types, rr->memory_type,
606 rr->bounds, rr->condition, make_fork(first, rr->body));
607 } else {
608 return Fork::make(first, rest);
609 }
610 }
611
visit(const Fork * op)612 Stmt visit(const Fork *op) override {
613 Stmt first, rest;
614 {
615 ScopedValue<bool> old_in_fork(in_fork, true);
616 first = mutate(op->first);
617 rest = mutate(op->rest);
618 }
619
620 if (is_no_op(first)) {
621 return rest;
622 } else if (is_no_op(rest)) {
623 return first;
624 } else {
625 return make_fork(first, rest);
626 }
627 }
628
629 // This is also a good time to nuke any dangling allocations and lets in the fork children.
visit(const Realize * op)630 Stmt visit(const Realize *op) override {
631 Stmt body = mutate(op->body);
632 if (in_fork && !stmt_uses_var(body, op->name) && !stmt_uses_var(body, op->name + ".buffer")) {
633 return body;
634 } else {
635 return Realize::make(op->name, op->types, op->memory_type,
636 op->bounds, op->condition, body);
637 }
638 }
639
visit(const LetStmt * op)640 Stmt visit(const LetStmt *op) override {
641 Stmt body = mutate(op->body);
642 if (in_fork && !stmt_uses_var(body, op->name)) {
643 return body;
644 } else {
645 return LetStmt::make(op->name, op->value, body);
646 }
647 }
648
649 bool in_fork = false;
650 };
651
652 // TODO: merge semaphores?
653
fork_async_producers(Stmt s,const map<string,Function> & env)654 Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
655 s = TightenProducerConsumerNodes(env).mutate(s);
656 s = ForkAsyncProducers(env).mutate(s);
657 s = ExpandAcquireNodes().mutate(s);
658 s = TightenForkNodes().mutate(s);
659 s = InitializeSemaphores().mutate(s);
660 return s;
661 }
662
663 } // namespace Internal
664 } // namespace Halide
665