From 8b6426210bae293805a082f0bb795223edd3f9e0 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw <robertwb@gmail.com> Date: Tue, 23 Aug 2016 02:08:37 -0700 Subject: [PATCH] Infer common parent of C++ classes for spanning type of pointers. --- Cython/Compiler/PyrexTypes.py | 24 ++++++++++++++++++++++++ tests/run/cpp_type_inference.pyx | 25 +++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 9edfb1715..2b0a2c4b7 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -4,6 +4,7 @@ from __future__ import absolute_import +import collections import copy import re @@ -12,6 +13,7 @@ try: except NameError: from functools import reduce +from Cython.Utils import cached_function from .Code import UtilityCode, LazyUtilityCode, TempitaUtilityCode from . import StringEncoding from . import Naming @@ -4219,6 +4221,10 @@ def _spanning_type(type1, type2): return py_object_type return type2 elif type1.is_ptr and type2.is_ptr: + if type1.base_type.is_cpp_class and type2.base_type.is_cpp_class: + common_base = widest_cpp_type(type1.base_type, type2.base_type) + if common_base: + return CPtrType(common_base) # incompatible pointers, void* will do as a result return c_void_ptr_type else: @@ -4236,6 +4242,24 @@ def widest_extension_type(type1, type2): if type1 is None or type2 is None: return py_object_type +def widest_cpp_type(type1, type2): + @cached_function + def bases(type): + all = set() + for base in type.base_classes: + all.add(base) + all.update(bases(base)) + return all + common_bases = bases(type1).intersection(bases(type2)) + common_bases_bases = reduce(set.union, [bases(b) for b in common_bases], set()) + candidates = [b for b in common_bases if b not in common_bases_bases] + if len(candidates) == 1: + return candidates[0] + else: + # Fall back to void* for now. + return None + + def simple_c_type(signed, longness, name): # Find type descriptor for simple type given name and modifiers. # Returns None if arguments don't make sense. diff --git a/tests/run/cpp_type_inference.pyx b/tests/run/cpp_type_inference.pyx index c6a54646a..95631ebfb 100644 --- a/tests/run/cpp_type_inference.pyx +++ b/tests/run/cpp_type_inference.pyx @@ -1,6 +1,17 @@ # mode: run # tag: cpp, werror +cdef extern from "shapes.h" namespace "shapes": + cdef cppclass Shape: + float area() + + cdef cppclass Circle(Shape): + int radius + Circle(int) + + cdef cppclass Square(Shape): + Square(int) + from cython cimport typeof from cython.operator cimport dereference as d @@ -23,3 +34,17 @@ def test_reversed_vector_iteration(L): incr(it) print('%s: %s' % (typeof(a), a)) print(typeof(a)) + +def test_derived_types(int size, bint round): + """ + >>> test_derived_types(5, True) + Shape * + >>> test_derived_types(5, False) + Shape * + """ + if round: + ptr = new Circle(size) + else: + ptr = new Square(size) + print typeof(ptr) + del ptr -- 2.30.9