From 8b6426210bae293805a082f0bb795223edd3f9e0 Mon Sep 17 00:00:00 2001
From: Robert Bradshaw <robertwb@gmail.com>
Date: Tue, 23 Aug 2016 02:08:37 -0700
Subject: [PATCH] Infer common parent of C++ classes for spanning type of
 pointers.

---
 Cython/Compiler/PyrexTypes.py    | 24 ++++++++++++++++++++++++
 tests/run/cpp_type_inference.pyx | 25 +++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py
index 9edfb1715..2b0a2c4b7 100644
--- a/Cython/Compiler/PyrexTypes.py
+++ b/Cython/Compiler/PyrexTypes.py
@@ -4,6 +4,7 @@
 
 from __future__ import absolute_import
 
+import collections
 import copy
 import re
 
@@ -12,6 +13,7 @@ try:
 except NameError:
     from functools import reduce
 
+from Cython.Utils import cached_function
 from .Code import UtilityCode, LazyUtilityCode, TempitaUtilityCode
 from . import StringEncoding
 from . import Naming
@@ -4219,6 +4221,10 @@ def _spanning_type(type1, type2):
             return py_object_type
         return type2
     elif type1.is_ptr and type2.is_ptr:
+        if type1.base_type.is_cpp_class and type2.base_type.is_cpp_class:
+            common_base = widest_cpp_type(type1.base_type, type2.base_type)
+            if common_base:
+                return CPtrType(common_base)
         # incompatible pointers, void* will do as a result
         return c_void_ptr_type
     else:
@@ -4236,6 +4242,24 @@ def widest_extension_type(type1, type2):
         if type1 is None or type2 is None:
             return py_object_type
 
+def widest_cpp_type(type1, type2):
+    @cached_function
+    def bases(type):
+        all = set()
+        for base in type.base_classes:
+            all.add(base)
+            all.update(bases(base))
+        return all
+    common_bases = bases(type1).intersection(bases(type2))
+    common_bases_bases = reduce(set.union, [bases(b) for b in common_bases], set())
+    candidates = [b for b in common_bases if b not in common_bases_bases]
+    if len(candidates) == 1:
+        return candidates[0]
+    else:
+        # Fall back to void* for now.
+        return None
+
+
 def simple_c_type(signed, longness, name):
     # Find type descriptor for simple type given name and modifiers.
     # Returns None if arguments don't make sense.
diff --git a/tests/run/cpp_type_inference.pyx b/tests/run/cpp_type_inference.pyx
index c6a54646a..95631ebfb 100644
--- a/tests/run/cpp_type_inference.pyx
+++ b/tests/run/cpp_type_inference.pyx
@@ -1,6 +1,17 @@
 # mode: run
 # tag: cpp, werror
 
+cdef extern from "shapes.h" namespace "shapes":
+    cdef cppclass Shape:
+        float area()
+
+    cdef cppclass Circle(Shape):
+        int radius
+        Circle(int)
+
+    cdef cppclass Square(Shape):
+        Square(int)
+
 from cython cimport typeof
 
 from cython.operator cimport dereference as d
@@ -23,3 +34,17 @@ def test_reversed_vector_iteration(L):
         incr(it)
         print('%s: %s' % (typeof(a), a))
     print(typeof(a))
+
+def test_derived_types(int size, bint round):
+    """
+    >>> test_derived_types(5, True)
+    Shape *
+    >>> test_derived_types(5, False)
+    Shape *
+    """
+    if round:
+        ptr = new Circle(size)
+    else:
+        ptr = new Square(size)
+    print typeof(ptr)
+    del ptr
-- 
2.30.9