1 //===-- Implementation of fmaf function -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "src/math/fmaf.h"
10 #include "src/__support/common.h"
11 
12 #include "utils/FPUtil/FEnv.h"
13 #include "utils/FPUtil/FPBits.h"
14 
15 namespace __llvm_libc {
16 
17 LLVM_LIBC_FUNCTION(float, fmaf, (float x, float y, float z)){
18   // Product is exact.
19   double prod = static_cast<double>(x) * static_cast<double>(y);
20   double z_d = static_cast<double>(z);
21   double sum = prod + z_d;
22   fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
23 
24   if (!(bit_sum.isInfOrNaN() || bit_sum.isZero())) {
25     // Since the sum is computed in double precision, rounding might happen
26     // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
27     // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
28     // the sum back to float, double rounding error might occur.
29     // A concrete example of this phenomenon is as follows:
30     //   x = y = 1 + 2^(-12), z = 2^(-53)
31     // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
32     // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
33     // On the other hand, with the default rounding mode,
34     //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
35     // and casting again to float gives us:
36     //   float(double(x*y + z)) = 1 + 2^(-11).
37     //
38     // In order to correct this possible double rounding error, first we use
39     // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
40     // assuming the (default) rounding mode is round-to-the-nearest,
41     // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
42     // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
43     // occurs when computing the sum, we just need to use t to adjust (any) last
44     // bit of sum, so that the sticky bits used when rounding sum to float are
45     // correct (when it matters).
46     fputil::FPBits<double> t(
47         (bit_prod.exponent >= bitz.exponent)
48             ? ((static_cast<double>(bit_sum) - bit_prod) - bitz)
49             : ((static_cast<double>(bit_sum) - bitz) - bit_prod));
50 
51     // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
52     // zero.
53     if (!t.isZero() && ((bit_sum.mantissa & 0xfff'ffffULL) == 0)) {
54       if (bit_sum.sign != t.sign) {
55         ++bit_sum.mantissa;
56       } else if (bit_sum.mantissa) {
57         --bit_sum.mantissa;
58       }
59     }
60   }
61 
62   return static_cast<float>(static_cast<double>(bit_sum));
63 }
64 
65 } // namespace __llvm_libc
66