1 //===-- Implementation of hypotf function ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_UTILS_FPUTIL_HYPOT_H
10 #define LLVM_LIBC_UTILS_FPUTIL_HYPOT_H
11 
12 #include "BasicOperations.h"
13 #include "FPBits.h"
14 #include "utils/CPP/TypeTraits.h"
15 
16 namespace __llvm_libc {
17 namespace fputil {
18 
19 namespace internal {
20 
21 template <typename T> static inline T findLeadingOne(T mant, int &shift_length);
22 
23 template <>
24 inline uint32_t findLeadingOne<uint32_t>(uint32_t mant, int &shift_length) {
25   shift_length = 0;
26   constexpr int nsteps = 5;
27   constexpr uint32_t bounds[nsteps] = {1 << 16, 1 << 8, 1 << 4, 1 << 2, 1 << 1};
28   constexpr int shifts[nsteps] = {16, 8, 4, 2, 1};
29   for (int i = 0; i < nsteps; ++i) {
30     if (mant >= bounds[i]) {
31       shift_length += shifts[i];
32       mant >>= shifts[i];
33     }
34   }
35   return 1U << shift_length;
36 }
37 
38 template <>
39 inline uint64_t findLeadingOne<uint64_t>(uint64_t mant, int &shift_length) {
40   shift_length = 0;
41   constexpr int nsteps = 6;
42   constexpr uint64_t bounds[nsteps] = {1ULL << 32, 1ULL << 16, 1ULL << 8,
43                                        1ULL << 4,  1ULL << 2,  1ULL << 1};
44   constexpr int shifts[nsteps] = {32, 16, 8, 4, 2, 1};
45   for (int i = 0; i < nsteps; ++i) {
46     if (mant >= bounds[i]) {
47       shift_length += shifts[i];
48       mant >>= shifts[i];
49     }
50   }
51   return 1ULL << shift_length;
52 }
53 
54 } // namespace internal
55 
56 template <typename T> struct DoubleLength;
57 
58 template <> struct DoubleLength<uint16_t> { using Type = uint32_t; };
59 
60 template <> struct DoubleLength<uint32_t> { using Type = uint64_t; };
61 
62 template <> struct DoubleLength<uint64_t> { using Type = __uint128_t; };
63 
64 // Correctly rounded IEEE 754 HYPOT(x, y) with round to nearest, ties to even.
65 //
66 // Algorithm:
67 //   -  Let a = max(|x|, |y|), b = min(|x|, |y|), then we have that:
68 //          a <= sqrt(a^2 + b^2) <= min(a + b, a*sqrt(2))
69 //   1. So if b < eps(a)/2, then HYPOT(x, y) = a.
70 //
71 //   -  Moreover, the exponent part of HYPOT(x, y) is either the same or 1 more
72 //      than the exponent part of a.
73 //
74 //   2. For the remaining cases, we will use the digit-by-digit (shift-and-add)
75 //      algorithm to compute SQRT(Z):
76 //
77 //   -  For Y = y0.y1...yn... = SQRT(Z),
78 //      let Y(n) = y0.y1...yn be the first n fractional digits of Y.
79 //
80 //   -  The nth scaled residual R(n) is defined to be:
81 //          R(n) = 2^n * (Z - Y(n)^2)
82 //
83 //   -  Since Y(n) = Y(n - 1) + yn * 2^(-n), the scaled residual
84 //      satisfies the following recurrence formula:
85 //          R(n) = 2*R(n - 1) - yn*(2*Y(n - 1) + 2^(-n)),
86 //      with the initial conditions:
87 //          Y(0) = y0, and R(0) = Z - y0.
88 //
89 //   -  So the nth fractional digit of Y = SQRT(Z) can be decided by:
90 //          yn = 1  if 2*R(n - 1) >= 2*Y(n - 1) + 2^(-n),
91 //               0  otherwise.
92 //
93 //   3. Precision analysis:
94 //
95 //   -  Notice that in the decision function:
96 //          2*R(n - 1) >= 2*Y(n - 1) + 2^(-n),
97 //      the right hand side only uses up to the 2^(-n)-bit, and both sides are
98 //      non-negative, so R(n - 1) can be truncated at the 2^(-(n + 1))-bit, so
99 //      that 2*R(n - 1) is corrected up to the 2^(-n)-bit.
100 //
101 //   -  Thus, in order to round SQRT(a^2 + b^2) correctly up to n-fractional
102 //      bits, we need to perform the summation (a^2 + b^2) correctly up to (2n +
103 //      2)-fractional bits, and the remaining bits are sticky bits (i.e. we only
104 //      care if they are 0 or > 0), and the comparisons, additions/subtractions
105 //      can be done in n-fractional bits precision.
106 //
107 //   -  For single precision (float), we can use uint64_t to store the sum a^2 +
108 //      b^2 exact up to (2n + 2)-fractional bits.
109 //
110 //   -  Then we can feed this sum into the digit-by-digit algorithm for SQRT(Z)
111 //      described above.
112 //
113 //
114 // Special cases:
115 //   - HYPOT(x, y) is +Inf if x or y is +Inf or -Inf; else
116 //   - HYPOT(x, y) is NaN if x or y is NaN.
117 //
118 template <typename T,
119           cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
120 static inline T hypot(T x, T y) {
121   using FPBits_t = FPBits<T>;
122   using UIntType = typename FPBits<T>::UIntType;
123   using DUIntType = typename DoubleLength<UIntType>::Type;
124 
125   FPBits_t x_bits(x), y_bits(y);
126 
127   if (x_bits.isInf() || y_bits.isInf()) {
128     return FPBits_t::inf();
129   }
130   if (x_bits.isNaN()) {
131     return x;
132   }
133   if (y_bits.isNaN()) {
134     return y;
135   }
136 
137   uint16_t a_exp, b_exp, out_exp;
138   UIntType a_mant, b_mant;
139   DUIntType a_mant_sq, b_mant_sq;
140   bool sticky_bits;
141 
142   if ((x_bits.exponent >= y_bits.exponent + MantissaWidth<T>::value + 2) ||
143       (y == 0)) {
144     return abs(x);
145   } else if ((y_bits.exponent >=
146               x_bits.exponent + MantissaWidth<T>::value + 2) ||
147              (x == 0)) {
148     y_bits.sign = 0;
149     return abs(y);
150   }
151 
152   if (x >= y) {
153     a_exp = x_bits.exponent;
154     a_mant = x_bits.mantissa;
155     b_exp = y_bits.exponent;
156     b_mant = y_bits.mantissa;
157   } else {
158     a_exp = y_bits.exponent;
159     a_mant = y_bits.mantissa;
160     b_exp = x_bits.exponent;
161     b_mant = x_bits.mantissa;
162   }
163 
164   out_exp = a_exp;
165 
166   // Add an extra bit to simplify the final rounding bit computation.
167   constexpr UIntType one = UIntType(1) << (MantissaWidth<T>::value + 1);
168 
169   a_mant <<= 1;
170   b_mant <<= 1;
171 
172   UIntType leading_one;
173   int y_mant_width;
174   if (a_exp != 0) {
175     leading_one = one;
176     a_mant |= one;
177     y_mant_width = MantissaWidth<T>::value + 1;
178   } else {
179     leading_one = internal::findLeadingOne(a_mant, y_mant_width);
180   }
181 
182   if (b_exp != 0) {
183     b_mant |= one;
184   }
185 
186   a_mant_sq = static_cast<DUIntType>(a_mant) * a_mant;
187   b_mant_sq = static_cast<DUIntType>(b_mant) * b_mant;
188 
189   // At this point, a_exp >= b_exp > a_exp - 25, so in order to line up aSqMant
190   // and bSqMant, we need to shift bSqMant to the right by (a_exp - b_exp) bits.
191   // But before that, remember to store the losing bits to sticky.
192   // The shift length is for a^2 and b^2, so it's double of the exponent
193   // difference between a and b.
194   uint16_t shift_length = 2 * (a_exp - b_exp);
195   sticky_bits =
196       ((b_mant_sq & ((DUIntType(1) << shift_length) - DUIntType(1))) !=
197        DUIntType(0));
198   b_mant_sq >>= shift_length;
199 
200   DUIntType sum = a_mant_sq + b_mant_sq;
201   if (sum >= (DUIntType(1) << (2 * y_mant_width + 2))) {
202     // a^2 + b^2 >= 4* leading_one^2, so we will need an extra bit to the left.
203     if (leading_one == one) {
204       // For normal result, we discard the last 2 bits of the sum and increase
205       // the exponent.
206       sticky_bits = sticky_bits || ((sum & 0x3U) != 0);
207       sum >>= 2;
208       ++out_exp;
209       if (out_exp >= FPBits_t::maxExponent) {
210         return FPBits_t::inf();
211       }
212     } else {
213       // For denormal result, we simply move the leading bit of the result to
214       // the left by 1.
215       leading_one <<= 1;
216       ++y_mant_width;
217     }
218   }
219 
220   UIntType Y = leading_one;
221   UIntType R = static_cast<UIntType>(sum >> y_mant_width) - leading_one;
222   UIntType tailBits = static_cast<UIntType>(sum) & (leading_one - 1);
223 
224   for (UIntType current_bit = leading_one >> 1; current_bit;
225        current_bit >>= 1) {
226     R = (R << 1) + ((tailBits & current_bit) ? 1 : 0);
227     UIntType tmp = (Y << 1) + current_bit; // 2*y(n - 1) + 2^(-n)
228     if (R >= tmp) {
229       R -= tmp;
230       Y += current_bit;
231     }
232   }
233 
234   bool round_bit = Y & UIntType(1);
235   bool lsb = Y & UIntType(2);
236 
237   if (Y >= one) {
238     Y -= one;
239 
240     if (out_exp == 0) {
241       out_exp = 1;
242     }
243   }
244 
245   Y >>= 1;
246 
247   // Round to the nearest, tie to even.
248   if (round_bit && (lsb || sticky_bits || (R != 0))) {
249     ++Y;
250   }
251 
252   if (Y >= (one >> 1)) {
253     Y -= one >> 1;
254     ++out_exp;
255     if (out_exp >= FPBits_t::maxExponent) {
256       return FPBits_t::inf();
257     }
258   }
259 
260   Y |= static_cast<UIntType>(out_exp) << MantissaWidth<T>::value;
261   return *reinterpret_cast<T *>(&Y);
262 }
263 
264 } // namespace fputil
265 } // namespace __llvm_libc
266 
267 #endif // LLVM_LIBC_UTILS_FPUTIL_HYPOT_H
268