1from __future__ import absolute_import
2
3import itertools
4from time import time
5
6from . import Errors
7from . import DebugFlags
8from . import Options
9from .Errors import CompileError, InternalError, AbortError
10from . import Naming
11
12#
13# Really small pipeline stages
14#
15def dumptree(t):
16    # For quick debugging in pipelines
17    print(t.dump())
18    return t
19
20def abort_on_errors(node):
21    # Stop the pipeline if there are any errors.
22    if Errors.num_errors != 0:
23        raise AbortError("pipeline break")
24    return node
25
26def parse_stage_factory(context):
27    def parse(compsrc):
28        source_desc = compsrc.source_desc
29        full_module_name = compsrc.full_module_name
30        initial_pos = (source_desc, 1, 0)
31        saved_cimport_from_pyx, Options.cimport_from_pyx = Options.cimport_from_pyx, False
32        scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
33        Options.cimport_from_pyx = saved_cimport_from_pyx
34        tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
35        tree.compilation_source = compsrc
36        tree.scope = scope
37        tree.is_pxd = False
38        return tree
39    return parse
40
41def parse_pxd_stage_factory(context, scope, module_name):
42    def parse(source_desc):
43        tree = context.parse(source_desc, scope, pxd=True,
44                             full_module_name=module_name)
45        tree.scope = scope
46        tree.is_pxd = True
47        return tree
48    return parse
49
50def generate_pyx_code_stage_factory(options, result):
51    def generate_pyx_code_stage(module_node):
52        module_node.process_implementation(options, result)
53        result.compilation_source = module_node.compilation_source
54        return result
55    return generate_pyx_code_stage
56
57
58def inject_pxd_code_stage_factory(context):
59    def inject_pxd_code_stage(module_node):
60        for name, (statlistnode, scope) in context.pxds.items():
61            module_node.merge_in(statlistnode, scope)
62        return module_node
63    return inject_pxd_code_stage
64
65
66def use_utility_code_definitions(scope, target, seen=None):
67    if seen is None:
68        seen = set()
69
70    for entry in scope.entries.values():
71        if entry in seen:
72            continue
73
74        seen.add(entry)
75        if entry.used and entry.utility_code_definition:
76            target.use_utility_code(entry.utility_code_definition)
77            for required_utility in entry.utility_code_definition.requires:
78                target.use_utility_code(required_utility)
79        elif entry.as_module:
80            use_utility_code_definitions(entry.as_module, target, seen)
81
82
83def sort_utility_codes(utilcodes):
84    ranks = {}
85    def get_rank(utilcode):
86        if utilcode not in ranks:
87            ranks[utilcode] = 0  # prevent infinite recursion on circular dependencies
88            original_order = len(ranks)
89            ranks[utilcode] = 1 + min([get_rank(dep) for dep in utilcode.requires or ()] or [-1]) + original_order * 1e-8
90        return ranks[utilcode]
91    for utilcode in utilcodes:
92        get_rank(utilcode)
93    return [utilcode for utilcode, _ in sorted(ranks.items(), key=lambda kv: kv[1])]
94
95
96def normalize_deps(utilcodes):
97    deps = {}
98    for utilcode in utilcodes:
99        deps[utilcode] = utilcode
100
101    def unify_dep(dep):
102        if dep in deps:
103            return deps[dep]
104        else:
105            deps[dep] = dep
106            return dep
107
108    for utilcode in utilcodes:
109        utilcode.requires = [unify_dep(dep) for dep in utilcode.requires or ()]
110
111
112def inject_utility_code_stage_factory(context):
113    def inject_utility_code_stage(module_node):
114        module_node.prepare_utility_code()
115        use_utility_code_definitions(context.cython_scope, module_node.scope)
116        module_node.scope.utility_code_list = sort_utility_codes(module_node.scope.utility_code_list)
117        normalize_deps(module_node.scope.utility_code_list)
118        added = []
119        # Note: the list might be extended inside the loop (if some utility code
120        # pulls in other utility code, explicitly or implicitly)
121        for utilcode in module_node.scope.utility_code_list:
122            if utilcode in added:
123                continue
124            added.append(utilcode)
125            if utilcode.requires:
126                for dep in utilcode.requires:
127                    if dep not in added and dep not in module_node.scope.utility_code_list:
128                        module_node.scope.utility_code_list.append(dep)
129            tree = utilcode.get_tree(cython_scope=context.cython_scope)
130            if tree:
131                module_node.merge_in(tree.with_compiler_directives(),
132                                     tree.scope, merge_scope=True)
133        return module_node
134    return inject_utility_code_stage
135
136
137#
138# Pipeline factories
139#
140
141def create_pipeline(context, mode, exclude_classes=()):
142    assert mode in ('pyx', 'py', 'pxd')
143    from .Visitor import PrintTree
144    from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
145    from .ParseTreeTransforms import ForwardDeclareTypes, InjectGilHandling, AnalyseDeclarationsTransform
146    from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes
147    from .ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
148    from .ParseTreeTransforms import TrackNumpyAttributes, InterpretCompilerDirectives, TransformBuiltinMethods
149    from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
150    from .ParseTreeTransforms import CalculateQualifiedNamesTransform
151    from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
152    from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions, AutoCpdefFunctionDefinitions
153    from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck, CoerceCppTemps
154    from .FlowControl import ControlFlowAnalysis
155    from .AnalysedTreeTransforms import AutoTestDictTransform
156    from .AutoDocTransforms import EmbedSignature
157    from .Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
158    from .Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
159    from .Optimize import InlineDefNodeCalls
160    from .Optimize import ConstantFolding, FinalOptimizePhase
161    from .Optimize import DropRefcountingTransform
162    from .Optimize import ConsolidateOverflowCheck
163    from .Buffer import IntroduceBufferAuxiliaryVars
164    from .ModuleNode import check_c_declarations, check_c_declarations_pxd
165
166
167    if mode == 'pxd':
168        _check_c_declarations = check_c_declarations_pxd
169        _specific_post_parse = PxdPostParse(context)
170    else:
171        _check_c_declarations = check_c_declarations
172        _specific_post_parse = None
173
174    if mode == 'py':
175        _align_function_definitions = AlignFunctionDefinitions(context)
176    else:
177        _align_function_definitions = None
178
179    # NOTE: This is the "common" parts of the pipeline, which is also
180    # code in pxd files. So it will be run multiple times in a
181    # compilation stage.
182    stages = [
183        NormalizeTree(context),
184        PostParse(context),
185        _specific_post_parse,
186        TrackNumpyAttributes(),
187        InterpretCompilerDirectives(context, context.compiler_directives),
188        ParallelRangeTransform(context),
189        WithTransform(context),
190        AdjustDefByDirectives(context),
191        _align_function_definitions,
192        MarkClosureVisitor(context),
193        AutoCpdefFunctionDefinitions(context),
194        RemoveUnreachableCode(context),
195        ConstantFolding(),
196        FlattenInListTransform(),
197        DecoratorTransform(context),
198        ForwardDeclareTypes(context),
199        InjectGilHandling(),
200        AnalyseDeclarationsTransform(context),
201        AutoTestDictTransform(context),
202        EmbedSignature(context),
203        EarlyReplaceBuiltinCalls(context),  ## Necessary?
204        TransformBuiltinMethods(context),
205        MarkParallelAssignments(context),
206        ControlFlowAnalysis(context),
207        RemoveUnreachableCode(context),
208        # MarkParallelAssignments(context),
209        MarkOverflowingArithmetic(context),
210        IntroduceBufferAuxiliaryVars(context),
211        _check_c_declarations,
212        InlineDefNodeCalls(context),
213        AnalyseExpressionsTransform(context),
214        FindInvalidUseOfFusedTypes(context),
215        ExpandInplaceOperators(context),
216        IterationTransform(context),
217        SwitchTransform(context),
218        OptimizeBuiltinCalls(context),  ## Necessary?
219        CreateClosureClasses(context),  ## After all lookups and type inference
220        CalculateQualifiedNamesTransform(context),
221        ConsolidateOverflowCheck(context),
222        DropRefcountingTransform(),
223        FinalOptimizePhase(context),
224        CoerceCppTemps(context),
225        GilCheck(),
226        ]
227    filtered_stages = []
228    for s in stages:
229        if s.__class__ not in exclude_classes:
230            filtered_stages.append(s)
231    return filtered_stages
232
233def create_pyx_pipeline(context, options, result, py=False, exclude_classes=()):
234    if py:
235        mode = 'py'
236    else:
237        mode = 'pyx'
238    test_support = []
239    if options.evaluate_tree_assertions:
240        from ..TestUtils import TreeAssertVisitor
241        test_support.append(TreeAssertVisitor())
242
243    if options.gdb_debug:
244        from ..Debugger import DebugWriter  # requires Py2.5+
245        from .ParseTreeTransforms import DebugTransform
246        context.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
247            options.output_dir)
248        debug_transform = [DebugTransform(context, options, result)]
249    else:
250        debug_transform = []
251
252    return list(itertools.chain(
253        [parse_stage_factory(context)],
254        create_pipeline(context, mode, exclude_classes=exclude_classes),
255        test_support,
256        [inject_pxd_code_stage_factory(context),
257         inject_utility_code_stage_factory(context),
258         abort_on_errors],
259        debug_transform,
260        [generate_pyx_code_stage_factory(options, result)]))
261
262def create_pxd_pipeline(context, scope, module_name):
263    from .CodeGeneration import ExtractPxdCode
264
265    # The pxd pipeline ends up with a CCodeWriter containing the
266    # code of the pxd, as well as a pxd scope.
267    return [
268        parse_pxd_stage_factory(context, scope, module_name)
269        ] + create_pipeline(context, 'pxd') + [
270        ExtractPxdCode()
271        ]
272
273def create_py_pipeline(context, options, result):
274    return create_pyx_pipeline(context, options, result, py=True)
275
276def create_pyx_as_pxd_pipeline(context, result):
277    from .ParseTreeTransforms import AlignFunctionDefinitions, \
278        MarkClosureVisitor, WithTransform, AnalyseDeclarationsTransform
279    from .Optimize import ConstantFolding, FlattenInListTransform
280    from .Nodes import StatListNode
281    pipeline = []
282    pyx_pipeline = create_pyx_pipeline(context, context.options, result,
283                                       exclude_classes=[
284                                           AlignFunctionDefinitions,
285                                           MarkClosureVisitor,
286                                           ConstantFolding,
287                                           FlattenInListTransform,
288                                           WithTransform
289                                           ])
290    from .Visitor import VisitorTransform
291    class SetInPxdTransform(VisitorTransform):
292        # A number of nodes have an "in_pxd" attribute which affects AnalyseDeclarationsTransform
293        # (for example controlling pickling generation). Set it, to make sure we don't mix them up with
294        # the importing main module.
295        # FIXME: This should be done closer to the parsing step.
296        def visit_StatNode(self, node):
297            if hasattr(node, "in_pxd"):
298                node.in_pxd = True
299            self.visitchildren(node)
300            return node
301
302        visit_Node = VisitorTransform.recurse_to_children
303
304    for stage in pyx_pipeline:
305        pipeline.append(stage)
306        if isinstance(stage, AnalyseDeclarationsTransform):
307            pipeline.insert(-1, SetInPxdTransform())
308            break  # This is the last stage we need.
309    def fake_pxd(root):
310        for entry in root.scope.entries.values():
311            if not entry.in_cinclude:
312                entry.defined_in_pxd = 1
313                if entry.name == entry.cname and entry.visibility != 'extern':
314                    # Always mangle non-extern cimported entries.
315                    entry.cname = entry.scope.mangle(Naming.func_prefix, entry.name)
316        return StatListNode(root.pos, stats=[]), root.scope
317    pipeline.append(fake_pxd)
318    return pipeline
319
320def insert_into_pipeline(pipeline, transform, before=None, after=None):
321    """
322    Insert a new transform into the pipeline after or before an instance of
323    the given class. e.g.
324
325        pipeline = insert_into_pipeline(pipeline, transform,
326                                        after=AnalyseDeclarationsTransform)
327    """
328    assert before or after
329
330    cls = before or after
331    for i, t in enumerate(pipeline):
332        if isinstance(t, cls):
333            break
334
335    if after:
336        i += 1
337
338    return pipeline[:i] + [transform] + pipeline[i:]
339
340#
341# Running a pipeline
342#
343
344_pipeline_entry_points = {}
345
346
347def run_pipeline(pipeline, source, printtree=True):
348    from .Visitor import PrintTree
349    exec_ns = globals().copy() if DebugFlags.debug_verbose_pipeline else None
350
351    def run(phase, data):
352        return phase(data)
353
354    error = None
355    data = source
356    try:
357        try:
358            for phase in pipeline:
359                if phase is not None:
360                    if not printtree and isinstance(phase, PrintTree):
361                        continue
362                    if DebugFlags.debug_verbose_pipeline:
363                        t = time()
364                        print("Entering pipeline phase %r" % phase)
365                        # create a new wrapper for each step to show the name in profiles
366                        phase_name = getattr(phase, '__name__', type(phase).__name__)
367                        try:
368                            run = _pipeline_entry_points[phase_name]
369                        except KeyError:
370                            exec("def %s(phase, data): return phase(data)" % phase_name, exec_ns)
371                            run = _pipeline_entry_points[phase_name] = exec_ns[phase_name]
372                    data = run(phase, data)
373                    if DebugFlags.debug_verbose_pipeline:
374                        print("    %.3f seconds" % (time() - t))
375        except CompileError as err:
376            # err is set
377            Errors.report_error(err, use_stack=False)
378            error = err
379    except InternalError as err:
380        # Only raise if there was not an earlier error
381        if Errors.num_errors == 0:
382            raise
383        error = err
384    except AbortError as err:
385        error = err
386    return (error, data)
387