1 #include "Halide.h"
2 
3 namespace {
4 
func_vector(const std::string & name,int size)5 std::vector<Halide::Func> func_vector(const std::string &name, int size) {
6     std::vector<Halide::Func> funcs;
7     for (int i = 0; i < size; i++) {
8         funcs.emplace_back(Halide::Func{name + "_" + std::to_string(i)});
9     }
10     return funcs;
11 }
12 
13 class Interpolate : public Halide::Generator<Interpolate> {
14 public:
15     GeneratorParam<int> levels{"levels", 10};
16 
17     Input<Buffer<float>> input{"input", 3};
18     Output<Buffer<float>> output{"output", 3};
19 
generate()20     void generate() {
21         Var x("x"), y("y"), c("c");
22 
23         // Input must have four color channels - rgba
24         input.dim(2).set_bounds(0, 4);
25 
26         auto downsampled = func_vector("downsampled", levels);
27         auto downx = func_vector("downx", levels);
28         auto interpolated = func_vector("interpolated", levels);
29         auto upsampled = func_vector("upsampled", levels);
30         auto upsampledx = func_vector("upsampledx", levels);
31 
32         Func clamped = Halide::BoundaryConditions::repeat_edge(input);
33 
34         downsampled[0](x, y, c) = select(c < 3, clamped(x, y, c) * clamped(x, y, 3), clamped(x, y, 3));
35 
36         for (int l = 1; l < levels; ++l) {
37             Func prev = downsampled[l - 1];
38 
39             if (l == 4) {
40                 // Also add a boundary condition at a middle pyramid level
41                 // to prevent the footprint of the downsamplings to extend
42                 // too far off the base image. Otherwise we look 512
43                 // pixels off each edge.
44                 Expr w = input.width() / (1 << l);
45                 Expr h = input.height() / (1 << l);
46                 prev = lambda(x, y, c, prev(clamp(x, 0, w), clamp(y, 0, h), c));
47             }
48 
49             downx[l](x, y, c) = (prev(x * 2 - 1, y, c) +
50                                  2.0f * prev(x * 2, y, c) +
51                                  prev(x * 2 + 1, y, c)) *
52                                 0.25f;
53             downsampled[l](x, y, c) = (downx[l](x, y * 2 - 1, c) +
54                                        2.0f * downx[l](x, y * 2, c) +
55                                        downx[l](x, y * 2 + 1, c)) *
56                                       0.25f;
57         }
58         interpolated[levels - 1](x, y, c) = downsampled[levels - 1](x, y, c);
59         for (int l = levels - 2; l >= 0; --l) {
60             upsampledx[l](x, y, c) = (interpolated[l + 1](x / 2, y, c) +
61                                       interpolated[l + 1]((x + 1) / 2, y, c)) /
62                                      2.0f;
63             upsampled[l](x, y, c) = (upsampledx[l](x, y / 2, c) +
64                                      upsampledx[l](x, (y + 1) / 2, c)) /
65                                     2.0f;
66             Expr alpha = 1.0f - downsampled[l](x, y, 3);
67             interpolated[l](x, y, c) = (downsampled[l](x, y, c) +
68                                         alpha * upsampled[l](x, y, c));
69         }
70 
71         Func normalize("normalize");
72         normalize(x, y, c) = interpolated[0](x, y, c) / interpolated[0](x, y, 3);
73 
74         // Schedule
75         if (auto_schedule) {
76             output = normalize;
77         } else {
78             // 0.86ms on a 2060 RTX
79             Var yo, yi, xo, xi, ci, xii, yii;
80             if (get_target().has_gpu_feature()) {
81                 normalize
82                     .bound(x, 0, input.width())
83                     .bound(y, 0, input.height())
84                     .bound(c, 0, 3)
85                     .reorder(c, x, y)
86                     .tile(x, y, xi, yi, 32, 32, TailStrategy::RoundUp)
87                     .tile(xi, yi, xii, yii, 2, 2)
88                     .gpu_blocks(x, y)
89                     .gpu_threads(xi, yi)
90                     .unroll(xii)
91                     .unroll(yii)
92                     .unroll(c);
93 
94                 for (int l = 1; l < levels; l++) {
95                     downsampled[l]
96                         .compute_root()
97                         .reorder(c, x, y)
98                         .unroll(c)
99                         .gpu_tile(x, y, xi, yi, 16, 16);
100                 }
101 
102                 for (int l = 3; l < levels; l += 2) {
103                     interpolated[l]
104                         .compute_root()
105                         .reorder(c, x, y)
106                         .tile(x, y, xi, yi, 32, 32, TailStrategy::RoundUp)
107                         .tile(xi, yi, xii, yii, 2, 2)
108                         .gpu_blocks(x, y)
109                         .gpu_threads(xi, yi)
110                         .unroll(xii)
111                         .unroll(yii)
112                         .unroll(c);
113                 }
114 
115                 upsampledx[1]
116                     .compute_at(normalize, x)
117                     .reorder(c, x, y)
118                     .tile(x, y, xi, yi, 2, 1)
119                     .unroll(xi)
120                     .unroll(yi)
121                     .unroll(c)
122                     .gpu_threads(x, y);
123 
124                 interpolated[1]
125                     .compute_at(normalize, x)
126                     .reorder(c, x, y)
127                     .tile(x, y, xi, yi, 2, 2)
128                     .unroll(xi)
129                     .unroll(yi)
130                     .unroll(c)
131                     .gpu_threads(x, y);
132 
133                 interpolated[2]
134                     .compute_at(normalize, x)
135                     .reorder(c, x, y)
136                     .unroll(c)
137                     .gpu_threads(x, y);
138 
139                 output = normalize;
140             } else {
141                 // 4.54ms on an Intel i9-9960X using 16 threads
142                 Var xo, xi, yo, yi;
143                 const int vec = natural_vector_size<float>();
144                 for (int l = 1; l < levels - 1; ++l) {
145                     // We must refer to the downsampled stages in the
146                     // upsampling later, so they must all be
147                     // compute_root or redundantly recomputed, as in
148                     // the local_laplacian app.
149                     downsampled[l]
150                         .compute_root()
151                         .reorder(x, c, y)
152                         .split(y, yo, yi, 8)
153                         .parallel(yo)
154                         .vectorize(x, vec);
155                 }
156 
157                 // downsampled[0] takes too long to compute_root, so
158                 // we'll redundantly recompute it instead.  Make a
159                 // separate clone of it in the first downsampled stage
160                 // so that we can schedule the two versions
161                 // separately.
162                 downsampled[0]
163                     .clone_in(downx[1])
164                     .store_at(downsampled[1], yo)
165                     .compute_at(downsampled[1], yi)
166                     .reorder(c, x, y)
167                     .unroll(c)
168                     .vectorize(x, vec);
169 
170                 normalize
171                     .bound(x, 0, input.width())
172                     .bound(y, 0, input.height())
173                     .bound(c, 0, 3)
174                     .split(x, xo, xi, vec)
175                     .split(y, yo, yi, 32)
176                     .reorder(xi, c, xo, yi, yo)
177                     .unroll(c)
178                     .vectorize(xi)
179                     .parallel(yo);
180 
181                 downsampled[0]
182                     .store_at(normalize, yo)
183                     .compute_at(normalize, yi)
184                     .reorder(c, x, y)
185                     .unroll(c)
186                     .vectorize(x, vec);
187 
188                 for (int l = 1; l < levels; l++) {
189                     interpolated[l]
190                         .store_at(normalize, yo)
191                         .compute_at(normalize, yi)
192                         .vectorize(x, vec);
193                 }
194 
195                 output = normalize;
196             }
197         }
198 
199         // Estimates (for autoscheduler; ignored otherwise)
200         {
201             input.dim(0).set_estimate(0, 1536);
202             input.dim(1).set_estimate(0, 2560);
203             input.dim(2).set_estimate(0, 4);
204             output.dim(0).set_estimate(0, 1536);
205             output.dim(1).set_estimate(0, 2560);
206             output.dim(2).set_estimate(0, 3);
207         }
208     }
209 };
210 
211 }  // namespace
212 
213 HALIDE_REGISTER_GENERATOR(Interpolate, interpolate)
214