1 #include <algorithm>
2 #include <utility>
3 
4 #include "CSE.h"
5 #include "CodeGen_GPU_Dev.h"
6 #include "Deinterleave.h"
7 #include "ExprUsesVar.h"
8 #include "IREquality.h"
9 #include "IRMutator.h"
10 #include "IROperator.h"
11 #include "IRPrinter.h"
12 #include "Scope.h"
13 #include "Simplify.h"
14 #include "Solve.h"
15 #include "Substitute.h"
16 #include "VectorizeLoops.h"
17 
18 namespace Halide {
19 namespace Internal {
20 
21 using std::map;
22 using std::pair;
23 using std::string;
24 using std::vector;
25 
26 namespace {
27 
28 // For a given var, replace expressions like shuffle_vector(var, 4)
29 // with var.lane.4
30 class ReplaceShuffleVectors : public IRMutator {
31     string var;
32 
33     using IRMutator::visit;
34 
visit(const Shuffle * op)35     Expr visit(const Shuffle *op) override {
36         const Variable *v;
37         if (op->indices.size() == 1 &&
38             (v = op->vectors[0].as<Variable>()) &&
39             v->name == var) {
40             return Variable::make(op->type, var + ".lane." + std::to_string(op->indices[0]));
41         } else {
42             return IRMutator::visit(op);
43         }
44     }
45 
46 public:
ReplaceShuffleVectors(const string & v)47     ReplaceShuffleVectors(const string &v)
48         : var(v) {
49     }
50 };
51 
52 /** Find the exact max and min lanes of a vector expression. Not
53  * conservative like bounds_of_expr, but uses similar rules for some
54  * common node types where it can be exact. Assumes any vector
55  * variables defined externally also have .min_lane and .max_lane
56  * versions in scope. */
bounds_of_lanes(const Expr & e)57 Interval bounds_of_lanes(const Expr &e) {
58     if (const Add *add = e.as<Add>()) {
59         if (const Broadcast *b = add->b.as<Broadcast>()) {
60             Interval ia = bounds_of_lanes(add->a);
61             return {ia.min + b->value, ia.max + b->value};
62         } else if (const Broadcast *b = add->a.as<Broadcast>()) {
63             Interval ia = bounds_of_lanes(add->b);
64             return {b->value + ia.min, b->value + ia.max};
65         }
66     } else if (const Sub *sub = e.as<Sub>()) {
67         if (const Broadcast *b = sub->b.as<Broadcast>()) {
68             Interval ia = bounds_of_lanes(sub->a);
69             return {ia.min - b->value, ia.max - b->value};
70         } else if (const Broadcast *b = sub->a.as<Broadcast>()) {
71             Interval ia = bounds_of_lanes(sub->b);
72             return {b->value - ia.max, b->value - ia.max};
73         }
74     } else if (const Mul *mul = e.as<Mul>()) {
75         if (const Broadcast *b = mul->b.as<Broadcast>()) {
76             if (is_positive_const(b->value)) {
77                 Interval ia = bounds_of_lanes(mul->a);
78                 return {ia.min * b->value, ia.max * b->value};
79             } else if (is_negative_const(b->value)) {
80                 Interval ia = bounds_of_lanes(mul->a);
81                 return {ia.max * b->value, ia.min * b->value};
82             }
83         } else if (const Broadcast *b = mul->a.as<Broadcast>()) {
84             if (is_positive_const(b->value)) {
85                 Interval ia = bounds_of_lanes(mul->b);
86                 return {b->value * ia.min, b->value * ia.max};
87             } else if (is_negative_const(b->value)) {
88                 Interval ia = bounds_of_lanes(mul->b);
89                 return {b->value * ia.max, b->value * ia.min};
90             }
91         }
92     } else if (const Div *div = e.as<Div>()) {
93         if (const Broadcast *b = div->b.as<Broadcast>()) {
94             if (is_positive_const(b->value)) {
95                 Interval ia = bounds_of_lanes(div->a);
96                 return {ia.min / b->value, ia.max / b->value};
97             } else if (is_negative_const(b->value)) {
98                 Interval ia = bounds_of_lanes(div->a);
99                 return {ia.max / b->value, ia.min / b->value};
100             }
101         }
102     } else if (const And *and_ = e.as<And>()) {
103         if (const Broadcast *b = and_->b.as<Broadcast>()) {
104             Interval ia = bounds_of_lanes(and_->a);
105             return {ia.min && b->value, ia.max && b->value};
106         } else if (const Broadcast *b = and_->a.as<Broadcast>()) {
107             Interval ia = bounds_of_lanes(and_->b);
108             return {ia.min && b->value, ia.max && b->value};
109         }
110     } else if (const Or *or_ = e.as<Or>()) {
111         if (const Broadcast *b = or_->b.as<Broadcast>()) {
112             Interval ia = bounds_of_lanes(or_->a);
113             return {ia.min && b->value, ia.max && b->value};
114         } else if (const Broadcast *b = or_->a.as<Broadcast>()) {
115             Interval ia = bounds_of_lanes(or_->b);
116             return {ia.min && b->value, ia.max && b->value};
117         }
118     } else if (const Min *min = e.as<Min>()) {
119         if (const Broadcast *b = min->b.as<Broadcast>()) {
120             Interval ia = bounds_of_lanes(min->a);
121             return {Min::make(ia.min, b->value), Min::make(ia.max, b->value)};
122         } else if (const Broadcast *b = min->a.as<Broadcast>()) {
123             Interval ia = bounds_of_lanes(min->b);
124             return {Min::make(ia.min, b->value), Min::make(ia.max, b->value)};
125         }
126     } else if (const Max *max = e.as<Max>()) {
127         if (const Broadcast *b = max->b.as<Broadcast>()) {
128             Interval ia = bounds_of_lanes(max->a);
129             return {Max::make(ia.min, b->value), Max::make(ia.max, b->value)};
130         } else if (const Broadcast *b = max->a.as<Broadcast>()) {
131             Interval ia = bounds_of_lanes(max->b);
132             return {Max::make(ia.min, b->value), Max::make(ia.max, b->value)};
133         }
134     } else if (const Not *not_ = e.as<Not>()) {
135         Interval ia = bounds_of_lanes(not_->a);
136         return {!ia.max, !ia.min};
137     } else if (const Ramp *r = e.as<Ramp>()) {
138         Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1);
139         if (is_positive_const(r->stride)) {
140             return {r->base, r->base + last_lane_idx * r->stride};
141         } else if (is_negative_const(r->stride)) {
142             return {r->base + last_lane_idx * r->stride, r->base};
143         }
144     } else if (const Broadcast *b = e.as<Broadcast>()) {
145         return {b->value, b->value};
146     } else if (const Variable *var = e.as<Variable>()) {
147         return {Variable::make(var->type.element_of(), var->name + ".min_lane"),
148                 Variable::make(var->type.element_of(), var->name + ".max_lane")};
149     } else if (const Let *let = e.as<Let>()) {
150         Interval ia = bounds_of_lanes(let->value);
151         Interval ib = bounds_of_lanes(let->body);
152         if (expr_uses_var(ib.min, let->name + ".min_lane")) {
153             ib.min = Let::make(let->name + ".min_lane", ia.min, ib.min);
154         }
155         if (expr_uses_var(ib.max, let->name + ".min_lane")) {
156             ib.max = Let::make(let->name + ".min_lane", ia.min, ib.max);
157         }
158         if (expr_uses_var(ib.min, let->name + ".max_lane")) {
159             ib.min = Let::make(let->name + ".max_lane", ia.max, ib.min);
160         }
161         if (expr_uses_var(ib.max, let->name + ".max_lane")) {
162             ib.max = Let::make(let->name + ".max_lane", ia.max, ib.max);
163         }
164         if (expr_uses_var(ib.min, let->name)) {
165             ib.min = Let::make(let->name, let->value, ib.min);
166         }
167         if (expr_uses_var(ib.max, let->name)) {
168             ib.max = Let::make(let->name, let->value, ib.max);
169         }
170         return ib;
171     }
172 
173     // Take the explicit min and max over the lanes
174     if (e.type().is_bool()) {
175         Expr min_lane = VectorReduce::make(VectorReduce::And, e, 1);
176         Expr max_lane = VectorReduce::make(VectorReduce::Or, e, 1);
177         return {min_lane, max_lane};
178     } else {
179         Expr min_lane = VectorReduce::make(VectorReduce::Min, e, 1);
180         Expr max_lane = VectorReduce::make(VectorReduce::Max, e, 1);
181         return {min_lane, max_lane};
182     }
183 };
184 
185 // A ramp with the lanes repeated (e.g. <0 0 2 2 4 4 6 6>)
186 // TODO(vksnk): With nested vectorization, this will be representable
187 // as a ramp(broadcast(a, repetitions), broadcast(b, repetitions,
188 // lanes)
189 struct InterleavedRamp {
190     Expr base, stride;
191     int lanes, repetitions;
192 };
193 
is_interleaved_ramp(const Expr & e,const Scope<Expr> & scope,InterleavedRamp * result)194 bool is_interleaved_ramp(const Expr &e, const Scope<Expr> &scope, InterleavedRamp *result) {
195     if (const Ramp *r = e.as<Ramp>()) {
196         result->base = r->base;
197         result->stride = r->stride;
198         result->lanes = r->lanes;
199         result->repetitions = 1;
200         return true;
201     } else if (const Broadcast *b = e.as<Broadcast>()) {
202         result->base = b->value;
203         result->stride = 0;
204         result->lanes = b->lanes;
205         result->repetitions = 0;
206         return true;
207     } else if (const Add *add = e.as<Add>()) {
208         InterleavedRamp ra;
209         if (is_interleaved_ramp(add->a, scope, &ra) &&
210             is_interleaved_ramp(add->b, scope, result) &&
211             (ra.repetitions == 0 ||
212              result->repetitions == 0 ||
213              ra.repetitions == result->repetitions)) {
214             result->base = simplify(result->base + ra.base);
215             result->stride = simplify(result->stride + ra.stride);
216             if (!result->repetitions) {
217                 result->repetitions = ra.repetitions;
218             }
219             return true;
220         }
221     } else if (const Sub *sub = e.as<Sub>()) {
222         InterleavedRamp ra;
223         if (is_interleaved_ramp(sub->a, scope, &ra) &&
224             is_interleaved_ramp(sub->b, scope, result) &&
225             (ra.repetitions == 0 ||
226              result->repetitions == 0 ||
227              ra.repetitions == result->repetitions)) {
228             result->base = simplify(ra.base - result->base);
229             result->stride = simplify(ra.stride - result->stride);
230             if (!result->repetitions) {
231                 result->repetitions = ra.repetitions;
232             }
233             return true;
234         }
235     } else if (const Mul *mul = e.as<Mul>()) {
236         const int64_t *b = nullptr;
237         if (is_interleaved_ramp(mul->a, scope, result) &&
238             (b = as_const_int(mul->b))) {
239             result->base = simplify(result->base * (int)(*b));
240             result->stride = simplify(result->stride * (int)(*b));
241             return true;
242         }
243     } else if (const Div *div = e.as<Div>()) {
244         const int64_t *b = nullptr;
245         if (is_interleaved_ramp(div->a, scope, result) &&
246             (b = as_const_int(div->b)) &&
247             is_one(result->stride) &&
248             (result->repetitions == 1 ||
249              result->repetitions == 0) &&
250             can_prove((result->base % (int)(*b)) == 0)) {
251             // TODO: Generalize this. Currently only matches
252             // ramp(base*b, 1, lanes) / b
253             // broadcast(base * b, lanes) / b
254             result->base = simplify(result->base / (int)(*b));
255             result->repetitions *= (int)(*b);
256             return true;
257         }
258     } else if (const Variable *var = e.as<Variable>()) {
259         if (scope.contains(var->name)) {
260             return is_interleaved_ramp(scope.get(var->name), scope, result);
261         }
262     }
263     return false;
264 }
265 
266 // Allocations inside vectorized loops grow an additional inner
267 // dimension to represent the separate copy of the allocation per
268 // vector lane. This means loads and stores to them need to be
269 // rewritten slightly.
270 class RewriteAccessToVectorAlloc : public IRMutator {
271     Expr var;
272     string alloc;
273     int lanes;
274 
275     using IRMutator::visit;
276 
mutate_index(const string & a,Expr index)277     Expr mutate_index(const string &a, Expr index) {
278         index = mutate(index);
279         if (a == alloc) {
280             return index * lanes + var;
281         } else {
282             return index;
283         }
284     }
285 
mutate_alignment(const string & a,const ModulusRemainder & align)286     ModulusRemainder mutate_alignment(const string &a, const ModulusRemainder &align) {
287         if (a == alloc) {
288             return align * lanes;
289         } else {
290             return align;
291         }
292     }
293 
visit(const Load * op)294     Expr visit(const Load *op) override {
295         return Load::make(op->type, op->name, mutate_index(op->name, op->index),
296                           op->image, op->param, mutate(op->predicate), mutate_alignment(op->name, op->alignment));
297     }
298 
visit(const Store * op)299     Stmt visit(const Store *op) override {
300         return Store::make(op->name, mutate(op->value), mutate_index(op->name, op->index),
301                            op->param, mutate(op->predicate), mutate_alignment(op->name, op->alignment));
302     }
303 
304 public:
RewriteAccessToVectorAlloc(const string & v,string a,int l)305     RewriteAccessToVectorAlloc(const string &v, string a, int l)
306         : var(Variable::make(Int(32), v)), alloc(std::move(a)), lanes(l) {
307     }
308 };
309 
310 class UsesGPUVars : public IRVisitor {
311 private:
312     using IRVisitor::visit;
visit(const Variable * op)313     void visit(const Variable *op) override {
314         if (CodeGen_GPU_Dev::is_gpu_var(op->name)) {
315             debug(3) << "Found gpu loop var: " << op->name << "\n";
316             uses_gpu = true;
317         }
318     }
319 
320 public:
321     bool uses_gpu = false;
322 };
323 
uses_gpu_vars(const Expr & s)324 bool uses_gpu_vars(const Expr &s) {
325     UsesGPUVars uses;
326     s.accept(&uses);
327     return uses.uses_gpu;
328 }
329 
330 // Wrap a vectorized predicate around a Load/Store node.
331 class PredicateLoadStore : public IRMutator {
332     string var;
333     Expr vector_predicate;
334     bool in_hexagon;
335     const Target &target;
336     int lanes;
337     bool valid;
338     bool vectorized;
339 
340     using IRMutator::visit;
341 
should_predicate_store_load(int bit_size)342     bool should_predicate_store_load(int bit_size) {
343         if (in_hexagon) {
344             internal_assert(target.features_any_of({Target::HVX_64, Target::HVX_128}))
345                 << "We are inside a hexagon loop, but the target doesn't have hexagon's features\n";
346             return true;
347         } else if (target.arch == Target::X86) {
348             // Should only attempt to predicate store/load if the lane size is
349             // no less than 4
350             // TODO: disabling for now due to trunk LLVM breakage.
351             // See: https://github.com/halide/Halide/issues/3534
352             // return (bit_size == 32) && (lanes >= 4);
353             return false;
354         }
355         // For other architecture, do not predicate vector load/store
356         return false;
357     }
358 
merge_predicate(Expr pred,const Expr & new_pred)359     Expr merge_predicate(Expr pred, const Expr &new_pred) {
360         if (pred.type().lanes() == new_pred.type().lanes()) {
361             Expr res = simplify(pred && new_pred);
362             return res;
363         }
364         valid = false;
365         return pred;
366     }
367 
visit(const Load * op)368     Expr visit(const Load *op) override {
369         valid = valid && should_predicate_store_load(op->type.bits());
370         if (!valid) {
371             return op;
372         }
373 
374         Expr predicate, index;
375         if (!op->index.type().is_scalar()) {
376             internal_assert(op->predicate.type().lanes() == lanes);
377             internal_assert(op->index.type().lanes() == lanes);
378 
379             predicate = mutate(op->predicate);
380             index = mutate(op->index);
381         } else if (expr_uses_var(op->index, var)) {
382             predicate = mutate(Broadcast::make(op->predicate, lanes));
383             index = mutate(Broadcast::make(op->index, lanes));
384         } else {
385             return IRMutator::visit(op);
386         }
387 
388         predicate = merge_predicate(predicate, vector_predicate);
389         if (!valid) {
390             return op;
391         }
392         vectorized = true;
393         return Load::make(op->type, op->name, index, op->image, op->param, predicate, op->alignment);
394     }
395 
visit(const Store * op)396     Stmt visit(const Store *op) override {
397         valid = valid && should_predicate_store_load(op->value.type().bits());
398         if (!valid) {
399             return op;
400         }
401 
402         Expr predicate, value, index;
403         if (!op->index.type().is_scalar()) {
404             internal_assert(op->predicate.type().lanes() == lanes);
405             internal_assert(op->index.type().lanes() == lanes);
406             internal_assert(op->value.type().lanes() == lanes);
407 
408             predicate = mutate(op->predicate);
409             value = mutate(op->value);
410             index = mutate(op->index);
411         } else if (expr_uses_var(op->index, var)) {
412             predicate = mutate(Broadcast::make(op->predicate, lanes));
413             value = mutate(Broadcast::make(op->value, lanes));
414             index = mutate(Broadcast::make(op->index, lanes));
415         } else {
416             return IRMutator::visit(op);
417         }
418 
419         predicate = merge_predicate(predicate, vector_predicate);
420         if (!valid) {
421             return op;
422         }
423         vectorized = true;
424         return Store::make(op->name, value, index, op->param, predicate, op->alignment);
425     }
426 
visit(const Call * op)427     Expr visit(const Call *op) override {
428         // We should not vectorize calls with side-effects
429         valid = valid && op->is_pure();
430         return IRMutator::visit(op);
431     }
432 
433 public:
PredicateLoadStore(string v,const Expr & vpred,bool in_hexagon,const Target & t)434     PredicateLoadStore(string v, const Expr &vpred, bool in_hexagon, const Target &t)
435         : var(std::move(v)), vector_predicate(vpred), in_hexagon(in_hexagon), target(t),
436           lanes(vpred.type().lanes()), valid(true), vectorized(false) {
437         internal_assert(lanes > 1);
438     }
439 
is_vectorized() const440     bool is_vectorized() const {
441         return valid && vectorized;
442     }
443 };
444 
445 // Substitutes a vector for a scalar var in a Stmt. Used on the
446 // body of every vectorized loop.
447 class VectorSubs : public IRMutator {
448     // The var we're vectorizing
449     string var;
450 
451     // What we're replacing it with. Usually a ramp.
452     Expr replacement;
453 
454     const Target &target;
455 
456     bool in_hexagon;  // Are we inside the hexagon loop?
457 
458     // A suffix to attach to widened variables.
459     string widening_suffix;
460 
461     // A scope containing lets and letstmts whose values became
462     // vectors.
463     Scope<Expr> scope;
464 
465     // The same set of Exprs, indexed by the vectorized var name
466     Scope<Expr> vector_scope;
467 
468     // A stack of all containing lets. We need to reinject the scalar
469     // version of them if we scalarize inner code.
470     vector<pair<string, Expr>> containing_lets;
471 
472     // Widen an expression to the given number of lanes.
widen(Expr e,int lanes)473     Expr widen(Expr e, int lanes) {
474         if (e.type().lanes() == lanes) {
475             return e;
476         } else if (e.type().lanes() == 1) {
477             return Broadcast::make(e, lanes);
478         } else {
479             internal_error << "Mismatched vector lanes in VectorSubs\n";
480         }
481         return Expr();
482     }
483 
484     using IRMutator::visit;
485 
visit(const Cast * op)486     Expr visit(const Cast *op) override {
487         Expr value = mutate(op->value);
488         if (value.same_as(op->value)) {
489             return op;
490         } else {
491             Type t = op->type.with_lanes(value.type().lanes());
492             return Cast::make(t, value);
493         }
494     }
495 
visit(const Variable * op)496     Expr visit(const Variable *op) override {
497         string widened_name = op->name + widening_suffix;
498         if (op->name == var) {
499             return replacement;
500         } else if (scope.contains(op->name)) {
501             // If the variable appears in scope then we previously widened
502             // it and we use the new widened name for the variable.
503             return Variable::make(scope.get(op->name).type(), widened_name);
504         } else {
505             return op;
506         }
507     }
508 
509     template<typename T>
mutate_binary_operator(const T * op)510     Expr mutate_binary_operator(const T *op) {
511         Expr a = mutate(op->a), b = mutate(op->b);
512         if (a.same_as(op->a) && b.same_as(op->b)) {
513             return op;
514         } else {
515             int w = std::max(a.type().lanes(), b.type().lanes());
516             return T::make(widen(a, w), widen(b, w));
517         }
518     }
519 
visit(const Add * op)520     Expr visit(const Add *op) override {
521         return mutate_binary_operator(op);
522     }
visit(const Sub * op)523     Expr visit(const Sub *op) override {
524         return mutate_binary_operator(op);
525     }
visit(const Mul * op)526     Expr visit(const Mul *op) override {
527         return mutate_binary_operator(op);
528     }
visit(const Div * op)529     Expr visit(const Div *op) override {
530         return mutate_binary_operator(op);
531     }
visit(const Mod * op)532     Expr visit(const Mod *op) override {
533         return mutate_binary_operator(op);
534     }
visit(const Min * op)535     Expr visit(const Min *op) override {
536         return mutate_binary_operator(op);
537     }
visit(const Max * op)538     Expr visit(const Max *op) override {
539         return mutate_binary_operator(op);
540     }
visit(const EQ * op)541     Expr visit(const EQ *op) override {
542         return mutate_binary_operator(op);
543     }
visit(const NE * op)544     Expr visit(const NE *op) override {
545         return mutate_binary_operator(op);
546     }
visit(const LT * op)547     Expr visit(const LT *op) override {
548         return mutate_binary_operator(op);
549     }
visit(const LE * op)550     Expr visit(const LE *op) override {
551         return mutate_binary_operator(op);
552     }
visit(const GT * op)553     Expr visit(const GT *op) override {
554         return mutate_binary_operator(op);
555     }
visit(const GE * op)556     Expr visit(const GE *op) override {
557         return mutate_binary_operator(op);
558     }
visit(const And * op)559     Expr visit(const And *op) override {
560         return mutate_binary_operator(op);
561     }
visit(const Or * op)562     Expr visit(const Or *op) override {
563         return mutate_binary_operator(op);
564     }
565 
visit(const Select * op)566     Expr visit(const Select *op) override {
567         Expr condition = mutate(op->condition);
568         Expr true_value = mutate(op->true_value);
569         Expr false_value = mutate(op->false_value);
570         if (condition.same_as(op->condition) &&
571             true_value.same_as(op->true_value) &&
572             false_value.same_as(op->false_value)) {
573             return op;
574         } else {
575             int lanes = std::max(true_value.type().lanes(), false_value.type().lanes());
576             lanes = std::max(lanes, condition.type().lanes());
577             // Widen the true and false values, but we don't have to widen the condition
578             true_value = widen(true_value, lanes);
579             false_value = widen(false_value, lanes);
580             return Select::make(condition, true_value, false_value);
581         }
582     }
583 
visit(const Load * op)584     Expr visit(const Load *op) override {
585         Expr predicate = mutate(op->predicate);
586         Expr index = mutate(op->index);
587 
588         if (predicate.same_as(op->predicate) && index.same_as(op->index)) {
589             return op;
590         } else {
591             int w = index.type().lanes();
592             predicate = widen(predicate, w);
593             return Load::make(op->type.with_lanes(w), op->name, index, op->image,
594                               op->param, predicate, op->alignment);
595         }
596     }
597 
visit(const Call * op)598     Expr visit(const Call *op) override {
599         // Widen the call by changing the lanes of all of its
600         // arguments and its return type
601         vector<Expr> new_args(op->args.size());
602         bool changed = false;
603 
604         // Mutate the args
605         int max_lanes = 0;
606         for (size_t i = 0; i < op->args.size(); i++) {
607             Expr old_arg = op->args[i];
608             Expr new_arg = mutate(old_arg);
609             if (!new_arg.same_as(old_arg)) changed = true;
610             new_args[i] = new_arg;
611             max_lanes = std::max(new_arg.type().lanes(), max_lanes);
612         }
613 
614         if (!changed) {
615             return op;
616         } else if (op->name == Call::trace) {
617             const int64_t *event = as_const_int(op->args[6]);
618             internal_assert(event != nullptr);
619             if (*event == halide_trace_begin_realization || *event == halide_trace_end_realization) {
620                 // Call::trace vectorizes uniquely for begin/end realization, because the coordinates
621                 // for these are actually min/extent pairs; we need to maintain the proper dimensionality
622                 // count and instead aggregate the widened values into a single pair.
623                 for (size_t i = 1; i <= 2; i++) {
624                     const Call *call = new_args[i].as<Call>();
625                     internal_assert(call && call->is_intrinsic(Call::make_struct));
626                     if (i == 1) {
627                         // values should always be empty for these events
628                         internal_assert(call->args.empty());
629                         continue;
630                     }
631                     vector<Expr> call_args(call->args.size());
632                     for (size_t j = 0; j < call_args.size(); j += 2) {
633                         Expr min_v = widen(call->args[j], max_lanes);
634                         Expr extent_v = widen(call->args[j + 1], max_lanes);
635                         Expr min_scalar = extract_lane(min_v, 0);
636                         Expr max_scalar = min_scalar + extract_lane(extent_v, 0);
637                         for (int k = 1; k < max_lanes; ++k) {
638                             Expr min_k = extract_lane(min_v, k);
639                             Expr extent_k = extract_lane(extent_v, k);
640                             min_scalar = min(min_scalar, min_k);
641                             max_scalar = max(max_scalar, min_k + extent_k);
642                         }
643                         call_args[j] = min_scalar;
644                         call_args[j + 1] = max_scalar - min_scalar;
645                     }
646                     new_args[i] = Call::make(call->type.element_of(), Call::make_struct, call_args, Call::Intrinsic);
647                 }
648             } else {
649                 // Call::trace vectorizes uniquely, because we want a
650                 // single trace call for the entire vector, instead of
651                 // scalarizing the call and tracing each element.
652                 for (size_t i = 1; i <= 2; i++) {
653                     // Each struct should be a struct-of-vectors, not a
654                     // vector of distinct structs.
655                     const Call *call = new_args[i].as<Call>();
656                     internal_assert(call && call->is_intrinsic(Call::make_struct));
657                     // Widen the call args to have the same lanes as the max lanes found
658                     vector<Expr> call_args(call->args.size());
659                     for (size_t j = 0; j < call_args.size(); j++) {
660                         call_args[j] = widen(call->args[j], max_lanes);
661                     }
662                     new_args[i] = Call::make(call->type.element_of(), Call::make_struct,
663                                              call_args, Call::Intrinsic);
664                 }
665                 // One of the arguments to the trace helper
666                 // records the number of vector lanes in the type being
667                 // stored.
668                 new_args[5] = max_lanes;
669                 // One of the arguments to the trace helper
670                 // records the number entries in the coordinates (which we just widened)
671                 if (max_lanes > 1) {
672                     new_args[9] = new_args[9] * max_lanes;
673                 }
674             }
675             return Call::make(op->type, Call::trace, new_args, op->call_type);
676         } else {
677             // Widen the args to have the same lanes as the max lanes found
678             for (size_t i = 0; i < new_args.size(); i++) {
679                 new_args[i] = widen(new_args[i], max_lanes);
680             }
681             return Call::make(op->type.with_lanes(max_lanes), op->name, new_args,
682                               op->call_type, op->func, op->value_index, op->image, op->param);
683         }
684     }
685 
visit(const Let * op)686     Expr visit(const Let *op) override {
687 
688         // Vectorize the let value and check to see if it was vectorized by
689         // this mutator. The type of the expression might already be vector
690         // width.
691         Expr mutated_value = mutate(op->value);
692         bool was_vectorized = (!op->value.type().is_vector() &&
693                                mutated_value.type().is_vector());
694 
695         // If the value was vectorized by this mutator, add a new name to
696         // the scope for the vectorized value expression.
697         string vectorized_name;
698         if (was_vectorized) {
699             vectorized_name = op->name + widening_suffix;
700             scope.push(op->name, mutated_value);
701             vector_scope.push(vectorized_name, mutated_value);
702         }
703 
704         Expr mutated_body = mutate(op->body);
705 
706         InterleavedRamp ir;
707         if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) {
708             return substitute(vectorized_name, mutated_value, mutated_body);
709         } else if (mutated_value.same_as(op->value) &&
710                    mutated_body.same_as(op->body)) {
711             return op;
712         } else if (was_vectorized) {
713             scope.pop(op->name);
714             vector_scope.pop(vectorized_name);
715             return Let::make(vectorized_name, mutated_value, mutated_body);
716         } else {
717             return Let::make(op->name, mutated_value, mutated_body);
718         }
719     }
720 
visit(const LetStmt * op)721     Stmt visit(const LetStmt *op) override {
722         Expr mutated_value = mutate(op->value);
723         string mutated_name = op->name;
724 
725         // Check if the value was vectorized by this mutator.
726         bool was_vectorized = (!op->value.type().is_vector() &&
727                                mutated_value.type().is_vector());
728 
729         if (was_vectorized) {
730             mutated_name += widening_suffix;
731             scope.push(op->name, mutated_value);
732             vector_scope.push(mutated_name, mutated_value);
733             // Also keep track of the original let, in case inner code scalarizes.
734             containing_lets.emplace_back(op->name, op->value);
735         }
736 
737         Stmt mutated_body = mutate(op->body);
738 
739         if (was_vectorized) {
740             containing_lets.pop_back();
741             scope.pop(op->name);
742             vector_scope.pop(mutated_name);
743 
744             // Inner code might have extracted my lanes using
745             // extract_lane, which introduces a shuffle_vector. If
746             // so we should define separate lets for the lanes and
747             // get it to use those instead.
748             mutated_body = ReplaceShuffleVectors(mutated_name).mutate(mutated_body);
749 
750             // Check if inner code wants my individual lanes.
751             Type t = mutated_value.type();
752             for (int i = 0; i < t.lanes(); i++) {
753                 string lane_name = mutated_name + ".lane." + std::to_string(i);
754                 if (stmt_uses_var(mutated_body, lane_name)) {
755                     mutated_body =
756                         LetStmt::make(lane_name, extract_lane(mutated_value, i), mutated_body);
757                 }
758             }
759 
760             // Inner code may also have wanted my max or min lane
761             bool uses_min_lane = stmt_uses_var(mutated_body, mutated_name + ".min_lane");
762             bool uses_max_lane = stmt_uses_var(mutated_body, mutated_name + ".max_lane");
763 
764             if (uses_min_lane || uses_max_lane) {
765                 Interval i = bounds_of_lanes(mutated_value);
766 
767                 if (uses_min_lane) {
768                     mutated_body =
769                         LetStmt::make(mutated_name + ".min_lane", i.min, mutated_body);
770                 }
771 
772                 if (uses_max_lane) {
773                     mutated_body =
774                         LetStmt::make(mutated_name + ".max_lane", i.max, mutated_body);
775                 }
776             }
777         }
778 
779         InterleavedRamp ir;
780         if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) {
781             return substitute(mutated_name, mutated_value, mutated_body);
782         } else if (mutated_value.same_as(op->value) &&
783                    mutated_body.same_as(op->body)) {
784             return op;
785         } else {
786             return LetStmt::make(mutated_name, mutated_value, mutated_body);
787         }
788     }
789 
visit(const Provide * op)790     Stmt visit(const Provide *op) override {
791         vector<Expr> new_args(op->args.size());
792         vector<Expr> new_values(op->values.size());
793         bool changed = false;
794 
795         // Mutate the args
796         int max_lanes = 0;
797         for (size_t i = 0; i < op->args.size(); i++) {
798             Expr old_arg = op->args[i];
799             Expr new_arg = mutate(old_arg);
800             if (!new_arg.same_as(old_arg)) changed = true;
801             new_args[i] = new_arg;
802             max_lanes = std::max(new_arg.type().lanes(), max_lanes);
803         }
804 
805         for (size_t i = 0; i < op->args.size(); i++) {
806             Expr old_value = op->values[i];
807             Expr new_value = mutate(old_value);
808             if (!new_value.same_as(old_value)) changed = true;
809             new_values[i] = new_value;
810             max_lanes = std::max(new_value.type().lanes(), max_lanes);
811         }
812 
813         if (!changed) {
814             return op;
815         } else {
816             // Widen the args to have the same lanes as the max lanes found
817             for (size_t i = 0; i < new_args.size(); i++) {
818                 new_args[i] = widen(new_args[i], max_lanes);
819             }
820             for (size_t i = 0; i < new_values.size(); i++) {
821                 new_values[i] = widen(new_values[i], max_lanes);
822             }
823             return Provide::make(op->name, new_values, new_args);
824         }
825     }
826 
visit(const Store * op)827     Stmt visit(const Store *op) override {
828         Expr predicate = mutate(op->predicate);
829         Expr value = mutate(op->value);
830         Expr index = mutate(op->index);
831 
832         if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
833             return op;
834         } else {
835             int lanes = std::max(predicate.type().lanes(), std::max(value.type().lanes(), index.type().lanes()));
836             return Store::make(op->name, widen(value, lanes), widen(index, lanes),
837                                op->param, widen(predicate, lanes), op->alignment);
838         }
839     }
840 
visit(const AssertStmt * op)841     Stmt visit(const AssertStmt *op) override {
842         return (op->condition.type().lanes() > 1) ? scalarize(op) : op;
843     }
844 
visit(const IfThenElse * op)845     Stmt visit(const IfThenElse *op) override {
846         Expr cond = mutate(op->condition);
847         int lanes = cond.type().lanes();
848         debug(3) << "Vectorizing over " << var << "\n"
849                  << "Old: " << op->condition << "\n"
850                  << "New: " << cond << "\n";
851 
852         Stmt then_case = mutate(op->then_case);
853         Stmt else_case = mutate(op->else_case);
854 
855         if (lanes > 1) {
856             // We have an if statement with a vector condition,
857             // which would mean control flow divergence within the
858             // SIMD lanes.
859 
860             bool vectorize_predicate = !uses_gpu_vars(cond);
861             Stmt predicated_stmt;
862             if (vectorize_predicate) {
863                 PredicateLoadStore p(var, cond, in_hexagon, target);
864                 predicated_stmt = p.mutate(then_case);
865                 vectorize_predicate = p.is_vectorized();
866             }
867             if (vectorize_predicate && else_case.defined()) {
868                 PredicateLoadStore p(var, !cond, in_hexagon, target);
869                 predicated_stmt = Block::make(predicated_stmt, p.mutate(else_case));
870                 vectorize_predicate = p.is_vectorized();
871             }
872 
873             debug(4) << "IfThenElse should vectorize predicate over var " << var << "? " << vectorize_predicate << "; cond: " << cond << "\n";
874             debug(4) << "Predicated stmt:\n"
875                      << predicated_stmt << "\n";
876 
877             // First check if the condition is marked as likely.
878             const Call *c = cond.as<Call>();
879             if (c && (c->is_intrinsic(Call::likely) ||
880                       c->is_intrinsic(Call::likely_if_innermost))) {
881 
882                 // The meaning of the likely intrinsic is that
883                 // Halide should optimize for the case in which
884                 // *every* likely value is true. We can do that by
885                 // generating a scalar condition that checks if
886                 // the least-true lane is true.
887                 Expr all_true = bounds_of_lanes(c->args[0]).min;
888 
889                 // Wrap it in the same flavor of likely
890                 all_true = Call::make(Bool(), c->name,
891                                       {all_true}, Call::PureIntrinsic);
892 
893                 if (!vectorize_predicate) {
894                     // We should strip the likelies from the case
895                     // that's going to scalarize, because it's no
896                     // longer likely.
897                     Stmt without_likelies =
898                         IfThenElse::make(op->condition.as<Call>()->args[0],
899                                          op->then_case, op->else_case);
900                     Stmt stmt =
901                         IfThenElse::make(all_true,
902                                          then_case,
903                                          scalarize(without_likelies));
904                     debug(4) << "...With all_true likely: \n"
905                              << stmt << "\n";
906                     return stmt;
907                 } else {
908                     Stmt stmt =
909                         IfThenElse::make(all_true,
910                                          then_case,
911                                          predicated_stmt);
912                     debug(4) << "...Predicated IfThenElse: \n"
913                              << stmt << "\n";
914                     return stmt;
915                 }
916             } else {
917                 // It's some arbitrary vector condition.
918                 if (!vectorize_predicate) {
919                     debug(4) << "...Scalarizing vector predicate: \n"
920                              << Stmt(op) << "\n";
921                     return scalarize(op);
922                 } else {
923                     Stmt stmt = predicated_stmt;
924                     debug(4) << "...Predicated IfThenElse: \n"
925                              << stmt << "\n";
926                     return stmt;
927                 }
928             }
929         } else {
930             // It's an if statement on a scalar, we're ok to vectorize the innards.
931             debug(3) << "Not scalarizing if then else\n";
932             if (cond.same_as(op->condition) &&
933                 then_case.same_as(op->then_case) &&
934                 else_case.same_as(op->else_case)) {
935                 return op;
936             } else {
937                 return IfThenElse::make(cond, then_case, else_case);
938             }
939         }
940     }
941 
visit(const For * op)942     Stmt visit(const For *op) override {
943         ForType for_type = op->for_type;
944         if (for_type == ForType::Vectorized) {
945             user_warning << "Warning: Encountered vector for loop over " << op->name
946                          << " inside vector for loop over " << var << "."
947                          << " Ignoring the vectorize directive for the inner for loop.\n";
948             for_type = ForType::Serial;
949         }
950 
951         Expr min = mutate(op->min);
952         Expr extent = mutate(op->extent);
953 
954         Stmt body = op->body;
955 
956         if (min.type().is_vector()) {
957             // Rebase the loop to zero and try again
958             Expr var = Variable::make(Int(32), op->name);
959             Stmt body = substitute(op->name, var + op->min, op->body);
960             Stmt transformed = For::make(op->name, 0, op->extent, for_type, op->device_api, body);
961             return mutate(transformed);
962         }
963 
964         if (extent.type().is_vector()) {
965             // We'll iterate up to the max over the lanes, but
966             // inject an if statement inside the loop that stops
967             // each lane from going too far.
968 
969             extent = bounds_of_lanes(extent).max;
970             Expr var = Variable::make(Int(32), op->name);
971             body = IfThenElse::make(likely(var < op->min + op->extent), body);
972         }
973 
974         body = mutate(body);
975 
976         if (min.same_as(op->min) &&
977             extent.same_as(op->extent) &&
978             body.same_as(op->body) &&
979             for_type == op->for_type) {
980             return op;
981         } else {
982             return For::make(op->name, min, extent, for_type, op->device_api, body);
983         }
984     }
985 
visit(const Allocate * op)986     Stmt visit(const Allocate *op) override {
987         vector<Expr> new_extents;
988         Expr new_expr;
989 
990         int lanes = replacement.type().lanes();
991 
992         // The new expanded dimension is innermost.
993         new_extents.emplace_back(lanes);
994 
995         for (size_t i = 0; i < op->extents.size(); i++) {
996             Expr extent = mutate(op->extents[i]);
997             // For vector sizes, take the max over the lanes. Note
998             // that we haven't changed the strides, which also may
999             // vary per lane. This is a bit weird, but the way we
1000             // set up the vectorized memory means that lanes can't
1001             // clobber each others' memory, so it doesn't matter.
1002             if (extent.type().is_vector()) {
1003                 extent = bounds_of_lanes(extent).max;
1004             }
1005             new_extents.push_back(extent);
1006         }
1007 
1008         if (op->new_expr.defined()) {
1009             new_expr = mutate(op->new_expr);
1010             user_assert(new_expr.type().is_scalar())
1011                 << "Cannot vectorize an allocation with a varying new_expr per vector lane.\n";
1012         }
1013 
1014         Stmt body = op->body;
1015 
1016         // Rewrite loads and stores to this allocation like so:
1017         // foo[x] -> foo[x*lanes + v]
1018         string v = unique_name('v');
1019         body = RewriteAccessToVectorAlloc(v, op->name, lanes).mutate(body);
1020 
1021         scope.push(v, Ramp::make(0, 1, lanes));
1022         body = mutate(body);
1023         scope.pop(v);
1024 
1025         // Replace the widened 'v' with the actual ramp
1026         // foo[x*lanes + widened_v] -> foo[x*lanes + ramp(0, 1, lanes)]
1027         body = substitute(v + widening_suffix, Ramp::make(0, 1, lanes), body);
1028 
1029         // The variable itself could still exist inside an inner scalarized block.
1030         body = substitute(v, Variable::make(Int(32), var), body);
1031 
1032         return Allocate::make(op->name, op->type, op->memory_type, new_extents, op->condition, body, new_expr, op->free_function);
1033     }
1034 
visit(const Atomic * op)1035     Stmt visit(const Atomic *op) override {
1036         // Recognize a few special cases that we can handle as within-vector reduction trees.
1037         do {
1038             if (!op->mutex_name.empty()) {
1039                 // We can't vectorize over a mutex
1040                 break;
1041             }
1042 
1043             // f[x] = f[x] <op> y
1044             const Store *store = op->body.as<Store>();
1045             if (!store) break;
1046 
1047             VectorReduce::Operator reduce_op = VectorReduce::Add;
1048             Expr a, b;
1049             if (const Add *add = store->value.as<Add>()) {
1050                 a = add->a;
1051                 b = add->b;
1052                 reduce_op = VectorReduce::Add;
1053             } else if (const Mul *mul = store->value.as<Mul>()) {
1054                 a = mul->a;
1055                 b = mul->b;
1056                 reduce_op = VectorReduce::Mul;
1057             } else if (const Min *min = store->value.as<Min>()) {
1058                 a = min->a;
1059                 b = min->b;
1060                 reduce_op = VectorReduce::Min;
1061             } else if (const Max *max = store->value.as<Max>()) {
1062                 a = max->a;
1063                 b = max->b;
1064                 reduce_op = VectorReduce::Max;
1065             } else if (const Cast *cast_op = store->value.as<Cast>()) {
1066                 if (cast_op->type.element_of() == UInt(8) &&
1067                     cast_op->value.type().is_bool()) {
1068                     if (const And *and_op = cast_op->value.as<And>()) {
1069                         a = and_op->a;
1070                         b = and_op->b;
1071                         reduce_op = VectorReduce::And;
1072                     } else if (const Or *or_op = cast_op->value.as<Or>()) {
1073                         a = or_op->a;
1074                         b = or_op->b;
1075                         reduce_op = VectorReduce::Or;
1076                     }
1077                 }
1078             }
1079 
1080             if (!a.defined() || !b.defined()) {
1081                 break;
1082             }
1083 
1084             // Bools get cast to uint8 for storage. Strip off that
1085             // cast around any load.
1086             if (b.type().is_bool()) {
1087                 const Cast *cast_op = b.as<Cast>();
1088                 if (cast_op) {
1089                     b = cast_op->value;
1090                 }
1091             }
1092             if (a.type().is_bool()) {
1093                 const Cast *cast_op = b.as<Cast>();
1094                 if (cast_op) {
1095                     a = cast_op->value;
1096                 }
1097             }
1098 
1099             if (a.as<Variable>() && !b.as<Variable>()) {
1100                 std::swap(a, b);
1101             }
1102 
1103             // We require b to be a var, because it should have been lifted.
1104             const Variable *var_b = b.as<Variable>();
1105             const Load *load_a = a.as<Load>();
1106 
1107             if (!var_b ||
1108                 !scope.contains(var_b->name) ||
1109                 !load_a ||
1110                 load_a->name != store->name ||
1111                 !is_one(load_a->predicate) ||
1112                 !is_one(store->predicate)) {
1113                 break;
1114             }
1115 
1116             b = scope.get(var_b->name);
1117             Expr store_index = mutate(store->index);
1118             Expr load_index = mutate(load_a->index);
1119 
1120             // The load and store indices must be the same interleaved
1121             // ramp (or the same scalar, in the total reduction case).
1122             InterleavedRamp store_ir, load_ir;
1123             Expr test;
1124             if (store_index.type().is_scalar()) {
1125                 test = simplify(load_index == store_index);
1126             } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) &&
1127                        is_interleaved_ramp(load_index, vector_scope, &load_ir) &&
1128                        store_ir.repetitions == load_ir.repetitions &&
1129                        store_ir.lanes == load_ir.lanes) {
1130                 test = simplify(store_ir.base == load_ir.base &&
1131                                 store_ir.stride == load_ir.stride);
1132             }
1133 
1134             if (!test.defined()) {
1135                 break;
1136             }
1137 
1138             if (is_zero(test)) {
1139                 break;
1140             } else if (!is_one(test)) {
1141                 // TODO: try harder by substituting in more things in scope
1142                 break;
1143             }
1144 
1145             int output_lanes = 1;
1146             if (store_index.type().is_scalar()) {
1147                 // The index doesn't depend on the value being
1148                 // vectorized, so it's a total reduction.
1149 
1150                 b = VectorReduce::make(reduce_op, b, 1);
1151             } else {
1152 
1153                 output_lanes = store_index.type().lanes() / store_ir.repetitions;
1154 
1155                 store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes);
1156                 b = VectorReduce::make(reduce_op, b, output_lanes);
1157             }
1158 
1159             Expr new_load = Load::make(load_a->type.with_lanes(output_lanes),
1160                                        load_a->name, store_index, load_a->image,
1161                                        load_a->param, const_true(output_lanes),
1162                                        ModulusRemainder{});
1163 
1164             switch (reduce_op) {
1165             case VectorReduce::Add:
1166                 b = new_load + b;
1167                 break;
1168             case VectorReduce::Mul:
1169                 b = new_load * b;
1170                 break;
1171             case VectorReduce::Min:
1172                 b = min(new_load, b);
1173                 break;
1174             case VectorReduce::Max:
1175                 b = max(new_load, b);
1176                 break;
1177             case VectorReduce::And:
1178                 b = cast(new_load.type(), cast(b.type(), new_load) && b);
1179                 break;
1180             case VectorReduce::Or:
1181                 b = cast(new_load.type(), cast(b.type(), new_load) || b);
1182                 break;
1183             }
1184 
1185             Stmt s = Store::make(store->name, b, store_index, store->param,
1186                                  const_true(b.type().lanes()), store->alignment);
1187 
1188             // We may still need the atomic node, if there was more
1189             // parallelism than just the vectorization.
1190             s = Atomic::make(op->producer_name, op->mutex_name, s);
1191 
1192             return s;
1193         } while (0);
1194 
1195         // In the general case, if a whole stmt has to be done
1196         // atomically, we need to serialize.
1197         return scalarize(op);
1198     }
1199 
scalarize(Stmt s)1200     Stmt scalarize(Stmt s) {
1201         // Wrap a serial loop around it. Maybe LLVM will have
1202         // better luck vectorizing it.
1203 
1204         // We'll need the original scalar versions of any containing lets.
1205         for (size_t i = containing_lets.size(); i > 0; i--) {
1206             const auto &l = containing_lets[i - 1];
1207             s = LetStmt::make(l.first, l.second, s);
1208         }
1209 
1210         const Ramp *r = replacement.as<Ramp>();
1211         internal_assert(r) << "Expected replacement in VectorSubs to be a ramp\n";
1212         return For::make(var, r->base, r->lanes, ForType::Serial, DeviceAPI::None, s);
1213     }
1214 
scalarize(Expr e)1215     Expr scalarize(Expr e) {
1216         // This method returns a select tree that produces a vector lanes
1217         // result expression
1218 
1219         Expr result;
1220         int lanes = replacement.type().lanes();
1221 
1222         for (int i = lanes - 1; i >= 0; --i) {
1223             // Hide all the vector let values in scope with a scalar version
1224             // in the appropriate lane.
1225             for (Scope<Expr>::const_iterator iter = scope.cbegin(); iter != scope.cend(); ++iter) {
1226                 string name = iter.name() + ".lane." + std::to_string(i);
1227                 Expr lane = extract_lane(iter.value(), i);
1228                 e = substitute(iter.name(), Variable::make(lane.type(), name), e);
1229             }
1230 
1231             // Replace uses of the vectorized variable with the extracted
1232             // lane expression
1233             e = substitute(var, i, e);
1234 
1235             if (i == lanes - 1) {
1236                 result = Broadcast::make(e, lanes);
1237             } else {
1238                 Expr cond = (replacement == Broadcast::make(i, lanes));
1239                 result = Select::make(cond, Broadcast::make(e, lanes), result);
1240             }
1241         }
1242 
1243         return result;
1244     }
1245 
1246 public:
VectorSubs(string v,Expr r,bool in_hexagon,const Target & t)1247     VectorSubs(string v, Expr r, bool in_hexagon, const Target &t)
1248         : var(std::move(v)), replacement(std::move(r)), target(t), in_hexagon(in_hexagon) {
1249         widening_suffix = ".x" + std::to_string(replacement.type().lanes());
1250     }
1251 };  // namespace
1252 
1253 class FindVectorizableExprsInAtomicNode : public IRMutator {
1254     // An Atomic node protects all accesses to a given buffer. We
1255     // consider a name "poisoned" if it depends on an access to this
1256     // buffer. We can't lift or vectorize anything that has been
1257     // poisoned.
1258     Scope<> poisoned_names;
1259     bool poison = false;
1260 
1261     using IRMutator::visit;
1262 
1263     template<typename T>
visit_let(const T * op)1264     const T *visit_let(const T *op) {
1265         mutate(op->value);
1266         ScopedBinding<> bind_if(poison, poisoned_names, op->name);
1267         mutate(op->body);
1268         return op;
1269     }
1270 
visit(const LetStmt * op)1271     Stmt visit(const LetStmt *op) override {
1272         return visit_let(op);
1273     }
1274 
visit(const Let * op)1275     Expr visit(const Let *op) override {
1276         return visit_let(op);
1277     }
1278 
visit(const Load * op)1279     Expr visit(const Load *op) override {
1280         // Even if the load is bad, maybe we can lift the index
1281         IRMutator::visit(op);
1282 
1283         poison |= poisoned_names.contains(op->name);
1284         return op;
1285     }
1286 
visit(const Variable * op)1287     Expr visit(const Variable *op) override {
1288         poison = poisoned_names.contains(op->name);
1289         return op;
1290     }
1291 
visit(const Store * op)1292     Stmt visit(const Store *op) override {
1293         // A store poisons all subsequent loads, but loads before the
1294         // first store can be lifted.
1295         mutate(op->index);
1296         mutate(op->value);
1297         poisoned_names.push(op->name);
1298         return op;
1299     }
1300 
visit(const Call * op)1301     Expr visit(const Call *op) override {
1302         IRMutator::visit(op);
1303         poison |= !op->is_pure();
1304         return op;
1305     }
1306 
1307 public:
1308     using IRMutator::mutate;
1309 
mutate(const Expr & e)1310     Expr mutate(const Expr &e) override {
1311         bool old_poison = poison;
1312         poison = false;
1313         IRMutator::mutate(e);
1314         if (!poison) {
1315             liftable.insert(e);
1316         }
1317         poison |= old_poison;
1318         // We're not actually mutating anything. This class is only a
1319         // mutator so that we can override a generic mutate() method.
1320         return e;
1321     }
1322 
FindVectorizableExprsInAtomicNode(const string & buf,const map<string,Function> & env)1323     FindVectorizableExprsInAtomicNode(const string &buf, const map<string, Function> &env) {
1324         poisoned_names.push(buf);
1325         auto it = env.find(buf);
1326         if (it != env.end()) {
1327             // Handle tuples
1328             size_t n = it->second.values().size();
1329             if (n > 1) {
1330                 for (size_t i = 0; i < n; i++) {
1331                     poisoned_names.push(buf + "." + std::to_string(i));
1332                 }
1333             }
1334         }
1335     }
1336 
1337     std::set<Expr, ExprCompare> liftable;
1338 };
1339 
1340 class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator {
1341     const std::set<Expr, ExprCompare> &liftable;
1342 
1343     using IRMutator::visit;
1344 
1345     template<typename StmtOrExpr, typename LetStmtOrLet>
1346     StmtOrExpr visit_let(const LetStmtOrLet *op) {
1347         if (liftable.count(op->value)) {
1348             // Lift it under its current name to avoid having to
1349             // rewrite the variables in other lifted exprs.
1350             // TODO: duplicate non-overlapping liftable let stmts due to unrolling.
1351             lifted.emplace_back(op->name, op->value);
1352             return mutate(op->body);
1353         } else {
1354             return IRMutator::visit(op);
1355         }
1356     }
1357 
visit(const LetStmt * op)1358     Stmt visit(const LetStmt *op) override {
1359         return visit_let<Stmt>(op);
1360     }
1361 
visit(const Let * op)1362     Expr visit(const Let *op) override {
1363         return visit_let<Expr>(op);
1364     }
1365 
1366 public:
1367     map<Expr, string, IRDeepCompare> already_lifted;
1368     vector<pair<string, Expr>> lifted;
1369 
1370     using IRMutator::mutate;
1371 
mutate(const Expr & e)1372     Expr mutate(const Expr &e) override {
1373         if (liftable.count(e) && !is_const(e) && !e.as<Variable>()) {
1374             auto it = already_lifted.find(e);
1375             string name;
1376             if (it != already_lifted.end()) {
1377                 name = it->second;
1378             } else {
1379                 name = unique_name('t');
1380                 lifted.emplace_back(name, e);
1381                 already_lifted.emplace(e, name);
1382             }
1383             return Variable::make(e.type(), name);
1384         } else {
1385             return IRMutator::mutate(e);
1386         }
1387     }
1388 
LiftVectorizableExprsOutOfSingleAtomicNode(const std::set<Expr,ExprCompare> & liftable)1389     LiftVectorizableExprsOutOfSingleAtomicNode(const std::set<Expr, ExprCompare> &liftable)
1390         : liftable(liftable) {
1391     }
1392 };
1393 
1394 class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator {
1395     using IRMutator::visit;
1396 
visit(const Atomic * op)1397     Stmt visit(const Atomic *op) override {
1398         FindVectorizableExprsInAtomicNode finder(op->producer_name, env);
1399         finder.mutate(op->body);
1400         LiftVectorizableExprsOutOfSingleAtomicNode lifter(finder.liftable);
1401         Stmt new_body = lifter.mutate(op->body);
1402         new_body = Atomic::make(op->producer_name, op->mutex_name, new_body);
1403         while (!lifter.lifted.empty()) {
1404             auto p = lifter.lifted.back();
1405             new_body = LetStmt::make(p.first, p.second, new_body);
1406             lifter.lifted.pop_back();
1407         }
1408         return new_body;
1409     }
1410 
1411     const map<string, Function> &env;
1412 
1413 public:
LiftVectorizableExprsOutOfAllAtomicNodes(const map<string,Function> & env)1414     LiftVectorizableExprsOutOfAllAtomicNodes(const map<string, Function> &env)
1415         : env(env) {
1416     }
1417 };
1418 
1419 // Vectorize all loops marked as such in a Stmt
1420 class VectorizeLoops : public IRMutator {
1421     const Target &target;
1422     bool in_hexagon;
1423 
1424     using IRMutator::visit;
1425 
visit(const For * for_loop)1426     Stmt visit(const For *for_loop) override {
1427         bool old_in_hexagon = in_hexagon;
1428         if (for_loop->device_api == DeviceAPI::Hexagon) {
1429             in_hexagon = true;
1430         }
1431 
1432         Stmt stmt;
1433         if (for_loop->for_type == ForType::Vectorized) {
1434             const IntImm *extent = for_loop->extent.as<IntImm>();
1435             if (!extent || extent->value <= 1) {
1436                 user_error << "Loop over " << for_loop->name
1437                            << " has extent " << for_loop->extent
1438                            << ". Can only vectorize loops over a "
1439                            << "constant extent > 1\n";
1440             }
1441 
1442             // Replace the var with a ramp within the body
1443             Expr for_var = Variable::make(Int(32), for_loop->name);
1444             Expr replacement = Ramp::make(for_loop->min, 1, extent->value);
1445             stmt = VectorSubs(for_loop->name, replacement, in_hexagon, target).mutate(for_loop->body);
1446         } else {
1447             stmt = IRMutator::visit(for_loop);
1448         }
1449 
1450         if (for_loop->device_api == DeviceAPI::Hexagon) {
1451             in_hexagon = old_in_hexagon;
1452         }
1453 
1454         return stmt;
1455     }
1456 
1457 public:
VectorizeLoops(const Target & t)1458     VectorizeLoops(const Target &t)
1459         : target(t), in_hexagon(false) {
1460     }
1461 };
1462 
1463 /** Check if all stores in a Stmt are to names in a given scope. Used
1464     by RemoveUnnecessaryAtomics below. */
1465 class AllStoresInScope : public IRVisitor {
1466     using IRVisitor::visit;
visit(const Store * op)1467     void visit(const Store *op) override {
1468         result = result && s.contains(op->name);
1469     }
1470 
1471 public:
1472     bool result = true;
1473     const Scope<> &s;
AllStoresInScope(const Scope<> & s)1474     AllStoresInScope(const Scope<> &s)
1475         : s(s) {
1476     }
1477 };
all_stores_in_scope(const Stmt & stmt,const Scope<> & scope)1478 bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) {
1479     AllStoresInScope checker(scope);
1480     stmt.accept(&checker);
1481     return checker.result;
1482 }
1483 
1484 /** Drop any atomic nodes protecting buffers that are only accessed
1485  * from a single thread. */
1486 class RemoveUnnecessaryAtomics : public IRMutator {
1487     using IRMutator::visit;
1488 
1489     // Allocations made from within this same thread
1490     bool in_thread = false;
1491     Scope<> local_allocs;
1492 
visit(const Allocate * op)1493     Stmt visit(const Allocate *op) override {
1494         ScopedBinding<> bind(local_allocs, op->name);
1495         return IRMutator::visit(op);
1496     }
1497 
visit(const Atomic * op)1498     Stmt visit(const Atomic *op) override {
1499         if (!in_thread || all_stores_in_scope(op->body, local_allocs)) {
1500             return mutate(op->body);
1501         } else {
1502             return op;
1503         }
1504     }
1505 
visit(const For * op)1506     Stmt visit(const For *op) override {
1507         if (is_parallel(op->for_type)) {
1508             ScopedValue<bool> old_in_thread(in_thread, true);
1509             Scope<> old_local_allocs;
1510             old_local_allocs.swap(local_allocs);
1511             Stmt s = IRMutator::visit(op);
1512             old_local_allocs.swap(local_allocs);
1513             return s;
1514         } else {
1515             return IRMutator::visit(op);
1516         }
1517     }
1518 };
1519 
1520 }  // namespace
1521 
vectorize_loops(const Stmt & stmt,const map<string,Function> & env,const Target & t)1522 Stmt vectorize_loops(const Stmt &stmt, const map<string, Function> &env, const Target &t) {
1523     // Limit the scope of atomic nodes to just the necessary stuff.
1524     // TODO: Should this be an earlier pass? It's probably a good idea
1525     // for non-vectorizing stuff too.
1526     Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(stmt);
1527     s = VectorizeLoops(t).mutate(s);
1528     s = RemoveUnnecessaryAtomics().mutate(s);
1529     return s;
1530 }
1531 
1532 }  // namespace Internal
1533 }  // namespace Halide
1534