1 #include "Deinterleave.h"
2
3 #include "CSE.h"
4 #include "Debug.h"
5 #include "IREquality.h"
6 #include "IRMutator.h"
7 #include "IROperator.h"
8 #include "IRPrinter.h"
9 #include "ModulusRemainder.h"
10 #include "Scope.h"
11 #include "Simplify.h"
12 #include "Substitute.h"
13
14 namespace Halide {
15 namespace Internal {
16
17 using std::pair;
18
19 namespace {
20
21 class StoreCollector : public IRMutator {
22 public:
23 const std::string store_name;
24 const int store_stride, max_stores;
25 std::vector<Stmt> &let_stmts;
26 std::vector<Stmt> &stores;
27
StoreCollector(const std::string & name,int stride,int ms,std::vector<Stmt> & lets,std::vector<Stmt> & ss)28 StoreCollector(const std::string &name, int stride, int ms,
29 std::vector<Stmt> &lets, std::vector<Stmt> &ss)
30 : store_name(name), store_stride(stride), max_stores(ms),
31 let_stmts(lets), stores(ss), collecting(true) {
32 }
33
34 private:
35 using IRMutator::visit;
36
37 // Don't enter any inner constructs for which it's not safe to pull out stores.
visit(const For * op)38 Stmt visit(const For *op) override {
39 collecting = false;
40 return op;
41 }
visit(const IfThenElse * op)42 Stmt visit(const IfThenElse *op) override {
43 collecting = false;
44 return op;
45 }
visit(const ProducerConsumer * op)46 Stmt visit(const ProducerConsumer *op) override {
47 collecting = false;
48 return op;
49 }
visit(const Allocate * op)50 Stmt visit(const Allocate *op) override {
51 collecting = false;
52 return op;
53 }
visit(const Realize * op)54 Stmt visit(const Realize *op) override {
55 collecting = false;
56 return op;
57 }
58
59 bool collecting;
60 // These are lets that we've encountered since the last collected
61 // store. If we collect another store, these "potential" lets
62 // become lets used by the collected stores.
63 std::vector<Stmt> potential_lets;
64
visit(const Load * op)65 Expr visit(const Load *op) override {
66 if (!collecting) {
67 return op;
68 }
69
70 // If we hit a load from the buffer we're trying to collect
71 // stores for, stop collecting to avoid reordering loads and
72 // stores from the same buffer.
73 if (op->name == store_name) {
74 collecting = false;
75 return op;
76 } else {
77 return IRMutator::visit(op);
78 }
79 }
80
visit(const Store * op)81 Stmt visit(const Store *op) override {
82 if (!collecting) {
83 return op;
84 }
85
86 // By default, do nothing.
87 Stmt stmt = op;
88
89 if (stores.size() >= (size_t)max_stores) {
90 // Already have enough stores.
91 collecting = false;
92 return stmt;
93 }
94
95 // Make sure this Store doesn't do anything that causes us to
96 // stop collecting.
97 stmt = IRMutator::visit(op);
98 if (!collecting) {
99 return stmt;
100 }
101
102 if (op->name != store_name) {
103 // Not a store to the buffer we're looking for.
104 return stmt;
105 }
106
107 const Ramp *r = op->index.as<Ramp>();
108 if (!r || !is_const(r->stride, store_stride)) {
109 // Store doesn't store to the ramp we're looking
110 // for. Can't interleave it. Since we don't want to
111 // reorder stores, stop collecting.
112 collecting = false;
113 return stmt;
114 }
115
116 // This store is good, collect it and replace with a no-op.
117 stores.emplace_back(op);
118 stmt = Evaluate::make(0);
119
120 // Because we collected this store, we need to save the
121 // potential lets since the last collected store.
122 let_stmts.insert(let_stmts.end(), potential_lets.begin(), potential_lets.end());
123 potential_lets.clear();
124 return stmt;
125 }
126
visit(const Call * op)127 Expr visit(const Call *op) override {
128 if (!op->is_pure()) {
129 // Avoid reordering calls to impure functions
130 collecting = false;
131 return op;
132 } else {
133 return IRMutator::visit(op);
134 }
135 }
136
visit(const LetStmt * op)137 Stmt visit(const LetStmt *op) override {
138 if (!collecting) {
139 return op;
140 }
141
142 // Walk inside the let chain
143 Stmt stmt = IRMutator::visit(op);
144
145 // If we're still collecting, we need to save the entire let chain as potential lets.
146 if (collecting) {
147 Stmt body;
148 do {
149 potential_lets.emplace_back(op);
150 body = op->body;
151 } while ((op = body.as<LetStmt>()));
152 }
153 return stmt;
154 }
155
visit(const Block * op)156 Stmt visit(const Block *op) override {
157 if (!collecting) {
158 return op;
159 }
160
161 Stmt first = mutate(op->first);
162 Stmt rest = op->rest;
163 // We might have decided to stop collecting during mutation of first.
164 if (collecting) {
165 rest = mutate(rest);
166 }
167 return Block::make(first, rest);
168 }
169 };
170
collect_strided_stores(const Stmt & stmt,const std::string & name,int stride,int max_stores,std::vector<Stmt> lets,std::vector<Stmt> & stores)171 Stmt collect_strided_stores(const Stmt &stmt, const std::string &name, int stride, int max_stores,
172 std::vector<Stmt> lets, std::vector<Stmt> &stores) {
173
174 StoreCollector collect(name, stride, max_stores, lets, stores);
175 return collect.mutate(stmt);
176 }
177
178 class Deinterleaver : public IRGraphMutator {
179 public:
Deinterleaver(int starting_lane,int lane_stride,int new_lanes,const Scope<> & lets)180 Deinterleaver(int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets)
181 : starting_lane(starting_lane),
182 lane_stride(lane_stride),
183 new_lanes(new_lanes),
184 external_lets(lets) {
185 }
186
187 private:
188 int starting_lane;
189 int lane_stride;
190 int new_lanes;
191
192 // lets for which we have even and odd lane specializations
193 const Scope<> &external_lets;
194
195 using IRMutator::visit;
196
visit(const VectorReduce * op)197 Expr visit(const VectorReduce *op) override {
198 std::vector<int> input_lanes;
199 int factor = op->value.type().lanes() / op->type.lanes();
200 for (int i = starting_lane; i < op->type.lanes(); i += lane_stride) {
201 for (int j = 0; j < factor; j++) {
202 input_lanes.push_back(i * factor + j);
203 }
204 }
205 Expr in = Shuffle::make({op->value}, input_lanes);
206 return VectorReduce::make(op->op, in, new_lanes);
207 }
208
visit(const Broadcast * op)209 Expr visit(const Broadcast *op) override {
210 if (new_lanes == 1) {
211 return op->value;
212 }
213 return Broadcast::make(op->value, new_lanes);
214 }
215
visit(const Load * op)216 Expr visit(const Load *op) override {
217 if (op->type.is_scalar()) {
218 return op;
219 } else {
220 Type t = op->type.with_lanes(new_lanes);
221 ModulusRemainder align = op->alignment;
222 // TODO: Figure out the alignment of every nth lane
223 if (starting_lane != 0) {
224 align = ModulusRemainder();
225 }
226 return Load::make(t, op->name, mutate(op->index), op->image, op->param, mutate(op->predicate), align);
227 }
228 }
229
visit(const Ramp * op)230 Expr visit(const Ramp *op) override {
231 Expr expr = op->base + starting_lane * op->stride;
232 internal_assert(expr.type() == op->base.type());
233 if (new_lanes > 1) {
234 expr = Ramp::make(expr, op->stride * lane_stride, new_lanes);
235 }
236 return expr;
237 }
238
visit(const Variable * op)239 Expr visit(const Variable *op) override {
240 if (op->type.is_scalar()) {
241 return op;
242 } else {
243
244 Type t = op->type.with_lanes(new_lanes);
245 if (external_lets.contains(op->name) &&
246 starting_lane == 0 &&
247 lane_stride == 2) {
248 return Variable::make(t, op->name + ".even_lanes", op->image, op->param, op->reduction_domain);
249 } else if (external_lets.contains(op->name) &&
250 starting_lane == 1 &&
251 lane_stride == 2) {
252 return Variable::make(t, op->name + ".odd_lanes", op->image, op->param, op->reduction_domain);
253 } else if (external_lets.contains(op->name) &&
254 starting_lane == 0 &&
255 lane_stride == 3) {
256 return Variable::make(t, op->name + ".lanes_0_of_3", op->image, op->param, op->reduction_domain);
257 } else if (external_lets.contains(op->name) &&
258 starting_lane == 1 &&
259 lane_stride == 3) {
260 return Variable::make(t, op->name + ".lanes_1_of_3", op->image, op->param, op->reduction_domain);
261 } else if (external_lets.contains(op->name) &&
262 starting_lane == 2 &&
263 lane_stride == 3) {
264 return Variable::make(t, op->name + ".lanes_2_of_3", op->image, op->param, op->reduction_domain);
265 } else {
266 // Uh-oh, we don't know how to deinterleave this vector expression
267 // Make llvm do it
268 std::vector<int> indices;
269 for (int i = 0; i < new_lanes; i++) {
270 indices.push_back(starting_lane + lane_stride * i);
271 }
272 return Shuffle::make({op}, indices);
273 }
274 }
275 }
276
visit(const Cast * op)277 Expr visit(const Cast *op) override {
278 if (op->type.is_scalar()) {
279 return op;
280 } else {
281 Type t = op->type.with_lanes(new_lanes);
282 return Cast::make(t, mutate(op->value));
283 }
284 }
285
visit(const Call * op)286 Expr visit(const Call *op) override {
287 Type t = op->type.with_lanes(new_lanes);
288
289 // Don't mutate scalars
290 if (op->type.is_scalar()) {
291 return op;
292 } else if (op->is_intrinsic(Call::glsl_texture_load)) {
293 // glsl_texture_load returns a <uint x 4> result. Deinterleave by
294 // wrapping the call in a shuffle_vector
295 std::vector<int> indices;
296 for (int i = 0; i < new_lanes; i++) {
297 indices.push_back(i * lane_stride + starting_lane);
298 }
299 return Shuffle::make({op}, indices);
300 } else {
301
302 // Vector calls are always parallel across the lanes, so we
303 // can just deinterleave the args.
304
305 // Beware of intrinsics for which this is not true!
306 std::vector<Expr> args(op->args.size());
307 for (size_t i = 0; i < args.size(); i++) {
308 args[i] = mutate(op->args[i]);
309 }
310
311 return Call::make(t, op->name, args, op->call_type,
312 op->func, op->value_index, op->image, op->param);
313 }
314 }
315
visit(const Shuffle * op)316 Expr visit(const Shuffle *op) override {
317 if (op->is_interleave()) {
318 // Special case where we can discard some of the vector arguments entirely.
319 internal_assert(starting_lane >= 0 && starting_lane < lane_stride);
320 if ((int)op->vectors.size() == lane_stride) {
321 return op->vectors[starting_lane];
322 } else if ((int)op->vectors.size() % lane_stride == 0) {
323 // Pick up every lane-stride vector.
324 std::vector<Expr> new_vectors(op->vectors.size() / lane_stride);
325 for (size_t i = 0; i < new_vectors.size(); i++) {
326 new_vectors[i] = op->vectors[i * lane_stride + starting_lane];
327 }
328 return Shuffle::make_interleave(new_vectors);
329 }
330 }
331
332 // Keep the same set of vectors and extract every nth numeric
333 // arg to the shuffle.
334 std::vector<int> indices;
335 for (int i = 0; i < new_lanes; i++) {
336 int idx = i * lane_stride + starting_lane;
337 indices.push_back(op->indices[idx]);
338 }
339 return Shuffle::make(op->vectors, indices);
340 }
341 };
342
deinterleave(Expr e,int starting_lane,int lane_stride,int new_lanes,const Scope<> & lets)343 Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) {
344 e = substitute_in_all_lets(e);
345 Deinterleaver d(starting_lane, lane_stride, new_lanes, lets);
346 e = d.mutate(e);
347 e = common_subexpression_elimination(e);
348 return simplify(e);
349 }
350 } // namespace
351
extract_odd_lanes(const Expr & e,const Scope<> & lets)352 Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) {
353 internal_assert(e.type().lanes() % 2 == 0);
354 return deinterleave(e, 1, 2, e.type().lanes() / 2, lets);
355 }
356
extract_even_lanes(const Expr & e,const Scope<> & lets)357 Expr extract_even_lanes(const Expr &e, const Scope<> &lets) {
358 internal_assert(e.type().lanes() % 2 == 0);
359 return deinterleave(e, 0, 2, (e.type().lanes() + 1) / 2, lets);
360 }
361
extract_even_lanes(const Expr & e)362 Expr extract_even_lanes(const Expr &e) {
363 internal_assert(e.type().lanes() % 2 == 0);
364 Scope<> lets;
365 return extract_even_lanes(e, lets);
366 }
367
extract_odd_lanes(const Expr & e)368 Expr extract_odd_lanes(const Expr &e) {
369 internal_assert(e.type().lanes() % 2 == 0);
370 Scope<> lets;
371 return extract_odd_lanes(e, lets);
372 }
373
extract_mod3_lanes(const Expr & e,int lane,const Scope<> & lets)374 Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) {
375 internal_assert(e.type().lanes() % 3 == 0);
376 return deinterleave(e, lane, 3, (e.type().lanes() + 2) / 3, lets);
377 }
378
extract_lane(const Expr & e,int lane)379 Expr extract_lane(const Expr &e, int lane) {
380 Scope<> lets;
381 return deinterleave(e, lane, e.type().lanes(), 1, lets);
382 }
383
384 namespace {
385
386 class Interleaver : public IRMutator {
387 Scope<> vector_lets;
388
389 using IRMutator::visit;
390
391 bool should_deinterleave;
392 int num_lanes;
393
deinterleave_expr(Expr e)394 Expr deinterleave_expr(Expr e) {
395 if (e.type().lanes() <= num_lanes) {
396 // Just scalarize
397 return e;
398 } else if (num_lanes == 2) {
399 Expr a = extract_even_lanes(e, vector_lets);
400 Expr b = extract_odd_lanes(e, vector_lets);
401 return Shuffle::make_interleave({a, b});
402 } else if (num_lanes == 3) {
403 Expr a = extract_mod3_lanes(e, 0, vector_lets);
404 Expr b = extract_mod3_lanes(e, 1, vector_lets);
405 Expr c = extract_mod3_lanes(e, 2, vector_lets);
406 return Shuffle::make_interleave({a, b, c});
407 } else if (num_lanes == 4) {
408 Expr a = extract_even_lanes(e, vector_lets);
409 Expr b = extract_odd_lanes(e, vector_lets);
410 Expr aa = extract_even_lanes(a, vector_lets);
411 Expr ab = extract_odd_lanes(a, vector_lets);
412 Expr ba = extract_even_lanes(b, vector_lets);
413 Expr bb = extract_odd_lanes(b, vector_lets);
414 return Shuffle::make_interleave({aa, ba, ab, bb});
415 } else {
416 // Give up and don't do anything clever for >4
417 return e;
418 }
419 }
420
421 template<typename T, typename Body>
visit_lets(const T * op)422 Body visit_lets(const T *op) {
423 // Visit an entire chain of lets in a single method to conserve stack space.
424 struct Frame {
425 const T *op;
426 Expr new_value;
427 ScopedBinding<> binding;
428 Frame(const T *op, Expr v, Scope<void> &scope)
429 : op(op),
430 new_value(std::move(v)),
431 binding(new_value.type().is_vector(), scope, op->name) {
432 }
433 };
434 std::vector<Frame> frames;
435 Body result;
436
437 do {
438 result = op->body;
439 frames.emplace_back(op, mutate(op->value), vector_lets);
440 } while ((op = result.template as<T>()));
441
442 result = mutate(result);
443
444 for (auto it = frames.rbegin(); it != frames.rend(); it++) {
445 Expr value = std::move(it->new_value);
446
447 result = T::make(it->op->name, value, result);
448
449 // For vector lets, we may additionally need a let defining the even and odd lanes only
450 if (value.type().is_vector()) {
451 if (value.type().lanes() % 2 == 0) {
452 result = T::make(it->op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
453 result = T::make(it->op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
454 }
455 if (value.type().lanes() % 3 == 0) {
456 result = T::make(it->op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
457 result = T::make(it->op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
458 result = T::make(it->op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
459 }
460 }
461 }
462
463 return result;
464 }
465
visit(const Let * op)466 Expr visit(const Let *op) override {
467 return visit_lets<Let, Expr>(op);
468 }
469
visit(const LetStmt * op)470 Stmt visit(const LetStmt *op) override {
471 return visit_lets<LetStmt, Stmt>(op);
472 }
473
visit(const Mod * op)474 Expr visit(const Mod *op) override {
475 const Ramp *r = op->a.as<Ramp>();
476 for (int i = 2; i <= 4; ++i) {
477 if (r &&
478 is_const(op->b, i) &&
479 (r->type.lanes() % i) == 0) {
480 should_deinterleave = true;
481 num_lanes = i;
482 break;
483 }
484 }
485 return IRMutator::visit(op);
486 }
487
visit(const Div * op)488 Expr visit(const Div *op) override {
489 const Ramp *r = op->a.as<Ramp>();
490 for (int i = 2; i <= 4; ++i) {
491 if (r &&
492 is_const(op->b, i) &&
493 (r->type.lanes() % i) == 0) {
494 should_deinterleave = true;
495 num_lanes = i;
496 break;
497 }
498 }
499 return IRMutator::visit(op);
500 }
501
visit(const Call * op)502 Expr visit(const Call *op) override {
503 if (!op->is_pure()) {
504 // deinterleaving potentially changes the order of execution.
505 should_deinterleave = false;
506 }
507 return IRMutator::visit(op);
508 }
509
visit(const Load * op)510 Expr visit(const Load *op) override {
511 bool old_should_deinterleave = should_deinterleave;
512 int old_num_lanes = num_lanes;
513
514 should_deinterleave = false;
515 Expr idx = mutate(op->index);
516 bool should_deinterleave_idx = should_deinterleave;
517
518 should_deinterleave = false;
519 Expr predicate = mutate(op->predicate);
520 bool should_deinterleave_predicate = should_deinterleave;
521
522 Expr expr;
523 if (should_deinterleave_idx && (should_deinterleave_predicate || is_one(predicate))) {
524 // If we want to deinterleave both the index and predicate
525 // (or the predicate is one), then deinterleave the
526 // resulting load.
527 expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
528 expr = deinterleave_expr(expr);
529 } else if (should_deinterleave_idx) {
530 // If we only want to deinterleave the index and not the
531 // predicate, deinterleave the index prior to the load.
532 idx = deinterleave_expr(idx);
533 expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
534 } else if (should_deinterleave_predicate) {
535 // Similarly, deinterleave the predicate prior to the load
536 // if we don't want to deinterleave the index.
537 predicate = deinterleave_expr(predicate);
538 expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
539 } else if (!idx.same_as(op->index) || !predicate.same_as(op->index)) {
540 expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
541 } else {
542 expr = op;
543 }
544
545 should_deinterleave = old_should_deinterleave;
546 num_lanes = old_num_lanes;
547 return expr;
548 }
549
visit(const Store * op)550 Stmt visit(const Store *op) override {
551 bool old_should_deinterleave = should_deinterleave;
552 int old_num_lanes = num_lanes;
553
554 should_deinterleave = false;
555 Expr idx = mutate(op->index);
556 if (should_deinterleave) {
557 idx = deinterleave_expr(idx);
558 }
559
560 should_deinterleave = false;
561 Expr value = mutate(op->value);
562 if (should_deinterleave) {
563 value = deinterleave_expr(value);
564 }
565
566 should_deinterleave = false;
567 Expr predicate = mutate(op->predicate);
568 if (should_deinterleave) {
569 predicate = deinterleave_expr(predicate);
570 }
571
572 Stmt stmt = Store::make(op->name, value, idx, op->param, predicate, op->alignment);
573
574 should_deinterleave = old_should_deinterleave;
575 num_lanes = old_num_lanes;
576
577 return stmt;
578 }
579
gather_stores(const Block * op)580 HALIDE_NEVER_INLINE Stmt gather_stores(const Block *op) {
581 const LetStmt *let = op->first.as<LetStmt>();
582 const Store *store = op->first.as<Store>();
583
584 // Gather all the let stmts surrounding the first.
585 std::vector<Stmt> let_stmts;
586 while (let) {
587 let_stmts.emplace_back(let);
588 store = let->body.as<Store>();
589 let = let->body.as<LetStmt>();
590 }
591
592 // There was no inner store.
593 if (!store) return Stmt();
594
595 const Ramp *r0 = store->index.as<Ramp>();
596
597 // It's not a store of a ramp index.
598 if (!r0) return Stmt();
599
600 const int64_t *stride_ptr = as_const_int(r0->stride);
601
602 // The stride isn't a constant or is <= 0
603 if (!stride_ptr || *stride_ptr < 1) return Stmt();
604
605 const int64_t stride = *stride_ptr;
606 const int lanes = r0->lanes;
607 const int64_t expected_stores = stride == 1 ? lanes : stride;
608
609 // Collect the rest of the stores.
610 std::vector<Stmt> stores;
611 stores.emplace_back(store);
612 Stmt rest = collect_strided_stores(op->rest, store->name,
613 stride, expected_stores,
614 let_stmts, stores);
615
616 // Check the store collector didn't collect too many
617 // stores (that would be a bug).
618 internal_assert(stores.size() <= (size_t)expected_stores);
619
620 // Not enough stores collected.
621 if (stores.size() != (size_t)expected_stores) return Stmt();
622
623 Type t = store->value.type();
624 Expr base;
625 std::vector<Expr> args(stores.size());
626 std::vector<Expr> predicates(stores.size());
627
628 int min_offset = 0;
629 std::vector<int> offsets(stores.size());
630
631 std::string load_name;
632 Buffer<> load_image;
633 Parameter load_param;
634 for (size_t i = 0; i < stores.size(); ++i) {
635 const Ramp *ri = stores[i].as<Store>()->index.as<Ramp>();
636 internal_assert(ri);
637
638 // Mismatched store vector laness.
639 if (ri->lanes != lanes) return Stmt();
640
641 Expr diff = simplify(ri->base - r0->base);
642 const int64_t *offs = as_const_int(diff);
643
644 // Difference between bases is not constant.
645 if (!offs) return Stmt();
646
647 offsets[i] = *offs;
648 if (*offs < min_offset) {
649 min_offset = *offs;
650 }
651
652 if (stride == 1) {
653 // Difference between bases is not a multiple of the lanes.
654 if (*offs % lanes != 0) return Stmt();
655
656 // This case only triggers if we have an immediate load of the correct stride on the RHS.
657 // TODO: Could we consider mutating the RHS so that we can handle more complex Expr's than just loads?
658 const Load *load = stores[i].as<Store>()->value.as<Load>();
659 if (!load) return Stmt();
660 // TODO(psuriana): Predicated load is not currently handled.
661 if (!is_one(load->predicate)) return Stmt();
662
663 const Ramp *ramp = load->index.as<Ramp>();
664 if (!ramp) return Stmt();
665
666 // Load stride or lanes is not equal to the store lanes.
667 if (!is_const(ramp->stride, lanes) || ramp->lanes != lanes) return Stmt();
668
669 if (i == 0) {
670 load_name = load->name;
671 load_image = load->image;
672 load_param = load->param;
673 } else {
674 if (load->name != load_name) return Stmt();
675 }
676 }
677 }
678
679 // Gather the args for interleaving.
680 for (size_t i = 0; i < stores.size(); ++i) {
681 int j = offsets[i] - min_offset;
682 if (stride == 1) {
683 j /= stores.size();
684 }
685
686 if (j == 0) {
687 base = stores[i].as<Store>()->index.as<Ramp>()->base;
688 }
689
690 // The offset is not between zero and the stride.
691 if (j < 0 || (size_t)j >= stores.size()) return Stmt();
692
693 // We already have a store for this offset.
694 if (args[j].defined()) return Stmt();
695
696 if (stride == 1) {
697 // Convert multiple dense vector stores of strided vector loads
698 // into one dense vector store of interleaving dense vector loads.
699 args[j] = Load::make(t, load_name, stores[i].as<Store>()->index,
700 load_image, load_param, const_true(t.lanes()), ModulusRemainder());
701 } else {
702 args[j] = stores[i].as<Store>()->value;
703 }
704 predicates[j] = stores[i].as<Store>()->predicate;
705 }
706
707 // One of the stores should have had the minimum offset.
708 internal_assert(base.defined());
709
710 // Generate a single interleaving store.
711 t = t.with_lanes(lanes * stores.size());
712 Expr index = Ramp::make(base, make_one(base.type()), t.lanes());
713 Expr value = Shuffle::make_interleave(args);
714 Expr predicate = Shuffle::make_interleave(predicates);
715 Stmt new_store = Store::make(store->name, value, index, store->param, predicate, ModulusRemainder());
716
717 // Continue recursively into the stuff that
718 // collect_strided_stores didn't collect.
719 Stmt stmt = Block::make(new_store, mutate(rest));
720
721 // Rewrap the let statements we pulled off.
722 while (!let_stmts.empty()) {
723 const LetStmt *let = let_stmts.back().as<LetStmt>();
724 stmt = LetStmt::make(let->name, let->value, stmt);
725 let_stmts.pop_back();
726 }
727
728 // Success!
729 return stmt;
730 }
731
visit(const Block * op)732 Stmt visit(const Block *op) override {
733 Stmt s = gather_stores(op);
734 if (s.defined()) {
735 return s;
736 } else {
737 Stmt first = mutate(op->first);
738 Stmt rest = mutate(op->rest);
739 if (first.same_as(op->first) && rest.same_as(op->rest)) {
740 return op;
741 } else {
742 return Block::make(first, rest);
743 }
744 }
745 }
746
747 public:
Interleaver()748 Interleaver()
749 : should_deinterleave(false) {
750 }
751 };
752
753 } // namespace
754
rewrite_interleavings(const Stmt & s)755 Stmt rewrite_interleavings(const Stmt &s) {
756 return Interleaver().mutate(s);
757 }
758
759 namespace {
check(Expr a,const Expr & even,const Expr & odd)760 void check(Expr a, const Expr &even, const Expr &odd) {
761 a = simplify(a);
762 Expr correct_even = extract_even_lanes(a);
763 Expr correct_odd = extract_odd_lanes(a);
764 if (!equal(correct_even, even)) {
765 internal_error << correct_even << " != " << even << "\n";
766 }
767 if (!equal(correct_odd, odd)) {
768 internal_error << correct_odd << " != " << odd << "\n";
769 }
770 }
771 } // namespace
772
deinterleave_vector_test()773 void deinterleave_vector_test() {
774 std::pair<Expr, Expr> result;
775 Expr x = Variable::make(Int(32), "x");
776 Expr ramp = Ramp::make(x + 4, 3, 8);
777 Expr ramp_a = Ramp::make(x + 4, 6, 4);
778 Expr ramp_b = Ramp::make(x + 7, 6, 4);
779 Expr broadcast = Broadcast::make(x + 4, 16);
780 Expr broadcast_a = Broadcast::make(x + 4, 8);
781 const Expr &broadcast_b = broadcast_a;
782
783 check(ramp, ramp_a, ramp_b);
784 check(broadcast, broadcast_a, broadcast_b);
785
786 check(Load::make(ramp.type(), "buf", ramp, Buffer<>(), Parameter(), const_true(ramp.type().lanes()), ModulusRemainder()),
787 Load::make(ramp_a.type(), "buf", ramp_a, Buffer<>(), Parameter(), const_true(ramp_a.type().lanes()), ModulusRemainder()),
788 Load::make(ramp_b.type(), "buf", ramp_b, Buffer<>(), Parameter(), const_true(ramp_b.type().lanes()), ModulusRemainder()));
789
790 Expr vec_x = Variable::make(Int(32, 4), "vec_x");
791 Expr vec_y = Variable::make(Int(32, 4), "vec_y");
792 check(Shuffle::make({vec_x, vec_y}, {0, 4, 2, 6, 4, 2, 3, 7, 1, 2, 3, 4}),
793 Shuffle::make({vec_x, vec_y}, {0, 2, 4, 3, 1, 3}),
794 Shuffle::make({vec_x, vec_y}, {4, 6, 2, 7, 2, 4}));
795
796 std::cout << "deinterleave_vector test passed" << std::endl;
797 }
798
799 } // namespace Internal
800 } // namespace Halide
801