1 #include "Halide.h"
2
3 using namespace Halide;
4
5 namespace {
6
7 template<typename T>
set_interleaved(T & t)8 void set_interleaved(T &t) {
9 t.dim(0).set_stride(3).dim(2).set_min(0).set_extent(3).set_stride(1);
10 }
11
12 // Add two inputs
13 class NestedExternsCombine : public Generator<NestedExternsCombine> {
14 public:
15 Input<Buffer<float>> input_a{"input_a", 3};
16 Input<Buffer<float>> input_b{"input_b", 3};
17 Output<Buffer<>> combine{"combine"}; // unspecified type-and-dim will be inferred
18
generate()19 void generate() {
20 Var x, y, c;
21 combine(x, y, c) = input_a(x, y, c) + input_b(x, y, c);
22 }
23
schedule()24 void schedule() {
25 set_interleaved(input_a);
26 set_interleaved(input_b);
27 set_interleaved(combine);
28 }
29 };
30
31 // Call two extern stages then pass the two results to another extern stage.
32 class NestedExternsInner : public Generator<NestedExternsInner> {
33 public:
34 Input<float> value{"value", 1.0f};
35 Output<Buffer<float>> inner{"inner", 3};
36
generate()37 void generate() {
38 extern_stage_1.define_extern("nested_externs_leaf", {value}, Float(32), 3);
39 extern_stage_2.define_extern("nested_externs_leaf", {value + 1}, Float(32), 3);
40 extern_stage_combine.define_extern("nested_externs_combine",
41 {extern_stage_1, extern_stage_2}, Float(32), 3);
42 inner(x, y, c) = extern_stage_combine(x, y, c);
43 }
44
schedule()45 void schedule() {
46 for (Func f : {extern_stage_1, extern_stage_2, extern_stage_combine}) {
47 auto args = f.args();
48 f.compute_root().reorder_storage(args[2], args[0], args[1]);
49 }
50 set_interleaved(inner);
51 }
52
53 private:
54 Var x, y, c;
55 Func extern_stage_1, extern_stage_2, extern_stage_combine;
56 };
57
58 // Basically a memset.
59 class NestedExternsLeaf : public Generator<NestedExternsLeaf> {
60 public:
61 Input<float> value{"value", 1.0f};
62 Output<Buffer<float>> leaf{"leaf", 3};
63
generate()64 void generate() {
65 Var x, y, c;
66 leaf(x, y, c) = value;
67 }
68
schedule()69 void schedule() {
70 set_interleaved(leaf);
71 }
72 };
73
74 // Call two extern stages then pass the two results to another extern stage.
75 class NestedExternsRoot : public Generator<NestedExternsRoot> {
76 public:
77 Input<float> value{"value", 1.0f};
78 Output<Buffer<float>> root{"root", 3};
79
generate()80 void generate() {
81 extern_stage_1.define_extern("nested_externs_inner", {value}, Float(32), 3);
82 extern_stage_2.define_extern("nested_externs_inner", {value + 1}, Float(32), 3);
83 extern_stage_combine.define_extern("nested_externs_combine",
84 {extern_stage_1, extern_stage_2}, Float(32), 3);
85 root(x, y, c) = extern_stage_combine(x, y, c);
86 }
87
schedule()88 void schedule() {
89 for (Func f : {extern_stage_1, extern_stage_2, extern_stage_combine}) {
90 auto args = f.args();
91 f.compute_at(root, y).reorder_storage(args[2], args[0], args[1]);
92 }
93 set_interleaved(root);
94 root.reorder_storage(c, x, y);
95 }
96
97 private:
98 Var x, y, c;
99 Func extern_stage_1, extern_stage_2, extern_stage_combine;
100 };
101
102 } // namespace
103
104 HALIDE_REGISTER_GENERATOR(NestedExternsCombine, nested_externs_combine)
105 HALIDE_REGISTER_GENERATOR(NestedExternsInner, nested_externs_inner)
106 HALIDE_REGISTER_GENERATOR(NestedExternsLeaf, nested_externs_leaf)
107 HALIDE_REGISTER_GENERATOR(NestedExternsRoot, nested_externs_root)
108