diff --git a/CHANGES.rst b/CHANGES.rst index 72a089874f1559c6dcdc7b2507b7fe72f6bf3079..deefd34c94bce9b8541934063d913a40858fecb7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -28,6 +28,8 @@ Features added * Exception type tests have slightly lower overhead. This fixes ticket 868. +* C++ classes can now be declared with default template parameters. + Bugs fixed ---------- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index e98a2419d9308df6a4748a97f9115462bb8d4ef0..5b39da43632b7fc370771c4a7f8b1944afb91f69 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -888,7 +888,8 @@ class ExprNode(Node): # Added the string comparison, since for c types that # is enough, but Cython gets confused when the types are # in different pxi files. - if not (str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)): + # TODO: Remove this hack and require shared declarations. + if not (src.type == dst_type or str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)): self.fail_assignment(dst_type) return src diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 69b6ff14f987f10247646cc069d8e581cf4148d2..60b3533603372401da12f4cd1e8698a76c0b7f29 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1403,7 +1403,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): # attributes [CVarDefNode] or None # entry Entry # base_classes [CBaseTypeNode] - # templates [string] or None + # templates [(string, bool)] or None # decorators [DecoratorNode] or None decorators = None @@ -1412,25 +1412,31 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): if self.templates is None: template_types = None else: - template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates] + template_types = [PyrexTypes.TemplatePlaceholderType(template_name, not required) + for template_name, required in self.templates] + num_optional_templates = sum(not required for _, required in self.templates) + if num_optional_templates and not all(required for _, required in self.templates[:-num_optional_templates]): + error(self.pos, "Required template parameters must precede optional template parameters.") self.entry = env.declare_cpp_class( self.name, None, self.pos, self.cname, base_classes = [], visibility = self.visibility, templates = template_types) def analyse_declarations(self, env): + if self.templates is None: + template_types = template_names = None + else: + template_names = [template_name for template_name, _ in self.templates] + template_types = [PyrexTypes.TemplatePlaceholderType(template_name, not required) + for template_name, required in self.templates] scope = None if self.attributes is not None: - scope = CppClassScope(self.name, env, templates = self.templates) + scope = CppClassScope(self.name, env, templates = template_names) def base_ok(base_class): if base_class.is_cpp_class or base_class.is_struct: return True else: error(self.pos, "Base class '%s' not a struct or class." % base_class) base_class_types = filter(base_ok, [b.analyse(scope or env) for b in self.base_classes]) - if self.templates is None: - template_types = None - else: - template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates] self.entry = env.declare_cpp_class( self.name, scope, self.pos, self.cname, base_class_types, visibility = self.visibility, templates = template_types) @@ -1455,7 +1461,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): for func in func_attributes(self.attributes): defined_funcs.append(func) if self.templates is not None: - func.template_declaration = "template <typename %s>" % ", typename ".join(self.templates) + func.template_declaration = "template <typename %s>" % ", typename ".join(template_names) self.body = StatListNode(self.pos, stats=defined_funcs) self.scope = scope diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 0fd79728df12791a0d160d1fc73dde9118c8bc8d..c6a27298db1ffd3d69f10b9786d928cfe527ebbd 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -3363,6 +3363,16 @@ def p_module(s, pxd, full_module_name, ctx=Ctx): full_module_name = full_module_name, directive_comments = directive_comments) +def p_template_definition(s): + name = p_ident(s) + if s.sy == '=': + s.expect('=') + s.expect('*') + required = False + else: + required = True + return name, required + def p_cpp_class_definition(s, pos, ctx): # s.sy == 'cppclass' s.next() @@ -3375,19 +3385,21 @@ def p_cpp_class_definition(s, pos, ctx): error(pos, "Qualified class name not allowed C++ class") if s.sy == '[': s.next() - templates = [p_ident(s)] + templates = [p_template_definition(s)] while s.sy == ',': s.next() - templates.append(p_ident(s)) + templates.append(p_template_definition(s)) s.expect(']') + template_names = [name for name, required in templates] else: templates = None + template_names = None if s.sy == '(': s.next() - base_classes = [p_c_base_type(s, templates = templates)] + base_classes = [p_c_base_type(s, templates = template_names)] while s.sy == ',': s.next() - base_classes.append(p_c_base_type(s, templates = templates)) + base_classes.append(p_c_base_type(s, templates = template_names)) s.expect(')') else: base_classes = [] @@ -3400,7 +3412,7 @@ def p_cpp_class_definition(s, pos, ctx): s.expect_indent() attributes = [] body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil) - body_ctx.templates = templates + body_ctx.templates = template_names while s.sy != 'DEDENT': if s.sy != 'pass': attributes.append(p_cpp_class_attribute(s, body_ctx)) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 85792d1f37c0852a68f139d67754096b550cf19d..8fa41a3b314ff59294b6d9c991f799a77fe6e6e1 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -3398,7 +3398,6 @@ builtin_cpp_conversions = ("std::pair", "std::set", "std::unordered_set", "std::map", "std::unordered_map") - class CppClassType(CType): # name string # cname string @@ -3425,6 +3424,7 @@ class CppClassType(CType): self.operators = [] self.templates = templates self.template_type = template_type + self.num_optional_templates = sum(is_optional_template_param(T) for T in templates or ()) self.specializations = {} self.is_cpp_string = cname in cpp_string_conversions @@ -3554,6 +3554,13 @@ class CppClassType(CType): if not self.is_template_type(): error(pos, "'%s' type is not a template" % self) return error_type + if len(self.templates) - self.num_optional_templates <= len(template_values) < len(self.templates): + partial_specialization = self.declaration_code('', template_params=template_values) + template_values = template_values + [ + TemplatePlaceholderType("%s %s::%s" % ( + TemplatePlaceholderType.UNDECLARABLE_DEFAULT, partial_specialization, param.name), + True) + for param in self.templates[-self.num_optional_templates:]] if len(self.templates) != len(template_values): error(pos, "%s templated type receives %d arguments, got %d" % (self.name, len(self.templates), len(template_values))) @@ -3601,10 +3608,14 @@ class CppClassType(CType): return None def declaration_code(self, entity_code, - for_display = 0, dll_linkage = None, pyrex = 0): + for_display = 0, dll_linkage = None, pyrex = 0, + template_params = None): + if template_params is None: + template_params = self.templates if self.templates: template_strings = [param.declaration_code('', for_display, None, pyrex) - for param in self.templates] + for param in template_params + if not is_optional_template_param(param)] if for_display: brackets = "[%s]" else: @@ -3673,11 +3684,17 @@ class CppClassType(CType): class TemplatePlaceholderType(CType): - def __init__(self, name): + UNDECLARABLE_DEFAULT = "undeclarable default " + + def __init__(self, name, optional=False): self.name = name + self.optional = optional def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): + if self.name.startswith(self.UNDECLARABLE_DEFAULT) and not for_display: + error(None, "Can't declare variable of type '%s'" + % self.name[len(self.UNDECLARABLE_DEFAULT) + 1:]) if entity_code: return self.name + " " + entity_code else: @@ -3713,6 +3730,9 @@ class TemplatePlaceholderType(CType): else: return False +def is_optional_template_param(type): + return isinstance(type, TemplatePlaceholderType) and type.optional + class CEnumType(CType): # name string diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py index 8743fcdc6384fee4497ecd5f44a7fbd02ddab517..bd34eea53d77f8009005587965f555f648790e8d 100644 --- a/Cython/Compiler/TypeInference.py +++ b/Cython/Compiler/TypeInference.py @@ -398,7 +398,7 @@ class SimpleAssignmentTypeInferer(object): else: entry = node.entry node_type = spanning_type( - types, entry.might_overflow, entry.pos) + types, entry.might_overflow, entry.pos, scope) node.inferred_type = node_type def infer_name_node_type_partial(node): @@ -407,7 +407,7 @@ class SimpleAssignmentTypeInferer(object): if not types: return entry = node.entry - return spanning_type(types, entry.might_overflow, entry.pos) + return spanning_type(types, entry.might_overflow, entry.pos, scope) def resolve_assignments(assignments): resolved = set() @@ -464,7 +464,7 @@ class SimpleAssignmentTypeInferer(object): types = [assmt.inferred_type for assmt in entry.cf_assignments] if types and all(types): entry_type = spanning_type( - types, entry.might_overflow, entry.pos) + types, entry.might_overflow, entry.pos, scope) inferred.add(entry) self.set_entry_type(entry, entry_type) @@ -473,7 +473,7 @@ class SimpleAssignmentTypeInferer(object): for entry in inferred: types = [assmt.infer_type() for assmt in entry.cf_assignments] - new_type = spanning_type(types, entry.might_overflow, entry.pos) + new_type = spanning_type(types, entry.might_overflow, entry.pos, scope) if new_type != entry.type: self.set_entry_type(entry, new_type) dirty = True @@ -516,10 +516,10 @@ def simply_type(result_type, pos): result_type = PyrexTypes.c_ptr_type(result_type.base_type) return result_type -def aggressive_spanning_type(types, might_overflow, pos): +def aggressive_spanning_type(types, might_overflow, pos, scope): return simply_type(reduce(find_spanning_type, types), pos) -def safe_spanning_type(types, might_overflow, pos): +def safe_spanning_type(types, might_overflow, pos, scope): result_type = simply_type(reduce(find_spanning_type, types), pos) if result_type.is_pyobject: # In theory, any specific Python type is always safe to @@ -554,6 +554,8 @@ def safe_spanning_type(types, might_overflow, pos): # to make sure everything is supported. elif (result_type.is_int or result_type.is_enum) and not might_overflow: return result_type + elif not result_type.can_coerce_to_pyobject(scope): + return result_type return py_object_type diff --git a/tests/run/cpp_templates.pyx b/tests/run/cpp_templates.pyx index 1f7eb551bb906f34726219c37081316aaf1f2c05..b68f2f5218a77751a25786c84c5024510de83e18 100644 --- a/tests/run/cpp_templates.pyx +++ b/tests/run/cpp_templates.pyx @@ -3,12 +3,15 @@ from cython.operator import dereference as deref cdef extern from "cpp_templates_helper.h": - cdef cppclass Wrap[T]: + cdef cppclass Wrap[T, S=*]: Wrap(T) void set(T) T get() bint operator==(Wrap[T]) + S get_alt_type() + void set_alt_type(S) + cdef cppclass Pair[T1,T2]: Pair(T1,T2) T1 first() @@ -57,6 +60,29 @@ def test_double(double x, double y): finally: del a, b + +def test_default_template_arguments(double x): + """ + >>> test_default_template_arguments(3.5) + (3.5, 3.0) + """ + try: + a = new Wrap[double](x) + b = new Wrap[double, int](x) + +# ax = a.get_alt_type() +# a.set_alt_type(ax) + a.set_alt_type(a.get_alt_type()) + +# bx = b.get_alt_type() +# b.set_alt_type(bx) + b.set_alt_type(b.get_alt_type()) + + return a.get(), b.get() + finally: + del a + + def test_pair(int i, double x): """ >>> test_pair(1, 1.5) diff --git a/tests/run/cpp_templates_helper.h b/tests/run/cpp_templates_helper.h index 544d2d5ade414fa12246357c0482876190ebcb6f..a685db6afcf36a53f776b25826cbc01e81e1f670 100644 --- a/tests/run/cpp_templates_helper.h +++ b/tests/run/cpp_templates_helper.h @@ -1,4 +1,4 @@ -template <class T> +template <typename T, typename S=T> class Wrap { T value; public: @@ -6,6 +6,9 @@ public: void set(T v) { value = v; } T get(void) { return value; } bool operator==(Wrap<T> other) { return value == other.value; } + + S get_alt_type(void) { return (S) value; } + void set_alt_type(S v) { value = (T) v; } }; template <class T1, class T2>