Commit 656196c3 authored by Stefan Behnel's avatar Stefan Behnel

reduce Unicode iteration over Latin-1 literals to bytes iteration

parent 5f0b23cb
...@@ -285,6 +285,29 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -285,6 +285,29 @@ class IterationTransform(Visitor.EnvTransform):
exception_value = '-1') exception_value = '-1')
def _transform_unicode_iteration(self, node, slice_node, reversed=False): def _transform_unicode_iteration(self, node, slice_node, reversed=False):
if slice_node.is_literal:
# try to reduce to byte iteration for plain Latin-1 strings
try:
bytes_value = BytesLiteral(slice_node.value.encode('latin1'))
except UnicodeEncodeError:
pass
else:
bytes_slice = ExprNodes.SliceIndexNode(
slice_node.pos,
base=ExprNodes.BytesNode(
slice_node.pos, value=bytes_value,
constant_result=bytes_value,
type=PyrexTypes.c_char_ptr_type).coerce_to(
PyrexTypes.c_uchar_ptr_type, self.current_env()),
start=None,
stop=ExprNodes.IntNode(
slice_node.pos, value=len(bytes_value),
constant_result=len(bytes_value),
type=PyrexTypes.c_py_ssize_t_type),
type=Builtin.unicode_type, # hint for Python conversion
)
return self._transform_carray_iteration(node, bytes_slice, reversed)
unpack_temp_node = UtilNodes.LetRefNode( unpack_temp_node = UtilNodes.LetRefNode(
slice_node.as_none_safe_node("'NoneType' is not iterable")) slice_node.as_none_safe_node("'NoneType' is not iterable"))
...@@ -455,22 +478,32 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -455,22 +478,32 @@ class IterationTransform(Visitor.EnvTransform):
counter_temp = counter.ref(node.target.pos) counter_temp = counter.ref(node.target.pos)
if slice_base.type.is_string and node.target.type.is_pyobject: if slice_base.type.is_string and node.target.type.is_pyobject:
# special case: char* -> bytes # special case: char* -> bytes/unicode
target_value = ExprNodes.SliceIndexNode( if slice_node.type is Builtin.unicode_type:
node.target.pos, target_value = ExprNodes.CastNode(
start=ExprNodes.IntNode(node.target.pos, value='0', ExprNodes.DereferenceNode(
constant_result=0, node.target.pos, operand=counter_temp,
type=PyrexTypes.c_int_type), type=ptr_type.base_type),
stop=ExprNodes.IntNode(node.target.pos, value='1', PyrexTypes.c_py_ucs4_type).coerce_to(
constant_result=1, node.target.type, self.current_env())
type=PyrexTypes.c_int_type), else:
base=counter_temp, # char* -> bytes coercion requires slicing, not indexing
type=Builtin.bytes_type, target_value = ExprNodes.SliceIndexNode(
is_temp=1) node.target.pos,
start=ExprNodes.IntNode(node.target.pos, value='0',
constant_result=0,
type=PyrexTypes.c_int_type),
stop=ExprNodes.IntNode(node.target.pos, value='1',
constant_result=1,
type=PyrexTypes.c_int_type),
base=counter_temp,
type=Builtin.bytes_type,
is_temp=1)
elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type): elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
# Allow iteration with pointer target to avoid copy. # Allow iteration with pointer target to avoid copy.
target_value = counter_temp target_value = counter_temp
else: else:
# TODO: can this safely be replaced with DereferenceNode() as above?
target_value = ExprNodes.IndexNode( target_value = ExprNodes.IndexNode(
node.target.pos, node.target.pos,
index=ExprNodes.IntNode(node.target.pos, value='0', index=ExprNodes.IntNode(node.target.pos, value='0',
......
...@@ -291,7 +291,7 @@ def loop_over_unicode_literal(): ...@@ -291,7 +291,7 @@ def loop_over_unicode_literal():
""" """
# Py_UCS4 can represent any Unicode character # Py_UCS4 can represent any Unicode character
for uchar in 'abcdefg': for uchar in 'abcdefg':
pass assert uchar in 'abcdefg'
return cython.typeof(uchar) return cython.typeof(uchar)
def list_comp(): def list_comp():
......
...@@ -209,14 +209,31 @@ def count_lower_case_characters_slice_reversed(unicode ustring): ...@@ -209,14 +209,31 @@ def count_lower_case_characters_slice_reversed(unicode ustring):
count += 1 count += 1
return count return count
def loop_object_over_latin1_unicode_literal():
"""
>>> result = loop_object_over_latin1_unicode_literal()
>>> print(result[:-1])
abcdefg
>>> ord(result[-1]) == 0xD7
True
"""
cdef object uchar
chars = []
for uchar in u'abcdefg\xD7':
chars.append(uchar)
return u''.join(chars)
def loop_object_over_unicode_literal(): def loop_object_over_unicode_literal():
""" """
>>> print(loop_object_over_unicode_literal()) >>> result = loop_object_over_unicode_literal()
>>> print(result[:-1])
abcdefg abcdefg
>>> ord(result[-1]) == 0xF8FD
True
""" """
cdef object uchar cdef object uchar
chars = [] chars = []
for uchar in u'abcdefg': for uchar in u'abcdefg\uF8FD':
chars.append(uchar) chars.append(uchar)
return u''.join(chars) return u''.join(chars)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment