ParseTreeTransforms.py 91.1 KB
Newer Older
1 2

import cython
3 4 5 6 7 8 9 10 11 12 13
cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
               Options=object, UtilNodes=object, ModuleNode=object,
               LetNode=object, LetRefNode=object, TreeFragment=object,
               TemplateTransform=object, EncodedString=object,
               error=object, warning=object, copy=object)

import PyrexTypes
import Naming
import ExprNodes
import Nodes
import Options
14
import Builtin
15

16
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
17
from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
18
from Cython.Compiler.ModuleNode import ModuleNode
19
from Cython.Compiler.UtilNodes import LetNode, LetRefNode, ResultRefNode
20
from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
21
from Cython.Compiler.StringEncoding import EncodedString
22
from Cython.Compiler.Errors import error, warning, CompileError, InternalError
23

24
import copy
25

26 27 28 29 30 31 32 33 34 35 36 37

class NameNodeCollector(TreeVisitor):
    """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
    attribute.
    """
    def __init__(self):
        super(NameNodeCollector, self).__init__()
        self.name_nodes = []

    def visit_NameNode(self, node):
        self.name_nodes.append(node)

38 39 40
    def visit_Node(self, node):
        self._visitchildren(node, None)

41

42
class SkipDeclarations(object):
43
    """
44 45 46 47 48
    Variable and function declarations can often have a deep tree structure,
    and yet most transformations don't need to descend to this depth.

    Declaration nodes are removed after AnalyseDeclarationsTransform, so there
    is no need to use this for transformations after that point.
49 50 51
    """
    def visit_CTypeDefNode(self, node):
        return node
52

53 54
    def visit_CVarDefNode(self, node):
        return node
55

56 57
    def visit_CDeclaratorNode(self, node):
        return node
58

59 60
    def visit_CBaseTypeNode(self, node):
        return node
61

62 63 64 65 66 67
    def visit_CEnumDefNode(self, node):
        return node

    def visit_CStructOrUnionDefNode(self, node):
        return node

68
class NormalizeTree(CythonTransform):
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    """
    This transform fixes up a few things after parsing
    in order to make the parse tree more suitable for
    transforms.

    a) After parsing, blocks with only one statement will
    be represented by that statement, not by a StatListNode.
    When doing transforms this is annoying and inconsistent,
    as one cannot in general remove a statement in a consistent
    way and so on. This transform wraps any single statements
    in a StatListNode containing a single statement.

    b) The PassStatNode is a noop and serves no purpose beyond
    plugging such one-statement blocks; i.e., once parsed a
`    "pass" can just as well be represented using an empty
    StatListNode. This means less special cases to worry about
    in subsequent transforms (one always checks to see if a
    StatListNode has no children to see if the block is empty).
    """

89 90
    def __init__(self, context):
        super(NormalizeTree, self).__init__(context)
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
        self.is_in_statlist = False
        self.is_in_expr = False

    def visit_ExprNode(self, node):
        stacktmp = self.is_in_expr
        self.is_in_expr = True
        self.visitchildren(node)
        self.is_in_expr = stacktmp
        return node

    def visit_StatNode(self, node, is_listcontainer=False):
        stacktmp = self.is_in_statlist
        self.is_in_statlist = is_listcontainer
        self.visitchildren(node)
        self.is_in_statlist = stacktmp
        if not self.is_in_statlist and not self.is_in_expr:
107
            return Nodes.StatListNode(pos=node.pos, stats=[node])
108 109 110 111 112 113 114 115 116 117 118
        else:
            return node

    def visit_StatListNode(self, node):
        self.is_in_statlist = True
        self.visitchildren(node)
        self.is_in_statlist = False
        return node

    def visit_ParallelAssignmentNode(self, node):
        return self.visit_StatNode(node, True)
119

120 121 122 123 124 125
    def visit_CEnumDefNode(self, node):
        return self.visit_StatNode(node, True)

    def visit_CStructOrUnionDefNode(self, node):
        return self.visit_StatNode(node, True)

126 127 128
    # Eliminate PassStatNode
    def visit_PassStatNode(self, node):
        if not self.is_in_statlist:
129
            return Nodes.StatListNode(pos=node.pos, stats=[])
130 131 132
        else:
            return []

133
    def visit_CDeclaratorNode(self, node):
134
        return node
135

136

137 138 139
class PostParseError(CompileError): pass

# error strings checked by unit tests, so define them
140
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
141 142
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
143
class PostParse(ScopeTrackingTransform):
144 145 146 147 148 149 150
    """
    Basic interpretation of the parse tree, as well as validity
    checking that can be done on a very basic level on the parse
    tree (while still not being a problem with the basic syntax,
    as such).

    Specifically:
151
    - Default values to cdef assignments are turned into single
152 153
    assignments following the declaration (everywhere but in class
    bodies, where they raise a compile error)
154

155 156
    - Interpret some node structures into Python runtime values.
    Some nodes take compile-time arguments (currently:
157
    TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
158 159 160 161 162 163 164 165
    which should be interpreted. This happens in a general way
    and other steps should be taken to ensure validity.

    Type arguments cannot be interpreted in this way.

    - For __cythonbufferdefaults__ the arguments are checked for
    validity.

Robert Bradshaw's avatar
Robert Bradshaw committed
166
    TemplatedTypeNode has its directives interpreted:
167 168
    Any first positional argument goes into the "dtype" attribute,
    any "ndim" keyword argument goes into the "ndim" attribute and
169
    so on. Also it is checked that the directive combination is valid.
170 171
    - __cythonbufferdefaults__ attributes are parsed and put into the
    type information.
172 173 174 175 176 177

    Note: Currently Parsing.py does a lot of interpretation and
    reorganization that can be refactored into this transform
    if a more pure Abstract Syntax Tree is wanted.
    """

178 179 180 181 182 183
    def __init__(self, context):
        super(PostParse, self).__init__(context)
        self.specialattribute_handlers = {
            '__cythonbufferdefaults__' : self.handle_bufferdefaults
        }

184
    def visit_ModuleNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
185
        self.lambda_counter = 1
186
        self.genexpr_counter = 1
187
        return super(PostParse, self).visit_ModuleNode(node)
188

Stefan Behnel's avatar
Stefan Behnel committed
189 190 191 192 193
    def visit_LambdaNode(self, node):
        # unpack a lambda expression into the corresponding DefNode
        lambda_id = self.lambda_counter
        self.lambda_counter += 1
        node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
Vitja Makarov's avatar
Vitja Makarov committed
194 195 196
        collector = YieldNodeCollector()
        collector.visitchildren(node.result_expr)
        if collector.yields or isinstance(node.result_expr, ExprNodes.YieldExprNode):
Vitja Makarov's avatar
Vitja Makarov committed
197 198
            body = Nodes.ExprStatNode(
                node.result_expr.pos, expr=node.result_expr)
Vitja Makarov's avatar
Vitja Makarov committed
199 200 201
        else:
            body = Nodes.ReturnStatNode(
                node.result_expr.pos, value=node.result_expr)
Stefan Behnel's avatar
Stefan Behnel committed
202 203 204 205
        node.def_node = Nodes.DefNode(
            node.pos, name=node.name, lambda_name=node.lambda_name,
            args=node.args, star_arg=node.star_arg,
            starstar_arg=node.starstar_arg,
Vitja Makarov's avatar
Vitja Makarov committed
206
            body=body, doc=None)
Stefan Behnel's avatar
Stefan Behnel committed
207 208
        self.visitchildren(node)
        return node
209 210 211 212 213 214 215

    def visit_GeneratorExpressionNode(self, node):
        # unpack a generator expression into the corresponding DefNode
        genexpr_id = self.genexpr_counter
        self.genexpr_counter += 1
        node.genexpr_name = EncodedString(u'genexpr%d' % genexpr_id)

Vitja Makarov's avatar
Vitja Makarov committed
216
        node.def_node = Nodes.DefNode(node.pos, name=node.name,
217 218 219 220
                                      doc=None,
                                      args=[], star_arg=None,
                                      starstar_arg=None,
                                      body=node.loop)
Stefan Behnel's avatar
Stefan Behnel committed
221 222 223
        self.visitchildren(node)
        return node

224
    # cdef variables
225
    def handle_bufferdefaults(self, decl):
226
        if not isinstance(decl.default, ExprNodes.DictNode):
227
            raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
228 229
        self.scope_node.buffer_defaults_node = decl.default
        self.scope_node.buffer_defaults_pos = decl.pos
230

231 232
    def visit_CVarDefNode(self, node):
        # This assumes only plain names and pointers are assignable on
233 234 235
        # declaration. Also, it makes use of the fact that a cdef decl
        # must appear before the first use, so we don't have to deal with
        # "i = 3; cdef int i = i" and can simply move the nodes around.
236 237
        try:
            self.visitchildren(node)
238 239 240 241
            stats = [node]
            newdecls = []
            for decl in node.declarators:
                declbase = decl
242
                while isinstance(declbase, Nodes.CPtrDeclaratorNode):
243
                    declbase = declbase.base
244
                if isinstance(declbase, Nodes.CNameDeclaratorNode):
245
                    if declbase.default is not None:
246
                        if self.scope_type in ('cclass', 'pyclass', 'struct'):
247
                            if isinstance(self.scope_node, Nodes.CClassDefNode):
248 249 250 251 252 253 254
                                handler = self.specialattribute_handlers.get(decl.name)
                                if handler:
                                    if decl is not declbase:
                                        raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
                                    handler(decl)
                                    continue # Remove declaration
                            raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
255
                        first_assignment = self.scope_type != 'module'
256 257
                        stats.append(Nodes.SingleAssignmentNode(node.pos,
                            lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
258
                            rhs=declbase.default, first=first_assignment))
259 260 261 262
                        declbase.default = None
                newdecls.append(decl)
            node.declarators = newdecls
            return stats
263 264 265 266 267
        except PostParseError, e:
            # An error in a cdef clause is ok, simply remove the declaration
            # and try to move on to report more errors
            self.context.nonfatal_error(e)
            return None
268

Stefan Behnel's avatar
Stefan Behnel committed
269 270
    # Split parallel assignments (a,b = b,a) into separate partial
    # assignments that are executed rhs-first using temps.  This
Stefan Behnel's avatar
Stefan Behnel committed
271 272 273 274
    # restructuring must be applied before type analysis so that known
    # types on rhs and lhs can be matched directly.  It is required in
    # the case that the types cannot be coerced to a Python type in
    # order to assign from a tuple.
275 276 277 278 279 280 281 282 283 284

    def visit_SingleAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, [node.lhs, node.rhs])

    def visit_CascadedAssignmentNode(self, node):
        self.visitchildren(node)
        return self._visit_assignment_node(node, node.lhs_list + [node.rhs])

    def _visit_assignment_node(self, node, expr_list):
285 286 287
        """Flatten parallel assignments into separate single
        assignments or cascaded assignments.
        """
288 289 290 291
        if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2:
            # no parallel assignments => nothing to do
            return node

292 293
        expr_list_list = []
        flatten_parallel_assignments(expr_list, expr_list_list)
294 295 296
        temp_refs = []
        eliminate_rhs_duplicates(expr_list_list, temp_refs)

297 298 299 300 301
        nodes = []
        for expr_list in expr_list_list:
            lhs_list = expr_list[:-1]
            rhs = expr_list[-1]
            if len(lhs_list) == 1:
302
                node = Nodes.SingleAssignmentNode(rhs.pos,
303 304 305 306 307
                    lhs = lhs_list[0], rhs = rhs)
            else:
                node = Nodes.CascadedAssignmentNode(rhs.pos,
                    lhs_list = lhs_list, rhs = rhs)
            nodes.append(node)
308

309
        if len(nodes) == 1:
310 311 312 313 314 315 316 317 318 319 320 321 322
            assign_node = nodes[0]
        else:
            assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)

        if temp_refs:
            duplicates_and_temps = [ (temp.expression, temp)
                                     for temp in temp_refs ]
            sort_common_subsequences(duplicates_and_temps)
            for _, temp_ref in duplicates_and_temps[::-1]:
                assign_node = LetNode(temp_ref, assign_node)

        return assign_node

323 324 325 326 327 328 329 330 331 332 333 334 335 336
    def _flatten_sequence(self, seq, result):
        for arg in seq.args:
            if arg.is_sequence_constructor:
                self._flatten_sequence(arg, result)
            else:
                result.append(arg)
        return result

    def visit_DelStatNode(self, node):
        self.visitchildren(node)
        node.args = self._flatten_sequence(node, [])
        return node


337 338 339 340 341 342
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
    """Replace rhs items by LetRefNodes if they appear more than once.
    Creates a sequence of LetRefNodes that set up the required temps
    and appends them to ref_node_sequence.  The input list is modified
    in-place.
    """
Stefan Behnel's avatar
Stefan Behnel committed
343
    seen_nodes = cython.set()
344 345 346 347 348 349 350 351 352 353 354
    ref_nodes = {}
    def find_duplicates(node):
        if node.is_literal or node.is_name:
            # no need to replace those; can't include attributes here
            # as their access is not necessarily side-effect free
            return
        if node in seen_nodes:
            if node not in ref_nodes:
                ref_node = LetRefNode(node)
                ref_nodes[node] = ref_node
                ref_node_sequence.append(ref_node)
355
        else:
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
            seen_nodes.add(node)
            if node.is_sequence_constructor:
                for item in node.args:
                    find_duplicates(item)

    for expr_list in expr_list_list:
        rhs = expr_list[-1]
        find_duplicates(rhs)
    if not ref_nodes:
        return

    def substitute_nodes(node):
        if node in ref_nodes:
            return ref_nodes[node]
        elif node.is_sequence_constructor:
371
            node.args = list(map(substitute_nodes, node.args))
372
        return node
373

374 375 376
    # replace nodes inside of the common subexpressions
    for node in ref_nodes:
        if node.is_sequence_constructor:
377
            node.args = list(map(substitute_nodes, node.args))
378 379 380 381 382 383 384

    # replace common subexpressions on all rhs items
    for expr_list in expr_list_list:
        expr_list[-1] = substitute_nodes(expr_list[-1])

def sort_common_subsequences(items):
    """Sort items/subsequences so that all items and subsequences that
Stefan Behnel's avatar
Stefan Behnel committed
385 386 387 388 389 390 391 392 393
    an item contains appear before the item itself.  This is needed
    because each rhs item must only be evaluated once, so its value
    must be evaluated first and then reused when packing sequences
    that contain it.

    This implies a partial order, and the sort must be stable to
    preserve the original order as much as possible, so we use a
    simple insertion sort (which is very fast for short sequences, the
    normal case in practice).
394 395 396 397 398 399 400 401 402 403 404 405
    """
    def contains(seq, x):
        for item in seq:
            if item is x:
                return True
            elif item.is_sequence_constructor and contains(item.args, x):
                return True
        return False
    def lower_than(a,b):
        return b.is_sequence_constructor and contains(b.args, a)

    for pos, item in enumerate(items):
406
        key = item[1] # the ResultRefNode which has already been injected into the sequences
407 408 409 410 411 412 413 414
        new_pos = pos
        for i in xrange(pos-1, -1, -1):
            if lower_than(key, items[i][0]):
                new_pos = i
        if new_pos != pos:
            for i in xrange(pos, new_pos, -1):
                items[i] = items[i-1]
            items[new_pos] = item
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430

def flatten_parallel_assignments(input, output):
    #  The input is a list of expression nodes, representing the LHSs
    #  and RHS of one (possibly cascaded) assignment statement.  For
    #  sequence constructors, rearranges the matching parts of both
    #  sides into a list of equivalent assignments between the
    #  individual elements.  This transformation is applied
    #  recursively, so that nested structures get matched as well.
    rhs = input[-1]
    if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
        output.append(input)
        return

    complete_assignments = []

    rhs_size = len(rhs.args)
Stefan Behnel's avatar
Stefan Behnel committed
431
    lhs_targets = [ [] for _ in xrange(rhs_size) ]
432 433 434 435 436 437 438 439 440
    starred_assignments = []
    for lhs in input[:-1]:
        if not lhs.is_sequence_constructor:
            if lhs.is_starred:
                error(lhs.pos, "starred assignment target must be in a list or tuple")
            complete_assignments.append(lhs)
            continue
        lhs_size = len(lhs.args)
        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
Stefan Behnel's avatar
Stefan Behnel committed
441 442 443 444 445 446 447 448 449
        if starred_targets > 1:
            error(lhs.pos, "more than 1 starred expression in assignment")
            output.append([lhs,rhs])
            continue
        elif lhs_size - starred_targets > rhs_size:
            error(lhs.pos, "need more than %d value%s to unpack"
                  % (rhs_size, (rhs_size != 1) and 's' or ''))
            output.append([lhs,rhs])
            continue
Stefan Behnel's avatar
Stefan Behnel committed
450
        elif starred_targets:
451 452
            map_starred_assignment(lhs_targets, starred_assignments,
                                   lhs.args, rhs.args)
Stefan Behnel's avatar
Stefan Behnel committed
453 454 455 456 457
        elif lhs_size < rhs_size:
            error(lhs.pos, "too many values to unpack (expected %d, got %d)"
                  % (lhs_size, rhs_size))
            output.append([lhs,rhs])
            continue
458
        else:
Stefan Behnel's avatar
Stefan Behnel committed
459 460
            for targets, expr in zip(lhs_targets, lhs.args):
                targets.append(expr)
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498

    if complete_assignments:
        complete_assignments.append(rhs)
        output.append(complete_assignments)

    # recursively flatten partial assignments
    for cascade, rhs in zip(lhs_targets, rhs.args):
        if cascade:
            cascade.append(rhs)
            flatten_parallel_assignments(cascade, output)

    # recursively flatten starred assignments
    for cascade in starred_assignments:
        if cascade[0].is_sequence_constructor:
            flatten_parallel_assignments(cascade, output)
        else:
            output.append(cascade)

def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
    # Appends the fixed-position LHS targets to the target list that
    # appear left and right of the starred argument.
    #
    # The starred_assignments list receives a new tuple
    # (lhs_target, rhs_values_list) that maps the remaining arguments
    # (those that match the starred target) to a list.

    # left side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
        if expr.is_starred:
            starred = i
            lhs_remaining = len(lhs_args) - i - 1
            break
        targets.append(expr)
    else:
        raise InternalError("no starred arg found when splitting starred assignment")

    # right side of the starred target
    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
Vitja Makarov's avatar
Vitja Makarov committed
499
                                            lhs_args[starred + 1:])):
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
        targets.append(expr)

    # the starred target itself, must be assigned a (potentially empty) list
    target = lhs_args[starred].target # unpack starred node
    starred_rhs = rhs_args[starred:]
    if lhs_remaining:
        starred_rhs = starred_rhs[:-lhs_remaining]
    if starred_rhs:
        pos = starred_rhs[0].pos
    else:
        pos = target.pos
    starred_assignments.append([
        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])


515
class PxdPostParse(CythonTransform, SkipDeclarations):
516 517 518
    """
    Basic interpretation/validity checking that should only be
    done on pxd trees.
519 520 521 522 523 524

    A lot of this checking currently happens in the parser; but
    what is listed below happens here.

    - "def" functions are let through only if they fill the
    getbuffer/releasebuffer slots
525

526 527
    - cdef functions are let through only if they are on the
    top level and are declared "inline"
528
    """
529 530
    ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
    ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545

    def __call__(self, node):
        self.scope_type = 'pxd'
        return super(PxdPostParse, self).__call__(node)

    def visit_CClassDefNode(self, node):
        old = self.scope_type
        self.scope_type = 'cclass'
        self.visitchildren(node)
        self.scope_type = old
        return node

    def visit_FuncDefNode(self, node):
        # FuncDefNode always come with an implementation (without
        # an imp they are CVarDefNodes..)
546
        err = self.ERR_INLINE_ONLY
547

548
        if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
549
            and node.name in ('__getbuffer__', '__releasebuffer__')):
550
            err = None # allow these slots
551

552
        if isinstance(node, Nodes.CFuncDefNode):
553 554 555 556 557 558 559 560 561
            if u'inline' in node.modifiers and self.scope_type == 'pxd':
                node.inline_in_pxd = True
                if node.visibility != 'private':
                    err = self.ERR_NOGO_WITH_INLINE % node.visibility
                elif node.api:
                    err = self.ERR_NOGO_WITH_INLINE % 'api'
                else:
                    err = None # allow inline function
            else:
562 563
                err = self.ERR_INLINE_ONLY

564 565
        if err:
            self.context.nonfatal_error(PostParseError(node.pos, err))
566 567 568
            return None
        else:
            return node
569

570
class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
571
    """
572
    After parsing, directives can be stored in a number of places:
573 574
    - #cython-comments at the top of the file (stored in ModuleNode)
    - Command-line arguments overriding these
575 576
    - @cython.directivename decorators
    - with cython.directivename: statements
577

578
    This transform is responsible for interpreting these various sources
579
    and store the directive in two ways:
580 581 582 583 584 585 586 587 588 589 590
    - Set the directives attribute of the ModuleNode for global directives.
    - Use a CompilerDirectivesNode to override directives for a subtree.

    (The first one is primarily to not have to modify with the tree
    structure, so that ModuleNode stay on top.)

    The directives are stored in dictionaries from name to value in effect.
    Each such dictionary is always filled in for all possible directives,
    using default values where no value is given by the user.

    The available directives are controlled in Options.py.
591 592 593

    Note that we have to run this prior to analysis, and so some minor
    duplication of functionality has to occur: We manually track cimports
594
    and which names the "cython" module may have been imported to.
595
    """
596
    unop_method_nodes = {
597
        'typeof': ExprNodes.TypeofNode,
598

599 600 601 602 603 604
        'operator.address': ExprNodes.AmpersandNode,
        'operator.dereference': ExprNodes.DereferenceNode,
        'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
        'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
        'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
        'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
605

606
        # For backwards compatability.
607
        'address': ExprNodes.AmpersandNode,
608
    }
Robert Bradshaw's avatar
Robert Bradshaw committed
609 610

    binop_method_nodes = {
611
        'operator.comma'        : ExprNodes.c_binop_constructor(','),
Robert Bradshaw's avatar
Robert Bradshaw committed
612
    }
613

Stefan Behnel's avatar
Stefan Behnel committed
614
    special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
Mark Florisson's avatar
Mark Florisson committed
615
                                  'cast', 'pointer', 'compiled', 'NULL', 'parallel'])
Stefan Behnel's avatar
Stefan Behnel committed
616
    special_methods.update(unop_method_nodes.keys())
617

