1from operator import gt, lt 2 3from .libmp.backend import xrange 4 5from .functions.functions import SpecialFunctions 6from .functions.rszeta import RSCache 7from .calculus.quadrature import QuadratureMethods 8from .calculus.inverselaplace import LaplaceTransformInversionMethods 9from .calculus.calculus import CalculusMethods 10from .calculus.optimization import OptimizationMethods 11from .calculus.odes import ODEMethods 12from .matrices.matrices import MatrixMethods 13from .matrices.calculus import MatrixCalculusMethods 14from .matrices.linalg import LinearAlgebraMethods 15from .matrices.eigen import Eigen 16from .identification import IdentificationMethods 17from .visualization import VisualizationMethods 18 19from . import libmp 20 21class Context(object): 22 pass 23 24class StandardBaseContext(Context, 25 SpecialFunctions, 26 RSCache, 27 QuadratureMethods, 28 LaplaceTransformInversionMethods, 29 CalculusMethods, 30 MatrixMethods, 31 MatrixCalculusMethods, 32 LinearAlgebraMethods, 33 Eigen, 34 IdentificationMethods, 35 OptimizationMethods, 36 ODEMethods, 37 VisualizationMethods): 38 39 NoConvergence = libmp.NoConvergence 40 ComplexResult = libmp.ComplexResult 41 42 def __init__(ctx): 43 ctx._aliases = {} 44 # Call those that need preinitialization (e.g. for wrappers) 45 SpecialFunctions.__init__(ctx) 46 RSCache.__init__(ctx) 47 QuadratureMethods.__init__(ctx) 48 LaplaceTransformInversionMethods.__init__(ctx) 49 CalculusMethods.__init__(ctx) 50 MatrixMethods.__init__(ctx) 51 52 def _init_aliases(ctx): 53 for alias, value in ctx._aliases.items(): 54 try: 55 setattr(ctx, alias, getattr(ctx, value)) 56 except AttributeError: 57 pass 58 59 _fixed_precision = False 60 61 # XXX 62 verbose = False 63 64 def warn(ctx, msg): 65 print("Warning:", msg) 66 67 def bad_domain(ctx, msg): 68 raise ValueError(msg) 69 70 def _re(ctx, x): 71 if hasattr(x, "real"): 72 return x.real 73 return x 74 75 def _im(ctx, x): 76 if hasattr(x, "imag"): 77 return x.imag 78 return ctx.zero 79 80 def _as_points(ctx, x): 81 return x 82 83 def fneg(ctx, x, **kwargs): 84 return -ctx.convert(x) 85 86 def fadd(ctx, x, y, **kwargs): 87 return ctx.convert(x)+ctx.convert(y) 88 89 def fsub(ctx, x, y, **kwargs): 90 return ctx.convert(x)-ctx.convert(y) 91 92 def fmul(ctx, x, y, **kwargs): 93 return ctx.convert(x)*ctx.convert(y) 94 95 def fdiv(ctx, x, y, **kwargs): 96 return ctx.convert(x)/ctx.convert(y) 97 98 def fsum(ctx, args, absolute=False, squared=False): 99 if absolute: 100 if squared: 101 return sum((abs(x)**2 for x in args), ctx.zero) 102 return sum((abs(x) for x in args), ctx.zero) 103 if squared: 104 return sum((x**2 for x in args), ctx.zero) 105 return sum(args, ctx.zero) 106 107 def fdot(ctx, xs, ys=None, conjugate=False): 108 if ys is not None: 109 xs = zip(xs, ys) 110 if conjugate: 111 cf = ctx.conj 112 return sum((x*cf(y) for (x,y) in xs), ctx.zero) 113 else: 114 return sum((x*y for (x,y) in xs), ctx.zero) 115 116 def fprod(ctx, args): 117 prod = ctx.one 118 for arg in args: 119 prod *= arg 120 return prod 121 122 def nprint(ctx, x, n=6, **kwargs): 123 """ 124 Equivalent to ``print(nstr(x, n))``. 125 """ 126 print(ctx.nstr(x, n, **kwargs)) 127 128 def chop(ctx, x, tol=None): 129 """ 130 Chops off small real or imaginary parts, or converts 131 numbers close to zero to exact zeros. The input can be a 132 single number or an iterable:: 133 134 >>> from mpmath import * 135 >>> mp.dps = 15; mp.pretty = False 136 >>> chop(5+1e-10j, tol=1e-9) 137 mpf('5.0') 138 >>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2])) 139 [1.0, 0.0, 3.0, -4.0, 2.0] 140 141 The tolerance defaults to ``100*eps``. 142 """ 143 if tol is None: 144 tol = 100*ctx.eps 145 try: 146 x = ctx.convert(x) 147 absx = abs(x) 148 if abs(x) < tol: 149 return ctx.zero 150 if ctx._is_complex_type(x): 151 #part_tol = min(tol, absx*tol) 152 part_tol = max(tol, absx*tol) 153 if abs(x.imag) < part_tol: 154 return x.real 155 if abs(x.real) < part_tol: 156 return ctx.mpc(0, x.imag) 157 except TypeError: 158 if isinstance(x, ctx.matrix): 159 return x.apply(lambda a: ctx.chop(a, tol)) 160 if hasattr(x, "__iter__"): 161 return [ctx.chop(a, tol) for a in x] 162 return x 163 164 def almosteq(ctx, s, t, rel_eps=None, abs_eps=None): 165 r""" 166 Determine whether the difference between `s` and `t` is smaller 167 than a given epsilon, either relatively or absolutely. 168 169 Both a maximum relative difference and a maximum difference 170 ('epsilons') may be specified. The absolute difference is 171 defined as `|s-t|` and the relative difference is defined 172 as `|s-t|/\max(|s|, |t|)`. 173 174 If only one epsilon is given, both are set to the same value. 175 If none is given, both epsilons are set to `2^{-p+m}` where 176 `p` is the current working precision and `m` is a small 177 integer. The default setting typically allows :func:`~mpmath.almosteq` 178 to be used to check for mathematical equality 179 in the presence of small rounding errors. 180 181 **Examples** 182 183 >>> from mpmath import * 184 >>> mp.dps = 15 185 >>> almosteq(3.141592653589793, 3.141592653589790) 186 True 187 >>> almosteq(3.141592653589793, 3.141592653589700) 188 False 189 >>> almosteq(3.141592653589793, 3.141592653589700, 1e-10) 190 True 191 >>> almosteq(1e-20, 2e-20) 192 True 193 >>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0) 194 False 195 196 """ 197 t = ctx.convert(t) 198 if abs_eps is None and rel_eps is None: 199 rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4) 200 if abs_eps is None: 201 abs_eps = rel_eps 202 elif rel_eps is None: 203 rel_eps = abs_eps 204 diff = abs(s-t) 205 if diff <= abs_eps: 206 return True 207 abss = abs(s) 208 abst = abs(t) 209 if abss < abst: 210 err = diff/abst 211 else: 212 err = diff/abss 213 return err <= rel_eps 214 215 def arange(ctx, *args): 216 r""" 217 This is a generalized version of Python's :func:`~mpmath.range` function 218 that accepts fractional endpoints and step sizes and 219 returns a list of ``mpf`` instances. Like :func:`~mpmath.range`, 220 :func:`~mpmath.arange` can be called with 1, 2 or 3 arguments: 221 222 ``arange(b)`` 223 `[0, 1, 2, \ldots, x]` 224 ``arange(a, b)`` 225 `[a, a+1, a+2, \ldots, x]` 226 ``arange(a, b, h)`` 227 `[a, a+h, a+h, \ldots, x]` 228 229 where `b-1 \le x < b` (in the third case, `b-h \le x < b`). 230 231 Like Python's :func:`~mpmath.range`, the endpoint is not included. To 232 produce ranges where the endpoint is included, :func:`~mpmath.linspace` 233 is more convenient. 234 235 **Examples** 236 237 >>> from mpmath import * 238 >>> mp.dps = 15; mp.pretty = False 239 >>> arange(4) 240 [mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')] 241 >>> arange(1, 2, 0.25) 242 [mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')] 243 >>> arange(1, -1, -0.75) 244 [mpf('1.0'), mpf('0.25'), mpf('-0.5')] 245 246 """ 247 if not len(args) <= 3: 248 raise TypeError('arange expected at most 3 arguments, got %i' 249 % len(args)) 250 if not len(args) >= 1: 251 raise TypeError('arange expected at least 1 argument, got %i' 252 % len(args)) 253 # set default 254 a = 0 255 dt = 1 256 # interpret arguments 257 if len(args) == 1: 258 b = args[0] 259 elif len(args) >= 2: 260 a = args[0] 261 b = args[1] 262 if len(args) == 3: 263 dt = args[2] 264 a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt) 265 assert a + dt != a, 'dt is too small and would cause an infinite loop' 266 # adapt code for sign of dt 267 if a > b: 268 if dt > 0: 269 return [] 270 op = gt 271 else: 272 if dt < 0: 273 return [] 274 op = lt 275 # create list 276 result = [] 277 i = 0 278 t = a 279 while 1: 280 t = a + dt*i 281 i += 1 282 if op(t, b): 283 result.append(t) 284 else: 285 break 286 return result 287 288 def linspace(ctx, *args, **kwargs): 289 """ 290 ``linspace(a, b, n)`` returns a list of `n` evenly spaced 291 samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)`` 292 is also valid. 293 294 This function is often more convenient than :func:`~mpmath.arange` 295 for partitioning an interval into subintervals, since 296 the endpoint is included:: 297 298 >>> from mpmath import * 299 >>> mp.dps = 15; mp.pretty = False 300 >>> linspace(1, 4, 4) 301 [mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')] 302 303 You may also provide the keyword argument ``endpoint=False``:: 304 305 >>> linspace(1, 4, 4, endpoint=False) 306 [mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')] 307 308 """ 309 if len(args) == 3: 310 a = ctx.mpf(args[0]) 311 b = ctx.mpf(args[1]) 312 n = int(args[2]) 313 elif len(args) == 2: 314 assert hasattr(args[0], '_mpi_') 315 a = args[0].a 316 b = args[0].b 317 n = int(args[1]) 318 else: 319 raise TypeError('linspace expected 2 or 3 arguments, got %i' \ 320 % len(args)) 321 if n < 1: 322 raise ValueError('n must be greater than 0') 323 if not 'endpoint' in kwargs or kwargs['endpoint']: 324 if n == 1: 325 return [ctx.mpf(a)] 326 step = (b - a) / ctx.mpf(n - 1) 327 y = [i*step + a for i in xrange(n)] 328 y[-1] = b 329 else: 330 step = (b - a) / ctx.mpf(n) 331 y = [i*step + a for i in xrange(n)] 332 return y 333 334 def cos_sin(ctx, z, **kwargs): 335 return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs) 336 337 def cospi_sinpi(ctx, z, **kwargs): 338 return ctx.cospi(z, **kwargs), ctx.sinpi(z, **kwargs) 339 340 def _default_hyper_maxprec(ctx, p): 341 return int(1000 * p**0.25 + 4*p) 342 343 _gcd = staticmethod(libmp.gcd) 344 list_primes = staticmethod(libmp.list_primes) 345 isprime = staticmethod(libmp.isprime) 346 bernfrac = staticmethod(libmp.bernfrac) 347 moebius = staticmethod(libmp.moebius) 348 _ifac = staticmethod(libmp.ifac) 349 _eulernum = staticmethod(libmp.eulernum) 350 _stirling1 = staticmethod(libmp.stirling1) 351 _stirling2 = staticmethod(libmp.stirling2) 352 353 def sum_accurately(ctx, terms, check_step=1): 354 prec = ctx.prec 355 try: 356 extraprec = 10 357 while 1: 358 ctx.prec = prec + extraprec + 5 359 max_mag = ctx.ninf 360 s = ctx.zero 361 k = 0 362 for term in terms(): 363 s += term 364 if (not k % check_step) and term: 365 term_mag = ctx.mag(term) 366 max_mag = max(max_mag, term_mag) 367 sum_mag = ctx.mag(s) 368 if sum_mag - term_mag > ctx.prec: 369 break 370 k += 1 371 cancellation = max_mag - sum_mag 372 if cancellation != cancellation: 373 break 374 if cancellation < extraprec or ctx._fixed_precision: 375 break 376 extraprec += min(ctx.prec, cancellation) 377 return s 378 finally: 379 ctx.prec = prec 380 381 def mul_accurately(ctx, factors, check_step=1): 382 prec = ctx.prec 383 try: 384 extraprec = 10 385 while 1: 386 ctx.prec = prec + extraprec + 5 387 max_mag = ctx.ninf 388 one = ctx.one 389 s = one 390 k = 0 391 for factor in factors(): 392 s *= factor 393 term = factor - one 394 if (not k % check_step): 395 term_mag = ctx.mag(term) 396 max_mag = max(max_mag, term_mag) 397 sum_mag = ctx.mag(s-one) 398 #if sum_mag - term_mag > ctx.prec: 399 # break 400 if -term_mag > ctx.prec: 401 break 402 k += 1 403 cancellation = max_mag - sum_mag 404 if cancellation != cancellation: 405 break 406 if cancellation < extraprec or ctx._fixed_precision: 407 break 408 extraprec += min(ctx.prec, cancellation) 409 return s 410 finally: 411 ctx.prec = prec 412 413 def power(ctx, x, y): 414 r"""Converts `x` and `y` to mpmath numbers and evaluates 415 `x^y = \exp(y \log(x))`:: 416 417 >>> from mpmath import * 418 >>> mp.dps = 30; mp.pretty = True 419 >>> power(2, 0.5) 420 1.41421356237309504880168872421 421 422 This shows the leading few digits of a large Mersenne prime 423 (performing the exact calculation ``2**43112609-1`` and 424 displaying the result in Python would be very slow):: 425 426 >>> power(2, 43112609)-1 427 3.16470269330255923143453723949e+12978188 428 """ 429 return ctx.convert(x) ** ctx.convert(y) 430 431 def _zeta_int(ctx, n): 432 return ctx.zeta(n) 433 434 def maxcalls(ctx, f, N): 435 """ 436 Return a wrapped copy of *f* that raises ``NoConvergence`` when *f* 437 has been called more than *N* times:: 438 439 >>> from mpmath import * 440 >>> mp.dps = 15 441 >>> f = maxcalls(sin, 10) 442 >>> print(sum(f(n) for n in range(10))) 443 1.95520948210738 444 >>> f(10) # doctest: +IGNORE_EXCEPTION_DETAIL 445 Traceback (most recent call last): 446 ... 447 NoConvergence: maxcalls: function evaluated 10 times 448 449 """ 450 counter = [0] 451 def f_maxcalls_wrapped(*args, **kwargs): 452 counter[0] += 1 453 if counter[0] > N: 454 raise ctx.NoConvergence("maxcalls: function evaluated %i times" % N) 455 return f(*args, **kwargs) 456 return f_maxcalls_wrapped 457 458 def memoize(ctx, f): 459 """ 460 Return a wrapped copy of *f* that caches computed values, i.e. 461 a memoized copy of *f*. Values are only reused if the cached precision 462 is equal to or higher than the working precision:: 463 464 >>> from mpmath import * 465 >>> mp.dps = 15; mp.pretty = True 466 >>> f = memoize(maxcalls(sin, 1)) 467 >>> f(2) 468 0.909297426825682 469 >>> f(2) 470 0.909297426825682 471 >>> mp.dps = 25 472 >>> f(2) # doctest: +IGNORE_EXCEPTION_DETAIL 473 Traceback (most recent call last): 474 ... 475 NoConvergence: maxcalls: function evaluated 1 times 476 477 """ 478 f_cache = {} 479 def f_cached(*args, **kwargs): 480 if kwargs: 481 key = args, tuple(kwargs.items()) 482 else: 483 key = args 484 prec = ctx.prec 485 if key in f_cache: 486 cprec, cvalue = f_cache[key] 487 if cprec >= prec: 488 return +cvalue 489 value = f(*args, **kwargs) 490 f_cache[key] = (prec, value) 491 return value 492 f_cached.__name__ = f.__name__ 493 f_cached.__doc__ = f.__doc__ 494 return f_cached 495