TypeInference.py 20.2 KB
Newer Older
1
from Errors import error, message
2
import ExprNodes
3 4
import Nodes
import Builtin
5
import PyrexTypes
6
from Cython import Utils
7
from PyrexTypes import py_object_type, unspecified_type
8
from Visitor import CythonTransform, EnvTransform
9

Robert Bradshaw's avatar
Robert Bradshaw committed
10

11
class TypedExprNode(ExprNodes.ExprNode):
Stefan Behnel's avatar
Stefan Behnel committed
12
    # Used for declaring assignments of a specified type without a known entry.
13 14 15
    def __init__(self, type):
        self.type = type

Robert Bradshaw's avatar
Robert Bradshaw committed
16
object_expr = TypedExprNode(py_object_type)
17

18 19 20 21 22

class MarkParallelAssignments(EnvTransform):
    # Collects assignments inside parallel blocks prange, with parallel.
    # Perhaps it's better to move it to ControlFlowAnalysis.

23 24 25
    # tells us whether we're in a normal loop
    in_loop = False

26 27
    parallel_errors = False

Mark Florisson's avatar
Mark Florisson committed
28 29 30
    def __init__(self, context):
        # Track the parallel block scopes (with parallel, for i in prange())
        self.parallel_block_stack = []
31
        return super(MarkParallelAssignments, self).__init__(context)
Mark Florisson's avatar
Mark Florisson committed
32 33

    def mark_assignment(self, lhs, rhs, inplace_op=None):
34
        if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
35 36 37
            if lhs.entry is None:
                # TODO: This shouldn't happen...
                return
Mark Florisson's avatar
Mark Florisson committed
38 39 40

            if self.parallel_block_stack:
                parallel_node = self.parallel_block_stack[-1]
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
                previous_assignment = parallel_node.assignments.get(lhs.entry)

                # If there was a previous assignment to the variable, keep the
                # previous assignment position
                if previous_assignment:
                    pos, previous_inplace_op = previous_assignment

                    if (inplace_op and previous_inplace_op and
                            inplace_op != previous_inplace_op):
                        # x += y; x *= y
                        t = (inplace_op, previous_inplace_op)
                        error(lhs.pos,
                              "Reduction operator '%s' is inconsistent "
                              "with previous reduction operator '%s'" % t)
                else:
                    pos = lhs.pos

                parallel_node.assignments[lhs.entry] = (pos, inplace_op)
Mark Florisson's avatar
Mark Florisson committed
59
                parallel_node.assigned_nodes.append(lhs)
Mark Florisson's avatar
Mark Florisson committed
60

61 62 63 64 65 66
        elif isinstance(lhs, ExprNodes.SequenceNode):
            for arg in lhs.args:
                self.mark_assignment(arg, object_expr)
        else:
            # Could use this info to infer cdef class attributes...
            pass
67

68 69 70 71 72
    def visit_WithTargetAssignmentStatNode(self, node):
        self.mark_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node

73 74 75 76 77 78 79 80 81 82
    def visit_SingleAssignmentNode(self, node):
        self.mark_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node

    def visit_CascadedAssignmentNode(self, node):
        for lhs in node.lhs_list:
            self.mark_assignment(lhs, node.rhs)
        self.visitchildren(node)
        return node
83

84
    def visit_InPlaceAssignmentNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
85
        self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
86 87 88 89
        self.visitchildren(node)
        return node

    def visit_ForInStatNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
90
        # TODO: Remove redundancy with range optimization...
Stefan Behnel's avatar
Stefan Behnel committed
91
        is_special = False
Robert Bradshaw's avatar
Robert Bradshaw committed
92
        sequence = node.iterator.sequence
93
        target = node.target
94 95
        if isinstance(sequence, ExprNodes.SimpleCallNode):
            function = sequence.function
96
            if sequence.self is None and function.is_name:
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
                entry = self.current_env().lookup(function.name)
                if not entry or entry.is_builtin:
                    if function.name == 'reversed' and len(sequence.args) == 1:
                        sequence = sequence.args[0]
                    elif function.name == 'enumerate' and len(sequence.args) == 1:
                        if target.is_sequence_constructor and len(target.args) == 2:
                            iterator = sequence.args[0]
                            if iterator.is_name:
                                iterator_type = iterator.infer_type(self.current_env())
                                if iterator_type.is_builtin_type:
                                    # assume that builtin types have a length within Py_ssize_t
                                    self.mark_assignment(
                                        target.args[0],
                                        ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
                                                          type=PyrexTypes.c_py_ssize_t_type))
                                    target = target.args[1]
                                    sequence = sequence.args[0]
Robert Bradshaw's avatar
Robert Bradshaw committed
114 115
        if isinstance(sequence, ExprNodes.SimpleCallNode):
            function = sequence.function
Stefan Behnel's avatar
Stefan Behnel committed
116
            if sequence.self is None and function.is_name:
117 118 119 120 121 122 123 124 125 126 127 128 129
                entry = self.current_env().lookup(function.name)
                if not entry or entry.is_builtin:
                    if function.name in ('range', 'xrange'):
                        is_special = True
                        for arg in sequence.args[:2]:
                            self.mark_assignment(target, arg)
                        if len(sequence.args) > 2:
                            self.mark_assignment(
                                target,
                                ExprNodes.binop_node(node.pos,
                                                     '+',
                                                     sequence.args[0],
                                                     sequence.args[2]))
130

Stefan Behnel's avatar
Stefan Behnel committed
131
        if not is_special:
132 133 134 135 136
            # A for-loop basically translates to subsequent calls to
            # __getitem__(), so using an IndexNode here allows us to
            # naturally infer the base type of pointers, C arrays,
            # Python strings, etc., while correctly falling back to an
            # object type when the base type cannot be handled.
137
            self.mark_assignment(target, ExprNodes.IndexNode(
138 139 140
                node.pos,
                base = sequence,
                index = ExprNodes.IntNode(node.pos, value = '0')))
141

142 143 144 145 146 147
        self.visitchildren(node)
        return node

    def visit_ForFromStatNode(self, node):
        self.mark_assignment(node.target, node.bound1)
        if node.step is not None:
Robert Bradshaw's avatar
Robert Bradshaw committed
148
            self.mark_assignment(node.target,
149 150 151
                    ExprNodes.binop_node(node.pos,
                                         '+',
                                         node.bound1,
Robert Bradshaw's avatar
Robert Bradshaw committed
152
                                         node.step))
153 154 155
        self.visitchildren(node)
        return node

156
    def visit_WhileStatNode(self, node):
157
        self.visitchildren(node)
158 159 160 161 162 163 164
        return node

    def visit_ExceptClauseNode(self, node):
        if node.target is not None:
            self.mark_assignment(node.target, object_expr)
        self.visitchildren(node)
        return node
165

166
    def visit_FromCImportStatNode(self, node):
167
        pass # Can't be assigned to...
168 169 170 171 172 173 174

    def visit_FromImportStatNode(self, node):
        for name, target in node.items:
            if name != "*":
                self.mark_assignment(target, object_expr)
        self.visitchildren(node)
        return node
Robert Bradshaw's avatar
Robert Bradshaw committed
175

176 177 178 179 180 181 182 183
    def visit_DefNode(self, node):
        # use fake expressions with the right result type
        if node.star_arg:
            self.mark_assignment(
                node.star_arg, TypedExprNode(Builtin.tuple_type))
        if node.starstar_arg:
            self.mark_assignment(
                node.starstar_arg, TypedExprNode(Builtin.dict_type))
184
        EnvTransform.visit_FuncDefNode(self, node)
185 186
        return node

187 188 189 190 191 192
    def visit_DelStatNode(self, node):
        for arg in node.args:
            self.mark_assignment(arg, arg)
        self.visitchildren(node)
        return node

Mark Florisson's avatar
Mark Florisson committed
193 194 195 196 197 198
    def visit_ParallelStatNode(self, node):
        if self.parallel_block_stack:
            node.parent = self.parallel_block_stack[-1]
        else:
            node.parent = None

199
        nested = False
Mark Florisson's avatar
Mark Florisson committed
200 201 202 203 204 205
        if node.is_prange:
            if not node.parent:
                node.is_parallel = True
            else:
                node.is_parallel = (node.parent.is_prange or not
                                    node.parent.is_parallel)
206
                nested = node.parent.is_prange
Mark Florisson's avatar
Mark Florisson committed
207 208
        else:
            node.is_parallel = True
