1 #include "Halide.h"
2 
3 namespace {
4 
5 using namespace Halide;
6 
7 class NonLocalMeans : public Halide::Generator<NonLocalMeans> {
8 public:
9     Input<Buffer<float>> input{"input", 3};
10     Input<int> patch_size{"patch_size"};
11     Input<int> search_area{"search_area"};
12     Input<float> sigma{"sigma"};
13 
14     Output<Buffer<float>> non_local_means{"non_local_means", 3};
15 
generate()16     void generate() {
17         /* THE ALGORITHM */
18 
19         // This implements the basic description of non-local means found at
20         // https://en.wikipedia.org/wiki/Non-local_means
21 
22         Var x("x"), y("y"), c("c");
23 
24         Expr inv_sigma_sq = -1.0f / (sigma * sigma * patch_size * patch_size);
25 
26         // Add a boundary condition
27         Func clamped = BoundaryConditions::repeat_edge(input);
28 
29         // Define the difference images
30         Var dx("dx"), dy("dy");
31         Func dc("d");
32         dc(x, y, dx, dy, c) = pow(clamped(x, y, c) - clamped(x + dx, y + dy, c), 2);
33 
34         // Sum across color channels
35         RDom channels(0, 3);
36         Func d("d");
37         d(x, y, dx, dy) = sum(dc(x, y, dx, dy, channels));
38 
39         // Find the patch differences by blurring the difference images
40         RDom patch_dom(-(patch_size / 2), patch_size);
41         Func blur_d_y("blur_d_y");
42         blur_d_y(x, y, dx, dy) = sum(d(x, y + patch_dom, dx, dy));
43 
44         Func blur_d("blur_d");
45         blur_d(x, y, dx, dy) = sum(blur_d_y(x + patch_dom, y, dx, dy));
46 
47         // Compute the weights from the patch differences
48         Func w("w");
49         w(x, y, dx, dy) = fast_exp(blur_d(x, y, dx, dy) * inv_sigma_sq);
50 
51         // Add an alpha channel
52         Func clamped_with_alpha("clamped_with_alpha");
53         clamped_with_alpha(x, y, c) = mux(c, {clamped(x, y, 0), clamped(x, y, 1), clamped(x, y, 2), 1.0f});
54 
55         // Define a reduction domain for the search area
56         RDom s_dom(-(search_area / 2), search_area, -(search_area / 2), search_area);
57 
58         // Compute the sum of the pixels in the search area
59         Func non_local_means_sum("non_local_means_sum");
60         non_local_means_sum(x, y, c) += w(x, y, s_dom.x, s_dom.y) * clamped_with_alpha(x + s_dom.x, y + s_dom.y, c);
61 
62         non_local_means(x, y, c) =
63             clamp(non_local_means_sum(x, y, c) / non_local_means_sum(x, y, 3), 0.0f, 1.0f);
64 
65         /* THE SCHEDULE */
66 
67         // Require 3 channels for output
68         non_local_means.dim(2).set_bounds(0, 3);
69 
70         Var tx("tx"), ty("ty"), xi("xi"), yi("yi");
71 
72         /* ESTIMATES */
73         // (This can be useful in conjunction with RunGen and benchmarks as well
74         // as auto-schedule, so we do it in all cases.)
75         // Provide estimates on the input image
76         input.set_estimates({{0, 1536}, {0, 2560}, {0, 3}});
77         // Provide estimates on the parameters
78         patch_size.set_estimate(7);
79         search_area.set_estimate(7);
80         sigma.set_estimate(0.12f);
81         // Provide estimates on the output pipeline
82         non_local_means.set_estimates({{0, 1536}, {0, 2560}, {0, 3}});
83 
84         if (auto_schedule) {
85             // nothing
86         } else if (get_target().has_gpu_feature()) {
87             // 22 ms on a 2060 RTX
88             Var xii, yii;
89 
90             // We'll use 32x16 thread blocks throughout. This was
91             // found by just trying lots of sizes, but large thread
92             // blocks are particularly good in the blur_d stage to
93             // avoid doing wasted blurring work at tile boundaries
94             // (especially for large patch sizes).
95 
96             non_local_means.compute_root()
97                 .reorder(c, x, y)
98                 .unroll(c)
99                 .gpu_tile(x, y, xi, yi, 32, 16);
100 
101             non_local_means_sum.compute_root()
102                 .gpu_tile(x, y, xi, yi, 32, 16)
103                 .update()
104                 .reorder(c, s_dom.x, x, y, s_dom.y)
105                 .tile(x, y, xi, yi, 32, 16)
106                 .gpu_blocks(x, y)
107                 .gpu_threads(xi, yi)
108                 .unroll(c);
109 
110             // The patch size we're benchmarking for is 7, which
111             // implies an expansion of 6 pixels for footprint of the
112             // blur, so we'll size tiles of blur_d to be a multiple of
113             // the thread block size minus 6.
114             blur_d.compute_at(non_local_means_sum, s_dom.y)
115                 .tile(x, y, xi, yi, 128 - 6, 32 - 6)
116                 .tile(xi, yi, xii, yii, 32, 16)
117                 .gpu_threads(xii, yii)
118                 .gpu_blocks(x, y, dx);
119 
120             blur_d_y.compute_at(blur_d, x)
121                 .tile(x, y, xi, yi, 32, 16)
122                 .gpu_threads(xi, yi);
123 
124             d.compute_at(blur_d, x)
125                 .tile(x, y, xi, yi, 32, 16)
126                 .gpu_threads(xi, yi);
127 
128         } else {
129             // 64 ms on an Intel i9-9960X using 32 threads at 3.0 GHz
130 
131             const int vec = natural_vector_size<float>();
132 
133             non_local_means.compute_root()
134                 .reorder(c, x, y)
135                 .tile(x, y, tx, ty, x, y, 16, 8)
136                 .parallel(ty)
137                 .vectorize(x, vec);
138             blur_d_y.compute_at(non_local_means, tx)
139                 .reorder(y, x)
140                 .vectorize(x, vec);
141             d.compute_at(non_local_means, tx)
142                 .vectorize(x, vec);
143             non_local_means_sum.compute_at(non_local_means, x)
144                 .reorder(c, x, y)
145                 .bound(c, 0, 4)
146                 .unroll(c)
147                 .vectorize(x, vec);
148             non_local_means_sum.update(0)
149                 .reorder(c, x, y, s_dom.x, s_dom.y)
150                 .unroll(c)
151                 .vectorize(x, vec);
152             blur_d.compute_at(non_local_means_sum, x)
153                 .vectorize(x, vec);
154         }
155     }
156 };
157 
158 }  // namespace
159 
160 HALIDE_REGISTER_GENERATOR(NonLocalMeans, nl_means)
161