Optimize.py 152 KB
Newer Older
1
from Cython.Compiler import TypeSlots
2
from Cython.Compiler.ExprNodes import not_a_constant
3 4 5 6 7
import cython
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
               Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
               UtilNodes=object, Naming=object)

8 9
import Nodes
import ExprNodes
10
import PyrexTypes
11
import Visitor
12 13
import Builtin
import UtilNodes
14
import Options
15
import Naming
16

17
from Code import UtilityCode
18
from StringEncoding import EncodedString, BytesLiteral
19
from Errors import error
20 21
from ParseTreeTransforms import SkipDeclarations

22
import copy
23 24
import codecs

25
try:
26 27
    from __builtin__ import reduce
except ImportError:
28 29
    from functools import reduce

30 31 32 33 34
try:
    from __builtin__ import basestring
except ImportError:
    basestring = str # Python 3

35
def load_c_utility(name):
36
    return UtilityCode.load_cached(name, "Optimize.c")
37

38 39 40 41 42
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
    if isinstance(node, coercion_nodes):
        return node.arg
    return node

43
def unwrap_node(node):
44 45
    while isinstance(node, UtilNodes.ResultRefNode):
        node = node.expression
46
    return node
47 48

def is_common_value(a, b):
49 50
    a = unwrap_node(a)
    b = unwrap_node(b)
51 52 53
    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
        return a.name == b.name
    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
54
        return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
55 56
    return False

57 58 59 60 61
def filter_none_node(node):
    if node is not None and node.constant_result is None:
        return None
    return node

62
class IterationTransform(Visitor.EnvTransform):
63 64 65
    """Transform some common for-in loop patterns into efficient C loops:

    - for-in-dict loop becomes a while loop calling PyDict_Next()
Stefan Behnel's avatar
Stefan Behnel committed
66
    - for-in-enumerate is replaced by an external counter variable
67
    - for-in-range loop becomes a plain C for loop
68
    """
69 70
    def visit_PrimaryCmpNode(self, node):
        if node.is_ptr_contains():
71

72 73 74 75 76 77
            # for t in operand2:
            #     if operand1 == t:
            #         res = True
            #         break
            # else:
            #     res = False
78

79 80 81 82 83 84 85 86 87 88 89
            pos = node.pos
            result_ref = UtilNodes.ResultRefNode(node)
            if isinstance(node.operand2, ExprNodes.IndexNode):
                base_type = node.operand2.base.type.base_type
            else:
                base_type = node.operand2.type.base_type
            target_handle = UtilNodes.TempHandle(base_type)
            target = target_handle.ref(pos)
            cmp_node = ExprNodes.PrimaryCmpNode(
                pos, operator=u'==', operand1=node.operand1, operand2=target)
            if_body = Nodes.StatListNode(
90
                pos,
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
                         Nodes.BreakStatNode(pos)])
            if_node = Nodes.IfStatNode(
                pos,
                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
                else_clause=None)
            for_loop = UtilNodes.TempsBlockNode(
                pos,
                temps = [target_handle],
                body = Nodes.ForInStatNode(
                    pos,
                    target=target,
                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
                    body=if_node,
                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
106
            for_loop = for_loop.analyse_expressions(self.current_env())
107
            for_loop = self.visit(for_loop)
108
            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
109

110 111 112 113 114 115 116
            if node.operator == 'not_in':
                new_node = ExprNodes.NotNode(pos, operand=new_node)
            return new_node

        else:
            self.visitchildren(node)
            return node
117

118 119
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
120
        return self._optimise_for_loop(node, node.iterator.sequence)
121

122
    def _optimise_for_loop(self, node, iterator, reversed=False):
123 124
        if iterator.type is Builtin.dict_type:
            # like iterating over dict.keys()
125
            if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
126
                # CPython raises an error here: not a sequence
127
                return node
Stefan Behnel's avatar
Stefan Behnel committed
128
            return self._transform_dict_iteration(
129
                node, dict_obj=iterator, method=None, keys=True, values=False)
130

131
        # C array (slice) iteration?
132
        if iterator.type.is_ptr or iterator.type.is_array:
133
            return self._transform_carray_iteration(node, iterator, reversed=reversed)
134 135 136 137
        if iterator.type is Builtin.bytes_type:
            return self._transform_bytes_iteration(node, iterator, reversed=reversed)
        if iterator.type is Builtin.unicode_type:
            return self._transform_unicode_iteration(node, iterator, reversed=reversed)
138 139 140

        # the rest is based on function calls
        if not isinstance(iterator, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
141 142
            return node

143 144 145 146 147 148 149
        if iterator.args is None:
            arg_count = iterator.arg_tuple and len(iterator.arg_tuple.args) or 0
        else:
            arg_count = len(iterator.args)
            if arg_count and iterator.self is not None:
                arg_count -= 1

Stefan Behnel's avatar
Stefan Behnel committed
150
        function = iterator.function
151
        # dict iteration?
152
        if function.is_attribute and not reversed and not arg_count:
153
            base_obj = iterator.self or function.obj
154
            method = function.attribute
155
            # in Py3, items() is equivalent to Py2's iteritems()
156
            is_safe_iter = self.global_scope().context.language_level >= 3
157 158 159 160 161 162 163 164 165 166

            if not is_safe_iter and method in ('keys', 'values', 'items'):
                # try to reduce this to the corresponding .iter*() methods
                if isinstance(base_obj, ExprNodes.SimpleCallNode):
                    inner_function = base_obj.function
                    if (inner_function.is_name and inner_function.name == 'dict'
                            and inner_function.entry
                            and inner_function.entry.is_builtin):
                        # e.g. dict(something).items() => safe to use .iter*()
                        is_safe_iter = True
167 168

            keys = values = False
169
            if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
170
                keys = True
171
            elif method == 'itervalues' or (is_safe_iter and method == 'values'):
172
                values = True
173
            elif method == 'iteritems' or (is_safe_iter and method == 'items'):
174
                keys = values = True
175 176 177 178

            if keys or values:
                return self._transform_dict_iteration(
                    node, base_obj, method, keys, values)
179

180
        # enumerate/reversed ?
Stefan Behnel's avatar
Stefan Behnel committed
181
        if iterator.self is None and function.is_name and \
182 183 184
               function.entry and function.entry.is_builtin:
            if function.name == 'enumerate':
                if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
185
                    # CPython raises an error here: not a sequence
186 187 188 189
                    return node
                return self._transform_enumerate_iteration(node, iterator)
            elif function.name == 'reversed':
                if reversed:
Stefan Behnel's avatar
Stefan Behnel committed
190
                    # CPython raises an error here: not a sequence
191 192
                    return node
                return self._transform_reversed_iteration(node, iterator)
193

194 195
        # range() iteration?
        if Options.convert_range and node.target.type.is_int:
Stefan Behnel's avatar
Stefan Behnel committed
196 197 198
            if iterator.self is None and function.is_name and \
                   function.entry and function.entry.is_builtin and \
                   function.name in ('range', 'xrange'):
199
                return self._transform_range_iteration(node, iterator, reversed=reversed)
200

Stefan Behnel's avatar
Stefan Behnel committed
201
        return node
202

203 204 205 206 207 208 209 210 211 212
    def _transform_reversed_iteration(self, node, reversed_function):
        args = reversed_function.arg_tuple.args
        if len(args) == 0:
            error(reversed_function.pos,
                  "reversed() requires an iterable argument")
            return node
        elif len(args) > 1:
            error(reversed_function.pos,
                  "reversed() takes exactly 1 argument")
            return node
213 214 215 216 217 218 219 220 221
        arg = args[0]

        # reversed(list/tuple) ?
        if arg.type in (Builtin.tuple_type, Builtin.list_type):
            node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
            node.iterator.reversed = True
            return node

        return self._optimise_for_loop(node, arg, reversed=True)
222

223 224 225 226 227 228 229 230 231 232
    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_ptr_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
            ])

233 234
    def _transform_bytes_iteration(self, node, slice_node, reversed=False):
        target_type = node.target.type
235
        if not target_type.is_int and target_type is not Builtin.bytes_type:
236 237
            # bytes iteration returns bytes objects in Py2, but
            # integers in Py3
238 239 240
            return node

        unpack_temp_node = UtilNodes.LetRefNode(
241
            slice_node.as_none_safe_node("'NoneType' is not iterable"))
242 243

        slice_base_node = ExprNodes.PythonCapiCallNode(
244 245
            slice_node.pos, "PyBytes_AS_STRING",
            self.PyBytes_AS_STRING_func_type,
246 247 248 249
            args = [unpack_temp_node],
            is_temp = 0,
            )
        len_node = ExprNodes.PythonCapiCallNode(
250 251
            slice_node.pos, "PyBytes_GET_SIZE",
            self.PyBytes_GET_SIZE_func_type,
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
            args = [unpack_temp_node],
            is_temp = 0,
            )

        return UtilNodes.LetNode(
            unpack_temp_node,
            self._transform_carray_iteration(
                node,
                ExprNodes.SliceIndexNode(
                    slice_node.pos,
                    base = slice_base_node,
                    start = None,
                    step = None,
                    stop = len_node,
                    type = slice_base_node.type,
                    is_temp = 1,
268 269
                    ),
                reversed = reversed))
270

271 272 273 274 275 276 277
    PyUnicode_READ_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ucs4_type, [
            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
        ])

278 279 280 281 282 283 284 285
    init_unicode_iteration_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_int_type, [
            PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
            PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
        ],
        exception_value = '-1')
286 287

    def _transform_unicode_iteration(self, node, slice_node, reversed=False):
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
        if slice_node.is_literal:
            # try to reduce to byte iteration for plain Latin-1 strings
            try:
                bytes_value = BytesLiteral(slice_node.value.encode('latin1'))
            except UnicodeEncodeError:
                pass
            else:
                bytes_slice = ExprNodes.SliceIndexNode(
                    slice_node.pos,
                    base=ExprNodes.BytesNode(
                        slice_node.pos, value=bytes_value,
                        constant_result=bytes_value,
                        type=PyrexTypes.c_char_ptr_type).coerce_to(
                            PyrexTypes.c_uchar_ptr_type, self.current_env()),
                    start=None,
                    stop=ExprNodes.IntNode(
304
                        slice_node.pos, value=str(len(bytes_value)),
305 306 307 308 309 310
                        constant_result=len(bytes_value),
                        type=PyrexTypes.c_py_ssize_t_type),
                    type=Builtin.unicode_type,  # hint for Python conversion
                )
                return self._transform_carray_iteration(node, bytes_slice, reversed)

311 312 313 314 315
        unpack_temp_node = UtilNodes.LetRefNode(
            slice_node.as_none_safe_node("'NoneType' is not iterable"))

        start_node = ExprNodes.IntNode(
            node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
316 317
        length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        end_node = length_temp.ref(node.pos)
318 319 320 321 322 323
        if reversed:
            relation1, relation2 = '>', '>='
            start_node, end_node = end_node, start_node
        else:
            relation1, relation2 = '<=', '<'

324 325 326
        kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
        data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
        counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
327

328 329 330
        target_value = ExprNodes.PythonCapiCallNode(
            slice_node.pos, "__Pyx_PyUnicode_READ",
            self.PyUnicode_READ_func_type,
331 332 333
            args = [kind_temp.ref(slice_node.pos),
                    data_temp.ref(slice_node.pos),
                    counter_temp.ref(node.target.pos)],
334 335 336 337
            is_temp = False,
            )
        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
338
                                                  self.current_env())
339 340 341
        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
342
            rhs = target_value)
343 344 345 346 347 348 349
        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

        loop_node = Nodes.ForFromStatNode(
            node.pos,
            bound1=start_node, relation1=relation1,
350
            target=counter_temp.ref(node.target.pos),
351 352 353 354 355
            relation2=relation2, bound2=end_node,
            step=None, body=body,
            else_clause=node.else_clause,
            from_range=True)

356 357 358
        setup_node = Nodes.ExprStatNode(
            node.pos,
            expr = ExprNodes.PythonCapiCallNode(
359 360 361 362 363 364 365 366 367 368
                slice_node.pos, "__Pyx_init_unicode_iteration",
                self.init_unicode_iteration_func_type,
                args = [unpack_temp_node,
                        ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
                                                type=PyrexTypes.c_py_ssize_t_ptr_type),
                        ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
                                                type=PyrexTypes.c_void_ptr_ptr_type),
                        ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
                                                type=PyrexTypes.c_int_ptr_type),
                        ],
369 370
                is_temp = True,
                result_is_used = False,
371
                utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
372 373 374
                ))
        return UtilNodes.LetNode(
            unpack_temp_node,
375 376 377
            UtilNodes.TempsBlockNode(
                node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
                body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
378

379
    def _transform_carray_iteration(self, node, slice_node, reversed=False):
380
        neg_step = False
381 382
        if isinstance(slice_node, ExprNodes.SliceIndexNode):
            slice_base = slice_node.base
383 384
            start = filter_none_node(slice_node.start)
            stop = filter_none_node(slice_node.stop)
385 386
            step = None
            if not stop:
387 388
                if not slice_base.type.is_pyobject:
                    error(slice_node.pos, "C array iteration requires known end index")
389
                return node
390

391
        elif isinstance(slice_node, ExprNodes.IndexNode):
392
            assert isinstance(slice_node.index, ExprNodes.SliceNode)
393 394
            slice_base = slice_node.base
            index = slice_node.index
395 396 397
            start = filter_none_node(index.start)
            stop = filter_none_node(index.stop)
            step = filter_none_node(index.step)
398
            if step:
399
                if not isinstance(step.constant_result, (int,long)) \
400 401 402
                       or step.constant_result == 0 \
                       or step.constant_result > 0 and not stop \
                       or step.constant_result < 0 and not start:
403 404
                    if not slice_base.type.is_pyobject:
                        error(step.pos, "C array iteration requires known step size and end index")
405 406 407
                    return node
                else:
                    # step sign is handled internally by ForFromStatNode
408 409 410 411
                    step_value = step.constant_result
                    if reversed:
                        step_value = -step_value
                    neg_step = step_value < 0
412
                    step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
413 414
                                             value=str(abs(step_value)),
                                             constant_result=abs(step_value))
415

416 417
        elif slice_node.type.is_array:
            if slice_node.type.size is None:
Stefan Behnel's avatar
Stefan Behnel committed
418
                error(slice_node.pos, "C array iteration requires known end index")
419
                return node
420 421 422
            slice_base = slice_node
            start = None
            stop = ExprNodes.IntNode(
423 424
                slice_node.pos, value=str(slice_node.type.size),
                type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
425
            step = None
426

427
        else:
428
            if not slice_node.type.is_pyobject:
429
                error(slice_node.pos, "C array iteration requires known end index")
430 431
            return node

432
        if start:
433
            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
434
        if stop:
435
            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
436 437 438 439 440 441 442
        if stop is None:
            if neg_step:
                stop = ExprNodes.IntNode(
                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
            else:
                error(slice_node.pos, "C array iteration requires known step size and end index")
                return node
443

444 445 446 447 448 449 450
        if reversed:
            if not start:
                start = ExprNodes.IntNode(slice_node.pos, value="0",  constant_result=0,
                                          type=PyrexTypes.c_py_ssize_t_type)
            # if step was provided, it was already negated above
            start, stop = stop, start

451 452 453
        ptr_type = slice_base.type
        if ptr_type.is_array:
            ptr_type = ptr_type.element_ptr_type()
454
        carray_ptr = slice_base.coerce_to_simple(self.current_env())
455

456
        if start and start.constant_result != 0:
457 458 459 460 461
            start_ptr_node = ExprNodes.AddNode(
                start.pos,
                operand1=carray_ptr,
                operator='+',
                operand2=start,
462
                type=ptr_type)
463
        else:
464
            start_ptr_node = carray_ptr
465

466 467 468 469 470 471 472
        if stop and stop.constant_result != 0:
            stop_ptr_node = ExprNodes.AddNode(
                stop.pos,
                operand1=ExprNodes.CloneNode(carray_ptr),
                operator='+',
                operand2=stop,
                type=ptr_type
473
                ).coerce_to_simple(self.current_env())
474 475
        else:
            stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
476

477
        counter = UtilNodes.TempHandle(ptr_type)
478 479
        counter_temp = counter.ref(node.target.pos)

480
        if slice_base.type.is_string and node.target.type.is_pyobject:
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
            # special case: char* -> bytes/unicode
            if slice_node.type is Builtin.unicode_type:
                target_value = ExprNodes.CastNode(
                    ExprNodes.DereferenceNode(
                        node.target.pos, operand=counter_temp,
                        type=ptr_type.base_type),
                    PyrexTypes.c_py_ucs4_type).coerce_to(
                        node.target.type, self.current_env())
            else:
                # char* -> bytes coercion requires slicing, not indexing
                target_value = ExprNodes.SliceIndexNode(
                    node.target.pos,
                    start=ExprNodes.IntNode(node.target.pos, value='0',
                                            constant_result=0,
                                            type=PyrexTypes.c_int_type),
                    stop=ExprNodes.IntNode(node.target.pos, value='1',
                                           constant_result=1,
                                           type=PyrexTypes.c_int_type),
                    base=counter_temp,
                    type=Builtin.bytes_type,
                    is_temp=1)
502 503 504
        elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
            # Allow iteration with pointer target to avoid copy.
            target_value = counter_temp
505
        else:
506
            # TODO: can this safely be replaced with DereferenceNode() as above?
507 508
            target_value = ExprNodes.IndexNode(
                node.target.pos,
509 510 511 512
                index=ExprNodes.IntNode(node.target.pos, value='0',
                                        constant_result=0,
                                        type=PyrexTypes.c_int_type),
                base=counter_temp,
513
                is_buffer_access=False,
514
                type=ptr_type.base_type)
515 516 517

        if target_value.type != node.target.type:
            target_value = target_value.coerce_to(node.target.type,
518
                                                  self.current_env())
519 520 521 522 523 524 525 526 527 528

        target_assign = Nodes.SingleAssignmentNode(
            pos = node.target.pos,
            lhs = node.target,
            rhs = target_value)

        body = Nodes.StatListNode(
            node.pos,
            stats = [target_assign, node.body])

529 530
        relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)

531 532
        for_node = Nodes.ForFromStatNode(
            node.pos,
533
            bound1=start_ptr_node, relation1=relation1,
534
            target=counter_temp,
535
            relation2=relation2, bound2=stop_ptr_node,
536 537 538 539 540 541 542 543
            step=step, body=body,
            else_clause=node.else_clause,
            from_range=True)

        return UtilNodes.TempsBlockNode(
            node.pos, temps=[counter],
            body=for_node)

544 545 546 547 548 549
    def _transform_enumerate_iteration(self, node, enumerate_function):
        args = enumerate_function.arg_tuple.args
        if len(args) == 0:
            error(enumerate_function.pos,
                  "enumerate() requires an iterable argument")
            return node
550
        elif len(args) > 2:
551
            error(enumerate_function.pos,
552
                  "enumerate() takes at most 2 arguments")
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
            return node

        if not node.target.is_sequence_constructor:
            # leave this untouched for now
            return node
        targets = node.target.args
        if len(targets) != 2:
            # leave this untouched for now
            return node

        enumerate_target, iterable_target = targets
        counter_type = enumerate_target.type

        if not counter_type.is_pyobject and not counter_type.is_int:
            # nothing we can do here, I guess
            return node

570
        if len(args) == 2:
571
            start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
572 573 574 575 576 577 578
        else:
            start = ExprNodes.IntNode(enumerate_function.pos,
                                      value='0',
                                      type=counter_type,
                                      constant_result=0)
        temp = UtilNodes.LetRefNode(start)

579 580
        inc_expression = ExprNodes.AddNode(
            enumerate_function.pos,
581
            operand1 = temp,
582
            operand2 = ExprNodes.IntNode(node.pos, value='1',
583 584
                                         type=counter_type,
                                         constant_result=1),
585 586
            operator = '+',
            type = counter_type,
Stefan Behnel's avatar
Stefan Behnel committed
587
            #inplace = True,   # not worth using in-place operation for Py ints
588 589 590
            is_temp = counter_type.is_pyobject
            )

591 592 593 594
        loop_body = [
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
                lhs = enumerate_target,
595
                rhs = temp),
596 597
            Nodes.SingleAssignmentNode(
                pos = enumerate_target.pos,
598
                lhs = temp,
599 600
                rhs = inc_expression)
            ]
601

602 603 604 605 606 607 608
        if isinstance(node.body, Nodes.StatListNode):
            node.body.stats = loop_body + node.body.stats
        else:
            loop_body.append(node.body)
            node.body = Nodes.StatListNode(
                node.body.pos,
                stats = loop_body)
609 610

        node.target = iterable_target
611
        node.item = node.item.coerce_to(iterable_target.type, self.current_env())
612
        node.iterator.sequence = args[0]
613 614

        # recurse into loop to check for further optimisations
615
        return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
616

617 618 619 620 621 622 623 624 625 626 627
    def _find_for_from_node_relations(self, neg_step_value, reversed):
        if reversed:
            if neg_step_value:
                return '<', '<='
            else:
                return '>', '>='
        else:
            if neg_step_value:
                return '>=', '>'
            else:
                return '<=', '<'
628

629
    def _transform_range_iteration(self, node, range_function, reversed=False):
630 631 632 633
        args = range_function.arg_tuple.args
        if len(args) < 3:
            step_pos = range_function.pos
            step_value = 1
634 635
            step = ExprNodes.IntNode(step_pos, value='1',
                                     constant_result=1)
636 637 638
        else:
            step = args[2]
            step_pos = step.pos
639
            if not isinstance(step.constant_result, (int, long)):
640 641
                # cannot determine step direction
                return node
642 643 644
            step_value = step.constant_result
            if step_value == 0:
                # will lead to an error elsewhere
645
                return node
646 647 648
            if reversed and step_value not in (1, -1):
                # FIXME: currently broken - requires calculation of the correct bounds
                return node
649
            if not isinstance(step, ExprNodes.IntNode):
650 651
                step = ExprNodes.IntNode(step_pos, value=str(step_value),
                                         constant_result=step_value)
652 653

        if len(args) == 1:
654 655
            bound1 = ExprNodes.IntNode(range_function.pos, value='0',
                                       constant_result=0)
656
            bound2 = args[0].coerce_to_integer(self.current_env())
657
        else:
658 659
            bound1 = args[0].coerce_to_integer(self.current_env())
            bound2 = args[1].coerce_to_integer(self.current_env())
660

661 662
        relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)

663 664 665 666 667 668 669 670 671 672
        if reversed:
            bound1, bound2 = bound2, bound1
            if step_value < 0:
                step_value = -step_value
        else:
            if step_value < 0:
                step_value = -step_value

        step.value = str(step_value)
        step.constant_result = step_value
673
        step = step.coerce_to_integer(self.current_env())
674

675
        if not bound2.is_literal:
676 677 678 679 680 681
            # stop bound must be immutable => keep it in a temp var
            bound2_is_temp = True
            bound2 = UtilNodes.LetRefNode(bound2)
        else:
            bound2_is_temp = False

682 683 684 685 686 687 688
        for_node = Nodes.ForFromStatNode(
            node.pos,
            target=node.target,
            bound1=bound1, relation1=relation1,
            relation2=relation2, bound2=bound2,
            step=step, body=node.body,
            else_clause=node.else_clause,
Magnus Lie Hetland's avatar
Magnus Lie Hetland committed
689
            from_range=True)
690 691 692 693

        if bound2_is_temp:
            for_node = UtilNodes.LetNode(bound2, for_node)

694 695
        return for_node

696
    def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
697
        temps = []
698 699 700
        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
        temps.append(temp)
        dict_temp = temp.ref(dict_obj.pos)
701 702
        temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(temp)
703
        pos_temp = temp.ref(node.pos)
704

705
        key_target = value_target = tuple_target = None
706 707 708 709 710
        if keys and values:
            if node.target.is_sequence_constructor:
                if len(node.target.args) == 2:
                    key_target, value_target = node.target.args
                else:
Stefan Behnel's avatar
Stefan Behnel committed
711
                    # unusual case that may or may not lead to an error
712 713 714
                    return node
            else:
                tuple_target = node.target
715 716
        elif keys:
            key_target = node.target
717
        else:
718
            value_target = node.target
719 720 721 722 723 724 725 726 727 728

        if isinstance(node.body, Nodes.StatListNode):
            body = node.body
        else:
            body = Nodes.StatListNode(pos = node.body.pos,
                                      stats = [node.body])

        # keep original length to guard against dict modification
        dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
        temps.append(dict_len_temp)
729 730 731 732 733 734 735 736 737 738 739 740 741 742
        dict_len_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=dict_len_temp.ref(dict_obj.pos),
            type=PyrexTypes.c_ptr_type(dict_len_temp.type))
        temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
        temps.append(temp)
        is_dict_temp = temp.ref(node.pos)
        is_dict_temp_addr = ExprNodes.AmpersandNode(
            node.pos, operand=is_dict_temp,
            type=PyrexTypes.c_ptr_type(temp.type))

        iter_next_node = Nodes.DictIterationNextNode(
            dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
            key_target, value_target, tuple_target,
            is_dict_temp)
743
        iter_next_node = iter_next_node.analyse_expressions(self.current_env())
744 745 746 747 748 749 750 751 752 753 754 755
        body.stats[0:0] = [iter_next_node]

        if method:
            method_node = ExprNodes.StringNode(
                dict_obj.pos, is_identifier=True, value=method)
            dict_obj = dict_obj.as_none_safe_node(
                "'NoneType' object has no attribute '%s'",
                error = "PyExc_AttributeError",
                format_args = [method])
        else:
            method_node = ExprNodes.NullNode(dict_obj.pos)
            dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
756

757 758 759
        def flag_node(value):
            value = value and 1 or 0
            return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
760 761

        result_code = [
762
            Nodes.SingleAssignmentNode(
Stefan Behnel's avatar
Stefan Behnel committed
763
                node.pos,
764
                lhs = pos_temp,
765 766
                rhs = ExprNodes.IntNode(node.pos, value='0',
                                        constant_result=0)),
767
            Nodes.SingleAssignmentNode(
Stefan Behnel's avatar
Stefan Behnel committed
768
                dict_obj.pos,
769 770 771 772 773 774
                lhs = dict_temp,
                rhs = ExprNodes.PythonCapiCallNode(
                    dict_obj.pos,
                    "__Pyx_dict_iterator",
                    self.PyDict_Iterator_func_type,
                    utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
775
                    args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
776 777 778
                            method_node, dict_len_temp_addr, is_dict_temp_addr,
                            ],
                    is_temp=True,
779 780
                )),
            Nodes.WhileStatNode(
Stefan Behnel's avatar
Stefan Behnel committed
781
                node.pos,
782
                condition = None,
783 784 785 786 787 788 789 790
                body = body,
                else_clause = node.else_clause
                )
            ]

        return UtilNodes.TempsBlockNode(
            node.pos, temps=temps,
            body=Nodes.StatListNode(
791
                node.pos,
792 793 794
                stats = result_code
                ))

795 796 797 798 799 800 801 802 803
    PyDict_Iterator_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("is_dict",  PyrexTypes.c_int_type, None),
            PyrexTypes.CFuncTypeArg("method_name",  PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("p_orig_length",  PyrexTypes.c_py_ssize_t_ptr_type, None),
            PyrexTypes.CFuncTypeArg("p_is_dict",  PyrexTypes.c_int_ptr_type, None),
            ])

804

805 806
class SwitchTransform(Visitor.VisitorTransform):
    """
807
    This transformation tries to turn long if statements into C switch statements.
808
    The requirement is that every clause be an (or of) var == value, where the var
809
    is common among all clauses and both var and value are ints.
810
    """
811 812 813
    NO_MATCH = (None, None, None)

    def extract_conditions(self, cond, allow_not_in):
814
        while True:
815 816
            if isinstance(cond, (ExprNodes.CoerceToTempNode,
                                 ExprNodes.CoerceToBooleanNode)):
817 818 819 820 821 822 823 824
                cond = cond.arg
            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
                # this is what we get from the FlattenInListTransform
                cond = cond.subexpression
            elif isinstance(cond, ExprNodes.TypecastNode):
                cond = cond.operand
            else:
                break
825

826
        if isinstance(cond, ExprNodes.PrimaryCmpNode):
827 828 829 830 831 832 833 834 835 836 837 838 839 840 841
            if cond.cascade is not None:
                return self.NO_MATCH
            elif cond.is_c_string_contains() and \
                   isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
                not_in = cond.operator == 'not_in'
                if not_in and not allow_not_in:
                    return self.NO_MATCH
                if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
                       cond.operand2.contains_surrogates():
                    # dealing with surrogates leads to different
                    # behaviour on wide and narrow Unicode
                    # platforms => refuse to optimise this case
                    return self.NO_MATCH
                return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
            elif not cond.is_python_comparison():
842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
                if cond.operator == '==':
                    not_in = False
                elif allow_not_in and cond.operator == '!=':
                    not_in = True
                else:
                    return self.NO_MATCH
                # this looks somewhat silly, but it does the right
                # checks for NameNode and AttributeNode
                if is_common_value(cond.operand1, cond.operand1):
                    if cond.operand2.is_literal:
                        return not_in, cond.operand1, [cond.operand2]
                    elif getattr(cond.operand2, 'entry', None) \
                             and cond.operand2.entry.is_const:
                        return not_in, cond.operand1, [cond.operand2]
                if is_common_value(cond.operand2, cond.operand2):
                    if cond.operand1.is_literal:
                        return not_in, cond.operand2, [cond.operand1]
                    elif getattr(cond.operand1, 'entry', None) \
                             and cond.operand1.entry.is_const:
                        return not_in, cond.operand2, [cond.operand1]
        elif isinstance(cond, ExprNodes.BoolBinopNode):
            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
                allow_not_in = (cond.operator == 'and')
                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
                    if (not not_in_1) or allow_not_in:
                        return not_in_1, t1, c1+c2
        return self.NO_MATCH

872 873
    def extract_in_string_conditions(self, string_literal):
        if isinstance(string_literal, ExprNodes.UnicodeNode):
874
            charvals = list(map(ord, set(string_literal.value)))
875 876 877 878 879 880 881 882 883
            charvals.sort()
            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
                                       constant_result=charval)
                     for charval in charvals ]
        else:
            # this is a bit tricky as Py3's bytes type returns
            # integers on iteration, whereas Py2 returns 1-char byte
            # strings
            characters = string_literal.value
884 885
            characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
            characters.sort()
886 887 888 889
            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
                                        constant_result=charval)
                     for charval in characters ]

890 891
    def extract_common_conditions(self, common_var, condition, allow_not_in):
        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
892
        if var is None:
893
            return self.NO_MATCH
894
        elif common_var is not None and not is_common_value(var, common_var):
895
            return self.NO_MATCH
896
        elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
897 898 899 900 901 902 903
            return self.NO_MATCH
        return not_in, var, conditions

    def has_duplicate_values(self, condition_values):
        # duplicated values don't work in a switch statement
        seen = set()
        for value in condition_values:
904
            if value.has_constant_result():
905 906 907 908 909 910
                if value.constant_result in seen:
                    return True
                seen.add(value.constant_result)
            else:
                # this isn't completely safe as we don't know the
                # final C value, but this is about the best we can do
911 912 913 914 915 916
                try:
                    if value.entry.cname in seen:
                        return True
                except AttributeError:
                    return True  # play safe
                seen.add(value.entry.cname)
917
        return False
918

919 920 921 922
    def visit_IfStatNode(self, node):
        common_var = None
        cases = []
        for if_clause in node.if_clauses:
923 924
            _, common_var, conditions = self.extract_common_conditions(
                common_var, if_clause.condition, False)
925
            if common_var is None:
926
                self.visitchildren(node)
927
                return node
928 929 930 931
            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                              conditions = conditions,
                                              body = if_clause.body))

932 933 934
        condition_values = [
            cond for case in cases for cond in case.conditions]
        if len(condition_values) < 2:
935 936
            self.visitchildren(node)
            return node
937
        if self.has_duplicate_values(condition_values):
938
            self.visitchildren(node)
939
            return node
940

Robert Bradshaw's avatar
Robert Bradshaw committed
941
        common_var = unwrap_node(common_var)
942 943 944 945 946 947 948
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = node.else_clause)
        return switch_node

    def visit_CondExprNode(self, node):
949 950 951 952 953 954
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node.test, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
955
            return node
956 957 958
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            node.true_val, node.false_val)
959 960

    def visit_BoolBinopNode(self, node):
961 962 963 964 965 966
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
967 968
            return node

969 970
        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))

    def visit_PrimaryCmpNode(self, node):
        not_in, common_var, conditions = self.extract_common_conditions(
            None, node, True)
        if common_var is None \
               or len(conditions) < 2 \
               or self.has_duplicate_values(conditions):
            self.visitchildren(node)
            return node

        return self.build_simple_switch_statement(
            node, common_var, conditions, not_in,
            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
987 988 989

    def build_simple_switch_statement(self, node, common_var, conditions,
                                      not_in, true_val, false_val):
990 991 992 993
        result_ref = UtilNodes.ResultRefNode(node)
        true_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
994
            rhs = true_val,
995 996 997 998
            first = True)
        false_body = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
999
            rhs = false_val,
1000 1001
            first = True)

1002 1003 1004
        if not_in:
            true_body, false_body = false_body, true_body

1005 1006 1007 1008 1009 1010 1011 1012 1013
        cases = [Nodes.SwitchCaseNode(pos = node.pos,
                                      conditions = conditions,
                                      body = true_body)]

        common_var = unwrap_node(common_var)
        switch_node = Nodes.SwitchStatNode(pos = node.pos,
                                           test = common_var,
                                           cases = cases,
                                           else_clause = false_body)
1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
        replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
        return replacement

    def visit_EvalWithTempExprNode(self, node):
        # drop unused expression temp from FlattenInListTransform
        orig_expr = node.subexpression
        temp_ref = node.lazy_temp
        self.visitchildren(node)
        if node.subexpression is not orig_expr:
            # node was restructured => check if temp is still used
            if not Visitor.tree_contains(node.subexpression, temp_ref):
                return node.subexpression
        return node
1027

1028
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1029

1030

1031
class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
1032 1033
    """
    This transformation flattens "x in [val1, ..., valn]" into a sequential list
1034
    of comparisons.
1035
    """
1036

1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
    def visit_PrimaryCmpNode(self, node):
        self.visitchildren(node)
        if node.cascade is not None:
            return node
        elif node.operator == 'in':
            conjunction = 'or'
            eq_or_neq = '=='
        elif node.operator == 'not_in':
            conjunction = 'and'
            eq_or_neq = '!='
        else:
            return node
1049

1050 1051 1052
        if not isinstance(node.operand2, (ExprNodes.TupleNode,
                                          ExprNodes.ListNode,
                                          ExprNodes.SetNode)):
Stefan Behnel's avatar
Stefan Behnel committed
1053
            return node
1054

Stefan Behnel's avatar
Stefan Behnel committed
1055 1056
        args = node.operand2.args
        if len(args) == 0:
1057 1058
            # note: lhs may have side effects
            return node
1059

1060
        lhs = UtilNodes.ResultRefNode(node.operand1)
Stefan Behnel's avatar
Stefan Behnel committed
1061 1062

        conds = []
1063
        temps = []
Stefan Behnel's avatar
Stefan Behnel committed
1064
        for arg in args:
1065 1066 1067 1068 1069 1070 1071 1072 1073
            try:
                # Trial optimisation to avoid redundant temp
                # assignments.  However, since is_simple() is meant to
                # be called after type analysis, we ignore any errors
                # and just play safe in that case.
                is_simple_arg = arg.is_simple()
            except Exception:
                is_simple_arg = False
            if not is_simple_arg:
1074 1075 1076
                # must evaluate all non-simple RHS before doing the comparisons
                arg = UtilNodes.LetRefNode(arg)
                temps.append(arg)
Stefan Behnel's avatar
Stefan Behnel committed
1077 1078 1079 1080 1081 1082 1083
            cond = ExprNodes.PrimaryCmpNode(
                                pos = node.pos,
                                operand1 = lhs,
                                operator = eq_or_neq,
                                operand2 = arg,
                                cascade = None)
            conds.append(ExprNodes.TypecastNode(
1084
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
1085 1086 1087 1088
                                operand = cond,
                                type = PyrexTypes.c_bint_type))
        def concat(left, right):
            return ExprNodes.BoolBinopNode(
1089
                                pos = node.pos,
Stefan Behnel's avatar
Stefan Behnel committed
1090 1091 1092 1093
                                operator = conjunction,
                                operand1 = left,
                                operand2 = right)

1094
        condition = reduce(concat, conds)
1095 1096 1097 1098
        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
        for temp in temps[::-1]:
            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
        return new_node
1099

1100
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1101 1102


1103 1104 1105 1106 1107 1108
class DropRefcountingTransform(Visitor.VisitorTransform):
    """Drop ref-counting in safe places.
    """
    visit_Node = Visitor.VisitorTransform.recurse_to_children

    def visit_ParallelAssignmentNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
1109 1110 1111
        """
        Parallel swap assignments like 'a,b = b,a' are safe.
        """
1112 1113 1114 1115
        left_names, right_names = [], []
        left_indices, right_indices = [], []
        temps = []

1116 1117
        for stat in node.stats:
            if isinstance(stat, Nodes.SingleAssignmentNode):
1118 1119
                if not self._extract_operand(stat.lhs, left_names,
                                             left_indices, temps):
1120
                    return node
1121 1122
                if not self._extract_operand(stat.rhs, right_names,
                                             right_indices, temps):
1123
                    return node
1124 1125 1126
            elif isinstance(stat, Nodes.CascadedAssignmentNode):
                # FIXME
                return node
1127 1128 1129
            else:
                return node

1130 1131
        if left_names or right_names:
            # lhs/rhs names must be a non-redundant permutation
1132 1133
            lnames = [ path for path, n in left_names ]
            rnames = [ path for path, n in right_names ]
1134 1135 1136
            if set(lnames) != set(rnames):
                return node
            if len(set(lnames)) != len(right_names):
1137 1138
                return node

1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153
        if left_indices or right_indices:
            # base name and index of index nodes must be a
            # non-redundant permutation
            lindices = []
            for lhs_node in left_indices:
                index_id = self._extract_index_id(lhs_node)
                if not index_id:
                    return node
                lindices.append(index_id)
            rindices = []
            for rhs_node in right_indices:
                index_id = self._extract_index_id(rhs_node)
                if not index_id:
                    return node
                rindices.append(index_id)
1154

1155 1156 1157 1158 1159 1160 1161
            if set(lindices) != set(rindices):
                return node
            if len(set(lindices)) != len(right_indices):
                return node

            # really supporting IndexNode requires support in
            # __Pyx_GetItemInt(), so let's stop short for now
1162 1163
            return node

1164 1165 1166 1167
        temp_args = [t.arg for t in temps]
        for temp in temps:
            temp.use_managed_ref = False

1168
        for _, name_node in left_names + right_names:
1169 1170 1171 1172 1173
            if name_node not in temp_args:
                name_node.use_managed_ref = False

        for index_node in left_indices + right_indices:
            index_node.use_managed_ref = False
1174 1175 1176

        return node

1177 1178 1179 1180 1181 1182 1183
    def _extract_operand(self, node, names, indices, temps):
        node = unwrap_node(node)
        if not node.type.is_pyobject:
            return False
        if isinstance(node, ExprNodes.CoerceToTempNode):
            temps.append(node)
            node = node.arg
1184 1185 1186 1187
        name_path = []
        obj_node = node
        while isinstance(obj_node, ExprNodes.AttributeNode):
            if obj_node.is_py_attr:
1188
                return False
1189 1190 1191 1192 1193
            name_path.append(obj_node.member)
            obj_node = obj_node.obj
        if isinstance(obj_node, ExprNodes.NameNode):
            name_path.append(obj_node.name)
            names.append( ('.'.join(name_path[::-1]), node) )
1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217
        elif isinstance(node, ExprNodes.IndexNode):
            if node.base.type != Builtin.list_type:
                return False
            if not node.index.type.is_int:
                return False
            if not isinstance(node.base, ExprNodes.NameNode):
                return False
            indices.append(node)
        else:
            return False
        return True

    def _extract_index_id(self, index_node):
        base = index_node.base
        index = index_node.index
        if isinstance(index, ExprNodes.NameNode):
            index_val = index.name
        elif isinstance(index, ExprNodes.ConstNode):
            # FIXME:
            return None
        else:
            return None
        return (base.name, index_val)

1218

1219 1220 1221 1222 1223 1224 1225
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
    """Optimize some common calls to builtin types *before* the type
    analysis phase and *after* the declarations analysis phase.

    This transform cannot make use of any argument types, but it can
    restructure the tree in a way that the type analysis phase can
    respond to.
Stefan Behnel's avatar
Stefan Behnel committed
1226 1227 1228

    Introducing C function calls here may not be a good idea.  Move
    them to the OptimizeBuiltinCalls transform instead, which runs
Stefan Behnel's avatar
Stefan Behnel committed
1229
    after type analysis.
1230
    """
1231 1232
    # only intercept on call nodes
    visit_Node = Visitor.VisitorTransform.recurse_to_children
1233

1234 1235 1236 1237 1238 1239 1240
    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        function = node.function
        if not self._function_is_builtin_name(function):
            return node
        return self._dispatch_to_handler(node, function, node.args)

1241
    def visit_GeneralCallNode(self, node):
1242
        self.visitchildren(node)
1243
        function = node.function
1244
        if not self._function_is_builtin_name(function):
1245 1246 1247 1248
            return node
        arg_tuple = node.positional_args
        if not isinstance(arg_tuple, ExprNodes.TupleNode):
            return node
1249
        args = arg_tuple.args
1250
        return self._dispatch_to_handler(
1251
            node, function, args, node.keyword_args)
1252

1253 1254 1255
    def _function_is_builtin_name(self, function):
        if not function.is_name:
            return False
1256
        env = self.current_env()
1257
        entry = env.lookup(function.name)
1258
        if entry is not env.builtin_scope().lookup_here(function.name):
1259
            return False
1260
        # if entry is None, it's at least an undeclared name, so likely builtin
1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298
        return True

    def _dispatch_to_handler(self, node, function, args, kwargs=None):
        if kwargs is None:
            handler_name = '_handle_simple_function_%s' % function.name
        else:
            handler_name = '_handle_general_function_%s' % function.name
        handle_call = getattr(self, handler_name, None)
        if handle_call is not None:
            if kwargs is None:
                return handle_call(node, args)
            else:
                return handle_call(node, args, kwargs)
        return node

    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
        node.function = ExprNodes.PythonCapiFunctionNode(
            node.function.pos, node.function.name, cname, func_type,
            utility_code = utility_code)

    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

    # specific handlers for simple call nodes

1299
    def _handle_simple_function_float(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1300
        if not pos_args:
1301 1302 1303
            return ExprNodes.FloatNode(node.pos, value='0.0')
        if len(pos_args) > 1:
            self._error_wrong_arg_count('float', node, pos_args, 1)
1304 1305 1306
        arg_type = getattr(pos_args[0], 'type', None)
        if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
            return pos_args[0]
1307 1308
        return node

1309 1310 1311
    class YieldNodeCollector(Visitor.TreeVisitor):
        def __init__(self):
            Visitor.TreeVisitor.__init__(self)
1312
            self.yield_stat_nodes = {}
1313 1314 1315
            self.yield_nodes = []

        visit_Node = Visitor.TreeVisitor.visitchildren
1316 1317
        # XXX: disable inlining while it's not back supported
        def __visit_YieldExprNode(self, node):
1318 1319 1320
            self.yield_nodes.append(node)
            self.visitchildren(node)

1321
        def __visit_ExprStatNode(self, node):
1322 1323 1324 1325
            self.visitchildren(node)
            if node.expr in self.yield_nodes:
                self.yield_stat_nodes[node.expr] = node

1326 1327 1328 1329 1330 1331
        def __visit_GeneratorExpressionNode(self, node):
            # enable when we support generic generator expressions
            #
            # everything below this node is out of scope
            pass

1332
    def _find_single_yield_expression(self, node):
1333 1334 1335
        collector = self.YieldNodeCollector()
        collector.visitchildren(node)
        if len(collector.yield_nodes) != 1:
1336 1337
            return None, None
        yield_node = collector.yield_nodes[0]
1338 1339 1340 1341
        try:
            return (yield_node.arg, collector.yield_stat_nodes[yield_node])
        except KeyError:
            return None, None
1342

1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387
    def _handle_simple_function_all(self, node, pos_args):
        """Transform

        _result = all(x for L in LL for x in L)

        into

        for L in LL:
            for x in L:
                if not x:
                    _result = False
                    break
            else:
                continue
            break
        else:
            _result = True
        """
        return self._transform_any_all(node, pos_args, False)

    def _handle_simple_function_any(self, node, pos_args):
        """Transform

        _result = any(x for L in LL for x in L)

        into

        for L in LL:
            for x in L:
                if x:
                    _result = True
                    break
            else:
                continue
            break
        else:
            _result = False
        """
        return self._transform_any_all(node, pos_args, True)

    def _transform_any_all(self, node, pos_args, is_any):
        if len(pos_args) != 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
1388 1389
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop
1390 1391
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1392 1393 1394 1395 1396 1397 1398
            return node

        if is_any:
            condition = yield_expression
        else:
            condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)

1399
        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1400
        test_node = Nodes.IfStatNode(
1401
            yield_expression.pos,
1402 1403
            else_clause = None,
            if_clauses = [ Nodes.IfClauseNode(
1404
                yield_expression.pos,
1405 1406 1407 1408 1409 1410 1411
                condition = condition,
                body = Nodes.StatListNode(
                    node.pos,
                    stats = [
                        Nodes.SingleAssignmentNode(
                            node.pos,
                            lhs = result_ref,
1412
                            rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1413 1414 1415 1416 1417 1418 1419 1420 1421
                                                     constant_result = is_any)),
                        Nodes.BreakStatNode(node.pos)
                        ])) ]
            )
        loop = loop_node
        while isinstance(loop.body, Nodes.LoopNode):
            next_loop = loop.body
            loop.body = Nodes.StatListNode(loop.body.pos, stats = [
                loop.body,
1422
                Nodes.BreakStatNode(yield_expression.pos)
1423
                ])
1424
            next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1425 1426 1427 1428
            loop = next_loop
        loop_node.else_clause = Nodes.SingleAssignmentNode(
            node.pos,
            lhs = result_ref,
1429
            rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1430 1431
                                     constant_result = not is_any))

1432
        Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1433

1434 1435
        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1436
            expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1437

1438
    def _handle_simple_function_sorted(self, node, pos_args):
1439 1440 1441 1442 1443
        """Transform sorted(genexpr) and sorted([listcomp]) into
        [listcomp].sort().  CPython just reads the iterable into a
        list and calls .sort() on it.  Expanding the iterable in a
        listcomp is still faster and the result can be sorted in
        place.
1444 1445 1446
        """
        if len(pos_args) != 1:
            return node
1447
        if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
1448
               and pos_args[0].type is Builtin.list_type:
1449 1450 1451 1452 1453 1454 1455 1456
            listcomp_node = pos_args[0]
            loop_node = listcomp_node.loop
        elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            gen_expr_node = pos_args[0]
            loop_node = gen_expr_node.loop
            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
            if yield_expression is None:
                return node
1457

1458
            append_node = ExprNodes.ComprehensionAppendNode(
1459
                yield_expression.pos, expr = yield_expression)
1460

1461
            Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1462

1463
            listcomp_node = ExprNodes.ComprehensionNode(
1464
                gen_expr_node.pos, loop = loop_node,
1465 1466 1467
                append = append_node, type = Builtin.list_type,
                expr_scope = gen_expr_node.expr_scope,
                has_local_scope = True)
1468
            append_node.target = listcomp_node
1469 1470
        else:
            return node
1471

1472 1473
        result_node = UtilNodes.ResultRefNode(
            pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
        listcomp_assign_node = Nodes.SingleAssignmentNode(
            node.pos, lhs = result_node, rhs = listcomp_node, first = True)

        sort_method = ExprNodes.AttributeNode(
            node.pos, obj = result_node, attribute = EncodedString('sort'),
            # entry ? type ?
            needs_none_check = False)
        sort_node = Nodes.ExprStatNode(
            node.pos, expr = ExprNodes.SimpleCallNode(
                node.pos, function = sort_method, args = []))

        sort_node.analyse_declarations(self.current_env())

        return UtilNodes.TempResultFromStatNode(
            result_node,
            Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))

1491
    def _handle_simple_function_sum(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1492 1493
        """Transform sum(genexpr) into an equivalent inlined aggregation loop.
        """
1494 1495
        if len(pos_args) not in (1,2):
            return node
1496 1497
        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
                                        ExprNodes.ComprehensionNode)):
1498 1499 1500 1501
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515
        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
            if yield_expression is None:
                return node
        else: # ComprehensionNode
            yield_stat_node = gen_expr_node.append
            yield_expression = yield_stat_node.expr
            try:
                if not yield_expression.is_literal or not yield_expression.type.is_int:
                    return node
            except AttributeError:
                return node # in case we don't have a type yet
            # special case: old Py2 backwards compatible "sum([int_const for ...])"
            # can safely be unpacked into a genexpr
1516 1517 1518 1519 1520 1521 1522 1523

        if len(pos_args) == 1:
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
        else:
            start = pos_args[1]

        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
        add_node = Nodes.SingleAssignmentNode(
1524
            yield_expression.pos,
1525 1526 1527 1528
            lhs = result_ref,
            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
            )

1529
        Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543

        exec_code = Nodes.StatListNode(
            node.pos,
            stats = [
                Nodes.SingleAssignmentNode(
                    start.pos,
                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
                    rhs = start,
                    first = True),
                loop_node
                ])

        return ExprNodes.InlinedGeneratorExpressionNode(
            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1544 1545
            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
            has_local_scope = gen_expr_node.has_local_scope)
1546

1547 1548 1549 1550 1551 1552 1553 1554 1555 1556
    def _handle_simple_function_min(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '<')

    def _handle_simple_function_max(self, node, pos_args):
        return self._optimise_min_max(node, pos_args, '>')

    def _optimise_min_max(self, node, args, operator):
        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
        """
        if len(args) <= 1:
1557 1558 1559 1560 1561
            if len(args) == 1 and args[0].is_sequence_constructor:
                args = args[0].args
            else:
                # leave this to Python
                return node
1562

1563
        cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585

        last_result = args[0]
        for arg_node in cascaded_nodes:
            result_ref = UtilNodes.ResultRefNode(last_result)
            last_result = ExprNodes.CondExprNode(
                arg_node.pos,
                true_val = arg_node,
                false_val = result_ref,
                test = ExprNodes.PrimaryCmpNode(
                    arg_node.pos,
                    operand1 = arg_node,
                    operator = operator,
                    operand2 = result_ref,
                    )
                )
            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)

        for ref_node in cascaded_nodes[::-1]:
            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)

        return last_result

1586
    def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1587
        if not pos_args:
1588 1589 1590 1591 1592 1593 1594 1595
            return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
        # This is a bit special - for iterables (including genexps),
        # Python actually overallocates and resizes a newly created
        # tuple incrementally while reading items, which we can't
        # easily do without explicit node support. Instead, we read
        # the items into a list and then copy them into a tuple of the
        # final size.  This takes up to twice as much memory, but will
        # have to do until we have real support for genexps.
1596
        result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1597 1598 1599 1600
        if result is not node:
            return ExprNodes.AsTupleNode(node.pos, arg=result)
        return node

1601
    def _handle_simple_function_list(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1602
        if not pos_args:
1603
            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1604
        return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
1605 1606

    def _handle_simple_function_set(self, node, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
1607
        if not pos_args:
1608
            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1609
        return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
1610

1611
    def _transform_list_set_genexpr(self, node, pos_args, target_type):
1612 1613 1614 1615 1616 1617 1618 1619 1620
        """Replace set(genexpr) and list(genexpr) by a literal comprehension.
        """
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1621 1622
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1623 1624 1625
            return node

        append_node = ExprNodes.ComprehensionAppendNode(
1626
            yield_expression.pos,
1627
            expr = yield_expression)
1628

1629
        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1630

1631
        comp = ExprNodes.ComprehensionNode(
1632 1633 1634 1635 1636
            node.pos,
            has_local_scope = True,
            expr_scope = gen_expr_node.expr_scope,
            loop = loop_node,
            append = append_node,
1637 1638 1639
            type = target_type)
        append_node.target = comp
        return comp
1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652

    def _handle_simple_function_dict(self, node, pos_args):
        """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
        """
        if len(pos_args) == 0:
            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
        if len(pos_args) > 1:
            return node
        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
            return node
        gen_expr_node = pos_args[0]
        loop_node = gen_expr_node.loop

1653 1654
        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
        if yield_expression is None:
1655 1656 1657 1658 1659 1660 1661 1662
            return node

        if not isinstance(yield_expression, ExprNodes.TupleNode):
            return node
        if len(yield_expression.args) != 2:
            return node

        append_node = ExprNodes.DictComprehensionAppendNode(
1663
            yield_expression.pos,
1664
            key_expr = yield_expression.args[0],
1665
            value_expr = yield_expression.args[1])
1666

1667
        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1668 1669 1670 1671 1672 1673 1674

        dictcomp = ExprNodes.ComprehensionNode(
            node.pos,
            has_local_scope = True,
            expr_scope = gen_expr_node.expr_scope,
            loop = loop_node,
            append = append_node,
1675
            type = Builtin.dict_type)
1676 1677 1678
        append_node.target = dictcomp
        return dictcomp

1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690
    # specific handlers for general call nodes

    def _handle_general_function_dict(self, node, pos_args, kwargs):
        """Replace dict(a=b,c=d,...) by the underlying keyword dict
        construction which is done anyway.
        """
        if len(pos_args) > 0:
            return node
        if not isinstance(kwargs, ExprNodes.DictNode):
            return node
        return kwargs

1691 1692

class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
1693 1694
    visit_Node = Visitor.VisitorTransform.recurse_to_children

1695
    def get_constant_value_node(self, name_node):
1696 1697 1698
        if name_node.cf_state is None:
            return None
        if name_node.cf_state.cf_is_null:
1699 1700 1701 1702 1703 1704
            return None
        entry = self.current_env().lookup(name_node.name)
        if not entry or (not entry.cf_assignments
                         or len(entry.cf_assignments) != 1):
            # not just a single assignment in all closures
            return None
1705
        return entry.cf_assignments[0].rhs
1706

1707 1708 1709 1710 1711 1712 1713
    def visit_SimpleCallNode(self, node):
        self.visitchildren(node)
        if not self.current_directives.get('optimize.inline_defnode_calls'):
            return node
        function_name = node.function
        if not function_name.is_name:
            return node
1714
        function = self.get_constant_value_node(function_name)
1715 1716 1717 1718 1719 1720
        if not isinstance(function, ExprNodes.PyCFunctionNode):
            return node
        inlined = ExprNodes.InlinedDefNodeCallNode(
            node.pos, function_name=function_name,
            function=function, args=node.args)
        if inlined.can_be_inlined():
1721
            return self.replace(node, inlined)
1722 1723
        return node

1724

1725
class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
Stefan Behnel's avatar
Stefan Behnel committed
1726
    """Optimize some common methods calls and instantiation patterns
1727 1728 1729 1730 1731
    for builtin types *after* the type analysis phase.

    Running after type analysis, this transform can only perform
    function replacements that do not alter the function return type
    in a way that was not anticipated by the type analysis.
1732
    """
1733 1734
    ### cleanup to avoid redundant coercions to/from Python types

1735 1736 1737
    def _visit_PyTypeTestNode(self, node):
        # disabled - appears to break assignments in some cases, and
        # also drops a None check, which might still be required
1738 1739 1740 1741 1742 1743 1744 1745
        """Flatten redundant type checks after tree changes.
        """
        old_arg = node.arg
        self.visitchildren(node)
        if old_arg is node.arg or node.arg.type != node.type:
            return node
        return node.arg

1746 1747 1748
    def _visit_TypecastNode(self, node):
        # disabled - the user may have had a reason to put a type
        # cast, even if it looks redundant to Cython
1749 1750 1751 1752 1753 1754 1755 1756
        """
        Drop redundant type casts.
        """
        self.visitchildren(node)
        if node.type == node.operand.type:
            return node.operand
        return node

1757 1758 1759 1760 1761 1762 1763 1764 1765
    def visit_ExprStatNode(self, node):
        """
        Drop useless coercions.
        """
        self.visitchildren(node)
        if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
            node.expr = node.expr.arg
        return node

1766 1767 1768 1769 1770
    def visit_CoerceToBooleanNode(self, node):
        """Drop redundant conversion nodes after tree changes.
        """
        self.visitchildren(node)
        arg = node.arg
Stefan Behnel's avatar
Stefan Behnel committed
1771 1772
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
1773 1774
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1775
                return arg.arg.coerce_to_boolean(self.current_env())
1776 1777
        return node

1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791
    def visit_CoerceFromPyTypeNode(self, node):
        """Drop redundant conversion nodes after tree changes.

        Also, optimise away calls to Python's builtin int() and
        float() if the result is going to be coerced back into a C
        type anyway.
        """
        self.visitchildren(node)
        arg = node.arg
        if not arg.type.is_pyobject:
            # no Python conversion left at all, just do a C coercion instead
            if node.type == arg.type:
                return arg
            else:
1792
                return arg.coerce_to(node.type, self.current_env())
1793 1794
        if isinstance(arg, ExprNodes.PyTypeTestNode):
            arg = arg.arg
1795 1796
        if arg.is_literal:
            if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
1797 1798
                    node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
                    node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
1799 1800
                return arg.coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1801 1802 1803
            if arg.type is PyrexTypes.py_object_type:
                if node.type.assignable_from(arg.arg.type):
                    # completely redundant C->Py->C coercion
1804
                    return arg.arg.coerce_to(node.type, self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
1805
        elif isinstance(arg, ExprNodes.SimpleCallNode):
Stefan Behnel's avatar
Stefan Behnel committed
1806 1807
            if node.type.is_int or node.type.is_float:
                return self._optimise_numeric_cast_call(node, arg)
1808 1809 1810 1811 1812 1813
        elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
            index_node = arg.index
            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
                index_node = index_node.arg
            if index_node.type.is_int:
                return self._optimise_int_indexing(node, arg, index_node)
Stefan Behnel's avatar
Stefan Behnel committed
1814 1815
        return node

1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827
    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_char_type, [
            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
            ],
        exception_value = "((char)-1)",
        exception_check = True)

    def _optimise_int_indexing(self, coerce_node, arg, index_node):
        env = self.current_env()
        bound_check_bool = env.directives['boundscheck'] and 1 or 0
1828
        if arg.base.type is Builtin.bytes_type:
1829 1830 1831 1832 1833 1834 1835 1836
            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
                # bytes[index] -> char
                bound_check_node = ExprNodes.IntNode(
                    coerce_node.pos, value=str(bound_check_bool),
                    constant_result=bound_check_bool)
                node = ExprNodes.PythonCapiCallNode(
                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
                    self.PyBytes_GetItemInt_func_type,
1837
                    args=[
Stefan Behnel's avatar
Stefan Behnel committed
1838
                        arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1839 1840 1841
                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
                        bound_check_node,
                        ],
1842 1843 1844
                    is_temp=True,
                    utility_code=UtilityCode.load_cached(
                        'bytes_index', 'StringTools.c'))
1845 1846 1847 1848 1849
                if coerce_node.type is not PyrexTypes.c_char_type:
                    node = node.coerce_to(coerce_node.type, env)
                return node
        return coerce_node

Stefan Behnel's avatar
Stefan Behnel committed
1850
    def _optimise_numeric_cast_call(self, node, arg):
1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869
        function = arg.function
        if not isinstance(function, ExprNodes.NameNode) \
               or not function.type.is_builtin_type \
               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
            return node
        args = arg.arg_tuple.args
        if len(args) != 1:
            return node
        func_arg = args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        elif func_arg.type.is_pyobject:
            # play safe: Python conversion might work on all sorts of things
            return node
        if function.name == 'int':
            if func_arg.type.is_int or node.type.is_int:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1870 1871
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1872 1873 1874 1875 1876
        elif function.name == 'float':
            if func_arg.type.is_float or node.type.is_float:
                if func_arg.type == node.type:
                    return func_arg
                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1877 1878
                    return ExprNodes.TypecastNode(
                        node.pos, operand=func_arg, type=node.type)
1879 1880
        return node

1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896
    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
        if not expected: # None or 0
            arg_str = ''
        elif isinstance(expected, basestring) or expected > 1:
            arg_str = '...'
        elif expected == 1:
            arg_str = 'x'
        else:
            arg_str = ''
        if expected is not None:
            expected_str = 'expected %s, ' % expected
        else:
            expected_str = ''
        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
            function_name, arg_str, expected_str, len(args)))

1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942
    ### generic fallbacks

    def _handle_function(self, node, function_name, function, arg_list, kwargs):
        return node

    def _handle_method(self, node, type_name, attr_name, function,
                       arg_list, is_unbound_method, kwargs):
        """
        Try to inject C-API calls for unbound method calls to builtin types.
        While the method declarations in Builtin.py already handle this, we
        can additionally resolve bound and unbound methods here that were
        assigned to variables ahead of time.
        """
        if kwargs:
            return node
        if not function or not function.is_attribute or not function.obj.is_name:
            # cannot track unbound method calls over more than one indirection as
            # the names might have been reassigned in the meantime
            return node
        type_entry = self.current_env().lookup(type_name)
        if not type_entry:
            return node
        method = ExprNodes.AttributeNode(
            node.function.pos,
            obj=ExprNodes.NameNode(
                function.pos,
                name=type_name,
                entry=type_entry,
                type=type_entry.type),
            attribute=attr_name,
            is_called=True).analyse_as_unbound_cmethod_node(self.current_env())
        if method is None:
            return node
        args = node.args
        if args is None and node.arg_tuple:
            args = node.arg_tuple.args
        call_node = ExprNodes.SimpleCallNode(
            node.pos,
            function=method,
            args=args)
        if not is_unbound_method:
            call_node.self = function.obj
        call_node.analyse_c_function_call(self.current_env())
        call_node.analysed = True
        return call_node.coerce_to(node.type, self.current_env())

1943 1944
    ### builtin types

1945 1946 1947 1948 1949
    PyDict_Copy_func_type = PyrexTypes.CFuncType(
        Builtin.dict_type, [
            PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
            ])

1950
    def _handle_simple_function_dict(self, node, function, pos_args):
1951
        """Replace dict(some_dict) by PyDict_Copy(some_dict).
1952
        """
1953
        if len(pos_args) != 1:
1954
            return node
1955
        arg = pos_args[0]
1956
        if arg.type is Builtin.dict_type:
1957
            arg = arg.as_none_safe_node("'NoneType' is not iterable")
1958 1959
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1960
                args = [arg],
1961 1962 1963
                is_temp = node.is_temp
                )
        return node
1964

1965 1966 1967 1968 1969
    PyList_AsTuple_func_type = PyrexTypes.CFuncType(
        Builtin.tuple_type, [
            PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
            ])

1970
    def _handle_simple_function_tuple(self, node, function, pos_args):
1971 1972
        """Replace tuple([...]) by a call to PyList_AsTuple.
        """
1973
        if len(pos_args) != 1:
1974
            return node
1975
        list_arg = pos_args[0]
1976 1977 1978 1979
        if list_arg.type is not Builtin.list_type:
            return node
        if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
                                     ExprNodes.ListNode)):
1980
            pos_args[0] = list_arg.as_none_safe_node(
1981
                "'NoneType' object is not iterable")
1982

1983 1984
        return ExprNodes.PythonCapiCallNode(
            node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1985
            args = pos_args,
1986 1987 1988
            is_temp = node.is_temp
            )

1989 1990 1991 1992 1993 1994 1995
    PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_double_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "((double)-1)",
        exception_check = True)

1996
    def _handle_simple_function_set(self, node, function, pos_args):
1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016
        if len(pos_args) == 1 and isinstance(pos_args[0], (ExprNodes.ListNode,
                                                           ExprNodes.TupleNode)):
            # We can optimise set([x,y,z]) safely into a set literal,
            # but only if we create all items before adding them -
            # adding an item may raise an exception if it is not
            # hashable, but creating the later items may have
            # side-effects.
            args = []
            temps = []
            for arg in pos_args[0].args:
                if not arg.is_simple():
                    arg = UtilNodes.LetRefNode(arg)
                    temps.append(arg)
                args.append(arg)
            result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
            for temp in temps[::-1]:
                result = UtilNodes.EvalWithTempExprNode(temp, result)
            return result
        return node

2017
    def _handle_simple_function_float(self, node, function, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
2018 2019 2020
        """Transform float() into either a C type cast or a faster C
        function call.
        """
2021 2022
        # Note: this requires the float() function to be typed as
        # returning a C 'double'
2023
        if len(pos_args) == 0:
Stefan Behnel's avatar
typo  
Stefan Behnel committed
2024
            return ExprNodes.FloatNode(
2025 2026 2027 2028
                node, value="0.0", constant_result=0.0
                ).coerce_to(Builtin.float_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
2029 2030 2031 2032 2033 2034 2035
            return node
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            func_arg = func_arg.arg
        if func_arg.type is PyrexTypes.c_double_type:
            return func_arg
        elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
2036 2037
            return ExprNodes.TypecastNode(
                node.pos, operand=func_arg, type=node.type)
2038 2039 2040 2041 2042
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_AsDouble",
            self.PyObject_AsDouble_func_type,
            args = pos_args,
            is_temp = node.is_temp,
2043
            utility_code = load_c_utility('pyobject_as_double'),
2044 2045
            py_name = "float")

2046 2047 2048 2049 2050 2051 2052 2053 2054 2055
    PyNumber_Int_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
            ])

    def _handle_simple_function_int(self, node, function, pos_args):
        """Transform int() into a faster C function call.
        """
        if len(pos_args) == 0:
            return ExprNodes.IntNode(node, value="0", constant_result=0,
Stefan Behnel's avatar
Stefan Behnel committed
2056
                                     type=PyrexTypes.py_object_type)
2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067
        elif len(pos_args) != 1:
            return node  # int(x, base)
        func_arg = pos_args[0]
        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
            return node  # handled in visit_CoerceFromPyTypeNode()
        if func_arg.type.is_pyobject and node.type.is_pyobject:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "PyNumber_Int", self.PyNumber_Int_func_type,
                args=pos_args, is_temp=True)
        return node

2068
    def _handle_simple_function_bool(self, node, function, pos_args):
2069 2070
        """Transform bool(x) into a type coercion to a boolean.
        """
2071 2072 2073 2074 2075 2076
        if len(pos_args) == 0:
            return ExprNodes.BoolNode(
                node.pos, value=False, constant_result=False
                ).coerce_to(Builtin.bool_type, self.current_env())
        elif len(pos_args) != 1:
            self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
2077
            return node
Craig Citro's avatar
Craig Citro committed
2078
        else:
2079 2080 2081 2082 2083 2084
            # => !!<bint>(x)  to make sure it's exactly 0 or 1
            operand = pos_args[0].coerce_to_boolean(self.current_env())
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            operand = ExprNodes.NotNode(node.pos, operand = operand)
            # coerce back to Python object as that's the result we are expecting
            return operand.coerce_to_pyobject(self.current_env())
2085

2086 2087
    ### builtin functions

2088 2089 2090 2091 2092
    Pyx_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
            PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
            ])

2093 2094 2095 2096 2097
    Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_size_t_type, [
            PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_py_unicode_ptr_type, None)
            ])

2098 2099 2100
    PyObject_Size_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
2101 2102
            ],
        exception_value="-1")
2103 2104

    _map_to_capi_len_function = {
2105
        Builtin.unicode_type   : "__Pyx_PyUnicode_GET_LENGTH",
2106
        Builtin.bytes_type     : "PyBytes_GET_SIZE",
2107 2108 2109 2110 2111 2112 2113
        Builtin.list_type      : "PyList_GET_SIZE",
        Builtin.tuple_type     : "PyTuple_GET_SIZE",
        Builtin.dict_type      : "PyDict_Size",
        Builtin.set_type       : "PySet_Size",
        Builtin.frozenset_type : "PySet_Size",
        }.get

2114 2115
    _ext_types_with_pysize = set(["cpython.array.array"])

2116
    def _handle_simple_function_len(self, node, function, pos_args):
2117 2118
        """Replace len(char*) by the equivalent call to strlen(),
        len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
Stefan Behnel's avatar
Stefan Behnel committed
2119
        len(known_builtin_type) by an equivalent C-API call.
Stefan Behnel's avatar
Stefan Behnel committed
2120
        """
2121 2122 2123 2124 2125 2126
        if len(pos_args) != 1:
            self._error_wrong_arg_count('len', node, pos_args, 1)
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
            arg = arg.arg
2127 2128 2129 2130 2131
        if arg.type.is_string:
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "strlen", self.Pyx_strlen_func_type,
                args = [arg],
                is_temp = node.is_temp,
2132
                utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
2133
        elif arg.type.is_pyunicode_ptr:
2134 2135 2136
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
                args = [arg],
2137
                is_temp = node.is_temp)
2138 2139 2140
        elif arg.type.is_pyobject:
            cfunc_name = self._map_to_capi_len_function(arg.type)
            if cfunc_name is None:
2141 2142 2143 2144 2145 2146
                arg_type = arg.type
                if ((arg_type.is_extension_type or arg_type.is_builtin_type)
                    and arg_type.entry.qualified_name in self._ext_types_with_pysize):
                    cfunc_name = 'Py_SIZE'
                else:
                    return node
2147 2148
            arg = arg.as_none_safe_node(
                "object of type 'NoneType' has no len()")
2149 2150 2151 2152
            new_node = ExprNodes.PythonCapiCallNode(
                node.pos, cfunc_name, self.PyObject_Size_func_type,
                args = [arg],
                is_temp = node.is_temp)
Stefan Behnel's avatar
Stefan Behnel committed
2153
        elif arg.type.is_unicode_char:
2154 2155
            return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
                                     type=node.type)
2156
        else:
Stefan Behnel's avatar
Stefan Behnel committed
2157
            return node
2158
        if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
2159
            new_node = new_node.coerce_to(node.type, self.current_env())
2160
        return new_node
2161

2162 2163 2164 2165 2166
    Pyx_Type_func_type = PyrexTypes.CFuncType(
        Builtin.type_type, [
            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
            ])

2167
    def _handle_simple_function_type(self, node, function, pos_args):
Stefan Behnel's avatar
Stefan Behnel committed
2168 2169
        """Replace type(o) by a macro call to Py_TYPE(o).
        """
2170
        if len(pos_args) != 1:
2171 2172
            return node
        node = ExprNodes.PythonCapiCallNode(
2173 2174 2175 2176
            node.pos, "Py_TYPE", self.Pyx_Type_func_type,
            args = pos_args,
            is_temp = False)
        return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
2177

2178 2179 2180 2181 2182
    Py_type_check_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
            PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
            ])

2183
    def _handle_simple_function_isinstance(self, node, function, pos_args):
2184 2185 2186 2187 2188 2189 2190 2191 2192
        """Replace isinstance() checks against builtin types by the
        corresponding C-API call.
        """
        if len(pos_args) != 2:
            return node
        arg, types = pos_args
        temp = None
        if isinstance(types, ExprNodes.TupleNode):
            types = types.args
2193 2194
            if arg.is_attribute or not arg.is_simple():
                arg = temp = UtilNodes.ResultRefNode(arg)
2195 2196 2197 2198 2199 2200 2201 2202 2203
        elif types.type is Builtin.type_type:
            types = [types]
        else:
            return node

        tests = []
        test_nodes = []
        env = self.current_env()
        for test_type_node in types:
Robert Bradshaw's avatar
Robert Bradshaw committed
2204
            builtin_type = None
Stefan Behnel's avatar
Stefan Behnel committed
2205
            if test_type_node.is_name:
Robert Bradshaw's avatar
Robert Bradshaw committed
2206 2207 2208 2209
                if test_type_node.entry:
                    entry = env.lookup(test_type_node.entry.name)
                    if entry and entry.type and entry.type.is_builtin_type:
                        builtin_type = entry.type
2210 2211 2212 2213 2214 2215
            if builtin_type is Builtin.type_type:
                # all types have type "type", but there's only one 'type'
                if entry.name != 'type' or not (
                        entry.scope and entry.scope.is_builtin_scope):
                    builtin_type = None
            if builtin_type is not None:
Robert Bradshaw's avatar
Robert Bradshaw committed
2216
                type_check_function = entry.type.type_check_function(exact=False)
2217 2218 2219
                if type_check_function in tests:
                    continue
                tests.append(type_check_function)
Robert Bradshaw's avatar
Robert Bradshaw committed
2220 2221 2222 2223 2224
                type_check_args = [arg]
            elif test_type_node.type is Builtin.type_type:
                type_check_function = '__Pyx_TypeCheck'
                type_check_args = [arg, test_type_node]
            else:
2225
                return node
2226 2227 2228 2229 2230 2231
            test_nodes.append(
                ExprNodes.PythonCapiCallNode(
                    test_type_node.pos, type_check_function, self.Py_type_check_func_type,
                    args = type_check_args,
                    is_temp = True,
                    ))
2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243

        def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
            or_node = make_binop_node(node.pos, 'or', a, b)
            or_node.type = PyrexTypes.c_bint_type
            or_node.is_temp = True
            return or_node

        test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
        if temp is not None:
            test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
        return test_node

2244
    def _handle_simple_function_ord(self, node, function, pos_args):
2245
        """Unpack ord(Py_UNICODE) and ord('X').
2246 2247 2248 2249 2250
        """
        if len(pos_args) != 1:
            return node
        arg = pos_args[0]
        if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
Stefan Behnel's avatar
Stefan Behnel committed
2251
            if arg.arg.type.is_unicode_char:
2252 2253 2254 2255 2256 2257
                return ExprNodes.TypecastNode(
                    arg.pos, operand=arg.arg, type=PyrexTypes.c_int_type
                    ).coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.UnicodeNode):
            if len(arg.value) == 1:
                return ExprNodes.IntNode(
2258
                    arg.pos, type=PyrexTypes.c_int_type,
2259 2260 2261 2262 2263
                    value=str(ord(arg.value)),
                    constant_result=ord(arg.value)
                    ).coerce_to(node.type, self.current_env())
        elif isinstance(arg, ExprNodes.StringNode):
            if arg.unicode_value and len(arg.unicode_value) == 1 \
2264
                    and ord(arg.unicode_value) <= 255:  # Py2/3 portability
2265
                return ExprNodes.IntNode(
2266
                    arg.pos, type=PyrexTypes.c_int_type,
2267 2268 2269
                    value=str(ord(arg.unicode_value)),
                    constant_result=ord(arg.unicode_value)
                    ).coerce_to(node.type, self.current_env())
2270 2271
        return node

2272 2273
    ### special methods

2274 2275
    Pyx_tp_new_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
2276
            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2277
            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
2278 2279
            ])

2280 2281
    Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
2282
            PyrexTypes.CFuncTypeArg("type",   PyrexTypes.py_object_type, None),
2283 2284 2285 2286
            PyrexTypes.CFuncTypeArg("args",   Builtin.tuple_type, None),
            PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
        ])

2287 2288
    def _handle_any_slot__new__(self, node, function, args,
                                is_unbound_method, kwargs=None):
2289
        """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
2290
        """
2291
        obj = function.obj
2292
        if not is_unbound_method or len(args) < 1:
2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306
            return node
        type_arg = args[0]
        if not obj.is_name or not type_arg.is_name:
            # play safe
            return node
        if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
            # not a known type, play safe
            return node
        if not type_arg.type_entry or not obj.type_entry:
            if obj.name != type_arg.name:
                return node
            # otherwise, we know it's a type and we know it's the same
            # type for both - that should do
        elif type_arg.type_entry != obj.type_entry:
2307
            # different types - may or may not lead to an error at runtime
2308 2309
            return node

2310 2311 2312
        args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
        args_tuple = args_tuple.analyse_types(
            self.current_env(), skip_children=True)
2313

2314 2315
        if type_arg.type_entry:
            ext_type = type_arg.type_entry.type
2316 2317 2318
            if (ext_type.is_extension_type and ext_type.typeobj_cname and
                    ext_type.scope.global_scope() == self.current_env().global_scope()):
                # known type in current module
2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339
                tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
                slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
                if slot_func_cname:
                    cython_scope = self.context.cython_scope
                    PyTypeObjectPtr = PyrexTypes.CPtrType(
                        cython_scope.lookup('PyTypeObject').type)
                    pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
                        PyrexTypes.py_object_type, [
                            PyrexTypes.CFuncTypeArg("type",   PyTypeObjectPtr, None),
                            PyrexTypes.CFuncTypeArg("args",   PyrexTypes.py_object_type, None),
                            PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
                            ])

                    type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
                    if not kwargs:
                        kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type)  # hack?
                    return ExprNodes.PythonCapiCallNode(
                        node.pos, slot_func_cname,
                        pyx_tp_new_kwargs_func_type,
                        args=[type_arg, args_tuple, kwargs],
                        is_temp=True)
2340
        else:
2341
            # arbitrary variable, needs a None check for safety
2342
            type_arg = type_arg.as_none_safe_node(
2343 2344
                "object.__new__(X): X is not a type object (NoneType)")

2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358
        utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
        if kwargs:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
                args=[type_arg, args_tuple, kwargs],
                utility_code=utility_code,
                is_temp=node.is_temp
                )
        else:
            return ExprNodes.PythonCapiCallNode(
                node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
                args=[type_arg, args_tuple],
                utility_code=utility_code,
                is_temp=node.is_temp
2359 2360
            )

2361 2362 2363
    ### methods of builtin types

    PyObject_Append_func_type = PyrexTypes.CFuncType(
2364
        PyrexTypes.c_returncode_type, [
2365 2366
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2367 2368
            ],
        exception_value="-1")
2369

2370
    def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2371 2372 2373
        """Optimistic optimisation as X.append() is almost always
        referring to a list.
        """
2374
        if len(args) != 2 or node.result_is_used:
2375 2376
            return node

2377 2378
        return ExprNodes.PythonCapiCallNode(
            node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2379 2380 2381 2382 2383 2384
            args=args,
            may_return_none=False,
            is_temp=node.is_temp,
            result_is_used=False,
            utility_code=load_c_utility('append')
        )
2385

2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432
    PyByteArray_Append_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_returncode_type, [
            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
            ],
        exception_value="-1")

    PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_returncode_type, [
            PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
            ],
        exception_value="-1")

    def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
        if len(args) != 2:
            return node
        func_name = "__Pyx_PyByteArray_Append"
        func_type = self.PyByteArray_Append_func_type

        value = unwrap_coerced_node(args[1])
        if value.type.is_int:
            value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
        elif value.is_string_literal:
            if not value.can_coerce_to_char_literal():
                return node
            value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
            utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
        elif value.type.is_pyobject:
            func_name = "__Pyx_PyByteArray_AppendObject"
            func_type = self.PyByteArray_AppendObject_func_type
            utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
        else:
            return node

        new_node = ExprNodes.PythonCapiCallNode(
            node.pos, func_name, func_type,
            args=[args[0], value],
            may_return_none=False,
            is_temp=node.is_temp,
            utility_code=utility_code,
        )
        if node.result_is_used:
            new_node = new_node.coerce_to(node.type, self.current_env())
        return new_node

Robert Bradshaw's avatar
Robert Bradshaw committed
2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443
    PyObject_Pop_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            ])

    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
            ])

2444 2445 2446 2447 2448
    def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
        return self._handle_simple_method_object_pop(
            node, function, args, is_unbound_method, is_list=True)

    def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
Stefan Behnel's avatar
Stefan Behnel committed
2449 2450 2451
        """Optimistic optimisation as X.pop([n]) is almost always
        referring to a list.
        """
2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462
        if not args:
            return node
        args = args[:]
        if is_list:
            type_name = 'List'
            args[0] = args[0].as_none_safe_node(
                "'NoneType' object has no attribute '%s'",
                error="PyExc_AttributeError",
                format_args=['pop'])
        else:
            type_name = 'Object'
Robert Bradshaw's avatar
Robert Bradshaw committed
2463 2464
        if len(args) == 1:
            return ExprNodes.PythonCapiCallNode(
2465 2466 2467 2468 2469 2470 2471
                node.pos, "__Pyx_Py%s_Pop" % type_name,
                self.PyObject_Pop_func_type,
                args=args,
                may_return_none=True,
                is_temp=node.is_temp,
                utility_code=load_c_utility('pop'),
            )
Robert Bradshaw's avatar
Robert Bradshaw committed
2472
        elif len(args) == 2:
2473 2474 2475
            index = unwrap_coerced_node(args[1])
            if is_list or isinstance(index, ExprNodes.IntNode):
                index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2476 2477 2478 2479 2480
            if index.type.is_int:
                widest = PyrexTypes.widest_numeric_type(
                    index.type, PyrexTypes.c_py_ssize_t_type)
                if widest == PyrexTypes.c_py_ssize_t_type:
                    args[1] = index
Robert Bradshaw's avatar
Robert Bradshaw committed
2481
                    return ExprNodes.PythonCapiCallNode(
2482
                        node.pos, "__Pyx_Py%s_PopIndex" % type_name,
2483
                        self.PyObject_PopIndex_func_type,
2484 2485 2486 2487 2488
                        args=args,
                        may_return_none=True,
                        is_temp=node.is_temp,
                        utility_code=load_c_utility("pop_index"),
                    )
2489

Robert Bradshaw's avatar
Robert Bradshaw committed
2490 2491
        return node

2492
    single_param_func_type = PyrexTypes.CFuncType(
2493
        PyrexTypes.c_returncode_type, [
2494 2495 2496
            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
            ],
        exception_value = "-1")
2497

2498
    def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2499 2500
        """Call PyList_Sort() instead of the 0-argument l.sort().
        """
2501
        if len(args) != 1:
2502
            return node
2503
        return self._substitute_method_call(
2504
            node, function, "PyList_Sort", self.single_param_func_type,
2505
            'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
2506

2507 2508 2509 2510 2511
    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2512
            ])
2513

2514
    def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2515 2516
        """Replace dict.get() by a call to PyDict_GetItem().
        """
2517 2518 2519 2520 2521 2522 2523
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
            return node

        return self._substitute_method_call(
2524 2525
            node, function,
            "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2526
            'get', is_unbound_method, args,
Stefan Behnel's avatar
Stefan Behnel committed
2527
            may_return_none = True,
2528
            utility_code = load_c_utility("dict_getitem_default"))
2529

2530 2531 2532 2533 2534
    Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
        PyrexTypes.py_object_type, [
            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2535
            PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
2536 2537
            ])

2538
    def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
2539 2540 2541 2542 2543 2544 2545
        """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
        """
        if len(args) == 2:
            args.append(ExprNodes.NoneNode(node.pos))
        elif len(args) != 3:
            self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
            return node
2546 2547 2548 2549 2550 2551 2552 2553 2554
        key_type = args[1].type
        if key_type.is_builtin_type:
            is_safe_type = int(key_type.name in
                               'str bytes unicode float int long bool')
        elif key_type is PyrexTypes.py_object_type:
            is_safe_type = -1  # don't know
        else:
            is_safe_type = 0   # definitely not
        args.append(ExprNodes.IntNode(
2555
            node.pos, value=str(is_safe_type), constant_result=is_safe_type))
2556 2557

        return self._substitute_method_call(
2558 2559
            node, function,
            "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
2560
            'setdefault', is_unbound_method, args,
2561 2562
            may_return_none=True,
            utility_code=load_c_utility('dict_setdefault'))
2563

2564 2565 2566

    ### unicode type methods

2567 2568
    PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_bint_type, [
2569
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2570 2571
            ])

2572
    def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
2573 2574 2575 2576
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
2577
               not ustring.arg.type.is_unicode_char:
2578 2579
            return node
        uchar = ustring.arg
2580
        method_name = function.attribute
2581 2582
        if method_name == 'istitle':
            # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2583 2584
            utility_code = UtilityCode.load_cached(
                "py_unicode_istitle", "StringTools.c")
2585 2586 2587 2588 2589
            function_name = '__Pyx_Py_UNICODE_ISTITLE'
        else:
            utility_code = None
            function_name = 'Py_UNICODE_%s' % method_name.upper()
        func_call = self._substitute_method_call(
2590 2591
            node, function,
            function_name, self.PyUnicode_uchar_predicate_func_type,
2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608
            method_name, is_unbound_method, [uchar],
            utility_code = utility_code)
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
    _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
    _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
    _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
    _handle_simple_method_unicode_islower   = _inject_unicode_predicate
    _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
    _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
    _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
    _handle_simple_method_unicode_isupper   = _inject_unicode_predicate

    PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2609 2610
        PyrexTypes.c_py_ucs4_type, [
            PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2611 2612
            ])

Stefan Behnel's avatar
Stefan Behnel committed
2613
    def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
2614 2615 2616 2617
        if is_unbound_method or len(args) != 1:
            return node
        ustring = args[0]
        if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
Stefan Behnel's avatar
Stefan Behnel committed
2618
               not ustring.arg.type.is_unicode_char:
2619 2620
            return node
        uchar = ustring.arg
2621
        method_name = function.attribute
2622 2623
        function_name = 'Py_UNICODE_TO%s' % method_name.upper()
        func_call = self._substitute_method_call(
2624 2625
            node, function,
            function_name, self.PyUnicode_uchar_conversion_func_type,
2626 2627 2628 2629 2630 2631 2632 2633 2634
            method_name, is_unbound_method, [uchar])
        if node.type.is_pyobject:
            func_call = func_call.coerce_to_pyobject(self.current_env)
        return func_call

    _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
    _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
    _handle_simple_method_unicode_title = _inject_unicode_character_conversion

2635 2636 2637 2638 2639 2640
    PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
            ])

Stefan Behnel's avatar
Stefan Behnel committed
2641
    def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
2642 2643 2644 2645 2646 2647
        """Replace unicode.splitlines(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2):
            self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
            return node
2648
        self._inject_bint_default_argument(node, args, 1, False)
2649 2650

        return self._substitute_method_call(
2651 2652
            node, function,
            "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2653 2654 2655 2656 2657 2658 2659 2660 2661 2662
            'splitlines', is_unbound_method, args)

    PyUnicode_Split_func_type = PyrexTypes.CFuncType(
        Builtin.list_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
            ]
        )

2663
    def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
2664 2665 2666 2667 2668 2669 2670 2671
        """Replace unicode.split(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (1,2,3):
            self._error_wrong_arg_count('unicode.split', node, args, "1-3")
            return node
        if len(args) < 2:
            args.append(ExprNodes.NullNode(node.pos))
2672 2673
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2674 2675

        return self._substitute_method_call(
2676 2677
            node, function,
            "PyUnicode_Split", self.PyUnicode_Split_func_type,
2678 2679
            'split', is_unbound_method, args)

2680
    PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
2681
        PyrexTypes.c_bint_type, [
2682
            PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None),  # bytes/str/unicode
2683 2684 2685 2686 2687 2688 2689
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-1')

2690
    def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
2691
        return self._inject_tailmatch(
2692
            node, function, args, is_unbound_method, 'unicode', 'endswith',
2693
            unicode_tailmatch_utility_code, +1)
2694

2695
    def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
2696
        return self._inject_tailmatch(
2697
            node, function, args, is_unbound_method, 'unicode', 'startswith',
2698
            unicode_tailmatch_utility_code, -1)
2699

2700
    def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
2701
                          method_name, utility_code, direction):
2702 2703 2704 2705
        """Replace unicode.startswith(...) and unicode.endswith(...)
        by a direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
2706
            self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
2707
            return node
2708 2709 2710 2711
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2712 2713 2714 2715
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
2716 2717
            node, function,
            "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
2718
            self.PyString_Tailmatch_func_type,
2719
            method_name, is_unbound_method, args,
2720
            utility_code = utility_code)
Stefan Behnel's avatar
Stefan Behnel committed
2721
        return method_call.coerce_to(Builtin.bool_type, self.current_env())
2722

2723 2724 2725 2726 2727 2728 2729 2730 2731 2732
    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
            ],
        exception_value = '-2')

2733
    def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
2734
        return self._inject_unicode_find(
2735
            node, function, args, is_unbound_method, 'find', +1)
2736

2737
    def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
2738
        return self._inject_unicode_find(
2739
            node, function, args, is_unbound_method, 'rfind', -1)
2740

2741
    def _inject_unicode_find(self, node, function, args, is_unbound_method,
2742 2743 2744 2745 2746 2747 2748
                             method_name, direction):
        """Replace unicode.find(...) and unicode.rfind(...) by a
        direct call to the corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
            return node
2749 2750 2751 2752
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2753 2754 2755 2756
        args.append(ExprNodes.IntNode(
            node.pos, value=str(direction), type=PyrexTypes.c_int_type))

        method_call = self._substitute_method_call(
2757
            node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2758
            method_name, is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
2759
        return method_call.coerce_to_pyobject(self.current_env())
2760

Stefan Behnel's avatar
Stefan Behnel committed
2761 2762 2763 2764 2765 2766 2767 2768 2769
    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
        PyrexTypes.c_py_ssize_t_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
            ],
        exception_value = '-1')

2770
    def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2771 2772 2773 2774 2775 2776
        """Replace unicode.count(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (2,3,4):
            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
            return node
2777 2778 2779 2780
        self._inject_int_default_argument(
            node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
Stefan Behnel's avatar
Stefan Behnel committed
2781 2782

        method_call = self._substitute_method_call(
2783
            node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
2784
            'count', is_unbound_method, args)
Stefan Behnel's avatar
Stefan Behnel committed
2785
        return method_call.coerce_to_pyobject(self.current_env())
Stefan Behnel's avatar
Stefan Behnel committed
2786

Stefan Behnel's avatar
Stefan Behnel committed
2787 2788 2789 2790 2791 2792 2793 2794
    PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
            ])

2795
    def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2796 2797 2798 2799 2800 2801
        """Replace unicode.replace(...) by a direct call to the
        corresponding C-API function.
        """
        if len(args) not in (3,4):
            self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
            return node
2802 2803
        self._inject_int_default_argument(
            node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
Stefan Behnel's avatar
Stefan Behnel committed
2804 2805

        return self._substitute_method_call(
2806
            node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
2807 2808
            'replace', is_unbound_method, args)

2809 2810
    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
2811
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2812 2813
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2814
            ])
2815 2816 2817

    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
        Builtin.bytes_type, [
2818
            PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2819
            ])
2820 2821 2822 2823

    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
                          'unicode_escape', 'raw_unicode_escape']

2824 2825
    _special_codecs = [ (name, codecs.getencoder(name))
                        for name in _special_encodings ]
2826

2827
    def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2828 2829 2830
        """Replace unicode.encode(...) by a direct C-API call to the
        corresponding codec.
        """
2831
        if len(args) < 1 or len(args) > 3:
2832
            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2833 2834 2835 2836 2837
            return node

        string_node = args[0]

        if len(args) == 1:
2838
            null_node = ExprNodes.NullNode(node.pos)
2839
            return self._substitute_method_call(
2840
                node, function, "PyUnicode_AsEncodedString",
2841 2842 2843
                self.PyUnicode_AsEncodedString_func_type,
                'encode', is_unbound_method, [string_node, null_node, null_node])

2844 2845 2846 2847 2848
        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

2849
        if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861
            # constant, so try to do the encoding at compile time
            try:
                value = string_node.value.encode(encoding, error_handling)
            except:
                # well, looks like we can't
                pass
            else:
                value = BytesLiteral(value)
                value.encoding = encoding
                return ExprNodes.BytesNode(
                    string_node.pos, value=value, type=Builtin.bytes_type)

2862
        if encoding and error_handling == 'strict':
2863 2864 2865 2866 2867
            # try to find a specific encoder function
            codec_name = self._find_special_codec_name(encoding)
            if codec_name is not None:
                encode_function = "PyUnicode_As%sString" % codec_name
                return self._substitute_method_call(
2868
                    node, function, encode_function,
2869 2870 2871 2872
                    self.PyUnicode_AsXyzString_func_type,
                    'encode', is_unbound_method, [string_node])

        return self._substitute_method_call(
2873
            node, function, "PyUnicode_AsEncodedString",
2874 2875 2876 2877
            self.PyUnicode_AsEncodedString_func_type,
            'encode', is_unbound_method,
            [string_node, encoding_node, error_handling_node])

2878
    PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
2879 2880 2881 2882
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2883
            ]))
2884

2885
    _decode_c_string_func_type = PyrexTypes.CFuncType(
2886 2887
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2888 2889
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
2890 2891
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2892
            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
2893
            ])
2894

Stefan Behnel's avatar
Stefan Behnel committed
2895 2896 2897 2898 2899 2900 2901 2902 2903 2904
    _decode_bytes_func_type = PyrexTypes.CFuncType(
        Builtin.unicode_type, [
            PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
            PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
            ])

Stefan Behnel's avatar
Stefan Behnel committed
2905
    _decode_cpp_string_func_type = None  # lazy init
2906

2907
    def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
Stefan Behnel's avatar
Stefan Behnel committed
2908
        """Replace char*.decode() by a direct C-API call to the
Stefan Behnel's avatar
Stefan Behnel committed
2909
        corresponding codec, possibly resolving a slice on the char*.
Stefan Behnel's avatar
Stefan Behnel committed
2910
        """
2911
        if not (1 <= len(args) <= 3):
2912 2913
            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
            return node
2914 2915

        # normalise input nodes
Stefan Behnel's avatar
Stefan Behnel committed
2916 2917 2918 2919
        string_node = args[0]
        start = stop = None
        if isinstance(string_node, ExprNodes.SliceIndexNode):
            index_node = string_node
2920 2921 2922 2923
            string_node = index_node.base
            start, stop = index_node.start, index_node.stop
            if not start or start.constant_result == 0:
                start = None
Stefan Behnel's avatar
Stefan Behnel committed
2924 2925 2926 2927
        if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
            string_node = string_node.arg

        string_type = string_node.type
Stefan Behnel's avatar
Stefan Behnel committed
2928
        if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
Stefan Behnel's avatar
Stefan Behnel committed
2929 2930 2931
            if is_unbound_method:
                string_node = string_node.as_none_safe_node(
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
Stefan Behnel's avatar
Stefan Behnel committed
2932
                    format_args=['decode', string_type.name])
Stefan Behnel's avatar
Stefan Behnel committed
2933 2934 2935
            else:
                string_node = string_node.as_none_safe_node(
                    "'NoneType' object has no attribute '%s'",
Stefan Behnel's avatar
Stefan Behnel committed
2936 2937
                    error="PyExc_AttributeError",
                    format_args=['decode'])
Stefan Behnel's avatar
Stefan Behnel committed
2938
        elif not string_type.is_string and not string_type.is_cpp_string:
2939 2940
            # nothing to optimise here
            return node
2941 2942 2943 2944 2945 2946

        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
        if parameters is None:
            return node
        encoding, encoding_node, error_handling, error_handling_node = parameters

2947 2948 2949 2950 2951 2952 2953
        if not start:
            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
        elif not start.type.is_int:
            start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
        if stop and not stop.type.is_int:
            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())

2954
        # try to find a specific encoder function
2955 2956 2957
        codec_name = None
        if encoding is not None:
            codec_name = self._find_special_codec_name(encoding)
2958
        if codec_name is not None:
2959 2960 2961 2962
            decode_function = ExprNodes.RawCNameExprNode(
                node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type,
                cname="PyUnicode_Decode%s" % codec_name)
            encoding_node = ExprNodes.NullNode(node.pos)
2963
        else:
2964 2965 2966 2967
            decode_function = ExprNodes.NullNode(node.pos)

        # build the helper function call
        temps = []
Stefan Behnel's avatar
Stefan Behnel committed
2968
        if string_type.is_string:
2969 2970 2971 2972 2973 2974 2975 2976
            # C string
            if not stop:
                # use strlen() to find the string length, just as CPython would
                if not string_node.is_name:
                    string_node = UtilNodes.LetRefNode(string_node) # used twice
                    temps.append(string_node)
                stop = ExprNodes.PythonCapiCallNode(
                    string_node.pos, "strlen", self.Pyx_strlen_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
2977 2978 2979 2980
                    args=[string_node],
                    is_temp=False,
                    utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
                ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2981 2982
            helper_func_type = self._decode_c_string_func_type
            utility_code_name = 'decode_c_string'
Stefan Behnel's avatar
Stefan Behnel committed
2983
        elif string_type.is_cpp_string:
2984 2985 2986 2987 2988 2989 2990 2991
            # C++ std::string
            if not stop:
                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
                                         constant_result=ExprNodes.not_a_constant)
            if self._decode_cpp_string_func_type is None:
                # lazy init to reuse the C++ string type
                self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
                    Builtin.unicode_type, [
Stefan Behnel's avatar
Stefan Behnel committed
2992
                        PyrexTypes.CFuncTypeArg("string", string_type, None),
2993 2994 2995 2996 2997
                        PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
                        PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
                        PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
                        PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
                        PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
Stefan Behnel's avatar
Stefan Behnel committed
2998
                    ])
2999 3000
            helper_func_type = self._decode_cpp_string_func_type
            utility_code_name = 'decode_cpp_string'
Stefan Behnel's avatar
Stefan Behnel committed
3001
        else:
Stefan Behnel's avatar
Stefan Behnel committed
3002
            # Python bytes/bytearray object
Stefan Behnel's avatar
Stefan Behnel committed
3003 3004 3005 3006
            if not stop:
                stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
                                         constant_result=ExprNodes.not_a_constant)
            helper_func_type = self._decode_bytes_func_type
Stefan Behnel's avatar
Stefan Behnel committed
3007 3008 3009 3010
            if string_type is Builtin.bytes_type:
                utility_code_name = 'decode_bytes'
            else:
                utility_code_name = 'decode_bytearray'
3011 3012 3013

        node = ExprNodes.PythonCapiCallNode(
            node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
Stefan Behnel's avatar
Stefan Behnel committed
3014 3015
            args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
            is_temp=node.is_temp,
3016
            utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
Stefan Behnel's avatar
Stefan Behnel committed
3017
        )
3018

3019 3020 3021
        for temp in temps[::-1]:
            node = UtilNodes.EvalWithTempExprNode(temp, node)
        return node
3022

Stefan Behnel's avatar
Stefan Behnel committed
3023 3024
    _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode

3025 3026 3027
    def _find_special_codec_name(self, encoding):
        try:
            requested_codec = codecs.getencoder(encoding)
Stefan Behnel's avatar
Stefan Behnel committed
3028
        except LookupError:
3029 3030 3031 3032
            return None
        for name, codec in self._special_codecs:
            if codec == requested_codec:
                if '_' in name:
Stefan Behnel's avatar
Stefan Behnel committed
3033 3034
                    name = ''.join([s.capitalize()
                                    for s in name.split('_')])
3035 3036 3037 3038
                return name
        return None

    def _unpack_encoding_and_error_mode(self, pos, args):
3039 3040 3041
        null_node = ExprNodes.NullNode(pos)

        if len(args) >= 2:
3042 3043
            encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
            if encoding_node is None:
3044
                return None
3045
        else:
3046 3047
            encoding = None
            encoding_node = null_node
3048 3049

        if len(args) == 3:
3050 3051
            error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
            if error_handling_node is None:
3052
                return None
3053 3054
            if error_handling == 'strict':
                error_handling_node = null_node
3055 3056 3057 3058
        else:
            error_handling = 'strict'
            error_handling_node = null_node

3059
        return (encoding, encoding_node, error_handling, error_handling_node)
3060

3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072 3073 3074 3075 3076 3077 3078
    def _unpack_string_and_cstring_node(self, node):
        if isinstance(node, ExprNodes.CoerceToPyTypeNode):
            node = node.arg
        if isinstance(node, ExprNodes.UnicodeNode):
            encoding = node.value
            node = ExprNodes.BytesNode(
                node.pos, value=BytesLiteral(encoding.utf8encode()),
                type=PyrexTypes.c_char_ptr_type)
        elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
            encoding = node.value.decode('ISO-8859-1')
            node = ExprNodes.BytesNode(
                node.pos, value=node.value, type=PyrexTypes.c_char_ptr_type)
        elif node.type is Builtin.bytes_type:
            encoding = None
            node = node.coerce_to(PyrexTypes.c_char_ptr_type, self.current_env())
        elif node.type.is_string:
            encoding = None
        else:
3079
            encoding = node = None
3080 3081
        return encoding, node

3082
    def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
3083
        return self._inject_tailmatch(
3084
            node, function, args, is_unbound_method, 'str', 'endswith',
3085
            str_tailmatch_utility_code, +1)
3086

3087
    def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
3088
        return self._inject_tailmatch(
3089
            node, function, args, is_unbound_method, 'str', 'startswith',
3090
            str_tailmatch_utility_code, -1)
3091

3092
    def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
3093
        return self._inject_tailmatch(
3094
            node, function, args, is_unbound_method, 'bytes', 'endswith',
3095
            bytes_tailmatch_utility_code, +1)
3096

3097
    def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
3098
        return self._inject_tailmatch(
3099
            node, function, args, is_unbound_method, 'bytes', 'startswith',
3100
            bytes_tailmatch_utility_code, -1)
3101

Stefan Behnel's avatar
Stefan Behnel committed
3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113
    '''   # disabled for now, enable when we consider it worth it (see StringTools.c)
    def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
        return self._inject_tailmatch(
            node, function, args, is_unbound_method, 'bytearray', 'endswith',
            bytes_tailmatch_utility_code, +1)

    def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
        return self._inject_tailmatch(
            node, function, args, is_unbound_method, 'bytearray', 'startswith',
            bytes_tailmatch_utility_code, -1)
    '''

3114 3115
    ### helpers

3116
    def _substitute_method_call(self, node, function, name, func_type,
3117
                                attr_name, is_unbound_method, args=(),
3118
                                utility_code=None, is_temp=None,
Stefan Behnel's avatar
Stefan Behnel committed
3119
                                may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
3120
        args = list(args)
3121
        if args and not args[0].is_literal:
3122 3123
            self_arg = args[0]
            if is_unbound_method:
3124
                self_arg = self_arg.as_none_safe_node(
3125
                    "descriptor '%s' requires a '%s' object but received a 'NoneType'",
3126
                    format_args=[attr_name, function.obj.name])
3127
            else:
3128
                self_arg = self_arg.as_none_safe_node(
3129 3130 3131
                    "'NoneType' object has no attribute '%s'",
                    error = "PyExc_AttributeError",
                    format_args = [attr_name])
3132
            args[0] = self_arg
3133 3134
        if is_temp is None:
            is_temp = node.is_temp
3135
        return ExprNodes.PythonCapiCallNode(
3136
            node.pos, name, func_type,
3137
            args = args,
3138
            is_temp = is_temp,
Stefan Behnel's avatar
Stefan Behnel committed
3139 3140
            utility_code = utility_code,
            may_return_none = may_return_none,
3141
            result_is_used = node.result_is_used,
3142 3143
            )

3144 3145 3146
    def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
3147 3148
            args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
                                          type=type, constant_result=default_value))
3149
        else:
3150
            args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
3151 3152 3153 3154

    def _inject_bint_default_argument(self, node, args, arg_index, default_value):
        assert len(args) >= arg_index
        if len(args) == arg_index:
3155 3156 3157
            default_value = bool(default_value)
            args.append(ExprNodes.BoolNode(node.pos, value=default_value,
                                           constant_result=default_value))
3158
        else:
3159
            args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
3160

3161

3162 3163 3164
unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
3165

3166

3167 3168 3169 3170
class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
    """Calculate the result of constant expressions to store it in
    ``expr_node.constant_result``, and replace trivial cases by their
    constant result.
3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181

    General rules:

    - We calculate float constants to make them available to the
      compiler, but we do not aggregate them into a single literal
      node to prevent any loss of precision.

    - We recursively calculate constants from non-literal nodes to
      make them available to the compiler, but we only aggregate
      literal nodes at each step.  Non-literal nodes are never merged
      into a single node.
3182
    """
3183

Mark Florisson's avatar
Mark Florisson committed
3184 3185 3186 3187 3188 3189 3190
    def __init__(self, reevaluate=False):
        """
        The reevaluate argument specifies whether constant values that were
        previously computed should be recomputed.
        """
        super(ConstantFolding, self).__init__()
        self.reevaluate = reevaluate
3191

3192
    def _calculate_const(self, node):
Mark Florisson's avatar
Mark Florisson committed
3193
        if (not self.reevaluate and
3194
                node.constant_result is not ExprNodes.constant_value_not_set):
3195 3196 3197 3198 3199 3200 3201 3202
            return

        # make sure we always set the value
        not_a_constant = ExprNodes.not_a_constant
        node.constant_result = not_a_constant

        # check if all children are constant
        children = self.visitchildren(node)
3203
        for child_result in children.values():
3204 3205
            if type(child_result) is list:
                for child in child_result:
Stefan Behnel's avatar
Stefan Behnel committed
3206
                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
3207
                        return
Stefan Behnel's avatar
Stefan Behnel committed
3208
            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
3209 3210 3211 3212 3213 3214 3215
                return

        # now try to calculate the real constant value
        try:
            node.calculate_constant_result()
#            if node.constant_result is not ExprNodes.not_a_constant:
#                print node.__class__.__name__, node.constant_result
3216
        except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3217 3218 3219 3220 3221 3222 3223
            # ignore all 'normal' errors here => no constant result
            pass
        except Exception:
            # this looks like a real error
            import traceback, sys
            traceback.print_exc(file=sys.stdout)

3224 3225
    NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
                       ExprNodes.IntNode, ExprNodes.FloatNode]
3226 3227 3228 3229 3230 3231 3232 3233

    def _widest_node_class(self, *nodes):
        try:
            return self.NODE_TYPE_ORDER[
                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
        except ValueError:
            return None

3234 3235 3236 3237
    def _bool_node(self, node, value):
        value = bool(value)
        return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)

3238 3239 3240 3241
    def visit_ExprNode(self, node):
        self._calculate_const(node)
        return node

3242
    def visit_UnopNode(self, node):
3243
        self._calculate_const(node)
Stefan Behnel's avatar
Stefan Behnel committed
3244
        if not node.has_constant_result():
3245
            if node.operator == '!':
3246
                return self._handle_NotNode(node)
3247 3248 3249
            return node
        if not node.operand.is_literal:
            return node
Stefan Behnel's avatar
Stefan Behnel committed
3250
        if node.operator == '!':
3251
            return self._bool_node(node, node.constant_result)
3252
        elif isinstance(node.operand, ExprNodes.BoolNode):
Stefan Behnel's avatar
Stefan Behnel committed
3253 3254 3255
            return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
                                     type=PyrexTypes.c_int_type,
                                     constant_result=int(node.constant_result))
3256
        elif node.operator == '+':
3257 3258 3259 3260 3261
            return self._handle_UnaryPlusNode(node)
        elif node.operator == '-':
            return self._handle_UnaryMinusNode(node)
        return node

3262 3263 3264 3265 3266 3267 3268
    _negate_operator = {
        'in': 'not_in',
        'not_in': 'in',
        'is': 'is_not',
        'is_not': 'is'
    }.get

3269
    def _handle_NotNode(self, node):
3270 3271 3272 3273 3274 3275 3276
        operand = node.operand
        if isinstance(operand, ExprNodes.PrimaryCmpNode):
            operator = self._negate_operator(operand.operator)
            if operator:
                node = copy.copy(operand)
                node.operator = operator
                node = self.visit_PrimaryCmpNode(node)
3277 3278
        return node

3279
    def _handle_UnaryMinusNode(self, node):
3280 3281 3282 3283 3284 3285 3286
        def _negate(value):
            if value.startswith('-'):
                value = value[1:]
            else:
                value = '-' + value
            return value

3287
        node_type = node.operand.type
3288 3289
        if isinstance(node.operand, ExprNodes.FloatNode):
            # this is a safe operation
3290
            return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
3291
                                       type=node_type,
3292
                                       constant_result=node.constant_result)
3293
        if node_type.is_int and node_type.signed or \
3294 3295 3296 3297 3298
                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
            return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
                                     type=node_type,
                                     longness=node.operand.longness,
                                     constant_result=node.constant_result)
3299 3300
        return node

3301
    def _handle_UnaryPlusNode(self, node):
3302 3303
        if (node.operand.has_constant_result() and
                    node.constant_result == node.operand.constant_result):
3304 3305 3306
            return node.operand
        return node

3307 3308
    def visit_BoolBinopNode(self, node):
        self._calculate_const(node)
Stefan Behnel's avatar
Stefan Behnel committed
3309
        if not node.operand1.has_constant_result():
3310
            return node
Stefan Behnel's avatar
Stefan Behnel committed
3311
        if node.operand1.constant_result:
3312 3313 3314 3315
            if node.operator == 'and':
                return node.operand2
            else:
                return node.operand1
3316
        else:
3317 3318 3319 3320
            if node.operator == 'and':
                return node.operand1
            else:
                return node.operand2
3321

3322 3323 3324 3325
    def visit_BinopNode(self, node):
        self._calculate_const(node)
        if node.constant_result is ExprNodes.not_a_constant:
            return node
3326 3327
        if isinstance(node.constant_result, float):
            return node
3328 3329
        operand1, operand2 = node.operand1, node.operand2
        if not operand1.is_literal or not operand2.is_literal:
3330 3331 3332
            return node

        # now inject a new constant node with the calculated value
3333
        try:
3334
            type1, type2 = operand1.type, operand2.type
3335
            if type1 is None or type2 is None:
3336 3337 3338 3339
                return node
        except AttributeError:
            return node

3340
        if type1.is_numeric and type2.is_numeric:
3341
            widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3342 3343
        else:
            widest_type = PyrexTypes.py_object_type
3344

3345
        target_class = self._widest_node_class(operand1, operand2)
3346 3347
        if target_class is None:
            return node
Stefan Behnel's avatar
Stefan Behnel committed
3348
        elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
3349 3350
            # C arithmetic results in at least an int type
            target_class = ExprNodes.IntNode
Stefan Behnel's avatar
Stefan Behnel committed
3351
        elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
3352 3353 3354 3355
            # C arithmetic results in at least an int type
            target_class = ExprNodes.IntNode

        if target_class is ExprNodes.IntNode:
3356 3357 3358 3359
            unsigned = getattr(operand1, 'unsigned', '') and \
                       getattr(operand2, 'unsigned', '')
            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
                                 len(getattr(operand2, 'longness', '')))]
3360
            new_node = ExprNodes.IntNode(pos=node.pos,
3361 3362 3363
                                         unsigned=unsigned, longness=longness,
                                         value=str(int(node.constant_result)),
                                         constant_result=int(node.constant_result))
3364 3365 3366 3367
            # IntNode is smart about the type it chooses, so we just
            # make sure we were not smarter this time
            if widest_type.is_pyobject or new_node.type.is_pyobject:
                new_node.type = PyrexTypes.py_object_type
3368
            else:
3369
                new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3370
        else:
3371
            if target_class is ExprNodes.BoolNode:
3372 3373 3374 3375 3376 3377
                node_value = node.constant_result
            else:
                node_value = str(node.constant_result)
            new_node = target_class(pos=node.pos, type = widest_type,
                                    value = node_value,
                                    constant_result = node.constant_result)
3378 3379
        return new_node

3380
    def visit_MulNode(self, node):
3381
        self._calculate_const(node)
3382 3383
        if node.operand1.is_sequence_constructor:
            return self._calculate_constant_seq(node, node.operand1, node.operand2)
3384
        if isinstance(node.operand1, ExprNodes.IntNode) and \
3385 3386
                node.operand2.is_sequence_constructor:
            return self._calculate_constant_seq(node, node.operand2, node.operand1)
3387 3388
        return self.visit_BinopNode(node)

3389
    def _calculate_constant_seq(self, node, sequence_node, factor):
3390 3391 3392
        if factor.constant_result != 1 and sequence_node.args:
            if isinstance(factor.constant_result, (int, long)) and factor.constant_result <= 0:
                del sequence_node.args[:]
3393 3394 3395 3396 3397 3398 3399 3400 3401 3402 3403 3404 3405
                sequence_node.mult_factor = None
            elif sequence_node.mult_factor is not None:
                if (isinstance(factor.constant_result, (int, long)) and
                        isinstance(sequence_node.mult_factor.constant_result, (int, long))):
                    value = sequence_node.mult_factor.constant_result * factor.constant_result
                    sequence_node.mult_factor = ExprNodes.IntNode(
                        sequence_node.mult_factor.pos,
                        value=str(value), constant_result=value)
                else:
                    # don't know if we can combine the factors, so don't
                    return self.visit_BinopNode(node)
            else:
                sequence_node.mult_factor = factor
3406 3407
        return sequence_node

3408
    def visit_PrimaryCmpNode(self, node):
3409
        # calculate constant partial results in the comparison cascade
3410
        self.visitchildren(node, ['operand1'])
3411 3412 3413
        left_node = node.operand1
        cmp_node = node
        while cmp_node is not None:
3414
            self.visitchildren(cmp_node, ['operand2'])
3415 3416 3417 3418 3419 3420 3421 3422 3423 3424
            right_node = cmp_node.operand2
            cmp_node.constant_result = not_a_constant
            if left_node.has_constant_result() and right_node.has_constant_result():
                try:
                    cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
                except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
                    pass  # ignore all 'normal' errors here => no constant result
            left_node = right_node
            cmp_node = cmp_node.cascade

3425 3426 3427 3428 3429
        if not node.cascade:
            if node.has_constant_result():
                return self._bool_node(node, node.constant_result)
            return node

3430 3431
        # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
        cascades = [[node.operand1]]
3432
        final_false_result = []
3433 3434 3435 3436 3437

        def split_cascades(cmp_node):
            if cmp_node.has_constant_result():
                if not cmp_node.constant_result:
                    # False => short-circuit
3438
                    final_false_result.append(self._bool_node(cmp_node, False))
3439 3440 3441 3442 3443 3444 3445 3446 3447 3448 3449 3450 3451 3452 3453 3454 3455 3456 3457 3458 3459 3460 3461 3462 3463 3464 3465 3466 3467 3468 3469
                    return
                else:
                    # True => discard and start new cascade
                    cascades.append([cmp_node.operand2])
            else:
                # not constant => append to current cascade
                cascades[-1].append(cmp_node)
            if cmp_node.cascade:
                split_cascades(cmp_node.cascade)

        split_cascades(node)

        cmp_nodes = []
        for cascade in cascades:
            if len(cascade) < 2:
                continue
            cmp_node = cascade[1]
            pcmp_node = ExprNodes.PrimaryCmpNode(
                cmp_node.pos,
                operand1=cascade[0],
                operator=cmp_node.operator,
                operand2=cmp_node.operand2,
                constant_result=not_a_constant)
            cmp_nodes.append(pcmp_node)

            last_cmp_node = pcmp_node
            for cmp_node in cascade[2:]:
                last_cmp_node.cascade = cmp_node
                last_cmp_node = cmp_node
            last_cmp_node.cascade = None

3470
        if final_false_result:
3471
            # last cascade was constant False
3472
            cmp_nodes.append(final_false_result[0])
3473
        elif not cmp_nodes:
3474 3475 3476 3477 3478 3479 3480 3481 3482 3483 3484 3485 3486 3487
            # only constants, but no False result
            return self._bool_node(node, True)
        node = cmp_nodes[0]
        if len(cmp_nodes) == 1:
            if node.has_constant_result():
                return self._bool_node(node, node.constant_result)
        else:
            for cmp_node in cmp_nodes[1:]:
                node = ExprNodes.BoolBinopNode(
                    node.pos,
                    operand1=node,
                    operator='and',
                    operand2=cmp_node,
                    constant_result=not_a_constant)
3488
        return node
3489

3490 3491
    def visit_CondExprNode(self, node):
        self._calculate_const(node)
3492
        if not node.test.has_constant_result():
3493 3494 3495 3496 3497 3498
            return node
        if node.test.constant_result:
            return node.true_val
        else:
            return node.false_val

3499 3500 3501 3502 3503
    def visit_IfStatNode(self, node):
        self.visitchildren(node)
        # eliminate dead code based on constant condition results
        if_clauses = []
        for if_clause in node.if_clauses:
Stefan Behnel's avatar
Stefan Behnel committed
3504 3505 3506 3507 3508 3509 3510 3511
            condition = if_clause.condition
            if condition.has_constant_result():
                if condition.constant_result:
                    # always true => subsequent clauses can safely be dropped
                    node.else_clause = if_clause.body
                    break
                # else: false => drop clause
            else:
3512
                # unknown result => normal runtime evaluation
3513
                if_clauses.append(if_clause)
3514 3515 3516 3517
        if if_clauses:
            node.if_clauses = if_clauses
            return node
        elif node.else_clause:
3518
            return node.else_clause
3519 3520
        else:
            return Nodes.StatListNode(node.pos, stats=[])
3521

3522 3523 3524
    def visit_SliceIndexNode(self, node):
        self._calculate_const(node)
        # normalise start/stop values
3525 3526 3527 3528 3529 3530 3531 3532 3533 3534 3535
        if node.start is None or node.start.constant_result is None:
            start = node.start = None
        else:
            start = node.start.constant_result
        if node.stop is None or node.stop.constant_result is None:
            stop = node.stop = None
        else:
            stop = node.stop.constant_result
        # cut down sliced constant sequences
        if node.constant_result is not not_a_constant:
            base = node.base
3536
            if base.is_sequence_constructor and base.mult_factor is None:
3537 3538 3539
                base.args = base.args[start:stop]
                return base
            elif base.is_string_literal:
3540 3541 3542
                base = base.as_sliced_node(start, stop)
                if base is not None:
                    return base
3543 3544
        return node

3545 3546 3547 3548
    def visit_ComprehensionNode(self, node):
        self.visitchildren(node)
        if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
            # loop was pruned already => transform into literal
3549
            if node.type is Builtin.list_type:
3550 3551
                return ExprNodes.ListNode(
                    node.pos, args=[], constant_result=[])
3552
            elif node.type is Builtin.set_type:
3553 3554
                return ExprNodes.SetNode(
                    node.pos, args=[], constant_result=set())
3555
            elif node.type is Builtin.dict_type:
3556 3557
                return ExprNodes.DictNode(
                    node.pos, key_value_pairs=[], constant_result={})
3558 3559
        return node

3560 3561 3562
    def visit_ForInStatNode(self, node):
        self.visitchildren(node)
        sequence = node.iterator.sequence
3563 3564 3565 3566 3567 3568 3569
        if isinstance(sequence, ExprNodes.SequenceNode):
            if not sequence.args:
                if node.else_clause:
                    return node.else_clause
                else:
                    # don't break list comprehensions
                    return Nodes.StatListNode(node.pos, stats=[])
Stefan Behnel's avatar
Stefan Behnel committed
3570 3571 3572
            # iterating over a list literal? => tuples are more efficient
            if isinstance(sequence, ExprNodes.ListNode):
                node.iterator.sequence = sequence.as_tuple()
3573 3574
        return node

3575 3576
    def visit_WhileStatNode(self, node):
        self.visitchildren(node)
Stefan Behnel's avatar
Stefan Behnel committed
3577
        if node.condition and node.condition.has_constant_result():
3578
            if node.condition.constant_result:
3579
                node.condition = None
3580 3581 3582 3583 3584
                node.else_clause = None
            else:
                return node.else_clause
        return node

3585 3586
    def visit_ExprStatNode(self, node):
        self.visitchildren(node)
Stefan Behnel's avatar
Stefan Behnel committed
3587 3588 3589
        if not isinstance(node.expr, ExprNodes.ExprNode):
            # ParallelRangeTransform does this ...
            return node
3590 3591 3592 3593 3594
        # drop unused constant expressions
        if node.expr.has_constant_result():
            return None
        return node

3595 3596
    # in the future, other nodes can have their own handler method here
    # that can replace them with a constant result node
Stefan Behnel's avatar
Stefan Behnel committed
3597

3598
    visit_Node = Visitor.VisitorTransform.recurse_to_children
3599 3600


3601 3602 3603
class FinalOptimizePhase(Visitor.CythonTransform):
    """
    This visitor handles several commuting optimizations, and is run
3604 3605 3606 3607
    just before the C code generation phase.

    The optimizations currently implemented in this class are:
        - eliminate None assignment and refcounting for first assignment.
3608
        - isinstance -> typecheck for cdef types
Stefan Behnel's avatar
Stefan Behnel committed
3609
        - eliminate checks for None and/or types that became redundant after tree changes
3610
    """
3611
    def visit_SingleAssignmentNode(self, node):
3612 3613 3614 3615
        """Avoid redundant initialisation of local variables before their
        first assignment.
        """
        self.visitchildren(node)
3616 3617
        if node.first:
            lhs = node.lhs
3618
            lhs.lhs_of_first_assignment = True
3619
        return node
3620

3621
    def visit_SimpleCallNode(self, node):
3622 3623 3624
        """Replace generic calls to isinstance(x, type) by a more efficient
        type check.
        """
3625
        self.visitchildren(node)
Robert Bradshaw's avatar
Robert Bradshaw committed
3626
        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3627
            if node.function.name == 'isinstance' and len(node.args) == 2:
3628 3629
                type_arg = node.args[1]
                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3630 3631
                    cython_scope = self.context.cython_scope
                    node.function.entry = cython_scope.lookup('PyObject_TypeCheck')
3632
                    node.function.type = node.function.entry.type
3633
                    PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
3634 3635
                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
        return node
Stefan Behnel's avatar
Stefan Behnel committed
3636 3637 3638 3639 3640 3641 3642 3643 3644 3645 3646

    def visit_PyTypeTestNode(self, node):
        """Remove tests for alternatively allowed None values from
        type tests when we know that the argument cannot be None
        anyway.
        """
        self.visitchildren(node)
        if not node.notnone:
            if not node.arg.may_be_none():
                node.notnone = True
        return node
3647 3648 3649 3650 3651 3652 3653 3654 3655

    def visit_NoneCheckNode(self, node):
        """Remove None checks from expressions that definitely do not
        carry a None value.
        """
        self.visitchildren(node)
        if not node.arg.may_be_none():
            return node.arg
        return node
3656 3657 3658 3659 3660 3661 3662 3663 3664 3665 3666 3667 3668 3669 3670 3671 3672 3673 3674 3675 3676

class ConsolidateOverflowCheck(Visitor.CythonTransform):
    """
    This class facilitates the sharing of overflow checking among all nodes
    of a nested arithmetic expression.  For example, given the expression
    a*b + c, where a, b, and x are all possibly overflowing ints, the entire
    sequence will be evaluated and the overflow bit checked only at the end.
    """
    overflow_bit_node = None
    
    def visit_Node(self, node):
        if self.overflow_bit_node is not None:
            saved = self.overflow_bit_node
            self.overflow_bit_node = None
            self.visitchildren(node)
            self.overflow_bit_node = saved
        else:
            self.visitchildren(node)
        return node
    
    def visit_NumBinopNode(self, node):
3677
        if node.overflow_check and node.overflow_fold:
3678 3679 3680 3681 3682 3683 3684 3685 3686 3687 3688 3689
            top_level_overflow = self.overflow_bit_node is None
            if top_level_overflow:
                self.overflow_bit_node = node
            else:
                node.overflow_bit_node = self.overflow_bit_node
                node.overflow_check = False
            self.visitchildren(node)
            if top_level_overflow:
                self.overflow_bit_node = None
        else:
            self.visitchildren(node)
        return node