1 #include "Halide.h"
2 #include <algorithm>
3 #include <stdio.h>
4
5 using namespace Halide;
6
main(int argc,char ** argv)7 int main(int argc, char **argv) {
8 // Compute the variance of a 3x3 patch about each pixel
9 RDom r(-1, 3, -1, 3);
10
11 // Test a complex summation
12 Func input;
13 Var x, y, z;
14 input(x, y) = cast<float>(x * y + 1);
15
16 Func local_variance;
17 Expr input_val = input(x + r.x, y + r.y);
18 Expr local_mean = sum(input_val) / 9.0f;
19 local_variance(x, y) = sum(input_val * input_val) / 81.0f - local_mean * local_mean;
20
21 Buffer<float> result = local_variance.realize(10, 10);
22
23 for (int y = 0; y < 10; y++) {
24 for (int x = 0; x < 10; x++) {
25 float local_mean = 0;
26 float local_variance = 0;
27 for (int rx = -1; rx < 2; rx++) {
28 for (int ry = -1; ry < 2; ry++) {
29 float val = (x + rx) * (y + ry) + 1.0f;
30 local_mean += val;
31 local_variance += val * val;
32 }
33 }
34 local_mean /= 9.0f;
35 float correct = local_variance / 81.0f - local_mean * local_mean;
36 float r = result(x, y);
37 float delta = correct - r;
38 if (delta < -0.001 || delta > 0.001) {
39 printf("result(%d, %d) was %f instead of %f\n", x, y, r, correct);
40 return -1;
41 }
42 }
43 }
44
45 // Test the other reductions.
46 Func local_product, local_max, local_min;
47 local_product(x, y) = product(input_val);
48 local_max(x, y) = maximum(input_val);
49 local_min(x, y) = minimum(input_val);
50
51 // Try a separable form of minimum too, so we test two reductions
52 // in one pipeline.
53 Func min_x, min_y;
54 RDom kx(-1, 3), ky(-1, 3);
55 min_x(x, y) = minimum(input(x + kx, y));
56 min_y(x, y) = minimum(min_x(x, y + ky));
57
58 // Vectorize them all, to make life more interesting.
59 local_product.vectorize(x, 4);
60 local_max.vectorize(x, 4);
61 local_min.vectorize(x, 4);
62 min_y.vectorize(x, 4);
63
64 Buffer<float> prod_im = local_product.realize(10, 10);
65 Buffer<float> max_im = local_max.realize(10, 10);
66 Buffer<float> min_im = local_min.realize(10, 10);
67 Buffer<float> min_im_separable = min_y.realize(10, 10);
68
69 for (int y = 0; y < 10; y++) {
70 for (int x = 0; x < 10; x++) {
71 float correct_prod = 1.0f;
72 float correct_min = 1e10f;
73 float correct_max = -1e10f;
74 for (int rx = -1; rx < 2; rx++) {
75 for (int ry = -1; ry < 2; ry++) {
76 float val = (x + rx) * (y + ry) + 1.0f;
77 correct_prod *= val;
78 correct_min = std::min(correct_min, val);
79 correct_max = std::max(correct_max, val);
80 }
81 }
82
83 float delta;
84 delta = (correct_prod + 10) / (prod_im(x, y) + 10);
85 if (delta < 0.99 || delta > 1.01) {
86 printf("prod_im(%d, %d) = %f instead of %f\n", x, y, prod_im(x, y), correct_prod);
87 return -1;
88 }
89
90 delta = correct_min - min_im(x, y);
91 if (delta < -0.001 || delta > 0.001) {
92 printf("min_im(%d, %d) = %f instead of %f\n", x, y, min_im(x, y), correct_min);
93 return -1;
94 }
95
96 delta = correct_min - min_im_separable(x, y);
97 if (delta < -0.001 || delta > 0.001) {
98 printf("min_im(%d, %d) = %f instead of %f\n", x, y, min_im_separable(x, y), correct_min);
99 return -1;
100 }
101
102 delta = correct_max - max_im(x, y);
103 if (delta < -0.001 || delta > 0.001) {
104 printf("max_im(%d, %d) = %f instead of %f\n", x, y, max_im(x, y), correct_max);
105 return -1;
106 }
107 }
108 }
109
110 // Verify that all inline reductions compile with implicit argument syntax.
111 Buffer<float> input_3d = lambda(x, y, z, x * 100.0f + y * 10.0f + ((z + 5 % 10))).realize(10, 10, 10);
112 RDom all_z(input_3d.min(2), input_3d.extent(2));
113
114 Func sum_implicit;
115 sum_implicit(_) = sum(input_3d(_, all_z));
116 Buffer<float> sum_implicit_im = sum_implicit.realize(10, 10);
117
118 Func product_implicit;
119 product_implicit(_) = product(input_3d(_, all_z));
120 Buffer<float> product_implicit_im = product_implicit.realize(10, 10);
121
122 Func min_implicit;
123 min_implicit(_) = minimum(input_3d(_, all_z));
124 Buffer<float> min_implicit_im = min_implicit.realize(10, 10);
125
126 Func max_implicit;
127 max_implicit(_, y) = maximum(input_3d(_, y, all_z));
128 Buffer<float> max_implicit_im = max_implicit.realize(10, 10);
129
130 Func argmin_implicit;
131 argmin_implicit(_) = argmin(input_3d(_, all_z))[0];
132 Buffer<int32_t> argmin_implicit_im = argmin_implicit.realize(10, 10);
133
134 Func argmax_implicit;
135 argmax_implicit(x, _) = argmax(input_3d(x, _, all_z))[0];
136 Buffer<int32_t> argmax_implicit_im = argmax_implicit.realize(10, 10);
137
138 // Verify that the min of negative floats and doubles is correct
139 // (this used to be buggy due to the minimum float being the
140 // smallest positive float instead of the smallest float).
141 float result_f32 = evaluate<float>(minimum(RDom(0, 11) * -0.5f));
142 if (result_f32 != -5.0f) {
143 printf("minimum is %f instead of -5.0f\n", result_f32);
144 return -1;
145 }
146
147 double result_f64 = evaluate<double>(minimum(RDom(0, 11) * cast<double>(-0.5f)));
148 if (result_f64 != -5.0) {
149 printf("minimum is %f instead of -5.0\n", result_f64);
150 return -1;
151 }
152
153 // Check that min of a bunch of infinities is infinity.
154 // Be sure to use strict_float() so that LLVM doesn't optimize away
155 // the infinities.
156 const float inf_f32 = std::numeric_limits<float>::infinity();
157 const double inf_f64 = std::numeric_limits<double>::infinity();
158 result_f32 = evaluate<float>(minimum(strict_float(RDom(1, 10) * inf_f32)));
159 if (result_f32 != inf_f32) {
160 printf("minimum is %f instead of infinity\n", result_f32);
161 return -1;
162 }
163 result_f64 = evaluate<double>(minimum(strict_float(RDom(1, 10) * Expr(inf_f64))));
164 if (result_f64 != inf_f64) {
165 printf("minimum is %f instead of infinity\n", result_f64);
166 return -1;
167 }
168 result_f32 = evaluate<float>(maximum(strict_float(RDom(1, 10) * -inf_f32)));
169 if (result_f32 != -inf_f32) {
170 printf("maximum is %f instead of -infinity\n", result_f32);
171 return -1;
172 }
173 result_f64 = evaluate<double>(maximum(strict_float(RDom(1, 10) * Expr(-inf_f64))));
174 if (result_f64 != -inf_f64) {
175 printf("maximum is %f instead of -infinity\n", result_f64);
176 return -1;
177 }
178
179 printf("Success!\n");
180 return 0;
181 }
182