Optimize.py 92.3 KB
Newer Older
1 2
import Nodes
import ExprNodes
3
import PyrexTypes
4
import Visitor
5 6 7 8
import Builtin
import UtilNodes
import TypeSlots
import Symtab
9
import Options
10
import Naming
11

12
from Code import UtilityCode
13
from StringEncoding import EncodedString, BytesLiteral
14
from Errors import error
15 16
from ParseTreeTransforms import SkipDeclarations

17 18
import codecs

19 20 21 22 23
try:
    reduce
except NameError:
    from functools import reduce

24 25 26 27 28
try:
    set
except NameError:
    from sets import Set as set

29 30 31 32
class FakePythonEnv(object):
    "A fake environment for creating type test nodes etc."
    nogil = False

33
def unwrap_node(node):
34 35
    while isinstance(node, UtilNodes.ResultRefNode):
        node = node.expression
36
    return node
37 38

def is_common_value(a, b):
39 40
    a = unwrap_node(a)
    b = unwrap_node(b)
41 42 43
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
        return a.name == b.name
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
44
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
45 46
    return False

47 48 49 50
class IterationTransform(Visitor.VisitorTransform):
    """Transform some common for-in loop patterns into efficient C loops:

    - for-in-dict loop becomes a while loop calling PyDict_Next()
Stefan Behnel's avatar
Stefan Behnel committed
51
    - for-in-enumerate is replaced by an external counter variable
52
    - for-in-range loop becomes a plain C for loop
53 54 55 56 57 58 59 60 61 62 63 64 65 66
    """
    PyDict_Next_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
            ])

    PyDict_Next_name = EncodedString("PyDict_Next")

    PyDict_Next_entry = Symtab.Entry(
        PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)

67
    visit_Node = Visitor.VisitorTransform.recurse_to_children
Stefan Behnel's avatar
Stefan Behnel committed
68

69 70 71 72 73 74 75 76 77 78 79 80
    def visit_ModuleNode(self, node):
        self.current_scope = node.scope
        self.visitchildren(node)
        return node

    def visit_DefNode(self, node):
        oldscope = self.current_scope
        self.current_scope = node.entry.scope
        self.visitchildren(node)
        self.current_scope = oldscope
        return node

81 82
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
83 84 85
        return self._optimise_for_loop(node)

    def _optimise_for_loop(self, node):
86
        iterator = node.iterator.sequence
87 88
        if iterator.type is Builtin.dict_type:
            # like iterating over dict.keys()
Stefan Behnel's avatar
Stefan Behnel committed
89 90
            return self._transform_dict_iteration(
                node, dict_obj=iterator, keys=True, values=False)
91

92
        # C array (slice) iteration?
93 94 95
        if isinstance(iterator, ExprNodes.SliceIndexNode) and \
               (iterator.base.type.is_array or iterator.base.type.is_ptr):
            return self._transform_carray_iteration(node, iterator)
96 97
        elif iterator.type.is_array:
            return self._transform_carray_iteration(node, iterator)
98
        elif not isinstance(iterator, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
99 100 101
            return node

        function = iterator.function
102
        # dict iteration?
Stefan Behnel's avatar
Stefan Behnel committed
103 104
        if isinstance(function, ExprNodes.AttributeNode) and \
                function.obj.type == Builtin.dict_type:
105 106 107 108 109 110 111 112 113 114 115 116
            dict_obj = function.obj
            method = function.attribute

            keys = values = False
            if method == 'iterkeys':
                keys = True
            elif method == 'itervalues':
                values = True
            elif method == 'iteritems':
                keys = values = True
            else:
                return node
Stefan Behnel's avatar
Stefan Behnel committed
117 118
            return self._transform_dict_iteration(
                node, dict_obj, keys, values)
119

120
        # enumerate() ?
Stefan Behnel's avatar
Stefan Behnel committed
121
        if iterator.self is None and function.is_name and \
122
               function.entry and function.entry.is_builtin and \
123 124 125
               function.name == 'enumerate':
            return self._transform_enumerate_iteration(node, iterator)

126 127
        # range() iteration?
        if Options.convert_range and node.target.type.is_int:
Stefan Behnel's avatar
Stefan Behnel committed
128 129 130
            if iterator.self is None and function.is_name and \
                   function.entry and function.entry.is_builtin and \
                   function.name in ('range', 'xrange'):
Stefan Behnel's avatar
Stefan Behnel committed
131
                return self._transform_range_iteration(node, iterator)
132

Stefan Behnel's avatar
Stefan Behnel committed
133
        return node
134

135
    def _transform_carray_iteration(self, node, slice_node):
136 137 138 139 140 141 142 143 144 145 146 147 148 149
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
            slice_base = slice_node.base
            start = slice_node.start
            stop = slice_node.stop
            step = None
            if not stop:
                return node
        elif slice_node.type.is_array and slice_node.type.size is not None:
            slice_base = slice_node
            start = None
            stop = ExprNodes.IntNode(
                slice_node.pos, value=str(slice_node.type.size))
            step = None
        else:
150 151
            return node

152 153 154 155
        ptr_type = slice_base.type
        if ptr_type.is_array:
            ptr_type = ptr_type.element_ptr_type()
        carray_ptr = slice_base.coerce_to_simple(self.current_scope)
156

157
        if start and start.constant_result != 0:
158 159 160 161 162
            start_ptr_node = ExprNodes.AddNode(
                start.pos,
                operand1=carray_ptr,
                operator='+',
                operand2=start,
163
                type=ptr_type)
164
        else:
165
            start_ptr_node = carray_ptr
166

167 168 169 170 171
        stop_ptr_node = ExprNodes.AddNode(
            stop.pos,
            operand1=carray_ptr,
            operator='+',
            operand2=stop,
172
            type=ptr_type
173
            ).coerce_to_simple(self.current_scope)
174

175
        counter = UtilNodes.TempHandle(ptr_type)
176 177
        counter_temp = counter.ref(node.target.pos)

178
        if slice_base.type.is_string and node.target.type.is_pyobject:
179
            # special case: char* -> bytes
180 181
            target_value = ExprNodes.SliceIndexNode(
                node.target.pos,
182 183 184 185 186 187 188
                start=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                stop=ExprNodes.IntNode(node.target.pos, value='1',
                                       constant_result=1,
                                       type=PyrexTypes.c_int_type),
                base=counter_temp,
189 190 191 192 193
                type=Builtin.bytes_type,
                is_temp=1)
        else:
            target_value = ExprNodes.IndexNode(
                node.target.pos,
194 195 196 197
                index=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                base=counter_temp,
198
                is_buffer_access=False,
199
                type=ptr_type.base_type)
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215

        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
                                                  self.current_scope)

        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
            rhs = target_value)

        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

        for_node = Nodes.ForFromStatNode(
            node.pos,
216
            bound1=start_ptr_node, relation1='<=',
217
            target=counter_temp,
218
            relation2='<', bound2=stop_ptr_node,
219 220 221 222 223 224 225 226
            step=step, body=body,
            else_clause=node.else_clause,
            from_range=True)

        return UtilNodes.TempsBlockNode(
            node.pos, temps=[counter],
            body=for_node)

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    def _transform_enumerate_iteration(self, node, enumerate_function):
        args = enumerate_function.arg_tuple.args
        if len(args) == 0:
            error(enumerate_function.pos,
                  "enumerate() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(enumerate_function.pos,
                  "enumerate() takes at most 1 argument")
            return node

        if not node.target.is_sequence_constructor:
            # leave this untouched for now
            return node
        targets = node.target.args
        if len(targets) != 2:
            # leave this untouched for now
            return node
        if not isinstance(targets[0], ExprNodes.NameNode):
            # leave this untouched for now
            return node

        enumerate_target, iterable_target = targets
        counter_type = enumerate_target.type

        if not counter_type.is_pyobject and not counter_type.is_int:
            # nothing we can do here, I guess
            return node

256 257 258 259
        temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
                                                      value='0',
                                                      type=counter_type,
                                                      constant_result=0))
260 261
        inc_expression = ExprNodes.AddNode(
            enumerate_function.pos,
262
            operand1 = temp,
263
            operand2 = ExprNodes.IntNode(node.pos, value='1',
264 265
                                         type=counter_type,
                                         constant_result=1),
266 267 268 269 270
            operator = '+',
            type = counter_type,
            is_temp = counter_type.is_pyobject
            )

271 272 273 274
        loop_body = [
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
                lhs = enumerate_target,
275
                rhs = temp),
276 277
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
278
                lhs = temp,
279 280
                rhs = inc_expression)
            ]
281

282 283 284 285 286 287 288
        if isinstance(node.body, Nodes.StatListNode):
            node.body.stats = loop_body + node.body.stats
        else:
            loop_body.append(node.body)
            node.body = Nodes.StatListNode(
                node.body.pos,
                stats = loop_body)
289 290

        node.target = iterable_target
291
        node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
292 293 294
        node.iterator.sequence = enumerate_function.arg_tuple.args[0]

        # recurse into loop to check for further optimisations
295
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
296

297 298 299 300 301
    def _transform_range_iteration(self, node, range_function):
        args = range_function.arg_tuple.args
        if len(args) < 3:
            step_pos = range_function.pos
            step_value = 1
302 303
            step = ExprNodes.IntNode(step_pos, value='1',
                                     constant_result=1)
304 305 306
        else:
            step = args[2]
            step_pos = step.pos
307
            if not isinstance(step.constant_result, (int, long)):
308 309
                # cannot determine step direction
                return node
310 311 312
            step_value = step.constant_result
            if step_value == 0:
                # will lead to an error elsewhere
313 314
                return node
            if not isinstance(step, ExprNodes.IntNode):
315 316
                step = ExprNodes.IntNode(step_pos, value=str(step_value),
                                         constant_result=step_value)
317

318
        if step_value < 0:
319
            step.value = str(-step_value)
320 321 322
            relation1 = '>='
            relation2 = '>'
        else:
323 324
            relation1 = '<='
            relation2 = '<'
325 326

        if len(args) == 1:
327 328
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
                                       constant_result=0)
329
            bound2 = args[0].coerce_to_integer(self.current_scope)
330
        else:
331 332 333
            bound1 = args[0].coerce_to_integer(self.current_scope)
            bound2 = args[1].coerce_to_integer(self.current_scope)
        step = step.coerce_to_integer(self.current_scope)
334

335
        if not bound2.is_literal:
336 337 338 339 340 341
            # stop bound must be immutable => keep it in a temp var
            bound2_is_temp = True
            bound2 = UtilNodes.LetRefNode(bound2)
        else:
            bound2_is_temp = False

342 343 344 345 346 347 348
        for_node = Nodes.ForFromStatNode(
            node.pos,
            target=node.target,
            bound1=bound1, relation1=relation1,
            relation2=relation2, bound2=bound2,
            step=step, body=node.body,
            else_clause=node.else_clause,
Magnus Lie Hetland's avatar
Magnus Lie Hetland committed
349
            from_range=True)
350 351 352 353

        if bound2_is_temp:
            for_node = UtilNodes.LetNode(bound2, for_node)

354 355
        return for_node

Stefan Behnel's avatar
Stefan Behnel committed
356
    def _transform_dict_iteration(self, node, dict_obj, keys, values):
357 358 359
        py_object_ptr = PyrexTypes.c_void_ptr_type

        temps = []
360 361 362
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        dict_temp = temp.ref(dict_obj.pos)
363 364
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
365
        pos_temp = temp.ref(node.pos)
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
        pos_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=pos_temp,
            type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
        if keys:
            temp = UtilNodes.TempHandle(py_object_ptr)
            temps.append(temp)
            key_temp = temp.ref(node.target.pos)
            key_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=key_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            key_temp_addr = key_temp = ExprNodes.NullNode(
                pos=node.target.pos)
        if values:
            temp = UtilNodes.TempHandle(py_object_ptr)
            temps.append(temp)
            value_temp = temp.ref(node.target.pos)
            value_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=value_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            value_temp_addr = value_temp = ExprNodes.NullNode(
                pos=node.target.pos)

        key_target = value_target = node.target
        tuple_target = None
        if keys and values:
            if node.target.is_sequence_constructor:
                if len(node.target.args) == 2:
                    key_target, value_target = node.target.args
                else:
Stefan Behnel's avatar
Stefan Behnel committed
397
                    # unusual case that may or may not lead to an error
398 399 400 401
                    return node
            else:
                tuple_target = node.target

402 403
        def coerce_object_to(obj_node, dest_type):
            if dest_type.is_pyobject:
404 405 406
                if dest_type != obj_node.type:
                    if dest_type.is_extension_type or dest_type.is_builtin_type:
                        obj_node = ExprNodes.PyTypeTestNode(
407
                            obj_node, dest_type, self.current_scope, notnone=True)
408 409 410 411
                result = ExprNodes.TypecastNode(
                    obj_node.pos,
                    operand = obj_node,
                    type = dest_type)
412
                return (result, None)
413 414 415 416 417 418 419 420 421
            else:
                temp = UtilNodes.TempHandle(dest_type)
                temps.append(temp)
                temp_result = temp.ref(obj_node.pos)
                class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
                    def result(self):
                        return temp_result.result()
                    def generate_execution_code(self, code):
                        self.generate_result_code(code)
422
                return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
423 424 425 426 427 428 429 430

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        if tuple_target:
431
            tuple_result = ExprNodes.TupleNode(
432
                pos = tuple_target.pos,
433
                args = [key_temp, value_temp],
434 435
                is_temp = 1,
                type = Builtin.tuple_type,
436
                )
437
            body.stats.insert(
438 439 440 441
                0, Nodes.SingleAssignmentNode(
                    pos = tuple_target.pos,
                    lhs = tuple_target,
                    rhs = tuple_result))
442
        else:
443 444 445
            # execute all coercions before the assignments
            coercion_stats = []
            assign_stats = []
446
            if keys:
447 448 449 450 451 452 453
                temp_result, coercion = coerce_object_to(
                    key_temp, key_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = key_temp.pos,
454 455
                        lhs = key_target,
                        rhs = temp_result))
456 457 458 459 460 461 462 463
            if values:
                temp_result, coercion = coerce_object_to(
                    value_temp, value_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = value_temp.pos,
464 465
                        lhs = value_target,
                        rhs = temp_result))
466
            body.stats[0:0] = coercion_stats + assign_stats
467 468

        result_code = [
469 470 471 472
            Nodes.SingleAssignmentNode(
                pos = dict_obj.pos,
                lhs = dict_temp,
                rhs = dict_obj),
473 474 475
            Nodes.SingleAssignmentNode(
                pos = node.pos,
                lhs = pos_temp,
476 477
                rhs = ExprNodes.IntNode(node.pos, value='0',
                                        constant_result=0)),
478 479 480 481 482 483
            Nodes.WhileStatNode(
                pos = node.pos,
                condition = ExprNodes.SimpleCallNode(
                    pos = dict_obj.pos,
                    type = PyrexTypes.c_bint_type,
                    function = ExprNodes.NameNode(
Stefan Behnel's avatar
Stefan Behnel committed
484 485
                        pos = dict_obj.pos,
                        name = self.PyDict_Next_name,
486 487
                        type = self.PyDict_Next_func_type,
                        entry = self.PyDict_Next_entry),
488
                    args = [dict_temp, pos_temp_addr,
489 490 491 492 493 494 495 496 497 498
                            key_temp_addr, value_temp_addr]
                    ),
                body = body,
                else_clause = node.else_clause
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
499
                node.pos,
500 501 502 503
                stats = result_code
                ))


504 505 506 507
class SwitchTransform(Visitor.VisitorTransform):
    """
    This transformation tries to turn long if statements into C switch statements. 
    The requirement is that every clause be an (or of) var == value, where the var
Robert Bradshaw's avatar
Robert Bradshaw committed
508
    is common among all clauses and both var and value are ints. 
509
    """
510 511 512
    NO_MATCH = (None, None, None)

    def extract_conditions(self, cond, allow_not_in):
513 514 515 516 517 518 519 520 521 522
        while True:
            if isinstance(cond, ExprNodes.CoerceToTempNode):
                cond = cond.arg
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
                # this is what we get from the FlattenInListTransform
                cond = cond.subexpression
            elif isinstance(cond, ExprNodes.TypecastNode):
                cond = cond.operand
            else:
                break
523

