1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 enum InterpolationType {
6     Box,
7     Linear,
8     Cubic,
9     Lanczos
10 };
11 
kernel_box(Expr x)12 Expr kernel_box(Expr x) {
13     Expr xx = abs(x);
14     return select(xx <= 0.5f, 1.0f, 0.0f);
15 }
16 
kernel_linear(Expr x)17 Expr kernel_linear(Expr x) {
18     Expr xx = abs(x);
19     return select(xx < 1.0f, 1.0f - xx, 0.0f);
20 }
21 
kernel_cubic(Expr x)22 Expr kernel_cubic(Expr x) {
23     Expr xx = abs(x);
24     Expr xx2 = xx * xx;
25     Expr xx3 = xx2 * xx;
26     float a = -0.5f;
27 
28     return select(xx < 1.0f, (a + 2.0f) * xx3 - (a + 3.0f) * xx2 + 1,
29                   select(xx < 2.0f, a * xx3 - 5 * a * xx2 + 8 * a * xx - 4.0f * a,
30                          0.0f));
31 }
32 
sinc(Expr x)33 Expr sinc(Expr x) {
34     x *= 3.14159265359f;
35     return sin(x) / x;
36 }
37 
kernel_lanczos(Expr x)38 Expr kernel_lanczos(Expr x) {
39     Expr value = sinc(x) * sinc(x / 3);
40     value = select(x == 0.0f, 1.0f, value);        // Take care of singularity at zero
41     value = select(x > 3 || x < -3, 0.0f, value);  // Clamp to zero out of bounds
42     return value;
43 }
44 
45 struct KernelInfo {
46     const char *name;
47     int taps;
48     Expr (*kernel)(Expr);
49 };
50 
51 static KernelInfo kernel_info[] = {
52     {"box", 1, kernel_box},
53     {"linear", 2, kernel_linear},
54     {"cubic", 4, kernel_cubic},
55     {"lanczos", 6, kernel_lanczos}};
56 
57 class Resize : public Halide::Generator<Resize> {
58 public:
59     GeneratorParam<InterpolationType> interpolation_type{"interpolation_type", Cubic, {{"box", Box}, {"linear", Linear}, {"cubic", Cubic}, {"lanczos", Lanczos}}};
60 
61     // If we statically know whether we're upsampling or downsampling,
62     // we can generate different pipelines (we want to reorder the
63     // resample in x and in y).
64     GeneratorParam<bool> upsample{"upsample", false};
65 
66     Input<Buffer<>> input{"input", 3};
67     Input<float> scale_factor{"scale_factor"};
68     Output<Buffer<>> output{"output", 3};
69 
70     // Common Vars
71     Var x, y, c, k;
72 
73     // Intermediate Funcs
74     Func as_float, clamped, resized_x, resized_y,
75         unnormalized_kernel_x, unnormalized_kernel_y,
76         kernel_x, kernel_y,
77         kernel_sum_x, kernel_sum_y;
78 
generate()79     void generate() {
80 
81         clamped = BoundaryConditions::repeat_edge(input,
82                                                   {{input.dim(0).min(), input.dim(0).extent()},
83                                                    {input.dim(1).min(), input.dim(1).extent()}});
84 
85         // Handle different types by just casting to float
86         as_float(x, y, c) = cast<float>(clamped(x, y, c));
87 
88         // For downscaling, widen the interpolation kernel to perform lowpass
89         // filtering.
90 
91         Expr kernel_scaling = upsample ? Expr(1.0f) : scale_factor;
92 
93         Expr kernel_radius = 0.5f * kernel_info[interpolation_type].taps / kernel_scaling;
94 
95         Expr kernel_taps = ceil(kernel_info[interpolation_type].taps / kernel_scaling);
96 
97         // source[xy] are the (non-integer) coordinates inside the source image
98         Expr sourcex = (x + 0.5f) / scale_factor - 0.5f;
99         Expr sourcey = (y + 0.5f) / scale_factor - 0.5f;
100 
101         // Initialize interpolation kernels. Since we allow an arbitrary
102         // scaling factor, the filter coefficients are different for each x
103         // and y coordinate.
104         Expr beginx = cast<int>(ceil(sourcex - kernel_radius));
105         Expr beginy = cast<int>(ceil(sourcey - kernel_radius));
106 
107         RDom r(0, cast<int>(kernel_taps));
108         const KernelInfo &info = kernel_info[interpolation_type];
109 
110         unnormalized_kernel_x(x, k) = info.kernel((k + beginx - sourcex) * kernel_scaling);
111         unnormalized_kernel_y(y, k) = info.kernel((k + beginy - sourcey) * kernel_scaling);
112 
113         kernel_sum_x(x) = sum(unnormalized_kernel_x(x, r), "kernel_sum_x");
114         kernel_sum_y(y) = sum(unnormalized_kernel_y(y, r), "kernel_sum_y");
115 
116         kernel_x(x, k) = unnormalized_kernel_x(x, k) / kernel_sum_x(x);
117         kernel_y(y, k) = unnormalized_kernel_y(y, k) / kernel_sum_y(y);
118 
119         // Perform separable resizing. The resize in x vectorizes
120         // poorly compared to the resize in y, so do it first if we're
121         // upsampling, and do it second if we're downsampling.
122         Func resized;
123         if (upsample) {
124             resized_x(x, y, c) = sum(kernel_x(x, r) * as_float(r + beginx, y, c), "resized_x");
125             resized_y(x, y, c) = sum(kernel_y(y, r) * resized_x(x, r + beginy, c), "resized_y");
126             resized = resized_y;
127         } else {
128             resized_y(x, y, c) = sum(kernel_y(y, r) * as_float(x, r + beginy, c), "resized_y");
129             resized_x(x, y, c) = sum(kernel_x(x, r) * resized_y(r + beginx, y, c), "resized_x");
130             resized = resized_x;
131         }
132 
133         if (input.type().is_float()) {
134             output(x, y, c) = clamp(resized(x, y, c), 0.0f, 1.0f);
135         } else {
136             output(x, y, c) = saturating_cast(input.type(), resized(x, y, c));
137         }
138     }
139 
schedule()140     void schedule() {
141         Var xi, yi;
142         unnormalized_kernel_x
143             .compute_at(kernel_x, x)
144             .vectorize(x);
145         kernel_sum_x
146             .compute_at(kernel_x, x)
147             .vectorize(x);
148         kernel_x
149             .compute_root()
150             .reorder(k, x)
151             .vectorize(x, 8);
152 
153         unnormalized_kernel_y
154             .compute_at(kernel_y, y)
155             .vectorize(y, 8);
156         kernel_sum_y
157             .compute_at(kernel_y, y)
158             .vectorize(y);
159         kernel_y
160             .compute_at(output, y)
161             .reorder(k, y)
162             .vectorize(y, 8);
163 
164         if (upsample) {
165             output
166                 .tile(x, y, xi, yi, 16, 64)
167                 .parallel(y)
168                 .vectorize(xi);
169             resized_x
170                 .compute_at(output, x)
171                 .vectorize(x, 8);
172             as_float
173                 .compute_at(output, y)
174                 .vectorize(x, 8);
175         } else {
176             output
177                 .tile(x, y, xi, yi, 32, 8)
178                 .parallel(y)
179                 .vectorize(xi);
180             resized_y
181                 .compute_at(output, y)
182                 .vectorize(x, 8);
183             resized_x
184                 .compute_at(output, xi);
185         }
186 
187         // Allow the input and output to have arbitrary memory layout,
188         // and add some specializations for a few common cases. If
189         // your case is not covered (e.g. planar input, packed rgb
190         // output), you could add a new specialization here.
191         output.dim(0).set_stride(Expr());
192         input.dim(0).set_stride(Expr());
193 
194         Expr planar = (output.dim(0).stride() == 1 &&
195                        input.dim(0).stride() == 1);
196         Expr packed_rgb = (output.dim(0).stride() == 3 &&
197                            output.dim(2).stride() == 1 &&
198                            output.dim(2).min() == 0 &&
199                            output.dim(2).extent() == 3 &&
200                            input.dim(0).stride() == 3 &&
201                            input.dim(2).stride() == 1 &&
202                            input.dim(2).min() == 0 &&
203                            input.dim(2).extent() == 3);
204         Expr packed_rgba = (output.dim(0).stride() == 4 &&
205                             output.dim(2).stride() == 1 &&
206                             output.dim(2).min() == 0 &&
207                             output.dim(2).extent() == 4 &&
208                             input.dim(0).stride() == 4 &&
209                             input.dim(2).stride() == 1 &&
210                             input.dim(2).min() == 0 &&
211                             input.dim(2).extent() == 4);
212 
213         output.specialize(planar);
214 
215         output.specialize(packed_rgb)
216             .reorder(c, xi, yi, x, y)
217             .unroll(c);
218 
219         output.specialize(packed_rgba)
220             .reorder(c, xi, yi, x, y)
221             .unroll(c);
222     }
223 };
224 
225 HALIDE_REGISTER_GENERATOR(Resize, resize);
226