Commit 3e25a388 authored by Robert Bradshaw's avatar Robert Bradshaw

Allow declaring C++ classes in Cython files.

This is necessary to use some C++ APIs, and ugly if not impossible
to work around using cname specifiers and external files.
parent ce8776b3
...@@ -2010,7 +2010,7 @@ class IteratorNode(ExprNode): ...@@ -2010,7 +2010,7 @@ class IteratorNode(ExprNode):
elif sequence_type.is_cpp_class: elif sequence_type.is_cpp_class:
begin = sequence_type.scope.lookup("begin") begin = sequence_type.scope.lookup("begin")
if begin is not None: if begin is not None:
return begin.type.base_type.return_type return begin.type.return_type
elif sequence_type.is_pyobject: elif sequence_type.is_pyobject:
return sequence_type return sequence_type
return py_object_type return py_object_type
...@@ -2022,25 +2022,23 @@ class IteratorNode(ExprNode): ...@@ -2022,25 +2022,23 @@ class IteratorNode(ExprNode):
begin = sequence_type.scope.lookup("begin") begin = sequence_type.scope.lookup("begin")
end = sequence_type.scope.lookup("end") end = sequence_type.scope.lookup("end")
if (begin is None if (begin is None
or not begin.type.is_ptr or not begin.type.is_cfunction
or not begin.type.base_type.is_cfunction or begin.type.args):
or begin.type.base_type.args):
error(self.pos, "missing begin() on %s" % self.sequence.type) error(self.pos, "missing begin() on %s" % self.sequence.type)
self.type = error_type self.type = error_type
return return
if (end is None if (end is None
or not end.type.is_ptr or not end.type.is_cfunction
or not end.type.base_type.is_cfunction or end.type.args):
or end.type.base_type.args):
error(self.pos, "missing end() on %s" % self.sequence.type) error(self.pos, "missing end() on %s" % self.sequence.type)
self.type = error_type self.type = error_type
return return
iter_type = begin.type.base_type.return_type iter_type = begin.type.return_type
if iter_type.is_cpp_class: if iter_type.is_cpp_class:
if env.lookup_operator_for_types( if env.lookup_operator_for_types(
self.pos, self.pos,
"!=", "!=",
[iter_type, end.type.base_type.return_type]) is None: [iter_type, end.type.return_type]) is None:
error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type) error(self.pos, "missing operator!= on result of begin() on %s" % self.sequence.type)
self.type = error_type self.type = error_type
return return
...@@ -2054,7 +2052,7 @@ class IteratorNode(ExprNode): ...@@ -2054,7 +2052,7 @@ class IteratorNode(ExprNode):
return return
self.type = iter_type self.type = iter_type
elif iter_type.is_ptr: elif iter_type.is_ptr:
if not (iter_type == end.type.base_type.return_type): if not (iter_type == end.type.return_type):
error(self.pos, "incompatible types for begin() and end()") error(self.pos, "incompatible types for begin() and end()")
self.type = iter_type self.type = iter_type
else: else:
...@@ -2234,7 +2232,7 @@ class NextNode(AtomicExprNode): ...@@ -2234,7 +2232,7 @@ class NextNode(AtomicExprNode):
if iterator_type.is_ptr or iterator_type.is_array: if iterator_type.is_ptr or iterator_type.is_array:
return iterator_type.base_type return iterator_type.base_type
elif iterator_type.is_cpp_class: elif iterator_type.is_cpp_class:
item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.base_type.return_type item_type = env.lookup_operator_for_types(self.pos, "*", [iterator_type]).type.return_type
if item_type.is_reference: if item_type.is_reference:
item_type = item_type.ref_base_type item_type = item_type.ref_base_type
return item_type return item_type
...@@ -2587,7 +2585,7 @@ class IndexNode(ExprNode): ...@@ -2587,7 +2585,7 @@ class IndexNode(ExprNode):
] ]
index_func = env.lookup_operator('[]', operands) index_func = env.lookup_operator('[]', operands)
if index_func is not None: if index_func is not None:
return index_func.type.base_type.return_type return index_func.type.return_type
# may be slicing or indexing, we don't know # may be slicing or indexing, we don't know
if base_type in (unicode_type, str_type): if base_type in (unicode_type, str_type):
......
...@@ -612,7 +612,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -612,7 +612,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type = entry.type type = entry.type
if type.is_typedef: # Must test this first! if type.is_typedef: # Must test this first!
pass pass
elif type.is_struct_or_union: elif type.is_struct_or_union or type.is_cpp_class:
self.generate_struct_union_predeclaration(entry, code) self.generate_struct_union_predeclaration(entry, code)
elif type.is_extension_type: elif type.is_extension_type:
self.generate_objstruct_predeclaration(type, code) self.generate_objstruct_predeclaration(type, code)
...@@ -627,6 +627,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -627,6 +627,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_enum_definition(entry, code) self.generate_enum_definition(entry, code)
elif type.is_struct_or_union: elif type.is_struct_or_union:
self.generate_struct_union_definition(entry, code) self.generate_struct_union_definition(entry, code)
elif type.is_cpp_class:
self.generate_cpp_class_definition(entry, code)
elif type.is_extension_type: elif type.is_extension_type:
self.generate_objstruct_definition(type, code) self.generate_objstruct_definition(type, code)
...@@ -666,6 +668,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -666,6 +668,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_struct_union_predeclaration(self, entry, code): def generate_struct_union_predeclaration(self, entry, code):
type = entry.type type = entry.type
if type.is_cpp_class and type.templates:
code.putln("template <class %s>" % ", class ".join([T.declaration_code("") for T in type.templates]))
code.putln(self.sue_predeclaration(type, type.kind, type.cname)) code.putln(self.sue_predeclaration(type, type.kind, type.cname))
def sue_header_footer(self, type, kind, name): def sue_header_footer(self, type, kind, name):
...@@ -709,6 +713,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -709,6 +713,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(" #pragma pack(pop)") code.putln(" #pragma pack(pop)")
code.putln("#endif") code.putln("#endif")
def generate_cpp_class_definition(self, entry, code):
code.mark_pos(entry.pos)
type = entry.type
scope = type.scope
if scope:
if type.templates:
code.putln("template <class %s>" % ", class ".join([T.declaration_code("") for T in type.templates]))
# Just let everything be public.
code.put("struct %s" % type.cname)
if type.base_classes:
base_class_decl = ", public ".join(
[base_class.declaration_code("") for base_class in type.base_classes])
code.put(" : public %s" % base_class_decl)
code.putln(" {")
for attr in scope.var_entries:
if attr.type.is_cfunction:
code.put("virtual ")
code.putln(
"%s;" %
attr.type.declaration_code(attr.cname))
code.putln("};")
def generate_enum_definition(self, entry, code): def generate_enum_definition(self, entry, code):
code.mark_pos(entry.pos) code.mark_pos(entry.pos)
type = entry.type type = entry.type
......
...@@ -1162,7 +1162,7 @@ class CStructOrUnionDefNode(StatNode): ...@@ -1162,7 +1162,7 @@ class CStructOrUnionDefNode(StatNode):
pass pass
class CppClassNode(CStructOrUnionDefNode): class CppClassNode(CStructOrUnionDefNode, BlockNode):
# name string # name string
# cname string or None # cname string or None
...@@ -1197,11 +1197,30 @@ class CppClassNode(CStructOrUnionDefNode): ...@@ -1197,11 +1197,30 @@ class CppClassNode(CStructOrUnionDefNode):
if self.entry is None: if self.entry is None:
return return
self.entry.is_cpp_class = 1 self.entry.is_cpp_class = 1
scope.class_namespace = self.entry.type.declaration_code("")
defined_funcs = []
if self.attributes is not None: if self.attributes is not None:
if self.in_pxd and not env.in_cinclude: if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for attr in self.attributes: for attr in self.attributes:
attr.analyse_declarations(scope) attr.analyse_declarations(scope)
if isinstance(attr, CFuncDefNode):
defined_funcs.append(attr)
self.body = StatListNode(self.pos, stats=defined_funcs)
self.scope = scope
def analyse_expressions(self, env):
self.body.analyse_expressions(self.entry.type.scope)
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(self.entry.type.scope, code)
def generate_execution_code(self, code):
self.body.generate_execution_code(code)
def annotate(self, code):
self.body.annotate(code)
class CEnumDefNode(StatNode): class CEnumDefNode(StatNode):
# name string or None # name string or None
...@@ -2122,7 +2141,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -2122,7 +2141,7 @@ class CFuncDefNode(FuncDefNode):
if cname is None: if cname is None:
cname = self.entry.func_cname cname = self.entry.func_cname
entity = type.function_header_code(cname, ', '.join(arg_decls)) entity = type.function_header_code(cname, ', '.join(arg_decls))
if self.entry.visibility == 'private': if self.entry.visibility == 'private' and '::' not in cname:
storage_class = "static " storage_class = "static "
else: else:
storage_class = "" storage_class = ""
......
...@@ -1640,6 +1640,12 @@ if VALUE is not None: ...@@ -1640,6 +1640,12 @@ if VALUE is not None:
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.env_stack[-1])
return node return node
def visit_CppClassNode(self, node):
if node.visibility == 'extern':
return None
else:
return self.visit_ClassDefNode(node)
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
# Create a wrapper node if needed. # Create a wrapper node if needed.
# We want to use the struct type information (so it can't happen # We want to use the struct type information (so it can't happen
......
...@@ -2513,8 +2513,6 @@ def p_cdef_statement(s, ctx): ...@@ -2513,8 +2513,6 @@ def p_cdef_statement(s, ctx):
error(pos, "Extension types cannot be declared cpdef") error(pos, "Extension types cannot be declared cpdef")
return p_c_class_definition(s, pos, ctx) return p_c_class_definition(s, pos, ctx)
elif s.sy == 'IDENT' and s.systring == 'cppclass': elif s.sy == 'IDENT' and s.systring == 'cppclass':
if ctx.visibility != 'extern':
error(pos, "C++ classes need to be declared extern")
return p_cpp_class_definition(s, pos, ctx) return p_cpp_class_definition(s, pos, ctx)
elif s.sy == 'IDENT' and s.systring in struct_enum_union: elif s.sy == 'IDENT' and s.systring in struct_enum_union:
if ctx.level not in ('module', 'module_pxd'): if ctx.level not in ('module', 'module_pxd'):
...@@ -2706,7 +2704,7 @@ def p_c_func_or_var_declaration(s, pos, ctx): ...@@ -2706,7 +2704,7 @@ def p_c_func_or_var_declaration(s, pos, ctx):
assignable = 1, nonempty = 1) assignable = 1, nonempty = 1)
declarator.overridable = ctx.overridable declarator.overridable = ctx.overridable
if s.sy == ':': if s.sy == ':':
if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd') and not ctx.templates: if ctx.level not in ('module', 'c_class', 'module_pxd', 'c_class_pxd', 'cpp_class') and not ctx.templates:
s.error("C function definition not allowed here") s.error("C function definition not allowed here")
doc, suite = p_suite(s, Ctx(level = 'function'), with_doc = 1) doc, suite = p_suite(s, Ctx(level = 'function'), with_doc = 1)
result = Nodes.CFuncDefNode(pos, result = Nodes.CFuncDefNode(pos,
...@@ -3055,7 +3053,7 @@ def p_cpp_class_definition(s, pos, ctx): ...@@ -3055,7 +3053,7 @@ def p_cpp_class_definition(s, pos, ctx):
s.expect('NEWLINE') s.expect('NEWLINE')
s.expect_indent() s.expect_indent()
attributes = [] attributes = []
body_ctx = Ctx(visibility = ctx.visibility) body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class')
body_ctx.templates = templates body_ctx.templates = templates
while s.sy != 'DEDENT': while s.sy != 'DEDENT':
if s.systring == 'cppclass': if s.systring == 'cppclass':
......
...@@ -3005,6 +3005,11 @@ class CppClassType(CType): ...@@ -3005,6 +3005,11 @@ class CppClassType(CType):
exception_check = True exception_check = True
namespace = None namespace = None
# For struct-like declaration.
kind = "struct"
packed = False
typedef_flag = False
subtypes = ['templates'] subtypes = ['templates']
def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None): def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None):
......
...@@ -486,10 +486,11 @@ class Scope(object): ...@@ -486,10 +486,11 @@ class Scope(object):
def declare_cpp_class(self, name, scope, def declare_cpp_class(self, name, scope,
pos, cname = None, base_classes = (), pos, cname = None, base_classes = (),
visibility = 'extern', templates = None): visibility = 'extern', templates = None):
if visibility != 'extern':
error(pos, "C++ classes may only be extern")
if cname is None: if cname is None:
if self.in_cinclude or (visibility != 'private'):
cname = name cname = name
else:
cname = self.mangle(Naming.type_prefix, name)
base_classes = list(base_classes) base_classes = list(base_classes)
entry = self.lookup_here(name) entry = self.lookup_here(name)
if not entry: if not entry:
...@@ -497,6 +498,7 @@ class Scope(object): ...@@ -497,6 +498,7 @@ class Scope(object):
name, scope, cname, base_classes, templates = templates) name, scope, cname, base_classes, templates = templates)
entry = self.declare_type(name, type, pos, cname, entry = self.declare_type(name, type, pos, cname,
visibility = visibility, defining = scope is not None) visibility = visibility, defining = scope is not None)
self.sue_entries.append(entry)
else: else:
if not (entry.is_type and entry.type.is_cpp_class): if not (entry.is_type and entry.type.is_cpp_class):
error(pos, "'%s' redeclared " % name) error(pos, "'%s' redeclared " % name)
...@@ -525,6 +527,7 @@ class Scope(object): ...@@ -525,6 +527,7 @@ class Scope(object):
entry.type.scope.declare_inherited_cpp_attributes(base_class.scope) entry.type.scope.declare_inherited_cpp_attributes(base_class.scope)
if entry.type.scope: if entry.type.scope:
declare_inherited_attributes(entry, base_classes) declare_inherited_attributes(entry, base_classes)
entry.type.scope.declare_var(name="this", cname="this", type=PyrexTypes.CPtrType(entry.type), pos=entry.pos)
if self.is_cpp_class_scope: if self.is_cpp_class_scope:
entry.type.namespace = self.outer_scope.lookup(self.name).type entry.type.namespace = self.outer_scope.lookup(self.name).type
return entry return entry
...@@ -1992,6 +1995,7 @@ class CppClassScope(Scope): ...@@ -1992,6 +1995,7 @@ class CppClassScope(Scope):
is_cpp_class_scope = 1 is_cpp_class_scope = 1
default_constructor = None default_constructor = None
class_namespace = None
def __init__(self, name, outer_scope, templates=None): def __init__(self, name, outer_scope, templates=None):
Scope.__init__(self, name, outer_scope, None) Scope.__init__(self, name, outer_scope, None)
...@@ -2010,10 +2014,10 @@ class CppClassScope(Scope): ...@@ -2010,10 +2014,10 @@ class CppClassScope(Scope):
# Add an entry for an attribute. # Add an entry for an attribute.
if not cname: if not cname:
cname = name cname = name
if type.is_cfunction:
type = PyrexTypes.CPtrType(type)
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_variable = 1 entry.is_variable = 1
if type.is_cfunction and self.class_namespace:
entry.func_cname = "%s::%s" % (self.class_namespace, cname)
self.var_entries.append(entry) self.var_entries.append(entry)
if type.is_pyobject and not allow_pyobject: if type.is_pyobject and not allow_pyobject:
error(pos, error(pos,
......
...@@ -350,6 +350,12 @@ class EnvTransform(CythonTransform): ...@@ -350,6 +350,12 @@ class EnvTransform(CythonTransform):
self.env_stack.pop() self.env_stack.pop()
return node return node
def visit_CStructOrUnionDefNode(self, node):
self.env_stack.append((node, node.scope))
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.expr_scope: if node.expr_scope:
self.env_stack.append((node, node.expr_scope)) self.env_stack.append((node, node.expr_scope))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment