1 #include "Halide.h" 2 3 using namespace Halide; 4 5 class Median3x3 : public Generator<Median3x3> { 6 private: mid(Expr a,Expr b,Expr c)7 static Expr mid(Expr a, Expr b, Expr c) { 8 return max(min(max(a, b), c), min(a, b)); 9 } 10 11 public: 12 // Takes an 8 bit image; one channel. 13 Input<Buffer<uint8_t>> input{"input", 2}; 14 // Outputs an 8 bit image; one channel. 15 Output<Buffer<uint8_t>> output{"output", 2}; 16 17 GeneratorParam<bool> use_parallel_sched{"use_parallel_sched", true}; 18 GeneratorParam<bool> use_prefetch_sched{"use_prefetch_sched", true}; 19 generate()20 void generate() { 21 bounded_input(x, y) = BoundaryConditions::repeat_edge(input)(x, y); 22 max_y(x, y) = max(bounded_input(x, y - 1), bounded_input(x, y), bounded_input(x, y + 1)); 23 min_y(x, y) = min(bounded_input(x, y - 1), bounded_input(x, y), bounded_input(x, y + 1)); 24 mid_y(x, y) = mid(bounded_input(x, y - 1), bounded_input(x, y), bounded_input(x, y + 1)); 25 26 minmax_x(x, y) = min(max_y(x - 1, y), max_y(x, y), max_y(x + 1, y)); 27 maxmin_x(x, y) = max(min_y(x - 1, y), min_y(x, y), min_y(x + 1, y)); 28 midmid_x(x, y) = mid(mid_y(x - 1, y), mid_y(x, y), mid_y(x + 1, y)); 29 30 output(x, y) = mid(minmax_x(x, y), maxmin_x(x, y), midmid_x(x, y)); 31 } 32 schedule()33 void schedule() { 34 Var xi{"xi"}, yi{"yi"}; 35 36 input.dim(0).set_min(0); 37 input.dim(1).set_min(0); 38 39 output.dim(0).set_min(0); 40 output.dim(1).set_min(0); 41 42 if (get_target().features_any_of({Target::HVX_64, Target::HVX_128})) { 43 const int vector_size = get_target().has_feature(Target::HVX_128) ? 128 : 64; 44 Expr input_stride = input.dim(1).stride(); 45 input.dim(1).set_stride((input_stride / vector_size) * vector_size); 46 47 Expr output_stride = output.dim(1).stride(); 48 output.dim(1).set_stride((output_stride / vector_size) * vector_size); 49 bounded_input 50 .compute_at(Func(output), y) 51 .align_storage(x, 128) 52 .vectorize(x, vector_size, TailStrategy::RoundUp); 53 output 54 .hexagon() 55 .tile(x, y, xi, yi, vector_size, 4) 56 .vectorize(xi) 57 .unroll(yi); 58 if (use_prefetch_sched) { 59 output.prefetch(input, y, 2); 60 } 61 if (use_parallel_sched) { 62 Var yo; 63 output.split(y, yo, y, 128).parallel(yo); 64 } 65 } else { 66 const int vector_size = natural_vector_size<uint8_t>(); 67 output 68 .vectorize(x, vector_size) 69 .parallel(y, 16); 70 } 71 } 72 73 private: 74 Var x{"x"}, y{"y"}; 75 Func max_y{"max_y"}, min_y{"min_y"}, mid_y{"mid_y"}; 76 Func minmax_x{"minmax_x"}, maxmin_x{"maxmin_x"}, midmid_x{"midmid_x"}; 77 Func bounded_input{"bounded_input"}; 78 }; 79 80 HALIDE_REGISTER_GENERATOR(Median3x3, median3x3) 81