1 #include "Halide.h"
2 #include "halide_trace_config.h"
3 
4 namespace {
5 
6 class BilateralGrid : public Halide::Generator<BilateralGrid> {
7 public:
8     GeneratorParam<int> s_sigma{"s_sigma", 8};
9 
10     Input<Buffer<float>> input{"input", 2};
11     Input<float> r_sigma{"r_sigma"};
12 
13     Output<Buffer<float>> bilateral_grid{"bilateral_grid", 2};
14 
generate()15     void generate() {
16         Var x("x"), y("y"), z("z"), c("c");
17 
18         // Add a boundary condition
19         Func clamped = Halide::BoundaryConditions::repeat_edge(input);
20 
21         // Construct the bilateral grid
22         RDom r(0, s_sigma, 0, s_sigma);
23         Expr val = clamped(x * s_sigma + r.x - s_sigma / 2, y * s_sigma + r.y - s_sigma / 2);
24         val = clamp(val, 0.0f, 1.0f);
25 
26         Expr zi = cast<int>(val * (1.0f / r_sigma) + 0.5f);
27 
28         Func histogram("histogram");
29         histogram(x, y, z, c) = 0.0f;
30         histogram(x, y, zi, c) += mux(c, {val, 1.0f});
31 
32         // Blur the grid using a five-tap filter
33         Func blurx("blurx"), blury("blury"), blurz("blurz");
34         blurz(x, y, z, c) = (histogram(x, y, z - 2, c) +
35                              histogram(x, y, z - 1, c) * 4 +
36                              histogram(x, y, z, c) * 6 +
37                              histogram(x, y, z + 1, c) * 4 +
38                              histogram(x, y, z + 2, c));
39         blurx(x, y, z, c) = (blurz(x - 2, y, z, c) +
40                              blurz(x - 1, y, z, c) * 4 +
41                              blurz(x, y, z, c) * 6 +
42                              blurz(x + 1, y, z, c) * 4 +
43                              blurz(x + 2, y, z, c));
44         blury(x, y, z, c) = (blurx(x, y - 2, z, c) +
45                              blurx(x, y - 1, z, c) * 4 +
46                              blurx(x, y, z, c) * 6 +
47                              blurx(x, y + 1, z, c) * 4 +
48                              blurx(x, y + 2, z, c));
49 
50         // Take trilinear samples to compute the output
51         val = clamp(input(x, y), 0.0f, 1.0f);
52         Expr zv = val * (1.0f / r_sigma);
53         zi = cast<int>(zv);
54         Expr zf = zv - zi;
55         Expr xf = cast<float>(x % s_sigma) / s_sigma;
56         Expr yf = cast<float>(y % s_sigma) / s_sigma;
57         Expr xi = x / s_sigma;
58         Expr yi = y / s_sigma;
59         Func interpolated("interpolated");
60         interpolated(x, y, c) =
61             lerp(lerp(lerp(blury(xi, yi, zi, c), blury(xi + 1, yi, zi, c), xf),
62                       lerp(blury(xi, yi + 1, zi, c), blury(xi + 1, yi + 1, zi, c), xf), yf),
63                  lerp(lerp(blury(xi, yi, zi + 1, c), blury(xi + 1, yi, zi + 1, c), xf),
64                       lerp(blury(xi, yi + 1, zi + 1, c), blury(xi + 1, yi + 1, zi + 1, c), xf), yf),
65                  zf);
66 
67         // Normalize
68         bilateral_grid(x, y) = interpolated(x, y, 0) / interpolated(x, y, 1);
69 
70         /* ESTIMATES */
71         // (This can be useful in conjunction with RunGen and benchmarks as well
72         // as auto-schedule, so we do it in all cases.)
73         // Provide estimates on the input image
74         input.set_estimates({{0, 1536}, {0, 2560}});
75         // Provide estimates on the parameters
76         r_sigma.set_estimate(0.1f);
77         // TODO: Compute estimates from the parameter values
78         histogram.set_estimate(z, -2, 16);
79         blurz.set_estimate(z, 0, 12);
80         blurx.set_estimate(z, 0, 12);
81         blury.set_estimate(z, 0, 12);
82         bilateral_grid.set_estimates({{0, 1536}, {0, 2560}});
83 
84         if (auto_schedule) {
85             // nothing
86         } else if (get_target().has_gpu_feature()) {
87             // 0.50ms on an RTX 2060
88 
89             Var xi("xi"), yi("yi"), zi("zi");
90 
91             // Schedule blurz in 8x8 tiles. This is a tile in
92             // grid-space, which means it represents something like
93             // 64x64 pixels in the input (if s_sigma is 8).
94             blurz.compute_root().reorder(c, z, x, y).gpu_tile(x, y, xi, yi, 8, 8);
95 
96             // Schedule histogram to happen per-tile of blurz, with
97             // intermediate results in shared memory. This means histogram
98             // and blurz makes a three-stage kernel:
99             // 1) Zero out the 8x8 set of histograms
100             // 2) Compute those histogram by iterating over lots of the input image
101             // 3) Blur the set of histograms in z
102             histogram.reorder(c, z, x, y).compute_at(blurz, x).gpu_threads(x, y);
103             histogram.update().reorder(c, r.x, r.y, x, y).gpu_threads(x, y).unroll(c);
104 
105             // Schedule the remaining blurs and the sampling at the end similarly.
106             blurx
107                 .compute_root()
108                 .reorder(c, x, y, z)
109                 .reorder_storage(c, x, y, z)
110                 .vectorize(c)
111                 .unroll(y, 2, TailStrategy::RoundUp)
112                 .gpu_tile(x, y, z, xi, yi, zi, 32, 8, 1, TailStrategy::RoundUp);
113             blury
114                 .compute_root()
115                 .reorder(c, x, y, z)
116                 .reorder_storage(c, x, y, z)
117                 .vectorize(c)
118                 .unroll(y, 2, TailStrategy::RoundUp)
119                 .gpu_tile(x, y, z, xi, yi, zi, 32, 8, 1, TailStrategy::RoundUp);
120             bilateral_grid.compute_root().gpu_tile(x, y, xi, yi, 32, 8);
121             interpolated.compute_at(bilateral_grid, xi).vectorize(c);
122         } else {
123             // CPU schedule.
124 
125             // 3.98ms on an Intel i9-9960X using 32 threads at 3.7 GHz
126             // using target x86-64-avx2. This is a little less
127             // SIMD-friendly than some of the other apps, so we
128             // benefit from hyperthreading, and don't benefit from
129             // AVX-512, which on my machine reduces the clock to 3.0
130             // GHz.
131 
132             blurz.compute_root()
133                 .reorder(c, z, x, y)
134                 .parallel(y)
135                 .vectorize(x, 8)
136                 .unroll(c);
137             histogram.compute_at(blurz, y);
138             histogram.update()
139                 .reorder(c, r.x, r.y, x, y)
140                 .unroll(c);
141             blurx.compute_root()
142                 .reorder(c, x, y, z)
143                 .parallel(z)
144                 .vectorize(x, 8)
145                 .unroll(c);
146             blury.compute_root()
147                 .reorder(c, x, y, z)
148                 .parallel(z)
149                 .vectorize(x, 8)
150                 .unroll(c);
151             bilateral_grid.compute_root()
152                 .parallel(y)
153                 .vectorize(x, 8);
154         }
155 
156         /* Optional tags to specify layout for HalideTraceViz */
157         {
158             Halide::Trace::FuncConfig cfg;
159             cfg.pos.x = 100;
160             cfg.pos.y = 300;
161             input.add_trace_tag(cfg.to_trace_tag());
162 
163             cfg.pos.x = 1564;
164             bilateral_grid.add_trace_tag(cfg.to_trace_tag());
165         }
166         {
167             Halide::Trace::FuncConfig cfg;
168             cfg.strides = {{1, 0}, {0, 1}, {40, 0}};
169             cfg.zoom = 3;
170 
171             cfg.max = 32;
172             cfg.pos.x = 550;
173             cfg.pos.y = 100;
174             histogram.add_trace_tag(cfg.to_trace_tag());
175 
176             cfg.max = 512;
177             cfg.pos.y = 300;
178             blurz.add_trace_tag(cfg.to_trace_tag());
179 
180             cfg.max = 8192;
181             cfg.pos.y = 500;
182             blurx.add_trace_tag(cfg.to_trace_tag());
183 
184             cfg.max = 131072;
185             cfg.pos.y = 700;
186             blury.add_trace_tag(cfg.to_trace_tag());
187         }
188         {
189             // GlobalConfig applies to the entire visualization pipeline;
190             // you can set this tag on any Func that is realized, but only
191             // the last one seen will be used. (Since the tags are emitted in
192             // an arbitrary order, emitting only one such tag is the best practice).
193             // Note also that since the global settings are often context-dependent
194             // (eg the output size and timestep may vary depending on the
195             // input data), it's often more useful to specify these on the
196             // command line.
197             Halide::Trace::GlobalConfig global_cfg;
198             global_cfg.timestep = 1000;
199 
200             bilateral_grid.add_trace_tag(global_cfg.to_trace_tag());
201         }
202     }
203 };
204 
205 }  // namespace
206 
207 HALIDE_REGISTER_GENERATOR(BilateralGrid, bilateral_grid)
208