1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_UTILS_FPUTIL_GENERIC_FMA_H
11 
12 #include "utils/CPP/TypeTraits.h"
13 
14 namespace __llvm_libc {
15 namespace fputil {
16 namespace generic {
17 
18 template <typename T>
fma(T x,T y,T z)19 static inline cpp::EnableIfType<cpp::IsSame<T, float>::Value, T> fma(T x, T y,
20                                                                      T z) {
21   // Product is exact.
22   double prod = static_cast<double>(x) * static_cast<double>(y);
23   double z_d = static_cast<double>(z);
24   double sum = prod + z_d;
25   fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
26 
27   if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
28     // Since the sum is computed in double precision, rounding might happen
29     // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
30     // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
31     // the sum back to float, double rounding error might occur.
32     // A concrete example of this phenomenon is as follows:
33     //   x = y = 1 + 2^(-12), z = 2^(-53)
34     // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
35     // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
36     // On the other hand, with the default rounding mode,
37     //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
38     // and casting again to float gives us:
39     //   float(double(x*y + z)) = 1 + 2^(-11).
40     //
41     // In order to correct this possible double rounding error, first we use
42     // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
43     // assuming the (default) rounding mode is round-to-the-nearest,
44     // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
45     // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
46     // occurs when computing the sum, we just need to use t to adjust (any) last
47     // bit of sum, so that the sticky bits used when rounding sum to float are
48     // correct (when it matters).
49     fputil::FPBits<double> t(
50         (bit_prod.getUnbiasedExponent() >= bitz.getUnbiasedExponent())
51             ? ((double(bit_sum) - double(bit_prod)) - double(bitz))
52             : ((double(bit_sum) - double(bitz)) - double(bit_prod)));
53 
54     // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
55     // zero.
56     if (!t.isZero() && ((bit_sum.getMantissa() & 0xfff'ffffULL) == 0)) {
57       if (bit_sum.getSign() != t.getSign()) {
58         bit_sum.setMantissa(bit_sum.getMantissa() + 1);
59       } else if (bit_sum.getMantissa()) {
60         bit_sum.setMantissa(bit_sum.getMantissa() - 1);
61       }
62     }
63   }
64 
65   return static_cast<float>(static_cast<double>(bit_sum));
66 }
67 
68 } // namespace generic
69 } // namespace fputil
70 } // namespace __llvm_libc
71 
72 #endif // LLVM_LIBC_UTILS_FPUTIL_GENERIC_FMA_H
73