Commit c5366434 authored by Stefan Behnel's avatar Stefan Behnel

find common type for comparisons *before* coercing operands, to prevent...

find common type for comparisons *before* coercing operands, to prevent inconsistent types and loosing type information
parent c7a06feb
...@@ -5072,7 +5072,7 @@ class CmpNode(object): ...@@ -5072,7 +5072,7 @@ class CmpNode(object):
result = result and cascade.compile_time_value(operand2, denv) result = result and cascade.compile_time_value(operand2, denv)
return result return result
def try_coerce_to_int_cmp(self, env, op, operand1, operand2): def find_common_int_type(self, env, op, operand1, operand2):
# type1 != type2 and at least one of the types is not a C int # type1 != type2 and at least one of the types is not a C int
type1 = operand1.type type1 = operand1.type
type2 = operand2.type type2 = operand2.type
...@@ -5088,22 +5088,23 @@ class CmpNode(object): ...@@ -5088,22 +5088,23 @@ class CmpNode(object):
if type1.is_int: if type1.is_int:
if type2_can_be_int: if type2_can_be_int:
operand2 = operand2.coerce_to(type1, env) return type1
elif type2.is_int: elif type2.is_int:
if type1_can_be_int: if type1_can_be_int:
operand1 = operand1.coerce_to(type2, env) return type2
elif type1_can_be_int: elif type1_can_be_int:
if type2_can_be_int: if type2_can_be_int:
operand1 = operand1.coerce_to(PyrexTypes.c_uchar_type, env) return PyrexTypes.c_uchar_type
operand2 = operand2.coerce_to(PyrexTypes.c_uchar_type, env)
return operand1, operand2 return None
def coerce_operands(self, env, op, operand1, common_type=None): def find_common_type(self, env, op, operand1, common_type=None):
operand2 = self.operand2 operand2 = self.operand2
type1 = operand1.type type1 = operand1.type
type2 = operand2.type type2 = operand2.type
new_common_type = None
if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \ if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \
type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)): type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)):
error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3") error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
...@@ -5112,32 +5113,38 @@ class CmpNode(object): ...@@ -5112,32 +5113,38 @@ class CmpNode(object):
if op not in ('==', '!='): if op not in ('==', '!='):
error(self.pos, "complex types unordered") error(self.pos, "complex types unordered")
if operand1.type.is_pyobject: if operand1.type.is_pyobject:
operand2 = operand2.coerce_to(operand2.type, env) new_common_type = operand1.type
elif operand2.type.is_pyobject: elif operand2.type.is_pyobject:
operand1 = operand1.coerce_to(operand2.type, env) new_common_type = operand2.type
else: else:
common_type = PyrexTypes.widest_numeric_type(type1, type2) new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
operand1 = operand1.coerce_to(common_type, env)
operand2 = operand2.coerce_to(common_type, env)
elif common_type is None or not common_type.is_pyobject: elif common_type is None or not common_type.is_pyobject:
if not type1.is_int or not type2.is_int: if not type1.is_int or not type2.is_int:
operand1, operand2 = self.try_coerce_to_int_cmp(env, op, operand1, operand2) new_common_type = self.find_common_int_type(env, op, operand1, operand2)
if new_common_type is None:
new_common_type = PyrexTypes.spanning_type(operand1.type, operand2.type)
if operand1.type.is_pyobject or operand2.type.is_pyobject: if common_type is None:
common_type = new_common_type
else:
# we could do a lot better by splitting the comparison # we could do a lot better by splitting the comparison
# into a non-Python part and a Python part, but this is # into a non-Python part and a Python part, but this is
# safer for now # safer for now
if operand1.type == operand2.type: common_type = PyrexTypes.spanning_type(common_type, new_common_type)
common_type = operand1.type
else:
common_type = py_object_type
if self.cascade: if self.cascade:
operand2 = self.cascade.coerce_operands(env, self.operator, operand2, common_type) common_type = self.cascade.find_common_type(env, self.operator, operand2, common_type)
self.operand2 = operand2 return common_type
return operand1
def coerce_operands_to(self, dst_type, env):
operand2 = self.operand2
if operand2.type != dst_type:
self.operand2 = operand2.coerce_to(dst_type, env)
if self.cascade:
self.cascade.coerce_operands_to(dst_type, env)
def is_python_comparison(self): def is_python_comparison(self):
return (self.has_python_operands() return (self.has_python_operands()
...@@ -5292,11 +5299,14 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5292,11 +5299,14 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand1.analyse_types(env) self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
self.cascade.analyse_types(env, self.operand2) self.cascade.analyse_types(env)
self.operand1 = self.coerce_operands(env, self.operator, self.operand1)
self.is_pycmp = self.is_python_comparison() common_type = self.find_common_type(env, self.operator, self.operand1)
if self.is_pycmp: self.is_pycmp = common_type.is_pyobject
self.coerce_operands_to_pyobjects(env) if self.operand1.type != common_type:
self.operand1 = self.operand1.coerce_to(common_type, env)
self.coerce_operands_to(common_type, env)
if self.cascade: if self.cascade:
self.operand2 = self.operand2.coerce_to_simple(env) self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.coerce_cascaded_operands_to_temp(env) self.cascade.coerce_cascaded_operands_to_temp(env)
...@@ -5407,10 +5417,10 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -5407,10 +5417,10 @@ class CascadedCmpNode(Node, CmpNode):
def type_dependencies(self, env): def type_dependencies(self, env):
return () return ()
def analyse_types(self, env, operand1): def analyse_types(self, env):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
self.cascade.analyse_types(env, self.operand2) self.cascade.analyse_types(env)
def check_operand_types(self, env, operand1): def check_operand_types(self, env, operand1):
self.check_types(env, self.check_types(env,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment