Optimize.py 127 KB
Newer Older
1 2 3 4 5 6 7

import cython
from cython import set
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
               UtilNodes=object, Naming=object)

8 9
import Nodes
import ExprNodes
10
import PyrexTypes
11
import Visitor
12 13 14 15
import Builtin
import UtilNodes
import TypeSlots
import Symtab
16
import Options
17
import Naming
18

19
from Code import UtilityCode
20
from StringEncoding import EncodedString, BytesLiteral
21
from Errors import error
22 23
from ParseTreeTransforms import SkipDeclarations

24 25
import codecs

26
try:
27 28
    from __builtin__ import reduce
except ImportError:
29 30
    from functools import reduce

31 32 33 34 35
try:
    from __builtin__ import basestring
except ImportError:
    basestring = str # Python 3

36 37 38 39
class FakePythonEnv(object):
    "A fake environment for creating type test nodes etc."
    nogil = False

40 41 42 43 44
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
    if isinstance(node, coercion_nodes):
        return node.arg
    return node

45
def unwrap_node(node):
46 47
    while isinstance(node, UtilNodes.ResultRefNode):
        node = node.expression
48
    return node
49 50

def is_common_value(a, b):
51 52
    a = unwrap_node(a)
    b = unwrap_node(b)
53 54 55
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
        return a.name == b.name
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
56
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
57 58
    return False

59 60 61 62
class IterationTransform(Visitor.VisitorTransform):
    """Transform some common for-in loop patterns into efficient C loops:

    - for-in-dict loop becomes a while loop calling PyDict_Next()
Stefan Behnel's avatar
Stefan Behnel committed
63
    - for-in-enumerate is replaced by an external counter variable
64
    - for-in-range loop becomes a plain C for loop
65 66 67 68 69 70 71 72 73 74 75 76 77 78
    """
    PyDict_Next_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
            ])

    PyDict_Next_name = EncodedString("PyDict_Next")

    PyDict_Next_entry = Symtab.Entry(
        PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)

79
    visit_Node = Visitor.VisitorTransform.recurse_to_children
Stefan Behnel's avatar
Stefan Behnel committed
80

81 82
    def visit_ModuleNode(self, node):
        self.current_scope = node.scope
83
        self.module_scope = node.scope
84 85 86 87 88 89 90 91 92
        self.visitchildren(node)
        return node

    def visit_DefNode(self, node):
        oldscope = self.current_scope
        self.current_scope = node.entry.scope
        self.visitchildren(node)
        self.current_scope = oldscope
        return node
93

94 95
    def visit_PrimaryCmpNode(self, node):
        if node.is_ptr_contains():
96

97 98 99 100 101 102
            # for t in operand2:
            #     if operand1 == t:
            #         res = True
            #         break
            # else:
            #     res = False
103

104 105 106 107 108 109 110 111 112 113 114 115 116
            pos = node.pos
            res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
            res = res_handle.ref(pos)
            result_ref = UtilNodes.ResultRefNode(node)
            if isinstance(node.operand2, ExprNodes.IndexNode):
                base_type = node.operand2.base.type.base_type
            else:
                base_type = node.operand2.type.base_type
            target_handle = UtilNodes.TempHandle(base_type)
            target = target_handle.ref(pos)
            cmp_node = ExprNodes.PrimaryCmpNode(
                pos, operator=u'==', operand1=node.operand1, operand2=target)
            if_body = Nodes.StatListNode(
117
                pos,
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
                         Nodes.BreakStatNode(pos)])
            if_node = Nodes.IfStatNode(
                pos,
                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
                else_clause=None)
            for_loop = UtilNodes.TempsBlockNode(
                pos,
                temps = [target_handle],
                body = Nodes.ForInStatNode(
                    pos,
                    target=target,
                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
                    body=if_node,
                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
            for_loop.analyse_expressions(self.current_scope)
            for_loop = self(for_loop)
            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
136

137 138 139 140 141 142 143
            if node.operator == 'not_in':
                new_node = ExprNodes.NotNode(pos, operand=new_node)
            return new_node

        else:
            self.visitchildren(node)
            return node
144

145 146
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
147
        return self._optimise_for_loop(node)
148

149
    def _optimise_for_loop(self, node):
150
        iterator = node.iterator.sequence
151 152
        if iterator.type is Builtin.dict_type:
            # like iterating over dict.keys()
Stefan Behnel's avatar
Stefan Behnel committed
153 154
            return self._transform_dict_iteration(
                node, dict_obj=iterator, keys=True, values=False)
155

156
        # C array (slice) iteration?
157 158 159 160
        if False:
            plain_iterator = unwrap_coerced_node(iterator)
            if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
                   (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
161
                return self._transform_carray_iteration(node, plain_iterator)
162 163

        if iterator.type.is_ptr or iterator.type.is_array:
164
            return self._transform_carray_iteration(node, iterator)
165
        if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
166 167 168 169
            return self._transform_string_iteration(node, iterator)

        # the rest is based on function calls
        if not isinstance(iterator, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
170 171 172
            return node

        function = iterator.function
173
        # dict iteration?
Stefan Behnel's avatar
Stefan Behnel committed
174 175
        if isinstance(function, ExprNodes.AttributeNode) and \
                function.obj.type == Builtin.dict_type:
176 177 178
            dict_obj = function.obj
            method = function.attribute

179
            is_py3 = self.module_scope.context.language_level >= 3
180
            keys = values = False
181
            if method == 'iterkeys' or (is_py3 and method == 'keys'):
182
                keys = True
183
            elif method == 'itervalues' or (is_py3 and method == 'values'):
184
                values = True
185
            elif method == 'iteritems' or (is_py3 and method == 'items'):
186 187 188
                keys = values = True
            else:
                return node
Stefan Behnel's avatar
Stefan Behnel committed
189 190
            return self._transform_dict_iteration(
                node, dict_obj, keys, values)
191

192
        # enumerate() ?
Stefan Behnel's avatar
Stefan Behnel committed
193
        if iterator.self is None and function.is_name and \
194
               function.entry and function.entry.is_builtin and \
195 196 197
               function.name == 'enumerate':
            return self._transform_enumerate_iteration(node, iterator)

198 199
        # range() iteration?
        if Options.convert_range and node.target.type.is_int:
Stefan Behnel's avatar
Stefan Behnel committed
200 201 202
            if iterator.self is None and function.is_name and \
                   function.entry and function.entry.is_builtin and \
                   function.name in ('range', 'xrange'):
Stefan Behnel's avatar
Stefan Behnel committed
203
                return self._transform_range_iteration(node, iterator)
204

Stefan Behnel's avatar
Stefan Behnel committed
205
        return node
206

207
    PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
Stefan Behnel's avatar
Stefan Behnel committed
208
        PyrexTypes.c_py_unicode_ptr_type, [
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
            PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
            ])

    PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
            ])

    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_ptr_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

    def _transform_string_iteration(self, node, slice_node):
        if not node.target.type.is_int:
229
            return self._transform_carray_iteration(node, slice_node)
230 231 232 233 234 235 236 237 238 239 240 241 242 243
        if slice_node.type is Builtin.unicode_type:
            unpack_func = "PyUnicode_AS_UNICODE"
            len_func = "PyUnicode_GET_SIZE"
            unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
            len_func_type = self.PyUnicode_GET_SIZE_func_type
        elif slice_node.type is Builtin.bytes_type:
            unpack_func = "PyBytes_AS_STRING"
            unpack_func_type = self.PyBytes_AS_STRING_func_type
            len_func = "PyBytes_GET_SIZE"
            len_func_type = self.PyBytes_GET_SIZE_func_type
        else:
            return node

        unpack_temp_node = UtilNodes.LetRefNode(
244
            slice_node.as_none_safe_node("'NoneType' is not iterable"))
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270

        slice_base_node = ExprNodes.PythonCapiCallNode(
            slice_node.pos, unpack_func, unpack_func_type,
            args = [unpack_temp_node],
            is_temp = 0,
            )
        len_node = ExprNodes.PythonCapiCallNode(
            slice_node.pos, len_func, len_func_type,
            args = [unpack_temp_node],
            is_temp = 0,
            )

        return UtilNodes.LetNode(
            unpack_temp_node,
            self._transform_carray_iteration(
                node,
                ExprNodes.SliceIndexNode(
                    slice_node.pos,
                    base = slice_base_node,
                    start = None,
                    step = None,
                    stop = len_node,
                    type = slice_base_node.type,
                    is_temp = 1,
                    )))

271
    def _transform_carray_iteration(self, node, slice_node):
272
        neg_step = False
273 274 275 276 277 278
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
            slice_base = slice_node.base
            start = slice_node.start
            stop = slice_node.stop
            step = None
            if not stop:
279 280
                if not slice_base.type.is_pyobject:
                    error(slice_node.pos, "C array iteration requires known end index")
281
                return node
282 283
        elif isinstance(slice_node, ExprNodes.IndexNode):
            # slice_node.index must be a SliceNode
284 285
            slice_base = slice_node.base
            index = slice_node.index
286 287 288 289 290 291 292 293 294 295
            start = index.start
            stop = index.stop
            step = index.step
            if step:
                if step.constant_result is None:
                    step = None
                elif not isinstance(step.constant_result, (int,long)) \
                       or step.constant_result == 0 \
                       or step.constant_result > 0 and not stop \
                       or step.constant_result < 0 and not start:
296 297
                    if not slice_base.type.is_pyobject:
                        error(step.pos, "C array iteration requires known step size and end index")
298 299 300 301 302 303 304
                    return node
                else:
                    # step sign is handled internally by ForFromStatNode
                    neg_step = step.constant_result < 0
                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
                                             value=abs(step.constant_result),
                                             constant_result=abs(step.constant_result))
305 306 307 308
        elif slice_node.type.is_array:
            if slice_node.type.size is None:
                error(step.pos, "C array iteration requires known end index")
                return node
