Pipeline.py 12.8 KB
Newer Older
1 2
from __future__ import absolute_import

3 4 5
import itertools
from time import time

6 7 8 9 10 11
from . import Errors
from . import DebugFlags
from . import Options
from .Visitor import CythonTransform
from .Errors import CompileError, InternalError, AbortError
from . import Naming
12 13 14 15 16 17 18 19 20 21 22 23

#
# Really small pipeline stages
#
def dumptree(t):
    # For quick debugging in pipelines
    print t.dump()
    return t

def abort_on_errors(node):
    # Stop the pipeline if there are any errors.
    if Errors.num_errors != 0:
24
        raise AbortError("pipeline break")
25 26 27 28 29
    return node

def parse_stage_factory(context):
    def parse(compsrc):
        source_desc = compsrc.source_desc
30
        full_module_name = compsrc.full_module_name
31 32
        initial_pos = (source_desc, 1, 0)
        saved_cimport_from_pyx, Options.cimport_from_pyx = Options.cimport_from_pyx, False
33
        scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        Options.cimport_from_pyx = saved_cimport_from_pyx
        tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
        tree.compilation_source = compsrc
        tree.scope = scope
        tree.is_pxd = False
        return tree
    return parse

def parse_pxd_stage_factory(context, scope, module_name):
    def parse(source_desc):
        tree = context.parse(source_desc, scope, pxd=True,
                             full_module_name=module_name)
        tree.scope = scope
        tree.is_pxd = True
        return tree
    return parse

def generate_pyx_code_stage_factory(options, result):
    def generate_pyx_code_stage(module_node):
        module_node.process_implementation(options, result)
        result.compilation_source = module_node.compilation_source
        return result
    return generate_pyx_code_stage

def inject_pxd_code_stage_factory(context):
    def inject_pxd_code_stage(module_node):
        for name, (statlistnode, scope) in context.pxds.iteritems():
            module_node.merge_in(statlistnode, scope)
        return module_node
    return inject_pxd_code_stage

65 66 67 68
def use_utility_code_definitions(scope, target, seen=None):
    if seen is None:
        seen = set()

69
    for entry in scope.entries.itervalues():
70 71 72 73
        if entry in seen:
            continue

        seen.add(entry)
74 75
        if entry.used and entry.utility_code_definition:
            target.use_utility_code(entry.utility_code_definition)
76 77
            for required_utility in entry.utility_code_definition.requires:
                target.use_utility_code(required_utility)
78
        elif entry.as_module:
79
            use_utility_code_definitions(entry.as_module, target, seen)
80 81 82

def inject_utility_code_stage_factory(context):
    def inject_utility_code_stage(module_node):
83
        use_utility_code_definitions(context.cython_scope, module_node.scope)
84 85
        added = []
        # Note: the list might be extended inside the loop (if some utility code
86
        # pulls in other utility code, explicitly or implicitly)
87 88 89
        for utilcode in module_node.scope.utility_code_list:
            if utilcode in added: continue
            added.append(utilcode)
90 91 92 93
            if utilcode.requires:
                for dep in utilcode.requires:
                    if not dep in added and not dep in module_node.scope.utility_code_list:
                        module_node.scope.utility_code_list.append(dep)
94 95 96 97 98
            tree = utilcode.get_tree()
            if tree:
                module_node.merge_in(tree.body, tree.scope, merge_scope=True)
        return module_node
    return inject_utility_code_stage
99

Mark Florisson's avatar
Mark Florisson committed
100 101 102 103 104 105 106 107
class UseUtilityCodeDefinitions(CythonTransform):
    # Temporary hack to use any utility code in nodes' "utility_code_definitions".
    # This should be moved to the code generation phase of the relevant nodes once
    # it is safe to generate CythonUtilityCode at code generation time.
    def __call__(self, node):
        self.scope = node.scope
        return super(UseUtilityCodeDefinitions, self).__call__(node)

108 109 110 111 112 113
    def process_entry(self, entry):
        if entry:
            for utility_code in (entry.utility_code, entry.utility_code_definition):
                if utility_code:
                    self.scope.use_utility_code(utility_code)

Mark Florisson's avatar
Mark Florisson committed
114
    def visit_AttributeNode(self, node):
115
        self.process_entry(node.entry)
