1 #include "Halide.h"
2 
3 #include <algorithm>
4 #include <iomanip>
5 #include <ios>
6 #include <iostream>
7 
8 using namespace Halide;
9 
10 #if defined(__SSE2__) || defined(__AVX__)
11 #include <immintrin.h>
12 #endif
13 
14 #ifdef __SSE2__
no_fma_dot_prod_sse(const float * in,int count)15 float no_fma_dot_prod_sse(const float *in, int count) {
16     __m128 sum = _mm_set1_ps(0.0f);
17     const __m128 *in_v = (const __m128 *)in;
18     for (int i = 0; i < count / 4; i++) {
19         __m128 prod = _mm_mul_ps(in_v[i], in_v[i]);
20         sum = _mm_add_ps(prod, sum);
21     }
22     float *f = (float *)&sum;
23     float result = 0.0f;
24     for (int i = 0; i < 4; i++) {
25         result += f[i];
26     }
27     return result;
28 }
29 #endif
30 
31 #if defined(__SSE2__) && defined(__FMA__)
fma_dot_prod_sse(const float * in,int count)32 float fma_dot_prod_sse(const float *in, int count) {
33     __m128 sum = _mm_set1_ps(0.0f);
34     const __m128 *in_v = (const __m128 *)in;
35     for (int i = 0; i < count / 4; i++) {
36         sum = _mm_fmadd_ps(in_v[i], in_v[i], sum);
37     }
38     float *f = (float *)&sum;
39     float result = 0.0f;
40     for (int i = 0; i < 4; i++) {
41         result += f[i];
42     }
43     return result;
44 }
45 #endif
46 
47 #if defined(__AVX__)
no_fma_dot_prod_avx(const float * in,int count)48 float no_fma_dot_prod_avx(const float *in, int count) {
49     __m256 sum = _mm256_set1_ps(0.0f);
50     const __m256 *in_v = (const __m256 *)in;
51     for (int i = 0; i < count / 8; i++) {
52         __m256 prod = _mm256_mul_ps(in_v[i], in_v[i]);
53         sum = _mm256_add_ps(prod, sum);
54     }
55     float *f = (float *)&sum;
56     float result = 0.0f;
57     for (int i = 0; i < 8; i++) {
58         result += f[i];
59     }
60     return result;
61 }
62 #endif
63 
64 #if defined(__AVX__) && defined(__FMA__)
fma_dot_prod_avx(const float * in,int count)65 float fma_dot_prod_avx(const float *in, int count) {
66     __m256 sum = _mm256_set1_ps(0.0f);
67     const __m256 *in_v = (const __m256 *)in;
68     for (int i = 0; i < count / 8; i++) {
69         sum = _mm256_fmadd_ps(in_v[i], in_v[i], sum);
70     }
71     float *f = (float *)&sum;
72     float result = 0.0f;
73     for (int i = 0; i < 8; i++) {
74         result += f[i];
75     }
76     return result;
77 }
78 #endif
79 
one_million_rando_floats()80 Buffer<float> one_million_rando_floats() {
81     Var x("x");
82     Func randos;
83     randos(x) = random_float();
84     return randos.realize(1e6);
85 }
86 
87 ImageParam in(Float(32), 1);
88 
term(Expr index)89 Expr term(Expr index) {
90     return in(index)*in(index);
91 }
92 
93 enum class FloatStrictness {
94     Default,
95     Strict
96 } global_strictness = FloatStrictness::Default;
97 
strictness_to_string(FloatStrictness strictness)98 std::string strictness_to_string(FloatStrictness strictness) {
99     if (strictness == FloatStrictness::Strict) {
100         return "strict_float";
101     }
102     return "default";
103 }
104 
apply_strictness(Expr x)105 Expr apply_strictness(Expr x) {
106     if (global_strictness == FloatStrictness::Strict) {
107         return strict_float(x);
108     }
109     return x;
110 }
111 
112 template<typename Accum>
simple_sum(int vectorize)113 Func simple_sum(int vectorize) {
114     Func total("total");
115     // Can't use rfactor because strict_float is not associative.
116     if (vectorize != 0) {
117         Func total_inner("total_inner");
118         RDom r_outer(0, in.width() / vectorize);
119         RDom r_lanes(0, vectorize);
120         Var i("i");
121         total_inner(i) = cast<Accum>(0);
122         total_inner(i) = apply_strictness(total_inner(i) + cast<Accum>(term(r_outer * vectorize + i)));
123         total() = cast<Accum>(0);
124         total() = apply_strictness(total() + total_inner(r_lanes));
125         total_inner.compute_at(total, Var::outermost());
126         total_inner.vectorize(i);
127         total_inner.update(0).vectorize(i);
128     } else {
129         RDom r(0, in.width(), "r");
130 
131         total() = apply_strictness(cast<Accum>(0));
132         total() = apply_strictness(total() + cast<Accum>(term(r)));
133     }
134 #if 0
135     if (vectorize != 0) {
136         RVar rxo("rxo"), rxi("rxi");
137         Var u("u");
138         Func intm = total.update(0).split(r, rxo, rxi, vectorize).rfactor({{rxi, u}});
139         intm.compute_at(total, Var::outermost());
140         intm.vectorize(u, vectorize);
141         intm.update(0).vectorize(u, vectorize);
142     }
143 #endif
144     return lambda(apply_strictness(cast<float>(total())));
145 }
146 
kahan_sum(int vectorize)147 Func kahan_sum(int vectorize) {
148     // Item 0 of the tuple valued k_sum is the sum and item 1 is an error compensation term.
149     // See: https://en.wikipedia.org/wiki/Kahan_summation_algorithm
150     Func k_sum("k_sum");
151 
152     // rfactor cannot prove associativity for the non-strict formulation and strict_float is not associative.
153     if (vectorize != 0) {
154         Func k_sum_inner("k_sum_inner");
155         RDom r_outer(0, in.width() / vectorize);
156         RDom r_lanes(0, vectorize);
157         Var i("i");
158         k_sum_inner(i) = Tuple(0.0f, 0.0f);
159         k_sum_inner(i) = Tuple(apply_strictness(k_sum_inner(i)[0] + (term(r_outer * vectorize + i) - k_sum_inner(i)[1])),
160                                apply_strictness((k_sum_inner(i)[0] + (term(r_outer * vectorize + i) - k_sum_inner(i)[1])) - k_sum_inner(i)[0]) - (term(r_outer * vectorize + i) - k_sum_inner(i)[1]));
161         k_sum() = Tuple(0.0f, 0.0f);
162         k_sum() = Tuple(apply_strictness(k_sum()[0] + (k_sum_inner(r_lanes)[0] - k_sum()[1])),
163                         apply_strictness((k_sum()[0] + (k_sum_inner(r_lanes)[0] - k_sum()[1])) - k_sum()[0]) - (k_sum_inner(r_lanes)[0] - k_sum()[1]));
164         k_sum_inner.compute_at(k_sum, Var::outermost());
165         k_sum_inner.vectorize(i);
166         k_sum_inner.update(0).vectorize(i);
167     } else {
168         RDom r(0, in.width(), "r");
169 
170         k_sum() = Tuple(0.0f, 0.0f);
171         k_sum() = Tuple(apply_strictness(k_sum()[0] + (term(r) - k_sum()[1])),
172                         apply_strictness((k_sum()[0] + (term(r) - k_sum()[1])) - k_sum()[0]) - (term(r) - k_sum()[1]));
173     }
174 
175     return lambda(k_sum()[0]);
176 }
177 
eval(Func f,const Target & t,const std::string & name,const std::string & suffix,float expected)178 float eval(Func f, const Target &t, const std::string &name, const std::string &suffix, float expected) {
179     float val = ((Buffer<float>)f.realize(t))();
180     std::cout << "        " << name << ": " << val;
181     if (expected != 0.0f) {
182         std::cout << " residual: " << val - expected;
183     }
184     std::cout << "\n";
185     return val;
186 }
187 
run_one_condition(const Target & t,FloatStrictness strictness,Buffer<float> vals)188 void run_one_condition(const Target &t, FloatStrictness strictness, Buffer<float> vals) {
189     global_strictness = strictness;
190     std::string suffix = "_" + t.to_string() + "_" + strictness_to_string(strictness);
191 
192     std::cout << "    Target: " << t.to_string() << " Strictness: " << strictness_to_string(strictness) << "\n";
193 
194     float simple_double = eval(simple_sum<double>(0), t, "simple_double", suffix, 0.0f);
195     float simple_double_vec_4 = eval(simple_sum<double>(4), t, "simple_double_vec_4", suffix, simple_double);
196     float simple_double_vec_8 = eval(simple_sum<double>(8), t, "simple_double_vec_8", suffix, simple_double);
197     float simple_float = eval(simple_sum<float>(0), t, "simple_float", suffix, simple_double);
198     float simple_float_vec_4 = eval(simple_sum<float>(4), t, "simple_float_vec_4", suffix, simple_double);
199     float simple_float_vec_8 = eval(simple_sum<float>(8), t, "simple_float_vec_8", suffix, simple_double);
200     float kahan = eval(kahan_sum(0), t, "kahan", suffix, simple_double);
201     float kahan_vec_4 = eval(kahan_sum(4), t, "kahan_vec_4", suffix, simple_double);
202     float kahan_vec_8 = eval(kahan_sum(8), t, "kahan_vec_8", suffix, simple_double);
203 
204 #ifdef __SSE2__
205     float vec_dot_prod_4 = no_fma_dot_prod_sse(&vals(0), vals.width());
206     std::cout << "        four wide no fma: " << vec_dot_prod_4 << " residual: " << vec_dot_prod_4 - simple_double << "\n";
207 #endif
208 
209 #if defined(__SSE2__) && defined(__FMA__)
210     float fma_dot_prod_4 = fma_dot_prod_sse(&vals(0), vals.width());
211     std::cout << "        four wide fma: " << fma_dot_prod_4 << " residual: " << fma_dot_prod_4 - simple_double << "\n";
212 #endif
213 
214 #if defined(__AVX__)
215     float vec_dot_prod_8 = no_fma_dot_prod_avx(&vals(0), vals.width());
216     std::cout << "        eight wide no fma: " << vec_dot_prod_8 << " residual: " << vec_dot_prod_8 - simple_double << "\n";
217 #endif
218 
219 #if defined(__AVX__) && defined(__FMA__)
220     float fma_dot_prod_8 = fma_dot_prod_avx(&vals(0), vals.width());
221     std::cout << "        eight wide fma: " << fma_dot_prod_8 << " residual: " << fma_dot_prod_8 - simple_double << "\n";
222 #endif
223 
224     if (strictness == FloatStrictness::Strict) {
225         // assert kahan is more accurate than simple method
226         assert((fabs(simple_double - kahan) <= fabs(simple_double - simple_float)));
227         // assert vecotorized kahan is more accurate than simple method
228         assert((fabs(simple_double - kahan_vec_4) <= fabs(simple_double - simple_float)));
229         assert((fabs(simple_double - kahan_vec_8) <= fabs(simple_double - simple_float)));
230         // Just use some vars for now.
231         assert(simple_double_vec_4 != 0 && simple_double_vec_8 != 0 && simple_float_vec_4 != 0 && simple_float_vec_8 != 0);
232     }
233 }
234 
run_all_conditions(const char * name,Buffer<float> & vals)235 void run_all_conditions(const char *name, Buffer<float> &vals) {
236     std::cout << "Running on " << name << " data:\n";
237 
238     Target loose{get_jit_target_from_environment().without_feature(Target::StrictFloat)};
239     Target strict{loose.with_feature(Target::StrictFloat)};
240 
241     run_one_condition(loose, FloatStrictness::Default, vals);
242     run_one_condition(strict, FloatStrictness::Default, vals);
243     run_one_condition(loose, FloatStrictness::Strict, vals);
244     run_one_condition(strict, FloatStrictness::Strict, vals);
245 }
246 
block_transposed_by_n(Buffer<float> & buf,int vectorize)247 Buffer<float> block_transposed_by_n(Buffer<float> &buf, int vectorize) {
248     Buffer<float> result(buf.width());
249 
250     int block_size = buf.width() / vectorize;
251     for (int32_t i = 0; i < block_size; i++) {
252         for (int32_t j = 0; j < vectorize; j++) {
253             result(i * vectorize + j) = buf(j * block_size + i);
254         }
255     }
256     return result;
257 }
258 
main(int argc,char ** argv)259 int main(int argc, char **argv) {
260     std::cout << std::setprecision(10);
261     Buffer<float> vals = one_million_rando_floats();
262     Buffer<float> transposed;
263     in.set(vals);
264     // Clean up stmt file by asserting clean division. Also eliminates needing boundary conditions.
265     in.dim(0).set_bounds(0, 1000000);
266 
267     // Random data, average case for error.
268     run_all_conditions("random", vals);
269     transposed = block_transposed_by_n(vals, 4);
270     in.set(transposed);
271     run_all_conditions("random transposed", transposed);
272 
273     // Ascending, best case for error.
274     std::sort(vals.begin(), vals.end());
275     in.set(vals);
276     run_all_conditions("sorted ascending", vals);
277     transposed = block_transposed_by_n(vals, 4);
278     in.set(transposed);
279     run_all_conditions("sorted ascending transposed", transposed);
280 
281     // Descending, worst case for error.
282     std::sort(vals.begin(), vals.end(), std::greater<float>());
283     in.set(vals);
284     run_all_conditions("sorted descending", vals);
285     transposed = block_transposed_by_n(vals, 4);
286     in.set(transposed);
287     run_all_conditions("sorted descending transposed", transposed);
288 
289     printf("Success!\n");
290 
291     return 0;
292 }
293