Commit cc11ca32 authored by Xavier Thompson's avatar Xavier Thompson

Make cypclass wrappers mirror single inheritance of underlying cypclass

parent b35dcc29
...@@ -102,6 +102,7 @@ class CypclassWrapperInjection(VisitorTransform): ...@@ -102,6 +102,7 @@ class CypclassWrapperInjection(VisitorTransform):
def __call__(self, root): def __call__(self, root):
self.cypclass_wrappers_stack = [] self.cypclass_wrappers_stack = []
self.nesting_stack = [] self.nesting_stack = []
self.module_scope = root.scope
return super(CypclassWrapperInjection, self).__call__(root) return super(CypclassWrapperInjection, self).__call__(root)
def visit_Node(self, node): def visit_Node(self, node):
...@@ -154,17 +155,45 @@ class CypclassWrapperInjection(VisitorTransform): ...@@ -154,17 +155,45 @@ class CypclassWrapperInjection(VisitorTransform):
if not node_has_suite: if not node_has_suite:
return None return None
if len(node.base_classes) > 1:
return None
# TODO: take nesting into account for the name # TODO: take nesting into account for the name
cclass_name = EncodedString("%s_cyp_wrapper" % node.name) cclass_name = EncodedString("%s_cyp_wrapper" % node.name)
from .ExprNodes import TupleNode from .ExprNodes import TupleNode
cclass_bases = TupleNode(node.pos, args=[]) bases_args = []
if node.base_classes:
first_base = node.base_classes[0]
if isinstance(first_base, Nodes.CSimpleBaseTypeNode) and first_base.templates is None:
first_base_name = first_base.name
builtin_entry = self.module_scope.lookup(first_base_name)
if builtin_entry is not None:
return
wrapped_first_base = Nodes.CSimpleBaseTypeNode(
first_base.pos,
name = "%s_cyp_wrapper" % first_base_name,
module_path = [],
is_basic_c_type = first_base.is_basic_c_type,
signed = first_base.signed,
complex = first_base.complex,
longness = first_base.longness,
is_self_arg = first_base.is_self_arg,
templates = None
)
bases_args.append(wrapped_first_base)
cclass_bases = TupleNode(node.pos, args=bases_args)
# the underlying cyobject must come first thing after PyObject_HEAD in the memory layout # the underlying cyobject must come first thing after PyObject_HEAD in the memory layout
# long term, only the base class will declare the underlying attribute # long term, only the base class will declare the underlying attribute
underlying_cyobject = self.synthesize_underlying_cyobject_attribute(node)
stats = [underlying_cyobject] stats = []
if not bases_args:
underlying_cyobject = self.synthesize_underlying_cyobject_attribute(node)
stats.append(underlying_cyobject)
cclass_body = Nodes.StatListNode(pos=node.pos, stats=stats) cclass_body = Nodes.StatListNode(pos=node.pos, stats=stats)
cclass_doc = EncodedString("Python Object wrapper for underlying cypclass %s" % node.name) cclass_doc = EncodedString("Python Object wrapper for underlying cypclass %s" % node.name)
wrapper = Nodes.CypclassWrapperDefNode( wrapper = Nodes.CypclassWrapperDefNode(
...@@ -182,11 +211,11 @@ class CypclassWrapperInjection(VisitorTransform): ...@@ -182,11 +211,11 @@ class CypclassWrapperInjection(VisitorTransform):
in_pxd = node.in_pxd, in_pxd = node.in_pxd,
doc = cclass_doc, doc = cclass_doc,
body = cclass_body, body = cclass_body,
wrapped_cypclass = node wrapped_cypclass = node,
) )
return wrapper return wrapper
def synthesize_underlying_cyobject_attribute(self, node): def synthesize_underlying_cyobject_attribute(self, node):
nested_names = [node.name for node in self.nesting_stack] nested_names = [node.name for node in self.nesting_stack]
...@@ -717,9 +746,10 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry ...@@ -717,9 +746,10 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry
# initialise PyObject fields # initialise PyObject fields
if is_new_return_type and type.wrapper_type: if is_new_return_type and type.wrapper_type:
objstruct_cname = type.wrapper_type.objstruct_cname objstruct_cname = type.wrapper_type.objstruct_cname
cclass_wrapper_base = type.wrapped_base_type.wrapper_type
code.putln("if(self) {") code.putln("if(self) {")
code.putln("%s * wrapper = new %s();" % (objstruct_cname, objstruct_cname)) code.putln("%s * wrapper = new %s();" % (objstruct_cname, objstruct_cname))
code.putln("wrapper->nogil_cyobject = self;") code.putln("((%s *)wrapper)->nogil_cyobject = self;" % cclass_wrapper_base.objstruct_cname)
code.putln("PyObject * wrapper_as_py = (PyObject *) wrapper;") code.putln("PyObject * wrapper_as_py = (PyObject *) wrapper;")
code.putln("wrapper_as_py->ob_refcnt = 0;") code.putln("wrapper_as_py->ob_refcnt = 0;")
code.putln("wrapper_as_py->ob_type = %s;" % type.wrapper_type.typeptr_cname) code.putln("wrapper_as_py->ob_type = %s;" % type.wrapper_type.typeptr_cname)
......
...@@ -5323,7 +5323,7 @@ class CClassDefNode(ClassDefNode): ...@@ -5323,7 +5323,7 @@ class CClassDefNode(ClassDefNode):
class CypclassWrapperDefNode(CClassDefNode): class CypclassWrapperDefNode(CClassDefNode):
# wrapped_cypclass CppClassNode The wrapped cypclass # wrapped_cypclass CppClassNode The wrapped cypclass
is_cyp_wrapper = 1 is_cyp_wrapper = 1
...@@ -5416,7 +5416,15 @@ class CypclassWrapperDefNode(CClassDefNode): ...@@ -5416,7 +5416,15 @@ class CypclassWrapperDefNode(CClassDefNode):
# > access the method of the underlying cyobject from the self argument of the wrapper method # > access the method of the underlying cyobject from the self argument of the wrapper method
underlying_obj = ExprNodes.AttributeNode(cfunc_method.pos, obj=self_obj, attribute=underlying_name) underlying_obj = ExprNodes.AttributeNode(cfunc_method.pos, obj=self_obj, attribute=underlying_name)
cfunc = ExprNodes.AttributeNode(cfunc_method.pos, obj=underlying_obj, attribute=cfunc_name) empty_declarator = CNameDeclaratorNode(cfunc_method.pos, name="", cname=None)
cast_underlying_obj = ExprNodes.TypecastNode(
cfunc_method.pos,
type = self.wrapped_cypclass.entry.type,
operand = underlying_obj,
typecheck = False
)
cfunc = ExprNodes.AttributeNode(cfunc_method.pos, obj=cast_underlying_obj, attribute=cfunc_name)
# > call to the underlying method # > call to the underlying method
c_call = ExprNodes.SimpleCallNode( c_call = ExprNodes.SimpleCallNode(
......
...@@ -3900,15 +3900,30 @@ class CppClassType(CType): ...@@ -3900,15 +3900,30 @@ class CppClassType(CType):
class CypClassType(CppClassType): class CypClassType(CppClassType):
# lock_mode string (tri-state: "nolock"/"checklock"/"autolock") # lock_mode string (tri-state: "nolock"/"checklock"/"autolock")
# wrapper_type PyExtensionType or None the type of the cclass wrapper # wrapper_type PyExtensionType or None the type of the cclass wrapper
# wrapped_base_type CypClassType or None the type of the oldest wrapped cypclass base
is_cyp_class = 1 is_cyp_class = 1
to_py_function = None to_py_function = None
def __init__(self, name, scope, cname, base_classes, templates=None, template_type=None, nogil=0, lock_mode=None, activable=False): def __init__(self, name, scope, cname, base_classes, templates=None, template_type=None, nogil=0, lock_mode=None, activable=False):
CppClassType.__init__(self, name, scope, cname, base_classes, templates, template_type, nogil) CppClassType.__init__(self, name, scope, cname, base_classes, templates, template_type, nogil)
if base_classes:
self.find_wrapped_base_type(base_classes)
self.lock_mode = lock_mode if lock_mode else "autolock" self.lock_mode = lock_mode if lock_mode else "autolock"
self.activable = activable self.activable = activable
self.wrapper_type = None # set during self.wrapper_type = None
self.wrapped_base_type = None
def find_wrapped_base_type(self, base_classes):
first_wrapped_cypclass_base = None
for base_type in base_classes:
if base_type.is_cyp_class and base_type.wrapper_type:
first_wrapped_cypclass_base = base_type
break
if first_wrapped_cypclass_base:
self.wrapped_base_type = first_wrapped_cypclass_base.wrapped_base_type
else:
self.wrapped_base_type = self
# allow conversion to Python only when there is a wrapper type # allow conversion to Python only when there is a wrapper type
def can_coerce_to_pyobject(self, env): def can_coerce_to_pyobject(self, env):
...@@ -3923,7 +3938,7 @@ class CypClassType(CppClassType): ...@@ -3923,7 +3938,7 @@ class CypClassType(CppClassType):
def create_from_py_utility_code(self, env): def create_from_py_utility_code(self, env):
if not self.wrapper_type: if not self.wrapper_type:
return False return False
wrapper_objstruct = self.wrapper_type.objstruct_cname wrapper_objstruct = self.wrapped_base_type.wrapper_type.objstruct_cname
underlying_type_name = self.cname underlying_type_name = self.cname
self.from_py_function = "__Pyx_PyObject_AsCyObject<%s, %s>" % (wrapper_objstruct, underlying_type_name) self.from_py_function = "__Pyx_PyObject_AsCyObject<%s, %s>" % (wrapper_objstruct, underlying_type_name)
return True return True
......
...@@ -697,6 +697,8 @@ class Scope(object): ...@@ -697,6 +697,8 @@ class Scope(object):
entry.already_declared_here() entry.already_declared_here()
else: else:
entry.type.base_classes = base_classes entry.type.base_classes = base_classes
if cypclass:
entry.type.find_wrapped_base_type(base_classes)
if templates or entry.type.templates: if templates or entry.type.templates:
if templates != entry.type.templates: if templates != entry.type.templates:
error(pos, "Template parameters do not match previous declaration") error(pos, "Template parameters do not match previous declaration")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment