1""" 2Low-level functions for arbitrary-precision floating-point arithmetic. 3""" 4 5__docformat__ = 'plaintext' 6 7import math 8 9from bisect import bisect 10 11import sys 12 13# Importing random is slow 14#from random import getrandbits 15getrandbits = None 16 17from .backend import (MPZ, MPZ_TYPE, MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_FIVE, 18 BACKEND, STRICT, HASH_MODULUS, HASH_BITS, gmpy, sage, sage_utils) 19 20from .libintmath import (giant_steps, 21 trailtable, bctable, lshift, rshift, bitcount, trailing, 22 sqrt_fixed, numeral, isqrt, isqrt_fast, sqrtrem, 23 bin_to_radix) 24 25# We don't pickle tuples directly for the following reasons: 26# 1: pickle uses str() for ints, which is inefficient when they are large 27# 2: pickle doesn't work for gmpy mpzs 28# Both problems are solved by using hex() 29 30if BACKEND == 'sage': 31 def to_pickable(x): 32 sign, man, exp, bc = x 33 return sign, hex(man), exp, bc 34else: 35 def to_pickable(x): 36 sign, man, exp, bc = x 37 return sign, hex(man)[2:], exp, bc 38 39def from_pickable(x): 40 sign, man, exp, bc = x 41 return (sign, MPZ(man, 16), exp, bc) 42 43class ComplexResult(ValueError): 44 pass 45 46try: 47 intern 48except NameError: 49 intern = lambda x: x 50 51# All supported rounding modes 52round_nearest = intern('n') 53round_floor = intern('f') 54round_ceiling = intern('c') 55round_up = intern('u') 56round_down = intern('d') 57round_fast = round_down 58 59def prec_to_dps(n): 60 """Return number of accurate decimals that can be represented 61 with a precision of n bits.""" 62 return max(1, int(round(int(n)/3.3219280948873626)-1)) 63 64def dps_to_prec(n): 65 """Return the number of bits required to represent n decimals 66 accurately.""" 67 return max(1, int(round((int(n)+1)*3.3219280948873626))) 68 69def repr_dps(n): 70 """Return the number of decimal digits required to represent 71 a number with n-bit precision so that it can be uniquely 72 reconstructed from the representation.""" 73 dps = prec_to_dps(n) 74 if dps == 15: 75 return 17 76 return dps + 3 77 78#----------------------------------------------------------------------------# 79# Some commonly needed float values # 80#----------------------------------------------------------------------------# 81 82# Regular number format: 83# (-1)**sign * mantissa * 2**exponent, plus bitcount of mantissa 84fzero = (0, MPZ_ZERO, 0, 0) 85fnzero = (1, MPZ_ZERO, 0, 0) 86fone = (0, MPZ_ONE, 0, 1) 87fnone = (1, MPZ_ONE, 0, 1) 88ftwo = (0, MPZ_ONE, 1, 1) 89ften = (0, MPZ_FIVE, 1, 3) 90fhalf = (0, MPZ_ONE, -1, 1) 91 92# Arbitrary encoding for special numbers: zero mantissa, nonzero exponent 93fnan = (0, MPZ_ZERO, -123, -1) 94finf = (0, MPZ_ZERO, -456, -2) 95fninf = (1, MPZ_ZERO, -789, -3) 96 97# Was 1e1000; this is broken in Python 2.4 98math_float_inf = 1e300 * 1e300 99 100 101#----------------------------------------------------------------------------# 102# Rounding # 103#----------------------------------------------------------------------------# 104 105# This function can be used to round a mantissa generally. However, 106# we will try to do most rounding inline for efficiency. 107def round_int(x, n, rnd): 108 if rnd == round_nearest: 109 if x >= 0: 110 t = x >> (n-1) 111 if t & 1 and ((t & 2) or (x & h_mask[n<300][n])): 112 return (t>>1)+1 113 else: 114 return t>>1 115 else: 116 return -round_int(-x, n, rnd) 117 if rnd == round_floor: 118 return x >> n 119 if rnd == round_ceiling: 120 return -((-x) >> n) 121 if rnd == round_down: 122 if x >= 0: 123 return x >> n 124 return -((-x) >> n) 125 if rnd == round_up: 126 if x >= 0: 127 return -((-x) >> n) 128 return x >> n 129 130# These masks are used to pick out segments of numbers to determine 131# which direction to round when rounding to nearest. 132class h_mask_big: 133 def __getitem__(self, n): 134 return (MPZ_ONE<<(n-1))-1 135 136h_mask_small = [0]+[((MPZ_ONE<<(_-1))-1) for _ in range(1, 300)] 137h_mask = [h_mask_big(), h_mask_small] 138 139# The >> operator rounds to floor. shifts_down[rnd][sign] 140# tells whether this is the right direction to use, or if the 141# number should be negated before shifting 142shifts_down = {round_floor:(1,0), round_ceiling:(0,1), 143 round_down:(1,1), round_up:(0,0)} 144 145 146#----------------------------------------------------------------------------# 147# Normalization of raw mpfs # 148#----------------------------------------------------------------------------# 149 150# This function is called almost every time an mpf is created. 151# It has been optimized accordingly. 152 153def _normalize(sign, man, exp, bc, prec, rnd): 154 """ 155 Create a raw mpf tuple with value (-1)**sign * man * 2**exp and 156 normalized mantissa. The mantissa is rounded in the specified 157 direction if its size exceeds the precision. Trailing zero bits 158 are also stripped from the mantissa to ensure that the 159 representation is canonical. 160 161 Conditions on the input: 162 * The input must represent a regular (finite) number 163 * The sign bit must be 0 or 1 164 * The mantissa must be positive 165 * The exponent must be an integer 166 * The bitcount must be exact 167 168 If these conditions are not met, use from_man_exp, mpf_pos, or any 169 of the conversion functions to create normalized raw mpf tuples. 170 """ 171 if not man: 172 return fzero 173 # Cut mantissa down to size if larger than target precision 174 n = bc - prec 175 if n > 0: 176 if rnd == round_nearest: 177 t = man >> (n-1) 178 if t & 1 and ((t & 2) or (man & h_mask[n<300][n])): 179 man = (t>>1)+1 180 else: 181 man = t>>1 182 elif shifts_down[rnd][sign]: 183 man >>= n 184 else: 185 man = -((-man)>>n) 186 exp += n 187 bc = prec 188 # Strip trailing bits 189 if not man & 1: 190 t = trailtable[int(man & 255)] 191 if not t: 192 while not man & 255: 193 man >>= 8 194 exp += 8 195 bc -= 8 196 t = trailtable[int(man & 255)] 197 man >>= t 198 exp += t 199 bc -= t 200 # Bit count can be wrong if the input mantissa was 1 less than 201 # a power of 2 and got rounded up, thereby adding an extra bit. 202 # With trailing bits removed, all powers of two have mantissa 1, 203 # so this is easy to check for. 204 if man == 1: 205 bc = 1 206 return sign, man, exp, bc 207 208def _normalize1(sign, man, exp, bc, prec, rnd): 209 """same as normalize, but with the added condition that 210 man is odd or zero 211 """ 212 if not man: 213 return fzero 214 if bc <= prec: 215 return sign, man, exp, bc 216 n = bc - prec 217 if rnd == round_nearest: 218 t = man >> (n-1) 219 if t & 1 and ((t & 2) or (man & h_mask[n<300][n])): 220 man = (t>>1)+1 221 else: 222 man = t>>1 223 elif shifts_down[rnd][sign]: 224 man >>= n 225 else: 226 man = -((-man)>>n) 227 exp += n 228 bc = prec 229 # Strip trailing bits 230 if not man & 1: 231 t = trailtable[int(man & 255)] 232 if not t: 233 while not man & 255: 234 man >>= 8 235 exp += 8 236 bc -= 8 237 t = trailtable[int(man & 255)] 238 man >>= t 239 exp += t 240 bc -= t 241 # Bit count can be wrong if the input mantissa was 1 less than 242 # a power of 2 and got rounded up, thereby adding an extra bit. 243 # With trailing bits removed, all powers of two have mantissa 1, 244 # so this is easy to check for. 245 if man == 1: 246 bc = 1 247 return sign, man, exp, bc 248 249try: 250 _exp_types = (int, long) 251except NameError: 252 _exp_types = (int,) 253 254def strict_normalize(sign, man, exp, bc, prec, rnd): 255 """Additional checks on the components of an mpf. Enable tests by setting 256 the environment variable MPMATH_STRICT to Y.""" 257 assert type(man) == MPZ_TYPE 258 assert type(bc) in _exp_types 259 assert type(exp) in _exp_types 260 assert bc == bitcount(man) 261 return _normalize(sign, man, exp, bc, prec, rnd) 262 263def strict_normalize1(sign, man, exp, bc, prec, rnd): 264 """Additional checks on the components of an mpf. Enable tests by setting 265 the environment variable MPMATH_STRICT to Y.""" 266 assert type(man) == MPZ_TYPE 267 assert type(bc) in _exp_types 268 assert type(exp) in _exp_types 269 assert bc == bitcount(man) 270 assert (not man) or (man & 1) 271 return _normalize1(sign, man, exp, bc, prec, rnd) 272 273if BACKEND == 'gmpy' and '_mpmath_normalize' in dir(gmpy): 274 _normalize = gmpy._mpmath_normalize 275 _normalize1 = gmpy._mpmath_normalize 276 277if BACKEND == 'sage': 278 _normalize = _normalize1 = sage_utils.normalize 279 280if STRICT: 281 normalize = strict_normalize 282 normalize1 = strict_normalize1 283else: 284 normalize = _normalize 285 normalize1 = _normalize1 286 287#----------------------------------------------------------------------------# 288# Conversion functions # 289#----------------------------------------------------------------------------# 290 291def from_man_exp(man, exp, prec=None, rnd=round_fast): 292 """Create raw mpf from (man, exp) pair. The mantissa may be signed. 293 If no precision is specified, the mantissa is stored exactly.""" 294 man = MPZ(man) 295 sign = 0 296 if man < 0: 297 sign = 1 298 man = -man 299 if man < 1024: 300 bc = bctable[int(man)] 301 else: 302 bc = bitcount(man) 303 if not prec: 304 if not man: 305 return fzero 306 if not man & 1: 307 if man & 2: 308 return (sign, man >> 1, exp + 1, bc - 1) 309 t = trailtable[int(man & 255)] 310 if not t: 311 while not man & 255: 312 man >>= 8 313 exp += 8 314 bc -= 8 315 t = trailtable[int(man & 255)] 316 man >>= t 317 exp += t 318 bc -= t 319 return (sign, man, exp, bc) 320 return normalize(sign, man, exp, bc, prec, rnd) 321 322int_cache = dict((n, from_man_exp(n, 0)) for n in range(-10, 257)) 323 324if BACKEND == 'gmpy' and '_mpmath_create' in dir(gmpy): 325 from_man_exp = gmpy._mpmath_create 326 327if BACKEND == 'sage': 328 from_man_exp = sage_utils.from_man_exp 329 330def from_int(n, prec=0, rnd=round_fast): 331 """Create a raw mpf from an integer. If no precision is specified, 332 the mantissa is stored exactly.""" 333 if not prec: 334 if n in int_cache: 335 return int_cache[n] 336 return from_man_exp(n, 0, prec, rnd) 337 338def to_man_exp(s): 339 """Return (man, exp) of a raw mpf. Raise an error if inf/nan.""" 340 sign, man, exp, bc = s 341 if (not man) and exp: 342 raise ValueError("mantissa and exponent are undefined for %s" % man) 343 return man, exp 344 345def to_int(s, rnd=None): 346 """Convert a raw mpf to the nearest int. Rounding is done down by 347 default (same as int(float) in Python), but can be changed. If the 348 input is inf/nan, an exception is raised.""" 349 sign, man, exp, bc = s 350 if (not man) and exp: 351 raise ValueError("cannot convert inf or nan to int") 352 if exp >= 0: 353 if sign: 354 return (-man) << exp 355 return man << exp 356 # Make default rounding fast 357 if not rnd: 358 if sign: 359 return -(man >> (-exp)) 360 else: 361 return man >> (-exp) 362 if sign: 363 return round_int(-man, -exp, rnd) 364 else: 365 return round_int(man, -exp, rnd) 366 367def mpf_round_int(s, rnd): 368 sign, man, exp, bc = s 369 if (not man) and exp: 370 return s 371 if exp >= 0: 372 return s 373 mag = exp+bc 374 if mag < 1: 375 if rnd == round_ceiling: 376 if sign: return fzero 377 else: return fone 378 elif rnd == round_floor: 379 if sign: return fnone 380 else: return fzero 381 elif rnd == round_nearest: 382 if mag < 0 or man == MPZ_ONE: return fzero 383 elif sign: return fnone 384 else: return fone 385 else: 386 raise NotImplementedError 387 return mpf_pos(s, min(bc, mag), rnd) 388 389def mpf_floor(s, prec=0, rnd=round_fast): 390 v = mpf_round_int(s, round_floor) 391 if prec: 392 v = mpf_pos(v, prec, rnd) 393 return v 394 395def mpf_ceil(s, prec=0, rnd=round_fast): 396 v = mpf_round_int(s, round_ceiling) 397 if prec: 398 v = mpf_pos(v, prec, rnd) 399 return v 400 401def mpf_nint(s, prec=0, rnd=round_fast): 402 v = mpf_round_int(s, round_nearest) 403 if prec: 404 v = mpf_pos(v, prec, rnd) 405 return v 406 407def mpf_frac(s, prec=0, rnd=round_fast): 408 return mpf_sub(s, mpf_floor(s), prec, rnd) 409 410def from_float(x, prec=53, rnd=round_fast): 411 """Create a raw mpf from a Python float, rounding if necessary. 412 If prec >= 53, the result is guaranteed to represent exactly the 413 same number as the input. If prec is not specified, use prec=53.""" 414 # frexp only raises an exception for nan on some platforms 415 if x != x: 416 return fnan 417 # in Python2.5 math.frexp gives an exception for float infinity 418 # in Python2.6 it returns (float infinity, 0) 419 try: 420 m, e = math.frexp(x) 421 except: 422 if x == math_float_inf: return finf 423 if x == -math_float_inf: return fninf 424 return fnan 425 if x == math_float_inf: return finf 426 if x == -math_float_inf: return fninf 427 return from_man_exp(int(m*(1<<53)), e-53, prec, rnd) 428 429def from_npfloat(x, prec=113, rnd=round_fast): 430 """Create a raw mpf from a numpy float, rounding if necessary. 431 If prec >= 113, the result is guaranteed to represent exactly the 432 same number as the input. If prec is not specified, use prec=113.""" 433 y = float(x) 434 if x == y: # ldexp overflows for float16 435 return from_float(y, prec, rnd) 436 import numpy as np 437 if np.isfinite(x): 438 m, e = np.frexp(x) 439 return from_man_exp(int(np.ldexp(m, 113)), int(e-113), prec, rnd) 440 if np.isposinf(x): return finf 441 if np.isneginf(x): return fninf 442 return fnan 443 444def from_Decimal(x, prec=None, rnd=round_fast): 445 """Create a raw mpf from a Decimal, rounding if necessary. 446 If prec is not specified, use the equivalent bit precision 447 of the number of significant digits in x.""" 448 if x.is_nan(): return fnan 449 if x.is_infinite(): return fninf if x.is_signed() else finf 450 if prec is None: 451 prec = int(len(x.as_tuple()[1])*3.3219280948873626) 452 return from_str(str(x), prec, rnd) 453 454def to_float(s, strict=False, rnd=round_fast): 455 """ 456 Convert a raw mpf to a Python float. The result is exact if the 457 bitcount of s is <= 53 and no underflow/overflow occurs. 458 459 If the number is too large or too small to represent as a regular 460 float, it will be converted to inf or 0.0. Setting strict=True 461 forces an OverflowError to be raised instead. 462 463 Warning: with a directed rounding mode, the correct nearest representable 464 floating-point number in the specified direction might not be computed 465 in case of overflow or (gradual) underflow. 466 """ 467 sign, man, exp, bc = s 468 if not man: 469 if s == fzero: return 0.0 470 if s == finf: return math_float_inf 471 if s == fninf: return -math_float_inf 472 return math_float_inf/math_float_inf 473 if bc > 53: 474 sign, man, exp, bc = normalize1(sign, man, exp, bc, 53, rnd) 475 if sign: 476 man = -man 477 try: 478 return math.ldexp(man, exp) 479 except OverflowError: 480 if strict: 481 raise 482 # Overflow to infinity 483 if exp + bc > 0: 484 if sign: 485 return -math_float_inf 486 else: 487 return math_float_inf 488 # Underflow to zero 489 return 0.0 490 491def from_rational(p, q, prec, rnd=round_fast): 492 """Create a raw mpf from a rational number p/q, round if 493 necessary.""" 494 return mpf_div(from_int(p), from_int(q), prec, rnd) 495 496def to_rational(s): 497 """Convert a raw mpf to a rational number. Return integers (p, q) 498 such that s = p/q exactly.""" 499 sign, man, exp, bc = s 500 if sign: 501 man = -man 502 if bc == -1: 503 raise ValueError("cannot convert %s to a rational number" % man) 504 if exp >= 0: 505 return man * (1<<exp), 1 506 else: 507 return man, 1<<(-exp) 508 509def to_fixed(s, prec): 510 """Convert a raw mpf to a fixed-point big integer""" 511 sign, man, exp, bc = s 512 offset = exp + prec 513 if sign: 514 if offset >= 0: return (-man) << offset 515 else: return (-man) >> (-offset) 516 else: 517 if offset >= 0: return man << offset 518 else: return man >> (-offset) 519 520 521############################################################################## 522############################################################################## 523 524#----------------------------------------------------------------------------# 525# Arithmetic operations, etc. # 526#----------------------------------------------------------------------------# 527 528def mpf_rand(prec): 529 """Return a raw mpf chosen randomly from [0, 1), with prec bits 530 in the mantissa.""" 531 global getrandbits 532 if not getrandbits: 533 import random 534 getrandbits = random.getrandbits 535 return from_man_exp(getrandbits(prec), -prec, prec, round_floor) 536 537def mpf_eq(s, t): 538 """Test equality of two raw mpfs. This is simply tuple comparison 539 unless either number is nan, in which case the result is False.""" 540 if not s[1] or not t[1]: 541 if s == fnan or t == fnan: 542 return False 543 return s == t 544 545def mpf_hash(s): 546 # Duplicate the new hash algorithm introduces in Python 3.2. 547 if sys.version_info >= (3, 2): 548 ssign, sman, sexp, sbc = s 549 550 # Handle special numbers 551 if not sman: 552 if s == fnan: return sys.hash_info.nan 553 if s == finf: return sys.hash_info.inf 554 if s == fninf: return -sys.hash_info.inf 555 h = sman % HASH_MODULUS 556 if sexp >= 0: 557 sexp = sexp % HASH_BITS 558 else: 559 sexp = HASH_BITS - 1 - ((-1 - sexp) % HASH_BITS) 560 h = (h << sexp) % HASH_MODULUS 561 if ssign: h = -h 562 if h == -1: h == -2 563 return int(h) 564 else: 565 try: 566 # Try to be compatible with hash values for floats and ints 567 return hash(to_float(s, strict=1)) 568 except OverflowError: 569 # We must unfortunately sacrifice compatibility with ints here. 570 # We could do hash(man << exp) when the exponent is positive, but 571 # this would cause unreasonable inefficiency for large numbers. 572 return hash(s) 573 574def mpf_cmp(s, t): 575 """Compare the raw mpfs s and t. Return -1 if s < t, 0 if s == t, 576 and 1 if s > t. (Same convention as Python's cmp() function.)""" 577 578 # In principle, a comparison amounts to determining the sign of s-t. 579 # A full subtraction is relatively slow, however, so we first try to 580 # look at the components. 581 ssign, sman, sexp, sbc = s 582 tsign, tman, texp, tbc = t 583 584 # Handle zeros and special numbers 585 if not sman or not tman: 586 if s == fzero: return -mpf_sign(t) 587 if t == fzero: return mpf_sign(s) 588 if s == t: return 0 589 # Follow same convention as Python's cmp for float nan 590 if t == fnan: return 1 591 if s == finf: return 1 592 if t == fninf: return 1 593 return -1 594 # Different sides of zero 595 if ssign != tsign: 596 if not ssign: return 1 597 return -1 598 # This reduces to direct integer comparison 599 if sexp == texp: 600 if sman == tman: 601 return 0 602 if sman > tman: 603 if ssign: return -1 604 else: return 1 605 else: 606 if ssign: return 1 607 else: return -1 608 # Check position of the highest set bit in each number. If 609 # different, there is certainly an inequality. 610 a = sbc + sexp 611 b = tbc + texp 612 if ssign: 613 if a < b: return 1 614 if a > b: return -1 615 else: 616 if a < b: return -1 617 if a > b: return 1 618 619 # Both numbers have the same highest bit. Subtract to find 620 # how the lower bits compare. 621 delta = mpf_sub(s, t, 5, round_floor) 622 if delta[0]: 623 return -1 624 return 1 625 626def mpf_lt(s, t): 627 if s == fnan or t == fnan: 628 return False 629 return mpf_cmp(s, t) < 0 630 631def mpf_le(s, t): 632 if s == fnan or t == fnan: 633 return False 634 return mpf_cmp(s, t) <= 0 635 636def mpf_gt(s, t): 637 if s == fnan or t == fnan: 638 return False 639 return mpf_cmp(s, t) > 0 640 641def mpf_ge(s, t): 642 if s == fnan or t == fnan: 643 return False 644 return mpf_cmp(s, t) >= 0 645 646def mpf_min_max(seq): 647 min = max = seq[0] 648 for x in seq[1:]: 649 if mpf_lt(x, min): min = x 650 if mpf_gt(x, max): max = x 651 return min, max 652 653def mpf_pos(s, prec=0, rnd=round_fast): 654 """Calculate 0+s for a raw mpf (i.e., just round s to the specified 655 precision).""" 656 if prec: 657 sign, man, exp, bc = s 658 if (not man) and exp: 659 return s 660 return normalize1(sign, man, exp, bc, prec, rnd) 661 return s 662 663def mpf_neg(s, prec=None, rnd=round_fast): 664 """Negate a raw mpf (return -s), rounding the result to the 665 specified precision. The prec argument can be omitted to do the 666 operation exactly.""" 667 sign, man, exp, bc = s 668 if not man: 669 if exp: 670 if s == finf: return fninf 671 if s == fninf: return finf 672 return s 673 if not prec: 674 return (1-sign, man, exp, bc) 675 return normalize1(1-sign, man, exp, bc, prec, rnd) 676 677def mpf_abs(s, prec=None, rnd=round_fast): 678 """Return abs(s) of the raw mpf s, rounded to the specified 679 precision. The prec argument can be omitted to generate an 680 exact result.""" 681 sign, man, exp, bc = s 682 if (not man) and exp: 683 if s == fninf: 684 return finf 685 return s 686 if not prec: 687 if sign: 688 return (0, man, exp, bc) 689 return s 690 return normalize1(0, man, exp, bc, prec, rnd) 691 692def mpf_sign(s): 693 """Return -1, 0, or 1 (as a Python int, not a raw mpf) depending on 694 whether s is negative, zero, or positive. (Nan is taken to give 0.)""" 695 sign, man, exp, bc = s 696 if not man: 697 if s == finf: return 1 698 if s == fninf: return -1 699 return 0 700 return (-1) ** sign 701 702def mpf_add(s, t, prec=0, rnd=round_fast, _sub=0): 703 """ 704 Add the two raw mpf values s and t. 705 706 With prec=0, no rounding is performed. Note that this can 707 produce a very large mantissa (potentially too large to fit 708 in memory) if exponents are far apart. 709 """ 710 ssign, sman, sexp, sbc = s 711 tsign, tman, texp, tbc = t 712 tsign ^= _sub 713 # Standard case: two nonzero, regular numbers 714 if sman and tman: 715 offset = sexp - texp 716 if offset: 717 if offset > 0: 718 # Outside precision range; only need to perturb 719 if offset > 100 and prec: 720 delta = sbc + sexp - tbc - texp 721 if delta > prec + 4: 722 offset = prec + 4 723 sman <<= offset 724 if tsign == ssign: sman += 1 725 else: sman -= 1 726 return normalize1(ssign, sman, sexp-offset, 727 bitcount(sman), prec, rnd) 728 # Add 729 if ssign == tsign: 730 man = tman + (sman << offset) 731 # Subtract 732 else: 733 if ssign: man = tman - (sman << offset) 734 else: man = (sman << offset) - tman 735 if man >= 0: 736 ssign = 0 737 else: 738 man = -man 739 ssign = 1 740 bc = bitcount(man) 741 return normalize1(ssign, man, texp, bc, prec or bc, rnd) 742 elif offset < 0: 743 # Outside precision range; only need to perturb 744 if offset < -100 and prec: 745 delta = tbc + texp - sbc - sexp 746 if delta > prec + 4: 747 offset = prec + 4 748 tman <<= offset 749 if ssign == tsign: tman += 1 750 else: tman -= 1 751 return normalize1(tsign, tman, texp-offset, 752 bitcount(tman), prec, rnd) 753 # Add 754 if ssign == tsign: 755 man = sman + (tman << -offset) 756 # Subtract 757 else: 758 if tsign: man = sman - (tman << -offset) 759 else: man = (tman << -offset) - sman 760 if man >= 0: 761 ssign = 0 762 else: 763 man = -man 764 ssign = 1 765 bc = bitcount(man) 766 return normalize1(ssign, man, sexp, bc, prec or bc, rnd) 767 # Equal exponents; no shifting necessary 768 if ssign == tsign: 769 man = tman + sman 770 else: 771 if ssign: man = tman - sman 772 else: man = sman - tman 773 if man >= 0: 774 ssign = 0 775 else: 776 man = -man 777 ssign = 1 778 bc = bitcount(man) 779 return normalize(ssign, man, texp, bc, prec or bc, rnd) 780 # Handle zeros and special numbers 781 if _sub: 782 t = mpf_neg(t) 783 if not sman: 784 if sexp: 785 if s == t or tman or not texp: 786 return s 787 return fnan 788 if tman: 789 return normalize1(tsign, tman, texp, tbc, prec or tbc, rnd) 790 return t 791 if texp: 792 return t 793 if sman: 794 return normalize1(ssign, sman, sexp, sbc, prec or sbc, rnd) 795 return s 796 797def mpf_sub(s, t, prec=0, rnd=round_fast): 798 """Return the difference of two raw mpfs, s-t. This function is 799 simply a wrapper of mpf_add that changes the sign of t.""" 800 return mpf_add(s, t, prec, rnd, 1) 801 802def mpf_sum(xs, prec=0, rnd=round_fast, absolute=False): 803 """ 804 Sum a list of mpf values efficiently and accurately 805 (typically no temporary roundoff occurs). If prec=0, 806 the final result will not be rounded either. 807 808 There may be roundoff error or cancellation if extremely 809 large exponent differences occur. 810 811 With absolute=True, sums the absolute values. 812 """ 813 man = 0 814 exp = 0 815 max_extra_prec = prec*2 or 1000000 # XXX 816 special = None 817 for x in xs: 818 xsign, xman, xexp, xbc = x 819 if xman: 820 if xsign and not absolute: 821 xman = -xman 822 delta = xexp - exp 823 if xexp >= exp: 824 # x much larger than existing sum? 825 # first: quick test 826 if (delta > max_extra_prec) and \ 827 ((not man) or delta-bitcount(abs(man)) > max_extra_prec): 828 man = xman 829 exp = xexp 830 else: 831 man += (xman << delta) 832 else: 833 delta = -delta 834 # x much smaller than existing sum? 835 if delta-xbc > max_extra_prec: 836 if not man: 837 man, exp = xman, xexp 838 else: 839 man = (man << delta) + xman 840 exp = xexp 841 elif xexp: 842 if absolute: 843 x = mpf_abs(x) 844 special = mpf_add(special or fzero, x, 1) 845 # Will be inf or nan 846 if special: 847 return special 848 return from_man_exp(man, exp, prec, rnd) 849 850def gmpy_mpf_mul(s, t, prec=0, rnd=round_fast): 851 """Multiply two raw mpfs""" 852 ssign, sman, sexp, sbc = s 853 tsign, tman, texp, tbc = t 854 sign = ssign ^ tsign 855 man = sman*tman 856 if man: 857 bc = bitcount(man) 858 if prec: 859 return normalize1(sign, man, sexp+texp, bc, prec, rnd) 860 else: 861 return (sign, man, sexp+texp, bc) 862 s_special = (not sman) and sexp 863 t_special = (not tman) and texp 864 if not s_special and not t_special: 865 return fzero 866 if fnan in (s, t): return fnan 867 if (not tman) and texp: s, t = t, s 868 if t == fzero: return fnan 869 return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)] 870 871def gmpy_mpf_mul_int(s, n, prec, rnd=round_fast): 872 """Multiply by a Python integer.""" 873 sign, man, exp, bc = s 874 if not man: 875 return mpf_mul(s, from_int(n), prec, rnd) 876 if not n: 877 return fzero 878 if n < 0: 879 sign ^= 1 880 n = -n 881 man *= n 882 return normalize(sign, man, exp, bitcount(man), prec, rnd) 883 884def python_mpf_mul(s, t, prec=0, rnd=round_fast): 885 """Multiply two raw mpfs""" 886 ssign, sman, sexp, sbc = s 887 tsign, tman, texp, tbc = t 888 sign = ssign ^ tsign 889 man = sman*tman 890 if man: 891 bc = sbc + tbc - 1 892 bc += int(man>>bc) 893 if prec: 894 return normalize1(sign, man, sexp+texp, bc, prec, rnd) 895 else: 896 return (sign, man, sexp+texp, bc) 897 s_special = (not sman) and sexp 898 t_special = (not tman) and texp 899 if not s_special and not t_special: 900 return fzero 901 if fnan in (s, t): return fnan 902 if (not tman) and texp: s, t = t, s 903 if t == fzero: return fnan 904 return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)] 905 906def python_mpf_mul_int(s, n, prec, rnd=round_fast): 907 """Multiply by a Python integer.""" 908 sign, man, exp, bc = s 909 if not man: 910 return mpf_mul(s, from_int(n), prec, rnd) 911 if not n: 912 return fzero 913 if n < 0: 914 sign ^= 1 915 n = -n 916 man *= n 917 # Generally n will be small 918 if n < 1024: 919 bc += bctable[int(n)] - 1 920 else: 921 bc += bitcount(n) - 1 922 bc += int(man>>bc) 923 return normalize(sign, man, exp, bc, prec, rnd) 924 925 926if BACKEND == 'gmpy': 927 mpf_mul = gmpy_mpf_mul 928 mpf_mul_int = gmpy_mpf_mul_int 929else: 930 mpf_mul = python_mpf_mul 931 mpf_mul_int = python_mpf_mul_int 932 933def mpf_shift(s, n): 934 """Quickly multiply the raw mpf s by 2**n without rounding.""" 935 sign, man, exp, bc = s 936 if not man: 937 return s 938 return sign, man, exp+n, bc 939 940def mpf_frexp(x): 941 """Convert x = y*2**n to (y, n) with abs(y) in [0.5, 1) if nonzero""" 942 sign, man, exp, bc = x 943 if not man: 944 if x == fzero: 945 return (fzero, 0) 946 else: 947 raise ValueError 948 return mpf_shift(x, -bc-exp), bc+exp 949 950def mpf_div(s, t, prec, rnd=round_fast): 951 """Floating-point division""" 952 ssign, sman, sexp, sbc = s 953 tsign, tman, texp, tbc = t 954 if not sman or not tman: 955 if s == fzero: 956 if t == fzero: raise ZeroDivisionError 957 if t == fnan: return fnan 958 return fzero 959 if t == fzero: 960 raise ZeroDivisionError 961 s_special = (not sman) and sexp 962 t_special = (not tman) and texp 963 if s_special and t_special: 964 return fnan 965 if s == fnan or t == fnan: 966 return fnan 967 if not t_special: 968 if t == fzero: 969 return fnan 970 return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)] 971 return fzero 972 sign = ssign ^ tsign 973 if tman == 1: 974 return normalize1(sign, sman, sexp-texp, sbc, prec, rnd) 975 # Same strategy as for addition: if there is a remainder, perturb 976 # the result a few bits outside the precision range before rounding 977 extra = prec - sbc + tbc + 5 978 if extra < 5: 979 extra = 5 980 quot, rem = divmod(sman<<extra, tman) 981 if rem: 982 quot = (quot<<1) + 1 983 extra += 1 984 return normalize1(sign, quot, sexp-texp-extra, bitcount(quot), prec, rnd) 985 return normalize(sign, quot, sexp-texp-extra, bitcount(quot), prec, rnd) 986 987def mpf_rdiv_int(n, t, prec, rnd=round_fast): 988 """Floating-point division n/t with a Python integer as numerator""" 989 sign, man, exp, bc = t 990 if not n or not man: 991 return mpf_div(from_int(n), t, prec, rnd) 992 if n < 0: 993 sign ^= 1 994 n = -n 995 extra = prec + bc + 5 996 quot, rem = divmod(n<<extra, man) 997 if rem: 998 quot = (quot<<1) + 1 999 extra += 1 1000 return normalize1(sign, quot, -exp-extra, bitcount(quot), prec, rnd) 1001 return normalize(sign, quot, -exp-extra, bitcount(quot), prec, rnd) 1002 1003def mpf_mod(s, t, prec, rnd=round_fast): 1004 ssign, sman, sexp, sbc = s 1005 tsign, tman, texp, tbc = t 1006 if ((not sman) and sexp) or ((not tman) and texp): 1007 return fnan 1008 # Important special case: do nothing if t is larger 1009 if ssign == tsign and texp > sexp+sbc: 1010 return s 1011 # Another important special case: this allows us to do e.g. x % 1.0 1012 # to find the fractional part of x, and it will work when x is huge. 1013 if tman == 1 and sexp > texp+tbc: 1014 return fzero 1015 base = min(sexp, texp) 1016 sman = (-1)**ssign * sman 1017 tman = (-1)**tsign * tman 1018 man = (sman << (sexp-base)) % (tman << (texp-base)) 1019 if man >= 0: 1020 sign = 0 1021 else: 1022 man = -man 1023 sign = 1 1024 return normalize(sign, man, base, bitcount(man), prec, rnd) 1025 1026reciprocal_rnd = { 1027 round_down : round_up, 1028 round_up : round_down, 1029 round_floor : round_ceiling, 1030 round_ceiling : round_floor, 1031 round_nearest : round_nearest 1032} 1033 1034negative_rnd = { 1035 round_down : round_down, 1036 round_up : round_up, 1037 round_floor : round_ceiling, 1038 round_ceiling : round_floor, 1039 round_nearest : round_nearest 1040} 1041 1042def mpf_pow_int(s, n, prec, rnd=round_fast): 1043 """Compute s**n, where s is a raw mpf and n is a Python integer.""" 1044 sign, man, exp, bc = s 1045 1046 if (not man) and exp: 1047 if s == finf: 1048 if n > 0: return s 1049 if n == 0: return fnan 1050 return fzero 1051 if s == fninf: 1052 if n > 0: return [finf, fninf][n & 1] 1053 if n == 0: return fnan 1054 return fzero 1055 return fnan 1056 1057 n = int(n) 1058 if n == 0: return fone 1059 if n == 1: return mpf_pos(s, prec, rnd) 1060 if n == 2: 1061 _, man, exp, bc = s 1062 if not man: 1063 return fzero 1064 man = man*man 1065 if man == 1: 1066 return (0, MPZ_ONE, exp+exp, 1) 1067 bc = bc + bc - 2 1068 bc += bctable[int(man>>bc)] 1069 return normalize1(0, man, exp+exp, bc, prec, rnd) 1070 if n == -1: return mpf_div(fone, s, prec, rnd) 1071 if n < 0: 1072 inverse = mpf_pow_int(s, -n, prec+5, reciprocal_rnd[rnd]) 1073 return mpf_div(fone, inverse, prec, rnd) 1074 1075 result_sign = sign & n 1076 1077 # Use exact integer power when the exact mantissa is small 1078 if man == 1: 1079 return (result_sign, MPZ_ONE, exp*n, 1) 1080 if bc*n < 1000: 1081 man **= n 1082 return normalize1(result_sign, man, exp*n, bitcount(man), prec, rnd) 1083 1084 # Use directed rounding all the way through to maintain rigorous 1085 # bounds for interval arithmetic 1086 rounds_down = (rnd == round_nearest) or \ 1087 shifts_down[rnd][result_sign] 1088 1089 # Now we perform binary exponentiation. Need to estimate precision 1090 # to avoid rounding errors from temporary operations. Roughly log_2(n) 1091 # operations are performed. 1092 workprec = prec + 4*bitcount(n) + 4 1093 _, pm, pe, pbc = fone 1094 while 1: 1095 if n & 1: 1096 pm = pm*man 1097 pe = pe+exp 1098 pbc += bc - 2 1099 pbc = pbc + bctable[int(pm >> pbc)] 1100 if pbc > workprec: 1101 if rounds_down: 1102 pm = pm >> (pbc-workprec) 1103 else: 1104 pm = -((-pm) >> (pbc-workprec)) 1105 pe += pbc - workprec 1106 pbc = workprec 1107 n -= 1 1108 if not n: 1109 break 1110 man = man*man 1111 exp = exp+exp 1112 bc = bc + bc - 2 1113 bc = bc + bctable[int(man >> bc)] 1114 if bc > workprec: 1115 if rounds_down: 1116 man = man >> (bc-workprec) 1117 else: 1118 man = -((-man) >> (bc-workprec)) 1119 exp += bc - workprec 1120 bc = workprec 1121 n = n // 2 1122 1123 return normalize(result_sign, pm, pe, pbc, prec, rnd) 1124 1125 1126def mpf_perturb(x, eps_sign, prec, rnd): 1127 """ 1128 For nonzero x, calculate x + eps with directed rounding, where 1129 eps < prec relatively and eps has the given sign (0 for 1130 positive, 1 for negative). 1131 1132 With rounding to nearest, this is taken to simply normalize 1133 x to the given precision. 1134 """ 1135 if rnd == round_nearest: 1136 return mpf_pos(x, prec, rnd) 1137 sign, man, exp, bc = x 1138 eps = (eps_sign, MPZ_ONE, exp+bc-prec-1, 1) 1139 if sign: 1140 away = (rnd in (round_down, round_ceiling)) ^ eps_sign 1141 else: 1142 away = (rnd in (round_up, round_ceiling)) ^ eps_sign 1143 if away: 1144 return mpf_add(x, eps, prec, rnd) 1145 else: 1146 return mpf_pos(x, prec, rnd) 1147 1148 1149#----------------------------------------------------------------------------# 1150# Radix conversion # 1151#----------------------------------------------------------------------------# 1152 1153def to_digits_exp(s, dps): 1154 """Helper function for representing the floating-point number s as 1155 a decimal with dps digits. Returns (sign, string, exponent) where 1156 sign is '' or '-', string is the digit string, and exponent is 1157 the decimal exponent as an int. 1158 1159 If inexact, the decimal representation is rounded toward zero.""" 1160 1161 # Extract sign first so it doesn't mess up the string digit count 1162 if s[0]: 1163 sign = '-' 1164 s = mpf_neg(s) 1165 else: 1166 sign = '' 1167 _sign, man, exp, bc = s 1168 1169 if not man: 1170 return '', '0', 0 1171 1172 bitprec = int(dps * math.log(10,2)) + 10 1173 1174 # Cut down to size 1175 # TODO: account for precision when doing this 1176 exp_from_1 = exp + bc 1177 if abs(exp_from_1) > 3500: 1178 from .libelefun import mpf_ln2, mpf_ln10 1179 # Set b = int(exp * log(2)/log(10)) 1180 # If exp is huge, we must use high-precision arithmetic to 1181 # find the nearest power of ten 1182 expprec = bitcount(abs(exp)) + 5 1183 tmp = from_int(exp) 1184 tmp = mpf_mul(tmp, mpf_ln2(expprec)) 1185 tmp = mpf_div(tmp, mpf_ln10(expprec), expprec) 1186 b = to_int(tmp) 1187 s = mpf_div(s, mpf_pow_int(ften, b, bitprec), bitprec) 1188 _sign, man, exp, bc = s 1189 exponent = b 1190 else: 1191 exponent = 0 1192 1193 # First, calculate mantissa digits by converting to a binary 1194 # fixed-point number and then converting that number to 1195 # a decimal fixed-point number. 1196 fixprec = max(bitprec - exp - bc, 0) 1197 fixdps = int(fixprec / math.log(10,2) + 0.5) 1198 sf = to_fixed(s, fixprec) 1199 sd = bin_to_radix(sf, fixprec, 10, fixdps) 1200 digits = numeral(sd, base=10, size=dps) 1201 1202 exponent += len(digits) - fixdps - 1 1203 return sign, digits, exponent 1204 1205def to_str(s, dps, strip_zeros=True, min_fixed=None, max_fixed=None, 1206 show_zero_exponent=False): 1207 """ 1208 Convert a raw mpf to a decimal floating-point literal with at 1209 most `dps` decimal digits in the mantissa (not counting extra zeros 1210 that may be inserted for visual purposes). 1211 1212 The number will be printed in fixed-point format if the position 1213 of the leading digit is strictly between min_fixed 1214 (default = min(-dps/3,-5)) and max_fixed (default = dps). 1215 1216 To force fixed-point format always, set min_fixed = -inf, 1217 max_fixed = +inf. To force floating-point format, set 1218 min_fixed >= max_fixed. 1219 1220 The literal is formatted so that it can be parsed back to a number 1221 by to_str, float() or Decimal(). 1222 """ 1223 1224 # Special numbers 1225 if not s[1]: 1226 if s == fzero: 1227 if dps: t = '0.0' 1228 else: t = '.0' 1229 if show_zero_exponent: 1230 t += 'e+0' 1231 return t 1232 if s == finf: return '+inf' 1233 if s == fninf: return '-inf' 1234 if s == fnan: return 'nan' 1235 raise ValueError 1236 1237 if min_fixed is None: min_fixed = min(-(dps//3), -5) 1238 if max_fixed is None: max_fixed = dps 1239 1240 # to_digits_exp rounds to floor. 1241 # This sometimes kills some instances of "...00001" 1242 sign, digits, exponent = to_digits_exp(s, dps+3) 1243 1244 # No digits: show only .0; round exponent to nearest 1245 if not dps: 1246 if digits[0] in '56789': 1247 exponent += 1 1248 digits = ".0" 1249 1250 else: 1251 # Rounding up kills some instances of "...99999" 1252 if len(digits) > dps and digits[dps] in '56789': 1253 digits = digits[:dps] 1254 i = dps - 1 1255 while i >= 0 and digits[i] == '9': 1256 i -= 1 1257 if i >= 0: 1258 digits = digits[:i] + str(int(digits[i]) + 1) + '0' * (dps - i - 1) 1259 else: 1260 digits = '1' + '0' * (dps - 1) 1261 exponent += 1 1262 else: 1263 digits = digits[:dps] 1264 1265 # Prettify numbers close to unit magnitude 1266 if min_fixed < exponent < max_fixed: 1267 if exponent < 0: 1268 digits = ("0"*int(-exponent)) + digits 1269 split = 1 1270 else: 1271 split = exponent + 1 1272 if split > dps: 1273 digits += "0"*(split-dps) 1274 exponent = 0 1275 else: 1276 split = 1 1277 1278 digits = (digits[:split] + "." + digits[split:]) 1279 1280 if strip_zeros: 1281 # Clean up trailing zeros 1282 digits = digits.rstrip('0') 1283 if digits[-1] == ".": 1284 digits += "0" 1285 1286 if exponent == 0 and dps and not show_zero_exponent: return sign + digits 1287 if exponent >= 0: return sign + digits + "e+" + str(exponent) 1288 if exponent < 0: return sign + digits + "e" + str(exponent) 1289 1290def str_to_man_exp(x, base=10): 1291 """Helper function for from_str.""" 1292 x = x.lower().rstrip('l') 1293 # Verify that the input is a valid float literal 1294 float(x) 1295 # Split into mantissa, exponent 1296 parts = x.split('e') 1297 if len(parts) == 1: 1298 exp = 0 1299 else: # == 2 1300 x = parts[0] 1301 exp = int(parts[1]) 1302 # Look for radix point in mantissa 1303 parts = x.split('.') 1304 if len(parts) == 2: 1305 a, b = parts[0], parts[1].rstrip('0') 1306 exp -= len(b) 1307 x = a + b 1308 x = MPZ(int(x, base)) 1309 return x, exp 1310 1311special_str = {'inf':finf, '+inf':finf, '-inf':fninf, 'nan':fnan} 1312 1313def from_str(x, prec, rnd=round_fast): 1314 """Create a raw mpf from a decimal literal, rounding in the 1315 specified direction if the input number cannot be represented 1316 exactly as a binary floating-point number with the given number of 1317 bits. The literal syntax accepted is the same as for Python 1318 floats. 1319 1320 TODO: the rounding does not work properly for large exponents. 1321 """ 1322 x = x.lower().strip() 1323 if x in special_str: 1324 return special_str[x] 1325 1326 if '/' in x: 1327 p, q = x.split('/') 1328 p, q = p.rstrip('l'), q.rstrip('l') 1329 return from_rational(int(p), int(q), prec, rnd) 1330 1331 man, exp = str_to_man_exp(x, base=10) 1332 1333 # XXX: appropriate cutoffs & track direction 1334 # note no factors of 5 1335 if abs(exp) > 400: 1336 s = from_int(man, prec+10) 1337 s = mpf_mul(s, mpf_pow_int(ften, exp, prec+10), prec, rnd) 1338 else: 1339 if exp >= 0: 1340 s = from_int(man * 10**exp, prec, rnd) 1341 else: 1342 s = from_rational(man, 10**-exp, prec, rnd) 1343 return s 1344 1345# Binary string conversion. These are currently mainly used for debugging 1346# and could use some improvement in the future 1347 1348def from_bstr(x): 1349 man, exp = str_to_man_exp(x, base=2) 1350 man = MPZ(man) 1351 sign = 0 1352 if man < 0: 1353 man = -man 1354 sign = 1 1355 bc = bitcount(man) 1356 return normalize(sign, man, exp, bc, bc, round_floor) 1357 1358def to_bstr(x): 1359 sign, man, exp, bc = x 1360 return ['','-'][sign] + numeral(man, size=bitcount(man), base=2) + ("e%i" % exp) 1361 1362 1363#----------------------------------------------------------------------------# 1364# Square roots # 1365#----------------------------------------------------------------------------# 1366 1367 1368def mpf_sqrt(s, prec, rnd=round_fast): 1369 """ 1370 Compute the square root of a nonnegative mpf value. The 1371 result is correctly rounded. 1372 """ 1373 sign, man, exp, bc = s 1374 if sign: 1375 raise ComplexResult("square root of a negative number") 1376 if not man: 1377 return s 1378 if exp & 1: 1379 exp -= 1 1380 man <<= 1 1381 bc += 1 1382 elif man == 1: 1383 return normalize1(sign, man, exp//2, bc, prec, rnd) 1384 shift = max(4, 2*prec-bc+4) 1385 shift += shift & 1 1386 if rnd in 'fd': 1387 man = isqrt(man<<shift) 1388 else: 1389 man, rem = sqrtrem(man<<shift) 1390 # Perturb up 1391 if rem: 1392 man = (man<<1)+1 1393 shift += 2 1394 return from_man_exp(man, (exp-shift)//2, prec, rnd) 1395 1396def mpf_hypot(x, y, prec, rnd=round_fast): 1397 """Compute the Euclidean norm sqrt(x**2 + y**2) of two raw mpfs 1398 x and y.""" 1399 if y == fzero: return mpf_abs(x, prec, rnd) 1400 if x == fzero: return mpf_abs(y, prec, rnd) 1401 hypot2 = mpf_add(mpf_mul(x,x), mpf_mul(y,y), prec+4) 1402 return mpf_sqrt(hypot2, prec, rnd) 1403 1404 1405if BACKEND == 'sage': 1406 try: 1407 import sage.libs.mpmath.ext_libmp as ext_lib 1408 mpf_add = ext_lib.mpf_add 1409 mpf_sub = ext_lib.mpf_sub 1410 mpf_mul = ext_lib.mpf_mul 1411 mpf_div = ext_lib.mpf_div 1412 mpf_sqrt = ext_lib.mpf_sqrt 1413 except ImportError: 1414 pass 1415