1 #include "Halide.h"
2
3 #include <algorithm>
4 #include <iomanip>
5 #include <ios>
6 #include <iostream>
7
8 using namespace Halide;
9
10 #if defined(__SSE2__) || defined(__AVX__)
11 #include <immintrin.h>
12 #endif
13
14 #ifdef __SSE2__
no_fma_dot_prod_sse(const float * in,int count)15 float no_fma_dot_prod_sse(const float *in, int count) {
16 __m128 sum = _mm_set1_ps(0.0f);
17 const __m128 *in_v = (const __m128 *)in;
18 for (int i = 0; i < count / 4; i++) {
19 __m128 prod = _mm_mul_ps(in_v[i], in_v[i]);
20 sum = _mm_add_ps(prod, sum);
21 }
22 float *f = (float *)∑
23 float result = 0.0f;
24 for (int i = 0; i < 4; i++) {
25 result += f[i];
26 }
27 return result;
28 }
29 #endif
30
31 #if defined(__SSE2__) && defined(__FMA__)
fma_dot_prod_sse(const float * in,int count)32 float fma_dot_prod_sse(const float *in, int count) {
33 __m128 sum = _mm_set1_ps(0.0f);
34 const __m128 *in_v = (const __m128 *)in;
35 for (int i = 0; i < count / 4; i++) {
36 sum = _mm_fmadd_ps(in_v[i], in_v[i], sum);
37 }
38 float *f = (float *)∑
39 float result = 0.0f;
40 for (int i = 0; i < 4; i++) {
41 result += f[i];
42 }
43 return result;
44 }
45 #endif
46
47 #if defined(__AVX__)
no_fma_dot_prod_avx(const float * in,int count)48 float no_fma_dot_prod_avx(const float *in, int count) {
49 __m256 sum = _mm256_set1_ps(0.0f);
50 const __m256 *in_v = (const __m256 *)in;
51 for (int i = 0; i < count / 8; i++) {
52 __m256 prod = _mm256_mul_ps(in_v[i], in_v[i]);
53 sum = _mm256_add_ps(prod, sum);
54 }
55 float *f = (float *)∑
56 float result = 0.0f;
57 for (int i = 0; i < 8; i++) {
58 result += f[i];
59 }
60 return result;
61 }
62 #endif
63
64 #if defined(__AVX__) && defined(__FMA__)
fma_dot_prod_avx(const float * in,int count)65 float fma_dot_prod_avx(const float *in, int count) {
66 __m256 sum = _mm256_set1_ps(0.0f);
67 const __m256 *in_v = (const __m256 *)in;
68 for (int i = 0; i < count / 8; i++) {
69 sum = _mm256_fmadd_ps(in_v[i], in_v[i], sum);
70 }
71 float *f = (float *)∑
72 float result = 0.0f;
73 for (int i = 0; i < 8; i++) {
74 result += f[i];
75 }
76 return result;
77 }
78 #endif
79
one_million_rando_floats()80 Buffer<float> one_million_rando_floats() {
81 Var x("x");
82 Func randos;
83 randos(x) = random_float();
84 return randos.realize(1e6);
85 }
86
87 ImageParam in(Float(32), 1);
88
term(Expr index)89 Expr term(Expr index) {
90 return in(index)*in(index);
91 }
92
93 enum class FloatStrictness {
94 Default,
95 Strict
96 } global_strictness = FloatStrictness::Default;
97
strictness_to_string(FloatStrictness strictness)98 std::string strictness_to_string(FloatStrictness strictness) {
99 if (strictness == FloatStrictness::Strict) {
100 return "strict_float";
101 }
102 return "default";
103 }
104
apply_strictness(Expr x)105 Expr apply_strictness(Expr x) {
106 if (global_strictness == FloatStrictness::Strict) {
107 return strict_float(x);
108 }
109 return x;
110 }
111
112 template<typename Accum>
simple_sum(int vectorize)113 Func simple_sum(int vectorize) {
114 Func total("total");
115 // Can't use rfactor because strict_float is not associative.
116 if (vectorize != 0) {
117 Func total_inner("total_inner");
118 RDom r_outer(0, in.width() / vectorize);
119 RDom r_lanes(0, vectorize);
120 Var i("i");
121 total_inner(i) = cast<Accum>(0);
122 total_inner(i) = apply_strictness(total_inner(i) + cast<Accum>(term(r_outer * vectorize + i)));
123 total() = cast<Accum>(0);
124 total() = apply_strictness(total() + total_inner(r_lanes));
125 total_inner.compute_at(total, Var::outermost());
126 total_inner.vectorize(i);
127 total_inner.update(0).vectorize(i);
128 } else {
129 RDom r(0, in.width(), "r");
130
131 total() = apply_strictness(cast<Accum>(0));
132 total() = apply_strictness(total() + cast<Accum>(term(r)));
133 }
134 #if 0
135 if (vectorize != 0) {
136 RVar rxo("rxo"), rxi("rxi");
137 Var u("u");
138 Func intm = total.update(0).split(r, rxo, rxi, vectorize).rfactor({{rxi, u}});
139 intm.compute_at(total, Var::outermost());
140 intm.vectorize(u, vectorize);
141 intm.update(0).vectorize(u, vectorize);
142 }
143 #endif
144 return lambda(apply_strictness(cast<float>(total())));
145 }
146
kahan_sum(int vectorize)147 Func kahan_sum(int vectorize) {
148 // Item 0 of the tuple valued k_sum is the sum and item 1 is an error compensation term.
149 // See: https://en.wikipedia.org/wiki/Kahan_summation_algorithm
150 Func k_sum("k_sum");
151
152 // rfactor cannot prove associativity for the non-strict formulation and strict_float is not associative.
153 if (vectorize != 0) {
154 Func k_sum_inner("k_sum_inner");
155 RDom r_outer(0, in.width() / vectorize);
156 RDom r_lanes(0, vectorize);
157 Var i("i");
158 k_sum_inner(i) = Tuple(0.0f, 0.0f);
159 k_sum_inner(i) = Tuple(apply_strictness(k_sum_inner(i)[0] + (term(r_outer * vectorize + i) - k_sum_inner(i)[1])),
160 apply_strictness((k_sum_inner(i)[0] + (term(r_outer * vectorize + i) - k_sum_inner(i)[1])) - k_sum_inner(i)[0]) - (term(r_outer * vectorize + i) - k_sum_inner(i)[1]));
161 k_sum() = Tuple(0.0f, 0.0f);
162 k_sum() = Tuple(apply_strictness(k_sum()[0] + (k_sum_inner(r_lanes)[0] - k_sum()[1])),
163 apply_strictness((k_sum()[0] + (k_sum_inner(r_lanes)[0] - k_sum()[1])) - k_sum()[0]) - (k_sum_inner(r_lanes)[0] - k_sum()[1]));
164 k_sum_inner.compute_at(k_sum, Var::outermost());
165 k_sum_inner.vectorize(i);
166 k_sum_inner.update(0).vectorize(i);
167 } else {
168 RDom r(0, in.width(), "r");
169
170 k_sum() = Tuple(0.0f, 0.0f);
171 k_sum() = Tuple(apply_strictness(k_sum()[0] + (term(r) - k_sum()[1])),
172 apply_strictness((k_sum()[0] + (term(r) - k_sum()[1])) - k_sum()[0]) - (term(r) - k_sum()[1]));
173 }
174
175 return lambda(k_sum()[0]);
176 }
177
eval(Func f,const Target & t,const std::string & name,const std::string & suffix,float expected)178 float eval(Func f, const Target &t, const std::string &name, const std::string &suffix, float expected) {
179 float val = ((Buffer<float>)f.realize(t))();
180 std::cout << " " << name << ": " << val;
181 if (expected != 0.0f) {
182 std::cout << " residual: " << val - expected;
183 }
184 std::cout << "\n";
185 return val;
186 }
187
run_one_condition(const Target & t,FloatStrictness strictness,Buffer<float> vals)188 void run_one_condition(const Target &t, FloatStrictness strictness, Buffer<float> vals) {
189 global_strictness = strictness;
190 std::string suffix = "_" + t.to_string() + "_" + strictness_to_string(strictness);
191
192 std::cout << " Target: " << t.to_string() << " Strictness: " << strictness_to_string(strictness) << "\n";
193
194 float simple_double = eval(simple_sum<double>(0), t, "simple_double", suffix, 0.0f);
195 float simple_double_vec_4 = eval(simple_sum<double>(4), t, "simple_double_vec_4", suffix, simple_double);
196 float simple_double_vec_8 = eval(simple_sum<double>(8), t, "simple_double_vec_8", suffix, simple_double);
197 float simple_float = eval(simple_sum<float>(0), t, "simple_float", suffix, simple_double);
198 float simple_float_vec_4 = eval(simple_sum<float>(4), t, "simple_float_vec_4", suffix, simple_double);
199 float simple_float_vec_8 = eval(simple_sum<float>(8), t, "simple_float_vec_8", suffix, simple_double);
200 float kahan = eval(kahan_sum(0), t, "kahan", suffix, simple_double);
201 float kahan_vec_4 = eval(kahan_sum(4), t, "kahan_vec_4", suffix, simple_double);
202 float kahan_vec_8 = eval(kahan_sum(8), t, "kahan_vec_8", suffix, simple_double);
203
204 #ifdef __SSE2__
205 float vec_dot_prod_4 = no_fma_dot_prod_sse(&vals(0), vals.width());
206 std::cout << " four wide no fma: " << vec_dot_prod_4 << " residual: " << vec_dot_prod_4 - simple_double << "\n";
207 #endif
208
209 #if defined(__SSE2__) && defined(__FMA__)
210 float fma_dot_prod_4 = fma_dot_prod_sse(&vals(0), vals.width());
211 std::cout << " four wide fma: " << fma_dot_prod_4 << " residual: " << fma_dot_prod_4 - simple_double << "\n";
212 #endif
213
214 #if defined(__AVX__)
215 float vec_dot_prod_8 = no_fma_dot_prod_avx(&vals(0), vals.width());
216 std::cout << " eight wide no fma: " << vec_dot_prod_8 << " residual: " << vec_dot_prod_8 - simple_double << "\n";
217 #endif
218
219 #if defined(__AVX__) && defined(__FMA__)
220 float fma_dot_prod_8 = fma_dot_prod_avx(&vals(0), vals.width());
221 std::cout << " eight wide fma: " << fma_dot_prod_8 << " residual: " << fma_dot_prod_8 - simple_double << "\n";
222 #endif
223
224 if (strictness == FloatStrictness::Strict) {
225 // assert kahan is more accurate than simple method
226 assert((fabs(simple_double - kahan) <= fabs(simple_double - simple_float)));
227 // assert vecotorized kahan is more accurate than simple method
228 assert((fabs(simple_double - kahan_vec_4) <= fabs(simple_double - simple_float)));
229 assert((fabs(simple_double - kahan_vec_8) <= fabs(simple_double - simple_float)));
230 // Just use some vars for now.
231 assert(simple_double_vec_4 != 0 && simple_double_vec_8 != 0 && simple_float_vec_4 != 0 && simple_float_vec_8 != 0);
232 }
233 }
234
run_all_conditions(const char * name,Buffer<float> & vals)235 void run_all_conditions(const char *name, Buffer<float> &vals) {
236 std::cout << "Running on " << name << " data:\n";
237
238 Target loose{get_jit_target_from_environment().without_feature(Target::StrictFloat)};
239 Target strict{loose.with_feature(Target::StrictFloat)};
240
241 run_one_condition(loose, FloatStrictness::Default, vals);
242 run_one_condition(strict, FloatStrictness::Default, vals);
243 run_one_condition(loose, FloatStrictness::Strict, vals);
244 run_one_condition(strict, FloatStrictness::Strict, vals);
245 }
246
block_transposed_by_n(Buffer<float> & buf,int vectorize)247 Buffer<float> block_transposed_by_n(Buffer<float> &buf, int vectorize) {
248 Buffer<float> result(buf.width());
249
250 int block_size = buf.width() / vectorize;
251 for (int32_t i = 0; i < block_size; i++) {
252 for (int32_t j = 0; j < vectorize; j++) {
253 result(i * vectorize + j) = buf(j * block_size + i);
254 }
255 }
256 return result;
257 }
258
main(int argc,char ** argv)259 int main(int argc, char **argv) {
260 std::cout << std::setprecision(10);
261 Buffer<float> vals = one_million_rando_floats();
262 Buffer<float> transposed;
263 in.set(vals);
264 // Clean up stmt file by asserting clean division. Also eliminates needing boundary conditions.
265 in.dim(0).set_bounds(0, 1000000);
266
267 // Random data, average case for error.
268 run_all_conditions("random", vals);
269 transposed = block_transposed_by_n(vals, 4);
270 in.set(transposed);
271 run_all_conditions("random transposed", transposed);
272
273 // Ascending, best case for error.
274 std::sort(vals.begin(), vals.end());
275 in.set(vals);
276 run_all_conditions("sorted ascending", vals);
277 transposed = block_transposed_by_n(vals, 4);
278 in.set(transposed);
279 run_all_conditions("sorted ascending transposed", transposed);
280
281 // Descending, worst case for error.
282 std::sort(vals.begin(), vals.end(), std::greater<float>());
283 in.set(vals);
284 run_all_conditions("sorted descending", vals);
285 transposed = block_transposed_by_n(vals, 4);
286 in.set(transposed);
287 run_all_conditions("sorted descending transposed", transposed);
288
289 printf("Success!\n");
290
291 return 0;
292 }
293