1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2007 Julien Pommier 5 // Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) 6 // Copyright (C) 2009-2019 Gael Guennebaud <gael.guennebaud@inria.fr> 7 // 8 // This Source Code Form is subject to the terms of the Mozilla 9 // Public License v. 2.0. If a copy of the MPL was not distributed 10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 11 12 /* The exp and log functions of this file initially come from 13 * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/ 14 */ 15 16 #ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H 17 #define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H 18 19 namespace Eigen { 20 namespace internal { 21 22 // Creates a Scalar integer type with same bit-width. 23 template<typename T> struct make_integer; 24 template<> struct make_integer<float> { typedef numext::int32_t type; }; 25 template<> struct make_integer<double> { typedef numext::int64_t type; }; 26 template<> struct make_integer<half> { typedef numext::int16_t type; }; 27 template<> struct make_integer<bfloat16> { typedef numext::int16_t type; }; 28 29 template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC 30 Packet pfrexp_generic_get_biased_exponent(const Packet& a) { 31 typedef typename unpacket_traits<Packet>::type Scalar; 32 typedef typename unpacket_traits<Packet>::integer_packet PacketI; 33 enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1}; 34 return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a)))); 35 } 36 37 // Safely applies frexp, correctly handles denormals. 38 // Assumes IEEE floating point format. 39 template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC 40 Packet pfrexp_generic(const Packet& a, Packet& exponent) { 41 typedef typename unpacket_traits<Packet>::type Scalar; 42 typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI; 43 enum { 44 TotalBits = sizeof(Scalar) * CHAR_BIT, 45 MantissaBits = numext::numeric_limits<Scalar>::digits - 1, 46 ExponentBits = int(TotalBits) - int(MantissaBits) - 1 47 }; 48 49 EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = 50 ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000 51 const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask)); 52 const Packet half = pset1<Packet>(Scalar(0.5)); 53 const Packet zero = pzero(a); 54 const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126 55 56 // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1). 57 const Packet is_denormal = pcmp_lt(pabs(a), normal_min); 58 EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24 59 // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. 60 const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 61 const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor); 62 const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); 63 64 // Determine exponent offset: -126 if normal, -126-24 if denormal 65 const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126 66 Packet exponent_offset = pset1<Packet>(scalar_exponent_offset); 67 const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24 68 exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); 69 70 // Determine exponent and mantissa from normalized_a. 71 exponent = pfrexp_generic_get_biased_exponent(normalized_a); 72 // Zero, Inf and NaN return 'a' unmodified, exponent is zero 73 // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) 74 const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255 75 const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent); 76 const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); 77 const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); 78 exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); 79 return m; 80 } 81 82 // Safely applies ldexp, correctly handles overflows, underflows and denormals. 83 // Assumes IEEE floating point format. 84 template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC 85 Packet pldexp_generic(const Packet& a, const Packet& exponent) { 86 // We want to return a * 2^exponent, allowing for all possible integer 87 // exponents without overflowing or underflowing in intermediate 88 // computations. 89 // 90 // Since 'a' and the output can be denormal, the maximum range of 'exponent' 91 // to consider for a float is: 92 // -255-23 -> 255+23 93 // Below -278 any finite float 'a' will become zero, and above +278 any 94 // finite float will become inf, including when 'a' is the smallest possible 95 // denormal. 96 // 97 // Unfortunately, 2^(278) cannot be represented using either one or two 98 // finite normal floats, so we must split the scale factor into at least 99 // three parts. It turns out to be faster to split 'exponent' into four 100 // factors, since [exponent>>2] is much faster to compute that [exponent/3]. 101 // 102 // Set e = min(max(exponent, -278), 278); 103 // b = floor(e/4); 104 // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) 105 // 106 // This will avoid any intermediate overflows and correctly handle 0, inf, 107 // NaN cases. 108 typedef typename unpacket_traits<Packet>::integer_packet PacketI; 109 typedef typename unpacket_traits<Packet>::type Scalar; 110 typedef typename unpacket_traits<PacketI>::type ScalarI; 111 enum { 112 TotalBits = sizeof(Scalar) * CHAR_BIT, 113 MantissaBits = numext::numeric_limits<Scalar>::digits - 1, 114 ExponentBits = int(TotalBits) - int(MantissaBits) - 1 115 }; 116 117 const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278 118 const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 119 const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); 120 PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); 121 Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b 122 Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) 123 b = psub(psub(psub(e, b), b), b); // e - 3b 124 c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b) 125 out = pmul(out, c); 126 return out; 127 } 128 129 // Explicitly multiplies 130 // a * (2^e) 131 // clamping e to the range 132 // [NumTraits<Scalar>::min_exponent()-2, NumTraits<Scalar>::max_exponent()] 133 // 134 // This is approx 7x faster than pldexp_impl, but will prematurely over/underflow 135 // if 2^e doesn't fit into a normal floating-point Scalar. 136 // 137 // Assumes IEEE floating point format 138 template<typename Packet> 139 struct pldexp_fast_impl { 140 typedef typename unpacket_traits<Packet>::integer_packet PacketI; 141 typedef typename unpacket_traits<Packet>::type Scalar; 142 typedef typename unpacket_traits<PacketI>::type ScalarI; 143 enum { 144 TotalBits = sizeof(Scalar) * CHAR_BIT, 145 MantissaBits = numext::numeric_limits<Scalar>::digits - 1, 146 ExponentBits = int(TotalBits) - int(MantissaBits) - 1 147 }; 148 149 static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC 150 Packet run(const Packet& a, const Packet& exponent) { 151 const Packet bias = pset1<Packet>(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127 152 const Packet limit = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) - ScalarI(1))); // 255 153 // restrict biased exponent between 0 and 255 for float. 154 const PacketI e = pcast<Packet, PacketI>(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127 155 // return a * (2^e) 156 return pmul(a, preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(e))); 157 } 158 }; 159 160 // Natural or base 2 logarithm. 161 // Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) 162 // and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can 163 // be easily approximated by a polynomial centered on m=1 for stability. 164 // TODO(gonnet): Further reduce the interval allowing for lower-degree 165 // polynomial interpolants -> ... -> profit! 166 template <typename Packet, bool base2> 167 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 168 EIGEN_UNUSED 169 Packet plog_impl_float(const Packet _x) 170 { 171 Packet x = _x; 172 173 const Packet cst_1 = pset1<Packet>(1.0f); 174 const Packet cst_neg_half = pset1<Packet>(-0.5f); 175 // The smallest non denormalized float number. 176 const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u); 177 const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u); 178 const Packet cst_pos_inf = pset1frombits<Packet>( 0x7f800000u); 179 180 // Polynomial coefficients. 181 const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f); 182 const Packet cst_cephes_log_p0 = pset1<Packet>(7.0376836292E-2f); 183 const Packet cst_cephes_log_p1 = pset1<Packet>(-1.1514610310E-1f); 184 const Packet cst_cephes_log_p2 = pset1<Packet>(1.1676998740E-1f); 185 const Packet cst_cephes_log_p3 = pset1<Packet>(-1.2420140846E-1f); 186 const Packet cst_cephes_log_p4 = pset1<Packet>(+1.4249322787E-1f); 187 const Packet cst_cephes_log_p5 = pset1<Packet>(-1.6668057665E-1f); 188 const Packet cst_cephes_log_p6 = pset1<Packet>(+2.0000714765E-1f); 189 const Packet cst_cephes_log_p7 = pset1<Packet>(-2.4999993993E-1f); 190 const Packet cst_cephes_log_p8 = pset1<Packet>(+3.3333331174E-1f); 191 192 // Truncate input values to the minimum positive normal. 193 x = pmax(x, cst_min_norm_pos); 194 195 Packet e; 196 // extract significant in the range [0.5,1) and exponent 197 x = pfrexp(x,e); 198 199 // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) 200 // and shift by -1. The values are then centered around 0, which improves 201 // the stability of the polynomial evaluation. 202 // if( x < SQRTHF ) { 203 // e -= 1; 204 // x = x + x - 1.0; 205 // } else { x = x - 1.0; } 206 Packet mask = pcmp_lt(x, cst_cephes_SQRTHF); 207 Packet tmp = pand(x, mask); 208 x = psub(x, cst_1); 209 e = psub(e, pand(cst_1, mask)); 210 x = padd(x, tmp); 211 212 Packet x2 = pmul(x, x); 213 Packet x3 = pmul(x2, x); 214 215 // Evaluate the polynomial approximant of degree 8 in three parts, probably 216 // to improve instruction-level parallelism. 217 Packet y, y1, y2; 218 y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); 219 y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); 220 y2 = pmadd(cst_cephes_log_p6, x, cst_cephes_log_p7); 221 y = pmadd(y, x, cst_cephes_log_p2); 222 y1 = pmadd(y1, x, cst_cephes_log_p5); 223 y2 = pmadd(y2, x, cst_cephes_log_p8); 224 y = pmadd(y, x3, y1); 225 y = pmadd(y, x3, y2); 226 y = pmul(y, x3); 227 228 y = pmadd(cst_neg_half, x2, y); 229 x = padd(x, y); 230 231 // Add the logarithm of the exponent back to the result of the interpolation. 232 if (base2) { 233 const Packet cst_log2e = pset1<Packet>(static_cast<float>(EIGEN_LOG2E)); 234 x = pmadd(x, cst_log2e, e); 235 } else { 236 const Packet cst_ln2 = pset1<Packet>(static_cast<float>(EIGEN_LN2)); 237 x = pmadd(e, cst_ln2, x); 238 } 239 240 Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); 241 Packet iszero_mask = pcmp_eq(_x,pzero(_x)); 242 Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf); 243 // Filter out invalid inputs, i.e.: 244 // - negative arg will be NAN 245 // - 0 will be -INF 246 // - +INF will be +INF 247 return pselect(iszero_mask, cst_minus_inf, 248 por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); 249 } 250 251 template <typename Packet> 252 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 253 EIGEN_UNUSED 254 Packet plog_float(const Packet _x) 255 { 256 return plog_impl_float<Packet, /* base2 */ false>(_x); 257 } 258 259 template <typename Packet> 260 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 261 EIGEN_UNUSED 262 Packet plog2_float(const Packet _x) 263 { 264 return plog_impl_float<Packet, /* base2 */ true>(_x); 265 } 266 267 /* Returns the base e (2.718...) or base 2 logarithm of x. 268 * The argument is separated into its exponent and fractional parts. 269 * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)], 270 * is approximated by 271 * 272 * log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x). 273 * 274 * for more detail see: http://www.netlib.org/cephes/ 275 */ 276 template <typename Packet, bool base2> 277 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 278 EIGEN_UNUSED 279 Packet plog_impl_double(const Packet _x) 280 { 281 Packet x = _x; 282 283 const Packet cst_1 = pset1<Packet>(1.0); 284 const Packet cst_neg_half = pset1<Packet>(-0.5); 285 // The smallest non denormalized double. 286 const Packet cst_min_norm_pos = pset1frombits<Packet>( static_cast<uint64_t>(0x0010000000000000ull)); 287 const Packet cst_minus_inf = pset1frombits<Packet>( static_cast<uint64_t>(0xfff0000000000000ull)); 288 const Packet cst_pos_inf = pset1frombits<Packet>( static_cast<uint64_t>(0x7ff0000000000000ull)); 289 290 291 // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x) 292 // 1/sqrt(2) <= x < sqrt(2) 293 const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0); 294 const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4); 295 const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1); 296 const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0); 297 const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1); 298 const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1); 299 const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0); 300 301 const Packet cst_cephes_log_q0 = pset1<Packet>(1.0); 302 const Packet cst_cephes_log_q1 = pset1<Packet>(1.12873587189167450590E1); 303 const Packet cst_cephes_log_q2 = pset1<Packet>(4.52279145837532221105E1); 304 const Packet cst_cephes_log_q3 = pset1<Packet>(8.29875266912776603211E1); 305 const Packet cst_cephes_log_q4 = pset1<Packet>(7.11544750618563894466E1); 306 const Packet cst_cephes_log_q5 = pset1<Packet>(2.31251620126765340583E1); 307 308 // Truncate input values to the minimum positive normal. 309 x = pmax(x, cst_min_norm_pos); 310 311 Packet e; 312 // extract significant in the range [0.5,1) and exponent 313 x = pfrexp(x,e); 314 315 // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) 316 // and shift by -1. The values are then centered around 0, which improves 317 // the stability of the polynomial evaluation. 318 // if( x < SQRTHF ) { 319 // e -= 1; 320 // x = x + x - 1.0; 321 // } else { x = x - 1.0; } 322 Packet mask = pcmp_lt(x, cst_cephes_SQRTHF); 323 Packet tmp = pand(x, mask); 324 x = psub(x, cst_1); 325 e = psub(e, pand(cst_1, mask)); 326 x = padd(x, tmp); 327 328 Packet x2 = pmul(x, x); 329 Packet x3 = pmul(x2, x); 330 331 // Evaluate the polynomial approximant , probably to improve instruction-level parallelism. 332 // y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) ); 333 Packet y, y1, y_; 334 y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); 335 y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); 336 y = pmadd(y, x, cst_cephes_log_p2); 337 y1 = pmadd(y1, x, cst_cephes_log_p5); 338 y_ = pmadd(y, x3, y1); 339 340 y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1); 341 y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4); 342 y = pmadd(y, x, cst_cephes_log_q2); 343 y1 = pmadd(y1, x, cst_cephes_log_q5); 344 y = pmadd(y, x3, y1); 345 346 y_ = pmul(y_, x3); 347 y = pdiv(y_, y); 348 349 y = pmadd(cst_neg_half, x2, y); 350 x = padd(x, y); 351 352 // Add the logarithm of the exponent back to the result of the interpolation. 353 if (base2) { 354 const Packet cst_log2e = pset1<Packet>(static_cast<double>(EIGEN_LOG2E)); 355 x = pmadd(x, cst_log2e, e); 356 } else { 357 const Packet cst_ln2 = pset1<Packet>(static_cast<double>(EIGEN_LN2)); 358 x = pmadd(e, cst_ln2, x); 359 } 360 361 Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); 362 Packet iszero_mask = pcmp_eq(_x,pzero(_x)); 363 Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf); 364 // Filter out invalid inputs, i.e.: 365 // - negative arg will be NAN 366 // - 0 will be -INF 367 // - +INF will be +INF 368 return pselect(iszero_mask, cst_minus_inf, 369 por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); 370 } 371 372 template <typename Packet> 373 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 374 EIGEN_UNUSED 375 Packet plog_double(const Packet _x) 376 { 377 return plog_impl_double<Packet, /* base2 */ false>(_x); 378 } 379 380 template <typename Packet> 381 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 382 EIGEN_UNUSED 383 Packet plog2_double(const Packet _x) 384 { 385 return plog_impl_double<Packet, /* base2 */ true>(_x); 386 } 387 388 /** \internal \returns log(1 + x) computed using W. Kahan's formula. 389 See: http://www.plunk.org/~hatch/rightway.php 390 */ 391 template<typename Packet> 392 Packet generic_plog1p(const Packet& x) 393 { 394 typedef typename unpacket_traits<Packet>::type ScalarType; 395 const Packet one = pset1<Packet>(ScalarType(1)); 396 Packet xp1 = padd(x, one); 397 Packet small_mask = pcmp_eq(xp1, one); 398 Packet log1 = plog(xp1); 399 Packet inf_mask = pcmp_eq(xp1, log1); 400 Packet log_large = pmul(x, pdiv(log1, psub(xp1, one))); 401 return pselect(por(small_mask, inf_mask), x, log_large); 402 } 403 404 /** \internal \returns exp(x)-1 computed using W. Kahan's formula. 405 See: http://www.plunk.org/~hatch/rightway.php 406 */ 407 template<typename Packet> 408 Packet generic_expm1(const Packet& x) 409 { 410 typedef typename unpacket_traits<Packet>::type ScalarType; 411 const Packet one = pset1<Packet>(ScalarType(1)); 412 const Packet neg_one = pset1<Packet>(ScalarType(-1)); 413 Packet u = pexp(x); 414 Packet one_mask = pcmp_eq(u, one); 415 Packet u_minus_one = psub(u, one); 416 Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one); 417 Packet logu = plog(u); 418 // The following comparison is to catch the case where 419 // exp(x) = +inf. It is written in this way to avoid having 420 // to form the constant +inf, which depends on the packet 421 // type. 422 Packet pos_inf_mask = pcmp_eq(logu, u); 423 Packet expm1 = pmul(u_minus_one, pdiv(x, logu)); 424 expm1 = pselect(pos_inf_mask, u, expm1); 425 return pselect(one_mask, 426 x, 427 pselect(neg_one_mask, 428 neg_one, 429 expm1)); 430 } 431 432 433 // Exponential function. Works by writing "x = m*log(2) + r" where 434 // "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then 435 // "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1). 436 template <typename Packet> 437 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 438 EIGEN_UNUSED 439 Packet pexp_float(const Packet _x) 440 { 441 const Packet cst_1 = pset1<Packet>(1.0f); 442 const Packet cst_half = pset1<Packet>(0.5f); 443 const Packet cst_exp_hi = pset1<Packet>( 88.723f); 444 const Packet cst_exp_lo = pset1<Packet>(-88.723f); 445 446 const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f); 447 const Packet cst_cephes_exp_p0 = pset1<Packet>(1.9875691500E-4f); 448 const Packet cst_cephes_exp_p1 = pset1<Packet>(1.3981999507E-3f); 449 const Packet cst_cephes_exp_p2 = pset1<Packet>(8.3334519073E-3f); 450 const Packet cst_cephes_exp_p3 = pset1<Packet>(4.1665795894E-2f); 451 const Packet cst_cephes_exp_p4 = pset1<Packet>(1.6666665459E-1f); 452 const Packet cst_cephes_exp_p5 = pset1<Packet>(5.0000001201E-1f); 453 454 // Clamp x. 455 Packet x = pmax(pmin(_x, cst_exp_hi), cst_exp_lo); 456 457 // Express exp(x) as exp(m*ln(2) + r), start by extracting 458 // m = floor(x/ln(2) + 0.5). 459 Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half)); 460 461 // Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is 462 // subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating 463 // truncation errors. 464 const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f); 465 const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f); 466 Packet r = pmadd(m, cst_cephes_exp_C1, x); 467 r = pmadd(m, cst_cephes_exp_C2, r); 468 469 Packet r2 = pmul(r, r); 470 Packet r3 = pmul(r2, r); 471 472 // Evaluate the polynomial approximant,improved by instruction-level parallelism. 473 Packet y, y1, y2; 474 y = pmadd(cst_cephes_exp_p0, r, cst_cephes_exp_p1); 475 y1 = pmadd(cst_cephes_exp_p3, r, cst_cephes_exp_p4); 476 y2 = padd(r, cst_1); 477 y = pmadd(y, r, cst_cephes_exp_p2); 478 y1 = pmadd(y1, r, cst_cephes_exp_p5); 479 y = pmadd(y, r3, y1); 480 y = pmadd(y, r2, y2); 481 482 // Return 2^m * exp(r). 483 // TODO: replace pldexp with faster implementation since y in [-1, 1). 484 return pmax(pldexp(y,m), _x); 485 } 486 487 template <typename Packet> 488 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 489 EIGEN_UNUSED 490 Packet pexp_double(const Packet _x) 491 { 492 Packet x = _x; 493 494 const Packet cst_1 = pset1<Packet>(1.0); 495 const Packet cst_2 = pset1<Packet>(2.0); 496 const Packet cst_half = pset1<Packet>(0.5); 497 498 const Packet cst_exp_hi = pset1<Packet>(709.784); 499 const Packet cst_exp_lo = pset1<Packet>(-709.784); 500 501 const Packet cst_cephes_LOG2EF = pset1<Packet>(1.4426950408889634073599); 502 const Packet cst_cephes_exp_p0 = pset1<Packet>(1.26177193074810590878e-4); 503 const Packet cst_cephes_exp_p1 = pset1<Packet>(3.02994407707441961300e-2); 504 const Packet cst_cephes_exp_p2 = pset1<Packet>(9.99999999999999999910e-1); 505 const Packet cst_cephes_exp_q0 = pset1<Packet>(3.00198505138664455042e-6); 506 const Packet cst_cephes_exp_q1 = pset1<Packet>(2.52448340349684104192e-3); 507 const Packet cst_cephes_exp_q2 = pset1<Packet>(2.27265548208155028766e-1); 508 const Packet cst_cephes_exp_q3 = pset1<Packet>(2.00000000000000000009e0); 509 const Packet cst_cephes_exp_C1 = pset1<Packet>(0.693145751953125); 510 const Packet cst_cephes_exp_C2 = pset1<Packet>(1.42860682030941723212e-6); 511 512 Packet tmp, fx; 513 514 // clamp x 515 x = pmax(pmin(x, cst_exp_hi), cst_exp_lo); 516 // Express exp(x) as exp(g + n*log(2)). 517 fx = pmadd(cst_cephes_LOG2EF, x, cst_half); 518 519 // Get the integer modulus of log(2), i.e. the "n" described above. 520 fx = pfloor(fx); 521 522 // Get the remainder modulo log(2), i.e. the "g" described above. Subtract 523 // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last 524 // digits right. 525 tmp = pmul(fx, cst_cephes_exp_C1); 526 Packet z = pmul(fx, cst_cephes_exp_C2); 527 x = psub(x, tmp); 528 x = psub(x, z); 529 530 Packet x2 = pmul(x, x); 531 532 // Evaluate the numerator polynomial of the rational interpolant. 533 Packet px = cst_cephes_exp_p0; 534 px = pmadd(px, x2, cst_cephes_exp_p1); 535 px = pmadd(px, x2, cst_cephes_exp_p2); 536 px = pmul(px, x); 537 538 // Evaluate the denominator polynomial of the rational interpolant. 539 Packet qx = cst_cephes_exp_q0; 540 qx = pmadd(qx, x2, cst_cephes_exp_q1); 541 qx = pmadd(qx, x2, cst_cephes_exp_q2); 542 qx = pmadd(qx, x2, cst_cephes_exp_q3); 543 544 // I don't really get this bit, copied from the SSE2 routines, so... 545 // TODO(gonnet): Figure out what is going on here, perhaps find a better 546 // rational interpolant? 547 x = pdiv(px, psub(qx, px)); 548 x = pmadd(cst_2, x, cst_1); 549 550 // Construct the result 2^n * exp(g) = e * x. The max is used to catch 551 // non-finite values in the input. 552 // TODO: replace pldexp with faster implementation since x in [-1, 1). 553 return pmax(pldexp(x,fx), _x); 554 } 555 556 // The following code is inspired by the following stack-overflow answer: 557 // https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751 558 // It has been largely optimized: 559 // - By-pass calls to frexp. 560 // - Aligned loads of required 96 bits of 2/pi. This is accomplished by 561 // (1) balancing the mantissa and exponent to the required bits of 2/pi are 562 // aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi. 563 // - Avoid a branch in rounding and extraction of the remaining fractional part. 564 // Overall, I measured a speed up higher than x2 on x86-64. 565 inline float trig_reduce_huge (float xf, int *quadrant) 566 { 567 using Eigen::numext::int32_t; 568 using Eigen::numext::uint32_t; 569 using Eigen::numext::int64_t; 570 using Eigen::numext::uint64_t; 571 572 const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62 573 const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point foramt 574 575 // 192 bits of 2/pi for Payne-Hanek reduction 576 // Bits are introduced by packet of 8 to enable aligned reads. 577 static const uint32_t two_over_pi [] = 578 { 579 0x00000028, 0x000028be, 0x0028be60, 0x28be60db, 580 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a, 581 0x91054a7f, 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4, 582 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770, 583 0x4d377036, 0x377036d8, 0x7036d8a5, 0x36d8a566, 584 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410, 585 0x10e41000, 0xe4100000 586 }; 587 588 uint32_t xi = numext::bit_cast<uint32_t>(xf); 589 // Below, -118 = -126 + 8. 590 // -126 is to get the exponent, 591 // +8 is to enable alignment of 2/pi's bits on 8 bits. 592 // This is possible because the fractional part of x as only 24 meaningful bits. 593 uint32_t e = (xi >> 23) - 118; 594 // Extract the mantissa and shift it to align it wrt the exponent 595 xi = ((xi & 0x007fffffu)| 0x00800000u) << (e & 0x7); 596 597 uint32_t i = e >> 3; 598 uint32_t twoopi_1 = two_over_pi[i-1]; 599 uint32_t twoopi_2 = two_over_pi[i+3]; 600 uint32_t twoopi_3 = two_over_pi[i+7]; 601 602 // Compute x * 2/pi in 2.62-bit fixed-point format. 603 uint64_t p; 604 p = uint64_t(xi) * twoopi_3; 605 p = uint64_t(xi) * twoopi_2 + (p >> 32); 606 p = (uint64_t(xi * twoopi_1) << 32) + p; 607 608 // Round to nearest: add 0.5 and extract integral part. 609 uint64_t q = (p + zero_dot_five) >> 62; 610 *quadrant = int(q); 611 // Now it remains to compute "r = x - q*pi/2" with high accuracy, 612 // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as: 613 // r = (p-q)*pi/2, 614 // where the product can be be carried out with sufficient accuracy using double precision. 615 p -= q<<62; 616 return float(double(int64_t(p)) * pio2_62); 617 } 618 619 template<bool ComputeSine,typename Packet> 620 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 621 EIGEN_UNUSED 622 #if EIGEN_GNUC_AT_LEAST(4,4) && EIGEN_COMP_GNUC_STRICT 623 __attribute__((optimize("-fno-unsafe-math-optimizations"))) 624 #endif 625 Packet psincos_float(const Packet& _x) 626 { 627 typedef typename unpacket_traits<Packet>::integer_packet PacketI; 628 629 const Packet cst_2oPI = pset1<Packet>(0.636619746685028076171875f); // 2/PI 630 const Packet cst_rounding_magic = pset1<Packet>(12582912); // 2^23 for rounding 631 const PacketI csti_1 = pset1<PacketI>(1); 632 const Packet cst_sign_mask = pset1frombits<Packet>(0x80000000u); 633 634 Packet x = pabs(_x); 635 636 // Scale x by 2/Pi to find x's octant. 637 Packet y = pmul(x, cst_2oPI); 638 639 // Rounding trick: 640 Packet y_round = padd(y, cst_rounding_magic); 641 EIGEN_OPTIMIZATION_BARRIER(y_round) 642 PacketI y_int = preinterpret<PacketI>(y_round); // last 23 digits represent integer (if abs(x)<2^24) 643 y = psub(y_round, cst_rounding_magic); // nearest integer to x*4/pi 644 645 // Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4 646 // using "Extended precision modular arithmetic" 647 #if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) 648 // This version requires true FMA for high accuracy 649 // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08): 650 const float huge_th = ComputeSine ? 117435.992f : 71476.0625f; 651 x = pmadd(y, pset1<Packet>(-1.57079601287841796875f), x); 652 x = pmadd(y, pset1<Packet>(-3.1391647326017846353352069854736328125e-07f), x); 653 x = pmadd(y, pset1<Packet>(-5.390302529957764765544681040410068817436695098876953125e-15f), x); 654 #else 655 // Without true FMA, the previous set of coefficients maintain 1ULP accuracy 656 // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7. 657 // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs. 658 659 // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively. 660 // and 2 ULP up to: 661 const float huge_th = ComputeSine ? 25966.f : 18838.f; 662 x = pmadd(y, pset1<Packet>(-1.5703125), x); // = 0xbfc90000 663 EIGEN_OPTIMIZATION_BARRIER(x) 664 x = pmadd(y, pset1<Packet>(-0.000483989715576171875), x); // = 0xb9fdc000 665 EIGEN_OPTIMIZATION_BARRIER(x) 666 x = pmadd(y, pset1<Packet>(1.62865035235881805419921875e-07), x); // = 0x342ee000 667 x = pmadd(y, pset1<Packet>(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee 668 669 // For the record, the following set of coefficients maintain 2ULP up 670 // to a slightly larger range: 671 // const float huge_th = ComputeSine ? 51981.f : 39086.125f; 672 // but it slightly fails to maintain 1ULP for two values of sin below pi. 673 // x = pmadd(y, pset1<Packet>(-3.140625/2.), x); 674 // x = pmadd(y, pset1<Packet>(-0.00048351287841796875), x); 675 // x = pmadd(y, pset1<Packet>(-3.13855707645416259765625e-07), x); 676 // x = pmadd(y, pset1<Packet>(-6.0771006282767103812147979624569416046142578125e-11), x); 677 678 // For the record, with only 3 iterations it is possible to maintain 679 // 1 ULP up to 3PI (maybe more) and 2ULP up to 255. 680 // The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee 681 #endif 682 683 if(predux_any(pcmp_le(pset1<Packet>(huge_th),pabs(_x)))) 684 { 685 const int PacketSize = unpacket_traits<Packet>::size; 686 EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize]; 687 EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize]; 688 EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) int y_int2[PacketSize]; 689 pstoreu(vals, pabs(_x)); 690 pstoreu(x_cpy, x); 691 pstoreu(y_int2, y_int); 692 for(int k=0; k<PacketSize;++k) 693 { 694 float val = vals[k]; 695 if(val>=huge_th && (numext::isfinite)(val)) 696 x_cpy[k] = trig_reduce_huge(val,&y_int2[k]); 697 } 698 x = ploadu<Packet>(x_cpy); 699 y_int = ploadu<PacketI>(y_int2); 700 } 701 702 // Compute the sign to apply to the polynomial. 703 // sin: sign = second_bit(y_int) xor signbit(_x) 704 // cos: sign = second_bit(y_int+1) 705 Packet sign_bit = ComputeSine ? pxor(_x, preinterpret<Packet>(plogical_shift_left<30>(y_int))) 706 : preinterpret<Packet>(plogical_shift_left<30>(padd(y_int,csti_1))); 707 sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit 708 709 // Get the polynomial selection mask from the second bit of y_int 710 // We'll calculate both (sin and cos) polynomials and then select from the two. 711 Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(y_int, csti_1), pzero(y_int))); 712 713 Packet x2 = pmul(x,x); 714 715 // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4) 716 Packet y1 = pset1<Packet>(2.4372266125283204019069671630859375e-05f); 717 y1 = pmadd(y1, x2, pset1<Packet>(-0.00138865201734006404876708984375f )); 718 y1 = pmadd(y1, x2, pset1<Packet>(0.041666619479656219482421875f )); 719 y1 = pmadd(y1, x2, pset1<Packet>(-0.5f)); 720 y1 = pmadd(y1, x2, pset1<Packet>(1.f)); 721 722 // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4) 723 // octave/matlab code to compute those coefficients: 724 // x = (0:0.0001:pi/4)'; 725 // A = [x.^3 x.^5 x.^7]; 726 // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy 727 // c = (A'*diag(w)*A)\(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1 728 // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1)) 729 // 730 Packet y2 = pset1<Packet>(-0.0001959234114083702898469196984621021329076029360294342041015625f); 731 y2 = pmadd(y2, x2, pset1<Packet>( 0.0083326873655616851693794799871284340042620897293090820312500000f)); 732 y2 = pmadd(y2, x2, pset1<Packet>(-0.1666666203982298255503735617821803316473960876464843750000000000f)); 733 y2 = pmul(y2, x2); 734 y2 = pmadd(y2, x, x); 735 736 // Select the correct result from the two polynomials. 737 y = ComputeSine ? pselect(poly_mask,y2,y1) 738 : pselect(poly_mask,y1,y2); 739 740 // Update the sign and filter huge inputs 741 return pxor(y, sign_bit); 742 } 743 744 template<typename Packet> 745 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 746 EIGEN_UNUSED 747 Packet psin_float(const Packet& x) 748 { 749 return psincos_float<true>(x); 750 } 751 752 template<typename Packet> 753 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 754 EIGEN_UNUSED 755 Packet pcos_float(const Packet& x) 756 { 757 return psincos_float<false>(x); 758 } 759 760 761 template<typename Packet> 762 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 763 EIGEN_UNUSED 764 Packet psqrt_complex(const Packet& a) { 765 typedef typename unpacket_traits<Packet>::type Scalar; 766 typedef typename Scalar::value_type RealScalar; 767 typedef typename unpacket_traits<Packet>::as_real RealPacket; 768 769 // Computes the principal sqrt of the complex numbers in the input. 770 // 771 // For example, for packets containing 2 complex numbers stored in interleaved format 772 // a = [a0, a1] = [x0, y0, x1, y1], 773 // where x0 = real(a0), y0 = imag(a0) etc., this function returns 774 // b = [b0, b1] = [u0, v0, u1, v1], 775 // such that b0^2 = a0, b1^2 = a1. 776 // 777 // To derive the formula for the complex square roots, let's consider the equation for 778 // a single complex square root of the number x + i*y. We want to find real numbers 779 // u and v such that 780 // (u + i*v)^2 = x + i*y <=> 781 // u^2 - v^2 + i*2*u*v = x + i*v. 782 // By equating the real and imaginary parts we get: 783 // u^2 - v^2 = x 784 // 2*u*v = y. 785 // 786 // For x >= 0, this has the numerically stable solution 787 // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) 788 // v = 0.5 * (y / u) 789 // and for x < 0, 790 // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) 791 // u = 0.5 * (y / v) 792 // 793 // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as 794 // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , 795 796 // In the following, without lack of generality, we have annotated the code, assuming 797 // that the input is a packet of 2 complex numbers. 798 // 799 // Step 1. Compute l = [l0, l0, l1, l1], where 800 // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2) 801 // To avoid over- and underflow, we use the stable formula for each hypotenuse 802 // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)), 803 // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1. 804 805 RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|] 806 RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|] 807 RealPacket a_max = pmax(a_abs, a_abs_flip); 808 RealPacket a_min = pmin(a_abs, a_abs_flip); 809 RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); 810 RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); 811 RealPacket r = pdiv(a_min, a_max); 812 const RealPacket cst_one = pset1<RealPacket>(RealScalar(1)); 813 RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] 814 // Set l to a_max if a_min is zero. 815 l = pselect(a_min_zero_mask, a_max, l); 816 817 // Step 2. Compute [rho0, *, rho1, *], where 818 // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|)) 819 // We don't care about the imaginary parts computed here. They will be overwritten later. 820 const RealPacket cst_half = pset1<RealPacket>(RealScalar(0.5)); 821 Packet rho; 822 rho.v = psqrt(pmul(cst_half, padd(a_abs, l))); 823 824 // Step 3. Compute [rho0, eta0, rho1, eta1], where 825 // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2. 826 // set eta = 0 of input is 0 + i0. 827 RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask); 828 RealPacket real_mask = peven_mask(a.v); 829 Packet positive_real_result; 830 // Compute result for inputs with positive real part. 831 positive_real_result.v = pselect(real_mask, rho.v, eta); 832 833 // Step 4. Compute solution for inputs with negative real part: 834 // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] 835 const RealScalar neg_zero = RealScalar(numext::bit_cast<float>(0x80000000u)); 836 const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), neg_zero)).v; 837 RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); 838 Packet negative_real_result; 839 // Notice that rho is positive, so taking it's absolute value is a noop. 840 negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs); 841 842 // Step 5. Select solution branch based on the sign of the real parts. 843 Packet negative_real_mask; 844 negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v)); 845 negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v); 846 Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result); 847 848 // Step 6. Handle special cases for infinities: 849 // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN 850 // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN 851 // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y 852 // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y 853 const RealPacket cst_pos_inf = pset1<RealPacket>(NumTraits<RealScalar>::infinity()); 854 Packet is_inf; 855 is_inf.v = pcmp_eq(a_abs, cst_pos_inf); 856 Packet is_real_inf; 857 is_real_inf.v = pand(is_inf.v, real_mask); 858 is_real_inf = por(is_real_inf, pcplxflip(is_real_inf)); 859 // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part. 860 Packet real_inf_result; 861 real_inf_result.v = pmul(a_abs, pset1<Packet>(Scalar(RealScalar(1.0), RealScalar(0.0))).v); 862 real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v); 863 // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part. 864 Packet is_imag_inf; 865 is_imag_inf.v = pandnot(is_inf.v, real_mask); 866 is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); 867 Packet imag_inf_result; 868 imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); 869 870 return pselect(is_imag_inf, imag_inf_result, 871 pselect(is_real_inf, real_inf_result,result)); 872 } 873 874 // TODO(rmlarsen): The following set of utilities for double word arithmetic 875 // should perhaps be refactored as a separate file, since it would be generally 876 // useful for special function implementation etc. Writing the algorithms in 877 // terms if a double word type would also make the code more readable. 878 879 // This function splits x into the nearest integer n and fractional part r, 880 // such that x = n + r holds exactly. 881 template<typename Packet> 882 EIGEN_STRONG_INLINE 883 void absolute_split(const Packet& x, Packet& n, Packet& r) { 884 n = pround(x); 885 r = psub(x, n); 886 } 887 888 // This function computes the sum {s, r}, such that x + y = s_hi + s_lo 889 // holds exactly, and s_hi = fl(x+y), if |x| >= |y|. 890 template<typename Packet> 891 EIGEN_STRONG_INLINE 892 void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) { 893 s_hi = padd(x, y); 894 const Packet t = psub(s_hi, x); 895 s_lo = psub(y, t); 896 } 897 898 #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD 899 // This function implements the extended precision product of 900 // a pair of floating point numbers. Given {x, y}, it computes the pair 901 // {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and 902 // p_hi = fl(x * y). 903 template<typename Packet> 904 EIGEN_STRONG_INLINE 905 void twoprod(const Packet& x, const Packet& y, 906 Packet& p_hi, Packet& p_lo) { 907 p_hi = pmul(x, y); 908 p_lo = pmadd(x, y, pnegate(p_hi)); 909 } 910 911 #else 912 913 // This function implements the Veltkamp splitting. Given a floating point 914 // number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds 915 // exactly and that half of the significant of x fits in x_hi. 916 // This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions", 917 // 3rd edition, Birkh\"auser, 2016. 918 template<typename Packet> 919 EIGEN_STRONG_INLINE 920 void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) { 921 typedef typename unpacket_traits<Packet>::type Scalar; 922 EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2; 923 const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr. 924 const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x); 925 Packet rho = psub(x, gamma); 926 x_hi = padd(rho, gamma); 927 x_lo = psub(x, x_hi); 928 } 929 930 // This function implements Dekker's algorithm for products x * y. 931 // Given floating point numbers {x, y} computes the pair 932 // {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and 933 // p_hi = fl(x * y). 934 template<typename Packet> 935 EIGEN_STRONG_INLINE 936 void twoprod(const Packet& x, const Packet& y, 937 Packet& p_hi, Packet& p_lo) { 938 Packet x_hi, x_lo, y_hi, y_lo; 939 veltkamp_splitting(x, x_hi, x_lo); 940 veltkamp_splitting(y, y_hi, y_lo); 941 942 p_hi = pmul(x, y); 943 p_lo = pmadd(x_hi, y_hi, pnegate(p_hi)); 944 p_lo = pmadd(x_hi, y_lo, p_lo); 945 p_lo = pmadd(x_lo, y_hi, p_lo); 946 p_lo = pmadd(x_lo, y_lo, p_lo); 947 } 948 949 #endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD 950 951 952 // This function implements Dekker's algorithm for the addition 953 // of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}. 954 // It returns the result as a pair {s_hi, s_lo} such that 955 // x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly. 956 // This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions", 957 // 3rd edition, Birkh\"auser, 2016. 958 template<typename Packet> 959 EIGEN_STRONG_INLINE 960 void twosum(const Packet& x_hi, const Packet& x_lo, 961 const Packet& y_hi, const Packet& y_lo, 962 Packet& s_hi, Packet& s_lo) { 963 const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi)); 964 Packet r_hi_1, r_lo_1; 965 fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1); 966 Packet r_hi_2, r_lo_2; 967 fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2); 968 const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2); 969 970 const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo); 971 const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo); 972 const Packet s = pselect(x_greater_mask, s1, s2); 973 974 fast_twosum(r_hi, s, s_hi, s_lo); 975 } 976 977 // This is a version of twosum for double word numbers, 978 // which assumes that |x_hi| >= |y_hi|. 979 template<typename Packet> 980 EIGEN_STRONG_INLINE 981 void fast_twosum(const Packet& x_hi, const Packet& x_lo, 982 const Packet& y_hi, const Packet& y_lo, 983 Packet& s_hi, Packet& s_lo) { 984 Packet r_hi, r_lo; 985 fast_twosum(x_hi, y_hi, r_hi, r_lo); 986 const Packet s = padd(padd(y_lo, r_lo), x_lo); 987 fast_twosum(r_hi, s, s_hi, s_lo); 988 } 989 990 // This is a version of twosum for adding a floating point number x to 991 // double word number {y_hi, y_lo} number, with the assumption 992 // that |x| >= |y_hi|. 993 template<typename Packet> 994 EIGEN_STRONG_INLINE 995 void fast_twosum(const Packet& x, 996 const Packet& y_hi, const Packet& y_lo, 997 Packet& s_hi, Packet& s_lo) { 998 Packet r_hi, r_lo; 999 fast_twosum(x, y_hi, r_hi, r_lo); 1000 const Packet s = padd(y_lo, r_lo); 1001 fast_twosum(r_hi, s, s_hi, s_lo); 1002 } 1003 1004 // This function implements the multiplication of a double word 1005 // number represented by {x_hi, x_lo} by a floating point number y. 1006 // It returns the result as a pair {p_hi, p_lo} such that 1007 // (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error 1008 // of less than 2*2^{-2p}, where p is the number of significand bit 1009 // in the floating point type. 1010 // This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions", 1011 // 3rd edition, Birkh\"auser, 2016. 1012 template<typename Packet> 1013 EIGEN_STRONG_INLINE 1014 void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y, 1015 Packet& p_hi, Packet& p_lo) { 1016 Packet c_hi, c_lo1; 1017 twoprod(x_hi, y, c_hi, c_lo1); 1018 const Packet c_lo2 = pmul(x_lo, y); 1019 Packet t_hi, t_lo1; 1020 fast_twosum(c_hi, c_lo2, t_hi, t_lo1); 1021 const Packet t_lo2 = padd(t_lo1, c_lo1); 1022 fast_twosum(t_hi, t_lo2, p_hi, p_lo); 1023 } 1024 1025 // This function implements the multiplication of two double word 1026 // numbers represented by {x_hi, x_lo} and {y_hi, y_lo}. 1027 // It returns the result as a pair {p_hi, p_lo} such that 1028 // (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error 1029 // of less than 2*2^{-2p}, where p is the number of significand bit 1030 // in the floating point type. 1031 template<typename Packet> 1032 EIGEN_STRONG_INLINE 1033 void twoprod(const Packet& x_hi, const Packet& x_lo, 1034 const Packet& y_hi, const Packet& y_lo, 1035 Packet& p_hi, Packet& p_lo) { 1036 Packet p_hi_hi, p_hi_lo; 1037 twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo); 1038 Packet p_lo_hi, p_lo_lo; 1039 twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo); 1040 fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo); 1041 } 1042 1043 // This function computes the reciprocal of a floating point number 1044 // with extra precision and returns the result as a double word. 1045 template <typename Packet> 1046 void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) { 1047 typedef typename unpacket_traits<Packet>::type Scalar; 1048 // 1. Approximate the reciprocal as the reciprocal of the high order element. 1049 Packet approx_recip = prsqrt(x); 1050 approx_recip = pmul(approx_recip, approx_recip); 1051 1052 // 2. Run one step of Newton-Raphson iteration in double word arithmetic 1053 // to get the bottom half. The NR iteration for reciprocal of 'a' is 1054 // x_{i+1} = x_i * (2 - a * x_i) 1055 1056 // -a*x_i 1057 Packet t1_hi, t1_lo; 1058 twoprod(pnegate(x), approx_recip, t1_hi, t1_lo); 1059 // 2 - a*x_i 1060 Packet t2_hi, t2_lo; 1061 fast_twosum(pset1<Packet>(Scalar(2)), t1_hi, t2_hi, t2_lo); 1062 Packet t3_hi, t3_lo; 1063 fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo); 1064 // x_i * (2 - a * x_i) 1065 twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo); 1066 } 1067 1068 1069 // This function computes log2(x) and returns the result as a double word. 1070 template <typename Scalar> 1071 struct accurate_log2 { 1072 template <typename Packet> 1073 EIGEN_STRONG_INLINE 1074 void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { 1075 log2_x_hi = plog2(x); 1076 log2_x_lo = pzero(x); 1077 } 1078 }; 1079 1080 // This specialization uses a more accurate algorithm to compute log2(x) for 1081 // floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10. 1082 // This additional accuracy is needed to counter the error-magnification 1083 // inherent in multiplying by a potentially large exponent in pow(x,y). 1084 // The minimax polynomial used was calculated using the Sollya tool. 1085 // See sollya.org. 1086 template <> 1087 struct accurate_log2<float> { 1088 template <typename Packet> 1089 EIGEN_STRONG_INLINE 1090 void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) { 1091 // The function log(1+x)/x is approximated in the interval 1092 // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form 1093 // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))), 1094 // where the degree 6 polynomial P(x) is evaluated in single precision, 1095 // while the remaining 4 terms of Q(x), as well as the final multiplication by x 1096 // to reconstruct log(1+x) are evaluated in extra precision using 1097 // double word arithmetic. C0 through C3 are extra precise constants 1098 // stored as double words. 1099 // 1100 // The polynomial coefficients were calculated using Sollya commands: 1101 // > n = 10; 1102 // > f = log2(1+x)/x; 1103 // > interval = [sqrt(0.5)-1;sqrt(2)-1]; 1104 // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating); 1105 1106 const Packet p6 = pset1<Packet>( 9.703654795885e-2f); 1107 const Packet p5 = pset1<Packet>(-0.1690667718648f); 1108 const Packet p4 = pset1<Packet>( 0.1720575392246f); 1109 const Packet p3 = pset1<Packet>(-0.1789081543684f); 1110 const Packet p2 = pset1<Packet>( 0.2050433009862f); 1111 const Packet p1 = pset1<Packet>(-0.2404672354459f); 1112 const Packet p0 = pset1<Packet>( 0.2885761857032f); 1113 1114 const Packet C3_hi = pset1<Packet>(-0.360674142838f); 1115 const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f); 1116 const Packet C2_hi = pset1<Packet>(0.480897903442f); 1117 const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f); 1118 const Packet C1_hi = pset1<Packet>(-0.721347510815f); 1119 const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f); 1120 const Packet C0_hi = pset1<Packet>(1.44269502163f); 1121 const Packet C0_lo = pset1<Packet>(2.01711713999e-08f); 1122 const Packet one = pset1<Packet>(1.0f); 1123 1124 const Packet x = psub(z, one); 1125 // Evaluate P(x) in working precision. 1126 // We evaluate it in multiple parts to improve instruction level 1127 // parallelism. 1128 Packet x2 = pmul(x,x); 1129 Packet p_even = pmadd(p6, x2, p4); 1130 p_even = pmadd(p_even, x2, p2); 1131 p_even = pmadd(p_even, x2, p0); 1132 Packet p_odd = pmadd(p5, x2, p3); 1133 p_odd = pmadd(p_odd, x2, p1); 1134 Packet p = pmadd(p_odd, x, p_even); 1135 1136 // Now evaluate the low-order tems of Q(x) in double word precision. 1137 // In the following, due to the alternating signs and the fact that 1138 // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use 1139 // fast_twosum instead of the slower twosum. 1140 Packet q_hi, q_lo; 1141 Packet t_hi, t_lo; 1142 // C3 + x * p(x) 1143 twoprod(p, x, t_hi, t_lo); 1144 fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo); 1145 // C2 + x * p(x) 1146 twoprod(q_hi, q_lo, x, t_hi, t_lo); 1147 fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo); 1148 // C1 + x * p(x) 1149 twoprod(q_hi, q_lo, x, t_hi, t_lo); 1150 fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo); 1151 // C0 + x * p(x) 1152 twoprod(q_hi, q_lo, x, t_hi, t_lo); 1153 fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo); 1154 1155 // log(z) ~= x * Q(x) 1156 twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo); 1157 } 1158 }; 1159 1160 // This specialization uses a more accurate algorithm to compute log2(x) for 1161 // floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18. 1162 // This additional accuracy is needed to counter the error-magnification 1163 // inherent in multiplying by a potentially large exponent in pow(x,y). 1164 // The minimax polynomial used was calculated using the Sollya tool. 1165 // See sollya.org. 1166 1167 template <> 1168 struct accurate_log2<double> { 1169 template <typename Packet> 1170 EIGEN_STRONG_INLINE 1171 void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { 1172 // We use a transformation of variables: 1173 // r = c * (x-1) / (x+1), 1174 // such that 1175 // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r). 1176 // The function f(r) can be approximated well using an odd polynomial 1177 // of the form 1178 // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r, 1179 // For the implementation of log2<double> here, Q is of degree 6 with 1180 // coefficient represented in working precision (double), while C is a 1181 // constant represented in extra precision as a double word to achieve 1182 // full accuracy. 1183 // 1184 // The polynomial coefficients were computed by the Sollya script: 1185 // 1186 // c = 2 / log(2); 1187 // trans = c * (x-1)/(x+1); 1188 // itrans = (1+x/c)/(1-x/c); 1189 // interval=[trans(sqrt(0.5)); trans(sqrt(2))]; 1190 // print(interval); 1191 // f = log2(itrans(x)); 1192 // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating); 1193 const Packet q12 = pset1<Packet>(2.87074255468000586e-9); 1194 const Packet q10 = pset1<Packet>(2.38957980901884082e-8); 1195 const Packet q8 = pset1<Packet>(2.31032094540014656e-7); 1196 const Packet q6 = pset1<Packet>(2.27279857398537278e-6); 1197 const Packet q4 = pset1<Packet>(2.31271023278625638e-5); 1198 const Packet q2 = pset1<Packet>(2.47556738444535513e-4); 1199 const Packet q0 = pset1<Packet>(2.88543873228900172e-3); 1200 const Packet C_hi = pset1<Packet>(0.0400377511598501157); 1201 const Packet C_lo = pset1<Packet>(-4.77726582251425391e-19); 1202 const Packet one = pset1<Packet>(1.0); 1203 1204 const Packet cst_2_log2e_hi = pset1<Packet>(2.88539008177792677); 1205 const Packet cst_2_log2e_lo = pset1<Packet>(4.07660016854549667e-17); 1206 // c * (x - 1) 1207 Packet num_hi, num_lo; 1208 twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo); 1209 // TODO(rmlarsen): Investigate if using the division algorithm by 1210 // Muller et al. is faster/more accurate. 1211 // 1 / (x + 1) 1212 Packet denom_hi, denom_lo; 1213 doubleword_reciprocal(padd(x, one), denom_hi, denom_lo); 1214 // r = c * (x-1) / (x+1), 1215 Packet r_hi, r_lo; 1216 twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo); 1217 // r2 = r * r 1218 Packet r2_hi, r2_lo; 1219 twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo); 1220 // r4 = r2 * r2 1221 Packet r4_hi, r4_lo; 1222 twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo); 1223 1224 // Evaluate Q(r^2) in working precision. We evaluate it in two parts 1225 // (even and odd in r^2) to improve instruction level parallelism. 1226 Packet q_even = pmadd(q12, r4_hi, q8); 1227 Packet q_odd = pmadd(q10, r4_hi, q6); 1228 q_even = pmadd(q_even, r4_hi, q4); 1229 q_odd = pmadd(q_odd, r4_hi, q2); 1230 q_even = pmadd(q_even, r4_hi, q0); 1231 Packet q = pmadd(q_odd, r2_hi, q_even); 1232 1233 // Now evaluate the low order terms of P(x) in double word precision. 1234 // In the following, due to the increasing magnitude of the coefficients 1235 // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead 1236 // of the slower twosum. 1237 // Q(r^2) * r^2 1238 Packet p_hi, p_lo; 1239 twoprod(r2_hi, r2_lo, q, p_hi, p_lo); 1240 // Q(r^2) * r^2 + C 1241 Packet p1_hi, p1_lo; 1242 fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo); 1243 // (Q(r^2) * r^2 + C) * r^2 1244 Packet p2_hi, p2_lo; 1245 twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo); 1246 // ((Q(r^2) * r^2 + C) * r^2 + 1) 1247 Packet p3_hi, p3_lo; 1248 fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo); 1249 1250 // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r 1251 twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo); 1252 } 1253 }; 1254 1255 // This function computes exp2(x) (i.e. 2**x). 1256 template <typename Scalar> 1257 struct fast_accurate_exp2 { 1258 template <typename Packet> 1259 EIGEN_STRONG_INLINE 1260 Packet operator()(const Packet& x) { 1261 // TODO(rmlarsen): Add a pexp2 packetop. 1262 return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x)); 1263 } 1264 }; 1265 1266 // This specialization uses a faster algorithm to compute exp2(x) for floats 1267 // in [-0.5;0.5] with a relative accuracy of 1 ulp. 1268 // The minimax polynomial used was calculated using the Sollya tool. 1269 // See sollya.org. 1270 template <> 1271 struct fast_accurate_exp2<float> { 1272 template <typename Packet> 1273 EIGEN_STRONG_INLINE 1274 Packet operator()(const Packet& x) { 1275 // This function approximates exp2(x) by a degree 6 polynomial of the form 1276 // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in 1277 // single precision, and the remaining steps are evaluated with extra precision using 1278 // double word arithmetic. C is an extra precise constant stored as a double word. 1279 // 1280 // The polynomial coefficients were calculated using Sollya commands: 1281 // > n = 6; 1282 // > f = 2^x; 1283 // > interval = [-0.5;0.5]; 1284 // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating); 1285 1286 const Packet p4 = pset1<Packet>(1.539513905e-4f); 1287 const Packet p3 = pset1<Packet>(1.340007293e-3f); 1288 const Packet p2 = pset1<Packet>(9.618283249e-3f); 1289 const Packet p1 = pset1<Packet>(5.550328270e-2f); 1290 const Packet p0 = pset1<Packet>(0.2402264923f); 1291 1292 const Packet C_hi = pset1<Packet>(0.6931471825f); 1293 const Packet C_lo = pset1<Packet>(2.36836577e-08f); 1294 const Packet one = pset1<Packet>(1.0f); 1295 1296 // Evaluate P(x) in working precision. 1297 // We evaluate even and odd parts of the polynomial separately 1298 // to gain some instruction level parallelism. 1299 Packet x2 = pmul(x,x); 1300 Packet p_even = pmadd(p4, x2, p2); 1301 Packet p_odd = pmadd(p3, x2, p1); 1302 p_even = pmadd(p_even, x2, p0); 1303 Packet p = pmadd(p_odd, x, p_even); 1304 1305 // Evaluate the remaining terms of Q(x) with extra precision using 1306 // double word arithmetic. 1307 Packet p_hi, p_lo; 1308 // x * p(x) 1309 twoprod(p, x, p_hi, p_lo); 1310 // C + x * p(x) 1311 Packet q1_hi, q1_lo; 1312 twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); 1313 // x * (C + x * p(x)) 1314 Packet q2_hi, q2_lo; 1315 twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); 1316 // 1 + x * (C + x * p(x)) 1317 Packet q3_hi, q3_lo; 1318 // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum 1319 // for adding it to unity here. 1320 fast_twosum(one, q2_hi, q3_hi, q3_lo); 1321 return padd(q3_hi, padd(q2_lo, q3_lo)); 1322 } 1323 }; 1324 1325 // in [-0.5;0.5] with a relative accuracy of 1 ulp. 1326 // The minimax polynomial used was calculated using the Sollya tool. 1327 // See sollya.org. 1328 template <> 1329 struct fast_accurate_exp2<double> { 1330 template <typename Packet> 1331 EIGEN_STRONG_INLINE 1332 Packet operator()(const Packet& x) { 1333 // This function approximates exp2(x) by a degree 10 polynomial of the form 1334 // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in 1335 // single precision, and the remaining steps are evaluated with extra precision using 1336 // double word arithmetic. C is an extra precise constant stored as a double word. 1337 // 1338 // The polynomial coefficients were calculated using Sollya commands: 1339 // > n = 11; 1340 // > f = 2^x; 1341 // > interval = [-0.5;0.5]; 1342 // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating); 1343 1344 const Packet p9 = pset1<Packet>(4.431642109085495276e-10); 1345 const Packet p8 = pset1<Packet>(7.073829923303358410e-9); 1346 const Packet p7 = pset1<Packet>(1.017822306737031311e-7); 1347 const Packet p6 = pset1<Packet>(1.321543498017646657e-6); 1348 const Packet p5 = pset1<Packet>(1.525273342728892877e-5); 1349 const Packet p4 = pset1<Packet>(1.540353045780084423e-4); 1350 const Packet p3 = pset1<Packet>(1.333355814685869807e-3); 1351 const Packet p2 = pset1<Packet>(9.618129107593478832e-3); 1352 const Packet p1 = pset1<Packet>(5.550410866481961247e-2); 1353 const Packet p0 = pset1<Packet>(0.240226506959101332); 1354 const Packet C_hi = pset1<Packet>(0.693147180559945286); 1355 const Packet C_lo = pset1<Packet>(4.81927865669806721e-17); 1356 const Packet one = pset1<Packet>(1.0); 1357 1358 // Evaluate P(x) in working precision. 1359 // We evaluate even and odd parts of the polynomial separately 1360 // to gain some instruction level parallelism. 1361 Packet x2 = pmul(x,x); 1362 Packet p_even = pmadd(p8, x2, p6); 1363 Packet p_odd = pmadd(p9, x2, p7); 1364 p_even = pmadd(p_even, x2, p4); 1365 p_odd = pmadd(p_odd, x2, p5); 1366 p_even = pmadd(p_even, x2, p2); 1367 p_odd = pmadd(p_odd, x2, p3); 1368 p_even = pmadd(p_even, x2, p0); 1369 p_odd = pmadd(p_odd, x2, p1); 1370 Packet p = pmadd(p_odd, x, p_even); 1371 1372 // Evaluate the remaining terms of Q(x) with extra precision using 1373 // double word arithmetic. 1374 Packet p_hi, p_lo; 1375 // x * p(x) 1376 twoprod(p, x, p_hi, p_lo); 1377 // C + x * p(x) 1378 Packet q1_hi, q1_lo; 1379 twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); 1380 // x * (C + x * p(x)) 1381 Packet q2_hi, q2_lo; 1382 twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); 1383 // 1 + x * (C + x * p(x)) 1384 Packet q3_hi, q3_lo; 1385 // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum 1386 // for adding it to unity here. 1387 fast_twosum(one, q2_hi, q3_hi, q3_lo); 1388 return padd(q3_hi, padd(q2_lo, q3_lo)); 1389 } 1390 }; 1391 1392 // This function implements the non-trivial case of pow(x,y) where x is 1393 // positive and y is (possibly) non-integer. 1394 // Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x. 1395 // TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it 1396 // easier to specialize or turn off for specific types and/or backends.x 1397 template <typename Packet> 1398 EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { 1399 typedef typename unpacket_traits<Packet>::type Scalar; 1400 // Split x into exponent e_x and mantissa m_x. 1401 Packet e_x; 1402 Packet m_x = pfrexp(x, e_x); 1403 1404 // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x). 1405 EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440); 1406 const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half)); 1407 m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x); 1408 e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x); 1409 1410 // Compute log2(m_x) with 6 extra bits of accuracy. 1411 Packet rx_hi, rx_lo; 1412 accurate_log2<Scalar>()(m_x, rx_hi, rx_lo); 1413 1414 // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled 1415 // precision using double word arithmetic. 1416 Packet f1_hi, f1_lo, f2_hi, f2_lo; 1417 twoprod(e_x, y, f1_hi, f1_lo); 1418 twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo); 1419 // Sum the two terms in f using double word arithmetic. We know 1420 // that |e_x| > |log2(m_x)|, except for the case where e_x==0. 1421 // This means that we can use fast_twosum(f1,f2). 1422 // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any 1423 // accuracy by violating the assumption of fast_twosum, because 1424 // it's a no-op. 1425 Packet f_hi, f_lo; 1426 fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo); 1427 1428 // Split f into integer and fractional parts. 1429 Packet n_z, r_z; 1430 absolute_split(f_hi, n_z, r_z); 1431 r_z = padd(r_z, f_lo); 1432 Packet n_r; 1433 absolute_split(r_z, n_r, r_z); 1434 n_z = padd(n_z, n_r); 1435 1436 // We now have an accurate split of f = n_z + r_z and can compute 1437 // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}. 1438 // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy 1439 // using a specialized algorithm. Multiplication by the second factor can 1440 // be done exactly using pldexp(), since it is an integer power of 2. 1441 const Packet e_r = fast_accurate_exp2<Scalar>()(r_z); 1442 return pldexp(e_r, n_z); 1443 } 1444 1445 // Generic implementation of pow(x,y). 1446 template<typename Packet> 1447 EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS 1448 EIGEN_UNUSED 1449 Packet generic_pow(const Packet& x, const Packet& y) { 1450 typedef typename unpacket_traits<Packet>::type Scalar; 1451 1452 const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity()); 1453 const Packet cst_zero = pset1<Packet>(Scalar(0)); 1454 const Packet cst_one = pset1<Packet>(Scalar(1)); 1455 const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN()); 1456 1457 const Packet abs_x = pabs(x); 1458 // Predicates for sign and magnitude of x. 1459 const Packet x_is_zero = pcmp_eq(x, cst_zero); 1460 const Packet x_is_neg = pcmp_lt(x, cst_zero); 1461 const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); 1462 const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one); 1463 const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x); 1464 const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); 1465 const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); 1466 const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); 1467 const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); 1468 1469 // Predicates for sign and magnitude of y. 1470 const Packet y_is_one = pcmp_eq(y, cst_one); 1471 const Packet y_is_zero = pcmp_eq(y, cst_zero); 1472 const Packet y_is_neg = pcmp_lt(y, cst_zero); 1473 const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg)); 1474 const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); 1475 const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf); 1476 EIGEN_CONSTEXPR Scalar huge_exponent = 1477 (NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / 1478 NumTraits<Scalar>::epsilon(); 1479 const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y)); 1480 1481 // Predicates for whether y is integer and/or even. 1482 const Packet y_is_int = pcmp_eq(pfloor(y), y); 1483 const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5))); 1484 const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); 1485 1486 // Predicates encoding special cases for the value of pow(x,y) 1487 const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), 1488 y_is_int), 1489 abs_y_is_inf); 1490 const Packet pow_is_one = por(por(x_is_one, y_is_zero), 1491 pand(x_is_neg_one, 1492 por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x)))); 1493 const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan)); 1494 const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), 1495 pand(abs_x_is_inf, y_is_neg)), 1496 pand(pand(abs_x_is_lt_one, abs_y_is_huge), 1497 y_is_pos)), 1498 pand(pand(abs_x_is_gt_one, abs_y_is_huge), 1499 y_is_neg)); 1500 const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), 1501 pand(abs_x_is_inf, y_is_pos)), 1502 pand(pand(abs_x_is_lt_one, abs_y_is_huge), 1503 y_is_neg)), 1504 pand(pand(abs_x_is_gt_one, abs_y_is_huge), 1505 y_is_pos)); 1506 1507 // General computation of pow(x,y) for positive x or negative x and integer y. 1508 const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even); 1509 const Packet pow_abs = generic_pow_impl(abs_x, y); 1510 return pselect(y_is_one, x, 1511 pselect(pow_is_one, cst_one, 1512 pselect(pow_is_nan, cst_nan, 1513 pselect(pow_is_inf, cst_pos_inf, 1514 pselect(pow_is_zero, cst_zero, 1515 pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))); 1516 } 1517 1518 1519 1520 /* polevl (modified for Eigen) 1521 * 1522 * Evaluate polynomial 1523 * 1524 * 1525 * 1526 * SYNOPSIS: 1527 * 1528 * int N; 1529 * Scalar x, y, coef[N+1]; 1530 * 1531 * y = polevl<decltype(x), N>( x, coef); 1532 * 1533 * 1534 * 1535 * DESCRIPTION: 1536 * 1537 * Evaluates polynomial of degree N: 1538 * 1539 * 2 N 1540 * y = C + C x + C x +...+ C x 1541 * 0 1 2 N 1542 * 1543 * Coefficients are stored in reverse order: 1544 * 1545 * coef[0] = C , ..., coef[N] = C . 1546 * N 0 1547 * 1548 * The function p1evl() assumes that coef[N] = 1.0 and is 1549 * omitted from the array. Its calling arguments are 1550 * otherwise the same as polevl(). 1551 * 1552 * 1553 * The Eigen implementation is templatized. For best speed, store 1554 * coef as a const array (constexpr), e.g. 1555 * 1556 * const double coef[] = {1.0, 2.0, 3.0, ...}; 1557 * 1558 */ 1559 template <typename Packet, int N> 1560 struct ppolevl { 1561 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) { 1562 EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); 1563 return pmadd(ppolevl<Packet, N-1>::run(x, coeff), x, pset1<Packet>(coeff[N])); 1564 } 1565 }; 1566 1567 template <typename Packet> 1568 struct ppolevl<Packet, 0> { 1569 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) { 1570 EIGEN_UNUSED_VARIABLE(x); 1571 return pset1<Packet>(coeff[0]); 1572 } 1573 }; 1574 1575 /* chbevl (modified for Eigen) 1576 * 1577 * Evaluate Chebyshev series 1578 * 1579 * 1580 * 1581 * SYNOPSIS: 1582 * 1583 * int N; 1584 * Scalar x, y, coef[N], chebevl(); 1585 * 1586 * y = chbevl( x, coef, N ); 1587 * 1588 * 1589 * 1590 * DESCRIPTION: 1591 * 1592 * Evaluates the series 1593 * 1594 * N-1 1595 * - ' 1596 * y = > coef[i] T (x/2) 1597 * - i 1598 * i=0 1599 * 1600 * of Chebyshev polynomials Ti at argument x/2. 1601 * 1602 * Coefficients are stored in reverse order, i.e. the zero 1603 * order term is last in the array. Note N is the number of 1604 * coefficients, not the order. 1605 * 1606 * If coefficients are for the interval a to b, x must 1607 * have been transformed to x -> 2(2x - b - a)/(b-a) before 1608 * entering the routine. This maps x from (a, b) to (-1, 1), 1609 * over which the Chebyshev polynomials are defined. 1610 * 1611 * If the coefficients are for the inverted interval, in 1612 * which (a, b) is mapped to (1/b, 1/a), the transformation 1613 * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, 1614 * this becomes x -> 4a/x - 1. 1615 * 1616 * 1617 * 1618 * SPEED: 1619 * 1620 * Taking advantage of the recurrence properties of the 1621 * Chebyshev polynomials, the routine requires one more 1622 * addition per loop than evaluating a nested polynomial of 1623 * the same degree. 1624 * 1625 */ 1626 1627 template <typename Packet, int N> 1628 struct pchebevl { 1629 EIGEN_DEVICE_FUNC 1630 static EIGEN_STRONG_INLINE Packet run(Packet x, const typename unpacket_traits<Packet>::type coef[]) { 1631 typedef typename unpacket_traits<Packet>::type Scalar; 1632 Packet b0 = pset1<Packet>(coef[0]); 1633 Packet b1 = pset1<Packet>(static_cast<Scalar>(0.f)); 1634 Packet b2; 1635 1636 for (int i = 1; i < N; i++) { 1637 b2 = b1; 1638 b1 = b0; 1639 b0 = psub(pmadd(x, b1, pset1<Packet>(coef[i])), b2); 1640 } 1641 1642 return pmul(pset1<Packet>(static_cast<Scalar>(0.5f)), psub(b0, b2)); 1643 } 1644 }; 1645 1646 } // end namespace internal 1647 } // end namespace Eigen 1648 1649 #endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H 1650