524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
        if isinstance(cond, ExprNodes.PrimaryCmpNode):
            if cond.cascade is None and not cond.is_python_comparison():
                if cond.operator == '==':
                    not_in = False
                elif allow_not_in and cond.operator == '!=':
                    not_in = True
                else:
                    return self.NO_MATCH
                # this looks somewhat silly, but it does the right
                # checks for NameNode and AttributeNode
                if is_common_value(cond.operand1, cond.operand1):
                    if cond.operand2.is_literal:
                        return not_in, cond.operand1, [cond.operand2]
                    elif getattr(cond.operand2, 'entry', None) \
                             and cond.operand2.entry.is_const:
                        return not_in, cond.operand1, [cond.operand2]
                if is_common_value(cond.operand2, cond.operand2):
                    if cond.operand1.is_literal:
                        return not_in, cond.operand2, [cond.operand1]
                    elif getattr(cond.operand1, 'entry', None) \
                             and cond.operand1.entry.is_const:
                        return not_in, cond.operand2, [cond.operand1]
        elif isinstance(cond, ExprNodes.BoolBinopNode):
            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
                allow_not_in = (cond.operator == 'and')
                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
                    if (not not_in_1) or allow_not_in:
                        return not_in_1, t1, c1+c2
        return self.NO_MATCH

    def extract_common_conditions(self, common_var, condition, allow_not_in):
        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
558
        if var is None:
559
            return self.NO_MATCH
560
        elif common_var is not None and not is_common_value(var, common_var):
561
            return self.NO_MATCH
562
        elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
            return self.NO_MATCH
        return not_in, var, conditions

    def has_duplicate_values(self, condition_values):
        # duplicated values don't work in a switch statement
        seen = set()
        for value in condition_values:
            if value.constant_result is not ExprNodes.not_a_constant:
                if value.constant_result in seen:
                    return True
                seen.add(value.constant_result)
            else:
                # this isn't completely safe as we don't know the
                # final C value, but this is about the best we can do
                seen.add(getattr(getattr(value, 'entry', None), 'cname'))
        return False
579

580 581 582 583
    def visit_IfStatNode(self, node):
        common_var = None
        cases = []
        for if_clause in node.if_clauses:
584 585
            _, common_var, conditions = self.extract_common_conditions(
                common_var, if_clause.condition, False)
586
            if common_var is None:
587
                self.visitchildren(node)
588
                return node
589 590 591 592 593
            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                              conditions = conditions,
                                              body = if_clause.body))

        if sum([ len(case.conditions) for case in cases ]) < 2:
594 595 596 597
            self.visitchildren(node)
            return node
        if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
            self.visitchildren(node)
598
            return node
599

Robert Bradshaw's avatar
Robert Bradshaw committed
600
        common_var = unwrap_node(common_var)
601 602 603 604 605 606 607
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = node.else_clause)
        return switch_node

    def visit_CondExprNode(self, node):
608 609 610 611 612 613
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node.test, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
614
            return node
615 616 617
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            node.true_val, node.false_val)
618 619

    def visit_BoolBinopNode(self, node):
620 621 622 623 624 625
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
626 627
            return node

628 629 630 631 632 633 634
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            ExprNodes.BoolNode(node.pos, value=True),
            ExprNodes.BoolNode(node.pos, value=False))

    def build_simple_switch_statement(self, node, common_var, conditions,
                                      not_in, true_val, false_val):
635 636 637 638
        result_ref = UtilNodes.ResultRefNode(node)
        true_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
639
            rhs = true_val,
640 641 642 643
            first = True)
        false_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
644
            rhs = false_val,
645 646
            first = True)

647 648 649
        if not_in:
            true_body, false_body = false_body, true_body

650 651 652 653 654 655 656 657 658 659
        cases = [Nodes.SwitchCaseNode(pos = node.pos,
                                      conditions = conditions,
                                      body = true_body)]

        common_var = unwrap_node(common_var)
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = false_body)
        return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
660

661
    visit_Node = Visitor.VisitorTransform.recurse_to_children
662
                              
663

664
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
    """
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
    of comparisons. 
    """
    
    def visit_PrimaryCmpNode(self, node):
        self.visitchildren(node)
        if node.cascade is not None:
            return node
        elif node.operator == 'in':
            conjunction = 'or'
            eq_or_neq = '=='
        elif node.operator == 'not_in':
            conjunction = 'and'
            eq_or_neq = '!='
        else:
            return node
682

683 684 685
        if not isinstance(node.operand2, (ExprNodes.TupleNode,
                                          ExprNodes.ListNode,
                                          ExprNodes.SetNode)):
Stefan Behnel's avatar
Stefan Behnel committed
686
            return node
687

Stefan Behnel's avatar
Stefan Behnel committed
688 689 690
        args = node.operand2.args
        if len(args) == 0:
            return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
691

692
        lhs = UtilNodes.ResultRefNode(node.operand1)
Stefan Behnel's avatar
Stefan Behnel committed
693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712

        conds = []
        for arg in args:
            cond = ExprNodes.PrimaryCmpNode(
                                pos = node.pos,
                                operand1 = lhs,
                                operator = eq_or_neq,
                                operand2 = arg,
                                cascade = None)
            conds.append(ExprNodes.TypecastNode(
                                pos = node.pos, 
                                operand = cond,
                                type = PyrexTypes.c_bint_type))
        def concat(left, right):
            return ExprNodes.BoolBinopNode(
                                pos = node.pos, 
                                operator = conjunction,
                                operand1 = left,
                                operand2 = right)

713
        condition = reduce(concat, conds)
Stefan Behnel's avatar
Stefan Behnel committed
714
        return UtilNodes.EvalWithTempExprNode(lhs, condition)
715

716
    visit_Node = Visitor.VisitorTransform.recurse_to_children
717 718


719 720 721 722 723 724
class DropRefcountingTransform(Visitor.VisitorTransform):
    """Drop ref-counting in safe places.
    """
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_ParallelAssignmentNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
725 726 727
        """
        Parallel swap assignments like 'a,b = b,a' are safe.
        """
728 729 730 731
        left_names, right_names = [], []
        left_indices, right_indices = [], []
        temps = []

732 733
        for stat in node.stats:
            if isinstance(stat, Nodes.SingleAssignmentNode):
734 735
                if not self._extract_operand(stat.lhs, left_names,
                                             left_indices, temps):
736
                    return node
737 738
                if not self._extract_operand(stat.rhs, right_names,
                                             right_indices, temps):
739
                    return node
740 741 742
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
                # FIXME
                return node
743 744 745
            else:
                return node

746 747
        if left_names or right_names:
            # lhs/rhs names must be a non-redundant permutation
748 749
            lnames = [ path for path, n in left_names ]
            rnames = [ path for path, n in right_names ]
750 751 752
            if set(lnames) != set(rnames):
                return node
            if len(set(lnames)) != len(right_names):
753 754
                return node

755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777
        if left_indices or right_indices:
            # base name and index of index nodes must be a
            # non-redundant permutation
            lindices = []
            for lhs_node in left_indices:
                index_id = self._extract_index_id(lhs_node)
                if not index_id:
                    return node
                lindices.append(index_id)
            rindices = []
            for rhs_node in right_indices:
                index_id = self._extract_index_id(rhs_node)
                if not index_id:
                    return node
                rindices.append(index_id)
            
            if set(lindices) != set(rindices):
                return node
            if len(set(lindices)) != len(right_indices):
                return node

            # really supporting IndexNode requires support in
            # __Pyx_GetItemInt(), so let's stop short for now
778 779
            return node

780 781 782 783
        temp_args = [t.arg for t in temps]
        for temp in temps:
            temp.use_managed_ref = False

784
        for _, name_node in left_names + right_names:
785 786 787 788 789
            if name_node not in temp_args:
                name_node.use_managed_ref = False

        for index_node in left_indices + right_indices:
            index_node.use_managed_ref = False
790 791 792

        return node

793 794 795 796 797 798 799
    def _extract_operand(self, node, names, indices, temps):
        node = unwrap_node(node)
        if not node.type.is_pyobject:
            return False
        if isinstance(node, ExprNodes.CoerceToTempNode):
            temps.append(node)
            node = node.arg
800 801 802 803
        name_path = []
        obj_node = node
        while isinstance(obj_node, ExprNodes.AttributeNode):
            if obj_node.is_py_attr:
804
                return False
805 806 807 808 809
            name_path.append(obj_node.member)
            obj_node = obj_node.obj
        if isinstance(obj_node, ExprNodes.NameNode):
            name_path.append(obj_node.name)
            names.append( ('.'.join(name_path[::-1]), node) )
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
        elif isinstance(node, ExprNodes.IndexNode):
            if node.base.type != Builtin.list_type:
                return False
            if not node.index.type.is_int:
                return False
            if not isinstance(node.base, ExprNodes.NameNode):
                return False
            indices.append(node)
        else:
            return False
        return True

    def _extract_index_id(self, index_node):
        base = index_node.base
        index = index_node.index
        if isinstance(index, ExprNodes.NameNode):
            index_val = index.name
        elif isinstance(index, ExprNodes.ConstNode):
            # FIXME:
            return None
        else:
            return None
        return (base.name, index_val)

834

835 836 837 838 839 840 841
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
    """Optimize some common calls to builtin types *before* the type
    analysis phase and *after* the declarations analysis phase.

    This transform cannot make use of any argument types, but it can
    restructure the tree in a way that the type analysis phase can
    respond to.
Stefan Behnel's avatar
Stefan Behnel committed
842 843 844 845

    Introducing C function calls here may not be a good idea.  Move
    them to the OptimizeBuiltinCalls transform instead, which runs
    after type analyis.
846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
    """
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        return self._dispatch_to_handler(node, function, node.args)

    def visit_GeneralCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg_tuple.args
        return self._dispatch_to_handler(
            node, function, args, node.keyword_args)

    def _function_is_builtin_name(self, function):
        if not function.is_name:
            return False
872
        entry = self.current_env().lookup(function.name)
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958
        if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
            return False
        return True

    def _dispatch_to_handler(self, node, function, args, kwargs=None):
        if kwargs is None:
            handler_name = '_handle_simple_function_%s' % function.name
        else:
            handler_name = '_handle_general_function_%s' % function.name
        handle_call = getattr(self, handler_name, None)
        if handle_call is not None:
            if kwargs is None:
                return handle_call(node, args)
            else:
                return handle_call(node, args, kwargs)
        return node

    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
        node.function = ExprNodes.PythonCapiFunctionNode(
            node.function.pos, node.function.name, cname, func_type,
            utility_code = utility_code)

    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

    # specific handlers for simple call nodes

    def _handle_simple_function_set(self, node, pos_args):
        """Replace set([a,b,...]) by a literal set {a,b,...} and
        set([ x for ... ]) by a literal { x for ... }.
        """
        arg_count = len(pos_args)
        if arg_count == 0:
            return ExprNodes.SetNode(node.pos, args=[],
                                     type=Builtin.set_type)
        if arg_count > 1:
            return node
        iterable = pos_args[0]
        if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
            return ExprNodes.SetNode(node.pos, args=iterable.args)
        elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
                 isinstance(iterable.target, (ExprNodes.ListNode,
                                              ExprNodes.SetNode)):
            iterable.target = ExprNodes.SetNode(node.pos, args=[])
            iterable.pos = node.pos
            return iterable
        else:
            return node

    def _handle_simple_function_dict(self, node, pos_args):
        """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.ComprehensionNode) and \
               isinstance(arg.target, (ExprNodes.ListNode,
                                       ExprNodes.SetNode)):
            append_node = arg.append
            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
                   len(append_node.expr.args) == 2:
                key_node, value_node = append_node.expr.args
                target_node = ExprNodes.DictNode(
                    pos=arg.target.pos, key_value_pairs=[])
                new_append_node = ExprNodes.DictComprehensionAppendNode(
                    append_node.pos, target=target_node,
                    key_expr=key_node, value_expr=value_node)
                arg.target = target_node
                arg.type = target_node.type
                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
                return replace_in(arg)
        return node

959 960 961 962 963 964 965
    def _handle_simple_function_float(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.FloatNode(node.pos, value='0.0')
        if len(pos_args) > 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
        return node

966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981
    # specific handlers for general call nodes

    def _handle_general_function_dict(self, node, pos_args, kwargs):
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
        construction which is done anyway.
        """
        if len(pos_args) > 0:
            return node
        if not isinstance(kwargs, ExprNodes.DictNode):
            return node
        if node.starstar_arg:
            # we could optimize this by updating the kw dict instead
            return node
        return kwargs


982
class OptimizeBuiltinCalls(Visitor.EnvTransform):
Stefan Behnel's avatar
Stefan Behnel committed
983
    """Optimize some common methods calls and instantiation patterns
984 985 986 987 988
    for builtin types *after* the type analysis phase.

    Running after type analysis, this transform can only perform
    function replacements that do not alter the function return type
    in a way that was not anticipated by the type analysis.
989
    """
990 991
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
992

993
    def visit_GeneralCallNode(self, node):
994
        self.visitchildren(node)
995 996 997 998 999 1000
        function = node.function
        if not function.type.is_pyobject:
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1001 1002
        if node.starstar_arg:
            return node
1003
        args = arg_tuple.args
1004
        return self._dispatch_to_handler(
1005
            node, function, args, node.keyword_args)
1006 1007 1008

    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
1009
        function = node.function
1010 1011 1012 1013 1014 1015 1016
        if function.type.is_pyobject:
            arg_tuple = node.arg_tuple
            if not isinstance(arg_tuple, ExprNodes.TupleNode):
                return node
            args = arg_tuple.args
        else:
            args = node.args
1017
        return self._dispatch_to_handler(
1018
            node, function, args)
1019

1020 1021
    ### cleanup to avoid redundant coercions to/from Python types

1022 1023 1024
    def _visit_PyTypeTestNode(self, node):
        # disabled - appears to break assignments in some cases, and
        # also drops a None check, which might still be required
1025 1026 1027 1028 1029 1030 1031 1032
        """Flatten redundant type checks after tree changes.
        """
        old_arg = node.arg
        self.visitchildren(node)
        if old_arg is node.arg or node.arg.type != node.type:
            return node
        return node.arg

1033 1034 1035 1036 1037 1038 1039 1040 1041
    def visit_TypecastNode(self, node):
        """
        Drop redundant type casts.
        """
        self.visitchildren(node)
        if node.type == node.operand.type:
            return node.operand
        return node

1042 1043 1044 1045 1046 1047 1048
    def visit_CoerceToBooleanNode(self, node):
        """Drop redundant conversion nodes after tree changes.
        """
        self.visitchildren(node)
        arg = node.arg
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1049
                return arg.arg.coerce_to_boolean(self.current_env())
1050 1051
        return node

1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
    def visit_CoerceFromPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes.

        Also, optimise away calls to Python's builtin int() and
        float() if the result is going to be coerced back into a C
        type anyway.
        """
        self.visitchildren(node)
        arg = node.arg
        if not arg.type.is_pyobject:
            # no Python conversion left at all, just do a C coercion instead
            if node.type == arg.type:
                return arg
            else:
1066
                return arg.coerce_to(node.type, self.current_env())
1067 1068 1069 1070
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type is PyrexTypes.py_object_type:
                if node.type.assignable_from(arg.arg.type):
                    # completely redundant C->Py->C coercion
1071
                    return arg.arg.coerce_to(node.type, self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
1072 1073 1074 1075 1076 1077
        if isinstance(arg, ExprNodes.SimpleCallNode):
            if node.type.is_int or node.type.is_float:
                return self._optimise_numeric_cast_call(node, arg)
        return node

    def _optimise_numeric_cast_call(self, node, arg):
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
        function = arg.function
        if not isinstance(function, ExprNodes.NameNode) \
               or not function.type.is_builtin_type \
               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg.arg_tuple.args
        if len(args) != 1:
            return node
        func_arg = args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        elif func_arg.type.is_pyobject:
            # play safe: Python conversion might work on all sorts of things
            return node
        if function.name == 'int':
            if func_arg.type.is_int or node.type.is_int:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1097 1098
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1099 1100 1101 1102 1103
        elif function.name == 'float':
            if func_arg.type.is_float or node.type.is_float:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1104 1105
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1106 1107 1108 1109
        return node

    ### dispatch to specific optimisers

1110 1111 1112 1113 1114 1115 1116
    def _find_handler(self, match_name, has_kwargs):
        call_type = has_kwargs and 'general' or 'simple'
        handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
        if handler is None:
            handler = getattr(self, '_handle_any_%s' % match_name, None)
        return handler

1117
    def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1118
        if function.is_name:
1119 1120 1121
            # we only consider functions that are either builtin
            # Python functions or builtins that were already replaced
            # into a C function call (defined in the builtin scope)
1122 1123
            if not function.entry:
                return node
1124 1125 1126 1127
            is_builtin = function.entry.is_builtin \
                         or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
            if not is_builtin:
                return node
1128 1129 1130 1131 1132
            function_handler = self._find_handler(
                "function_%s" % function.name, kwargs)
            if function_handler is None:
                return node
            if kwargs:
1133
                return function_handler(node, arg_list, kwargs)
1134
            else:
1135 1136
                return function_handler(node, arg_list)
        elif function.is_attribute and function.type.is_pyobject:
Stefan Behnel's avatar
Stefan Behnel committed
1137
            attr_name = function.attribute
1138 1139
            self_arg = function.obj
            obj_type = self_arg.type
1140
            is_unbound_method = False
1141 1142 1143 1144 1145 1146 1147
            if obj_type.is_builtin_type:
                if obj_type is Builtin.type_type and arg_list and \
                         arg_list[0].type.is_pyobject:
                    # calling an unbound method like 'list.append(L,x)'
                    # (ignoring 'type.mro()' here ...)
                    type_name = function.obj.name
                    self_arg = None
1148
                    is_unbound_method = True
1149 1150 1151
                else:
                    type_name = obj_type.name
            else:
1152
                type_name = "object" # safety measure
1153
            method_handler = self._find_handler(
Stefan Behnel's avatar
Stefan Behnel committed
1154
                "method_%s_%s" % (type_name, attr_name), kwargs)
1155
            if method_handler is None:
Stefan Behnel's avatar
Stefan Behnel committed
1156 1157 1158 1159
                if attr_name in TypeSlots.method_name_to_slot \
                       or attr_name == '__new__':
                    method_handler = self._find_handler(
                        "slot%s" % attr_name, kwargs)
1160 1161
                if method_handler is None:
                    return node
1162 1163 1164
            if self_arg is not None:
                arg_list = [self_arg] + list(arg_list)
            if kwargs:
1165
                return method_handler(node, arg_list, kwargs, is_unbound_method)
1166
            else:
1167
                return method_handler(node, arg_list, is_unbound_method)
1168
        else:
1169
            return node
1170

1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

1187 1188
    ### builtin types

1189 1190 1191 1192 1193 1194
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
        Builtin.dict_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

    def _handle_simple_function_dict(self, node, pos_args):
1195
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
1196
        """
1197
        if len(pos_args) != 1:
1198
            return node
1199
        arg = pos_args[0]
1200 1201 1202 1203 1204
        if arg.type is Builtin.dict_type:
            arg = ExprNodes.NoneCheckNode(
                arg, "PyExc_TypeError", "'NoneType' is not iterable")
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1205
                args = [arg],
1206 1207 1208
                is_temp = node.is_temp
                )
        return node
1209

1210 1211 1212 1213 1214
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
        Builtin.tuple_type, [
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
            ])

1215
    def _handle_simple_function_tuple(self, node, pos_args):
1216 1217
        """Replace tuple([...]) by a call to PyList_AsTuple.
        """
1218
        if len(pos_args) != 1:
1219
            return node
1220
        list_arg = pos_args[0]
1221 1222 1223 1224
        if list_arg.type is not Builtin.list_type:
            return node
        if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
                                     ExprNodes.ListNode)):
1225
            pos_args[0] = ExprNodes.NoneCheckNode(
1226 1227
                list_arg, "PyExc_TypeError",
                "'NoneType' object is not iterable")
1228

1229 1230
        return ExprNodes.PythonCapiCallNode(
            node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1231
            args = pos_args,
1232 1233 1234
            is_temp = node.is_temp
            )

1235 1236 1237 1238 1239 1240 1241 1242
    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_double_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "((double)-1)",
        exception_check = True)

    def _handle_simple_function_float(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1243 1244 1245
        """Transform float() into either a C type cast or a faster C
        function call.
        """
1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256
        # Note: this requires the float() function to be typed as
        # returning a C 'double'
        if len(pos_args) != 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
            return node
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        if func_arg.type is PyrexTypes.c_double_type:
            return func_arg
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1257 1258
            return ExprNodes.TypecastNode(
                node.pos, operand=func_arg, type=node.type)
1259 1260 1261 1262 1263 1264 1265 1266
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_AsDouble",
            self.PyObject_AsDouble_func_type,
            args = pos_args,
            is_temp = node.is_temp,
            utility_code = pyobject_as_double_utility_code,
            py_name = "float")

1267 1268
    ### builtin functions

Stefan Behnel's avatar
Stefan Behnel committed
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282
    PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
            ])

    PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_function_getattr(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1283 1284
        """Replace 2/3 argument forms of getattr() by C-API calls.
        """
Stefan Behnel's avatar
Stefan Behnel committed
1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299
        if len(pos_args) == 2:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
                args = pos_args,
                is_temp = node.is_temp)
        elif len(pos_args) == 3:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
                args = pos_args,
                is_temp = node.is_temp,
                utility_code = Builtin.getattr3_utility_code)
        else:
            self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
        return node

1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
    PyObject_GetIter_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            ])

    PyCallIter_New_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("sentinel", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_function_iter(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1312 1313
        """Replace 1/2 argument forms of iter() by C-API calls.
        """
1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327
        if len(pos_args) == 1:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyObject_GetIter", self.PyObject_GetIter_func_type,
                args = pos_args,
                is_temp = node.is_temp)
        elif len(pos_args) == 2:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyCallIter_New", self.PyCallIter_New_func_type,
                args = pos_args,
                is_temp = node.is_temp)
        else:
            self._error_wrong_arg_count('iter', node, pos_args, '1 or 2')
        return node

1328 1329 1330 1331 1332
    Pyx_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
            ])

1333 1334 1335 1336 1337 1338 1339
    PyObject_Size_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
            ])

    _map_to_capi_len_function = {
        Builtin.unicode_type   : "PyUnicode_GET_SIZE",
Stefan Behnel's avatar
Stefan Behnel committed
1340
        Builtin.str_type       : "Py_SIZE", # works in Py2 and Py3
1341 1342 1343 1344 1345 1346 1347 1348
        Builtin.bytes_type     : "__Pyx_PyBytes_GET_SIZE",
        Builtin.list_type      : "PyList_GET_SIZE",
        Builtin.tuple_type     : "PyTuple_GET_SIZE",
        Builtin.dict_type      : "PyDict_Size",
        Builtin.set_type       : "PySet_Size",
        Builtin.frozenset_type : "PySet_Size",
        }.get

1349
    def _handle_simple_function_len(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1350 1351
        """Replace len(char*) by the equivalent call to strlen() and
        len(known_builtin_type) by an equivalent C-API call.
Stefan Behnel's avatar
Stefan Behnel committed
1352
        """
1353 1354 1355 1356 1357 1358
        if len(pos_args) != 1:
            self._error_wrong_arg_count('len', node, pos_args, 1)
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            arg = arg.arg
1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377
        if arg.type.is_string:
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "strlen", self.Pyx_strlen_func_type,
                args = [arg],
                is_temp = node.is_temp,
                utility_code = include_string_h_utility_code)
        elif arg.type.is_pyobject:
            cfunc_name = self._map_to_capi_len_function(arg.type)
            if cfunc_name is None:
                return node
            if not arg.is_literal:
                arg = ExprNodes.NoneCheckNode(
                    arg, "PyExc_TypeError",
                    "object of type 'NoneType' has no len()")
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, cfunc_name, self.PyObject_Size_func_type,
                args = [arg],
                is_temp = node.is_temp)
        else:
1378
            return node
1379
        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1380
            new_node = new_node.coerce_to(node.type, self.current_env())
1381
        return new_node
1382

1383 1384 1385 1386 1387 1388
    Pyx_Type_func_type = PyrexTypes.CFuncType(
        Builtin.type_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_type(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1389 1390
        """Replace type(o) by a macro call to Py_TYPE(o).
        """
1391 1392 1393 1394 1395 1396 1397 1398
        if len(pos_args) != 1:
            return node
        node = ExprNodes.PythonCapiCallNode(
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
            args = pos_args,
            is_temp = False)
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)

1399 1400
    ### special methods

1401 1402 1403 1404 1405
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
            ])

Stefan Behnel's avatar
Stefan Behnel committed
1406
    def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
        """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
        """
        obj = node.function.obj
        if not is_unbound_method or len(args) != 1:
            return node
        type_arg = args[0]
        if not obj.is_name or not type_arg.is_name:
            # play safe
            return node
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
            # not a known type, play safe
            return node
        if not type_arg.type_entry or not obj.type_entry:
            if obj.name != type_arg.name:
                return node
            # otherwise, we know it's a type and we know it's the same
            # type for both - that should do
        elif type_arg.type_entry != obj.type_entry:
1425
            # different types - may or may not lead to an error at runtime
1426 1427
            return node

Stefan Behnel's avatar
Stefan Behnel committed
1428 1429 1430 1431
        # FIXME: we could potentially look up the actual tp_new C
        # method of the extension type and call that instead of the
        # generic slot. That would also allow us to pass parameters
        # efficiently.
1432

1433 1434 1435 1436 1437 1438
        if not type_arg.type_entry:
            # arbitrary variable, needs a None check for safety
            type_arg = ExprNodes.NoneCheckNode(
                type_arg, "PyExc_TypeError",
                "object.__new__(X): X is not a type object (NoneType)")

1439
        return ExprNodes.PythonCapiCallNode(
1440
            node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1441
            args = [type_arg],
1442 1443 1444 1445
            utility_code = tpnew_utility_code,
            is_temp = node.is_temp
            )

1446 1447 1448 1449 1450 1451 1452 1453
    ### methods of builtin types

    PyObject_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
            ])

1454
    def _handle_simple_method_object_append(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1455 1456 1457
        """Optimistic optimisation as X.append() is almost always
        referring to a list.
        """
1458
        if len(args) != 2:
1459 1460
            return node

1461 1462 1463 1464
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
            args = args,
            is_temp = node.is_temp,
1465
            utility_code = append_utility_code
1466 1467
            )

Robert Bradshaw's avatar
Robert Bradshaw committed
1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            ])

    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
            ])

    def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1480 1481 1482
        """Optimistic optimisation as X.pop([n]) is almost always
        referring to a list.
        """
Robert Bradshaw's avatar
Robert Bradshaw committed
1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503
        if len(args) == 1:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
                args = args,
                is_temp = node.is_temp,
                utility_code = pop_utility_code
                )
        elif len(args) == 2:
            if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
                original_type = args[1].arg.type
                if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
                    args[1] = args[1].arg
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
                        args = args,
                        is_temp = node.is_temp,
                        utility_code = pop_index_utility_code
                        )
                
        return node

1504 1505 1506 1507 1508 1509 1510
    PyList_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_int_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")

1511
    def _handle_simple_method_list_append(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1512 1513
        """Call PyList_Append() instead of l.append().
        """
1514
        if len(args) != 2:
1515
            self._error_wrong_arg_count('list.append', node, args, 2)
1516
            return node
1517
        return self._substitute_method_call(
1518 1519
            node, "PyList_Append", self.PyList_Append_func_type,
            'append', is_unbound_method, args)
1520

1521 1522 1523 1524 1525
    single_param_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_int_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")
1526

1527
    def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1528 1529
        """Call PyList_Sort() instead of the 0-argument l.sort().
        """
1530
        if len(args) != 1:
1531
            return node
1532
        return self._substitute_method_call(
1533 1534
            node, "PyList_Sort", self.single_param_func_type,
            'sort', is_unbound_method, args)
1535

1536
    def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1537 1538
        """Call PyList_Reverse() instead of l.reverse().
        """
1539
        if len(args) != 1:
1540
            self._error_wrong_arg_count('list.reverse', node, args, 1)
1541
            return node
1542
        return self._substitute_method_call(
1543 1544
            node, "PyList_Reverse", self.single_param_func_type,
            'reverse', is_unbound_method, args)
1545

1546 1547 1548 1549 1550
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1551
            ])
1552 1553

    def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1554 1555
        """Replace dict.get() by a call to PyDict_GetItem().
        """
1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
            node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
            'get', is_unbound_method, args,
            utility_code = dict_getitem_default_utility_code)

1567 1568 1569

    ### unicode type methods

1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
            ])

    def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
        """Replace unicode.splitlines(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2):
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
            return node
1583
        self._inject_bint_default_argument(node, args, 1, False)
1584 1585 1586 1587 1588

        return self._substitute_method_call(
            node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
            'splitlines', is_unbound_method, args)

Stefan Behnel's avatar
Stefan Behnel committed
1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606
    PyUnicode_Join_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("sep", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("iterable", PyrexTypes.py_object_type, None),
            ])

    def _handle_simple_method_unicode_join(self, node, args, is_unbound_method):
        """Replace unicode.join(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) != 2:
            self._error_wrong_arg_count('unicode.join', node, args, 2)
            return node

        return self._substitute_method_call(
            node, "PyUnicode_Join", self.PyUnicode_Join_func_type,
            'join', is_unbound_method, args)

1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623
    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
            ]
        )

    def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
        """Replace unicode.split(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2,3):
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
            return node
        if len(args) < 2:
            args.append(ExprNodes.NullNode(node.pos))
1624 1625
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
1626 1627 1628 1629 1630

        return self._substitute_method_call(
            node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
            'split', is_unbound_method, args)

1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656
    PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-1')

    def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
        return self._inject_unicode_tailmatch(
            node, args, is_unbound_method, 'endswith', +1)

    def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
        return self._inject_unicode_tailmatch(
            node, args, is_unbound_method, 'startswith', -1)

    def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
                                  method_name, direction):
        """Replace unicode.startswith(...) and unicode.endswith(...)
        by a direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
1657 1658 1659 1660
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
1661 1662 1663 1664 1665 1666 1667 1668
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
            node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
            method_name, is_unbound_method, args,
            utility_code = unicode_tailmatch_utility_code)
        return ExprNodes.CoerceToPyTypeNode(
1669
            method_call, self.current_env(), Builtin.bool_type)
1670

