1 #include "Halide.h"
2 #include <algorithm>
3 #include <cmath>
4 #include <future>
5 #include <math.h>
6 #include <random>
7 #include <stdio.h>
8 #include <string.h>
9 
10 using namespace Halide;
11 
12 // Make some functions for turning types into strings
13 template<typename A>
14 const char *string_of_type();
15 
16 #define DECL_SOT(name)                   \
17     template<>                           \
18     const char *string_of_type<name>() { \
19         return #name;                    \
20     }
21 
22 DECL_SOT(uint8_t);
23 DECL_SOT(int8_t);
24 DECL_SOT(uint16_t);
25 DECL_SOT(int16_t);
26 DECL_SOT(uint32_t);
27 DECL_SOT(int32_t);
28 DECL_SOT(float);
29 DECL_SOT(double);
30 DECL_SOT(float16_t);
31 DECL_SOT(bfloat16_t);
32 
33 template<typename A>
34 A mod(A x, A y);
35 
36 template<>
mod(float x,float y)37 float mod(float x, float y) {
38     return fmod(x, y);
39 }
40 
41 template<>
mod(double x,double y)42 double mod(double x, double y) {
43     return fmod(x, y);
44 }
45 
46 template<>
mod(float16_t x,float16_t y)47 float16_t mod(float16_t x, float16_t y) {
48     return float16_t(fmod(float(x), float(y)));
49 }
50 
51 template<>
mod(bfloat16_t x,bfloat16_t y)52 bfloat16_t mod(bfloat16_t x, bfloat16_t y) {
53     return bfloat16_t(fmod(float(x), float(y)));
54 }
55 
56 template<typename A>
mod(A x,A y)57 A mod(A x, A y) {
58     return x % y;
59 }
60 
61 template<typename A>
close_enough(A x,A y)62 bool close_enough(A x, A y) {
63     return x == y;
64 }
65 
66 template<>
close_enough(float x,float y)67 bool close_enough<float>(float x, float y) {
68     return fabs(x - y) < 1e-4;
69 }
70 
71 template<>
close_enough(double x,double y)72 bool close_enough<double>(double x, double y) {
73     return fabs(x - y) < 1e-5;
74 }
75 
76 template<>
close_enough(float16_t x,float16_t y)77 bool close_enough<float16_t>(float16_t x, float16_t y) {
78     if (x == y) return true;
79     float16_t upper = float16_t::make_from_bits(x.to_bits() + 2);
80     float16_t lower = float16_t::make_from_bits(x.to_bits() - 2);
81     if (lower > upper) std::swap(lower, upper);
82     return y >= lower && y <= upper;
83 }
84 
85 template<>
close_enough(bfloat16_t x,bfloat16_t y)86 bool close_enough<bfloat16_t>(bfloat16_t x, bfloat16_t y) {
87     if (x == y) return true;
88     bfloat16_t upper = bfloat16_t::make_from_bits(x.to_bits() + 2);
89     bfloat16_t lower = bfloat16_t::make_from_bits(x.to_bits() - 2);
90     if (lower > upper) std::swap(lower, upper);
91     return (y >= lower) && (y <= upper);
92 }
93 
94 template<typename T>
divide(T x,T y)95 T divide(T x, T y) {
96     return (x - (((x % y) + y) % y)) / y;
97 }
98 
99 template<>
divide(float x,float y)100 float divide(float x, float y) {
101     return x / y;
102 }
103 
104 template<>
divide(double x,double y)105 double divide(double x, double y) {
106     return x / y;
107 }
108 
109 template<>
divide(float16_t x,float16_t y)110 float16_t divide(float16_t x, float16_t y) {
111     return x / y;
112 }
113 
114 template<>
divide(bfloat16_t x,bfloat16_t y)115 bfloat16_t divide(bfloat16_t x, bfloat16_t y) {
116     return x / y;
117 }
118 
119 template<typename A>
absd(A x,A y)120 A absd(A x, A y) {
121     return x > y ? x - y : y - x;
122 }
123 
mantissa(float x)124 int mantissa(float x) {
125     int bits = 0;
126     memcpy(&bits, &x, 4);
127     return bits & 0x007fffff;
128 }
129 
130 template<typename T>
131 struct with_unsigned {
132     typedef T type;
133 };
134 
135 template<>
136 struct with_unsigned<int8_t> {
137     typedef uint8_t type;
138 };
139 
140 template<>
141 struct with_unsigned<int16_t> {
142     typedef uint16_t type;
143 };
144 
145 template<>
146 struct with_unsigned<int32_t> {
147     typedef uint32_t type;
148 };
149 
150 template<>
151 struct with_unsigned<int64_t> {
152     typedef uint64_t type;
153 };
154 
155 template<typename A>
test(int lanes,int seed)156 bool test(int lanes, int seed) {
157     const int W = 320;
158     const int H = 16;
159 
160     const int verbose = false;
161 
162     printf("Testing %sx%d\n", string_of_type<A>(), lanes);
163 
164     // use std::mt19937 instead of rand() to ensure consistent behavior on all systems
165     std::mt19937 rng(seed);
166     std::uniform_int_distribution<> dis(0, 1023);
167 
168     Buffer<A> input(W + 16, H + 16);
169     for (int y = 0; y < H + 16; y++) {
170         for (int x = 0; x < W + 16; x++) {
171             // We must ensure that the result of casting is not out-of-range:
172             // float->int casts are UB if the result doesn't fit.
173             input(x, y) = (A)(dis(rng) * 0.0625 + 1.0);
174             if ((A)(-1) < (A)(0)) {
175                 input(x, y) -= (A)(10);
176             }
177         }
178     }
179     Var x, y;
180 
181     // Add
182     {
183         if (verbose) printf("Add\n");
184         Func f1;
185         f1(x, y) = input(x, y) + input(x + 1, y);
186         f1.vectorize(x, lanes);
187         Buffer<A> im1 = f1.realize(W, H);
188 
189         for (int y = 0; y < H; y++) {
190             for (int x = 0; x < W; x++) {
191                 A correct = input(x, y) + input(x + 1, y);
192                 if (im1(x, y) != correct) {
193                     printf("im1(%d, %d) = %f instead of %f\n", x, y, (double)(im1(x, y)), (double)(correct));
194                     return false;
195                 }
196             }
197         }
198     }
199 
200     // Sub
201     {
202         if (verbose) printf("Subtract\n");
203         Func f2;
204         f2(x, y) = input(x, y) - input(x + 1, y);
205         f2.vectorize(x, lanes);
206         Buffer<A> im2 = f2.realize(W, H);
207 
208         for (int y = 0; y < H; y++) {
209             for (int x = 0; x < W; x++) {
210                 A correct = input(x, y) - input(x + 1, y);
211                 if (im2(x, y) != correct) {
212                     printf("im2(%d, %d) = %f instead of %f\n", x, y, (double)(im2(x, y)), (double)(correct));
213                     return false;
214                 }
215             }
216         }
217     }
218 
219     // Mul
220     {
221         if (verbose) printf("Multiply\n");
222         Func f3;
223         f3(x, y) = input(x, y) * input(x + 1, y);
224         f3.vectorize(x, lanes);
225         Buffer<A> im3 = f3.realize(W, H);
226 
227         for (int y = 0; y < H; y++) {
228             for (int x = 0; x < W; x++) {
229                 A correct = input(x, y) * input(x + 1, y);
230                 if (im3(x, y) != correct) {
231                     printf("im3(%d, %d) = %f instead of %f\n", x, y, (double)(im3(x, y)), (double)(correct));
232                     return false;
233                 }
234             }
235         }
236     }
237 
238     // select
239     {
240         if (verbose) printf("Select\n");
241         Func f4;
242         f4(x, y) = select(input(x, y) > input(x + 1, y), input(x + 2, y), input(x + 3, y));
243         f4.vectorize(x, lanes);
244         Buffer<A> im4 = f4.realize(W, H);
245 
246         for (int y = 0; y < H; y++) {
247             for (int x = 0; x < W; x++) {
248                 A correct = input(x, y) > input(x + 1, y) ? input(x + 2, y) : input(x + 3, y);
249                 if (im4(x, y) != correct) {
250                     printf("im4(%d, %d) = %f instead of %f\n", x, y, (double)(im4(x, y)), (double)(correct));
251                     return false;
252                 }
253             }
254         }
255     }
256 
257     // Gather
258     {
259         if (verbose) printf("Gather\n");
260         Func f5;
261         Expr xCoord = clamp(cast<int>(input(x, y)), 0, W - 1);
262         Expr yCoord = clamp(cast<int>(input(x + 1, y)), 0, H - 1);
263         f5(x, y) = input(xCoord, yCoord);
264         f5.vectorize(x, lanes);
265         Buffer<A> im5 = f5.realize(W, H);
266 
267         for (int y = 0; y < H; y++) {
268             for (int x = 0; x < W; x++) {
269                 int xCoord = (int)(input(x, y));
270                 if (xCoord >= W) xCoord = W - 1;
271                 if (xCoord < 0) xCoord = 0;
272 
273                 int yCoord = (int)(input(x + 1, y));
274                 if (yCoord >= H) yCoord = H - 1;
275                 if (yCoord < 0) yCoord = 0;
276 
277                 A correct = input(xCoord, yCoord);
278 
279                 if (im5(x, y) != correct) {
280                     printf("im5(%d, %d) = %f instead of %f\n", x, y, (double)(im5(x, y)), (double)(correct));
281                     return false;
282                 }
283             }
284         }
285     }
286 
287     // Gather and scatter with constant but unknown stride
288     {
289         Func f5a;
290         f5a(x, y) = input(x, y) * cast<A>(2);
291         f5a.vectorize(y, lanes);
292         Buffer<A> im5a = f5a.realize(W, H);
293 
294         for (int y = 0; y < H; y++) {
295             for (int x = 0; x < W; x++) {
296                 A correct = input(x, y) * ((A)(2));
297                 if (im5a(x, y) != correct) {
298                     printf("im5a(%d, %d) = %f instead of %f\n", x, y, (double)(im5a(x, y)), (double)(correct));
299                     return false;
300                 }
301             }
302         }
303     }
304 
305     // Scatter
306     {
307         if (verbose) printf("Scatter\n");
308         Func f6;
309         // Set one entry in each column high
310         f6(x, y) = 0;
311         f6(x, clamp(x * x, 0, H - 1)) = 1;
312 
313         f6.update().vectorize(x, lanes);
314 
315         Buffer<int> im6 = f6.realize(W, H);
316 
317         for (int x = 0; x < W; x++) {
318             int yCoord = x * x;
319             if (yCoord >= H) yCoord = H - 1;
320             if (yCoord < 0) yCoord = 0;
321             for (int y = 0; y < H; y++) {
322                 int correct = y == yCoord ? 1 : 0;
323                 if (im6(x, y) != correct) {
324                     printf("im6(%d, %d) = %d instead of %d\n", x, y, im6(x, y), correct);
325                     return false;
326                 }
327             }
328         }
329     }
330 
331     // Min/max
332     {
333         if (verbose) printf("Min/max\n");
334         Func f7;
335         f7(x, y) = clamp(input(x, y), cast<A>(10), cast<A>(20));
336         f7.vectorize(x, lanes);
337         Buffer<A> im7 = f7.realize(W, H);
338 
339         for (int y = 0; y < H; y++) {
340             for (int x = 0; x < W; x++) {
341                 if (im7(x, y) < (A)10 || im7(x, y) > (A)20) {
342                     printf("im7(%d, %d) = %f\n", x, y, (double)(im7(x, y)));
343                     return false;
344                 }
345             }
346         }
347     }
348 
349     // Extern function call
350     {
351         if (verbose) printf("External call to hypot\n");
352         Func f8;
353         f8(x, y) = hypot(1.1f, cast<float>(input(x, y)));
354         f8.vectorize(x, lanes);
355         Buffer<float> im8 = f8.realize(W, H);
356 
357         for (int y = 0; y < H; y++) {
358             for (int x = 0; x < W; x++) {
359                 float correct = hypotf(1.1f, (float)input(x, y));
360                 if (!close_enough(im8(x, y), correct)) {
361                     printf("im8(%d, %d) = %f instead of %f\n",
362                            x, y, (double)im8(x, y), correct);
363                     return false;
364                 }
365             }
366         }
367     }
368 
369     // Div
370     {
371         if (verbose) printf("Division\n");
372         Func f9;
373         f9(x, y) = input(x, y) / clamp(input(x + 1, y), cast<A>(1), cast<A>(3));
374         f9.vectorize(x, lanes);
375         Buffer<A> im9 = f9.realize(W, H);
376 
377         for (int y = 0; y < H; y++) {
378             for (int x = 0; x < W; x++) {
379                 A clamped = input(x + 1, y);
380                 if (clamped < (A)1) clamped = (A)1;
381                 if (clamped > (A)3) clamped = (A)3;
382                 A correct = divide(input(x, y), clamped);
383                 // We allow floating point division to take some liberties with accuracy
384                 if (!close_enough(im9(x, y), correct)) {
385                     printf("im9(%d, %d) = %f/%f = %f instead of %f\n",
386                            x, y,
387                            (double)input(x, y), (double)clamped,
388                            (double)(im9(x, y)), (double)(correct));
389                     return false;
390                 }
391             }
392         }
393     }
394 
395     // Divide by small constants
396     {
397         if (verbose) printf("Dividing by small constants\n");
398         for (int c = 2; c < 16; c++) {
399             Func f10;
400             f10(x, y) = (input(x, y)) / cast<A>(Expr(c));
401             f10.vectorize(x, lanes);
402             Buffer<A> im10 = f10.realize(W, H);
403 
404             for (int y = 0; y < H; y++) {
405                 for (int x = 0; x < W; x++) {
406                     A correct = divide(input(x, y), (A)c);
407 
408                     if (!close_enough(im10(x, y), correct)) {
409                         printf("im10(%d, %d) = %f/%d = %f instead of %f\n", x, y,
410                                (double)(input(x, y)), c,
411                                (double)(im10(x, y)),
412                                (double)(correct));
413                         printf("Error when dividing by %d\n", c);
414                         return false;
415                     }
416                 }
417             }
418         }
419     }
420 
421     // Interleave
422     {
423         if (verbose) printf("Interleaving store\n");
424         Func f11;
425         f11(x, y) = select((x % 2) == 0, input(x / 2, y), input(x / 2, y + 1));
426         f11.vectorize(x, lanes);
427         Buffer<A> im11 = f11.realize(W, H);
428 
429         for (int y = 0; y < H; y++) {
430             for (int x = 0; x < W; x++) {
431                 A correct = ((x % 2) == 0) ? input(x / 2, y) : input(x / 2, y + 1);
432                 if (im11(x, y) != correct) {
433                     printf("im11(%d, %d) = %f instead of %f\n", x, y, (double)(im11(x, y)), (double)(correct));
434                     return false;
435                 }
436             }
437         }
438     }
439 
440     // Reverse
441     {
442         if (verbose) printf("Reversing\n");
443         Func f12;
444         f12(x, y) = input(W - 1 - x, H - 1 - y);
445         f12.vectorize(x, lanes);
446         Buffer<A> im12 = f12.realize(W, H);
447 
448         for (int y = 0; y < H; y++) {
449             for (int x = 0; x < W; x++) {
450                 A correct = input(W - 1 - x, H - 1 - y);
451                 if (im12(x, y) != correct) {
452                     printf("im12(%d, %d) = %f instead of %f\n", x, y, (double)(im12(x, y)), (double)(correct));
453                     return false;
454                 }
455             }
456         }
457     }
458 
459     // Unaligned load with known shift
460     {
461         if (verbose) printf("Unaligned load\n");
462         Func f13;
463         f13(x, y) = input(x + 3, y);
464         f13.vectorize(x, lanes);
465         Buffer<A> im13 = f13.realize(W, H);
466 
467         for (int y = 0; y < H; y++) {
468             for (int x = 0; x < W; x++) {
469                 A correct = input(x + 3, y);
470                 if (im13(x, y) != correct) {
471                     printf("im13(%d, %d) = %f instead of %f\n", x, y, (double)(im13(x, y)), (double)(correct));
472                 }
473             }
474         }
475     }
476 
477     // Absolute value
478     {
479         if (!type_of<A>().is_uint()) {
480             if (verbose) printf("Absolute value\n");
481             Func f14;
482             f14(x, y) = cast<A>(abs(input(x, y)));
483             Buffer<A> im14 = f14.realize(W, H);
484 
485             for (int y = 0; y < H; y++) {
486                 for (int x = 0; x < W; x++) {
487                     A correct = input(x, y);
488                     if (correct <= A(0)) correct = -correct;
489                     if (im14(x, y) != correct) {
490                         printf("im14(%d, %d) = %f instead of %f\n", x, y, (double)(im14(x, y)), (double)(correct));
491                     }
492                 }
493             }
494         }
495     }
496 
497     // pmaddwd
498     {
499         if (type_of<A>() == Int(16)) {
500             if (verbose) printf("pmaddwd\n");
501             Func f15, f16;
502             f15(x, y) = cast<int>(input(x, y)) * input(x, y + 2) + cast<int>(input(x, y + 1)) * input(x, y + 3);
503             f16(x, y) = cast<int>(input(x, y)) * input(x, y + 2) - cast<int>(input(x, y + 1)) * input(x, y + 3);
504             f15.vectorize(x, lanes);
505             f16.vectorize(x, lanes);
506             Buffer<int32_t> im15 = f15.realize(W, H);
507             Buffer<int32_t> im16 = f16.realize(W, H);
508             for (int y = 0; y < H; y++) {
509                 for (int x = 0; x < W; x++) {
510                     int correct15 = int(input(x, y) * input(x, y + 2) + input(x, y + 1) * input(x, y + 3));
511                     int correct16 = int(input(x, y) * input(x, y + 2) - input(x, y + 1) * input(x, y + 3));
512                     if (im15(x, y) != correct15) {
513                         printf("im15(%d, %d) = %d instead of %d\n", x, y, im15(x, y), correct15);
514                     }
515                     if (im16(x, y) != correct16) {
516                         printf("im16(%d, %d) = %d instead of %d\n", x, y, im16(x, y), correct16);
517                     }
518                 }
519             }
520         }
521     }
522 
523     // Fast exp, log, and pow
524     if (type_of<A>() == Float(32)) {
525         if (verbose) printf("Fast transcendentals\n");
526         Buffer<float> im15, im16, im17, im18, im19, im20;
527         Expr a = input(x, y) * 0.5f;
528         Expr b = input((x + 1) % W, y) * 0.5f;
529         {
530             Func f15;
531             f15(x, y) = log(a);
532             im15 = f15.realize(W, H);
533         }
534         {
535             Func f16;
536             f16(x, y) = exp(b);
537             im16 = f16.realize(W, H);
538         }
539         {
540             Func f17;
541             f17(x, y) = pow(a, b / 16.0f);
542             im17 = f17.realize(W, H);
543         }
544         {
545             Func f18;
546             f18(x, y) = fast_log(a);
547             im18 = f18.realize(W, H);
548         }
549         {
550             Func f19;
551             f19(x, y) = fast_exp(b);
552             im19 = f19.realize(W, H);
553         }
554         {
555             Func f20;
556             f20(x, y) = fast_pow(a, b / 16.0f);
557             im20 = f20.realize(W, H);
558         }
559 
560         int worst_log_mantissa = 0;
561         int worst_exp_mantissa = 0;
562         int worst_pow_mantissa = 0;
563         int worst_fast_log_mantissa = 0;
564         int worst_fast_exp_mantissa = 0;
565         int worst_fast_pow_mantissa = 0;
566 
567         for (int y = 0; y < H; y++) {
568             for (int x = 0; x < W; x++) {
569                 float a = float(input(x, y)) * 0.5f;
570                 float b = float(input((x + 1) % W, y)) * 0.5f;
571                 float correct_log = logf(a);
572                 float correct_exp = expf(b);
573                 float correct_pow = powf(a, b / 16.0f);
574 
575                 int correct_log_mantissa = mantissa(correct_log);
576                 int correct_exp_mantissa = mantissa(correct_exp);
577                 int correct_pow_mantissa = mantissa(correct_pow);
578 
579                 int log_mantissa = mantissa(im15(x, y));
580                 int exp_mantissa = mantissa(im16(x, y));
581                 int pow_mantissa = mantissa(im17(x, y));
582 
583                 int fast_log_mantissa = mantissa(im18(x, y));
584                 int fast_exp_mantissa = mantissa(im19(x, y));
585                 int fast_pow_mantissa = mantissa(im20(x, y));
586 
587                 int log_mantissa_error = abs(log_mantissa - correct_log_mantissa);
588                 int exp_mantissa_error = abs(exp_mantissa - correct_exp_mantissa);
589                 int pow_mantissa_error = abs(pow_mantissa - correct_pow_mantissa);
590                 int fast_log_mantissa_error = abs(fast_log_mantissa - correct_log_mantissa);
591                 int fast_exp_mantissa_error = abs(fast_exp_mantissa - correct_exp_mantissa);
592                 int fast_pow_mantissa_error = abs(fast_pow_mantissa - correct_pow_mantissa);
593 
594                 worst_log_mantissa = std::max(worst_log_mantissa, log_mantissa_error);
595                 worst_exp_mantissa = std::max(worst_exp_mantissa, exp_mantissa_error);
596 
597                 if (a >= 0) {
598                     worst_pow_mantissa = std::max(worst_pow_mantissa, pow_mantissa_error);
599                 }
600 
601                 if (std::isfinite(correct_log)) {
602                     worst_fast_log_mantissa = std::max(worst_fast_log_mantissa, fast_log_mantissa_error);
603                 }
604 
605                 if (std::isfinite(correct_exp)) {
606                     worst_fast_exp_mantissa = std::max(worst_fast_exp_mantissa, fast_exp_mantissa_error);
607                 }
608 
609                 if (std::isfinite(correct_pow) && a > 0) {
610                     worst_fast_pow_mantissa = std::max(worst_fast_pow_mantissa, fast_pow_mantissa_error);
611                 }
612 
613                 if (log_mantissa_error > 8) {
614                     printf("log(%f) = %1.10f instead of %1.10f (mantissa: %d vs %d)\n",
615                            a, im15(x, y), correct_log, correct_log_mantissa, log_mantissa);
616                 }
617                 if (exp_mantissa_error > 32) {
618                     // Actually good to the last 2 bits of the mantissa with sse4.1 / avx
619                     printf("exp(%f) = %1.10f instead of %1.10f (mantissa: %d vs %d)\n",
620                            b, im16(x, y), correct_exp, correct_exp_mantissa, exp_mantissa);
621                 }
622                 if (a >= 0 && pow_mantissa_error > 64) {
623                     printf("pow(%f, %f) = %1.10f instead of %1.10f (mantissa: %d vs %d)\n",
624                            a, b / 16.0f, im17(x, y), correct_pow, correct_pow_mantissa, pow_mantissa);
625                 }
626                 if (std::isfinite(correct_log) && fast_log_mantissa_error > 64) {
627                     printf("fast_log(%f) = %1.10f instead of %1.10f (mantissa: %d vs %d)\n",
628                            a, im18(x, y), correct_log, correct_log_mantissa, fast_log_mantissa);
629                 }
630                 if (std::isfinite(correct_exp) && fast_exp_mantissa_error > 64) {
631                     printf("fast_exp(%f) = %1.10f instead of %1.10f (mantissa: %d vs %d)\n",
632                            b, im19(x, y), correct_exp, correct_exp_mantissa, fast_exp_mantissa);
633                 }
634                 if (a >= 0 && std::isfinite(correct_pow) && fast_pow_mantissa_error > 128) {
635                     printf("fast_pow(%f, %f) = %1.10f instead of %1.10f (mantissa: %d vs %d)\n",
636                            a, b / 16.0f, im20(x, y), correct_pow, correct_pow_mantissa, fast_pow_mantissa);
637                 }
638             }
639         }
640 
641         /*
642         printf("log mantissa error: %d\n", worst_log_mantissa);
643         printf("exp mantissa error: %d\n", worst_exp_mantissa);
644         printf("pow mantissa error: %d\n", worst_pow_mantissa);
645         printf("fast_log mantissa error: %d\n", worst_fast_log_mantissa);
646         printf("fast_exp mantissa error: %d\n", worst_fast_exp_mantissa);
647         printf("fast_pow mantissa error: %d\n", worst_fast_pow_mantissa);
648         */
649     }
650 
651     // Lerp (where the weight is the same type as the values)
652     {
653         if (verbose) printf("Lerp\n");
654         Func f21;
655         Expr weight = input(x + 2, y);
656         Type t = type_of<A>();
657         if (t.is_float()) {
658             weight = clamp(weight, cast<A>(0), cast<A>(1));
659         } else if (t.is_int()) {
660             weight = cast(UInt(t.bits(), t.lanes()), max(0, weight));
661         }
662         f21(x, y) = lerp(input(x, y), input(x + 1, y), weight);
663         Buffer<A> im21 = f21.realize(W, H);
664 
665         for (int y = 0; y < H; y++) {
666             for (int x = 0; x < W; x++) {
667                 double a = (double)(input(x, y));
668                 double b = (double)(input(x + 1, y));
669                 double w = (double)(input(x + 2, y));
670                 if (w < 0) w = 0;
671                 if (!t.is_float()) {
672                     uint64_t divisor = 1;
673                     divisor <<= t.bits();
674                     divisor -= 1;
675                     w /= divisor;
676                 }
677                 w = std::min(std::max(w, 0.0), 1.0);
678 
679                 double lerped = (a * (1.0 - w) + b * w);
680                 if (!t.is_float()) {
681                     lerped = floor(lerped + 0.5);
682                 }
683                 A correct = (A)(lerped);
684                 if (im21(x, y) != correct) {
685                     printf("lerp(%f, %f, %f) = %f instead of %f\n", a, b, w, (double)(im21(x, y)), (double)(correct));
686                     return false;
687                 }
688             }
689         }
690     }
691 
692     // Absolute difference
693     {
694         if (verbose) printf("Absolute difference\n");
695         Func f22;
696         f22(x, y) = absd(input(x, y), input(x + 1, y));
697         f22.vectorize(x, lanes);
698         Buffer<typename with_unsigned<A>::type> im22 = f22.realize(W, H);
699 
700         for (int y = 0; y < H; y++) {
701             for (int x = 0; x < W; x++) {
702                 using T = typename with_unsigned<A>::type;
703                 T correct = T(absd((double)input(x, y), (double)input(x + 1, y)));
704                 if (im22(x, y) != correct) {
705                     printf("im22(%d, %d) = %f instead of %f\n", x, y, (double)(im22(x, y)), (double)(correct));
706                     return false;
707                 }
708             }
709         }
710     }
711 
712     return true;
713 }
714 
main(int argc,char ** argv)715 int main(int argc, char **argv) {
716 
717     int seed = argc > 1 ? atoi(argv[1]) : time(nullptr);
718     std::cout << "vector_math test seed: " << seed << std::endl;
719 
720     // Only native vector widths - llvm doesn't handle others well
721     Halide::Internal::ThreadPool<bool> pool;
722     std::vector<std::future<bool>> futures;
723     futures.push_back(pool.async(test<float>, 4, seed));
724     futures.push_back(pool.async(test<float>, 8, seed));
725     futures.push_back(pool.async(test<double>, 2, seed));
726     futures.push_back(pool.async(test<uint8_t>, 16, seed));
727     futures.push_back(pool.async(test<int8_t>, 16, seed));
728     futures.push_back(pool.async(test<uint16_t>, 8, seed));
729     futures.push_back(pool.async(test<int16_t>, 8, seed));
730     futures.push_back(pool.async(test<uint32_t>, 4, seed));
731     futures.push_back(pool.async(test<int32_t>, 4, seed));
732     futures.push_back(pool.async(test<bfloat16_t>, 8, seed));
733     futures.push_back(pool.async(test<bfloat16_t>, 16, seed));
734     futures.push_back(pool.async(test<float16_t>, 8, seed));
735     futures.push_back(pool.async(test<float16_t>, 16, seed));
736     bool ok = true;
737     for (auto &f : futures) {
738         ok &= f.get();
739     }
740 
741     if (!ok) return -1;
742     printf("Success!\n");
743     return 0;
744 }
745