1 #include "Halide.h"
2
3 using namespace Halide;
4
5 class CountConditionals : public Internal::IRVisitor {
6 public:
7 int count = 0;
8 int count_if = 0;
9 int count_select = 0;
10 bool in_produce = false;
11
12 private:
13 using Internal::IRVisitor::visit;
14
visit(const Internal::Select * op)15 void visit(const Internal::Select *op) override {
16 if (in_produce) {
17 count++;
18 count_select++;
19 }
20 Internal::IRVisitor::visit(op);
21 }
22
visit(const Internal::IfThenElse * op)23 void visit(const Internal::IfThenElse *op) override {
24 if (in_produce) {
25 count++;
26 count_if++;
27 }
28 Internal::IRVisitor::visit(op);
29 }
30
visit(const Internal::ProducerConsumer * op)31 void visit(const Internal::ProducerConsumer *op) override {
32 if (op->is_producer) {
33 bool old_in_produce = in_produce;
34 in_produce = true;
35 Internal::IRVisitor::visit(op);
36 in_produce = old_in_produce;
37 } else {
38 IRVisitor::visit(op);
39 }
40 }
41 };
42
main(int argc,char ** argv)43 int main(int argc, char **argv) {
44 printf("Running inequality condition test\n");
45 {
46 // Loop iterations that would be no-ops should be trimmed off.
47 Func f;
48 Var x;
49 f(x) = x;
50 f(x) += select(x > 10 && x < 20, 1, 0);
51 f(x) += select(x < 10, 0, 1);
52 f(x) *= select(x > 20 && x < 30, 2, 1);
53 f(x) = select(x >= 60 && x <= 100, 100 - f(x), f(x));
54
55 // There should be no selects or ifs after trim_no_ops runs
56 Module m = f.compile_to_module({});
57 CountConditionals s;
58 m.functions().front().body.accept(&s);
59 if (s.count != 0) {
60 std::cerr << "There were conditionals in the lowered code: \n"
61 << m.functions().front().body << "\n";
62 return -1;
63 }
64
65 // Also check the output is correct
66 Buffer<int> im = f.realize(100);
67 for (int x = 0; x < im.width(); x++) {
68 int correct = x;
69 correct += (x > 10 && x < 20) ? 1 : 0;
70 correct += (x < 10) ? 0 : 1;
71 correct *= (x > 20 && x < 30) ? 2 : 1;
72 correct = (x >= 60 && x <= 100) ? (100 - correct) : correct;
73 if (im(x) != correct) {
74 printf("im(%d) = %d instead of %d\n",
75 x, im(x), correct);
76 return -1;
77 }
78 }
79 }
80
81 printf("Running equality condition test\n");
82 {
83 // Loop iterations that would be no-ops should be trimmed off. trim_no_ops
84 // should be able to handle equality as well.
85 Func f;
86 Var x, y;
87 f(x, y) = x + y;
88 f(x, y) += select((x == 10) && (x < y), 1, 0);
89 Module m = f.compile_to_module({});
90
91 // There should be no selects after trim_no_ops runs
92 CountConditionals s;
93 m.functions().front().body.accept(&s);
94 if (s.count != 0) {
95 std::cerr << "There were selects in the lowered code: \n"
96 << m.functions().front().body << "\n";
97 return -1;
98 }
99
100 // Also check the output is correct
101 Buffer<int> im = f.realize(100, 100);
102 for (int y = 0; y < im.height(); y++) {
103 for (int x = 0; x < im.width(); x++) {
104 int correct = x + y;
105 correct += ((x == 10) && (x < y)) ? 1 : 0;
106 if (im(x, y) != correct) {
107 printf("im(%d, %d) = %d instead of %d\n",
108 x, y, im(x, y), correct);
109 return -1;
110 }
111 }
112 }
113 }
114
115 printf("Running tiled histogram test\n");
116 {
117 // Test a tiled histogram
118 Func f;
119 Var x, y;
120 f(x, y) = cast<uint8_t>(random_int());
121 f.compute_root();
122
123 Func hist;
124 {
125 RDom r(0, 10, 0, 10, 0, 10, 0, 10);
126 Expr xi = r[0] + r[2] * 10, yi = r[1] + r[3] * 10;
127 hist(x) = 0;
128 hist(f(clamp(xi, 0, 73), clamp(yi, 0, 73))) +=
129 select(xi >= 0 && xi <= 73 && yi >= 0 && yi <= 73, 1, 0);
130
131 Module m = hist.compile_to_module({});
132 CountConditionals s;
133 m.functions().front().body.accept(&s);
134 if (s.count != 0) {
135 std::cerr << "There were selects in the lowered code: \n"
136 << m.functions().front().body << "\n";
137 return -1;
138 }
139 }
140 Buffer<int> hist_result = hist.realize(256);
141
142 // Also check the output is correct.
143 Func true_hist;
144 {
145 RDom r(0, 74, 0, 74);
146 true_hist(x) = 0;
147 true_hist(f(r.x, r.y)) += 1;
148 }
149 Buffer<int> true_hist_result = true_hist.realize(256);
150
151 for (int i = 0; i < 256; i++) {
152 if (hist_result(i) != true_hist_result(i)) {
153 printf("hist(%d) = %d instead of %d\n",
154 i, hist_result(i), true_hist_result(i));
155 return -1;
156 }
157 }
158 }
159
160 printf("Running tiled iteration over triangle test\n");
161 {
162 // Test tiled iteration over a triangle, where the condition is an
163 // if statement instead of a select.
164 Func f;
165 Var x, y;
166 f(x, y) = select(2 * x < y, 5, undef<int>());
167
168 Var xi, yi;
169 f.tile(x, y, xi, yi, 4, 4);
170
171 // Check there are no if statements.
172 Module m = f.compile_to_module({});
173 CountConditionals s;
174 m.functions().front().body.accept(&s);
175 if (s.count != 0) {
176 std::cerr << "There were selects or ifs in the lowered code: \n"
177 << m.functions().front().body << "\n";
178 return -1;
179 }
180 }
181
182 // Test tiled iteration on the gpu if there is support for GPU.
183 // The gpu loop variable should not depend on outer gpu loop var.
184 if (!get_jit_target_from_environment().has_gpu_feature()) {
185 // TODO: split this test apart so the GPU pieces can be split appropriately
186 // printf("[SKIP] No GPU target enabled.\n");
187 printf("Success!\n");
188 return 0;
189 }
190
191 printf("Running select is not simplified on gpu test\n");
192 {
193 Func f;
194 Var x, y;
195 f(x, y) = x + y;
196
197 RDom r(0, 100, 0, 100);
198 f(r.x, r.y) += select((r.x < r.y) && (r.x == 10), 3, undef<int>());
199
200 RVar rxi, ryi;
201 f.update(0).gpu_tile(r.x, r.y, rxi, ryi, 4, 4);
202
203 Buffer<int> im = f.realize(200, 200);
204
205 // There should be no selects after trim_no_ops runs. The select should
206 // be lifted out as if condition. We can't trim gpu loop r.x based on the
207 // if condition since it depends on gpu outer loop r.y
208 Target gpu_target(get_host_target());
209 gpu_target.set_feature(Target::CUDA);
210 Module m = f.compile_to_module({}, "", gpu_target);
211 CountConditionals s;
212 m.functions().front().body.accept(&s);
213 if (s.count_select != 0) {
214 std::cerr << "There were selects in the lowered code: \n"
215 << m.functions().front().body << "\n";
216 return -1;
217 }
218 if (s.count_if != 1) {
219 std::cerr << "There should be 1 if in the lowered code: \n"
220 << m.functions().front().body << "\n";
221 return -1;
222 }
223
224 for (int y = 0; y < im.height(); y++) {
225 for (int x = 0; x < im.width(); x++) {
226 int correct = x + y;
227 if ((x == 10) && (0 <= y && y <= 99)) {
228 correct += (x < y) ? 3 : 0;
229 }
230 if (im(x, y) != correct) {
231 printf("im(%d, %d) = %d instead of %d\n",
232 x, y, im(x, y), correct);
233 return -1;
234 }
235 }
236 }
237 }
238
239 printf("Success!\n");
240 return 0;
241 }
242