1 #include "Halide.h" 2 #include "halide_trace_config.h" 3 4 namespace { 5 6 class BilateralGrid : public Halide::Generator<BilateralGrid> { 7 public: 8 GeneratorParam<int> s_sigma{"s_sigma", 8}; 9 10 Input<Buffer<float>> input{"input", 2}; 11 Input<float> r_sigma{"r_sigma"}; 12 13 Output<Buffer<float>> bilateral_grid{"bilateral_grid", 2}; 14 generate()15 void generate() { 16 Var x("x"), y("y"), z("z"), c("c"); 17 18 // Add a boundary condition 19 Func clamped = Halide::BoundaryConditions::repeat_edge(input); 20 21 // Construct the bilateral grid 22 RDom r(0, s_sigma, 0, s_sigma); 23 Expr val = clamped(x * s_sigma + r.x - s_sigma / 2, y * s_sigma + r.y - s_sigma / 2); 24 val = clamp(val, 0.0f, 1.0f); 25 26 Expr zi = cast<int>(val * (1.0f / r_sigma) + 0.5f); 27 28 Func histogram("histogram"); 29 histogram(x, y, z, c) = 0.0f; 30 histogram(x, y, zi, c) += mux(c, {val, 1.0f}); 31 32 // Blur the grid using a five-tap filter 33 Func blurx("blurx"), blury("blury"), blurz("blurz"); 34 blurz(x, y, z, c) = (histogram(x, y, z - 2, c) + 35 histogram(x, y, z - 1, c) * 4 + 36 histogram(x, y, z, c) * 6 + 37 histogram(x, y, z + 1, c) * 4 + 38 histogram(x, y, z + 2, c)); 39 blurx(x, y, z, c) = (blurz(x - 2, y, z, c) + 40 blurz(x - 1, y, z, c) * 4 + 41 blurz(x, y, z, c) * 6 + 42 blurz(x + 1, y, z, c) * 4 + 43 blurz(x + 2, y, z, c)); 44 blury(x, y, z, c) = (blurx(x, y - 2, z, c) + 45 blurx(x, y - 1, z, c) * 4 + 46 blurx(x, y, z, c) * 6 + 47 blurx(x, y + 1, z, c) * 4 + 48 blurx(x, y + 2, z, c)); 49 50 // Take trilinear samples to compute the output 51 val = clamp(input(x, y), 0.0f, 1.0f); 52 Expr zv = val * (1.0f / r_sigma); 53 zi = cast<int>(zv); 54 Expr zf = zv - zi; 55 Expr xf = cast<float>(x % s_sigma) / s_sigma; 56 Expr yf = cast<float>(y % s_sigma) / s_sigma; 57 Expr xi = x / s_sigma; 58 Expr yi = y / s_sigma; 59 Func interpolated("interpolated"); 60 interpolated(x, y, c) = 61 lerp(lerp(lerp(blury(xi, yi, zi, c), blury(xi + 1, yi, zi, c), xf), 62 lerp(blury(xi, yi + 1, zi, c), blury(xi + 1, yi + 1, zi, c), xf), yf), 63 lerp(lerp(blury(xi, yi, zi + 1, c), blury(xi + 1, yi, zi + 1, c), xf), 64 lerp(blury(xi, yi + 1, zi + 1, c), blury(xi + 1, yi + 1, zi + 1, c), xf), yf), 65 zf); 66 67 // Normalize 68 bilateral_grid(x, y) = interpolated(x, y, 0) / interpolated(x, y, 1); 69 70 /* ESTIMATES */ 71 // (This can be useful in conjunction with RunGen and benchmarks as well 72 // as auto-schedule, so we do it in all cases.) 73 // Provide estimates on the input image 74 input.set_estimates({{0, 1536}, {0, 2560}}); 75 // Provide estimates on the parameters 76 r_sigma.set_estimate(0.1f); 77 // TODO: Compute estimates from the parameter values 78 histogram.set_estimate(z, -2, 16); 79 blurz.set_estimate(z, 0, 12); 80 blurx.set_estimate(z, 0, 12); 81 blury.set_estimate(z, 0, 12); 82 bilateral_grid.set_estimates({{0, 1536}, {0, 2560}}); 83 84 if (auto_schedule) { 85 // nothing 86 } else if (get_target().has_gpu_feature()) { 87 // 0.50ms on an RTX 2060 88 89 Var xi("xi"), yi("yi"), zi("zi"); 90 91 // Schedule blurz in 8x8 tiles. This is a tile in 92 // grid-space, which means it represents something like 93 // 64x64 pixels in the input (if s_sigma is 8). 94 blurz.compute_root().reorder(c, z, x, y).gpu_tile(x, y, xi, yi, 8, 8); 95 96 // Schedule histogram to happen per-tile of blurz, with 97 // intermediate results in shared memory. This means histogram 98 // and blurz makes a three-stage kernel: 99 // 1) Zero out the 8x8 set of histograms 100 // 2) Compute those histogram by iterating over lots of the input image 101 // 3) Blur the set of histograms in z 102 histogram.reorder(c, z, x, y).compute_at(blurz, x).gpu_threads(x, y); 103 histogram.update().reorder(c, r.x, r.y, x, y).gpu_threads(x, y).unroll(c); 104 105 // Schedule the remaining blurs and the sampling at the end similarly. 106 blurx 107 .compute_root() 108 .reorder(c, x, y, z) 109 .reorder_storage(c, x, y, z) 110 .vectorize(c) 111 .unroll(y, 2, TailStrategy::RoundUp) 112 .gpu_tile(x, y, z, xi, yi, zi, 32, 8, 1, TailStrategy::RoundUp); 113 blury 114 .compute_root() 115 .reorder(c, x, y, z) 116 .reorder_storage(c, x, y, z) 117 .vectorize(c) 118 .unroll(y, 2, TailStrategy::RoundUp) 119 .gpu_tile(x, y, z, xi, yi, zi, 32, 8, 1, TailStrategy::RoundUp); 120 bilateral_grid.compute_root().gpu_tile(x, y, xi, yi, 32, 8); 121 interpolated.compute_at(bilateral_grid, xi).vectorize(c); 122 } else { 123 // CPU schedule. 124 125 // 3.98ms on an Intel i9-9960X using 32 threads at 3.7 GHz 126 // using target x86-64-avx2. This is a little less 127 // SIMD-friendly than some of the other apps, so we 128 // benefit from hyperthreading, and don't benefit from 129 // AVX-512, which on my machine reduces the clock to 3.0 130 // GHz. 131 132 blurz.compute_root() 133 .reorder(c, z, x, y) 134 .parallel(y) 135 .vectorize(x, 8) 136 .unroll(c); 137 histogram.compute_at(blurz, y); 138 histogram.update() 139 .reorder(c, r.x, r.y, x, y) 140 .unroll(c); 141 blurx.compute_root() 142 .reorder(c, x, y, z) 143 .parallel(z) 144 .vectorize(x, 8) 145 .unroll(c); 146 blury.compute_root() 147 .reorder(c, x, y, z) 148 .parallel(z) 149 .vectorize(x, 8) 150 .unroll(c); 151 bilateral_grid.compute_root() 152 .parallel(y) 153 .vectorize(x, 8); 154 } 155 156 /* Optional tags to specify layout for HalideTraceViz */ 157 { 158 Halide::Trace::FuncConfig cfg; 159 cfg.pos.x = 100; 160 cfg.pos.y = 300; 161 input.add_trace_tag(cfg.to_trace_tag()); 162 163 cfg.pos.x = 1564; 164 bilateral_grid.add_trace_tag(cfg.to_trace_tag()); 165 } 166 { 167 Halide::Trace::FuncConfig cfg; 168 cfg.strides = {{1, 0}, {0, 1}, {40, 0}}; 169 cfg.zoom = 3; 170 171 cfg.max = 32; 172 cfg.pos.x = 550; 173 cfg.pos.y = 100; 174 histogram.add_trace_tag(cfg.to_trace_tag()); 175 176 cfg.max = 512; 177 cfg.pos.y = 300; 178 blurz.add_trace_tag(cfg.to_trace_tag()); 179 180 cfg.max = 8192; 181 cfg.pos.y = 500; 182 blurx.add_trace_tag(cfg.to_trace_tag()); 183 184 cfg.max = 131072; 185 cfg.pos.y = 700; 186 blury.add_trace_tag(cfg.to_trace_tag()); 187 } 188 { 189 // GlobalConfig applies to the entire visualization pipeline; 190 // you can set this tag on any Func that is realized, but only 191 // the last one seen will be used. (Since the tags are emitted in 192 // an arbitrary order, emitting only one such tag is the best practice). 193 // Note also that since the global settings are often context-dependent 194 // (eg the output size and timestep may vary depending on the 195 // input data), it's often more useful to specify these on the 196 // command line. 197 Halide::Trace::GlobalConfig global_cfg; 198 global_cfg.timestep = 1000; 199 200 bilateral_grid.add_trace_tag(global_cfg.to_trace_tag()); 201 } 202 } 203 }; 204 205 } // namespace 206 207 HALIDE_REGISTER_GENERATOR(BilateralGrid, bilateral_grid) 208