Commit 9cb557c3 authored by da-woods's avatar da-woods Committed by GitHub

Handle `for x in cpp_function_call()` (GH-3667)

Fixes https://github.com/cython/cython/issues/3663

This ensures that rvalues here are saved as temps, while keeping the
existing behaviour for `for x in deref(vec)`, where the pointer for vec
is copied, meaning it doesn't crash if vec is reassigned.

The bit of this change liable to have the biggest effect is that I've
changed the result type of dereference(x) and x[0] (where x is a c++
type) to a reference rather than value type. I think this is OK because
it matches what C++ does. If that isn't a sensible change then I can
probably inspect the loop sequence more closely to try to detect this.
parent c42ad917
...@@ -2658,7 +2658,6 @@ class IteratorNode(ExprNode): ...@@ -2658,7 +2658,6 @@ class IteratorNode(ExprNode):
type = py_object_type type = py_object_type
iter_func_ptr = None iter_func_ptr = None
counter_cname = None counter_cname = None
cpp_iterator_cname = None
reversed = False # currently only used for list/tuple types (see Optimize.py) reversed = False # currently only used for list/tuple types (see Optimize.py)
is_async = False is_async = False
...@@ -2671,7 +2670,7 @@ class IteratorNode(ExprNode): ...@@ -2671,7 +2670,7 @@ class IteratorNode(ExprNode):
# C array iteration will be transformed later on # C array iteration will be transformed later on
self.type = self.sequence.type self.type = self.sequence.type
elif self.sequence.type.is_cpp_class: elif self.sequence.type.is_cpp_class:
self.analyse_cpp_types(env) return CppIteratorNode(self.pos, sequence=self.sequence).analyse_types(env)
else: else:
self.sequence = self.sequence.coerce_to_pyobject(env) self.sequence = self.sequence.coerce_to_pyobject(env)
if self.sequence.type in (list_type, tuple_type): if self.sequence.type in (list_type, tuple_type):
...@@ -2701,65 +2700,10 @@ class IteratorNode(ExprNode): ...@@ -2701,65 +2700,10 @@ class IteratorNode(ExprNode):
return sequence_type return sequence_type
return py_object_type return py_object_type
def analyse_cpp_types(self, env):
sequence_type = self.sequence.type
if sequence_type.is_ptr:
sequence_type = sequence_type.base_type
begin = sequence_type.scope.lookup("begin")
end = sequence_type.scope.lookup("end")
if (begin is None
or not begin.type.is_cfunction
or begin.type.args):
error(self.pos, "missing begin() on %s" % self.sequence.type)
self.type = error_type
return
if (end is None
or not end.type.is_cfunction
or end.type.args):
error(self.pos, "missing end() on %s" % self.sequence.type)
self.type = error_type
return
iter_type = begin.type.return_type
if iter_type.is_cpp_class:
if env.lookup_operator_for_types(
self.pos,
"!=",
[iter_type, end.type.return_type]) is None:
error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
self.type = error_type
return
self.type = iter_type
elif iter_type.is_ptr:
if not (iter_type == end.type.return_type):
error(self.pos, "incompatible types for begin() and end()")
self.type = iter_type
else:
error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
self.type = error_type
return
def generate_result_code(self, code): def generate_result_code(self, code):
sequence_type = self.sequence.type sequence_type = self.sequence.type
if sequence_type.is_cpp_class: if sequence_type.is_cpp_class:
if self.sequence.is_name: assert False, "Should have been changed to CppIteratorNode"
# safe: C++ won't allow you to reassign to class references
begin_func = "%s.begin" % self.sequence.result()
else:
sequence_type = PyrexTypes.c_ptr_type(sequence_type)
self.cpp_iterator_cname = code.funcstate.allocate_temp(sequence_type, manage_ref=False)
code.putln("%s = &%s;" % (self.cpp_iterator_cname, self.sequence.result()))
begin_func = "%s->begin" % self.cpp_iterator_cname
# TODO: Limit scope.
code.putln("%s = %s();" % (self.result(), begin_func))
return
if sequence_type.is_array or sequence_type.is_ptr: if sequence_type.is_array or sequence_type.is_ptr:
raise InternalError("for in carray slice not transformed") raise InternalError("for in carray slice not transformed")
...@@ -2855,21 +2799,7 @@ class IteratorNode(ExprNode): ...@@ -2855,21 +2799,7 @@ class IteratorNode(ExprNode):
sequence_type = self.sequence.type sequence_type = self.sequence.type
if self.reversed: if self.reversed:
code.putln("if (%s < 0) break;" % self.counter_cname) code.putln("if (%s < 0) break;" % self.counter_cname)
if sequence_type.is_cpp_class: if sequence_type is list_type:
if self.cpp_iterator_cname:
end_func = "%s->end" % self.cpp_iterator_cname
else:
end_func = "%s.end" % self.sequence.result()
# TODO: Cache end() call?
code.putln("if (!(%s != %s())) break;" % (
self.result(),
end_func))
code.putln("%s = *%s;" % (
result_name,
self.result()))
code.putln("++%s;" % self.result())
return
elif sequence_type is list_type:
self.generate_next_sequence_item('List', result_name, code) self.generate_next_sequence_item('List', result_name, code)
return return
elif sequence_type is tuple_type: elif sequence_type is tuple_type:
...@@ -2908,8 +2838,109 @@ class IteratorNode(ExprNode): ...@@ -2908,8 +2838,109 @@ class IteratorNode(ExprNode):
if self.iter_func_ptr: if self.iter_func_ptr:
code.funcstate.release_temp(self.iter_func_ptr) code.funcstate.release_temp(self.iter_func_ptr)
self.iter_func_ptr = None self.iter_func_ptr = None
if self.cpp_iterator_cname: ExprNode.free_temps(self, code)
code.funcstate.release_temp(self.cpp_iterator_cname)
class CppIteratorNode(ExprNode):
# Iteration over a C++ container.
# Created at the analyse_types stage by IteratorNode
cpp_sequence_cname = None
cpp_attribute_op = "."
is_temp = True
subexprs = ['sequence']
def analyse_types(self, env):
sequence_type = self.sequence.type
if sequence_type.is_ptr:
sequence_type = sequence_type.base_type
begin = sequence_type.scope.lookup("begin")
end = sequence_type.scope.lookup("end")
if (begin is None
or not begin.type.is_cfunction
or begin.type.args):
error(self.pos, "missing begin() on %s" % self.sequence.type)
self.type = error_type
return self
if (end is None
or not end.type.is_cfunction
or end.type.args):
error(self.pos, "missing end() on %s" % self.sequence.type)
self.type = error_type
return self
iter_type = begin.type.return_type
if iter_type.is_cpp_class:
if env.lookup_operator_for_types(
self.pos,
"!=",
[iter_type, end.type.return_type]) is None:
error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
self.type = error_type
return self
if env.lookup_operator_for_types(self.pos, '++', [iter_type]) is None:
error(self.pos, "missing operator++ on result of begin() on %s" % self.sequence.type)
self.type = error_type
return self
if env.lookup_operator_for_types(self.pos, '*', [iter_type]) is None:
error(self.pos, "missing operator* on result of begin() on %s" % self.sequence.type)
self.type = error_type
return self
self.type = iter_type
elif iter_type.is_ptr:
if not (iter_type == end.type.return_type):
error(self.pos, "incompatible types for begin() and end()")
self.type = iter_type
else:
error(self.pos, "result type of begin() on %s must be a C++ class or pointer" % self.sequence.type)
self.type = error_type
return self
def generate_result_code(self, code):
sequence_type = self.sequence.type
# essentially 3 options:
if self.sequence.is_name or self.sequence.is_attribute:
# 1) is a name and can be accessed directly;
# assigning to it may break the container, but that's the responsibility
# of the user
code.putln("%s = %s%sbegin();" % (self.result(),
self.sequence.result(),
self.cpp_attribute_op))
else:
# (while it'd be nice to limit the scope of the loop temp, it's essentially
# impossible to do while supporting generators)
temp_type = sequence_type
if temp_type.is_reference:
# 2) Sequence is a reference (often obtained by dereferencing a pointer);
# make the temp a pointer so we are not sensitive to users reassigning
# the pointer than it came from
temp_type = PyrexTypes.CPtrType(sequence_type.ref_base_type)
if temp_type.is_ptr:
self.cpp_attribute_op = "->"
# 3) (otherwise) sequence comes from a function call or similar, so we must
# create a temp to store it in
self.cpp_sequence_cname = code.funcstate.allocate_temp(temp_type, manage_ref=False)
code.putln("%s = %s%s;" % (self.cpp_sequence_cname,
"&" if temp_type.is_ptr else "",
self.sequence.move_result_rhs()))
code.putln("%s = %s%sbegin();" % (self.result(), self.cpp_sequence_cname,
self.cpp_attribute_op))
def generate_iter_next_result_code(self, result_name, code):
# end call isn't cached to support containers that allow adding while iterating
# (much as this is usually a bad idea)
code.putln("if (!(%s != %s%send())) break;" % (
self.result(),
self.cpp_sequence_cname or self.sequence.result(),
self.cpp_attribute_op))
code.putln("%s = *%s;" % (
result_name,
self.result()))
code.putln("++%s;" % self.result())
def free_temps(self, code):
if self.cpp_sequence_cname:
code.funcstate.release_temp(self.cpp_sequence_cname)
# skip over IteratorNode since we don't use any of the temps it does
ExprNode.free_temps(self, code) ExprNode.free_temps(self, code)
...@@ -3793,6 +3824,8 @@ class IndexNode(_IndexingBaseNode): ...@@ -3793,6 +3824,8 @@ class IndexNode(_IndexingBaseNode):
def analyse_as_c_array(self, env, is_slice): def analyse_as_c_array(self, env, is_slice):
base_type = self.base.type base_type = self.base.type
self.type = base_type.base_type self.type = base_type.base_type
if self.type.is_cpp_class:
self.type = PyrexTypes.CReferenceType(self.type)
if is_slice: if is_slice:
self.type = base_type self.type = base_type
elif self.index.type.is_pyobject: elif self.index.type.is_pyobject:
...@@ -10313,7 +10346,7 @@ class DereferenceNode(CUnopNode): ...@@ -10313,7 +10346,7 @@ class DereferenceNode(CUnopNode):
def analyse_c_operation(self, env): def analyse_c_operation(self, env):
if self.operand.type.is_ptr: if self.operand.type.is_ptr:
self.type = self.operand.type.base_type self.type = PyrexTypes.CReferenceType(self.operand.type.base_type)
else: else:
self.type_error() self.type_error()
......
...@@ -140,3 +140,40 @@ def test_iteration_in_generator_reassigned(): ...@@ -140,3 +140,40 @@ def test_iteration_in_generator_reassigned():
if vint is not orig_vint: if vint is not orig_vint:
del vint del vint
del orig_vint del orig_vint
cdef extern from *:
"""
std::vector<int> make_vec1() {
std::vector<int> vint;
vint.push_back(1);
vint.push_back(2);
return vint;
}
"""
cdef vector[int] make_vec1() except +
cdef vector[int] make_vec2() except *:
return make_vec1()
cdef vector[int] make_vec3():
try:
return make_vec1()
except:
pass
def test_iteration_from_function_call():
"""
>>> test_iteration_from_function_call()
1
2
1
2
1
2
"""
for i in make_vec1():
print(i)
for i in make_vec2():
print(i)
for i in make_vec3():
print(i)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment