1from ..core import (Add, Dummy, Expr, Integer, Mul, Symbol, Tuple, cacheit, 2 expand_log, expand_power_base, nan, oo) 3from ..core.compatibility import is_sequence 4from ..core.sympify import sympify 5from ..utilities import default_sort_key 6from ..utilities.iterables import uniq 7 8 9class Order(Expr): 10 r"""Represents the limiting behavior of function. 11 12 The formal definition for order symbol `O(f(x))` (Big O) is 13 that `g(x) \in O(f(x))` as `x\to a` iff 14 15 .. math:: \lim\limits_{x \rightarrow a} \sup 16 \left|\frac{g(x)}{f(x)}\right| < \infty 17 18 Parameters 19 ========== 20 21 expr : Expr 22 an expression 23 args : sequence of Symbol's or pairs (Symbol, Expr), optional 24 If only symbols are provided, i.e. no limit point are 25 passed, then the limit point is assumed to be zero. If no 26 symbols are passed then all symbols in the expression are used. 27 28 Examples 29 ======== 30 31 The order of a function can be intuitively thought of representing all 32 terms of powers greater than the one specified. For example, `O(x^3)` 33 corresponds to any terms proportional to `x^3, x^4,\ldots` and any 34 higher power. For a polynomial, this leaves terms proportional 35 to `x^2`, `x` and constants. 36 37 >>> 1 + x + x**2 + x**3 + x**4 + O(x**3) 38 1 + x + x**2 + O(x**3) 39 40 ``O(f(x))`` is automatically transformed to ``O(f(x).as_leading_term(x))``: 41 42 >>> O(x + x**2) 43 O(x) 44 >>> O(cos(x)) 45 O(1) 46 47 Some arithmetic operations: 48 49 >>> O(x)*x 50 O(x**2) 51 >>> O(x) - O(x) 52 O(x) 53 54 The Big O symbol is a set, so we support membership test: 55 56 >>> x in O(x) 57 True 58 >>> O(1) in O(1, x) 59 True 60 >>> O(1, x) in O(1) 61 False 62 >>> O(x) in O(1, x) 63 True 64 >>> O(x**2) in O(x) 65 True 66 67 Limit points other then zero and multivariate Big O are also supported: 68 69 >>> O(x) == O(x, (x, 0)) 70 True 71 >>> O(x + x**2, (x, oo)) 72 O(x**2, (x, oo)) 73 >>> O(cos(x), (x, pi/2)) 74 O(x - pi/2, (x, pi/2)) 75 76 >>> O(1 + x*y) 77 O(1, x, y) 78 >>> O(1 + x*y, (x, 0), (y, 0)) 79 O(1, x, y) 80 >>> O(1 + x*y, (x, oo), (y, oo)) 81 O(x*y, (x, oo), (y, oo)) 82 83 References 84 ========== 85 86 * https://en.wikipedia.org/wiki/Big_O_notation 87 88 """ 89 90 is_Order = True 91 92 @cacheit 93 def __new__(cls, expr, *args, **kwargs): 94 expr = sympify(expr) 95 96 if not args: 97 if expr.is_Order: 98 variables = expr.variables 99 point = expr.point 100 else: 101 variables = list(expr.free_symbols) 102 point = [Integer(0)]*len(variables) 103 else: 104 args = list(args if is_sequence(args) else [args]) 105 variables, point = [], [] 106 if is_sequence(args[0]): 107 for a in args: 108 v, p = list(map(sympify, a)) 109 variables.append(v) 110 point.append(p) 111 else: 112 variables = list(map(sympify, args)) 113 point = [Integer(0)]*len(variables) 114 115 if not all(isinstance(v, (Dummy, Symbol)) for v in variables): 116 raise TypeError(f'Variables are not symbols, got {variables}') 117 118 if len(list(uniq(variables))) != len(variables): 119 raise ValueError(f'Variables are supposed to be unique symbols, got {variables}') 120 121 if expr.is_Order: 122 expr_vp = dict(expr.args[1:]) 123 new_vp = dict(expr_vp) 124 vp = dict(zip(variables, point)) 125 for v, p in vp.items(): 126 if v in new_vp: 127 if p != new_vp[v]: 128 raise NotImplementedError( 129 'Mixing Order at different points is not supported.') 130 else: 131 new_vp[v] = p 132 if set(expr_vp) == set(new_vp): 133 return expr 134 else: 135 variables = list(new_vp) 136 point = [new_vp[v] for v in variables] 137 138 if expr is nan: 139 return nan 140 141 if any(x in p.free_symbols for x in variables for p in point): 142 raise ValueError(f'Got {point} as a point.') 143 144 if variables: 145 if any(p != point[0] for p in point): 146 raise NotImplementedError 147 if point[0] in [oo, -oo]: 148 s = {k: 1/Dummy() for k in variables} 149 rs = {1/v: 1/k for k, v in s.items()} 150 elif point[0] != 0: 151 s = {k: Dummy() + point[0] for k in variables} 152 rs = {v - point[0]: k - point[0] for k, v in s.items()} 153 else: 154 s = () 155 rs = () 156 157 expr = expr.subs(s) 158 159 if expr.is_Add: 160 from ..core import expand_multinomial 161 expr = expand_multinomial(expr) 162 163 if s: 164 args = tuple(r[0] for r in rs.items()) 165 else: 166 args = tuple(variables) 167 168 if len(variables) > 1: 169 # XXX: better way? We need this expand() to 170 # workaround e.g: expr = x*(x + y). 171 # (x*(x + y)).as_leading_term(x, y) currently returns 172 # x*y (wrong order term!). That's why we want to deal with 173 # expand()'ed expr (handled in "if expr.is_Add" branch below). 174 expr = expr.expand() 175 176 if expr.is_Add: 177 lst = expr.extract_leading_order(args) 178 expr = Add(*[f.expr for (e, f) in lst]) 179 180 elif expr: 181 expr = expr.as_leading_term(*args) 182 expr = expr.as_independent(*args, as_Add=False)[1] 183 184 expr = expand_power_base(expr) 185 expr = expand_log(expr) 186 187 if len(args) == 1: 188 # The definition of O(f(x)) symbol explicitly stated that 189 # the argument of f(x) is irrelevant. That's why we can 190 # combine some power exponents (only "on top" of the 191 # expression tree for f(x)), e.g.: 192 # x**p * (-x)**q -> x**(p+q) for real p, q. 193 x = args[0] 194 margs = list(Mul.make_args( 195 expr.as_independent(x, as_Add=False)[1])) 196 197 for i, t in enumerate(margs): 198 if t.is_Pow: 199 b, q = t.base, t.exp 200 if b in (x, -x) and q.is_extended_real and not q.has(x): 201 margs[i] = x**q 202 elif b.is_Pow and not b.exp.has(x): 203 b, r = b.base, b.exp 204 if b in (x, -x) and r.is_extended_real: 205 margs[i] = x**(r*q) 206 elif b.is_Mul and b.args[0] == -1: 207 b = -b 208 if b.is_Pow and not b.exp.has(x): 209 b, r = b.base, b.exp 210 if b in (x, -x) and r.is_extended_real: 211 margs[i] = x**(r*q) 212 213 expr = Mul(*margs) 214 215 expr = expr.subs(rs) 216 217 if expr == 0: 218 return expr 219 220 if expr.is_Order: 221 expr = expr.expr 222 223 if not expr.has(*variables): 224 expr = Integer(1) 225 226 # create Order instance: 227 vp = dict(zip(variables, point)) 228 variables.sort(key=default_sort_key) 229 point = [vp[v] for v in variables] 230 args = (expr,) + Tuple(*zip(variables, point)) 231 obj = Expr.__new__(cls, *args) 232 return obj 233 234 def _eval_nseries(self, x, n, logx): 235 return self 236 237 @property 238 def expr(self): 239 return self.args[0] 240 241 @property 242 def variables(self): 243 if self.args[1:]: 244 return tuple(x[0] for x in self.args[1:]) 245 else: 246 return () 247 248 @property 249 def point(self): 250 if self.args[1:]: 251 return tuple(x[1] for x in self.args[1:]) 252 else: 253 return () 254 255 @property 256 def free_symbols(self): 257 return self.expr.free_symbols | set(self.variables) 258 259 def _eval_power(self, other): 260 if other.is_Number and other.is_nonnegative: 261 return self.func(self.expr**other, *self.args[1:]) 262 if other == O(1): 263 return self 264 265 def as_expr_variables(self, order_symbols): 266 if order_symbols is None: 267 order_symbols = self.args[1:] 268 else: 269 if (not all(o[1] == order_symbols[0][1] for o in order_symbols) and 270 not all(p == self.point[0] for p in self.point)): # pragma: no cover 271 raise NotImplementedError('Order at points other than 0 ' 272 f'or oo not supported, got {self.point} as a point.') 273 if order_symbols and order_symbols[0][1] != self.point[0]: 274 raise NotImplementedError( 275 'Multiplying Order at different points is not supported.') 276 order_symbols = dict(order_symbols) 277 for s, p in dict(self.args[1:]).items(): 278 if s not in order_symbols: 279 order_symbols[s] = p 280 order_symbols = sorted(order_symbols.items(), key=lambda x: default_sort_key(x[0])) 281 return self.expr, tuple(order_symbols) 282 283 def removeO(self): 284 return Integer(0) 285 286 def getO(self): 287 return self 288 289 @cacheit 290 def contains(self, expr): 291 """Membership test. 292 293 Returns 294 ======= 295 296 Boolean or None 297 Return True if ``expr`` belongs to ``self``. Return False if 298 ``self`` belongs to ``expr``. Return None if the inclusion 299 relation cannot be determined. 300 301 """ 302 from ..simplify import powsimp 303 from .limits import Limit 304 if expr == 0: 305 return True 306 if expr is nan: 307 return False 308 if expr.is_Order: 309 if (not all(p == expr.point[0] for p in expr.point) and 310 not all(p == self.point[0] for p in self.point)): # pragma: no cover 311 raise NotImplementedError('Order at points other than 0 ' 312 f'or oo not supported, got {self.point} as a point.') 313 else: 314 # self and/or expr is O(1): 315 if any(not p for p in [expr.point, self.point]): 316 point = self.point + expr.point 317 if point: 318 point = point[0] 319 else: 320 point = Integer(0) 321 else: 322 point = self.point[0] 323 if expr.expr == self.expr: 324 # O(1) + O(1), O(1) + O(1, x), etc. 325 return all(x in self.args[1:] for x in expr.args[1:]) 326 if expr.expr.is_Add: 327 return all(self.contains(x) for x in expr.expr.args) 328 if self.expr.is_Add and point == 0: 329 return any(self.func(x, *self.args[1:]).contains(expr) 330 for x in self.expr.args) 331 if self.variables and expr.variables: 332 common_symbols = tuple(s for s in self.variables if s in expr.variables) 333 elif self.variables: 334 common_symbols = self.variables 335 else: 336 common_symbols = expr.variables 337 if not common_symbols: 338 return 339 r = None 340 ratio = self.expr/expr.expr 341 ratio = powsimp(ratio, deep=True, combine='exp') 342 for s in common_symbols: 343 l = Limit(ratio, s, point).doit(heuristics=False) 344 if not isinstance(l, Limit): 345 l = l != 0 346 else: 347 l = None 348 if r is None: 349 r = l 350 else: 351 if r != l: 352 return 353 return r 354 obj = self.func(expr, *self.args[1:]) 355 return self.contains(obj) 356 357 def __contains__(self, other): 358 result = self.contains(other) 359 if result is None: 360 raise TypeError('contains did not evaluate to a bool') 361 return result 362 363 def _eval_subs(self, old, new): 364 if old in self.variables: 365 newexpr = self.expr.subs({old: new}) 366 i = self.variables.index(old) 367 newvars = list(self.variables) 368 newpt = list(self.point) 369 if new.is_Symbol: 370 newvars[i] = new 371 else: 372 syms = new.free_symbols 373 if len(syms) == 1 or old in syms: 374 if old in syms: 375 var = self.variables[i] 376 else: 377 var = syms.pop() 378 # First, try to substitute self.point in the "new" 379 # expr to see if this is a fixed point. 380 # E.g. O(y).subs({y: sin(x)}) 381 point = new.subs({var: self.point[i]}) 382 if point != self.point[i]: 383 from ..solvers import solve 384 d = Dummy() 385 res = solve(old - new.subs({var: d}), d) 386 point = d.subs(res[0]).limit(old, self.point[i]) 387 newvars[i] = var 388 newpt[i] = point 389 else: 390 del newvars[i], newpt[i] 391 if not syms and new == self.point[i]: 392 newvars.extend(syms) 393 newpt.extend([Integer(0)]*len(syms)) 394 return Order(newexpr, *zip(newvars, newpt)) 395 396 def _eval_conjugate(self): 397 expr = self.expr._eval_conjugate() 398 if expr is not None: 399 return self.func(expr, *self.args[1:]) 400 401 def _eval_transpose(self): 402 expr = self.expr._eval_transpose() 403 if expr is not None: 404 return self.func(expr, *self.args[1:]) 405 406 def _eval_is_commutative(self): 407 return self.expr.is_commutative 408 409 410O = Order 411