1 #include "Halide.h" 2 3 namespace { 4 5 using namespace Halide; 6 7 class NonLocalMeans : public Halide::Generator<NonLocalMeans> { 8 public: 9 Input<Buffer<float>> input{"input", 3}; 10 Input<int> patch_size{"patch_size"}; 11 Input<int> search_area{"search_area"}; 12 Input<float> sigma{"sigma"}; 13 14 Output<Buffer<float>> non_local_means{"non_local_means", 3}; 15 generate()16 void generate() { 17 /* THE ALGORITHM */ 18 19 // This implements the basic description of non-local means found at 20 // https://en.wikipedia.org/wiki/Non-local_means 21 22 Var x("x"), y("y"), c("c"); 23 24 Expr inv_sigma_sq = -1.0f / (sigma * sigma * patch_size * patch_size); 25 26 // Add a boundary condition 27 Func clamped = BoundaryConditions::repeat_edge(input); 28 29 // Define the difference images 30 Var dx("dx"), dy("dy"); 31 Func dc("d"); 32 dc(x, y, dx, dy, c) = pow(clamped(x, y, c) - clamped(x + dx, y + dy, c), 2); 33 34 // Sum across color channels 35 RDom channels(0, 3); 36 Func d("d"); 37 d(x, y, dx, dy) = sum(dc(x, y, dx, dy, channels)); 38 39 // Find the patch differences by blurring the difference images 40 RDom patch_dom(-(patch_size / 2), patch_size); 41 Func blur_d_y("blur_d_y"); 42 blur_d_y(x, y, dx, dy) = sum(d(x, y + patch_dom, dx, dy)); 43 44 Func blur_d("blur_d"); 45 blur_d(x, y, dx, dy) = sum(blur_d_y(x + patch_dom, y, dx, dy)); 46 47 // Compute the weights from the patch differences 48 Func w("w"); 49 w(x, y, dx, dy) = fast_exp(blur_d(x, y, dx, dy) * inv_sigma_sq); 50 51 // Add an alpha channel 52 Func clamped_with_alpha("clamped_with_alpha"); 53 clamped_with_alpha(x, y, c) = mux(c, {clamped(x, y, 0), clamped(x, y, 1), clamped(x, y, 2), 1.0f}); 54 55 // Define a reduction domain for the search area 56 RDom s_dom(-(search_area / 2), search_area, -(search_area / 2), search_area); 57 58 // Compute the sum of the pixels in the search area 59 Func non_local_means_sum("non_local_means_sum"); 60 non_local_means_sum(x, y, c) += w(x, y, s_dom.x, s_dom.y) * clamped_with_alpha(x + s_dom.x, y + s_dom.y, c); 61 62 non_local_means(x, y, c) = 63 clamp(non_local_means_sum(x, y, c) / non_local_means_sum(x, y, 3), 0.0f, 1.0f); 64 65 /* THE SCHEDULE */ 66 67 // Require 3 channels for output 68 non_local_means.dim(2).set_bounds(0, 3); 69 70 Var tx("tx"), ty("ty"), xi("xi"), yi("yi"); 71 72 /* ESTIMATES */ 73 // (This can be useful in conjunction with RunGen and benchmarks as well 74 // as auto-schedule, so we do it in all cases.) 75 // Provide estimates on the input image 76 input.set_estimates({{0, 1536}, {0, 2560}, {0, 3}}); 77 // Provide estimates on the parameters 78 patch_size.set_estimate(7); 79 search_area.set_estimate(7); 80 sigma.set_estimate(0.12f); 81 // Provide estimates on the output pipeline 82 non_local_means.set_estimates({{0, 1536}, {0, 2560}, {0, 3}}); 83 84 if (auto_schedule) { 85 // nothing 86 } else if (get_target().has_gpu_feature()) { 87 // 22 ms on a 2060 RTX 88 Var xii, yii; 89 90 // We'll use 32x16 thread blocks throughout. This was 91 // found by just trying lots of sizes, but large thread 92 // blocks are particularly good in the blur_d stage to 93 // avoid doing wasted blurring work at tile boundaries 94 // (especially for large patch sizes). 95 96 non_local_means.compute_root() 97 .reorder(c, x, y) 98 .unroll(c) 99 .gpu_tile(x, y, xi, yi, 32, 16); 100 101 non_local_means_sum.compute_root() 102 .gpu_tile(x, y, xi, yi, 32, 16) 103 .update() 104 .reorder(c, s_dom.x, x, y, s_dom.y) 105 .tile(x, y, xi, yi, 32, 16) 106 .gpu_blocks(x, y) 107 .gpu_threads(xi, yi) 108 .unroll(c); 109 110 // The patch size we're benchmarking for is 7, which 111 // implies an expansion of 6 pixels for footprint of the 112 // blur, so we'll size tiles of blur_d to be a multiple of 113 // the thread block size minus 6. 114 blur_d.compute_at(non_local_means_sum, s_dom.y) 115 .tile(x, y, xi, yi, 128 - 6, 32 - 6) 116 .tile(xi, yi, xii, yii, 32, 16) 117 .gpu_threads(xii, yii) 118 .gpu_blocks(x, y, dx); 119 120 blur_d_y.compute_at(blur_d, x) 121 .tile(x, y, xi, yi, 32, 16) 122 .gpu_threads(xi, yi); 123 124 d.compute_at(blur_d, x) 125 .tile(x, y, xi, yi, 32, 16) 126 .gpu_threads(xi, yi); 127 128 } else { 129 // 64 ms on an Intel i9-9960X using 32 threads at 3.0 GHz 130 131 const int vec = natural_vector_size<float>(); 132 133 non_local_means.compute_root() 134 .reorder(c, x, y) 135 .tile(x, y, tx, ty, x, y, 16, 8) 136 .parallel(ty) 137 .vectorize(x, vec); 138 blur_d_y.compute_at(non_local_means, tx) 139 .reorder(y, x) 140 .vectorize(x, vec); 141 d.compute_at(non_local_means, tx) 142 .vectorize(x, vec); 143 non_local_means_sum.compute_at(non_local_means, x) 144 .reorder(c, x, y) 145 .bound(c, 0, 4) 146 .unroll(c) 147 .vectorize(x, vec); 148 non_local_means_sum.update(0) 149 .reorder(c, x, y, s_dom.x, s_dom.y) 150 .unroll(c) 151 .vectorize(x, vec); 152 blur_d.compute_at(non_local_means_sum, x) 153 .vectorize(x, vec); 154 } 155 } 156 }; 157 158 } // namespace 159 160 HALIDE_REGISTER_GENERATOR(NonLocalMeans, nl_means) 161