1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696
    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-2')

    def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
        return self._inject_unicode_find(
            node, args, is_unbound_method, 'find', +1)

    def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
        return self._inject_unicode_find(
            node, args, is_unbound_method, 'rfind', -1)

    def _inject_unicode_find(self, node, args, is_unbound_method,
                             method_name, direction):
        """Replace unicode.find(...) and unicode.rfind(...) by a
        direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
1697 1698 1699 1700
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
1701 1702 1703 1704 1705 1706 1707
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
            node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
            method_name, is_unbound_method, args)
        return ExprNodes.CoerceToPyTypeNode(
1708
            method_call, self.current_env(), PyrexTypes.py_object_type)
1709

Stefan Behnel's avatar
Stefan Behnel committed
1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725
    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            ],
        exception_value = '-1')

    def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
        """Replace unicode.count(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
            return node
1726 1727 1728 1729
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
Stefan Behnel's avatar
Stefan Behnel committed
1730 1731 1732 1733 1734

        method_call = self._substitute_method_call(
            node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
            'count', is_unbound_method, args)
        return ExprNodes.CoerceToPyTypeNode(
1735
            method_call, self.current_env(), PyrexTypes.py_object_type)
Stefan Behnel's avatar
Stefan Behnel committed
1736

Stefan Behnel's avatar
Stefan Behnel committed
1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751
    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
            ])

    def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
        """Replace unicode.replace(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (3,4):
            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
            return node
1752 1753
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
Stefan Behnel's avatar
Stefan Behnel committed
1754 1755 1756 1757 1758

        return self._substitute_method_call(
            node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
            'replace', is_unbound_method, args)

1759 1760
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
1761
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1762 1763
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1764
            ])
1765 1766 1767

    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
1768
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
1769
            ])
1770 1771 1772 1773

    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
                          'unicode_escape', 'raw_unicode_escape']

1774 1775
    _special_codecs = [ (name, codecs.getencoder(name))
                        for name in _special_encodings ]
1776 1777

    def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1778 1779 1780
        """Replace unicode.encode(...) by a direct C-API call to the
        corresponding codec.
        """
1781
        if len(args) < 1 or len(args) > 3:
1782
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
1783 1784 1785 1786 1787
            return node

        string_node = args[0]

        if len(args) == 1:
1788
            null_node = ExprNodes.NullNode(node.pos)
1789 1790 1791 1792 1793
            return self._substitute_method_call(
                node, "PyUnicode_AsEncodedString",
                self.PyUnicode_AsEncodedString_func_type,
                'encode', is_unbound_method, [string_node, null_node, null_node])

1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        if isinstance(string_node, ExprNodes.UnicodeNode):
            # constant, so try to do the encoding at compile time
            try:
                value = string_node.value.encode(encoding, error_handling)
            except:
                # well, looks like we can't
                pass
            else:
                value = BytesLiteral(value)
                value.encoding = encoding
                return ExprNodes.BytesNode(
                    string_node.pos, value=value, type=Builtin.bytes_type)

        if error_handling == 'strict':
            # try to find a specific encoder function
            codec_name = self._find_special_codec_name(encoding)
            if codec_name is not None:
                encode_function = "PyUnicode_As%sString" % codec_name
                return self._substitute_method_call(
                    node, encode_function,
                    self.PyUnicode_AsXyzString_func_type,
                    'encode', is_unbound_method, [string_node])

        return self._substitute_method_call(
            node, "PyUnicode_AsEncodedString",
            self.PyUnicode_AsEncodedString_func_type,
            'encode', is_unbound_method,
            [string_node, encoding_node, error_handling_node])

    PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1833
            ])
1834 1835 1836 1837 1838 1839 1840

    PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
1841
            ])
1842 1843

    def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
1844 1845 1846
        """Replace char*.decode() by a direct C-API call to the
        corresponding codec, possibly resoving a slice on the char*.
        """
1847 1848 1849
        if len(args) < 1 or len(args) > 3:
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
            return node
1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861
        temps = []
        if isinstance(args[0], ExprNodes.SliceIndexNode):
            index_node = args[0]
            string_node = index_node.base
            if not string_node.type.is_string:
                # nothing to optimise here
                return node
            start, stop = index_node.start, index_node.stop
            if not start or start.constant_result == 0:
                start = None
            else:
                if start.type.is_pyobject:
1862
                    start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
1863
                if stop:
1864 1865 1866 1867 1868 1869 1870 1871 1872 1873
                    start = UtilNodes.LetRefNode(start)
                    temps.append(start)
                string_node = ExprNodes.AddNode(pos=start.pos,
                                                operand1=string_node,
                                                operator='+',
                                                operand2=start,
                                                is_temp=False,
                                                type=string_node.type
                                                )
            if stop and stop.type.is_pyobject:
1874
                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
1875 1876 1877 1878 1879 1880 1881
        elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
                 and args[0].arg.type.is_string:
            # use strlen() to find the string length, just as CPython would
            start = stop = None
            string_node = args[0].arg
        else:
            # let Python do its job
1882
            return node
1883

1884
        if not stop:
1885
            if start or not string_node.is_name:
1886 1887 1888 1889 1890 1891 1892
                string_node = UtilNodes.LetRefNode(string_node)
                temps.append(string_node)
            stop = ExprNodes.PythonCapiCallNode(
                string_node.pos, "strlen", self.Pyx_strlen_func_type,
                    args = [string_node],
                    is_temp = False,
                    utility_code = include_string_h_utility_code,
1893
                    ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
1894 1895 1896 1897 1898 1899 1900 1901 1902
        elif start:
            stop = ExprNodes.SubNode(
                pos = stop.pos,
                operand1 = stop,
                operator = '-',
                operand2 = start,
                is_temp = False,
                type = PyrexTypes.c_py_ssize_t_type
                )
1903 1904 1905 1906 1907 1908 1909

        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        # try to find a specific encoder function
1910 1911 1912
        codec_name = None
        if encoding is not None:
            codec_name = self._find_special_codec_name(encoding)
1913 1914
        if codec_name is not None:
            decode_function = "PyUnicode_Decode%s" % codec_name
1915
            node = ExprNodes.PythonCapiCallNode(
1916 1917 1918 1919 1920
                node.pos, decode_function,
                self.PyUnicode_DecodeXyz_func_type,
                args = [string_node, stop, error_handling_node],
                is_temp = node.is_temp,
                )
1921 1922 1923 1924 1925 1926 1927
        else:
            node = ExprNodes.PythonCapiCallNode(
                node.pos, "PyUnicode_Decode",
                self.PyUnicode_Decode_func_type,
                args = [string_node, stop, encoding_node, error_handling_node],
                is_temp = node.is_temp,
                )
1928

1929 1930 1931
        for temp in temps[::-1]:
            node = UtilNodes.EvalWithTempExprNode(temp, node)
        return node
1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946

    def _find_special_codec_name(self, encoding):
        try:
            requested_codec = codecs.getencoder(encoding)
        except:
            return None
        for name, codec in self._special_codecs:
            if codec == requested_codec:
                if '_' in name:
                    name = ''.join([ s.capitalize()
                                     for s in name.split('_')])
                return name
        return None

    def _unpack_encoding_and_error_mode(self, pos, args):
1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961
        null_node = ExprNodes.NullNode(pos)

        if len(args) >= 2:
            encoding_node = args[1]
            if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
                encoding_node = encoding_node.arg
            if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                                          ExprNodes.BytesNode)):
                encoding = encoding_node.value
                encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
                                                     type=PyrexTypes.c_char_ptr_type)
            elif encoding_node.type.is_string:
                encoding = None
            else:
                return None
1962
        else:
1963 1964
            encoding = None
            encoding_node = null_node
1965 1966 1967 1968 1969

        if len(args) == 3:
            error_handling_node = args[2]
            if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
                error_handling_node = error_handling_node.arg
1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981
            if isinstance(error_handling_node,
                          (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                           ExprNodes.BytesNode)):
                error_handling = error_handling_node.value
                if error_handling == 'strict':
                    error_handling_node = null_node
                else:
                    error_handling_node = ExprNodes.BytesNode(
                        error_handling_node.pos, value=error_handling,
                        type=PyrexTypes.c_char_ptr_type)
            elif error_handling_node.type.is_string:
                error_handling = None
1982
            else:
1983
                return None
1984 1985 1986 1987
        else:
            error_handling = 'strict'
            error_handling_node = null_node

1988
        return (encoding, encoding_node, error_handling, error_handling_node)
1989

1990 1991 1992

    ### helpers

1993
    def _substitute_method_call(self, node, name, func_type,
1994 1995
                                attr_name, is_unbound_method, args=(),
                                utility_code=None):
1996
        args = list(args)
1997
        if args and not args[0].is_literal:
1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008
            self_arg = args[0]
            if is_unbound_method:
                self_arg = ExprNodes.NoneCheckNode(
                    self_arg, "PyExc_TypeError",
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
                    attr_name, node.function.obj.name))
            else:
                self_arg = ExprNodes.NoneCheckNode(
                    self_arg, "PyExc_AttributeError",
                    "'NoneType' object has no attribute '%s'" % attr_name)
            args[0] = self_arg
2009
        return ExprNodes.PythonCapiCallNode(
2010
            node.pos, name, func_type,
2011
            args = args,
2012 2013
            is_temp = node.is_temp,
            utility_code = utility_code
2014 2015
            )

2016 2017 2018 2019 2020
    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
            args.append(ExprNodes.IntNode(node.pos, value=str(default_value), type=type))
        else:
2021
            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2022 2023 2024 2025 2026 2027

    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
            args.append(ExprNodes.BoolNode(node.pos, value=bool(default_value)))
        else:
2028
            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2029

2030

2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058
unicode_tailmatch_utility_code = UtilityCode(
    # Python's unicode.startswith() and unicode.endswith() support a
    # tuple of prefixes/suffixes, whereas it's much more common to
    # test for a single unicode string.
proto = '''
static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
Py_ssize_t start, Py_ssize_t end, int direction);
''',
impl = '''
static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
                                     Py_ssize_t start, Py_ssize_t end, int direction) {
    if (unlikely(PyTuple_Check(substr))) {
        int result;
        Py_ssize_t i;
        for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
            result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
                                         start, end, direction);
            if (result) {
                return result;
            }
        }
        return 0;
    }
    return PyUnicode_Tailmatch(s, substr, start, end, direction);
}
''',
)

2059 2060
dict_getitem_default_utility_code = UtilityCode(
proto = '''
2061
static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082
    PyObject* value;
#if PY_MAJOR_VERSION >= 3
    value = PyDict_GetItemWithError(d, key);
    if (unlikely(!value)) {
        if (unlikely(PyErr_Occurred()))
            return NULL;
        value = default_value;
    }
    Py_INCREF(value);
#else
    if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
        /* these presumably have safe hash functions */
        value = PyDict_GetItem(d, key);
        if (unlikely(!value)) {
            value = default_value;
        }
        Py_INCREF(value);
    } else {
        PyObject *m;
        m = __Pyx_GetAttrString(d, "get");
        if (!m) return NULL;
2083 2084
        value = PyObject_CallFunctionObjArgs(m, key,
            (default_value == Py_None) ? NULL : default_value, NULL);
2085 2086 2087 2088 2089 2090 2091 2092 2093
        Py_DECREF(m);
    }
#endif
    return value;
}
''',
impl = ""
)

2094 2095
append_utility_code = UtilityCode(
proto = """
2096
static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113
    if (likely(PyList_CheckExact(L))) {
        if (PyList_Append(L, x) < 0) return NULL;
        Py_INCREF(Py_None);
        return Py_None; /* this is just to have an accurate signature */
    }
    else {
        PyObject *r, *m;
        m = __Pyx_GetAttrString(L, "append");
        if (!m) return NULL;
        r = PyObject_CallFunctionObjArgs(m, x, NULL);
        Py_DECREF(m);
        return r;
    }
}
""",
impl = ""
)
2114 2115


Robert Bradshaw's avatar
Robert Bradshaw committed
2116 2117
pop_utility_code = UtilityCode(
proto = """
2118
static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2119
    PyObject *r, *m;
