Commit 1da05d6e authored by Stefan Behnel's avatar Stefan Behnel

improve type inference for string %/+/* operations and use more direct C-API...

improve type inference for string %/+/* operations and use more direct C-API calls for these unicode operations
parent 08c7bf8b
...@@ -28,8 +28,8 @@ import PyrexTypes ...@@ -28,8 +28,8 @@ import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \ from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \
unspecified_type unspecified_type
import TypeSlots import TypeSlots
from Builtin import list_type, tuple_type, set_type, dict_type, \ from Builtin import list_type, tuple_type, set_type, dict_type, type_type, \
unicode_type, str_type, bytes_type, bytearray_type, type_type unicode_type, str_type, bytes_type, bytearray_type, basestring_type
import Builtin import Builtin
import Symtab import Symtab
from Cython import Utils from Cython import Utils
...@@ -8792,30 +8792,20 @@ class BinopNode(ExprNode): ...@@ -8792,30 +8792,20 @@ class BinopNode(ExprNode):
type1 = Builtin.bytes_type type1 = Builtin.bytes_type
elif type1.is_pyunicode_ptr: elif type1.is_pyunicode_ptr:
type1 = Builtin.unicode_type type1 = Builtin.unicode_type
elif self.operator == '%' \ if type1.is_builtin_type or type2.is_builtin_type:
and type1 in (Builtin.str_type, Builtin.unicode_type): if type1 is type2 and self.operator in '**%+|&^':
# note that b'%s' % b'abc' doesn't work in Py3 # FIXME: at least these operators should be safe - others?
return type1 return type1
if type1.is_builtin_type: result_type = self.infer_builtin_types_operation(type1, type2)
if type1 is type2: if result_type is not None:
if self.operator in '**%+|&^': return result_type
# FIXME: at least these operators should be safe - others?
return type1
elif self.operator == '*':
if type1 in (Builtin.bytes_type, Builtin.str_type, Builtin.unicode_type):
return type1
# multiplication of containers/numbers with an
# integer value always (?) returns the same type
if type2.is_int:
return type1
elif type2.is_builtin_type and type1.is_int and self.operator == '*':
# multiplication of containers/numbers with an
# integer value always (?) returns the same type
return type2
return py_object_type return py_object_type
else: else:
return self.compute_c_result_type(type1, type2) return self.compute_c_result_type(type1, type2)
def infer_builtin_types_operation(self, type1, type2):
return None
def nogil_check(self, env): def nogil_check(self, env):
if self.is_py_operation(): if self.is_py_operation():
self.gil_error() self.gil_error()
...@@ -9019,14 +9009,15 @@ class NumBinopNode(BinopNode): ...@@ -9019,14 +9009,15 @@ class NumBinopNode(BinopNode):
"%": "PyNumber_Remainder", "%": "PyNumber_Remainder",
"**": "PyNumber_Power" "**": "PyNumber_Power"
} }
overflow_op_names = { overflow_op_names = {
"+": "add", "+": "add",
"-": "sub", "-": "sub",
"*": "mul", "*": "mul",
"<<": "lshift", "<<": "lshift",
} }
class IntBinopNode(NumBinopNode): class IntBinopNode(NumBinopNode):
# Binary operation taking integer arguments. # Binary operation taking integer arguments.
...@@ -9045,6 +9036,15 @@ class AddNode(NumBinopNode): ...@@ -9045,6 +9036,15 @@ class AddNode(NumBinopNode):
else: else:
return NumBinopNode.is_py_operation_types(self, type1, type2) return NumBinopNode.is_py_operation_types(self, type1, type2)
def infer_builtin_types_operation(self, type1, type2):
# b'abc' + 'abc' raises an exception in Py3,
# so we can safely infer the Py2 type for bytes here
string_types = [bytes_type, str_type, basestring_type, unicode_type] # Py2.4 lacks tuple.index()
if type1 in string_types and type2 in string_types:
return string_types[max(string_types.index(type1),
string_types.index(type2))]
return None
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
#print "AddNode.compute_c_result_type:", type1, self.operator, type2 ### #print "AddNode.compute_c_result_type:", type1, self.operator, type2 ###
if (type1.is_ptr or type1.is_array) and (type2.is_int or type2.is_enum): if (type1.is_ptr or type1.is_array) and (type2.is_int or type2.is_enum):
...@@ -9055,6 +9055,16 @@ class AddNode(NumBinopNode): ...@@ -9055,6 +9055,16 @@ class AddNode(NumBinopNode):
return NumBinopNode.compute_c_result_type( return NumBinopNode.compute_c_result_type(
self, type1, type2) self, type1, type2)
def py_operation_function(self):
type1, type2 = self.operand1.type, self.operand2.type
if type1 is unicode_type or type2 is unicode_type:
if type1.is_builtin_type and type2.is_builtin_type:
if self.operand1.may_be_none() or self.operand2.may_be_none():
return '__Pyx_PyUnicode_Concat'
else:
return 'PyUnicode_Concat'
return super(AddNode, self).py_operation_function()
class SubNode(NumBinopNode): class SubNode(NumBinopNode):
# '-' operator. # '-' operator.
...@@ -9073,12 +9083,28 @@ class MulNode(NumBinopNode): ...@@ -9073,12 +9083,28 @@ class MulNode(NumBinopNode):
# '*' operator. # '*' operator.
def is_py_operation_types(self, type1, type2): def is_py_operation_types(self, type1, type2):
if (type1.is_string and type2.is_int) \ if ((type1.is_string and type2.is_int) or
or (type2.is_string and type1.is_int): (type2.is_string and type1.is_int)):
return 1 return 1
else: else:
return NumBinopNode.is_py_operation_types(self, type1, type2) return NumBinopNode.is_py_operation_types(self, type1, type2)
def infer_builtin_types_operation(self, type1, type2):
# let's assume that whatever builtin type you multiply a string with
# will either return a string of the same type or fail with an exception
string_types = (bytes_type, str_type, basestring_type, unicode_type)
if type1 in string_types and type2.is_builtin_type:
return type1
if type2 in string_types and type1.is_builtin_type:
return type2
# multiplication of containers/numbers with an integer value
# always (?) returns the same type
if type1.is_int:
return type2
if type2.is_int:
return type1
return None
class DivNode(NumBinopNode): class DivNode(NumBinopNode):
# '/' or '//' operator. # '/' or '//' operator.
...@@ -9218,9 +9244,9 @@ class DivNode(NumBinopNode): ...@@ -9218,9 +9244,9 @@ class DivNode(NumBinopNode):
return "(%s / %s)" % (op1, op2) return "(%s / %s)" % (op1, op2)
else: else:
return "__Pyx_div_%s(%s, %s)" % ( return "__Pyx_div_%s(%s, %s)" % (
self.type.specialization_name(), self.type.specialization_name(),
self.operand1.result(), self.operand1.result(),
self.operand2.result()) self.operand2.result())
class ModNode(DivNode): class ModNode(DivNode):
...@@ -9228,8 +9254,25 @@ class ModNode(DivNode): ...@@ -9228,8 +9254,25 @@ class ModNode(DivNode):
def is_py_operation_types(self, type1, type2): def is_py_operation_types(self, type1, type2):
return (type1.is_string return (type1.is_string
or type2.is_string or type2.is_string
or NumBinopNode.is_py_operation_types(self, type1, type2)) or NumBinopNode.is_py_operation_types(self, type1, type2))
def infer_builtin_types_operation(self, type1, type2):
# b'%s' % xyz raises an exception in Py3, so it's safe to infer the type for Py2
if type1 is unicode_type:
# None + xyz may be implemented by RHS
if type2.is_builtin_type or not self.operand1.may_be_none():
return type1
elif type1 in (bytes_type, str_type, basestring_type):
if type2 is unicode_type:
return type2
elif type2.is_numeric:
return type1
elif type1 is bytes_type and not type2.is_builtin_type:
return None # RHS might implement '% operator differently in Py3
else:
return basestring_type # either str or unicode, can't tell
return None
def zero_division_message(self): def zero_division_message(self):
if self.type.is_int: if self.type.is_int:
...@@ -9275,6 +9318,15 @@ class ModNode(DivNode): ...@@ -9275,6 +9318,15 @@ class ModNode(DivNode):
self.operand1.result(), self.operand1.result(),
self.operand2.result()) self.operand2.result())
def py_operation_function(self):
if self.operand1.type is unicode_type and self.operand2.type.is_builtin_type:
if self.operand1.may_be_none():
return '__Pyx_PyUnicode_Format'
else:
return 'PyUnicode_Format'
return super(ModNode, self).py_operation_function()
class PowNode(NumBinopNode): class PowNode(NumBinopNode):
# '**' operator. # '**' operator.
......
...@@ -157,6 +157,10 @@ ...@@ -157,6 +157,10 @@
#define __Pyx_PyUnicode_READ(k, d, i) ((k=k), (Py_UCS4)(((Py_UNICODE*)d)[i])) #define __Pyx_PyUnicode_READ(k, d, i) ((k=k), (Py_UCS4)(((Py_UNICODE*)d)[i]))
#endif #endif
#define __Pyx_PyUnicode_Format(a, b) ((unlikely((a) == Py_None)) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b))
#define __Pyx_PyUnicode_Concat(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ? \
PyNumber_Add(a, b) : PyUnicode_Concat(a, b))
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
#define PyBaseString_Type PyUnicode_Type #define PyBaseString_Type PyUnicode_Type
#define PyStringObject PyUnicodeObject #define PyStringObject PyUnicodeObject
......
...@@ -8,6 +8,9 @@ PY_VERSION = sys.version_info ...@@ -8,6 +8,9 @@ PY_VERSION = sys.version_info
text = u'ab jd sdflk as sa sadas asdas fsdf ' text = u'ab jd sdflk as sa sadas asdas fsdf '
sep = u' ' sep = u' '
format1 = u'abc%sdef'
format2 = u'abc%sdef%sghi'
unicode_sa = u'sa'
multiline_text = u'''\ multiline_text = u'''\
ab jd ab jd
...@@ -383,6 +386,122 @@ def in_test(unicode s, substring): ...@@ -383,6 +386,122 @@ def in_test(unicode s, substring):
return substring in s return substring in s
# unicode.__concat__(s, suffix)
def concat_any(unicode s, suffix):
"""
>>> concat(text, 'sa') == text + 'sa' or concat(text, 'sa')
True
>>> concat(None, 'sa') # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...
>>> concat(text, None) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...
>>> class RAdd(object):
... def __radd__(self, other):
... return 123
>>> concat(None, 'sa') # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...
"""
assert cython.typeof(s + suffix) == 'Python object', cython.typeof(s + suffix)
return s + suffix
def concat(unicode s, str suffix):
"""
>>> concat(text, 'sa') == text + 'sa' or concat(text, 'sa')
True
>>> concat(None, 'sa') # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...
>>> concat(text, None) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...
>>> class RAdd(object):
... def __radd__(self, other):
... return 123
>>> concat(None, 'sa') # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...
"""
assert cython.typeof(s + object()) == 'Python object', cython.typeof(s + object())
assert cython.typeof(s + suffix) == 'unicode object', cython.typeof(s + suffix)
return s + suffix
def concat_literal_str(str suffix):
"""
>>> concat_literal_str('sa') == 'abcsa' or concat_literal_str('sa')
True
>>> concat_literal_str(None) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...NoneType...
"""
assert cython.typeof(u'abc' + object()) == 'Python object', cython.typeof(u'abc' + object())
assert cython.typeof(u'abc' + suffix) == 'unicode object', cython.typeof(u'abc' + suffix)
return u'abc' + suffix
def concat_literal_unicode(unicode suffix):
"""
>>> concat_literal_unicode(unicode_sa) == 'abcsa' or concat_literal_unicode(unicode_sa)
True
>>> concat_literal_unicode(None) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...NoneType...
"""
assert cython.typeof(u'abc' + suffix) == 'unicode object', cython.typeof(u'abc' + suffix)
return u'abc' + suffix
# unicode.__mod__(format, values)
def mod_format(unicode s, values):
"""
>>> mod_format(format1, 'sa') == 'abcsadef' or mod_format(format1, 'sa')
True
>>> mod_format(format2, ('XYZ', 'ABC')) == 'abcXYZdefABCghi' or mod_format(format2, ('XYZ', 'ABC'))
True
>>> mod_format(None, 'sa') # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: unsupported operand type(s) for %: 'NoneType' and 'str'
>>> class RMod(object):
... def __rmod__(self, other):
... return 123
>>> mod_format(None, RMod())
123
"""
assert cython.typeof(s % values) == 'Python object', cython.typeof(s % values)
return s % values
def mod_format_literal(values):
"""
>>> mod_format_literal('sa') == 'abcsadef' or mod_format(format1, 'sa')
True
>>> mod_format_literal(('sa',)) == 'abcsadef' or mod_format(format1, ('sa',))
True
>>> mod_format_literal(['sa']) == "abc['sa']def" or mod_format(format1, ['sa'])
True
"""
assert cython.typeof(u'abc%sdef' % values) == 'unicode object', cython.typeof(u'abc%sdef' % values)
return u'abc%sdef' % values
def mod_format_tuple(*values):
"""
>>> mod_format_tuple('sa') == 'abcsadef' or mod_format(format1, 'sa')
True
>>> mod_format_tuple()
Traceback (most recent call last):
TypeError: not enough arguments for format string
"""
assert cython.typeof(u'abc%sdef' % values) == 'unicode object', cython.typeof(u'abc%sdef' % values)
return u'abc%sdef' % values
# unicode.find(s, sub, [start, [end]]) # unicode.find(s, sub, [start, [end]])
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment