1"""
2There are three types of functions implemented in SymPy:
3
4    1) defined functions (in the sense that they can be evaluated) like
5       exp or sin; they have a name and a body:
6           f = exp
7    2) undefined function which have a name but no body. Undefined
8       functions can be defined using a Function class as follows:
9           f = Function('f')
10       (the result will be a Function instance)
11    3) anonymous function (or lambda function) which have a body (defined
12       with dummy variables) but have no name:
13           f = Lambda(x, exp(x)*x)
14           f = Lambda((x, y), exp(x)*y)
15    The fourth type of functions are composites, like (sin + cos)(x); these work in
16    SymPy core, but are not yet part of SymPy.
17
18    Examples
19    ========
20
21    >>> import sympy
22    >>> f = sympy.Function("f")
23    >>> from sympy.abc import x
24    >>> f(x)
25    f(x)
26    >>> print(sympy.srepr(f(x).func))
27    Function('f')
28    >>> f(x).args
29    (x,)
30
31"""
32
33from typing import Any, Dict as tDict, Optional, Set as tSet, Tuple as tTuple, Union
34
35from .add import Add
36from .assumptions import ManagedProperties
37from .basic import Basic, _atomic
38from .cache import cacheit
39from .compatibility import iterable, is_sequence, as_int, ordered, Iterable
40from .decorators import _sympifyit
41from .expr import Expr, AtomicExpr
42from .numbers import Rational, Float
43from .operations import LatticeOp
44from .rules import Transform
45from .singleton import S
46from .sympify import sympify
47
48from sympy.core.containers import Tuple, Dict
49from sympy.core.parameters import global_parameters
50from sympy.core.logic import fuzzy_and, fuzzy_or, fuzzy_not, FuzzyBool
51from sympy.utilities import default_sort_key
52from sympy.utilities.exceptions import SymPyDeprecationWarning
53from sympy.utilities.iterables import has_dups, sift
54from sympy.utilities.misc import filldedent
55
56import mpmath
57import mpmath.libmp as mlib
58
59import inspect
60from collections import Counter
61
62def _coeff_isneg(a):
63    """Return True if the leading Number is negative.
64
65    Examples
66    ========
67
68    >>> from sympy.core.function import _coeff_isneg
69    >>> from sympy import S, Symbol, oo, pi
70    >>> _coeff_isneg(-3*pi)
71    True
72    >>> _coeff_isneg(S(3))
73    False
74    >>> _coeff_isneg(-oo)
75    True
76    >>> _coeff_isneg(Symbol('n', negative=True)) # coeff is 1
77    False
78
79    For matrix expressions:
80
81    >>> from sympy import MatrixSymbol, sqrt
82    >>> A = MatrixSymbol("A", 3, 3)
83    >>> _coeff_isneg(-sqrt(2)*A)
84    True
85    >>> _coeff_isneg(sqrt(2)*A)
86    False
87    """
88
89    if a.is_MatMul:
90        a = a.args[0]
91    if a.is_Mul:
92        a = a.args[0]
93    return a.is_Number and a.is_extended_negative
94
95
96class PoleError(Exception):
97    pass
98
99
100class ArgumentIndexError(ValueError):
101    def __str__(self):
102        return ("Invalid operation with argument number %s for Function %s" %
103               (self.args[1], self.args[0]))
104
105
106class BadSignatureError(TypeError):
107    '''Raised when a Lambda is created with an invalid signature'''
108    pass
109
110
111class BadArgumentsError(TypeError):
112    '''Raised when a Lambda is called with an incorrect number of arguments'''
113    pass
114
115
116# Python 2/3 version that does not raise a Deprecation warning
117def arity(cls):
118    """Return the arity of the function if it is known, else None.
119
120    Explanation
121    ===========
122
123    When default values are specified for some arguments, they are
124    optional and the arity is reported as a tuple of possible values.
125
126    Examples
127    ========
128
129    >>> from sympy.core.function import arity
130    >>> from sympy import log
131    >>> arity(lambda x: x)
132    1
133    >>> arity(log)
134    (1, 2)
135    >>> arity(lambda *x: sum(x)) is None
136    True
137    """
138    eval_ = getattr(cls, 'eval', cls)
139
140    parameters = inspect.signature(eval_).parameters.items()
141    if [p for _, p in parameters if p.kind == p.VAR_POSITIONAL]:
142        return
143    p_or_k = [p for _, p in parameters if p.kind == p.POSITIONAL_OR_KEYWORD]
144    # how many have no default and how many have a default value
145    no, yes = map(len, sift(p_or_k,
146        lambda p:p.default == p.empty, binary=True))
147    return no if not yes else tuple(range(no, no + yes + 1))
148
149class FunctionClass(ManagedProperties):
150    """
151    Base class for function classes. FunctionClass is a subclass of type.
152
153    Use Function('<function name>' [ , signature ]) to create
154    undefined function classes.
155    """
156    _new = type.__new__
157
158    def __init__(cls, *args, **kwargs):
159        # honor kwarg value or class-defined value before using
160        # the number of arguments in the eval function (if present)
161        nargs = kwargs.pop('nargs', cls.__dict__.get('nargs', arity(cls)))
162        if nargs is None and 'nargs' not in cls.__dict__:
163            for supcls in cls.__mro__:
164                if hasattr(supcls, '_nargs'):
165                    nargs = supcls._nargs
166                    break
167                else:
168                    continue
169
170        # Canonicalize nargs here; change to set in nargs.
171        if is_sequence(nargs):
172            if not nargs:
173                raise ValueError(filldedent('''
174                    Incorrectly specified nargs as %s:
175                    if there are no arguments, it should be
176                    `nargs = 0`;
177                    if there are any number of arguments,
178                    it should be
179                    `nargs = None`''' % str(nargs)))
180            nargs = tuple(ordered(set(nargs)))
181        elif nargs is not None:
182            nargs = (as_int(nargs),)
183        cls._nargs = nargs
184
185        super().__init__(*args, **kwargs)
186
187    @property
188    def __signature__(self):
189        """
190        Allow Python 3's inspect.signature to give a useful signature for
191        Function subclasses.
192        """
193        # Python 3 only, but backports (like the one in IPython) still might
194        # call this.
195        try:
196            from inspect import signature
197        except ImportError:
198            return None
199
200        # TODO: Look at nargs
201        return signature(self.eval)
202
203    @property
204    def free_symbols(self):
205        return set()
206
207    @property
208    def xreplace(self):
209        # Function needs args so we define a property that returns
210        # a function that takes args...and then use that function
211        # to return the right value
212        return lambda rule, **_: rule.get(self, self)
213
214    @property
215    def nargs(self):
216        """Return a set of the allowed number of arguments for the function.
217
218        Examples
219        ========
220
221        >>> from sympy.core.function import Function
222        >>> f = Function('f')
223
224        If the function can take any number of arguments, the set of whole
225        numbers is returned:
226
227        >>> Function('f').nargs
228        Naturals0
229
230        If the function was initialized to accept one or more arguments, a
231        corresponding set will be returned:
232
233        >>> Function('f', nargs=1).nargs
234        {1}
235        >>> Function('f', nargs=(2, 1)).nargs
236        {1, 2}
237
238        The undefined function, after application, also has the nargs
239        attribute; the actual number of arguments is always available by
240        checking the ``args`` attribute:
241
242        >>> f = Function('f')
243        >>> f(1).nargs
244        Naturals0
245        >>> len(f(1).args)
246        1
247        """
248        from sympy.sets.sets import FiniteSet
249        # XXX it would be nice to handle this in __init__ but there are import
250        # problems with trying to import FiniteSet there
251        return FiniteSet(*self._nargs) if self._nargs else S.Naturals0
252
253    def __repr__(cls):
254        return cls.__name__
255
256
257class Application(Basic, metaclass=FunctionClass):
258    """
259    Base class for applied functions.
260
261    Explanation
262    ===========
263
264    Instances of Application represent the result of applying an application of
265    any type to any object.
266    """
267
268    is_Function = True
269
270    @cacheit
271    def __new__(cls, *args, **options):
272        from sympy.sets.fancysets import Naturals0
273        from sympy.sets.sets import FiniteSet
274
275        args = list(map(sympify, args))
276        evaluate = options.pop('evaluate', global_parameters.evaluate)
277        # WildFunction (and anything else like it) may have nargs defined
278        # and we throw that value away here
279        options.pop('nargs', None)
280
281        if options:
282            raise ValueError("Unknown options: %s" % options)
283
284        if evaluate:
285            evaluated = cls.eval(*args)
286            if evaluated is not None:
287                return evaluated
288
289        obj = super().__new__(cls, *args, **options)
290
291        # make nargs uniform here
292        sentinel = object()
293        objnargs = getattr(obj, "nargs", sentinel)
294        if objnargs is not sentinel:
295            # things passing through here:
296            #  - functions subclassed from Function (e.g. myfunc(1).nargs)
297            #  - functions like cos(1).nargs
298            #  - AppliedUndef with given nargs like Function('f', nargs=1)(1).nargs
299            # Canonicalize nargs here
300            if is_sequence(objnargs):
301                nargs = tuple(ordered(set(objnargs)))
302            elif objnargs is not None:
303                nargs = (as_int(objnargs),)
304            else:
305                nargs = None
306        else:
307            # things passing through here:
308            #  - WildFunction('f').nargs
309            #  - AppliedUndef with no nargs like Function('f')(1).nargs
310            nargs = obj._nargs  # note the underscore here
311        # convert to FiniteSet
312        obj.nargs = FiniteSet(*nargs) if nargs else Naturals0()
313        return obj
314
315    @classmethod
316    def eval(cls, *args):
317        """
318        Returns a canonical form of cls applied to arguments args.
319
320        Explanation
321        ===========
322
323        The eval() method is called when the class cls is about to be
324        instantiated and it should return either some simplified instance
325        (possible of some other class), or if the class cls should be
326        unmodified, return None.
327
328        Examples of eval() for the function "sign"
329        ---------------------------------------------
330
331        .. code-block:: python
332
333            @classmethod
334            def eval(cls, arg):
335                if arg is S.NaN:
336                    return S.NaN
337                if arg.is_zero: return S.Zero
338                if arg.is_positive: return S.One
339                if arg.is_negative: return S.NegativeOne
340                if isinstance(arg, Mul):
341                    coeff, terms = arg.as_coeff_Mul(rational=True)
342                    if coeff is not S.One:
343                        return cls(coeff) * cls(terms)
344
345        """
346        return
347
348    @property
349    def func(self):
350        return self.__class__
351
352    def _eval_subs(self, old, new):
353        if (old.is_Function and new.is_Function and
354            callable(old) and callable(new) and
355            old == self.func and len(self.args) in new.nargs):
356            return new(*[i._subs(old, new) for i in self.args])
357
358
359class Function(Application, Expr):
360    """
361    Base class for applied mathematical functions.
362
363    It also serves as a constructor for undefined function classes.
364
365    Examples
366    ========
367
368    First example shows how to use Function as a constructor for undefined
369    function classes:
370
371    >>> from sympy import Function, Symbol
372    >>> x = Symbol('x')
373    >>> f = Function('f')
374    >>> g = Function('g')(x)
375    >>> f
376    f
377    >>> f(x)
378    f(x)
379    >>> g
380    g(x)
381    >>> f(x).diff(x)
382    Derivative(f(x), x)
383    >>> g.diff(x)
384    Derivative(g(x), x)
385
386    Assumptions can be passed to Function, and if function is initialized with a
387    Symbol, the function inherits the name and assumptions associated with the Symbol:
388
389    >>> f_real = Function('f', real=True)
390    >>> f_real(x).is_real
391    True
392    >>> f_real_inherit = Function(Symbol('f', real=True))
393    >>> f_real_inherit(x).is_real
394    True
395
396    Note that assumptions on a function are unrelated to the assumptions on
397    the variable it is called on. If you want to add a relationship, subclass
398    Function and define the appropriate ``_eval_is_assumption`` methods.
399
400    In the following example Function is used as a base class for
401    ``my_func`` that represents a mathematical function *my_func*. Suppose
402    that it is well known, that *my_func(0)* is *1* and *my_func* at infinity
403    goes to *0*, so we want those two simplifications to occur automatically.
404    Suppose also that *my_func(x)* is real exactly when *x* is real. Here is
405    an implementation that honours those requirements:
406
407    >>> from sympy import Function, S, oo, I, sin
408    >>> class my_func(Function):
409    ...
410    ...     @classmethod
411    ...     def eval(cls, x):
412    ...         if x.is_Number:
413    ...             if x.is_zero:
414    ...                 return S.One
415    ...             elif x is S.Infinity:
416    ...                 return S.Zero
417    ...
418    ...     def _eval_is_real(self):
419    ...         return self.args[0].is_real
420    ...
421    >>> x = S('x')
422    >>> my_func(0) + sin(0)
423    1
424    >>> my_func(oo)
425    0
426    >>> my_func(3.54).n() # Not yet implemented for my_func.
427    my_func(3.54)
428    >>> my_func(I).is_real
429    False
430
431    In order for ``my_func`` to become useful, several other methods would
432    need to be implemented. See source code of some of the already
433    implemented functions for more complete examples.
434
435    Also, if the function can take more than one argument, then ``nargs``
436    must be defined, e.g. if ``my_func`` can take one or two arguments
437    then,
438
439    >>> class my_func(Function):
440    ...     nargs = (1, 2)
441    ...
442    >>>
443
444    """
445
446    @property
447    def _diff_wrt(self):
448        return False
449
450    @cacheit
451    def __new__(cls, *args, **options):
452        # Handle calls like Function('f')
453        if cls is Function:
454            return UndefinedFunction(*args, **options)
455
456        n = len(args)
457        if n not in cls.nargs:
458            # XXX: exception message must be in exactly this format to
459            # make it work with NumPy's functions like vectorize(). See,
460            # for example, https://github.com/numpy/numpy/issues/1697.
461            # The ideal solution would be just to attach metadata to
462            # the exception and change NumPy to take advantage of this.
463            temp = ('%(name)s takes %(qual)s %(args)s '
464                   'argument%(plural)s (%(given)s given)')
465            raise TypeError(temp % {
466                'name': cls,
467                'qual': 'exactly' if len(cls.nargs) == 1 else 'at least',
468                'args': min(cls.nargs),
469                'plural': 's'*(min(cls.nargs) != 1),
470                'given': n})
471
472        evaluate = options.get('evaluate', global_parameters.evaluate)
473        result = super().__new__(cls, *args, **options)
474        if evaluate and isinstance(result, cls) and result.args:
475            pr2 = min(cls._should_evalf(a) for a in result.args)
476            if pr2 > 0:
477                pr = max(cls._should_evalf(a) for a in result.args)
478                result = result.evalf(mlib.libmpf.prec_to_dps(pr))
479
480        return result
481
482    @classmethod
483    def _should_evalf(cls, arg):
484        """
485        Decide if the function should automatically evalf().
486
487        Explanation
488        ===========
489
490        By default (in this implementation), this happens if (and only if) the
491        ARG is a floating point number.
492        This function is used by __new__.
493
494        Returns the precision to evalf to, or -1 if it shouldn't evalf.
495        """
496        from sympy.core.evalf import pure_complex
497        if arg.is_Float:
498            return arg._prec
499        if not arg.is_Add:
500            return -1
501        m = pure_complex(arg)
502        if m is None or not (m[0].is_Float or m[1].is_Float):
503            return -1
504        l = [i._prec for i in m if i.is_Float]
505        l.append(-1)
506        return max(l)
507
508    @classmethod
509    def class_key(cls):
510        from sympy.sets.fancysets import Naturals0
511        funcs = {
512            'exp': 10,
513            'log': 11,
514            'sin': 20,
515            'cos': 21,
516            'tan': 22,
517            'cot': 23,
518            'sinh': 30,
519            'cosh': 31,
520            'tanh': 32,
521            'coth': 33,
522            'conjugate': 40,
523            're': 41,
524            'im': 42,
525            'arg': 43,
526        }
527        name = cls.__name__
528
529        try:
530            i = funcs[name]
531        except KeyError:
532            i = 0 if isinstance(cls.nargs, Naturals0) else 10000
533
534        return 4, i, name
535
536    def _eval_evalf(self, prec):
537
538        def _get_mpmath_func(fname):
539            """Lookup mpmath function based on name"""
540            if isinstance(self, AppliedUndef):
541                # Shouldn't lookup in mpmath but might have ._imp_
542                return None
543
544            if not hasattr(mpmath, fname):
545                from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
546                fname = MPMATH_TRANSLATIONS.get(fname, None)
547                if fname is None:
548                    return None
549            return getattr(mpmath, fname)
550
551        _eval_mpmath = getattr(self, '_eval_mpmath', None)
552        if _eval_mpmath is None:
553            func = _get_mpmath_func(self.func.__name__)
554            args = self.args
555        else:
556            func, args = _eval_mpmath()
557
558        # Fall-back evaluation
559        if func is None:
560            imp = getattr(self, '_imp_', None)
561            if imp is None:
562                return None
563            try:
564                return Float(imp(*[i.evalf(prec) for i in self.args]), prec)
565            except (TypeError, ValueError):
566                return None
567
568        # Convert all args to mpf or mpc
569        # Convert the arguments to *higher* precision than requested for the
570        # final result.
571        # XXX + 5 is a guess, it is similar to what is used in evalf.py. Should
572        #     we be more intelligent about it?
573        try:
574            args = [arg._to_mpmath(prec + 5) for arg in args]
575            def bad(m):
576                from mpmath import mpf, mpc
577                # the precision of an mpf value is the last element
578                # if that is 1 (and m[1] is not 1 which would indicate a
579                # power of 2), then the eval failed; so check that none of
580                # the arguments failed to compute to a finite precision.
581                # Note: An mpc value has two parts, the re and imag tuple;
582                # check each of those parts, too. Anything else is allowed to
583                # pass
584                if isinstance(m, mpf):
585                    m = m._mpf_
586                    return m[1] !=1 and m[-1] == 1
587                elif isinstance(m, mpc):
588                    m, n = m._mpc_
589                    return m[1] !=1 and m[-1] == 1 and \
590                        n[1] !=1 and n[-1] == 1
591                else:
592                    return False
593            if any(bad(a) for a in args):
594                raise ValueError  # one or more args failed to compute with significance
595        except ValueError:
596            return
597
598        with mpmath.workprec(prec):
599            v = func(*args)
600
601        return Expr._from_mpmath(v, prec)
602
603    def _eval_derivative(self, s):
604        # f(x).diff(s) -> x.diff(s) * f.fdiff(1)(s)
605        i = 0
606        l = []
607        for a in self.args:
608            i += 1
609            da = a.diff(s)
610            if da.is_zero:
611                continue
612            try:
613                df = self.fdiff(i)
614            except ArgumentIndexError:
615                df = Function.fdiff(self, i)
616            l.append(df * da)
617        return Add(*l)
618
619    def _eval_is_commutative(self):
620        return fuzzy_and(a.is_commutative for a in self.args)
621
622    def _eval_is_meromorphic(self, x, a):
623        if not self.args:
624            return True
625        if any(arg.has(x) for arg in self.args[1:]):
626            return False
627
628        arg = self.args[0]
629        if not arg._eval_is_meromorphic(x, a):
630            return None
631
632        return fuzzy_not(type(self).is_singular(arg.subs(x, a)))
633
634    _singularities = None  # type: Union[FuzzyBool, tTuple[Expr, ...]]
635
636    @classmethod
637    def is_singular(cls, a):
638        """
639        Tests whether the argument is an essential singularity
640        or a branch point, or the functions is non-holomorphic.
641        """
642        ss = cls._singularities
643        if ss in (True, None, False):
644            return ss
645
646        return fuzzy_or(a.is_infinite if s is S.ComplexInfinity
647                        else (a - s).is_zero for s in ss)
648
649    def as_base_exp(self):
650        """
651        Returns the method as the 2-tuple (base, exponent).
652        """
653        return self, S.One
654
655    def _eval_aseries(self, n, args0, x, logx):
656        """
657        Compute an asymptotic expansion around args0, in terms of self.args.
658        This function is only used internally by _eval_nseries and should not
659        be called directly; derived classes can overwrite this to implement
660        asymptotic expansions.
661        """
662        from sympy.utilities.misc import filldedent
663        raise PoleError(filldedent('''
664            Asymptotic expansion of %s around %s is
665            not implemented.''' % (type(self), args0)))
666
667    def _eval_nseries(self, x, n, logx, cdir=0):
668        """
669        This function does compute series for multivariate functions,
670        but the expansion is always in terms of *one* variable.
671
672        Examples
673        ========
674
675        >>> from sympy import atan2
676        >>> from sympy.abc import x, y
677        >>> atan2(x, y).series(x, n=2)
678        atan2(0, y) + x/y + O(x**2)
679        >>> atan2(x, y).series(y, n=2)
680        -y/x + atan2(x, 0) + O(y**2)
681
682        This function also computes asymptotic expansions, if necessary
683        and possible:
684
685        >>> from sympy import loggamma
686        >>> loggamma(1/x)._eval_nseries(x,0,None)
687        -1/x - log(x)/x + log(x)/2 + O(1)
688
689        """
690        from sympy import Order
691        from sympy.core.symbol import uniquely_named_symbol
692        from sympy.sets.sets import FiniteSet
693        args = self.args
694        args0 = [t.limit(x, 0) for t in args]
695        if any(t.is_finite is False for t in args0):
696            from sympy import oo, zoo, nan
697            # XXX could use t.as_leading_term(x) here but it's a little
698            # slower
699            a = [t.compute_leading_term(x, logx=logx) for t in args]
700            a0 = [t.limit(x, 0) for t in a]
701            if any([t.has(oo, -oo, zoo, nan) for t in a0]):
702                return self._eval_aseries(n, args0, x, logx)
703            # Careful: the argument goes to oo, but only logarithmically so. We
704            # are supposed to do a power series expansion "around the
705            # logarithmic term". e.g.
706            #      f(1+x+log(x))
707            #     -> f(1+logx) + x*f'(1+logx) + O(x**2)
708            # where 'logx' is given in the argument
709            a = [t._eval_nseries(x, n, logx) for t in args]
710            z = [r - r0 for (r, r0) in zip(a, a0)]
711            p = [Dummy() for _ in z]
712            q = []
713            v = None
714            for ai, zi, pi in zip(a0, z, p):
715                if zi.has(x):
716                    if v is not None:
717                        raise NotImplementedError
718                    q.append(ai + pi)
719                    v = pi
720                else:
721                    q.append(ai)
722            e1 = self.func(*q)
723            if v is None:
724                return e1
725            s = e1._eval_nseries(v, n, logx)
726            o = s.getO()
727            s = s.removeO()
728            s = s.subs(v, zi).expand() + Order(o.expr.subs(v, zi), x)
729            return s
730        if (self.func.nargs is S.Naturals0
731                or (self.func.nargs == FiniteSet(1) and args0[0])
732                or any(c > 1 for c in self.func.nargs)):
733            e = self
734            e1 = e.expand()
735            if e == e1:
736                #for example when e = sin(x+1) or e = sin(cos(x))
737                #let's try the general algorithm
738                if len(e.args) == 1:
739                    # issue 14411
740                    e = e.func(e.args[0].cancel())
741                term = e.subs(x, S.Zero)
742                if term.is_finite is False or term is S.NaN:
743                    raise PoleError("Cannot expand %s around 0" % (self))
744                series = term
745                fact = S.One
746
747                _x = uniquely_named_symbol('xi', self)
748                e = e.subs(x, _x)
749                for i in range(n - 1):
750                    i += 1
751                    fact *= Rational(i)
752                    e = e.diff(_x)
753                    subs = e.subs(_x, S.Zero)
754                    if subs is S.NaN:
755                        # try to evaluate a limit if we have to
756                        subs = e.limit(_x, S.Zero)
757                    if subs.is_finite is False:
758                        raise PoleError("Cannot expand %s around 0" % (self))
759                    term = subs*(x**i)/fact
760                    term = term.expand()
761                    series += term
762                return series + Order(x**n, x)
763            return e1.nseries(x, n=n, logx=logx)
764        arg = self.args[0]
765        l = []
766        g = None
767        # try to predict a number of terms needed
768        nterms = n + 2
769        cf = Order(arg.as_leading_term(x), x).getn()
770        if cf != 0:
771            nterms = (n/cf).ceiling()
772        for i in range(nterms):
773            g = self.taylor_term(i, arg, g)
774            g = g.nseries(x, n=n, logx=logx)
775            l.append(g)
776        return Add(*l) + Order(x**n, x)
777
778    def fdiff(self, argindex=1):
779        """
780        Returns the first derivative of the function.
781        """
782        if not (1 <= argindex <= len(self.args)):
783            raise ArgumentIndexError(self, argindex)
784        ix = argindex - 1
785        A = self.args[ix]
786        if A._diff_wrt:
787            if len(self.args) == 1 or not A.is_Symbol:
788                return _derivative_dispatch(self, A)
789            for i, v in enumerate(self.args):
790                if i != ix and A in v.free_symbols:
791                    # it can't be in any other argument's free symbols
792                    # issue 8510
793                    break
794            else:
795                    return _derivative_dispatch(self, A)
796
797        # See issue 4624 and issue 4719, 5600 and 8510
798        D = Dummy('xi_%i' % argindex, dummy_index=hash(A))
799        args = self.args[:ix] + (D,) + self.args[ix + 1:]
800        return Subs(Derivative(self.func(*args), D), D, A)
801
802    def _eval_as_leading_term(self, x, logx=None, cdir=0):
803        """Stub that should be overridden by new Functions to return
804        the first non-zero term in a series if ever an x-dependent
805        argument whose leading term vanishes as x -> 0 might be encountered.
806        See, for example, cos._eval_as_leading_term.
807        """
808        from sympy import Order
809        args = [a.as_leading_term(x, logx=logx) for a in self.args]
810        o = Order(1, x)
811        if any(x in a.free_symbols and o.contains(a) for a in args):
812            # Whereas x and any finite number are contained in O(1, x),
813            # expressions like 1/x are not. If any arg simplified to a
814            # vanishing expression as x -> 0 (like x or x**2, but not
815            # 3, 1/x, etc...) then the _eval_as_leading_term is needed
816            # to supply the first non-zero term of the series,
817            #
818            # e.g. expression    leading term
819            #      ----------    ------------
820            #      cos(1/x)      cos(1/x)
821            #      cos(cos(x))   cos(1)
822            #      cos(x)        1        <- _eval_as_leading_term needed
823            #      sin(x)        x        <- _eval_as_leading_term needed
824            #
825            raise NotImplementedError(
826                '%s has no _eval_as_leading_term routine' % self.func)
827        else:
828            return self.func(*args)
829
830
831class AppliedUndef(Function):
832    """
833    Base class for expressions resulting from the application of an undefined
834    function.
835    """
836
837    is_number = False
838
839    def __new__(cls, *args, **options):
840        args = list(map(sympify, args))
841        u = [a.name for a in args if isinstance(a, UndefinedFunction)]
842        if u:
843            raise TypeError('Invalid argument: expecting an expression, not UndefinedFunction%s: %s' % (
844                's'*(len(u) > 1), ', '.join(u)))
845        obj = super().__new__(cls, *args, **options)
846        return obj
847
848    def _eval_as_leading_term(self, x, logx=None, cdir=0):
849        return self
850
851    @property
852    def _diff_wrt(self):
853        """
854        Allow derivatives wrt to undefined functions.
855
856        Examples
857        ========
858
859        >>> from sympy import Function, Symbol
860        >>> f = Function('f')
861        >>> x = Symbol('x')
862        >>> f(x)._diff_wrt
863        True
864        >>> f(x).diff(x)
865        Derivative(f(x), x)
866        """
867        return True
868
869
870class UndefSageHelper:
871    """
872    Helper to facilitate Sage conversion.
873    """
874    def __get__(self, ins, typ):
875        import sage.all as sage
876        if ins is None:
877            return lambda: sage.function(typ.__name__)
878        else:
879            args = [arg._sage_() for arg in ins.args]
880            return lambda : sage.function(ins.__class__.__name__)(*args)
881
882_undef_sage_helper = UndefSageHelper()
883
884class UndefinedFunction(FunctionClass):
885    """
886    The (meta)class of undefined functions.
887    """
888    def __new__(mcl, name, bases=(AppliedUndef,), __dict__=None, **kwargs):
889        from .symbol import _filter_assumptions
890        # Allow Function('f', real=True)
891        # and/or Function(Symbol('f', real=True))
892        assumptions, kwargs = _filter_assumptions(kwargs)
893        if isinstance(name, Symbol):
894            assumptions = name._merge(assumptions)
895            name = name.name
896        elif not isinstance(name, str):
897            raise TypeError('expecting string or Symbol for name')
898        else:
899            commutative = assumptions.get('commutative', None)
900            assumptions = Symbol(name, **assumptions).assumptions0
901            if commutative is None:
902                assumptions.pop('commutative')
903        __dict__ = __dict__ or {}
904        # put the `is_*` for into __dict__
905        __dict__.update({'is_%s' % k: v for k, v in assumptions.items()})
906        # You can add other attributes, although they do have to be hashable
907        # (but seriously, if you want to add anything other than assumptions,
908        # just subclass Function)
909        __dict__.update(kwargs)
910        # add back the sanitized assumptions without the is_ prefix
911        kwargs.update(assumptions)
912        # Save these for __eq__
913        __dict__.update({'_kwargs': kwargs})
914        # do this for pickling
915        __dict__['__module__'] = None
916        obj = super().__new__(mcl, name, bases, __dict__)
917        obj.name = name
918        obj._sage_ = _undef_sage_helper
919        return obj
920
921    def __instancecheck__(cls, instance):
922        return cls in type(instance).__mro__
923
924    _kwargs = {}  # type: tDict[str, Optional[bool]]
925
926    def __hash__(self):
927        return hash((self.class_key(), frozenset(self._kwargs.items())))
928
929    def __eq__(self, other):
930        return (isinstance(other, self.__class__) and
931            self.class_key() == other.class_key() and
932            self._kwargs == other._kwargs)
933
934    def __ne__(self, other):
935        return not self == other
936
937    @property
938    def _diff_wrt(self):
939        return False
940
941
942# XXX: The type: ignore on WildFunction is because mypy complains:
943#
944# sympy/core/function.py:939: error: Cannot determine type of 'sort_key' in
945# base class 'Expr'
946#
947# Somehow this is because of the @cacheit decorator but it is not clear how to
948# fix it.
949
950
951class WildFunction(Function, AtomicExpr):  # type: ignore
952    """
953    A WildFunction function matches any function (with its arguments).
954
955    Examples
956    ========
957
958    >>> from sympy import WildFunction, Function, cos
959    >>> from sympy.abc import x, y
960    >>> F = WildFunction('F')
961    >>> f = Function('f')
962    >>> F.nargs
963    Naturals0
964    >>> x.match(F)
965    >>> F.match(F)
966    {F_: F_}
967    >>> f(x).match(F)
968    {F_: f(x)}
969    >>> cos(x).match(F)
970    {F_: cos(x)}
971    >>> f(x, y).match(F)
972    {F_: f(x, y)}
973
974    To match functions with a given number of arguments, set ``nargs`` to the
975    desired value at instantiation:
976
977    >>> F = WildFunction('F', nargs=2)
978    >>> F.nargs
979    {2}
980    >>> f(x).match(F)
981    >>> f(x, y).match(F)
982    {F_: f(x, y)}
983
984    To match functions with a range of arguments, set ``nargs`` to a tuple
985    containing the desired number of arguments, e.g. if ``nargs = (1, 2)``
986    then functions with 1 or 2 arguments will be matched.
987
988    >>> F = WildFunction('F', nargs=(1, 2))
989    >>> F.nargs
990    {1, 2}
991    >>> f(x).match(F)
992    {F_: f(x)}
993    >>> f(x, y).match(F)
994    {F_: f(x, y)}
995    >>> f(x, y, 1).match(F)
996
997    """
998
999    # XXX: What is this class attribute used for?
1000    include = set()  # type: tSet[Any]
1001
1002    def __init__(cls, name, **assumptions):
1003        from sympy.sets.sets import Set, FiniteSet
1004        cls.name = name
1005        nargs = assumptions.pop('nargs', S.Naturals0)
1006        if not isinstance(nargs, Set):
1007            # Canonicalize nargs here.  See also FunctionClass.
1008            if is_sequence(nargs):
1009                nargs = tuple(ordered(set(nargs)))
1010            elif nargs is not None:
1011                nargs = (as_int(nargs),)
1012            nargs = FiniteSet(*nargs)
1013        cls.nargs = nargs
1014
1015    def matches(self, expr, repl_dict={}, old=False):
1016        if not isinstance(expr, (AppliedUndef, Function)):
1017            return None
1018        if len(expr.args) not in self.nargs:
1019            return None
1020
1021        repl_dict = repl_dict.copy()
1022        repl_dict[self] = expr
1023        return repl_dict
1024
1025
1026class Derivative(Expr):
1027    """
1028    Carries out differentiation of the given expression with respect to symbols.
1029
1030    Examples
1031    ========
1032
1033    >>> from sympy import Derivative, Function, symbols, Subs
1034    >>> from sympy.abc import x, y
1035    >>> f, g = symbols('f g', cls=Function)
1036
1037    >>> Derivative(x**2, x, evaluate=True)
1038    2*x
1039
1040    Denesting of derivatives retains the ordering of variables:
1041
1042        >>> Derivative(Derivative(f(x, y), y), x)
1043        Derivative(f(x, y), y, x)
1044
1045    Contiguously identical symbols are merged into a tuple giving
1046    the symbol and the count:
1047
1048        >>> Derivative(f(x), x, x, y, x)
1049        Derivative(f(x), (x, 2), y, x)
1050
1051    If the derivative cannot be performed, and evaluate is True, the
1052    order of the variables of differentiation will be made canonical:
1053
1054        >>> Derivative(f(x, y), y, x, evaluate=True)
1055        Derivative(f(x, y), x, y)
1056
1057    Derivatives with respect to undefined functions can be calculated:
1058
1059        >>> Derivative(f(x)**2, f(x), evaluate=True)
1060        2*f(x)
1061
1062    Such derivatives will show up when the chain rule is used to
1063    evalulate a derivative:
1064
1065        >>> f(g(x)).diff(x)
1066        Derivative(f(g(x)), g(x))*Derivative(g(x), x)
1067
1068    Substitution is used to represent derivatives of functions with
1069    arguments that are not symbols or functions:
1070
1071        >>> f(2*x + 3).diff(x) == 2*Subs(f(y).diff(y), y, 2*x + 3)
1072        True
1073
1074    Notes
1075    =====
1076
1077    Simplification of high-order derivatives:
1078
1079    Because there can be a significant amount of simplification that can be
1080    done when multiple differentiations are performed, results will be
1081    automatically simplified in a fairly conservative fashion unless the
1082    keyword ``simplify`` is set to False.
1083
1084        >>> from sympy import sqrt, diff, Function, symbols
1085        >>> from sympy.abc import x, y, z
1086        >>> f, g = symbols('f,g', cls=Function)
1087
1088        >>> e = sqrt((x + 1)**2 + x)
1089        >>> diff(e, (x, 5), simplify=False).count_ops()
1090        136
1091        >>> diff(e, (x, 5)).count_ops()
1092        30
1093
1094    Ordering of variables:
1095
1096    If evaluate is set to True and the expression cannot be evaluated, the
1097    list of differentiation symbols will be sorted, that is, the expression is
1098    assumed to have continuous derivatives up to the order asked.
1099
1100    Derivative wrt non-Symbols:
1101
1102    For the most part, one may not differentiate wrt non-symbols.
1103    For example, we do not allow differentiation wrt `x*y` because
1104    there are multiple ways of structurally defining where x*y appears
1105    in an expression: a very strict definition would make
1106    (x*y*z).diff(x*y) == 0. Derivatives wrt defined functions (like
1107    cos(x)) are not allowed, either:
1108
1109        >>> (x*y*z).diff(x*y)
1110        Traceback (most recent call last):
1111        ...
1112        ValueError: Can't calculate derivative wrt x*y.
1113
1114    To make it easier to work with variational calculus, however,
1115    derivatives wrt AppliedUndef and Derivatives are allowed.
1116    For example, in the Euler-Lagrange method one may write
1117    F(t, u, v) where u = f(t) and v = f'(t). These variables can be
1118    written explicitly as functions of time::
1119
1120        >>> from sympy.abc import t
1121        >>> F = Function('F')
1122        >>> U = f(t)
1123        >>> V = U.diff(t)
1124
1125    The derivative wrt f(t) can be obtained directly:
1126
1127        >>> direct = F(t, U, V).diff(U)
1128
1129    When differentiation wrt a non-Symbol is attempted, the non-Symbol
1130    is temporarily converted to a Symbol while the differentiation
1131    is performed and the same answer is obtained:
1132
1133        >>> indirect = F(t, U, V).subs(U, x).diff(x).subs(x, U)
1134        >>> assert direct == indirect
1135
1136    The implication of this non-symbol replacement is that all
1137    functions are treated as independent of other functions and the
1138    symbols are independent of the functions that contain them::
1139
1140        >>> x.diff(f(x))
1141        0
1142        >>> g(x).diff(f(x))
1143        0
1144
1145    It also means that derivatives are assumed to depend only
1146    on the variables of differentiation, not on anything contained
1147    within the expression being differentiated::
1148
1149        >>> F = f(x)
1150        >>> Fx = F.diff(x)
1151        >>> Fx.diff(F)  # derivative depends on x, not F
1152        0
1153        >>> Fxx = Fx.diff(x)
1154        >>> Fxx.diff(Fx)  # derivative depends on x, not Fx
1155        0
1156
1157    The last example can be made explicit by showing the replacement
1158    of Fx in Fxx with y:
1159
1160        >>> Fxx.subs(Fx, y)
1161        Derivative(y, x)
1162
1163        Since that in itself will evaluate to zero, differentiating
1164        wrt Fx will also be zero:
1165
1166        >>> _.doit()
1167        0
1168
1169    Replacing undefined functions with concrete expressions
1170
1171    One must be careful to replace undefined functions with expressions
1172    that contain variables consistent with the function definition and
1173    the variables of differentiation or else insconsistent result will
1174    be obtained. Consider the following example:
1175
1176    >>> eq = f(x)*g(y)
1177    >>> eq.subs(f(x), x*y).diff(x, y).doit()
1178    y*Derivative(g(y), y) + g(y)
1179    >>> eq.diff(x, y).subs(f(x), x*y).doit()
1180    y*Derivative(g(y), y)
1181
1182    The results differ because `f(x)` was replaced with an expression
1183    that involved both variables of differentiation. In the abstract
1184    case, differentiation of `f(x)` by `y` is 0; in the concrete case,
1185    the presence of `y` made that derivative nonvanishing and produced
1186    the extra `g(y)` term.
1187
1188    Defining differentiation for an object
1189
1190    An object must define ._eval_derivative(symbol) method that returns
1191    the differentiation result. This function only needs to consider the
1192    non-trivial case where expr contains symbol and it should call the diff()
1193    method internally (not _eval_derivative); Derivative should be the only
1194    one to call _eval_derivative.
1195
1196    Any class can allow derivatives to be taken with respect to
1197    itself (while indicating its scalar nature). See the
1198    docstring of Expr._diff_wrt.
1199
1200    See Also
1201    ========
1202    _sort_variable_count
1203    """
1204
1205    is_Derivative = True
1206
1207    @property
1208    def _diff_wrt(self):
1209        """An expression may be differentiated wrt a Derivative if
1210        it is in elementary form.
1211
1212        Examples
1213        ========
1214
1215        >>> from sympy import Function, Derivative, cos
1216        >>> from sympy.abc import x
1217        >>> f = Function('f')
1218
1219        >>> Derivative(f(x), x)._diff_wrt
1220        True
1221        >>> Derivative(cos(x), x)._diff_wrt
1222        False
1223        >>> Derivative(x + 1, x)._diff_wrt
1224        False
1225
1226        A Derivative might be an unevaluated form of what will not be
1227        a valid variable of differentiation if evaluated. For example,
1228
1229        >>> Derivative(f(f(x)), x).doit()
1230        Derivative(f(x), x)*Derivative(f(f(x)), f(x))
1231
1232        Such an expression will present the same ambiguities as arise
1233        when dealing with any other product, like ``2*x``, so ``_diff_wrt``
1234        is False:
1235
1236        >>> Derivative(f(f(x)), x)._diff_wrt
1237        False
1238        """
1239        return self.expr._diff_wrt and isinstance(self.doit(), Derivative)
1240
1241    def __new__(cls, expr, *variables, **kwargs):
1242
1243        from sympy.matrices.common import MatrixCommon
1244        from sympy import Integer, MatrixExpr
1245        from sympy.tensor.array import Array, NDimArray
1246        from sympy.utilities.misc import filldedent
1247
1248        expr = sympify(expr)
1249        symbols_or_none = getattr(expr, "free_symbols", None)
1250        has_symbol_set = isinstance(symbols_or_none, set)
1251
1252        if not has_symbol_set:
1253            raise ValueError(filldedent('''
1254                Since there are no variables in the expression %s,
1255                it cannot be differentiated.''' % expr))
1256
1257        # determine value for variables if it wasn't given
1258        if not variables:
1259            variables = expr.free_symbols
1260            if len(variables) != 1:
1261                if expr.is_number:
1262                    return S.Zero
1263                if len(variables) == 0:
1264                    raise ValueError(filldedent('''
1265                        Since there are no variables in the expression,
1266                        the variable(s) of differentiation must be supplied
1267                        to differentiate %s''' % expr))
1268                else:
1269                    raise ValueError(filldedent('''
1270                        Since there is more than one variable in the
1271                        expression, the variable(s) of differentiation
1272                        must be supplied to differentiate %s''' % expr))
1273
1274        # Standardize the variables by sympifying them:
1275        variables = list(sympify(variables))
1276
1277        # Split the list of variables into a list of the variables we are diff
1278        # wrt, where each element of the list has the form (s, count) where
1279        # s is the entity to diff wrt and count is the order of the
1280        # derivative.
1281        variable_count = []
1282        array_likes = (tuple, list, Tuple)
1283
1284        for i, v in enumerate(variables):
1285            if isinstance(v, Integer):
1286                if i == 0:
1287                    raise ValueError("First variable cannot be a number: %i" % v)
1288                count = v
1289                prev, prevcount = variable_count[-1]
1290                if prevcount != 1:
1291                    raise TypeError("tuple {} followed by number {}".format((prev, prevcount), v))
1292                if count == 0:
1293                    variable_count.pop()
1294                else:
1295                    variable_count[-1] = Tuple(prev, count)
1296            else:
1297                if isinstance(v, array_likes):
1298                    if len(v) == 0:
1299                        # Ignore empty tuples: Derivative(expr, ... , (), ... )
1300                        continue
1301                    if isinstance(v[0], array_likes):
1302                        # Derive by array: Derivative(expr, ... , [[x, y, z]], ... )
1303                        if len(v) == 1:
1304                            v = Array(v[0])
1305                            count = 1
1306                        else:
1307                            v, count = v
1308                            v = Array(v)
1309                    else:
1310                        v, count = v
1311                    if count == 0:
1312                        continue
1313                elif isinstance(v, UndefinedFunction):
1314                    raise TypeError(
1315                        "cannot differentiate wrt "
1316                        "UndefinedFunction: %s" % v)
1317                else:
1318                    count = 1
1319                variable_count.append(Tuple(v, count))
1320
1321        # light evaluation of contiguous, identical
1322        # items: (x, 1), (x, 1) -> (x, 2)
1323        merged = []
1324        for t in variable_count:
1325            v, c = t
1326            if c.is_negative:
1327                raise ValueError(
1328                    'order of differentiation must be nonnegative')
1329            if merged and merged[-1][0] == v:
1330                c += merged[-1][1]
1331                if not c:
1332                    merged.pop()
1333                else:
1334                    merged[-1] = Tuple(v, c)
1335            else:
1336                merged.append(t)
1337        variable_count = merged
1338
1339        # sanity check of variables of differentation; we waited
1340        # until the counts were computed since some variables may
1341        # have been removed because the count was 0
1342        for v, c in variable_count:
1343            # v must have _diff_wrt True
1344            if not v._diff_wrt:
1345                __ = ''  # filler to make error message neater
1346                raise ValueError(filldedent('''
1347                    Can't calculate derivative wrt %s.%s''' % (v,
1348                    __)))
1349
1350        # We make a special case for 0th derivative, because there is no
1351        # good way to unambiguously print this.
1352        if len(variable_count) == 0:
1353            return expr
1354
1355        evaluate = kwargs.get('evaluate', False)
1356
1357        if evaluate:
1358            if isinstance(expr, Derivative):
1359                expr = expr.canonical
1360            variable_count = [
1361                (v.canonical if isinstance(v, Derivative) else v, c)
1362                for v, c in variable_count]
1363
1364            # Look for a quick exit if there are symbols that don't appear in
1365            # expression at all. Note, this cannot check non-symbols like
1366            # Derivatives as those can be created by intermediate
1367            # derivatives.
1368            zero = False
1369            free = expr.free_symbols
1370            for v, c in variable_count:
1371                vfree = v.free_symbols
1372                if c.is_positive and vfree:
1373                    if isinstance(v, AppliedUndef):
1374                        # these match exactly since
1375                        # x.diff(f(x)) == g(x).diff(f(x)) == 0
1376                        # and are not created by differentiation
1377                        D = Dummy()
1378                        if not expr.xreplace({v: D}).has(D):
1379                            zero = True
1380                            break
1381                    elif isinstance(v, MatrixExpr):
1382                        zero = False
1383                        break
1384                    elif isinstance(v, Symbol) and v not in free:
1385                        zero = True
1386                        break
1387                    else:
1388                        if not free & vfree:
1389                            # e.g. v is IndexedBase or Matrix
1390                            zero = True
1391                            break
1392            if zero:
1393                return cls._get_zero_with_shape_like(expr)
1394
1395            # make the order of symbols canonical
1396            #TODO: check if assumption of discontinuous derivatives exist
1397            variable_count = cls._sort_variable_count(variable_count)
1398
1399        # denest
1400        if isinstance(expr, Derivative):
1401            variable_count = list(expr.variable_count) + variable_count
1402            expr = expr.expr
1403            return _derivative_dispatch(expr, *variable_count, **kwargs)
1404
1405        # we return here if evaluate is False or if there is no
1406        # _eval_derivative method
1407        if not evaluate or not hasattr(expr, '_eval_derivative'):
1408            # return an unevaluated Derivative
1409            if evaluate and variable_count == [(expr, 1)] and expr.is_scalar:
1410                # special hack providing evaluation for classes
1411                # that have defined is_scalar=True but have no
1412                # _eval_derivative defined
1413                return S.One
1414            return Expr.__new__(cls, expr, *variable_count)
1415
1416        # evaluate the derivative by calling _eval_derivative method
1417        # of expr for each variable
1418        # -------------------------------------------------------------
1419        nderivs = 0  # how many derivatives were performed
1420        unhandled = []
1421        for i, (v, count) in enumerate(variable_count):
1422
1423            old_expr = expr
1424            old_v = None
1425
1426            is_symbol = v.is_symbol or isinstance(v,
1427                (Iterable, Tuple, MatrixCommon, NDimArray))
1428
1429            if not is_symbol:
1430                old_v = v
1431                v = Dummy('xi')
1432                expr = expr.xreplace({old_v: v})
1433                # Derivatives and UndefinedFunctions are independent
1434                # of all others
1435                clashing = not (isinstance(old_v, Derivative) or \
1436                    isinstance(old_v, AppliedUndef))
1437                if not v in expr.free_symbols and not clashing:
1438                    return expr.diff(v)  # expr's version of 0
1439                if not old_v.is_scalar and not hasattr(
1440                        old_v, '_eval_derivative'):
1441                    # special hack providing evaluation for classes
1442                    # that have defined is_scalar=True but have no
1443                    # _eval_derivative defined
1444                    expr *= old_v.diff(old_v)
1445
1446            obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
1447            if obj is not None and obj.is_zero:
1448                return obj
1449
1450            nderivs += count
1451
1452            if old_v is not None:
1453                if obj is not None:
1454                    # remove the dummy that was used
1455                    obj = obj.subs(v, old_v)
1456                # restore expr
1457                expr = old_expr
1458
1459            if obj is None:
1460                # we've already checked for quick-exit conditions
1461                # that give 0 so the remaining variables
1462                # are contained in the expression but the expression
1463                # did not compute a derivative so we stop taking
1464                # derivatives
1465                unhandled = variable_count[i:]
1466                break
1467
1468            expr = obj
1469
1470        # what we have so far can be made canonical
1471        expr = expr.replace(
1472            lambda x: isinstance(x, Derivative),
1473            lambda x: x.canonical)
1474
1475        if unhandled:
1476            if isinstance(expr, Derivative):
1477                unhandled = list(expr.variable_count) + unhandled
1478                expr = expr.expr
1479            expr = Expr.__new__(cls, expr, *unhandled)
1480
1481        if (nderivs > 1) == True and kwargs.get('simplify', True):
1482            from sympy.core.exprtools import factor_terms
1483            from sympy.simplify.simplify import signsimp
1484            expr = factor_terms(signsimp(expr))
1485        return expr
1486
1487    @property
1488    def canonical(cls):
1489        return cls.func(cls.expr,
1490            *Derivative._sort_variable_count(cls.variable_count))
1491
1492    @classmethod
1493    def _sort_variable_count(cls, vc):
1494        """
1495        Sort (variable, count) pairs into canonical order while
1496        retaining order of variables that do not commute during
1497        differentiation:
1498
1499        * symbols and functions commute with each other
1500        * derivatives commute with each other
1501        * a derivative doesn't commute with anything it contains
1502        * any other object is not allowed to commute if it has
1503          free symbols in common with another object
1504
1505        Examples
1506        ========
1507
1508        >>> from sympy import Derivative, Function, symbols
1509        >>> vsort = Derivative._sort_variable_count
1510        >>> x, y, z = symbols('x y z')
1511        >>> f, g, h = symbols('f g h', cls=Function)
1512
1513        Contiguous items are collapsed into one pair:
1514
1515        >>> vsort([(x, 1), (x, 1)])
1516        [(x, 2)]
1517        >>> vsort([(y, 1), (f(x), 1), (y, 1), (f(x), 1)])
1518        [(y, 2), (f(x), 2)]
1519
1520        Ordering is canonical.
1521
1522        >>> def vsort0(*v):
1523        ...     # docstring helper to
1524        ...     # change vi -> (vi, 0), sort, and return vi vals
1525        ...     return [i[0] for i in vsort([(i, 0) for i in v])]
1526
1527        >>> vsort0(y, x)
1528        [x, y]
1529        >>> vsort0(g(y), g(x), f(y))
1530        [f(y), g(x), g(y)]
1531
1532        Symbols are sorted as far to the left as possible but never
1533        move to the left of a derivative having the same symbol in
1534        its variables; the same applies to AppliedUndef which are
1535        always sorted after Symbols:
1536
1537        >>> dfx = f(x).diff(x)
1538        >>> assert vsort0(dfx, y) == [y, dfx]
1539        >>> assert vsort0(dfx, x) == [dfx, x]
1540        """
1541        from sympy.utilities.iterables import uniq, topological_sort
1542        if not vc:
1543            return []
1544        vc = list(vc)
1545        if len(vc) == 1:
1546            return [Tuple(*vc[0])]
1547        V = list(range(len(vc)))
1548        E = []
1549        v = lambda i: vc[i][0]
1550        D = Dummy()
1551        def _block(d, v, wrt=False):
1552            # return True if v should not come before d else False
1553            if d == v:
1554                return wrt
1555            if d.is_Symbol:
1556                return False
1557            if isinstance(d, Derivative):
1558                # a derivative blocks if any of it's variables contain
1559                # v; the wrt flag will return True for an exact match
1560                # and will cause an AppliedUndef to block if v is in
1561                # the arguments
1562                if any(_block(k, v, wrt=True)
1563                        for k in d._wrt_variables):
1564                    return True
1565                return False
1566            if not wrt and isinstance(d, AppliedUndef):
1567                return False
1568            if v.is_Symbol:
1569                return v in d.free_symbols
1570            if isinstance(v, AppliedUndef):
1571                return _block(d.xreplace({v: D}), D)
1572            return d.free_symbols & v.free_symbols
1573        for i in range(len(vc)):
1574            for j in range(i):
1575                if _block(v(j), v(i)):
1576                    E.append((j,i))
1577        # this is the default ordering to use in case of ties
1578        O = dict(zip(ordered(uniq([i for i, c in vc])), range(len(vc))))
1579        ix = topological_sort((V, E), key=lambda i: O[v(i)])
1580        # merge counts of contiguously identical items
1581        merged = []
1582        for v, c in [vc[i] for i in ix]:
1583            if merged and merged[-1][0] == v:
1584                merged[-1][1] += c
1585            else:
1586                merged.append([v, c])
1587        return [Tuple(*i) for i in merged]
1588
1589    def _eval_is_commutative(self):
1590        return self.expr.is_commutative
1591
1592    def _eval_derivative(self, v):
1593        # If v (the variable of differentiation) is not in
1594        # self.variables, we might be able to take the derivative.
1595        if v not in self._wrt_variables:
1596            dedv = self.expr.diff(v)
1597            if isinstance(dedv, Derivative):
1598                return dedv.func(dedv.expr, *(self.variable_count + dedv.variable_count))
1599            # dedv (d(self.expr)/dv) could have simplified things such that the
1600            # derivative wrt things in self.variables can now be done. Thus,
1601            # we set evaluate=True to see if there are any other derivatives
1602            # that can be done. The most common case is when dedv is a simple
1603            # number so that the derivative wrt anything else will vanish.
1604            return self.func(dedv, *self.variables, evaluate=True)
1605        # In this case v was in self.variables so the derivative wrt v has
1606        # already been attempted and was not computed, either because it
1607        # couldn't be or evaluate=False originally.
1608        variable_count = list(self.variable_count)
1609        variable_count.append((v, 1))
1610        return self.func(self.expr, *variable_count, evaluate=False)
1611
1612    def doit(self, **hints):
1613        expr = self.expr
1614        if hints.get('deep', True):
1615            expr = expr.doit(**hints)
1616        hints['evaluate'] = True
1617        rv = self.func(expr, *self.variable_count, **hints)
1618        if rv!= self and rv.has(Derivative):
1619            rv =  rv.doit(**hints)
1620        return rv
1621
1622    @_sympifyit('z0', NotImplementedError)
1623    def doit_numerically(self, z0):
1624        """
1625        Evaluate the derivative at z numerically.
1626
1627        When we can represent derivatives at a point, this should be folded
1628        into the normal evalf. For now, we need a special method.
1629        """
1630        if len(self.free_symbols) != 1 or len(self.variables) != 1:
1631            raise NotImplementedError('partials and higher order derivatives')
1632        z = list(self.free_symbols)[0]
1633
1634        def eval(x):
1635            f0 = self.expr.subs(z, Expr._from_mpmath(x, prec=mpmath.mp.prec))
1636            f0 = f0.evalf(mlib.libmpf.prec_to_dps(mpmath.mp.prec))
1637            return f0._to_mpmath(mpmath.mp.prec)
1638        return Expr._from_mpmath(mpmath.diff(eval,
1639                                             z0._to_mpmath(mpmath.mp.prec)),
1640                                 mpmath.mp.prec)
1641
1642    @property
1643    def expr(self):
1644        return self._args[0]
1645
1646    @property
1647    def _wrt_variables(self):
1648        # return the variables of differentiation without
1649        # respect to the type of count (int or symbolic)
1650        return [i[0] for i in self.variable_count]
1651
1652    @property
1653    def variables(self):
1654        # TODO: deprecate?  YES, make this 'enumerated_variables' and
1655        #       name _wrt_variables as variables
1656        # TODO: support for `d^n`?
1657        rv = []
1658        for v, count in self.variable_count:
1659            if not count.is_Integer:
1660                raise TypeError(filldedent('''
1661                Cannot give expansion for symbolic count. If you just
1662                want a list of all variables of differentiation, use
1663                _wrt_variables.'''))
1664            rv.extend([v]*count)
1665        return tuple(rv)
1666
1667    @property
1668    def variable_count(self):
1669        return self._args[1:]
1670
1671    @property
1672    def derivative_count(self):
1673        return sum([count for var, count in self.variable_count], 0)
1674
1675    @property
1676    def free_symbols(self):
1677        ret = self.expr.free_symbols
1678        # Add symbolic counts to free_symbols
1679        for var, count in self.variable_count:
1680            ret.update(count.free_symbols)
1681        return ret
1682
1683    @property
1684    def kind(self):
1685        return self.args[0].kind
1686
1687    def _eval_subs(self, old, new):
1688        # The substitution (old, new) cannot be done inside
1689        # Derivative(expr, vars) for a variety of reasons
1690        # as handled below.
1691        if old in self._wrt_variables:
1692            # first handle the counts
1693            expr = self.func(self.expr, *[(v, c.subs(old, new))
1694                for v, c in self.variable_count])
1695            if expr != self:
1696                return expr._eval_subs(old, new)
1697            # quick exit case
1698            if not getattr(new, '_diff_wrt', False):
1699                # case (0): new is not a valid variable of
1700                # differentiation
1701                if isinstance(old, Symbol):
1702                    # don't introduce a new symbol if the old will do
1703                    return Subs(self, old, new)
1704                else:
1705                    xi = Dummy('xi')
1706                    return Subs(self.xreplace({old: xi}), xi, new)
1707
1708        # If both are Derivatives with the same expr, check if old is
1709        # equivalent to self or if old is a subderivative of self.
1710        if old.is_Derivative and old.expr == self.expr:
1711            if self.canonical == old.canonical:
1712                return new
1713
1714            # collections.Counter doesn't have __le__
1715            def _subset(a, b):
1716                return all((a[i] <= b[i]) == True for i in a)
1717
1718            old_vars = Counter(dict(reversed(old.variable_count)))
1719            self_vars = Counter(dict(reversed(self.variable_count)))
1720            if _subset(old_vars, self_vars):
1721                return _derivative_dispatch(new, *(self_vars - old_vars).items()).canonical
1722
1723        args = list(self.args)
1724        newargs = list(x._subs(old, new) for x in args)
1725        if args[0] == old:
1726            # complete replacement of self.expr
1727            # we already checked that the new is valid so we know
1728            # it won't be a problem should it appear in variables
1729            return _derivative_dispatch(*newargs)
1730
1731        if newargs[0] != args[0]:
1732            # case (1) can't change expr by introducing something that is in
1733            # the _wrt_variables if it was already in the expr
1734            # e.g.
1735            # for Derivative(f(x, g(y)), y), x cannot be replaced with
1736            # anything that has y in it; for f(g(x), g(y)).diff(g(y))
1737            # g(x) cannot be replaced with anything that has g(y)
1738            syms = {vi: Dummy() for vi in self._wrt_variables
1739                if not vi.is_Symbol}
1740            wrt = {syms.get(vi, vi) for vi in self._wrt_variables}
1741            forbidden = args[0].xreplace(syms).free_symbols & wrt
1742            nfree = new.xreplace(syms).free_symbols
1743            ofree = old.xreplace(syms).free_symbols
1744            if (nfree - ofree) & forbidden:
1745                return Subs(self, old, new)
1746
1747        viter = ((i, j) for ((i, _), (j, _)) in zip(newargs[1:], args[1:]))
1748        if any(i != j for i, j in viter):  # a wrt-variable change
1749            # case (2) can't change vars by introducing a variable
1750            # that is contained in expr, e.g.
1751            # for Derivative(f(z, g(h(x), y)), y), y cannot be changed to
1752            # x, h(x), or g(h(x), y)
1753            for a in _atomic(self.expr, recursive=True):
1754                for i in range(1, len(newargs)):
1755                    vi, _ = newargs[i]
1756                    if a == vi and vi != args[i][0]:
1757                        return Subs(self, old, new)
1758            # more arg-wise checks
1759            vc = newargs[1:]
1760            oldv = self._wrt_variables
1761            newe = self.expr
1762            subs = []
1763            for i, (vi, ci) in enumerate(vc):
1764                if not vi._diff_wrt:
1765                    # case (3) invalid differentiation expression so
1766                    # create a replacement dummy
1767                    xi = Dummy('xi_%i' % i)
1768                    # replace the old valid variable with the dummy
1769                    # in the expression
1770                    newe = newe.xreplace({oldv[i]: xi})
1771                    # and replace the bad variable with the dummy
1772                    vc[i] = (xi, ci)
1773                    # and record the dummy with the new (invalid)
1774                    # differentiation expression
1775                    subs.append((xi, vi))
1776
1777            if subs:
1778                # handle any residual substitution in the expression
1779                newe = newe._subs(old, new)
1780                # return the Subs-wrapped derivative
1781                return Subs(Derivative(newe, *vc), *zip(*subs))
1782
1783        # everything was ok
1784        return _derivative_dispatch(*newargs)
1785
1786    def _eval_lseries(self, x, logx, cdir=0):
1787        dx = self.variables
1788        for term in self.expr.lseries(x, logx=logx, cdir=cdir):
1789            yield self.func(term, *dx)
1790
1791    def _eval_nseries(self, x, n, logx, cdir=0):
1792        arg = self.expr.nseries(x, n=n, logx=logx)
1793        o = arg.getO()
1794        dx = self.variables
1795        rv = [self.func(a, *dx) for a in Add.make_args(arg.removeO())]
1796        if o:
1797            rv.append(o/x)
1798        return Add(*rv)
1799
1800    def _eval_as_leading_term(self, x, logx=None, cdir=0):
1801        series_gen = self.expr.lseries(x)
1802        d = S.Zero
1803        for leading_term in series_gen:
1804            d = diff(leading_term, *self.variables)
1805            if d != 0:
1806                break
1807        return d
1808
1809    def as_finite_difference(self, points=1, x0=None, wrt=None):
1810        """ Expresses a Derivative instance as a finite difference.
1811
1812        Parameters
1813        ==========
1814
1815        points : sequence or coefficient, optional
1816            If sequence: discrete values (length >= order+1) of the
1817            independent variable used for generating the finite
1818            difference weights.
1819            If it is a coefficient, it will be used as the step-size
1820            for generating an equidistant sequence of length order+1
1821            centered around ``x0``. Default: 1 (step-size 1)
1822
1823        x0 : number or Symbol, optional
1824            the value of the independent variable (``wrt``) at which the
1825            derivative is to be approximated. Default: same as ``wrt``.
1826
1827        wrt : Symbol, optional
1828            "with respect to" the variable for which the (partial)
1829            derivative is to be approximated for. If not provided it
1830            is required that the derivative is ordinary. Default: ``None``.
1831
1832
1833        Examples
1834        ========
1835
1836        >>> from sympy import symbols, Function, exp, sqrt, Symbol
1837        >>> x, h = symbols('x h')
1838        >>> f = Function('f')
1839        >>> f(x).diff(x).as_finite_difference()
1840        -f(x - 1/2) + f(x + 1/2)
1841
1842        The default step size and number of points are 1 and
1843        ``order + 1`` respectively. We can change the step size by
1844        passing a symbol as a parameter:
1845
1846        >>> f(x).diff(x).as_finite_difference(h)
1847        -f(-h/2 + x)/h + f(h/2 + x)/h
1848
1849        We can also specify the discretized values to be used in a
1850        sequence:
1851
1852        >>> f(x).diff(x).as_finite_difference([x, x+h, x+2*h])
1853        -3*f(x)/(2*h) + 2*f(h + x)/h - f(2*h + x)/(2*h)
1854
1855        The algorithm is not restricted to use equidistant spacing, nor
1856        do we need to make the approximation around ``x0``, but we can get
1857        an expression estimating the derivative at an offset:
1858
1859        >>> e, sq2 = exp(1), sqrt(2)
1860        >>> xl = [x-h, x+h, x+e*h]
1861        >>> f(x).diff(x, 1).as_finite_difference(xl, x+h*sq2)  # doctest: +ELLIPSIS
1862        2*h*((h + sqrt(2)*h)/(2*h) - (-sqrt(2)*h + h)/(2*h))*f(E*h + x)/...
1863
1864        To approximate ``Derivative`` around ``x0`` using a non-equidistant
1865        spacing step, the algorithm supports assignment of undefined
1866        functions to ``points``:
1867
1868        >>> dx = Function('dx')
1869        >>> f(x).diff(x).as_finite_difference(points=dx(x), x0=x-h)
1870        -f(-h + x - dx(-h + x)/2)/dx(-h + x) + f(-h + x + dx(-h + x)/2)/dx(-h + x)
1871
1872        Partial derivatives are also supported:
1873
1874        >>> y = Symbol('y')
1875        >>> d2fdxdy=f(x,y).diff(x,y)
1876        >>> d2fdxdy.as_finite_difference(wrt=x)
1877        -Derivative(f(x - 1/2, y), y) + Derivative(f(x + 1/2, y), y)
1878
1879        We can apply ``as_finite_difference`` to ``Derivative`` instances in
1880        compound expressions using ``replace``:
1881
1882        >>> (1 + 42**f(x).diff(x)).replace(lambda arg: arg.is_Derivative,
1883        ...     lambda arg: arg.as_finite_difference())
1884        42**(-f(x - 1/2) + f(x + 1/2)) + 1
1885
1886
1887        See also
1888        ========
1889
1890        sympy.calculus.finite_diff.apply_finite_diff
1891        sympy.calculus.finite_diff.differentiate_finite
1892        sympy.calculus.finite_diff.finite_diff_weights
1893
1894        """
1895        from ..calculus.finite_diff import _as_finite_diff
1896        return _as_finite_diff(self, points, x0, wrt)
1897
1898    @classmethod
1899    def _get_zero_with_shape_like(cls, expr):
1900        return S.Zero
1901
1902    @classmethod
1903    def _dispatch_eval_derivative_n_times(cls, expr, v, count):
1904        # Evaluate the derivative `n` times.  If
1905        # `_eval_derivative_n_times` is not overridden by the current
1906        # object, the default in `Basic` will call a loop over
1907        # `_eval_derivative`:
1908        return expr._eval_derivative_n_times(v, count)
1909
1910
1911def _derivative_dispatch(expr, *variables, **kwargs):
1912    from sympy.matrices.common import MatrixCommon
1913    from sympy import MatrixExpr
1914    from sympy import NDimArray
1915    array_types = (MatrixCommon, MatrixExpr, NDimArray, list, tuple, Tuple)
1916    if isinstance(expr, array_types) or any(isinstance(i[0], array_types) if isinstance(i, (tuple, list, Tuple)) else isinstance(i, array_types) for i in variables):
1917        from sympy.tensor.array.array_derivatives import ArrayDerivative
1918        return ArrayDerivative(expr, *variables, **kwargs)
1919    return Derivative(expr, *variables, **kwargs)
1920
1921
1922class Lambda(Expr):
1923    """
1924    Lambda(x, expr) represents a lambda function similar to Python's
1925    'lambda x: expr'. A function of several variables is written as
1926    Lambda((x, y, ...), expr).
1927
1928    Examples
1929    ========
1930
1931    A simple example:
1932
1933    >>> from sympy import Lambda
1934    >>> from sympy.abc import x
1935    >>> f = Lambda(x, x**2)
1936    >>> f(4)
1937    16
1938
1939    For multivariate functions, use:
1940
1941    >>> from sympy.abc import y, z, t
1942    >>> f2 = Lambda((x, y, z, t), x + y**z + t**z)
1943    >>> f2(1, 2, 3, 4)
1944    73
1945
1946    It is also possible to unpack tuple arguments:
1947
1948    >>> f = Lambda( ((x, y), z) , x + y + z)
1949    >>> f((1, 2), 3)
1950    6
1951
1952    A handy shortcut for lots of arguments:
1953
1954    >>> p = x, y, z
1955    >>> f = Lambda(p, x + y*z)
1956    >>> f(*p)
1957    x + y*z
1958
1959    """
1960    is_Function = True
1961
1962    def __new__(cls, signature, expr):
1963        if iterable(signature) and not isinstance(signature, (tuple, Tuple)):
1964            SymPyDeprecationWarning(
1965                feature="non tuple iterable of argument symbols to Lambda",
1966                useinstead="tuple of argument symbols",
1967                issue=17474,
1968                deprecated_since_version="1.5").warn()
1969            signature = tuple(signature)
1970        sig = signature if iterable(signature) else (signature,)
1971        sig = sympify(sig)
1972        cls._check_signature(sig)
1973
1974        if len(sig) == 1 and sig[0] == expr:
1975            return S.IdentityFunction
1976
1977        return Expr.__new__(cls, sig, sympify(expr))
1978
1979    @classmethod
1980    def _check_signature(cls, sig):
1981        syms = set()
1982
1983        def rcheck(args):
1984            for a in args:
1985                if a.is_symbol:
1986                    if a in syms:
1987                        raise BadSignatureError("Duplicate symbol %s" % a)
1988                    syms.add(a)
1989                elif isinstance(a, Tuple):
1990                    rcheck(a)
1991                else:
1992                    raise BadSignatureError("Lambda signature should be only tuples"
1993                        " and symbols, not %s" % a)
1994
1995        if not isinstance(sig, Tuple):
1996            raise BadSignatureError("Lambda signature should be a tuple not %s" % sig)
1997        # Recurse through the signature:
1998        rcheck(sig)
1999
2000    @property
2001    def signature(self):
2002        """The expected form of the arguments to be unpacked into variables"""
2003        return self._args[0]
2004
2005    @property
2006    def expr(self):
2007        """The return value of the function"""
2008        return self._args[1]
2009
2010    @property
2011    def variables(self):
2012        """The variables used in the internal representation of the function"""
2013        def _variables(args):
2014            if isinstance(args, Tuple):
2015                for arg in args:
2016                    yield from _variables(arg)
2017            else:
2018                yield args
2019        return tuple(_variables(self.signature))
2020
2021    @property
2022    def nargs(self):
2023        from sympy.sets.sets import FiniteSet
2024        return FiniteSet(len(self.signature))
2025
2026    bound_symbols = variables
2027
2028    @property
2029    def free_symbols(self):
2030        return self.expr.free_symbols - set(self.variables)
2031
2032    def __call__(self, *args):
2033        n = len(args)
2034        if n not in self.nargs:  # Lambda only ever has 1 value in nargs
2035            # XXX: exception message must be in exactly this format to
2036            # make it work with NumPy's functions like vectorize(). See,
2037            # for example, https://github.com/numpy/numpy/issues/1697.
2038            # The ideal solution would be just to attach metadata to
2039            # the exception and change NumPy to take advantage of this.
2040            ## XXX does this apply to Lambda? If not, remove this comment.
2041            temp = ('%(name)s takes exactly %(args)s '
2042                   'argument%(plural)s (%(given)s given)')
2043            raise BadArgumentsError(temp % {
2044                'name': self,
2045                'args': list(self.nargs)[0],
2046                'plural': 's'*(list(self.nargs)[0] != 1),
2047                'given': n})
2048
2049        d = self._match_signature(self.signature, args)
2050
2051        return self.expr.xreplace(d)
2052
2053    def _match_signature(self, sig, args):
2054
2055        symargmap = {}
2056
2057        def rmatch(pars, args):
2058            for par, arg in zip(pars, args):
2059                if par.is_symbol:
2060                    symargmap[par] = arg
2061                elif isinstance(par, Tuple):
2062                    if not isinstance(arg, (tuple, Tuple)) or len(args) != len(pars):
2063                        raise BadArgumentsError("Can't match %s and %s" % (args, pars))
2064                    rmatch(par, arg)
2065
2066        rmatch(sig, args)
2067
2068        return symargmap
2069
2070    @property
2071    def is_identity(self):
2072        """Return ``True`` if this ``Lambda`` is an identity function. """
2073        return self.signature == self.expr
2074
2075    def _eval_evalf(self, prec):
2076        from sympy.core.evalf import prec_to_dps
2077        return self.func(self.args[0], self.args[1].evalf(n=prec_to_dps(prec)))
2078
2079
2080class Subs(Expr):
2081    """
2082    Represents unevaluated substitutions of an expression.
2083
2084    ``Subs(expr, x, x0)`` represents the expression resulting
2085    from substituting x with x0 in expr.
2086
2087    Parameters
2088    ==========
2089
2090    expr : Expr
2091        An expression.
2092
2093    x : tuple, variable
2094        A variable or list of distinct variables.
2095
2096    x0 : tuple or list of tuples
2097        A point or list of evaluation points
2098        corresponding to those variables.
2099
2100    Notes
2101    =====
2102
2103    ``Subs`` objects are generally useful to represent unevaluated derivatives
2104    calculated at a point.
2105
2106    The variables may be expressions, but they are subjected to the limitations
2107    of subs(), so it is usually a good practice to use only symbols for
2108    variables, since in that case there can be no ambiguity.
2109
2110    There's no automatic expansion - use the method .doit() to effect all
2111    possible substitutions of the object and also of objects inside the
2112    expression.
2113
2114    When evaluating derivatives at a point that is not a symbol, a Subs object
2115    is returned. One is also able to calculate derivatives of Subs objects - in
2116    this case the expression is always expanded (for the unevaluated form, use
2117    Derivative()).
2118
2119    Examples
2120    ========
2121
2122    >>> from sympy import Subs, Function, sin, cos
2123    >>> from sympy.abc import x, y, z
2124    >>> f = Function('f')
2125
2126    Subs are created when a particular substitution cannot be made. The
2127    x in the derivative cannot be replaced with 0 because 0 is not a
2128    valid variables of differentiation:
2129
2130    >>> f(x).diff(x).subs(x, 0)
2131    Subs(Derivative(f(x), x), x, 0)
2132
2133    Once f is known, the derivative and evaluation at 0 can be done:
2134
2135    >>> _.subs(f, sin).doit() == sin(x).diff(x).subs(x, 0) == cos(0)
2136    True
2137
2138    Subs can also be created directly with one or more variables:
2139
2140    >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
2141    Subs(z + f(x)*sin(y), (x, y), (0, 1))
2142    >>> _.doit()
2143    z + f(0)*sin(1)
2144
2145    Notes
2146    =====
2147
2148    In order to allow expressions to combine before doit is done, a
2149    representation of the Subs expression is used internally to make
2150    expressions that are superficially different compare the same:
2151
2152    >>> a, b = Subs(x, x, 0), Subs(y, y, 0)
2153    >>> a + b
2154    2*Subs(x, x, 0)
2155
2156    This can lead to unexpected consequences when using methods
2157    like `has` that are cached:
2158
2159    >>> s = Subs(x, x, 0)
2160    >>> s.has(x), s.has(y)
2161    (True, False)
2162    >>> ss = s.subs(x, y)
2163    >>> ss.has(x), ss.has(y)
2164    (True, False)
2165    >>> s, ss
2166    (Subs(x, x, 0), Subs(y, y, 0))
2167    """
2168    def __new__(cls, expr, variables, point, **assumptions):
2169        from sympy import Symbol
2170
2171        if not is_sequence(variables, Tuple):
2172            variables = [variables]
2173        variables = Tuple(*variables)
2174
2175        if has_dups(variables):
2176            repeated = [str(v) for v, i in Counter(variables).items() if i > 1]
2177            __ = ', '.join(repeated)
2178            raise ValueError(filldedent('''
2179                The following expressions appear more than once: %s
2180                ''' % __))
2181
2182        point = Tuple(*(point if is_sequence(point, Tuple) else [point]))
2183
2184        if len(point) != len(variables):
2185            raise ValueError('Number of point values must be the same as '
2186                             'the number of variables.')
2187
2188        if not point:
2189            return sympify(expr)
2190
2191        # denest
2192        if isinstance(expr, Subs):
2193            variables = expr.variables + variables
2194            point = expr.point + point
2195            expr = expr.expr
2196        else:
2197            expr = sympify(expr)
2198
2199        # use symbols with names equal to the point value (with prepended _)
2200        # to give a variable-independent expression
2201        pre = "_"
2202        pts = sorted(set(point), key=default_sort_key)
2203        from sympy.printing import StrPrinter
2204        class CustomStrPrinter(StrPrinter):
2205            def _print_Dummy(self, expr):
2206                return str(expr) + str(expr.dummy_index)
2207        def mystr(expr, **settings):
2208            p = CustomStrPrinter(settings)
2209            return p.doprint(expr)
2210        while 1:
2211            s_pts = {p: Symbol(pre + mystr(p)) for p in pts}
2212            reps = [(v, s_pts[p])
2213                for v, p in zip(variables, point)]
2214            # if any underscore-prepended symbol is already a free symbol
2215            # and is a variable with a different point value, then there
2216            # is a clash, e.g. _0 clashes in Subs(_0 + _1, (_0, _1), (1, 0))
2217            # because the new symbol that would be created is _1 but _1
2218            # is already mapped to 0 so __0 and __1 are used for the new
2219            # symbols
2220            if any(r in expr.free_symbols and
2221                   r in variables and
2222                   Symbol(pre + mystr(point[variables.index(r)])) != r
2223                   for _, r in reps):
2224                pre += "_"
2225                continue
2226            break
2227
2228        obj = Expr.__new__(cls, expr, Tuple(*variables), point)
2229        obj._expr = expr.xreplace(dict(reps))
2230        return obj
2231
2232    def _eval_is_commutative(self):
2233        return self.expr.is_commutative
2234
2235    def doit(self, **hints):
2236        e, v, p = self.args
2237
2238        # remove self mappings
2239        for i, (vi, pi) in enumerate(zip(v, p)):
2240            if vi == pi:
2241                v = v[:i] + v[i + 1:]
2242                p = p[:i] + p[i + 1:]
2243        if not v:
2244            return self.expr
2245
2246        if isinstance(e, Derivative):
2247            # apply functions first, e.g. f -> cos
2248            undone = []
2249            for i, vi in enumerate(v):
2250                if isinstance(vi, FunctionClass):
2251                    e = e.subs(vi, p[i])
2252                else:
2253                    undone.append((vi, p[i]))
2254            if not isinstance(e, Derivative):
2255                e = e.doit()
2256            if isinstance(e, Derivative):
2257                # do Subs that aren't related to differentiation
2258                undone2 = []
2259                D = Dummy()
2260                arg = e.args[0]
2261                for vi, pi in undone:
2262                    if D not in e.xreplace({vi: D}).free_symbols:
2263                        if arg.has(vi):
2264                            e = e.subs(vi, pi)
2265                    else:
2266                        undone2.append((vi, pi))
2267                undone = undone2
2268                # differentiate wrt variables that are present
2269                wrt = []
2270                D = Dummy()
2271                expr = e.expr
2272                free = expr.free_symbols
2273                for vi, ci in e.variable_count:
2274                    if isinstance(vi, Symbol) and vi in free:
2275                        expr = expr.diff((vi, ci))
2276                    elif D in expr.subs(vi, D).free_symbols:
2277                        expr = expr.diff((vi, ci))
2278                    else:
2279                        wrt.append((vi, ci))
2280                # inject remaining subs
2281                rv = expr.subs(undone)
2282                # do remaining differentiation *in order given*
2283                for vc in wrt:
2284                    rv = rv.diff(vc)
2285            else:
2286                # inject remaining subs
2287                rv = e.subs(undone)
2288        else:
2289            rv = e.doit(**hints).subs(list(zip(v, p)))
2290
2291        if hints.get('deep', True) and rv != self:
2292            rv = rv.doit(**hints)
2293        return rv
2294
2295    def evalf(self, prec=None, **options):
2296        return self.doit().evalf(prec, **options)
2297
2298    n = evalf
2299
2300    @property
2301    def variables(self):
2302        """The variables to be evaluated"""
2303        return self._args[1]
2304
2305    bound_symbols = variables
2306
2307    @property
2308    def expr(self):
2309        """The expression on which the substitution operates"""
2310        return self._args[0]
2311
2312    @property
2313    def point(self):
2314        """The values for which the variables are to be substituted"""
2315        return self._args[2]
2316
2317    @property
2318    def free_symbols(self):
2319        return (self.expr.free_symbols - set(self.variables) |
2320            set(self.point.free_symbols))
2321
2322    @property
2323    def expr_free_symbols(self):
2324        from sympy.utilities.exceptions import SymPyDeprecationWarning
2325        SymPyDeprecationWarning(feature="expr_free_symbols method",
2326                                issue=21494,
2327                                deprecated_since_version="1.9").warn()
2328        return (self.expr.expr_free_symbols - set(self.variables) |
2329            set(self.point.expr_free_symbols))
2330
2331    def __eq__(self, other):
2332        if not isinstance(other, Subs):
2333            return False
2334        return self._hashable_content() == other._hashable_content()
2335
2336    def __ne__(self, other):
2337        return not(self == other)
2338
2339    def __hash__(self):
2340        return super().__hash__()
2341
2342    def _hashable_content(self):
2343        return (self._expr.xreplace(self.canonical_variables),
2344            ) + tuple(ordered([(v, p) for v, p in
2345            zip(self.variables, self.point) if not self.expr.has(v)]))
2346
2347    def _eval_subs(self, old, new):
2348        # Subs doit will do the variables in order; the semantics
2349        # of subs for Subs is have the following invariant for
2350        # Subs object foo:
2351        #    foo.doit().subs(reps) == foo.subs(reps).doit()
2352        pt = list(self.point)
2353        if old in self.variables:
2354            if _atomic(new) == {new} and not any(
2355                    i.has(new) for i in self.args):
2356                # the substitution is neutral
2357                return self.xreplace({old: new})
2358            # any occurrence of old before this point will get
2359            # handled by replacements from here on
2360            i = self.variables.index(old)
2361            for j in range(i, len(self.variables)):
2362                pt[j] = pt[j]._subs(old, new)
2363            return self.func(self.expr, self.variables, pt)
2364        v = [i._subs(old, new) for i in self.variables]
2365        if v != list(self.variables):
2366            return self.func(self.expr, self.variables + (old,), pt + [new])
2367        expr = self.expr._subs(old, new)
2368        pt = [i._subs(old, new) for i in self.point]
2369        return self.func(expr, v, pt)
2370
2371    def _eval_derivative(self, s):
2372        # Apply the chain rule of the derivative on the substitution variables:
2373        f = self.expr
2374        vp = V, P = self.variables, self.point
2375        val = Add.fromiter(p.diff(s)*Subs(f.diff(v), *vp).doit()
2376            for v, p in zip(V, P))
2377
2378        # these are all the free symbols in the expr
2379        efree = f.free_symbols
2380        # some symbols like IndexedBase include themselves and args
2381        # as free symbols
2382        compound = {i for i in efree if len(i.free_symbols) > 1}
2383        # hide them and see what independent free symbols remain
2384        dums = {Dummy() for i in compound}
2385        masked = f.xreplace(dict(zip(compound, dums)))
2386        ifree = masked.free_symbols - dums
2387        # include the compound symbols
2388        free = ifree | compound
2389        # remove the variables already handled
2390        free -= set(V)
2391        # add back any free symbols of remaining compound symbols
2392        free |= {i for j in free & compound for i in j.free_symbols}
2393        # if symbols of s are in free then there is more to do
2394        if free & s.free_symbols:
2395            val += Subs(f.diff(s), self.variables, self.point).doit()
2396        return val
2397
2398    def _eval_nseries(self, x, n, logx, cdir=0):
2399        if x in self.point:
2400            # x is the variable being substituted into
2401            apos = self.point.index(x)
2402            other = self.variables[apos]
2403        else:
2404            other = x
2405        arg = self.expr.nseries(other, n=n, logx=logx)
2406        o = arg.getO()
2407        terms = Add.make_args(arg.removeO())
2408        rv = Add(*[self.func(a, *self.args[1:]) for a in terms])
2409        if o:
2410            rv += o.subs(other, x)
2411        return rv
2412
2413    def _eval_as_leading_term(self, x, logx=None, cdir=0):
2414        if x in self.point:
2415            ipos = self.point.index(x)
2416            xvar = self.variables[ipos]
2417            return self.expr.as_leading_term(xvar)
2418        if x in self.variables:
2419            # if `x` is a dummy variable, it means it won't exist after the
2420            # substitution has been performed:
2421            return self
2422        # The variable is independent of the substitution:
2423        return self.expr.as_leading_term(x)
2424
2425
2426def diff(f, *symbols, **kwargs):
2427    """
2428    Differentiate f with respect to symbols.
2429
2430    Explanation
2431    ===========
2432
2433    This is just a wrapper to unify .diff() and the Derivative class; its
2434    interface is similar to that of integrate().  You can use the same
2435    shortcuts for multiple variables as with Derivative.  For example,
2436    diff(f(x), x, x, x) and diff(f(x), x, 3) both return the third derivative
2437    of f(x).
2438
2439    You can pass evaluate=False to get an unevaluated Derivative class.  Note
2440    that if there are 0 symbols (such as diff(f(x), x, 0), then the result will
2441    be the function (the zeroth derivative), even if evaluate=False.
2442
2443    Examples
2444    ========
2445
2446    >>> from sympy import sin, cos, Function, diff
2447    >>> from sympy.abc import x, y
2448    >>> f = Function('f')
2449
2450    >>> diff(sin(x), x)
2451    cos(x)
2452    >>> diff(f(x), x, x, x)
2453    Derivative(f(x), (x, 3))
2454    >>> diff(f(x), x, 3)
2455    Derivative(f(x), (x, 3))
2456    >>> diff(sin(x)*cos(y), x, 2, y, 2)
2457    sin(x)*cos(y)
2458
2459    >>> type(diff(sin(x), x))
2460    cos
2461    >>> type(diff(sin(x), x, evaluate=False))
2462    <class 'sympy.core.function.Derivative'>
2463    >>> type(diff(sin(x), x, 0))
2464    sin
2465    >>> type(diff(sin(x), x, 0, evaluate=False))
2466    sin
2467
2468    >>> diff(sin(x))
2469    cos(x)
2470    >>> diff(sin(x*y))
2471    Traceback (most recent call last):
2472    ...
2473    ValueError: specify differentiation variables to differentiate sin(x*y)
2474
2475    Note that ``diff(sin(x))`` syntax is meant only for convenience
2476    in interactive sessions and should be avoided in library code.
2477
2478    References
2479    ==========
2480
2481    http://reference.wolfram.com/legacy/v5_2/Built-inFunctions/AlgebraicComputation/Calculus/D.html
2482
2483    See Also
2484    ========
2485
2486    Derivative
2487    idiff: computes the derivative implicitly
2488
2489    """
2490    if hasattr(f, 'diff'):
2491        return f.diff(*symbols, **kwargs)
2492    kwargs.setdefault('evaluate', True)
2493    return _derivative_dispatch(f, *symbols, **kwargs)
2494
2495
2496def expand(e, deep=True, modulus=None, power_base=True, power_exp=True,
2497        mul=True, log=True, multinomial=True, basic=True, **hints):
2498    r"""
2499    Expand an expression using methods given as hints.
2500
2501    Explanation
2502    ===========
2503
2504    Hints evaluated unless explicitly set to False are:  ``basic``, ``log``,
2505    ``multinomial``, ``mul``, ``power_base``, and ``power_exp`` The following
2506    hints are supported but not applied unless set to True:  ``complex``,
2507    ``func``, and ``trig``.  In addition, the following meta-hints are
2508    supported by some or all of the other hints:  ``frac``, ``numer``,
2509    ``denom``, ``modulus``, and ``force``.  ``deep`` is supported by all
2510    hints.  Additionally, subclasses of Expr may define their own hints or
2511    meta-hints.
2512
2513    The ``basic`` hint is used for any special rewriting of an object that
2514    should be done automatically (along with the other hints like ``mul``)
2515    when expand is called. This is a catch-all hint to handle any sort of
2516    expansion that may not be described by the existing hint names. To use
2517    this hint an object should override the ``_eval_expand_basic`` method.
2518    Objects may also define their own expand methods, which are not run by
2519    default.  See the API section below.
2520
2521    If ``deep`` is set to ``True`` (the default), things like arguments of
2522    functions are recursively expanded.  Use ``deep=False`` to only expand on
2523    the top level.
2524
2525    If the ``force`` hint is used, assumptions about variables will be ignored
2526    in making the expansion.
2527
2528    Hints
2529    =====
2530
2531    These hints are run by default
2532
2533    mul
2534    ---
2535
2536    Distributes multiplication over addition:
2537
2538    >>> from sympy import cos, exp, sin
2539    >>> from sympy.abc import x, y, z
2540    >>> (y*(x + z)).expand(mul=True)
2541    x*y + y*z
2542
2543    multinomial
2544    -----------
2545
2546    Expand (x + y + ...)**n where n is a positive integer.
2547
2548    >>> ((x + y + z)**2).expand(multinomial=True)
2549    x**2 + 2*x*y + 2*x*z + y**2 + 2*y*z + z**2
2550
2551    power_exp
2552    ---------
2553
2554    Expand addition in exponents into multiplied bases.
2555
2556    >>> exp(x + y).expand(power_exp=True)
2557    exp(x)*exp(y)
2558    >>> (2**(x + y)).expand(power_exp=True)
2559    2**x*2**y
2560
2561    power_base
2562    ----------
2563
2564    Split powers of multiplied bases.
2565
2566    This only happens by default if assumptions allow, or if the
2567    ``force`` meta-hint is used:
2568
2569    >>> ((x*y)**z).expand(power_base=True)
2570    (x*y)**z
2571    >>> ((x*y)**z).expand(power_base=True, force=True)
2572    x**z*y**z
2573    >>> ((2*y)**z).expand(power_base=True)
2574    2**z*y**z
2575
2576    Note that in some cases where this expansion always holds, SymPy performs
2577    it automatically:
2578
2579    >>> (x*y)**2
2580    x**2*y**2
2581
2582    log
2583    ---
2584
2585    Pull out power of an argument as a coefficient and split logs products
2586    into sums of logs.
2587
2588    Note that these only work if the arguments of the log function have the
2589    proper assumptions--the arguments must be positive and the exponents must
2590    be real--or else the ``force`` hint must be True:
2591
2592    >>> from sympy import log, symbols
2593    >>> log(x**2*y).expand(log=True)
2594    log(x**2*y)
2595    >>> log(x**2*y).expand(log=True, force=True)
2596    2*log(x) + log(y)
2597    >>> x, y = symbols('x,y', positive=True)
2598    >>> log(x**2*y).expand(log=True)
2599    2*log(x) + log(y)
2600
2601    basic
2602    -----
2603
2604    This hint is intended primarily as a way for custom subclasses to enable
2605    expansion by default.
2606
2607    These hints are not run by default:
2608
2609    complex
2610    -------
2611
2612    Split an expression into real and imaginary parts.
2613
2614    >>> x, y = symbols('x,y')
2615    >>> (x + y).expand(complex=True)
2616    re(x) + re(y) + I*im(x) + I*im(y)
2617    >>> cos(x).expand(complex=True)
2618    -I*sin(re(x))*sinh(im(x)) + cos(re(x))*cosh(im(x))
2619
2620    Note that this is just a wrapper around ``as_real_imag()``.  Most objects
2621    that wish to redefine ``_eval_expand_complex()`` should consider
2622    redefining ``as_real_imag()`` instead.
2623
2624    func
2625    ----
2626
2627    Expand other functions.
2628
2629    >>> from sympy import gamma
2630    >>> gamma(x + 1).expand(func=True)
2631    x*gamma(x)
2632
2633    trig
2634    ----
2635
2636    Do trigonometric expansions.
2637
2638    >>> cos(x + y).expand(trig=True)
2639    -sin(x)*sin(y) + cos(x)*cos(y)
2640    >>> sin(2*x).expand(trig=True)
2641    2*sin(x)*cos(x)
2642
2643    Note that the forms of ``sin(n*x)`` and ``cos(n*x)`` in terms of ``sin(x)``
2644    and ``cos(x)`` are not unique, due to the identity `\sin^2(x) + \cos^2(x)
2645    = 1`.  The current implementation uses the form obtained from Chebyshev
2646    polynomials, but this may change.  See `this MathWorld article
2647    <http://mathworld.wolfram.com/Multiple-AngleFormulas.html>`_ for more
2648    information.
2649
2650    Notes
2651    =====
2652
2653    - You can shut off unwanted methods::
2654
2655        >>> (exp(x + y)*(x + y)).expand()
2656        x*exp(x)*exp(y) + y*exp(x)*exp(y)
2657        >>> (exp(x + y)*(x + y)).expand(power_exp=False)
2658        x*exp(x + y) + y*exp(x + y)
2659        >>> (exp(x + y)*(x + y)).expand(mul=False)
2660        (x + y)*exp(x)*exp(y)
2661
2662    - Use deep=False to only expand on the top level::
2663
2664        >>> exp(x + exp(x + y)).expand()
2665        exp(x)*exp(exp(x)*exp(y))
2666        >>> exp(x + exp(x + y)).expand(deep=False)
2667        exp(x)*exp(exp(x + y))
2668
2669    - Hints are applied in an arbitrary, but consistent order (in the current
2670      implementation, they are applied in alphabetical order, except
2671      multinomial comes before mul, but this may change).  Because of this,
2672      some hints may prevent expansion by other hints if they are applied
2673      first. For example, ``mul`` may distribute multiplications and prevent
2674      ``log`` and ``power_base`` from expanding them. Also, if ``mul`` is
2675      applied before ``multinomial`, the expression might not be fully
2676      distributed. The solution is to use the various ``expand_hint`` helper
2677      functions or to use ``hint=False`` to this function to finely control
2678      which hints are applied. Here are some examples::
2679
2680        >>> from sympy import expand, expand_mul, expand_power_base
2681        >>> x, y, z = symbols('x,y,z', positive=True)
2682
2683        >>> expand(log(x*(y + z)))
2684        log(x) + log(y + z)
2685
2686      Here, we see that ``log`` was applied before ``mul``.  To get the mul
2687      expanded form, either of the following will work::
2688
2689        >>> expand_mul(log(x*(y + z)))
2690        log(x*y + x*z)
2691        >>> expand(log(x*(y + z)), log=False)
2692        log(x*y + x*z)
2693
2694      A similar thing can happen with the ``power_base`` hint::
2695
2696        >>> expand((x*(y + z))**x)
2697        (x*y + x*z)**x
2698
2699      To get the ``power_base`` expanded form, either of the following will
2700      work::
2701
2702        >>> expand((x*(y + z))**x, mul=False)
2703        x**x*(y + z)**x
2704        >>> expand_power_base((x*(y + z))**x)
2705        x**x*(y + z)**x
2706
2707        >>> expand((x + y)*y/x)
2708        y + y**2/x
2709
2710      The parts of a rational expression can be targeted::
2711
2712        >>> expand((x + y)*y/x/(x + 1), frac=True)
2713        (x*y + y**2)/(x**2 + x)
2714        >>> expand((x + y)*y/x/(x + 1), numer=True)
2715        (x*y + y**2)/(x*(x + 1))
2716        >>> expand((x + y)*y/x/(x + 1), denom=True)
2717        y*(x + y)/(x**2 + x)
2718
2719    - The ``modulus`` meta-hint can be used to reduce the coefficients of an
2720      expression post-expansion::
2721
2722        >>> expand((3*x + 1)**2)
2723        9*x**2 + 6*x + 1
2724        >>> expand((3*x + 1)**2, modulus=5)
2725        4*x**2 + x + 1
2726
2727    - Either ``expand()`` the function or ``.expand()`` the method can be
2728      used.  Both are equivalent::
2729
2730        >>> expand((x + 1)**2)
2731        x**2 + 2*x + 1
2732        >>> ((x + 1)**2).expand()
2733        x**2 + 2*x + 1
2734
2735    API
2736    ===
2737
2738    Objects can define their own expand hints by defining
2739    ``_eval_expand_hint()``.  The function should take the form::
2740
2741        def _eval_expand_hint(self, **hints):
2742            # Only apply the method to the top-level expression
2743            ...
2744
2745    See also the example below.  Objects should define ``_eval_expand_hint()``
2746    methods only if ``hint`` applies to that specific object.  The generic
2747    ``_eval_expand_hint()`` method defined in Expr will handle the no-op case.
2748
2749    Each hint should be responsible for expanding that hint only.
2750    Furthermore, the expansion should be applied to the top-level expression
2751    only.  ``expand()`` takes care of the recursion that happens when
2752    ``deep=True``.
2753
2754    You should only call ``_eval_expand_hint()`` methods directly if you are
2755    100% sure that the object has the method, as otherwise you are liable to
2756    get unexpected ``AttributeError``s.  Note, again, that you do not need to
2757    recursively apply the hint to args of your object: this is handled
2758    automatically by ``expand()``.  ``_eval_expand_hint()`` should
2759    generally not be used at all outside of an ``_eval_expand_hint()`` method.
2760    If you want to apply a specific expansion from within another method, use
2761    the public ``expand()`` function, method, or ``expand_hint()`` functions.
2762
2763    In order for expand to work, objects must be rebuildable by their args,
2764    i.e., ``obj.func(*obj.args) == obj`` must hold.
2765
2766    Expand methods are passed ``**hints`` so that expand hints may use
2767    'metahints'--hints that control how different expand methods are applied.
2768    For example, the ``force=True`` hint described above that causes
2769    ``expand(log=True)`` to ignore assumptions is such a metahint.  The
2770    ``deep`` meta-hint is handled exclusively by ``expand()`` and is not
2771    passed to ``_eval_expand_hint()`` methods.
2772
2773    Note that expansion hints should generally be methods that perform some
2774    kind of 'expansion'.  For hints that simply rewrite an expression, use the
2775    .rewrite() API.
2776
2777    Examples
2778    ========
2779
2780    >>> from sympy import Expr, sympify
2781    >>> class MyClass(Expr):
2782    ...     def __new__(cls, *args):
2783    ...         args = sympify(args)
2784    ...         return Expr.__new__(cls, *args)
2785    ...
2786    ...     def _eval_expand_double(self, *, force=False, **hints):
2787    ...         '''
2788    ...         Doubles the args of MyClass.
2789    ...
2790    ...         If there more than four args, doubling is not performed,
2791    ...         unless force=True is also used (False by default).
2792    ...         '''
2793    ...         if not force and len(self.args) > 4:
2794    ...             return self
2795    ...         return self.func(*(self.args + self.args))
2796    ...
2797    >>> a = MyClass(1, 2, MyClass(3, 4))
2798    >>> a
2799    MyClass(1, 2, MyClass(3, 4))
2800    >>> a.expand(double=True)
2801    MyClass(1, 2, MyClass(3, 4, 3, 4), 1, 2, MyClass(3, 4, 3, 4))
2802    >>> a.expand(double=True, deep=False)
2803    MyClass(1, 2, MyClass(3, 4), 1, 2, MyClass(3, 4))
2804
2805    >>> b = MyClass(1, 2, 3, 4, 5)
2806    >>> b.expand(double=True)
2807    MyClass(1, 2, 3, 4, 5)
2808    >>> b.expand(double=True, force=True)
2809    MyClass(1, 2, 3, 4, 5, 1, 2, 3, 4, 5)
2810
2811    See Also
2812    ========
2813
2814    expand_log, expand_mul, expand_multinomial, expand_complex, expand_trig,
2815    expand_power_base, expand_power_exp, expand_func, sympy.simplify.hyperexpand.hyperexpand
2816
2817    """
2818    # don't modify this; modify the Expr.expand method
2819    hints['power_base'] = power_base
2820    hints['power_exp'] = power_exp
2821    hints['mul'] = mul
2822    hints['log'] = log
2823    hints['multinomial'] = multinomial
2824    hints['basic'] = basic
2825    return sympify(e).expand(deep=deep, modulus=modulus, **hints)
2826
2827# This is a special application of two hints
2828
2829def _mexpand(expr, recursive=False):
2830    # expand multinomials and then expand products; this may not always
2831    # be sufficient to give a fully expanded expression (see
2832    # test_issue_8247_8354 in test_arit)
2833    if expr is None:
2834        return
2835    was = None
2836    while was != expr:
2837        was, expr = expr, expand_mul(expand_multinomial(expr))
2838        if not recursive:
2839            break
2840    return expr
2841
2842
2843# These are simple wrappers around single hints.
2844
2845
2846def expand_mul(expr, deep=True):
2847    """
2848    Wrapper around expand that only uses the mul hint.  See the expand
2849    docstring for more information.
2850
2851    Examples
2852    ========
2853
2854    >>> from sympy import symbols, expand_mul, exp, log
2855    >>> x, y = symbols('x,y', positive=True)
2856    >>> expand_mul(exp(x+y)*(x+y)*log(x*y**2))
2857    x*exp(x + y)*log(x*y**2) + y*exp(x + y)*log(x*y**2)
2858
2859    """
2860    return sympify(expr).expand(deep=deep, mul=True, power_exp=False,
2861    power_base=False, basic=False, multinomial=False, log=False)
2862
2863
2864def expand_multinomial(expr, deep=True):
2865    """
2866    Wrapper around expand that only uses the multinomial hint.  See the expand
2867    docstring for more information.
2868
2869    Examples
2870    ========
2871
2872    >>> from sympy import symbols, expand_multinomial, exp
2873    >>> x, y = symbols('x y', positive=True)
2874    >>> expand_multinomial((x + exp(x + 1))**2)
2875    x**2 + 2*x*exp(x + 1) + exp(2*x + 2)
2876
2877    """
2878    return sympify(expr).expand(deep=deep, mul=False, power_exp=False,
2879    power_base=False, basic=False, multinomial=True, log=False)
2880
2881
2882def expand_log(expr, deep=True, force=False, factor=False):
2883    """
2884    Wrapper around expand that only uses the log hint.  See the expand
2885    docstring for more information.
2886
2887    Examples
2888    ========
2889
2890    >>> from sympy import symbols, expand_log, exp, log
2891    >>> x, y = symbols('x,y', positive=True)
2892    >>> expand_log(exp(x+y)*(x+y)*log(x*y**2))
2893    (x + y)*(log(x) + 2*log(y))*exp(x + y)
2894
2895    """
2896    from sympy import Mul, log
2897    if factor is False:
2898        def _handle(x):
2899            x1 = expand_mul(expand_log(x, deep=deep, force=force, factor=True))
2900            if x1.count(log) <= x.count(log):
2901                return x1
2902            return x
2903
2904        expr = expr.replace(
2905        lambda x: x.is_Mul and all(any(isinstance(i, log) and i.args[0].is_Rational
2906        for i in Mul.make_args(j)) for j in x.as_numer_denom()),
2907        lambda x: _handle(x))
2908
2909    return sympify(expr).expand(deep=deep, log=True, mul=False,
2910        power_exp=False, power_base=False, multinomial=False,
2911        basic=False, force=force, factor=factor)
2912
2913
2914def expand_func(expr, deep=True):
2915    """
2916    Wrapper around expand that only uses the func hint.  See the expand
2917    docstring for more information.
2918
2919    Examples
2920    ========
2921
2922    >>> from sympy import expand_func, gamma
2923    >>> from sympy.abc import x
2924    >>> expand_func(gamma(x + 2))
2925    x*(x + 1)*gamma(x)
2926
2927    """
2928    return sympify(expr).expand(deep=deep, func=True, basic=False,
2929    log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
2930
2931
2932def expand_trig(expr, deep=True):
2933    """
2934    Wrapper around expand that only uses the trig hint.  See the expand
2935    docstring for more information.
2936
2937    Examples
2938    ========
2939
2940    >>> from sympy import expand_trig, sin
2941    >>> from sympy.abc import x, y
2942    >>> expand_trig(sin(x+y)*(x+y))
2943    (x + y)*(sin(x)*cos(y) + sin(y)*cos(x))
2944
2945    """
2946    return sympify(expr).expand(deep=deep, trig=True, basic=False,
2947    log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
2948
2949
2950def expand_complex(expr, deep=True):
2951    """
2952    Wrapper around expand that only uses the complex hint.  See the expand
2953    docstring for more information.
2954
2955    Examples
2956    ========
2957
2958    >>> from sympy import expand_complex, exp, sqrt, I
2959    >>> from sympy.abc import z
2960    >>> expand_complex(exp(z))
2961    I*exp(re(z))*sin(im(z)) + exp(re(z))*cos(im(z))
2962    >>> expand_complex(sqrt(I))
2963    sqrt(2)/2 + sqrt(2)*I/2
2964
2965    See Also
2966    ========
2967
2968    sympy.core.expr.Expr.as_real_imag
2969    """
2970    return sympify(expr).expand(deep=deep, complex=True, basic=False,
2971    log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
2972
2973
2974def expand_power_base(expr, deep=True, force=False):
2975    """
2976    Wrapper around expand that only uses the power_base hint.
2977
2978    A wrapper to expand(power_base=True) which separates a power with a base
2979    that is a Mul into a product of powers, without performing any other
2980    expansions, provided that assumptions about the power's base and exponent
2981    allow.
2982
2983    deep=False (default is True) will only apply to the top-level expression.
2984
2985    force=True (default is False) will cause the expansion to ignore
2986    assumptions about the base and exponent. When False, the expansion will
2987    only happen if the base is non-negative or the exponent is an integer.
2988
2989    >>> from sympy.abc import x, y, z
2990    >>> from sympy import expand_power_base, sin, cos, exp
2991
2992    >>> (x*y)**2
2993    x**2*y**2
2994
2995    >>> (2*x)**y
2996    (2*x)**y
2997    >>> expand_power_base(_)
2998    2**y*x**y
2999
3000    >>> expand_power_base((x*y)**z)
3001    (x*y)**z
3002    >>> expand_power_base((x*y)**z, force=True)
3003    x**z*y**z
3004    >>> expand_power_base(sin((x*y)**z), deep=False)
3005    sin((x*y)**z)
3006    >>> expand_power_base(sin((x*y)**z), force=True)
3007    sin(x**z*y**z)
3008
3009    >>> expand_power_base((2*sin(x))**y + (2*cos(x))**y)
3010    2**y*sin(x)**y + 2**y*cos(x)**y
3011
3012    >>> expand_power_base((2*exp(y))**x)
3013    2**x*exp(y)**x
3014
3015    >>> expand_power_base((2*cos(x))**y)
3016    2**y*cos(x)**y
3017
3018    Notice that sums are left untouched. If this is not the desired behavior,
3019    apply full ``expand()`` to the expression:
3020
3021    >>> expand_power_base(((x+y)*z)**2)
3022    z**2*(x + y)**2
3023    >>> (((x+y)*z)**2).expand()
3024    x**2*z**2 + 2*x*y*z**2 + y**2*z**2
3025
3026    >>> expand_power_base((2*y)**(1+z))
3027    2**(z + 1)*y**(z + 1)
3028    >>> ((2*y)**(1+z)).expand()
3029    2*2**z*y*y**z
3030
3031    See Also
3032    ========
3033
3034    expand
3035
3036    """
3037    return sympify(expr).expand(deep=deep, log=False, mul=False,
3038        power_exp=False, power_base=True, multinomial=False,
3039        basic=False, force=force)
3040
3041
3042def expand_power_exp(expr, deep=True):
3043    """
3044    Wrapper around expand that only uses the power_exp hint.
3045
3046    See the expand docstring for more information.
3047
3048    Examples
3049    ========
3050
3051    >>> from sympy import expand_power_exp
3052    >>> from sympy.abc import x, y
3053    >>> expand_power_exp(x**(y + 2))
3054    x**2*x**y
3055    """
3056    return sympify(expr).expand(deep=deep, complex=False, basic=False,
3057    log=False, mul=False, power_exp=True, power_base=False, multinomial=False)
3058
3059
3060def count_ops(expr, visual=False):
3061    """
3062    Return a representation (integer or expression) of the operations in expr.
3063
3064    Parameters
3065    ==========
3066
3067    expr : Expr
3068        If expr is an iterable, the sum of the op counts of the
3069        items will be returned.
3070
3071    visual : bool, optional
3072        If ``False`` (default) then the sum of the coefficients of the
3073        visual expression will be returned.
3074        If ``True`` then the number of each type of operation is shown
3075        with the core class types (or their virtual equivalent) multiplied by the
3076        number of times they occur.
3077
3078    Examples
3079    ========
3080
3081    >>> from sympy.abc import a, b, x, y
3082    >>> from sympy import sin, count_ops
3083
3084    Although there isn't a SUB object, minus signs are interpreted as
3085    either negations or subtractions:
3086
3087    >>> (x - y).count_ops(visual=True)
3088    SUB
3089    >>> (-x).count_ops(visual=True)
3090    NEG
3091
3092    Here, there are two Adds and a Pow:
3093
3094    >>> (1 + a + b**2).count_ops(visual=True)
3095    2*ADD + POW
3096
3097    In the following, an Add, Mul, Pow and two functions:
3098
3099    >>> (sin(x)*x + sin(x)**2).count_ops(visual=True)
3100    ADD + MUL + POW + 2*SIN
3101
3102    for a total of 5:
3103
3104    >>> (sin(x)*x + sin(x)**2).count_ops(visual=False)
3105    5
3106
3107    Note that "what you type" is not always what you get. The expression
3108    1/x/y is translated by sympy into 1/(x*y) so it gives a DIV and MUL rather
3109    than two DIVs:
3110
3111    >>> (1/x/y).count_ops(visual=True)
3112    DIV + MUL
3113
3114    The visual option can be used to demonstrate the difference in
3115    operations for expressions in different forms. Here, the Horner
3116    representation is compared with the expanded form of a polynomial:
3117
3118    >>> eq=x*(1 + x*(2 + x*(3 + x)))
3119    >>> count_ops(eq.expand(), visual=True) - count_ops(eq, visual=True)
3120    -MUL + 3*POW
3121
3122    The count_ops function also handles iterables:
3123
3124    >>> count_ops([x, sin(x), None, True, x + 2], visual=False)
3125    2
3126    >>> count_ops([x, sin(x), None, True, x + 2], visual=True)
3127    ADD + SIN
3128    >>> count_ops({x: sin(x), x + 2: y + 1}, visual=True)
3129    2*ADD + SIN
3130
3131    """
3132    from sympy import Integral, Sum, Symbol
3133    from sympy.core.relational import Relational
3134    from sympy.simplify.radsimp import fraction
3135    from sympy.logic.boolalg import BooleanFunction
3136    from sympy.utilities.misc import func_name
3137
3138    expr = sympify(expr)
3139    if isinstance(expr, Expr) and not expr.is_Relational:
3140
3141        ops = []
3142        args = [expr]
3143        NEG = Symbol('NEG')
3144        DIV = Symbol('DIV')
3145        SUB = Symbol('SUB')
3146        ADD = Symbol('ADD')
3147        EXP = Symbol('EXP')
3148        while args:
3149            a = args.pop()
3150
3151            if a.is_Rational:
3152                #-1/3 = NEG + DIV
3153                if a is not S.One:
3154                    if a.p < 0:
3155                        ops.append(NEG)
3156                    if a.q != 1:
3157                        ops.append(DIV)
3158                    continue
3159            elif a.is_Mul or a.is_MatMul:
3160                if _coeff_isneg(a):
3161                    ops.append(NEG)
3162                    if a.args[0] is S.NegativeOne:
3163                        a = a.as_two_terms()[1]
3164                    else:
3165                        a = -a
3166                n, d = fraction(a)
3167                if n.is_Integer:
3168                    ops.append(DIV)
3169                    if n < 0:
3170                        ops.append(NEG)
3171                    args.append(d)
3172                    continue  # won't be -Mul but could be Add
3173                elif d is not S.One:
3174                    if not d.is_Integer:
3175                        args.append(d)
3176                    ops.append(DIV)
3177                    args.append(n)
3178                    continue  # could be -Mul
3179            elif a.is_Add or a.is_MatAdd:
3180                aargs = list(a.args)
3181                negs = 0
3182                for i, ai in enumerate(aargs):
3183                    if _coeff_isneg(ai):
3184                        negs += 1
3185                        args.append(-ai)
3186                        if i > 0:
3187                            ops.append(SUB)
3188                    else:
3189                        args.append(ai)
3190                        if i > 0:
3191                            ops.append(ADD)
3192                if negs == len(aargs):  # -x - y = NEG + SUB
3193                    ops.append(NEG)
3194                elif _coeff_isneg(aargs[0]):  # -x + y = SUB, but already recorded ADD
3195                    ops.append(SUB - ADD)
3196                continue
3197            if a.is_Pow and a.exp is S.NegativeOne:
3198                ops.append(DIV)
3199                args.append(a.base)  # won't be -Mul but could be Add
3200                continue
3201            if a == S.Exp1:
3202                ops.append(EXP)
3203                continue
3204            if a.is_Pow and a.base == S.Exp1:
3205                ops.append(EXP)
3206                args.append(a.exp)
3207                continue
3208            if a.is_Mul or isinstance(a, LatticeOp):
3209                o = Symbol(a.func.__name__.upper())
3210                # count the args
3211                ops.append(o*(len(a.args) - 1))
3212            elif a.args and (
3213                    a.is_Pow or
3214                    a.is_Function or
3215                    isinstance(a, Derivative) or
3216                    isinstance(a, Integral) or
3217                    isinstance(a, Sum)):
3218                # if it's not in the list above we don't
3219                # consider a.func something to count, e.g.
3220                # Tuple, MatrixSymbol, etc...
3221                o = Symbol(a.func.__name__.upper())
3222                ops.append(o)
3223
3224            if not a.is_Symbol:
3225                args.extend(a.args)
3226
3227    elif isinstance(expr, Dict):
3228        ops = [count_ops(k, visual=visual) +
3229               count_ops(v, visual=visual) for k, v in expr.items()]
3230    elif iterable(expr):
3231        ops = [count_ops(i, visual=visual) for i in expr]
3232    elif isinstance(expr, (Relational, BooleanFunction)):
3233        ops = []
3234        for arg in expr.args:
3235            ops.append(count_ops(arg, visual=True))
3236        o = Symbol(func_name(expr, short=True).upper())
3237        ops.append(o)
3238    elif not isinstance(expr, Basic):
3239        ops = []
3240    else:  # it's Basic not isinstance(expr, Expr):
3241        if not isinstance(expr, Basic):
3242            raise TypeError("Invalid type of expr")
3243        else:
3244            ops = []
3245            args = [expr]
3246            while args:
3247                a = args.pop()
3248
3249                if a.args:
3250                    o = Symbol(type(a).__name__.upper())
3251                    if a.is_Boolean:
3252                        ops.append(o*(len(a.args)-1))
3253                    else:
3254                        ops.append(o)
3255                    args.extend(a.args)
3256
3257    if not ops:
3258        if visual:
3259            return S.Zero
3260        return 0
3261
3262    ops = Add(*ops)
3263
3264    if visual:
3265        return ops
3266
3267    if ops.is_Number:
3268        return int(ops)
3269
3270    return sum(int((a.args or [1])[0]) for a in Add.make_args(ops))
3271
3272
3273def nfloat(expr, n=15, exponent=False, dkeys=False):
3274    """Make all Rationals in expr Floats except those in exponents
3275    (unless the exponents flag is set to True). When processing
3276    dictionaries, don't modify the keys unless ``dkeys=True``.
3277
3278    Examples
3279    ========
3280
3281    >>> from sympy.core.function import nfloat
3282    >>> from sympy.abc import x, y
3283    >>> from sympy import cos, pi, sqrt
3284    >>> nfloat(x**4 + x/2 + cos(pi/3) + 1 + sqrt(y))
3285    x**4 + 0.5*x + sqrt(y) + 1.5
3286    >>> nfloat(x**4 + sqrt(y), exponent=True)
3287    x**4.0 + y**0.5
3288
3289    Container types are not modified:
3290
3291    >>> type(nfloat((1, 2))) is tuple
3292    True
3293    """
3294    from sympy.core.power import Pow
3295    from sympy.polys.rootoftools import RootOf
3296    from sympy import MatrixBase
3297
3298    kw = dict(n=n, exponent=exponent, dkeys=dkeys)
3299
3300    if isinstance(expr, MatrixBase):
3301        return expr.applyfunc(lambda e: nfloat(e, **kw))
3302
3303    # handling of iterable containers
3304    if iterable(expr, exclude=str):
3305        if isinstance(expr, (dict, Dict)):
3306            if dkeys:
3307                args = [tuple(map(lambda i: nfloat(i, **kw), a))
3308                    for a in expr.items()]
3309            else:
3310                args = [(k, nfloat(v, **kw)) for k, v in expr.items()]
3311            if isinstance(expr, dict):
3312                return type(expr)(args)
3313            else:
3314                return expr.func(*args)
3315        elif isinstance(expr, Basic):
3316            return expr.func(*[nfloat(a, **kw) for a in expr.args])
3317        return type(expr)([nfloat(a, **kw) for a in expr])
3318
3319    rv = sympify(expr)
3320
3321    if rv.is_Number:
3322        return Float(rv, n)
3323    elif rv.is_number:
3324        # evalf doesn't always set the precision
3325        rv = rv.n(n)
3326        if rv.is_Number:
3327            rv = Float(rv.n(n), n)
3328        else:
3329            pass  # pure_complex(rv) is likely True
3330        return rv
3331    elif rv.is_Atom:
3332        return rv
3333    elif rv.is_Relational:
3334        args_nfloat = (nfloat(arg, **kw) for arg in rv.args)
3335        return rv.func(*args_nfloat)
3336
3337
3338    # watch out for RootOf instances that don't like to have
3339    # their exponents replaced with Dummies and also sometimes have
3340    # problems with evaluating at low precision (issue 6393)
3341    rv = rv.xreplace({ro: ro.n(n) for ro in rv.atoms(RootOf)})
3342
3343    if not exponent:
3344        reps = [(p, Pow(p.base, Dummy())) for p in rv.atoms(Pow)]
3345        rv = rv.xreplace(dict(reps))
3346    rv = rv.n(n)
3347    if not exponent:
3348        rv = rv.xreplace({d.exp: p.exp for p, d in reps})
3349    else:
3350        # Pow._eval_evalf special cases Integer exponents so if
3351        # exponent is suppose to be handled we have to do so here
3352        rv = rv.xreplace(Transform(
3353            lambda x: Pow(x.base, Float(x.exp, n)),
3354            lambda x: x.is_Pow and x.exp.is_Integer))
3355
3356    return rv.xreplace(Transform(
3357        lambda x: x.func(*nfloat(x.args, n, exponent)),
3358        lambda x: isinstance(x, Function)))
3359
3360
3361from sympy.core.symbol import Dummy, Symbol
3362