Commit ffc292c6 authored by Stefan Behnel's avatar Stefan Behnel

prevent compile time constant folding of non-portable string comparisons (e.g....

prevent compile time constant folding of non-portable string comparisons (e.g. between bytes and str)
parent 3b8250f6
...@@ -44,6 +44,12 @@ try: ...@@ -44,6 +44,12 @@ try:
except ImportError: except ImportError:
basestring = str # Python 3 basestring = str # Python 3
try:
from builtins import bytes
except ImportError:
bytes = str # Python 2
class NotConstant(object): class NotConstant(object):
_obj = None _obj = None
...@@ -1384,7 +1390,9 @@ class StringNode(PyConstNode): ...@@ -1384,7 +1390,9 @@ class StringNode(PyConstNode):
unicode_value = None unicode_value = None
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.value if self.unicode_value is not None:
# only the Unicode value is portable across Py2/3
self.constant_result = self.unicode_value
def as_sliced_node(self, start, stop, step=None): def as_sliced_node(self, start, stop, step=None):
value = type(self.value)(self.value[start:stop:step]) value = type(self.value)(self.value[start:stop:step])
...@@ -9635,6 +9643,11 @@ class CmpNode(object): ...@@ -9635,6 +9643,11 @@ class CmpNode(object):
def calculate_cascaded_constant_result(self, operand1_result): def calculate_cascaded_constant_result(self, operand1_result):
func = compile_time_binary_operators[self.operator] func = compile_time_binary_operators[self.operator]
operand2_result = self.operand2.constant_result operand2_result = self.operand2.constant_result
if (isinstance(operand1_result, (bytes, unicode)) and
isinstance(operand2_result, (bytes, unicode)) and
type(operand1_result) != type(operand2_result)):
# string comparison of different types isn't portable
return
result = func(operand1_result, operand2_result) result = func(operand1_result, operand2_result)
if self.cascade: if self.cascade:
self.cascade.calculate_cascaded_constant_result(operand2_result) self.cascade.calculate_cascaded_constant_result(operand2_result)
......
cimport cython
import sys
IS_PY3 = sys.version_info[0] >= 3
bstring1 = b"abcdefg" bstring1 = b"abcdefg"
bstring2 = b"1234567" bstring2 = b"1234567"
...@@ -211,3 +216,18 @@ def bytes_cascade_untyped_end(bytes s1, bytes s2): ...@@ -211,3 +216,18 @@ def bytes_cascade_untyped_end(bytes s1, bytes s2):
False False
""" """
return s1 == s2 == b"abcdefg" == (<object>bstring1) == bstring1 return s1 == s2 == b"abcdefg" == (<object>bstring1) == bstring1
@cython.test_assert_path_exists(
'//CondExprNode',
'//CondExprNode//PrimaryCmpNode',
'//CondExprNode//PrimaryCmpNode[@operator = "=="]',
'//CondExprNode//PrimaryCmpNode[@operator = "!="]',
)
def literal_compare_bytes_str():
"""
>>> literal_compare_bytes_str()
True
"""
# we must not constant fold the subexpressions as the result is Py2/3 sensitive
return b'abc' != 'abc' if IS_PY3 else b'abc' == 'abc'
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment