Commit 404c8968 authored by Robert Bradshaw's avatar Robert Bradshaw

Add support for external C++ template functions.

The syntax follows that of template classes, namely

    cdef T foo[T](T, ...)
parent 93e0ec35
...@@ -2257,7 +2257,7 @@ class IteratorNode(ExprNode): ...@@ -2257,7 +2257,7 @@ class IteratorNode(ExprNode):
elif sequence_type.is_pyobject: elif sequence_type.is_pyobject:
return sequence_type return sequence_type
return py_object_type return py_object_type
def analyse_cpp_types(self, env): def analyse_cpp_types(self, env):
sequence_type = self.sequence.type sequence_type = self.sequence.type
if sequence_type.is_ptr: if sequence_type.is_ptr:
...@@ -2721,6 +2721,7 @@ class IndexNode(ExprNode): ...@@ -2721,6 +2721,7 @@ class IndexNode(ExprNode):
# base ExprNode # base ExprNode
# index ExprNode # index ExprNode
# indices [ExprNode] # indices [ExprNode]
# type_indices [PyrexType]
# is_buffer_access boolean Whether this is a buffer access. # is_buffer_access boolean Whether this is a buffer access.
# #
# indices is used on buffer access, index on non-buffer access. # indices is used on buffer access, index on non-buffer access.
...@@ -2732,6 +2733,7 @@ class IndexNode(ExprNode): ...@@ -2732,6 +2733,7 @@ class IndexNode(ExprNode):
subexprs = ['base', 'index', 'indices'] subexprs = ['base', 'index', 'indices']
indices = None indices = None
type_indices = None
is_subscript = True is_subscript = True
is_fused_index = False is_fused_index = False
...@@ -3103,8 +3105,7 @@ class IndexNode(ExprNode): ...@@ -3103,8 +3105,7 @@ class IndexNode(ExprNode):
else: else:
base_type = self.base.type base_type = self.base.type
fused_index_operation = base_type.is_cfunction and base_type.is_fused if not base_type.is_cfunction:
if not fused_index_operation:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index = self.index.analyse_types( self.index = self.index.analyse_types(
env, skip_children=skip_child_analysis) env, skip_children=skip_child_analysis)
...@@ -3188,8 +3189,17 @@ class IndexNode(ExprNode): ...@@ -3188,8 +3189,17 @@ class IndexNode(ExprNode):
self.type = func_type.return_type self.type = func_type.return_type
if setting and not func_type.return_type.is_reference: if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference result '%s'" % self.type) error(self.pos, "Can't set non-reference result '%s'" % self.type)
elif fused_index_operation: elif base_type.is_cfunction:
self.parse_indexed_fused_cdef(env) if base_type.is_fused:
self.parse_indexed_fused_cdef(env)
else:
self.type_indices = self.parse_index_as_types(env)
if base_type.templates is None:
error(self.pos, "Can only parameterize template functions.")
elif len(base_type.templates) != len(self.type_indices):
error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
(len(base_type.templates), len(self.type_indices))))
self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
else: else:
error(self.pos, error(self.pos,
"Attempting to index non-array type '%s'" % "Attempting to index non-array type '%s'" %
...@@ -3215,6 +3225,20 @@ class IndexNode(ExprNode): ...@@ -3215,6 +3225,20 @@ class IndexNode(ExprNode):
self.base = self.base.as_none_safe_node(msg) self.base = self.base.as_none_safe_node(msg)
def parse_index_as_types(self, env, required=True):
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
type_indices = []
for index in indices:
type_indices.append(index.analyse_as_type(env))
if type_indices[-1] is None:
if required:
error(index.pos, "not parsable as a type")
return None
return type_indices
def parse_indexed_fused_cdef(self, env): def parse_indexed_fused_cdef(self, env):
""" """
Interpret fused_cdef_func[specific_type1, ...] Interpret fused_cdef_func[specific_type1, ...]
...@@ -3234,16 +3258,12 @@ class IndexNode(ExprNode): ...@@ -3234,16 +3258,12 @@ class IndexNode(ExprNode):
if self.index.is_name or self.index.is_attribute: if self.index.is_name or self.index.is_attribute:
positions.append(self.index.pos) positions.append(self.index.pos)
specific_types.append(self.index.analyse_as_type(env))
elif isinstance(self.index, TupleNode): elif isinstance(self.index, TupleNode):
for arg in self.index.args: for arg in self.index.args:
positions.append(arg.pos) positions.append(arg.pos)
specific_type = arg.analyse_as_type(env) specific_types = self.parse_index_as_types(env, required=False)
specific_types.append(specific_type)
else:
specific_types = [False]
if not Utils.all(specific_types): if specific_types is None:
self.index = self.index.analyse_types(env) self.index = self.index.analyse_types(env)
if not self.base.entry.as_variable: if not self.base.entry.as_variable:
...@@ -3362,6 +3382,10 @@ class IndexNode(ExprNode): ...@@ -3362,6 +3382,10 @@ class IndexNode(ExprNode):
index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))" index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))"
else: else:
assert False, "unexpected base type in indexing: %s" % self.base.type assert False, "unexpected base type in indexing: %s" % self.base.type
elif self.base.type.is_cfunction:
return "%s<%s>" % (
self.base.result(),
",".join([param.declaration_code("") for param in self.type_indices]))
else: else:
if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type: if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
error(self.pos, "Invalid use of pointer slice") error(self.pos, "Invalid use of pointer slice")
...@@ -3388,7 +3412,9 @@ class IndexNode(ExprNode): ...@@ -3388,7 +3412,9 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code) self.base.generate_evaluation_code(code)
if self.indices is None: if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else: else:
for i in self.indices: for i in self.indices:
...@@ -3396,7 +3422,9 @@ class IndexNode(ExprNode): ...@@ -3396,7 +3422,9 @@ class IndexNode(ExprNode):
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
if self.indices is None: if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_disposal_code(code) self.index.generate_disposal_code(code)
else: else:
for i in self.indices: for i in self.indices:
...@@ -3866,7 +3894,7 @@ class SliceIndexNode(ExprNode): ...@@ -3866,7 +3894,7 @@ class SliceIndexNode(ExprNode):
if (dst_type not in (bytes_type, bytearray_type) if (dst_type not in (bytes_type, bytearray_type)
and not env.directives['c_string_encoding']): and not env.directives['c_string_encoding']):
error(self.pos, error(self.pos,
"default encoding required for conversion from '%s' to '%s'" % "default encoding required for conversion from '%s' to '%s'" %
(self.base.type, dst_type)) (self.base.type, dst_type))
self.type = dst_type self.type = dst_type
return super(SliceIndexNode, self).coerce_to(dst_type, env) return super(SliceIndexNode, self).coerce_to(dst_type, env)
...@@ -3876,7 +3904,7 @@ class SliceIndexNode(ExprNode): ...@@ -3876,7 +3904,7 @@ class SliceIndexNode(ExprNode):
error(self.pos, error(self.pos,
"Slicing is not currently supported for '%s'." % self.type) "Slicing is not currently supported for '%s'." % self.type)
return return
base_result = self.base.result() base_result = self.base.result()
result = self.result() result = self.result()
start_code = self.start_code() start_code = self.start_code()
...@@ -3929,8 +3957,8 @@ class SliceIndexNode(ExprNode): ...@@ -3929,8 +3957,8 @@ class SliceIndexNode(ExprNode):
code.error_goto_if_null(result, self.pos))) code.error_goto_if_null(result, self.pos)))
elif self.base.type is unicode_type: elif self.base.type is unicode_type:
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load_cached("PyUnicode_Substring", "StringTools.c")) UtilityCode.load_cached("PyUnicode_Substring", "StringTools.c"))
code.putln( code.putln(
"%s = __Pyx_PyUnicode_Substring(%s, %s, %s); %s" % ( "%s = __Pyx_PyUnicode_Substring(%s, %s, %s); %s" % (
result, result,
...@@ -10599,7 +10627,7 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -10599,7 +10627,7 @@ class CoerceToPyTypeNode(CoercionNode):
if (type not in (bytes_type, bytearray_type) if (type not in (bytes_type, bytearray_type)
and not env.directives['c_string_encoding']): and not env.directives['c_string_encoding']):
error(arg.pos, error(arg.pos,
"default encoding required for conversion from '%s' to '%s'" % "default encoding required for conversion from '%s' to '%s'" %
(arg.type, type)) (arg.type, type))
self.type = type self.type = type
else: else:
......
...@@ -19,8 +19,8 @@ import Naming ...@@ -19,8 +19,8 @@ import Naming
import PyrexTypes import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type from PyrexTypes import py_object_type, error_type
from Symtab import ModuleScope, LocalScope, ClosureScope, \ from Symtab import (ModuleScope, LocalScope, ClosureScope,
StructOrUnionScope, PyClassScope, CppClassScope StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope)
from Code import UtilityCode from Code import UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_string_literal from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options import Options
...@@ -465,6 +465,9 @@ class CDeclaratorNode(Node): ...@@ -465,6 +465,9 @@ class CDeclaratorNode(Node):
calling_convention = "" calling_convention = ""
def analyse_templates(self):
# Only C++ functions have templates.
return None
class CNameDeclaratorNode(CDeclaratorNode): class CNameDeclaratorNode(CDeclaratorNode):
# name string The Cython name being declared # name string The Cython name being declared
...@@ -523,7 +526,7 @@ class CArrayDeclaratorNode(CDeclaratorNode): ...@@ -523,7 +526,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
child_attrs = ["base", "dimension"] child_attrs = ["base", "dimension"]
def analyse(self, base_type, env, nonempty = 0): def analyse(self, base_type, env, nonempty = 0):
if base_type.is_cpp_class: if base_type.is_cpp_class or base_type.is_cfunction:
from ExprNodes import TupleNode from ExprNodes import TupleNode
if isinstance(self.dimension, TupleNode): if isinstance(self.dimension, TupleNode):
args = self.dimension.args args = self.dimension.args
...@@ -565,6 +568,7 @@ class CArrayDeclaratorNode(CDeclaratorNode): ...@@ -565,6 +568,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
class CFuncDeclaratorNode(CDeclaratorNode): class CFuncDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
# args [CArgDeclNode] # args [CArgDeclNode]
# templates [TemplatePlaceholderType]
# has_varargs boolean # has_varargs boolean
# exception_value ConstNode # exception_value ConstNode
# exception_check boolean True if PyErr_Occurred check needed # exception_check boolean True if PyErr_Occurred check needed
...@@ -575,6 +579,28 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -575,6 +579,28 @@ class CFuncDeclaratorNode(CDeclaratorNode):
overridable = 0 overridable = 0
optional_arg_count = 0 optional_arg_count = 0
templates = None
def analyse_templates(self):
if isinstance(self.base, CArrayDeclaratorNode):
from ExprNodes import TupleNode, NameNode
template_node = self.base.dimension
if isinstance(template_node, TupleNode):
template_nodes = template_node.args
elif isinstance(template_node, NameNode):
template_nodes = [template_node]
else:
error(template_node.pos, "Template arguments must be a list of names")
self.templates = []
for template in template_nodes:
if isinstance(template, NameNode):
self.templates.append(PyrexTypes.TemplatePlaceholderType(template.name))
else:
error(template.pos, "Template arguments must be a list of names")
self.base = self.base.base
return self.templates
else:
return None
def analyse(self, return_type, env, nonempty = 0, directive_locals = {}): def analyse(self, return_type, env, nonempty = 0, directive_locals = {}):
if nonempty: if nonempty:
...@@ -659,7 +685,8 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -659,7 +685,8 @@ class CFuncDeclaratorNode(CDeclaratorNode):
optional_arg_count = self.optional_arg_count, optional_arg_count = self.optional_arg_count,
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
calling_convention = self.base.calling_convention, calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable,
templates = self.templates)
if self.optional_arg_count: if self.optional_arg_count:
if func_type.is_fused: if func_type.is_fused:
...@@ -892,7 +919,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -892,7 +919,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
else: else:
scope = None scope = None
break break
if scope is None: if scope is None:
# Maybe it's a cimport. # Maybe it's a cimport.
scope = env.find_imported_module(self.module_path, self.pos) scope = env.find_imported_module(self.module_path, self.pos)
...@@ -1164,6 +1191,21 @@ class CVarDefNode(StatNode): ...@@ -1164,6 +1191,21 @@ class CVarDefNode(StatNode):
if not dest_scope: if not dest_scope:
dest_scope = env dest_scope = env
self.dest_scope = dest_scope self.dest_scope = dest_scope
if self.declarators:
templates = self.declarators[0].analyse_templates()
else:
templates = None
if templates is not None:
if self.visibility != 'extern':
error(self.pos, "Only extern functions allowed")
if len(self.declarators) > 1:
error(self.declarators[1].pos, "Can't multiply declare template types")
env = TemplateScope('func_template', env)
env.directives = env.outer_scope.directives
for template_param in templates:
env.declare_type(template_param.name, template_param, self.pos)
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or
...@@ -1175,12 +1217,12 @@ class CVarDefNode(StatNode): ...@@ -1175,12 +1217,12 @@ class CVarDefNode(StatNode):
visibility = self.visibility visibility = self.visibility
for declarator in self.declarators: for declarator in self.declarators:
if (len(self.declarators) > 1 if (len(self.declarators) > 1
and not isinstance(declarator, CNameDeclaratorNode) and not isinstance(declarator, CNameDeclaratorNode)
and env.directives['warn.multiple_declarators']): and env.directives['warn.multiple_declarators']):
warning(declarator.pos, "Non-trivial type declarators in shared declaration.", 1) warning(declarator.pos, "Non-trivial type declarators in shared declaration.", 1)
if isinstance(declarator, CFuncDeclaratorNode): if isinstance(declarator, CFuncDeclaratorNode):
name_declarator, type = declarator.analyse(base_type, env, directive_locals=self.directive_locals) name_declarator, type = declarator.analyse(base_type, env, directive_locals=self.directive_locals)
else: else:
......
...@@ -119,7 +119,7 @@ class UseUtilityCodeDefinitions(CythonTransform): ...@@ -119,7 +119,7 @@ class UseUtilityCodeDefinitions(CythonTransform):
self.process_entry(node.entry) self.process_entry(node.entry)
self.process_entry(node.type_entry) self.process_entry(node.type_entry)
return node return node
# #
# Pipeline factories # Pipeline factories
# #
......
...@@ -2525,11 +2525,6 @@ class CFuncType(CType): ...@@ -2525,11 +2525,6 @@ class CFuncType(CType):
return '(%s)' % s return '(%s)' % s
def specialize(self, values): def specialize(self, values):
if self.templates is None:
new_templates = None
else:
new_templates = [v.specialize(values) for v in self.templates]
result = CFuncType(self.return_type.specialize(values), result = CFuncType(self.return_type.specialize(values),
[arg.specialize(values) for arg in self.args], [arg.specialize(values) for arg in self.args],
has_varargs = self.has_varargs, has_varargs = self.has_varargs,
...@@ -2540,7 +2535,7 @@ class CFuncType(CType): ...@@ -2540,7 +2535,7 @@ class CFuncType(CType):
with_gil = self.with_gil, with_gil = self.with_gil,
is_overridable = self.is_overridable, is_overridable = self.is_overridable,
optional_arg_count = self.optional_arg_count, optional_arg_count = self.optional_arg_count,
templates = new_templates) templates = self.templates)
result.from_fused = self.is_fused result.from_fused = self.is_fused
return result return result
......
...@@ -276,7 +276,7 @@ class Scope(object): ...@@ -276,7 +276,7 @@ class Scope(object):
# qualified_name string "modname" or "modname.classname" # qualified_name string "modname" or "modname.classname"
# Python strings in this scope # Python strings in this scope
# nogil boolean In a nogil section # nogil boolean In a nogil section
# directives dict Helper variable for the recursive # directives dict Helper variable for the recursive
# analysis, contains directive values. # analysis, contains directive values.
# is_internal boolean Is only used internally (simpler setup) # is_internal boolean Is only used internally (simpler setup)
...@@ -2195,7 +2195,7 @@ class CppClassScope(Scope): ...@@ -2195,7 +2195,7 @@ class CppClassScope(Scope):
entry.pos, entry.pos,
entry.cname, entry.cname,
entry.visibility) entry.visibility)
return scope return scope
...@@ -2237,3 +2237,8 @@ class CConstScope(Scope): ...@@ -2237,3 +2237,8 @@ class CConstScope(Scope):
entry = copy.copy(entry) entry = copy.copy(entry)
entry.type = PyrexTypes.c_const_type(entry.type) entry.type = PyrexTypes.c_const_type(entry.type)
return entry return entry
class TemplateScope(Scope):
def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, None)
self.directives = outer_scope.directives
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment