1"""Tools for manipulation of expressions using paths. """
2
3from sympy.core import Basic
4
5
6class EPath:
7    r"""
8    Manipulate expressions using paths.
9
10    EPath grammar in EBNF notation::
11
12        literal   ::= /[A-Za-z_][A-Za-z_0-9]*/
13        number    ::= /-?\d+/
14        type      ::= literal
15        attribute ::= literal "?"
16        all       ::= "*"
17        slice     ::= "[" number? (":" number? (":" number?)?)? "]"
18        range     ::= all | slice
19        query     ::= (type | attribute) ("|" (type | attribute))*
20        selector  ::= range | query range?
21        path      ::= "/" selector ("/" selector)*
22
23    See the docstring of the epath() function.
24
25    """
26
27    __slots__ = ("_path", "_epath")
28
29    def __new__(cls, path):
30        """Construct new EPath. """
31        if isinstance(path, EPath):
32            return path
33
34        if not path:
35            raise ValueError("empty EPath")
36
37        _path = path
38
39        if path[0] == '/':
40            path = path[1:]
41        else:
42            raise NotImplementedError("non-root EPath")
43
44        epath = []
45
46        for selector in path.split('/'):
47            selector = selector.strip()
48
49            if not selector:
50                raise ValueError("empty selector")
51
52            index = 0
53
54            for c in selector:
55                if c.isalnum() or c == '_' or c == '|' or c == '?':
56                    index += 1
57                else:
58                    break
59
60            attrs = []
61            types = []
62
63            if index:
64                elements = selector[:index]
65                selector = selector[index:]
66
67                for element in elements.split('|'):
68                    element = element.strip()
69
70                    if not element:
71                        raise ValueError("empty element")
72
73                    if element.endswith('?'):
74                        attrs.append(element[:-1])
75                    else:
76                        types.append(element)
77
78            span = None
79
80            if selector == '*':
81                pass
82            else:
83                if selector.startswith('['):
84                    try:
85                        i = selector.index(']')
86                    except ValueError:
87                        raise ValueError("expected ']', got EOL")
88
89                    _span, span = selector[1:i], []
90
91                    if ':' not in _span:
92                        span = int(_span)
93                    else:
94                        for elt in _span.split(':', 3):
95                            if not elt:
96                                span.append(None)
97                            else:
98                                span.append(int(elt))
99
100                        span = slice(*span)
101
102                    selector = selector[i + 1:]
103
104                if selector:
105                    raise ValueError("trailing characters in selector")
106
107            epath.append((attrs, types, span))
108
109        obj = object.__new__(cls)
110
111        obj._path = _path
112        obj._epath = epath
113
114        return obj
115
116    def __repr__(self):
117        return "%s(%r)" % (self.__class__.__name__, self._path)
118
119    def _get_ordered_args(self, expr):
120        """Sort ``expr.args`` using printing order. """
121        if expr.is_Add:
122            return expr.as_ordered_terms()
123        elif expr.is_Mul:
124            return expr.as_ordered_factors()
125        else:
126            return expr.args
127
128    def _hasattrs(self, expr, attrs):
129        """Check if ``expr`` has any of ``attrs``. """
130        for attr in attrs:
131            if not hasattr(expr, attr):
132                return False
133
134        return True
135
136    def _hastypes(self, expr, types):
137        """Check if ``expr`` is any of ``types``. """
138        _types = [ cls.__name__ for cls in expr.__class__.mro() ]
139        return bool(set(_types).intersection(types))
140
141    def _has(self, expr, attrs, types):
142        """Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """
143        if not (attrs or types):
144            return True
145
146        if attrs and self._hasattrs(expr, attrs):
147            return True
148
149        if types and self._hastypes(expr, types):
150            return True
151
152        return False
153
154    def apply(self, expr, func, args=None, kwargs=None):
155        """
156        Modify parts of an expression selected by a path.
157
158        Examples
159        ========
160
161        >>> from sympy.simplify.epathtools import EPath
162        >>> from sympy import sin, cos, E
163        >>> from sympy.abc import x, y, z, t
164
165        >>> path = EPath("/*/[0]/Symbol")
166        >>> expr = [((x, 1), 2), ((3, y), z)]
167
168        >>> path.apply(expr, lambda expr: expr**2)
169        [((x**2, 1), 2), ((3, y**2), z)]
170
171        >>> path = EPath("/*/*/Symbol")
172        >>> expr = t + sin(x + 1) + cos(x + y + E)
173
174        >>> path.apply(expr, lambda expr: 2*expr)
175        t + sin(2*x + 1) + cos(2*x + 2*y + E)
176
177        """
178        def _apply(path, expr, func):
179            if not path:
180                return func(expr)
181            else:
182                selector, path = path[0], path[1:]
183                attrs, types, span = selector
184
185                if isinstance(expr, Basic):
186                    if not expr.is_Atom:
187                        args, basic = self._get_ordered_args(expr), True
188                    else:
189                        return expr
190                elif hasattr(expr, '__iter__'):
191                    args, basic = expr, False
192                else:
193                    return expr
194
195                args = list(args)
196
197                if span is not None:
198                    if type(span) == slice:
199                        indices = range(*span.indices(len(args)))
200                    else:
201                        indices = [span]
202                else:
203                    indices = range(len(args))
204
205                for i in indices:
206                    try:
207                        arg = args[i]
208                    except IndexError:
209                        continue
210
211                    if self._has(arg, attrs, types):
212                        args[i] = _apply(path, arg, func)
213
214                if basic:
215                    return expr.func(*args)
216                else:
217                    return expr.__class__(args)
218
219        _args, _kwargs = args or (), kwargs or {}
220        _func = lambda expr: func(expr, *_args, **_kwargs)
221
222        return _apply(self._epath, expr, _func)
223
224    def select(self, expr):
225        """
226        Retrieve parts of an expression selected by a path.
227
228        Examples
229        ========
230
231        >>> from sympy.simplify.epathtools import EPath
232        >>> from sympy import sin, cos, E
233        >>> from sympy.abc import x, y, z, t
234
235        >>> path = EPath("/*/[0]/Symbol")
236        >>> expr = [((x, 1), 2), ((3, y), z)]
237
238        >>> path.select(expr)
239        [x, y]
240
241        >>> path = EPath("/*/*/Symbol")
242        >>> expr = t + sin(x + 1) + cos(x + y + E)
243
244        >>> path.select(expr)
245        [x, x, y]
246
247        """
248        result = []
249
250        def _select(path, expr):
251            if not path:
252                result.append(expr)
253            else:
254                selector, path = path[0], path[1:]
255                attrs, types, span = selector
256
257                if isinstance(expr, Basic):
258                    args = self._get_ordered_args(expr)
259                elif hasattr(expr, '__iter__'):
260                    args = expr
261                else:
262                    return
263
264                if span is not None:
265                    if type(span) == slice:
266                        args = args[span]
267                    else:
268                        try:
269                            args = [args[span]]
270                        except IndexError:
271                            return
272
273                for arg in args:
274                    if self._has(arg, attrs, types):
275                        _select(path, arg)
276
277        _select(self._epath, expr)
278        return result
279
280
281def epath(path, expr=None, func=None, args=None, kwargs=None):
282    r"""
283    Manipulate parts of an expression selected by a path.
284
285    Explanation
286    ===========
287
288    This function allows to manipulate large nested expressions in single
289    line of code, utilizing techniques to those applied in XML processing
290    standards (e.g. XPath).
291
292    If ``func`` is ``None``, :func:`epath` retrieves elements selected by
293    the ``path``. Otherwise it applies ``func`` to each matching element.
294
295    Note that it is more efficient to create an EPath object and use the select
296    and apply methods of that object, since this will compile the path string
297    only once.  This function should only be used as a convenient shortcut for
298    interactive use.
299
300    This is the supported syntax:
301
302    * select all: ``/*``
303          Equivalent of ``for arg in args:``.
304    * select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]``
305          Supports standard Python's slice syntax.
306    * select by type: ``/list`` or ``/list|tuple``
307          Emulates ``isinstance()``.
308    * select by attribute: ``/__iter__?``
309          Emulates ``hasattr()``.
310
311    Parameters
312    ==========
313
314    path : str | EPath
315        A path as a string or a compiled EPath.
316    expr : Basic | iterable
317        An expression or a container of expressions.
318    func : callable (optional)
319        A callable that will be applied to matching parts.
320    args : tuple (optional)
321        Additional positional arguments to ``func``.
322    kwargs : dict (optional)
323        Additional keyword arguments to ``func``.
324
325    Examples
326    ========
327
328    >>> from sympy.simplify.epathtools import epath
329    >>> from sympy import sin, cos, E
330    >>> from sympy.abc import x, y, z, t
331
332    >>> path = "/*/[0]/Symbol"
333    >>> expr = [((x, 1), 2), ((3, y), z)]
334
335    >>> epath(path, expr)
336    [x, y]
337    >>> epath(path, expr, lambda expr: expr**2)
338    [((x**2, 1), 2), ((3, y**2), z)]
339
340    >>> path = "/*/*/Symbol"
341    >>> expr = t + sin(x + 1) + cos(x + y + E)
342
343    >>> epath(path, expr)
344    [x, x, y]
345    >>> epath(path, expr, lambda expr: 2*expr)
346    t + sin(2*x + 1) + cos(2*x + 2*y + E)
347
348    """
349    _epath = EPath(path)
350
351    if expr is None:
352        return _epath
353    if func is None:
354        return _epath.select(expr)
355    else:
356        return _epath.apply(expr, func, args, kwargs)
357