1 #include "Halide.h"
2 #include <stdio.h>
3 
4 #include "testing.h"
5 
6 using namespace Halide;
7 
test_per_channel_select()8 int test_per_channel_select() {
9 
10     printf("Testing select of channel.\n");
11 
12     // This test must be run with an OpenGL target.
13     const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);
14 
15     Func gpu("gpu"), cpu("cpu");
16     Var x("x"), y("y"), c("c");
17 
18     gpu(x, y, c) = cast<uint8_t>(mux(c, {128, x, y, x * y}));
19     gpu.bound(c, 0, 4);
20     gpu.glsl(x, y, c);
21     gpu.compute_root();
22 
23     cpu(x, y, c) = gpu(x, y, c);
24 
25     Buffer<uint8_t> out(10, 10, 4);
26     cpu.realize(out, target);
27 
28     // Verify the result
29     if (!Testing::check_result<uint8_t>(out, [&](int x, int y, int c) {
30 	    switch (c) {
31 		case 0: return 128;
32 		case 1: return x;
33 		case 2: return y;
34 		default: return x*y;
35 	    } })) {
36         return 1;
37     }
38 
39     return 0;
40 }
41 
test_flag_scalar_select()42 int test_flag_scalar_select() {
43 
44     printf("Testing select of scalar value with flag.\n");
45 
46     // This test must be run with an OpenGL target.
47     const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);
48 
49     Func gpu("gpu"), cpu("cpu");
50     Var x("x"), y("y"), c("c");
51 
52     int flag_value = 0;
53 
54     Param<int> flag("flag");
55     flag.set(flag_value);
56 
57     gpu(x, y, c) = cast<uint8_t>(select(flag != 0, 128,
58                                         255));
59     gpu.bound(c, 0, 4);
60     gpu.glsl(x, y, c);
61     gpu.compute_root();
62 
63     // This should trigger a copy_to_host operation
64     cpu(x, y, c) = gpu(x, y, c);
65 
66     Buffer<uint8_t> out(10, 10, 4);
67     cpu.realize(out, target);
68 
69     // Verify the result
70     if (!Testing::check_result<uint8_t>(out, [&](int x, int y, int c) {
71             return !flag_value ? 255 : 128;
72         })) {
73         return 1;
74     }
75 
76     return 0;
77 }
78 
test_flag_pixel_select()79 int test_flag_pixel_select() {
80 
81     printf("Testing select of pixel value with flag.\n");
82 
83     // This test must be run with an OpenGL target.
84     const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);
85 
86     Func gpu("gpu"), cpu("cpu");
87     Var x("x"), y("y"), c("c");
88 
89     int flag_value = 0;
90 
91     Param<int> flag("flag");
92     flag.set(flag_value);
93 
94     Buffer<uint8_t> image(10, 10, 4);
95     for (int y = 0; y < image.height(); y++) {
96         for (int x = 0; x < image.width(); x++) {
97             for (int c = 0; c < image.channels(); c++) {
98                 image(x, y, c) = 128;
99             }
100         }
101     }
102 
103     gpu(x, y, c) = cast<uint8_t>(select(flag != 0, image(x, y, c),
104                                         255));
105     gpu.bound(c, 0, 4);
106     gpu.glsl(x, y, c);
107     gpu.compute_root();
108 
109     // This should trigger a copy_to_host operation
110     cpu(x, y, c) = gpu(x, y, c);
111 
112     Buffer<uint8_t> out(10, 10, 4);
113     cpu.realize(out, target);
114 
115     // Verify the result
116     if (!Testing::check_result<uint8_t>(out, [&](int x, int y, int c) {
117             return !flag_value ? 255 : 128;
118         })) {
119         return 1;
120     }
121 
122     return 0;
123 }
124 
test_nested_select()125 int test_nested_select() {
126 
127     printf("Testing nested select.\n");
128 
129     // This test must be run with an OpenGL target.
130     const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);
131 
132     // Define the algorithm.
133     Var x("x"), y("y"), c("c");
134     Func f("f");
135     Expr temp = cast<uint8_t>(select(x == 0, 1, 2));
136     f(x, y, c) = select(y == 0, temp, 255 - temp);
137 
138     // Schedule f to run on the GPU.
139     const int channels = 3;
140     f.bound(c, 0, channels).glsl(x, y, c);
141 
142     // Generate the result.
143     const int width = 10, height = 10;
144     Buffer<uint8_t> out = f.realize(width, height, channels, target);
145 
146     // Check the result.
147     int errors = 0;
148     out.for_each_element([&](int x, int y, int c) {
149         uint8_t temp = x == 0 ? 1 : 2;
150         uint8_t expected = y == 0 ? temp : 255 - temp;
151         uint8_t actual = out(x, y, c);
152         if (expected != actual && ++errors == 1) {
153             fprintf(stderr, "out(%d, %d, %d) = %d instead of %d\n",
154                     x, y, c, actual, expected);
155         }
156     });
157 
158     return errors;
159 }
160 
test_nested_select_varying()161 int test_nested_select_varying() {
162 
163     printf("Testing nested select with varying condition.\n");
164 
165     // This test must be run with an OpenGL target.
166     const Target target = get_jit_target_from_environment().with_feature(Target::OpenGL);
167 
168     // Define the algorithm.
169     Var x("x"), y("y"), c("c");
170     Func f("f");
171     Expr temp = cast<uint8_t>(select(x - c > 0, 1, 2));
172     f(x, y, c) = select(y == 0, temp, 255 - temp);
173 
174     // Schedule f to run on the GPU.
175     const int channels = 3;
176     f.bound(c, 0, channels).glsl(x, y, c);
177 
178     // Generate the result.
179     const int width = 10, height = 10;
180     Buffer<uint8_t> out = f.realize(width, height, channels, target);
181 
182     // Check the result.
183     int errors = 0;
184     out.for_each_element([&](int x, int y, int c) {
185         uint8_t temp = x - c > 0 ? 1 : 2;
186         uint8_t expected = y == 0 ? temp : 255 - temp;
187         uint8_t actual = out(x, y, c);
188         if (expected != actual && ++errors == 1) {
189             fprintf(stderr, "out(%d, %d, %d) = %d instead of %d\n",
190                     x, y, c, actual, expected);
191         }
192     });
193 
194     return errors;
195 }
196 
main()197 int main() {
198 
199     int err = 0;
200 
201     err |= test_per_channel_select();
202     err |= test_flag_scalar_select();
203     err |= test_flag_pixel_select();
204     err |= test_nested_select();
205     err |= test_nested_select_varying();
206 
207     if (err) {
208         printf("FAILED\n");
209         return 1;
210     }
211 
212     printf("Success!\n");
213     return 0;
214 }
215