1#     Copyright 2021, Kay Hayen, mailto:kay.hayen@gmail.com
2#
3#     Part of "Nuitka", an optimizing Python compiler that is compatible and
4#     integrates with CPython, but also works on its own.
5#
6#     Licensed under the Apache License, Version 2.0 (the "License");
7#     you may not use this file except in compliance with the License.
8#     You may obtain a copy of the License at
9#
10#        http://www.apache.org/licenses/LICENSE-2.0
11#
12#     Unless required by applicable law or agreed to in writing, software
13#     distributed under the License is distributed on an "AS IS" BASIS,
14#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15#     See the License for the specific language governing permissions and
16#     limitations under the License.
17#
18""" Helper functions for parsing the AST nodes and building the Nuitka node tree.
19
20"""
21
22import __future__
23
24import ast
25
26from nuitka import Constants, Options
27from nuitka.Errors import CodeTooComplexCode
28from nuitka.nodes.CallNodes import makeExpressionCall
29from nuitka.nodes.CodeObjectSpecs import CodeObjectSpec
30from nuitka.nodes.ConstantRefNodes import makeConstantRefNode
31from nuitka.nodes.ContainerMakingNodes import makeExpressionMakeTupleOrConstant
32from nuitka.nodes.DictionaryNodes import (
33    ExpressionKeyValuePair,
34    makeExpressionMakeDict,
35)
36from nuitka.nodes.ExceptionNodes import StatementReraiseException
37from nuitka.nodes.FrameNodes import (
38    StatementsFrameAsyncgen,
39    StatementsFrameCoroutine,
40    StatementsFrameFunction,
41    StatementsFrameGenerator,
42    StatementsFrameModule,
43)
44from nuitka.nodes.NodeBases import NodeBase
45from nuitka.nodes.NodeMakingHelpers import mergeStatements
46from nuitka.nodes.StatementNodes import StatementsSequence
47from nuitka.PythonVersions import python_version
48from nuitka.Tracing import optimization_logger, printLine
49
50
51def dump(node):
52    printLine(ast.dump(node))
53
54
55def getKind(node):
56    return node.__class__.__name__.rsplit(".", 1)[-1]
57
58
59def extractDocFromBody(node):
60    body = node.body
61    doc = None
62
63    # Work around ast.get_docstring breakage.
64    if body and getKind(body[0]) == "Expr":
65        if getKind(body[0].value) == "Str":  # python3.7 or earlier
66            doc = body[0].value.s
67            body = body[1:]
68        elif getKind(body[0].value) == "Constant":  # python3.8
69            # Only strings should be used, but all other constants can immediately be ignored,
70            # it seems that e.g. Ellipsis is common.
71            if type(body[0].value.value) is str:
72                doc = body[0].value.value
73            body = body[1:]
74
75        if Options.hasPythonFlagNoDocstrings():
76            doc = None
77
78    return body, doc
79
80
81def parseSourceCodeToAst(source_code, module_name, filename, line_offset):
82    # Workaround: ast.parse cannot cope with some situations where a file is not
83    # terminated by a new line.
84    if not source_code.endswith("\n"):
85        source_code = source_code + "\n"
86
87    try:
88        body = ast.parse(source_code, filename)
89    except RuntimeError as e:
90        if "maximum recursion depth" in e.args[0]:
91            raise CodeTooComplexCode(module_name, filename)
92
93        raise
94
95    assert getKind(body) == "Module"
96
97    if line_offset > 0:
98        ast.increment_lineno(body, line_offset)
99
100    return body
101
102
103def detectFunctionBodyKind(nodes, start_value=None):
104    # This is a complex mess, following the scope means a lot of checks need
105    # to be done. pylint: disable=too-many-branches,too-many-statements
106
107    indications = set()
108    if start_value is not None:
109        indications.add(start_value)
110
111    flags = set()
112
113    def _checkCoroutine(field):
114        """Check only for co-routine nature of the field and only update that."""
115        # TODO: This is clumsy code, trying to achieve what non-local does for
116        # Python2 as well.
117
118        old = set(indications)
119        indications.clear()
120
121        _check(field)
122
123        if "Coroutine" in indications:
124            old.add("Coroutine")
125
126        indications.clear()
127        indications.update(old)
128
129    def _check(node):
130        node_class = node.__class__
131
132        if node_class is ast.Yield:
133            indications.add("Generator")
134        elif python_version >= 0x300 and node_class is ast.YieldFrom:
135            indications.add("Generator")
136        elif python_version >= 0x350 and node_class in (ast.Await, ast.AsyncWith):
137            indications.add("Coroutine")
138
139        # Recurse to children, but do not cross scope boundary doing so.
140        if node_class is ast.ClassDef:
141            for name, field in ast.iter_fields(node):
142                if name in ("name", "body"):
143                    pass
144                elif name in ("bases", "decorator_list", "keywords"):
145                    for child in field:
146                        _check(child)
147                elif name == "starargs":
148                    if field is not None:
149                        _check(field)
150                elif name == "kwargs":
151                    if field is not None:
152                        _check(field)
153                else:
154                    assert False, (name, field, ast.dump(node))
155        elif node_class in (ast.FunctionDef, ast.Lambda) or (
156            python_version >= 0x350 and node_class is ast.AsyncFunctionDef
157        ):
158            for name, field in ast.iter_fields(node):
159                if name in ("name", "body"):
160                    pass
161                elif name in ("bases", "decorator_list"):
162                    for child in field:
163                        _check(child)
164                elif name == "args":
165                    for child in field.defaults:
166                        _check(child)
167
168                    if python_version >= 0x300:
169                        for child in node.args.kw_defaults:
170                            if child is not None:
171                                _check(child)
172
173                        for child in node.args.args:
174                            if child.annotation is not None:
175                                _check(child.annotation)
176
177                elif name == "returns":
178                    if field is not None:
179                        _check(field)
180                elif name == "type_comment":
181                    # Python3.8: We don't have structure here.
182                    assert field is None or type(field) is str
183                else:
184                    assert False, (name, field, ast.dump(node))
185        elif node_class is ast.GeneratorExp:
186            for name, field in ast.iter_fields(node):
187                if name == "name":
188                    pass
189                elif name in ("body", "comparators", "elt"):
190                    if python_version >= 0x370:
191                        _checkCoroutine(field)
192                elif name == "generators":
193                    _check(field[0].iter)
194
195                    # New syntax in 3.7 allows these to be present in functions not
196                    # declared with "async def", so we need to check them, but
197                    # only if top level.
198                    if python_version >= 0x370 and node in nodes:
199                        for gen in field:
200                            if gen.is_async:
201                                indications.add("Coroutine")
202                                break
203
204                            if _checkCoroutine(gen):
205                                break
206                else:
207                    assert False, (name, field, ast.dump(node))
208        elif node_class is ast.ListComp and python_version >= 0x300:
209            for name, field in ast.iter_fields(node):
210                if name in ("name", "body", "comparators"):
211                    pass
212                elif name == "generators":
213                    _check(field[0].iter)
214                elif name in ("body", "elt"):
215                    _check(field)
216                else:
217                    assert False, (name, field, ast.dump(node))
218        elif python_version >= 0x270 and node_class is ast.SetComp:
219            for name, field in ast.iter_fields(node):
220                if name in ("name", "body", "comparators", "elt"):
221                    pass
222                elif name == "generators":
223                    _check(field[0].iter)
224                else:
225                    assert False, (name, field, ast.dump(node))
226        elif python_version >= 0x270 and node_class is ast.DictComp:
227            for name, field in ast.iter_fields(node):
228                if name in ("name", "body", "comparators", "key", "value"):
229                    pass
230                elif name == "generators":
231                    _check(field[0].iter)
232                else:
233                    assert False, (name, field, ast.dump(node))
234        elif python_version >= 0x370 and node_class is ast.comprehension:
235            for name, field in ast.iter_fields(node):
236                if name in ("name", "target"):
237                    pass
238                elif name == "iter":
239                    # Top level comprehension iterators do not influence those.
240                    if node not in nodes:
241                        _check(field)
242                elif name == "ifs":
243                    for child in field:
244                        _check(child)
245                elif name == "is_async":
246                    if field:
247                        indications.add("Coroutine")
248                else:
249                    assert False, (name, field, ast.dump(node))
250        elif node_class is ast.Name:
251            if python_version >= 0x300 and node.id == "super":
252                flags.add("has_super")
253        elif python_version < 0x300 and node_class is ast.Exec:
254            flags.add("has_exec")
255
256            if node.globals is None:
257                flags.add("has_unqualified_exec")
258
259            for child in ast.iter_child_nodes(node):
260                _check(child)
261        elif python_version < 0x300 and node_class is ast.ImportFrom:
262            for import_desc in node.names:
263                if import_desc.name[0] == "*":
264                    flags.add("has_exec")
265            for child in ast.iter_child_nodes(node):
266                _check(child)
267        else:
268            for child in ast.iter_child_nodes(node):
269                _check(child)
270
271    for node in nodes:
272        _check(node)
273
274    if indications:
275        if "Coroutine" in indications and "Generator" in indications:
276            function_kind = "Asyncgen"
277        else:
278            # If we found something, make sure we agree on all clues.
279            assert len(indications) == 1, indications
280            function_kind = indications.pop()
281    else:
282        function_kind = "Function"
283
284    return function_kind, flags
285
286
287build_nodes_args3 = None
288build_nodes_args2 = None
289build_nodes_args1 = None
290
291
292def setBuildingDispatchers(path_args3, path_args2, path_args1):
293    # Using global here, as this is really a singleton, in the form of a module,
294    # and this is to break the cyclic dependency it has, pylint: disable=global-statement
295
296    global build_nodes_args3, build_nodes_args2, build_nodes_args1
297
298    build_nodes_args3 = path_args3
299    build_nodes_args2 = path_args2
300    build_nodes_args1 = path_args1
301
302
303def buildNode(provider, node, source_ref, allow_none=False):
304    if node is None and allow_none:
305        return None
306
307    try:
308        kind = getKind(node)
309
310        if hasattr(node, "lineno"):
311            source_ref = source_ref.atLineNumber(node.lineno)
312
313        if kind in build_nodes_args3:
314            result = build_nodes_args3[kind](
315                provider=provider, node=node, source_ref=source_ref
316            )
317        elif kind in build_nodes_args2:
318            result = build_nodes_args2[kind](node=node, source_ref=source_ref)
319        elif kind in build_nodes_args1:
320            result = build_nodes_args1[kind](source_ref=source_ref)
321        elif kind == "Pass":
322            result = None
323        else:
324            assert False, ast.dump(node)
325
326        if result is None and allow_none:
327            return None
328
329        assert isinstance(result, NodeBase), result
330
331        return result
332    except SyntaxError:
333        raise
334    except RuntimeError:
335        # Very likely the stack overflow, which we will turn into too complex
336        # code exception, don't warn about it with a code dump then.
337        raise
338    except KeyboardInterrupt:
339        # User interrupting is not a problem with the source, but tell where
340        # we got interrupted.
341        optimization_logger.info("Interrupted at '%s'." % source_ref)
342        raise
343    except:
344        optimization_logger.warning(
345            "Problem at '%s' with %s." % (source_ref, ast.dump(node))
346        )
347        raise
348
349
350def buildNodeList(provider, nodes, source_ref, allow_none=False):
351    if nodes is not None:
352        result = []
353
354        for node in nodes:
355            if hasattr(node, "lineno"):
356                node_source_ref = source_ref.atLineNumber(node.lineno)
357            else:
358                node_source_ref = source_ref
359
360            entry = buildNode(provider, node, node_source_ref, allow_none)
361
362            if entry is not None:
363                result.append(entry)
364
365        return result
366    else:
367        return []
368
369
370_host_node = None
371
372
373def buildAnnotationNode(provider, node, source_ref):
374    if (
375        python_version >= 0x370
376        and provider.getParentModule().getFutureSpec().isFutureAnnotations()
377    ):
378
379        # Using global value for cache, to avoid creating it over and over,
380        # avoiding the pylint: disable=global-statement
381        global _host_node
382
383        if _host_node is None:
384            _host_node = ast.parse("x:1")
385
386        _host_node.body[0].annotation = node
387
388        r = compile(
389            _host_node,
390            "<annotations>",
391            "exec",
392            __future__.CO_FUTURE_ANNOTATIONS,
393            dont_inherit=True,
394        )
395
396        # Using exec here, to compile the ast node tree back to string,
397        # there is no accessible "ast.unparse", and this works as a hack
398        # to convert our node to a string annotation, pylint: disable=exec-used
399        m = {}
400        exec(r, m)
401
402        value = m["__annotations__"]["x"]
403
404        if Options.is_debug and python_version >= 0x390:
405            # TODO: In Python3.9+, we should only use ast.unparse
406            assert value == ast.unparse(node)
407
408        return makeConstantRefNode(constant=value, source_ref=source_ref)
409
410    return buildNode(provider, node, source_ref)
411
412
413def makeModuleFrame(module, statements, source_ref):
414    assert module.isCompiledPythonModule()
415
416    if Options.is_fullcompat:
417        code_name = "<module>"
418    else:
419        if module.isMainModule():
420            code_name = "<module>"
421        else:
422            code_name = "<module %s>" % module.getFullName()
423
424    return StatementsFrameModule(
425        statements=statements,
426        code_object=CodeObjectSpec(
427            co_name=code_name,
428            co_kind="Module",
429            co_varnames=(),
430            co_freevars=(),
431            co_argcount=0,
432            co_posonlyargcount=0,
433            co_kwonlyargcount=0,
434            co_has_starlist=False,
435            co_has_stardict=False,
436            co_filename=module.getRunTimeFilename(),
437            co_lineno=source_ref.getLineNumber(),
438            future_spec=module.getFutureSpec(),
439        ),
440        source_ref=source_ref,
441    )
442
443
444def buildStatementsNode(provider, nodes, source_ref):
445    # We are not creating empty statement sequences.
446    if nodes is None:
447        return None
448
449    # Build as list of statements, throw away empty ones, and remove useless
450    # nesting.
451    statements = buildNodeList(provider, nodes, source_ref, allow_none=True)
452    statements = mergeStatements(statements)
453
454    # We are not creating empty statement sequences. Might be empty, because
455    # e.g. a global node generates not really a statement, or pass statements.
456    if not statements:
457        return None
458    else:
459        return StatementsSequence(statements=statements, source_ref=source_ref)
460
461
462def buildFrameNode(provider, nodes, code_object, source_ref):
463    # We are not creating empty statement sequences.
464    if nodes is None:
465        return None
466
467    # Build as list of statements, throw away empty ones, and remove useless
468    # nesting.
469    statements = buildNodeList(provider, nodes, source_ref, allow_none=True)
470    statements = mergeStatements(statements)
471
472    # We are not creating empty statement sequences. Might be empty, because
473    # e.g. a global node generates not really a statement, or pass statements.
474    if not statements:
475        return None
476
477    if provider.isExpressionOutlineFunction():
478        provider = provider.getParentVariableProvider()
479
480    if provider.isExpressionFunctionBody() or provider.isExpressionClassBody():
481        result = StatementsFrameFunction(
482            statements=statements, code_object=code_object, source_ref=source_ref
483        )
484    elif provider.isExpressionGeneratorObjectBody():
485        result = StatementsFrameGenerator(
486            statements=statements, code_object=code_object, source_ref=source_ref
487        )
488    elif provider.isExpressionCoroutineObjectBody():
489        result = StatementsFrameCoroutine(
490            statements=statements, code_object=code_object, source_ref=source_ref
491        )
492    elif provider.isExpressionAsyncgenObjectBody():
493        result = StatementsFrameAsyncgen(
494            statements=statements, code_object=code_object, source_ref=source_ref
495        )
496    else:
497        assert False, provider
498
499    return result
500
501
502def makeStatementsSequenceOrStatement(statements, source_ref):
503    """Make a statement sequence, but only if more than one statement
504
505    Useful for when we can unroll constructs already here, but are not sure if
506    we actually did that. This avoids the branch or the pollution of doing it
507    always.
508    """
509
510    if len(statements) > 1:
511        return StatementsSequence(
512            statements=mergeStatements(statements), source_ref=source_ref
513        )
514    else:
515        return statements[0]
516
517
518def makeStatementsSequence(statements, allow_none, source_ref):
519    if allow_none:
520        statements = tuple(
521            statement for statement in statements if statement is not None
522        )
523
524    if statements:
525        return StatementsSequence(
526            statements=mergeStatements(statements), source_ref=source_ref
527        )
528    else:
529        return None
530
531
532def makeStatementsSequenceFromStatement(statement):
533    return StatementsSequence(
534        statements=mergeStatements((statement,)),
535        source_ref=statement.getSourceReference(),
536    )
537
538
539def makeStatementsSequenceFromStatements(*statements):
540    assert statements
541    assert None not in statements
542
543    statements = mergeStatements(statements, allow_none=False)
544
545    return StatementsSequence(
546        statements=statements, source_ref=statements[0].getSourceReference()
547    )
548
549
550def makeDictCreationOrConstant2(keys, values, source_ref):
551    # Create dictionary node. Tries to avoid it for constant values that are not
552    # mutable. Keys are Python strings here.
553
554    assert len(keys) == len(values)
555    for value in values:
556        if not value.isExpressionConstantRef():
557            constant = False
558            break
559    else:
560        constant = True
561
562    # Note: This would happen in optimization instead, but lets just do it
563    # immediately to save some time.
564    if constant:
565        # Unless told otherwise, create the dictionary in its full size, so
566        # that no growing occurs and the constant becomes as similar as possible
567        # before being marshaled.
568        result = makeConstantRefNode(
569            constant=Constants.createConstantDict(
570                keys=keys, values=[value.getCompileTimeConstant() for value in values]
571            ),
572            user_provided=True,
573            source_ref=source_ref,
574        )
575    else:
576        result = makeExpressionMakeDict(
577            pairs=[
578                ExpressionKeyValuePair(
579                    key=makeConstantRefNode(
580                        constant=key,
581                        source_ref=value.getSourceReference(),
582                        user_provided=True,
583                    ),
584                    value=value,
585                    source_ref=value.getSourceReference(),
586                )
587                for key, value in zip(keys, values)
588            ],
589            source_ref=source_ref,
590        )
591
592    if values:
593        result.setCompatibleSourceReference(
594            source_ref=values[-1].getCompatibleSourceReference()
595        )
596
597    return result
598
599
600def getStatementsAppended(statement_sequence, statements):
601    return makeStatementsSequence(
602        statements=(statement_sequence, statements),
603        allow_none=False,
604        source_ref=statement_sequence.getSourceReference(),
605    )
606
607
608def getStatementsPrepended(statement_sequence, statements):
609    return makeStatementsSequence(
610        statements=(statements, statement_sequence),
611        allow_none=False,
612        source_ref=statement_sequence.getSourceReference(),
613    )
614
615
616def makeReraiseExceptionStatement(source_ref):
617    return StatementReraiseException(source_ref=source_ref)
618
619
620def mangleName(name, owner):
621    """Mangle names with leading "__" for usage in a class owner.
622
623    Notes: The is the private name handling for Python classes.
624    """
625
626    if not name.startswith("__") or name.endswith("__"):
627        return name
628    else:
629        # The mangling of function variable names depends on being inside a
630        # class.
631        class_container = owner.getContainingClassDictCreation()
632
633        if class_container is None:
634            return name
635        else:
636            return "_%s%s" % (class_container.getName().lstrip("_"), name)
637
638
639def makeCallNode(called, *args, **kwargs):
640    source_ref = args[-1]
641
642    if len(args) > 1:
643        args = makeExpressionMakeTupleOrConstant(
644            elements=args[:-1], user_provided=True, source_ref=source_ref
645        )
646    else:
647        args = None
648
649    if kwargs:
650        kwargs = makeDictCreationOrConstant2(
651            keys=tuple(kwargs.keys()),
652            values=tuple(kwargs.values()),
653            source_ref=source_ref,
654        )
655    else:
656        kwargs = None
657
658    return makeExpressionCall(
659        called=called, args=args, kw=kwargs, source_ref=source_ref
660    )
661
662
663build_contexts = [None]
664
665
666def pushBuildContext(value):
667    build_contexts.append(value)
668
669
670def popBuildContext():
671    del build_contexts[-1]
672
673
674def getBuildContext():
675    return build_contexts[-1]
676