1"""
2There are three types of functions implemented in Diofant:
3
4    1) defined functions (in the sense that they can be evaluated) like
5       exp or sin; they have a name and a body:
6           f = exp
7    2) undefined function which have a name but no body. Undefined
8       functions can be defined using a Function class as follows:
9           f = Function('f')
10       (the result will be a Function instance)
11    3) anonymous function (or lambda function) which have a body (defined
12       with dummy variables) but have no name:
13           f = Lambda(x, exp(x)*x)
14           f = Lambda((x, y), exp(x)*y)
15
16    Examples
17    ========
18
19    >>> f(x)
20    f(x)
21    >>> print(repr(f(x).func))
22    Function('f')
23    >>> f(x).args
24    (x,)
25
26"""
27
28from __future__ import annotations
29
30import collections
31import inspect
32import typing
33
34import mpmath
35import mpmath.libmp as mlib
36
37from ..utilities import default_sort_key, ordered
38from ..utilities.iterables import uniq
39from .add import Add
40from .assumptions import ManagedProperties
41from .basic import Basic
42from .cache import cacheit
43from .compatibility import as_int, is_sequence, iterable
44from .containers import Dict, Tuple
45from .decorators import _sympifyit
46from .evalf import PrecisionExhausted
47from .evaluate import global_evaluate
48from .expr import AtomicExpr, Expr
49from .logic import fuzzy_and
50from .numbers import Float, Integer, Rational, nan
51from .operations import LatticeOp
52from .rules import Transform
53from .singleton import S
54from .sympify import sympify
55
56
57def _coeff_isneg(a):
58    """Return True if the leading Number is negative.
59
60    Examples
61    ========
62
63    >>> _coeff_isneg(-3*pi)
64    True
65    >>> _coeff_isneg(Integer(3))
66    False
67    >>> _coeff_isneg(-oo)
68    True
69    >>> _coeff_isneg(Symbol('n', negative=True))  # coeff is 1
70    False
71
72    """
73    if a.is_Mul or a.is_MatMul:
74        a = a.args[0]
75    return a.is_Number and a.is_negative
76
77
78class PoleError(Exception):
79    """Raised when an expansion pole is encountered."""
80
81
82class ArgumentIndexError(ValueError):
83    """Raised when an invalid operation for positional argument happened."""
84
85    def __str__(self):
86        return ('Invalid operation with argument number %s for Function %s' %
87                (self.args[1], self.args[0]))
88
89
90class FunctionClass(ManagedProperties):
91    """
92    Base class for function classes. FunctionClass is a subclass of type.
93
94    Use Function('<function name>' [ , signature ]) to create
95    undefined function classes.
96
97    """
98
99    def __init__(self, *args, **kwargs):
100        assert hasattr(self, 'eval')
101        evalargspec = inspect.getfullargspec(self.eval)
102        if evalargspec.varargs:
103            evalargs = None
104        else:
105            evalargs = len(evalargspec.args) - 1  # subtract 1 for cls
106            if evalargspec.defaults:
107                # if there are default args then they are optional; the
108                # fewest args will occur when all defaults are used and
109                # the most when none are used (i.e. all args are given)
110                evalargs = tuple(range(evalargs - len(evalargspec.defaults),
111                                       evalargs + 1))
112        # honor kwarg value or class-defined value before using
113        # the number of arguments in the eval function (if present)
114        nargs = kwargs.pop('nargs', self.__dict__.get('nargs', evalargs))
115        super().__init__(args, kwargs)
116
117        # Canonicalize nargs here; change to set in nargs.
118        if is_sequence(nargs):
119            if not nargs:
120                raise ValueError('Incorrectly specified nargs as %s' % str(nargs))
121            nargs = tuple(ordered(set(nargs)))
122        elif nargs is not None:
123            nargs = as_int(nargs),
124        self._nargs = nargs
125
126    @property
127    def __signature__(self):
128        """
129        Allow inspect.signature to give a useful signature for
130        Function subclasses.
131
132        """
133        # TODO: Look at nargs
134        return inspect.signature(self.eval)
135
136    @property
137    def nargs(self):
138        """Return a set of the allowed number of arguments for the function.
139
140        Examples
141        ========
142
143        If the function can take any number of arguments, the set of whole
144        numbers is returned:
145
146        >>> Function('f').nargs
147        Naturals0()
148
149        If the function was initialized to accept one or more arguments, a
150        corresponding set will be returned:
151
152        >>> Function('f', nargs=1).nargs
153        {1}
154        >>> Function('f', nargs=(2, 1)).nargs
155        {1, 2}
156
157        The undefined function, after application, also has the nargs
158        attribute; the actual number of arguments is always available by
159        checking the ``args`` attribute:
160
161        >>> f(1).nargs
162        Naturals0()
163        >>> len(f(1).args)
164        1
165
166        """
167        from ..sets.sets import FiniteSet
168
169        # XXX it would be nice to handle this in __init__ but there are import
170        # problems with trying to import FiniteSet there
171        return FiniteSet(*self._nargs) if self._nargs else S.Naturals0
172
173    def __repr__(self):
174        if issubclass(self, AppliedUndef):
175            return f'Function({self.__name__!r})'
176        else:
177            return self.__name__
178
179    def __str__(self):
180        return self.__name__
181
182
183class Application(Expr, metaclass=FunctionClass):
184    """
185    Base class for applied functions.
186
187    Instances of Application represent the result of applying an application of
188    any type to any object.
189
190    """
191
192    is_Function = True
193
194    @cacheit
195    def __new__(cls, *args, **options):
196        from ..sets.fancysets import Naturals0
197        from ..sets.sets import FiniteSet
198
199        args = list(map(sympify, args))
200        evaluate = options.pop('evaluate', global_evaluate[0])
201        # WildFunction (and anything else like it) may have nargs defined
202        # and we throw that value away here
203        options.pop('nargs', None)
204
205        if options:
206            raise ValueError(f'Unknown options: {options}')
207
208        if evaluate:
209            if nan in args:
210                return nan
211
212            evaluated = cls.eval(*args)
213            if evaluated is not None:
214                return evaluated
215
216        obj = super().__new__(cls, *args, **options)
217
218        # make nargs uniform here
219        try:
220            # things passing through here:
221            #  - functions subclassed from Function (e.g. myfunc(1).nargs)
222            #  - functions like cos(1).nargs
223            #  - AppliedUndef with given nargs like Function('f', nargs=1)(1).nargs
224            # Canonicalize nargs here
225            if is_sequence(obj.nargs):
226                nargs = tuple(ordered(set(obj.nargs)))
227            elif obj.nargs is not None:
228                nargs = as_int(obj.nargs),
229            else:
230                nargs = None
231        except AttributeError:
232            # things passing through here:
233            #  - WildFunction('f').nargs
234            #  - AppliedUndef with no nargs like Function('f')(1).nargs
235            nargs = obj._nargs  # note the underscore here
236
237        obj.nargs = FiniteSet(*nargs) if nargs else Naturals0()
238        return obj
239
240    @classmethod
241    def eval(cls, *args):
242        """
243        Returns a canonical form of cls applied to arguments args.
244
245        The eval() method is called when the class cls is about to be
246        instantiated and it should return either some simplified instance
247        (possible of some other class), or if the class cls should be
248        unmodified, return None.
249
250        """
251        return
252
253    def _eval_subs(self, old, new):
254        if (old.is_Function and new.is_Function and old == self.func and
255                len(self.args) in new.nargs):
256            return new(*self.args)
257
258
259class Function(Application, Expr):
260    """Base class for applied mathematical functions.
261
262    It also serves as a constructor for undefined function classes.
263
264    Examples
265    ========
266
267    First example shows how to use Function as a constructor for undefined
268    function classes:
269
270    >>> g = g(x)
271    >>> f
272    f
273    >>> f(x)
274    f(x)
275    >>> g
276    g(x)
277    >>> f(x).diff(x)
278    Derivative(f(x), x)
279    >>> g.diff(x)
280    Derivative(g(x), x)
281
282    In the following example Function is used as a base class for
283    ``MyFunc`` that represents a mathematical function *MyFunc*. Suppose
284    that it is well known, that *MyFunc(0)* is *1* and *MyFunc* at infinity
285    goes to *0*, so we want those two simplifications to occur automatically.
286    Suppose also that *MyFunc(x)* is real exactly when *x* is real. Here is
287    an implementation that honours those requirements:
288
289    >>> class MyFunc(Function):
290    ...
291    ...     @classmethod
292    ...     def eval(cls, x):
293    ...         if x.is_Number:
294    ...             if x == 0:
295    ...                 return Integer(1)
296    ...             elif x is oo:
297    ...                 return Integer(0)
298    ...
299    ...     def _eval_is_real(self):
300    ...         return self.args[0].is_real
301    ...
302    >>> MyFunc(0) + sin(0)
303    1
304    >>> MyFunc(oo)
305    0
306    >>> MyFunc(3.54).evalf()  # Not yet implemented for MyFunc.
307    MyFunc(3.54)
308    >>> MyFunc(I).is_real
309    False
310
311    In order for ``MyFunc`` to become useful, several other methods would
312    need to be implemented. See source code of some of the already
313    implemented functions for more complete examples.
314
315    Also, if the function can take more than one argument, then ``nargs``
316    must be defined, e.g. if ``MyFunc`` can take one or two arguments
317    then,
318
319    >>> class MyFunc(Function):
320    ...     nargs = (1, 2)
321    ...
322    >>>
323
324    """
325
326    @property
327    def _diff_wrt(self):
328        """Allow derivatives wrt functions.
329
330        Examples
331        ========
332
333        >>> f(x)._diff_wrt
334        True
335
336        """
337        return True
338
339    @cacheit
340    def __new__(cls, *args, **options):
341        # Handle calls like Function('f')
342        if cls is Function:
343            return UndefinedFunction(*args, **options)
344
345        n = len(args)
346        if n not in cls.nargs:
347            # XXX: exception message must be in exactly this format to
348            # make it work with NumPy's functions like vectorize(). See,
349            # for example, https://github.com/numpy/numpy/issues/1697.
350            # The ideal solution would be just to attach metadata to
351            # the exception and change NumPy to take advantage of this.
352            temp = ('%(name)s takes %(qual)s %(args)s '
353                    'argument%(plural)s (%(given)s given)')
354            raise TypeError(temp % {
355                'name': cls,
356                'qual': 'exactly' if len(cls.nargs) == 1 else 'at least',
357                'args': min(cls.nargs),
358                'plural': 's'*(min(cls.nargs) != 1),
359                'given': n})
360
361        evaluate = options.get('evaluate', global_evaluate[0])
362        result = super().__new__(cls, *args, **options)
363        if not evaluate or not isinstance(result, cls):
364            return result
365
366        pr = max(cls._should_evalf(a) for a in result.args)
367        pr2 = min(cls._should_evalf(a) for a in result.args)
368        if pr2 > 0:
369            return result.evalf(mlib.libmpf.prec_to_dps(pr), strict=False)
370        return result
371
372    @classmethod
373    def _should_evalf(cls, arg):
374        """
375        Decide if the function should automatically evalf().
376
377        By default (in this implementation), this happens if (and only if) the
378        ARG is a floating point number.
379        This function is used by __new__.
380
381        """
382        if arg.is_Float:
383            return arg._prec
384        if not arg.is_Add:
385            return -1
386        re, im = arg.as_real_imag()
387        l = [a._prec for a in [re, im] if a.is_Float]
388        l.append(-1)
389        return max(l)
390
391    @classmethod
392    def class_key(cls):
393        """Nice order of classes."""
394        from ..sets.fancysets import Naturals0
395        funcs = {
396            'log': 11,
397            'sin': 20,
398            'cos': 21,
399            'tan': 22,
400            'cot': 23,
401            'sinh': 30,
402            'cosh': 31,
403            'tanh': 32,
404            'coth': 33,
405            'conjugate': 40,
406            're': 41,
407            'im': 42,
408            'arg': 43,
409        }
410        name = cls.__name__
411
412        try:
413            i = funcs[name]
414        except KeyError:
415            i = 0 if isinstance(cls.nargs, Naturals0) else 10000
416
417        return 4, i, name
418
419    def _eval_evalf(self, prec):
420        # Lookup mpmath function based on name
421        try:
422            if isinstance(self.func, UndefinedFunction):
423                # Shouldn't lookup in mpmath but might have ._imp_
424                raise AttributeError
425            fname = self.func.__name__
426            if not hasattr(mpmath, fname):
427                from ..utilities.lambdify import MPMATH_TRANSLATIONS
428                fname = MPMATH_TRANSLATIONS[fname]
429            func = getattr(mpmath, fname)
430        except (AttributeError, KeyError):
431            try:
432                return Float(self._imp_(*[i.evalf(prec) for i in self.args]), prec)
433            except (AttributeError, TypeError, ValueError, PrecisionExhausted):
434                return
435
436        # Convert all args to mpf or mpc
437        # Convert the arguments to *higher* precision than requested for the
438        # final result.
439        # XXX + 5 is a guess, it is similar to what is used in evalf.py. Should
440        #     we be more intelligent about it?
441        try:
442            args = [arg._to_mpmath(prec + 5) for arg in self.args]
443        except ValueError:
444            return
445
446        with mpmath.workprec(prec):
447            v = func(*args)
448
449        return Expr._from_mpmath(v, prec)
450
451    def _eval_derivative(self, s):
452        # f(x).diff(s) -> x.diff(s) * f.fdiff(1)(s)
453        i = 0
454        l = []
455        for a in self.args:
456            i += 1
457            da = a.diff(s)
458            if da == 0:
459                continue
460            try:
461                df = self.fdiff(i)
462            except ArgumentIndexError:
463                df = Function.fdiff(self, i)
464            l.append(df * da)
465        return Add(*l)
466
467    def _eval_is_commutative(self):
468        return fuzzy_and(a.is_commutative for a in self.args)
469
470    def as_base_exp(self):
471        """Returns the method as the 2-tuple (base, exponent)."""
472        return self, Integer(1)
473
474    def _eval_aseries(self, n, args0, x, logx):
475        """
476        Compute an asymptotic expansion around args0, in terms of self.args.
477        This function is only used internally by _eval_nseries and should not
478        be called directly; derived classes can overwrite this to implement
479        asymptotic expansions.
480
481        """
482        from ..utilities.misc import filldedent
483        raise PoleError(filldedent("""
484            Asymptotic expansion of %s around %s is
485            not implemented.""" % (type(self), args0)))
486
487    def _eval_nseries(self, x, n, logx):
488        """
489        This function does compute series for multivariate functions,
490        but the expansion is always in terms of *one* variable.
491        Examples
492        ========
493
494        >>> atan2(x, y).series(x, n=2)
495        atan2(0, y) + x/y + O(x**2)
496        >>> atan2(x, y).series(y, n=2)
497        -y/x + atan2(x, 0) + O(y**2)
498
499        This function also computes asymptotic expansions, if necessary
500        and possible:
501
502        >>> loggamma(1/x)._eval_nseries(x, 0, None)
503        -1/x - log(x)/x + log(x)/2 + O(1)
504
505        """
506        from ..series import Order
507        from ..sets.sets import FiniteSet
508        from .symbol import Dummy
509        args = self.args
510        args0 = [t.limit(x, 0) for t in args]
511        if any(isinstance(t, Expr) and t.is_finite is False for t in args0):
512            from .numbers import oo, zoo
513
514            # XXX could use t.as_leading_term(x) here but it's a little
515            # slower
516            a = [t.compute_leading_term(x, logx=logx) for t in args]
517            a0 = [t.limit(x, 0) for t in a]
518            if any(t.has(oo, -oo, zoo, nan) for t in a0):
519                return self._eval_aseries(n, args0, x, logx)
520            # Careful: the argument goes to oo, but only logarithmically so. We
521            # are supposed to do a power series expansion "around the
522            # logarithmic term". e.g.
523            #      f(1+x+log(x))
524            #     -> f(1+logx) + x*f'(1+logx) + O(x**2)
525            # where 'logx' is given in the argument
526            a = [t._eval_nseries(x, n, logx) for t in args]
527            z = [r - r0 for (r, r0) in zip(a, a0)]
528            p = [Dummy()]*len(z)
529            q = []
530            v = None
531            for ai, zi, pi in zip(a0, z, p):
532                if zi.has(x):
533                    if v is not None:
534                        raise NotImplementedError
535                    q.append(ai + pi)
536                    v = pi
537                else:
538                    q.append(ai)
539            e1 = self.func(*q)
540            if v is None:
541                return e1
542            s = e1._eval_nseries(v, n, logx)
543            o = s.getO()
544            s = s.removeO()
545            return s.subs({v: zi}).expand() + Order(o.expr.subs({v: zi}), x)
546        if (self.func.nargs is S.Naturals0
547                or (self.func.nargs == FiniteSet(1) and args0[0])
548                or any(c > 1 for c in self.func.nargs)):
549            e = self
550            e1 = e.expand()
551            if e == e1:
552                # for example when e = sin(x+1) or e = sin(cos(x))
553                # let's try the general algorithm
554                term = e.subs({x: 0})
555                if term.is_finite is False:
556                    raise PoleError(f'Cannot expand {self} around 0')
557                series = term
558                fact = Integer(1)
559                _x = Dummy('x', real=True, positive=True)
560                e = e.subs({x: _x})
561                for i in range(n - 1):
562                    i += 1
563                    fact *= Rational(i)
564                    e = e.diff(_x)
565                    subs = e.subs({_x: 0})
566                    term = subs*(x**i)/fact
567                    term = term.expand()
568                    series += term
569                return series + Order(x**n, x)
570            return e1.nseries(x, n=n, logx=logx)
571        arg = self.args[0]
572        f_series = order = Integer(0)
573        i, terms = 0, []
574        while order == 0 or i <= n:
575            term = self.taylor_term(i, arg, *terms)
576            term = term.nseries(x, n=n, logx=logx)
577            terms.append(term)
578            if term:
579                f_series += term
580            order = Order(term, x)
581            i += 1
582        return f_series + order
583
584    def fdiff(self, argindex=1):
585        """Returns the first derivative of the function."""
586        from .symbol import Dummy
587
588        if not (1 <= argindex <= len(self.args)):
589            raise ArgumentIndexError(self, argindex)
590
591        if self.args[argindex - 1].is_Symbol:
592            for i, arg in enumerate(self.args):
593                if i == argindex - 1:
594                    continue
595                # See issue sympy/sympy#8510
596                if self.args[argindex - 1] in arg.free_symbols:
597                    break
598            else:
599                return Derivative(self, self.args[argindex - 1], evaluate=False)
600        # See issue sympy/sympy#4624 and issue sympy/sympy#4719
601        # and issue sympy/sympy#5600
602        arg_dummy = Dummy(f'xi_{argindex:d}')
603        arg_dummy.dummy_index = hash(self.args[argindex - 1])
604        new_args = list(self.args)
605        new_args[argindex-1] = arg_dummy
606        return Subs(Derivative(self.func(*new_args), arg_dummy),
607                    (arg_dummy, self.args[argindex - 1]))
608
609    def _eval_as_leading_term(self, x):
610        """Stub that should be overridden by new Functions to return
611        the first non-zero term in a series if ever an x-dependent
612        argument whose leading term vanishes as x -> 0 might be encountered.
613        See, for example, cos._eval_as_leading_term.
614
615        """
616        from ..series import Order
617        args = [a.as_leading_term(x) for a in self.args]
618        o = Order(1, x)
619        if any(x in a.free_symbols and o.contains(a) for a in args):
620            # Whereas x and any finite number are contained in O(1, x),
621            # expressions like 1/x are not. If any arg simplified to a
622            # vanishing expression as x -> 0 (like x or x**2, but not
623            # 3, 1/x, etc...) then the _eval_as_leading_term is needed
624            # to supply the first non-zero term of the series,
625            #
626            # e.g. expression    leading term
627            #      ----------    ------------
628            #      cos(1/x)      cos(1/x)
629            #      cos(cos(x))   cos(1)
630            #      cos(x)        1        <- _eval_as_leading_term needed
631            #      sin(x)        x        <- _eval_as_leading_term needed
632            #
633            raise NotImplementedError(
634                f'{self.func} has no _eval_as_leading_term routine')
635        else:
636            return self.func(*args)
637
638
639class AppliedUndef(Function):
640    """
641    Base class for expressions resulting from the application of an undefined
642    function.
643
644    """
645
646    def __new__(cls, *args, **options):
647        args = list(map(sympify, args))
648        obj = super().__new__(cls, *args, **options)
649        return obj
650
651    def _eval_as_leading_term(self, x):
652        return self
653
654
655class UndefinedFunction(FunctionClass):
656    """The (meta)class of undefined functions."""
657
658    def __new__(cls, name, **kwargs):
659        ret = type.__new__(cls, name, (AppliedUndef,), kwargs)
660        ret.__module__ = None
661        return ret
662
663    def __instancecheck__(self, instance):
664        return self in type(instance).__mro__
665
666    def __eq__(self, other):
667        return (isinstance(other, self.__class__) and
668                (self.class_key() == other.class_key()))
669
670    def __hash__(self):
671        return super().__hash__()
672
673
674class WildFunction(Function, AtomicExpr):
675    """
676    A WildFunction function matches any function (with its arguments).
677
678    Examples
679    ========
680
681    >>> F = WildFunction('F')
682    >>> F.nargs
683    Naturals0()
684    >>> x.match(F)
685    >>> F.match(F)
686    {F_: F_}
687    >>> f(x).match(F)
688    {F_: f(x)}
689    >>> cos(x).match(F)
690    {F_: cos(x)}
691    >>> f(x, y).match(F)
692    {F_: f(x, y)}
693
694    To match functions with a given number of arguments, set ``nargs`` to the
695    desired value at instantiation:
696
697    >>> F = WildFunction('F', nargs=2)
698    >>> F.nargs
699    {2}
700    >>> f(x).match(F)
701    >>> f(x, y).match(F)
702    {F_: f(x, y)}
703
704    To match functions with a range of arguments, set ``nargs`` to a tuple
705    containing the desired number of arguments, e.g. if ``nargs = (1, 2)``
706    then functions with 1 or 2 arguments will be matched.
707
708    >>> F = WildFunction('F', nargs=(1, 2))
709    >>> F.nargs
710    {1, 2}
711    >>> f(x).match(F)
712    {F_: f(x)}
713    >>> f(x, y).match(F)
714    {F_: f(x, y)}
715    >>> f(x, y, 1).match(F)
716
717    """
718
719    include: set[typing.Any] = set()
720
721    def __init__(self, name, **assumptions):
722        from ..sets.sets import FiniteSet, Set
723        self.name = name
724        nargs = assumptions.pop('nargs', S.Naturals0)
725        if not isinstance(nargs, Set):
726            # Canonicalize nargs here.  See also FunctionClass.
727            if is_sequence(nargs):
728                nargs = tuple(ordered(set(nargs)))
729            else:
730                nargs = as_int(nargs),
731            nargs = FiniteSet(*nargs)
732        self.nargs = nargs
733
734    def _matches(self, expr, repl_dict={}):
735        """Helper method for match()
736
737        See Also
738        ========
739
740        diofant.core.basic.Basic.matches
741
742        """
743        if not isinstance(expr, (AppliedUndef, Function)):
744            return
745        if len(expr.args) not in self.nargs:
746            return
747
748        repl_dict = repl_dict.copy()
749        repl_dict[self] = expr
750        return repl_dict
751
752
753class Derivative(Expr):
754    """
755    Carries out differentiation of the given expression with respect to symbols.
756
757    expr must define ._eval_derivative(symbol) method that returns
758    the differentiation result. This function only needs to consider the
759    non-trivial case where expr contains symbol and it should call the diff()
760    method internally (not _eval_derivative); Derivative should be the only
761    one to call _eval_derivative.
762
763    Simplification of high-order derivatives:
764
765    Because there can be a significant amount of simplification that can be
766    done when multiple differentiations are performed, results will be
767    automatically simplified in a fairly conservative fashion unless the
768    keyword ``simplify`` is set to False.
769
770        >>> e = sqrt((x + 1)**2 + x)
771        >>> diff(e, (x, 5), simplify=False).count_ops()
772        136
773        >>> diff(e, (x, 5)).count_ops()
774        30
775
776    Ordering of variables:
777
778    If evaluate is set to True and the expression can not be evaluated, the
779    list of differentiation symbols will be sorted, that is, the expression is
780    assumed to have continuous derivatives up to the order asked. This sorting
781    assumes that derivatives wrt Symbols commute, derivatives wrt non-Symbols
782    commute, but Symbol and non-Symbol derivatives don't commute with each
783    other.
784
785    Derivative wrt non-Symbols:
786
787    This class also allows derivatives wrt non-Symbols that have _diff_wrt
788    set to True, such as Function and Derivative. When a derivative wrt a non-
789    Symbol is attempted, the non-Symbol is temporarily converted to a Symbol
790    while the differentiation is performed.
791
792    Note that this may seem strange, that Derivative allows things like
793    f(g(x)).diff(g(x)), or even f(cos(x)).diff(cos(x)).  The motivation for
794    allowing this syntax is to make it easier to work with variational calculus
795    (i.e., the Euler-Lagrange method).  The best way to understand this is that
796    the action of derivative with respect to a non-Symbol is defined by the
797    above description:  the object is substituted for a Symbol and the
798    derivative is taken with respect to that.  This action is only allowed for
799    objects for which this can be done unambiguously, for example Function and
800    Derivative objects.  Note that this leads to what may appear to be
801    mathematically inconsistent results.  For example::
802
803        >>> (2*cos(x)).diff(cos(x))
804        2
805        >>> (2*sqrt(1 - sin(x)**2)).diff(cos(x))
806        0
807
808    This appears wrong because in fact 2*cos(x) and 2*sqrt(1 - sin(x)**2) are
809    identically equal.  However this is the wrong way to think of this.  Think
810    of it instead as if we have something like this::
811
812        >>> from diofant.abc import s
813        >>> def f(u):
814        ...     return 2*u
815        ...
816        >>> def g(u):
817        ...     return 2*sqrt(1 - u**2)
818        ...
819        >>> f(cos(x))
820        2*cos(x)
821        >>> g(sin(x))
822        2*sqrt(-sin(x)**2 + 1)
823        >>> f(c).diff(c)
824        2
825        >>> f(c).diff(c)
826        2
827        >>> g(s).diff(c)
828        0
829        >>> g(sin(x)).diff(cos(x))
830        0
831
832    Here, the Symbols c and s act just like the functions cos(x) and sin(x),
833    respectively. Think of 2*cos(x) as f(c).subs({c: cos(x)}) (or f(c) *at*
834    c = cos(x)) and 2*sqrt(1 - sin(x)**2) as g(s).subs({s: sin(x)}) (or g(s) *at*
835    s = sin(x)), where f(u) == 2*u and g(u) == 2*sqrt(1 - u**2).  Here, we
836    define the function first and evaluate it at the function, but we can
837    actually unambiguously do this in reverse in Diofant, because
838    expr.subs({Function: Symbol}) is well-defined:  just structurally replace the
839    function everywhere it appears in the expression.
840
841    This is the same notational convenience used in the Euler-Lagrange method
842    when one says F(t, f(t), f'(t)).diff(f(t)).  What is actually meant is
843    that the expression in question is represented by some F(t, u, v) at u =
844    f(t) and v = f'(t), and F(t, f(t), f'(t)).diff(f(t)) simply means F(t, u,
845    v).diff(u) at u = f(t).
846
847    We do not allow derivatives to be taken with respect to expressions where this
848    is not so well defined.  For example, we do not allow expr.diff(x*y)
849    because there are multiple ways of structurally defining where x*y appears
850    in an expression, some of which may surprise the reader (for example, a
851    very strict definition would have that (x*y*z).diff(x*y) == 0).
852
853        >>> (x*y*z).diff(x*y)
854        Traceback (most recent call last):
855        ...
856        ValueError: Can't differentiate wrt the variable: x*y, 1
857
858    Note that this definition also fits in nicely with the definition of the
859    chain rule.  Note how the chain rule in Diofant is defined using unevaluated
860    Subs objects::
861
862        >>> f, g = symbols('f g', cls=Function)
863        >>> f(2*g(x)).diff(x)
864        2*Derivative(g(x), x)*Subs(Derivative(f(_xi_1), _xi_1), (_xi_1, 2*g(x)))
865        >>> f(g(x)).diff(x)
866        Derivative(g(x), x)*Subs(Derivative(f(_xi_1), _xi_1), (_xi_1, g(x)))
867
868    Finally, note that, to be consistent with variational calculus, and to
869    ensure that the definition of substituting a Function for a Symbol in an
870    expression is well-defined, derivatives of functions are assumed to not be
871    related to the function.  In other words, we have::
872
873        >>> diff(f(x), x).diff(f(x))
874        0
875
876    The same is true for derivatives of different orders::
877
878        >>> diff(f(x), (x, 2)).diff(diff(f(x), (x, 1)))
879        0
880        >>> diff(f(x), (x, 1)).diff(diff(f(x), (x, 2)))
881        0
882
883    Note, any class can allow derivatives to be taken with respect to itself.
884
885    Examples
886    ========
887
888    Some basic examples:
889
890        >>> Derivative(x**2, x, evaluate=True)
891        2*x
892        >>> Derivative(Derivative(f(x, y), x), y)
893        Derivative(f(x, y), x, y)
894        >>> Derivative(f(x), (x, 3))
895        Derivative(f(x), x, x, x)
896        >>> Derivative(f(x, y), y, x, evaluate=True)
897        Derivative(f(x, y), x, y)
898
899    Now some derivatives wrt functions:
900
901        >>> Derivative(f(x)**2, f(x), evaluate=True)
902        2*f(x)
903        >>> Derivative(f(g(x)), x, evaluate=True)
904        Derivative(g(x), x)*Subs(Derivative(f(_xi_1), _xi_1), (_xi_1, g(x)))
905
906    """
907
908    is_Derivative = True
909
910    @property
911    def _diff_wrt(self):
912        """Allow derivatives wrt Derivatives if it contains a function.
913
914        Examples
915        ========
916
917            >>> Derivative(f(x), x)._diff_wrt
918            True
919            >>> Derivative(x**2, x)._diff_wrt
920            False
921
922        """
923        if self.expr.is_Function:
924            return True
925        else:
926            return False
927
928    def __new__(cls, expr, *args, **assumptions):
929        from .symbol import Dummy
930
931        expr = sympify(expr)
932
933        # There are no args, we differentiate wrt all of the free symbols
934        # in expr.
935        if not args:
936            variables = expr.free_symbols
937            args = tuple(variables)
938            if len(variables) != 1:
939                from ..utilities.misc import filldedent
940                raise ValueError(filldedent("""
941                    The variable(s) of differentiation
942                    must be supplied to differentiate %s""" % expr))
943
944        # Standardize the args by sympifying them and making appending a
945        # count of 1 if there is only variable: diff(e, x) -> diff(e, (x, 1)).
946        args = list(sympify(args))
947        for i, a in enumerate(args):
948            if not isinstance(a, Tuple):
949                args[i] = (a, Integer(1))
950
951        variable_count = []
952        all_zero = True
953        for v, count in args:
954            if not v._diff_wrt:
955                from ..utilities.misc import filldedent
956                ordinal = 'st' if count == 1 else 'nd' if count == 2 else 'rd' if count == 3 else 'th'
957                raise ValueError(filldedent("""
958                Can\'t calculate %s%s derivative wrt %s.""" % (count, ordinal, v)))
959            if count:
960                if all_zero:
961                    all_zero = False
962                variable_count.append(Tuple(v, count))
963
964        # We make a special case for 0th derivative, because there is no
965        # good way to unambiguously print this.
966        if all_zero:
967            return expr
968
969        # Pop evaluate because it is not really an assumption and we will need
970        # to track it carefully below.
971        evaluate = assumptions.pop('evaluate', False)
972
973        # Look for a quick exit if there are symbols that don't appear in
974        # expression at all. Note, this cannnot check non-symbols like
975        # functions and Derivatives as those can be created by intermediate
976        # derivatives.
977        if evaluate:
978            symbol_set = {sc[0] for sc in variable_count if sc[0].is_Symbol}
979            if symbol_set.difference(expr.free_symbols):
980                return Integer(0)
981
982        # We make a generator so as to only generate a variable when necessary.
983        # If a high order of derivative is requested and the expr becomes 0
984        # after a few differentiations, then we won't need the other variables.
985        variablegen = (v for v, count in variable_count for i in range(count))
986
987        # If we can't compute the derivative of expr (but we wanted to) and
988        # expr is itself not a Derivative, finish building an unevaluated
989        # derivative class by calling Expr.__new__.
990        if (not (hasattr(expr, '_eval_derivative') and evaluate) and
991                (not isinstance(expr, Derivative))):
992            variables = list(variablegen)
993            # If we wanted to evaluate, we sort the variables into standard
994            # order for later comparisons. This is too aggressive if evaluate
995            # is False, so we don't do it in that case.
996            if evaluate:
997                # TODO: check if assumption of discontinuous derivatives exist
998                variables = cls._sort_variables(variables)
999            # Here we *don't* need to reinject evaluate into assumptions
1000            # because we are done with it and it is not an assumption that
1001            # Expr knows about.
1002            obj = Expr.__new__(cls, expr, *variables, **assumptions)
1003            return obj
1004
1005        # Compute the derivative now by repeatedly calling the
1006        # _eval_derivative method of expr for each variable. When this method
1007        # returns None, the derivative couldn't be computed wrt that variable
1008        # and we save the variable for later.
1009        unhandled_variables = []
1010
1011        # Once we encouter a non_symbol that is unhandled, we stop taking
1012        # derivatives entirely. This is because derivatives wrt functions
1013        # don't commute with derivatives wrt symbols and we can't safely
1014        # continue.
1015        unhandled_non_symbol = False
1016        nderivs = 0  # how many derivatives were performed
1017        for v in variablegen:
1018            is_symbol = v.is_Symbol
1019
1020            if unhandled_non_symbol:
1021                obj = None
1022            else:
1023                if not is_symbol:
1024                    new_v = Dummy(f'xi_{i:d}')
1025                    new_v.dummy_index = hash(v)
1026                    expr = expr.xreplace({v: new_v})
1027                    old_v = v
1028                    v = new_v
1029                obj = expr._eval_derivative(v)
1030                nderivs += 1
1031                if not is_symbol:
1032                    if obj is not None:
1033                        if obj.is_Derivative and not old_v.is_Symbol:
1034                            # Derivative evaluated at a generic point, i.e.
1035                            # that is not a symbol.
1036                            obj = Subs(obj, (v, old_v))
1037                        else:
1038                            obj = obj.xreplace({v: old_v})
1039                    v = old_v
1040
1041            if obj is None:
1042                unhandled_variables.append(v)
1043                if not is_symbol:
1044                    unhandled_non_symbol = True
1045            elif obj == 0:
1046                return Integer(0)
1047            else:
1048                expr = obj
1049
1050        if unhandled_variables:
1051            unhandled_variables = cls._sort_variables(unhandled_variables)
1052            expr = Expr.__new__(cls, expr, *unhandled_variables, **assumptions)
1053        else:
1054            # We got a Derivative at the end of it all, and we rebuild it by
1055            # sorting its variables.
1056            if isinstance(expr, Derivative):
1057                expr = cls(
1058                    expr.expr, *cls._sort_variables(expr.variables)
1059                )
1060
1061        if nderivs > 1 and assumptions.get('simplify', True):
1062            from ..simplify.simplify import signsimp
1063            from .exprtools import factor_terms
1064            expr = factor_terms(signsimp(expr))
1065        return expr
1066
1067    @classmethod
1068    def _sort_variables(cls, vars):
1069        """Sort variables, but disallow sorting of non-symbols.
1070
1071        When taking derivatives, the following rules usually hold:
1072
1073        * Derivative wrt different symbols commute.
1074        * Derivative wrt different non-symbols commute.
1075        * Derivatives wrt symbols and non-symbols don't commute.
1076
1077        Examples
1078        ========
1079
1080        >>> vsort = Derivative._sort_variables
1081
1082        >>> vsort((x, y, z))
1083        [x, y, z]
1084
1085        >>> vsort((h(x), g(x), f(x)))
1086        [f(x), g(x), h(x)]
1087
1088        >>> vsort((z, y, x, h(x), g(x), f(x)))
1089        [x, y, z, f(x), g(x), h(x)]
1090
1091        >>> vsort((x, f(x), y, f(y)))
1092        [x, f(x), y, f(y)]
1093
1094        >>> vsort((y, x, g(x), f(x), z, h(x), y, x))
1095        [x, y, f(x), g(x), z, h(x), x, y]
1096
1097        >>> vsort((z, y, f(x), x, f(x), g(x)))
1098        [y, z, f(x), x, f(x), g(x)]
1099
1100        >>> vsort((z, y, f(x), x, f(x), g(x), z, z, y, x))
1101        [y, z, f(x), x, f(x), g(x), x, y, z, z]
1102
1103        """
1104        sorted_vars = []
1105        symbol_part = []
1106        non_symbol_part = []
1107        for v in vars:
1108            if not v.is_Symbol:
1109                if len(symbol_part) > 0:
1110                    sorted_vars.extend(sorted(symbol_part,
1111                                              key=default_sort_key))
1112                    symbol_part = []
1113                non_symbol_part.append(v)
1114            else:
1115                if len(non_symbol_part) > 0:
1116                    sorted_vars.extend(sorted(non_symbol_part,
1117                                              key=default_sort_key))
1118                    non_symbol_part = []
1119                symbol_part.append(v)
1120        if len(non_symbol_part) > 0:
1121            sorted_vars.extend(sorted(non_symbol_part,
1122                                      key=default_sort_key))
1123        if len(symbol_part) > 0:
1124            sorted_vars.extend(sorted(symbol_part,
1125                                      key=default_sort_key))
1126        return sorted_vars
1127
1128    def _eval_is_commutative(self):
1129        return self.expr.is_commutative
1130
1131    def _eval_derivative(self, v):
1132        # If the variable s we are diff wrt is not in self.variables, we
1133        # assume that we might be able to take the derivative.
1134        if v not in self.variables:
1135            obj = self.expr.diff(v)
1136            if obj == 0:
1137                return Integer(0)
1138            if isinstance(obj, Derivative):
1139                return obj.func(obj.expr, *(self.variables + obj.variables))
1140            # The derivative wrt s could have simplified things such that the
1141            # derivative wrt things in self.variables can now be done. Thus,
1142            # we set evaluate=True to see if there are any other derivatives
1143            # that can be done. The most common case is when obj is a simple
1144            # number so that the derivative wrt anything else will vanish.
1145            return self.func(obj, *self.variables, evaluate=True)
1146        # In this case s was in self.variables so the derivatve wrt s has
1147        # already been attempted and was not computed, either because it
1148        # couldn't be or evaluate=False originally.
1149        return self.func(self.expr, *(self.variables + (v, )), evaluate=False)
1150
1151    def doit(self, **hints):
1152        """Evaluate objects that are not evaluated by default.
1153
1154        See Also
1155        ========
1156
1157        diofant.core.basic.Basic.doit
1158
1159        """
1160        expr = self.expr
1161        if hints.get('deep', True):
1162            expr = expr.doit(**hints)
1163        hints['evaluate'] = True
1164        return self.func(expr, *self.variables, **hints)
1165
1166    @_sympifyit('z0', NotImplementedError)
1167    def doit_numerically(self, z0):
1168        """
1169        Evaluate the derivative at z numerically.
1170
1171        When we can represent derivatives at a point, this should be folded
1172        into the normal evalf. For now, we need a special method.
1173
1174        """
1175        from .expr import Expr
1176        if len(self.free_symbols) != 1 or len(self.variables) != 1:
1177            raise NotImplementedError('partials and higher order derivatives')
1178        z = list(self.free_symbols)[0]
1179
1180        def eval(x):
1181            f0 = self.expr.subs({z: Expr._from_mpmath(x, prec=mpmath.mp.prec)})
1182            f0 = f0.evalf(mlib.libmpf.prec_to_dps(mpmath.mp.prec), strict=False)
1183            return f0._to_mpmath(mpmath.mp.prec)
1184        return Expr._from_mpmath(mpmath.diff(eval,
1185                                             z0._to_mpmath(mpmath.mp.prec)),
1186                                 mpmath.mp.prec)
1187
1188    @property
1189    def expr(self):
1190        """Return expression."""
1191        return self.args[0]
1192
1193    @property
1194    def variables(self):
1195        """Return tuple of symbols, wrt derivative is taken."""
1196        return self.args[1:]
1197
1198    @property
1199    def free_symbols(self):
1200        """Return from the atoms of self those which are free symbols.
1201
1202        See Also
1203        ========
1204
1205        diofant.core.basic.Basic.free_symbols
1206
1207        """
1208        return self.expr.free_symbols
1209
1210    def _eval_subs(self, old, new):
1211        if old in self.variables and not new._diff_wrt:
1212            # issue sympy/sympy#4719
1213            return Subs(self, (old, new))
1214        # If both are Derivatives with the same expr, check if old is
1215        # equivalent to self or if old is a subderivative of self.
1216        if old.is_Derivative and old.expr == self.expr:
1217            # Check if canonnical order of variables is equal.
1218            old_vars = collections.Counter(old.variables)
1219            self_vars = collections.Counter(self.variables)
1220            if old_vars == self_vars:
1221                return new
1222
1223            # collections.Counter doesn't have __le__
1224            def _subset(a, b):
1225                return all(a[i] <= b[i] for i in a)
1226
1227            if _subset(old_vars, self_vars):
1228                return Derivative(new, *(self_vars - old_vars).elements())
1229
1230        return Derivative(*(x._subs(old, new) for x in self.args))
1231
1232    def _eval_lseries(self, x, logx):
1233        for term in self.expr.series(x, n=None, logx=logx):
1234            yield self.func(term, *self.variables)
1235
1236    def _eval_nseries(self, x, n, logx):
1237        arg = self.expr.nseries(x, n=n, logx=logx)
1238        o = arg.getO()
1239        rv = [self.func(a, *self.variables) for a in Add.make_args(arg.removeO())]
1240        if o:
1241            rv.append(o/x)
1242        return Add(*rv)
1243
1244    def _eval_as_leading_term(self, x):
1245        return self.func(self.expr.as_leading_term(x), *self.variables)
1246
1247
1248class Lambda(Expr):
1249    """
1250    Lambda(x, expr) represents a lambda function similar to Python's
1251    'lambda x: expr'. A function of several variables is written as
1252    Lambda((x, y, ...), expr).
1253
1254    A simple example:
1255
1256    >>> f = Lambda(x, x**2)
1257    >>> f(4)
1258    16
1259
1260    For multivariate functions, use:
1261
1262    >>> f2 = Lambda((x, y, z, t), x + y**z + t**z)
1263    >>> f2(1, 2, 3, 4)
1264    73
1265
1266    A handy shortcut for lots of arguments:
1267
1268    >>> p = x, y, z
1269    >>> f = Lambda(p, x + y*z)
1270    >>> f(*p)
1271    x + y*z
1272
1273    """
1274
1275    is_Function = True
1276
1277    def __new__(cls, variables, expr):
1278        from ..sets.sets import FiniteSet
1279        v = list(variables) if iterable(variables) else [variables]
1280        for i in v:
1281            if not getattr(i, 'is_Symbol', False):
1282                raise TypeError(f'variable is not a symbol: {i}')
1283        if len(v) == 1 and v[0] == expr:
1284            return S.IdentityFunction
1285
1286        obj = Expr.__new__(cls, Tuple(*v), sympify(expr))
1287        obj.nargs = FiniteSet(len(v))
1288        return obj
1289
1290    @property
1291    def variables(self):
1292        """The variables used in the internal representation of the function."""
1293        return self.args[0]
1294
1295    @property
1296    def expr(self):
1297        """The return value of the function."""
1298        return self.args[1]
1299
1300    @property
1301    def free_symbols(self):
1302        """Return from the atoms of self those which are free symbols.
1303
1304        See Also
1305        ========
1306
1307        diofant.core.basic.Basic.free_symbols
1308
1309        """
1310        return self.expr.free_symbols - set(self.variables)
1311
1312    def __call__(self, *args):
1313        n = len(args)
1314        if n not in self.nargs:  # Lambda only ever has 1 value in nargs
1315            # XXX: exception message must be in exactly this format to
1316            # make it work with NumPy's functions like vectorize(). See,
1317            # for example, https://github.com/numpy/numpy/issues/1697.
1318            # The ideal solution would be just to attach metadata to
1319            # the exception and change NumPy to take advantage of this.
1320            # XXX does this apply to Lambda? If not, remove this comment.
1321            temp = ('%(name)s takes exactly %(args)s '
1322                    'argument%(plural)s (%(given)s given)')
1323            raise TypeError(temp % {
1324                'name': self,
1325                'args': list(self.nargs)[0],
1326                'plural': 's'*(list(self.nargs)[0] != 1),
1327                'given': n})
1328        return self.expr.xreplace(dict(zip(self.variables, args)))
1329
1330    def __eq__(self, other):
1331        if not isinstance(other, Lambda):
1332            return False
1333        if self.nargs != other.nargs:
1334            return False
1335
1336        selfexpr = self.args[1]
1337        otherexpr = other.args[1]
1338        otherexpr = otherexpr.xreplace(dict(zip(other.args[0], self.args[0])))
1339        return selfexpr == otherexpr
1340
1341    def __hash__(self):
1342        return super().__hash__()
1343
1344    def _hashable_content(self):
1345        return self.expr.xreplace(self.canonical_variables),
1346
1347
1348class Subs(Expr):
1349    """
1350    Represents unevaluated substitutions of an expression.
1351
1352    ``Subs`` receives at least 2 arguments: an expression, a pair of old
1353    and new expression to substitute or several such pairs.
1354
1355    ``Subs`` objects are generally useful to represent unevaluated derivatives
1356    calculated at a point.
1357
1358    The variables may be expressions, but they are subjected to the limitations
1359    of subs(), so it is usually a good practice to use only symbols for
1360    variables, since in that case there can be no ambiguity.
1361
1362    There's no automatic expansion - use the method .doit() to effect all
1363    possible substitutions of the object and also of objects inside the
1364    expression.
1365
1366    When evaluating derivatives at a point that is not a symbol, a Subs object
1367    is returned. One is also able to calculate derivatives of Subs objects - in
1368    this case the expression is always expanded (for the unevaluated form, use
1369    Derivative()).
1370
1371    Examples
1372    ========
1373
1374    >>> e = Subs(f(x).diff(x), (x, y))
1375    >>> e.subs({y: 0})
1376    Subs(Derivative(f(x), x), (x, 0))
1377    >>> e.subs({f: sin}).doit()
1378    cos(y)
1379
1380    >>> Subs(f(x)*sin(y) + z, (x, 0), (y, 1))
1381    Subs(z + f(x)*sin(y), (x, 0), (y, 1))
1382    >>> _.doit()
1383    z + f(0)*sin(1)
1384
1385    """
1386
1387    def __new__(cls, expr, *args, **assumptions):
1388        from .symbol import Symbol
1389        args = sympify(args)
1390        if all(is_sequence(_) and len(_) == 2 for _ in args):
1391            variables, point = zip(*args)
1392        else:
1393            raise ValueError('Subs support two or more arguments')
1394
1395        if tuple(uniq(variables)) != variables:
1396            repeated = [v for v in set(variables) if variables.count(v) > 1]
1397            raise ValueError('cannot substitute expressions %s more than '
1398                             'once.' % repeated)
1399
1400        expr = sympify(expr)
1401
1402        # use symbols with names equal to the point value (with preppended _)
1403        # to give a variable-independent expression
1404        pre = '_'
1405        pts = sorted(set(point), key=default_sort_key)
1406        from ..printing import StrPrinter
1407
1408        class CustomStrPrinter(StrPrinter):
1409            def _print_Dummy(self, expr):
1410                return str(expr) + str(expr.dummy_index)
1411
1412        def mystr(expr, **settings):
1413            p = CustomStrPrinter(settings)
1414            return p.doprint(expr)
1415
1416        while 1:
1417            s_pts = {p: Symbol(pre + mystr(p)) for p in pts}
1418            reps = [(v, s_pts[p])
1419                    for v, p in zip(variables, point)]
1420            # if any underscore-preppended symbol is already a free symbol
1421            # and is a variable with a different point value, then there
1422            # is a clash, e.g. _0 clashes in Subs(_0 + _1, (_0, 1), (_1, 0))
1423            # because the new symbol that would be created is _1 but _1
1424            # is already mapped to 0 so __0 and __1 are used for the new
1425            # symbols
1426            if any(r in expr.free_symbols and
1427                   r in variables and
1428                   Symbol(pre + mystr(point[variables.index(r)])) != r
1429                   for _, r in reps):
1430                pre += '_'
1431                continue
1432            break
1433
1434        obj = Expr.__new__(cls, expr, *sympify(tuple(zip(variables, point))))
1435        obj._expr = expr.subs(reps)
1436        return obj
1437
1438    def _eval_is_commutative(self):
1439        return (self.expr.is_commutative and
1440                all(p.is_commutative for p in self.point))
1441
1442    def doit(self, **hints):
1443        """Evaluate objects that are not evaluated by default.
1444
1445        See Also
1446        ========
1447
1448        diofant.core.basic.Basic.doit
1449
1450        """
1451        return self.expr.doit(**hints).subs(list(zip(self.variables, self.point)))
1452
1453    def evalf(self, dps=15, **options):
1454        """Evaluate the given formula to an accuracy of dps decimal digits.
1455
1456        See Also
1457        ========
1458
1459        diofant.core.evalf.EvalfMixin.evalf
1460
1461        """
1462        return self.doit().evalf(dps, **options)
1463
1464    #:
1465    n = evalf
1466
1467    @property
1468    def variables(self):
1469        """The variables to be evaluated."""
1470        return Tuple(*tuple(zip(*self.args[1:])))[0]
1471
1472    @property
1473    def expr(self):
1474        """The expression on which the substitution operates."""
1475        return self.args[0]
1476
1477    @property
1478    def point(self):
1479        """The values for which the variables are to be substituted."""
1480        return Tuple(*tuple(zip(*self.args[1:])))[1]
1481
1482    @property
1483    def free_symbols(self):
1484        """Return from the atoms of self those which are free symbols.
1485
1486        See Also
1487        ========
1488
1489        diofant.core.basic.Basic.free_symbols
1490
1491        """
1492        return (self.expr.free_symbols - set(self.variables) |
1493                set(self.point.free_symbols))
1494
1495    def __eq__(self, other):
1496        if not isinstance(other, Subs):
1497            return False
1498        return self._expr == other._expr
1499
1500    def __hash__(self):
1501        return super().__hash__()
1502
1503    def _hashable_content(self):
1504        return self._expr.xreplace(self.canonical_variables),
1505
1506    def _eval_subs(self, old, new):
1507        if old in self.variables:
1508            return self
1509
1510        if isinstance(old, Subs) and self.point == old.point:
1511            if self.expr.subs(zip(self.variables, old.variables)) == old.expr:
1512                return new
1513
1514    def _eval_derivative(self, s):
1515        return Add((self.func(self.expr.diff(s), *self.args[1:]).doit()
1516                    if s not in self.variables else Integer(0)),
1517                   *[p.diff(s)*self.func(self.expr.diff(v), *self.args[1:]).doit()
1518                     for v, p in zip(self.variables, self.point)])
1519
1520
1521def diff(f, *args, **kwargs):
1522    """
1523    Differentiate f with respect to symbols.
1524
1525    This is just a wrapper to unify .diff() and the Derivative class; its
1526    interface is similar to that of integrate().  You can use the same
1527    shortcuts for multiple variables as with Derivative.  For example,
1528    diff(f(x), x, x, x) and diff(f(x), (x, 3)) both return the third derivative
1529    of f(x).
1530
1531    You can pass evaluate=False to get an unevaluated Derivative class.  Note
1532    that if there are 0 symbols (such as diff(f(x), (x, 0)), then the result will
1533    be the function (the zeroth derivative), even if evaluate=False.
1534
1535    Examples
1536    ========
1537
1538    >>> diff(sin(x), x)
1539    cos(x)
1540    >>> diff(f(x), x, x, x)
1541    Derivative(f(x), x, x, x)
1542    >>> diff(f(x), (x, 3))
1543    Derivative(f(x), x, x, x)
1544    >>> diff(sin(x)*cos(y), (x, 2), (y, 2))
1545    sin(x)*cos(y)
1546
1547    >>> type(diff(sin(x), x))
1548    cos
1549    >>> type(diff(sin(x), x, evaluate=False))
1550    <class 'diofant.core.function.Derivative'>
1551    >>> type(diff(sin(x), (x, 0)))
1552    sin
1553    >>> type(diff(sin(x), (x, 0), evaluate=False))
1554    sin
1555
1556    >>> diff(sin(x))
1557    cos(x)
1558    >>> diff(sin(x*y))
1559    Traceback (most recent call last):
1560    ...
1561    ValueError: specify differentiation variables to differentiate sin(x*y)
1562
1563    Note that ``diff(sin(x))`` syntax is meant only for convenience
1564    in interactive sessions and should be avoided in library code.
1565
1566    References
1567    ==========
1568
1569    * https://reference.wolfram.com/legacy/v5_2/Built-inFunctions/AlgebraicComputation/Calculus/D.html
1570
1571    See Also
1572    ========
1573
1574    Derivative
1575    diofant.geometry.util.idiff: computes the derivative implicitly
1576
1577    """
1578    kwargs.setdefault('evaluate', True)
1579    return Derivative(f, *args, **kwargs)
1580
1581
1582def expand(e, deep=True, modulus=None, power_base=True, power_exp=True,
1583           mul=True, log=True, multinomial=True, basic=True, **hints):
1584    r"""Expand an expression using methods given as hints.
1585
1586    Hints evaluated unless explicitly set to False are:  ``basic``, ``log``,
1587    ``multinomial``, ``mul``, ``power_base``, and ``power_exp``.  The following
1588    hints are supported but not applied unless set to True:  ``complex``,
1589    ``func``, and ``trig``.  In addition, the following meta-hints are
1590    supported by some or all of the other hints:  ``frac``, ``numer``,
1591    ``denom``, ``modulus``, and ``force``.  ``deep`` is supported by all
1592    hints.  Additionally, subclasses of Expr may define their own hints or
1593    meta-hints.
1594
1595    Parameters
1596    ==========
1597
1598    basic : boolean, optional
1599        This hint is used for any special
1600        rewriting of an object that should be done automatically (along with
1601        the other hints like ``mul``) when expand is called. This is a catch-all
1602        hint to handle any sort of expansion that may not be described by
1603        the existing hint names.
1604
1605    deep : boolean, optional
1606        If ``deep`` is set to ``True`` (the default), things like arguments of
1607        functions are recursively expanded.  Use ``deep=False`` to only expand on
1608        the top level.
1609
1610    mul : boolean, optional
1611        Distributes multiplication over addition (``):
1612
1613        >>> (y*(x + z)).expand(mul=True)
1614        x*y + y*z
1615
1616    multinomial : boolean, optional
1617        Expand (x + y + ...)**n where n is a positive integer.
1618
1619        >>> ((x + y + z)**2).expand(multinomial=True)
1620        x**2 + 2*x*y + 2*x*z + y**2 + 2*y*z + z**2
1621
1622    power_exp : boolean, optional
1623        Expand addition in exponents into multiplied bases.
1624
1625        >>> exp(x + y).expand(power_exp=True)
1626        E**x*E**y
1627        >>> (2**(x + y)).expand(power_exp=True)
1628        2**x*2**y
1629
1630    power_base : boolean, optional
1631        Split powers of multiplied bases.
1632
1633        This only happens by default if assumptions allow, or if the
1634        ``force`` meta-hint is used:
1635
1636        >>> ((x*y)**z).expand(power_base=True)
1637        (x*y)**z
1638        >>> ((x*y)**z).expand(power_base=True, force=True)
1639        x**z*y**z
1640        >>> ((2*y)**z).expand(power_base=True)
1641        2**z*y**z
1642
1643        Note that in some cases where this expansion always holds, Diofant performs
1644        it automatically:
1645
1646        >>> (x*y)**2
1647        x**2*y**2
1648
1649    log : boolean, optional
1650        Pull out power of an argument as a coefficient and split logs products
1651        into sums of logs.
1652
1653        Note that these only work if the arguments of the log function have the
1654        proper assumptions--the arguments must be positive and the exponents must
1655        be real--or else the ``force`` hint must be True:
1656
1657        >>> log(x**2*y).expand(log=True)
1658        log(x**2*y)
1659        >>> log(x**2*y).expand(log=True, force=True)
1660        2*log(x) + log(y)
1661        >>> x, y = symbols('x y', positive=True)
1662        >>> log(x**2*y).expand(log=True)
1663        2*log(x) + log(y)
1664
1665    complex : boolean, optional
1666        Split an expression into real and imaginary parts.
1667
1668        >>> x, y = symbols('x y')
1669        >>> (x + y).expand(complex=True)
1670        re(x) + re(y) + I*im(x) + I*im(y)
1671        >>> cos(x).expand(complex=True)
1672        -I*sin(re(x))*sinh(im(x)) + cos(re(x))*cosh(im(x))
1673
1674        Note that this is just a wrapper around ``as_real_imag()``.  Most objects
1675        that wish to redefine ``_eval_expand_complex()`` should consider
1676        redefining ``as_real_imag()`` instead.
1677
1678    func : boolean : optional
1679        Expand other functions.
1680
1681        >>> gamma(x + 1).expand(func=True)
1682        x*gamma(x)
1683
1684    trig : boolean, optional
1685        Do trigonometric expansions.
1686
1687        >>> cos(x + y).expand(trig=True)
1688        -sin(x)*sin(y) + cos(x)*cos(y)
1689        >>> sin(2*x).expand(trig=True)
1690        2*sin(x)*cos(x)
1691
1692        Note that the forms of ``sin(n*x)`` and ``cos(n*x)`` in terms of ``sin(x)``
1693        and ``cos(x)`` are not unique, due to the identity `\sin^2(x) + \cos^2(x)
1694        = 1`.  The current implementation uses the form obtained from Chebyshev
1695        polynomials, but this may change.
1696
1697    force : boolean, optional
1698        If the ``force`` hint is used, assumptions about variables will be ignored
1699        in making the expansion.
1700
1701
1702    Notes
1703    =====
1704
1705    - You can shut off unwanted methods::
1706
1707        >>> (exp(x + y)*(x + y)).expand()
1708        E**x*E**y*x + E**x*E**y*y
1709        >>> (exp(x + y)*(x + y)).expand(power_exp=False)
1710        E**(x + y)*x + E**(x + y)*y
1711        >>> (exp(x + y)*(x + y)).expand(mul=False)
1712        E**x*E**y*(x + y)
1713
1714    - Use deep=False to only expand on the top level::
1715
1716        >>> exp(x + exp(x + y)).expand()
1717        E**x*E**(E**x*E**y)
1718        >>> exp(x + exp(x + y)).expand(deep=False)
1719        E**(E**(x + y))*E**x
1720
1721    - Hints are applied in an arbitrary, but consistent order (in the current
1722      implementation, they are applied in alphabetical order, except
1723      multinomial comes before mul, but this may change).  Because of this,
1724      some hints may prevent expansion by other hints if they are applied
1725      first. For example, ``mul`` may distribute multiplications and prevent
1726      ``log`` and ``power_base`` from expanding them. Also, if ``mul`` is
1727      applied before ``multinomial``, the expression might not be fully
1728      distributed. The solution is to use the various ``expand_hint`` helper
1729      functions or to use ``hint=False`` to this function to finely control
1730      which hints are applied. Here are some examples::
1731
1732        >>> x, y, z = symbols('x y z', positive=True)
1733
1734        >>> expand(log(x*(y + z)))
1735        log(x) + log(y + z)
1736
1737      Here, we see that ``log`` was applied before ``mul``.  To get the mul
1738      expanded form, either of the following will work::
1739
1740        >>> expand_mul(log(x*(y + z)))
1741        log(x*y + x*z)
1742        >>> expand(log(x*(y + z)), log=False)
1743        log(x*y + x*z)
1744
1745      A similar thing can happen with the ``power_base`` hint::
1746
1747        >>> expand((x*(y + z))**x)
1748        (x*y + x*z)**x
1749
1750      To get the ``power_base`` expanded form, either of the following will
1751      work::
1752
1753        >>> expand((x*(y + z))**x, mul=False)
1754        x**x*(y + z)**x
1755        >>> expand_power_base((x*(y + z))**x)
1756        x**x*(y + z)**x
1757
1758        >>> expand((x + y)*y/x)
1759        y + y**2/x
1760
1761      The parts of a rational expression can be targeted::
1762
1763        >>> expand((x + y)*y/x/(x + 1), frac=True)
1764        (x*y + y**2)/(x**2 + x)
1765        >>> expand((x + y)*y/x/(x + 1), numer=True)
1766        (x*y + y**2)/(x*(x + 1))
1767        >>> expand((x + y)*y/x/(x + 1), denom=True)
1768        y*(x + y)/(x**2 + x)
1769
1770    - The ``modulus`` meta-hint can be used to reduce the coefficients of an
1771      expression post-expansion::
1772
1773        >>> expand((3*x + 1)**2)
1774        9*x**2 + 6*x + 1
1775        >>> expand((3*x + 1)**2, modulus=5)
1776        4*x**2 + x + 1
1777
1778    - Either ``expand()`` the function or ``.expand()`` the method can be
1779      used.  Both are equivalent::
1780
1781        >>> expand((x + 1)**2)
1782        x**2 + 2*x + 1
1783        >>> ((x + 1)**2).expand()
1784        x**2 + 2*x + 1
1785
1786
1787    - Objects can define their own expand hints by defining
1788      ``_eval_expand_hint()``.  The function should take the form::
1789
1790        def _eval_expand_hint(self, **hints):
1791            # Only apply the method to the top-level expression
1792            ...
1793
1794      See also the example below.  Objects should define ``_eval_expand_hint()``
1795      methods only if ``hint`` applies to that specific object.  The generic
1796      ``_eval_expand_hint()`` method defined in Expr will handle the no-op case.
1797
1798      Each hint should be responsible for expanding that hint only.
1799      Furthermore, the expansion should be applied to the top-level expression
1800      only.  ``expand()`` takes care of the recursion that happens when
1801      ``deep=True``.
1802
1803      You should only call ``_eval_expand_hint()`` methods directly if you are
1804      100% sure that the object has the method, as otherwise you are liable to
1805      get unexpected ``AttributeError``'s.  Note, again, that you do not need to
1806      recursively apply the hint to args of your object: this is handled
1807      automatically by ``expand()``.  ``_eval_expand_hint()`` should
1808      generally not be used at all outside of an ``_eval_expand_hint()`` method.
1809      If you want to apply a specific expansion from within another method, use
1810      the public ``expand()`` function, method, or ``expand_hint()`` functions.
1811
1812      In order for expand to work, objects must be rebuildable by their args,
1813      i.e., ``obj.func(*obj.args) == obj`` must hold.
1814
1815      Expand methods are passed ``**hints`` so that expand hints may use
1816      'metahints'--hints that control how different expand methods are applied.
1817      For example, the ``force=True`` hint described above that causes
1818      ``expand(log=True)`` to ignore assumptions is such a metahint.  The
1819      ``deep`` meta-hint is handled exclusively by ``expand()`` and is not
1820      passed to ``_eval_expand_hint()`` methods.
1821
1822      Note that expansion hints should generally be methods that perform some
1823      kind of 'expansion'.  For hints that simply rewrite an expression, use the
1824      .rewrite() API.
1825
1826    Examples
1827    ========
1828
1829    >>> class MyClass(Expr):
1830    ...     def __new__(cls, *args):
1831    ...         args = sympify(args)
1832    ...         return Expr.__new__(cls, *args)
1833    ...
1834    ...     def _eval_expand_double(self, **hints):
1835    ...         # Doubles the args of MyClass.
1836    ...         # If there more than four args, doubling is not performed,
1837    ...         # unless force=True is also used (False by default).
1838    ...         force = hints.pop('force', False)
1839    ...         if not force and len(self.args) > 4:
1840    ...             return self
1841    ...         return self.func(*(self.args + self.args))
1842    ...
1843    >>> a = MyClass(1, 2, MyClass(3, 4))
1844    >>> a
1845    MyClass(1, 2, MyClass(3, 4))
1846    >>> a.expand(double=True)
1847    MyClass(1, 2, MyClass(3, 4, 3, 4), 1, 2, MyClass(3, 4, 3, 4))
1848    >>> a.expand(double=True, deep=False)
1849    MyClass(1, 2, MyClass(3, 4), 1, 2, MyClass(3, 4))
1850
1851    >>> b = MyClass(1, 2, 3, 4, 5)
1852    >>> b.expand(double=True)
1853    MyClass(1, 2, 3, 4, 5)
1854    >>> b.expand(double=True, force=True)
1855    MyClass(1, 2, 3, 4, 5, 1, 2, 3, 4, 5)
1856
1857    See Also
1858    ========
1859
1860    expand_log, expand_mul, expand_multinomial, expand_complex, expand_trig,
1861    expand_power_base, expand_power_exp, expand_func,
1862    diofant.simplify.hyperexpand.hyperexpand
1863
1864    References
1865    ==========
1866
1867    * https://mathworld.wolfram.com/Multiple-AngleFormulas.html
1868
1869    """
1870    # don't modify this; modify the Expr.expand method
1871    hints['power_base'] = power_base
1872    hints['power_exp'] = power_exp
1873    hints['mul'] = mul
1874    hints['log'] = log
1875    hints['multinomial'] = multinomial
1876    hints['basic'] = basic
1877    return sympify(e).expand(deep=deep, modulus=modulus, **hints)
1878
1879# This is a special application of two hints
1880
1881
1882def _mexpand(expr, recursive=False):
1883    # expand multinomials and then expand products; this may not always
1884    # be sufficient to give a fully expanded expression (see
1885    # test_sympyissue_8247_8354 in test_arit)
1886    was = None
1887    while was != expr:
1888        was, expr = expr, expand_mul(expand_multinomial(expr))
1889        if not recursive:
1890            break
1891    return expr
1892
1893
1894# These are simple wrappers around single hints.
1895
1896
1897def expand_mul(expr, deep=True):
1898    """
1899    Wrapper around expand that only uses the mul hint.  See the expand
1900    docstring for more information.
1901
1902    Examples
1903    ========
1904
1905    >>> x, y = symbols('x y', positive=True)
1906    >>> expand_mul(exp(x+y)*(x+y)*log(x*y**2))
1907    E**(x + y)*x*log(x*y**2) + E**(x + y)*y*log(x*y**2)
1908
1909    """
1910    return sympify(expr).expand(deep=deep, mul=True, power_exp=False,
1911                                power_base=False, basic=False, multinomial=False, log=False)
1912
1913
1914def expand_multinomial(expr, deep=True):
1915    """
1916    Wrapper around expand that only uses the multinomial hint.  See the expand
1917    docstring for more information.
1918
1919    Examples
1920    ========
1921
1922    >>> x, y = symbols('x y', positive=True)
1923    >>> expand_multinomial((x + exp(x + 1))**2)
1924    2*E**(x + 1)*x + E**(2*x + 2) + x**2
1925
1926    """
1927    return sympify(expr).expand(deep=deep, mul=False, power_exp=False,
1928                                power_base=False, basic=False, multinomial=True, log=False)
1929
1930
1931def expand_log(expr, deep=True, force=False):
1932    """
1933    Wrapper around expand that only uses the log hint.  See the expand
1934    docstring for more information.
1935
1936    Examples
1937    ========
1938
1939    >>> x, y = symbols('x y', positive=True)
1940    >>> expand_log(exp(x+y)*(x+y)*log(x*y**2))
1941    E**(x + y)*(x + y)*(log(x) + 2*log(y))
1942
1943    """
1944    return sympify(expr).expand(deep=deep, log=True, mul=False,
1945                                power_exp=False, power_base=False, multinomial=False,
1946                                basic=False, force=force)
1947
1948
1949def expand_func(expr, deep=True):
1950    """
1951    Wrapper around expand that only uses the func hint.  See the expand
1952    docstring for more information.
1953
1954    Examples
1955    ========
1956
1957    >>> expand_func(gamma(x + 2))
1958    x*(x + 1)*gamma(x)
1959
1960    """
1961    return sympify(expr).expand(deep=deep, func=True, basic=False,
1962                                log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
1963
1964
1965def expand_trig(expr, deep=True):
1966    """
1967    Wrapper around expand that only uses the trig hint.  See the expand
1968    docstring for more information.
1969
1970    Examples
1971    ========
1972
1973    >>> expand_trig(sin(x+y)*(x+y))
1974    (x + y)*(sin(x)*cos(y) + sin(y)*cos(x))
1975
1976    """
1977    return sympify(expr).expand(deep=deep, trig=True, basic=False,
1978                                log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
1979
1980
1981def expand_complex(expr, deep=True):
1982    """
1983    Wrapper around expand that only uses the complex hint.  See the expand
1984    docstring for more information.
1985
1986    Examples
1987    ========
1988
1989    >>> expand_complex(exp(z))
1990    E**re(z)*I*sin(im(z)) + E**re(z)*cos(im(z))
1991    >>> expand_complex(sqrt(I))
1992    sqrt(2)/2 + sqrt(2)*I/2
1993
1994    See Also
1995    ========
1996
1997    diofant.core.expr.Expr.as_real_imag
1998
1999    """
2000    return sympify(expr).expand(deep=deep, complex=True, basic=False,
2001                                log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
2002
2003
2004def expand_power_base(expr, deep=True, force=False):
2005    """
2006    Wrapper around expand that only uses the power_base hint.
2007
2008    A wrapper to expand(power_base=True) which separates a power with a base
2009    that is a Mul into a product of powers, without performing any other
2010    expansions, provided that assumptions about the power's base and exponent
2011    allow.
2012
2013    deep=False (default is True) will only apply to the top-level expression.
2014
2015    force=True (default is False) will cause the expansion to ignore
2016    assumptions about the base and exponent. When False, the expansion will
2017    only happen if the base is non-negative or the exponent is an integer.
2018
2019    >>> (x*y)**2
2020    x**2*y**2
2021
2022    >>> (2*x)**y
2023    (2*x)**y
2024    >>> expand_power_base(_)
2025    2**y*x**y
2026
2027    >>> expand_power_base((x*y)**z)
2028    (x*y)**z
2029    >>> expand_power_base((x*y)**z, force=True)
2030    x**z*y**z
2031    >>> expand_power_base(sin((x*y)**z), deep=False)
2032    sin((x*y)**z)
2033    >>> expand_power_base(sin((x*y)**z), force=True)
2034    sin(x**z*y**z)
2035
2036    >>> expand_power_base((2*sin(x))**y + (2*cos(x))**y)
2037    2**y*sin(x)**y + 2**y*cos(x)**y
2038
2039    >>> expand_power_base((2*exp(y))**x)
2040    2**x*(E**y)**x
2041
2042    >>> expand_power_base((2*cos(x))**y)
2043    2**y*cos(x)**y
2044
2045    Notice that sums are left untouched. If this is not the desired behavior,
2046    apply full ``expand()`` to the expression:
2047
2048    >>> expand_power_base(((x+y)*z)**2)
2049    z**2*(x + y)**2
2050    >>> (((x+y)*z)**2).expand()
2051    x**2*z**2 + 2*x*y*z**2 + y**2*z**2
2052
2053    >>> expand_power_base((2*y)**(1+z))
2054    2**(z + 1)*y**(z + 1)
2055    >>> ((2*y)**(1+z)).expand()
2056    2*2**z*y*y**z
2057
2058    See Also
2059    ========
2060
2061    expand
2062
2063    """
2064    return sympify(expr).expand(deep=deep, log=False, mul=False,
2065                                power_exp=False, power_base=True, multinomial=False,
2066                                basic=False, force=force)
2067
2068
2069def expand_power_exp(expr, deep=True):
2070    """
2071    Wrapper around expand that only uses the power_exp hint.
2072
2073    Examples
2074    ========
2075
2076    >>> expand_power_exp(x**(y + 2))
2077    x**2*x**y
2078
2079    See Also
2080    ========
2081
2082    expand
2083
2084    """
2085    return sympify(expr).expand(deep=deep, complex=False, basic=False,
2086                                log=False, mul=False, power_exp=True, power_base=False, multinomial=False)
2087
2088
2089def count_ops(expr, visual=False):
2090    """
2091    Return a representation (integer or expression) of the operations in expr.
2092
2093    If ``visual`` is ``False`` (default) then the sum of the coefficients of the
2094    visual expression will be returned.
2095
2096    If ``visual`` is ``True`` then the number of each type of operation is shown
2097    with the core class types (or their virtual equivalent) multiplied by the
2098    number of times they occur.
2099
2100    If expr is an iterable, the sum of the op counts of the
2101    items will be returned.
2102
2103    Examples
2104    ========
2105
2106    Although there isn't a SUB object, minus signs are interpreted as
2107    either negations or subtractions:
2108
2109    >>> (x - y).count_ops(visual=True)
2110    SUB
2111    >>> (-x).count_ops(visual=True)
2112    NEG
2113
2114    Here, there are two Adds and a Pow:
2115
2116    >>> (1 + a + b**2).count_ops(visual=True)
2117    2*ADD + POW
2118
2119    In the following, an Add, Mul, Pow and two functions:
2120
2121    >>> (sin(x)*x + sin(x)**2).count_ops(visual=True)
2122    ADD + MUL + POW + 2*SIN
2123
2124    for a total of 5:
2125
2126    >>> (sin(x)*x + sin(x)**2).count_ops(visual=False)
2127    5
2128
2129    Note that "what you type" is not always what you get. The expression
2130    1/x/y is translated by diofant into 1/(x*y) so it gives a DIV and MUL rather
2131    than two DIVs:
2132
2133    >>> (1/x/y).count_ops(visual=True)
2134    DIV + MUL
2135
2136    The visual option can be used to demonstrate the difference in
2137    operations for expressions in different forms. Here, the Horner
2138    representation is compared with the expanded form of a polynomial:
2139
2140    >>> eq = x*(1 + x*(2 + x*(3 + x)))
2141    >>> count_ops(eq.expand(), visual=True) - count_ops(eq, visual=True)
2142    -MUL + 3*POW
2143
2144    The count_ops function also handles iterables:
2145
2146    >>> count_ops([x, sin(x), None, True, x + 2], visual=False)
2147    2
2148    >>> count_ops([x, sin(x), None, True, x + 2], visual=True)
2149    ADD + SIN
2150    >>> count_ops({x: sin(x), x + 2: y + 1}, visual=True)
2151    2*ADD + SIN
2152
2153    """
2154    from ..integrals import Integral
2155    from ..logic.boolalg import BooleanFunction
2156    from ..simplify.radsimp import fraction
2157    from .symbol import Symbol
2158
2159    expr = sympify(expr)
2160
2161    if type(expr) is dict:
2162        ops = [count_ops(k, visual=visual) +
2163               count_ops(v, visual=visual) for k, v in expr.items()]
2164    elif iterable(expr):
2165        ops = [count_ops(i, visual=visual) for i in expr]
2166    elif isinstance(expr, Expr):
2167
2168        ops = []
2169        args = [expr]
2170        NEG = Symbol('NEG')
2171        DIV = Symbol('DIV')
2172        SUB = Symbol('SUB')
2173        ADD = Symbol('ADD')
2174        while args:
2175            a = args.pop()
2176
2177            if a.is_Rational:
2178                # -1/3 = NEG + DIV
2179                if a != 1:
2180                    if a.numerator < 0:
2181                        ops.append(NEG)
2182                    if a.denominator != 1:
2183                        ops.append(DIV)
2184                    continue
2185            elif a.is_Mul:
2186                if _coeff_isneg(a):
2187                    ops.append(NEG)
2188                    if a.args[0] == -1:
2189                        a = a.as_two_terms()[1]
2190                    else:
2191                        a = -a
2192                n, d = fraction(a)
2193                if n.is_Integer:
2194                    ops.append(DIV)
2195                    args.append(d)
2196                    continue  # won't be -Mul but could be Add
2197                elif d != 1:
2198                    if not d.is_Integer:
2199                        args.append(d)
2200                    ops.append(DIV)
2201                    args.append(n)
2202                    continue  # could be -Mul
2203            elif a.is_Add:
2204                aargs = list(a.args)
2205                negs = 0
2206                for i, ai in enumerate(aargs):
2207                    if _coeff_isneg(ai):
2208                        negs += 1
2209                        args.append(-ai)
2210                        if i > 0:
2211                            ops.append(SUB)
2212                    else:
2213                        args.append(ai)
2214                        if i > 0:
2215                            ops.append(ADD)
2216                if negs == len(aargs):  # -x - y = NEG + SUB
2217                    ops.append(NEG)
2218                elif _coeff_isneg(aargs[0]):  # -x + y = SUB, but already recorded ADD
2219                    ops.append(SUB - ADD)
2220                continue
2221            elif isinstance(expr, BooleanFunction):
2222                ops = []
2223                for arg in expr.args:
2224                    ops.append(count_ops(arg, visual=True))
2225                o = Symbol(expr.func.__name__.upper())
2226                ops.append(o)
2227                continue
2228            if a.is_Pow and a.exp == -1:
2229                ops.append(DIV)
2230                args.append(a.base)  # won't be -Mul but could be Add
2231                continue
2232            if (a.is_Mul or
2233                a.is_Pow or
2234                a.is_Function or
2235                isinstance(a, Derivative) or
2236                    isinstance(a, Integral)):
2237
2238                o = Symbol(a.func.__name__.upper())
2239                # count the args
2240                if (a.is_Mul or isinstance(a, LatticeOp)):
2241                    ops.append(o*(len(a.args) - 1))
2242                else:
2243                    ops.append(o)
2244            if not a.is_Symbol:
2245                args.extend(a.args)
2246
2247    elif not isinstance(expr, Basic):
2248        ops = []
2249    else:
2250        ops = []
2251        args = [expr]
2252        while args:
2253            a = args.pop()
2254            if a.args:
2255                o = Symbol(a.func.__name__.upper())
2256                ops.append(o)
2257                args.extend(a.args)
2258
2259    if not ops:
2260        if visual:
2261            return Integer(0)
2262        return 0
2263
2264    ops = Add(*ops)
2265
2266    if visual:
2267        return ops
2268
2269    if ops.is_Number:
2270        return int(ops)
2271
2272    return sum(int((a.args or [1])[0]) for a in Add.make_args(ops))
2273
2274
2275def nfloat(expr, n=15, exponent=False):
2276    """Make all Rationals in expr Floats except those in exponents
2277    (unless the exponents flag is set to True).
2278
2279    Examples
2280    ========
2281
2282    >>> nfloat(x**4 + x/2 + cos(pi/3) + 1 + sqrt(y))
2283    x**4 + 0.5*x + sqrt(y) + 1.5
2284    >>> nfloat(x**4 + sqrt(y), exponent=True)
2285    x**4.0 + y**0.5
2286
2287    """
2288    from ..polys.rootoftools import RootOf
2289    from .power import Pow
2290    from .symbol import Dummy
2291
2292    if iterable(expr, exclude=(str,)):
2293        if isinstance(expr, (dict, Dict)):
2294            return type(expr)([(k, nfloat(v, n, exponent)) for k, v in
2295                               list(expr.items())])
2296        return type(expr)([nfloat(a, n, exponent) for a in expr])
2297    rv = sympify(expr)
2298
2299    if rv.is_Number:
2300        return Float(rv, n)
2301    elif rv.is_number:
2302        # evalf doesn't always set the precision
2303        rv = rv.evalf(n)
2304        if rv.is_Number:
2305            rv = Float(rv, n)
2306        else:
2307            pass  # pure_complex(rv) is likely True
2308        return rv
2309
2310    # watch out for RootOf instances that don't like to have
2311    # their exponents replaced with Dummies and also sometimes have
2312    # problems with evaluating at low precision (issue sympy/sympy#6393)
2313    rv = rv.xreplace({ro: ro.evalf(n) for ro in rv.atoms(RootOf)})
2314
2315    if not exponent:
2316        reps = [(p, Pow(p.base, Dummy())) for p in rv.atoms(Pow)]
2317        rv = rv.xreplace(dict(reps))
2318    rv = rv.evalf(n, strict=False)
2319    if not exponent:
2320        rv = rv.xreplace({d.exp: p.exp for p, d in reps})
2321    else:
2322        # Pow._eval_evalf special cases Integer exponents so if
2323        # exponent is suppose to be handled we have to do so here
2324        rv = rv.xreplace(Transform(
2325            lambda x: Pow(x.base, Float(x.exp, n)),
2326            lambda x: x.is_Pow and x.exp.is_Integer))
2327
2328    return rv.xreplace(Transform(
2329        lambda x: x.func(*nfloat(x.args, n, exponent)),
2330        lambda x: isinstance(x, Function)))
2331