1from ..core import (Add, Dummy, Expr, Integer, Mul, Symbol, Tuple, cacheit,
2                    expand_log, expand_power_base, nan, oo)
3from ..core.compatibility import is_sequence
4from ..core.sympify import sympify
5from ..utilities import default_sort_key
6from ..utilities.iterables import uniq
7
8
9class Order(Expr):
10    r"""Represents the limiting behavior of function.
11
12    The formal definition for order symbol `O(f(x))` (Big O) is
13    that `g(x) \in O(f(x))` as `x\to a` iff
14
15    .. math:: \lim\limits_{x \rightarrow a} \sup
16              \left|\frac{g(x)}{f(x)}\right| < \infty
17
18    Parameters
19    ==========
20
21    expr : Expr
22        an expression
23    args : sequence of Symbol's or pairs (Symbol, Expr), optional
24        If only symbols are provided, i.e. no limit point are
25        passed, then the limit point is assumed to be zero.  If no
26        symbols are passed then all symbols in the expression are used.
27
28    Examples
29    ========
30
31    The order of a function can be intuitively thought of representing all
32    terms of powers greater than the one specified.  For example, `O(x^3)`
33    corresponds to any terms proportional to `x^3, x^4,\ldots` and any
34    higher power.  For a polynomial, this leaves terms proportional
35    to `x^2`, `x` and constants.
36
37    >>> 1 + x + x**2 + x**3 + x**4 + O(x**3)
38    1 + x + x**2 + O(x**3)
39
40    ``O(f(x))`` is automatically transformed to ``O(f(x).as_leading_term(x))``:
41
42    >>> O(x + x**2)
43    O(x)
44    >>> O(cos(x))
45    O(1)
46
47    Some arithmetic operations:
48
49    >>> O(x)*x
50    O(x**2)
51    >>> O(x) - O(x)
52    O(x)
53
54    The Big O symbol is a set, so we support membership test:
55
56    >>> x in O(x)
57    True
58    >>> O(1) in O(1, x)
59    True
60    >>> O(1, x) in O(1)
61    False
62    >>> O(x) in O(1, x)
63    True
64    >>> O(x**2) in O(x)
65    True
66
67    Limit points other then zero and multivariate Big O are also supported:
68
69    >>> O(x) == O(x, (x, 0))
70    True
71    >>> O(x + x**2, (x, oo))
72    O(x**2, (x, oo))
73    >>> O(cos(x), (x, pi/2))
74    O(x - pi/2, (x, pi/2))
75
76    >>> O(1 + x*y)
77    O(1, x, y)
78    >>> O(1 + x*y, (x, 0), (y, 0))
79    O(1, x, y)
80    >>> O(1 + x*y, (x, oo), (y, oo))
81    O(x*y, (x, oo), (y, oo))
82
83    References
84    ==========
85
86    * https://en.wikipedia.org/wiki/Big_O_notation
87
88    """
89
90    is_Order = True
91
92    @cacheit
93    def __new__(cls, expr, *args, **kwargs):
94        expr = sympify(expr)
95
96        if not args:
97            if expr.is_Order:
98                variables = expr.variables
99                point = expr.point
100            else:
101                variables = list(expr.free_symbols)
102                point = [Integer(0)]*len(variables)
103        else:
104            args = list(args if is_sequence(args) else [args])
105            variables, point = [], []
106            if is_sequence(args[0]):
107                for a in args:
108                    v, p = list(map(sympify, a))
109                    variables.append(v)
110                    point.append(p)
111            else:
112                variables = list(map(sympify, args))
113                point = [Integer(0)]*len(variables)
114
115        if not all(isinstance(v, (Dummy, Symbol)) for v in variables):
116            raise TypeError(f'Variables are not symbols, got {variables}')
117
118        if len(list(uniq(variables))) != len(variables):
119            raise ValueError(f'Variables are supposed to be unique symbols, got {variables}')
120
121        if expr.is_Order:
122            expr_vp = dict(expr.args[1:])
123            new_vp = dict(expr_vp)
124            vp = dict(zip(variables, point))
125            for v, p in vp.items():
126                if v in new_vp:
127                    if p != new_vp[v]:
128                        raise NotImplementedError(
129                            'Mixing Order at different points is not supported.')
130                else:
131                    new_vp[v] = p
132            if set(expr_vp) == set(new_vp):
133                return expr
134            else:
135                variables = list(new_vp)
136                point = [new_vp[v] for v in variables]
137
138        if expr is nan:
139            return nan
140
141        if any(x in p.free_symbols for x in variables for p in point):
142            raise ValueError(f'Got {point} as a point.')
143
144        if variables:
145            if any(p != point[0] for p in point):
146                raise NotImplementedError
147            if point[0] in [oo, -oo]:
148                s = {k: 1/Dummy() for k in variables}
149                rs = {1/v: 1/k for k, v in s.items()}
150            elif point[0] != 0:
151                s = {k: Dummy() + point[0] for k in variables}
152                rs = {v - point[0]: k - point[0] for k, v in s.items()}
153            else:
154                s = ()
155                rs = ()
156
157            expr = expr.subs(s)
158
159            if expr.is_Add:
160                from ..core import expand_multinomial
161                expr = expand_multinomial(expr)
162
163            if s:
164                args = tuple(r[0] for r in rs.items())
165            else:
166                args = tuple(variables)
167
168            if len(variables) > 1:
169                # XXX: better way?  We need this expand() to
170                # workaround e.g: expr = x*(x + y).
171                # (x*(x + y)).as_leading_term(x, y) currently returns
172                # x*y (wrong order term!).  That's why we want to deal with
173                # expand()'ed expr (handled in "if expr.is_Add" branch below).
174                expr = expr.expand()
175
176            if expr.is_Add:
177                lst = expr.extract_leading_order(args)
178                expr = Add(*[f.expr for (e, f) in lst])
179
180            elif expr:
181                expr = expr.as_leading_term(*args)
182                expr = expr.as_independent(*args, as_Add=False)[1]
183
184                expr = expand_power_base(expr)
185                expr = expand_log(expr)
186
187                if len(args) == 1:
188                    # The definition of O(f(x)) symbol explicitly stated that
189                    # the argument of f(x) is irrelevant.  That's why we can
190                    # combine some power exponents (only "on top" of the
191                    # expression tree for f(x)), e.g.:
192                    # x**p * (-x)**q -> x**(p+q) for real p, q.
193                    x = args[0]
194                    margs = list(Mul.make_args(
195                        expr.as_independent(x, as_Add=False)[1]))
196
197                    for i, t in enumerate(margs):
198                        if t.is_Pow:
199                            b, q = t.base, t.exp
200                            if b in (x, -x) and q.is_extended_real and not q.has(x):
201                                margs[i] = x**q
202                            elif b.is_Pow and not b.exp.has(x):
203                                b, r = b.base, b.exp
204                                if b in (x, -x) and r.is_extended_real:
205                                    margs[i] = x**(r*q)
206                            elif b.is_Mul and b.args[0] == -1:
207                                b = -b
208                                if b.is_Pow and not b.exp.has(x):
209                                    b, r = b.base, b.exp
210                                    if b in (x, -x) and r.is_extended_real:
211                                        margs[i] = x**(r*q)
212
213                    expr = Mul(*margs)
214
215            expr = expr.subs(rs)
216
217        if expr == 0:
218            return expr
219
220        if expr.is_Order:
221            expr = expr.expr
222
223        if not expr.has(*variables):
224            expr = Integer(1)
225
226        # create Order instance:
227        vp = dict(zip(variables, point))
228        variables.sort(key=default_sort_key)
229        point = [vp[v] for v in variables]
230        args = (expr,) + Tuple(*zip(variables, point))
231        obj = Expr.__new__(cls, *args)
232        return obj
233
234    def _eval_nseries(self, x, n, logx):
235        return self
236
237    @property
238    def expr(self):
239        return self.args[0]
240
241    @property
242    def variables(self):
243        if self.args[1:]:
244            return tuple(x[0] for x in self.args[1:])
245        else:
246            return ()
247
248    @property
249    def point(self):
250        if self.args[1:]:
251            return tuple(x[1] for x in self.args[1:])
252        else:
253            return ()
254
255    @property
256    def free_symbols(self):
257        return self.expr.free_symbols | set(self.variables)
258
259    def _eval_power(self, other):
260        if other.is_Number and other.is_nonnegative:
261            return self.func(self.expr**other, *self.args[1:])
262        if other == O(1):
263            return self
264
265    def as_expr_variables(self, order_symbols):
266        if order_symbols is None:
267            order_symbols = self.args[1:]
268        else:
269            if (not all(o[1] == order_symbols[0][1] for o in order_symbols) and
270                    not all(p == self.point[0] for p in self.point)):  # pragma: no cover
271                raise NotImplementedError('Order at points other than 0 '
272                                          f'or oo not supported, got {self.point} as a point.')
273            if order_symbols and order_symbols[0][1] != self.point[0]:
274                raise NotImplementedError(
275                    'Multiplying Order at different points is not supported.')
276            order_symbols = dict(order_symbols)
277            for s, p in dict(self.args[1:]).items():
278                if s not in order_symbols:
279                    order_symbols[s] = p
280            order_symbols = sorted(order_symbols.items(), key=lambda x: default_sort_key(x[0]))
281        return self.expr, tuple(order_symbols)
282
283    def removeO(self):
284        return Integer(0)
285
286    def getO(self):
287        return self
288
289    @cacheit
290    def contains(self, expr):
291        """Membership test.
292
293        Returns
294        =======
295
296        Boolean or None
297            Return True if ``expr`` belongs to ``self``.  Return False if
298            ``self`` belongs to ``expr``.  Return None if the inclusion
299            relation cannot be determined.
300
301        """
302        from ..simplify import powsimp
303        from .limits import Limit
304        if expr == 0:
305            return True
306        if expr is nan:
307            return False
308        if expr.is_Order:
309            if (not all(p == expr.point[0] for p in expr.point) and
310                    not all(p == self.point[0] for p in self.point)):  # pragma: no cover
311                raise NotImplementedError('Order at points other than 0 '
312                                          f'or oo not supported, got {self.point} as a point.')
313            else:
314                # self and/or expr is O(1):
315                if any(not p for p in [expr.point, self.point]):
316                    point = self.point + expr.point
317                    if point:
318                        point = point[0]
319                    else:
320                        point = Integer(0)
321                else:
322                    point = self.point[0]
323            if expr.expr == self.expr:
324                # O(1) + O(1), O(1) + O(1, x), etc.
325                return all(x in self.args[1:] for x in expr.args[1:])
326            if expr.expr.is_Add:
327                return all(self.contains(x) for x in expr.expr.args)
328            if self.expr.is_Add and point == 0:
329                return any(self.func(x, *self.args[1:]).contains(expr)
330                           for x in self.expr.args)
331            if self.variables and expr.variables:
332                common_symbols = tuple(s for s in self.variables if s in expr.variables)
333            elif self.variables:
334                common_symbols = self.variables
335            else:
336                common_symbols = expr.variables
337            if not common_symbols:
338                return
339            r = None
340            ratio = self.expr/expr.expr
341            ratio = powsimp(ratio, deep=True, combine='exp')
342            for s in common_symbols:
343                l = Limit(ratio, s, point).doit(heuristics=False)
344                if not isinstance(l, Limit):
345                    l = l != 0
346                else:
347                    l = None
348                if r is None:
349                    r = l
350                else:
351                    if r != l:
352                        return
353            return r
354        obj = self.func(expr, *self.args[1:])
355        return self.contains(obj)
356
357    def __contains__(self, other):
358        result = self.contains(other)
359        if result is None:
360            raise TypeError('contains did not evaluate to a bool')
361        return result
362
363    def _eval_subs(self, old, new):
364        if old in self.variables:
365            newexpr = self.expr.subs({old: new})
366            i = self.variables.index(old)
367            newvars = list(self.variables)
368            newpt = list(self.point)
369            if new.is_Symbol:
370                newvars[i] = new
371            else:
372                syms = new.free_symbols
373                if len(syms) == 1 or old in syms:
374                    if old in syms:
375                        var = self.variables[i]
376                    else:
377                        var = syms.pop()
378                    # First, try to substitute self.point in the "new"
379                    # expr to see if this is a fixed point.
380                    # E.g.  O(y).subs({y: sin(x)})
381                    point = new.subs({var: self.point[i]})
382                    if point != self.point[i]:
383                        from ..solvers import solve
384                        d = Dummy()
385                        res = solve(old - new.subs({var: d}), d)
386                        point = d.subs(res[0]).limit(old, self.point[i])
387                    newvars[i] = var
388                    newpt[i] = point
389                else:
390                    del newvars[i], newpt[i]
391                    if not syms and new == self.point[i]:
392                        newvars.extend(syms)
393                        newpt.extend([Integer(0)]*len(syms))
394            return Order(newexpr, *zip(newvars, newpt))
395
396    def _eval_conjugate(self):
397        expr = self.expr._eval_conjugate()
398        if expr is not None:
399            return self.func(expr, *self.args[1:])
400
401    def _eval_transpose(self):
402        expr = self.expr._eval_transpose()
403        if expr is not None:
404            return self.func(expr, *self.args[1:])
405
406    def _eval_is_commutative(self):
407        return self.expr.is_commutative
408
409
410O = Order
411