1import math
2from sympy import Interval
3from sympy.calculus.singularities import is_increasing, is_decreasing
4from sympy.codegen.rewriting import Optimization
5from sympy.core.function import UndefinedFunction
6
7"""
8This module collects classes useful for approimate rewriting of expressions.
9This can be beneficial when generating numeric code for which performance is
10of greater importance than precision (e.g. for preconditioners used in iterative
11methods).
12"""
13
14class SumApprox(Optimization):
15    """
16    Approximates sum by neglecting small terms.
17
18    Explanation
19    ===========
20
21    If terms are expressions which can be determined to be monotonic, then
22    bounds for those expressions are added.
23
24    Parameters
25    ==========
26
27    bounds : dict
28        Mapping expressions to length 2 tuple of bounds (low, high).
29    reltol : number
30        Threshold for when to ignore a term. Taken relative to the largest
31        lower bound among bounds.
32
33    Examples
34    ========
35
36    >>> from sympy import exp
37    >>> from sympy.abc import x, y, z
38    >>> from sympy.codegen.rewriting import optimize
39    >>> from sympy.codegen.approximations import SumApprox
40    >>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
41    >>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
42    >>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
43    >>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
44    >>> expr = 3*(x + y + exp(z))
45    >>> optimize(expr, [sum_approx3])
46    3*(x + y + exp(z))
47    >>> optimize(expr, [sum_approx2])
48    3*y + 3*exp(z)
49    >>> optimize(expr, [sum_approx1])
50    3*y
51
52    """
53
54    def __init__(self, bounds, reltol, **kwargs):
55        super().__init__(**kwargs)
56        self.bounds = bounds
57        self.reltol = reltol
58
59    def __call__(self, expr):
60        return expr.factor().replace(self.query, lambda arg: self.value(arg))
61
62    def query(self, expr):
63        return expr.is_Add
64
65    def value(self, add):
66        for term in add.args:
67            if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
68                continue
69            fs, = term.free_symbols
70            if fs not in self.bounds:
71                continue
72            intrvl = Interval(*self.bounds[fs])
73            if is_increasing(term, intrvl, fs):
74                self.bounds[term] = (
75                    term.subs({fs: self.bounds[fs][0]}),
76                    term.subs({fs: self.bounds[fs][1]})
77                )
78            elif is_decreasing(term, intrvl, fs):
79                self.bounds[term] = (
80                    term.subs({fs: self.bounds[fs][1]}),
81                    term.subs({fs: self.bounds[fs][0]})
82                )
83            else:
84                return add
85
86        if all(term.is_number or term in self.bounds for term in add.args):
87            bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
88            largest_abs_guarantee = 0
89            for lo, hi in bounds:
90                if lo <= 0 <= hi:
91                    continue
92                largest_abs_guarantee = max(largest_abs_guarantee,
93                                            min(abs(lo), abs(hi)))
94            new_terms = []
95            for term, (lo, hi) in zip(add.args, bounds):
96                if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
97                    new_terms.append(term)
98            return add.func(*new_terms)
99        else:
100            return add
101
102
103class SeriesApprox(Optimization):
104    """ Approximates functions by expanding them as a series.
105
106    Parameters
107    ==========
108
109    bounds : dict
110        Mapping expressions to length 2 tuple of bounds (low, high).
111    reltol : number
112        Threshold for when to ignore a term. Taken relative to the largest
113        lower bound among bounds.
114    max_order : int
115        Largest order to include in series expansion
116    n_point_checks : int (even)
117        The validity of an expansion (with respect to reltol) is checked at
118        discrete points (linearly spaced over the bounds of the variable). The
119        number of points used in this numerical check is given by this number.
120
121    Examples
122    ========
123
124    >>> from sympy import sin, pi
125    >>> from sympy.abc import x, y
126    >>> from sympy.codegen.rewriting import optimize
127    >>> from sympy.codegen.approximations import SeriesApprox
128    >>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
129    >>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
130    >>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
131    >>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
132    >>> expr = sin(x)*sin(y)
133    >>> optimize(expr, [series_approx2])
134    x*(-y + (y - pi)**3/6 + pi)
135    >>> optimize(expr, [series_approx3])
136    (-x**3/6 + x)*sin(y)
137    >>> optimize(expr, [series_approx8])
138    sin(x)*sin(y)
139
140    """
141    def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
142        super().__init__(**kwargs)
143        self.bounds = bounds
144        self.reltol = reltol
145        self.max_order = max_order
146        if n_point_checks % 2 == 1:
147            raise ValueError("Checking the solution at expansion point is not helpful")
148        self.n_point_checks = n_point_checks
149        self._prec = math.ceil(-math.log10(self.reltol))
150
151    def __call__(self, expr):
152        return expr.factor().replace(self.query, lambda arg: self.value(arg))
153
154    def query(self, expr):
155        return (expr.is_Function and not isinstance(expr, UndefinedFunction)
156                and len(expr.args) == 1)
157
158    def value(self, fexpr):
159        free_symbols = fexpr.free_symbols
160        if len(free_symbols) != 1:
161            return fexpr
162        symb, = free_symbols
163        if symb not in self.bounds:
164            return fexpr
165        lo, hi = self.bounds[symb]
166        x0 = (lo + hi)/2
167        cheapest = None
168        for n in range(self.max_order+1, 0, -1):
169            fseri = fexpr.series(symb, x0=x0, n=n).removeO()
170            n_ok = True
171            for idx in range(self.n_point_checks):
172                x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
173                val = fseri.xreplace({symb: x})
174                ref = fexpr.xreplace({symb: x})
175                if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
176                    n_ok = False
177                    break
178
179            if n_ok:
180                cheapest = fseri
181            else:
182                break
183
184        if cheapest is None:
185            return fexpr
186        else:
187            return cheapest
188