1 #include "Deinterleave.h"
2 
3 #include "CSE.h"
4 #include "Debug.h"
5 #include "IREquality.h"
6 #include "IRMutator.h"
7 #include "IROperator.h"
8 #include "IRPrinter.h"
9 #include "ModulusRemainder.h"
10 #include "Scope.h"
11 #include "Simplify.h"
12 #include "Substitute.h"
13 
14 namespace Halide {
15 namespace Internal {
16 
17 using std::pair;
18 
19 namespace {
20 
21 class StoreCollector : public IRMutator {
22 public:
23     const std::string store_name;
24     const int store_stride, max_stores;
25     std::vector<Stmt> &let_stmts;
26     std::vector<Stmt> &stores;
27 
StoreCollector(const std::string & name,int stride,int ms,std::vector<Stmt> & lets,std::vector<Stmt> & ss)28     StoreCollector(const std::string &name, int stride, int ms,
29                    std::vector<Stmt> &lets, std::vector<Stmt> &ss)
30         : store_name(name), store_stride(stride), max_stores(ms),
31           let_stmts(lets), stores(ss), collecting(true) {
32     }
33 
34 private:
35     using IRMutator::visit;
36 
37     // Don't enter any inner constructs for which it's not safe to pull out stores.
visit(const For * op)38     Stmt visit(const For *op) override {
39         collecting = false;
40         return op;
41     }
visit(const IfThenElse * op)42     Stmt visit(const IfThenElse *op) override {
43         collecting = false;
44         return op;
45     }
visit(const ProducerConsumer * op)46     Stmt visit(const ProducerConsumer *op) override {
47         collecting = false;
48         return op;
49     }
visit(const Allocate * op)50     Stmt visit(const Allocate *op) override {
51         collecting = false;
52         return op;
53     }
visit(const Realize * op)54     Stmt visit(const Realize *op) override {
55         collecting = false;
56         return op;
57     }
58 
59     bool collecting;
60     // These are lets that we've encountered since the last collected
61     // store. If we collect another store, these "potential" lets
62     // become lets used by the collected stores.
63     std::vector<Stmt> potential_lets;
64 
visit(const Load * op)65     Expr visit(const Load *op) override {
66         if (!collecting) {
67             return op;
68         }
69 
70         // If we hit a load from the buffer we're trying to collect
71         // stores for, stop collecting to avoid reordering loads and
72         // stores from the same buffer.
73         if (op->name == store_name) {
74             collecting = false;
75             return op;
76         } else {
77             return IRMutator::visit(op);
78         }
79     }
80 
visit(const Store * op)81     Stmt visit(const Store *op) override {
82         if (!collecting) {
83             return op;
84         }
85 
86         // By default, do nothing.
87         Stmt stmt = op;
88 
89         if (stores.size() >= (size_t)max_stores) {
90             // Already have enough stores.
91             collecting = false;
92             return stmt;
93         }
94 
95         // Make sure this Store doesn't do anything that causes us to
96         // stop collecting.
97         stmt = IRMutator::visit(op);
98         if (!collecting) {
99             return stmt;
100         }
101 
102         if (op->name != store_name) {
103             // Not a store to the buffer we're looking for.
104             return stmt;
105         }
106 
107         const Ramp *r = op->index.as<Ramp>();
108         if (!r || !is_const(r->stride, store_stride)) {
109             // Store doesn't store to the ramp we're looking
110             // for. Can't interleave it. Since we don't want to
111             // reorder stores, stop collecting.
112             collecting = false;
113             return stmt;
114         }
115 
116         // This store is good, collect it and replace with a no-op.
117         stores.emplace_back(op);
118         stmt = Evaluate::make(0);
119 
120         // Because we collected this store, we need to save the
121         // potential lets since the last collected store.
122         let_stmts.insert(let_stmts.end(), potential_lets.begin(), potential_lets.end());
123         potential_lets.clear();
124         return stmt;
125     }
126 
visit(const Call * op)127     Expr visit(const Call *op) override {
128         if (!op->is_pure()) {
129             // Avoid reordering calls to impure functions
130             collecting = false;
131             return op;
132         } else {
133             return IRMutator::visit(op);
134         }
135     }
136 
visit(const LetStmt * op)137     Stmt visit(const LetStmt *op) override {
138         if (!collecting) {
139             return op;
140         }
141 
142         // Walk inside the let chain
143         Stmt stmt = IRMutator::visit(op);
144 
145         // If we're still collecting, we need to save the entire let chain as potential lets.
146         if (collecting) {
147             Stmt body;
148             do {
149                 potential_lets.emplace_back(op);
150                 body = op->body;
151             } while ((op = body.as<LetStmt>()));
152         }
153         return stmt;
154     }
155 
visit(const Block * op)156     Stmt visit(const Block *op) override {
157         if (!collecting) {
158             return op;
159         }
160 
161         Stmt first = mutate(op->first);
162         Stmt rest = op->rest;
163         // We might have decided to stop collecting during mutation of first.
164         if (collecting) {
165             rest = mutate(rest);
166         }
167         return Block::make(first, rest);
168     }
169 };
170 
collect_strided_stores(const Stmt & stmt,const std::string & name,int stride,int max_stores,std::vector<Stmt> lets,std::vector<Stmt> & stores)171 Stmt collect_strided_stores(const Stmt &stmt, const std::string &name, int stride, int max_stores,
172                             std::vector<Stmt> lets, std::vector<Stmt> &stores) {
173 
174     StoreCollector collect(name, stride, max_stores, lets, stores);
175     return collect.mutate(stmt);
176 }
177 
178 class Deinterleaver : public IRGraphMutator {
179 public:
Deinterleaver(int starting_lane,int lane_stride,int new_lanes,const Scope<> & lets)180     Deinterleaver(int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets)
181         : starting_lane(starting_lane),
182           lane_stride(lane_stride),
183           new_lanes(new_lanes),
184           external_lets(lets) {
185     }
186 
187 private:
188     int starting_lane;
189     int lane_stride;
190     int new_lanes;
191 
192     // lets for which we have even and odd lane specializations
193     const Scope<> &external_lets;
194 
195     using IRMutator::visit;
196 
visit(const VectorReduce * op)197     Expr visit(const VectorReduce *op) override {
198         std::vector<int> input_lanes;
199         int factor = op->value.type().lanes() / op->type.lanes();
200         for (int i = starting_lane; i < op->type.lanes(); i += lane_stride) {
201             for (int j = 0; j < factor; j++) {
202                 input_lanes.push_back(i * factor + j);
203             }
204         }
205         Expr in = Shuffle::make({op->value}, input_lanes);
206         return VectorReduce::make(op->op, in, new_lanes);
207     }
208 
visit(const Broadcast * op)209     Expr visit(const Broadcast *op) override {
210         if (new_lanes == 1) {
211             return op->value;
212         }
213         return Broadcast::make(op->value, new_lanes);
214     }
215 
visit(const Load * op)216     Expr visit(const Load *op) override {
217         if (op->type.is_scalar()) {
218             return op;
219         } else {
220             Type t = op->type.with_lanes(new_lanes);
221             ModulusRemainder align = op->alignment;
222             // TODO: Figure out the alignment of every nth lane
223             if (starting_lane != 0) {
224                 align = ModulusRemainder();
225             }
226             return Load::make(t, op->name, mutate(op->index), op->image, op->param, mutate(op->predicate), align);
227         }
228     }
229 
visit(const Ramp * op)230     Expr visit(const Ramp *op) override {
231         Expr expr = op->base + starting_lane * op->stride;
232         internal_assert(expr.type() == op->base.type());
233         if (new_lanes > 1) {
234             expr = Ramp::make(expr, op->stride * lane_stride, new_lanes);
235         }
236         return expr;
237     }
238 
visit(const Variable * op)239     Expr visit(const Variable *op) override {
240         if (op->type.is_scalar()) {
241             return op;
242         } else {
243 
244             Type t = op->type.with_lanes(new_lanes);
245             if (external_lets.contains(op->name) &&
246                 starting_lane == 0 &&
247                 lane_stride == 2) {
248                 return Variable::make(t, op->name + ".even_lanes", op->image, op->param, op->reduction_domain);
249             } else if (external_lets.contains(op->name) &&
250                        starting_lane == 1 &&
251                        lane_stride == 2) {
252                 return Variable::make(t, op->name + ".odd_lanes", op->image, op->param, op->reduction_domain);
253             } else if (external_lets.contains(op->name) &&
254                        starting_lane == 0 &&
255                        lane_stride == 3) {
256                 return Variable::make(t, op->name + ".lanes_0_of_3", op->image, op->param, op->reduction_domain);
257             } else if (external_lets.contains(op->name) &&
258                        starting_lane == 1 &&
259                        lane_stride == 3) {
260                 return Variable::make(t, op->name + ".lanes_1_of_3", op->image, op->param, op->reduction_domain);
261             } else if (external_lets.contains(op->name) &&
262                        starting_lane == 2 &&
263                        lane_stride == 3) {
264                 return Variable::make(t, op->name + ".lanes_2_of_3", op->image, op->param, op->reduction_domain);
265             } else {
266                 // Uh-oh, we don't know how to deinterleave this vector expression
267                 // Make llvm do it
268                 std::vector<int> indices;
269                 for (int i = 0; i < new_lanes; i++) {
270                     indices.push_back(starting_lane + lane_stride * i);
271                 }
272                 return Shuffle::make({op}, indices);
273             }
274         }
275     }
276 
visit(const Cast * op)277     Expr visit(const Cast *op) override {
278         if (op->type.is_scalar()) {
279             return op;
280         } else {
281             Type t = op->type.with_lanes(new_lanes);
282             return Cast::make(t, mutate(op->value));
283         }
284     }
285 
visit(const Call * op)286     Expr visit(const Call *op) override {
287         Type t = op->type.with_lanes(new_lanes);
288 
289         // Don't mutate scalars
290         if (op->type.is_scalar()) {
291             return op;
292         } else if (op->is_intrinsic(Call::glsl_texture_load)) {
293             // glsl_texture_load returns a <uint x 4> result. Deinterleave by
294             // wrapping the call in a shuffle_vector
295             std::vector<int> indices;
296             for (int i = 0; i < new_lanes; i++) {
297                 indices.push_back(i * lane_stride + starting_lane);
298             }
299             return Shuffle::make({op}, indices);
300         } else {
301 
302             // Vector calls are always parallel across the lanes, so we
303             // can just deinterleave the args.
304 
305             // Beware of intrinsics for which this is not true!
306             std::vector<Expr> args(op->args.size());
307             for (size_t i = 0; i < args.size(); i++) {
308                 args[i] = mutate(op->args[i]);
309             }
310 
311             return Call::make(t, op->name, args, op->call_type,
312                               op->func, op->value_index, op->image, op->param);
313         }
314     }
315 
visit(const Shuffle * op)316     Expr visit(const Shuffle *op) override {
317         if (op->is_interleave()) {
318             // Special case where we can discard some of the vector arguments entirely.
319             internal_assert(starting_lane >= 0 && starting_lane < lane_stride);
320             if ((int)op->vectors.size() == lane_stride) {
321                 return op->vectors[starting_lane];
322             } else if ((int)op->vectors.size() % lane_stride == 0) {
323                 // Pick up every lane-stride vector.
324                 std::vector<Expr> new_vectors(op->vectors.size() / lane_stride);
325                 for (size_t i = 0; i < new_vectors.size(); i++) {
326                     new_vectors[i] = op->vectors[i * lane_stride + starting_lane];
327                 }
328                 return Shuffle::make_interleave(new_vectors);
329             }
330         }
331 
332         // Keep the same set of vectors and extract every nth numeric
333         // arg to the shuffle.
334         std::vector<int> indices;
335         for (int i = 0; i < new_lanes; i++) {
336             int idx = i * lane_stride + starting_lane;
337             indices.push_back(op->indices[idx]);
338         }
339         return Shuffle::make(op->vectors, indices);
340     }
341 };
342 
deinterleave(Expr e,int starting_lane,int lane_stride,int new_lanes,const Scope<> & lets)343 Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) {
344     e = substitute_in_all_lets(e);
345     Deinterleaver d(starting_lane, lane_stride, new_lanes, lets);
346     e = d.mutate(e);
347     e = common_subexpression_elimination(e);
348     return simplify(e);
349 }
350 }  // namespace
351 
extract_odd_lanes(const Expr & e,const Scope<> & lets)352 Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) {
353     internal_assert(e.type().lanes() % 2 == 0);
354     return deinterleave(e, 1, 2, e.type().lanes() / 2, lets);
355 }
356 
extract_even_lanes(const Expr & e,const Scope<> & lets)357 Expr extract_even_lanes(const Expr &e, const Scope<> &lets) {
358     internal_assert(e.type().lanes() % 2 == 0);
359     return deinterleave(e, 0, 2, (e.type().lanes() + 1) / 2, lets);
360 }
361 
extract_even_lanes(const Expr & e)362 Expr extract_even_lanes(const Expr &e) {
363     internal_assert(e.type().lanes() % 2 == 0);
364     Scope<> lets;
365     return extract_even_lanes(e, lets);
366 }
367 
extract_odd_lanes(const Expr & e)368 Expr extract_odd_lanes(const Expr &e) {
369     internal_assert(e.type().lanes() % 2 == 0);
370     Scope<> lets;
371     return extract_odd_lanes(e, lets);
372 }
373 
extract_mod3_lanes(const Expr & e,int lane,const Scope<> & lets)374 Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) {
375     internal_assert(e.type().lanes() % 3 == 0);
376     return deinterleave(e, lane, 3, (e.type().lanes() + 2) / 3, lets);
377 }
378 
extract_lane(const Expr & e,int lane)379 Expr extract_lane(const Expr &e, int lane) {
380     Scope<> lets;
381     return deinterleave(e, lane, e.type().lanes(), 1, lets);
382 }
383 
384 namespace {
385 
386 class Interleaver : public IRMutator {
387     Scope<> vector_lets;
388 
389     using IRMutator::visit;
390 
391     bool should_deinterleave;
392     int num_lanes;
393 
deinterleave_expr(Expr e)394     Expr deinterleave_expr(Expr e) {
395         if (e.type().lanes() <= num_lanes) {
396             // Just scalarize
397             return e;
398         } else if (num_lanes == 2) {
399             Expr a = extract_even_lanes(e, vector_lets);
400             Expr b = extract_odd_lanes(e, vector_lets);
401             return Shuffle::make_interleave({a, b});
402         } else if (num_lanes == 3) {
403             Expr a = extract_mod3_lanes(e, 0, vector_lets);
404             Expr b = extract_mod3_lanes(e, 1, vector_lets);
405             Expr c = extract_mod3_lanes(e, 2, vector_lets);
406             return Shuffle::make_interleave({a, b, c});
407         } else if (num_lanes == 4) {
408             Expr a = extract_even_lanes(e, vector_lets);
409             Expr b = extract_odd_lanes(e, vector_lets);
410             Expr aa = extract_even_lanes(a, vector_lets);
411             Expr ab = extract_odd_lanes(a, vector_lets);
412             Expr ba = extract_even_lanes(b, vector_lets);
413             Expr bb = extract_odd_lanes(b, vector_lets);
414             return Shuffle::make_interleave({aa, ba, ab, bb});
415         } else {
416             // Give up and don't do anything clever for >4
417             return e;
418         }
419     }
420 
421     template<typename T, typename Body>
visit_lets(const T * op)422     Body visit_lets(const T *op) {
423         // Visit an entire chain of lets in a single method to conserve stack space.
424         struct Frame {
425             const T *op;
426             Expr new_value;
427             ScopedBinding<> binding;
428             Frame(const T *op, Expr v, Scope<void> &scope)
429                 : op(op),
430                   new_value(std::move(v)),
431                   binding(new_value.type().is_vector(), scope, op->name) {
432             }
433         };
434         std::vector<Frame> frames;
435         Body result;
436 
437         do {
438             result = op->body;
439             frames.emplace_back(op, mutate(op->value), vector_lets);
440         } while ((op = result.template as<T>()));
441 
442         result = mutate(result);
443 
444         for (auto it = frames.rbegin(); it != frames.rend(); it++) {
445             Expr value = std::move(it->new_value);
446 
447             result = T::make(it->op->name, value, result);
448 
449             // For vector lets, we may additionally need a let defining the even and odd lanes only
450             if (value.type().is_vector()) {
451                 if (value.type().lanes() % 2 == 0) {
452                     result = T::make(it->op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result);
453                     result = T::make(it->op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result);
454                 }
455                 if (value.type().lanes() % 3 == 0) {
456                     result = T::make(it->op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result);
457                     result = T::make(it->op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result);
458                     result = T::make(it->op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result);
459                 }
460             }
461         }
462 
463         return result;
464     }
465 
visit(const Let * op)466     Expr visit(const Let *op) override {
467         return visit_lets<Let, Expr>(op);
468     }
469 
visit(const LetStmt * op)470     Stmt visit(const LetStmt *op) override {
471         return visit_lets<LetStmt, Stmt>(op);
472     }
473 
visit(const Mod * op)474     Expr visit(const Mod *op) override {
475         const Ramp *r = op->a.as<Ramp>();
476         for (int i = 2; i <= 4; ++i) {
477             if (r &&
478                 is_const(op->b, i) &&
479                 (r->type.lanes() % i) == 0) {
480                 should_deinterleave = true;
481                 num_lanes = i;
482                 break;
483             }
484         }
485         return IRMutator::visit(op);
486     }
487 
visit(const Div * op)488     Expr visit(const Div *op) override {
489         const Ramp *r = op->a.as<Ramp>();
490         for (int i = 2; i <= 4; ++i) {
491             if (r &&
492                 is_const(op->b, i) &&
493                 (r->type.lanes() % i) == 0) {
494                 should_deinterleave = true;
495                 num_lanes = i;
496                 break;
497             }
498         }
499         return IRMutator::visit(op);
500     }
501 
visit(const Call * op)502     Expr visit(const Call *op) override {
503         if (!op->is_pure()) {
504             // deinterleaving potentially changes the order of execution.
505             should_deinterleave = false;
506         }
507         return IRMutator::visit(op);
508     }
509 
visit(const Load * op)510     Expr visit(const Load *op) override {
511         bool old_should_deinterleave = should_deinterleave;
512         int old_num_lanes = num_lanes;
513 
514         should_deinterleave = false;
515         Expr idx = mutate(op->index);
516         bool should_deinterleave_idx = should_deinterleave;
517 
518         should_deinterleave = false;
519         Expr predicate = mutate(op->predicate);
520         bool should_deinterleave_predicate = should_deinterleave;
521 
522         Expr expr;
523         if (should_deinterleave_idx && (should_deinterleave_predicate || is_one(predicate))) {
524             // If we want to deinterleave both the index and predicate
525             // (or the predicate is one), then deinterleave the
526             // resulting load.
527             expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
528             expr = deinterleave_expr(expr);
529         } else if (should_deinterleave_idx) {
530             // If we only want to deinterleave the index and not the
531             // predicate, deinterleave the index prior to the load.
532             idx = deinterleave_expr(idx);
533             expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
534         } else if (should_deinterleave_predicate) {
535             // Similarly, deinterleave the predicate prior to the load
536             // if we don't want to deinterleave the index.
537             predicate = deinterleave_expr(predicate);
538             expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
539         } else if (!idx.same_as(op->index) || !predicate.same_as(op->index)) {
540             expr = Load::make(op->type, op->name, idx, op->image, op->param, predicate, op->alignment);
541         } else {
542             expr = op;
543         }
544 
545         should_deinterleave = old_should_deinterleave;
546         num_lanes = old_num_lanes;
547         return expr;
548     }
549 
visit(const Store * op)550     Stmt visit(const Store *op) override {
551         bool old_should_deinterleave = should_deinterleave;
552         int old_num_lanes = num_lanes;
553 
554         should_deinterleave = false;
555         Expr idx = mutate(op->index);
556         if (should_deinterleave) {
557             idx = deinterleave_expr(idx);
558         }
559 
560         should_deinterleave = false;
561         Expr value = mutate(op->value);
562         if (should_deinterleave) {
563             value = deinterleave_expr(value);
564         }
565 
566         should_deinterleave = false;
567         Expr predicate = mutate(op->predicate);
568         if (should_deinterleave) {
569             predicate = deinterleave_expr(predicate);
570         }
571 
572         Stmt stmt = Store::make(op->name, value, idx, op->param, predicate, op->alignment);
573 
574         should_deinterleave = old_should_deinterleave;
575         num_lanes = old_num_lanes;
576 
577         return stmt;
578     }
579 
gather_stores(const Block * op)580     HALIDE_NEVER_INLINE Stmt gather_stores(const Block *op) {
581         const LetStmt *let = op->first.as<LetStmt>();
582         const Store *store = op->first.as<Store>();
583 
584         // Gather all the let stmts surrounding the first.
585         std::vector<Stmt> let_stmts;
586         while (let) {
587             let_stmts.emplace_back(let);
588             store = let->body.as<Store>();
589             let = let->body.as<LetStmt>();
590         }
591 
592         // There was no inner store.
593         if (!store) return Stmt();
594 
595         const Ramp *r0 = store->index.as<Ramp>();
596 
597         // It's not a store of a ramp index.
598         if (!r0) return Stmt();
599 
600         const int64_t *stride_ptr = as_const_int(r0->stride);
601 
602         // The stride isn't a constant or is <= 0
603         if (!stride_ptr || *stride_ptr < 1) return Stmt();
604 
605         const int64_t stride = *stride_ptr;
606         const int lanes = r0->lanes;
607         const int64_t expected_stores = stride == 1 ? lanes : stride;
608 
609         // Collect the rest of the stores.
610         std::vector<Stmt> stores;
611         stores.emplace_back(store);
612         Stmt rest = collect_strided_stores(op->rest, store->name,
613                                            stride, expected_stores,
614                                            let_stmts, stores);
615 
616         // Check the store collector didn't collect too many
617         // stores (that would be a bug).
618         internal_assert(stores.size() <= (size_t)expected_stores);
619 
620         // Not enough stores collected.
621         if (stores.size() != (size_t)expected_stores) return Stmt();
622 
623         Type t = store->value.type();
624         Expr base;
625         std::vector<Expr> args(stores.size());
626         std::vector<Expr> predicates(stores.size());
627 
628         int min_offset = 0;
629         std::vector<int> offsets(stores.size());
630 
631         std::string load_name;
632         Buffer<> load_image;
633         Parameter load_param;
634         for (size_t i = 0; i < stores.size(); ++i) {
635             const Ramp *ri = stores[i].as<Store>()->index.as<Ramp>();
636             internal_assert(ri);
637 
638             // Mismatched store vector laness.
639             if (ri->lanes != lanes) return Stmt();
640 
641             Expr diff = simplify(ri->base - r0->base);
642             const int64_t *offs = as_const_int(diff);
643 
644             // Difference between bases is not constant.
645             if (!offs) return Stmt();
646 
647             offsets[i] = *offs;
648             if (*offs < min_offset) {
649                 min_offset = *offs;
650             }
651 
652             if (stride == 1) {
653                 // Difference between bases is not a multiple of the lanes.
654                 if (*offs % lanes != 0) return Stmt();
655 
656                 // This case only triggers if we have an immediate load of the correct stride on the RHS.
657                 // TODO: Could we consider mutating the RHS so that we can handle more complex Expr's than just loads?
658                 const Load *load = stores[i].as<Store>()->value.as<Load>();
659                 if (!load) return Stmt();
660                 // TODO(psuriana): Predicated load is not currently handled.
661                 if (!is_one(load->predicate)) return Stmt();
662 
663                 const Ramp *ramp = load->index.as<Ramp>();
664                 if (!ramp) return Stmt();
665 
666                 // Load stride or lanes is not equal to the store lanes.
667                 if (!is_const(ramp->stride, lanes) || ramp->lanes != lanes) return Stmt();
668 
669                 if (i == 0) {
670                     load_name = load->name;
671                     load_image = load->image;
672                     load_param = load->param;
673                 } else {
674                     if (load->name != load_name) return Stmt();
675                 }
676             }
677         }
678 
679         // Gather the args for interleaving.
680         for (size_t i = 0; i < stores.size(); ++i) {
681             int j = offsets[i] - min_offset;
682             if (stride == 1) {
683                 j /= stores.size();
684             }
685 
686             if (j == 0) {
687                 base = stores[i].as<Store>()->index.as<Ramp>()->base;
688             }
689 
690             // The offset is not between zero and the stride.
691             if (j < 0 || (size_t)j >= stores.size()) return Stmt();
692 
693             // We already have a store for this offset.
694             if (args[j].defined()) return Stmt();
695 
696             if (stride == 1) {
697                 // Convert multiple dense vector stores of strided vector loads
698                 // into one dense vector store of interleaving dense vector loads.
699                 args[j] = Load::make(t, load_name, stores[i].as<Store>()->index,
700                                      load_image, load_param, const_true(t.lanes()), ModulusRemainder());
701             } else {
702                 args[j] = stores[i].as<Store>()->value;
703             }
704             predicates[j] = stores[i].as<Store>()->predicate;
705         }
706 
707         // One of the stores should have had the minimum offset.
708         internal_assert(base.defined());
709 
710         // Generate a single interleaving store.
711         t = t.with_lanes(lanes * stores.size());
712         Expr index = Ramp::make(base, make_one(base.type()), t.lanes());
713         Expr value = Shuffle::make_interleave(args);
714         Expr predicate = Shuffle::make_interleave(predicates);
715         Stmt new_store = Store::make(store->name, value, index, store->param, predicate, ModulusRemainder());
716 
717         // Continue recursively into the stuff that
718         // collect_strided_stores didn't collect.
719         Stmt stmt = Block::make(new_store, mutate(rest));
720 
721         // Rewrap the let statements we pulled off.
722         while (!let_stmts.empty()) {
723             const LetStmt *let = let_stmts.back().as<LetStmt>();
724             stmt = LetStmt::make(let->name, let->value, stmt);
725             let_stmts.pop_back();
726         }
727 
728         // Success!
729         return stmt;
730     }
731 
visit(const Block * op)732     Stmt visit(const Block *op) override {
733         Stmt s = gather_stores(op);
734         if (s.defined()) {
735             return s;
736         } else {
737             Stmt first = mutate(op->first);
738             Stmt rest = mutate(op->rest);
739             if (first.same_as(op->first) && rest.same_as(op->rest)) {
740                 return op;
741             } else {
742                 return Block::make(first, rest);
743             }
744         }
745     }
746 
747 public:
Interleaver()748     Interleaver()
749         : should_deinterleave(false) {
750     }
751 };
752 
753 }  // namespace
754 
rewrite_interleavings(const Stmt & s)755 Stmt rewrite_interleavings(const Stmt &s) {
756     return Interleaver().mutate(s);
757 }
758 
759 namespace {
check(Expr a,const Expr & even,const Expr & odd)760 void check(Expr a, const Expr &even, const Expr &odd) {
761     a = simplify(a);
762     Expr correct_even = extract_even_lanes(a);
763     Expr correct_odd = extract_odd_lanes(a);
764     if (!equal(correct_even, even)) {
765         internal_error << correct_even << " != " << even << "\n";
766     }
767     if (!equal(correct_odd, odd)) {
768         internal_error << correct_odd << " != " << odd << "\n";
769     }
770 }
771 }  // namespace
772 
deinterleave_vector_test()773 void deinterleave_vector_test() {
774     std::pair<Expr, Expr> result;
775     Expr x = Variable::make(Int(32), "x");
776     Expr ramp = Ramp::make(x + 4, 3, 8);
777     Expr ramp_a = Ramp::make(x + 4, 6, 4);
778     Expr ramp_b = Ramp::make(x + 7, 6, 4);
779     Expr broadcast = Broadcast::make(x + 4, 16);
780     Expr broadcast_a = Broadcast::make(x + 4, 8);
781     const Expr &broadcast_b = broadcast_a;
782 
783     check(ramp, ramp_a, ramp_b);
784     check(broadcast, broadcast_a, broadcast_b);
785 
786     check(Load::make(ramp.type(), "buf", ramp, Buffer<>(), Parameter(), const_true(ramp.type().lanes()), ModulusRemainder()),
787           Load::make(ramp_a.type(), "buf", ramp_a, Buffer<>(), Parameter(), const_true(ramp_a.type().lanes()), ModulusRemainder()),
788           Load::make(ramp_b.type(), "buf", ramp_b, Buffer<>(), Parameter(), const_true(ramp_b.type().lanes()), ModulusRemainder()));
789 
790     Expr vec_x = Variable::make(Int(32, 4), "vec_x");
791     Expr vec_y = Variable::make(Int(32, 4), "vec_y");
792     check(Shuffle::make({vec_x, vec_y}, {0, 4, 2, 6, 4, 2, 3, 7, 1, 2, 3, 4}),
793           Shuffle::make({vec_x, vec_y}, {0, 2, 4, 3, 1, 3}),
794           Shuffle::make({vec_x, vec_y}, {4, 6, 2, 7, 2, 4}));
795 
796     std::cout << "deinterleave_vector test passed" << std::endl;
797 }
798 
799 }  // namespace Internal
800 }  // namespace Halide
801