Commit 7b4079c5 authored by Stefan Behnel's avatar Stefan Behnel

Check some child nodes against the correct nogil context when they are...

Check some child nodes against the correct nogil context when they are actually being evaluated in the outer scope (e.g default arguments or annotations of a nogil function).
parent 04114b9b
...@@ -207,6 +207,9 @@ class Node(object): ...@@ -207,6 +207,9 @@ class Node(object):
# can either contain a single node or a list of nodes. See Visitor.py. # can either contain a single node or a list of nodes. See Visitor.py.
child_attrs = None child_attrs = None
# Subset of attributes that are evaluated in the outer scope (e.g. function default arguments).
outer_attrs = None
cf_state = None cf_state = None
# This may be an additional (or 'actual') type that will be checked when # This may be an additional (or 'actual') type that will be checked when
...@@ -222,6 +225,7 @@ class Node(object): ...@@ -222,6 +225,7 @@ class Node(object):
gil_message = "Operation" gil_message = "Operation"
nogil_check = None nogil_check = None
in_nogil_context = False # For use only during code generation.
def gil_error(self, env=None): def gil_error(self, env=None):
error(self.pos, "%s not allowed without gil" % self.gil_message) error(self.pos, "%s not allowed without gil" % self.gil_message)
...@@ -848,6 +852,7 @@ class CArgDeclNode(Node): ...@@ -848,6 +852,7 @@ class CArgDeclNode(Node):
# is_dynamic boolean Non-literal arg stored inside CyFunction # is_dynamic boolean Non-literal arg stored inside CyFunction
child_attrs = ["base_type", "declarator", "default", "annotation"] child_attrs = ["base_type", "declarator", "default", "annotation"]
outer_attrs = ["default", "annotation"]
is_self_arg = 0 is_self_arg = 0
is_type_arg = 0 is_type_arg = 0
...@@ -1680,10 +1685,6 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1680,10 +1685,6 @@ class FuncDefNode(StatNode, BlockNode):
return None return None
if not env.directives['annotation_typing'] or annotation.analyse_as_type(env) is None: if not env.directives['annotation_typing'] or annotation.analyse_as_type(env) is None:
annotation = annotation.analyse_types(env) annotation = annotation.analyse_types(env)
elif isinstance(self, CFuncDefNode):
# Discard invisible type annotations from cdef functions after applying them,
# as they might get in the way of @nogil declarations etc.
return None
return annotation return annotation
def analyse_annotations(self, env): def analyse_annotations(self, env):
...@@ -2741,6 +2742,7 @@ class DefNode(FuncDefNode): ...@@ -2741,6 +2742,7 @@ class DefNode(FuncDefNode):
# decorator_indirection IndirectionNode Used to remove __Pyx_Method_ClassMethod for fused functions # decorator_indirection IndirectionNode Used to remove __Pyx_Method_ClassMethod for fused functions
child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators", "return_type_annotation"] child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators", "return_type_annotation"]
outer_attrs = ["decorators", "return_type_annotation"]
is_staticmethod = False is_staticmethod = False
is_classmethod = False is_classmethod = False
......
...@@ -2861,24 +2861,33 @@ class GilCheck(VisitorTransform): ...@@ -2861,24 +2861,33 @@ class GilCheck(VisitorTransform):
self.nogil_declarator_only = False self.nogil_declarator_only = False
return super(GilCheck, self).__call__(root) return super(GilCheck, self).__call__(root)
def _visit_scoped_children(self, node, gil_state):
was_nogil = self.nogil
outer_attrs = node.outer_attrs
if outer_attrs and len(self.env_stack) > 1:
self.nogil = self.env_stack[-2].nogil
self.visitchildren(node, outer_attrs)
self.nogil = gil_state
self.visitchildren(node, exclude=outer_attrs)
self.nogil = was_nogil
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope) self.env_stack.append(node.local_scope)
was_nogil = self.nogil inner_nogil = node.local_scope.nogil
self.nogil = node.local_scope.nogil
if self.nogil: if inner_nogil:
self.nogil_declarator_only = True self.nogil_declarator_only = True
if self.nogil and node.nogil_check: if inner_nogil and node.nogil_check:
node.nogil_check(node.local_scope) node.nogil_check(node.local_scope)
self.visitchildren(node) self._visit_scoped_children(node, inner_nogil)
# This cannot be nested, so it doesn't need backup/restore # This cannot be nested, so it doesn't need backup/restore
self.nogil_declarator_only = False self.nogil_declarator_only = False
self.env_stack.pop() self.env_stack.pop()
self.nogil = was_nogil
return node return node
def visit_GILStatNode(self, node): def visit_GILStatNode(self, node):
...@@ -2886,9 +2895,9 @@ class GilCheck(VisitorTransform): ...@@ -2886,9 +2895,9 @@ class GilCheck(VisitorTransform):
node.nogil_check() node.nogil_check()
was_nogil = self.nogil was_nogil = self.nogil
self.nogil = (node.state == 'nogil') is_nogil = (node.state == 'nogil')
if was_nogil == self.nogil and not self.nogil_declarator_only: if was_nogil == is_nogil and not self.nogil_declarator_only:
if not was_nogil: if not was_nogil:
error(node.pos, "Trying to acquire the GIL while it is " error(node.pos, "Trying to acquire the GIL while it is "
"already held.") "already held.")
...@@ -2901,8 +2910,7 @@ class GilCheck(VisitorTransform): ...@@ -2901,8 +2910,7 @@ class GilCheck(VisitorTransform):
# which is wrapped in a StatListNode. Just unpack that. # which is wrapped in a StatListNode. Just unpack that.
node.finally_clause, = node.finally_clause.stats node.finally_clause, = node.finally_clause.stats
self.visitchildren(node) self._visit_scoped_children(node, is_nogil)
self.nogil = was_nogil
return node return node
def visit_ParallelRangeNode(self, node): def visit_ParallelRangeNode(self, node):
...@@ -2949,8 +2957,12 @@ class GilCheck(VisitorTransform): ...@@ -2949,8 +2957,12 @@ class GilCheck(VisitorTransform):
def visit_Node(self, node): def visit_Node(self, node):
if self.env_stack and self.nogil and node.nogil_check: if self.env_stack and self.nogil and node.nogil_check:
node.nogil_check(self.env_stack[-1]) node.nogil_check(self.env_stack[-1])
self.visitchildren(node) if node.outer_attrs:
node.in_nogil_context = self.nogil self._visit_scoped_children(node, self.nogil)
else:
self.visitchildren(node)
if self.nogil:
node.in_nogil_context = True
return node return node
......
...@@ -16,8 +16,9 @@ cdef class TreeVisitor: ...@@ -16,8 +16,9 @@ cdef class TreeVisitor:
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
cdef dict _process_children(self, parent, attrs=*) cdef dict _process_children(self, parent, attrs=*)
cpdef visitchildren(self, parent, attrs=*) cpdef visitchildren(self, parent, attrs=*, exclude=*)
cdef list _flatten_list(self, list orig_list) cdef list _flatten_list(self, list orig_list)
cdef list _select_attrs(self, attrs, exclude)
cdef class CythonTransform(VisitorTransform): cdef class CythonTransform(VisitorTransform):
cdef public context cdef public context
......
...@@ -244,10 +244,16 @@ class VisitorTransform(TreeVisitor): ...@@ -244,10 +244,16 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None, exclude=None):
# generic def entry point for calls from Python subclasses # generic def entry point for calls from Python subclasses
if exclude is not None:
attrs = self._select_attrs(parent.child_attrs if attrs is None else attrs, exclude)
return self._process_children(parent, attrs) return self._process_children(parent, attrs)
@cython.final
def _select_attrs(self, attrs, exclude):
return [name for name in attrs if name not in exclude]
@cython.final @cython.final
def _process_children(self, parent, attrs=None): def _process_children(self, parent, attrs=None):
# fast cdef entry point for calls from Cython subclasses # fast cdef entry point for calls from Cython subclasses
......
...@@ -34,3 +34,21 @@ def two_dim(a: cython.double[:,:]): ...@@ -34,3 +34,21 @@ def two_dim(a: cython.double[:,:]):
""" """
a[0,0] *= 3 a[0,0] *= 3
return a[0,0], a[0,1], a.ndim return a[0,0], a[0,1], a.ndim
@cython.nogil
@cython.cfunc
def _one_dim_nogil_cfunc(a: cython.double[:]) -> cython.double:
a[0] *= 2
return a[0]
def one_dim_nogil_cfunc(a: cython.double[:]):
"""
>>> a = numpy.ones((10,), numpy.double)
>>> one_dim_nogil_cfunc(a)
2.0
"""
with cython.nogil:
result = _one_dim_nogil_cfunc(a)
return result
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment