Commit 827e5188 authored by da-woods's avatar da-woods Committed by GitHub

Stop cfunc/ccall/class applying to inner funcs/classes (GH-4575)

Fixes https://github.com/cython/cython/issues/4092

Nested cdef classes are not tested because they don't officially work
at this stage (see https://github.com/cython/cython/issues/4572).
It should be tested if/when they're fully supported.
parent e190d8a8
......@@ -385,6 +385,20 @@ directive_scopes = { # defaults to available everywhere
}
# a list of directives that (when used as a decorator) are only applied to
# the object they decorate and not to its children.
immediate_decorator_directives = {
'cfunc', 'ccall', 'cclass',
# function signature directives
'inline', 'exceptval', 'returns',
# class directives
'freelist', 'no_gc', 'no_gc_clear', 'type_version_tag', 'final',
'auto_pickle', 'internal',
# testing directives
'test_fail_if_path_exists', 'test_assert_path_exists',
}
def parse_directive_value(name, value, relaxed_bool=False):
"""
Parses value as an option value for the given name and returns
......
......@@ -874,6 +874,12 @@ class InterpretCompilerDirectives(CythonTransform):
node.cython_module_names = self.cython_module_names
return node
def visit_CompilerDirectivesNode(self, node):
old_directives, self.directives = self.directives, node.directives
self.visitchildren(node)
self.directives = old_directives
return node
# The following four functions track imports and cimports that
# begin with "cython"
def is_cython_directive(self, name):
......@@ -1181,17 +1187,36 @@ class InterpretCompilerDirectives(CythonTransform):
else:
assert False
def visit_with_directives(self, node, directives):
def visit_with_directives(self, node, directives, contents_directives):
# contents_directives may be None
if not directives:
assert not contents_directives
return self.visit_Node(node)
old_directives = self.directives
new_directives = Options.copy_inherited_directives(old_directives, **directives)
if contents_directives is not None:
new_contents_directives = Options.copy_inherited_directives(
old_directives, **contents_directives)
else:
new_contents_directives = new_directives
if new_directives == old_directives:
return self.visit_Node(node)
self.directives = new_directives
if (contents_directives is not None and
new_contents_directives != new_directives):
# we need to wrap the node body in a compiler directives node
node.body = Nodes.StatListNode(
node.body.pos,
stats=[
Nodes.CompilerDirectivesNode(
node.body.pos,
directives=new_contents_directives,
body=node.body)
]
)
retbody = self.visit_Node(node)
self.directives = old_directives
......@@ -1200,13 +1225,14 @@ class InterpretCompilerDirectives(CythonTransform):
return Nodes.CompilerDirectivesNode(
pos=retbody.pos, body=retbody, directives=new_directives)
# Handle decorators
def visit_FuncDefNode(self, node):
directives = self._extract_directives(node, 'function')
return self.visit_with_directives(node, directives)
directives, contents_directives = self._extract_directives(node, 'function')
return self.visit_with_directives(node, directives, contents_directives)
def visit_CVarDefNode(self, node):
directives = self._extract_directives(node, 'function')
directives, _ = self._extract_directives(node, 'function')
for name, value in directives.items():
if name == 'locals':
node.directive_locals = value
......@@ -1215,23 +1241,28 @@ class InterpretCompilerDirectives(CythonTransform):
node.pos,
"Cdef functions can only take cython.locals(), "
"staticmethod, or final decorators, got %s." % name))
return self.visit_with_directives(node, directives)
return self.visit_with_directives(node, directives, contents_directives=None)
def visit_CClassDefNode(self, node):
directives = self._extract_directives(node, 'cclass')
return self.visit_with_directives(node, directives)
directives, contents_directives = self._extract_directives(node, 'cclass')
return self.visit_with_directives(node, directives, contents_directives)
def visit_CppClassNode(self, node):
directives = self._extract_directives(node, 'cppclass')
return self.visit_with_directives(node, directives)
directives, contents_directives = self._extract_directives(node, 'cppclass')
return self.visit_with_directives(node, directives, contents_directives)
def visit_PyClassDefNode(self, node):
directives = self._extract_directives(node, 'class')
return self.visit_with_directives(node, directives)
directives, contents_directives = self._extract_directives(node, 'class')
return self.visit_with_directives(node, directives, contents_directives)
def _extract_directives(self, node, scope_name):
"""
Returns two dicts - directives applied to this function/class
and directives applied to its contents. They aren't always the
same (since e.g. cfunc should not be applied to inner functions)
"""
if not node.decorators:
return {}
return {}, {}
# Split the decorators into two lists -- real decorators and directives
directives = []
realdecs = []
......@@ -1266,8 +1297,8 @@ class InterpretCompilerDirectives(CythonTransform):
node.decorators = realdecs[::-1] + both[::-1]
# merge or override repeated directives
optdict = {}
for directive in directives:
name, value = directive
contents_optdict = {}
for name, value in directives:
if name in optdict:
old_value = optdict[name]
# keywords and arg lists can be merged, everything
......@@ -1280,7 +1311,9 @@ class InterpretCompilerDirectives(CythonTransform):
optdict[name] = value
else:
optdict[name] = value
return optdict
if name not in Options.immediate_decorator_directives:
contents_optdict[name] = value
return optdict, contents_optdict
# Handle with-statements
def visit_WithStatNode(self, node):
......@@ -1299,7 +1332,7 @@ class InterpretCompilerDirectives(CythonTransform):
if self.check_directive_scope(node.pos, name, 'with statement'):
directive_dict[name] = value
if directive_dict:
return self.visit_with_directives(node.body, directive_dict)
return self.visit_with_directives(node.body, directive_dict, contents_directives=None)
return self.visit_Node(node)
......
......@@ -844,6 +844,8 @@ class PrintTree(TreeVisitor):
result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
elif isinstance(node, Nodes.DefNode):
result += "(name=\"%s\")" % node.name
elif isinstance(node, Nodes.CFuncDefNode):
result += "(name=\"%s\")" % node.declared_name()
elif isinstance(node, ExprNodes.AttributeNode):
result += "(type=%s, attribute=\"%s\")" % (repr(node.type), node.attribute)
elif isinstance(node, (ExprNodes.ConstNode, ExprNodes.PyConstNode)):
......
......@@ -319,6 +319,25 @@ def test_cdef_nogil(x):
return result
@cython.cfunc
@cython.inline
def has_inner_func(x):
# the inner function must remain a Python function
# (and inline must not be applied to it)
@cython.test_fail_if_path_exists("//CFuncDefNode")
def inner():
return x
return inner
def test_has_inner_func():
"""
>>> test_has_inner_func()
1
"""
return has_inner_func(1)()
@cython.locals(counts=cython.int[10], digit=cython.int)
def count_digits_in_carray(digits):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment