1 #include <algorithm>
2 #include <utility>
3
4 #include "CSE.h"
5 #include "CodeGen_GPU_Dev.h"
6 #include "Deinterleave.h"
7 #include "ExprUsesVar.h"
8 #include "IREquality.h"
9 #include "IRMutator.h"
10 #include "IROperator.h"
11 #include "IRPrinter.h"
12 #include "Scope.h"
13 #include "Simplify.h"
14 #include "Solve.h"
15 #include "Substitute.h"
16 #include "VectorizeLoops.h"
17
18 namespace Halide {
19 namespace Internal {
20
21 using std::map;
22 using std::pair;
23 using std::string;
24 using std::vector;
25
26 namespace {
27
28 // For a given var, replace expressions like shuffle_vector(var, 4)
29 // with var.lane.4
30 class ReplaceShuffleVectors : public IRMutator {
31 string var;
32
33 using IRMutator::visit;
34
visit(const Shuffle * op)35 Expr visit(const Shuffle *op) override {
36 const Variable *v;
37 if (op->indices.size() == 1 &&
38 (v = op->vectors[0].as<Variable>()) &&
39 v->name == var) {
40 return Variable::make(op->type, var + ".lane." + std::to_string(op->indices[0]));
41 } else {
42 return IRMutator::visit(op);
43 }
44 }
45
46 public:
ReplaceShuffleVectors(const string & v)47 ReplaceShuffleVectors(const string &v)
48 : var(v) {
49 }
50 };
51
52 /** Find the exact max and min lanes of a vector expression. Not
53 * conservative like bounds_of_expr, but uses similar rules for some
54 * common node types where it can be exact. Assumes any vector
55 * variables defined externally also have .min_lane and .max_lane
56 * versions in scope. */
bounds_of_lanes(const Expr & e)57 Interval bounds_of_lanes(const Expr &e) {
58 if (const Add *add = e.as<Add>()) {
59 if (const Broadcast *b = add->b.as<Broadcast>()) {
60 Interval ia = bounds_of_lanes(add->a);
61 return {ia.min + b->value, ia.max + b->value};
62 } else if (const Broadcast *b = add->a.as<Broadcast>()) {
63 Interval ia = bounds_of_lanes(add->b);
64 return {b->value + ia.min, b->value + ia.max};
65 }
66 } else if (const Sub *sub = e.as<Sub>()) {
67 if (const Broadcast *b = sub->b.as<Broadcast>()) {
68 Interval ia = bounds_of_lanes(sub->a);
69 return {ia.min - b->value, ia.max - b->value};
70 } else if (const Broadcast *b = sub->a.as<Broadcast>()) {
71 Interval ia = bounds_of_lanes(sub->b);
72 return {b->value - ia.max, b->value - ia.max};
73 }
74 } else if (const Mul *mul = e.as<Mul>()) {
75 if (const Broadcast *b = mul->b.as<Broadcast>()) {
76 if (is_positive_const(b->value)) {
77 Interval ia = bounds_of_lanes(mul->a);
78 return {ia.min * b->value, ia.max * b->value};
79 } else if (is_negative_const(b->value)) {
80 Interval ia = bounds_of_lanes(mul->a);
81 return {ia.max * b->value, ia.min * b->value};
82 }
83 } else if (const Broadcast *b = mul->a.as<Broadcast>()) {
84 if (is_positive_const(b->value)) {
85 Interval ia = bounds_of_lanes(mul->b);
86 return {b->value * ia.min, b->value * ia.max};
87 } else if (is_negative_const(b->value)) {
88 Interval ia = bounds_of_lanes(mul->b);
89 return {b->value * ia.max, b->value * ia.min};
90 }
91 }
92 } else if (const Div *div = e.as<Div>()) {
93 if (const Broadcast *b = div->b.as<Broadcast>()) {
94 if (is_positive_const(b->value)) {
95 Interval ia = bounds_of_lanes(div->a);
96 return {ia.min / b->value, ia.max / b->value};
97 } else if (is_negative_const(b->value)) {
98 Interval ia = bounds_of_lanes(div->a);
99 return {ia.max / b->value, ia.min / b->value};
100 }
101 }
102 } else if (const And *and_ = e.as<And>()) {
103 if (const Broadcast *b = and_->b.as<Broadcast>()) {
104 Interval ia = bounds_of_lanes(and_->a);
105 return {ia.min && b->value, ia.max && b->value};
106 } else if (const Broadcast *b = and_->a.as<Broadcast>()) {
107 Interval ia = bounds_of_lanes(and_->b);
108 return {ia.min && b->value, ia.max && b->value};
109 }
110 } else if (const Or *or_ = e.as<Or>()) {
111 if (const Broadcast *b = or_->b.as<Broadcast>()) {
112 Interval ia = bounds_of_lanes(or_->a);
113 return {ia.min && b->value, ia.max && b->value};
114 } else if (const Broadcast *b = or_->a.as<Broadcast>()) {
115 Interval ia = bounds_of_lanes(or_->b);
116 return {ia.min && b->value, ia.max && b->value};
117 }
118 } else if (const Min *min = e.as<Min>()) {
119 if (const Broadcast *b = min->b.as<Broadcast>()) {
120 Interval ia = bounds_of_lanes(min->a);
121 return {Min::make(ia.min, b->value), Min::make(ia.max, b->value)};
122 } else if (const Broadcast *b = min->a.as<Broadcast>()) {
123 Interval ia = bounds_of_lanes(min->b);
124 return {Min::make(ia.min, b->value), Min::make(ia.max, b->value)};
125 }
126 } else if (const Max *max = e.as<Max>()) {
127 if (const Broadcast *b = max->b.as<Broadcast>()) {
128 Interval ia = bounds_of_lanes(max->a);
129 return {Max::make(ia.min, b->value), Max::make(ia.max, b->value)};
130 } else if (const Broadcast *b = max->a.as<Broadcast>()) {
131 Interval ia = bounds_of_lanes(max->b);
132 return {Max::make(ia.min, b->value), Max::make(ia.max, b->value)};
133 }
134 } else if (const Not *not_ = e.as<Not>()) {
135 Interval ia = bounds_of_lanes(not_->a);
136 return {!ia.max, !ia.min};
137 } else if (const Ramp *r = e.as<Ramp>()) {
138 Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1);
139 if (is_positive_const(r->stride)) {
140 return {r->base, r->base + last_lane_idx * r->stride};
141 } else if (is_negative_const(r->stride)) {
142 return {r->base + last_lane_idx * r->stride, r->base};
143 }
144 } else if (const Broadcast *b = e.as<Broadcast>()) {
145 return {b->value, b->value};
146 } else if (const Variable *var = e.as<Variable>()) {
147 return {Variable::make(var->type.element_of(), var->name + ".min_lane"),
148 Variable::make(var->type.element_of(), var->name + ".max_lane")};
149 } else if (const Let *let = e.as<Let>()) {
150 Interval ia = bounds_of_lanes(let->value);
151 Interval ib = bounds_of_lanes(let->body);
152 if (expr_uses_var(ib.min, let->name + ".min_lane")) {
153 ib.min = Let::make(let->name + ".min_lane", ia.min, ib.min);
154 }
155 if (expr_uses_var(ib.max, let->name + ".min_lane")) {
156 ib.max = Let::make(let->name + ".min_lane", ia.min, ib.max);
157 }
158 if (expr_uses_var(ib.min, let->name + ".max_lane")) {
159 ib.min = Let::make(let->name + ".max_lane", ia.max, ib.min);
160 }
161 if (expr_uses_var(ib.max, let->name + ".max_lane")) {
162 ib.max = Let::make(let->name + ".max_lane", ia.max, ib.max);
163 }
164 if (expr_uses_var(ib.min, let->name)) {
165 ib.min = Let::make(let->name, let->value, ib.min);
166 }
167 if (expr_uses_var(ib.max, let->name)) {
168 ib.max = Let::make(let->name, let->value, ib.max);
169 }
170 return ib;
171 }
172
173 // Take the explicit min and max over the lanes
174 if (e.type().is_bool()) {
175 Expr min_lane = VectorReduce::make(VectorReduce::And, e, 1);
176 Expr max_lane = VectorReduce::make(VectorReduce::Or, e, 1);
177 return {min_lane, max_lane};
178 } else {
179 Expr min_lane = VectorReduce::make(VectorReduce::Min, e, 1);
180 Expr max_lane = VectorReduce::make(VectorReduce::Max, e, 1);
181 return {min_lane, max_lane};
182 }
183 };
184
185 // A ramp with the lanes repeated (e.g. <0 0 2 2 4 4 6 6>)
186 // TODO(vksnk): With nested vectorization, this will be representable
187 // as a ramp(broadcast(a, repetitions), broadcast(b, repetitions,
188 // lanes)
189 struct InterleavedRamp {
190 Expr base, stride;
191 int lanes, repetitions;
192 };
193
is_interleaved_ramp(const Expr & e,const Scope<Expr> & scope,InterleavedRamp * result)194 bool is_interleaved_ramp(const Expr &e, const Scope<Expr> &scope, InterleavedRamp *result) {
195 if (const Ramp *r = e.as<Ramp>()) {
196 result->base = r->base;
197 result->stride = r->stride;
198 result->lanes = r->lanes;
199 result->repetitions = 1;
200 return true;
201 } else if (const Broadcast *b = e.as<Broadcast>()) {
202 result->base = b->value;
203 result->stride = 0;
204 result->lanes = b->lanes;
205 result->repetitions = 0;
206 return true;
207 } else if (const Add *add = e.as<Add>()) {
208 InterleavedRamp ra;
209 if (is_interleaved_ramp(add->a, scope, &ra) &&
210 is_interleaved_ramp(add->b, scope, result) &&
211 (ra.repetitions == 0 ||
212 result->repetitions == 0 ||
213 ra.repetitions == result->repetitions)) {
214 result->base = simplify(result->base + ra.base);
215 result->stride = simplify(result->stride + ra.stride);
216 if (!result->repetitions) {
217 result->repetitions = ra.repetitions;
218 }
219 return true;
220 }
221 } else if (const Sub *sub = e.as<Sub>()) {
222 InterleavedRamp ra;
223 if (is_interleaved_ramp(sub->a, scope, &ra) &&
224 is_interleaved_ramp(sub->b, scope, result) &&
225 (ra.repetitions == 0 ||
226 result->repetitions == 0 ||
227 ra.repetitions == result->repetitions)) {
228 result->base = simplify(ra.base - result->base);
229 result->stride = simplify(ra.stride - result->stride);
230 if (!result->repetitions) {
231 result->repetitions = ra.repetitions;
232 }
233 return true;
234 }
235 } else if (const Mul *mul = e.as<Mul>()) {
236 const int64_t *b = nullptr;
237 if (is_interleaved_ramp(mul->a, scope, result) &&
238 (b = as_const_int(mul->b))) {
239 result->base = simplify(result->base * (int)(*b));
240 result->stride = simplify(result->stride * (int)(*b));
241 return true;
242 }
243 } else if (const Div *div = e.as<Div>()) {
244 const int64_t *b = nullptr;
245 if (is_interleaved_ramp(div->a, scope, result) &&
246 (b = as_const_int(div->b)) &&
247 is_one(result->stride) &&
248 (result->repetitions == 1 ||
249 result->repetitions == 0) &&
250 can_prove((result->base % (int)(*b)) == 0)) {
251 // TODO: Generalize this. Currently only matches
252 // ramp(base*b, 1, lanes) / b
253 // broadcast(base * b, lanes) / b
254 result->base = simplify(result->base / (int)(*b));
255 result->repetitions *= (int)(*b);
256 return true;
257 }
258 } else if (const Variable *var = e.as<Variable>()) {
259 if (scope.contains(var->name)) {
260 return is_interleaved_ramp(scope.get(var->name), scope, result);
261 }
262 }
263 return false;
264 }
265
266 // Allocations inside vectorized loops grow an additional inner
267 // dimension to represent the separate copy of the allocation per
268 // vector lane. This means loads and stores to them need to be
269 // rewritten slightly.
270 class RewriteAccessToVectorAlloc : public IRMutator {
271 Expr var;
272 string alloc;
273 int lanes;
274
275 using IRMutator::visit;
276
mutate_index(const string & a,Expr index)277 Expr mutate_index(const string &a, Expr index) {
278 index = mutate(index);
279 if (a == alloc) {
280 return index * lanes + var;
281 } else {
282 return index;
283 }
284 }
285
mutate_alignment(const string & a,const ModulusRemainder & align)286 ModulusRemainder mutate_alignment(const string &a, const ModulusRemainder &align) {
287 if (a == alloc) {
288 return align * lanes;
289 } else {
290 return align;
291 }
292 }
293
visit(const Load * op)294 Expr visit(const Load *op) override {
295 return Load::make(op->type, op->name, mutate_index(op->name, op->index),
296 op->image, op->param, mutate(op->predicate), mutate_alignment(op->name, op->alignment));
297 }
298
visit(const Store * op)299 Stmt visit(const Store *op) override {
300 return Store::make(op->name, mutate(op->value), mutate_index(op->name, op->index),
301 op->param, mutate(op->predicate), mutate_alignment(op->name, op->alignment));
302 }
303
304 public:
RewriteAccessToVectorAlloc(const string & v,string a,int l)305 RewriteAccessToVectorAlloc(const string &v, string a, int l)
306 : var(Variable::make(Int(32), v)), alloc(std::move(a)), lanes(l) {
307 }
308 };
309
310 class UsesGPUVars : public IRVisitor {
311 private:
312 using IRVisitor::visit;
visit(const Variable * op)313 void visit(const Variable *op) override {
314 if (CodeGen_GPU_Dev::is_gpu_var(op->name)) {
315 debug(3) << "Found gpu loop var: " << op->name << "\n";
316 uses_gpu = true;
317 }
318 }
319
320 public:
321 bool uses_gpu = false;
322 };
323
uses_gpu_vars(const Expr & s)324 bool uses_gpu_vars(const Expr &s) {
325 UsesGPUVars uses;
326 s.accept(&uses);
327 return uses.uses_gpu;
328 }
329
330 // Wrap a vectorized predicate around a Load/Store node.
331 class PredicateLoadStore : public IRMutator {
332 string var;
333 Expr vector_predicate;
334 bool in_hexagon;
335 const Target ⌖
336 int lanes;
337 bool valid;
338 bool vectorized;
339
340 using IRMutator::visit;
341
should_predicate_store_load(int bit_size)342 bool should_predicate_store_load(int bit_size) {
343 if (in_hexagon) {
344 internal_assert(target.features_any_of({Target::HVX_64, Target::HVX_128}))
345 << "We are inside a hexagon loop, but the target doesn't have hexagon's features\n";
346 return true;
347 } else if (target.arch == Target::X86) {
348 // Should only attempt to predicate store/load if the lane size is
349 // no less than 4
350 // TODO: disabling for now due to trunk LLVM breakage.
351 // See: https://github.com/halide/Halide/issues/3534
352 // return (bit_size == 32) && (lanes >= 4);
353 return false;
354 }
355 // For other architecture, do not predicate vector load/store
356 return false;
357 }
358
merge_predicate(Expr pred,const Expr & new_pred)359 Expr merge_predicate(Expr pred, const Expr &new_pred) {
360 if (pred.type().lanes() == new_pred.type().lanes()) {
361 Expr res = simplify(pred && new_pred);
362 return res;
363 }
364 valid = false;
365 return pred;
366 }
367
visit(const Load * op)368 Expr visit(const Load *op) override {
369 valid = valid && should_predicate_store_load(op->type.bits());
370 if (!valid) {
371 return op;
372 }
373
374 Expr predicate, index;
375 if (!op->index.type().is_scalar()) {
376 internal_assert(op->predicate.type().lanes() == lanes);
377 internal_assert(op->index.type().lanes() == lanes);
378
379 predicate = mutate(op->predicate);
380 index = mutate(op->index);
381 } else if (expr_uses_var(op->index, var)) {
382 predicate = mutate(Broadcast::make(op->predicate, lanes));
383 index = mutate(Broadcast::make(op->index, lanes));
384 } else {
385 return IRMutator::visit(op);
386 }
387
388 predicate = merge_predicate(predicate, vector_predicate);
389 if (!valid) {
390 return op;
391 }
392 vectorized = true;
393 return Load::make(op->type, op->name, index, op->image, op->param, predicate, op->alignment);
394 }
395
visit(const Store * op)396 Stmt visit(const Store *op) override {
397 valid = valid && should_predicate_store_load(op->value.type().bits());
398 if (!valid) {
399 return op;
400 }
401
402 Expr predicate, value, index;
403 if (!op->index.type().is_scalar()) {
404 internal_assert(op->predicate.type().lanes() == lanes);
405 internal_assert(op->index.type().lanes() == lanes);
406 internal_assert(op->value.type().lanes() == lanes);
407
408 predicate = mutate(op->predicate);
409 value = mutate(op->value);
410 index = mutate(op->index);
411 } else if (expr_uses_var(op->index, var)) {
412 predicate = mutate(Broadcast::make(op->predicate, lanes));
413 value = mutate(Broadcast::make(op->value, lanes));
414 index = mutate(Broadcast::make(op->index, lanes));
415 } else {
416 return IRMutator::visit(op);
417 }
418
419 predicate = merge_predicate(predicate, vector_predicate);
420 if (!valid) {
421 return op;
422 }
423 vectorized = true;
424 return Store::make(op->name, value, index, op->param, predicate, op->alignment);
425 }
426
visit(const Call * op)427 Expr visit(const Call *op) override {
428 // We should not vectorize calls with side-effects
429 valid = valid && op->is_pure();
430 return IRMutator::visit(op);
431 }
432
433 public:
PredicateLoadStore(string v,const Expr & vpred,bool in_hexagon,const Target & t)434 PredicateLoadStore(string v, const Expr &vpred, bool in_hexagon, const Target &t)
435 : var(std::move(v)), vector_predicate(vpred), in_hexagon(in_hexagon), target(t),
436 lanes(vpred.type().lanes()), valid(true), vectorized(false) {
437 internal_assert(lanes > 1);
438 }
439
is_vectorized() const440 bool is_vectorized() const {
441 return valid && vectorized;
442 }
443 };
444
445 // Substitutes a vector for a scalar var in a Stmt. Used on the
446 // body of every vectorized loop.
447 class VectorSubs : public IRMutator {
448 // The var we're vectorizing
449 string var;
450
451 // What we're replacing it with. Usually a ramp.
452 Expr replacement;
453
454 const Target ⌖
455
456 bool in_hexagon; // Are we inside the hexagon loop?
457
458 // A suffix to attach to widened variables.
459 string widening_suffix;
460
461 // A scope containing lets and letstmts whose values became
462 // vectors.
463 Scope<Expr> scope;
464
465 // The same set of Exprs, indexed by the vectorized var name
466 Scope<Expr> vector_scope;
467
468 // A stack of all containing lets. We need to reinject the scalar
469 // version of them if we scalarize inner code.
470 vector<pair<string, Expr>> containing_lets;
471
472 // Widen an expression to the given number of lanes.
widen(Expr e,int lanes)473 Expr widen(Expr e, int lanes) {
474 if (e.type().lanes() == lanes) {
475 return e;
476 } else if (e.type().lanes() == 1) {
477 return Broadcast::make(e, lanes);
478 } else {
479 internal_error << "Mismatched vector lanes in VectorSubs\n";
480 }
481 return Expr();
482 }
483
484 using IRMutator::visit;
485
visit(const Cast * op)486 Expr visit(const Cast *op) override {
487 Expr value = mutate(op->value);
488 if (value.same_as(op->value)) {
489 return op;
490 } else {
491 Type t = op->type.with_lanes(value.type().lanes());
492 return Cast::make(t, value);
493 }
494 }
495
visit(const Variable * op)496 Expr visit(const Variable *op) override {
497 string widened_name = op->name + widening_suffix;
498 if (op->name == var) {
499 return replacement;
500 } else if (scope.contains(op->name)) {
501 // If the variable appears in scope then we previously widened
502 // it and we use the new widened name for the variable.
503 return Variable::make(scope.get(op->name).type(), widened_name);
504 } else {
505 return op;
506 }
507 }
508
509 template<typename T>
mutate_binary_operator(const T * op)510 Expr mutate_binary_operator(const T *op) {
511 Expr a = mutate(op->a), b = mutate(op->b);
512 if (a.same_as(op->a) && b.same_as(op->b)) {
513 return op;
514 } else {
515 int w = std::max(a.type().lanes(), b.type().lanes());
516 return T::make(widen(a, w), widen(b, w));
517 }
518 }
519
visit(const Add * op)520 Expr visit(const Add *op) override {
521 return mutate_binary_operator(op);
522 }
visit(const Sub * op)523 Expr visit(const Sub *op) override {
524 return mutate_binary_operator(op);
525 }
visit(const Mul * op)526 Expr visit(const Mul *op) override {
527 return mutate_binary_operator(op);
528 }
visit(const Div * op)529 Expr visit(const Div *op) override {
530 return mutate_binary_operator(op);
531 }
visit(const Mod * op)532 Expr visit(const Mod *op) override {
533 return mutate_binary_operator(op);
534 }
visit(const Min * op)535 Expr visit(const Min *op) override {
536 return mutate_binary_operator(op);
537 }
visit(const Max * op)538 Expr visit(const Max *op) override {
539 return mutate_binary_operator(op);
540 }
visit(const EQ * op)541 Expr visit(const EQ *op) override {
542 return mutate_binary_operator(op);
543 }
visit(const NE * op)544 Expr visit(const NE *op) override {
545 return mutate_binary_operator(op);
546 }
visit(const LT * op)547 Expr visit(const LT *op) override {
548 return mutate_binary_operator(op);
549 }
visit(const LE * op)550 Expr visit(const LE *op) override {
551 return mutate_binary_operator(op);
552 }
visit(const GT * op)553 Expr visit(const GT *op) override {
554 return mutate_binary_operator(op);
555 }
visit(const GE * op)556 Expr visit(const GE *op) override {
557 return mutate_binary_operator(op);
558 }
visit(const And * op)559 Expr visit(const And *op) override {
560 return mutate_binary_operator(op);
561 }
visit(const Or * op)562 Expr visit(const Or *op) override {
563 return mutate_binary_operator(op);
564 }
565
visit(const Select * op)566 Expr visit(const Select *op) override {
567 Expr condition = mutate(op->condition);
568 Expr true_value = mutate(op->true_value);
569 Expr false_value = mutate(op->false_value);
570 if (condition.same_as(op->condition) &&
571 true_value.same_as(op->true_value) &&
572 false_value.same_as(op->false_value)) {
573 return op;
574 } else {
575 int lanes = std::max(true_value.type().lanes(), false_value.type().lanes());
576 lanes = std::max(lanes, condition.type().lanes());
577 // Widen the true and false values, but we don't have to widen the condition
578 true_value = widen(true_value, lanes);
579 false_value = widen(false_value, lanes);
580 return Select::make(condition, true_value, false_value);
581 }
582 }
583
visit(const Load * op)584 Expr visit(const Load *op) override {
585 Expr predicate = mutate(op->predicate);
586 Expr index = mutate(op->index);
587
588 if (predicate.same_as(op->predicate) && index.same_as(op->index)) {
589 return op;
590 } else {
591 int w = index.type().lanes();
592 predicate = widen(predicate, w);
593 return Load::make(op->type.with_lanes(w), op->name, index, op->image,
594 op->param, predicate, op->alignment);
595 }
596 }
597
visit(const Call * op)598 Expr visit(const Call *op) override {
599 // Widen the call by changing the lanes of all of its
600 // arguments and its return type
601 vector<Expr> new_args(op->args.size());
602 bool changed = false;
603
604 // Mutate the args
605 int max_lanes = 0;
606 for (size_t i = 0; i < op->args.size(); i++) {
607 Expr old_arg = op->args[i];
608 Expr new_arg = mutate(old_arg);
609 if (!new_arg.same_as(old_arg)) changed = true;
610 new_args[i] = new_arg;
611 max_lanes = std::max(new_arg.type().lanes(), max_lanes);
612 }
613
614 if (!changed) {
615 return op;
616 } else if (op->name == Call::trace) {
617 const int64_t *event = as_const_int(op->args[6]);
618 internal_assert(event != nullptr);
619 if (*event == halide_trace_begin_realization || *event == halide_trace_end_realization) {
620 // Call::trace vectorizes uniquely for begin/end realization, because the coordinates
621 // for these are actually min/extent pairs; we need to maintain the proper dimensionality
622 // count and instead aggregate the widened values into a single pair.
623 for (size_t i = 1; i <= 2; i++) {
624 const Call *call = new_args[i].as<Call>();
625 internal_assert(call && call->is_intrinsic(Call::make_struct));
626 if (i == 1) {
627 // values should always be empty for these events
628 internal_assert(call->args.empty());
629 continue;
630 }
631 vector<Expr> call_args(call->args.size());
632 for (size_t j = 0; j < call_args.size(); j += 2) {
633 Expr min_v = widen(call->args[j], max_lanes);
634 Expr extent_v = widen(call->args[j + 1], max_lanes);
635 Expr min_scalar = extract_lane(min_v, 0);
636 Expr max_scalar = min_scalar + extract_lane(extent_v, 0);
637 for (int k = 1; k < max_lanes; ++k) {
638 Expr min_k = extract_lane(min_v, k);
639 Expr extent_k = extract_lane(extent_v, k);
640 min_scalar = min(min_scalar, min_k);
641 max_scalar = max(max_scalar, min_k + extent_k);
642 }
643 call_args[j] = min_scalar;
644 call_args[j + 1] = max_scalar - min_scalar;
645 }
646 new_args[i] = Call::make(call->type.element_of(), Call::make_struct, call_args, Call::Intrinsic);
647 }
648 } else {
649 // Call::trace vectorizes uniquely, because we want a
650 // single trace call for the entire vector, instead of
651 // scalarizing the call and tracing each element.
652 for (size_t i = 1; i <= 2; i++) {
653 // Each struct should be a struct-of-vectors, not a
654 // vector of distinct structs.
655 const Call *call = new_args[i].as<Call>();
656 internal_assert(call && call->is_intrinsic(Call::make_struct));
657 // Widen the call args to have the same lanes as the max lanes found
658 vector<Expr> call_args(call->args.size());
659 for (size_t j = 0; j < call_args.size(); j++) {
660 call_args[j] = widen(call->args[j], max_lanes);
661 }
662 new_args[i] = Call::make(call->type.element_of(), Call::make_struct,
663 call_args, Call::Intrinsic);
664 }
665 // One of the arguments to the trace helper
666 // records the number of vector lanes in the type being
667 // stored.
668 new_args[5] = max_lanes;
669 // One of the arguments to the trace helper
670 // records the number entries in the coordinates (which we just widened)
671 if (max_lanes > 1) {
672 new_args[9] = new_args[9] * max_lanes;
673 }
674 }
675 return Call::make(op->type, Call::trace, new_args, op->call_type);
676 } else {
677 // Widen the args to have the same lanes as the max lanes found
678 for (size_t i = 0; i < new_args.size(); i++) {
679 new_args[i] = widen(new_args[i], max_lanes);
680 }
681 return Call::make(op->type.with_lanes(max_lanes), op->name, new_args,
682 op->call_type, op->func, op->value_index, op->image, op->param);
683 }
684 }
685
visit(const Let * op)686 Expr visit(const Let *op) override {
687
688 // Vectorize the let value and check to see if it was vectorized by
689 // this mutator. The type of the expression might already be vector
690 // width.
691 Expr mutated_value = mutate(op->value);
692 bool was_vectorized = (!op->value.type().is_vector() &&
693 mutated_value.type().is_vector());
694
695 // If the value was vectorized by this mutator, add a new name to
696 // the scope for the vectorized value expression.
697 string vectorized_name;
698 if (was_vectorized) {
699 vectorized_name = op->name + widening_suffix;
700 scope.push(op->name, mutated_value);
701 vector_scope.push(vectorized_name, mutated_value);
702 }
703
704 Expr mutated_body = mutate(op->body);
705
706 InterleavedRamp ir;
707 if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) {
708 return substitute(vectorized_name, mutated_value, mutated_body);
709 } else if (mutated_value.same_as(op->value) &&
710 mutated_body.same_as(op->body)) {
711 return op;
712 } else if (was_vectorized) {
713 scope.pop(op->name);
714 vector_scope.pop(vectorized_name);
715 return Let::make(vectorized_name, mutated_value, mutated_body);
716 } else {
717 return Let::make(op->name, mutated_value, mutated_body);
718 }
719 }
720
visit(const LetStmt * op)721 Stmt visit(const LetStmt *op) override {
722 Expr mutated_value = mutate(op->value);
723 string mutated_name = op->name;
724
725 // Check if the value was vectorized by this mutator.
726 bool was_vectorized = (!op->value.type().is_vector() &&
727 mutated_value.type().is_vector());
728
729 if (was_vectorized) {
730 mutated_name += widening_suffix;
731 scope.push(op->name, mutated_value);
732 vector_scope.push(mutated_name, mutated_value);
733 // Also keep track of the original let, in case inner code scalarizes.
734 containing_lets.emplace_back(op->name, op->value);
735 }
736
737 Stmt mutated_body = mutate(op->body);
738
739 if (was_vectorized) {
740 containing_lets.pop_back();
741 scope.pop(op->name);
742 vector_scope.pop(mutated_name);
743
744 // Inner code might have extracted my lanes using
745 // extract_lane, which introduces a shuffle_vector. If
746 // so we should define separate lets for the lanes and
747 // get it to use those instead.
748 mutated_body = ReplaceShuffleVectors(mutated_name).mutate(mutated_body);
749
750 // Check if inner code wants my individual lanes.
751 Type t = mutated_value.type();
752 for (int i = 0; i < t.lanes(); i++) {
753 string lane_name = mutated_name + ".lane." + std::to_string(i);
754 if (stmt_uses_var(mutated_body, lane_name)) {
755 mutated_body =
756 LetStmt::make(lane_name, extract_lane(mutated_value, i), mutated_body);
757 }
758 }
759
760 // Inner code may also have wanted my max or min lane
761 bool uses_min_lane = stmt_uses_var(mutated_body, mutated_name + ".min_lane");
762 bool uses_max_lane = stmt_uses_var(mutated_body, mutated_name + ".max_lane");
763
764 if (uses_min_lane || uses_max_lane) {
765 Interval i = bounds_of_lanes(mutated_value);
766
767 if (uses_min_lane) {
768 mutated_body =
769 LetStmt::make(mutated_name + ".min_lane", i.min, mutated_body);
770 }
771
772 if (uses_max_lane) {
773 mutated_body =
774 LetStmt::make(mutated_name + ".max_lane", i.max, mutated_body);
775 }
776 }
777 }
778
779 InterleavedRamp ir;
780 if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) {
781 return substitute(mutated_name, mutated_value, mutated_body);
782 } else if (mutated_value.same_as(op->value) &&
783 mutated_body.same_as(op->body)) {
784 return op;
785 } else {
786 return LetStmt::make(mutated_name, mutated_value, mutated_body);
787 }
788 }
789
visit(const Provide * op)790 Stmt visit(const Provide *op) override {
791 vector<Expr> new_args(op->args.size());
792 vector<Expr> new_values(op->values.size());
793 bool changed = false;
794
795 // Mutate the args
796 int max_lanes = 0;
797 for (size_t i = 0; i < op->args.size(); i++) {
798 Expr old_arg = op->args[i];
799 Expr new_arg = mutate(old_arg);
800 if (!new_arg.same_as(old_arg)) changed = true;
801 new_args[i] = new_arg;
802 max_lanes = std::max(new_arg.type().lanes(), max_lanes);
803 }
804
805 for (size_t i = 0; i < op->args.size(); i++) {
806 Expr old_value = op->values[i];
807 Expr new_value = mutate(old_value);
808 if (!new_value.same_as(old_value)) changed = true;
809 new_values[i] = new_value;
810 max_lanes = std::max(new_value.type().lanes(), max_lanes);
811 }
812
813 if (!changed) {
814 return op;
815 } else {
816 // Widen the args to have the same lanes as the max lanes found
817 for (size_t i = 0; i < new_args.size(); i++) {
818 new_args[i] = widen(new_args[i], max_lanes);
819 }
820 for (size_t i = 0; i < new_values.size(); i++) {
821 new_values[i] = widen(new_values[i], max_lanes);
822 }
823 return Provide::make(op->name, new_values, new_args);
824 }
825 }
826
visit(const Store * op)827 Stmt visit(const Store *op) override {
828 Expr predicate = mutate(op->predicate);
829 Expr value = mutate(op->value);
830 Expr index = mutate(op->index);
831
832 if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
833 return op;
834 } else {
835 int lanes = std::max(predicate.type().lanes(), std::max(value.type().lanes(), index.type().lanes()));
836 return Store::make(op->name, widen(value, lanes), widen(index, lanes),
837 op->param, widen(predicate, lanes), op->alignment);
838 }
839 }
840
visit(const AssertStmt * op)841 Stmt visit(const AssertStmt *op) override {
842 return (op->condition.type().lanes() > 1) ? scalarize(op) : op;
843 }
844
visit(const IfThenElse * op)845 Stmt visit(const IfThenElse *op) override {
846 Expr cond = mutate(op->condition);
847 int lanes = cond.type().lanes();
848 debug(3) << "Vectorizing over " << var << "\n"
849 << "Old: " << op->condition << "\n"
850 << "New: " << cond << "\n";
851
852 Stmt then_case = mutate(op->then_case);
853 Stmt else_case = mutate(op->else_case);
854
855 if (lanes > 1) {
856 // We have an if statement with a vector condition,
857 // which would mean control flow divergence within the
858 // SIMD lanes.
859
860 bool vectorize_predicate = !uses_gpu_vars(cond);
861 Stmt predicated_stmt;
862 if (vectorize_predicate) {
863 PredicateLoadStore p(var, cond, in_hexagon, target);
864 predicated_stmt = p.mutate(then_case);
865 vectorize_predicate = p.is_vectorized();
866 }
867 if (vectorize_predicate && else_case.defined()) {
868 PredicateLoadStore p(var, !cond, in_hexagon, target);
869 predicated_stmt = Block::make(predicated_stmt, p.mutate(else_case));
870 vectorize_predicate = p.is_vectorized();
871 }
872
873 debug(4) << "IfThenElse should vectorize predicate over var " << var << "? " << vectorize_predicate << "; cond: " << cond << "\n";
874 debug(4) << "Predicated stmt:\n"
875 << predicated_stmt << "\n";
876
877 // First check if the condition is marked as likely.
878 const Call *c = cond.as<Call>();
879 if (c && (c->is_intrinsic(Call::likely) ||
880 c->is_intrinsic(Call::likely_if_innermost))) {
881
882 // The meaning of the likely intrinsic is that
883 // Halide should optimize for the case in which
884 // *every* likely value is true. We can do that by
885 // generating a scalar condition that checks if
886 // the least-true lane is true.
887 Expr all_true = bounds_of_lanes(c->args[0]).min;
888
889 // Wrap it in the same flavor of likely
890 all_true = Call::make(Bool(), c->name,
891 {all_true}, Call::PureIntrinsic);
892
893 if (!vectorize_predicate) {
894 // We should strip the likelies from the case
895 // that's going to scalarize, because it's no
896 // longer likely.
897 Stmt without_likelies =
898 IfThenElse::make(op->condition.as<Call>()->args[0],
899 op->then_case, op->else_case);
900 Stmt stmt =
901 IfThenElse::make(all_true,
902 then_case,
903 scalarize(without_likelies));
904 debug(4) << "...With all_true likely: \n"
905 << stmt << "\n";
906 return stmt;
907 } else {
908 Stmt stmt =
909 IfThenElse::make(all_true,
910 then_case,
911 predicated_stmt);
912 debug(4) << "...Predicated IfThenElse: \n"
913 << stmt << "\n";
914 return stmt;
915 }
916 } else {
917 // It's some arbitrary vector condition.
918 if (!vectorize_predicate) {
919 debug(4) << "...Scalarizing vector predicate: \n"
920 << Stmt(op) << "\n";
921 return scalarize(op);
922 } else {
923 Stmt stmt = predicated_stmt;
924 debug(4) << "...Predicated IfThenElse: \n"
925 << stmt << "\n";
926 return stmt;
927 }
928 }
929 } else {
930 // It's an if statement on a scalar, we're ok to vectorize the innards.
931 debug(3) << "Not scalarizing if then else\n";
932 if (cond.same_as(op->condition) &&
933 then_case.same_as(op->then_case) &&
934 else_case.same_as(op->else_case)) {
935 return op;
936 } else {
937 return IfThenElse::make(cond, then_case, else_case);
938 }
939 }
940 }
941
visit(const For * op)942 Stmt visit(const For *op) override {
943 ForType for_type = op->for_type;
944 if (for_type == ForType::Vectorized) {
945 user_warning << "Warning: Encountered vector for loop over " << op->name
946 << " inside vector for loop over " << var << "."
947 << " Ignoring the vectorize directive for the inner for loop.\n";
948 for_type = ForType::Serial;
949 }
950
951 Expr min = mutate(op->min);
952 Expr extent = mutate(op->extent);
953
954 Stmt body = op->body;
955
956 if (min.type().is_vector()) {
957 // Rebase the loop to zero and try again
958 Expr var = Variable::make(Int(32), op->name);
959 Stmt body = substitute(op->name, var + op->min, op->body);
960 Stmt transformed = For::make(op->name, 0, op->extent, for_type, op->device_api, body);
961 return mutate(transformed);
962 }
963
964 if (extent.type().is_vector()) {
965 // We'll iterate up to the max over the lanes, but
966 // inject an if statement inside the loop that stops
967 // each lane from going too far.
968
969 extent = bounds_of_lanes(extent).max;
970 Expr var = Variable::make(Int(32), op->name);
971 body = IfThenElse::make(likely(var < op->min + op->extent), body);
972 }
973
974 body = mutate(body);
975
976 if (min.same_as(op->min) &&
977 extent.same_as(op->extent) &&
978 body.same_as(op->body) &&
979 for_type == op->for_type) {
980 return op;
981 } else {
982 return For::make(op->name, min, extent, for_type, op->device_api, body);
983 }
984 }
985
visit(const Allocate * op)986 Stmt visit(const Allocate *op) override {
987 vector<Expr> new_extents;
988 Expr new_expr;
989
990 int lanes = replacement.type().lanes();
991
992 // The new expanded dimension is innermost.
993 new_extents.emplace_back(lanes);
994
995 for (size_t i = 0; i < op->extents.size(); i++) {
996 Expr extent = mutate(op->extents[i]);
997 // For vector sizes, take the max over the lanes. Note
998 // that we haven't changed the strides, which also may
999 // vary per lane. This is a bit weird, but the way we
1000 // set up the vectorized memory means that lanes can't
1001 // clobber each others' memory, so it doesn't matter.
1002 if (extent.type().is_vector()) {
1003 extent = bounds_of_lanes(extent).max;
1004 }
1005 new_extents.push_back(extent);
1006 }
1007
1008 if (op->new_expr.defined()) {
1009 new_expr = mutate(op->new_expr);
1010 user_assert(new_expr.type().is_scalar())
1011 << "Cannot vectorize an allocation with a varying new_expr per vector lane.\n";
1012 }
1013
1014 Stmt body = op->body;
1015
1016 // Rewrite loads and stores to this allocation like so:
1017 // foo[x] -> foo[x*lanes + v]
1018 string v = unique_name('v');
1019 body = RewriteAccessToVectorAlloc(v, op->name, lanes).mutate(body);
1020
1021 scope.push(v, Ramp::make(0, 1, lanes));
1022 body = mutate(body);
1023 scope.pop(v);
1024
1025 // Replace the widened 'v' with the actual ramp
1026 // foo[x*lanes + widened_v] -> foo[x*lanes + ramp(0, 1, lanes)]
1027 body = substitute(v + widening_suffix, Ramp::make(0, 1, lanes), body);
1028
1029 // The variable itself could still exist inside an inner scalarized block.
1030 body = substitute(v, Variable::make(Int(32), var), body);
1031
1032 return Allocate::make(op->name, op->type, op->memory_type, new_extents, op->condition, body, new_expr, op->free_function);
1033 }
1034
visit(const Atomic * op)1035 Stmt visit(const Atomic *op) override {
1036 // Recognize a few special cases that we can handle as within-vector reduction trees.
1037 do {
1038 if (!op->mutex_name.empty()) {
1039 // We can't vectorize over a mutex
1040 break;
1041 }
1042
1043 // f[x] = f[x] <op> y
1044 const Store *store = op->body.as<Store>();
1045 if (!store) break;
1046
1047 VectorReduce::Operator reduce_op = VectorReduce::Add;
1048 Expr a, b;
1049 if (const Add *add = store->value.as<Add>()) {
1050 a = add->a;
1051 b = add->b;
1052 reduce_op = VectorReduce::Add;
1053 } else if (const Mul *mul = store->value.as<Mul>()) {
1054 a = mul->a;
1055 b = mul->b;
1056 reduce_op = VectorReduce::Mul;
1057 } else if (const Min *min = store->value.as<Min>()) {
1058 a = min->a;
1059 b = min->b;
1060 reduce_op = VectorReduce::Min;
1061 } else if (const Max *max = store->value.as<Max>()) {
1062 a = max->a;
1063 b = max->b;
1064 reduce_op = VectorReduce::Max;
1065 } else if (const Cast *cast_op = store->value.as<Cast>()) {
1066 if (cast_op->type.element_of() == UInt(8) &&
1067 cast_op->value.type().is_bool()) {
1068 if (const And *and_op = cast_op->value.as<And>()) {
1069 a = and_op->a;
1070 b = and_op->b;
1071 reduce_op = VectorReduce::And;
1072 } else if (const Or *or_op = cast_op->value.as<Or>()) {
1073 a = or_op->a;
1074 b = or_op->b;
1075 reduce_op = VectorReduce::Or;
1076 }
1077 }
1078 }
1079
1080 if (!a.defined() || !b.defined()) {
1081 break;
1082 }
1083
1084 // Bools get cast to uint8 for storage. Strip off that
1085 // cast around any load.
1086 if (b.type().is_bool()) {
1087 const Cast *cast_op = b.as<Cast>();
1088 if (cast_op) {
1089 b = cast_op->value;
1090 }
1091 }
1092 if (a.type().is_bool()) {
1093 const Cast *cast_op = b.as<Cast>();
1094 if (cast_op) {
1095 a = cast_op->value;
1096 }
1097 }
1098
1099 if (a.as<Variable>() && !b.as<Variable>()) {
1100 std::swap(a, b);
1101 }
1102
1103 // We require b to be a var, because it should have been lifted.
1104 const Variable *var_b = b.as<Variable>();
1105 const Load *load_a = a.as<Load>();
1106
1107 if (!var_b ||
1108 !scope.contains(var_b->name) ||
1109 !load_a ||
1110 load_a->name != store->name ||
1111 !is_one(load_a->predicate) ||
1112 !is_one(store->predicate)) {
1113 break;
1114 }
1115
1116 b = scope.get(var_b->name);
1117 Expr store_index = mutate(store->index);
1118 Expr load_index = mutate(load_a->index);
1119
1120 // The load and store indices must be the same interleaved
1121 // ramp (or the same scalar, in the total reduction case).
1122 InterleavedRamp store_ir, load_ir;
1123 Expr test;
1124 if (store_index.type().is_scalar()) {
1125 test = simplify(load_index == store_index);
1126 } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) &&
1127 is_interleaved_ramp(load_index, vector_scope, &load_ir) &&
1128 store_ir.repetitions == load_ir.repetitions &&
1129 store_ir.lanes == load_ir.lanes) {
1130 test = simplify(store_ir.base == load_ir.base &&
1131 store_ir.stride == load_ir.stride);
1132 }
1133
1134 if (!test.defined()) {
1135 break;
1136 }
1137
1138 if (is_zero(test)) {
1139 break;
1140 } else if (!is_one(test)) {
1141 // TODO: try harder by substituting in more things in scope
1142 break;
1143 }
1144
1145 int output_lanes = 1;
1146 if (store_index.type().is_scalar()) {
1147 // The index doesn't depend on the value being
1148 // vectorized, so it's a total reduction.
1149
1150 b = VectorReduce::make(reduce_op, b, 1);
1151 } else {
1152
1153 output_lanes = store_index.type().lanes() / store_ir.repetitions;
1154
1155 store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes);
1156 b = VectorReduce::make(reduce_op, b, output_lanes);
1157 }
1158
1159 Expr new_load = Load::make(load_a->type.with_lanes(output_lanes),
1160 load_a->name, store_index, load_a->image,
1161 load_a->param, const_true(output_lanes),
1162 ModulusRemainder{});
1163
1164 switch (reduce_op) {
1165 case VectorReduce::Add:
1166 b = new_load + b;
1167 break;
1168 case VectorReduce::Mul:
1169 b = new_load * b;
1170 break;
1171 case VectorReduce::Min:
1172 b = min(new_load, b);
1173 break;
1174 case VectorReduce::Max:
1175 b = max(new_load, b);
1176 break;
1177 case VectorReduce::And:
1178 b = cast(new_load.type(), cast(b.type(), new_load) && b);
1179 break;
1180 case VectorReduce::Or:
1181 b = cast(new_load.type(), cast(b.type(), new_load) || b);
1182 break;
1183 }
1184
1185 Stmt s = Store::make(store->name, b, store_index, store->param,
1186 const_true(b.type().lanes()), store->alignment);
1187
1188 // We may still need the atomic node, if there was more
1189 // parallelism than just the vectorization.
1190 s = Atomic::make(op->producer_name, op->mutex_name, s);
1191
1192 return s;
1193 } while (0);
1194
1195 // In the general case, if a whole stmt has to be done
1196 // atomically, we need to serialize.
1197 return scalarize(op);
1198 }
1199
scalarize(Stmt s)1200 Stmt scalarize(Stmt s) {
1201 // Wrap a serial loop around it. Maybe LLVM will have
1202 // better luck vectorizing it.
1203
1204 // We'll need the original scalar versions of any containing lets.
1205 for (size_t i = containing_lets.size(); i > 0; i--) {
1206 const auto &l = containing_lets[i - 1];
1207 s = LetStmt::make(l.first, l.second, s);
1208 }
1209
1210 const Ramp *r = replacement.as<Ramp>();
1211 internal_assert(r) << "Expected replacement in VectorSubs to be a ramp\n";
1212 return For::make(var, r->base, r->lanes, ForType::Serial, DeviceAPI::None, s);
1213 }
1214
scalarize(Expr e)1215 Expr scalarize(Expr e) {
1216 // This method returns a select tree that produces a vector lanes
1217 // result expression
1218
1219 Expr result;
1220 int lanes = replacement.type().lanes();
1221
1222 for (int i = lanes - 1; i >= 0; --i) {
1223 // Hide all the vector let values in scope with a scalar version
1224 // in the appropriate lane.
1225 for (Scope<Expr>::const_iterator iter = scope.cbegin(); iter != scope.cend(); ++iter) {
1226 string name = iter.name() + ".lane." + std::to_string(i);
1227 Expr lane = extract_lane(iter.value(), i);
1228 e = substitute(iter.name(), Variable::make(lane.type(), name), e);
1229 }
1230
1231 // Replace uses of the vectorized variable with the extracted
1232 // lane expression
1233 e = substitute(var, i, e);
1234
1235 if (i == lanes - 1) {
1236 result = Broadcast::make(e, lanes);
1237 } else {
1238 Expr cond = (replacement == Broadcast::make(i, lanes));
1239 result = Select::make(cond, Broadcast::make(e, lanes), result);
1240 }
1241 }
1242
1243 return result;
1244 }
1245
1246 public:
VectorSubs(string v,Expr r,bool in_hexagon,const Target & t)1247 VectorSubs(string v, Expr r, bool in_hexagon, const Target &t)
1248 : var(std::move(v)), replacement(std::move(r)), target(t), in_hexagon(in_hexagon) {
1249 widening_suffix = ".x" + std::to_string(replacement.type().lanes());
1250 }
1251 }; // namespace
1252
1253 class FindVectorizableExprsInAtomicNode : public IRMutator {
1254 // An Atomic node protects all accesses to a given buffer. We
1255 // consider a name "poisoned" if it depends on an access to this
1256 // buffer. We can't lift or vectorize anything that has been
1257 // poisoned.
1258 Scope<> poisoned_names;
1259 bool poison = false;
1260
1261 using IRMutator::visit;
1262
1263 template<typename T>
visit_let(const T * op)1264 const T *visit_let(const T *op) {
1265 mutate(op->value);
1266 ScopedBinding<> bind_if(poison, poisoned_names, op->name);
1267 mutate(op->body);
1268 return op;
1269 }
1270
visit(const LetStmt * op)1271 Stmt visit(const LetStmt *op) override {
1272 return visit_let(op);
1273 }
1274
visit(const Let * op)1275 Expr visit(const Let *op) override {
1276 return visit_let(op);
1277 }
1278
visit(const Load * op)1279 Expr visit(const Load *op) override {
1280 // Even if the load is bad, maybe we can lift the index
1281 IRMutator::visit(op);
1282
1283 poison |= poisoned_names.contains(op->name);
1284 return op;
1285 }
1286
visit(const Variable * op)1287 Expr visit(const Variable *op) override {
1288 poison = poisoned_names.contains(op->name);
1289 return op;
1290 }
1291
visit(const Store * op)1292 Stmt visit(const Store *op) override {
1293 // A store poisons all subsequent loads, but loads before the
1294 // first store can be lifted.
1295 mutate(op->index);
1296 mutate(op->value);
1297 poisoned_names.push(op->name);
1298 return op;
1299 }
1300
visit(const Call * op)1301 Expr visit(const Call *op) override {
1302 IRMutator::visit(op);
1303 poison |= !op->is_pure();
1304 return op;
1305 }
1306
1307 public:
1308 using IRMutator::mutate;
1309
mutate(const Expr & e)1310 Expr mutate(const Expr &e) override {
1311 bool old_poison = poison;
1312 poison = false;
1313 IRMutator::mutate(e);
1314 if (!poison) {
1315 liftable.insert(e);
1316 }
1317 poison |= old_poison;
1318 // We're not actually mutating anything. This class is only a
1319 // mutator so that we can override a generic mutate() method.
1320 return e;
1321 }
1322
FindVectorizableExprsInAtomicNode(const string & buf,const map<string,Function> & env)1323 FindVectorizableExprsInAtomicNode(const string &buf, const map<string, Function> &env) {
1324 poisoned_names.push(buf);
1325 auto it = env.find(buf);
1326 if (it != env.end()) {
1327 // Handle tuples
1328 size_t n = it->second.values().size();
1329 if (n > 1) {
1330 for (size_t i = 0; i < n; i++) {
1331 poisoned_names.push(buf + "." + std::to_string(i));
1332 }
1333 }
1334 }
1335 }
1336
1337 std::set<Expr, ExprCompare> liftable;
1338 };
1339
1340 class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator {
1341 const std::set<Expr, ExprCompare> &liftable;
1342
1343 using IRMutator::visit;
1344
1345 template<typename StmtOrExpr, typename LetStmtOrLet>
1346 StmtOrExpr visit_let(const LetStmtOrLet *op) {
1347 if (liftable.count(op->value)) {
1348 // Lift it under its current name to avoid having to
1349 // rewrite the variables in other lifted exprs.
1350 // TODO: duplicate non-overlapping liftable let stmts due to unrolling.
1351 lifted.emplace_back(op->name, op->value);
1352 return mutate(op->body);
1353 } else {
1354 return IRMutator::visit(op);
1355 }
1356 }
1357
visit(const LetStmt * op)1358 Stmt visit(const LetStmt *op) override {
1359 return visit_let<Stmt>(op);
1360 }
1361
visit(const Let * op)1362 Expr visit(const Let *op) override {
1363 return visit_let<Expr>(op);
1364 }
1365
1366 public:
1367 map<Expr, string, IRDeepCompare> already_lifted;
1368 vector<pair<string, Expr>> lifted;
1369
1370 using IRMutator::mutate;
1371
mutate(const Expr & e)1372 Expr mutate(const Expr &e) override {
1373 if (liftable.count(e) && !is_const(e) && !e.as<Variable>()) {
1374 auto it = already_lifted.find(e);
1375 string name;
1376 if (it != already_lifted.end()) {
1377 name = it->second;
1378 } else {
1379 name = unique_name('t');
1380 lifted.emplace_back(name, e);
1381 already_lifted.emplace(e, name);
1382 }
1383 return Variable::make(e.type(), name);
1384 } else {
1385 return IRMutator::mutate(e);
1386 }
1387 }
1388
LiftVectorizableExprsOutOfSingleAtomicNode(const std::set<Expr,ExprCompare> & liftable)1389 LiftVectorizableExprsOutOfSingleAtomicNode(const std::set<Expr, ExprCompare> &liftable)
1390 : liftable(liftable) {
1391 }
1392 };
1393
1394 class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator {
1395 using IRMutator::visit;
1396
visit(const Atomic * op)1397 Stmt visit(const Atomic *op) override {
1398 FindVectorizableExprsInAtomicNode finder(op->producer_name, env);
1399 finder.mutate(op->body);
1400 LiftVectorizableExprsOutOfSingleAtomicNode lifter(finder.liftable);
1401 Stmt new_body = lifter.mutate(op->body);
1402 new_body = Atomic::make(op->producer_name, op->mutex_name, new_body);
1403 while (!lifter.lifted.empty()) {
1404 auto p = lifter.lifted.back();
1405 new_body = LetStmt::make(p.first, p.second, new_body);
1406 lifter.lifted.pop_back();
1407 }
1408 return new_body;
1409 }
1410
1411 const map<string, Function> &env;
1412
1413 public:
LiftVectorizableExprsOutOfAllAtomicNodes(const map<string,Function> & env)1414 LiftVectorizableExprsOutOfAllAtomicNodes(const map<string, Function> &env)
1415 : env(env) {
1416 }
1417 };
1418
1419 // Vectorize all loops marked as such in a Stmt
1420 class VectorizeLoops : public IRMutator {
1421 const Target ⌖
1422 bool in_hexagon;
1423
1424 using IRMutator::visit;
1425
visit(const For * for_loop)1426 Stmt visit(const For *for_loop) override {
1427 bool old_in_hexagon = in_hexagon;
1428 if (for_loop->device_api == DeviceAPI::Hexagon) {
1429 in_hexagon = true;
1430 }
1431
1432 Stmt stmt;
1433 if (for_loop->for_type == ForType::Vectorized) {
1434 const IntImm *extent = for_loop->extent.as<IntImm>();
1435 if (!extent || extent->value <= 1) {
1436 user_error << "Loop over " << for_loop->name
1437 << " has extent " << for_loop->extent
1438 << ". Can only vectorize loops over a "
1439 << "constant extent > 1\n";
1440 }
1441
1442 // Replace the var with a ramp within the body
1443 Expr for_var = Variable::make(Int(32), for_loop->name);
1444 Expr replacement = Ramp::make(for_loop->min, 1, extent->value);
1445 stmt = VectorSubs(for_loop->name, replacement, in_hexagon, target).mutate(for_loop->body);
1446 } else {
1447 stmt = IRMutator::visit(for_loop);
1448 }
1449
1450 if (for_loop->device_api == DeviceAPI::Hexagon) {
1451 in_hexagon = old_in_hexagon;
1452 }
1453
1454 return stmt;
1455 }
1456
1457 public:
VectorizeLoops(const Target & t)1458 VectorizeLoops(const Target &t)
1459 : target(t), in_hexagon(false) {
1460 }
1461 };
1462
1463 /** Check if all stores in a Stmt are to names in a given scope. Used
1464 by RemoveUnnecessaryAtomics below. */
1465 class AllStoresInScope : public IRVisitor {
1466 using IRVisitor::visit;
visit(const Store * op)1467 void visit(const Store *op) override {
1468 result = result && s.contains(op->name);
1469 }
1470
1471 public:
1472 bool result = true;
1473 const Scope<> &s;
AllStoresInScope(const Scope<> & s)1474 AllStoresInScope(const Scope<> &s)
1475 : s(s) {
1476 }
1477 };
all_stores_in_scope(const Stmt & stmt,const Scope<> & scope)1478 bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) {
1479 AllStoresInScope checker(scope);
1480 stmt.accept(&checker);
1481 return checker.result;
1482 }
1483
1484 /** Drop any atomic nodes protecting buffers that are only accessed
1485 * from a single thread. */
1486 class RemoveUnnecessaryAtomics : public IRMutator {
1487 using IRMutator::visit;
1488
1489 // Allocations made from within this same thread
1490 bool in_thread = false;
1491 Scope<> local_allocs;
1492
visit(const Allocate * op)1493 Stmt visit(const Allocate *op) override {
1494 ScopedBinding<> bind(local_allocs, op->name);
1495 return IRMutator::visit(op);
1496 }
1497
visit(const Atomic * op)1498 Stmt visit(const Atomic *op) override {
1499 if (!in_thread || all_stores_in_scope(op->body, local_allocs)) {
1500 return mutate(op->body);
1501 } else {
1502 return op;
1503 }
1504 }
1505
visit(const For * op)1506 Stmt visit(const For *op) override {
1507 if (is_parallel(op->for_type)) {
1508 ScopedValue<bool> old_in_thread(in_thread, true);
1509 Scope<> old_local_allocs;
1510 old_local_allocs.swap(local_allocs);
1511 Stmt s = IRMutator::visit(op);
1512 old_local_allocs.swap(local_allocs);
1513 return s;
1514 } else {
1515 return IRMutator::visit(op);
1516 }
1517 }
1518 };
1519
1520 } // namespace
1521
vectorize_loops(const Stmt & stmt,const map<string,Function> & env,const Target & t)1522 Stmt vectorize_loops(const Stmt &stmt, const map<string, Function> &env, const Target &t) {
1523 // Limit the scope of atomic nodes to just the necessary stuff.
1524 // TODO: Should this be an earlier pass? It's probably a good idea
1525 // for non-vectorizing stuff too.
1526 Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(stmt);
1527 s = VectorizeLoops(t).mutate(s);
1528 s = RemoveUnnecessaryAtomics().mutate(s);
1529 return s;
1530 }
1531
1532 } // namespace Internal
1533 } // namespace Halide
1534