1from ..libmp.backend import xrange 2 3class SpecialFunctions(object): 4 """ 5 This class implements special functions using high-level code. 6 7 Elementary and some other functions (e.g. gamma function, basecase 8 hypergeometric series) are assumed to be predefined by the context as 9 "builtins" or "low-level" functions. 10 """ 11 defined_functions = {} 12 13 # The series for the Jacobi theta functions converge for |q| < 1; 14 # in the current implementation they throw a ValueError for 15 # abs(q) > THETA_Q_LIM 16 THETA_Q_LIM = 1 - 10**-7 17 18 def __init__(self): 19 cls = self.__class__ 20 for name in cls.defined_functions: 21 f, wrap = cls.defined_functions[name] 22 cls._wrap_specfun(name, f, wrap) 23 24 self.mpq_1 = self._mpq((1,1)) 25 self.mpq_0 = self._mpq((0,1)) 26 self.mpq_1_2 = self._mpq((1,2)) 27 self.mpq_3_2 = self._mpq((3,2)) 28 self.mpq_1_4 = self._mpq((1,4)) 29 self.mpq_1_16 = self._mpq((1,16)) 30 self.mpq_3_16 = self._mpq((3,16)) 31 self.mpq_5_2 = self._mpq((5,2)) 32 self.mpq_3_4 = self._mpq((3,4)) 33 self.mpq_7_4 = self._mpq((7,4)) 34 self.mpq_5_4 = self._mpq((5,4)) 35 self.mpq_1_3 = self._mpq((1,3)) 36 self.mpq_2_3 = self._mpq((2,3)) 37 self.mpq_4_3 = self._mpq((4,3)) 38 self.mpq_1_6 = self._mpq((1,6)) 39 self.mpq_5_6 = self._mpq((5,6)) 40 self.mpq_5_3 = self._mpq((5,3)) 41 42 self._misc_const_cache = {} 43 44 self._aliases.update({ 45 'phase' : 'arg', 46 'conjugate' : 'conj', 47 'nthroot' : 'root', 48 'polygamma' : 'psi', 49 'hurwitz' : 'zeta', 50 #'digamma' : 'psi0', 51 #'trigamma' : 'psi1', 52 #'tetragamma' : 'psi2', 53 #'pentagamma' : 'psi3', 54 'fibonacci' : 'fib', 55 'factorial' : 'fac', 56 }) 57 58 self.zetazero_memoized = self.memoize(self.zetazero) 59 60 # Default -- do nothing 61 @classmethod 62 def _wrap_specfun(cls, name, f, wrap): 63 setattr(cls, name, f) 64 65 # Optional fast versions of common functions in common cases. 66 # If not overridden, default (generic hypergeometric series) 67 # implementations will be used 68 def _besselj(ctx, n, z): raise NotImplementedError 69 def _erf(ctx, z): raise NotImplementedError 70 def _erfc(ctx, z): raise NotImplementedError 71 def _gamma_upper_int(ctx, z, a): raise NotImplementedError 72 def _expint_int(ctx, n, z): raise NotImplementedError 73 def _zeta(ctx, s): raise NotImplementedError 74 def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError 75 def _ei(ctx, z): raise NotImplementedError 76 def _e1(ctx, z): raise NotImplementedError 77 def _ci(ctx, z): raise NotImplementedError 78 def _si(ctx, z): raise NotImplementedError 79 def _altzeta(ctx, s): raise NotImplementedError 80 81def defun_wrapped(f): 82 SpecialFunctions.defined_functions[f.__name__] = f, True 83 return f 84 85def defun(f): 86 SpecialFunctions.defined_functions[f.__name__] = f, False 87 return f 88 89def defun_static(f): 90 setattr(SpecialFunctions, f.__name__, f) 91 return f 92 93@defun_wrapped 94def cot(ctx, z): return ctx.one / ctx.tan(z) 95 96@defun_wrapped 97def sec(ctx, z): return ctx.one / ctx.cos(z) 98 99@defun_wrapped 100def csc(ctx, z): return ctx.one / ctx.sin(z) 101 102@defun_wrapped 103def coth(ctx, z): return ctx.one / ctx.tanh(z) 104 105@defun_wrapped 106def sech(ctx, z): return ctx.one / ctx.cosh(z) 107 108@defun_wrapped 109def csch(ctx, z): return ctx.one / ctx.sinh(z) 110 111@defun_wrapped 112def acot(ctx, z): 113 if not z: 114 return ctx.pi * 0.5 115 else: 116 return ctx.atan(ctx.one / z) 117 118@defun_wrapped 119def asec(ctx, z): return ctx.acos(ctx.one / z) 120 121@defun_wrapped 122def acsc(ctx, z): return ctx.asin(ctx.one / z) 123 124@defun_wrapped 125def acoth(ctx, z): 126 if not z: 127 return ctx.pi * 0.5j 128 else: 129 return ctx.atanh(ctx.one / z) 130 131 132@defun_wrapped 133def asech(ctx, z): return ctx.acosh(ctx.one / z) 134 135@defun_wrapped 136def acsch(ctx, z): return ctx.asinh(ctx.one / z) 137 138@defun 139def sign(ctx, x): 140 x = ctx.convert(x) 141 if not x or ctx.isnan(x): 142 return x 143 if ctx._is_real_type(x): 144 if x > 0: 145 return ctx.one 146 else: 147 return -ctx.one 148 return x / abs(x) 149 150@defun 151def agm(ctx, a, b=1): 152 if b == 1: 153 return ctx.agm1(a) 154 a = ctx.convert(a) 155 b = ctx.convert(b) 156 return ctx._agm(a, b) 157 158@defun_wrapped 159def sinc(ctx, x): 160 if ctx.isinf(x): 161 return 1/x 162 if not x: 163 return x+1 164 return ctx.sin(x)/x 165 166@defun_wrapped 167def sincpi(ctx, x): 168 if ctx.isinf(x): 169 return 1/x 170 if not x: 171 return x+1 172 return ctx.sinpi(x)/(ctx.pi*x) 173 174# TODO: tests; improve implementation 175@defun_wrapped 176def expm1(ctx, x): 177 if not x: 178 return ctx.zero 179 # exp(x) - 1 ~ x 180 if ctx.mag(x) < -ctx.prec: 181 return x + 0.5*x**2 182 # TODO: accurately eval the smaller of the real/imag parts 183 return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1) 184 185@defun_wrapped 186def log1p(ctx, x): 187 if not x: 188 return ctx.zero 189 if ctx.mag(x) < -ctx.prec: 190 return x - 0.5*x**2 191 return ctx.log(ctx.fadd(1, x, prec=2*ctx.prec)) 192 193@defun_wrapped 194def powm1(ctx, x, y): 195 mag = ctx.mag 196 one = ctx.one 197 w = x**y - one 198 M = mag(w) 199 # Only moderate cancellation 200 if M > -8: 201 return w 202 # Check for the only possible exact cases 203 if not w: 204 if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)): 205 return w 206 x1 = x - one 207 magy = mag(y) 208 lnx = ctx.ln(x) 209 # Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2) 210 if magy + mag(lnx) < -ctx.prec: 211 return lnx*y + (lnx*y)**2/2 212 # TODO: accurately eval the smaller of the real/imag part 213 return ctx.sum_accurately(lambda: iter([x**y, -1]), 1) 214 215@defun 216def _rootof1(ctx, k, n): 217 k = int(k) 218 n = int(n) 219 k %= n 220 if not k: 221 return ctx.one 222 elif 2*k == n: 223 return -ctx.one 224 elif 4*k == n: 225 return ctx.j 226 elif 4*k == 3*n: 227 return -ctx.j 228 return ctx.expjpi(2*ctx.mpf(k)/n) 229 230@defun 231def root(ctx, x, n, k=0): 232 n = int(n) 233 x = ctx.convert(x) 234 if k: 235 # Special case: there is an exact real root 236 if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0): 237 return -ctx.root(-x, n) 238 # Multiply by root of unity 239 prec = ctx.prec 240 try: 241 ctx.prec += 10 242 v = ctx.root(x, n, 0) * ctx._rootof1(k, n) 243 finally: 244 ctx.prec = prec 245 return +v 246 return ctx._nthroot(x, n) 247 248@defun 249def unitroots(ctx, n, primitive=False): 250 gcd = ctx._gcd 251 prec = ctx.prec 252 try: 253 ctx.prec += 10 254 if primitive: 255 v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1] 256 else: 257 # TODO: this can be done *much* faster 258 v = [ctx._rootof1(k,n) for k in range(n)] 259 finally: 260 ctx.prec = prec 261 return [+x for x in v] 262 263@defun 264def arg(ctx, x): 265 x = ctx.convert(x) 266 re = ctx._re(x) 267 im = ctx._im(x) 268 return ctx.atan2(im, re) 269 270@defun 271def fabs(ctx, x): 272 return abs(ctx.convert(x)) 273 274@defun 275def re(ctx, x): 276 x = ctx.convert(x) 277 if hasattr(x, "real"): # py2.5 doesn't have .real/.imag for all numbers 278 return x.real 279 return x 280 281@defun 282def im(ctx, x): 283 x = ctx.convert(x) 284 if hasattr(x, "imag"): # py2.5 doesn't have .real/.imag for all numbers 285 return x.imag 286 return ctx.zero 287 288@defun 289def conj(ctx, x): 290 x = ctx.convert(x) 291 try: 292 return x.conjugate() 293 except AttributeError: 294 return x 295 296@defun 297def polar(ctx, z): 298 return (ctx.fabs(z), ctx.arg(z)) 299 300@defun_wrapped 301def rect(ctx, r, phi): 302 return r * ctx.mpc(*ctx.cos_sin(phi)) 303 304@defun 305def log(ctx, x, b=None): 306 if b is None: 307 return ctx.ln(x) 308 wp = ctx.prec + 20 309 return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp) 310 311@defun 312def log10(ctx, x): 313 return ctx.log(x, 10) 314 315@defun 316def fmod(ctx, x, y): 317 return ctx.convert(x) % ctx.convert(y) 318 319@defun 320def degrees(ctx, x): 321 return x / ctx.degree 322 323@defun 324def radians(ctx, x): 325 return x * ctx.degree 326 327def _lambertw_special(ctx, z, k): 328 # W(0,0) = 0; all other branches are singular 329 if not z: 330 if not k: 331 return z 332 return ctx.ninf + z 333 if z == ctx.inf: 334 if k == 0: 335 return z 336 else: 337 return z + 2*k*ctx.pi*ctx.j 338 if z == ctx.ninf: 339 return (-z) + (2*k+1)*ctx.pi*ctx.j 340 # Some kind of nan or complex inf/nan? 341 return ctx.ln(z) 342 343import math 344import cmath 345 346def _lambertw_approx_hybrid(z, k): 347 imag_sign = 0 348 if hasattr(z, "imag"): 349 x = float(z.real) 350 y = z.imag 351 if y: 352 imag_sign = (-1) ** (y < 0) 353 y = float(y) 354 else: 355 x = float(z) 356 y = 0.0 357 imag_sign = 0 358 # hack to work regardless of whether Python supports -0.0 359 if not y: 360 y = 0.0 361 z = complex(x,y) 362 if k == 0: 363 if -4.0 < y < 4.0 and -1.0 < x < 2.5: 364 if imag_sign: 365 # Taylor series in upper/lower half-plane 366 if y > 1.00: return (0.876+0.645j) + (0.118-0.174j)*(z-(0.75+2.5j)) 367 if y > 0.25: return (0.505+0.204j) + (0.375-0.132j)*(z-(0.75+0.5j)) 368 if y < -1.00: return (0.876-0.645j) + (0.118+0.174j)*(z-(0.75-2.5j)) 369 if y < -0.25: return (0.505-0.204j) + (0.375+0.132j)*(z-(0.75-0.5j)) 370 # Taylor series near -1 371 if x < -0.5: 372 if imag_sign >= 0: 373 return (-0.318+1.34j) + (-0.697-0.593j)*(z+1) 374 else: 375 return (-0.318-1.34j) + (-0.697+0.593j)*(z+1) 376 # return real type 377 r = -0.367879441171442 378 if (not imag_sign) and x > r: 379 z = x 380 # Singularity near -1/e 381 if x < -0.2: 382 return -1 + 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r) 383 # Taylor series near 0 384 if x < 0.5: return z 385 # Simple linear approximation 386 return 0.2 + 0.3*z 387 if (not imag_sign) and x > 0.0: 388 L1 = math.log(x); L2 = math.log(L1) 389 else: 390 L1 = cmath.log(z); L2 = cmath.log(L1) 391 elif k == -1: 392 # return real type 393 r = -0.367879441171442 394 if (not imag_sign) and r < x < 0.0: 395 z = x 396 if (imag_sign >= 0) and y < 0.1 and -0.6 < x < -0.2: 397 return -1 - 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r) 398 if (not imag_sign) and -0.2 <= x < 0.0: 399 L1 = math.log(-x) 400 return L1 - math.log(-L1) 401 else: 402 if imag_sign == -1 and (not y) and x < 0.0: 403 L1 = cmath.log(z) - 3.1415926535897932j 404 else: 405 L1 = cmath.log(z) - 6.2831853071795865j 406 L2 = cmath.log(L1) 407 return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2) 408 409def _lambertw_series(ctx, z, k, tol): 410 """ 411 Return rough approximation for W_k(z) from an asymptotic series, 412 sufficiently accurate for the Halley iteration to converge to 413 the correct value. 414 """ 415 magz = ctx.mag(z) 416 if (-10 < magz < 900) and (-1000 < k < 1000): 417 # Near the branch point at -1/e 418 if magz < 1 and abs(z+0.36787944117144) < 0.05: 419 if k == 0 or (k == -1 and ctx._im(z) >= 0) or \ 420 (k == 1 and ctx._im(z) < 0): 421 delta = ctx.sum_accurately(lambda: [z, ctx.exp(-1)]) 422 cancellation = -ctx.mag(delta) 423 ctx.prec += cancellation 424 # Use series given in Corless et al. 425 p = ctx.sqrt(2*(ctx.e*z+1)) 426 ctx.prec -= cancellation 427 u = {0:ctx.mpf(-1), 1:ctx.mpf(1)} 428 a = {0:ctx.mpf(2), 1:ctx.mpf(-1)} 429 if k != 0: 430 p = -p 431 s = ctx.zero 432 # The series converges, so we could use it directly, but unless 433 # *extremely* close, it is better to just use the first few 434 # terms to get a good approximation for the iteration 435 for l in xrange(max(2,cancellation)): 436 if l not in u: 437 a[l] = ctx.fsum(u[j]*u[l+1-j] for j in xrange(2,l)) 438 u[l] = (l-1)*(u[l-2]/2+a[l-2]/4)/(l+1)-a[l]/2-u[l-1]/(l+1) 439 term = u[l] * p**l 440 s += term 441 if ctx.mag(term) < -tol: 442 return s, True 443 l += 1 444 ctx.prec += cancellation//2 445 return s, False 446 if k == 0 or k == -1: 447 return _lambertw_approx_hybrid(z, k), False 448 if k == 0: 449 if magz < -1: 450 return z*(1-z), False 451 L1 = ctx.ln(z) 452 L2 = ctx.ln(L1) 453 elif k == -1 and (not ctx._im(z)) and (-0.36787944117144 < ctx._re(z) < 0): 454 L1 = ctx.ln(-z) 455 return L1 - ctx.ln(-L1), False 456 else: 457 # This holds both as z -> 0 and z -> inf. 458 # Relative error is O(1/log(z)). 459 L1 = ctx.ln(z) + 2j*ctx.pi*k 460 L2 = ctx.ln(L1) 461 return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2), False 462 463@defun 464def lambertw(ctx, z, k=0): 465 z = ctx.convert(z) 466 k = int(k) 467 if not ctx.isnormal(z): 468 return _lambertw_special(ctx, z, k) 469 prec = ctx.prec 470 ctx.prec += 20 + ctx.mag(k or 1) 471 wp = ctx.prec 472 tol = wp - 5 473 w, done = _lambertw_series(ctx, z, k, tol) 474 if not done: 475 # Use Halley iteration to solve w*exp(w) = z 476 two = ctx.mpf(2) 477 for i in xrange(100): 478 ew = ctx.exp(w) 479 wew = w*ew 480 wewz = wew-z 481 wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two)) 482 if ctx.mag(wn-w) <= ctx.mag(wn) - tol: 483 w = wn 484 break 485 else: 486 w = wn 487 if i == 100: 488 ctx.warn("Lambert W iteration failed to converge for z = %s" % z) 489 ctx.prec = prec 490 return +w 491 492@defun_wrapped 493def bell(ctx, n, x=1): 494 x = ctx.convert(x) 495 if not n: 496 if ctx.isnan(x): 497 return x 498 return type(x)(1) 499 if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n): 500 return x**n 501 if n == 1: return x 502 if n == 2: return x*(x+1) 503 if x == 0: return ctx.sincpi(n) 504 return _polyexp(ctx, n, x, True) / ctx.exp(x) 505 506def _polyexp(ctx, n, x, extra=False): 507 def _terms(): 508 if extra: 509 yield ctx.sincpi(n) 510 t = x 511 k = 1 512 while 1: 513 yield k**n * t 514 k += 1 515 t = t*x/k 516 return ctx.sum_accurately(_terms, check_step=4) 517 518@defun_wrapped 519def polyexp(ctx, s, z): 520 if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s): 521 return z**s 522 if z == 0: return z*s 523 if s == 0: return ctx.expm1(z) 524 if s == 1: return ctx.exp(z)*z 525 if s == 2: return ctx.exp(z)*z*(z+1) 526 return _polyexp(ctx, s, z) 527 528@defun_wrapped 529def cyclotomic(ctx, n, z): 530 n = int(n) 531 if n < 0: 532 raise ValueError("n cannot be negative") 533 p = ctx.one 534 if n == 0: 535 return p 536 if n == 1: 537 return z - p 538 if n == 2: 539 return z + p 540 # Use divisor product representation. Unfortunately, this sometimes 541 # includes singularities for roots of unity, which we have to cancel out. 542 # Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1). 543 a_prod = 1 544 b_prod = 1 545 num_zeros = 0 546 num_poles = 0 547 for d in range(1,n+1): 548 if not n % d: 549 w = ctx.moebius(n//d) 550 # Use powm1 because it is important that we get 0 only 551 # if it really is exactly 0 552 b = -ctx.powm1(z, d) 553 if b: 554 p *= b**w 555 else: 556 if w == 1: 557 a_prod *= d 558 num_zeros += 1 559 elif w == -1: 560 b_prod *= d 561 num_poles += 1 562 #print n, num_zeros, num_poles 563 if num_zeros: 564 if num_zeros > num_poles: 565 p *= 0 566 else: 567 p *= a_prod 568 p /= b_prod 569 return p 570 571@defun 572def mangoldt(ctx, n): 573 r""" 574 Evaluates the von Mangoldt function `\Lambda(n) = \log p` 575 if `n = p^k` a power of a prime, and `\Lambda(n) = 0` otherwise. 576 577 **Examples** 578 579 >>> from mpmath import * 580 >>> mp.dps = 25; mp.pretty = True 581 >>> [mangoldt(n) for n in range(-2,3)] 582 [0.0, 0.0, 0.0, 0.0, 0.6931471805599453094172321] 583 >>> mangoldt(6) 584 0.0 585 >>> mangoldt(7) 586 1.945910149055313305105353 587 >>> mangoldt(8) 588 0.6931471805599453094172321 589 >>> fsum(mangoldt(n) for n in range(101)) 590 94.04531122935739224600493 591 >>> fsum(mangoldt(n) for n in range(10001)) 592 10013.39669326311478372032 593 594 """ 595 n = int(n) 596 if n < 2: 597 return ctx.zero 598 if n % 2 == 0: 599 # Must be a power of two 600 if n & (n-1) == 0: 601 return +ctx.ln2 602 else: 603 return ctx.zero 604 # TODO: the following could be generalized into a perfect 605 # power testing function 606 # --- 607 # Look for a small factor 608 for p in (3,5,7,11,13,17,19,23,29,31): 609 if not n % p: 610 q, r = n // p, 0 611 while q > 1: 612 q, r = divmod(q, p) 613 if r: 614 return ctx.zero 615 return ctx.ln(p) 616 if ctx.isprime(n): 617 return ctx.ln(n) 618 # Obviously, we could use arbitrary-precision arithmetic for this... 619 if n > 10**30: 620 raise NotImplementedError 621 k = 2 622 while 1: 623 p = int(n**(1./k) + 0.5) 624 if p < 2: 625 return ctx.zero 626 if p ** k == n: 627 if ctx.isprime(p): 628 return ctx.ln(p) 629 k += 1 630 631@defun 632def stirling1(ctx, n, k, exact=False): 633 v = ctx._stirling1(int(n), int(k)) 634 if exact: 635 return int(v) 636 else: 637 return ctx.mpf(v) 638 639@defun 640def stirling2(ctx, n, k, exact=False): 641 v = ctx._stirling2(int(n), int(k)) 642 if exact: 643 return int(v) 644 else: 645 return ctx.mpf(v) 646