309 310 311
            slice_base = slice_node
            start = None
            stop = ExprNodes.IntNode(
312 313
                slice_node.pos, value=str(slice_node.type.size),
                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
314
            step = None
315

316
        else:
317
            if not slice_node.type.is_pyobject:
318
                error(slice_node.pos, "C array iteration requires known end index")
319 320
            return node

321 322 323 324 325 326 327 328 329 330
        if start:
            if start.constant_result is None:
                start = None
            else:
                start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
        if stop:
            if stop.constant_result is None:
                stop = None
            else:
                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
331 332 333 334 335 336 337
        if stop is None:
            if neg_step:
                stop = ExprNodes.IntNode(
                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
            else:
                error(slice_node.pos, "C array iteration requires known step size and end index")
                return node
338

339 340 341 342
        ptr_type = slice_base.type
        if ptr_type.is_array:
            ptr_type = ptr_type.element_ptr_type()
        carray_ptr = slice_base.coerce_to_simple(self.current_scope)
343

344
        if start and start.constant_result != 0:
345 346 347 348 349
            start_ptr_node = ExprNodes.AddNode(
                start.pos,
                operand1=carray_ptr,
                operator='+',
                operand2=start,
350
                type=ptr_type)
351
        else:
352
            start_ptr_node = carray_ptr
353

354 355
        stop_ptr_node = ExprNodes.AddNode(
            stop.pos,
356
            operand1=ExprNodes.CloneNode(carray_ptr),
357 358
            operator='+',
            operand2=stop,
359
            type=ptr_type
360
            ).coerce_to_simple(self.current_scope)
361

362
        counter = UtilNodes.TempHandle(ptr_type)
363 364
        counter_temp = counter.ref(node.target.pos)

365
        if slice_base.type.is_string and node.target.type.is_pyobject:
366
            # special case: char* -> bytes
367 368
            target_value = ExprNodes.SliceIndexNode(
                node.target.pos,
369 370 371 372 373 374 375
                start=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                stop=ExprNodes.IntNode(node.target.pos, value='1',
                                       constant_result=1,
                                       type=PyrexTypes.c_int_type),
                base=counter_temp,
376 377
                type=Builtin.bytes_type,
                is_temp=1)
378 379 380
        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
            # Allow iteration with pointer target to avoid copy.
            target_value = counter_temp
381 382 383
        else:
            target_value = ExprNodes.IndexNode(
                node.target.pos,
384 385 386 387
                index=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                base=counter_temp,
388
                is_buffer_access=False,
389
                type=ptr_type.base_type)
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405

        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
                                                  self.current_scope)

        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
            rhs = target_value)

        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

        for_node = Nodes.ForFromStatNode(
            node.pos,
406
            bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
407
            target=counter_temp,
408
            relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
409 410 411 412 413 414 415 416
            step=step, body=body,
            else_clause=node.else_clause,
            from_range=True)

        return UtilNodes.TempsBlockNode(
            node.pos, temps=[counter],
            body=for_node)

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
    def _transform_enumerate_iteration(self, node, enumerate_function):
        args = enumerate_function.arg_tuple.args
        if len(args) == 0:
            error(enumerate_function.pos,
                  "enumerate() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(enumerate_function.pos,
                  "enumerate() takes at most 1 argument")
            return node

        if not node.target.is_sequence_constructor:
            # leave this untouched for now
            return node
        targets = node.target.args
        if len(targets) != 2:
            # leave this untouched for now
            return node
        if not isinstance(targets[0], ExprNodes.NameNode):
            # leave this untouched for now
            return node

        enumerate_target, iterable_target = targets
        counter_type = enumerate_target.type

        if not counter_type.is_pyobject and not counter_type.is_int:
            # nothing we can do here, I guess
            return node

446 447 448 449
        temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
                                                      value='0',
                                                      type=counter_type,
                                                      constant_result=0))
450 451
        inc_expression = ExprNodes.AddNode(
            enumerate_function.pos,
452
            operand1 = temp,
453
            operand2 = ExprNodes.IntNode(node.pos, value='1',
454 455
                                         type=counter_type,
                                         constant_result=1),
456 457 458 459 460
            operator = '+',
            type = counter_type,
            is_temp = counter_type.is_pyobject
            )

461 462 463 464
        loop_body = [
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
                lhs = enumerate_target,
465
                rhs = temp),
466 467
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
468
                lhs = temp,
469 470
                rhs = inc_expression)
            ]
471

472 473 474 475 476 477 478
        if isinstance(node.body, Nodes.StatListNode):
            node.body.stats = loop_body + node.body.stats
        else:
            loop_body.append(node.body)
            node.body = Nodes.StatListNode(
                node.body.pos,
                stats = loop_body)
479 480

        node.target = iterable_target
481
        node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
482 483 484
        node.iterator.sequence = enumerate_function.arg_tuple.args[0]

        # recurse into loop to check for further optimisations
485
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
486

487 488 489 490 491
    def _transform_range_iteration(self, node, range_function):
        args = range_function.arg_tuple.args
        if len(args) < 3:
            step_pos = range_function.pos
            step_value = 1
492 493
            step = ExprNodes.IntNode(step_pos, value='1',
                                     constant_result=1)
494 495 496
        else:
            step = args[2]
            step_pos = step.pos
497
            if not isinstance(step.constant_result, (int, long)):
498 499
                # cannot determine step direction
                return node
500 501 502
            step_value = step.constant_result
            if step_value == 0:
                # will lead to an error elsewhere
503 504
                return node
            if not isinstance(step, ExprNodes.IntNode):
505 506
                step = ExprNodes.IntNode(step_pos, value=str(step_value),
                                         constant_result=step_value)
507

508
        if step_value < 0:
509
            step.value = str(-step_value)
510 511 512
            relation1 = '>='
            relation2 = '>'
        else:
513 514
            relation1 = '<='
            relation2 = '<'
515 516

        if len(args) == 1:
517 518
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
                                       constant_result=0)
519
            bound2 = args[0].coerce_to_integer(self.current_scope)
520
        else:
521 522 523
            bound1 = args[0].coerce_to_integer(self.current_scope)
            bound2 = args[1].coerce_to_integer(self.current_scope)
        step = step.coerce_to_integer(self.current_scope)
524

525
        if not bound2.is_literal:
526 527 528 529 530 531
            # stop bound must be immutable => keep it in a temp var
            bound2_is_temp = True
            bound2 = UtilNodes.LetRefNode(bound2)
        else:
            bound2_is_temp = False

532 533 534 535 536 537 538
        for_node = Nodes.ForFromStatNode(
            node.pos,
            target=node.target,
            bound1=bound1, relation1=relation1,
            relation2=relation2, bound2=bound2,
            step=step, body=node.body,
            else_clause=node.else_clause,
Magnus Lie Hetland's avatar
Magnus Lie Hetland committed
539
            from_range=True)
540 541 542 543

        if bound2_is_temp:
            for_node = UtilNodes.LetNode(bound2, for_node)

544 545
        return for_node

Stefan Behnel's avatar
Stefan Behnel committed
546
    def _transform_dict_iteration(self, node, dict_obj, keys, values):
547 548 549
        py_object_ptr = PyrexTypes.c_void_ptr_type

        temps = []
550 551 552
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        dict_temp = temp.ref(dict_obj.pos)
553 554
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
555
        pos_temp = temp.ref(node.pos)
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
        pos_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=pos_temp,
            type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
        if keys:
            temp = UtilNodes.TempHandle(py_object_ptr)
            temps.append(temp)
            key_temp = temp.ref(node.target.pos)
            key_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=key_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            key_temp_addr = key_temp = ExprNodes.NullNode(
                pos=node.target.pos)
        if values:
            temp = UtilNodes.TempHandle(py_object_ptr)
            temps.append(temp)
            value_temp = temp.ref(node.target.pos)
            value_temp_addr = ExprNodes.AmpersandNode(
                node.target.pos, operand=value_temp,
                type=PyrexTypes.c_ptr_type(py_object_ptr))
        else:
            value_temp_addr = value_temp = ExprNodes.NullNode(
                pos=node.target.pos)

        key_target = value_target = node.target
        tuple_target = None
        if keys and values:
            if node.target.is_sequence_constructor:
                if len(node.target.args) == 2:
                    key_target, value_target = node.target.args
                else:
Stefan Behnel's avatar
Stefan Behnel committed
587
                    # unusual case that may or may not lead to an error
588 589 590 591
                    return node
            else:
                tuple_target = node.target

592 593
        def coerce_object_to(obj_node, dest_type):
            if dest_type.is_pyobject:
594 595 596
                if dest_type != obj_node.type:
                    if dest_type.is_extension_type or dest_type.is_builtin_type:
                        obj_node = ExprNodes.PyTypeTestNode(
597
                            obj_node, dest_type, self.current_scope, notnone=True)
598 599 600 601
                result = ExprNodes.TypecastNode(
                    obj_node.pos,
                    operand = obj_node,
                    type = dest_type)
602
                return (result, None)
603 604 605 606 607 608 609 610 611
            else:
                temp = UtilNodes.TempHandle(dest_type)
                temps.append(temp)
                temp_result = temp.ref(obj_node.pos)
                class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
                    def result(self):
                        return temp_result.result()
                    def generate_execution_code(self, code):
                        self.generate_result_code(code)
612
                return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
613 614 615 616 617 618 619 620

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        if tuple_target:
621
            tuple_result = ExprNodes.TupleNode(
622
                pos = tuple_target.pos,
623
                args = [key_temp, value_temp],
624 625
                is_temp = 1,
                type = Builtin.tuple_type,
626
                )
627
            body.stats.insert(
628 629 630 631
                0, Nodes.SingleAssignmentNode(
                    pos = tuple_target.pos,
                    lhs = tuple_target,
                    rhs = tuple_result))
632
        else:
633 634 635
            # execute all coercions before the assignments
            coercion_stats = []
            assign_stats = []
636
            if keys:
637 638 639 640 641 642 643
                temp_result, coercion = coerce_object_to(
                    key_temp, key_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = key_temp.pos,
644 645
                        lhs = key_target,
                        rhs = temp_result))
646 647 648 649 650 651 652 653
            if values:
                temp_result, coercion = coerce_object_to(
                    value_temp, value_target.type)
                if coercion:
                    coercion_stats.append(coercion)
                assign_stats.append(
                    Nodes.SingleAssignmentNode(
                        pos = value_temp.pos,
654 655
                        lhs = value_target,
                        rhs = temp_result))
656
            body.stats[0:0] = coercion_stats + assign_stats
657 658

        result_code = [
659 660 661 662
            Nodes.SingleAssignmentNode(
                pos = dict_obj.pos,
                lhs = dict_temp,
                rhs = dict_obj),
663 664 665
            Nodes.SingleAssignmentNode(
                pos = node.pos,
                lhs = pos_temp,
666 667
                rhs = ExprNodes.IntNode(node.pos, value='0',
                                        constant_result=0)),
668 669 670 671 672 673
            Nodes.WhileStatNode(
                pos = node.pos,
                condition = ExprNodes.SimpleCallNode(
                    pos = dict_obj.pos,
                    type = PyrexTypes.c_bint_type,
                    function = ExprNodes.NameNode(
Stefan Behnel's avatar
Stefan Behnel committed
674 675
                        pos = dict_obj.pos,
                        name = self.PyDict_Next_name,
676 677
                        type = self.PyDict_Next_func_type,
                        entry = self.PyDict_Next_entry),
678
                    args = [dict_temp, pos_temp_addr,
679 680 681 682 683 684 685 686 687 688
                            key_temp_addr, value_temp_addr]
                    ),
                body = body,
                else_clause = node.else_clause
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
689
                node.pos,
690 691 692 693
                stats = result_code
                ))


694 695
class SwitchTransform(Visitor.VisitorTransform):
    """
696
    This transformation tries to turn long if statements into C switch statements.
697
    The requirement is that every clause be an (or of) var == value, where the var
698
    is common among all clauses and both var and value are ints.
699
    """
700 701 702
    NO_MATCH = (None, None, None)

    def extract_conditions(self, cond, allow_not_in):
703 704 705 706 707 708 709 710 711 712
        while True:
            if isinstance(cond, ExprNodes.CoerceToTempNode):
                cond = cond.arg
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
                # this is what we get from the FlattenInListTransform
                cond = cond.subexpression
            elif isinstance(cond, ExprNodes.TypecastNode):
                cond = cond.operand
            else:
                break
713

714
        if isinstance(cond, ExprNodes.PrimaryCmpNode):
715 716 717 718 719 720 721 722 723 724 725 726 727 728 729
            if cond.cascade is not None:
                return self.NO_MATCH
            elif cond.is_c_string_contains() and \
                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
                not_in = cond.operator == 'not_in'
                if not_in and not allow_not_in:
                    return self.NO_MATCH
                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
                       cond.operand2.contains_surrogates():
                    # dealing with surrogates leads to different
                    # behaviour on wide and narrow Unicode
                    # platforms => refuse to optimise this case
                    return self.NO_MATCH
                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
            elif not cond.is_python_comparison():
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759
                if cond.operator == '==':
                    not_in = False
                elif allow_not_in and cond.operator == '!=':
                    not_in = True
                else:
                    return self.NO_MATCH
                # this looks somewhat silly, but it does the right
                # checks for NameNode and AttributeNode
                if is_common_value(cond.operand1, cond.operand1):
                    if cond.operand2.is_literal:
                        return not_in, cond.operand1, [cond.operand2]
                    elif getattr(cond.operand2, 'entry', None) \
                             and cond.operand2.entry.is_const:
                        return not_in, cond.operand1, [cond.operand2]
                if is_common_value(cond.operand2, cond.operand2):
                    if cond.operand1.is_literal:
                        return not_in, cond.operand2, [cond.operand1]
                    elif getattr(cond.operand1, 'entry', None) \
                             and cond.operand1.entry.is_const:
                        return not_in, cond.operand2, [cond.operand1]
        elif isinstance(cond, ExprNodes.BoolBinopNode):
            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
                allow_not_in = (cond.operator == 'and')
                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
                    if (not not_in_1) or allow_not_in:
                        return not_in_1, t1, c1+c2
        return self.NO_MATCH

760 761
    def extract_in_string_conditions(self, string_literal):
        if isinstance(string_literal, ExprNodes.UnicodeNode):
762
            charvals = list(map(ord, set(string_literal.value)))
763 764 765 766 767 768 769 770 771
            charvals.sort()
            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
                                       constant_result=charval)
                     for charval in charvals ]
        else:
            # this is a bit tricky as Py3's bytes type returns
            # integers on iteration, whereas Py2 returns 1-char byte
            # strings
            characters = string_literal.value
772 773
            characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
            characters.sort()
774 775 776 777
            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
                                        constant_result=charval)
                     for charval in characters ]

778 779
    def extract_common_conditions(self, common_var, condition, allow_not_in):
        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
780
        if var is None:
781
            return self.NO_MATCH
782
        elif common_var is not None and not is_common_value(var, common_var):
783
            return self.NO_MATCH
784
        elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800
            return self.NO_MATCH
        return not_in, var, conditions

    def has_duplicate_values(self, condition_values):
        # duplicated values don't work in a switch statement
        seen = set()
        for value in condition_values:
            if value.constant_result is not ExprNodes.not_a_constant:
                if value.constant_result in seen:
                    return True
                seen.add(value.constant_result)
            else:
                # this isn't completely safe as we don't know the
                # final C value, but this is about the best we can do
                seen.add(getattr(getattr(value, 'entry', None), 'cname'))
        return False
801

802 803 804 805
    def visit_IfStatNode(self, node):
        common_var = None
        cases = []
        for if_clause in node.if_clauses:
806 807
            _, common_var, conditions = self.extract_common_conditions(
                common_var, if_clause.condition, False)
808
            if common_var is None:
809
                self.visitchildren(node)
810
                return node
811 812 813 814 815
            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                              conditions = conditions,
                                              body = if_clause.body))

        if sum([ len(case.conditions) for case in cases ]) < 2:
816 817 818 819
            self.visitchildren(node)
            return node
        if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
            self.visitchildren(node)
820
            return node
821

Robert Bradshaw's avatar
Robert Bradshaw committed
822
        common_var = unwrap_node(common_var)
823 824 825 826 827 828 829
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = node.else_clause)
        return switch_node

    def visit_CondExprNode(self, node):
830 831 832 833 834 835
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node.test, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
836
            return node
837 838 839
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            node.true_val, node.false_val)
840 841

    def visit_BoolBinopNode(self, node):
842 843 844 845 846 847
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
848 849
            return node

850 851
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))

    def visit_PrimaryCmpNode(self, node):
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
            return node

        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
868 869 870

    def build_simple_switch_statement(self, node, common_var, conditions,
                                      not_in, true_val, false_val):
871 872 873 874
        result_ref = UtilNodes.ResultRefNode(node)
        true_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
875
            rhs = true_val,
876 877 878 879
            first = True)
        false_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
880
            rhs = false_val,
881 882
            first = True)

883 884 885
        if not_in:
            true_body, false_body = false_body, true_body

886 887 888 889 890 891 892 893 894 895
        cases = [Nodes.SwitchCaseNode(pos = node.pos,
                                      conditions = conditions,
                                      body = true_body)]

        common_var = unwrap_node(common_var)
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = false_body)
        return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
896

897
    visit_Node = Visitor.VisitorTransform.recurse_to_children
898

899

900
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
901 902
    """
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
903
    of comparisons.
904
    """
905

906 907 908 909 910 911 912 913 914 915 916 917
    def visit_PrimaryCmpNode(self, node):
        self.visitchildren(node)
        if node.cascade is not None:
            return node
        elif node.operator == 'in':
            conjunction = 'or'
            eq_or_neq = '=='
        elif node.operator == 'not_in':
            conjunction = 'and'
            eq_or_neq = '!='
        else:
            return node
918

919 920 921
        if not isinstance(node.operand2, (ExprNodes.TupleNode,
                                          ExprNodes.ListNode,
                                          ExprNodes.SetNode)):
Stefan Behnel's avatar
Stefan Behnel committed
922
            return node
923

Stefan Behnel's avatar
Stefan Behnel committed
924 925 926
        args = node.operand2.args
        if len(args) == 0:
            return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
927

928
        lhs = UtilNodes.ResultRefNode(node.operand1)
Stefan Behnel's avatar
Stefan Behnel committed
929 930

        conds = []
931
        temps = []
Stefan Behnel's avatar
Stefan Behnel committed
932
        for arg in args:
933 934 935 936
            if not arg.is_simple():
                # must evaluate all non-simple RHS before doing the comparisons
                arg = UtilNodes.LetRefNode(arg)
                temps.append(arg)
Stefan Behnel's avatar
Stefan Behnel committed
937 938 939 940 941 942 943
            cond = ExprNodes.PrimaryCmpNode(
                                pos = node.pos,
                                operand1 = lhs,
                                operator = eq_or_neq,
                                operand2 = arg,
                                cascade = None)
            conds.append(ExprNodes.TypecastNode(
944
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
945 946 947 948
                                operand = cond,
                                type = PyrexTypes.c_bint_type))
        def concat(left, right):
            return ExprNodes.BoolBinopNode(
949
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
950 951 952 953
                                operator = conjunction,
                                operand1 = left,
                                operand2 = right)

954
        condition = reduce(concat, conds)
955 956 957 958
        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
        for temp in temps[::-1]:
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
        return new_node
959

960
    visit_Node = Visitor.VisitorTransform.recurse_to_children
961 962


963 964 965 966 967 968
class DropRefcountingTransform(Visitor.VisitorTransform):
    """Drop ref-counting in safe places.
    """
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_ParallelAssignmentNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
969 970 971
        """
        Parallel swap assignments like 'a,b = b,a' are safe.
        """
972 973 974 975
        left_names, right_names = [], []
        left_indices, right_indices = [], []
        temps = []

976 977
        for stat in node.stats:
            if isinstance(stat, Nodes.SingleAssignmentNode):
978 979
                if not self._extract_operand(stat.lhs, left_names,
                                             left_indices, temps):
980
                    return node
981 982
                if not self._extract_operand(stat.rhs, right_names,
                                             right_indices, temps):
983
                    return node
984 985 986
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
                # FIXME
                return node
987 988 989
            else:
                return node

990 991
        if left_names or right_names:
            # lhs/rhs names must be a non-redundant permutation
992 993
            lnames = [ path for path, n in left_names ]
            rnames = [ path for path, n in right_names ]
994 995 996
            if set(lnames) != set(rnames):
                return node
            if len(set(lnames)) != len(right_names):
997 998
                return node

999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
        if left_indices or right_indices:
            # base name and index of index nodes must be a
            # non-redundant permutation
            lindices = []
            for lhs_node in left_indices:
                index_id = self._extract_index_id(lhs_node)
                if not index_id:
                    return node
                lindices.append(index_id)
            rindices = []
            for rhs_node in right_indices:
                index_id = self._extract_index_id(rhs_node)
                if not index_id:
                    return node
                rindices.append(index_id)
1014

1015 1016 1017 1018 1019 1020 1021
            if set(lindices) != set(rindices):
                return node
            if len(set(lindices)) != len(right_indices):
                return node

            # really supporting IndexNode requires support in
            # __Pyx_GetItemInt(), so let's stop short for now
1022 1023
            return node

1024 1025 1026 1027
        temp_args = [t.arg for t in temps]
        for temp in temps:
            temp.use_managed_ref = False

1028
        for _, name_node in left_names + right_names:
1029 1030 1031 1032 1033
            if name_node not in temp_args:
                name_node.use_managed_ref = False

        for index_node in left_indices + right_indices:
            index_node.use_managed_ref = False
1034 1035 1036

        return node

1037 1038 1039 1040 1041 1042 1043
    def _extract_operand(self, node, names, indices, temps):
        node = unwrap_node(node)
        if not node.type.is_pyobject:
            return False
        if isinstance(node, ExprNodes.CoerceToTempNode):
            temps.append(node)
            node = node.arg
1044 1045 1046 1047
        name_path = []
        obj_node = node
        while isinstance(obj_node, ExprNodes.AttributeNode):
            if obj_node.is_py_attr:
1048
                return False
1049 1050 1051 1052 1053
            name_path.append(obj_node.member)
            obj_node = obj_node.obj
        if isinstance(obj_node, ExprNodes.NameNode):
            name_path.append(obj_node.name)
            names.append( ('.'.join(name_path[::-1]), node) )
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
        elif isinstance(node, ExprNodes.IndexNode):
            if node.base.type != Builtin.list_type:
                return False
            if not node.index.type.is_int:
                return False
            if not isinstance(node.base, ExprNodes.NameNode):
                return False
            indices.append(node)
        else:
            return False
        return True

    def _extract_index_id(self, index_node):
        base = index_node.base
        index = index_node.index
        if isinstance(index, ExprNodes.NameNode):
            index_val = index.name
        elif isinstance(index, ExprNodes.ConstNode):
            # FIXME:
            return None
        else:
            return None
        return (base.name, index_val)

1078

1079 1080 1081 1082 1083 1084 1085
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
    """Optimize some common calls to builtin types *before* the type
    analysis phase and *after* the declarations analysis phase.

    This transform cannot make use of any argument types, but it can
    restructure the tree in a way that the type analysis phase can
    respond to.
Stefan Behnel's avatar
Stefan Behnel committed
1086 1087 1088 1089

    Introducing C function calls here may not be a good idea.  Move
    them to the OptimizeBuiltinCalls transform instead, which runs
    after type analyis.
1090
    """
1091 1092
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1093

1094 1095 1096 1097 1098 1099 1100
    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        return self._dispatch_to_handler(node, function, node.args)

1101
    def visit_GeneralCallNode(self, node):
1102
        self.visitchildren(node)
1103
        function = node.function
1104
        if not self._function_is_builtin_name(function):
1105 1106 1107 1108
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1109
        args = arg_tuple.args
1110
        return self._dispatch_to_handler(
1111
            node, function, args, node.keyword_args)
1112

1113 1114 1115
    def _function_is_builtin_name(self, function):
        if not function.is_name:
            return False
1116 1117 1118
        env = self.current_env()
        entry = env.lookup(function.name)
        if entry is not env.builtin_scope().lookup_here(function.name):
1119
            return False
1120
        # if entry is None, it's at least an undeclared name, so likely builtin
1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158
        return True

    def _dispatch_to_handler(self, node, function, args, kwargs=None):
        if kwargs is None:
            handler_name = '_handle_simple_function_%s' % function.name
        else:
            handler_name = '_handle_general_function_%s' % function.name
        handle_call = getattr(self, handler_name, None)
        if handle_call is not None:
            if kwargs is None:
                return handle_call(node, args)
            else:
                return handle_call(node, args, kwargs)
        return node

    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
        node.function = ExprNodes.PythonCapiFunctionNode(
            node.function.pos, node.function.name, cname, func_type,
            utility_code = utility_code)

    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

    # specific handlers for simple call nodes

1159 1160 1161 1162 1163 1164 1165
    def _handle_simple_function_float(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.FloatNode(node.pos, value='0.0')
        if len(pos_args) > 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
        return node

1166 1167 1168
    class YieldNodeCollector(Visitor.TreeVisitor):
        def __init__(self):
            Visitor.TreeVisitor.__init__(self)
1169
            self.yield_stat_nodes = {}
1170 1171 1172
            self.yield_nodes = []

        visit_Node = Visitor.TreeVisitor.visitchildren
1173 1174
        # XXX: disable inlining while it's not back supported
        def __visit_YieldExprNode(self, node):
1175 1176 1177
            self.yield_nodes.append(node)
            self.visitchildren(node)

1178
        def __visit_ExprStatNode(self, node):
1179 1180 1181 1182
            self.visitchildren(node)
            if node.expr in self.yield_nodes:
                self.yield_stat_nodes[node.expr] = node

1183 1184 1185 1186 1187 1188
        def __visit_GeneratorExpressionNode(self, node):
            # enable when we support generic generator expressions
            #
            # everything below this node is out of scope
            pass

1189
    def _find_single_yield_expression(self, node):
1190 1191 1192
        collector = self.YieldNodeCollector()
        collector.visitchildren(node)
        if len(collector.yield_nodes) != 1:
1193 1194
            return None, None
        yield_node = collector.yield_nodes[0]
1195 1196 1197 1198
        try:
            return (yield_node.arg, collector.yield_stat_nodes[yield_node])
        except KeyError:
            return None, None
1199

1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
    def _handle_simple_function_all(self, node, pos_args):
        """Transform

        _result = all(x for L in LL for x in L)

        into

        for L in LL:
            for x in L:
                if not x:
                    _result = False
                    break
            else:
                continue
            break
        else:
            _result = True
        """
        return self._transform_any_all(node, pos_args, False)

    def _handle_simple_function_any(self, node, pos_args):
        """Transform

        _result = any(x for L in LL for x in L)

        into

        for L in LL:
            for x in L:
                if x:
                    _result = True
                    break
            else:
                continue
            break
        else:
            _result = False
        """
        return self._transform_any_all(node, pos_args, True)

    def _transform_any_all(self, node, pos_args, is_any):
        if len(pos_args) != 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
1245 1246
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop
1247 1248
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1249 1250 1251 1252 1253 1254 1255
            return node

        if is_any:
            condition = yield_expression
        else:
            condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)

1256
        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1257
        test_node = Nodes.IfStatNode(
1258
            yield_expression.pos,
1259 1260
            else_clause = None,
            if_clauses = [ Nodes.IfClauseNode(
1261
                yield_expression.pos,
1262 1263 1264 1265 1266 1267 1268
                condition = condition,
                body = Nodes.StatListNode(
                    node.pos,
                    stats = [
                        Nodes.SingleAssignmentNode(
                            node.pos,
                            lhs = result_ref,
1269
                            rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1270 1271 1272 1273 1274 1275 1276 1277 1278
                                                     constant_result = is_any)),
                        Nodes.BreakStatNode(node.pos)
                        ])) ]
            )
        loop = loop_node
        while isinstance(loop.body, Nodes.LoopNode):
            next_loop = loop.body
            loop.body = Nodes.StatListNode(loop.body.pos, stats = [
                loop.body,
1279
                Nodes.BreakStatNode(yield_expression.pos)
1280
                ])
1281
            next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1282 1283 1284 1285
            loop = next_loop
        loop_node.else_clause = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
1286
            rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1287 1288
                                     constant_result = not is_any))

1289
        Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1290

1291 1292
        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1293
            expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1294

1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341
    def _handle_simple_function_sorted(self, node, pos_args):
        """Transform sorted(genexpr) into [listcomp].sort().  CPython
        just reads the iterable into a list and calls .sort() on it.
        Expanding the iterable in a listcomp is still faster.
        """
        if len(pos_args) != 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
            return node

        result_node = UtilNodes.ResultRefNode(
            pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)

        target = ExprNodes.ListNode(node.pos, args = [])
        append_node = ExprNodes.ComprehensionAppendNode(
            yield_expression.pos, expr = yield_expression,
            target = ExprNodes.CloneNode(target))

        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)

        listcomp_node = ExprNodes.ComprehensionNode(
            gen_expr_node.pos, loop = loop_node, target = target,
            append = append_node, type = Builtin.list_type,
            expr_scope = gen_expr_node.expr_scope,
            has_local_scope = True)
        listcomp_assign_node = Nodes.SingleAssignmentNode(
            node.pos, lhs = result_node, rhs = listcomp_node, first = True)

        sort_method = ExprNodes.AttributeNode(
            node.pos, obj = result_node, attribute = EncodedString('sort'),
            # entry ? type ?
            needs_none_check = False)
        sort_node = Nodes.ExprStatNode(
            node.pos, expr = ExprNodes.SimpleCallNode(
                node.pos, function = sort_method, args = []))

        sort_node.analyse_declarations(self.current_env())

        return UtilNodes.TempResultFromStatNode(
            result_node,
            Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))

1342
    def _handle_simple_function_sum(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1343 1344
        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
        """
1345 1346
        if len(pos_args) not in (1,2):
            return node
1347 1348
        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
                                        ExprNodes.ComprehensionNode)):
1349 1350 1351 1352
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366
        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
            if yield_expression is None:
                return node
        else: # ComprehensionNode
            yield_stat_node = gen_expr_node.append
            yield_expression = yield_stat_node.expr
            try:
                if not yield_expression.is_literal or not yield_expression.type.is_int:
                    return node
            except AttributeError:
                return node # in case we don't have a type yet
            # special case: old Py2 backwards compatible "sum([int_const for ...])"
            # can safely be unpacked into a genexpr
1367 1368 1369 1370 1371 1372 1373 1374

        if len(pos_args) == 1:
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
        else:
            start = pos_args[1]

        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
        add_node = Nodes.SingleAssignmentNode(
1375
            yield_expression.pos,
1376 1377 1378 1379
            lhs = result_ref,
            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
            )

1380
        Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394

        exec_code = Nodes.StatListNode(
            node.pos,
            stats = [
                Nodes.SingleAssignmentNode(
                    start.pos,
                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
                    rhs = start,
                    first = True),
                loop_node
                ])

        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1395 1396
            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
            has_local_scope = gen_expr_node.has_local_scope)
1397

1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
    def _handle_simple_function_min(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '<')

    def _handle_simple_function_max(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '>')

    def _optimise_min_max(self, node, args, operator):
        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
        """
        if len(args) <= 1:
            # leave this to Python
            return node

1411
        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433

        last_result = args[0]
        for arg_node in cascaded_nodes:
            result_ref = UtilNodes.ResultRefNode(last_result)
            last_result = ExprNodes.CondExprNode(
                arg_node.pos,
                true_val = arg_node,
                false_val = result_ref,
                test = ExprNodes.PrimaryCmpNode(
                    arg_node.pos,
                    operand1 = arg_node,
                    operator = operator,
                    operand2 = result_ref,
                    )
                )
            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)

        for ref_node in cascaded_nodes[::-1]:
            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)

        return last_result

1434
    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448
        if len(pos_args) == 0:
            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
        # This is a bit special - for iterables (including genexps),
        # Python actually overallocates and resizes a newly created
        # tuple incrementally while reading items, which we can't
        # easily do without explicit node support. Instead, we read
        # the items into a list and then copy them into a tuple of the
        # final size.  This takes up to twice as much memory, but will
        # have to do until we have real support for genexps.
        result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
        if result is not node:
            return ExprNodes.AsTupleNode(node.pos, arg=result)
        return node

1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468
    def _handle_simple_function_list(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)

    def _handle_simple_function_set(self, node, pos_args):
        if len(pos_args) == 0:
            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)

    def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
        """Replace set(genexpr) and list(genexpr) by a literal comprehension.
        """
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1469 1470
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1471 1472 1473 1474
            return node

        target_node = container_node_class(node.pos, args=[])
        append_node = ExprNodes.ComprehensionAppendNode(
1475
            yield_expression.pos,
1476
            expr = yield_expression,
Stefan Behnel's avatar
Stefan Behnel committed
1477
            target = ExprNodes.CloneNode(target_node))
1478

1479
        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502

        setcomp = ExprNodes.ComprehensionNode(
            node.pos,
            has_local_scope = True,
            expr_scope = gen_expr_node.expr_scope,
            loop = loop_node,
            append = append_node,
            target = target_node)
        append_node.target = setcomp
        return setcomp

    def _handle_simple_function_dict(self, node, pos_args):
        """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
        """
        if len(pos_args) == 0:
            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1503 1504
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1505 1506 1507 1508 1509 1510 1511 1512 1513
            return node

        if not isinstance(yield_expression, ExprNodes.TupleNode):
            return node
        if len(yield_expression.args) != 2:
            return node

        target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
        append_node = ExprNodes.DictComprehensionAppendNode(
1514
            yield_expression.pos,
1515 1516
            key_expr = yield_expression.args[0],
            value_expr = yield_expression.args[1],
Stefan Behnel's avatar
Stefan Behnel committed
1517
            target = ExprNodes.CloneNode(target_node))
1518

1519
        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530

        dictcomp = ExprNodes.ComprehensionNode(
            node.pos,
            has_local_scope = True,
            expr_scope = gen_expr_node.expr_scope,
            loop = loop_node,
            append = append_node,
            target = target_node)
        append_node.target = dictcomp
        return dictcomp

1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546
    # specific handlers for general call nodes

    def _handle_general_function_dict(self, node, pos_args, kwargs):
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
        construction which is done anyway.
        """
        if len(pos_args) > 0:
            return node
        if not isinstance(kwargs, ExprNodes.DictNode):
            return node
        if node.starstar_arg:
            # we could optimize this by updating the kw dict instead
            return node
        return kwargs


1547
class OptimizeBuiltinCalls(Visitor.EnvTransform):
Stefan Behnel's avatar
Stefan Behnel committed
1548
    """Optimize some common methods calls and instantiation patterns
1549 1550 1551 1552 1553
    for builtin types *after* the type analysis phase.

    Running after type analysis, this transform can only perform
    function replacements that do not alter the function return type
    in a way that was not anticipated by the type analysis.
1554
    """
1555 1556
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1557

1558
    def visit_GeneralCallNode(self, node):
1559
        self.visitchildren(node)
1560 1561
        function = node.function
        if not function.type.is_pyobject:
Stefan Behnel's avatar
Stefan Behnel committed
1562
            return node
1563 1564 1565
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1566 1567
        if node.starstar_arg:
            return node
1568
        args = arg_tuple.args
1569
        return self._dispatch_to_handler(
1570
            node, function, args, node.keyword_args)
1571 1572 1573

    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
1574
        function = node.function
1575 1576 1577 1578 1579 1580 1581
        if function.type.is_pyobject:
            arg_tuple = node.arg_tuple
            if not isinstance(arg_tuple, ExprNodes.TupleNode):
                return node
            args = arg_tuple.args
        else:
            args = node.args
1582
        return self._dispatch_to_handler(
1583
            node, function, args)
1584

1585 1586
    ### cleanup to avoid redundant coercions to/from Python types

1587 1588 1589
    def _visit_PyTypeTestNode(self, node):
        # disabled - appears to break assignments in some cases, and
        # also drops a None check, which might still be required
1590 1591 1592 1593 1594 1595 1596 1597
        """Flatten redundant type checks after tree changes.
        """
        old_arg = node.arg
        self.visitchildren(node)
        if old_arg is node.arg or node.arg.type != node.type:
            return node
        return node.arg

1598 1599 1600 1601 1602 1603 1604 1605 1606
    def visit_TypecastNode(self, node):
        """
        Drop redundant type casts.
        """
        self.visitchildren(node)
        if node.type == node.operand.type:
            return node.operand
        return node

1607 1608 1609 1610 1611 1612 1613 1614 1615
    def visit_ExprStatNode(self, node):
        """
        Drop useless coercions.
        """
        self.visitchildren(node)
        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
            node.expr = node.expr.arg
        return node

1616 1617 1618 1619 1620
    def visit_CoerceToBooleanNode(self, node):
        """Drop redundant conversion nodes after tree changes.
        """
        self.visitchildren(node)
        arg = node.arg
Stefan Behnel's avatar
Stefan Behnel committed
1621 1622
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
1623 1624
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1625
                return arg.arg.coerce_to_boolean(self.current_env())
1626 1627
        return node

1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641
    def visit_CoerceFromPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes.

        Also, optimise away calls to Python's builtin int() and
        float() if the result is going to be coerced back into a C
        type anyway.
        """
        self.visitchildren(node)
        arg = node.arg
        if not arg.type.is_pyobject:
            # no Python conversion left at all, just do a C coercion instead
            if node.type == arg.type:
                return arg
            else:
1642
                return arg.coerce_to(node.type, self.current_env())
1643 1644
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
1645 1646 1647 1648
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type is PyrexTypes.py_object_type:
                if node.type.assignable_from(arg.arg.type):
                    # completely redundant C->Py->C coercion
1649
                    return arg.arg.coerce_to(node.type, self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
1650 1651 1652
        if isinstance(arg, ExprNodes.SimpleCallNode):
            if node.type.is_int or node.type.is_float:
                return self._optimise_numeric_cast_call(node, arg)
1653 1654 1655 1656 1657 1658
        elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
            index_node = arg.index
            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
                index_node = index_node.arg
            if index_node.type.is_int:
                return self._optimise_int_indexing(node, arg, index_node)
Stefan Behnel's avatar
Stefan Behnel committed
1659 1660
        return node

1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672
    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_type, [
            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
            ],
        exception_value = "((char)-1)",
        exception_check = True)

    def _optimise_int_indexing(self, coerce_node, arg, index_node):
        env = self.current_env()
        bound_check_bool = env.directives['boundscheck'] and 1 or 0
1673
        if arg.base.type is Builtin.bytes_type:
1674 1675 1676 1677 1678 1679 1680 1681 1682
            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
                # bytes[index] -> char
                bound_check_node = ExprNodes.IntNode(
                    coerce_node.pos, value=str(bound_check_bool),
                    constant_result=bound_check_bool)
                node = ExprNodes.PythonCapiCallNode(
                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
                    self.PyBytes_GetItemInt_func_type,
                    args = [
Stefan Behnel's avatar
Stefan Behnel committed
1683
                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1684 1685 1686 1687 1688 1689 1690 1691 1692 1693
                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
                        bound_check_node,
                        ],
                    is_temp = True,
                    utility_code=bytes_index_utility_code)
                if coerce_node.type is not PyrexTypes.c_char_type:
                    node = node.coerce_to(coerce_node.type, env)
                return node
        return coerce_node

Stefan Behnel's avatar
Stefan Behnel committed
1694
    def _optimise_numeric_cast_call(self, node, arg):
1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713
        function = arg.function
        if not isinstance(function, ExprNodes.NameNode) \
               or not function.type.is_builtin_type \
               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg.arg_tuple.args
        if len(args) != 1:
            return node
        func_arg = args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        elif func_arg.type.is_pyobject:
            # play safe: Python conversion might work on all sorts of things
            return node
        if function.name == 'int':
            if func_arg.type.is_int or node.type.is_int:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1714 1715
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1716 1717 1718 1719 1720
        elif function.name == 'float':
            if func_arg.type.is_float or node.type.is_float:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1721 1722
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1723 1724 1725 1726
        return node

    ### dispatch to specific optimisers

1727 1728 1729 1730 1731 1732 1733
    def _find_handler(self, match_name, has_kwargs):
        call_type = has_kwargs and 'general' or 'simple'
        handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
        if handler is None:
            handler = getattr(self, '_handle_any_%s' % match_name, None)
        return handler

1734
    def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1735
        if function.is_name:
1736 1737 1738
            # we only consider functions that are either builtin
            # Python functions or builtins that were already replaced
            # into a C function call (defined in the builtin scope)
1739 1740
            if not function.entry:
                return node
1741 1742
            is_builtin = function.entry.is_builtin or \
                         function.entry is self.current_env().builtin_scope().lookup_here(function.name)
1743 1744
            if not is_builtin:
                return node
1745 1746 1747 1748 1749
            function_handler = self._find_handler(
                "function_%s" % function.name, kwargs)
            if function_handler is None:
                return node
            if kwargs:
1750
                return function_handler(node, arg_list, kwargs)
1751
            else:
1752 1753
                return function_handler(node, arg_list)
        elif function.is_attribute and function.type.is_pyobject:
Stefan Behnel's avatar
Stefan Behnel committed
1754
            attr_name = function.attribute
1755 1756
            self_arg = function.obj
            obj_type = self_arg.type
1757
            is_unbound_method = False
1758 1759 1760 1761 1762 1763 1764
            if obj_type.is_builtin_type:
                if obj_type is Builtin.type_type and arg_list and \
                         arg_list[0].type.is_pyobject:
                    # calling an unbound method like 'list.append(L,x)'
                    # (ignoring 'type.mro()' here ...)
                    type_name = function.obj.name
                    self_arg = None
1765
                    is_unbound_method = True
1766 1767 1768
                else:
                    type_name = obj_type.name
            else:
1769
                type_name = "object" # safety measure
1770
            method_handler = self._find_handler(
Stefan Behnel's avatar
Stefan Behnel committed
1771
                "method_%s_%s" % (type_name, attr_name), kwargs)
1772
            if method_handler is None:
Stefan Behnel's avatar
Stefan Behnel committed
1773 1774 1775 1776
                if attr_name in TypeSlots.method_name_to_slot \
                       or attr_name == '__new__':
                    method_handler = self._find_handler(
                        "slot%s" % attr_name, kwargs)
1777 1778
                if method_handler is None:
                    return node
1779 1780 1781
            if self_arg is not None:
                arg_list = [self_arg] + list(arg_list)
            if kwargs:
1782
                return method_handler(node, arg_list, kwargs, is_unbound_method)
1783
            else:
1784
                return method_handler(node, arg_list, is_unbound_method)
1785
        else:
1786
            return node
1787

1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

1804 1805
    ### builtin types

1806 1807 1808 1809 1810 1811
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
        Builtin.dict_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

    def _handle_simple_function_dict(self, node, pos_args):
1812
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
1813
        """
1814
        if len(pos_args) != 1:
1815
            return node
1816
        arg = pos_args[0]
1817
        if arg.type is Builtin.dict_type:
1818
            arg = arg.as_none_safe_node("'NoneType' is not iterable")
1819 1820
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1821
                args = [arg],
1822 1823 1824
                is_temp = node.is_temp
                )
        return node
1825

1826 1827 1828 1829 1830
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
        Builtin.tuple_type, [
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
            ])

1831
    def _handle_simple_function_tuple(self, node, pos_args):
1832 1833
        """Replace tuple([...]) by a call to PyList_AsTuple.
        """
1834
        if len(pos_args) != 1:
1835
            return node
1836
        list_arg = pos_args[0]
1837 1838 1839 1840
        if list_arg.type is not Builtin.list_type:
            return node
        if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
                                     ExprNodes.ListNode)):
1841
            pos_args[0] = list_arg.as_none_safe_node(
1842
                "'NoneType' object is not iterable")
1843

1844 1845
        return ExprNodes.PythonCapiCallNode(
            node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1846
            args = pos_args,
1847 1848 1849
            is_temp = node.is_temp
            )

1850 1851 1852 1853 1854 1855 1856 1857
    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_double_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "((double)-1)",
        exception_check = True)

    def _handle_simple_function_float(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1858 1859 1860
        """Transform float() into either a C type cast or a faster C
        function call.
        """
1861 1862
        # Note: this requires the float() function to be typed as
        # returning a C 'double'
1863
        if len(pos_args) == 0:
Stefan Behnel's avatar
typo  
Stefan Behnel committed
1864
            return ExprNodes.FloatNode(
1865 1866 1867 1868
                node, value="0.0", constant_result=0.0
                ).coerce_to(Builtin.float_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1869 1870 1871 1872 1873 1874 1875
            return node
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        if func_arg.type is PyrexTypes.c_double_type:
            return func_arg
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1876 1877
            return ExprNodes.TypecastNode(
                node.pos, operand=func_arg, type=node.type)
1878 1879 1880 1881 1882 1883 1884 1885
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_AsDouble",
            self.PyObject_AsDouble_func_type,
            args = pos_args,
            is_temp = node.is_temp,
            utility_code = pyobject_as_double_utility_code,
            py_name = "float")

1886 1887 1888
    def _handle_simple_function_bool(self, node, pos_args):
        """Transform bool(x) into a type coercion to a boolean.
        """
1889 1890 1891 1892 1893 1894
        if len(pos_args) == 0:
            return ExprNodes.BoolNode(
                node.pos, value=False, constant_result=False
                ).coerce_to(Builtin.bool_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1895
            return node
Craig Citro's avatar
Craig Citro committed
1896
        else:
1897 1898 1899 1900 1901 1902
            # => !!<bint>(x)  to make sure it's exactly 0 or 1
            operand = pos_args[0].coerce_to_boolean(self.current_env())
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            # coerce back to Python object as that's the result we are expecting
            return operand.coerce_to_pyobject(self.current_env())
1903

1904 1905
    ### builtin functions

1906 1907 1908 1909 1910
    Pyx_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
            ])

1911 1912 1913 1914 1915 1916 1917
    PyObject_Size_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
            ])

    _map_to_capi_len_function = {
        Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1918
        Builtin.bytes_type     : "PyBytes_GET_SIZE",
1919 1920 1921 1922 1923 1924 1925
        Builtin.list_type      : "PyList_GET_SIZE",
        Builtin.tuple_type     : "PyTuple_GET_SIZE",
        Builtin.dict_type      : "PyDict_Size",
        Builtin.set_type       : "PySet_Size",
        Builtin.frozenset_type : "PySet_Size",
        }.get

1926
    def _handle_simple_function_len(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1927 1928
        """Replace len(char*) by the equivalent call to strlen() and
        len(known_builtin_type) by an equivalent C-API call.
Stefan Behnel's avatar
Stefan Behnel committed
1929
        """
1930 1931 1932 1933 1934 1935
        if len(pos_args) != 1:
            self._error_wrong_arg_count('len', node, pos_args, 1)
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            arg = arg.arg
1936 1937 1938 1939 1940
        if arg.type.is_string:
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "strlen", self.Pyx_strlen_func_type,
                args = [arg],
                is_temp = node.is_temp,
1941
                utility_code = Builtin.include_string_h_utility_code)
1942 1943 1944 1945
        elif arg.type.is_pyobject:
            cfunc_name = self._map_to_capi_len_function(arg.type)
            if cfunc_name is None:
                return node
1946 1947
            arg = arg.as_none_safe_node(
                "object of type 'NoneType' has no len()")
1948 1949 1950 1951
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, cfunc_name, self.PyObject_Size_func_type,
                args = [arg],
                is_temp = node.is_temp)
Stefan Behnel's avatar
Stefan Behnel committed
1952
        elif arg.type.is_unicode_char:
1953 1954
            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
                                     type=node.type)
1955
        else:
Stefan Behnel's avatar
Stefan Behnel committed
1956
            return node
1957
        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1958
            new_node = new_node.coerce_to(node.type, self.current_env())
1959
        return new_node
1960

1961 1962 1963 1964 1965 1966
    Pyx_Type_func_type = PyrexTypes.CFuncType(
        Builtin.type_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_type(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1967 1968
        """Replace type(o) by a macro call to Py_TYPE(o).
        """
1969
        if len(pos_args) != 1:
1970 1971
            return node
        node = ExprNodes.PythonCapiCallNode(
1972 1973 1974 1975
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
            args = pos_args,
            is_temp = False)
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1976

1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001
    Py_type_check_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_isinstance(self, node, pos_args):
        """Replace isinstance() checks against builtin types by the
        corresponding C-API call.
        """
        if len(pos_args) != 2:
            return node
        arg, types = pos_args
        temp = None
        if isinstance(types, ExprNodes.TupleNode):
            types = types.args
            arg = temp = UtilNodes.ResultRefNode(arg)
        elif types.type is Builtin.type_type:
            types = [types]
        else:
            return node

        tests = []
        test_nodes = []
        env = self.current_env()
        for test_type_node in types:
Robert Bradshaw's avatar
Robert Bradshaw committed
2002 2003 2004 2005 2006 2007 2008 2009
            builtin_type = None
            if isinstance(test_type_node, ExprNodes.NameNode):
                if test_type_node.entry:
                    entry = env.lookup(test_type_node.entry.name)
                    if entry and entry.type and entry.type.is_builtin_type:
                        builtin_type = entry.type
            if builtin_type and builtin_type is not Builtin.type_type:
                type_check_function = entry.type.type_check_function(exact=False)
2010 2011 2012
                if type_check_function in tests:
                    continue
                tests.append(type_check_function)
Robert Bradshaw's avatar
Robert Bradshaw committed
2013 2014 2015 2016 2017
                type_check_args = [arg]
            elif test_type_node.type is Builtin.type_type:
                type_check_function = '__Pyx_TypeCheck'
                type_check_args = [arg, test_type_node]
            else:
2018
                return node
2019 2020 2021 2022 2023 2024
            test_nodes.append(
                ExprNodes.PythonCapiCallNode(
                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
                    args = type_check_args,
                    is_temp = True,
                    ))
2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036

        def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
            or_node = make_binop_node(node.pos, 'or', a, b)
            or_node.type = PyrexTypes.c_bint_type
            or_node.is_temp = True
            return or_node

        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
        if temp is not None:
            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
        return test_node

2037 2038 2039 2040 2041 2042 2043
    def _handle_simple_function_ord(self, node, pos_args):
        """Unpack ord(Py_UNICODE).
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
Stefan Behnel's avatar
Stefan Behnel committed
2044
            if arg.arg.type.is_unicode_char:
2045 2046 2047
                return arg.arg.coerce_to(node.type, self.current_env())
        return node

2048 2049
    ### special methods

2050 2051 2052 2053 2054
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
            ])

Stefan Behnel's avatar
Stefan Behnel committed
2055
    def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073
        """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
        """
        obj = node.function.obj
        if not is_unbound_method or len(args) != 1:
            return node
        type_arg = args[0]
        if not obj.is_name or not type_arg.is_name:
            # play safe
            return node
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
            # not a known type, play safe
            return node
        if not type_arg.type_entry or not obj.type_entry:
            if obj.name != type_arg.name:
                return node
            # otherwise, we know it's a type and we know it's the same
            # type for both - that should do
        elif type_arg.type_entry != obj.type_entry:
2074
            # different types - may or may not lead to an error at runtime
2075 2076
            return node

Stefan Behnel's avatar
Stefan Behnel committed
2077 2078 2079 2080
        # FIXME: we could potentially look up the actual tp_new C
        # method of the extension type and call that instead of the
        # generic slot. That would also allow us to pass parameters
        # efficiently.
2081

2082 2083
        if not type_arg.type_entry:
            # arbitrary variable, needs a None check for safety
2084
            type_arg = type_arg.as_none_safe_node(
2085 2086
                "object.__new__(X): X is not a type object (NoneType)")

2087
        return ExprNodes.PythonCapiCallNode(
2088
            node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2089
            args = [type_arg],
2090 2091 2092 2093
            utility_code = tpnew_utility_code,
            is_temp = node.is_temp
            )

2094 2095 2096 2097 2098 2099 2100 2101
    ### methods of builtin types

    PyObject_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
            ])

2102
    def _handle_simple_method_object_append(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2103 2104 2105
        """Optimistic optimisation as X.append() is almost always
        referring to a list.
        """
2106
        if len(args) != 2:
2107 2108
            return node

2109 2110 2111
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
            args = args,
Stefan Behnel's avatar
Stefan Behnel committed
2112
            may_return_none = True,
2113
            is_temp = node.is_temp,
2114
            utility_code = append_utility_code
2115 2116
            )

Robert Bradshaw's avatar
Robert Bradshaw committed
2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            ])

    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
            ])

    def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2129 2130 2131
        """Optimistic optimisation as X.pop([n]) is almost always
        referring to a list.
        """
Robert Bradshaw's avatar
Robert Bradshaw committed
2132 2133 2134 2135
        if len(args) == 1:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
                args = args,
Stefan Behnel's avatar
Stefan Behnel committed
2136
                may_return_none = True,
Robert Bradshaw's avatar
Robert Bradshaw committed
2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147
                is_temp = node.is_temp,
                utility_code = pop_utility_code
                )
        elif len(args) == 2:
            if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
                original_type = args[1].arg.type
                if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
                    args[1] = args[1].arg
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
                        args = args,
Stefan Behnel's avatar
Stefan Behnel committed
2148
                        may_return_none = True,
Robert Bradshaw's avatar
Robert Bradshaw committed
2149 2150 2151
                        is_temp = node.is_temp,
                        utility_code = pop_index_utility_code
                        )
2152

Robert Bradshaw's avatar
Robert Bradshaw committed
2153 2154
        return node

2155 2156
    _handle_simple_method_list_pop = _handle_simple_method_object_pop

2157
    single_param_func_type = PyrexTypes.CFuncType(
2158
        PyrexTypes.c_returncode_type, [
2159 2160 2161
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")
2162

2163
    def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2164 2165
        """Call PyList_Sort() instead of the 0-argument l.sort().
        """
2166
        if len(args) != 1:
2167
            return node
2168
        return self._substitute_method_call(
2169
            node, "PyList_Sort", self.single_param_func_type,
2170
            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
2171

2172 2173 2174 2175 2176
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2177
            ])
2178 2179

    def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2180 2181
        """Replace dict.get() by a call to PyDict_GetItem().
        """
2182 2183 2184 2185 2186 2187 2188 2189 2190
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
            node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
            'get', is_unbound_method, args,
Stefan Behnel's avatar
Stefan Behnel committed
2191
            may_return_none = True,
2192 2193
            utility_code = dict_getitem_default_utility_code)

2194 2195 2196

    ### unicode type methods

2197 2198
    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
2199
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2200 2201 2202 2203 2204 2205 2206
            ])

    def _inject_unicode_predicate(self, node, args, is_unbound_method):
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
2207
               not ustring.arg.type.is_unicode_char:
2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236
            return node
        uchar = ustring.arg
        method_name = node.function.attribute
        if method_name == 'istitle':
            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
            utility_code = py_unicode_istitle_utility_code
            function_name = '__Pyx_Py_UNICODE_ISTITLE'
        else:
            utility_code = None
            function_name = 'Py_UNICODE_%s' % method_name.upper()
        func_call = self._substitute_method_call(
            node, function_name, self.PyUnicode_uchar_predicate_func_type,
            method_name, is_unbound_method, [uchar],
            utility_code = utility_code)
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate

    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2237 2238
        PyrexTypes.c_py_ucs4_type, [
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2239 2240 2241 2242 2243 2244 2245
            ])

    def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
2246
               not ustring.arg.type.is_unicode_char:
2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261
            return node
        uchar = ustring.arg
        method_name = node.function.attribute
        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
        func_call = self._substitute_method_call(
            node, function_name, self.PyUnicode_uchar_conversion_func_type,
            method_name, is_unbound_method, [uchar])
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
    _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
    _handle_simple_method_unicode_title = _inject_unicode_character_conversion

2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
            ])

    def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
        """Replace unicode.splitlines(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2):
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
            return node
2275
        self._inject_bint_default_argument(node, args, 1, False)
2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297

        return self._substitute_method_call(
            node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
            'splitlines', is_unbound_method, args)

    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
            ]
        )

    def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
        """Replace unicode.split(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2,3):
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
            return node
        if len(args) < 2:
            args.append(ExprNodes.NullNode(node.pos))
2298 2299
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2300 2301 2302 2303 2304

        return self._substitute_method_call(
            node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
            'split', is_unbound_method, args)

2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330
    PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-1')

    def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
        return self._inject_unicode_tailmatch(
            node, args, is_unbound_method, 'endswith', +1)

    def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
        return self._inject_unicode_tailmatch(
            node, args, is_unbound_method, 'startswith', -1)

    def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
                                  method_name, direction):
        """Replace unicode.startswith(...) and unicode.endswith(...)
        by a direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
2331 2332 2333 2334
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2335 2336 2337 2338 2339 2340 2341
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
            node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
            method_name, is_unbound_method, args,
            utility_code = unicode_tailmatch_utility_code)
Stefan Behnel's avatar
Stefan Behnel committed
2342
        return method_call.coerce_to(Builtin.bool_type, self.current_env())
2343

2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369
    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-2')

    def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
        return self._inject_unicode_find(
            node, args, is_unbound_method, 'find', +1)

    def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
        return self._inject_unicode_find(
            node, args, is_unbound_method, 'rfind', -1)

    def _inject_unicode_find(self, node, args, is_unbound_method,
                             method_name, direction):
        """Replace unicode.find(...) and unicode.rfind(...) by a
        direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
2370 2371 2372 2373
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2374 2375 2376 2377 2378 2379
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
            node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
            method_name, is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
2380
        return method_call.coerce_to_pyobject(self.current_env())
2381

Stefan Behnel's avatar
Stefan Behnel committed
2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397
    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            ],
        exception_value = '-1')

    def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
        """Replace unicode.count(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
            return node
2398 2399 2400 2401
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
Stefan Behnel's avatar
Stefan Behnel committed
2402 2403 2404 2405

        method_call = self._substitute_method_call(
            node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
            'count', is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
2406
        return method_call.coerce_to_pyobject(self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
2407

Stefan Behnel's avatar
Stefan Behnel committed
2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422
    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
            ])

    def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
        """Replace unicode.replace(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (3,4):
            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
            return node
2423 2424
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
Stefan Behnel's avatar
Stefan Behnel committed
2425 2426 2427 2428 2429

        return self._substitute_method_call(
            node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
            'replace', is_unbound_method, args)

2430 2431
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
2432
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2433 2434
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2435
            ])
2436 2437 2438

    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
2439
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2440
            ])
2441 2442 2443 2444

    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
                          'unicode_escape', 'raw_unicode_escape']

2445 2446
    _special_codecs = [ (name, codecs.getencoder(name))
                        for name in _special_encodings ]
2447 2448

    def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2449 2450 2451
        """Replace unicode.encode(...) by a direct C-API call to the
        corresponding codec.
        """
2452
        if len(args) < 1 or len(args) > 3:
2453
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2454 2455 2456 2457 2458
            return node

        string_node = args[0]

        if len(args) == 1:
2459
            null_node = ExprNodes.NullNode(node.pos)
2460 2461 2462 2463 2464
            return self._substitute_method_call(
                node, "PyUnicode_AsEncodedString",
                self.PyUnicode_AsEncodedString_func_type,
                'encode', is_unbound_method, [string_node, null_node, null_node])

2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        if isinstance(string_node, ExprNodes.UnicodeNode):
            # constant, so try to do the encoding at compile time
            try:
                value = string_node.value.encode(encoding, error_handling)
            except:
                # well, looks like we can't
                pass
            else:
                value = BytesLiteral(value)
                value.encoding = encoding
                return ExprNodes.BytesNode(
                    string_node.pos, value=value, type=Builtin.bytes_type)

        if error_handling == 'strict':
            # try to find a specific encoder function
            codec_name = self._find_special_codec_name(encoding)
            if codec_name is not None:
                encode_function = "PyUnicode_As%sString" % codec_name
                return self._substitute_method_call(
                    node, encode_function,
                    self.PyUnicode_AsXyzString_func_type,
                    'encode', is_unbound_method, [string_node])

        return self._substitute_method_call(
            node, "PyUnicode_AsEncodedString",
            self.PyUnicode_AsEncodedString_func_type,
            'encode', is_unbound_method,
            [string_node, encoding_node, error_handling_node])

    PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2504
            ])
2505 2506 2507 2508 2509 2510 2511

    PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2512
            ])
2513 2514

    def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2515 2516 2517
        """Replace char*.decode() by a direct C-API call to the
        corresponding codec, possibly resoving a slice on the char*.
        """
2518 2519 2520
        if len(args) < 1 or len(args) > 3:
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
            return node
2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532
        temps = []
        if isinstance(args[0], ExprNodes.SliceIndexNode):
            index_node = args[0]
            string_node = index_node.base
            if not string_node.type.is_string:
                # nothing to optimise here
                return node
            start, stop = index_node.start, index_node.stop
            if not start or start.constant_result == 0:
                start = None
            else:
                if start.type.is_pyobject:
2533
                    start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2534
                if stop:
2535 2536 2537 2538 2539 2540 2541 2542 2543 2544
                    start = UtilNodes.LetRefNode(start)
                    temps.append(start)
                string_node = ExprNodes.AddNode(pos=start.pos,
                                                operand1=string_node,
                                                operator='+',
                                                operand2=start,
                                                is_temp=False,
                                                type=string_node.type
                                                )
            if stop and stop.type.is_pyobject:
2545
                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2546 2547 2548 2549 2550 2551 2552
        elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
                 and args[0].arg.type.is_string:
            # use strlen() to find the string length, just as CPython would
            start = stop = None
            string_node = args[0].arg
        else:
            # let Python do its job
2553
            return node
2554

2555
        if not stop:
2556
            if start or not string_node.is_name:
2557 2558 2559 2560 2561 2562
                string_node = UtilNodes.LetRefNode(string_node)
                temps.append(string_node)
            stop = ExprNodes.PythonCapiCallNode(
                string_node.pos, "strlen", self.Pyx_strlen_func_type,
                    args = [string_node],
                    is_temp = False,
2563
                    utility_code = Builtin.include_string_h_utility_code,
2564
                    ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2565 2566 2567 2568 2569 2570 2571 2572 2573
        elif start:
            stop = ExprNodes.SubNode(
                pos = stop.pos,
                operand1 = stop,
                operator = '-',
                operand2 = start,
                is_temp = False,
                type = PyrexTypes.c_py_ssize_t_type
                )
2574 2575 2576 2577 2578 2579 2580

        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

        # try to find a specific encoder function
2581 2582 2583
        codec_name = None
        if encoding is not None:
            codec_name = self._find_special_codec_name(encoding)
2584 2585
        if codec_name is not None:
            decode_function = "PyUnicode_Decode%s" % codec_name
2586
            node = ExprNodes.PythonCapiCallNode(
2587 2588 2589 2590 2591
                node.pos, decode_function,
                self.PyUnicode_DecodeXyz_func_type,
                args = [string_node, stop, error_handling_node],
                is_temp = node.is_temp,
                )
2592 2593 2594 2595 2596 2597 2598
        else:
            node = ExprNodes.PythonCapiCallNode(
                node.pos, "PyUnicode_Decode",
                self.PyUnicode_Decode_func_type,
                args = [string_node, stop, encoding_node, error_handling_node],
                is_temp = node.is_temp,
                )
2599

2600 2601 2602
        for temp in temps[::-1]:
            node = UtilNodes.EvalWithTempExprNode(temp, node)
        return node
2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617

    def _find_special_codec_name(self, encoding):
        try:
            requested_codec = codecs.getencoder(encoding)
        except:
            return None
        for name, codec in self._special_codecs:
            if codec == requested_codec:
                if '_' in name:
                    name = ''.join([ s.capitalize()
                                     for s in name.split('_')])
                return name
        return None

    def _unpack_encoding_and_error_mode(self, pos, args):
2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628
        null_node = ExprNodes.NullNode(pos)

        if len(args) >= 2:
            encoding_node = args[1]
            if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
                encoding_node = encoding_node.arg
            if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                                          ExprNodes.BytesNode)):
                encoding = encoding_node.value
                encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
                                                     type=PyrexTypes.c_char_ptr_type)
2629 2630 2631 2632
            elif encoding_node.type is Builtin.bytes_type:
                encoding = None
                encoding_node = encoding_node.coerce_to(
                    PyrexTypes.c_char_ptr_type, self.current_env())
2633 2634 2635 2636
            elif encoding_node.type.is_string:
                encoding = None
            else:
                return None
2637
        else:
2638 2639
            encoding = None
            encoding_node = null_node
2640 2641 2642 2643 2644

        if len(args) == 3:
            error_handling_node = args[2]
            if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
                error_handling_node = error_handling_node.arg
2645 2646 2647 2648 2649 2650 2651 2652 2653 2654
            if isinstance(error_handling_node,
                          (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                           ExprNodes.BytesNode)):
                error_handling = error_handling_node.value
                if error_handling == 'strict':
                    error_handling_node = null_node
                else:
                    error_handling_node = ExprNodes.BytesNode(
                        error_handling_node.pos, value=error_handling,
                        type=PyrexTypes.c_char_ptr_type)
2655 2656 2657 2658
            elif error_handling_node.type is Builtin.bytes_type:
                error_handling = None
                error_handling_node = error_handling_node.coerce_to(
                    PyrexTypes.c_char_ptr_type, self.current_env())
2659 2660
            elif error_handling_node.type.is_string:
                error_handling = None
2661
            else:
2662
                return None
2663 2664 2665 2666
        else:
            error_handling = 'strict'
            error_handling_node = null_node

2667
        return (encoding, encoding_node, error_handling, error_handling_node)
2668

2669 2670 2671

    ### helpers

2672
    def _substitute_method_call(self, node, name, func_type,
2673
                                attr_name, is_unbound_method, args=(),
Stefan Behnel's avatar
Stefan Behnel committed
2674 2675
                                utility_code=None,
                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2676
        args = list(args)
2677
        if args and not args[0].is_literal:
2678 2679
            self_arg = args[0]
            if is_unbound_method:
2680
                self_arg = self_arg.as_none_safe_node(
2681
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2682
                        attr_name, node.function.obj.name))
2683
            else:
2684
                self_arg = self_arg.as_none_safe_node(
2685 2686
                    "'NoneType' object has no attribute '%s'" % attr_name,
                    error = "PyExc_AttributeError")
2687
            args[0] = self_arg
2688
        return ExprNodes.PythonCapiCallNode(
2689
            node.pos, name, func_type,
2690
            args = args,
2691
            is_temp = node.is_temp,
Stefan Behnel's avatar
Stefan Behnel committed
2692 2693
            utility_code = utility_code,
            may_return_none = may_return_none,
2694 2695
            )

2696 2697 2698
    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
2699 2700
            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
                                          type=type, constant_result=default_value))
2701
        else:
2702
            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2703 2704 2705 2706

    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
2707 2708 2709
            default_value = bool(default_value)
            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
                                           constant_result=default_value))
2710
        else:
2711
            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2712

2713

2714 2715 2716 2717
py_unicode_istitle_utility_code = UtilityCode(
# Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
# additionally allows character that comply with Py_UNICODE_ISUPPER()
proto = '''
2718
#if PY_VERSION_HEX < 0x030200A2
2719
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2720 2721 2722
#else
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar); /* proto */
#endif
2723 2724
''',
impl = '''
2725
#if PY_VERSION_HEX < 0x030200A2
2726
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2727 2728 2729
#else
static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar) {
#endif
2730 2731 2732 2733
    return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
}
''')

2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760
unicode_tailmatch_utility_code = UtilityCode(
    # Python's unicode.startswith() and unicode.endswith() support a
    # tuple of prefixes/suffixes, whereas it's much more common to
    # test for a single unicode string.
proto = '''
static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
Py_ssize_t start, Py_ssize_t end, int direction);
''',
impl = '''
static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
                                     Py_ssize_t start, Py_ssize_t end, int direction) {
    if (unlikely(PyTuple_Check(substr))) {
        int result;
        Py_ssize_t i;
        for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
            result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
                                         start, end, direction);
            if (result) {
                return result;
            }
        }
        return 0;
    }
    return PyUnicode_Tailmatch(s, substr, start, end, direction);
}
''',
)
2761

2762 2763
dict_getitem_default_utility_code = UtilityCode(
proto = '''
2764
static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785
    PyObject* value;
#if PY_MAJOR_VERSION >= 3
    value = PyDict_GetItemWithError(d, key);
    if (unlikely(!value)) {
        if (unlikely(PyErr_Occurred()))
            return NULL;
        value = default_value;
    }
    Py_INCREF(value);
#else
    if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
        /* these presumably have safe hash functions */
        value = PyDict_GetItem(d, key);
        if (unlikely(!value)) {
            value = default_value;
        }
        Py_INCREF(value);
    } else {
        PyObject *m;
        m = __Pyx_GetAttrString(d, "get");
        if (!m) return NULL;
2786 2787
        value = PyObject_CallFunctionObjArgs(m, key,
            (default_value == Py_None) ? NULL : default_value, NULL);
2788 2789 2790 2791 2792 2793 2794 2795 2796
        Py_DECREF(m);
    }
#endif
    return value;
}
''',
impl = ""
)

2797 2798
append_utility_code = UtilityCode(
proto = """
2799
static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816
    if (likely(PyList_CheckExact(L))) {
        if (PyList_Append(L, x) < 0) return NULL;
        Py_INCREF(Py_None);
        return Py_None; /* this is just to have an accurate signature */
    }
    else {
        PyObject *r, *m;
        m = __Pyx_GetAttrString(L, "append");
        if (!m) return NULL;
        r = PyObject_CallFunctionObjArgs(m, x, NULL);
        Py_DECREF(m);
        return r;
    }
}
""",
impl = ""
)
2817 2818


Robert Bradshaw's avatar
Robert Bradshaw committed
2819 2820
pop_utility_code = UtilityCode(
proto = """
2821
static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2822
    PyObject *r, *m;
2823
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
2824 2825 2826 2827 2828 2829
    if (likely(PyList_CheckExact(L))
            /* Check that both the size is positive and no reallocation shrinking needs to be done. */
            && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
        Py_SIZE(L) -= 1;
        return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
    }
2830 2831 2832 2833 2834 2835
#endif
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) return NULL;
    r = PyObject_CallObject(m, NULL);
    Py_DECREF(m);
    return r;
Robert Bradshaw's avatar
Robert Bradshaw committed
2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847
}
""",
impl = ""
)

