Commit 00dd1dfb authored by Stefan Behnel's avatar Stefan Behnel

optimised for-in-reversed(array) etc., including char*, bytes and unicode

parent c7e5af8b
......@@ -160,15 +160,9 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_carray_iteration(node, plain_iterator)
if iterator.type.is_ptr or iterator.type.is_array:
if reversed:
# TODO: implement
return node
return self._transform_carray_iteration(node, iterator)
return self._transform_carray_iteration(node, iterator, reversed=reversed)
if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
if reversed:
# TODO: implement
return node
return self._transform_string_iteration(node, iterator)
return self._transform_string_iteration(node, iterator, reversed=reversed)
# the rest is based on function calls
if not isinstance(iterator, ExprNodes.SimpleCallNode):
......@@ -253,7 +247,7 @@ class IterationTransform(Visitor.VisitorTransform):
PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
])
def _transform_string_iteration(self, node, slice_node):
def _transform_string_iteration(self, node, slice_node, reversed=False):
if not node.target.type.is_int:
return self._transform_carray_iteration(node, slice_node)
if slice_node.type is Builtin.unicode_type:
......@@ -295,9 +289,10 @@ class IterationTransform(Visitor.VisitorTransform):
stop = len_node,
type = slice_base_node.type,
is_temp = 1,
)))
),
reversed = reversed))
def _transform_carray_iteration(self, node, slice_node):
def _transform_carray_iteration(self, node, slice_node, reversed=False):
neg_step = False
if isinstance(slice_node, ExprNodes.SliceIndexNode):
slice_base = slice_node.base
......@@ -327,10 +322,13 @@ class IterationTransform(Visitor.VisitorTransform):
return node
else:
# step sign is handled internally by ForFromStatNode
neg_step = step.constant_result < 0
step_value = step.constant_result
if reversed:
step_value = -step_value
neg_step = step_value < 0
step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
value=abs(step.constant_result),
constant_result=abs(step.constant_result))
value=str(abs(step_value)),
constant_result=abs(step_value))
elif slice_node.type.is_array:
if slice_node.type.size is None:
error(step.pos, "C array iteration requires known end index")
......@@ -365,6 +363,13 @@ class IterationTransform(Visitor.VisitorTransform):
error(slice_node.pos, "C array iteration requires known step size and end index")
return node
if reversed:
if not start:
start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0,
type=PyrexTypes.c_py_ssize_t_type)
# if step was provided, it was already negated above
start, stop = stop, start
ptr_type = slice_base.type
if ptr_type.is_array:
ptr_type = ptr_type.element_ptr_type()
......@@ -380,13 +385,16 @@ class IterationTransform(Visitor.VisitorTransform):
else:
start_ptr_node = carray_ptr
stop_ptr_node = ExprNodes.AddNode(
stop.pos,
operand1=ExprNodes.CloneNode(carray_ptr),
operator='+',
operand2=stop,
type=ptr_type
).coerce_to_simple(self.current_scope)
if stop and stop.constant_result != 0:
stop_ptr_node = ExprNodes.AddNode(
stop.pos,
operand1=ExprNodes.CloneNode(carray_ptr),
operator='+',
operand2=stop,
type=ptr_type
).coerce_to_simple(self.current_scope)
else:
stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
counter = UtilNodes.TempHandle(ptr_type)
counter_temp = counter.ref(node.target.pos)
......@@ -430,11 +438,13 @@ class IterationTransform(Visitor.VisitorTransform):
node.pos,
stats = [target_assign, node.body])
relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
for_node = Nodes.ForFromStatNode(
node.pos,
bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
bound1=start_ptr_node, relation1=relation1,
target=counter_temp,
relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
relation2=relation2, bound2=stop_ptr_node,
step=step, body=body,
else_clause=node.else_clause,
from_range=True)
......@@ -511,7 +521,19 @@ class IterationTransform(Visitor.VisitorTransform):
node.iterator.sequence = enumerate_function.arg_tuple.args[0]
# recurse into loop to check for further optimisations
return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
def _find_for_from_node_relations(self, neg_step_value, reversed):
if reversed:
if neg_step_value:
return '<', '<='
else:
return '>', '>='
else:
if neg_step_value:
return '>=', '>'
else:
return '<=', '<'
def _transform_range_iteration(self, node, range_function, reversed=False):
args = range_function.arg_tuple.args
......@@ -542,23 +564,15 @@ class IterationTransform(Visitor.VisitorTransform):
bound1 = args[0].coerce_to_integer(self.current_scope)
bound2 = args[1].coerce_to_integer(self.current_scope)
relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
if reversed:
bound1, bound2 = bound2, bound1
if step_value < 0:
step_value = -step_value
relation1 = '<'
relation2 = '<='
else:
relation1 = '>'
relation2 = '>='
else:
if step_value < 0:
step_value = -step_value
relation1 = '>='
relation2 = '>'
else:
relation1 = '<='
relation2 = '<'
step.value = str(step_value)
step.constant_result = step_value
......
......@@ -82,6 +82,9 @@ class MarkAssignments(CythonTransform):
'+',
sequence.args[0],
sequence.args[2]))
elif function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
if not is_special:
# A for-loop basically translates to subsequent calls to
# __getitem__(), so using an IndexNode here allows us to
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment