1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 class CountConditionals : public Internal::IRVisitor {
6 public:
7     int count = 0;
8     int count_if = 0;
9     int count_select = 0;
10     bool in_produce = false;
11 
12 private:
13     using Internal::IRVisitor::visit;
14 
visit(const Internal::Select * op)15     void visit(const Internal::Select *op) override {
16         if (in_produce) {
17             count++;
18             count_select++;
19         }
20         Internal::IRVisitor::visit(op);
21     }
22 
visit(const Internal::IfThenElse * op)23     void visit(const Internal::IfThenElse *op) override {
24         if (in_produce) {
25             count++;
26             count_if++;
27         }
28         Internal::IRVisitor::visit(op);
29     }
30 
visit(const Internal::ProducerConsumer * op)31     void visit(const Internal::ProducerConsumer *op) override {
32         if (op->is_producer) {
33             bool old_in_produce = in_produce;
34             in_produce = true;
35             Internal::IRVisitor::visit(op);
36             in_produce = old_in_produce;
37         } else {
38             IRVisitor::visit(op);
39         }
40     }
41 };
42 
main(int argc,char ** argv)43 int main(int argc, char **argv) {
44     printf("Running inequality condition test\n");
45     {
46         // Loop iterations that would be no-ops should be trimmed off.
47         Func f;
48         Var x;
49         f(x) = x;
50         f(x) += select(x > 10 && x < 20, 1, 0);
51         f(x) += select(x < 10, 0, 1);
52         f(x) *= select(x > 20 && x < 30, 2, 1);
53         f(x) = select(x >= 60 && x <= 100, 100 - f(x), f(x));
54 
55         // There should be no selects or ifs after trim_no_ops runs
56         Module m = f.compile_to_module({});
57         CountConditionals s;
58         m.functions().front().body.accept(&s);
59         if (s.count != 0) {
60             std::cerr << "There were conditionals in the lowered code: \n"
61                       << m.functions().front().body << "\n";
62             return -1;
63         }
64 
65         // Also check the output is correct
66         Buffer<int> im = f.realize(100);
67         for (int x = 0; x < im.width(); x++) {
68             int correct = x;
69             correct += (x > 10 && x < 20) ? 1 : 0;
70             correct += (x < 10) ? 0 : 1;
71             correct *= (x > 20 && x < 30) ? 2 : 1;
72             correct = (x >= 60 && x <= 100) ? (100 - correct) : correct;
73             if (im(x) != correct) {
74                 printf("im(%d) = %d instead of %d\n",
75                        x, im(x), correct);
76                 return -1;
77             }
78         }
79     }
80 
81     printf("Running equality condition test\n");
82     {
83         // Loop iterations that would be no-ops should be trimmed off. trim_no_ops
84         // should be able to handle equality as well.
85         Func f;
86         Var x, y;
87         f(x, y) = x + y;
88         f(x, y) += select((x == 10) && (x < y), 1, 0);
89         Module m = f.compile_to_module({});
90 
91         // There should be no selects after trim_no_ops runs
92         CountConditionals s;
93         m.functions().front().body.accept(&s);
94         if (s.count != 0) {
95             std::cerr << "There were selects in the lowered code: \n"
96                       << m.functions().front().body << "\n";
97             return -1;
98         }
99 
100         // Also check the output is correct
101         Buffer<int> im = f.realize(100, 100);
102         for (int y = 0; y < im.height(); y++) {
103             for (int x = 0; x < im.width(); x++) {
104                 int correct = x + y;
105                 correct += ((x == 10) && (x < y)) ? 1 : 0;
106                 if (im(x, y) != correct) {
107                     printf("im(%d, %d) = %d instead of %d\n",
108                            x, y, im(x, y), correct);
109                     return -1;
110                 }
111             }
112         }
113     }
114 
115     printf("Running tiled histogram test\n");
116     {
117         // Test a tiled histogram
118         Func f;
119         Var x, y;
120         f(x, y) = cast<uint8_t>(random_int());
121         f.compute_root();
122 
123         Func hist;
124         {
125             RDom r(0, 10, 0, 10, 0, 10, 0, 10);
126             Expr xi = r[0] + r[2] * 10, yi = r[1] + r[3] * 10;
127             hist(x) = 0;
128             hist(f(clamp(xi, 0, 73), clamp(yi, 0, 73))) +=
129                 select(xi >= 0 && xi <= 73 && yi >= 0 && yi <= 73, 1, 0);
130 
131             Module m = hist.compile_to_module({});
132             CountConditionals s;
133             m.functions().front().body.accept(&s);
134             if (s.count != 0) {
135                 std::cerr << "There were selects in the lowered code: \n"
136                           << m.functions().front().body << "\n";
137                 return -1;
138             }
139         }
140         Buffer<int> hist_result = hist.realize(256);
141 
142         // Also check the output is correct.
143         Func true_hist;
144         {
145             RDom r(0, 74, 0, 74);
146             true_hist(x) = 0;
147             true_hist(f(r.x, r.y)) += 1;
148         }
149         Buffer<int> true_hist_result = true_hist.realize(256);
150 
151         for (int i = 0; i < 256; i++) {
152             if (hist_result(i) != true_hist_result(i)) {
153                 printf("hist(%d) = %d instead of %d\n",
154                        i, hist_result(i), true_hist_result(i));
155                 return -1;
156             }
157         }
158     }
159 
160     printf("Running tiled iteration over triangle test\n");
161     {
162         // Test tiled iteration over a triangle, where the condition is an
163         // if statement instead of a select.
164         Func f;
165         Var x, y;
166         f(x, y) = select(2 * x < y, 5, undef<int>());
167 
168         Var xi, yi;
169         f.tile(x, y, xi, yi, 4, 4);
170 
171         // Check there are no if statements.
172         Module m = f.compile_to_module({});
173         CountConditionals s;
174         m.functions().front().body.accept(&s);
175         if (s.count != 0) {
176             std::cerr << "There were selects or ifs in the lowered code: \n"
177                       << m.functions().front().body << "\n";
178             return -1;
179         }
180     }
181 
182     // Test tiled iteration on the gpu if there is support for GPU.
183     // The gpu loop variable should not depend on outer gpu loop var.
184     if (!get_jit_target_from_environment().has_gpu_feature()) {
185         // TODO: split this test apart so the GPU pieces can be split appropriately
186         // printf("[SKIP] No GPU target enabled.\n");
187         printf("Success!\n");
188         return 0;
189     }
190 
191     printf("Running select is not simplified on gpu test\n");
192     {
193         Func f;
194         Var x, y;
195         f(x, y) = x + y;
196 
197         RDom r(0, 100, 0, 100);
198         f(r.x, r.y) += select((r.x < r.y) && (r.x == 10), 3, undef<int>());
199 
200         RVar rxi, ryi;
201         f.update(0).gpu_tile(r.x, r.y, rxi, ryi, 4, 4);
202 
203         Buffer<int> im = f.realize(200, 200);
204 
205         // There should be no selects after trim_no_ops runs. The select should
206         // be lifted out as if condition. We can't trim gpu loop r.x based on the
207         // if condition since it depends on gpu outer loop r.y
208         Target gpu_target(get_host_target());
209         gpu_target.set_feature(Target::CUDA);
210         Module m = f.compile_to_module({}, "", gpu_target);
211         CountConditionals s;
212         m.functions().front().body.accept(&s);
213         if (s.count_select != 0) {
214             std::cerr << "There were selects in the lowered code: \n"
215                       << m.functions().front().body << "\n";
216             return -1;
217         }
218         if (s.count_if != 1) {
219             std::cerr << "There should be 1 if in the lowered code: \n"
220                       << m.functions().front().body << "\n";
221             return -1;
222         }
223 
224         for (int y = 0; y < im.height(); y++) {
225             for (int x = 0; x < im.width(); x++) {
226                 int correct = x + y;
227                 if ((x == 10) && (0 <= y && y <= 99)) {
228                     correct += (x < y) ? 3 : 0;
229                 }
230                 if (im(x, y) != correct) {
231                     printf("im(%d, %d) = %d instead of %d\n",
232                            x, y, im(x, y), correct);
233                     return -1;
234                 }
235             }
236         }
237     }
238 
239     printf("Success!\n");
240     return 0;
241 }
242