2120
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
2121 2122 2123 2124 2125 2126
    if (likely(PyList_CheckExact(L))
            /* Check that both the size is positive and no reallocation shrinking needs to be done. */
            && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
        Py_SIZE(L) -= 1;
        return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
    }
2127 2128 2129 2130 2131 2132
#endif
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) return NULL;
    r = PyObject_CallObject(m, NULL);
    Py_DECREF(m);
    return r;
Robert Bradshaw's avatar
Robert Bradshaw committed
2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144
}
""",
impl = ""
)

pop_index_utility_code = UtilityCode(
proto = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
""",
impl = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
    PyObject *r, *m, *t, *py_ix;
2145
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163
    if (likely(PyList_CheckExact(L))) {
        Py_ssize_t size = PyList_GET_SIZE(L);
        if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
            if (ix < 0) {
                ix += size;
            }
            if (likely(0 <= ix && ix < size)) {
                Py_ssize_t i;
                PyObject* v = PyList_GET_ITEM(L, ix);
                Py_SIZE(L) -= 1;
                size -= 1;
                for(i=ix; i<size; i++) {
                    PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
                }
                return v;
            }
        }
    }
2164
#endif
Robert Bradshaw's avatar
Robert Bradshaw committed
2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187
    py_ix = t = NULL;
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) goto bad;
    py_ix = PyInt_FromSsize_t(ix);
    if (!py_ix) goto bad;
    t = PyTuple_New(1);
    if (!t) goto bad;
    PyTuple_SET_ITEM(t, 0, py_ix);
    py_ix = NULL;
    r = PyObject_CallObject(m, t);
    Py_DECREF(m);
    Py_DECREF(t);
    return r;
bad:
    Py_XDECREF(m);
    Py_XDECREF(t);
    Py_XDECREF(py_ix);
    return NULL;
}
"""
)


2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200
pyobject_as_double_utility_code = UtilityCode(
proto = '''
static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */

#define __Pyx_PyObject_AsDouble(obj) \\
    ((likely(PyFloat_CheckExact(obj))) ? \\
     PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
''',
impl='''
static double __Pyx__PyObject_AsDouble(PyObject* obj) {
    PyObject* float_value;
    if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
        return PyFloat_AsDouble(obj);
2201
    } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226
#if PY_MAJOR_VERSION >= 3
        float_value = PyFloat_FromString(obj);
#else
        float_value = PyFloat_FromString(obj, 0);
#endif
    } else {
        PyObject* args = PyTuple_New(1);
        if (unlikely(!args)) goto bad;
        PyTuple_SET_ITEM(args, 0, obj);
        float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
        PyTuple_SET_ITEM(args, 0, 0);
        Py_DECREF(args);
    }
    if (likely(float_value)) {
        double value = PyFloat_AS_DOUBLE(float_value);
        Py_DECREF(float_value);
        return value;
    }
bad:
    return (double)-1;
}
'''
)


2227 2228 2229 2230 2231 2232 2233
include_string_h_utility_code = UtilityCode(
proto = """
#include <string.h>
"""
)


2234 2235
tpnew_utility_code = UtilityCode(
proto = """
2236
static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2237 2238 2239 2240 2241 2242 2243
    return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
        (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
}
""" % {'TUPLE' : Naming.empty_tuple}
)


2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
    """Calculate the result of constant expressions to store it in
    ``expr_node.constant_result``, and replace trivial cases by their
    constant result.
    """
    def _calculate_const(self, node):
        if node.constant_result is not ExprNodes.constant_value_not_set:
            return

        # make sure we always set the value
        not_a_constant = ExprNodes.not_a_constant
        node.constant_result = not_a_constant

        # check if all children are constant
        children = self.visitchildren(node)
        for child_result in children.itervalues():
            if type(child_result) is list:
                for child in child_result:
                    if child.constant_result is not_a_constant:
                        return
            elif child_result.constant_result is not_a_constant:
                return

        # now try to calculate the real constant value
        try:
            node.calculate_constant_result()
#            if node.constant_result is not ExprNodes.not_a_constant:
#                print node.__class__.__name__, node.constant_result
2272
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2273 2274 2275 2276 2277 2278 2279
            # ignore all 'normal' errors here => no constant result
            pass
        except Exception:
            # this looks like a real error
            import traceback, sys
            traceback.print_exc(file=sys.stdout)

2280 2281 2282 2283 2284 2285 2286 2287 2288 2289
    NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
                       ExprNodes.LongNode, ExprNodes.FloatNode)

    def _widest_node_class(self, *nodes):
        try:
            return self.NODE_TYPE_ORDER[
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
        except ValueError:
            return None

2290 2291 2292 2293
    def visit_ExprNode(self, node):
        self._calculate_const(node)
        return node

2294 2295 2296 2297
    def visit_BinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
2298 2299 2300 2301 2302
        if isinstance(node.constant_result, float):
            # We calculate float constants to make them available to
            # the compiler, but we do not aggregate them into a
            # constant node to prevent any loss of precision.
            return node
2303
        if not node.operand1.is_literal or not node.operand2.is_literal:
2304 2305 2306 2307 2308 2309
            # We calculate other constants to make them available to
            # the compiler, but we only aggregate constant nodes
            # recursively, so non-const nodes are straight out.
            return node

        # now inject a new constant node with the calculated value
2310
        try:
2311 2312
            type1, type2 = node.operand1.type, node.operand2.type
            if type1 is None or type2 is None:
2313 2314 2315 2316
                return node
        except AttributeError:
            return node

2317 2318 2319 2320 2321 2322 2323 2324
        if type1 is type2:
            new_node = node.operand1
        else:
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
            if type(node.operand1) is type(node.operand2):
                new_node = node.operand1
                new_node.type = widest_type
            elif type1 is widest_type:
2325
                new_node = node.operand1
2326 2327
            elif type2 is widest_type:
                new_node = node.operand2
2328
            else:
2329 2330 2331 2332 2333
                target_class = self._widest_node_class(
                    node.operand1, node.operand2)
                if target_class is None:
                    return node
                new_node = target_class(pos=node.pos, type = widest_type)
2334 2335 2336 2337

        new_node.constant_result = node.constant_result
        new_node.value = str(node.constant_result)
        #new_node = new_node.coerce_to(node.type, self.current_scope)
2338 2339
        return new_node

2340 2341
    # in the future, other nodes can have their own handler method here
    # that can replace them with a constant result node
Stefan Behnel's avatar
Stefan Behnel committed
2342

2343
    visit_Node = Visitor.VisitorTransform.recurse_to_children
2344 2345


2346 2347 2348 2349 2350 2351
class FinalOptimizePhase(Visitor.CythonTransform):
    """
    This visitor handles several commuting optimizations, and is run
    just before the C code generation phase. 
    
    The optimizations currently implemented in this class are: 
2352 2353
        - Eliminate None assignment and refcounting for first assignment. 
        - isinstance -> typecheck for cdef types
2354
    """
2355
    def visit_SingleAssignmentNode(self, node):
2356 2357 2358 2359
        """Avoid redundant initialisation of local variables before their
        first assignment.
        """
        self.visitchildren(node)
2360 2361
        if node.first:
            lhs = node.lhs
2362
            lhs.lhs_of_first_assignment = True
2363 2364 2365 2366 2367
            if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
                # Have variable initialized to 0 rather than None
                lhs.entry.init_to_none = False
                lhs.entry.init = 0
        return node
2368

2369 2370 2371
    def visit_SimpleCallNode(self, node):
        """Replace generic calls to isinstance(x, type) by a more efficient
        type check.
2372
        """
2373 2374 2375 2376 2377 2378 2379 2380 2381 2382
        self.visitchildren(node)
        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
            if node.function.name == 'isinstance':
                type_arg = node.args[1]
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
                    from CythonScope import utility_scope
                    node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
                    node.function.type = node.function.entry.type
                    PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
2383
        return node