1# -*- coding: utf-8 -*-
2
3from __future__ import absolute_import, division, print_function
4
5from itertools import chain
6
7import numpy as np
8from sym import Backend
9from sym.util import banded_jacobian, check_transforms
10
11from .core import NeqSys, _ensure_3args
12
13
14def _map2(cb, iterable):
15    if cb is None:  # identity function is assumed
16        return iterable
17    else:
18        return map(cb, iterable)
19
20
21def _map2l(cb, iterable):  # Py2 type of map in Py3
22    return list(_map2(cb, iterable))
23
24
25class SymbolicSys(NeqSys):
26    """ Symbolically defined system of non-linear equations.
27
28    This object is analogous to :class:`pyneqsys.NeqSys` but instead of
29    providing a callable, the user provides symbolic expressions.
30
31    Parameters
32    ----------
33    x : iterable of Symbols
34    exprs : iterable of expressions for ``f``
35    params : iterable of Symbols (optional)
36        list of symbols appearing in exprs which are parameters
37    jac : ImmutableMatrix or bool
38        If ``True``:
39            - Calculate Jacobian from ``exprs``.
40        If ``False``:
41            - Do not compute Jacobian (numeric approximation).
42        If ImmutableMatrix:
43            - User provided expressions for the Jacobian.
44    backend : str or sym.Backend
45        See documentation of `sym.Backend \
46<https://pythonhosted.org/sym/sym.html#sym.backend.Backend>`_.
47    module : str
48        ``module`` keyword argument passed to ``backend.Lambdify``.
49    \\*\\*kwargs:
50        See :py:class:`pyneqsys.core.NeqSys`.
51
52    Examples
53    --------
54    >>> import sympy as sp
55    >>> e = sp.exp
56    >>> x = x0, x1 = sp.symbols('x:2')
57    >>> params = a, b = sp.symbols('a b')
58    >>> neqsys = SymbolicSys(x, [a*(1 - x0), b*(x1 - x0**2)], params)
59    >>> xout, sol = neqsys.solve('scipy', [-10, -5], [1, 10])
60    >>> print(xout)  # doctest: +NORMALIZE_WHITESPACE
61    [ 1.  1.]
62    >>> print(neqsys.get_jac()[0, 0])
63    -a
64
65    Notes
66    -----
67    When using SymPy as the backend, a limited number of unknowns is supported.
68    The reason is that (currently) ``sympy.lambdify`` has an upper limit on
69    number of arguments.
70
71    """
72
73    def __init__(self, x, exprs, params=(), jac=True, backend=None, **kwargs):
74        self.x = x
75        self.exprs = exprs
76        self.params = params
77        self._jac = jac
78        self.be = Backend(backend)
79        self.nf, self.nx = len(exprs), len(x)  # needed by get_*_cb
80        self.band = kwargs.get('band', None)  # needed by get_*_cb
81        self.module = kwargs.pop('module', 'numpy')
82        super(SymbolicSys, self).__init__(self.nf, self.nx,
83                                          self._get_f_cb(),
84                                          self._get_j_cb(),
85                                          **kwargs)
86
87    @classmethod
88    def from_callback(cls, cb, nx=None, nparams=None, **kwargs):
89        """ Generate a SymbolicSys instance from a callback.
90
91        Parameters
92        ----------
93        cb : callable
94            Should have the signature ``cb(x, p, backend) -> list of exprs``.
95        nx : int
96            Number of unknowns, when not given it is deduced from ``kwargs['names']``.
97        nparams : int
98            Number of parameters, when not given it is deduced from ``kwargs['param_names']``.
99
100        \\*\\*kwargs :
101            Keyword arguments passed on to :class:`SymbolicSys`. See also :class:`pyneqsys.NeqSys`.
102
103        Examples
104        --------
105        >>> symbolicsys = SymbolicSys.from_callback(lambda x, p, be: [
106        ...     x[0]*x[1] - p[0],
107        ...     be.exp(-x[0]) + be.exp(-x[1]) - p[0]**-2
108        ... ], 2, 1)
109        ...
110
111        """
112        if kwargs.get('x_by_name', False):
113            if 'names' not in kwargs:
114                raise ValueError("Need ``names`` in kwargs.")
115            if nx is None:
116                nx = len(kwargs['names'])
117            elif nx != len(kwargs['names']):
118                raise ValueError("Inconsistency between nx and length of ``names``.")
119        if kwargs.get('par_by_name', False):
120            if 'param_names' not in kwargs:
121                raise ValueError("Need ``param_names`` in kwargs.")
122            if nparams is None:
123                nparams = len(kwargs['param_names'])
124            elif nparams != len(kwargs['param_names']):
125                raise ValueError("Inconsistency between ``nparam`` and length of ``param_names``.")
126
127        if nparams is None:
128            nparams = 0
129
130        if nx is None:
131            raise ValueError("Need ``nx`` of ``names`` together with ``x_by_name==True``.")
132        be = Backend(kwargs.pop('backend', None))
133        x = be.real_symarray('x', nx)
134        p = be.real_symarray('p', nparams)
135        _x = dict(zip(kwargs['names'], x)) if kwargs.get('x_by_name', False) else x
136        _p = dict(zip(kwargs['param_names'], p)) if kwargs.get('par_by_name', False) else p
137        try:
138            exprs = cb(_x, _p, be)
139        except TypeError:
140            exprs = _ensure_3args(cb)(_x, _p, be)
141        return cls(x, exprs, p, backend=be, **kwargs)
142
143    def get_jac(self):
144        """ Return the jacobian of the expressions """
145        if self._jac is True:
146            if self.band is None:
147                f = self.be.Matrix(self.nf, 1, self.exprs)
148                _x = self.be.Matrix(self.nx, 1, self.x)
149                return f.jacobian(_x)
150            else:
151                # Banded
152                return self.be.Matrix(banded_jacobian(
153                    self.exprs, self.x, *self.band))
154        elif self._jac is False:
155            return False
156        else:
157            return self._jac
158
159    def _get_f_cb(self):
160        args = list(chain(self.x, self.params))
161        kw = dict(module=self.module, dtype=object if self.module == 'mpmath' else None)
162        try:
163            cb = self.be.Lambdify(args, self.exprs, **kw)
164        except TypeError:
165            cb = self.be.Lambdify(args, self.exprs)
166
167        def f(x, params):
168            return cb(np.concatenate((x, params), axis=-1))
169        return f
170
171    def _get_j_cb(self):
172        args = list(chain(self.x, self.params))
173        kw = dict(module=self.module, dtype=object if self.module == 'mpmath' else None)
174        try:
175            cb = self.be.Lambdify(args, self.get_jac(), **kw)
176        except TypeError:
177            cb = self.be.Lambdify(args, self.get_jac())
178
179        def j(x, params):
180            return cb(np.concatenate((x, params), axis=-1))
181        return j
182
183    _use_symbol_latex_names = True
184
185    def _repr_latex_(self):  # pretty printing in Jupyter notebook
186        from ._sympy import NeqSysTexPrinter
187        if self.latex_names and (self.latex_param_names if len(self.params) else True):
188            pretty = {s: n for s, n in chain(
189                zip(self.x, self.latex_names) if self._use_symbol_latex_names else [],
190                zip(self.params, self.latex_param_names)
191            )}
192        else:
193            pretty = {}
194
195        return '$%s$' % NeqSysTexPrinter(dict(symbol_names=pretty)).doprint(self.exprs)
196
197
198class TransformedSys(SymbolicSys):
199    """ A system which transforms the equations and variables internally
200
201    Can be used to reformulate a problem in a numerically more stable form.
202
203    Parameters
204    ----------
205    x : iterable of variables
206    exprs : iterable of expressions
207         Expressions to find root for (untransformed).
208    transf : iterable of pairs of expressions
209        Forward, backward transformed instances of x.
210    params : iterable of symbols
211    post_adj : callable (default: None)
212        To tweak expression after transformation.
213    \\*\\*kwargs :
214        Keyword arguments passed onto :class:`SymbolicSys`.
215
216    """
217    _use_symbol_latex_names = False  # symbols have been transformed
218
219    def __init__(self, x, exprs, transf, params=(), post_adj=None, **kwargs):
220        self.fw, self.bw = zip(*transf)
221        check_transforms(self.fw, self.bw, x)
222        exprs = [e.subs(zip(x, self.fw)) for e in exprs]
223        super(TransformedSys, self).__init__(
224            x, _map2l(post_adj, exprs), params,
225            pre_processors=[lambda xarr, params: (self.bw_cb(xarr), params)],
226            post_processors=[lambda xarr, params: (self.fw_cb(xarr), params)],
227            **kwargs)
228        self.fw_cb = self.be.Lambdify(x, self.fw)
229        self.bw_cb = self.be.Lambdify(x, self.bw)
230
231    @classmethod
232    def from_callback(cls, cb, transf_cbs, nx, nparams=0, pre_adj=None,
233                      **kwargs):
234        """ Generate a TransformedSys instance from a callback
235
236        Parameters
237        ----------
238        cb : callable
239            Should have the signature ``cb(x, p, backend) -> list of exprs``.
240            The callback ``cb`` should return *untransformed* expressions.
241        transf_cbs : pair or iterable of pairs of callables
242            Callables for forward- and backward-transformations. Each
243            callable should take a single parameter (expression) and
244            return a single expression.
245        nx : int
246            Number of unkowns.
247        nparams : int
248            Number of parameters.
249        pre_adj : callable, optional
250            To tweak expression prior to transformation. Takes a
251            sinlge argument (expression) and return a single argument
252            rewritten expression.
253        \\*\\*kwargs :
254            Keyword arguments passed on to :class:`TransformedSys`. See also
255            :class:`SymbolicSys` and :class:`pyneqsys.NeqSys`.
256
257        Examples
258        --------
259        >>> import sympy as sp
260        >>> transformed = TransformedSys.from_callback(lambda x, p, be: [
261        ...     x[0]*x[1] - p[0],
262        ...     be.exp(-x[0]) + be.exp(-x[1]) - p[0]**-2
263        ... ], (sp.log, sp.exp), 2, 1)
264        ...
265
266
267        """
268        be = Backend(kwargs.pop('backend', None))
269        x = be.real_symarray('x', nx)
270        p = be.real_symarray('p', nparams)
271        try:
272            transf = [(transf_cbs[idx][0](xi),
273                       transf_cbs[idx][1](xi))
274                      for idx, xi in enumerate(x)]
275        except TypeError:
276            transf = zip(_map2(transf_cbs[0], x), _map2(transf_cbs[1], x))
277        try:
278            exprs = cb(x, p, be)
279        except TypeError:
280            exprs = _ensure_3args(cb)(x, p, be)
281        return cls(x, _map2l(pre_adj, exprs), transf, p, backend=be, **kwargs)
282
283
284def linear_rref(A, b, Matrix=None, S=None):
285    """ Transform a linear system to reduced row-echelon form
286
287    Transforms both the matrix and right-hand side of a linear
288    system of equations to reduced row echelon form
289
290    Parameters
291    ----------
292    A : Matrix-like
293        Iterable of rows.
294    b : iterable
295
296    Returns
297    -------
298    A', b' - transformed versions
299
300    """
301    if Matrix is None:
302        from sympy import Matrix
303    if S is None:
304        from sympy import S
305    mat_rows = [_map2l(S, list(row) + [v]) for row, v in zip(A, b)]
306    aug = Matrix(mat_rows)
307    raug, pivot = aug.rref()
308    nindep = len(pivot)
309    return raug[:nindep, :-1], raug[:nindep, -1]
310
311
312def linear_exprs(A, x, b=None, rref=False, Matrix=None):
313    """ Returns Ax - b
314
315    Parameters
316    ----------
317    A : matrix_like of numbers
318        Of shape (len(b), len(x)).
319    x : iterable of symbols
320    b : array_like of numbers (default: None)
321        When ``None``, assume zeros of length ``len(x)``.
322    Matrix : class
323        When ``rref == True``: A matrix class which supports slicing,
324        and methods ``__mul__`` and ``rref``. Defaults to ``sympy.Matrix``.
325    rref : bool
326        Calculate the reduced row echelon form of (A | -b).
327
328    Returns
329    -------
330    A list of the elements in the resulting column vector.
331
332    """
333    if b is None:
334        b = [0]*len(x)
335    if rref:
336        rA, rb = linear_rref(A, b, Matrix)
337        if Matrix is None:
338            from sympy import Matrix
339        return [lhs - rhs for lhs, rhs in zip(rA * Matrix(len(x), 1, x), rb)]
340    else:
341        return [sum([x0*x1 for x0, x1 in zip(row, x)]) - v
342                for row, v in zip(A, b)]
343