1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef LLVM_LIBC_UTILS_FPUTIL_SQRT_H
10 #define LLVM_LIBC_UTILS_FPUTIL_SQRT_H
11
12 #include "FPBits.h"
13 #include "PlatformDefs.h"
14
15 #include "utils/CPP/TypeTraits.h"
16
17 namespace __llvm_libc {
18 namespace fputil {
19
20 namespace internal {
21
22 template <typename T>
23 static inline void normalize(int &exponent,
24 typename FPBits<T>::UIntType &mantissa);
25
26 template <> inline void normalize<float>(int &exponent, uint32_t &mantissa) {
27 // Use binary search to shift the leading 1 bit.
28 // With MantissaWidth<float> = 23, it will take
29 // ceil(log2(23)) = 5 steps checking the mantissa bits as followed:
30 // Step 1: 0000 0000 0000 XXXX XXXX XXXX
31 // Step 2: 0000 00XX XXXX XXXX XXXX XXXX
32 // Step 3: 000X XXXX XXXX XXXX XXXX XXXX
33 // Step 4: 00XX XXXX XXXX XXXX XXXX XXXX
34 // Step 5: 0XXX XXXX XXXX XXXX XXXX XXXX
35 constexpr int nsteps = 5; // = ceil(log2(MantissaWidth))
36 constexpr uint32_t bounds[nsteps] = {1 << 12, 1 << 18, 1 << 21, 1 << 22,
37 1 << 23};
38 constexpr int shifts[nsteps] = {12, 6, 3, 2, 1};
39
40 for (int i = 0; i < nsteps; ++i) {
41 if (mantissa < bounds[i]) {
42 exponent -= shifts[i];
43 mantissa <<= shifts[i];
44 }
45 }
46 }
47
48 template <> inline void normalize<double>(int &exponent, uint64_t &mantissa) {
49 // Use binary search to shift the leading 1 bit similar to float.
50 // With MantissaWidth<double> = 52, it will take
51 // ceil(log2(52)) = 6 steps checking the mantissa bits.
52 constexpr int nsteps = 6; // = ceil(log2(MantissaWidth))
53 constexpr uint64_t bounds[nsteps] = {1ULL << 26, 1ULL << 39, 1ULL << 46,
54 1ULL << 49, 1ULL << 51, 1ULL << 52};
55 constexpr int shifts[nsteps] = {27, 14, 7, 4, 2, 1};
56
57 for (int i = 0; i < nsteps; ++i) {
58 if (mantissa < bounds[i]) {
59 exponent -= shifts[i];
60 mantissa <<= shifts[i];
61 }
62 }
63 }
64
65 #ifdef LONG_DOUBLE_IS_DOUBLE
66 template <>
67 inline void normalize<long double>(int &exponent, uint64_t &mantissa) {
68 normalize<double>(exponent, mantissa);
69 }
70 #elif !defined(SPECIAL_X86_LONG_DOUBLE)
71 template <>
72 inline void normalize<long double>(int &exponent, __uint128_t &mantissa) {
73 // Use binary search to shift the leading 1 bit similar to float.
74 // With MantissaWidth<long double> = 112, it will take
75 // ceil(log2(112)) = 7 steps checking the mantissa bits.
76 constexpr int nsteps = 7; // = ceil(log2(MantissaWidth))
77 constexpr __uint128_t bounds[nsteps] = {
78 __uint128_t(1) << 56, __uint128_t(1) << 84, __uint128_t(1) << 98,
79 __uint128_t(1) << 105, __uint128_t(1) << 109, __uint128_t(1) << 111,
80 __uint128_t(1) << 112};
81 constexpr int shifts[nsteps] = {57, 29, 15, 8, 4, 2, 1};
82
83 for (int i = 0; i < nsteps; ++i) {
84 if (mantissa < bounds[i]) {
85 exponent -= shifts[i];
86 mantissa <<= shifts[i];
87 }
88 }
89 }
90 #endif
91
92 } // namespace internal
93
94 // Correctly rounded IEEE 754 SQRT with round to nearest, ties to even.
95 // Shift-and-add algorithm.
96 template <typename T,
97 cpp::EnableIfType<cpp::IsFloatingPointType<T>::Value, int> = 0>
sqrt(T x)98 static inline T sqrt(T x) {
99 using UIntType = typename FPBits<T>::UIntType;
100 constexpr UIntType One = UIntType(1) << MantissaWidth<T>::value;
101
102 FPBits<T> bits(x);
103
104 if (bits.isInfOrNaN()) {
105 if (bits.getSign() && (bits.getMantissa() == 0)) {
106 // sqrt(-Inf) = NaN
107 return FPBits<T>::buildNaN(One >> 1);
108 } else {
109 // sqrt(NaN) = NaN
110 // sqrt(+Inf) = +Inf
111 return x;
112 }
113 } else if (bits.isZero()) {
114 // sqrt(+0) = +0
115 // sqrt(-0) = -0
116 return x;
117 } else if (bits.getSign()) {
118 // sqrt( negative numbers ) = NaN
119 return FPBits<T>::buildNaN(One >> 1);
120 } else {
121 int xExp = bits.getExponent();
122 UIntType xMant = bits.getMantissa();
123
124 // Step 1a: Normalize denormal input and append hiddent bit to the mantissa
125 if (bits.getUnbiasedExponent() == 0) {
126 ++xExp; // let xExp be the correct exponent of One bit.
127 internal::normalize<T>(xExp, xMant);
128 } else {
129 xMant |= One;
130 }
131
132 // Step 1b: Make sure the exponent is even.
133 if (xExp & 1) {
134 --xExp;
135 xMant <<= 1;
136 }
137
138 // After step 1b, x = 2^(xExp) * xMant, where xExp is even, and
139 // 1 <= xMant < 4. So sqrt(x) = 2^(xExp / 2) * y, with 1 <= y < 2.
140 // Notice that the output of sqrt is always in the normal range.
141 // To perform shift-and-add algorithm to find y, let denote:
142 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
143 // r(n) = 2^n ( xMant - y(n)^2 ).
144 // That leads to the following recurrence formula:
145 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
146 // with the initial conditions: y(0) = 1, and r(0) = x - 1.
147 // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
148 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
149 // 0 otherwise.
150 UIntType y = One;
151 UIntType r = xMant - One;
152
153 for (UIntType current_bit = One >> 1; current_bit; current_bit >>= 1) {
154 r <<= 1;
155 UIntType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
156 if (r >= tmp) {
157 r -= tmp;
158 y += current_bit;
159 }
160 }
161
162 // We compute one more iteration in order to round correctly.
163 bool lsb = y & 1; // Least significant bit
164 bool rb = false; // Round bit
165 r <<= 2;
166 UIntType tmp = (y << 2) + 1;
167 if (r >= tmp) {
168 r -= tmp;
169 rb = true;
170 }
171
172 // Remove hidden bit and append the exponent field.
173 xExp = ((xExp >> 1) + FPBits<T>::exponentBias);
174
175 y = (y - One) | (static_cast<UIntType>(xExp) << MantissaWidth<T>::value);
176 // Round to nearest, ties to even
177 if (rb && (lsb || (r != 0))) {
178 ++y;
179 }
180
181 return *reinterpret_cast<T *>(&y);
182 }
183 }
184
185 } // namespace fputil
186 } // namespace __llvm_libc
187
188 #ifdef SPECIAL_X86_LONG_DOUBLE
189 #include "SqrtLongDoubleX86.h"
190 #endif // SPECIAL_X86_LONG_DOUBLE
191
192 #endif // LLVM_LIBC_UTILS_FPUTIL_SQRT_H
193