Commit a66de9b6 authored by Xavier Thompson's avatar Xavier Thompson

Cast underlying cyobject to base type and back safely instead of void *

parent e01817ce
...@@ -237,12 +237,20 @@ class CypclassWrapperInjection(CythonTransform): ...@@ -237,12 +237,20 @@ class CypclassWrapperInjection(CythonTransform):
return wrapper return wrapper
def synthesize_underlying_cyobject_attribute(self, node): def synthesize_underlying_cyobject_attribute(self, node):
base_type = node.entry.type.wrapped_base_type
void_type_node = Nodes.CSimpleBaseTypeNode( nesting_path = []
outer_scope = base_type.scope.outer_scope
while outer_scope and not outer_scope.is_module_scope:
nesting_path.append(outer_scope.name)
outer_scope = outer_scope.outer_scope
nesting_path.reverse()
base_type_node = Nodes.CSimpleBaseTypeNode(
node.pos, node.pos,
name = "void", name = base_type.name,
module_path = [], module_path = nesting_path,
is_basic_c_type = 1, is_basic_c_type = 0,
signed = 1, signed = 1,
complex = 0, complex = 0,
longness = 0, longness = 0,
...@@ -251,12 +259,11 @@ class CypclassWrapperInjection(CythonTransform): ...@@ -251,12 +259,11 @@ class CypclassWrapperInjection(CythonTransform):
) )
underlying_name_declarator = Nodes.CNameDeclaratorNode(node.pos, name=underlying_name, cname=None) underlying_name_declarator = Nodes.CNameDeclaratorNode(node.pos, name=underlying_name, cname=None)
underlying_name_declarator = Nodes.CPtrDeclaratorNode(node.pos, base=underlying_name_declarator)
underlying_cyobject = Nodes.CVarDefNode( underlying_cyobject = Nodes.CVarDefNode(
pos = node.pos, pos = node.pos,
visibility = 'private', visibility = 'private',
base_type = void_type_node, base_type = base_type_node,
declarators = [underlying_name_declarator], declarators = [underlying_name_declarator],
in_pxd = node.in_pxd, in_pxd = node.in_pxd,
doc = None, doc = None,
...@@ -765,11 +772,14 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry ...@@ -765,11 +772,14 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry
# initialise PyObject fields # initialise PyObject fields
if is_new_return_type and type.wrapper_type: if is_new_return_type and type.wrapper_type:
chain_casted_self = "self"
for first_base in type.first_base_iter():
chain_casted_self = "static_cast<%s *>(%s)" % (first_base.empty_declaration_code(), chain_casted_self)
objstruct_cname = type.wrapper_type.objstruct_cname objstruct_cname = type.wrapper_type.objstruct_cname
cclass_wrapper_base = type.wrapped_base_type.wrapper_type cclass_wrapper_base = type.wrapped_base_type.wrapper_type
code.putln("if(self) {") code.putln("if(self) {")
code.putln("%s * wrapper = new %s();" % (objstruct_cname, objstruct_cname)) code.putln("%s * wrapper = new %s();" % (objstruct_cname, objstruct_cname))
code.putln("((%s *)wrapper)->nogil_cyobject = self;" % cclass_wrapper_base.objstruct_cname) code.putln("((%s *)wrapper)->nogil_cyobject = %s;" % (cclass_wrapper_base.objstruct_cname, chain_casted_self))
code.putln("PyObject * wrapper_as_py = (PyObject *) wrapper;") code.putln("PyObject * wrapper_as_py = (PyObject *) wrapper;")
code.putln("wrapper_as_py->ob_refcnt = 0;") code.putln("wrapper_as_py->ob_refcnt = 0;")
code.putln("wrapper_as_py->ob_type = %s;" % type.wrapper_type.typeptr_cname) code.putln("wrapper_as_py->ob_type = %s;" % type.wrapper_type.typeptr_cname)
......
...@@ -1354,7 +1354,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1354,7 +1354,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# internal classes (should) never need None inits, normal zeroing will do # internal classes (should) never need None inits, normal zeroing will do
py_attrs = [] py_attrs = []
# cyp_class attributes should not be treated as normal cpp_class attributes # unlike normal cpp_class attributes, cyp_class attributes are always held as pointers
cpp_class_attrs = [entry for entry in scope.var_entries cpp_class_attrs = [entry for entry in scope.var_entries
if entry.type.is_cpp_class and not entry.type.is_cyp_class] if entry.type.is_cpp_class and not entry.type.is_cyp_class]
...@@ -1534,7 +1534,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1534,7 +1534,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
_, (py_attrs, _, memoryview_slices) = scope.get_refcounted_entries() _, (py_attrs, _, memoryview_slices) = scope.get_refcounted_entries()
# cyp_class attributes should not be treated as normal cpp_class attributes # unlike normal cpp_class attributes, cyp_class attributes are always held as pointers
cpp_class_attrs = [entry for entry in scope.var_entries cpp_class_attrs = [entry for entry in scope.var_entries
if entry.type.is_cpp_class and not entry.type.is_cyp_class] if entry.type.is_cpp_class and not entry.type.is_cyp_class]
......
...@@ -5432,11 +5432,15 @@ class CypclassWrapperDefNode(CClassDefNode): ...@@ -5432,11 +5432,15 @@ class CypclassWrapperDefNode(CClassDefNode):
# > access the underlying cyobject from the self argument of the wrapper method # > access the underlying cyobject from the self argument of the wrapper method
underlying_obj = ExprNodes.AttributeNode(cfunc_method.pos, obj=self_obj, attribute=underlying_name) underlying_obj = ExprNodes.AttributeNode(cfunc_method.pos, obj=self_obj, attribute=underlying_name)
empty_declarator = CNameDeclaratorNode(cfunc_method.pos, name="", cname=None)
# > cast the underlying object back to this type
underlying_type = self.wrapped_cypclass.entry.type
cast_operation = underlying_obj
for derived_type in underlying_type.first_base_rev_iter():
cast_operation = ExprNodes.TypecastNode( cast_operation = ExprNodes.TypecastNode(
cfunc_method.pos, cfunc_method.pos,
type = self.wrapped_cypclass.entry.type, type = derived_type,
operand = underlying_obj, operand = cast_operation,
typecheck = False typecheck = False
) )
......
...@@ -3934,6 +3934,20 @@ class CypClassType(CppClassType): ...@@ -3934,6 +3934,20 @@ class CypClassType(CppClassType):
self.wrapped_base_type = base_type.wrapped_base_type self.wrapped_base_type = base_type.wrapped_base_type
break break
# iterate over the chain of first wrapped bases until the oldest wrapped base is reached
def first_base_iter(self):
type_item = self
while type_item is not self.wrapped_base_type:
type_item = type_item.first_wrapped_base
yield type_item
# iterate down the chain of first wrapped bases until this type is reached
def first_base_rev_iter(self):
if self is not self.wrapped_base_type:
for t in self.first_wrapped_base.first_base_rev_iter():
yield t
yield self
# allow conversion to Python only when there is a wrapper type # allow conversion to Python only when there is a wrapper type
def can_coerce_to_pyobject(self, env): def can_coerce_to_pyobject(self, env):
return self.wrapper_type is not None return self.wrapper_type is not None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment