Commit 3baefece authored by Xavier Thompson's avatar Xavier Thompson

Move wrapping logic into visitors

parent e3c3ebbf
...@@ -55,6 +55,7 @@ from ..Utils import open_new_file, replace_suffix, decode_filename, build_hex_ve ...@@ -55,6 +55,7 @@ from ..Utils import open_new_file, replace_suffix, decode_filename, build_hex_ve
from .Code import UtilityCode, IncludeCode from .Code import UtilityCode, IncludeCode
from .StringEncoding import EncodedString from .StringEncoding import EncodedString
from .Pythran import has_np_pythran from .Pythran import has_np_pythran
from .Visitor import VisitorTransform, CythonTransform
...@@ -83,10 +84,231 @@ def cypclass_iter_scopes(scope): ...@@ -83,10 +84,231 @@ def cypclass_iter_scopes(scope):
yield e, s yield e, s
#
# Visitor for wrapper cclass injection
#
# - Insert additional cclass wrapper nodes by returning lists of nodes
# => must run after NormalizeTree (otherwise single statements might not be held in a list)
#
class CypclassWrapperInjection(VisitorTransform):
"""
Synthesize and insert a wrapper c class at the module level for each cypclass that supports it.
- Even nested cypclasses have their wrapper at the module level.
- Must run after NormalizeTree.
"""
def __call__(self, root):
self.cypclass_wrappers_stack = []
self.nesting_stack = []
return super(CypclassWrapperInjection, self).__call__(root)
def visit_Node(self, node):
self.visitchildren(node)
return node
# TODO: can cypclasses be nested in something other than this ?
def visit_CStructOrUnionDefNode(self, node):
self.nesting_stack.append(node)
self.visitchildren(node)
self.nesting_stack.pop()
top_level = not self.nesting_stack
if top_level:
return_nodes = [node]
for wrapper in self.cypclass_wrappers_stack:
return_nodes.append(wrapper)
self.cypclass_wrappers_stack.clear()
return return_nodes
return node
def visit_CppClassNode(self, node):
if node.cypclass:
wrapper = self.synthesise_cypclass_wrapper_cclass(node)
if wrapper is not None:
self.cypclass_wrappers_stack.append(wrapper)
# visit children and return all wrappers when at the top level
return self.visit_CStructOrUnionDefNode(node)
def find_module_scope(self, scope):
module_scope = scope
while module_scope and not module_scope.is_module_scope:
module_scope = module_scope.outer_scope
return module_scope
def iter_wrapper_methods(self, wrapper_cclass):
for node in wrapper_cclass.body.stats:
if isinstance(node, Nodes.DefNode):
yield node
def synthesise_cypclass_wrapper_cclass(self, node):
if node.templates:
# Python wrapper for templated cypclasses not supported yet
# this is signaled to the compiler by not doing what is below
return None
# whether the is declared with ':' and a suite, or just a forward declaration
node_has_suite = node.attributes is not None
if not node_has_suite:
return None
# Todo: take nesting into account
cclass_name = EncodedString("%s_cyp_wrapper" % node.name)
from .ExprNodes import TupleNode
cclass_bases = TupleNode(node.pos, args=[])
# the underlying cyobject must come first thing after PyObject_HEAD in the memory layout
# long term, only the base class will declare the underlying attribute
underlying_cyobject = self.synthesise_underlying_cyobject_attribute(node)
stats = [underlying_cyobject]
for attr in node.attributes:
if isinstance(attr, Nodes.CFuncDefNode):
py_method_wrapper = self.synthesise_cypclass_method_wrapper(attr)
if py_method_wrapper:
stats.append(py_method_wrapper)
cclass_body = Nodes.StatListNode(pos=node.pos, stats=stats)
cclass_doc = EncodedString("Python Object wrapper for underlying cypclass %s" % node.name)
wrapper = Nodes.CClassDefNode(
node.pos,
visibility = 'private',
typedef_flag = 0,
api = 0,
module_name = "",
class_name = cclass_name,
as_name = cclass_name,
bases = cclass_bases,
objstruct_name = None,
typeobj_name = None,
check_size = None,
in_pxd = node.in_pxd,
doc = cclass_doc,
body = cclass_body,
is_cyp_wrapper = 1
)
node.cyp_wrapper = wrapper
return wrapper
underlying_name = "nogil_cyobject"
def synthesise_underlying_cyobject_attribute(self, node):
nested_names = [node.name for node in self.nesting_stack]
underlying_base_type = Nodes.CSimpleBaseTypeNode(
node.pos,
name = node.name,
module_path = nested_names,
is_basic_c_type = 0,
signed = 1,
complex = 0,
longness = 0,
is_self_arg = 0,
templates = None
)
underlying_name_declarator = Nodes.CNameDeclaratorNode(node.pos, name=self.underlying_name, cname=None)
underlying_cyobject = Nodes.CVarDefNode(
pos = node.pos,
visibility = 'private',
base_type = underlying_base_type,
declarators = [underlying_name_declarator],
in_pxd = node.in_pxd,
doc = None,
api = 0,
modifiers = [],
overridable = 0
)
return underlying_cyobject
# cypclass entries that take on a special name: reverse mapping
cycplass_special_entry_names = {
"<init>": "__init__"
}
def synthesise_cypclass_method_wrapper(self, cfunc_method):
return
if cfunc_method.is_static_method:
return # for now skip static methods
cfunc_declarator = cfunc_method.cfunc_declarator
py_name = cfunc_method.entry.name
# transform e.g. <init> back into __init__
try:
py_name = self.cycplass_special_entry_names[py_name]
except KeyError:
pass
py_args = cfunc_declarator.args
py_doc = cfunc_method.doc
arg_names = [arg.name for arg in py_args]
# C++ methods have an implict 'this', so the 'self' argument is skipped in the declarator
skipped_self = cfunc_method.cfunc_declarator.skipped_self
if not skipped_self:
print("Non static cypclass method without self argument ... ??")
# should not happen
return
from . import ExprNodes
self_name, self_type, self_pos, self_arg = skipped_self
type_entry = self_type.entry
type_arg = ExprNodes.NameNode(node.pos, name=type_entry.name)
type_arg.entry = type_entry
cfunc = ExprNodes.AttributeNode(cfunc_method.pos, obj=type_arg, attribute=self.underlying_name)
c_call = ExprNodes.SimpleCallNode(
cfunc_method.pos,
function=cfunc,
args=[ExprNodes.NameNode(cfunc_method.pos, name=n) for n in arg_names]
)
py_body = ReturnStatNode(pos=cfunc_method.pos, return_type=PyrexTypes.py_object_type, value=c_call)
return Nodes.DefNode(
cfunc_method.pos,
name = py_name,
args = py_args,
star_arg = None,
starstar_arg = None,
doc = py_doc,
body = py_body,
decorators = None,
is_async_def = 0,
return_type_annotation = None
)
#
# Post declaration analysis visitor for wrapped cypclasses
#
# - Associate the type of the wrapper cclass to the wrapped type
# => must run after AnalyseDeclarationsTransform
#
class CypclassPostDeclarationsVisitor(CythonTransform):
"""
Associate the type of each wrapper cclass to the wrapped type.
- Must run after the declarations analysis phase.
"""
# associate the type of the wrapper cclass to the type of the wrapped cypclass
def visit_CppClassNode(self, node):
if node.cypclass and node.cyp_wrapper:
node.entry.type.wrapper_type = node.cyp_wrapper.entry.type
self.visitchildren(node)
return node
#
# Cypclass code generation
# #
# Cypclass generation, originally authored by Gwenaël Samain, moved here from ModuleNode.py # - originally authored by Gwenaël Samain
# - moved here from ModuleNode.py
# #
def generate_cyp_class_deferred_definitions(env, code, definition): def generate_cyp_class_deferred_definitions(env, code, definition):
...@@ -580,7 +802,7 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry ...@@ -580,7 +802,7 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry
# initialise PyObject fields # initialise PyObject fields
if is_new_return_type and type.wrapper_type: if is_new_return_type and type.wrapper_type:
code.putln("if(self) {") code.putln("if(self) {")
code.putln("self->ob_cypyobject = new %s();" % type.wrapper_type.module_name ) code.putln("self->ob_cypyobject = new CyPyObject(); // fow now") # % type.wrapper_type.objstruct_cname
code.putln("self->ob_cypyobject->ob_refcnt = 0;") code.putln("self->ob_cypyobject->ob_refcnt = 0;")
code.putln("self->ob_cypyobject->ob_type = %s;" % type.wrapper_type.typeptr_cname) code.putln("self->ob_cypyobject->ob_type = %s;" % type.wrapper_type.typeptr_cname)
code.putln("}") code.putln("}")
......
...@@ -1507,11 +1507,13 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1507,11 +1507,13 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
# base_classes [CBaseTypeNode] # base_classes [CBaseTypeNode]
# templates [(string, bool)] or None # templates [(string, bool)] or None
# decorators [DecoratorNode] or None # decorators [DecoratorNode] or None
# cypclass boolean
# cyp_wrapper CClassDefNode or None # cyp_wrapper CClassDefNode or None
decorators = None decorators = None
cyp_wrapper = None cyp_wrapper = None
# child_attrs = ['attributes', 'cyp_wrapper']
def declare(self, env): def declare(self, env):
if not env.is_cpp(): if not env.is_cpp():
...@@ -1595,169 +1597,13 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1595,169 +1597,13 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
func.template_declaration = "template <typename %s>" % ", typename ".join(template_names) func.template_declaration = "template <typename %s>" % ", typename ".join(template_names)
self.body = StatListNode(self.pos, stats=defined_funcs) self.body = StatListNode(self.pos, stats=defined_funcs)
self.scope = scope self.scope = scope
if self.cypclass:
self.declare_cypclass_wrapper_cclass(env)
def find_module_scope(self, scope):
module_scope = scope
while module_scope and not module_scope.is_module_scope:
module_scope = module_scope.outer_scope
return module_scope
def declare_cypclass_wrapper_cclass(self, env):
module_scope = self.find_module_scope(env)
cclass_name = EncodedString("%s_cyp_wrapper" % self.name)
from .ExprNodes import TupleNode
cclass_bases = TupleNode(self.pos, args=[])
if self.templates:
# Python wrapper for templated cypclasses not supported yet
# this is signaled to the compiler by not doing what is below
return
if self.attributes is not None:
# the underlying cyobject must come first thing after PyObject_HEAD in the memory layout
# long term, only the base class will declare the underlying attribute
underlying_cyobject = self.synthesise_underlying_cyobject_attribute(env)
stats = [underlying_cyobject]
for attr in self.attributes:
if isinstance(attr, CFuncDefNode):
py_method_wrapper = self.synthesise_cypclass_method_wrapper(attr)
if py_method_wrapper:
stats.append(py_method_wrapper)
cclass_body = StatListNode(pos=self.pos, stats=stats)
else:
cclass_body = None
wrapper = CClassDefNode(
self.pos,
visibility = 'private',
typedef_flag = 0,
api = 0,
module_name = "",
class_name = cclass_name,
as_name = cclass_name,
bases = cclass_bases,
objstruct_name = None,
typeobj_name = None,
check_size = None,
in_pxd = self.in_pxd,
doc = EncodedString("Python Object wrapper for underlying cypclass %s" % self.name),
body = cclass_body,
is_cyp_wrapper = 1
)
if module_scope:
wrapper.declare(module_scope)
if self.scope:
wrapper.analyse_declarations(module_scope)
self._cyp_wrapper_analysed = 1
self.cyp_wrapper = wrapper
self.entry.type.wrapper_type = wrapper.entry.type
wrapper.entry.type.is_cyp_wrapper = 1
underlying_name = "nogil_cyobject"
def synthesise_underlying_cyobject_attribute(self, env):
nested_path = [] if env.is_module_scope else env.qualified_name.split(".")
underlying_base_type = CSimpleBaseTypeNode(
self.pos,
name = self.name,
module_path = nested_path,
is_basic_c_type = 0,
signed = 1,
complex = 0,
longness = 0,
is_self_arg = 0,
templates = None
)
underlying_name_declarator = CNameDeclaratorNode(self.pos, name=self.underlying_name, cname=None)
underlying_cyobject = CVarDefNode(
pos = self.pos,
visibility = 'private',
base_type = underlying_base_type,
declarators = [underlying_name_declarator],
in_pxd = self.in_pxd,
doc = None,
api = 0,
modifiers = [],
overridable = 0
)
return underlying_cyobject
# cypclass entries that take on a special name: reverse mapping
cycplass_special_entry_names = {
"<init>": "__init__"
}
def synthesise_cypclass_method_wrapper(self, cfunc_method):
if cfunc_method.is_static_method:
return # for now skip static methods
cfunc_declarator = cfunc_method.cfunc_declarator
py_name = cfunc_method.entry.name
# transform e.g. <init> back into __init__
try:
py_name = self.cycplass_special_entry_names[py_name]
except KeyError:
pass
py_args = cfunc_declarator.args
py_doc = cfunc_method.doc
arg_names = [arg.name for arg in py_args]
# C++ methods have an implict 'this', so the 'self' argument is skipped in the declarator
skipped_self = cfunc_method.cfunc_declarator.skipped_self
if not skipped_self:
print("Non static cypclass method without self argument ... ??")
# should not happen
return
from . import ExprNodes
self_name, self_type, self_pos, self_arg = skipped_self
type_entry = self_type.entry
type_arg = ExprNodes.NameNode(self.pos, name=type_entry.name)
type_arg.entry = type_entry
cfunc = ExprNodes.AttributeNode(cfunc_method.pos, obj=type_arg, attribute=self.underlying_name)
c_call = ExprNodes.SimpleCallNode(
cfunc_method.pos,
function=cfunc,
args=[ExprNodes.NameNode(cfunc_method.pos, name=n) for n in arg_names]
)
py_body = ReturnStatNode(pos=cfunc_method.pos, return_type=PyrexTypes.py_object_type, value=c_call)
return DefNode(
cfunc_method.pos,
name = py_name,
args = py_args,
star_arg = None,
starstar_arg = None,
doc = py_doc,
body = py_body,
decorators = None,
is_async_def = 0,
return_type_annotation = None
)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body = self.body.analyse_expressions(self.entry.type.scope) self.body = self.body.analyse_expressions(self.entry.type.scope)
if self.cyp_wrapper:
self.cyp_wrapper.analyse_expressions(env.global_scope())
return self return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(self.entry.type.scope, code) self.body.generate_function_definitions(self.entry.type.scope, code)
if self.cyp_wrapper:
self.cyp_wrapper.generate_function_definitions(env.global_scope(), code)
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
...@@ -5109,7 +4955,7 @@ class CClassDefNode(ClassDefNode): ...@@ -5109,7 +4955,7 @@ class CClassDefNode(ClassDefNode):
check_size = None check_size = None
decorators = None decorators = None
shadow = False shadow = False
is_cyp_wrapper = False is_cyp_wrapper = 0
def buffer_defaults(self, env): def buffer_defaults(self, env):
if not hasattr(self, '_buffer_defaults'): if not hasattr(self, '_buffer_defaults'):
......
...@@ -140,6 +140,7 @@ def inject_utility_code_stage_factory(context): ...@@ -140,6 +140,7 @@ def inject_utility_code_stage_factory(context):
def create_pipeline(context, mode, exclude_classes=()): def create_pipeline(context, mode, exclude_classes=()):
assert mode in ('pyx', 'py', 'pxd') assert mode in ('pyx', 'py', 'pxd')
from .Visitor import PrintTree from .Visitor import PrintTree
from .CypclassWrapper import CypclassWrapperInjection, CypclassPostDeclarationsVisitor
from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
from .ParseTreeTransforms import ForwardDeclareTypes, InjectGilHandling, AnalyseDeclarationsTransform from .ParseTreeTransforms import ForwardDeclareTypes, InjectGilHandling, AnalyseDeclarationsTransform
from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes
...@@ -180,6 +181,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -180,6 +181,7 @@ def create_pipeline(context, mode, exclude_classes=()):
# compilation stage. # compilation stage.
stages = [ stages = [
NormalizeTree(context), NormalizeTree(context),
CypclassWrapperInjection(),
PostParse(context), PostParse(context),
_specific_post_parse, _specific_post_parse,
TrackNumpyAttributes(), TrackNumpyAttributes(),
...@@ -196,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -196,6 +198,7 @@ def create_pipeline(context, mode, exclude_classes=()):
ForwardDeclareTypes(context), ForwardDeclareTypes(context),
InjectGilHandling(), InjectGilHandling(),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
CypclassPostDeclarationsVisitor(context),
AutoTestDictTransform(context), AutoTestDictTransform(context),
EmbedSignature(context), EmbedSignature(context),
ReplacePropertyNode(context), ReplacePropertyNode(context),
......
...@@ -3910,7 +3910,7 @@ class CypClassType(CppClassType): ...@@ -3910,7 +3910,7 @@ class CypClassType(CppClassType):
self.activable = activable self.activable = activable
self.wrapper_type = None # set during self.wrapper_type = None # set during
# allow conversion to Python only when wrapping is supported # allow conversion to Python only when there is a wrapper type
def create_to_py_utility_code(self, env): def create_to_py_utility_code(self, env):
if not self.wrapper_type: if not self.wrapper_type:
return False return False
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment