1# -*- coding: utf-8 -*-
2"""The xonsh abstract syntax tree node."""
3# These are imported into our module namespace for the benefit of parser.py.
4# pylint: disable=unused-import
5import sys
6from ast import (
7    Module,
8    Num,
9    Expr,
10    Str,
11    Bytes,
12    UnaryOp,
13    UAdd,
14    USub,
15    Invert,
16    BinOp,
17    Add,
18    Sub,
19    Mult,
20    Div,
21    FloorDiv,
22    Mod,
23    Pow,
24    Compare,
25    Lt,
26    Gt,
27    LtE,
28    GtE,
29    Eq,
30    NotEq,
31    In,
32    NotIn,
33    Is,
34    IsNot,
35    Not,
36    BoolOp,
37    Or,
38    And,
39    Subscript,
40    Load,
41    Slice,
42    ExtSlice,
43    List,
44    Tuple,
45    Set,
46    Dict,
47    AST,
48    NameConstant,
49    Name,
50    GeneratorExp,
51    Store,
52    comprehension,
53    ListComp,
54    SetComp,
55    DictComp,
56    Assign,
57    AugAssign,
58    BitXor,
59    BitAnd,
60    BitOr,
61    LShift,
62    RShift,
63    Assert,
64    Delete,
65    Del,
66    Pass,
67    Raise,
68    Import,
69    alias,
70    ImportFrom,
71    Continue,
72    Break,
73    Yield,
74    YieldFrom,
75    Return,
76    IfExp,
77    Lambda,
78    arguments,
79    arg,
80    Call,
81    keyword,
82    Attribute,
83    Global,
84    Nonlocal,
85    If,
86    While,
87    For,
88    withitem,
89    With,
90    Try,
91    ExceptHandler,
92    FunctionDef,
93    ClassDef,
94    Starred,
95    NodeTransformer,
96    Interactive,
97    Expression,
98    Index,
99    literal_eval,
100    dump,
101    walk,
102    increment_lineno,
103)
104from ast import Ellipsis as EllipsisNode
105
106# pylint: enable=unused-import
107import textwrap
108import itertools
109
110from xonsh.tools import subproc_toks, find_next_break, get_logical_line
111from xonsh.platform import PYTHON_VERSION_INFO
112
113if PYTHON_VERSION_INFO >= (3, 5, 0):
114    # pylint: disable=unused-import
115    # pylint: disable=no-name-in-module
116    from ast import MatMult, AsyncFunctionDef, AsyncWith, AsyncFor, Await
117else:
118    MatMult = AsyncFunctionDef = AsyncWith = AsyncFor = Await = None
119
120if PYTHON_VERSION_INFO >= (3, 6, 0):
121    # pylint: disable=unused-import
122    # pylint: disable=no-name-in-module
123    from ast import JoinedStr, FormattedValue
124else:
125    JoinedStr = FormattedValue = None
126
127STATEMENTS = (
128    FunctionDef,
129    ClassDef,
130    Return,
131    Delete,
132    Assign,
133    AugAssign,
134    For,
135    While,
136    If,
137    With,
138    Raise,
139    Try,
140    Assert,
141    Import,
142    ImportFrom,
143    Global,
144    Nonlocal,
145    Expr,
146    Pass,
147    Break,
148    Continue,
149)
150
151
152def leftmostname(node):
153    """Attempts to find the first name in the tree."""
154    if isinstance(node, Name):
155        rtn = node.id
156    elif isinstance(node, (BinOp, Compare)):
157        rtn = leftmostname(node.left)
158    elif isinstance(node, (Attribute, Subscript, Starred, Expr)):
159        rtn = leftmostname(node.value)
160    elif isinstance(node, Call):
161        rtn = leftmostname(node.func)
162    elif isinstance(node, UnaryOp):
163        rtn = leftmostname(node.operand)
164    elif isinstance(node, BoolOp):
165        rtn = leftmostname(node.values[0])
166    elif isinstance(node, Assign):
167        rtn = leftmostname(node.targets[0])
168    elif isinstance(node, (Str, Bytes, JoinedStr)):
169        # handles case of "./my executable"
170        rtn = leftmostname(node.s)
171    elif isinstance(node, Tuple) and len(node.elts) > 0:
172        # handles case of echo ,1,2,3
173        rtn = leftmostname(node.elts[0])
174    else:
175        rtn = None
176    return rtn
177
178
179def get_lineno(node, default=0):
180    """Gets the lineno of a node or returns the default."""
181    return getattr(node, "lineno", default)
182
183
184def min_line(node):
185    """Computes the minimum lineno."""
186    node_line = get_lineno(node)
187    return min(map(get_lineno, walk(node), itertools.repeat(node_line)))
188
189
190def max_line(node):
191    """Computes the maximum lineno."""
192    return max(map(get_lineno, walk(node)))
193
194
195def get_col(node, default=-1):
196    """Gets the col_offset of a node, or returns the default"""
197    return getattr(node, "col_offset", default)
198
199
200def min_col(node):
201    """Computes the minimum col_offset."""
202    return min(map(get_col, walk(node), itertools.repeat(node.col_offset)))
203
204
205def max_col(node):
206    """Returns the maximum col_offset of the node and all sub-nodes."""
207    col = getattr(node, "max_col", None)
208    if col is not None:
209        return col
210    highest = max(walk(node), key=get_col)
211    col = highest.col_offset + node_len(highest)
212    return col
213
214
215def node_len(node):
216    """The length of a node as a string"""
217    val = 0
218    for n in walk(node):
219        if isinstance(n, Name):
220            val += len(n.id)
221        elif isinstance(n, Attribute):
222            val += 1 + (len(n.attr) if isinstance(n.attr, str) else 0)
223        # this may need to be added to for more nodes as more cases are found
224    return val
225
226
227def get_id(node, default=None):
228    """Gets the id attribute of a node, or returns a default."""
229    return getattr(node, "id", default)
230
231
232def gather_names(node):
233    """Returns the set of all names present in the node's tree."""
234    rtn = set(map(get_id, walk(node)))
235    rtn.discard(None)
236    return rtn
237
238
239def get_id_ctx(node):
240    """Gets the id and attribute of a node, or returns a default."""
241    nid = getattr(node, "id", None)
242    if nid is None:
243        return (None, None)
244    return (nid, node.ctx)
245
246
247def gather_load_store_names(node):
248    """Returns the names present in the node's tree in a set of load nodes and
249    a set of store nodes.
250    """
251    load = set()
252    store = set()
253    for nid, ctx in map(get_id_ctx, walk(node)):
254        if nid is None:
255            continue
256        elif isinstance(ctx, Load):
257            load.add(nid)
258        else:
259            store.add(nid)
260    return (load, store)
261
262
263def has_elts(x):
264    """Tests if x is an AST node with elements."""
265    return isinstance(x, AST) and hasattr(x, "elts")
266
267
268def xonsh_call(name, args, lineno=None, col=None):
269    """Creates the AST node for calling a function of a given name."""
270    return Call(
271        func=Name(id=name, ctx=Load(), lineno=lineno, col_offset=col),
272        args=args,
273        keywords=[],
274        starargs=None,
275        kwargs=None,
276        lineno=lineno,
277        col_offset=col,
278    )
279
280
281def isdescendable(node):
282    """Determines whether or not a node is worth visiting. Currently only
283    UnaryOp and BoolOp nodes are visited.
284    """
285    return isinstance(node, (UnaryOp, BoolOp))
286
287
288class CtxAwareTransformer(NodeTransformer):
289    """Transforms a xonsh AST based to use subprocess calls when
290    the first name in an expression statement is not known in the context.
291    This assumes that the expression statement is instead parseable as
292    a subprocess.
293    """
294
295    def __init__(self, parser):
296        """Parameters
297        ----------
298        parser : xonsh.Parser
299            A parse instance to try to parse subprocess statements with.
300        """
301        super(CtxAwareTransformer, self).__init__()
302        self.parser = parser
303        self.input = None
304        self.contexts = []
305        self.lines = None
306        self.mode = None
307        self._nwith = 0
308        self.filename = "<xonsh-code>"
309        self.debug_level = 0
310
311    def ctxvisit(self, node, inp, ctx, mode="exec", filename=None, debug_level=0):
312        """Transforms the node in a context-dependent way.
313
314        Parameters
315        ----------
316        node : ast.AST
317            A syntax tree to transform.
318        input : str
319            The input code in string format.
320        ctx : dict
321            The root context to use.
322        filename : str, optional
323            File we are to transform.
324        debug_level : int, optional
325            Debugging level to use in lexing and parsing.
326
327        Returns
328        -------
329        node : ast.AST
330            The transformed node.
331        """
332        self.filename = self.filename if filename is None else filename
333        self.debug_level = debug_level
334        self.lines = inp.splitlines()
335        self.contexts = [ctx, set()]
336        self.mode = mode
337        self._nwith = 0
338        node = self.visit(node)
339        del self.lines, self.contexts, self.mode
340        self._nwith = 0
341        return node
342
343    def ctxupdate(self, iterable):
344        """Updated the most recent context."""
345        self.contexts[-1].update(iterable)
346
347    def ctxadd(self, value):
348        """Adds a value the most recent context."""
349        self.contexts[-1].add(value)
350
351    def ctxremove(self, value):
352        """Removes a value the most recent context."""
353        for ctx in reversed(self.contexts):
354            if value in ctx:
355                ctx.remove(value)
356                break
357
358    def try_subproc_toks(self, node, strip_expr=False):
359        """Tries to parse the line of the node as a subprocess."""
360        line, nlogical, idx = get_logical_line(self.lines, node.lineno - 1)
361        if self.mode == "eval":
362            mincol = len(line) - len(line.lstrip())
363            maxcol = None
364        else:
365            mincol = max(min_col(node) - 1, 0)
366            maxcol = max_col(node)
367            if mincol == maxcol:
368                maxcol = find_next_break(line, mincol=mincol, lexer=self.parser.lexer)
369            elif nlogical > 1:
370                maxcol = None
371            elif maxcol < len(line) and line[maxcol] == ";":
372                pass
373            else:
374                maxcol += 1
375        spline = subproc_toks(
376            line,
377            mincol=mincol,
378            maxcol=maxcol,
379            returnline=False,
380            lexer=self.parser.lexer,
381        )
382        if spline is None or len(spline) < len(line[mincol:maxcol]) + 2:
383            # failed to get something consistent, try greedy wrap
384            # The +2 comes from "![]" being length 3, minus 1 since maxcol
385            # is one beyond the total length for slicing
386            spline = subproc_toks(
387                line,
388                mincol=mincol,
389                maxcol=maxcol,
390                returnline=False,
391                lexer=self.parser.lexer,
392                greedy=True,
393            )
394        if spline is None:
395            return node
396        try:
397            newnode = self.parser.parse(
398                spline,
399                mode=self.mode,
400                filename=self.filename,
401                debug_level=(self.debug_level > 2),
402            )
403            newnode = newnode.body
404            if not isinstance(newnode, AST):
405                # take the first (and only) Expr
406                newnode = newnode[0]
407            increment_lineno(newnode, n=node.lineno - 1)
408            newnode.col_offset = node.col_offset
409            if self.debug_level > 1:
410                msg = "{0}:{1}:{2}{3} - {4}\n" "{0}:{1}:{2}{3} + {5}"
411                mstr = "" if maxcol is None else ":" + str(maxcol)
412                msg = msg.format(self.filename, node.lineno, mincol, mstr, line, spline)
413                print(msg, file=sys.stderr)
414        except SyntaxError:
415            newnode = node
416        if strip_expr and isinstance(newnode, Expr):
417            newnode = newnode.value
418        return newnode
419
420    def is_in_scope(self, node):
421        """Determines whether or not the current node is in scope."""
422        names, store = gather_load_store_names(node)
423        names -= store
424        if not names:
425            return True
426        inscope = False
427        for ctx in reversed(self.contexts):
428            names -= ctx
429            if not names:
430                inscope = True
431                break
432        return inscope
433
434    #
435    # Replacement visitors
436    #
437
438    def visit_Expression(self, node):
439        """Handle visiting an expression body."""
440        if isdescendable(node.body):
441            node.body = self.visit(node.body)
442        body = node.body
443        inscope = self.is_in_scope(body)
444        if not inscope:
445            node.body = self.try_subproc_toks(body)
446        return node
447
448    def visit_Expr(self, node):
449        """Handle visiting an expression."""
450        if isdescendable(node.value):
451            node.value = self.visit(node.value)  # this allows diving into BoolOps
452        if self.is_in_scope(node) or isinstance(node.value, Lambda):
453            return node
454        else:
455            newnode = self.try_subproc_toks(node)
456            if not isinstance(newnode, Expr):
457                newnode = Expr(
458                    value=newnode, lineno=node.lineno, col_offset=node.col_offset
459                )
460                if hasattr(node, "max_lineno"):
461                    newnode.max_lineno = node.max_lineno
462                    newnode.max_col = node.max_col
463            return newnode
464
465    def visit_UnaryOp(self, node):
466        """Handle visiting an unary operands, like not."""
467        if isdescendable(node.operand):
468            node.operand = self.visit(node.operand)
469        operand = node.operand
470        inscope = self.is_in_scope(operand)
471        if not inscope:
472            node.operand = self.try_subproc_toks(operand, strip_expr=True)
473        return node
474
475    def visit_BoolOp(self, node):
476        """Handle visiting an boolean operands, like and/or."""
477        for i in range(len(node.values)):
478            val = node.values[i]
479            if isdescendable(val):
480                val = node.values[i] = self.visit(val)
481            inscope = self.is_in_scope(val)
482            if not inscope:
483                node.values[i] = self.try_subproc_toks(val, strip_expr=True)
484        return node
485
486    #
487    # Context aggregator visitors
488    #
489
490    def visit_Assign(self, node):
491        """Handle visiting an assignment statement."""
492        ups = set()
493        for targ in node.targets:
494            if isinstance(targ, (Tuple, List)):
495                ups.update(leftmostname(elt) for elt in targ.elts)
496            elif isinstance(targ, BinOp):
497                newnode = self.try_subproc_toks(node)
498                if newnode is node:
499                    ups.add(leftmostname(targ))
500                else:
501                    return newnode
502            else:
503                ups.add(leftmostname(targ))
504        self.ctxupdate(ups)
505        return node
506
507    def visit_Import(self, node):
508        """Handle visiting a import statement."""
509        for name in node.names:
510            if name.asname is None:
511                self.ctxadd(name.name)
512            else:
513                self.ctxadd(name.asname)
514        return node
515
516    def visit_ImportFrom(self, node):
517        """Handle visiting a "from ... import ..." statement."""
518        for name in node.names:
519            if name.asname is None:
520                self.ctxadd(name.name)
521            else:
522                self.ctxadd(name.asname)
523        return node
524
525    def visit_With(self, node):
526        """Handle visiting a with statement."""
527        for item in node.items:
528            if item.optional_vars is not None:
529                self.ctxupdate(gather_names(item.optional_vars))
530        self._nwith += 1
531        self.generic_visit(node)
532        self._nwith -= 1
533        return node
534
535    def visit_For(self, node):
536        """Handle visiting a for statement."""
537        targ = node.target
538        self.ctxupdate(gather_names(targ))
539        self.generic_visit(node)
540        return node
541
542    def visit_FunctionDef(self, node):
543        """Handle visiting a function definition."""
544        self.ctxadd(node.name)
545        self.contexts.append(set())
546        args = node.args
547        argchain = [args.args, args.kwonlyargs]
548        if args.vararg is not None:
549            argchain.append((args.vararg,))
550        if args.kwarg is not None:
551            argchain.append((args.kwarg,))
552        self.ctxupdate(a.arg for a in itertools.chain.from_iterable(argchain))
553        self.generic_visit(node)
554        self.contexts.pop()
555        return node
556
557    def visit_ClassDef(self, node):
558        """Handle visiting a class definition."""
559        self.ctxadd(node.name)
560        self.contexts.append(set())
561        self.generic_visit(node)
562        self.contexts.pop()
563        return node
564
565    def visit_Delete(self, node):
566        """Handle visiting a del statement."""
567        for targ in node.targets:
568            if isinstance(targ, Name):
569                self.ctxremove(targ.id)
570        self.generic_visit(node)
571        return node
572
573    def visit_Try(self, node):
574        """Handle visiting a try statement."""
575        for handler in node.handlers:
576            if handler.name is not None:
577                self.ctxadd(handler.name)
578        self.generic_visit(node)
579        return node
580
581    def visit_Global(self, node):
582        """Handle visiting a global statement."""
583        self.contexts[1].update(node.names)  # contexts[1] is the global ctx
584        self.generic_visit(node)
585        return node
586
587
588def pdump(s, **kwargs):
589    """performs a pretty dump of an AST node."""
590    if isinstance(s, AST):
591        s = dump(s, **kwargs).replace(",", ",\n")
592    openers = "([{"
593    closers = ")]}"
594    lens = len(s) + 1
595    if lens == 1:
596        return s
597    i = min([s.find(o) % lens for o in openers])
598    if i == lens - 1:
599        return s
600    closer = closers[openers.find(s[i])]
601    j = s.rfind(closer)
602    if j == -1 or j <= i:
603        return s[: i + 1] + "\n" + textwrap.indent(pdump(s[i + 1 :]), " ")
604    pre = s[: i + 1] + "\n"
605    mid = s[i + 1 : j]
606    post = "\n" + s[j:]
607    mid = textwrap.indent(pdump(mid), " ")
608    if "(" in post or "[" in post or "{" in post:
609        post = pdump(post)
610    return pre + mid + post
611
612
613def pprint_ast(s, *, sep=None, end=None, file=None, flush=False, **kwargs):
614    """Performs a pretty print of the AST nodes."""
615    print(pdump(s, **kwargs), sep=sep, end=end, file=file, flush=flush)
616
617
618#
619# Private helpers
620#
621
622
623def _getblockattr(name, lineno, col):
624    """calls getattr(name, '__xonsh_block__', False)."""
625    return xonsh_call(
626        "getattr",
627        args=[
628            Name(id=name, ctx=Load(), lineno=lineno, col_offset=col),
629            Str(s="__xonsh_block__", lineno=lineno, col_offset=col),
630            NameConstant(value=False, lineno=lineno, col_offset=col),
631        ],
632        lineno=lineno,
633        col=col,
634    )
635