1import math 2from sympy import Interval 3from sympy.calculus.singularities import is_increasing, is_decreasing 4from sympy.codegen.rewriting import Optimization 5from sympy.core.function import UndefinedFunction 6 7""" 8This module collects classes useful for approimate rewriting of expressions. 9This can be beneficial when generating numeric code for which performance is 10of greater importance than precision (e.g. for preconditioners used in iterative 11methods). 12""" 13 14class SumApprox(Optimization): 15 """ 16 Approximates sum by neglecting small terms. 17 18 Explanation 19 =========== 20 21 If terms are expressions which can be determined to be monotonic, then 22 bounds for those expressions are added. 23 24 Parameters 25 ========== 26 27 bounds : dict 28 Mapping expressions to length 2 tuple of bounds (low, high). 29 reltol : number 30 Threshold for when to ignore a term. Taken relative to the largest 31 lower bound among bounds. 32 33 Examples 34 ======== 35 36 >>> from sympy import exp 37 >>> from sympy.abc import x, y, z 38 >>> from sympy.codegen.rewriting import optimize 39 >>> from sympy.codegen.approximations import SumApprox 40 >>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)} 41 >>> sum_approx3 = SumApprox(bounds, reltol=1e-3) 42 >>> sum_approx2 = SumApprox(bounds, reltol=1e-2) 43 >>> sum_approx1 = SumApprox(bounds, reltol=1e-1) 44 >>> expr = 3*(x + y + exp(z)) 45 >>> optimize(expr, [sum_approx3]) 46 3*(x + y + exp(z)) 47 >>> optimize(expr, [sum_approx2]) 48 3*y + 3*exp(z) 49 >>> optimize(expr, [sum_approx1]) 50 3*y 51 52 """ 53 54 def __init__(self, bounds, reltol, **kwargs): 55 super().__init__(**kwargs) 56 self.bounds = bounds 57 self.reltol = reltol 58 59 def __call__(self, expr): 60 return expr.factor().replace(self.query, lambda arg: self.value(arg)) 61 62 def query(self, expr): 63 return expr.is_Add 64 65 def value(self, add): 66 for term in add.args: 67 if term.is_number or term in self.bounds or len(term.free_symbols) != 1: 68 continue 69 fs, = term.free_symbols 70 if fs not in self.bounds: 71 continue 72 intrvl = Interval(*self.bounds[fs]) 73 if is_increasing(term, intrvl, fs): 74 self.bounds[term] = ( 75 term.subs({fs: self.bounds[fs][0]}), 76 term.subs({fs: self.bounds[fs][1]}) 77 ) 78 elif is_decreasing(term, intrvl, fs): 79 self.bounds[term] = ( 80 term.subs({fs: self.bounds[fs][1]}), 81 term.subs({fs: self.bounds[fs][0]}) 82 ) 83 else: 84 return add 85 86 if all(term.is_number or term in self.bounds for term in add.args): 87 bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args] 88 largest_abs_guarantee = 0 89 for lo, hi in bounds: 90 if lo <= 0 <= hi: 91 continue 92 largest_abs_guarantee = max(largest_abs_guarantee, 93 min(abs(lo), abs(hi))) 94 new_terms = [] 95 for term, (lo, hi) in zip(add.args, bounds): 96 if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol: 97 new_terms.append(term) 98 return add.func(*new_terms) 99 else: 100 return add 101 102 103class SeriesApprox(Optimization): 104 """ Approximates functions by expanding them as a series. 105 106 Parameters 107 ========== 108 109 bounds : dict 110 Mapping expressions to length 2 tuple of bounds (low, high). 111 reltol : number 112 Threshold for when to ignore a term. Taken relative to the largest 113 lower bound among bounds. 114 max_order : int 115 Largest order to include in series expansion 116 n_point_checks : int (even) 117 The validity of an expansion (with respect to reltol) is checked at 118 discrete points (linearly spaced over the bounds of the variable). The 119 number of points used in this numerical check is given by this number. 120 121 Examples 122 ======== 123 124 >>> from sympy import sin, pi 125 >>> from sympy.abc import x, y 126 >>> from sympy.codegen.rewriting import optimize 127 >>> from sympy.codegen.approximations import SeriesApprox 128 >>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)} 129 >>> series_approx2 = SeriesApprox(bounds, reltol=1e-2) 130 >>> series_approx3 = SeriesApprox(bounds, reltol=1e-3) 131 >>> series_approx8 = SeriesApprox(bounds, reltol=1e-8) 132 >>> expr = sin(x)*sin(y) 133 >>> optimize(expr, [series_approx2]) 134 x*(-y + (y - pi)**3/6 + pi) 135 >>> optimize(expr, [series_approx3]) 136 (-x**3/6 + x)*sin(y) 137 >>> optimize(expr, [series_approx8]) 138 sin(x)*sin(y) 139 140 """ 141 def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs): 142 super().__init__(**kwargs) 143 self.bounds = bounds 144 self.reltol = reltol 145 self.max_order = max_order 146 if n_point_checks % 2 == 1: 147 raise ValueError("Checking the solution at expansion point is not helpful") 148 self.n_point_checks = n_point_checks 149 self._prec = math.ceil(-math.log10(self.reltol)) 150 151 def __call__(self, expr): 152 return expr.factor().replace(self.query, lambda arg: self.value(arg)) 153 154 def query(self, expr): 155 return (expr.is_Function and not isinstance(expr, UndefinedFunction) 156 and len(expr.args) == 1) 157 158 def value(self, fexpr): 159 free_symbols = fexpr.free_symbols 160 if len(free_symbols) != 1: 161 return fexpr 162 symb, = free_symbols 163 if symb not in self.bounds: 164 return fexpr 165 lo, hi = self.bounds[symb] 166 x0 = (lo + hi)/2 167 cheapest = None 168 for n in range(self.max_order+1, 0, -1): 169 fseri = fexpr.series(symb, x0=x0, n=n).removeO() 170 n_ok = True 171 for idx in range(self.n_point_checks): 172 x = lo + idx*(hi - lo)/(self.n_point_checks - 1) 173 val = fseri.xreplace({symb: x}) 174 ref = fexpr.xreplace({symb: x}) 175 if abs((1 - val/ref).evalf(self._prec)) > self.reltol: 176 n_ok = False 177 break 178 179 if n_ok: 180 cheapest = fseri 181 else: 182 break 183 184 if cheapest is None: 185 return fexpr 186 else: 187 return cheapest 188