Mark Florisson's avatar
Mark Florisson committed
618 619 620 621 622 623 624
    valid_parallel_directives = cython.set([
        "parallel",
        "prange",
        "threadid",
#        "threadsavailable",
    ])

625
    def __init__(self, context, compilation_directive_defaults):
626
        super(InterpretCompilerDirectives, self).__init__(context)
627
        self.compilation_directive_defaults = {}
628
        for key, value in compilation_directive_defaults.items():
629
            self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value)
Stefan Behnel's avatar
Stefan Behnel committed
630
        self.cython_module_names = cython.set()
631
        self.directive_names = {}
Mark Florisson's avatar
Mark Florisson committed
632
        self.parallel_directives = {}
633

634
    def check_directive_scope(self, pos, directive, scope):
635
        legal_scopes = Options.directive_scopes.get(directive, None)
636 637 638 639 640
        if legal_scopes and scope not in legal_scopes:
            self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
                                        'is not allowed in %s scope' % (directive, scope)))
            return False
        else:
641
            if directive not in Options.directive_defaults:
642
                error(pos, "Invalid directive: '%s'." % (directive,))
643
            return True
644

645
    # Set up processing and handle the cython: comments.
646
    def visit_ModuleNode(self, node):
647
        for key, value in node.directive_comments.items():
648 649
            if not self.check_directive_scope(node.pos, key, 'module'):
                self.wrong_scope_error(node.pos, key, 'module')
650 651
                del node.directive_comments[key]

652 653
        directives = copy.deepcopy(Options.directive_defaults)
        directives.update(copy.deepcopy(self.compilation_directive_defaults))
654 655 656
        directives.update(node.directive_comments)
        self.directives = directives
        node.directives = directives
Mark Florisson's avatar
Mark Florisson committed
657
        node.parallel_directives = self.parallel_directives
658
        self.visitchildren(node)
659
        node.cython_module_names = self.cython_module_names
660 661
        return node

662 663 664 665 666 667 668
    # The following four functions track imports and cimports that
    # begin with "cython"
    def is_cython_directive(self, name):
        return (name in Options.directive_types or
                name in self.special_methods or
                PyrexTypes.parse_basic_type(name))

Mark Florisson's avatar
Mark Florisson committed
669
    def is_parallel_directive(self, full_name, pos):
Mark Florisson's avatar
Mark Florisson committed
670 671 672 673 674
        """
        Checks to see if fullname (e.g. cython.parallel.prange) is a valid
        parallel directive. If it is a star import it also updates the
        parallel_directives.
        """
Mark Florisson's avatar
Mark Florisson committed
675 676 677
        result = (full_name + ".").startswith("cython.parallel.")

        if result:
Mark Florisson's avatar
Mark Florisson committed
678 679
            directive = full_name.split('.')
            if full_name == u"cython.parallel.*":
680 681
                for name in self.valid_parallel_directives:
                    self.parallel_directives[name] = u"cython.parallel.%s" % name
Mark Florisson's avatar
Mark Florisson committed
682 683
            elif (len(directive) != 3 or
                  directive[-1] not in self.valid_parallel_directives):
Mark Florisson's avatar
Mark Florisson committed
684 685 686 687
                error(pos, "No such directive: %s" % full_name)

        return result

688 689
    def visit_CImportStatNode(self, node):
        if node.module_name == u"cython":
690
            self.cython_module_names.add(node.as_name or u"cython")
691
        elif node.module_name.startswith(u"cython."):
Mark Florisson's avatar
Mark Florisson committed
692 693 694
            if node.module_name.startswith(u"cython.parallel."):
                error(node.pos, node.module_name + " is not a module")
            if node.module_name == u"cython.parallel":
695
                if node.as_name and node.as_name != u"cython":
Mark Florisson's avatar
Mark Florisson committed
696 697 698 699 700 701
                    self.parallel_directives[node.as_name] = node.module_name
                else:
                    self.cython_module_names.add(u"cython")
                    self.parallel_directives[
                                    u"cython.parallel"] = node.module_name
            elif node.as_name:
702
                self.directive_names[node.as_name] = node.module_name[7:]
703
            else:
704
                self.cython_module_names.add(u"cython")
705 706 707
            # if this cimport was a compiler directive, we don't
            # want to leave the cimport node sitting in the tree
            return None
708
        return node
709

710
    def visit_FromCImportStatNode(self, node):
711 712
        if (node.module_name == u"cython") or \
               node.module_name.startswith(u"cython."):
713
            submodule = (node.module_name + u".")[7:]
714
            newimp = []
Mark Florisson's avatar
Mark Florisson committed
715

716
            for pos, name, as_name, kind in node.imported_names:
717
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
718 719 720 721 722 723 724
                qualified_name = u"cython." + full_name

                if self.is_parallel_directive(qualified_name, node.pos):
                    # from cython cimport parallel, or
                    # from cython.parallel cimport parallel, prange, ...
                    self.parallel_directives[as_name or name] = qualified_name
                elif self.is_cython_directive(full_name):
Robert Bradshaw's avatar
Robert Bradshaw committed
725
                    if as_name is None:
726
                        as_name = full_name
Mark Florisson's avatar
Mark Florisson committed
727

728
                    self.directive_names[as_name] = full_name
729 730
                    if kind is not None:
                        self.context.nonfatal_error(PostParseError(pos,
731
                            "Compiler directive imports must be plain imports"))
732 733
                else:
                    newimp.append((pos, name, as_name, kind))
Mark Florisson's avatar
Mark Florisson committed
734

Robert Bradshaw's avatar
Robert Bradshaw committed
735 736
            if not newimp:
                return None
Mark Florisson's avatar
Mark Florisson committed
737

Robert Bradshaw's avatar
Robert Bradshaw committed
738
            node.imported_names = newimp
739
        return node
740

Robert Bradshaw's avatar
Robert Bradshaw committed
741
    def visit_FromImportStatNode(self, node):
742 743
        if (node.module.module_name.value == u"cython") or \
               node.module.module_name.value.startswith(u"cython."):
744
            submodule = (node.module.module_name.value + u".")[7:]
Robert Bradshaw's avatar
Robert Bradshaw committed
745
            newimp = []
746
            for name, name_node in node.items:
747
                full_name = submodule + name
Mark Florisson's avatar
Mark Florisson committed
748 749 750 751
                qualified_name = u"cython." + full_name
                if self.is_parallel_directive(qualified_name, node.pos):
                    self.parallel_directives[name_node.name] = qualified_name
                elif self.is_cython_directive(full_name):
752
                    self.directive_names[name_node.name] = full_name
Robert Bradshaw's avatar
Robert Bradshaw committed
753
                else:
754
                    newimp.append((name, name_node))
Robert Bradshaw's avatar
Robert Bradshaw committed
755 756 757 758 759
            if not newimp:
                return None
            node.items = newimp
        return node

760
    def visit_SingleAssignmentNode(self, node):
761 762 763 764 765 766
        if isinstance(node.rhs, ExprNodes.ImportNode):
            module_name = node.rhs.module_name.value
            is_parallel = (module_name + u".").startswith(u"cython.parallel.")

            if module_name != u"cython" and not is_parallel:
                return node
Mark Florisson's avatar
Mark Florisson committed
767 768 769 770

            module_name = node.rhs.module_name.value
            as_name = node.lhs.name

771
            node = Nodes.CImportStatNode(node.pos,
Mark Florisson's avatar
Mark Florisson committed
772 773
                                         module_name = module_name,
                                         as_name = as_name)
774
            node = self.visit_CImportStatNode(node)
775 776
        else:
            self.visitchildren(node)
777

778
        return node
779

780 781 782
    def visit_NameNode(self, node):
        if node.name in self.cython_module_names:
            node.is_cython_module = True
Robert Bradshaw's avatar
Robert Bradshaw committed
783
        else:
784
            node.cython_attribute = self.directive_names.get(node.name)
785
        return node
786

787
    def try_to_parse_directives(self, node):
788
        # If node is the contents of an directive (in a with statement or
789
        # decorator), returns a list of (directivename, value) pairs.
790
        # Otherwise, returns None
791
        if isinstance(node, ExprNodes.CallNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
792
            self.visit(node.function)
793
            optname = node.function.as_cython_attribute()
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype:
                    args, kwds = node.explicit_args_kwds()
                    directives = []
                    key_value_pairs = []
                    if kwds is not None and directivetype is not dict:
                        for keyvalue in kwds.key_value_pairs:
                            key, value = keyvalue
                            sub_optname = "%s.%s" % (optname, key.value)
                            if Options.directive_types.get(sub_optname):
                                directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
                            else:
                                key_value_pairs.append(keyvalue)
                        if not key_value_pairs:
                            kwds = None
                        else:
                            kwds.key_value_pairs = key_value_pairs
                        if directives and not kwds and not args:
                            return directives
                    directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
                    return directives
816
        elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
817 818 819 820 821 822 823 824 825 826 827
            self.visit(node)
            optname = node.as_cython_attribute()
            if optname:
                directivetype = Options.directive_types.get(optname)
                if directivetype is bool:
                    return [(optname, True)]
                elif directivetype is None:
                    return [(optname, None)]
                else:
                    raise PostParseError(
                        node.pos, "The '%s' directive should be used as a function call." % optname)
828
        return None
829

830 831
    def try_to_parse_directive(self, optname, args, kwds, pos):
        directivetype = Options.directive_types.get(optname)
832
        if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
833
            return optname, Options.directive_defaults[optname]
834
        elif directivetype is bool:
835
            if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
836 837 838 839
                raise PostParseError(pos,
                    'The %s directive takes one compile-time boolean argument' % optname)
            return (optname, args[0].value)
        elif directivetype is str:
840 841
            if kwds is not None or len(args) != 1 or not isinstance(args[0], (ExprNodes.StringNode,
                                                                              ExprNodes.UnicodeNode)):
842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857
                raise PostParseError(pos,
                    'The %s directive takes one compile-time string argument' % optname)
            return (optname, str(args[0].value))
        elif directivetype is dict:
            if len(args) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no prepositional arguments' % optname)
            return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
        elif directivetype is list:
            if kwds and len(kwds) != 0:
                raise PostParseError(pos,
                    'The %s directive takes no keyword arguments' % optname)
            return optname, [ str(arg.value) for arg in args ]
        else:
            assert False

858 859 860 861 862
    def visit_with_directives(self, body, directives):
        olddirectives = self.directives
        newdirectives = copy.copy(olddirectives)
        newdirectives.update(directives)
        self.directives = newdirectives
863
        assert isinstance(body, Nodes.StatListNode), body
864
        retbody = self.visit_Node(body)
865 866
        directive = Nodes.CompilerDirectivesNode(pos=retbody.pos, body=retbody,
                                                 directives=newdirectives)
867
        self.directives = olddirectives
868
        return directive
869

870
    # Handle decorators
871
    def visit_FuncDefNode(self, node):
872 873 874
        directives = self._extract_directives(node, 'function')
        if not directives:
            return self.visit_Node(node)
875
        body = Nodes.StatListNode(node.pos, stats=[node])
876 877 878
        return self.visit_with_directives(body, directives)

    def visit_CVarDefNode(self, node):
879 880
        if not node.decorators:
            return node
881 882 883 884
        for dec in node.decorators:
            for directive in self.try_to_parse_directives(dec.decorator) or ():
                if directive is not None and directive[0] == u'locals':
                    node.directive_locals = directive[1]
885
                else:
886 887 888 889 890 891 892 893
                    self.context.nonfatal_error(PostParseError(dec.pos,
                        "Cdef functions can only take cython.locals() decorator."))
        return node

    def visit_CClassDefNode(self, node):
        directives = self._extract_directives(node, 'cclass')
        if not directives:
            return self.visit_Node(node)
894
        body = Nodes.StatListNode(node.pos, stats=[node])
895 896
        return self.visit_with_directives(body, directives)

897 898 899 900
    def visit_PyClassDefNode(self, node):
        directives = self._extract_directives(node, 'class')
        if not directives:
            return self.visit_Node(node)
901
        body = Nodes.StatListNode(node.pos, stats=[node])
902 903
        return self.visit_with_directives(body, directives)

904 905 906 907 908 909 910 911 912 913 914 915
    def _extract_directives(self, node, scope_name):
        if not node.decorators:
            return {}
        # Split the decorators into two lists -- real decorators and directives
        directives = []
        realdecs = []
        for dec in node.decorators:
            new_directives = self.try_to_parse_directives(dec.decorator)
            if new_directives is not None:
                for directive in new_directives:
                    if self.check_directive_scope(node.pos, directive[0], scope_name):
                        directives.append(directive)
916
            else:
917
                realdecs.append(dec)
918
        if realdecs and isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode)):
919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
            raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
        else:
            node.decorators = realdecs
        # merge or override repeated directives
        optdict = {}
        directives.reverse() # Decorators coming first take precedence
        for directive in directives:
            name, value = directive
            if name in optdict:
                old_value = optdict[name]
                # keywords and arg lists can be merged, everything
                # else overrides completely
                if isinstance(old_value, dict):
                    old_value.update(value)
                elif isinstance(old_value, list):
                    old_value.extend(value)
935 936
                else:
                    optdict[name] = value
937 938 939 940
            else:
                optdict[name] = value
        return optdict

941 942
    # Handle with statements
    def visit_WithStatNode(self, node):
943 944 945 946 947 948 949 950
        directive_dict = {}
        for directive in self.try_to_parse_directives(node.manager) or []:
            if directive is not None:
                if node.target is not None:
                    self.context.nonfatal_error(
                        PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
                else:
                    name, value = directive
951
                    if name in ('nogil', 'gil'):
952
                        # special case: in pure mode, "with nogil" spells "with cython.nogil"
953
                        node = Nodes.GILStatNode(node.pos, state = name, body = node.body)
954
                        return self.visit_Node(node)
955 956 957 958
                    if self.check_directive_scope(node.pos, name, 'with statement'):
                        directive_dict[name] = value
        if directive_dict:
            return self.visit_with_directives(node.body, directive_dict)
959
        return self.visit_Node(node)
960

Mark Florisson's avatar
Mark Florisson committed
961 962 963 964 965 966 967

class ParallelRangeTransform(CythonTransform, SkipDeclarations):
    """
    Transform cython.parallel stuff. The parallel_directives come from the
    module node, set there by InterpretCompilerDirectives.

        x = cython.parallel.threadavailable()   -> ParallelThreadAvailableNode
968
        with nogil, cython.parallel.parallel(): -> ParallelWithBlockNode
Mark Florisson's avatar
Mark Florisson committed
969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986
            print cython.parallel.threadid()    -> ParallelThreadIdNode
            for i in cython.parallel.prange(...):  -> ParallelRangeNode
                ...
    """

    # a list of names, maps 'cython.parallel.prange' in the code to
    # ['cython', 'parallel', 'prange']
    parallel_directive = None

    # Indicates whether a namenode in an expression is the cython module
    namenode_is_cython_module = False

    # Keep track of whether we are the context manager of a 'with' statement
    in_context_manager_section = False

    # Keep track of whether we are in a parallel range section
    in_prange = False

987 988 989 990
    # One of 'prange' or 'with parallel'. This is used to disallow closely
    # nested 'with parallel:' blocks
    state = None

Mark Florisson's avatar
Mark Florisson committed
991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052
    directive_to_node = {
        u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
        # u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
        u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
        u"cython.parallel.prange": Nodes.ParallelRangeNode,
    }

    def node_is_parallel_directive(self, node):
        return node.name in self.parallel_directives or node.is_cython_module

    def get_directive_class_node(self, node):
        """
        Figure out which parallel directive was used and return the associated
        Node class.

        E.g. for a cython.parallel.prange() call we return ParallelRangeNode

        Also disallow break, continue and return in a prange section
        """
        if self.namenode_is_cython_module:
            directive = '.'.join(self.parallel_directive)
        else:
            directive = self.parallel_directives[self.parallel_directive[0]]
            directive = '%s.%s' % (directive,
                                   '.'.join(self.parallel_directive[1:]))
            directive = directive.rstrip('.')

        cls = self.directive_to_node.get(directive)
        if cls is None:
            error(node.pos, "Invalid directive: %s" % directive)

        self.namenode_is_cython_module = False
        self.parallel_directive = None

        return cls

    def visit_ModuleNode(self, node):
        """
        If any parallel directives were imported, copy them over and visit
        the AST
        """
        if node.parallel_directives:
            self.parallel_directives = node.parallel_directives
            self.assignment_stack = []
            return self.visit_Node(node)

        # No parallel directives were imported, so they can't be used :)
        return node

    def visit_NameNode(self, node):
        if self.node_is_parallel_directive(node):
            self.parallel_directive = [node.name]
            self.namenode_is_cython_module = node.is_cython_module
        return node

    def visit_AttributeNode(self, node):
        self.visitchildren(node)
        if self.parallel_directive:
            self.parallel_directive.append(node.attribute)
        return node

    def visit_CallNode(self, node):
1053
        self.visit(node.function)
Mark Florisson's avatar
Mark Florisson committed
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068
        if not self.parallel_directive:
            return node

        # We are a parallel directive, replace this node with the
        # corresponding ParallelSomethingSomething node

        if isinstance(node, ExprNodes.GeneralCallNode):
            args = node.positional_args.args
            kwargs = node.keyword_args
        else:
            args = node.args
            kwargs = {}

        parallel_directive_class = self.get_directive_class_node(node)
        if parallel_directive_class:
1069 1070
            # Note: in case of a parallel() the body is set by
            # visit_WithStatNode
Mark Florisson's avatar
Mark Florisson committed
1071 1072 1073 1074 1075
            node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)

        return node

    def visit_WithStatNode(self, node):
1076 1077
        "Rewrite with cython.parallel.parallel() blocks"
        newnode = self.visit(node.manager)
Mark Florisson's avatar
Mark Florisson committed
1078

1079
        if isinstance(newnode, Nodes.ParallelWithBlockNode):
1080 1081 1082 1083 1084
            if self.state == 'parallel with':
                error(node.manager.pos,
                      "Closely nested 'with parallel:' blocks are disallowed")

            self.state = 'parallel with'
1085
            body = self.visit(node.body)
1086
            self.state = None
Mark Florisson's avatar
Mark Florisson committed
1087

1088 1089 1090 1091
            newnode.body = body
            return newnode
        elif self.parallel_directive:
            parallel_directive_class = self.get_directive_class_node(node)
1092

1093 1094 1095
            if not parallel_directive_class:
                # There was an error, stop here and now
                return None
Mark Florisson's avatar
Mark Florisson committed
1096

1097 1098 1099
            if parallel_directive_class is Nodes.ParallelWithBlockNode:
                error(node.pos, "The parallel directive must be called")
                return None
Mark Florisson's avatar
Mark Florisson committed
1100

1101 1102
        node.body = self.visit(node.body)
        return node
Mark Florisson's avatar
Mark Florisson committed
1103 1104 1105 1106 1107 1108 1109 1110 1111

    def visit_ForInStatNode(self, node):
        "Rewrite 'for i in cython.parallel.prange(...):'"
        self.visit(node.iterator)
        self.visit(node.target)

        was_in_prange = self.in_prange
        self.in_prange = isinstance(node.iterator.sequence,
                                    Nodes.ParallelRangeNode)
1112
        previous_state = self.state
Mark Florisson's avatar
Mark Florisson committed
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128

        if self.in_prange:
            # This will replace the entire ForInStatNode, so copy the
            # attributes
            parallel_range_node = node.iterator.sequence

            parallel_range_node.target = node.target
            parallel_range_node.body = node.body
            parallel_range_node.else_clause = node.else_clause

            node = parallel_range_node

            if not isinstance(node.target, ExprNodes.NameNode):
                error(node.target.pos,
                      "Can only iterate over an iteration variable")

1129
            self.state = 'prange'
Mark Florisson's avatar
Mark Florisson committed
1130

1131 1132
        self.visit(node.body)
        self.state = previous_state
Mark Florisson's avatar
Mark Florisson committed
1133 1134
        self.in_prange = was_in_prange

1135
        self.visit(node.else_clause)
Mark Florisson's avatar
Mark Florisson committed
1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157
        return node

    def ensure_not_in_prange(name):
        "Creates error checking functions for break, continue and return"
        def visit_method(self, node):
            if self.in_prange:
                error(node.pos,
                      name + " not allowed in a parallel range section")

            # Do a visit for 'return'
            self.visitchildren(node)
            return node

        return visit_method

    visit_BreakStatNode = ensure_not_in_prange("break")
    visit_ContinueStatNode = ensure_not_in_prange("continue")
    visit_ReturnStatNode = ensure_not_in_prange("return")

    def visit(self, node):
        "Visit a node that may be None"
        if node is not None:
1158
            return super(ParallelRangeTransform, self).visit(node)
Mark Florisson's avatar
Mark Florisson committed
1159 1160


1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215
class WithTransform(CythonTransform, SkipDeclarations):
    def visit_WithStatNode(self, node):
        self.visitchildren(node, 'body')
        pos = node.pos
        body, target, manager = node.body, node.target, node.manager
        node.target_temp = ExprNodes.TempNode(pos, type=PyrexTypes.py_object_type)
        if target is not None:
            node.has_target = True
            body = Nodes.StatListNode(
                pos, stats = [
                    Nodes.WithTargetAssignmentStatNode(
                        pos, lhs = target, rhs = node.target_temp),
                    body
                    ])
            node.target = None

        excinfo_target = ResultRefNode(
            pos=pos, type=Builtin.tuple_type, may_hold_none=False)
        except_clause = Nodes.ExceptClauseNode(
            pos, body = Nodes.IfStatNode(
                pos, if_clauses = [
                    Nodes.IfClauseNode(
                        pos, condition = ExprNodes.NotNode(
                            pos, operand = ExprNodes.WithExitCallNode(
                                pos, with_stat = node,
                                args = excinfo_target)),
                        body = Nodes.ReraiseStatNode(pos),
                        ),
                    ],
                else_clause = None),
            pattern = None,
            target = None,
            excinfo_target = excinfo_target,
            )

        node.body = Nodes.TryFinallyStatNode(
            pos, body = Nodes.TryExceptStatNode(
                pos, body = body,
                except_clauses = [except_clause],
                else_clause = None,
                ),
            finally_clause = Nodes.ExprStatNode(
                pos, expr = ExprNodes.WithExitCallNode(
                    pos, with_stat = node,
                    args = ExprNodes.TupleNode(
                        pos, args = [ExprNodes.NoneNode(pos) for _ in range(3)]
                        ))),
            handle_error_case = False,
            )
        return node

    def visit_ExprNode(self, node):
        # With statements are never inside expressions.
        return node

1216

1217
class DecoratorTransform(CythonTransform, SkipDeclarations):
1218

1219
    def visit_DefNode(self, func_node):
1220
        self.visitchildren(func_node)
1221 1222
        if not func_node.decorators:
            return func_node
1223 1224 1225
        return self._handle_decorators(
            func_node, func_node.name)

1226 1227
    def visit_CClassDefNode(self, class_node):
        # This doesn't currently work, so it's disabled.
1228 1229 1230 1231 1232 1233 1234 1235 1236
        #
        # Problem: assignments to cdef class names do not work.  They
        # would require an additional check anyway, as the extension
        # type must not change its C type, so decorators cannot
        # replace an extension type, just alter it and return it.

        self.visitchildren(class_node)
        if not class_node.decorators:
            return class_node
1237 1238 1239 1240 1241
        error(class_node.pos,
              "Decorators not allowed on cdef classes (used on type '%s')" % class_node.class_name)
        return class_node
        #return self._handle_decorators(
        #    class_node, class_node.class_name)
1242 1243 1244 1245 1246 1247 1248 1249 1250

    def visit_ClassDefNode(self, class_node):
        self.visitchildren(class_node)
        if not class_node.decorators:
            return class_node
        return self._handle_decorators(
            class_node, class_node.name)

    def _handle_decorators(self, node, name):
1251
        decorator_result = ExprNodes.NameNode(node.pos, name = name)
1252
        for decorator in node.decorators[::-1]:
1253
            decorator_result = ExprNodes.SimpleCallNode(
1254 1255 1256 1257
                decorator.pos,
                function = decorator.decorator,
                args = [decorator_result])

1258 1259
        name_node = ExprNodes.NameNode(node.pos, name = name)
        reassignment = Nodes.SingleAssignmentNode(
1260 1261
            node.pos,
            lhs = name_node,
1262
            rhs = decorator_result)
1263
        return [node, reassignment]
1264

1265

1266
class AnalyseDeclarationsTransform(CythonTransform):
1267

1268 1269 1270 1271 1272 1273 1274
    basic_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    """, level='c_class')
1275 1276 1277 1278 1279 1280 1281 1282 1283
    basic_pyobject_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    def __del__(self):
        ATTR = None
    """, level='c_class')
1284 1285 1286 1287 1288
    basic_property_ro = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    """, level='c_class')
1289

1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309
    struct_or_union_wrapper = TreeFragment(u"""
cdef class NAME:
    cdef TYPE value
    def __init__(self, MEMBER=None):
        cdef int count
        count = 0
        INIT_ASSIGNMENTS
        if IS_UNION and count > 1:
            raise ValueError, "At most one union member should be specified."
    def __str__(self):
        return STR_FORMAT % MEMBER_TUPLE
    def __repr__(self):
        return REPR_FORMAT % MEMBER_TUPLE
    """)

    init_assignment = TreeFragment(u"""
if VALUE is not None:
    ATTR = VALUE
    count += 1
    """)
1310

1311 1312
    def __call__(self, root):
        self.env_stack = [root.scope]
1313
        # needed to determine if a cdef var is declared after it's used.
1314
        self.seen_vars_stack = []
1315 1316
        return super(AnalyseDeclarationsTransform, self).__call__(root)

1317
    def visit_NameNode(self, node):
1318
        self.seen_vars_stack[-1].add(node.name)
1319 1320
        return node

1321
    def visit_ModuleNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1322
        self.seen_vars_stack.append(cython.set())
1323 1324
        node.analyse_declarations(self.env_stack[-1])
        self.visitchildren(node)
1325
        self.seen_vars_stack.pop()
1326
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1327 1328 1329 1330 1331 1332

    def visit_LambdaNode(self, node):
        node.analyse_declarations(self.env_stack[-1])
        self.visitchildren(node)
        return node

1333 1334 1335 1336 1337
    def visit_ClassDefNode(self, node):
        self.env_stack.append(node.scope)
        self.visitchildren(node)
        self.env_stack.pop()
        return node
1338

1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351
    def visit_CClassDefNode(self, node):
        node = self.visit_ClassDefNode(node)
        if node.scope and node.scope.implemented:
            stats = []
            for entry in node.scope.var_entries:
                if entry.needs_property:
                    property = self.create_Property(entry)
                    property.analyse_declarations(node.scope)
                    self.visit(property)
                    stats.append(property)
            if stats:
                node.body.stats += stats
        return node
1352

1353
    def visit_FuncDefNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1354
        self.seen_vars_stack.append(cython.set())
1355
        lenv = node.local_scope
1356 1357
        node.body.analyse_control_flow(lenv) # this will be totally refactored
        node.declare_arguments(lenv)
1358 1359 1360 1361 1362 1363 1364
        for var, type_node in node.directive_locals.items():
            if not lenv.lookup_here(var):   # don't redeclare args
                type = type_node.analyse_as_type(lenv)
                if type:
                    lenv.declare_var(var, type, type_node.pos)
                else:
                    error(type_node.pos, "Not a type")
1365
        node.body.analyse_declarations(lenv)
1366 1367

        if lenv.nogil and lenv.has_with_gil_block:
1368 1369 1370
            # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
            # the entire function body in try/finally.
            # The corresponding release will be taken care of by
1371
            # Nodes.FuncDefNode.generate_function_definitions()
1372
            node.body = Nodes.NogilTryFinallyStatNode(
1373 1374 1375 1376 1377
                node.body.pos,
                body = node.body,
                finally_clause = Nodes.EnsureGILNode(node.body.pos),
            )

1378 1379 1380
        self.env_stack.append(lenv)
        self.visitchildren(node)
        self.env_stack.pop()
1381
        self.seen_vars_stack.pop()
1382
        return node
1383

1384
    def visit_ScopedExprNode(self, node):
1385 1386
        env = self.env_stack[-1]
        node.analyse_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1387
        # the node may or may not have a local scope
1388
        if node.has_local_scope:
Stefan Behnel's avatar
Stefan Behnel committed
1389
            self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
Stefan Behnel's avatar
Stefan Behnel committed
1390
            self.env_stack.append(node.expr_scope)
1391
            node.analyse_scoped_declarations(node.expr_scope)
Stefan Behnel's avatar
Stefan Behnel committed
1392 1393 1394
            self.visitchildren(node)
            self.env_stack.pop()
            self.seen_vars_stack.pop()
1395
        else:
1396
            node.analyse_scoped_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
1397
            self.visitchildren(node)
1398 1399
        return node

1400 1401 1402 1403 1404
    def visit_TempResultFromStatNode(self, node):
        self.visitchildren(node)
        node.analyse_declarations(self.env_stack[-1])
        return node

1405
    def visit_CStructOrUnionDefNode(self, node):
1406
        # Create a wrapper node if needed.
1407 1408 1409
        # We want to use the struct type information (so it can't happen
        # before this phase) but also create new objects to be declared
        # (so it can't happen later).
1410
        # Note that we don't return the original node, as it is
1411 1412 1413
        # never used after this phase.
        if True: # private (default)
            return None
1414

1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
        self_value = ExprNodes.AttributeNode(
            pos = node.pos,
            obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
            attribute = EncodedString(u"value"))
        var_entries = node.entry.type.scope.var_entries
        attributes = []
        for entry in var_entries:
            attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
                                                      obj = self_value,
                                                      attribute = entry.name))
1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461
        # __init__ assignments
        init_assignments = []
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            init_assignments.append(self.init_assignment.substitute({
                    u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
                    u"ATTR": attr,
                }, pos = entry.pos))

        # create the class
        str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
        wrapper_class = self.struct_or_union_wrapper.substitute({
            u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
            u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
            u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
            u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
            u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
        }, pos = node.pos).stats[0]
        wrapper_class.class_name = node.name
        wrapper_class.shadow = True
        class_body = wrapper_class.body.stats

        # fix value type
        assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
        class_body[0].base_type.name = node.name

        # fix __init__ arguments
        init_method = class_body[1]
        assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
        arg_template = init_method.args[1]
        if not node.entry.type.is_struct:
            arg_template.kw_only = True
        del init_method.args[1]
        for entry, attr in zip(var_entries, attributes):
            arg = copy.deepcopy(arg_template)
            arg.declarator.name = entry.name
            init_method.args.append(arg)
Robert Bradshaw's avatar
Robert Bradshaw committed
1462

1463
        # setters/getters
1464 1465 1466 1467 1468 1469 1470 1471 1472 1473
        for entry, attr in zip(var_entries, attributes):
            # TODO: branch on visibility
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
            property = template.substitute({
                    u"ATTR": attr,
                }, pos = entry.pos).stats[0]
            property.name = entry.name
1474
            wrapper_class.body.stats.append(property)
Robert Bradshaw's avatar
Robert Bradshaw committed
1475

1476 1477
        wrapper_class.analyse_declarations(self.env_stack[-1])
        return self.visit_CClassDefNode(wrapper_class)
1478

1479 1480 1481 1482
    # Some nodes are no longer needed after declaration
    # analysis and can be dropped. The analysis was performed
    # on these nodes in a seperate recursive process from the
    # enclosing function or module, so we can simply drop them.
1483
    def visit_CDeclaratorNode(self, node):
1484 1485
        # necessary to ensure that all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
1486
        return node
1487

1488 1489 1490 1491 1492
    def visit_CTypeDefNode(self, node):
        return node

    def visit_CBaseTypeNode(self, node):
        return None
1493

1494
    def visit_CEnumDefNode(self, node):
1495 1496 1497 1498
        if node.visibility == 'public':
            return node
        else:
            return None
1499

1500
    def visit_CNameDeclaratorNode(self, node):
1501 1502
        if node.name in self.seen_vars_stack[-1]:
            entry = self.env_stack[-1].lookup(node.name)
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
1503 1504
            if (entry is None or entry.visibility != 'extern'
                and not entry.scope.is_c_class_scope):
1505
                warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
1506 1507 1508
        self.visitchildren(node)
        return node

1509
    def visit_CVarDefNode(self, node):
1510 1511
        # to ensure all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
1512
        return None
1513

1514
    def create_Property(self, entry):
1515
        if entry.visibility == 'public':
1516 1517 1518 1519
            if entry.type.is_pyobject:
                template = self.basic_pyobject_property
            else:
                template = self.basic_property
1520 1521
        elif entry.visibility == 'readonly':
            template = self.basic_property_ro
1522
        property = template.substitute({
1523
                u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
1524
                                                 obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
1525
                                                 attribute=entry.name),
1526 1527
            }, pos=entry.pos).stats[0]
        property.name = entry.name
1528 1529 1530
        # ---------------------------------------
        # XXX This should go to AutoDocTransforms
        # ---------------------------------------
1531
        if (Options.docstrings and
1532
            self.current_directives['embedsignature']):
1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546
            attr_name = entry.name
            type_name = entry.type.declaration_code("", for_display=1)
            default_value = ''
            if not entry.type.is_pyobject:
                type_name = "'%s'" % type_name
            elif entry.type.is_extension_type:
                type_name = entry.type.module_name + '.' + type_name
            if entry.init is not None:
                default_value = ' = ' + entry.init
            elif entry.init_to_none:
                default_value = ' = ' + repr(None)
            docstring = attr_name + ': ' + type_name + default_value
            property.doc = EncodedString(docstring)
        # ---------------------------------------
1547
        return property
1548

1549

1550
class AnalyseExpressionsTransform(CythonTransform):
Robert Bradshaw's avatar
Robert Bradshaw committed
1551

1552
    def visit_ModuleNode(self, node):
1553
        node.scope.infer_types()
1554 1555 1556
        node.body.analyse_expressions(node.scope)
        self.visitchildren(node)
        return node
1557

1558
    def visit_FuncDefNode(self, node):
1559
        node.local_scope.infer_types()
1560 1561 1562
        node.body.analyse_expressions(node.local_scope)
        self.visitchildren(node)
        return node
1563 1564

    def visit_ScopedExprNode(self, node):
1565
        if node.has_local_scope:
1566 1567
            node.expr_scope.infer_types()
            node.analyse_scoped_expressions(node.expr_scope)
1568 1569
        self.visitchildren(node)
        return node
1570

1571

1572
class ExpandInplaceOperators(EnvTransform):
1573

1574 1575 1576 1577 1578 1579
    def visit_InPlaceAssignmentNode(self, node):
        lhs = node.lhs
        rhs = node.rhs
        if lhs.type.is_cpp_class:
            # No getting around this exact operator here.
            return node
1580
        if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access:
1581 1582 1583
            # There is code to handle this case.
            return node

Robert Bradshaw's avatar
Robert Bradshaw committed
1584
        env = self.current_env()
1585
        def side_effect_free_reference(node, setting=False):
1586
            if isinstance(node, ExprNodes.NameNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
1587 1588
                return node, []
            elif node.type.is_pyobject and not setting:
1589 1590
                node = LetRefNode(node)
                return node, [node]
1591
            elif isinstance(node, ExprNodes.IndexNode):
1592 1593 1594 1595
                if node.is_buffer_access:
                    raise ValueError, "Buffer access"
                base, temps = side_effect_free_reference(node.base)
                index = LetRefNode(node.index)
1596 1597
                return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
            elif isinstance(node, ExprNodes.AttributeNode):
1598
                obj, temps = side_effect_free_reference(node.obj)
1599
                return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
1600 1601 1602 1603 1604 1605 1606 1607
            else:
                node = LetRefNode(node)
                return node, [node]
        try:
            lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
        except ValueError:
            return node
        dup = lhs.__class__(**lhs.__dict__)
1608
        binop = ExprNodes.binop_node(node.pos,
1609 1610 1611 1612
                                     operator = node.operator,
                                     operand1 = dup,
                                     operand2 = rhs,
                                     inplace=True)
Robert Bradshaw's avatar
Robert Bradshaw committed
1613 1614 1615
        # Manually analyse types for new node.
        lhs.analyse_target_types(env)
        dup.analyse_types(env)
Robert Bradshaw's avatar
Robert Bradshaw committed
1616
        binop.analyse_operation(env)
1617
        node = Nodes.SingleAssignmentNode(
1618
            node.pos,
1619 1620
            lhs = lhs,
            rhs=binop.coerce_to(lhs.type, env))
1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631
        # Use LetRefNode to avoid side effects.
        let_ref_nodes.reverse()
        for t in let_ref_nodes:
            node = LetNode(t, node)
        return node

    def visit_ExprNode(self, node):
        # In-place assignments can't happen within an expression.
        return node


1632 1633
class AlignFunctionDefinitions(CythonTransform):
    """
