1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 class Median3x3 : public Generator<Median3x3> {
6 private:
mid(Expr a,Expr b,Expr c)7     static Expr mid(Expr a, Expr b, Expr c) {
8         return max(min(max(a, b), c), min(a, b));
9     }
10 
11 public:
12     // Takes an 8 bit image; one channel.
13     Input<Buffer<uint8_t>> input{"input", 2};
14     // Outputs an 8 bit image; one channel.
15     Output<Buffer<uint8_t>> output{"output", 2};
16 
17     GeneratorParam<bool> use_parallel_sched{"use_parallel_sched", true};
18     GeneratorParam<bool> use_prefetch_sched{"use_prefetch_sched", true};
19 
generate()20     void generate() {
21         bounded_input(x, y) = BoundaryConditions::repeat_edge(input)(x, y);
22         max_y(x, y) = max(bounded_input(x, y - 1), bounded_input(x, y), bounded_input(x, y + 1));
23         min_y(x, y) = min(bounded_input(x, y - 1), bounded_input(x, y), bounded_input(x, y + 1));
24         mid_y(x, y) = mid(bounded_input(x, y - 1), bounded_input(x, y), bounded_input(x, y + 1));
25 
26         minmax_x(x, y) = min(max_y(x - 1, y), max_y(x, y), max_y(x + 1, y));
27         maxmin_x(x, y) = max(min_y(x - 1, y), min_y(x, y), min_y(x + 1, y));
28         midmid_x(x, y) = mid(mid_y(x - 1, y), mid_y(x, y), mid_y(x + 1, y));
29 
30         output(x, y) = mid(minmax_x(x, y), maxmin_x(x, y), midmid_x(x, y));
31     }
32 
schedule()33     void schedule() {
34         Var xi{"xi"}, yi{"yi"};
35 
36         input.dim(0).set_min(0);
37         input.dim(1).set_min(0);
38 
39         output.dim(0).set_min(0);
40         output.dim(1).set_min(0);
41 
42         if (get_target().features_any_of({Target::HVX_64, Target::HVX_128})) {
43             const int vector_size = get_target().has_feature(Target::HVX_128) ? 128 : 64;
44             Expr input_stride = input.dim(1).stride();
45             input.dim(1).set_stride((input_stride / vector_size) * vector_size);
46 
47             Expr output_stride = output.dim(1).stride();
48             output.dim(1).set_stride((output_stride / vector_size) * vector_size);
49             bounded_input
50                 .compute_at(Func(output), y)
51                 .align_storage(x, 128)
52                 .vectorize(x, vector_size, TailStrategy::RoundUp);
53             output
54                 .hexagon()
55                 .tile(x, y, xi, yi, vector_size, 4)
56                 .vectorize(xi)
57                 .unroll(yi);
58             if (use_prefetch_sched) {
59                 output.prefetch(input, y, 2);
60             }
61             if (use_parallel_sched) {
62                 Var yo;
63                 output.split(y, yo, y, 128).parallel(yo);
64             }
65         } else {
66             const int vector_size = natural_vector_size<uint8_t>();
67             output
68                 .vectorize(x, vector_size)
69                 .parallel(y, 16);
70         }
71     }
72 
73 private:
74     Var x{"x"}, y{"y"};
75     Func max_y{"max_y"}, min_y{"min_y"}, mid_y{"mid_y"};
76     Func minmax_x{"minmax_x"}, maxmin_x{"maxmin_x"}, midmid_x{"midmid_x"};
77     Func bounded_input{"bounded_input"};
78 };
79 
80 HALIDE_REGISTER_GENERATOR(Median3x3, median3x3)
81