1 #include "AsyncProducers.h"
2 #include "ExprUsesVar.h"
3 #include "Function.h"
4 #include "IREquality.h"
5 #include "IRMutator.h"
6 #include "IROperator.h"
7 
8 namespace Halide {
9 namespace Internal {
10 
11 using std::map;
12 using std::pair;
13 using std::set;
14 using std::string;
15 using std::vector;
16 
17 /** A mutator which eagerly folds no-op stmts */
18 class NoOpCollapsingMutator : public IRMutator {
19 protected:
20     using IRMutator::visit;
21 
visit(const LetStmt * op)22     Stmt visit(const LetStmt *op) override {
23         Stmt body = mutate(op->body);
24         if (is_no_op(body)) {
25             return body;
26         } else {
27             return LetStmt::make(op->name, op->value, body);
28         }
29     }
30 
visit(const For * op)31     Stmt visit(const For *op) override {
32         Stmt body = mutate(op->body);
33         if (is_no_op(body)) {
34             return body;
35         } else {
36             return For::make(op->name, op->min, op->extent, op->for_type, op->device_api, body);
37         }
38     }
39 
visit(const Block * op)40     Stmt visit(const Block *op) override {
41         Stmt first = mutate(op->first);
42         Stmt rest = mutate(op->rest);
43         if (is_no_op(first)) {
44             return rest;
45         } else if (is_no_op(rest)) {
46             return first;
47         } else {
48             return Block::make(first, rest);
49         }
50     }
51 
visit(const Fork * op)52     Stmt visit(const Fork *op) override {
53         Stmt first = mutate(op->first);
54         Stmt rest = mutate(op->rest);
55         if (is_no_op(first)) {
56             return rest;
57         } else if (is_no_op(rest)) {
58             return first;
59         } else {
60             return Fork::make(first, rest);
61         }
62     }
63 
visit(const Realize * op)64     Stmt visit(const Realize *op) override {
65         Stmt body = mutate(op->body);
66         if (is_no_op(body)) {
67             return body;
68         } else {
69             return Realize::make(op->name, op->types, op->memory_type,
70                                  op->bounds, op->condition, body);
71         }
72     }
73 
visit(const Allocate * op)74     Stmt visit(const Allocate *op) override {
75         Stmt body = mutate(op->body);
76         if (is_no_op(body)) {
77             return body;
78         } else {
79             return Allocate::make(op->name, op->type, op->memory_type,
80                                   op->extents, op->condition, body,
81                                   op->new_expr, op->free_function);
82         }
83     }
84 
visit(const IfThenElse * op)85     Stmt visit(const IfThenElse *op) override {
86         Stmt then_case = mutate(op->then_case);
87         Stmt else_case = mutate(op->else_case);
88         if (is_no_op(then_case) && is_no_op(else_case)) {
89             return then_case;
90         } else {
91             return IfThenElse::make(op->condition, then_case, else_case);
92         }
93     }
94 
visit(const Atomic * op)95     Stmt visit(const Atomic *op) override {
96         Stmt body = mutate(op->body);
97         if (is_no_op(body)) {
98             return body;
99         } else {
100             return Atomic::make(op->producer_name,
101                                 op->mutex_name,
102                                 std::move(body));
103         }
104     }
105 };
106 
107 class GenerateProducerBody : public NoOpCollapsingMutator {
108     const string &func;
109     vector<Expr> sema;
110 
111     using NoOpCollapsingMutator::visit;
112 
113     // Preserve produce nodes and add synchronization
visit(const ProducerConsumer * op)114     Stmt visit(const ProducerConsumer *op) override {
115         if (op->name == func && op->is_producer) {
116             // Add post-synchronization
117             internal_assert(!sema.empty()) << "Duplicate produce node: " << op->name << "\n";
118             Stmt body = op->body;
119             while (!sema.empty()) {
120                 Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
121                 body = Block::make(body, Evaluate::make(release));
122                 sema.pop_back();
123             }
124             return ProducerConsumer::make_produce(op->name, body);
125         } else {
126             Stmt body = mutate(op->body);
127             if (is_no_op(body) || op->is_producer) {
128                 return body;
129             } else {
130                 return ProducerConsumer::make(op->name, op->is_producer, body);
131             }
132         }
133     }
134 
135     // Other stmt leaves get replaced with no-ops
visit(const Evaluate *)136     Stmt visit(const Evaluate *) override {
137         return Evaluate::make(0);
138     }
139 
visit(const Provide *)140     Stmt visit(const Provide *) override {
141         return Evaluate::make(0);
142     }
143 
visit(const Store * op)144     Stmt visit(const Store *op) override {
145         if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
146             // This is a counter associated with the producer side of a storage-folding semaphore. Keep it.
147             return op;
148         } else {
149             return Evaluate::make(0);
150         }
151     }
152 
visit(const AssertStmt *)153     Stmt visit(const AssertStmt *) override {
154         return Evaluate::make(0);
155     }
156 
visit(const Prefetch *)157     Stmt visit(const Prefetch *) override {
158         return Evaluate::make(0);
159     }
160 
visit(const Acquire * op)161     Stmt visit(const Acquire *op) override {
162         Stmt body = mutate(op->body);
163         const Variable *var = op->semaphore.as<Variable>();
164         internal_assert(var);
165         if (is_no_op(body)) {
166             return body;
167         } else if (starts_with(var->name, func + ".folding_semaphore.")) {
168             // This is a storage-folding semaphore for the func we're producing. Keep it.
169             return Acquire::make(op->semaphore, op->count, body);
170         } else {
171             // This semaphore will end up on both sides of the fork,
172             // so we'd better duplicate it.
173             string cloned_acquire = var->name + unique_name('_');
174             cloned_acquires[var->name] = cloned_acquire;
175             return Acquire::make(Variable::make(type_of<halide_semaphore_t *>(), cloned_acquire), op->count, body);
176         }
177     }
178 
visit(const Atomic * op)179     Stmt visit(const Atomic *op) override {
180         return Evaluate::make(0);
181     }
182 
visit(const Call * op)183     Expr visit(const Call *op) override {
184         if (op->name == "halide_semaphore_init") {
185             internal_assert(op->args.size() == 2);
186             const Variable *var = op->args[0].as<Variable>();
187             internal_assert(var);
188             inner_semaphores.insert(var->name);
189         }
190         return op;
191     }
192 
193     map<string, string> &cloned_acquires;
194     set<string> inner_semaphores;
195 
196 public:
GenerateProducerBody(const string & f,const vector<Expr> & s,map<string,string> & a)197     GenerateProducerBody(const string &f, const vector<Expr> &s, map<string, string> &a)
198         : func(f), sema(s), cloned_acquires(a) {
199     }
200 };
201 
202 class GenerateConsumerBody : public NoOpCollapsingMutator {
203     const string &func;
204     vector<Expr> sema;
205 
206     using NoOpCollapsingMutator::visit;
207 
visit(const ProducerConsumer * op)208     Stmt visit(const ProducerConsumer *op) override {
209         if (op->name == func) {
210             if (op->is_producer) {
211                 // Remove the work entirely
212                 return Evaluate::make(0);
213             } else {
214                 // Synchronize on the work done by the producer before beginning consumption
215                 Expr acquire_sema = sema.back();
216                 sema.pop_back();
217                 return Acquire::make(acquire_sema, 1, op);
218             }
219         } else {
220             return NoOpCollapsingMutator::visit(op);
221         }
222     }
223 
visit(const Allocate * op)224     Stmt visit(const Allocate *op) override {
225         // Don't want to keep the producer's storage-folding tracker - it's dead code on the consumer side
226         if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
227             return mutate(op->body);
228         } else {
229             return NoOpCollapsingMutator::visit(op);
230         }
231     }
232 
visit(const Store * op)233     Stmt visit(const Store *op) override {
234         if (starts_with(op->name, func + ".folding_semaphore.") && ends_with(op->name, ".head")) {
235             return Evaluate::make(0);
236         } else {
237             return NoOpCollapsingMutator::visit(op);
238         }
239     }
240 
visit(const Acquire * op)241     Stmt visit(const Acquire *op) override {
242         // Don't want to duplicate any semaphore acquires.
243         // Ones from folding should go to the producer side.
244         const Variable *var = op->semaphore.as<Variable>();
245         internal_assert(var);
246         if (starts_with(var->name, func + ".folding_semaphore.")) {
247             return mutate(op->body);
248         } else {
249             return NoOpCollapsingMutator::visit(op);
250         }
251     }
252 
253 public:
GenerateConsumerBody(const string & f,const vector<Expr> & s)254     GenerateConsumerBody(const string &f, const vector<Expr> &s)
255         : func(f), sema(s) {
256     }
257 };
258 
259 class CloneAcquire : public IRMutator {
260     using IRMutator::visit;
261 
262     const string &old_name;
263     Expr new_var;
264 
visit(const Evaluate * op)265     Stmt visit(const Evaluate *op) override {
266         const Call *call = op->value.as<Call>();
267         const Variable *var = ((call && !call->args.empty()) ? call->args[0].as<Variable>() : nullptr);
268         if (var && var->name == old_name &&
269             (call->name == "halide_semaphore_release" ||
270              call->name == "halide_semaphore_init")) {
271             vector<Expr> args = call->args;
272             args[0] = new_var;
273             Stmt new_stmt =
274                 Evaluate::make(Call::make(call->type, call->name, args, call->call_type));
275             return Block::make(op, new_stmt);
276         } else {
277             return op;
278         }
279     }
280 
281 public:
CloneAcquire(const string & o,const string & new_name)282     CloneAcquire(const string &o, const string &new_name)
283         : old_name(o) {
284         new_var = Variable::make(type_of<halide_semaphore_t *>(), new_name);
285     }
286 };
287 
288 class CountConsumeNodes : public IRVisitor {
289     const string &func;
290 
291     using IRVisitor::visit;
292 
visit(const ProducerConsumer * op)293     void visit(const ProducerConsumer *op) override {
294         if (op->name == func && !op->is_producer) {
295             count++;
296         }
297         IRVisitor::visit(op);
298     }
299 
300 public:
CountConsumeNodes(const string & f)301     CountConsumeNodes(const string &f)
302         : func(f) {
303     }
304     int count = 0;
305 };
306 
307 class ForkAsyncProducers : public IRMutator {
308     using IRMutator::visit;
309 
310     const map<string, Function> &env;
311 
312     map<string, string> cloned_acquires;
313 
visit(const Realize * op)314     Stmt visit(const Realize *op) override {
315         auto it = env.find(op->name);
316         internal_assert(it != env.end());
317         Function f = it->second;
318         if (f.schedule().async()) {
319             Stmt body = op->body;
320 
321             // Make two copies of the body, one which only does the
322             // producer, and one which only does the consumer. Inject
323             // synchronization to preserve dependencies. Put them in a
324             // task-parallel block.
325 
326             // Make a semaphore per consume node
327             CountConsumeNodes consumes(op->name);
328             body.accept(&consumes);
329 
330             vector<string> sema_names;
331             vector<Expr> sema_vars;
332             for (int i = 0; i < consumes.count; i++) {
333                 sema_names.push_back(op->name + ".semaphore_" + std::to_string(i));
334                 sema_vars.push_back(Variable::make(Handle(), sema_names.back()));
335             }
336 
337             Stmt producer = GenerateProducerBody(op->name, sema_vars, cloned_acquires).mutate(body);
338             Stmt consumer = GenerateConsumerBody(op->name, sema_vars).mutate(body);
339 
340             // Recurse on both sides
341             producer = mutate(producer);
342             consumer = mutate(consumer);
343 
344             // Run them concurrently
345             body = Fork::make(producer, consumer);
346 
347             for (const string &sema_name : sema_names) {
348                 // Make a semaphore on the stack
349                 Expr sema_space = Call::make(type_of<halide_semaphore_t *>(), "halide_make_semaphore",
350                                              {0}, Call::Extern);
351 
352                 // If there's a nested async producer, we may have
353                 // recursively cloned this semaphore inside the mutation
354                 // of the producer and consumer.
355                 auto it = cloned_acquires.find(sema_name);
356                 if (it != cloned_acquires.end()) {
357                     body = CloneAcquire(sema_name, it->second).mutate(body);
358                     body = LetStmt::make(it->second, sema_space, body);
359                 }
360 
361                 body = LetStmt::make(sema_name, sema_space, body);
362             }
363 
364             return Realize::make(op->name, op->types, op->memory_type,
365                                  op->bounds, op->condition, body);
366         } else {
367             return IRMutator::visit(op);
368         }
369     }
370 
371 public:
ForkAsyncProducers(const map<string,Function> & e)372     ForkAsyncProducers(const map<string, Function> &e)
373         : env(e) {
374     }
375 };
376 
377 // Lowers semaphore initialization from a call to
378 // "halide_make_semaphore" to an alloca followed by a call into the
379 // runtime to initialize. If something crashes before releasing a
380 // semaphore, the task system is responsible for propagating the
381 // failure to all branches of the fork. This depends on all semaphore
382 // acquires happening as part of the halide_do_parallel_tasks logic,
383 // not via explicit code in the closure.  The current design for this
384 // does not propagate failures downward to subtasks of a failed
385 // fork. It assumes these will be able to reach completion in spite of
386 // the failure, which remains to be proven. (There is a test for the
387 // simple failure case, error_async_require_fail. One has not been
388 // written for the complex nested case yet.)
389 class InitializeSemaphores : public IRMutator {
390     using IRMutator::visit;
391 
392     const Type sema_type = type_of<halide_semaphore_t *>();
393 
visit(const LetStmt * op)394     Stmt visit(const LetStmt *op) override {
395         vector<const LetStmt *> frames;
396 
397         // Find first op that is of sema_type
398         while (op && op->value.type() != sema_type) {
399             frames.push_back(op);
400             op = op->body.as<LetStmt>();
401         }
402 
403         Stmt body;
404         if (op) {
405             body = mutate(op->body);
406             // Peel off any enclosing let expressions from the value
407             vector<pair<string, Expr>> lets;
408             Expr value = op->value;
409             while (const Let *l = value.as<Let>()) {
410                 lets.emplace_back(l->name, l->value);
411                 value = l->body;
412             }
413             const Call *call = value.as<Call>();
414             if (call && call->name == "halide_make_semaphore") {
415                 internal_assert(call->args.size() == 1);
416 
417                 Expr sema_var = Variable::make(sema_type, op->name);
418                 Expr sema_init = Call::make(Int(32), "halide_semaphore_init",
419                                             {sema_var, call->args[0]}, Call::Extern);
420                 Expr sema_allocate = Call::make(sema_type, Call::alloca,
421                                                 {(int)sizeof(halide_semaphore_t)}, Call::Intrinsic);
422                 body = Block::make(Evaluate::make(sema_init), std::move(body));
423                 body = LetStmt::make(op->name, std::move(sema_allocate), std::move(body));
424 
425                 // Re-wrap any other lets
426                 for (auto it = lets.rbegin(); it != lets.rend(); it++) {
427                     body = LetStmt::make(it->first, it->second, std::move(body));
428                 }
429             }
430         } else {
431             body = mutate(frames.back()->body);
432         }
433 
434         for (auto it = frames.rbegin(); it != frames.rend(); it++) {
435             Expr value = mutate((*it)->value);
436             if (value.same_as((*it)->value) && body.same_as((*it)->body)) {
437                 body = *it;
438             } else {
439                 body = LetStmt::make((*it)->name, std::move(value), std::move(body));
440             }
441         }
442         return body;
443     }
444 
visit(const Call * op)445     Expr visit(const Call *op) override {
446         internal_assert(op->name != "halide_make_semaphore")
447             << "Call to halide_make_semaphore in unexpected place\n";
448         return op;
449     }
450 };
451 
452 // Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
453 class TightenProducerConsumerNodes : public IRMutator {
454     using IRMutator::visit;
455 
make_producer_consumer(const string & name,bool is_producer,Stmt body,const Scope<int> & scope)456     Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<int> &scope) {
457         if (const LetStmt *let = body.as<LetStmt>()) {
458             if (expr_uses_vars(let->value, scope)) {
459                 return ProducerConsumer::make(name, is_producer, body);
460             } else {
461                 return LetStmt::make(let->name, let->value, make_producer_consumer(name, is_producer, let->body, scope));
462             }
463         } else if (const Block *block = body.as<Block>()) {
464             // Check which sides it's used on
465             bool first = stmt_uses_vars(block->first, scope);
466             bool rest = stmt_uses_vars(block->rest, scope);
467             if (is_producer) {
468                 return ProducerConsumer::make(name, is_producer, body);
469             } else if (first && rest) {
470                 return Block::make(make_producer_consumer(name, is_producer, block->first, scope),
471                                    make_producer_consumer(name, is_producer, block->rest, scope));
472             } else if (first) {
473                 return Block::make(make_producer_consumer(name, is_producer, block->first, scope), block->rest);
474             } else if (rest) {
475                 return Block::make(block->first, make_producer_consumer(name, is_producer, block->rest, scope));
476             } else {
477                 // Used on neither side?!
478                 return body;
479             }
480         } else if (const ProducerConsumer *pc = body.as<ProducerConsumer>()) {
481             return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope));
482         } else if (const Realize *r = body.as<Realize>()) {
483             return Realize::make(r->name, r->types, r->memory_type,
484                                  r->bounds, r->condition,
485                                  make_producer_consumer(name, is_producer, r->body, scope));
486         } else {
487             return ProducerConsumer::make(name, is_producer, body);
488         }
489     }
490 
visit(const ProducerConsumer * op)491     Stmt visit(const ProducerConsumer *op) override {
492         Stmt body = mutate(op->body);
493         Scope<int> scope;
494         scope.push(op->name, 0);
495         Function f = env.find(op->name)->second;
496         if (f.outputs() == 1) {
497             scope.push(op->name + ".buffer", 0);
498         } else {
499             for (int i = 0; i < f.outputs(); i++) {
500                 scope.push(op->name + "." + std::to_string(i) + ".buffer", 0);
501             }
502         }
503         return make_producer_consumer(op->name, op->is_producer, body, scope);
504     }
505 
506     const map<string, Function> &env;
507 
508 public:
TightenProducerConsumerNodes(const map<string,Function> & e)509     TightenProducerConsumerNodes(const map<string, Function> &e)
510         : env(e) {
511     }
512 };
513 
514 // Broaden the scope of acquire nodes to pack trailing work into the
515 // same task and to potentially reduce the nesting depth of tasks.
516 class ExpandAcquireNodes : public IRMutator {
517     using IRMutator::visit;
518 
visit(const Block * op)519     Stmt visit(const Block *op) override {
520         // Do an entire sequence of blocks in a single visit method to conserve stack space.
521         vector<Stmt> stmts;
522         Stmt result;
523         do {
524             stmts.push_back(mutate(op->first));
525             result = op->rest;
526         } while ((op = result.as<Block>()));
527 
528         result = mutate(result);
529 
530         vector<pair<Expr, Expr>> semaphores;
531         for (auto it = stmts.rbegin(); it != stmts.rend(); it++) {
532             Stmt s = *it;
533             while (const Acquire *a = s.as<Acquire>()) {
534                 semaphores.emplace_back(a->semaphore, a->count);
535                 s = a->body;
536             }
537             result = Block::make(s, result);
538             while (!semaphores.empty()) {
539                 result = Acquire::make(semaphores.back().first, semaphores.back().second, result);
540                 semaphores.pop_back();
541             }
542         }
543 
544         return result;
545     }
546 
visit(const Realize * op)547     Stmt visit(const Realize *op) override {
548         Stmt body = mutate(op->body);
549         if (const Acquire *a = body.as<Acquire>()) {
550             // Don't do the allocation until we have the
551             // semaphore. Reduces peak memory use.
552             return Acquire::make(a->semaphore, a->count,
553                                  mutate(Realize::make(op->name, op->types, op->memory_type,
554                                                       op->bounds, op->condition, a->body)));
555         } else {
556             return Realize::make(op->name, op->types, op->memory_type,
557                                  op->bounds, op->condition, body);
558         }
559     }
560 
visit(const LetStmt * op)561     Stmt visit(const LetStmt *op) override {
562         Stmt body = mutate(op->body);
563         const Acquire *a = body.as<Acquire>();
564         if (a &&
565             !expr_uses_var(a->semaphore, op->name) &&
566             !expr_uses_var(a->count, op->name)) {
567             return Acquire::make(a->semaphore, a->count,
568                                  LetStmt::make(op->name, op->value, a->body));
569         } else {
570             return LetStmt::make(op->name, op->value, body);
571         }
572     }
573 
visit(const ProducerConsumer * op)574     Stmt visit(const ProducerConsumer *op) override {
575         Stmt body = mutate(op->body);
576         if (const Acquire *a = body.as<Acquire>()) {
577             return Acquire::make(a->semaphore, a->count,
578                                  mutate(ProducerConsumer::make(op->name, op->is_producer, a->body)));
579         } else {
580             return ProducerConsumer::make(op->name, op->is_producer, body);
581         }
582     }
583 };
584 
585 class TightenForkNodes : public IRMutator {
586     using IRMutator::visit;
587 
make_fork(const Stmt & first,const Stmt & rest)588     Stmt make_fork(const Stmt &first, const Stmt &rest) {
589         const LetStmt *lf = first.as<LetStmt>();
590         const LetStmt *lr = rest.as<LetStmt>();
591         const Realize *rf = first.as<Realize>();
592         const Realize *rr = rest.as<Realize>();
593         if (lf && lr &&
594             lf->name == lr->name &&
595             equal(lf->value, lr->value)) {
596             return LetStmt::make(lf->name, lf->value, make_fork(lf->body, lr->body));
597         } else if (lf && !stmt_uses_var(rest, lf->name)) {
598             return LetStmt::make(lf->name, lf->value, make_fork(lf->body, rest));
599         } else if (lr && !stmt_uses_var(first, lr->name)) {
600             return LetStmt::make(lr->name, lr->value, make_fork(first, lr->body));
601         } else if (rf && !stmt_uses_var(rest, rf->name)) {
602             return Realize::make(rf->name, rf->types, rf->memory_type,
603                                  rf->bounds, rf->condition, make_fork(rf->body, rest));
604         } else if (rr && !stmt_uses_var(first, rr->name)) {
605             return Realize::make(rr->name, rr->types, rr->memory_type,
606                                  rr->bounds, rr->condition, make_fork(first, rr->body));
607         } else {
608             return Fork::make(first, rest);
609         }
610     }
611 
visit(const Fork * op)612     Stmt visit(const Fork *op) override {
613         Stmt first, rest;
614         {
615             ScopedValue<bool> old_in_fork(in_fork, true);
616             first = mutate(op->first);
617             rest = mutate(op->rest);
618         }
619 
620         if (is_no_op(first)) {
621             return rest;
622         } else if (is_no_op(rest)) {
623             return first;
624         } else {
625             return make_fork(first, rest);
626         }
627     }
628 
629     // This is also a good time to nuke any dangling allocations and lets in the fork children.
visit(const Realize * op)630     Stmt visit(const Realize *op) override {
631         Stmt body = mutate(op->body);
632         if (in_fork && !stmt_uses_var(body, op->name) && !stmt_uses_var(body, op->name + ".buffer")) {
633             return body;
634         } else {
635             return Realize::make(op->name, op->types, op->memory_type,
636                                  op->bounds, op->condition, body);
637         }
638     }
639 
visit(const LetStmt * op)640     Stmt visit(const LetStmt *op) override {
641         Stmt body = mutate(op->body);
642         if (in_fork && !stmt_uses_var(body, op->name)) {
643             return body;
644         } else {
645             return LetStmt::make(op->name, op->value, body);
646         }
647     }
648 
649     bool in_fork = false;
650 };
651 
652 // TODO: merge semaphores?
653 
fork_async_producers(Stmt s,const map<string,Function> & env)654 Stmt fork_async_producers(Stmt s, const map<string, Function> &env) {
655     s = TightenProducerConsumerNodes(env).mutate(s);
656     s = ForkAsyncProducers(env).mutate(s);
657     s = ExpandAcquireNodes().mutate(s);
658     s = TightenForkNodes().mutate(s);
659     s = InitializeSemaphores().mutate(s);
660     return s;
661 }
662 
663 }  // namespace Internal
664 }  // namespace Halide
665