1634 1635
    This class takes the signatures from a .pxd file and applies them to
    the def methods in a .py file.
1636
    """
1637

1638 1639
    def visit_ModuleNode(self, node):
        self.scope = node.scope
1640
        self.directives = node.directives
1641 1642
        self.visitchildren(node)
        return node
1643

1644 1645 1646 1647 1648 1649 1650
    def visit_PyClassDefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
        if pxd_def:
            if pxd_def.is_cclass:
                return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
            else:
                error(node.pos, "'%s' redeclared" % node.name)
1651 1652
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
1653
                return None
1654 1655
        else:
            return node
1656

1657 1658 1659 1660 1661 1662 1663 1664 1665 1666
    def visit_CClassDefNode(self, node, pxd_def=None):
        if pxd_def is None:
            pxd_def = self.scope.lookup(node.class_name)
        if pxd_def:
            outer_scope = self.scope
            self.scope = pxd_def.type.scope
        self.visitchildren(node)
        if pxd_def:
            self.scope = outer_scope
        return node
1667

1668 1669
    def visit_DefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
1670
        if pxd_def and (not pxd_def.scope or not pxd_def.scope.is_builtin_scope):
1671
            if not pxd_def.is_cfunction:
1672
                error(node.pos, "'%s' redeclared" % node.name)
1673 1674
                if pxd_def.pos:
                    error(pxd_def.pos, "previous declaration here")
1675
                return None
1676
            node = node.as_cfunction(pxd_def)
1677 1678
        elif (self.scope.is_module_scope and self.directives['auto_cpdef']
              and node.is_cdef_func_compatible()):
1679
            node = node.as_cfunction(scope=self.scope)
1680
        # Enable this when nested cdef functions are allowed.
1681 1682
        # self.visitchildren(node)
        return node
1683

1684

1685 1686 1687 1688
class YieldNodeCollector(TreeVisitor):

    def __init__(self):
        super(YieldNodeCollector, self).__init__()
1689
        self.yields = []
1690 1691
        self.returns = []
        self.has_return_value = False
1692

1693 1694
    def visit_Node(self, node):
        return self.visitchildren(node)
1695 1696

    def visit_YieldExprNode(self, node):
1697
        if self.has_return_value:
1698
            error(node.pos, "'yield' outside function")
1699
        self.yields.append(node)
Vitja Makarov's avatar
Vitja Makarov committed
1700
        self.visitchildren(node)
1701 1702

    def visit_ReturnStatNode(self, node):
1703 1704 1705 1706 1707
        if node.value:
            self.has_return_value = True
            if self.yields:
                error(node.pos, "'return' with argument inside generator")
        self.returns.append(node)
1708 1709 1710 1711

    def visit_ClassDefNode(self, node):
        pass

1712
    def visit_FuncDefNode(self, node):
1713
        pass
1714

Vitja Makarov's avatar
Vitja Makarov committed
1715 1716 1717
    def visit_LambdaNode(self, node):
        pass

Vitja Makarov's avatar
Vitja Makarov committed
1718 1719 1720
    def visit_GeneratorExpressionNode(self, node):
        pass

1721
class MarkClosureVisitor(CythonTransform):
1722 1723 1724 1725 1726 1727

    def visit_ModuleNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
1728 1729 1730 1731 1732
    def visit_FuncDefNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
1733

1734 1735 1736 1737
        collector = YieldNodeCollector()
        collector.visitchildren(node)

        if collector.yields:
Vitja Makarov's avatar
Vitja Makarov committed
1738 1739
            for i, yield_expr in enumerate(collector.yields):
                yield_expr.label_num = i + 1
1740

1741 1742
            gbody = Nodes.GeneratorBodyDefNode(pos=node.pos,
                                               name=node.name,
1743
                                               body=node.body)
1744 1745 1746 1747 1748 1749 1750
            generator = Nodes.GeneratorDefNode(pos=node.pos,
                                               name=node.name,
                                               args=node.args,
                                               star_arg=node.star_arg,
                                               starstar_arg=node.starstar_arg,
                                               doc=node.doc,
                                               decorators=node.decorators,
Vitja Makarov's avatar
Vitja Makarov committed
1751 1752
                                               gbody=gbody,
                                               lambda_name=node.lambda_name)
1753
            return generator
Robert Bradshaw's avatar
Robert Bradshaw committed
1754
        return node
1755

1756 1757 1758 1759 1760
    def visit_CFuncDefNode(self, node):
        self.visit_FuncDefNode(node)
        if node.needs_closure:
            error(node.pos, "closures inside cdef functions not yet supported")
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1761 1762 1763 1764 1765 1766 1767 1768

    def visit_LambdaNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
        return node

Robert Bradshaw's avatar
Robert Bradshaw committed
1769 1770 1771 1772
    def visit_ClassDefNode(self, node):
        self.visitchildren(node)
        self.needs_closure = True
        return node
Stefan Behnel's avatar
Stefan Behnel committed
1773

1774
class CreateClosureClasses(CythonTransform):
1775
    # Output closure classes in module scope for all functions
Vitja Makarov's avatar
Vitja Makarov committed
1776 1777 1778 1779 1780 1781
    # that really need it.

    def __init__(self, context):
        super(CreateClosureClasses, self).__init__(context)
        self.path = []
        self.in_lambda = False
1782
        self.generator_class = None
Vitja Makarov's avatar
Vitja Makarov committed
1783

1784 1785 1786 1787 1788
    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.visitchildren(node)
        return node

1789
    def create_generator_class(self, target_module_scope, pos):
1790 1791 1792
        if self.generator_class:
            return self.generator_class
        # XXX: make generator class creation cleaner
1793 1794 1795
        entry = target_module_scope.declare_c_class(name='__pyx_Generator',
                    objstruct_cname='__pyx_Generator_object',
                    typeobj_cname='__pyx_Generator_type',
1796 1797 1798 1799 1800 1801 1802
                    pos=pos, defining=True, implementing=True)
        klass = entry.type.scope
        klass.is_internal = True
        klass.directives = {'final': True}

        body_type = PyrexTypes.create_typedef_type('generator_body',
                                                   PyrexTypes.c_void_ptr_type,
1803
                                                   '__pyx_generator_body_t')
1804 1805 1806 1807 1808 1809
        klass.declare_var(pos=pos, name='body', cname='body',
                          type=body_type, is_cdef=True)
        klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
                          is_cdef=True)
        klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
                          is_cdef=True)
1810 1811 1812 1813 1814 1815
        klass.declare_var(pos=pos, name='exc_type', cname='exc_type',
                          type=PyrexTypes.py_object_type, is_cdef=True)
        klass.declare_var(pos=pos, name='exc_value', cname='exc_value',
                          type=PyrexTypes.py_object_type, is_cdef=True)
        klass.declare_var(pos=pos, name='exc_traceback', cname='exc_traceback',
                          type=PyrexTypes.py_object_type, is_cdef=True)
1816 1817 1818

        import TypeSlots
        e = klass.declare_pyfunction('send', pos)
1819
        e.func_cname = '__Pyx_Generator_Send'
1820 1821
        e.signature = TypeSlots.binaryfunc

Vitja Makarov's avatar
Vitja Makarov committed
1822
        e = klass.declare_pyfunction('close', pos)
1823
        e.func_cname = '__Pyx_Generator_Close'
Vitja Makarov's avatar
Vitja Makarov committed
1824
        e.signature = TypeSlots.unaryfunc
1825

1826
        e = klass.declare_pyfunction('throw', pos)
1827
        e.func_cname = '__Pyx_Generator_Throw'
1828
        e.signature = TypeSlots.pyfunction_signature
1829 1830 1831 1832 1833

        e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
        e.func_cname = 'PyObject_SelfIter'

        e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
1834
        e.func_cname = '__Pyx_Generator_Next'
1835 1836 1837 1838

        self.generator_class = entry.type
        return self.generator_class

Stefan Behnel's avatar
Stefan Behnel committed
1839
    def find_entries_used_in_closures(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
1840 1841 1842 1843 1844
        from_closure = []
        in_closure = []
        for name, entry in node.local_scope.entries.items():
            if entry.from_closure:
                from_closure.append((name, entry))
Stefan Behnel's avatar
Stefan Behnel committed
1845
            elif entry.in_closure:
Vitja Makarov's avatar
Vitja Makarov committed
1846 1847 1848 1849
                in_closure.append((name, entry))
        return from_closure, in_closure

    def create_class_from_scope(self, node, target_module_scope, inner_node=None):
1850 1851 1852
        # skip generator body
        if node.is_generator_body:
            return
1853 1854 1855 1856 1857 1858
        # move local variables into closure
        if node.is_generator:
            for entry in node.local_scope.entries.values():
                if not entry.from_closure:
                    entry.in_closure = True

Stefan Behnel's avatar
Stefan Behnel committed
1859
        from_closure, in_closure = self.find_entries_used_in_closures(node)
Vitja Makarov's avatar
Vitja Makarov committed
1860 1861 1862 1863 1864 1865
        in_closure.sort()

        # Now from the begining
        node.needs_closure = False
        node.needs_outer_scope = False

1866
        func_scope = node.local_scope
Vitja Makarov's avatar
Vitja Makarov committed
1867 1868 1869 1870
        cscope = node.entry.scope
        while cscope.is_py_class_scope or cscope.is_c_class_scope:
            cscope = cscope.outer_scope

1871
        if not from_closure and (self.path or inner_node):
Vitja Makarov's avatar
Vitja Makarov committed
1872 1873 1874 1875 1876 1877
            if not inner_node:
                if not node.assmt:
                    raise InternalError, "DefNode does not have assignment node"
                inner_node = node.assmt.rhs
            inner_node.needs_self_code = False
            node.needs_outer_scope = False
1878

Stefan Behnel's avatar
Stefan Behnel committed
1879
        base_type = None
1880
        if node.is_generator:
Stefan Behnel's avatar
Stefan Behnel committed
1881
            base_type = self.create_generator_class(target_module_scope, node.pos)
1882
        elif not in_closure and not from_closure:
Vitja Makarov's avatar
Vitja Makarov committed
1883 1884 1885 1886 1887 1888 1889 1890
            return
        elif not in_closure:
            func_scope.is_passthrough = True
            func_scope.scope_class = cscope.scope_class
            node.needs_outer_scope = True
            return

        as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
1891

Stefan Behnel's avatar
Stefan Behnel committed
1892 1893 1894 1895
        entry = target_module_scope.declare_c_class(
            name=as_name, pos=node.pos, defining=True,
            implementing=True, base_type=base_type)

Robert Bradshaw's avatar
Robert Bradshaw committed
1896
        func_scope.scope_class = entry
1897
        class_scope = entry.type.scope
1898
        class_scope.is_internal = True
1899
        class_scope.directives = {'final': True}
1900

Vitja Makarov's avatar
Vitja Makarov committed
1901 1902
        if from_closure:
            assert cscope.is_closure_scope
1903
            class_scope.declare_var(pos=node.pos,
Vitja Makarov's avatar
Vitja Makarov committed
1904
                                    name=Naming.outer_scope_cname,
1905
                                    cname=Naming.outer_scope_cname,
1906
                                    type=cscope.scope_class.type,
1907
                                    is_cdef=True)
Vitja Makarov's avatar
Vitja Makarov committed
1908 1909
            node.needs_outer_scope = True
        for name, entry in in_closure:
1910
            closure_entry = class_scope.declare_var(pos=entry.pos,
1911
                                    name=entry.name,
1912
                                    cname=entry.cname,
1913 1914
                                    type=entry.type,
                                    is_cdef=True)
1915 1916
            if entry.is_declared_generic:
                closure_entry.is_declared_generic = 1
Vitja Makarov's avatar
Vitja Makarov committed
1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928
        node.needs_closure = True
        # Do it here because other classes are already checked
        target_module_scope.check_c_class(func_scope.scope_class)

    def visit_LambdaNode(self, node):
        was_in_lambda = self.in_lambda
        self.in_lambda = True
        self.create_class_from_scope(node.def_node, self.module_scope, node)
        self.visitchildren(node)
        self.in_lambda = was_in_lambda
        return node

1929
    def visit_FuncDefNode(self, node):
Vitja Makarov's avatar
Vitja Makarov committed
1930 1931 1932 1933
        if self.in_lambda:
            self.visitchildren(node)
            return node
        if node.needs_closure or self.path:
Robert Bradshaw's avatar
Robert Bradshaw committed
1934
            self.create_class_from_scope(node, self.module_scope)
Vitja Makarov's avatar
Vitja Makarov committed
1935
            self.path.append(node)
1936
            self.visitchildren(node)
Vitja Makarov's avatar
Vitja Makarov committed
1937
            self.path.pop()
1938
        return node
1939 1940 1941 1942 1943 1944 1945


class GilCheck(VisitorTransform):
    """
    Call `node.gil_check(env)` on each node to make sure we hold the
    GIL when we need it.  Raise an error when on Python operations
    inside a `nogil` environment.
1946 1947 1948

    Additionally, raise exceptions for closely nested with gil or with nogil
    statements. The latter would abort Python.
1949
    """
1950

1951 1952
    def __call__(self, root):
        self.env_stack = [root.scope]
1953
        self.nogil = False
1954 1955 1956 1957

        # True for 'cdef func() nogil:' functions, as the GIL may be held while
        # calling this function (thus contained 'nogil' blocks may be valid).
        self.nogil_declarator_only = False
1958 1959 1960 1961
        return super(GilCheck, self).__call__(root)

    def visit_FuncDefNode(self, node):
        self.env_stack.append(node.local_scope)
1962 1963
        was_nogil = self.nogil
        self.nogil = node.local_scope.nogil
Mark Florisson's avatar
Mark Florisson committed
1964

1965 1966 1967
        if self.nogil:
            self.nogil_declarator_only = True

1968 1969
        if self.nogil and node.nogil_check:
            node.nogil_check(node.local_scope)
Mark Florisson's avatar
Mark Florisson committed
1970

1971
        self.visitchildren(node)
1972 1973 1974 1975

        # This cannot be nested, so it doesn't need backup/restore
        self.nogil_declarator_only = False

1976
        self.env_stack.pop()
1977
        self.nogil = was_nogil
1978 1979 1980
        return node

    def visit_GILStatNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
1981 1982 1983
        if self.nogil and node.nogil_check:
            node.nogil_check()

1984 1985
        was_nogil = self.nogil
        self.nogil = (node.state == 'nogil')
1986 1987 1988 1989 1990 1991 1992 1993 1994

        if was_nogil == self.nogil and not self.nogil_declarator_only:
            if not was_nogil:
                error(node.pos, "Trying to acquire the GIL while it is "
                                "already held.")
            else:
                error(node.pos, "Trying to release the GIL while it was "
                                "previously released.")

1995 1996 1997 1998 1999 2000 2001 2002
        if isinstance(node.finally_clause, Nodes.StatListNode):
            # The finally clause of the GILStatNode is a GILExitNode,
            # which is wrapped in a StatListNode. Just unpack that.
            node.finally_clause, = node.finally_clause.stats

        if node.state == 'gil':
            self.seen_with_gil_statement = True

2003
        self.visitchildren(node)
2004
        self.nogil = was_nogil
2005 2006
        return node

Mark Florisson's avatar
Mark Florisson committed
2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033
    def visit_ParallelRangeNode(self, node):
        if node.is_nogil:
            node.is_nogil = False
            node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
            return self.visit_GILStatNode(node)

        if not self.nogil:
            error(node.pos, "prange() can only be used without the GIL")
            # Forget about any GIL-related errors that may occur in the body
            return None

        node.nogil_check(self.env_stack[-1])
        self.visitchildren(node)
        return node

    def visit_ParallelWithBlockNode(self, node):
        if not self.nogil:
            error(node.pos, "The parallel section may only be used without "
                            "the GIL")
            return None

        if node.nogil_check:
            # It does not currently implement this, but test for it anyway to
            # avoid potential future surprises
            node.nogil_check(self.env_stack[-1])

        self.visitchildren(node)
2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054

    def visit_TryFinallyStatNode(self, node):
        """
        Take care of try/finally statements in nogil code sections. The
        'try' must contain a 'with gil:' statement somewhere.
        """
        if not self.nogil or isinstance(node, Nodes.GILStatNode):
            return self.visit_Node(node)

        node.nogil_check = None
        node.is_try_finally_in_nogil = True

        # First, visit the body and check for errors
        self.seen_with_gil_statement = False
        self.visitchildren(node.body)

        if not self.seen_with_gil_statement:
            error(node.pos, "Cannot use try/finally in nogil sections unless "
                            "it contains a 'with gil' statement.")

        self.visitchildren(node.finally_clause)
Mark Florisson's avatar
Mark Florisson committed
2055 2056
        return node

2057
    def visit_Node(self, node):
2058 2059
        if self.env_stack and self.nogil and node.nogil_check:
            node.nogil_check(self.env_stack[-1])
2060 2061 2062
        self.visitchildren(node)
        return node

2063

Robert Bradshaw's avatar
Robert Bradshaw committed
2064 2065
class TransformBuiltinMethods(EnvTransform):

2066 2067 2068 2069 2070 2071
    def visit_SingleAssignmentNode(self, node):
        if node.declaration_only:
            return None
        else:
            self.visitchildren(node)
            return node
2072

2073
    def visit_AttributeNode(self, node):
2074
        self.visitchildren(node)
2075 2076 2077 2078
        return self.visit_cython_attribute(node)

    def visit_NameNode(self, node):
        return self.visit_cython_attribute(node)
2079

2080 2081
    def visit_cython_attribute(self, node):
        attribute = node.as_cython_attribute()
2082 2083
        if attribute:
            if attribute == u'compiled':
2084
                node = ExprNodes.BoolNode(node.pos, value=True)
2085
            elif attribute == u'NULL':
2086
                node = ExprNodes.NullNode(node.pos)
2087
            elif attribute in (u'set', u'frozenset'):
2088 2089
                node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
                                          entry=self.current_env().builtin_scope().lookup_here(attribute))
2090
            elif not PyrexTypes.parse_basic_type(attribute):
Robert Bradshaw's avatar
Robert Bradshaw committed
2091
                error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
2092 2093
        return node

2094 2095 2096 2097 2098 2099 2100 2101
    def _inject_locals(self, node, func_name):
        # locals()/dir() builtins
        lenv = self.current_env()
        entry = lenv.lookup_here(func_name)
        if entry:
            # not the builtin
            return node
        pos = node.pos
2102 2103 2104 2105 2106
        if func_name == 'locals':
            if len(node.args) > 0:
                error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d"
                      % len(node.args))
                return node
2107 2108 2109 2110 2111 2112
            items = [ ExprNodes.DictItemNode(pos,
                                             key=ExprNodes.StringNode(pos, value=var),
                                             value=ExprNodes.NameNode(pos, name=var))
                      for var in lenv.entries ]
            return ExprNodes.DictNode(pos, key_value_pairs=items)
        else:
2113 2114 2115 2116 2117 2118 2119 2120
            if len(node.args) > 1:
                error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d"
                      % len(node.args))
                return node
            elif len(node.args) == 1:
                # optimised in Builtin.py
                return node
            items = [ ExprNodes.StringNode(pos, value=var) for var in lenv.entries ]
2121
            return ExprNodes.ListNode(pos, args=items)
2122

2123
    def visit_SimpleCallNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
2124
        if isinstance(node.function, ExprNodes.NameNode):
2125 2126 2127
            func_name = node.function.name
            if func_name in ('dir', 'locals'):
                return self._inject_locals(node, func_name)
2128 2129

        # cython.foo
2130
        function = node.function.as_cython_attribute()
2131
        if function:
2132 2133 2134 2135 2136
            if function in InterpretCompilerDirectives.unop_method_nodes:
                if len(node.args) != 1:
                    error(node.function.pos, u"%s() takes exactly one argument" % function)
                else:
                    node = InterpretCompilerDirectives.unop_method_nodes[function](node.function.pos, operand=node.args[0])
Robert Bradshaw's avatar
Robert Bradshaw committed
2137 2138 2139 2140 2141
            elif function in InterpretCompilerDirectives.binop_method_nodes:
                if len(node.args) != 2:
                    error(node.function.pos, u"%s() takes exactly two arguments" % function)
                else:
                    node = InterpretCompilerDirectives.binop_method_nodes[function](node.function.pos, operand1=node.args[0], operand2=node.args[1])
2142
            elif function == u'cast':
2143
                if len(node.args) != 2:
2144
                    error(node.function.pos, u"cast() takes exactly two arguments")
2145
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2146
                    type = node.args[0].analyse_as_type(self.current_env())
2147
                    if type:
2148
                        node = ExprNodes.TypecastNode(node.function.pos, type=type, operand=node.args[1])
2149 2150 2151 2152
                    else:
                        error(node.args[0].pos, "Not a type")
            elif function == u'sizeof':
                if len(node.args) != 1:
Robert Bradshaw's avatar
Robert Bradshaw committed
2153
                    error(node.function.pos, u"sizeof() takes exactly one argument")
2154
                else:
Stefan Behnel's avatar
Stefan Behnel committed
2155
                    type = node.args[0].analyse_as_type(self.current_env())
2156
                    if type:
2157
                        node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
2158
                    else:
2159
                        node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
2160 2161
            elif function == 'cmod':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
2162
                    error(node.function.pos, u"cmod() takes exactly two arguments")
2163
                else:
2164
                    node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
2165 2166 2167
                    node.cdivision = True
            elif function == 'cdiv':
                if len(node.args) != 2:
Robert Bradshaw's avatar
Robert Bradshaw committed
2168
                    error(node.function.pos, u"cdiv() takes exactly two arguments")
2169
                else:
2170
                    node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
2171
                    node.cdivision = True
2172
            elif function == u'set':
2173
                node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
2174 2175
            else:
                error(node.function.pos, u"'%s' not a valid cython language construct" % function)
2176

2177
        self.visitchildren(node)
Robert Bradshaw's avatar
Robert Bradshaw committed
2178
        return node
2179 2180


Mark Florisson's avatar
Mark Florisson committed
2181
class DebugTransform(CythonTransform):
2182
    """
Mark Florisson's avatar
Mark Florisson committed
2183
    Write debug information for this Cython module.
2184
    """
2185

2186
    def __init__(self, context, options, result):
Mark Florisson's avatar
Mark Florisson committed
2187
        super(DebugTransform, self).__init__(context)
2188
        self.visited = cython.set()
2189
        # our treebuilder and debug output writer
Mark Florisson's avatar
Mark Florisson committed
2190
        # (see Cython.Debugger.debug_output.CythonDebugWriter)
2191
        self.tb = self.context.gdb_debug_outputwriter
2192
        #self.c_output_file = options.output_file
2193
        self.c_output_file = result.c_file
2194

2195 2196 2197
        # Closure support, basically treat nested functions as if the AST were
        # never nested
        self.nested_funcdefs = []
2198

Mark Florisson's avatar
Mark Florisson committed
2199 2200
        # tells visit_NameNode whether it should register step-into functions
        self.register_stepinto = False
2201

2202
    def visit_ModuleNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
2203
        self.tb.module_name = node.full_module_name
2204
        attrs = dict(
Mark Florisson's avatar
Mark Florisson committed
2205
            module_name=node.full_module_name,
Mark Florisson's avatar
Mark Florisson committed
2206 2207
            filename=node.pos[0].filename,
            c_filename=self.c_output_file)
2208

2209
        self.tb.start('Module', attrs)
2210

2211
        # serialize functions
Mark Florisson's avatar
Mark Florisson committed
2212
        self.tb.start('Functions')
2213
        # First, serialize functions normally...
2214
        self.visitchildren(node)
2215

2216 2217 2218
        # ... then, serialize nested functions
        for nested_funcdef in self.nested_funcdefs:
            self.visit_FuncDefNode(nested_funcdef)
2219

2220 2221 2222
        self.register_stepinto = True
        self.serialize_modulenode_as_function(node)
        self.register_stepinto = False
2223
        self.tb.end('Functions')
2224

2225
        # 2.3 compatibility. Serialize global variables
Mark Florisson's avatar
Mark Florisson committed
2226
        self.tb.start('Globals')
2227
        entries = {}
Mark Florisson's avatar
Mark Florisson committed
2228

2229
        for k, v in node.scope.entries.iteritems():
Mark Florisson's avatar
Mark Florisson committed
2230
            if (v.qualified_name not in self.visited and not
2231
                v.name.startswith('__pyx_') and not
Mark Florisson's avatar
Mark Florisson committed
2232 2233
                v.type.is_cfunction and not
                v.type.is_extension_type):
2234
                entries[k]= v
2235

2236 2237
        self.serialize_local_variables(entries)
        self.tb.end('Globals')
Mark Florisson's avatar
Mark Florisson committed
2238 2239
        # self.tb.end('Module') # end Module after the line number mapping in
        # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
2240
        return node
2241 2242

    def visit_FuncDefNode(self, node):
2243
        self.visited.add(node.local_scope.qualified_name)
2244 2245 2246 2247 2248 2249 2250 2251

        if getattr(node, 'is_wrapper', False):
            return node

        if self.register_stepinto:
            self.nested_funcdefs.append(node)
            return node

2252
        # node.entry.visibility = 'extern'
2253 2254 2255 2256
        if node.py_func is None:
            pf_cname = ''
        else:
            pf_cname = node.py_func.entry.func_cname
2257

2258 2259 2260 2261 2262 2263
        attrs = dict(
            name=node.entry.name,
            cname=node.entry.func_cname,
            pf_cname=pf_cname,
            qualified_name=node.local_scope.qualified_name,
            lineno=str(node.pos[1]))
2264

2265
        self.tb.start('Function', attrs=attrs)
2266

Mark Florisson's avatar
Mark Florisson committed
2267
        self.tb.start('Locals')
2268 2269
        self.serialize_local_variables(node.local_scope.entries)
        self.tb.end('Locals')
Mark Florisson's avatar
Mark Florisson committed
2270 2271

        self.tb.start('Arguments')
2272
        for arg in node.local_scope.arg_entries:
Mark Florisson's avatar
Mark Florisson committed
2273 2274
            self.tb.start(arg.name)
            self.tb.end(arg.name)
2275
        self.tb.end('Arguments')
Mark Florisson's avatar
Mark Florisson committed
2276 2277

        self.tb.start('StepIntoFunctions')
Mark Florisson's avatar
Mark Florisson committed
2278
        self.register_stepinto = True
Mark Florisson's avatar
Mark Florisson committed
2279
        self.visitchildren(node)
Mark Florisson's avatar
Mark Florisson committed
2280
        self.register_stepinto = False
Mark Florisson's avatar
Mark Florisson committed
2281
        self.tb.end('StepIntoFunctions')
2282
        self.tb.end('Function')
Mark Florisson's avatar
Mark Florisson committed
2283 2284 2285 2286

        return node

    def visit_NameNode(self, node):
2287 2288
        if (self.register_stepinto and
            node.type.is_cfunction and
2289 2290
            getattr(node, 'is_called', False) and
            node.entry.func_cname is not None):
2291 2292 2293 2294
            # don't check node.entry.in_cinclude, as 'cdef extern: ...'
            # declared functions are not 'in_cinclude'.
            # This means we will list called 'cdef' functions as
            # "step into functions", but this is not an issue as they will be
Mark Florisson's avatar
Mark Florisson committed
2295
            # recognized as Cython functions anyway.
Mark Florisson's avatar
Mark Florisson committed
2296 2297 2298
            attrs = dict(name=node.entry.func_cname)
            self.tb.start('StepIntoFunction', attrs=attrs)
            self.tb.end('StepIntoFunction')
2299

Mark Florisson's avatar
Mark Florisson committed
2300
        self.visitchildren(node)
2301
        return node
2302

2303 2304 2305 2306 2307 2308 2309
    def serialize_modulenode_as_function(self, node):
        """
        Serialize the module-level code as a function so the debugger will know
        it's a "relevant frame" and it will know where to set the breakpoint
        for 'break modulename'.
        """
        name = node.full_module_name.rpartition('.')[-1]
2310

2311 2312
        cname_py2 = 'init' + name
        cname_py3 = 'PyInit_' + name
2313

2314 2315 2316 2317
        py2_attrs = dict(
            name=name,
            cname=cname_py2,
            pf_cname='',
2318
            # Ignore the qualified_name, breakpoints should be set using
2319 2320 2321 2322 2323
            # `cy break modulename:lineno` for module-level breakpoints.
            qualified_name='',
            lineno='1',
            is_initmodule_function="True",
        )
2324

2325
        py3_attrs = dict(py2_attrs, cname=cname_py3)
2326

2327 2328
        self._serialize_modulenode_as_function(node, py2_attrs)
        self._serialize_modulenode_as_function(node, py3_attrs)
2329

2330 2331
    def _serialize_modulenode_as_function(self, node, attrs):
        self.tb.start('Function', attrs=attrs)
2332

2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344
        self.tb.start('Locals')
        self.serialize_local_variables(node.scope.entries)
        self.tb.end('Locals')

        self.tb.start('Arguments')
        self.tb.end('Arguments')

        self.tb.start('StepIntoFunctions')
        self.register_stepinto = True
        self.visitchildren(node)
        self.register_stepinto = False
        self.tb.end('StepIntoFunctions')
2345

2346
        self.tb.end('Function')
2347

2348 2349 2350
    def serialize_local_variables(self, entries):
        for entry in entries.values():
            if entry.type.is_pyobject:
Mark Florisson's avatar
Mark Florisson committed
2351
                vartype = 'PythonObject'
2352 2353
            else:
                vartype = 'CObject'
2354

2355 2356 2357
            if entry.from_closure:
                # We're dealing with a closure where a variable from an outer
                # scope is accessed, get it from the scope object.
2358
                cname = '%s->%s' % (Naming.cur_scope_cname,
2359
                                    entry.outer_entry.cname)
2360

2361
                qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
2362
                                      entry.scope.name,
2363
                                      entry.name)
2364
            elif entry.in_closure:
2365
                cname = '%s->%s' % (Naming.cur_scope_cname,
2366 2367
                                    entry.cname)
                qname = entry.qualified_name
2368 2369 2370
            else:
                cname = entry.cname
                qname = entry.qualified_name
2371

2372 2373 2374 2375 2376 2377 2378
            if not entry.pos:
                # this happens for variables that are not in the user's code,
                # e.g. for the global __builtins__, __doc__, etc. We can just
                # set the lineno to 0 for those.
                lineno = '0'
            else:
                lineno = str(entry.pos[1])
2379

2380 2381 2382
            attrs = dict(
                name=entry.name,
                cname=cname,
2383
                qualified_name=qname,
2384 2385
                type=vartype,
                lineno=lineno)
2386

Mark Florisson's avatar
Mark Florisson committed
2387 2388
            self.tb.start('LocalVar', attrs)
            self.tb.end('LocalVar')
2389