209 210 211 212
            # Note: nested with parallel() blocks are handled by
            # ParallelRangeTransform!
            # nested = node.parent
            nested = node.parent and node.parent.is_prange
Mark Florisson's avatar
Mark Florisson committed
213 214

        self.parallel_block_stack.append(node)
215

216
        nested = nested or len(self.parallel_block_stack) > 2
217 218
        if not self.parallel_errors and nested and not node.is_prange:
            error(node.pos, "Only prange() may be nested")
219 220
            self.parallel_errors = True

221 222 223 224 225 226 227 228 229 230 231 232 233
        if node.is_prange:
            child_attrs = node.child_attrs
            node.child_attrs = ['body', 'target', 'args']
            self.visitchildren(node)
            node.child_attrs = child_attrs

            self.parallel_block_stack.pop()
            if node.else_clause:
                node.else_clause = self.visit(node.else_clause)
        else:
            self.visitchildren(node)
            self.parallel_block_stack.pop()

234
        self.parallel_errors = False
235 236
        return node

237
    def visit_YieldExprNode(self, node):
238
        if self.parallel_block_stack:
239
            error(node.pos, "Yield not allowed in parallel sections")
240

Mark Florisson's avatar
Mark Florisson committed
241 242
        return node

243 244
    def visit_ReturnStatNode(self, node):
        node.in_parallel = bool(self.parallel_block_stack)
Mark Florisson's avatar
Mark Florisson committed
245 246 247
        return node


Craig Citro's avatar
Craig Citro committed
248
class MarkOverflowingArithmetic(CythonTransform):
249 250 251 252 253 254 255 256 257

    # It may be possible to integrate this with the above for
    # performance improvements (though likely not worth it).

    might_overflow = False

    def __call__(self, root):
        self.env_stack = []
        self.env = root.scope
258
        return super(MarkOverflowingArithmetic, self).__call__(root)
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274

    def visit_safe_node(self, node):
        self.might_overflow, saved = False, self.might_overflow
        self.visitchildren(node)
        self.might_overflow = saved
        return node

    def visit_neutral_node(self, node):
        self.visitchildren(node)
        return node

    def visit_dangerous_node(self, node):
        self.might_overflow, saved = True, self.might_overflow
        self.visitchildren(node)
        self.might_overflow = saved
        return node
275

276 277 278 279 280 281 282 283 284 285 286 287 288
    def visit_FuncDefNode(self, node):
        self.env_stack.append(self.env)
        self.env = node.local_scope
        self.visit_safe_node(node)
        self.env = self.env_stack.pop()
        return node

    def visit_NameNode(self, node):
        if self.might_overflow:
            entry = node.entry or self.env.lookup(node.name)
            if entry:
                entry.might_overflow = True
        return node
289

290 291 292 293 294
    def visit_BinopNode(self, node):
        if node.operator in '&|^':
            return self.visit_neutral_node(node)
        else:
            return self.visit_dangerous_node(node)
295

296
    visit_UnopNode = visit_neutral_node
297

298
    visit_UnaryMinusNode = visit_dangerous_node
299

300
    visit_InPlaceAssignmentNode = visit_dangerous_node
301

302
    visit_Node = visit_safe_node
303

304
    def visit_assignment(self, lhs, rhs):
305
        if (isinstance(rhs, ExprNodes.IntNode)
306 307 308 309 310
                and isinstance(lhs, ExprNodes.NameNode)
                and Utils.long_literal(rhs.value)):
            entry = lhs.entry or self.env.lookup(lhs.name)
            if entry:
                entry.might_overflow = True
311

312 313 314 315
    def visit_SingleAssignmentNode(self, node):
        self.visit_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node
316

317 318 319 320 321
    def visit_CascadedAssignmentNode(self, node):
        for lhs in node.lhs_list:
            self.visit_assignment(lhs, node.rhs)
        self.visitchildren(node)
        return node
Robert Bradshaw's avatar
Robert Bradshaw committed
322

Stefan Behnel's avatar
Stefan Behnel committed
323
class PyObjectTypeInferer(object):
Robert Bradshaw's avatar
Robert Bradshaw committed
324 325 326 327 328 329 330 331 332 333 334
    """
    If it's not declared, it's a PyObject.
    """
    def infer_types(self, scope):
        """
        Given a dict of entries, map all unspecified types to a specified type.
        """
        for name, entry in scope.entries.items():
            if entry.type is unspecified_type:
                entry.type = py_object_type

Stefan Behnel's avatar
Stefan Behnel committed
335
class SimpleAssignmentTypeInferer(object):
Robert Bradshaw's avatar
Robert Bradshaw committed
336 337
    """
    Very basic type inference.
338 339 340

    Note: in order to support cross-closure type inference, this must be
    applies to nested scopes in top-down order.
Robert Bradshaw's avatar
Robert Bradshaw committed
341
    """
342 343 344 345 346
    def set_entry_type(self, entry, entry_type):
        entry.type = entry_type
        for e in entry.all_entries():
            e.type = entry_type

Robert Bradshaw's avatar
Robert Bradshaw committed
347
    def infer_types(self, scope):
348
        enabled = scope.directives['infer_types']
349
        verbose = scope.directives['infer_types.verbose']
350

351 352 353 354 355 356 357
        if enabled == True:
            spanning_type = aggressive_spanning_type
        elif enabled is None: # safe mode
            spanning_type = safe_spanning_type
        else:
            for entry in scope.entries.values():
                if entry.type is unspecified_type:
358
                    self.set_entry_type(entry, py_object_type)
359 360
            return

361 362 363 364 365 366
        # Set of assignemnts
        assignments = set([])
        assmts_resolved = set([])
        dependencies = {}
        assmt_to_names = {}

Robert Bradshaw's avatar
Robert Bradshaw committed
367
        for name, entry in scope.entries.items():
368 369 370 371 372 373 374
            for assmt in entry.cf_assignments:
                names = assmt.type_dependencies()
                assmt_to_names[assmt] = names
                assmts = set()
                for node in names:
                    assmts.update(node.cf_state)
                dependencies[assmt] = assmts
Robert Bradshaw's avatar
Robert Bradshaw committed
375
            if entry.type is unspecified_type:
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
                assignments.update(entry.cf_assignments)
            else:
                assmts_resolved.update(entry.cf_assignments)

        def infer_name_node_type(node):
            types = [assmt.inferred_type for assmt in node.cf_state]
            if not types:
                node_type = py_object_type
            else:
                node_type = spanning_type(
                    types, entry.might_overflow, entry.pos)
            node.inferred_type = node_type

        def infer_name_node_type_partial(node):
            types = [assmt.inferred_type for assmt in node.cf_state
                     if assmt.inferred_type is not None]
            if not types:
                return
            return spanning_type(types, entry.might_overflow, entry.pos)

        def resolve_assignments(assignments):
            resolved = set()
            for assmt in assignments:
                deps = dependencies[assmt]
                # All assignments are resolved
                if assmts_resolved.issuperset(deps):
                    for node in assmt_to_names[assmt]:
                        infer_name_node_type(node)
                    # Resolve assmt
                    inferred_type = assmt.infer_type()
                    done = False
                    assmts_resolved.add(assmt)
                    resolved.add(assmt)
            assignments -= resolved
            return resolved

        def partial_infer(assmt):
            partial_types = []
            for node in assmt_to_names[assmt]:
                partial_type = infer_name_node_type_partial(node)
                if partial_type is None:
                    return False
                partial_types.append((node, partial_type))
            for node, partial_type in partial_types:
                node.inferred_type = partial_type
            assmt.infer_type()
            return True

        partial_assmts = set()
        def resolve_partial(assignments):
            # try to handle circular references
            partials = set()
            for assmt in assignments:
                partial_types = []
                if assmt in partial_assmts:
                    continue
                for node in assmt_to_names[assmt]:
                    if partial_infer(assmt):
                        partials.add(assmt)
                        assmts_resolved.add(assmt)
            partial_assmts.update(partials)
            return partials

        # Infer assignments
Robert Bradshaw's avatar
Robert Bradshaw committed
440
        while True:
441 442 443 444 445 446 447 448 449 450 451
            if not resolve_assignments(assignments):
                if not resolve_partial(assignments):
                    break
        inferred = set()
        # First pass
        for entry in scope.entries.values():
            if entry.type is not unspecified_type:
                continue
            entry_type = py_object_type
            if assmts_resolved.issuperset(entry.cf_assignments):
                types = [assmt.inferred_type for assmt in entry.cf_assignments]
Mark Florisson's avatar
Mark Florisson committed
452
                if types and Utils.all(types):
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
                    entry_type = spanning_type(
                        types, entry.might_overflow, entry.pos)
                    inferred.add(entry)
            self.set_entry_type(entry, entry_type)

        def reinfer():
            dirty = False
            for entry in inferred:
                types = [assmt.infer_type()
                         for assmt in entry.cf_assignments]
                new_type = spanning_type(types, entry.might_overflow, entry.pos)
                if new_type != entry.type:
                    self.set_entry_type(entry, new_type)
                    dirty = True
            return dirty

        # types propagation
        while reinfer():
            pass

        if verbose:
            for entry in inferred:
                message(entry.pos, "inferred '%s' to be of type '%s'" % (
                    entry.name, entry.type))

Robert Bradshaw's avatar
Robert Bradshaw committed
478

479 480
def find_spanning_type(type1, type2):
    if type1 is type2:
481
        result_type = type1
482 483 484
    elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
        # type inference can break the coercion back to a Python bool
        # if it returns an arbitrary int type here
485
        return py_object_type
486 487
    else:
        result_type = PyrexTypes.spanning_type(type1, type2)
Craig Citro's avatar
Craig Citro committed
488 489
    if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
                       Builtin.float_type):
490 491 492
        # Python's float type is just a C double, so it's safe to
        # use the C type instead
        return PyrexTypes.c_double_type
493 494
    return result_type

495
def aggressive_spanning_type(types, might_overflow, pos):
496
    result_type = reduce(find_spanning_type, types)
Robert Bradshaw's avatar
Robert Bradshaw committed
497 498
    if result_type.is_reference:
        result_type = result_type.ref_base_type
Robert Bradshaw's avatar
Robert Bradshaw committed
499 500
    if result_type.is_const:
        result_type = result_type.const_base_type
501 502
    if result_type.is_cpp_class:
        result_type.check_nullary_constructor(pos)
503
    return result_type
504

505
def safe_spanning_type(types, might_overflow, pos):
506
    result_type = reduce(find_spanning_type, types)
Robert Bradshaw's avatar
Robert Bradshaw committed
507 508
    if result_type.is_const:
        result_type = result_type.const_base_type
Robert Bradshaw's avatar
Robert Bradshaw committed
509 510
    if result_type.is_reference:
        result_type = result_type.ref_base_type
511 512
    if result_type.is_cpp_class:
        result_type.check_nullary_constructor(pos)
513
    if result_type.is_pyobject:
514 515 516 517 518 519 520 521
        # In theory, any specific Python type is always safe to
        # infer. However, inferring str can cause some existing code
        # to break, since we are also now much more strict about
        # coercion from str to char *. See trac #553.
        if result_type.name == 'str':
            return py_object_type
        else:
            return result_type
522 523 524 525 526 527 528
    elif result_type is PyrexTypes.c_double_type:
        # Python's float type is just a C double, so it's safe to use
        # the C type instead
        return result_type
    elif result_type is PyrexTypes.c_bint_type:
        # find_spanning_type() only returns 'bint' for clean boolean
        # operations without other int types, so this is safe, too
529
        return result_type
530
    elif result_type.is_ptr:
531
        # Any pointer except (signed|unsigned|) char* can't implicitly
532
        # become a PyObject, and inferring char* is now accepted, too.
Robert Bradshaw's avatar
Robert Bradshaw committed
533 534 535 536 537 538 539 540 541
        return result_type
    elif result_type.is_cpp_class:
        # These can't implicitly become Python objects either.
        return result_type
    elif result_type.is_struct:
        # Though we have struct -> object for some structs, this is uncommonly
        # used, won't arise in pure Python, and there shouldn't be side
        # effects, so I'm declaring this safe.
        return result_type
542
    # TODO: double complex should be OK as well, but we need
Robert Bradshaw's avatar
Robert Bradshaw committed
543
    # to make sure everything is supported.
544 545
    elif result_type.is_int and not might_overflow:
        return result_type
546 547
    return py_object_type

548

Robert Bradshaw's avatar
Robert Bradshaw committed
549 550
def get_type_inferer():
    return SimpleAssignmentTypeInferer()