Commit ecd0ae5b authored by Stefan Behnel's avatar Stefan Behnel

merge

parents dae3f8f1 89680db6
...@@ -2013,7 +2013,8 @@ class IndexNode(ExprNode): ...@@ -2013,7 +2013,8 @@ class IndexNode(ExprNode):
# Handle the case where base is a literal char* (and we expect a string, not an int) # Handle the case where base is a literal char* (and we expect a string, not an int)
if isinstance(self.base, BytesNode) or is_slice: if isinstance(self.base, BytesNode) or is_slice:
self.base = self.base.coerce_to_pyobject(env) if self.base.type.is_string or not (self.base.type.is_ptr or self.base.type.is_array):
self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
...@@ -2094,7 +2095,7 @@ class IndexNode(ExprNode): ...@@ -2094,7 +2095,7 @@ class IndexNode(ExprNode):
if self.index.type.is_pyobject: if self.index.type.is_pyobject:
self.index = self.index.coerce_to( self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env) PyrexTypes.c_py_ssize_t_type, env)
if not self.index.type.is_int: elif not self.index.type.is_int:
error(self.pos, error(self.pos,
"Invalid index type '%s'" % "Invalid index type '%s'" %
self.index.type) self.index.type)
...@@ -5997,10 +5998,11 @@ class CmpNode(object): ...@@ -5997,10 +5998,11 @@ class CmpNode(object):
(op, operand1.type, operand2.type)) (op, operand1.type, operand2.type))
def is_python_comparison(self): def is_python_comparison(self):
return not self.is_c_string_contains() and ( return (not self.is_ptr_contains()
self.has_python_operands() and not self.is_c_string_contains()
or (self.cascade and self.cascade.is_python_comparison()) and (self.has_python_operands()
or self.operator in ('in', 'not_in')) or (self.cascade and self.cascade.is_python_comparison())
or self.operator in ('in', 'not_in')))
def coerce_operands_to(self, dst_type, env): def coerce_operands_to(self, dst_type, env):
operand2 = self.operand2 operand2 = self.operand2
...@@ -6012,7 +6014,8 @@ class CmpNode(object): ...@@ -6012,7 +6014,8 @@ class CmpNode(object):
def is_python_result(self): def is_python_result(self):
return ((self.has_python_operands() and return ((self.has_python_operands() and
self.operator not in ('is', 'is_not', 'in', 'not_in') and self.operator not in ('is', 'is_not', 'in', 'not_in') and
not self.is_c_string_contains()) not self.is_c_string_contains() and
not self.is_ptr_contains())
or (self.cascade and self.cascade.is_python_result())) or (self.cascade and self.cascade.is_python_result()))
def is_c_string_contains(self): def is_c_string_contains(self):
...@@ -6021,6 +6024,16 @@ class CmpNode(object): ...@@ -6021,6 +6024,16 @@ class CmpNode(object):
and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or
(self.operand1.type is PyrexTypes.c_py_unicode_type (self.operand1.type is PyrexTypes.c_py_unicode_type
and self.operand2.type is unicode_type)) and self.operand2.type is unicode_type))
def is_ptr_contains(self):
if self.operator in ('in', 'not_in'):
iterator = self.operand2
if iterator.type.is_ptr or iterator.type.is_array:
return iterator.type.base_type is not PyrexTypes.c_char_type
if (isinstance(iterator, IndexNode) and
isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
(iterator.base.type.is_array or iterator.base.type.is_ptr)):
return iterator.base.type.base_type is not PyrexTypes.c_char_type
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
operand1, op , operand2): operand1, op , operand2):
...@@ -6216,6 +6229,12 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -6216,6 +6229,12 @@ class PrimaryCmpNode(ExprNode, CmpNode):
env.use_utility_code(char_in_bytes_utility_code) env.use_utility_code(char_in_bytes_utility_code)
self.operand2 = self.operand2.as_none_safe_node( self.operand2 = self.operand2.as_none_safe_node(
"argument of type 'NoneType' is not iterable") "argument of type 'NoneType' is not iterable")
elif self.is_ptr_contains():
if self.cascade:
error(self.pos, "Cascading comparison not yet supported for 'val in sliced pointer'.")
self.type = PyrexTypes.c_bint_type
# Will be transformed by IterationTransform
return
else: else:
common_type = py_object_type common_type = py_object_type
self.is_pycmp = True self.is_pycmp = True
......
...@@ -4295,6 +4295,9 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -4295,6 +4295,9 @@ class ForInStatNode(LoopNode, StatNode):
self.target.analyse_target_types(env) self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env) self.iterator.analyse_expressions(env)
self.item = ExprNodes.NextNode(self.iterator, env) self.item = ExprNodes.NextNode(self.iterator, env)
if not self.target.type.assignable_from(self.item.type) and \
(self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array):
self.item.type = self.iterator.sequence.type.base_type
self.item = self.item.coerce_to(self.target.type, env) self.item = self.item.coerce_to(self.target.type, env)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
......
...@@ -82,11 +82,62 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -82,11 +82,62 @@ class IterationTransform(Visitor.VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
self.current_scope = oldscope self.current_scope = oldscope
return node return node
def visit_PrimaryCmpNode(self, node):
if node.is_ptr_contains():
# for t in operand2:
# if operand1 == t:
# res = True
# break
# else:
# res = False
pos = node.pos
res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
res = res_handle.ref(pos)
result_ref = UtilNodes.ResultRefNode(node)
if isinstance(node.operand2, ExprNodes.IndexNode):
base_type = node.operand2.base.type.base_type
else:
base_type = node.operand2.type.base_type
target_handle = UtilNodes.TempHandle(base_type)
target = target_handle.ref(pos)
cmp_node = ExprNodes.PrimaryCmpNode(
pos, operator=u'==', operand1=node.operand1, operand2=target)
if_body = Nodes.StatListNode(
pos,
stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
Nodes.BreakStatNode(pos)])
if_node = Nodes.IfStatNode(
pos,
if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
else_clause=None)
for_loop = UtilNodes.TempsBlockNode(
pos,
temps = [target_handle],
body = Nodes.ForInStatNode(
pos,
target=target,
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_scope)
for_loop = self(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
if node.operator == 'not_in':
new_node = ExprNodes.NotNode(pos, operand=new_node)
return new_node
else:
self.visitchildren(node)
return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
return self._optimise_for_loop(node) return self._optimise_for_loop(node)
def _optimise_for_loop(self, node): def _optimise_for_loop(self, node):
iterator = node.iterator.sequence iterator = node.iterator.sequence
if iterator.type is Builtin.dict_type: if iterator.type is Builtin.dict_type:
......
...@@ -53,6 +53,8 @@ cdef extern from "Python.h": ...@@ -53,6 +53,8 @@ cdef extern from "Python.h":
# Delete attribute named attr_name, for object o. Returns -1 on # Delete attribute named attr_name, for object o. Returns -1 on
# failure. This is the equivalent of the Python statement "del # failure. This is the equivalent of the Python statement "del
# o.attr_name". # o.attr_name".
int Py_LT, Py_LE, Py_EQ, Py_NE, Py_GT, Py_GE
object PyObject_RichCompare(object o1, object o2, int opid) object PyObject_RichCompare(object o1, object o2, int opid)
# Return value: New reference. # Return value: New reference.
......
from libc.stdlib cimport malloc, free
from cpython.object cimport Py_EQ, Py_NE
def double_ptr_slice(x, L, int a, int b):
"""
>>> L = list(range(10))
>>> double_ptr_slice(5, L, 0, 10)
>>> double_ptr_slice(6, L, 0, 10)
>>> double_ptr_slice(None, L, 0, 10)
>>> double_ptr_slice(0, L, 3, 7)
>>> double_ptr_slice(5, L, 3, 7)
>>> double_ptr_slice(9, L, 3, 7)
>>> double_ptr_slice(EqualsEvens(), L, 0, 10)
>>> double_ptr_slice(EqualsEvens(), L, 1, 10)
"""
cdef double *L_c = NULL
try:
L_c = <double*>malloc(len(L) * sizeof(double))
for i, a in enumerate(L):
L_c[i] = L[i]
assert (x in L_c[:b]) == (x in L[:b])
assert (x in L_c[a:b]) == (x in L[a:b])
assert (x in L_c[a:b:2]) == (x in L[a:b:2])
finally:
free(L_c)
def void_ptr_slice(py_x, L, int a, int b):
"""
>>> L = list(range(10))
>>> void_ptr_slice(5, L, 0, 10)
>>> void_ptr_slice(6, L, 0, 10)
>>> void_ptr_slice(None, L, 0, 10)
>>> void_ptr_slice(0, L, 3, 7)
>>> void_ptr_slice(5, L, 3, 7)
>>> void_ptr_slice(9, L, 3, 7)
"""
# I'm using the fact that small Python ints are cached.
cdef void **L_c = NULL
cdef void *x = <void*>py_x
try:
L_c = <void**>malloc(len(L) * sizeof(void*))
for i, a in enumerate(L):
L_c[i] = <void*>L[i]
assert (x in L_c[:b]) == (py_x in L[:b])
assert (x in L_c[a:b]) == (py_x in L[a:b])
# assert (x in L_c[a:b:2]) == (py_x in L[a:b:2])
finally:
free(L_c)
cdef class EqualsEvens:
"""
>>> e = EqualsEvens()
>>> e == 2
True
>>> e == 5
False
>>> [e == k for k in range(4)]
[True, False, True, False]
"""
def __richcmp__(self, other, int op):
if op == Py_EQ:
return other % 2 == 0
elif op == Py_NE:
return other % 2 == 1
else:
return False
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment