1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 namespace {
6 
7 template<typename T>
set_interleaved(T & t)8 void set_interleaved(T &t) {
9     t.dim(0).set_stride(3).dim(2).set_min(0).set_extent(3).set_stride(1);
10 }
11 
12 // Add two inputs
13 class NestedExternsCombine : public Generator<NestedExternsCombine> {
14 public:
15     Input<Buffer<float>> input_a{"input_a", 3};
16     Input<Buffer<float>> input_b{"input_b", 3};
17     Output<Buffer<>> combine{"combine"};  // unspecified type-and-dim will be inferred
18 
generate()19     void generate() {
20         Var x, y, c;
21         combine(x, y, c) = input_a(x, y, c) + input_b(x, y, c);
22     }
23 
schedule()24     void schedule() {
25         set_interleaved(input_a);
26         set_interleaved(input_b);
27         set_interleaved(combine);
28     }
29 };
30 
31 // Call two extern stages then pass the two results to another extern stage.
32 class NestedExternsInner : public Generator<NestedExternsInner> {
33 public:
34     Input<float> value{"value", 1.0f};
35     Output<Buffer<float>> inner{"inner", 3};
36 
generate()37     void generate() {
38         extern_stage_1.define_extern("nested_externs_leaf", {value}, Float(32), 3);
39         extern_stage_2.define_extern("nested_externs_leaf", {value + 1}, Float(32), 3);
40         extern_stage_combine.define_extern("nested_externs_combine",
41                                            {extern_stage_1, extern_stage_2}, Float(32), 3);
42         inner(x, y, c) = extern_stage_combine(x, y, c);
43     }
44 
schedule()45     void schedule() {
46         for (Func f : {extern_stage_1, extern_stage_2, extern_stage_combine}) {
47             auto args = f.args();
48             f.compute_root().reorder_storage(args[2], args[0], args[1]);
49         }
50         set_interleaved(inner);
51     }
52 
53 private:
54     Var x, y, c;
55     Func extern_stage_1, extern_stage_2, extern_stage_combine;
56 };
57 
58 // Basically a memset.
59 class NestedExternsLeaf : public Generator<NestedExternsLeaf> {
60 public:
61     Input<float> value{"value", 1.0f};
62     Output<Buffer<float>> leaf{"leaf", 3};
63 
generate()64     void generate() {
65         Var x, y, c;
66         leaf(x, y, c) = value;
67     }
68 
schedule()69     void schedule() {
70         set_interleaved(leaf);
71     }
72 };
73 
74 // Call two extern stages then pass the two results to another extern stage.
75 class NestedExternsRoot : public Generator<NestedExternsRoot> {
76 public:
77     Input<float> value{"value", 1.0f};
78     Output<Buffer<float>> root{"root", 3};
79 
generate()80     void generate() {
81         extern_stage_1.define_extern("nested_externs_inner", {value}, Float(32), 3);
82         extern_stage_2.define_extern("nested_externs_inner", {value + 1}, Float(32), 3);
83         extern_stage_combine.define_extern("nested_externs_combine",
84                                            {extern_stage_1, extern_stage_2}, Float(32), 3);
85         root(x, y, c) = extern_stage_combine(x, y, c);
86     }
87 
schedule()88     void schedule() {
89         for (Func f : {extern_stage_1, extern_stage_2, extern_stage_combine}) {
90             auto args = f.args();
91             f.compute_at(root, y).reorder_storage(args[2], args[0], args[1]);
92         }
93         set_interleaved(root);
94         root.reorder_storage(c, x, y);
95     }
96 
97 private:
98     Var x, y, c;
99     Func extern_stage_1, extern_stage_2, extern_stage_combine;
100 };
101 
102 }  // namespace
103 
104 HALIDE_REGISTER_GENERATOR(NestedExternsCombine, nested_externs_combine)
105 HALIDE_REGISTER_GENERATOR(NestedExternsInner, nested_externs_inner)
106 HALIDE_REGISTER_GENERATOR(NestedExternsLeaf, nested_externs_leaf)
107 HALIDE_REGISTER_GENERATOR(NestedExternsRoot, nested_externs_root)
108