Mark Florisson's avatar
Mark Florisson committed
116 117 118
        return node

    def visit_NameNode(self, node):
119 120
        self.process_entry(node.entry)
        self.process_entry(node.type_entry)
Mark Florisson's avatar
Mark Florisson committed
121
        return node
122

123 124 125 126
#
# Pipeline factories
#

127
def create_pipeline(context, mode, exclude_classes=()):
128
    assert mode in ('pyx', 'py', 'pxd')
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    from .Visitor import PrintTree
    from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
    from .ParseTreeTransforms import ForwardDeclareTypes, AnalyseDeclarationsTransform
    from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes
    from .ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
    from .ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
    from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
    from .ParseTreeTransforms import CalculateQualifiedNamesTransform
    from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
    from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions
    from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck
    from .FlowControl import ControlFlowAnalysis
    from .AnalysedTreeTransforms import AutoTestDictTransform
    from .AutoDocTransforms import EmbedSignature
    from .Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
    from .Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
    from .Optimize import InlineDefNodeCalls
    from .Optimize import ConstantFolding, FinalOptimizePhase
    from .Optimize import DropRefcountingTransform
    from .Optimize import ConsolidateOverflowCheck
    from .Buffer import IntroduceBufferAuxiliaryVars
    from .ModuleNode import check_c_declarations, check_c_declarations_pxd
Kurt Smith's avatar
Kurt Smith committed
151

152 153 154 155 156 157 158 159 160 161 162 163 164

    if mode == 'pxd':
        _check_c_declarations = check_c_declarations_pxd
        _specific_post_parse = PxdPostParse(context)
    else:
        _check_c_declarations = check_c_declarations
        _specific_post_parse = None

    if mode == 'py':
        _align_function_definitions = AlignFunctionDefinitions(context)
    else:
        _align_function_definitions = None

165 166 167
    # NOTE: This is the "common" parts of the pipeline, which is also
    # code in pxd files. So it will be run multiple times in a
    # compilation stage.
168
    stages = [
169 170 171 172 173 174
        NormalizeTree(context),
        PostParse(context),
        _specific_post_parse,
        InterpretCompilerDirectives(context, context.compiler_directives),
        ParallelRangeTransform(context),
        AdjustDefByDirectives(context),
175
        WithTransform(context),
176 177 178 179 180 181 182 183 184 185 186
        MarkClosureVisitor(context),
        _align_function_definitions,
        RemoveUnreachableCode(context),
        ConstantFolding(),
        FlattenInListTransform(),
        DecoratorTransform(context),
        ForwardDeclareTypes(context),
        AnalyseDeclarationsTransform(context),
        AutoTestDictTransform(context),
        EmbedSignature(context),
        EarlyReplaceBuiltinCalls(context),  ## Necessary?
187
        TransformBuiltinMethods(context),
188
        MarkParallelAssignments(context),
189
        ControlFlowAnalysis(context),
190
        RemoveUnreachableCode(context),
191
        # MarkParallelAssignments(context),
192 193 194
        MarkOverflowingArithmetic(context),
        IntroduceBufferAuxiliaryVars(context),
        _check_c_declarations,
195
        InlineDefNodeCalls(context),
196
        AnalyseExpressionsTransform(context),
197
        FindInvalidUseOfFusedTypes(context),
198
        ExpandInplaceOperators(context),
199 200
        IterationTransform(context),
        SwitchTransform(context),
201
        OptimizeBuiltinCalls(context),  ## Necessary?
202
        CreateClosureClasses(context),  ## After all lookups and type inference
203
        CalculateQualifiedNamesTransform(context),
204
        ConsolidateOverflowCheck(context),
205 206 207
        DropRefcountingTransform(),
        FinalOptimizePhase(context),
        GilCheck(),
Mark Florisson's avatar
Mark Florisson committed
208
        UseUtilityCodeDefinitions(context),
209
        ]
210 211 212 213 214
    filtered_stages = []
    for s in stages:
        if s.__class__ not in exclude_classes:
            filtered_stages.append(s)
    return filtered_stages
215

216
def create_pyx_pipeline(context, options, result, py=False, exclude_classes=()):
217 218 219 220 221 222
    if py:
        mode = 'py'
    else:
        mode = 'pyx'
    test_support = []
    if options.evaluate_tree_assertions:
223
        from ..TestUtils import TreeAssertVisitor
