1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 enum ScheduleType {
6     Naive,
7     LessNaive,
8     Complex
9 };
10 
11 // This is a dumbed-down version of the Resize generator, intended solely
12 // to demonstrate use of HalideTraceViz auto-layout mode; it has multiple
13 // schedule, ranging from 'naive' to 'complex', intended to demonstrate
14 // how even basic auto-layout of tracing can be useful.
15 //
16 // The approach of using an enum for 'naive-of-complex-schedule' is an expedient
17 // for this purpose; it shouldn't me mimicked in most real-world code.
18 class AutoVizDemo : public Halide::Generator<AutoVizDemo> {
19 public:
20     GeneratorParam<ScheduleType> schedule_type{"schedule_type", Naive, {{"naive", Naive}, {"lessnaive", LessNaive}, {"complex", Complex}}};
21 
22     // If we statically know whether we're upsampling or downsampling,
23     // we can generate different pipelines (we want to reorder the
24     // resample in x and in y).
25     GeneratorParam<bool> upsample{"upsample", false};
26 
27     Input<Buffer<float>> input{"input", 3};
28     Input<float> scale_factor{"scale_factor"};
29     Output<Buffer<float>> output{"output", 3};
30 
31     // Common Vars
32     Var x, y, c, k;
33 
34     // Intermediate Funcs
35     Func as_float, clamped, resized_x, resized_y,
36         unnormalized_kernel_x, unnormalized_kernel_y,
37         kernel_x, kernel_y,
38         kernel_sum_x, kernel_sum_y;
39 
generate()40     void generate() {
41 
42         clamped = BoundaryConditions::repeat_edge(input,
43                                                   {{input.dim(0).min(), input.dim(0).extent()},
44                                                    {input.dim(1).min(), input.dim(1).extent()}});
45 
46         // Handle different types by just casting to float
47         as_float(x, y, c) = cast<float>(clamped(x, y, c));
48 
49         // For downscaling, widen the interpolation kernel to perform lowpass
50         // filtering.
51 
52         Expr kernel_scaling = upsample ? Expr(1.0f) : scale_factor;
53 
54         Expr kernel_radius = 0.5f / kernel_scaling;
55 
56         Expr kernel_taps = ceil(1.0f / kernel_scaling);
57 
58         // source[xy] are the (non-integer) coordinates inside the source image
59         Expr sourcex = (x + 0.5f) / scale_factor - 0.5f;
60         Expr sourcey = (y + 0.5f) / scale_factor - 0.5f;
61 
62         // Initialize interpolation kernels. Since we allow an arbitrary
63         // scaling factor, the filter coefficients are different for each x
64         // and y coordinate.
65         Expr beginx = cast<int>(ceil(sourcex - kernel_radius));
66         Expr beginy = cast<int>(ceil(sourcey - kernel_radius));
67 
68         RDom r(0, kernel_taps);
69 
70         auto kernel = [](Expr x) -> Expr {
71             Expr xx = abs(x);
72             return select(abs(x) <= 0.5f, 1.0f, 0.0f);
73         };
74         unnormalized_kernel_x(x, k) = kernel((k + beginx - sourcex) * kernel_scaling);
75         unnormalized_kernel_y(y, k) = kernel((k + beginy - sourcey) * kernel_scaling);
76 
77         kernel_sum_x(x) = sum(unnormalized_kernel_x(x, r), "kernel_sum_x");
78         kernel_sum_y(y) = sum(unnormalized_kernel_y(y, r), "kernel_sum_y");
79 
80         kernel_x(x, k) = unnormalized_kernel_x(x, k) / kernel_sum_x(x);
81         kernel_y(y, k) = unnormalized_kernel_y(y, k) / kernel_sum_y(y);
82 
83         // Perform separable resizing. The resize in x vectorizes
84         // poorly compared to the resize in y, so do it first if we're
85         // upsampling, and do it second if we're downsampling.
86         Func resized;
87         if (upsample) {
88             resized_x(x, y, c) = sum(kernel_x(x, r) * as_float(r + beginx, y, c), "resized_x");
89             resized_y(x, y, c) = sum(kernel_y(y, r) * resized_x(x, r + beginy, c), "resized_y");
90             resized = resized_y;
91         } else {
92             resized_y(x, y, c) = sum(kernel_y(y, r) * as_float(x, r + beginy, c), "resized_y");
93             resized_x(x, y, c) = sum(kernel_x(x, r) * resized_y(r + beginx, y, c), "resized_x");
94             resized = resized_x;
95         }
96 
97         if (input.type().is_float()) {
98             output(x, y, c) = clamp(resized(x, y, c), 0.0f, 1.0f);
99         } else {
100             output(x, y, c) = saturating_cast(input.type(), resized(x, y, c));
101         }
102     }
103 
schedule()104     void schedule() {
105         Var xi, yi;
106         if (schedule_type == Naive) {
107             // naive: compute_root() everything
108             unnormalized_kernel_x
109                 .compute_root();
110             kernel_sum_x
111                 .compute_root();
112             kernel_x
113                 .compute_root();
114             unnormalized_kernel_y
115                 .compute_root();
116             kernel_sum_y
117                 .compute_root();
118             kernel_y
119                 .compute_root();
120             as_float
121                 .compute_root();
122             resized_x
123                 .compute_root();
124             output
125                 .compute_root();
126         } else if (schedule_type == LessNaive) {
127             // less-naive: add vectorization and parallelism to 'large' realizations;
128             // use compute_at for as_float calculation
129             unnormalized_kernel_x
130                 .compute_root();
131             kernel_sum_x
132                 .compute_root();
133             kernel_x
134                 .compute_root();
135 
136             unnormalized_kernel_y
137                 .compute_root();
138             kernel_sum_y
139                 .compute_root();
140             kernel_y
141                 .compute_root();
142 
143             as_float
144                 .compute_at(resized_x, y);
145             resized_x
146                 .compute_root()
147                 .parallel(y);
148             output
149                 .compute_root()
150                 .parallel(y)
151                 .vectorize(x, 8);
152         } else if (schedule_type == Complex) {
153             // complex: use compute_at() and tiling intelligently.
154             unnormalized_kernel_x
155                 .compute_at(kernel_x, x)
156                 .vectorize(x);
157             kernel_sum_x
158                 .compute_at(kernel_x, x)
159                 .vectorize(x);
160             kernel_x
161                 .compute_root()
162                 .reorder(k, x)
163                 .vectorize(x, 8);
164 
165             unnormalized_kernel_y
166                 .compute_at(kernel_y, y)
167                 .vectorize(y, 8);
168             kernel_sum_y
169                 .compute_at(kernel_y, y)
170                 .vectorize(y);
171             kernel_y
172                 .compute_at(output, y)
173                 .reorder(k, y)
174                 .vectorize(y, 8);
175 
176             if (upsample) {
177                 as_float
178                     .compute_at(output, y)
179                     .vectorize(x, 8);
180                 resized_x
181                     .compute_at(output, x)
182                     .vectorize(x, 8);
183                 output
184                     .tile(x, y, xi, yi, 16, 64)
185                     .parallel(y)
186                     .vectorize(xi);
187             } else {
188                 resized_y
189                     .compute_at(output, y)
190                     .vectorize(x, 8);
191                 resized_x
192                     .compute_at(output, xi);
193                 output
194                     .tile(x, y, xi, yi, 32, 8)
195                     .parallel(y)
196                     .vectorize(xi);
197             }
198         }
199     }
200 };
201 
202 HALIDE_REGISTER_GENERATOR(AutoVizDemo, auto_viz_demo);
203