diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index f89939d072e0f51ecf6c5a814d99d03568640a24..f5eceebee51832cd7935cbd70c70e54962e93a3a 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -8696,9 +8696,9 @@ class CmpNode(object): return (container_type.is_ptr or container_type.is_array) \ and not container_type.is_string - def find_special_bool_compare_function(self, env): + def find_special_bool_compare_function(self, env, operand1): if self.operator in ('==', '!='): - type1, type2 = self.operand1.type, self.operand2.type + type1, type2 = operand1.type, self.operand2.type if type1.is_pyobject and type2.is_pyobject: if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type: env.use_utility_code(UtilityCode.load_cached("UnicodeEquals", "StringTools.c")) @@ -8901,7 +8901,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable") common_type = py_object_type self.is_pycmp = True - elif self.find_special_bool_compare_function(env): + elif self.find_special_bool_compare_function(env, self.operand1): common_type = None # if coercion needed, the method call above has already done it self.is_pycmp = False # result is bint self.is_temp = True # must check for error return @@ -8916,6 +8916,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): if self.cascade: self.operand2 = self.operand2.coerce_to_simple(env) + self.cascade.optimise_comparison(env, self.operand2) self.cascade.coerce_cascaded_operands_to_temp(env) if self.is_python_result(): self.type = PyrexTypes.py_object_type @@ -9079,6 +9080,11 @@ class CascadedCmpNode(Node, CmpNode): def has_python_operands(self): return self.operand2.type.is_pyobject + def optimise_comparison(self, env, operand1): + self.find_special_bool_compare_function(env, operand1) + if self.cascade: + self.cascade.optimise_comparison(env, self.operand2) + def coerce_operands_to_pyobjects(self, env): self.operand2 = self.operand2.coerce_to_pyobject(env) if self.operand2.type is dict_type and self.operator in ('in', 'not_in'): diff --git a/tests/run/string_comparison.pyx b/tests/run/string_comparison.pyx new file mode 100644 index 0000000000000000000000000000000000000000..590e77be6452058a06e6aaf341d7bdd24cb4e3eb --- /dev/null +++ b/tests/run/string_comparison.pyx @@ -0,0 +1,193 @@ + +bstring1 = b"abcdefg" +bstring2 = b"1234567" + +string1 = "abcdefg" +string2 = "1234567" + +ustring1 = u"abcdefg" +ustring2 = u"1234567" + +# unicode + +def unicode_eq(unicode s1, unicode s2): + """ + >>> unicode_eq(ustring1, ustring1) + True + >>> unicode_eq(ustring1+ustring2, ustring1+ustring2) + True + >>> unicode_eq(ustring1, ustring2) + False + """ + return s1 == s2 + +def unicode_neq(unicode s1, unicode s2): + """ + >>> unicode_neq(ustring1, ustring1) + False + >>> unicode_neq(ustring1+ustring2, ustring1+ustring2) + False + >>> unicode_neq(ustring1, ustring2) + True + """ + return s1 != s2 + +def unicode_literal_eq(unicode s): + """ + >>> unicode_literal_eq(ustring1) + True + >>> unicode_literal_eq((ustring1+ustring2)[:len(ustring1)]) + True + >>> unicode_literal_eq(ustring2) + False + """ + return s == u"abcdefg" + +def unicode_literal_neq(unicode s): + """ + >>> unicode_literal_neq(ustring1) + False + >>> unicode_literal_neq((ustring1+ustring2)[:len(ustring1)]) + False + >>> unicode_literal_neq(ustring2) + True + """ + return s != u"abcdefg" + +def unicode_cascade(unicode s1, unicode s2): + """ + >>> unicode_cascade(ustring1, ustring1) + True + >>> unicode_cascade(ustring1, (ustring1+ustring2)[:len(ustring1)]) + True + >>> unicode_cascade(ustring1, ustring2) + False + """ + return s1 == s2 == u"abcdefg" + +''' # NOTE: currently crashes +def unicode_cascade_untyped_end(unicode s1, unicode s2): + """ + >>> unicode_cascade_untyped_end(ustring1, ustring1) + True + >>> unicode_cascade_untyped_end(ustring1, (ustring1+ustring2)[:len(ustring1)]) + True + >>> unicode_cascade_untyped_end(ustring1, ustring2) + False + """ + return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1 +''' + +# str + +def str_eq(str s1, str s2): + """ + >>> str_eq(string1, string1) + True + >>> str_eq(string1+string2, string1+string2) + True + >>> str_eq(string1, string2) + False + """ + return s1 == s2 + +def str_neq(str s1, str s2): + """ + >>> str_neq(string1, string1) + False + >>> str_neq(string1+string2, string1+string2) + False + >>> str_neq(string1, string2) + True + """ + return s1 != s2 + +def str_literal_eq(str s): + """ + >>> str_literal_eq(string1) + True + >>> str_literal_eq((string1+string2)[:len(string1)]) + True + >>> str_literal_eq(string2) + False + """ + return s == "abcdefg" + +def str_literal_neq(str s): + """ + >>> str_literal_neq(string1) + False + >>> str_literal_neq((string1+string2)[:len(string1)]) + False + >>> str_literal_neq(string2) + True + """ + return s != "abcdefg" + +def str_cascade(str s1, str s2): + """ + >>> str_cascade(string1, string1) + True + >>> str_cascade(string1, (string1+string2)[:len(string1)]) + True + >>> str_cascade(string1, string2) + False + """ + return s1 == s2 == "abcdefg" + +# bytes + +def bytes_eq(bytes s1, bytes s2): + """ + >>> bytes_eq(bstring1, bstring1) + True + >>> bytes_eq(bstring1+bstring2, bstring1+bstring2) + True + >>> bytes_eq(bstring1, bstring2) + False + """ + return s1 == s2 + +def bytes_neq(bytes s1, bytes s2): + """ + >>> bytes_neq(bstring1, bstring1) + False + >>> bytes_neq(bstring1+bstring2, bstring1+bstring2) + False + >>> bytes_neq(bstring1, bstring2) + True + """ + return s1 != s2 + +def bytes_literal_eq(bytes s): + """ + >>> bytes_literal_eq(bstring1) + True + >>> bytes_literal_eq((bstring1+bstring2)[:len(bstring1)]) + True + >>> bytes_literal_eq(bstring2) + False + """ + return s == b"abcdefg" + +def bytes_literal_neq(bytes s): + """ + >>> bytes_literal_neq(bstring1) + False + >>> bytes_literal_neq((bstring1+bstring2)[:len(bstring1)]) + False + >>> bytes_literal_neq(bstring2) + True + """ + return s != b"abcdefg" + +def bytes_cascade(bytes s1, bytes s2): + """ + >>> bytes_cascade(bstring1, bstring1) + True + >>> bytes_cascade(bstring1, (bstring1+bstring2)[:len(bstring1)]) + True + >>> bytes_cascade(bstring1, bstring2) + False + """ + return s1 == s2 == b"abcdefg"