pop_index_utility_code = UtilityCode(
proto = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
""",
impl = """
static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
    PyObject *r, *m, *t, *py_ix;
2848
#if PY_VERSION_HEX >= 0x02040000
Robert Bradshaw's avatar
Robert Bradshaw committed
2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866
    if (likely(PyList_CheckExact(L))) {
        Py_ssize_t size = PyList_GET_SIZE(L);
        if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
            if (ix < 0) {
                ix += size;
            }
            if (likely(0 <= ix && ix < size)) {
                Py_ssize_t i;
                PyObject* v = PyList_GET_ITEM(L, ix);
                Py_SIZE(L) -= 1;
                size -= 1;
                for(i=ix; i<size; i++) {
                    PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
                }
                return v;
            }
        }
    }
2867
#endif
Robert Bradshaw's avatar
Robert Bradshaw committed
2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890
    py_ix = t = NULL;
    m = __Pyx_GetAttrString(L, "pop");
    if (!m) goto bad;
    py_ix = PyInt_FromSsize_t(ix);
    if (!py_ix) goto bad;
    t = PyTuple_New(1);
    if (!t) goto bad;
    PyTuple_SET_ITEM(t, 0, py_ix);
    py_ix = NULL;
    r = PyObject_CallObject(m, t);
    Py_DECREF(m);
    Py_DECREF(t);
    return r;
bad:
    Py_XDECREF(m);
    Py_XDECREF(t);
    Py_XDECREF(py_ix);
    return NULL;
}
"""
)


2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903
pyobject_as_double_utility_code = UtilityCode(
proto = '''
static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */

#define __Pyx_PyObject_AsDouble(obj) \\
    ((likely(PyFloat_CheckExact(obj))) ? \\
     PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
''',
impl='''
static double __Pyx__PyObject_AsDouble(PyObject* obj) {
    PyObject* float_value;
    if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
        return PyFloat_AsDouble(obj);
2904
    } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924
#if PY_MAJOR_VERSION >= 3
        float_value = PyFloat_FromString(obj);
#else
        float_value = PyFloat_FromString(obj, 0);
#endif
    } else {
        PyObject* args = PyTuple_New(1);
        if (unlikely(!args)) goto bad;
        PyTuple_SET_ITEM(args, 0, obj);
        float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
        PyTuple_SET_ITEM(args, 0, 0);
        Py_DECREF(args);
    }
    if (likely(float_value)) {
        double value = PyFloat_AS_DOUBLE(float_value);
        Py_DECREF(float_value);
        return value;
    }
bad:
    return (double)-1;
2925
}
2926 2927 2928 2929
'''
)


2930 2931 2932 2933 2934 2935 2936 2937
bytes_index_utility_code = UtilityCode(
proto = """
static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
""",
impl = """
static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
    if (check_bounds) {
        if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2938
            ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949 2950
            PyErr_Format(PyExc_IndexError, "string index out of range");
            return -1;
        }
    }
    if (index < 0)
        index += PyBytes_GET_SIZE(bytes);
    return PyBytes_AS_STRING(bytes)[index];
}
"""
)


2951 2952
tpnew_utility_code = UtilityCode(
proto = """
2953
static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2954 2955 2956 2957 2958 2959 2960
    return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
        (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
}
""" % {'TUPLE' : Naming.empty_tuple}
)


2961 2962 2963 2964
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
    """Calculate the result of constant expressions to store it in
    ``expr_node.constant_result``, and replace trivial cases by their
    constant result.
2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975

    General rules:

    - We calculate float constants to make them available to the
      compiler, but we do not aggregate them into a single literal
      node to prevent any loss of precision.

    - We recursively calculate constants from non-literal nodes to
      make them available to the compiler, but we only aggregate
      literal nodes at each step.  Non-literal nodes are never merged
      into a single node.
2976 2977 2978 2979 2980 2981 2982 2983 2984 2985 2986
    """
    def _calculate_const(self, node):
        if node.constant_result is not ExprNodes.constant_value_not_set:
            return

        # make sure we always set the value
        not_a_constant = ExprNodes.not_a_constant
        node.constant_result = not_a_constant

        # check if all children are constant
        children = self.visitchildren(node)
2987
        for child_result in children.values():
2988 2989
            if type(child_result) is list:
                for child in child_result:
Stefan Behnel's avatar
Stefan Behnel committed
2990
                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2991
                        return
Stefan Behnel's avatar
Stefan Behnel committed
2992
            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2993 2994 2995 2996 2997 2998 2999
                return

        # now try to calculate the real constant value
        try:
            node.calculate_constant_result()
#            if node.constant_result is not ExprNodes.not_a_constant:
#                print node.__class__.__name__, node.constant_result
3000
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3001 3002 3003 3004 3005 3006 3007
            # ignore all 'normal' errors here => no constant result
            pass
        except Exception:
            # this looks like a real error
            import traceback, sys
            traceback.print_exc(file=sys.stdout)

Stefan Behnel's avatar
Stefan Behnel committed
3008 3009
    NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
                       ExprNodes.LongNode, ExprNodes.FloatNode]
3010 3011 3012 3013 3014 3015 3016 3017

    def _widest_node_class(self, *nodes):
        try:
            return self.NODE_TYPE_ORDER[
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
        except ValueError:
            return None

3018 3019 3020 3021
    def visit_ExprNode(self, node):
        self._calculate_const(node)
        return node

3022
    def visit_UnopNode(self, node):
3023 3024 3025 3026 3027
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        if not node.operand.is_literal:
            return node
3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038
        if isinstance(node.operand, ExprNodes.BoolNode):
            return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
                                     type = PyrexTypes.c_int_type,
                                     constant_result = node.constant_result)
        if node.operator == '+':
            return self._handle_UnaryPlusNode(node)
        elif node.operator == '-':
            return self._handle_UnaryMinusNode(node)
        return node

    def _handle_UnaryMinusNode(self, node):
3039 3040 3041 3042 3043 3044 3045 3046 3047 3048 3049 3050
        if isinstance(node.operand, ExprNodes.LongNode):
            return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
                                      constant_result = node.constant_result)
        if isinstance(node.operand, ExprNodes.FloatNode):
            # this is a safe operation
            return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
                                       constant_result = node.constant_result)
        node_type = node.operand.type
        if node_type.is_int and node_type.signed or \
               isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
            return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
                                     type = node_type,
3051
                                     longness = node.operand.longness,
3052 3053 3054
                                     constant_result = node.constant_result)
        return node

3055
    def _handle_UnaryPlusNode(self, node):
3056 3057 3058 3059
        if node.constant_result == node.operand.constant_result:
            return node.operand
        return node

3060 3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072 3073 3074
    def visit_BoolBinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        if not node.operand1.is_literal or not node.operand2.is_literal:
            return node

        if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
            return node.operand1
        elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
            return node.operand2
        else:
            # FIXME: we could do more ...
            return node

3075 3076 3077 3078
    def visit_BinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
3079 3080
        if isinstance(node.constant_result, float):
            return node
3081 3082
        operand1, operand2 = node.operand1, node.operand2
        if not operand1.is_literal or not operand2.is_literal:
3083 3084 3085
            return node

        # now inject a new constant node with the calculated value
3086
        try:
3087
            type1, type2 = operand1.type, operand2.type
3088
            if type1 is None or type2 is None:
3089 3090 3091 3092
                return node
        except AttributeError:
            return node

3093
        if type1.is_numeric and type2.is_numeric:
3094
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3095 3096
        else:
            widest_type = PyrexTypes.py_object_type
3097
        target_class = self._widest_node_class(operand1, operand2)
3098 3099 3100
        if target_class is None:
            return node
        elif target_class is ExprNodes.IntNode:
3101 3102 3103 3104
            unsigned = getattr(operand1, 'unsigned', '') and \
                       getattr(operand2, 'unsigned', '')
            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
                                 len(getattr(operand2, 'longness', '')))]
3105 3106 3107 3108 3109 3110 3111 3112
            new_node = ExprNodes.IntNode(pos=node.pos,
                                         unsigned = unsigned, longness = longness,
                                         value = str(node.constant_result),
                                         constant_result = node.constant_result)
            # IntNode is smart about the type it chooses, so we just
            # make sure we were not smarter this time
            if widest_type.is_pyobject or new_node.type.is_pyobject:
                new_node.type = PyrexTypes.py_object_type
3113
            else:
3114
                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3115
        else:
3116 3117 3118 3119 3120 3121 3122
            if isinstance(node, ExprNodes.BoolNode):
                node_value = node.constant_result
            else:
                node_value = str(node.constant_result)
            new_node = target_class(pos=node.pos, type = widest_type,
                                    value = node_value,
                                    constant_result = node.constant_result)
3123 3124
        return new_node

3125 3126 3127 3128 3129 3130 3131 3132 3133 3134 3135 3136 3137 3138
    def visit_PrimaryCmpNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
        bool_result = bool(node.constant_result)
        return ExprNodes.BoolNode(node.pos, value=bool_result,
                                  constant_result=bool_result)

    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        # eliminate dead code based on constant condition results
        if_clauses = []
        for if_clause in node.if_clauses:
            condition_result = if_clause.get_constant_condition_result()
3139 3140
            if condition_result is None:
                # unknown result => normal runtime evaluation
3141
                if_clauses.append(if_clause)
3142 3143 3144 3145 3146 3147
            elif condition_result == True:
                # subsequent clauses can safely be dropped
                node.else_clause = if_clause.body
                break
            else:
                assert condition_result == False
3148
        if not if_clauses:
3149 3150 3151
            return node.else_clause
        node.if_clauses = if_clauses
        return node
3152

3153 3154
    # in the future, other nodes can have their own handler method here
    # that can replace them with a constant result node
Stefan Behnel's avatar
Stefan Behnel committed
3155

3156
    visit_Node = Visitor.VisitorTransform.recurse_to_children
3157 3158


3159 3160 3161
class FinalOptimizePhase(Visitor.CythonTransform):
    """
    This visitor handles several commuting optimizations, and is run
3162 3163 3164 3165
    just before the C code generation phase.

    The optimizations currently implemented in this class are:
        - eliminate None assignment and refcounting for first assignment.
3166
        - isinstance -> typecheck for cdef types
Stefan Behnel's avatar
Stefan Behnel committed
3167
        - eliminate checks for None and/or types that became redundant after tree changes
3168
    """
3169
    def visit_SingleAssignmentNode(self, node):
3170 3171 3172 3173
        """Avoid redundant initialisation of local variables before their
        first assignment.
        """
        self.visitchildren(node)
3174 3175
        if node.first:
            lhs = node.lhs
3176
            lhs.lhs_of_first_assignment = True
3177 3178 3179 3180 3181
            if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
                # Have variable initialized to 0 rather than None
                lhs.entry.init_to_none = False
                lhs.entry.init = 0
        return node
3182

3183
    def visit_SimpleCallNode(self, node):
3184 3185 3186
        """Replace generic calls to isinstance(x, type) by a more efficient
        type check.
        """
3187
        self.visitchildren(node)
Robert Bradshaw's avatar
Robert Bradshaw committed
3188
        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3189 3190 3191
            if node.function.name == 'isinstance':
                type_arg = node.args[1]
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3192 3193
                    from CythonScope import utility_scope
                    node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3194
                    node.function.type = node.function.entry.type
3195
                    PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3196 3197
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
        return node
Stefan Behnel's avatar
Stefan Behnel committed
3198 3199 3200 3201 3202 3203 3204 3205 3206 3207 3208

    def visit_PyTypeTestNode(self, node):
        """Remove tests for alternatively allowed None values from
        type tests when we know that the argument cannot be None
        anyway.
        """
        self.visitchildren(node)
        if not node.notnone:
            if not node.arg.may_be_none():
                node.notnone = True
        return node