Commit 8791871e authored by Mark Florisson's avatar Mark Florisson

Disallow yield in parallel sections

parent 4657fff3
...@@ -987,9 +987,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -987,9 +987,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
# Keep track of whether we are the context manager of a 'with' statement # Keep track of whether we are the context manager of a 'with' statement
in_context_manager_section = False in_context_manager_section = False
# Keep track of whether we are in a parallel range section
in_prange = False
# One of 'prange' or 'with parallel'. This is used to disallow closely # One of 'prange' or 'with parallel'. This is used to disallow closely
# nested 'with parallel:' blocks # nested 'with parallel:' blocks
state = None state = None
...@@ -1082,7 +1079,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1082,7 +1079,7 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
if isinstance(newnode, Nodes.ParallelWithBlockNode): if isinstance(newnode, Nodes.ParallelWithBlockNode):
if self.state == 'parallel with': if self.state == 'parallel with':
error(node.manager.pos, error(node.manager.pos,
"Closely nested 'with parallel:' blocks are disallowed") "Closely nested parallel with blocks are disallowed")
self.state = 'parallel with' self.state = 'parallel with'
body = self.visit(node.body) body = self.visit(node.body)
...@@ -1109,12 +1106,11 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1109,12 +1106,11 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
self.visit(node.iterator) self.visit(node.iterator)
self.visit(node.target) self.visit(node.target)
was_in_prange = self.in_prange in_prange = isinstance(node.iterator.sequence,
self.in_prange = isinstance(node.iterator.sequence,
Nodes.ParallelRangeNode) Nodes.ParallelRangeNode)
previous_state = self.state previous_state = self.state
if self.in_prange: if in_prange:
# This will replace the entire ForInStatNode, so copy the # This will replace the entire ForInStatNode, so copy the
# attributes # attributes
parallel_range_node = node.iterator.sequence parallel_range_node = node.iterator.sequence
...@@ -1133,8 +1129,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): ...@@ -1133,8 +1129,6 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations):
self.visit(node.body) self.visit(node.body)
self.state = previous_state self.state = previous_state
self.in_prange = was_in_prange
self.visit(node.else_clause) self.visit(node.else_clause)
return node return node
......
...@@ -202,36 +202,9 @@ class MarkAssignments(CythonTransform): ...@@ -202,36 +202,9 @@ class MarkAssignments(CythonTransform):
return node return node
def visit_BreakStatNode(self, node): def visit_YieldExprNode(self, node):
if self.parallel_block_stack: if self.parallel_block_stack:
parnode = self.parallel_block_stack[-1] error(node.pos, "Yield not allowed in parallel sections")
parnode.break_label_used = True
if not parnode.is_prange and parnode.parent:
parnode.parent.break_label_used = True
return node
def visit_ContinueStatNode(self, node):
if self.parallel_block_stack:
parnode = self.parallel_block_stack[-1]
parnode.continue_label_used = True
if not parnode.is_prange and parnode.parent:
parnode.parent.continue_label_used = True
return node
def visit_ReturnStatNode(self, node):
for parnode in self.parallel_block_stack:
parnode.return_label_used = True
return node
def visit_GilStatNode(self, node):
if node.state == 'gil':
for parnode in self.parallel_block_stack:
parnode.error_label_used = True
return node return node
......
...@@ -40,17 +40,17 @@ with nogil, cython.parallel.parallel: ...@@ -40,17 +40,17 @@ with nogil, cython.parallel.parallel:
pass pass
cdef int y cdef int y
# this is not valid
for i in prange(10, nogil=True): for i in prange(10, nogil=True):
i = y * 4 i = y * 4
y = i y = i
# this is valid
for i in prange(10, nogil=True): for i in prange(10, nogil=True):
y = i y = i
i = y * 4 i = y * 4
y = i y = i
with nogil, cython.parallel.parallel(): with nogil, cython.parallel.parallel():
i = y i = y
y = i y = i
...@@ -65,6 +65,17 @@ with nogil, cython.parallel.parallel("invalid"): ...@@ -65,6 +65,17 @@ with nogil, cython.parallel.parallel("invalid"):
with nogil, cython.parallel.parallel(invalid=True): with nogil, cython.parallel.parallel(invalid=True):
pass pass
def f(x):
cdef int i
with nogil, cython.parallel.parallel():
with gil:
yield x
for i in prange(10):
with gil:
yield x
_ERRORS = u""" _ERRORS = u"""
e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module
e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something
...@@ -77,11 +88,13 @@ c_cython_parallel.pyx:21:29: The parallel section may only be used without the G ...@@ -77,11 +88,13 @@ c_cython_parallel.pyx:21:29: The parallel section may only be used without the G
e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL
e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable
e_cython_parallel.pyx:33:10: Must be of numeric type, not int * e_cython_parallel.pyx:33:10: Must be of numeric type, not int *
e_cython_parallel.pyx:36:33: Closely nested 'with parallel:' blocks are disallowed e_cython_parallel.pyx:36:33: Closely nested parallel with blocks are disallowed
e_cython_parallel.pyx:39:12: The parallel directive must be called e_cython_parallel.pyx:39:12: The parallel directive must be called
e_cython_parallel.pyx:45:10: Expression value depends on previous loop iteration, cannot execute in parallel e_cython_parallel.pyx:45:10: Expression value depends on previous loop iteration, cannot execute in parallel
e_cython_parallel.pyx:55:9: Expression depends on an uninitialized thread-private variable e_cython_parallel.pyx:55:9: Expression depends on an uninitialized thread-private variable
e_cython_parallel.pyx:60:6: Reduction operator '*' is inconsistent with previous reduction operator '+' e_cython_parallel.pyx:60:6: Reduction operator '*' is inconsistent with previous reduction operator '+'
e_cython_parallel.pyx:62:36: cython.parallel.parallel() does not take positional arguments e_cython_parallel.pyx:62:36: cython.parallel.parallel() does not take positional arguments
e_cython_parallel.pyx:65:36: Invalid keyword argument: invalid e_cython_parallel.pyx:65:36: Invalid keyword argument: invalid
e_cython_parallel.pyx:73:12: Yield not allowed in parallel sections
e_cython_parallel.pyx:77:16: Yield not allowed in parallel sections
""" """
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment