1# -*- encoding: utf-8 -*-
2# Copyright 2020 the authors.
3# This file is part of Hy, which is free software licensed under the Expat
4# license. See the LICENSE.
5
6from hy.models import (HyObject, HyExpression, HyKeyword, HyInteger, HyComplex,
7                       HyString, HyBytes, HySymbol, HyFloat, HyList, HySet,
8                       HyDict, HySequence, wrap_value)
9from hy.model_patterns import (FORM, SYM, KEYWORD, STR, sym, brackets, whole,
10                               notpexpr, dolike, pexpr, times, Tag, tag, unpack)
11from funcparserlib.parser import some, many, oneplus, maybe, NoParseError
12from hy.errors import (HyCompileError, HyTypeError, HyLanguageError,
13                       HySyntaxError, HyEvalError, HyInternalError)
14
15from hy.lex import mangle, unmangle, hy_parse, parse_one_thing, LexException
16
17from hy._compat import (PY36, PY38, reraise)
18from hy.macros import require, load_macros, macroexpand, tag_macroexpand
19
20import hy.core
21
22import re
23import textwrap
24import pkgutil
25import traceback
26import itertools
27import importlib
28import inspect
29import types
30import ast
31import sys
32import copy
33import builtins
34import __future__
35
36from collections import defaultdict
37from functools import reduce
38
39
40Inf = float('inf')
41
42
43hy_ast_compile_flags = (__future__.CO_FUTURE_DIVISION |
44                        __future__.CO_FUTURE_PRINT_FUNCTION)
45
46
47def ast_compile(a, filename, mode):
48    """Compile AST.
49
50    Parameters
51    ----------
52    a : instance of `ast.AST`
53
54    filename : str
55        Filename used for run-time error messages
56
57    mode: str
58        `compile` mode parameter
59
60    Returns
61    -------
62    out : instance of `types.CodeType`
63    """
64    return compile(a, filename, mode, hy_ast_compile_flags)
65
66
67def calling_module(n=1):
68    """Get the module calling, if available.
69
70    As a fallback, this will import a module using the calling frame's
71    globals value of `__name__`.
72
73    Parameters
74    ----------
75    n: int, optional
76        The number of levels up the stack from this function call.
77        The default is one level up.
78
79    Returns
80    -------
81    out: types.ModuleType
82        The module at stack level `n + 1` or `None`.
83    """
84    frame_up = inspect.stack(0)[n + 1][0]
85    module = inspect.getmodule(frame_up)
86    if module is None:
87        # This works for modules like `__main__`
88        module_name = frame_up.f_globals.get('__name__', None)
89        if module_name:
90            try:
91                module = importlib.import_module(module_name)
92            except ImportError:
93                pass
94    return module
95
96
97def ast_str(x, piecewise=False):
98    if piecewise:
99        return ".".join(ast_str(s) if s else "" for s in x.split("."))
100    return mangle(x)
101
102
103_special_form_compilers = {}
104_model_compilers = {}
105_decoratables = (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)
106# _bad_roots are fake special operators, which are used internally
107# by other special forms (e.g., `except` in `try`) but can't be
108# used to construct special forms themselves.
109_bad_roots = tuple(ast_str(x) for x in (
110    "unquote", "unquote-splice", "unpack-mapping", "except"))
111
112
113def named_constant(expr, v):
114    return (asty.Constant(expr, value=v)
115       if PY36
116       else asty.Name(expr, id=str(v), ctx=ast.Load()))
117
118
119def special(names, pattern):
120    """Declare special operators. The decorated method and the given pattern
121    is assigned to _special_form_compilers for each of the listed names."""
122    pattern = whole(pattern)
123    def dec(fn):
124        for name in names if isinstance(names, list) else [names]:
125            if isinstance(name, tuple):
126                condition, name = name
127                if not condition:
128                    continue
129            _special_form_compilers[ast_str(name)] = (fn, pattern)
130        return fn
131    return dec
132
133
134def builds_model(*model_types):
135    "Assign the decorated method to _model_compilers for the given types."
136    def _dec(fn):
137        for t in model_types:
138            _model_compilers[t] = fn
139        return fn
140    return _dec
141
142
143# Provide asty.Foo(x, ...) as shorthand for
144# ast.Foo(..., lineno=x.start_line, col_offset=x.start_column) or
145# ast.Foo(..., lineno=x.lineno, col_offset=x.col_offset)
146class Asty(object):
147    def __getattr__(self, name):
148        setattr(Asty, name, staticmethod(lambda x, **kwargs: getattr(ast, name)(
149            lineno=getattr(
150                x, 'start_line', getattr(x, 'lineno', None)),
151            col_offset=getattr(
152                x, 'start_column', getattr(x, 'col_offset', None)),
153            **kwargs)))
154        return getattr(Asty, name)
155asty = Asty()
156
157
158class Result(object):
159    """
160    Smart representation of the result of a hy->AST compilation
161
162    This object tries to reconcile the hy world, where everything can be used
163    as an expression, with the Python world, where statements and expressions
164    need to coexist.
165
166    To do so, we represent a compiler result as a list of statements `stmts`,
167    terminated by an expression context `expr`. The expression context is used
168    when the compiler needs to use the result as an expression.
169
170    Results are chained by addition: adding two results together returns a
171    Result representing the succession of the two Results' statements, with
172    the second Result's expression context.
173
174    We make sure that a non-empty expression context does not get clobbered by
175    adding more results, by checking accesses to the expression context. We
176    assume that the context has been used, or deliberately ignored, if it has
177    been accessed.
178
179    The Result object is interoperable with python AST objects: when an AST
180    object gets added to a Result object, it gets converted on-the-fly.
181    """
182    __slots__ = ("imports", "stmts", "temp_variables",
183                 "_expr", "__used_expr")
184
185    def __init__(self, *args, **kwargs):
186        if args:
187            # emulate kw-only args for future bits.
188            raise TypeError("Yo: Hacker: don't pass me real args, dingus")
189
190        self.imports = defaultdict(set)
191        self.stmts = []
192        self.temp_variables = []
193        self._expr = None
194
195        self.__used_expr = False
196
197        # XXX: Make sure we only have AST where we should.
198        for kwarg in kwargs:
199            if kwarg not in ["imports", "stmts", "expr", "temp_variables"]:
200                raise TypeError(
201                    "%s() got an unexpected keyword argument '%s'" % (
202                        self.__class__.__name__, kwarg))
203            setattr(self, kwarg, kwargs[kwarg])
204
205    @property
206    def expr(self):
207        self.__used_expr = True
208        return self._expr
209
210    @expr.setter
211    def expr(self, value):
212        self.__used_expr = False
213        self._expr = value
214
215    @property
216    def lineno(self):
217        if self._expr is not None:
218            return self._expr.lineno
219        if self.stmts:
220            return self.stmts[-1].lineno
221        return None
222
223    @property
224    def col_offset(self):
225        if self._expr is not None:
226            return self._expr.col_offset
227        if self.stmts:
228            return self.stmts[-1].col_offset
229        return None
230
231    def add_imports(self, mod, imports):
232        """Autoimport `imports` from `mod`"""
233        self.imports[mod].update(imports)
234
235    def is_expr(self):
236        """Check whether I am a pure expression"""
237        return self._expr and not (self.imports or self.stmts)
238
239    @property
240    def force_expr(self):
241        """Force the expression context of the Result.
242
243        If there is no expression context, we return a "None" expression.
244        """
245        if self.expr:
246            return self.expr
247        return (ast.Constant if PY36 else ast.Name)(
248            **({'value': None} if PY36 else {'id': 'None', 'ctx': ast.Load()}),
249            lineno=self.stmts[-1].lineno if self.stmts else 0,
250            col_offset=self.stmts[-1].col_offset if self.stmts else 0)
251
252    def expr_as_stmt(self):
253        """Convert the Result's expression context to a statement
254
255        This is useful when we want to use the stored expression in a
256        statement context (for instance in a code branch).
257
258        We drop ast.Names if they are appended to statements, as they
259        can't have any side effect. "Bare" names still get converted to
260        statements.
261
262        If there is no expression context, return an empty result.
263        """
264        if self.expr and not (isinstance(self.expr, ast.Name) and self.stmts):
265            return Result() + asty.Expr(self.expr, value=self.expr)
266        return Result()
267
268    def rename(self, new_name):
269        """Rename the Result's temporary variables to a `new_name`.
270
271        We know how to handle ast.Names and ast.FunctionDefs.
272        """
273        new_name = ast_str(new_name)
274        for var in self.temp_variables:
275            if isinstance(var, ast.Name):
276                var.id = new_name
277                var.arg = new_name
278            elif isinstance(var, (ast.FunctionDef, ast.AsyncFunctionDef)):
279                var.name = new_name
280            else:
281                raise TypeError("Don't know how to rename a %s!" % (
282                    var.__class__.__name__))
283        self.temp_variables = []
284
285    def __add__(self, other):
286        # If we add an ast statement, convert it first
287        if isinstance(other, ast.stmt):
288            return self + Result(stmts=[other])
289
290        # If we add an ast expression, clobber the expression context
291        if isinstance(other, ast.expr):
292            return self + Result(expr=other)
293
294        if isinstance(other, ast.excepthandler):
295            return self + Result(stmts=[other])
296
297        if not isinstance(other, Result):
298            raise TypeError("Can't add %r with non-compiler result %r" % (
299                self, other))
300
301        # Check for expression context clobbering
302        if self.expr and not self.__used_expr:
303            traceback.print_stack()
304            print("Bad boy clobbered expr %s with %s" % (
305                ast.dump(self.expr),
306                ast.dump(other.expr)))
307
308        # Fairly obvious addition
309        result = Result()
310        result.imports = other.imports
311        result.stmts = self.stmts + other.stmts
312        result.expr = other.expr
313        result.temp_variables = other.temp_variables
314
315        return result
316
317    def __str__(self):
318        return (
319            "Result(imports=[%s], stmts=[%s], expr=%s)"
320        % (
321            ", ".join(ast.dump(x) for x in self.imports),
322            ", ".join(ast.dump(x) for x in self.stmts),
323            ast.dump(self.expr) if self.expr else None
324        ))
325
326
327def is_unpack(kind, x):
328    return (isinstance(x, HyExpression)
329            and len(x) > 0
330            and isinstance(x[0], HySymbol)
331            and x[0] == "unpack-" + kind)
332
333
334def make_hy_model(outer, x, rest):
335   return outer(
336      [HySymbol(a) if type(a) is str else
337              a[0] if type(a) is list else a
338          for a in x] +
339      (rest or []))
340def mkexpr(*items, **kwargs):
341   return make_hy_model(HyExpression, items, kwargs.get('rest'))
342def mklist(*items, **kwargs):
343   return make_hy_model(HyList, items, kwargs.get('rest'))
344
345
346# Parse an annotation setting.
347OPTIONAL_ANNOTATION = maybe(pexpr(sym("annotate*") + FORM) >> (lambda x: x[0]))
348
349
350def is_annotate_expression(model):
351    return (isinstance(model, HyExpression) and model and isinstance(model[0], HySymbol)
352            and model[0] == HySymbol("annotate*"))
353
354
355class HyASTCompiler(object):
356    """A Hy-to-Python AST compiler"""
357
358    def __init__(self, module, filename=None, source=None):
359        """
360        Parameters
361        ----------
362        module: str or types.ModuleType
363            Module name or object in which the Hy tree is evaluated.
364        filename: str, optional
365            The name of the file for the source to be compiled.
366            This is optional information for informative error messages and
367            debugging.
368        source: str, optional
369            The source for the file, if any, being compiled.  This is optional
370            information for informative error messages and debugging.
371        """
372        self.anon_var_count = 0
373        self.imports = defaultdict(set)
374        self.temp_if = None
375
376        if not inspect.ismodule(module):
377            self.module = importlib.import_module(module)
378        else:
379            self.module = module
380
381        self.module_name = self.module.__name__
382
383        self.filename = filename
384        self.source = source
385
386        # Hy expects these to be present, so we prep the module for Hy
387        # compilation.
388        self.module.__dict__.setdefault('__macros__', {})
389        self.module.__dict__.setdefault('__tags__', {})
390
391        self.can_use_stdlib = not self.module_name.startswith("hy.core")
392
393        self._stdlib = {}
394
395        # Everything in core needs to be explicit (except for
396        # the core macros, which are built with the core functions).
397        if self.can_use_stdlib:
398            # Load stdlib macros into the module namespace.
399            load_macros(self.module)
400
401            # Populate _stdlib.
402            for stdlib_module in hy.core.STDLIB:
403                mod = importlib.import_module(stdlib_module)
404                for e in map(ast_str, getattr(mod, 'EXPORTS', [])):
405                    self._stdlib[e] = stdlib_module
406
407    def get_anon_var(self):
408        self.anon_var_count += 1
409        return "_hy_anon_var_%s" % self.anon_var_count
410
411    def update_imports(self, result):
412        """Retrieve the imports from the result object"""
413        for mod in result.imports:
414            self.imports[mod].update(result.imports[mod])
415
416    def imports_as_stmts(self, expr):
417        """Convert the Result's imports to statements"""
418        ret = Result()
419        for module, names in self.imports.items():
420            if None in names:
421                ret += self.compile(mkexpr('import', module).replace(expr))
422            names = sorted(name for name in names if name)
423            if names:
424                ret += self.compile(mkexpr('import',
425                    mklist(module, mklist(*names))))
426        self.imports = defaultdict(set)
427        return ret.stmts
428
429    def compile_atom(self, atom):
430        # Compilation methods may mutate the atom, so copy it first.
431        atom = copy.copy(atom)
432        return Result() + _model_compilers[type(atom)](self, atom)
433
434    def compile(self, tree):
435        if tree is None:
436            return Result()
437        try:
438            ret = self.compile_atom(tree)
439            self.update_imports(ret)
440            return ret
441        except HyCompileError:
442            # compile calls compile, so we're going to have multiple raise
443            # nested; so let's re-raise this exception, let's not wrap it in
444            # another HyCompileError!
445            raise
446        except HyLanguageError as e:
447            # These are expected errors that should be passed to the user.
448            reraise(type(e), e, sys.exc_info()[2])
449        except Exception as e:
450            # These are unexpected errors that will--hopefully--never be seen
451            # by the user.
452            f_exc = traceback.format_exc()
453            exc_msg = "Internal Compiler Bug ��\n⤷ {}".format(f_exc)
454            reraise(HyCompileError, HyCompileError(exc_msg), sys.exc_info()[2])
455
456    def _syntax_error(self, expr, message):
457        return HySyntaxError(message, expr, self.filename, self.source)
458
459    def _compile_collect(self, exprs, with_kwargs=False, dict_display=False):
460        """Collect the expression contexts from a list of compiled expression.
461
462        This returns a list of the expression contexts, and the sum of the
463        Result objects passed as arguments.
464
465        """
466        compiled_exprs = []
467        ret = Result()
468        keywords = []
469
470        exprs_iter = iter(exprs)
471        for expr in exprs_iter:
472
473            if is_unpack("mapping", expr):
474                ret += self.compile(expr[1])
475                if dict_display:
476                    compiled_exprs.append(None)
477                    compiled_exprs.append(ret.force_expr)
478                elif with_kwargs:
479                    keywords.append(asty.keyword(
480                        expr, arg=None, value=ret.force_expr))
481
482            elif with_kwargs and isinstance(expr, HyKeyword):
483                try:
484                    value = next(exprs_iter)
485                except StopIteration:
486                    raise self._syntax_error(expr,
487                        "Keyword argument {kw} needs a value.".format(kw=expr))
488
489                if not expr:
490                    raise self._syntax_error(expr,
491                        "Can't call a function with the empty keyword")
492
493                compiled_value = self.compile(value)
494                ret += compiled_value
495
496                arg = str(expr)[1:]
497                keywords.append(asty.keyword(
498                    expr, arg=ast_str(arg), value=compiled_value.force_expr))
499
500            else:
501                ret += self.compile(expr)
502                compiled_exprs.append(ret.force_expr)
503
504        return compiled_exprs, ret, keywords
505
506    def _compile_branch(self, exprs):
507        """Make a branch out of an iterable of Result objects
508
509        This generates a Result from the given sequence of Results, forcing each
510        expression context as a statement before the next result is used.
511
512        We keep the expression context of the last argument for the returned Result
513        """
514        ret = Result()
515        for x in map(self.compile, exprs[:-1]):
516            ret += x
517            ret += x.expr_as_stmt()
518        if exprs:
519            ret += self.compile(exprs[-1])
520        return ret
521
522    def _storeize(self, expr, name, func=None):
523        """Return a new `name` object with an ast.Store() context"""
524        if not func:
525            func = ast.Store
526
527        if isinstance(name, Result):
528            if not name.is_expr():
529                raise self._syntax_error(expr,
530                    "Can't assign or delete a non-expression")
531            name = name.expr
532
533        if isinstance(name, (ast.Tuple, ast.List)):
534            typ = type(name)
535            new_elts = []
536            for x in name.elts:
537                new_elts.append(self._storeize(expr, x, func))
538            new_name = typ(elts=new_elts)
539        elif isinstance(name, ast.Name):
540            new_name = ast.Name(id=name.id)
541        elif isinstance(name, ast.Subscript):
542            new_name = ast.Subscript(value=name.value, slice=name.slice)
543        elif isinstance(name, ast.Attribute):
544            new_name = ast.Attribute(value=name.value, attr=name.attr)
545        elif isinstance(name, ast.Starred):
546            new_name = ast.Starred(
547                value=self._storeize(expr, name.value, func))
548        else:
549            raise self._syntax_error(expr,
550                "Can't assign or delete a %s" % type(expr).__name__)
551
552        new_name.ctx = func()
553        ast.copy_location(new_name, name)
554        return new_name
555
556    def _render_quoted_form(self, form, level):
557        """
558        Render a quoted form as a new HyExpression.
559
560        `level` is the level of quasiquoting of the current form. We can
561        unquote if level is 0.
562
563        Returns a three-tuple (`imports`, `expression`, `splice`).
564
565        The `splice` return value is used to mark `unquote-splice`d forms.
566        We need to distinguish them as want to concatenate them instead of
567        just nesting them.
568        """
569
570        op = None
571        if isinstance(form, HyExpression) and form and (
572                isinstance(form[0], HySymbol)):
573            op = unmangle(ast_str(form[0]))
574        if level == 0 and op in ("unquote", "unquote-splice"):
575            if len(form) != 2:
576                raise HyTypeError("`%s' needs 1 argument, got %s" % op, len(form) - 1,
577                                  self.filename, form, self.source)
578            return set(), form[1], op == "unquote-splice"
579        elif op == "quasiquote":
580            level += 1
581        elif op in ("unquote", "unquote-splice"):
582            level -= 1
583
584        name = form.__class__.__name__
585        imports = set([name])
586        body = [form]
587
588        if isinstance(form, HySequence):
589            contents = []
590            for x in form:
591                f_imps, f_contents, splice = self._render_quoted_form(x, level)
592                imports.update(f_imps)
593                if splice:
594                    contents.append(HyExpression([
595                        HySymbol("list"),
596                        HyExpression([HySymbol("or"), f_contents, HyList()])]))
597                else:
598                    contents.append(HyList([f_contents]))
599            if form:
600                # If there are arguments, they can be spliced
601                # so we build a sum...
602                body = [HyExpression([HySymbol("+"), HyList()] + contents)]
603            else:
604                body = [HyList()]
605
606        elif isinstance(form, HySymbol):
607            body = [HyString(form)]
608
609        elif isinstance(form, HyKeyword):
610            body = [HyString(form.name)]
611
612        elif isinstance(form, HyString):
613            if form.is_format:
614                # Ensure that this f-string isn't evaluated right now.
615                body = [
616                    copy.copy(form),
617                    HyKeyword("is_format"),
618                    form.is_format,
619                ]
620                body[0].is_format = False
621            if form.brackets is not None:
622                body.extend([HyKeyword("brackets"), form.brackets])
623
624        ret = HyExpression([HySymbol(name)] + body).replace(form)
625        return imports, ret, False
626
627    @special(["quote", "quasiquote"], [FORM])
628    def compile_quote(self, expr, root, arg):
629        level = Inf if root == "quote" else 0   # Only quasiquotes can unquote
630        imports, stmts, _ = self._render_quoted_form(arg, level)
631        ret = self.compile(stmts)
632        ret.add_imports("hy", imports)
633        return ret
634
635    @special("unpack-iterable", [FORM])
636    def compile_unpack_iterable(self, expr, root, arg):
637        ret = self.compile(arg)
638        ret += asty.Starred(expr, value=ret.force_expr, ctx=ast.Load())
639        return ret
640
641    @special("do", [many(FORM)])
642    def compile_do(self, expr, root, body):
643        return self._compile_branch(body)
644
645    @special("raise", [maybe(FORM), maybe(sym(":from") + FORM)])
646    def compile_raise_expression(self, expr, root, exc, cause):
647        ret = Result()
648
649        if exc is not None:
650            exc = self.compile(exc)
651            ret += exc
652            exc = exc.force_expr
653
654        if cause is not None:
655            cause = self.compile(cause)
656            ret += cause
657            cause = cause.force_expr
658
659        return ret + asty.Raise(
660            expr, type=ret.expr, exc=exc,
661            inst=None, tback=None, cause=cause)
662
663    @special("try",
664       [many(notpexpr("except", "else", "finally")),
665        many(pexpr(sym("except"),
666            brackets() | brackets(FORM) | brackets(SYM, FORM),
667            many(FORM))),
668        maybe(dolike("else")),
669        maybe(dolike("finally"))])
670    def compile_try_expression(self, expr, root, body, catchers, orelse, finalbody):
671        body = self._compile_branch(body)
672
673        return_var = asty.Name(
674            expr, id=ast_str(self.get_anon_var()), ctx=ast.Store())
675
676        handler_results = Result()
677        handlers = []
678        for catcher in catchers:
679            handler_results += self._compile_catch_expression(
680                catcher, return_var, *catcher)
681            handlers.append(handler_results.stmts.pop())
682
683        if orelse is None:
684            orelse = []
685        else:
686            orelse = self._compile_branch(orelse)
687            orelse += asty.Assign(expr, targets=[return_var],
688                                  value=orelse.force_expr)
689            orelse += orelse.expr_as_stmt()
690            orelse = orelse.stmts
691
692        if finalbody is None:
693            finalbody = []
694        else:
695            finalbody = self._compile_branch(finalbody)
696            finalbody += finalbody.expr_as_stmt()
697            finalbody = finalbody.stmts
698
699        # Using (else) without (except) is verboten!
700        if orelse and not handlers:
701            raise self._syntax_error(expr,
702                "`try' cannot have `else' without `except'")
703        # Likewise a bare (try) or (try BODY).
704        if not (handlers or finalbody):
705            raise self._syntax_error(expr,
706                "`try' must have an `except' or `finally' clause")
707
708        returnable = Result(
709            expr=asty.Name(expr, id=return_var.id, ctx=ast.Load()),
710            temp_variables=[return_var])
711        body += body.expr_as_stmt() if orelse else asty.Assign(
712            expr, targets=[return_var], value=body.force_expr)
713        body = body.stmts or [asty.Pass(expr)]
714
715        x = asty.Try(
716            expr,
717            body=body,
718            handlers=handlers,
719            orelse=orelse,
720            finalbody=finalbody)
721        return handler_results + x + returnable
722
723    def _compile_catch_expression(self, expr, var, exceptions, body):
724        # exceptions catch should be either:
725        # [[list of exceptions]]
726        # or
727        # [variable [list of exceptions]]
728        # or
729        # [variable exception]
730        # or
731        # [exception]
732        # or
733        # []
734
735        name = None
736        if len(exceptions) == 2:
737            name = ast_str(exceptions[0])
738
739        exceptions_list = exceptions[-1] if exceptions else HyList()
740        if isinstance(exceptions_list, HyList):
741            if len(exceptions_list):
742                # [FooBar BarFoo] → catch Foobar and BarFoo exceptions
743                elts, types, _ = self._compile_collect(exceptions_list)
744                types += asty.Tuple(exceptions_list, elts=elts, ctx=ast.Load())
745            else:
746                # [] → all exceptions caught
747                types = Result()
748        else:
749            types = self.compile(exceptions_list)
750
751        body = self._compile_branch(body)
752        body += asty.Assign(expr, targets=[var], value=body.force_expr)
753        body += body.expr_as_stmt()
754
755        return types + asty.ExceptHandler(
756            expr, type=types.expr, name=name,
757            body=body.stmts or [asty.Pass(expr)])
758
759    @special("if*", [FORM, FORM, maybe(FORM)])
760    def compile_if(self, expr, _, cond, body, orel_expr):
761        cond = self.compile(cond)
762        body = self.compile(body)
763
764        nested = root = False
765        orel = Result()
766        if orel_expr is not None:
767            if isinstance(orel_expr, HyExpression) and isinstance(orel_expr[0],
768               HySymbol) and orel_expr[0] == 'if*':
769                # Nested ifs: don't waste temporaries
770                root = self.temp_if is None
771                nested = True
772                self.temp_if = self.temp_if or self.get_anon_var()
773            orel = self.compile(orel_expr)
774
775        if not cond.stmts and isinstance(cond.force_expr, ast.Name):
776            name = cond.force_expr.id
777            branch = None
778            if name == 'True':
779                branch = body
780            elif name in ('False', 'None'):
781                branch = orel
782            if branch is not None:
783                if self.temp_if and branch.stmts:
784                    name = asty.Name(expr,
785                                     id=ast_str(self.temp_if),
786                                     ctx=ast.Store())
787
788                    branch += asty.Assign(expr,
789                                          targets=[name],
790                                          value=body.force_expr)
791
792                return branch
793
794        # We want to hoist the statements from the condition
795        ret = cond
796
797        if body.stmts or orel.stmts:
798            # We have statements in our bodies
799            # Get a temporary variable for the result storage
800            var = self.temp_if or self.get_anon_var()
801            name = asty.Name(expr,
802                             id=ast_str(var),
803                             ctx=ast.Store())
804
805            # Store the result of the body
806            body += asty.Assign(expr,
807                                targets=[name],
808                                value=body.force_expr)
809
810            # and of the else clause
811            if not nested or not orel.stmts or (not root and
812               var != self.temp_if):
813                orel += asty.Assign(expr,
814                                    targets=[name],
815                                    value=orel.force_expr)
816
817            # Then build the if
818            ret += asty.If(expr,
819                           test=ret.force_expr,
820                           body=body.stmts,
821                           orelse=orel.stmts)
822
823            # And make our expression context our temp variable
824            expr_name = asty.Name(expr, id=ast_str(var), ctx=ast.Load())
825
826            ret += Result(expr=expr_name, temp_variables=[expr_name, name])
827        else:
828            # Just make that an if expression
829            ret += asty.IfExp(expr,
830                              test=ret.force_expr,
831                              body=body.force_expr,
832                              orelse=orel.force_expr)
833
834        if root:
835            self.temp_if = None
836
837        return ret
838
839    @special(["break", "continue"], [])
840    def compile_break_or_continue_expression(self, expr, root):
841        return (asty.Break if root == "break" else asty.Continue)(expr)
842
843    @special("assert", [FORM, maybe(FORM)])
844    def compile_assert_expression(self, expr, root, test, msg):
845        if msg is None or type(msg) is HySymbol:
846            ret = self.compile(test)
847            return ret + asty.Assert(
848                expr,
849                test=ret.force_expr,
850                msg=(None if msg is None else self.compile(msg).force_expr))
851
852        # The `msg` part may involve statements, which we only
853        # want to be executed if the assertion fails. Rewrite the
854        # form to set `msg` to a variable.
855        msg_var = self.get_anon_var()
856        return self.compile(mkexpr(
857            'if*', mkexpr('and', '__debug__', mkexpr('not', [test])),
858                mkexpr('do',
859                    mkexpr('setv', msg_var, [msg]),
860                    mkexpr('assert', 'False', msg_var))).replace(expr))
861
862    @special(["global", "nonlocal"], [oneplus(SYM)])
863    def compile_global_or_nonlocal(self, expr, root, syms):
864        node = asty.Global if root == "global" else asty.Nonlocal
865        return node(expr, names=list(map(ast_str, syms)))
866
867    @special("yield", [maybe(FORM)])
868    def compile_yield_expression(self, expr, root, arg):
869        ret = Result()
870        if arg is not None:
871            ret += self.compile(arg)
872        return ret + asty.Yield(expr, value=ret.force_expr)
873
874    @special(["yield-from", "await"], [FORM])
875    def compile_yield_from_or_await_expression(self, expr, root, arg):
876        ret = Result() + self.compile(arg)
877        node = asty.YieldFrom if root == "yield-from" else asty.Await
878        return ret + node(expr, value=ret.force_expr)
879
880    @special("get", [FORM, oneplus(FORM)])
881    def compile_index_expression(self, expr, name, obj, indices):
882        indices, ret, _ = self._compile_collect(indices)
883        ret += self.compile(obj)
884
885        for ix in indices:
886            ret += asty.Subscript(
887                expr,
888                value=ret.force_expr,
889                slice=ast.Index(value=ix),
890                ctx=ast.Load())
891
892        return ret
893
894    @special(".", [FORM, many(SYM | brackets(FORM))])
895    def compile_attribute_access(self, expr, name, invocant, keys):
896        ret = self.compile(invocant)
897
898        for attr in keys:
899            if isinstance(attr, HySymbol):
900                ret += asty.Attribute(attr,
901                                      value=ret.force_expr,
902                                      attr=ast_str(attr),
903                                      ctx=ast.Load())
904            else: # attr is a HyList
905                compiled_attr = self.compile(attr[0])
906                ret = compiled_attr + ret + asty.Subscript(
907                    attr,
908                    value=ret.force_expr,
909                    slice=ast.Index(value=compiled_attr.force_expr),
910                    ctx=ast.Load())
911
912        return ret
913
914    @special("del", [many(FORM)])
915    def compile_del_expression(self, expr, name, args):
916        if not args:
917            return asty.Pass(expr)
918
919        del_targets = []
920        ret = Result()
921        for target in args:
922            compiled_target = self.compile(target)
923            ret += compiled_target
924            del_targets.append(self._storeize(target, compiled_target,
925                                              ast.Del))
926
927        return ret + asty.Delete(expr, targets=del_targets)
928
929    @special("cut", [FORM, maybe(FORM), maybe(FORM), maybe(FORM)])
930    def compile_cut_expression(self, expr, name, obj, lower, upper, step):
931        ret = [Result()]
932        def c(e):
933            ret[0] += self.compile(e)
934            return ret[0].force_expr
935
936        s = asty.Subscript(
937            expr,
938            value=c(obj),
939            slice=asty.Slice(expr,
940                lower=c(lower), upper=c(upper), step=c(step)),
941            ctx=ast.Load())
942        return ret[0] + s
943
944    @special("with-decorator", [oneplus(FORM)])
945    def compile_decorate_expression(self, expr, name, args):
946        decs, fn = args[:-1], self.compile(args[-1])
947        if not fn.stmts or not isinstance(fn.stmts[-1], _decoratables):
948            raise self._syntax_error(args[-1],
949                "Decorated a non-function")
950        decs, ret, _ = self._compile_collect(decs)
951        fn.stmts[-1].decorator_list = decs + fn.stmts[-1].decorator_list
952        return ret + fn
953
954    @special(["with*", "with/a*"],
955             [brackets(FORM, maybe(FORM)), many(FORM)])
956    def compile_with_expression(self, expr, root, args, body):
957        thing, ctx = (None, args[0]) if args[1] is None else args
958        if thing is not None:
959            thing = self._storeize(thing, self.compile(thing))
960        ctx = self.compile(ctx)
961
962        body = self._compile_branch(body)
963
964        # Store the result of the body in a tempvar
965        var = self.get_anon_var()
966        name = asty.Name(expr, id=ast_str(var), ctx=ast.Store())
967        body += asty.Assign(expr, targets=[name], value=body.force_expr)
968        # Initialize the tempvar to None in case the `with` exits
969        # early with an exception.
970        initial_assign = asty.Assign(
971            expr, targets=[name], value=named_constant(expr, None))
972
973        node = asty.With if root == "with*" else asty.AsyncWith
974        the_with = node(expr,
975                        context_expr=ctx.force_expr,
976                        optional_vars=thing,
977                        body=body.stmts,
978                        items=[ast.withitem(context_expr=ctx.force_expr,
979                                            optional_vars=thing)])
980
981        ret = Result(stmts=[initial_assign]) + ctx + the_with
982        # And make our expression context our temp variable
983        expr_name = asty.Name(expr, id=ast_str(var), ctx=ast.Load())
984
985        ret += Result(expr=expr_name)
986        # We don't give the Result any temp_vars because we don't want
987        # Result.rename to touch `name`. Otherwise, initial_assign will
988        # clobber any preexisting value of the renamed-to variable.
989
990        return ret
991
992    @special(",", [many(FORM)])
993    def compile_tuple(self, expr, root, args):
994        elts, ret, _ = self._compile_collect(args)
995        return ret + asty.Tuple(expr, elts=elts, ctx=ast.Load())
996
997    _loopers = many(
998        tag('setv', sym(":setv") + FORM + FORM) |
999        tag('if', sym(":if") + FORM) |
1000        tag('do', sym(":do") + FORM) |
1001        tag('afor', sym(":async") + FORM + FORM) |
1002        tag('for', FORM + FORM))
1003    @special(["for"], [brackets(_loopers),
1004        many(notpexpr("else")) + maybe(dolike("else"))])
1005    @special(["lfor", "sfor", "gfor"], [_loopers, FORM])
1006    @special(["dfor"], [_loopers, brackets(FORM, FORM)])
1007    def compile_comprehension(self, expr, root, parts, final):
1008        node_class = {
1009            "for":  asty.For,
1010            "lfor": asty.ListComp,
1011            "dfor": asty.DictComp,
1012            "sfor": asty.SetComp,
1013            "gfor": asty.GeneratorExp}[root]
1014        is_for = root == "for"
1015
1016        orel = []
1017        if is_for:
1018            # Get the `else`.
1019            body, else_expr = final
1020            if else_expr is not None:
1021                orel.append(self._compile_branch(else_expr))
1022                orel[0] += orel[0].expr_as_stmt()
1023        else:
1024            # Get the final value (and for dictionary
1025            # comprehensions, the final key).
1026            if node_class is asty.DictComp:
1027                key, elt = map(self.compile, final)
1028            else:
1029                key = None
1030                elt = self.compile(final)
1031
1032        # Compile the parts.
1033        if is_for:
1034            parts = parts[0]
1035        if not parts:
1036            return Result(expr=ast.parse({
1037                asty.For: "None",
1038                asty.ListComp: "[]",
1039                asty.DictComp: "{}",
1040                asty.SetComp: "{1}.__class__()",
1041                asty.GeneratorExp: "(_ for _ in [])"}[node_class]).body[0].value)
1042        parts = [
1043            Tag(p.tag, self.compile(p.value) if p.tag in ["if", "do"] else [
1044                self._storeize(p.value[0], self.compile(p.value[0])),
1045                self.compile(p.value[1])])
1046            for p in parts]
1047
1048        # Produce a result.
1049        if (is_for or elt.stmts or (key is not None and key.stmts) or
1050            any(p.tag == 'do' or (p.value[1].stmts if p.tag in ("for", "afor", "setv") else p.value.stmts)
1051                for p in parts)):
1052            # The desired comprehension can't be expressed as a
1053            # real Python comprehension. We'll write it as a nested
1054            # loop in a function instead.
1055            def f(parts):
1056                # This function is called recursively to construct
1057                # the nested loop.
1058                if not parts:
1059                    if is_for:
1060                        if body:
1061                            bd = self._compile_branch(body)
1062                            return bd + bd.expr_as_stmt()
1063                        return Result(stmts=[asty.Pass(expr)])
1064                    if node_class is asty.DictComp:
1065                        ret = key + elt
1066                        val = asty.Tuple(
1067                            key, ctx=ast.Load(),
1068                            elts=[key.force_expr, elt.force_expr])
1069                    else:
1070                        ret = elt
1071                        val = elt.force_expr
1072                    return ret + asty.Expr(
1073                        elt, value=asty.Yield(elt, value=val))
1074                (tagname, v), parts = parts[0], parts[1:]
1075                if tagname in ("for", "afor"):
1076                    orelse = orel and orel.pop().stmts
1077                    node = asty.AsyncFor if tagname == "afor" else asty.For
1078                    return v[1] + node(
1079                        v[1], target=v[0], iter=v[1].force_expr, body=f(parts).stmts,
1080                        orelse=orelse)
1081                elif tagname == "setv":
1082                    return v[1] + asty.Assign(
1083                        v[1], targets=[v[0]], value=v[1].force_expr) + f(parts)
1084                elif tagname == "if":
1085                    return v + asty.If(
1086                        v, test=v.force_expr, body=f(parts).stmts, orelse=[])
1087                elif tagname == "do":
1088                    return v + v.expr_as_stmt() + f(parts)
1089                else:
1090                    raise ValueError("can't happen")
1091            if is_for:
1092                return f(parts)
1093            fname = self.get_anon_var()
1094            # Define the generator function.
1095            ret = Result() + asty.FunctionDef(
1096                expr,
1097                name=fname,
1098                args=ast.arguments(
1099                    args=[], vararg=None, kwarg=None, posonlyargs=[],
1100                    kwonlyargs=[], kw_defaults=[], defaults=[]),
1101                body=f(parts).stmts,
1102                decorator_list=[])
1103            # Immediately call the new function. Unless the user asked
1104            # for a generator, wrap the call in `[].__class__(...)` or
1105            # `{}.__class__(...)` or `{1}.__class__(...)` to get the
1106            # right type. We don't want to just use e.g. `list(...)`
1107            # because the name `list` might be rebound.
1108            return ret + Result(expr=ast.parse(
1109                "{}({}())".format(
1110                    {asty.ListComp: "[].__class__",
1111                     asty.DictComp: "{}.__class__",
1112                     asty.SetComp: "{1}.__class__",
1113                     asty.GeneratorExp: ""}[node_class],
1114                    fname)).body[0].value)
1115
1116        # We can produce a real comprehension.
1117        generators = []
1118        for tagname, v in parts:
1119            if tagname in ("for", "afor"):
1120                generators.append(ast.comprehension(
1121                    target=v[0], iter=v[1].expr, ifs=[],
1122                    is_async=int(tagname == "afor")))
1123            elif tagname == "setv":
1124                generators.append(ast.comprehension(
1125                    target=v[0],
1126                    iter=asty.Tuple(v[1], elts=[v[1].expr], ctx=ast.Load()),
1127                    ifs=[], is_async=0))
1128            elif tagname == "if":
1129                generators[-1].ifs.append(v.expr)
1130            else:
1131                raise ValueError("can't happen")
1132        if node_class is asty.DictComp:
1133            return asty.DictComp(expr, key=key.expr, value=elt.expr, generators=generators)
1134        return node_class(expr, elt=elt.expr, generators=generators)
1135
1136    @special(["not", "~"], [FORM])
1137    def compile_unary_operator(self, expr, root, arg):
1138        ops = {"not": ast.Not,
1139               "~": ast.Invert}
1140        operand = self.compile(arg)
1141        return operand + asty.UnaryOp(
1142            expr, op=ops[root](), operand=operand.force_expr)
1143
1144    _symn = some(lambda x: isinstance(x, HySymbol) and "." not in x)
1145
1146    @special(["import", "require"], [many(
1147        SYM |
1148        brackets(SYM, sym(":as"), _symn) |
1149        brackets(SYM, brackets(many(_symn + maybe(sym(":as") + _symn)))))])
1150    def compile_import_or_require(self, expr, root, entries):
1151        ret = Result()
1152
1153        for entry in entries:
1154            assignments = "ALL"
1155            prefix = ""
1156
1157            if isinstance(entry, HySymbol):
1158                # e.g., (import foo)
1159                module, prefix = entry, entry
1160            elif isinstance(entry, HyList) and isinstance(entry[1], HySymbol):
1161                # e.g., (import [foo :as bar])
1162                module, prefix = entry
1163            else:
1164                # e.g., (import [foo [bar baz :as MyBaz bing]])
1165                # or (import [foo [*]])
1166                module, kids = entry
1167                kids = kids[0]
1168                if (HySymbol('*'), None) in kids:
1169                    if len(kids) != 1:
1170                        star = kids[kids.index((HySymbol('*'), None))][0]
1171                        raise self._syntax_error(star,
1172                            "* in an import name list must be on its own")
1173                else:
1174                    assignments = [(k, v or k) for k, v in kids]
1175
1176            ast_module = ast_str(module, piecewise=True)
1177
1178            if root == "import":
1179                module = ast_module.lstrip(".")
1180                level = len(ast_module) - len(module)
1181                if assignments == "ALL" and prefix == "":
1182                    node = asty.ImportFrom
1183                    names = [ast.alias(name="*", asname=None)]
1184                elif assignments == "ALL":
1185                    node = asty.Import
1186                    prefix = ast_str(prefix, piecewise=True)
1187                    names = [ast.alias(
1188                        name=ast_module,
1189                        asname=prefix if prefix != module else None)]
1190                else:
1191                    node = asty.ImportFrom
1192                    names = [
1193                        ast.alias(
1194                            name=ast_str(k),
1195                            asname=None if v == k else ast_str(v))
1196                        for k, v in assignments]
1197                ret += node(
1198                    expr, module=module or None, names=names, level=level)
1199
1200            elif require(ast_module, self.module, assignments=assignments,
1201                         prefix=prefix):
1202                # Actually calling `require` is necessary for macro expansions
1203                # occurring during compilation.
1204                self.imports['hy.macros'].update([None])
1205                # The `require` we're creating in AST is the same as above, but used at
1206                # run-time (e.g. when modules are loaded via bytecode).
1207                ret += self.compile(HyExpression([
1208                    HySymbol('hy.macros.require'),
1209                    HyString(ast_module),
1210                    HySymbol('None'),
1211                    HyKeyword('assignments'),
1212                    (HyString("ALL") if assignments == "ALL" else
1213                        [[HyString(k), HyString(v)] for k, v in assignments]),
1214                    HyKeyword('prefix'),
1215                    HyString(prefix)]).replace(expr))
1216                ret += ret.expr_as_stmt()
1217
1218        return ret
1219
1220    @special(["and", "or"], [many(FORM)])
1221    def compile_logical_or_and_and_operator(self, expr, operator, args):
1222        ops = {"and": (ast.And, True),
1223               "or": (ast.Or, None)}
1224        opnode, default = ops[operator]
1225        osym = expr[0]
1226        if len(args) == 0:
1227            return named_constant(osym, default)
1228        elif len(args) == 1:
1229            return self.compile(args[0])
1230        ret = Result()
1231        values = list(map(self.compile, args))
1232        if any(value.stmts for value in values):
1233            # Compile it to an if...else sequence
1234            var = self.get_anon_var()
1235            name = asty.Name(osym, id=var, ctx=ast.Store())
1236            expr_name = asty.Name(osym, id=var, ctx=ast.Load())
1237            temp_variables = [name, expr_name]
1238
1239            def make_assign(value, node=None):
1240                positioned_name = asty.Name(
1241                    node or osym, id=var, ctx=ast.Store())
1242                temp_variables.append(positioned_name)
1243                return asty.Assign(
1244                    node or osym, targets=[positioned_name], value=value)
1245
1246            current = root = []
1247            for i, value in enumerate(values):
1248                if value.stmts:
1249                    node = value.stmts[0]
1250                    current.extend(value.stmts)
1251                else:
1252                    node = value.expr
1253                current.append(make_assign(value.force_expr, value.force_expr))
1254                if i == len(values)-1:
1255                    # Skip a redundant 'if'.
1256                    break
1257                if operator == "and":
1258                    cond = expr_name
1259                elif operator == "or":
1260                    cond = asty.UnaryOp(node, op=ast.Not(), operand=expr_name)
1261                current.append(asty.If(node, test=cond, body=[], orelse=[]))
1262                current = current[-1].body
1263            ret = sum(root, ret)
1264            ret += Result(expr=expr_name, temp_variables=temp_variables)
1265        else:
1266            ret += asty.BoolOp(osym,
1267                               op=opnode(),
1268                               values=[value.force_expr for value in values])
1269        return ret
1270
1271    _c_ops = {"=": ast.Eq, "!=": ast.NotEq,
1272             "<": ast.Lt, "<=": ast.LtE,
1273             ">": ast.Gt, ">=": ast.GtE,
1274             "is": ast.Is, "is-not": ast.IsNot,
1275             "in": ast.In, "not-in": ast.NotIn}
1276    _c_ops = {ast_str(k): v for k, v in _c_ops.items()}
1277    def _get_c_op(self, sym):
1278        k = ast_str(sym)
1279        if k not in self._c_ops:
1280            raise self._syntax_error(sym,
1281                "Illegal comparison operator: " + str(sym))
1282        return self._c_ops[k]()
1283
1284    @special(["=", "is", "<", "<=", ">", ">="], [oneplus(FORM)])
1285    @special(["!=", "is-not", "in", "not-in"], [times(2, Inf, FORM)])
1286    def compile_compare_op_expression(self, expr, root, args):
1287        if len(args) == 1:
1288            return (self.compile(args[0]) +
1289                    named_constant(expr, True))
1290
1291        ops = [self._get_c_op(root) for _ in args[1:]]
1292        exprs, ret, _ = self._compile_collect(args)
1293        return ret + asty.Compare(
1294            expr, left=exprs[0], ops=ops, comparators=exprs[1:])
1295
1296    @special("cmp", [FORM, many(SYM + FORM)])
1297    def compile_chained_comparison(self, expr, root, arg1, args):
1298        ret = self.compile(arg1)
1299        arg1 = ret.force_expr
1300
1301        ops = [self._get_c_op(op) for op, _ in args]
1302        args, ret2, _ = self._compile_collect(
1303            [x for _, x in args])
1304
1305        return ret + ret2 + asty.Compare(expr,
1306            left=arg1, ops=ops, comparators=args)
1307
1308    # The second element of each tuple below is an aggregation operator
1309    # that's used for augmented assignment with three or more arguments.
1310    m_ops = {"+": (ast.Add, "+"),
1311             "/": (ast.Div, "*"),
1312             "//": (ast.FloorDiv, "*"),
1313             "*": (ast.Mult, "*"),
1314             "-": (ast.Sub, "+"),
1315             "%": (ast.Mod, None),
1316             "**": (ast.Pow, "**"),
1317             "<<": (ast.LShift, "+"),
1318             ">>": (ast.RShift, "+"),
1319             "|": (ast.BitOr, "|"),
1320             "^": (ast.BitXor, None),
1321             "&": (ast.BitAnd, "&"),
1322             "@": (ast.MatMult, "@")}
1323
1324    @special(["+", "*", "|"], [many(FORM)])
1325    @special(["-", "/", "&", "@"], [oneplus(FORM)])
1326    @special(["**", "//", "<<", ">>"], [times(2, Inf, FORM)])
1327    @special(["%", "^"], [times(2, 2, FORM)])
1328    def compile_maths_expression(self, expr, root, args):
1329        if len(args) == 0:
1330            # Return the identity element for this operator.
1331            return asty.Num(expr, n=(
1332                {"+": 0, "|": 0, "*": 1}[root]))
1333
1334        if len(args) == 1:
1335            if root == "/":
1336                # Compute the reciprocal of the argument.
1337                args = [HyInteger(1).replace(expr), args[0]]
1338            elif root in ("+", "-"):
1339                # Apply unary plus or unary minus to the argument.
1340                op = {"+": ast.UAdd, "-": ast.USub}[root]()
1341                ret = self.compile(args[0])
1342                return ret + asty.UnaryOp(expr, op=op, operand=ret.force_expr)
1343            else:
1344                # Return the argument unchanged.
1345                return self.compile(args[0])
1346
1347        op = self.m_ops[root][0]
1348        right_associative = root == "**"
1349        ret = self.compile(args[-1 if right_associative else 0])
1350        for child in args[-2 if right_associative else 1 ::
1351                          -1 if right_associative else 1]:
1352            left_expr = ret.force_expr
1353            ret += self.compile(child)
1354            right_expr = ret.force_expr
1355            if right_associative:
1356                left_expr, right_expr = right_expr, left_expr
1357            ret += asty.BinOp(expr, left=left_expr, op=op(), right=right_expr)
1358
1359        return ret
1360
1361    a_ops = {x + "=": v for x, v in m_ops.items()}
1362
1363    @special([x for x, (_, v) in a_ops.items() if v is not None], [FORM, oneplus(FORM)])
1364    @special([x for x, (_, v) in a_ops.items() if v is None], [FORM, times(1, 1, FORM)])
1365    def compile_augassign_expression(self, expr, root, target, values):
1366        if len(values) > 1:
1367            return self.compile(mkexpr(root, [target],
1368                mkexpr(self.a_ops[root][1], rest=values)).replace(expr))
1369
1370        op = self.a_ops[root][0]
1371        target = self._storeize(target, self.compile(target))
1372        ret = self.compile(values[0])
1373        return ret + asty.AugAssign(
1374            expr, target=target, value=ret.force_expr, op=op())
1375
1376    @special("setv", [many(OPTIONAL_ANNOTATION + FORM + FORM)])
1377    @special((PY38, "setx"), [times(1, 1, SYM + FORM)])
1378    def compile_def_expression(self, expr, root, decls):
1379        if not decls:
1380            return named_constant(expr, None)
1381
1382        result = Result()
1383        is_assignment_expr = root == HySymbol("setx")
1384        for decl in decls:
1385            if is_assignment_expr:
1386                ann = None
1387                name, value = decl
1388            else:
1389                ann, name, value = decl
1390
1391            result += self._compile_assign(ann, name, value,
1392                                           is_assignment_expr=is_assignment_expr)
1393        return result
1394
1395    @special(["annotate*"], [FORM, FORM])
1396    def compile_basic_annotation(self, expr, root, ann, target):
1397        return self._compile_assign(ann, target, None)
1398
1399    def _compile_assign(self, ann, name, value, *, is_assignment_expr = False):
1400        # Ensure that assignment expressions have a result and no annotation.
1401        assert not is_assignment_expr or (value is not None and ann is None)
1402
1403        ld_name = self.compile(name)
1404
1405        annotate_only = value is None
1406        if annotate_only:
1407            result = Result()
1408        else:
1409            result = self.compile(value)
1410
1411        invalid_name = False
1412        if ann is not None:
1413            # An annotation / annotated assignment is more strict with the target expression.
1414            invalid_name = not isinstance(ld_name.expr, (ast.Name, ast.Attribute, ast.Subscript))
1415        else:
1416            invalid_name = (str(name) in ("None", "True", "False")
1417                            or isinstance(ld_name.expr, ast.Call))
1418
1419        if invalid_name:
1420            raise self._syntax_error(name, "illegal target for {}".format(
1421                                        "annotation" if annotate_only else "assignment"))
1422
1423        if (result.temp_variables
1424                and isinstance(name, HySymbol)
1425                and '.' not in name):
1426            result.rename(name)
1427            if not is_assignment_expr:
1428                # Throw away .expr to ensure that (setv ...) returns None.
1429                result.expr = None
1430        else:
1431            st_name = self._storeize(name, ld_name)
1432
1433            if ann is not None:
1434                ann_result = self.compile(ann)
1435                result = ann_result + result
1436
1437            if is_assignment_expr:
1438                node = asty.NamedExpr
1439            elif ann is not None:
1440                if not PY36:
1441                    raise self._syntax_error(name, "Variable annotations are not supported on "
1442                                                   "Python <=3.6")
1443
1444                node = lambda x, **kw: asty.AnnAssign(x, annotation=ann_result.force_expr,
1445                                                      simple=int(isinstance(name, HySymbol)),
1446                                                      **kw)
1447            else:
1448                node = asty.Assign
1449
1450            result += node(
1451                name if hasattr(name, "start_line") else result,
1452                value=result.force_expr if not annotate_only else None,
1453                target=st_name, targets=[st_name])
1454
1455        return result
1456
1457    @special(["while"], [FORM, many(notpexpr("else")), maybe(dolike("else"))])
1458    def compile_while_expression(self, expr, root, cond, body, else_expr):
1459        cond_compiled = self.compile(cond)
1460
1461        body = self._compile_branch(body)
1462        body += body.expr_as_stmt()
1463        body_stmts = body.stmts or [asty.Pass(expr)]
1464
1465        if cond_compiled.stmts:
1466            # We need to ensure the statements for the condition are
1467            # executed on every iteration. Rewrite the loop to use a
1468            # single anonymous variable as the condition, i.e.:
1469            #  anon_var = True
1470            #  while anon_var:
1471            #    condition stmts...
1472            #    anon_var = condition expr
1473            #    if anon_var:
1474            #      while loop body
1475            cond_var = asty.Name(cond, id=self.get_anon_var(), ctx=ast.Load())
1476            def make_not(operand):
1477                return asty.UnaryOp(cond, op=ast.Not(), operand=operand)
1478
1479            body_stmts = cond_compiled.stmts + [
1480                asty.Assign(cond, targets=[self._storeize(cond, cond_var)],
1481                            # Cast the condition to a bool in case it's mutable and
1482                            # changes its truth value, but use (not (not ...)) instead of
1483                            # `bool` in case `bool` has been redefined.
1484                            value=make_not(make_not(cond_compiled.force_expr))),
1485                asty.If(cond, test=cond_var, body=body_stmts, orelse=[]),
1486            ]
1487
1488            cond_compiled = (Result()
1489                + asty.Assign(cond, targets=[self._storeize(cond, cond_var)],
1490                              value=named_constant(cond, True))
1491                + cond_var)
1492
1493        orel = Result()
1494        if else_expr is not None:
1495            orel = self._compile_branch(else_expr)
1496            orel += orel.expr_as_stmt()
1497
1498        ret = cond_compiled + asty.While(
1499            expr, test=cond_compiled.force_expr,
1500            body=body_stmts,
1501            orelse=orel.stmts)
1502
1503        return ret
1504
1505    NASYM = some(lambda x: isinstance(x, HySymbol) and x not in (
1506        "&optional", "&rest", "&kwonly", "&kwargs"))
1507    @special(["fn", "fn*", "fn/a"], [
1508        # The starred version is for internal use (particularly, in the
1509        # definition of `defn`). It ensures that a FunctionDef is
1510        # produced rather than a Lambda.
1511        OPTIONAL_ANNOTATION,
1512        brackets(
1513            many(OPTIONAL_ANNOTATION + NASYM),
1514            maybe(sym("&optional") + many(OPTIONAL_ANNOTATION
1515                                            + (NASYM | brackets(SYM, FORM)))),
1516            maybe(sym("&rest") + OPTIONAL_ANNOTATION + NASYM),
1517            maybe(sym("&kwonly") + many(OPTIONAL_ANNOTATION
1518                                        + (NASYM | brackets(SYM, FORM)))),
1519            maybe(sym("&kwargs") + OPTIONAL_ANNOTATION + NASYM)),
1520        many(FORM)])
1521    def compile_function_def(self, expr, root, returns, params, body):
1522        force_functiondef = root in ("fn*", "fn/a")
1523        node = asty.AsyncFunctionDef if root == "fn/a" else asty.FunctionDef
1524        ret = Result()
1525
1526        # NOTE: Our evaluation order of return type annotations is
1527        # different from Python: Python evalautes them after the argument
1528        # annotations / defaults (as that's where they are in the source),
1529        # but Hy evaluates them *first*, since here they come before the #
1530        # argument list. Therefore, it would be more confusing for
1531        # readability to evaluate them after like Python.
1532
1533        ret = Result()
1534        returns_ann = None
1535        if returns is not None:
1536            returns_result = self.compile(returns)
1537            ret += returns_result
1538
1539        mandatory, optional, rest, kwonly, kwargs = params
1540
1541        optional = optional or []
1542        kwonly = kwonly or []
1543
1544        mandatory_ast, _, ret = self._compile_arguments_set(mandatory, False, ret)
1545        optional_ast, optional_defaults, ret = self._compile_arguments_set(optional, True, ret)
1546        kwonly_ast, kwonly_defaults, ret = self._compile_arguments_set(kwonly, False, ret)
1547
1548        rest_ast = kwargs_ast = None
1549
1550        if rest is not None:
1551            [rest_ast], _, ret = self._compile_arguments_set([rest], False, ret)
1552        if kwargs is not None:
1553            [kwargs_ast], _, ret = self._compile_arguments_set([kwargs], False, ret)
1554
1555        args = ast.arguments(
1556            args=mandatory_ast + optional_ast, defaults=optional_defaults,
1557            vararg=rest_ast,
1558            posonlyargs=[],
1559            kwonlyargs=kwonly_ast, kw_defaults=kwonly_defaults,
1560            kwarg=kwargs_ast)
1561
1562        body = self._compile_branch(body)
1563
1564        if not force_functiondef and not body.stmts and returns is None:
1565            return ret + asty.Lambda(expr, args=args, body=body.force_expr)
1566
1567        if body.expr:
1568            body += asty.Return(body.expr, value=body.expr)
1569
1570        name = self.get_anon_var()
1571
1572        ret += node(expr,
1573                    name=name,
1574                    args=args,
1575                    body=body.stmts or [asty.Pass(expr)],
1576                    decorator_list=[],
1577                    returns=returns_result.force_expr if returns is not None else None)
1578
1579        ast_name = asty.Name(expr, id=name, ctx=ast.Load())
1580        ret += Result(expr=ast_name, temp_variables=[ast_name, ret.stmts[-1]])
1581        return ret
1582
1583    def _compile_arguments_set(self, decls, implicit_default_none, ret):
1584        args_ast = []
1585        args_defaults = []
1586
1587        for ann, decl in decls:
1588            default = None
1589
1590            # funcparserlib will check to make sure that the only times we
1591            # ever have a HyList here are due to a default value.
1592            if isinstance(decl, HyList):
1593                sym, default = decl
1594            else:
1595                sym = decl
1596                if implicit_default_none:
1597                    default = HySymbol('None').replace(sym)
1598
1599            if ann is not None:
1600                ret += self.compile(ann)
1601                ann_ast = ret.force_expr
1602            else:
1603                ann_ast = None
1604
1605            if default is not None:
1606                ret += self.compile(default)
1607                args_defaults.append(ret.force_expr)
1608            else:
1609                # Note that the only time any None should ever appear here
1610                # is in kwargs, since the order of those with defaults vs
1611                # those without isn't significant in the same way as
1612                # positional args.
1613                args_defaults.append(None)
1614
1615            args_ast.append(asty.arg(sym, arg=ast_str(sym), annotation=ann_ast))
1616
1617        return args_ast, args_defaults, ret
1618
1619    @special("return", [maybe(FORM)])
1620    def compile_return(self, expr, root, arg):
1621        ret = Result()
1622        if arg is None:
1623            return asty.Return(expr, value=None)
1624        ret += self.compile(arg)
1625        return ret + asty.Return(expr, value=ret.force_expr)
1626
1627    @special("defclass", [
1628        SYM,
1629        maybe(brackets(many(FORM)) + maybe(STR) + many(FORM))])
1630    def compile_class_expression(self, expr, root, name, rest):
1631        base_list, docstring, body = rest or ([[]], None, [])
1632
1633        bases_expr, bases, keywords = (
1634            self._compile_collect(base_list[0], with_kwargs=True))
1635
1636        bodyr = Result()
1637
1638        if docstring is not None:
1639            bodyr += self.compile(docstring).expr_as_stmt()
1640
1641        for e in body:
1642            e = self.compile(self._rewire_init(
1643                macroexpand(e, self.module, self)))
1644            bodyr += e + e.expr_as_stmt()
1645
1646        return bases + asty.ClassDef(
1647            expr,
1648            decorator_list=[],
1649            name=ast_str(name),
1650            keywords=keywords,
1651            starargs=None,
1652            kwargs=None,
1653            bases=bases_expr,
1654            body=bodyr.stmts or [asty.Pass(expr)])
1655
1656    def _rewire_init(self, expr):
1657        "Given a (setv …) form, append None to definitions of __init__."
1658
1659        if not (isinstance(expr, HyExpression)
1660                and len(expr) > 1
1661                and isinstance(expr[0], HySymbol)
1662                and expr[0] == HySymbol("setv")):
1663            return expr
1664
1665        new_args = []
1666        decls = list(expr[1:])
1667        while decls:
1668            if is_annotate_expression(decls[0]):
1669                # Handle annotations.
1670                ann = decls.pop(0)
1671            else:
1672                ann = None
1673
1674            k, v = (decls.pop(0), decls.pop(0))
1675            if ast_str(k) == "__init__" and isinstance(v, HyExpression):
1676                v += HyExpression([HySymbol("None")])
1677
1678            if ann is not None:
1679                new_args.append(ann)
1680
1681            new_args.extend((k, v))
1682        return HyExpression([HySymbol("setv")] + new_args).replace(expr)
1683
1684    @special("dispatch-tag-macro", [STR, FORM])
1685    def compile_dispatch_tag_macro(self, expr, root, tag, arg):
1686        return self.compile(tag_macroexpand(
1687            HyString(mangle(tag)).replace(tag),
1688            arg,
1689            self.module))
1690
1691    @special(["eval-and-compile", "eval-when-compile"], [many(FORM)])
1692    def compile_eval_and_compile(self, expr, root, body):
1693        new_expr = HyExpression([HySymbol("do").replace(expr[0])]).replace(expr)
1694
1695        try:
1696            hy_eval(new_expr + body,
1697                    self.module.__dict__,
1698                    self.module,
1699                    filename=self.filename,
1700                    source=self.source)
1701        except HyInternalError:
1702            # Unexpected "meta" compilation errors need to be treated
1703            # like normal (unexpected) compilation errors at this level
1704            # (or the compilation level preceding this one).
1705            raise
1706        except Exception as e:
1707            # These could be expected Hy language errors (e.g. syntax errors)
1708            # or regular Python runtime errors that do not signify errors in
1709            # the compilation *process* (although compilation did technically
1710            # fail).
1711            # We wrap these exceptions and pass them through.
1712            reraise(HyEvalError,
1713                    HyEvalError(str(e),
1714                                self.filename,
1715                                body,
1716                                self.source),
1717                    sys.exc_info()[2])
1718
1719        return (self._compile_branch(body)
1720                if ast_str(root) == "eval_and_compile"
1721                else Result())
1722
1723    @special(["py", "pys"], [STR])
1724    def compile_inline_python(self, expr, root, code):
1725        exec_mode = root == HySymbol("pys")
1726
1727        try:
1728            o = ast.parse(
1729                textwrap.dedent(code) if exec_mode else code,
1730                self.filename,
1731                'exec' if exec_mode else 'eval').body
1732        except (SyntaxError, ValueError if PY36 else TypeError) as e:
1733            raise self._syntax_error(
1734                expr,
1735                "Python parse error in '{}': {}".format(root, e))
1736
1737        return Result(stmts=o) if exec_mode else o
1738
1739    @builds_model(HyExpression)
1740    def compile_expression(self, expr, *, allow_annotation_expression=False):
1741        # Perform macro expansions
1742        expr = macroexpand(expr, self.module, self)
1743        if not isinstance(expr, HyExpression):
1744            # Go through compile again if the type changed.
1745            return self.compile(expr)
1746
1747        if not expr:
1748            raise self._syntax_error(expr,
1749                "empty expressions are not allowed at top level")
1750
1751        args = list(expr)
1752        root = args.pop(0)
1753        func = None
1754
1755        if isinstance(root, HySymbol):
1756            # First check if `root` is a special operator, unless it has an
1757            # `unpack-iterable` in it, since Python's operators (`+`,
1758            # etc.) can't unpack. An exception to this exception is that
1759            # tuple literals (`,`) can unpack. Finally, we allow unpacking in
1760            # `.` forms here so the user gets a better error message.
1761            sroot = ast_str(root)
1762
1763            bad_root = sroot in _bad_roots or (sroot == ast_str("annotate*")
1764                                               and not allow_annotation_expression)
1765
1766            if (sroot in _special_form_compilers or bad_root) and (
1767                    sroot in (mangle(","), mangle(".")) or
1768                    not any(is_unpack("iterable", x) for x in args)):
1769                if bad_root:
1770                    raise self._syntax_error(expr,
1771                        "The special form '{}' is not allowed here".format(root))
1772                # `sroot` is a special operator. Get the build method and
1773                # pattern-match the arguments.
1774                build_method, pattern = _special_form_compilers[sroot]
1775                try:
1776                    parse_tree = pattern.parse(args)
1777                except NoParseError as e:
1778                    raise self._syntax_error(
1779                        expr[min(e.state.pos + 1, len(expr) - 1)],
1780                        "parse error for special form '{}': {}".format(
1781                            root, e.msg.replace("<EOF>", "end of form")))
1782                return Result() + build_method(
1783                    self, expr, unmangle(sroot), *parse_tree)
1784
1785            if root.startswith("."):
1786                # (.split "test test") -> "test test".split()
1787                # (.a.b.c x v1 v2) -> (.c (. x a b) v1 v2) ->  x.a.b.c(v1, v2)
1788
1789                # Get the method name (the last named attribute
1790                # in the chain of attributes)
1791                attrs = [HySymbol(a).replace(root) for a in root.split(".")[1:]]
1792                root = attrs.pop()
1793
1794                # Get the object we're calling the method on
1795                # (extracted with the attribute access DSL)
1796                # Skip past keywords and their arguments.
1797                try:
1798                    kws, obj, rest = (
1799                        many(KEYWORD + FORM | unpack("mapping")) +
1800                        FORM +
1801                        many(FORM)).parse(args)
1802                except NoParseError:
1803                    raise self._syntax_error(expr,
1804                        "attribute access requires object")
1805                # Reconstruct `args` to exclude `obj`.
1806                args = [x for p in kws for x in p] + list(rest)
1807                if is_unpack("iterable", obj):
1808                    raise self._syntax_error(obj,
1809                        "can't call a method on an unpacking form")
1810                func = self.compile(HyExpression(
1811                    [HySymbol(".").replace(root), obj] +
1812                    attrs))
1813
1814                # And get the method
1815                func += asty.Attribute(root,
1816                                       value=func.force_expr,
1817                                       attr=ast_str(root),
1818                                       ctx=ast.Load())
1819
1820        elif is_annotate_expression(root):
1821            # Flatten and compile the annotation expression.
1822            ann_expr = HyExpression(root + args).replace(root)
1823            return self.compile_expression(ann_expr, allow_annotation_expression=True)
1824
1825        if not func:
1826            func = self.compile(root)
1827
1828        args, ret, keywords = self._compile_collect(args, with_kwargs=True)
1829
1830        return func + ret + asty.Call(
1831            expr, func=func.expr, args=args, keywords=keywords)
1832
1833    @builds_model(HyInteger, HyFloat, HyComplex)
1834    def compile_numeric_literal(self, x):
1835        f = {HyInteger: int,
1836             HyFloat: float,
1837             HyComplex: complex}[type(x)]
1838        return asty.Num(x, n=f(x))
1839
1840    @builds_model(HySymbol)
1841    def compile_symbol(self, symbol):
1842        if "." in symbol:
1843            glob, local = symbol.rsplit(".", 1)
1844
1845            if not glob:
1846                raise self._syntax_error(symbol,
1847                    'cannot access attribute on anything other than a name (in order to get attributes of expressions, use `(. <expression> {attr})` or `(.{attr} <expression>)`)'.format(attr=local))
1848
1849            if not local:
1850                raise self._syntax_error(symbol,
1851                    'cannot access empty attribute')
1852
1853            glob = HySymbol(glob).replace(symbol)
1854            ret = self.compile_symbol(glob)
1855
1856            return asty.Attribute(
1857                symbol,
1858                value=ret,
1859                attr=ast_str(local),
1860                ctx=ast.Load())
1861
1862        if self.can_use_stdlib and ast_str(symbol) in self._stdlib:
1863            self.imports[self._stdlib[ast_str(symbol)]].add(ast_str(symbol))
1864
1865        if ast_str(symbol) in ("None", "False", "True"):
1866            return named_constant(symbol, ast.literal_eval(ast_str(symbol)))
1867
1868        return asty.Name(symbol, id=ast_str(symbol), ctx=ast.Load())
1869
1870    @builds_model(HyKeyword)
1871    def compile_keyword(self, obj):
1872        ret = Result()
1873        ret += asty.Call(
1874            obj,
1875            func=asty.Name(obj, id="HyKeyword", ctx=ast.Load()),
1876            args=[asty.Str(obj, s=obj.name)],
1877            keywords=[])
1878        ret.add_imports("hy", {"HyKeyword"})
1879        return ret
1880
1881    @builds_model(HyString, HyBytes)
1882    def compile_string(self, string):
1883        if type(string) is HyString and string.is_format:
1884            # This is a format string (a.k.a. an f-string).
1885            return self._format_string(string, str(string))
1886        node = asty.Bytes if type(string) is HyBytes else asty.Str
1887        f = bytes if type(string) is HyBytes else str
1888        return node(string, s=f(string))
1889
1890    def _format_string(self, string, rest, allow_recursion=True):
1891        values = []
1892        ret = Result()
1893
1894        while True:
1895           # Look for the next replacement field, and get the
1896           # plain text before it.
1897           match = re.search(r'\{\{?|\}\}?', rest)
1898           if match:
1899              literal_chars = rest[: match.start()]
1900              if match.group() == '}':
1901                  raise self._syntax_error(string,
1902                      "f-string: single '}' is not allowed")
1903              if match.group() in ('{{', '}}'):
1904                  # Doubled braces just add a single brace to the text.
1905                  literal_chars += match.group()[0]
1906              rest = rest[match.end() :]
1907           else:
1908              literal_chars = rest
1909              rest = ""
1910           if literal_chars:
1911               values.append(asty.Str(string, s = literal_chars))
1912           if not rest:
1913               break
1914           if match.group() != '{':
1915               continue
1916
1917           # Look for the end of the replacement field, allowing
1918           # one more level of matched braces, but no deeper, and only
1919           # if we can recurse.
1920           match = re.match(
1921               r'(?: \{ [^{}]* \} | [^{}]+ )* \}'
1922                   if allow_recursion
1923                   else r'[^{}]* \}',
1924               rest, re.VERBOSE)
1925           if not match:
1926              raise self._syntax_error(string, 'f-string: mismatched braces')
1927           item = rest[: match.end() - 1]
1928           rest = rest[match.end() :]
1929
1930           # Parse the first form.
1931           try:
1932               model, item = parse_one_thing(item)
1933           except (ValueError, LexException) as e:
1934               raise self._syntax_error(string, "f-string: " + str(e))
1935
1936           # Look for a conversion character.
1937           item = item.lstrip()
1938           conversion = None
1939           if item.startswith('!'):
1940               conversion = item[1]
1941               item = item[2:].lstrip()
1942
1943           # Look for a format specifier.
1944           format_spec = None
1945           if item.startswith(':'):
1946               if allow_recursion:
1947                   ret += self._format_string(string,
1948                       item[1:],
1949                       allow_recursion=False)
1950                   format_spec = ret.force_expr
1951               else:
1952                   format_spec = asty.Str(string, s=item[1:])
1953                   if PY36:
1954                       format_spec = asty.JoinedStr(string, values=[format_spec])
1955           elif item:
1956               raise self._syntax_error(string,
1957                   "f-string: trailing junk in field")
1958
1959           # Now, having finished compiling any recursively included
1960           # forms, we can compile the first form that we parsed.
1961           ret += self.compile(model)
1962
1963           if PY36:
1964               values.append(asty.FormattedValue(
1965                   string,
1966                   conversion = -1 if conversion is None else ord(conversion),
1967                   format_spec = format_spec,
1968                   value = ret.force_expr))
1969           else:
1970               # Make an expression like:
1971               #    "{!r:{}}".format(value, format_spec)
1972               values.append(asty.Call(string,
1973                   func = asty.Attribute(
1974                       string,
1975                       value = asty.Str(string, s =
1976                           '{' +
1977                           ('!' + conversion if conversion else '') +
1978                           ':{}}'),
1979                       attr = 'format', ctx = ast.Load()),
1980                   args = [
1981                       ret.force_expr,
1982                       format_spec or asty.Str(string, s = "")],
1983                   keywords = [], starargs = None, kwargs = None))
1984
1985        return ret + (
1986           asty.JoinedStr(string, values = values)
1987           if PY36
1988           else reduce(
1989              lambda x, y:
1990                  asty.BinOp(string, left = x, op = ast.Add(), right = y),
1991              values))
1992
1993    @builds_model(HyList, HySet)
1994    def compile_list(self, expression):
1995        elts, ret, _ = self._compile_collect(expression)
1996        node = {HyList: asty.List, HySet: asty.Set}[type(expression)]
1997        return ret + node(expression, elts=elts, ctx=ast.Load())
1998
1999    @builds_model(HyDict)
2000    def compile_dict(self, m):
2001        keyvalues, ret, _ = self._compile_collect(m, dict_display=True)
2002        return ret + asty.Dict(m, keys=keyvalues[::2], values=keyvalues[1::2])
2003
2004
2005def get_compiler_module(module=None, compiler=None, calling_frame=False):
2006    """Get a module object from a compiler, given module object,
2007    string name of a module, and (optionally) the calling frame; otherwise,
2008    raise an error."""
2009
2010    module = getattr(compiler, 'module', None) or module
2011
2012    if isinstance(module, str):
2013        if module.startswith('<') and module.endswith('>'):
2014            module = types.ModuleType(module)
2015        else:
2016            module = importlib.import_module(ast_str(module, piecewise=True))
2017
2018    if calling_frame and not module:
2019        module = calling_module(n=2)
2020
2021    if not inspect.ismodule(module):
2022        raise TypeError('Invalid module type: {}'.format(type(module)))
2023
2024    return module
2025
2026
2027def hy_eval(hytree, locals=None, module=None, ast_callback=None,
2028            compiler=None, filename=None, source=None):
2029    """Evaluates a quoted expression and returns the value.
2030
2031    If you're evaluating hand-crafted AST trees, make sure the line numbers
2032    are set properly.  Try `fix_missing_locations` and related functions in the
2033    Python `ast` library.
2034
2035    Examples
2036    --------
2037       => (eval '(print "Hello World"))
2038       "Hello World"
2039
2040    If you want to evaluate a string, use ``read-str`` to convert it to a
2041    form first:
2042       => (eval (read-str "(+ 1 1)"))
2043       2
2044
2045    Parameters
2046    ----------
2047    hytree: HyObject
2048        The Hy AST object to evaluate.
2049
2050    locals: dict, optional
2051        Local environment in which to evaluate the Hy tree.  Defaults to the
2052        calling frame.
2053
2054    module: str or types.ModuleType, optional
2055        Module, or name of the module, to which the Hy tree is assigned and
2056        the global values are taken.
2057        The module associated with `compiler` takes priority over this value.
2058        When neither `module` nor `compiler` is specified, the calling frame's
2059        module is used.
2060
2061    ast_callback: callable, optional
2062        A callback that is passed the Hy compiled tree and resulting
2063        expression object, in that order, after compilation but before
2064        evaluation.
2065
2066    compiler: HyASTCompiler, optional
2067        An existing Hy compiler to use for compilation.  Also serves as
2068        the `module` value when given.
2069
2070    filename: str, optional
2071        The filename corresponding to the source for `tree`.  This will be
2072        overridden by the `filename` field of `tree`, if any; otherwise, it
2073        defaults to "<string>".  When `compiler` is given, its `filename` field
2074        value is always used.
2075
2076    source: str, optional
2077        A string containing the source code for `tree`.  This will be
2078        overridden by the `source` field of `tree`, if any; otherwise,
2079        if `None`, an attempt will be made to obtain it from the module given by
2080        `module`.  When `compiler` is given, its `source` field value is always
2081        used.
2082
2083    Returns
2084    -------
2085    out : Result of evaluating the Hy compiled tree.
2086    """
2087
2088    module = get_compiler_module(module, compiler, True)
2089
2090    if locals is None:
2091        frame = inspect.stack()[1][0]
2092        locals = inspect.getargvalues(frame).locals
2093
2094    if not isinstance(locals, dict):
2095        raise TypeError("Locals must be a dictionary")
2096
2097    # Does the Hy AST object come with its own information?
2098    filename = getattr(hytree, 'filename', filename) or '<string>'
2099    source = getattr(hytree, 'source', source)
2100
2101    _ast, expr = hy_compile(hytree, module, get_expr=True,
2102                            compiler=compiler, filename=filename,
2103                            source=source)
2104
2105    if ast_callback:
2106        ast_callback(_ast, expr)
2107
2108    # Two-step eval: eval() the body of the exec call
2109    eval(ast_compile(_ast, filename, "exec"),
2110         module.__dict__, locals)
2111
2112    # Then eval the expression context and return that
2113    return eval(ast_compile(expr, filename, "eval"),
2114                module.__dict__, locals)
2115
2116
2117def _module_file_source(module_name, filename, source):
2118    """Try to obtain missing filename and source information from a module name
2119    without actually loading the module.
2120    """
2121    if filename is None or source is None:
2122        mod_loader = pkgutil.get_loader(module_name)
2123        if mod_loader:
2124            if filename is None:
2125                filename = mod_loader.get_filename(module_name)
2126            if source is None:
2127                source = mod_loader.get_source(module_name)
2128
2129    # We need a non-None filename.
2130    filename = filename or '<string>'
2131
2132    return filename, source
2133
2134
2135def hy_compile(tree, module, root=ast.Module, get_expr=False,
2136               compiler=None, filename=None, source=None):
2137    """Compile a HyObject tree into a Python AST Module.
2138
2139    Parameters
2140    ----------
2141    tree: HyObject
2142        The Hy AST object to compile.
2143
2144    module: str or types.ModuleType, optional
2145        Module, or name of the module, in which the Hy tree is evaluated.
2146        The module associated with `compiler` takes priority over this value.
2147
2148    root: ast object, optional (ast.Module)
2149        Root object for the Python AST tree.
2150
2151    get_expr: bool, optional (False)
2152        If true, return a tuple with `(root_obj, last_expression)`.
2153
2154    compiler: HyASTCompiler, optional
2155        An existing Hy compiler to use for compilation.  Also serves as
2156        the `module` value when given.
2157
2158    filename: str, optional
2159        The filename corresponding to the source for `tree`.  This will be
2160        overridden by the `filename` field of `tree`, if any; otherwise, it
2161        defaults to "<string>".  When `compiler` is given, its `filename` field
2162        value is always used.
2163
2164    source: str, optional
2165        A string containing the source code for `tree`.  This will be
2166        overridden by the `source` field of `tree`, if any; otherwise,
2167        if `None`, an attempt will be made to obtain it from the module given by
2168        `module`.  When `compiler` is given, its `source` field value is always
2169        used.
2170
2171    Returns
2172    -------
2173    out : A Python AST tree
2174    """
2175    module = get_compiler_module(module, compiler, False)
2176
2177    if isinstance(module, str):
2178        if module.startswith('<') and module.endswith('>'):
2179            module = types.ModuleType(module)
2180        else:
2181            module = importlib.import_module(ast_str(module, piecewise=True))
2182
2183    if not inspect.ismodule(module):
2184        raise TypeError('Invalid module type: {}'.format(type(module)))
2185
2186    filename = getattr(tree, 'filename', filename)
2187    source = getattr(tree, 'source', source)
2188
2189    tree = wrap_value(tree)
2190    if not isinstance(tree, HyObject):
2191        raise TypeError("`tree` must be a HyObject or capable of "
2192                        "being promoted to one")
2193
2194    compiler = compiler or HyASTCompiler(module, filename=filename, source=source)
2195    result = compiler.compile(tree)
2196    expr = result.force_expr
2197
2198    if not get_expr:
2199        result += result.expr_as_stmt()
2200
2201    body = []
2202
2203    # Pull out a single docstring and prepend to the resulting body.
2204    if (len(result.stmts) > 0 and
2205        issubclass(root, ast.Module) and
2206        isinstance(result.stmts[0], ast.Expr) and
2207        isinstance(result.stmts[0].value, ast.Str)):
2208
2209        body += [result.stmts.pop(0)]
2210
2211    body += sorted(compiler.imports_as_stmts(tree) + result.stmts,
2212                   key=lambda a: not (isinstance(a, ast.ImportFrom) and
2213                                      a.module == '__future__'))
2214
2215    ret = root(body=body, type_ignores=[])
2216
2217    if get_expr:
2218        expr = ast.Expression(body=expr)
2219        ret = (ret, expr)
2220
2221    return ret
2222