224 225 226
        test_support.append(TreeAssertVisitor())

    if options.gdb_debug:
227 228
        from ..Debugger import DebugWriter # requires Py2.5+
        from .ParseTreeTransforms import DebugTransform
229 230 231 232 233 234 235 236
        context.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
            options.output_dir)
        debug_transform = [DebugTransform(context, options, result)]
    else:
        debug_transform = []

    return list(itertools.chain(
        [parse_stage_factory(context)],
237
        create_pipeline(context, mode, exclude_classes=exclude_classes),
238 239
        test_support,
        [inject_pxd_code_stage_factory(context),
240
         inject_utility_code_stage_factory(context),
241 242 243 244 245
         abort_on_errors],
        debug_transform,
        [generate_pyx_code_stage_factory(options, result)]))

def create_pxd_pipeline(context, scope, module_name):
246
    from .CodeGeneration import ExtractPxdCode
247 248 249 250 251 252

    # The pxd pipeline ends up with a CCodeWriter containing the
    # code of the pxd, as well as a pxd scope.
    return [
        parse_pxd_stage_factory(context, scope, module_name)
        ] + create_pipeline(context, 'pxd') + [
253
        ExtractPxdCode()
254 255 256 257 258 259
        ]

def create_py_pipeline(context, options, result):
    return create_pyx_pipeline(context, options, result, py=True)

def create_pyx_as_pxd_pipeline(context, result):
260
    from .ParseTreeTransforms import AlignFunctionDefinitions, \
Mark Florisson's avatar
Mark Florisson committed
261
        MarkClosureVisitor, WithTransform, AnalyseDeclarationsTransform
262 263
    from .Optimize import ConstantFolding, FlattenInListTransform
    from .Nodes import StatListNode
264
    pipeline = []
265 266 267 268 269 270 271 272
    pyx_pipeline = create_pyx_pipeline(context, context.options, result,
                                       exclude_classes=[
                                           AlignFunctionDefinitions,
                                           MarkClosureVisitor,
                                           ConstantFolding,
                                           FlattenInListTransform,
                                           WithTransform
                                           ])
273 274 275 276 277 278 279
    for stage in pyx_pipeline:
        pipeline.append(stage)
        if isinstance(stage, AnalyseDeclarationsTransform):
            # This is the last stage we need.
            break
    def fake_pxd(root):
        for entry in root.scope.entries.values():
280 281
            if not entry.in_cinclude:
                entry.defined_in_pxd = 1
Robert Bradshaw's avatar
Robert Bradshaw committed
282 283 284
                if entry.name == entry.cname and entry.visibility != 'extern':
                    # Always mangle non-extern cimported entries.
                    entry.cname = entry.scope.mangle(Naming.func_prefix, entry.name)
285 286 287 288
        return StatListNode(root.pos, stats=[]), root.scope
    pipeline.append(fake_pxd)
    return pipeline

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
def insert_into_pipeline(pipeline, transform, before=None, after=None):
    """
    Insert a new transform into the pipeline after or before an instance of
    the given class. e.g.

        pipeline = insert_into_pipeline(pipeline, transform,
                                        after=AnalyseDeclarationsTransform)
    """
    assert before or after

    cls = before or after
    for i, t in enumerate(pipeline):
        if isinstance(t, cls):
            break

    if after:
        i += 1

    return pipeline[:i] + [transform] + pipeline[i:]

309 310 311 312
#
# Running a pipeline
#

313
def run_pipeline(pipeline, source, printtree=True):
314
    from .Visitor import PrintTree
315

316 317 318 319 320 321 322 323 324
    error = None
    data = source
    try:
        try:
            for phase in pipeline:
                if phase is not None:
                    if DebugFlags.debug_verbose_pipeline:
                        t = time()
                        print "Entering pipeline phase %r" % phase
325 326
                    if not printtree and isinstance(phase, PrintTree):
                        continue
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
                    data = phase(data)
                    if DebugFlags.debug_verbose_pipeline:
                        print "    %.3f seconds" % (time() - t)
        except CompileError, err:
            # err is set
            Errors.report_error(err)
            error = err
    except InternalError, err:
        # Only raise if there was not an earlier error
        if Errors.num_errors == 0:
            raise
        error = err
    except AbortError, err:
        error = err
    return (error, data)