Commit fc0547e4 authored by Mark Florisson's avatar Mark Florisson

Allow type expressions in comparisons

parent ee914074
......@@ -141,7 +141,9 @@ class Context(object):
FlattenInListTransform(),
WithTransform(self),
DecoratorTransform(self),
# PrintTree(),
AnalyseDeclarationsTransform(self),
# PrintTree(),
AutoTestDictTransform(self),
EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), ## Necessary?
......@@ -159,6 +161,7 @@ class Context(object):
DropRefcountingTransform(),
FinalOptimizePhase(self),
GilCheck(),
# PrintTree(),
]
def create_pyx_pipeline(self, options, result, py=False):
......
......@@ -2055,6 +2055,7 @@ class FusedCFuncDefNode(StatListNode):
cname = self.node.type.get_specific_cname(cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname
# TransformBuiltinMethods(copied_node)
ParseTreeTransforms.ReplaceFusedTypeChecks(copied_node.local_scope)(copied_node)
......
......@@ -1912,11 +1912,13 @@ class ReplaceFusedTypeChecks(VisitorTransform):
def visit_IfClauseNode(self, node):
cond = node.condition
if isinstance(cond, ExprNodes.PrimaryCmpNode):
type1, type2 = self.get_types(cond)
op = cond.operator
type1 = cond.operand1.analyse_as_type(self.local_scope)
type2 = cond.operand2.analyse_as_type(self.local_scope)
type1 = self.specialize_type(type1, cond.operand1.pos)
if type1 and type2:
type1 = self.specialize_type(type1, cond.operand1.pos)
op = cond.operator
if op == 'is':
type2 = self.specialize_type(type2, cond.operand1.pos)
if type1.same_as(type2):
......@@ -1942,22 +1944,6 @@ class ReplaceFusedTypeChecks(VisitorTransform):
return node
def get_types(self, node):
if node.operand1.is_name and node.operand2.is_name:
return self.get_type(node.operand1), self.get_type(node.operand2)
return None, None
def get_type(self, node):
type = PyrexTypes.parse_basic_type(node.name)
if not type:
# Don't use self.lookup_type() as it will specialize
entry = self.local_scope.lookup(node.name)
if entry and entry.is_type:
type = entry.type
return type
def specialize_type(self, type, pos):
try:
return type.specialize(self.local_scope.fused_to_specific)
......
cimport cython
cimport check_fused_types_pxd
ctypedef char *string_t
ctypedef cython.fused_type(int, long, float, string_t) fused_t
ctypedef cython.fused_type(int, long) other_t
ctypedef cython.fused_type(int, float) unresolved_t
cdef func(fused_t a, other_t b):
cdef int int_a
......@@ -22,13 +22,13 @@ cdef func(fused_t a, other_t b):
print 'fused_t is string_t'
string_a = a
if fused_t in unresolved_t:
if fused_t in check_fused_types_pxd.unresolved_t:
print 'fused_t in unresolved_t'
if int in unresolved_t:
if int in check_fused_types_pxd.unresolved_t:
print 'int in unresolved_t'
if string_t in unresolved_t:
if string_t in check_fused_types_pxd.unresolved_t:
print 'string_t in unresolved_t'
......
cimport cython
ctypedef cython.fused_type(int, float) unresolved_t
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment