Commit ab78f93b authored by Stefan Behnel's avatar Stefan Behnel

adapt and apply major refactoring of IndexNode originally written by Mark Florisson

parent 7da49602
...@@ -201,7 +201,13 @@ class BufferEntry(object): ...@@ -201,7 +201,13 @@ class BufferEntry(object):
self.type = entry.type self.type = entry.type
self.cname = entry.buffer_aux.buflocal_nd_var.cname self.cname = entry.buffer_aux.buflocal_nd_var.cname
self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
self.buf_ptr_type = self.entry.type.buffer_ptr_type self.buf_ptr_type = entry.type.buffer_ptr_type
self.init_attributes()
def init_attributes(self):
self.shape = self.get_buf_shapevars()
self.strides = self.get_buf_stridevars()
self.suboffsets = self.get_buf_suboffsetvars()
def get_buf_suboffsetvars(self): def get_buf_suboffsetvars(self):
return self._for_all_ndim("%s.diminfo[%d].suboffsets") return self._for_all_ndim("%s.diminfo[%d].suboffsets")
......
...@@ -322,6 +322,13 @@ class ExprNode(Node): ...@@ -322,6 +322,13 @@ class ExprNode(Node):
is_string_literal = False is_string_literal = False
is_attribute = False is_attribute = False
is_subscript = False is_subscript = False
is_slice = False
is_buffer_access = False
is_memview_index = False
is_memview_slice = False
is_memview_broadcast = False
is_memview_copy_assignment = False
saved_subexpr_nodes = None saved_subexpr_nodes = None
is_temp = False is_temp = False
...@@ -330,9 +337,6 @@ class ExprNode(Node): ...@@ -330,9 +337,6 @@ class ExprNode(Node):
constant_result = constant_value_not_set constant_result = constant_value_not_set
# whether this node with a memoryview type should be broadcast
memslice_broadcast = False
child_attrs = property(fget=operator.attrgetter('subexprs')) child_attrs = property(fget=operator.attrgetter('subexprs'))
def not_implemented(self, method_name): def not_implemented(self, method_name):
...@@ -790,14 +794,12 @@ class ExprNode(Node): ...@@ -790,14 +794,12 @@ class ExprNode(Node):
if src.type.is_pyobject: if src.type.is_pyobject:
src = CoerceToMemViewSliceNode(src, dst_type, env) src = CoerceToMemViewSliceNode(src, dst_type, env)
elif src.type.is_array: elif src.type.is_array:
src = CythonArrayNode.from_carray(src, env).coerce_to( src = CythonArrayNode.from_carray(src, env).coerce_to(dst_type, env)
dst_type, env)
elif not src_type.is_error: elif not src_type.is_error:
error(self.pos, error(self.pos,
"Cannot convert '%s' to memoryviewslice" % "Cannot convert '%s' to memoryviewslice" % (src_type,))
(src_type,)) elif not src.type.conforms_to(dst_type, broadcast=self.is_memview_broadcast,
elif not MemoryView.src_conforms_to_dst( copying=self.is_memview_copy_assignment):
src.type, dst_type, broadcast=self.memslice_broadcast):
if src.type.dtype.same_as(dst_type.dtype): if src.type.dtype.same_as(dst_type.dtype):
msg = "Memoryview '%s' not conformable to memoryview '%s'." msg = "Memoryview '%s' not conformable to memoryview '%s'."
tup = src.type, dst_type tup = src.type, dst_type
...@@ -1834,10 +1836,6 @@ class NameNode(AtomicExprNode): ...@@ -1834,10 +1836,6 @@ class NameNode(AtomicExprNode):
self.gil_error() self.gil_error()
elif entry.is_pyglobal: elif entry.is_pyglobal:
self.gil_error() self.gil_error()
elif self.entry.type.is_memoryviewslice:
if self.cf_is_null or self.cf_maybe_null:
from . import MemoryView
MemoryView.err_if_nogil_initialized_check(self.pos, env)
gil_message = "Accessing Python global or builtin" gil_message = "Accessing Python global or builtin"
...@@ -2915,14 +2913,43 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode): ...@@ -2915,14 +2913,43 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
# #
#------------------------------------------------------------------- #-------------------------------------------------------------------
class IndexNode(ExprNode):
class _IndexingBaseNode(ExprNode):
# Base class for indexing nodes.
#
# base ExprNode the value being indexed
def is_ephemeral(self):
# in most cases, indexing will return a safe reference to an object in a container,
# so we consider the result safe if the base object is
return self.base.is_ephemeral() or self.base.type in (
basestring_type, str_type, bytes_type, unicode_type)
def check_const_addr(self):
return self.base.check_const_addr() and self.index.check_const()
def is_lvalue(self):
# NOTE: references currently have both is_reference and is_ptr
# set. Since pointers and references have different lvalue
# rules, we must be careful to separate the two.
if self.type.is_reference:
if self.type.ref_base_type.is_array:
# fixed-sized arrays aren't l-values
return False
elif self.type.is_ptr:
# non-const pointers can always be reassigned
return True
# Just about everything else returned by the index operator
# can be an lvalue.
return True
class IndexNode(_IndexingBaseNode):
# Sequence indexing. # Sequence indexing.
# #
# base ExprNode # base ExprNode
# index ExprNode # index ExprNode
# indices [ExprNode]
# type_indices [PyrexType] # type_indices [PyrexType]
# is_buffer_access boolean Whether this is a buffer access.
# #
# indices is used on buffer access, index on non-buffer access. # indices is used on buffer access, index on non-buffer access.
# The former contains a clean list of index parameters, the # The former contains a clean list of index parameters, the
...@@ -2931,33 +2958,18 @@ class IndexNode(ExprNode): ...@@ -2931,33 +2958,18 @@ class IndexNode(ExprNode):
# is_fused_index boolean Whether the index is used to specialize a # is_fused_index boolean Whether the index is used to specialize a
# c(p)def function # c(p)def function
subexprs = ['base', 'index', 'indices'] subexprs = ['base', 'index']
indices = None
type_indices = None type_indices = None
is_subscript = True is_subscript = True
is_fused_index = False is_fused_index = False
# Whether we're assigning to a buffer (in that case it needs to be
# writable)
writable_needed = False
# Whether we are indexing or slicing a memoryviewslice
memslice_index = False
memslice_slice = False
is_memslice_copy = False
memslice_ellipsis_noop = False
warned_untyped_idx = False
# set by SingleAssignmentNode after analyse_types()
is_memslice_scalar_assignment = False
def __init__(self, pos, index, **kw): def __init__(self, pos, index, **kw):
ExprNode.__init__(self, pos, index=index, **kw) ExprNode.__init__(self, pos, index=index, **kw)
self._index = index self._index = index
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = \ self.constant_result = self.base.constant_result[self.index.constant_result]
self.base.constant_result[self.index.constant_result]
def compile_time_value(self, denv): def compile_time_value(self, denv):
base = self.base.compile_time_value(denv) base = self.base.compile_time_value(denv)
...@@ -2967,18 +2979,7 @@ class IndexNode(ExprNode): ...@@ -2967,18 +2979,7 @@ class IndexNode(ExprNode):
except Exception as e: except Exception as e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def is_ephemeral(self):
# in most cases, indexing will return a safe reference to an object in a container,
# so we consider the result safe if the base object is
return self.base.is_ephemeral() or self.base.type in (
basestring_type, str_type, bytes_type, unicode_type)
def is_simple(self): def is_simple(self):
if self.is_buffer_access or self.memslice_index:
return False
elif self.memslice_slice:
return True
base = self.base base = self.base
return (base.is_simple() and self.index.is_simple() return (base.is_simple() and self.index.is_simple()
and base.type and (base.type.is_ptr or base.type.is_array)) and base.type and (base.type.is_ptr or base.type.is_array))
...@@ -3023,7 +3024,7 @@ class IndexNode(ExprNode): ...@@ -3023,7 +3024,7 @@ class IndexNode(ExprNode):
def infer_type(self, env): def infer_type(self, env):
base_type = self.base.infer_type(env) base_type = self.base.infer_type(env)
if isinstance(self.index, SliceNode): if self.index.is_slice:
# slicing! # slicing!
if base_type.is_string: if base_type.is_string:
# sliced C strings must coerce to Python # sliced C strings must coerce to Python
...@@ -3105,7 +3106,7 @@ class IndexNode(ExprNode): ...@@ -3105,7 +3106,7 @@ class IndexNode(ExprNode):
node = self.analyse_base_and_index_types(env, setting=True) node = self.analyse_base_and_index_types(env, setting=True)
if node.type.is_const: if node.type.is_const:
error(self.pos, "Assignment to const dereference") error(self.pos, "Assignment to const dereference")
if not node.is_lvalue(): if node is self and not node.is_lvalue():
error(self.pos, "Assignment to non-lvalue of type '%s'" % node.type) error(self.pos, "Assignment to non-lvalue of type '%s'" % node.type)
return node return node
...@@ -3114,19 +3115,6 @@ class IndexNode(ExprNode): ...@@ -3114,19 +3115,6 @@ class IndexNode(ExprNode):
# Note: This might be cleaned up by having IndexNode # Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if # parsed in a saner way and only construct the tuple if
# needed. # needed.
# Note that this function must leave IndexNode in a cloneable state.
# For buffers, self.index is packed out on the initial analysis, and
# when cloning self.indices is copied.
self.is_buffer_access = False
# a[...] = b
self.is_memslice_copy = False
# incomplete indexing, Ellipsis indexing or slicing
self.memslice_slice = False
# integer indexing
self.memslice_index = False
if analyse_base: if analyse_base:
self.base = self.base.analyse_types(env) self.base = self.base.analyse_types(env)
...@@ -3136,8 +3124,7 @@ class IndexNode(ExprNode): ...@@ -3136,8 +3124,7 @@ class IndexNode(ExprNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
return self return self
is_slice = isinstance(self.index, SliceNode) is_slice = self.index.is_slice
if not env.directives['wraparound']: if not env.directives['wraparound']:
if is_slice: if is_slice:
check_negative_indices(self.index.start, self.index.stop) check_negative_indices(self.index.start, self.index.stop)
...@@ -3149,181 +3136,21 @@ class IndexNode(ExprNode): ...@@ -3149,181 +3136,21 @@ class IndexNode(ExprNode):
self.index = self.index.coerce_to_pyobject(env) self.index = self.index.coerce_to_pyobject(env)
is_memslice = self.base.type.is_memoryviewslice is_memslice = self.base.type.is_memoryviewslice
# Handle the case where base is a literal char* (and we expect a string, not an int) # Handle the case where base is a literal char* (and we expect a string, not an int)
if not is_memslice and (isinstance(self.base, BytesNode) or is_slice): if not is_memslice and (isinstance(self.base, BytesNode) or is_slice):
if self.base.type.is_string or not (self.base.type.is_ptr or self.base.type.is_array): if self.base.type.is_string or not (self.base.type.is_ptr or self.base.type.is_array):
self.base = self.base.coerce_to_pyobject(env) self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False replacement_node = self.analyse_as_buffer_operation(env, getting)
buffer_access = False if replacement_node is not None:
return replacement_node
if self.indices:
indices = self.indices
elif isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
if (is_memslice and not self.indices and
isinstance(self.index, EllipsisNode)):
# Memoryviewslice copying
self.is_memslice_copy = True
elif is_memslice:
# memoryviewslice indexing or slicing
from . import MemoryView
skip_child_analysis = True
newaxes = [newaxis for newaxis in indices if newaxis.is_none]
have_slices, indices = MemoryView.unellipsify(indices,
newaxes,
self.base.type.ndim)
self.memslice_index = (not newaxes and
len(indices) == self.base.type.ndim)
axes = []
index_type = PyrexTypes.c_py_ssize_t_type
new_indices = []
if len(indices) - len(newaxes) > self.base.type.ndim:
self.type = error_type
error(indices[self.base.type.ndim].pos,
"Too many indices specified for type %s" %
self.base.type)
return self
axis_idx = 0
for i, index in enumerate(indices[:]):
index = index.analyse_types(env)
if not index.is_none:
access, packing = self.base.type.axes[axis_idx]
axis_idx += 1
if isinstance(index, SliceNode):
self.memslice_slice = True
if index.step.is_none:
axes.append((access, packing))
else:
axes.append((access, 'strided'))
# Coerce start, stop and step to temps of the right type
for attr in ('start', 'stop', 'step'):
value = getattr(index, attr)
if not value.is_none:
value = value.coerce_to(index_type, env)
#value = value.coerce_to_temp(env)
setattr(index, attr, value)
new_indices.append(value)
elif index.is_none:
self.memslice_slice = True
new_indices.append(index)
axes.append(('direct', 'strided'))
elif index.type.is_int or index.type.is_pyobject:
if index.type.is_pyobject and not self.warned_untyped_idx:
warning(index.pos, "Index should be typed for more "
"efficient access", level=2)
IndexNode.warned_untyped_idx = True
self.memslice_index = True
index = index.coerce_to(index_type, env)
indices[i] = index
new_indices.append(index)
else:
self.type = error_type
error(index.pos, "Invalid index for memoryview specified")
return self
self.memslice_index = self.memslice_index and not self.memslice_slice
self.original_indices = indices
# All indices with all start/stop/step for slices.
# We need to keep this around
self.indices = new_indices
self.env = env
elif self.base.type.is_buffer:
# Buffer indexing
if len(indices) == self.base.type.ndim:
buffer_access = True
skip_child_analysis = True
for x in indices:
x = x.analyse_types(env)
if not x.type.is_int:
buffer_access = False
if buffer_access and not self.base.type.is_memoryviewslice:
assert hasattr(self.base, "entry") # Must be a NameNode-like node
# On cloning, indices is cloned. Otherwise, unpack index into indices
assert not (buffer_access and isinstance(self.index, CloneNode))
self.nogil = env.nogil self.nogil = env.nogil
base_type = self.base.type
if buffer_access or self.memslice_index: if not base_type.is_cfunction:
#if self.base.type.is_memoryviewslice and not self.base.is_name: self.index = self.index.analyse_types(env)
# self.base = self.base.coerce_to_temp(env) self.original_index_type = self.index.type
self.base = self.base.coerce_to_simple(env)
self.indices = indices
self.index = None
self.type = self.base.type.dtype
self.is_buffer_access = True
self.buffer_type = self.base.type #self.base.entry.type
if getting and self.type.is_pyobject:
self.is_temp = True
if setting and self.base.type.is_memoryviewslice:
self.base.type.writable_needed = True
elif setting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.writable_needed = True
if self.base.type.is_buffer:
self.base.entry.buffer_aux.writable_needed = True
elif self.is_memslice_copy:
self.type = self.base.type
if getting:
self.memslice_ellipsis_noop = True
else:
self.memslice_broadcast = True
elif self.memslice_slice:
self.index = None
self.is_temp = True
self.use_managed_ref = True
if not MemoryView.validate_axes(self.pos, axes):
self.type = error_type
return self
self.type = PyrexTypes.MemoryViewSliceType(
self.base.type.dtype, axes)
if (self.base.type.is_memoryviewslice and not
self.base.is_name and not
self.base.result_in_temp()):
self.base = self.base.coerce_to_temp(env)
if setting:
self.memslice_broadcast = True
else:
base_type = self.base.type
if not base_type.is_cfunction:
if isinstance(self.index, TupleNode):
self.index = self.index.analyse_types(
env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
self.index = self.index.analyse_types(env)
self.original_index_type = self.index.type
if base_type.is_unicode_char: if base_type.is_unicode_char:
# we infer Py_UNICODE/Py_UCS4 for unicode strings in some # we infer Py_UNICODE/Py_UCS4 for unicode strings in some
...@@ -3335,125 +3162,173 @@ class IndexNode(ExprNode): ...@@ -3335,125 +3162,173 @@ class IndexNode(ExprNode):
return self.base return self.base
self.base = self.base.coerce_to_pyobject(env) self.base = self.base.coerce_to_pyobject(env)
base_type = self.base.type base_type = self.base.type
if base_type.is_pyobject:
if self.index.type.is_int and base_type is not dict_type: if base_type.is_pyobject:
if (getting return self.analyse_as_pyobject(env, is_slice, getting, setting)
and (base_type in (list_type, tuple_type, bytearray_type)) elif base_type.is_ptr or base_type.is_array:
and (not self.index.type.signed return self.analyse_as_c_array(env, is_slice)
or not env.directives['wraparound'] elif base_type.is_cpp_class:
or (isinstance(self.index, IntNode) and return self.analyse_as_cpp(env, setting)
self.index.has_constant_result() and self.index.constant_result >= 0)) elif base_type.is_cfunction:
and not env.directives['boundscheck']): return self.analyse_as_c_function(env)
self.is_temp = 0 elif base_type.is_ctuple:
else: return self.analyse_as_c_tuple(env, getting, setting)
self.is_temp = 1 else:
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) error(self.pos,
self.original_index_type.create_to_py_utility_code(env) "Attempting to index non-array type '%s'" %
else: base_type)
self.index = self.index.coerce_to_pyobject(env) self.type = PyrexTypes.error_type
self.is_temp = 1 return self
if self.index.type.is_int and base_type is unicode_type:
# Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string def analyse_as_pyobject(self, env, is_slice, getting, setting):
# if required, so this is fast and safe base_type = self.base.type
self.type = PyrexTypes.c_py_ucs4_type if self.index.type.is_int and base_type is not dict_type:
elif self.index.type.is_int and base_type is bytearray_type: if (getting
if setting: and (base_type in (list_type, tuple_type, bytearray_type))
self.type = PyrexTypes.c_uchar_type and (not self.index.type.signed
else: or not env.directives['wraparound']
# not using 'uchar' to enable fast and safe error reporting as '-1' or (isinstance(self.index, IntNode) and
self.type = PyrexTypes.c_int_type self.index.has_constant_result() and self.index.constant_result >= 0))
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type): and not env.directives['boundscheck']):
self.type = base_type self.is_temp = 0
else:
item_type = None
if base_type in (list_type, tuple_type) and self.index.type.is_int:
item_type = infer_sequence_item_type(
env, self.base, self.index, seq_type=base_type)
if item_type is None:
item_type = py_object_type
self.type = item_type
if base_type in (list_type, tuple_type, dict_type):
# do the None check explicitly (not in a helper) to allow optimising it away
self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
else: else:
if base_type.is_ptr or base_type.is_array: self.is_temp = 1
self.type = base_type.base_type self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
if is_slice: self.original_index_type.create_to_py_utility_code(env)
self.type = base_type else:
elif self.index.type.is_pyobject: self.index = self.index.coerce_to_pyobject(env)
self.index = self.index.coerce_to( self.is_temp = 1
PyrexTypes.c_py_ssize_t_type, env)
elif not self.index.type.is_int: if self.index.type.is_int and base_type is unicode_type:
error(self.pos, # Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string
"Invalid index type '%s'" % # if required, so this is fast and safe
self.index.type) self.type = PyrexTypes.c_py_ucs4_type
elif base_type.is_cpp_class: elif self.index.type.is_int and base_type is bytearray_type:
function = env.lookup_operator("[]", [self.base, self.index]) if setting:
if function is None: self.type = PyrexTypes.c_uchar_type
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type)) else:
self.type = PyrexTypes.error_type # not using 'uchar' to enable fast and safe error reporting as '-1'
self.result_code = "<error>" self.type = PyrexTypes.c_int_type
return self elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
func_type = function.type self.type = base_type
if func_type.is_ptr: else:
func_type = func_type.base_type item_type = None
self.index = self.index.coerce_to(func_type.args[0].type, env) if base_type in (list_type, tuple_type) and self.index.type.is_int:
self.type = func_type.return_type item_type = infer_sequence_item_type(
if setting and not func_type.return_type.is_reference: env, self.base, self.index, seq_type=base_type)
error(self.pos, "Can't set non-reference result '%s'" % self.type) if item_type is None:
elif base_type.is_cfunction: item_type = py_object_type
if base_type.is_fused: self.type = item_type
self.parse_indexed_fused_cdef(env) if base_type in (list_type, tuple_type, dict_type):
else: # do the None check explicitly (not in a helper) to allow optimising it away
self.type_indices = self.parse_index_as_types(env) self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
if base_type.templates is None:
error(self.pos, "Can only parameterize template functions.")
self.type = error_type
elif len(base_type.templates) != len(self.type_indices):
error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
(len(base_type.templates), len(self.type_indices))))
self.type = error_type
else:
self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
elif base_type.is_ctuple:
if isinstance(self.index, IntNode) and self.index.has_constant_result():
index = self.index.constant_result
if -base_type.size <= index < base_type.size:
if index < 0:
index += base_type.size
self.type = base_type.components[index]
else:
error(self.pos,
"Index %s out of bounds for '%s'" %
(index, base_type))
self.type = PyrexTypes.error_type
else:
self.base = self.base.coerce_to_pyobject(env)
return self.analyse_base_and_index_types(env, getting=getting, setting=setting, analyse_base=False)
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
base_type)
self.type = PyrexTypes.error_type
self.wrap_in_nonecheck_node(env, getting) self.wrap_in_nonecheck_node(env, getting)
return self return self
def wrap_in_nonecheck_node(self, env, getting): def analyse_as_c_array(self, env, is_slice):
if not env.directives['nonecheck'] or not self.base.may_be_none(): base_type = self.base.type
return self.type = base_type.base_type
if is_slice:
self.type = base_type
elif self.index.type.is_pyobject:
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env)
elif not self.index.type.is_int:
error(self.pos, "Invalid index type '%s'" % self.index.type)
return self
if self.base.type.is_memoryviewslice: def analyse_as_cpp(self, env, setting):
if self.is_memslice_copy and not getting: base_type = self.base.type
msg = "Cannot assign to None memoryview slice" function = env.lookup_operator("[]", [self.base, self.index])
elif self.memslice_slice: if function is None:
msg = "Cannot slice None memoryview slice" error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return self
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type
if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference result '%s'" % self.type)
return self
def analyse_as_c_function(self, env):
base_type = self.base.type
if base_type.is_fused:
self.parse_indexed_fused_cdef(env)
else:
self.type_indices = self.parse_index_as_types(env)
if base_type.templates is None:
error(self.pos, "Can only parameterize template functions.")
self.type = error_type
elif len(base_type.templates) != len(self.type_indices):
error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
(len(base_type.templates), len(self.type_indices))))
self.type = error_type
else:
self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
return self
def analyse_as_c_tuple(self, env, getting, setting):
base_type = self.base.type
if isinstance(self.index, IntNode) and self.index.has_constant_result():
index = self.index.constant_result
if -base_type.size <= index < base_type.size:
if index < 0:
index += base_type.size
self.type = base_type.components[index]
else: else:
msg = "Cannot index None memoryview slice" error(self.pos,
"Index %s out of bounds for '%s'" %
(index, base_type))
self.type = PyrexTypes.error_type
return self
else:
self.base = self.base.coerce_to_pyobject(env)
return self.analyse_base_and_index_types(env, getting=getting, setting=setting, analyse_base=False)
def analyse_as_buffer_operation(self, env, getting):
"""
Analyse buffer indexing and memoryview indexing/slicing
"""
if isinstance(self.index, TupleNode):
indices = self.index.args
else: else:
msg = "'NoneType' object is not subscriptable" indices = [self.index]
self.base = self.base.as_none_safe_node(msg) base_type = self.base.type
replacement_node = None
if base_type.is_memoryviewslice:
# memoryviewslice indexing or slicing
from . import MemoryView
have_slices, indices, newaxes = MemoryView.unellipsify(indices, base_type.ndim)
if have_slices:
replacement_node = MemoryViewSliceNode(self.pos, indices=indices, base=self.base)
else:
replacement_node = MemoryViewIndexNode(self.pos, indices=indices, base=self.base)
elif base_type.is_buffer and len(indices) == base_type.ndim:
# Buffer indexing
is_buffer_access = True
for index in indices:
index = index.analyse_types(env)
if not index.type.is_int:
is_buffer_access = False
if is_buffer_access:
replacement_node = BufferIndexNode(self.pos, indices=indices, base=self.base)
# On cloning, indices is cloned. Otherwise, unpack index into indices.
assert not isinstance(self.index, CloneNode)
if replacement_node is not None:
replacement_node = replacement_node.analyse_types(env, getting)
return replacement_node
def wrap_in_nonecheck_node(self, env, getting):
if not env.directives['nonecheck'] or not self.base.may_be_none():
return
self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
def parse_index_as_types(self, env, required=True): def parse_index_as_types(self, env, required=True):
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
...@@ -3563,43 +3438,8 @@ class IndexNode(ExprNode): ...@@ -3563,43 +3438,8 @@ class IndexNode(ExprNode):
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
def nogil_check(self, env):
if self.is_buffer_access or self.memslice_index or self.memslice_slice:
if not self.memslice_slice and env.directives['boundscheck']:
# error(self.pos, "Cannot check buffer index bounds without gil; "
# "use boundscheck(False) directive")
warning(self.pos, "Use boundscheck(False) for faster access",
level=1)
if self.type.is_pyobject:
error(self.pos, "Cannot access buffer with object dtype without gil")
return
super(IndexNode, self).nogil_check(env)
def check_const_addr(self):
return self.base.check_const_addr() and self.index.check_const()
def is_lvalue(self):
# NOTE: references currently have both is_reference and is_ptr
# set. Since pointers and references have different lvalue
# rules, we must be careful to separate the two.
if self.type.is_reference:
if self.type.ref_base_type.is_array:
# fixed-sized arrays aren't l-values
return False
elif self.type.is_ptr:
# non-const pointers can always be reassigned
return True
# Just about everything else returned by the index operator
# can be an lvalue.
return True
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access: if self.base.type in (list_type, tuple_type, bytearray_type):
return "(*%s)" % self.buffer_ptr_code
elif self.is_memslice_copy:
return self.base.result()
elif self.base.type in (list_type, tuple_type, bytearray_type):
if self.base.type is list_type: if self.base.type is list_type:
index_code = "PyList_GET_ITEM(%s, %s)" index_code = "PyList_GET_ITEM(%s, %s)"
elif self.base.type is tuple_type: elif self.base.type is tuple_type:
...@@ -3641,101 +3481,62 @@ class IndexNode(ExprNode): ...@@ -3641,101 +3481,62 @@ class IndexNode(ExprNode):
else: else:
return "" return ""
def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code)
if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code)
if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_disposal_code(code)
else:
for i in self.indices:
i.generate_disposal_code(code)
def free_subexpr_temps(self, code):
self.base.free_temps(code)
if self.indices is None:
self.index.free_temps(code)
else:
for i in self.indices:
i.free_temps(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.is_buffer_access or self.memslice_index: if not self.is_temp:
buffer_entry, self.buffer_ptr_code = self.buffer_lookup_code(code) # all handled in self.calculate_result_code()
if self.type.is_pyobject: return
# is_temp is True, so must pull out value and incref it. if self.type.is_pyobject:
# NOTE: object temporary results for nodes are declared error_value = 'NULL'
# as PyObject *, so we need a cast if self.index.type.is_int:
code.putln("%s = (PyObject *) *%s;" % (self.temp_code, if self.base.type is list_type:
self.buffer_ptr_code)) function = "__Pyx_GetItemInt_List"
code.putln("__Pyx_INCREF((PyObject*)%s);" % self.temp_code) elif self.base.type is tuple_type:
function = "__Pyx_GetItemInt_Tuple"
elif self.memslice_slice:
self.put_memoryviewslice_slice_code(code)
elif self.is_temp:
if self.type.is_pyobject:
error_value = 'NULL'
if self.index.type.is_int:
if self.base.type is list_type:
function = "__Pyx_GetItemInt_List"
elif self.base.type is tuple_type:
function = "__Pyx_GetItemInt_Tuple"
else:
function = "__Pyx_GetItemInt"
code.globalstate.use_utility_code(
TempitaUtilityCode.load_cached("GetItemInt", "ObjectHandling.c"))
else: else:
if self.base.type is dict_type: function = "__Pyx_GetItemInt"
function = "__Pyx_PyDict_GetItem"
code.globalstate.use_utility_code(
UtilityCode.load_cached("DictGetItem", "ObjectHandling.c"))
else:
function = "PyObject_GetItem"
elif self.type.is_unicode_char and self.base.type is unicode_type:
assert self.index.type.is_int
function = "__Pyx_GetItemInt_Unicode"
error_value = '(Py_UCS4)-1'
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntUnicode", "StringTools.c"))
elif self.base.type is bytearray_type:
assert self.index.type.is_int
assert self.type.is_int
function = "__Pyx_GetItemInt_ByteArray"
error_value = '-1'
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c")) TempitaUtilityCode.load_cached("GetItemInt", "ObjectHandling.c"))
else: else:
assert False, "unexpected type %s and base type %s for indexing" % ( if self.base.type is dict_type:
self.type, self.base.type) function = "__Pyx_PyDict_GetItem"
code.globalstate.use_utility_code(
UtilityCode.load_cached("DictGetItem", "ObjectHandling.c"))
else:
function = "PyObject_GetItem"
elif self.type.is_unicode_char and self.base.type is unicode_type:
assert self.index.type.is_int
function = "__Pyx_GetItemInt_Unicode"
error_value = '(Py_UCS4)-1'
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntUnicode", "StringTools.c"))
elif self.base.type is bytearray_type:
assert self.index.type.is_int
assert self.type.is_int
function = "__Pyx_GetItemInt_ByteArray"
error_value = '-1'
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c"))
else:
assert False, "unexpected type %s and base type %s for indexing" % (
self.type, self.base.type)
if self.index.type.is_int: if self.index.type.is_int:
index_code = self.index.result() index_code = self.index.result()
else: else:
index_code = self.index.py_result() index_code = self.index.py_result()
code.putln( code.putln(
"%s = %s(%s, %s%s); if (unlikely(%s == %s)) %s;" % ( "%s = %s(%s, %s%s); if (unlikely(%s == %s)) %s;" % (
self.result(), self.result(),
function, function,
self.base.py_result(), self.base.py_result(),
index_code, index_code,
self.extra_index_params(code), self.extra_index_params(code),
self.result(), self.result(),
error_value, error_value,
code.error_goto(self.pos))) code.error_goto(self.pos)))
if self.type.is_pyobject: if self.type.is_pyobject:
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
def generate_setitem_code(self, value_code, code): def generate_setitem_code(self, value_code, code):
if self.index.type.is_int: if self.index.type.is_int:
...@@ -3770,57 +3571,20 @@ class IndexNode(ExprNode): ...@@ -3770,57 +3571,20 @@ class IndexNode(ExprNode):
self.extra_index_params(code), self.extra_index_params(code),
code.error_goto(self.pos))) code.error_goto(self.pos)))
def generate_buffer_setitem_code(self, rhs, code, op=""):
# Used from generate_assignment_code and InPlaceAssignmentNode
buffer_entry, ptrexpr = self.buffer_lookup_code(code)
if self.buffer_type.dtype.is_pyobject:
# Must manage refcounts. Decref what is already there
# and incref what we put in.
ptr = code.funcstate.allocate_temp(buffer_entry.buf_ptr_type,
manage_ref=False)
rhs_code = rhs.result()
code.putln("%s = %s;" % (ptr, ptrexpr))
code.put_gotref("*%s" % ptr)
code.putln("__Pyx_INCREF(%s); __Pyx_DECREF(*%s);" % (
rhs_code, ptr))
code.putln("*%s %s= %s;" % (ptr, op, rhs_code))
code.put_giveref("*%s" % ptr)
code.funcstate.release_temp(ptr)
else:
# Simple case
code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result()))
def generate_assignment_code(self, rhs, code, overloaded_assignment=False): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
generate_evaluation_code = (self.is_memslice_scalar_assignment or self.generate_subexpr_evaluation_code(code)
self.memslice_slice)
if generate_evaluation_code:
self.generate_evaluation_code(code)
else:
self.generate_subexpr_evaluation_code(code)
if self.is_buffer_access or self.memslice_index: if self.type.is_pyobject:
self.generate_buffer_setitem_code(rhs, code)
elif self.is_memslice_scalar_assignment:
self.generate_memoryviewslice_assign_scalar_code(rhs, code)
elif self.memslice_slice or self.is_memslice_copy:
self.generate_memoryviewslice_setslice_code(rhs, code)
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
elif self.base.type is bytearray_type: elif self.base.type is bytearray_type:
value_code = self._check_byte_value(code, rhs) value_code = self._check_byte_value(code, rhs)
self.generate_setitem_code(value_code, code) self.generate_setitem_code(value_code, code)
else: else:
code.putln( code.putln(
"%s = %s;" % ( "%s = %s;" % (self.result(), rhs.result()))
self.result(), rhs.result()))
if generate_evaluation_code:
self.generate_disposal_code(code)
else:
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
rhs.generate_disposal_code(code) rhs.generate_disposal_code(code)
rhs.free_temps(code) rhs.free_temps(code)
...@@ -3884,27 +3648,88 @@ class IndexNode(ExprNode): ...@@ -3884,27 +3648,88 @@ class IndexNode(ExprNode):
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code) self.free_subexpr_temps(code)
def buffer_entry(self):
from . import Buffer, MemoryView
base = self.base class BufferIndexNode(_IndexingBaseNode):
if self.base.is_nonecheck: """
base = base.arg Indexing of buffers and memoryviews. This node is created during type
analysis from IndexNode and replaces it.
if base.is_name: Attributes:
entry = base.entry base - base node being indexed
else: indices - list of indexing expressions
# SimpleCallNode is_simple is not consistent with coerce_to_simple """
assert base.is_simple() or base.is_temp
cname = base.result()
entry = Symtab.Entry(cname, cname, self.base.type, self.base.pos)
if entry.type.is_buffer: subexprs = ['base', 'indices']
buffer_entry = Buffer.BufferEntry(entry)
else: is_buffer_access = True
buffer_entry = MemoryView.MemoryViewSliceBufferEntry(entry)
# Whether we're assigning to a buffer (in that case it needs to be writable)
writable_needed = False
def analyse_target_types(self, env):
self.analyse_types(env, getting=False)
def analyse_types(self, env, getting=True):
"""
Analyse types for buffer indexing only. Overridden by memoryview
indexing and slicing subclasses
"""
# self.indices are already analyzed
if not self.base.is_name:
error(self.pos, "Can only index buffer variables")
self.type = error_type
return self
if not getting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.writable_needed = True
if self.base.type.is_buffer:
self.base.entry.buffer_aux.writable_needed = True
self.none_error_message = "'NoneType' object is not subscriptable"
self.analyse_buffer_index(env, getting)
self.wrap_in_nonecheck_node(env)
return self
def analyse_buffer_index(self, env, getting):
self.base = self.base.coerce_to_simple(env)
self.type = self.base.type.dtype
self.buffer_type = self.base.type
if getting and self.type.is_pyobject:
self.is_temp = True
return buffer_entry def analyse_assignment(self, rhs):
"""
Called by IndexNode when this node is assigned to,
with the rhs of the assignment
"""
def wrap_in_nonecheck_node(self, env):
if not env.directives['nonecheck'] or not self.base.may_be_none():
return
self.base = self.base.as_none_safe_node(self.none_error_message)
def nogil_check(self, env):
if self.is_buffer_access or self.is_memview_index:
if env.directives['boundscheck']:
warning(self.pos, "Use boundscheck(False) for faster access",
level=1)
if self.type.is_pyobject:
error(self.pos, "Cannot access buffer with object dtype without gil")
self.type = error_type
def calculate_result_code(self):
return "(*%s)" % self.buffer_ptr_code
def buffer_entry(self):
base = self.base
if self.base.is_nonecheck:
base = base.arg
return base.type.get_entry(base)
def buffer_lookup_code(self, code): def buffer_lookup_code(self, code):
""" """
...@@ -3938,17 +3763,228 @@ class IndexNode(ExprNode): ...@@ -3938,17 +3763,228 @@ class IndexNode(ExprNode):
negative_indices=negative_indices, negative_indices=negative_indices,
in_nogil_context=self.in_nogil_context) in_nogil_context=self.in_nogil_context)
def put_memoryviewslice_slice_code(self, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
"memslice[:]" self.generate_subexpr_evaluation_code(code)
self.generate_buffer_setitem_code(rhs, code)
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
def generate_buffer_setitem_code(self, rhs, code, op=""):
# Used from generate_assignment_code and InPlaceAssignmentNode
buffer_entry, ptrexpr = self.buffer_lookup_code(code)
if self.buffer_type.dtype.is_pyobject:
# Must manage refcounts. Decref what is already there
# and incref what we put in.
ptr = code.funcstate.allocate_temp(buffer_entry.buf_ptr_type,
manage_ref=False)
rhs_code = rhs.result()
code.putln("%s = %s;" % (ptr, ptrexpr))
code.put_gotref("*%s" % ptr)
code.putln("__Pyx_INCREF(%s); __Pyx_DECREF(*%s);" % (
rhs_code, ptr))
code.putln("*%s %s= %s;" % (ptr, op, rhs_code))
code.put_giveref("*%s" % ptr)
code.funcstate.release_temp(ptr)
else:
# Simple case
code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result()))
def generate_result_code(self, code):
buffer_entry, self.buffer_ptr_code = self.buffer_lookup_code(code)
if self.type.is_pyobject:
# is_temp is True, so must pull out value and incref it.
# NOTE: object temporary results for nodes are declared
# as PyObject *, so we need a cast
code.putln("%s = (PyObject *) *%s;" % (self.result(), self.buffer_ptr_code))
code.putln("__Pyx_INCREF((PyObject*)%s);" % self.result())
class MemoryViewIndexNode(BufferIndexNode):
is_memview_index = True
is_buffer_access = False
warned_untyped_idx = False
def analyse_types(self, env, getting=True):
# memoryviewslice indexing or slicing
from . import MemoryView
indices = self.indices
have_slices, indices, newaxes = MemoryView.unellipsify(indices, self.base.type.ndim)
self.memslice_index = (not newaxes and len(indices) == self.base.type.ndim)
axes = []
index_type = PyrexTypes.c_py_ssize_t_type
new_indices = []
if len(indices) - len(newaxes) > self.base.type.ndim:
self.type = error_type
error(indices[self.base.type.ndim].pos,
"Too many indices specified for type %s" % self.base.type)
return self
axis_idx = 0
for i, index in enumerate(indices[:]):
index = index.analyse_types(env)
if index.is_none:
self.is_memview_slice = True
new_indices.append(index)
axes.append(('direct', 'strided'))
continue
access, packing = self.base.type.axes[axis_idx]
axis_idx += 1
if index.is_slice:
self.is_memview_slice = True
if index.step.is_none:
axes.append((access, packing))
else:
axes.append((access, 'strided'))
# Coerce start, stop and step to temps of the right type
for attr in ('start', 'stop', 'step'):
value = getattr(index, attr)
if not value.is_none:
value = value.coerce_to(index_type, env)
#value = value.coerce_to_temp(env)
setattr(index, attr, value)
new_indices.append(value)
elif index.type.is_int or index.type.is_pyobject:
if index.type.is_pyobject and not self.warned_untyped_idx:
warning(index.pos, "Index should be typed for more efficient access", level=2)
MemoryViewIndexNode.warned_untyped_idx = True
self.is_memview_index = True
index = index.coerce_to(index_type, env)
indices[i] = index
new_indices.append(index)
else:
self.type = error_type
error(index.pos, "Invalid index for memoryview specified, type %s" % index.type)
return self
### FIXME: replace by MemoryViewSliceNode if is_memview_slice ?
self.is_memview_index = self.is_memview_index and not self.is_memview_slice
self.indices = new_indices
# All indices with all start/stop/step for slices.
# We need to keep this around.
self.original_indices = indices
self.nogil = env.nogil
self.analyse_operation(env, getting, axes)
self.wrap_in_nonecheck_node(env)
return self
def analyse_operation(self, env, getting, axes):
self.none_error_message = "Cannot index None memoryview slice"
self.analyse_buffer_index(env, getting)
def analyse_broadcast_operation(self, rhs):
"""
Support broadcasting for slice assignment.
E.g.
m_2d[...] = m_1d # or,
m_1d[...] = m_2d # if the leading dimension has extent 1
"""
if self.type.is_memoryviewslice:
lhs = self
if lhs.is_memview_broadcast or rhs.is_memview_broadcast:
lhs.is_memview_broadcast = True
rhs.is_memview_broadcast = True
def analyse_as_memview_scalar_assignment(self, rhs):
lhs = self.analyse_assignment(rhs)
if lhs:
rhs.is_memview_copy_assignment = lhs.is_memview_copy_assignment
return lhs
return self
class MemoryViewSliceNode(MemoryViewIndexNode):
is_memview_slice = True
# No-op slicing operation, this node will be replaced
is_ellipsis_noop = False
is_memview_scalar_assignment = False
is_memview_index = False
is_memview_broadcast = False
def analyse_ellipsis_noop(self, env, getting):
"""Slicing operations needing no evaluation, i.e. m[...] or m[:, :]"""
### FIXME: replace directly
self.is_ellipsis_noop = all(
index.is_slice and index.start.is_none and index.stop.is_none and index.step.is_none
for index in self.indices)
if self.is_ellipsis_noop:
self.type = self.base.type
def analyse_operation(self, env, getting, axes):
from . import MemoryView
if not getting:
self.is_memview_broadcast = True
self.none_error_message = "Cannot assign to None memoryview slice"
else:
self.none_error_message = "Cannot slice None memoryview slice"
self.analyse_ellipsis_noop(env, getting)
if self.is_ellipsis_noop:
return
self.index = None
self.is_temp = True
self.use_managed_ref = True
if not MemoryView.validate_axes(self.pos, axes):
self.type = error_type
return
self.type = PyrexTypes.MemoryViewSliceType(self.base.type.dtype, axes)
if not (self.base.is_simple() or self.base.result_in_temp()):
self.base = self.base.coerce_to_temp(env)
def analyse_assignment(self, rhs):
if not rhs.type.is_memoryviewslice and (
self.type.dtype.assignable_from(rhs.type) or
rhs.type.is_pyobject):
# scalar assignment
return MemoryCopyScalar(self.pos, self)
else:
return MemoryCopySlice(self.pos, self)
def is_simple(self):
if self.is_ellipsis_noop:
# TODO: fix SimpleCallNode.is_simple()
return self.base.is_simple() or self.base.result_in_temp()
return self.result_in_temp()
def calculate_result_code(self):
"""This is called in case this is a no-op slicing node"""
return self.base.result()
def generate_result_code(self, code):
if self.is_ellipsis_noop:
return ### FIXME: remove
buffer_entry = self.buffer_entry() buffer_entry = self.buffer_entry()
have_gil = not self.in_nogil_context have_gil = not self.in_nogil_context
# TODO Mark: this is insane, do it better
have_slices = False have_slices = False
it = iter(self.indices) it = iter(self.indices)
for index in self.original_indices: for index in self.original_indices:
is_slice = isinstance(index, SliceNode) if index.is_slice:
have_slices = have_slices or is_slice have_slices = True
if is_slice:
if not index.start.is_none: if not index.start.is_none:
index.start = next(it) index.start = next(it)
if not index.stop.is_none: if not index.stop.is_none:
...@@ -3960,21 +3996,123 @@ class IndexNode(ExprNode): ...@@ -3960,21 +3996,123 @@ class IndexNode(ExprNode):
assert not list(it) assert not list(it)
buffer_entry.generate_buffer_slice_code(code, self.original_indices, buffer_entry.generate_buffer_slice_code(
self.result(), code, self.original_indices, self.result(),
have_gil=have_gil, have_gil=have_gil, have_slices=have_slices,
have_slices=have_slices, directives=code.globalstate.directives)
directives=code.globalstate.directives)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
if self.is_ellipsis_noop:
self.generate_subexpr_evaluation_code(code)
else:
self.generate_evaluation_code(code)
if self.is_memview_scalar_assignment:
self.generate_memoryviewslice_assign_scalar_code(rhs, code)
else:
self.generate_memoryviewslice_setslice_code(rhs, code)
if self.is_ellipsis_noop:
self.generate_subexpr_disposal_code(code)
else:
self.generate_disposal_code(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
class MemoryCopyNode(ExprNode):
"""
Wraps a memoryview slice for slice assignment.
dst: destination mememoryview slice
"""
subexprs = ['dst']
def __init__(self, pos, dst):
super(MemoryCopyNode, self).__init__(pos)
self.dst = dst
self.type = dst.type
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.dst.generate_evaluation_code(code)
self._generate_assignment_code(rhs, code)
self.dst.generate_disposal_code(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
def generate_memoryviewslice_setslice_code(self, rhs, code):
"memslice1[...] = memslice2 or memslice1[:] = memslice2"
from . import MemoryView
MemoryView.copy_broadcast_memview_src_to_dst(rhs, self, code)
def generate_memoryviewslice_assign_scalar_code(self, rhs, code): class MemoryCopySlice(MemoryCopyNode):
"memslice1[...] = 0.0 or memslice1[:] = 0.0" """
Copy the contents of slice src to slice dst. Does not support indirect
slices.
memslice1[...] = memslice2
memslice1[:] = memslice2
"""
is_memview_copy_assignment = True
copy_slice_cname = "__pyx_memoryview_copy_contents"
def _generate_assignment_code(self, src, code):
dst = self.dst
src.type.assert_direct_dims(src.pos)
dst.type.assert_direct_dims(dst.pos)
code.putln(code.error_goto_if_neg(
"%s(%s, %s, %d, %d, %d)" % (self.copy_slice_cname,
src.result(), dst.result(),
src.type.ndim, dst.type.ndim,
dst.type.dtype.is_pyobject),
dst.pos))
class MemoryCopyScalar(MemoryCopyNode):
"""
Assign a scalar to a slice. dst must be simple, scalar will be assigned
to a correct type and not just something assignable.
memslice1[...] = 0.0
memslice1[:] = 0.0
"""
def __init__(self, pos, dst):
super(MemoryCopyScalar, self).__init__(pos, dst)
self.type = dst.type.dtype
def _generate_assignment_code(self, scalar, code):
from . import MemoryView from . import MemoryView
MemoryView.assign_scalar(self, rhs, code)
self.dst.type.assert_direct_dims(self.dst.pos)
dtype = self.dst.type.dtype
type_decl = dtype.declaration_code("")
slice_decl = self.dst.type.declaration_code("")
code.begin_block()
code.putln("%s __pyx_temp_scalar = %s;" % (type_decl, scalar.result()))
if self.dst.result_in_temp() or self.dst.is_simple():
dst_temp = self.dst.result()
else:
code.putln("%s __pyx_temp_slice = %s;" % (slice_decl, self.dst.result()))
dst_temp = "__pyx_temp_slice"
slice_iter_obj = MemoryView.slice_iter(self.dst.type, dst_temp,
self.dst.type.ndim, code)
p = slice_iter_obj.start_loops()
if dtype.is_pyobject:
code.putln("Py_DECREF(*(PyObject **) %s);" % p)
code.putln("*((%s *) %s) = __pyx_temp_scalar;" % (type_decl, p))
if dtype.is_pyobject:
code.putln("Py_INCREF(__pyx_temp_scalar);")
slice_iter_obj.end_loops()
code.end_block()
class SliceIndexNode(ExprNode): class SliceIndexNode(ExprNode):
...@@ -4428,7 +4566,7 @@ class SliceNode(ExprNode): ...@@ -4428,7 +4566,7 @@ class SliceNode(ExprNode):
# step ExprNode # step ExprNode
subexprs = ['start', 'stop', 'step'] subexprs = ['start', 'stop', 'step']
is_slice = True
type = slice_type type = slice_type
is_temp = 1 is_temp = 1
...@@ -4710,8 +4848,7 @@ class SimpleCallNode(CallNode): ...@@ -4710,8 +4848,7 @@ class SimpleCallNode(CallNode):
return return
elif hasattr(self.function, 'entry'): elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry overloaded_entry = self.function.entry
elif (isinstance(self.function, IndexNode) and elif self.function.is_subscript and self.function.is_fused_index:
self.function.is_fused_index):
overloaded_entry = self.function.type.entry overloaded_entry = self.function.type.entry
else: else:
overloaded_entry = None overloaded_entry = None
...@@ -6014,7 +6151,7 @@ class AttributeNode(ExprNode): ...@@ -6014,7 +6151,7 @@ class AttributeNode(ExprNode):
self.is_memslice_transpose = True self.is_memslice_transpose = True
self.is_temp = True self.is_temp = True
self.use_managed_ref = True self.use_managed_ref = True
self.type = self.obj.type self.type = self.obj.type.transpose(self.pos)
return return
else: else:
obj_type.declare_attribute(self.attribute, env, self.pos) obj_type.declare_attribute(self.attribute, env, self.pos)
...@@ -6099,13 +6236,9 @@ class AttributeNode(ExprNode): ...@@ -6099,13 +6236,9 @@ class AttributeNode(ExprNode):
self.obj = self.obj.as_none_safe_node(msg, 'PyExc_AttributeError', self.obj = self.obj.as_none_safe_node(msg, 'PyExc_AttributeError',
format_args=format_args) format_args=format_args)
def nogil_check(self, env): def nogil_check(self, env):
if self.is_py_attr: if self.is_py_attr:
self.gil_error() self.gil_error()
elif self.type.is_memoryviewslice:
from . import MemoryView
MemoryView.err_if_nogil_initialized_check(self.pos, env, 'attribute')
gil_message = "Accessing Python attribute" gil_message = "Accessing Python attribute"
...@@ -9246,7 +9379,7 @@ class AmpersandNode(CUnopNode): ...@@ -9246,7 +9379,7 @@ class AmpersandNode(CUnopNode):
if argtype.is_memoryviewslice: if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice") self.error("Cannot take address of memoryview slice")
else: else:
self.error("Taking address of non-lvalue") self.error("Taking address of non-lvalue (type %s)" % argtype)
return self return self
if argtype.is_pyobject: if argtype.is_pyobject:
self.error("Cannot take address of Python variable") self.error("Cannot take address of Python variable")
...@@ -9434,6 +9567,7 @@ ERR_STEPS = ("Strides may only be given to indicate contiguity. " ...@@ -9434,6 +9567,7 @@ ERR_STEPS = ("Strides may only be given to indicate contiguity. "
ERR_NOT_POINTER = "Can only create cython.array from pointer or array" ERR_NOT_POINTER = "Can only create cython.array from pointer or array"
ERR_BASE_TYPE = "Pointer base type does not match cython.array base type" ERR_BASE_TYPE = "Pointer base type does not match cython.array base type"
class CythonArrayNode(ExprNode): class CythonArrayNode(ExprNode):
""" """
Used when a pointer of base_type is cast to a memoryviewslice with that Used when a pointer of base_type is cast to a memoryviewslice with that
...@@ -9474,8 +9608,6 @@ class CythonArrayNode(ExprNode): ...@@ -9474,8 +9608,6 @@ class CythonArrayNode(ExprNode):
array_dtype = self.base_type_node.base_type_node.analyse(env) array_dtype = self.base_type_node.base_type_node.analyse(env)
axes = self.base_type_node.axes axes = self.base_type_node.axes
MemoryView.validate_memslice_dtype(self.pos, array_dtype)
self.type = error_type self.type = error_type
self.shapes = [] self.shapes = []
ndim = len(axes) ndim = len(axes)
...@@ -9564,6 +9696,7 @@ class CythonArrayNode(ExprNode): ...@@ -9564,6 +9696,7 @@ class CythonArrayNode(ExprNode):
axes[-1] = ('direct', 'contig') axes[-1] = ('direct', 'contig')
self.coercion_type = PyrexTypes.MemoryViewSliceType(array_dtype, axes) self.coercion_type = PyrexTypes.MemoryViewSliceType(array_dtype, axes)
self.coercion_type.validate_memslice_dtype(self.pos)
self.type = self.get_cython_array_type(env) self.type = self.get_cython_array_type(env)
MemoryView.use_cython_array_utility_code(env) MemoryView.use_cython_array_utility_code(env)
env.use_utility_code(MemoryView.typeinfo_to_format_code) env.use_utility_code(MemoryView.typeinfo_to_format_code)
...@@ -11639,6 +11772,7 @@ class CoercionNode(ExprNode): ...@@ -11639,6 +11772,7 @@ class CoercionNode(ExprNode):
code.annotate((file, line, col-1), AnnotationItem( code.annotate((file, line, col-1), AnnotationItem(
style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type))) style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type)))
class CoerceToMemViewSliceNode(CoercionNode): class CoerceToMemViewSliceNode(CoercionNode):
""" """
Coerce an object to a memoryview slice. This holds a new reference in Coerce an object to a memoryview slice. This holds a new reference in
......
...@@ -200,7 +200,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -200,7 +200,7 @@ class FusedCFuncDefNode(StatListNode):
if arg.type.is_fused: if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific) arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice: if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype) arg.type.validate_memslice_dtype(arg.pos)
def create_new_local_scope(self, node, env, f2s): def create_new_local_scope(self, node, env, f2s):
""" """
......
...@@ -21,15 +21,11 @@ CF_ERR = "Invalid axis specification for a C/Fortran contiguous array." ...@@ -21,15 +21,11 @@ CF_ERR = "Invalid axis specification for a C/Fortran contiguous array."
ERR_UNINITIALIZED = ("Cannot check if memoryview %s is initialized without the " ERR_UNINITIALIZED = ("Cannot check if memoryview %s is initialized without the "
"GIL, consider using initializedcheck(False)") "GIL, consider using initializedcheck(False)")
def err_if_nogil_initialized_check(pos, env, name='variable'):
"This raises an exception at runtime now"
pass
#if env.nogil and env.directives['initializedcheck']:
#error(pos, ERR_UNINITIALIZED % name)
def concat_flags(*flags): def concat_flags(*flags):
return "(%s)" % "|".join(flags) return "(%s)" % "|".join(flags)
format_flag = "PyBUF_FORMAT" format_flag = "PyBUF_FORMAT"
memview_c_contiguous = "(PyBUF_C_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)" memview_c_contiguous = "(PyBUF_C_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
...@@ -71,18 +67,16 @@ memview_typeptr_cname = '__pyx_memoryview_type' ...@@ -71,18 +67,16 @@ memview_typeptr_cname = '__pyx_memoryview_type'
memview_objstruct_cname = '__pyx_memoryview_obj' memview_objstruct_cname = '__pyx_memoryview_obj'
memviewslice_cname = u'__Pyx_memviewslice' memviewslice_cname = u'__Pyx_memviewslice'
def put_init_entry(mv_cname, code): def put_init_entry(mv_cname, code):
code.putln("%s.data = NULL;" % mv_cname) code.putln("%s.data = NULL;" % mv_cname)
code.putln("%s.memview = NULL;" % mv_cname) code.putln("%s.memview = NULL;" % mv_cname)
def mangle_dtype_name(dtype):
# a dumb wrapper for now; move Buffer.mangle_dtype_name in here later?
from . import Buffer
return Buffer.mangle_dtype_name(dtype)
#def axes_to_str(axes): #def axes_to_str(axes):
# return "".join([access[0].upper()+packing[0] for (access, packing) in axes]) # return "".join([access[0].upper()+packing[0] for (access, packing) in axes])
def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code, def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
have_gil=False, first_assignment=True): have_gil=False, first_assignment=True):
"We can avoid decreffing the lhs if we know it is the first assignment" "We can avoid decreffing the lhs if we know it is the first assignment"
...@@ -103,6 +97,7 @@ def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code, ...@@ -103,6 +97,7 @@ def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
if not pretty_rhs: if not pretty_rhs:
code.funcstate.release_temp(rhstmp) code.funcstate.release_temp(rhstmp)
def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code, def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code,
have_gil=False, first_assignment=False): have_gil=False, first_assignment=False):
if not first_assignment: if not first_assignment:
...@@ -113,6 +108,7 @@ def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code ...@@ -113,6 +108,7 @@ def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code
code.putln("%s = %s;" % (lhs_cname, rhs_cname)) code.putln("%s = %s;" % (lhs_cname, rhs_cname))
def get_buf_flags(specs): def get_buf_flags(specs):
is_c_contig, is_f_contig = is_cf_contig(specs) is_c_contig, is_f_contig = is_cf_contig(specs)
...@@ -128,11 +124,13 @@ def get_buf_flags(specs): ...@@ -128,11 +124,13 @@ def get_buf_flags(specs):
else: else:
return memview_strided_access return memview_strided_access
def insert_newaxes(memoryviewtype, n): def insert_newaxes(memoryviewtype, n):
axes = [('direct', 'strided')] * n axes = [('direct', 'strided')] * n
axes.extend(memoryviewtype.axes) axes.extend(memoryviewtype.axes)
return PyrexTypes.MemoryViewSliceType(memoryviewtype.dtype, axes) return PyrexTypes.MemoryViewSliceType(memoryviewtype.dtype, axes)
def broadcast_types(src, dst): def broadcast_types(src, dst):
n = abs(src.ndim - dst.ndim) n = abs(src.ndim - dst.ndim)
if src.ndim < dst.ndim: if src.ndim < dst.ndim:
...@@ -140,37 +138,6 @@ def broadcast_types(src, dst): ...@@ -140,37 +138,6 @@ def broadcast_types(src, dst):
else: else:
return src, insert_newaxes(dst, n) return src, insert_newaxes(dst, n)
def src_conforms_to_dst(src, dst, broadcast=False):
'''
returns True if src conforms to dst, False otherwise.
If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.
Any packing/access spec is conformable to itself.
'direct' and 'ptr' are conformable to 'full'.
'contig' and 'follow' are conformable to 'strided'.
Any other combo is not conformable.
'''
if src.dtype != dst.dtype:
return False
if src.ndim != dst.ndim:
if broadcast:
src, dst = broadcast_types(src, dst)
else:
return False
for src_spec, dst_spec in zip(src.axes, dst.axes):
src_access, src_packing = src_spec
dst_access, dst_packing = dst_spec
if src_access != dst_access and dst_access != 'full':
return False
if src_packing != dst_packing and dst_packing != 'strided':
return False
return True
def valid_memslice_dtype(dtype, i=0): def valid_memslice_dtype(dtype, i=0):
""" """
...@@ -204,22 +171,22 @@ def valid_memslice_dtype(dtype, i=0): ...@@ -204,22 +171,22 @@ def valid_memslice_dtype(dtype, i=0):
(dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type)) (dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
) )
def validate_memslice_dtype(pos, dtype):
if not valid_memslice_dtype(dtype):
error(pos, "Invalid base type for memoryview slice: %s" % dtype)
class MemoryViewSliceBufferEntry(Buffer.BufferEntry): class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
"""
May be used during code generation time to be queried for
shape/strides/suboffsets attributes, or to perform indexing or slicing.
"""
def __init__(self, entry): def __init__(self, entry):
self.entry = entry self.entry = entry
self.type = entry.type self.type = entry.type
self.cname = entry.cname self.cname = entry.cname
self.buf_ptr = "%s.data" % self.cname self.buf_ptr = "%s.data" % self.cname
dtype = self.entry.type.dtype dtype = self.entry.type.dtype
dtype = PyrexTypes.CPtrType(dtype) self.buf_ptr_type = PyrexTypes.CPtrType(dtype)
self.init_attributes()
self.buf_ptr_type = dtype
def get_buf_suboffsetvars(self): def get_buf_suboffsetvars(self):
return self._for_all_ndim("%s.suboffsets[%d]") return self._for_all_ndim("%s.suboffsets[%d]")
...@@ -236,6 +203,10 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): ...@@ -236,6 +203,10 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
return self._generate_buffer_lookup_code(code, axes) return self._generate_buffer_lookup_code(code, axes)
def _generate_buffer_lookup_code(self, code, axes, cast_result=True): def _generate_buffer_lookup_code(self, code, axes, cast_result=True):
"""
Generate a single expression that indexes the memory view slice
in each dimension.
"""
bufp = self.buf_ptr bufp = self.buf_ptr
type_decl = self.type.dtype.empty_declaration_code() type_decl = self.type.dtype.empty_declaration_code()
...@@ -286,7 +257,9 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): ...@@ -286,7 +257,9 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
then it must be coercible to Py_ssize_t then it must be coercible to Py_ssize_t
Simply call __pyx_memoryview_slice_memviewslice with the right Simply call __pyx_memoryview_slice_memviewslice with the right
arguments. arguments, unless the dimension is omitted or a bare ':', in which
case we copy over the shape/strides/suboffsets attributes directly
for that dimension.
""" """
src = self.cname src = self.cname
...@@ -368,11 +341,13 @@ def empty_slice(pos): ...@@ -368,11 +341,13 @@ def empty_slice(pos):
return ExprNodes.SliceNode(pos, start=none, return ExprNodes.SliceNode(pos, start=none,
stop=none, step=none) stop=none, step=none)
def unellipsify(indices, newaxes, ndim):
def unellipsify(indices, ndim):
result = [] result = []
seen_ellipsis = False seen_ellipsis = False
have_slices = False have_slices = False
newaxes = [newaxis for newaxis in indices if newaxis.is_none]
n_indices = len(indices) - len(newaxes) n_indices = len(indices) - len(newaxes)
for index in indices: for index in indices:
...@@ -387,9 +362,7 @@ def unellipsify(indices, newaxes, ndim): ...@@ -387,9 +362,7 @@ def unellipsify(indices, newaxes, ndim):
result.extend([full_slice] * nslices) result.extend([full_slice] * nslices)
seen_ellipsis = True seen_ellipsis = True
else: else:
have_slices = (have_slices or have_slices = have_slices or index.is_slice or index.is_none
isinstance(index, ExprNodes.SliceNode) or
index.is_none)
result.append(index) result.append(index)
result_length = len(result) - len(newaxes) result_length = len(result) - len(newaxes)
...@@ -398,7 +371,8 @@ def unellipsify(indices, newaxes, ndim): ...@@ -398,7 +371,8 @@ def unellipsify(indices, newaxes, ndim):
nslices = ndim - result_length nslices = ndim - result_length
result.extend([empty_slice(indices[-1].pos)] * nslices) result.extend([empty_slice(indices[-1].pos)] * nslices)
return have_slices, result return have_slices, result, newaxes
def get_memoryview_flag(access, packing): def get_memoryview_flag(access, packing):
if access == 'full' and packing in ('strided', 'follow'): if access == 'full' and packing in ('strided', 'follow'):
...@@ -415,9 +389,11 @@ def get_memoryview_flag(access, packing): ...@@ -415,9 +389,11 @@ def get_memoryview_flag(access, packing):
assert (access, packing) == ('direct', 'contig'), (access, packing) assert (access, packing) == ('direct', 'contig'), (access, packing)
return 'contiguous' return 'contiguous'
def get_is_contig_func_name(c_or_f, ndim): def get_is_contig_func_name(c_or_f, ndim):
return "__pyx_memviewslice_is_%s_contig%d" % (c_or_f, ndim) return "__pyx_memviewslice_is_%s_contig%d" % (c_or_f, ndim)
def get_is_contig_utility(c_contig, ndim): def get_is_contig_utility(c_contig, ndim):
C = dict(context, ndim=ndim) C = dict(context, ndim=ndim)
if c_contig: if c_contig:
...@@ -430,88 +406,21 @@ def get_is_contig_utility(c_contig, ndim): ...@@ -430,88 +406,21 @@ def get_is_contig_utility(c_contig, ndim):
return utility return utility
def copy_src_to_dst_cname(): def slice_iter(slice_type, slice_result, ndim, code):
return "__pyx_memoryview_copy_contents"
def verify_direct_dimensions(node):
for access, packing in node.type.axes:
if access != 'direct':
error(node.pos, "All dimensions must be direct")
def copy_broadcast_memview_src_to_dst(src, dst, code):
"""
Copy the contents of slice src to slice dst. Does not support indirect
slices.
"""
verify_direct_dimensions(src)
verify_direct_dimensions(dst)
code.putln(code.error_goto_if_neg(
"%s(%s, %s, %d, %d, %d)" % (copy_src_to_dst_cname(),
src.result(), dst.result(),
src.type.ndim, dst.type.ndim,
dst.type.dtype.is_pyobject),
dst.pos))
def get_1d_fill_scalar_func(type, code):
dtype = type.dtype
type_decl = dtype.empty_declaration_code()
dtype_name = mangle_dtype_name(dtype)
context = dict(dtype_name=dtype_name, type_decl=type_decl)
utility = load_memview_c_utility("FillStrided1DScalar", context)
code.globalstate.use_utility_code(utility)
return '__pyx_fill_slice_%s' % dtype_name
def assign_scalar(dst, scalar, code):
"""
Assign a scalar to a slice. dst must be a temp, scalar will be assigned
to a correct type and not just something assignable.
"""
verify_direct_dimensions(dst)
dtype = dst.type.dtype
type_decl = dtype.empty_declaration_code()
slice_decl = dst.type.empty_declaration_code()
code.begin_block()
code.putln("%s __pyx_temp_scalar = %s;" % (type_decl, scalar.result()))
if dst.result_in_temp() or (dst.base.is_name and
isinstance(dst.index, ExprNodes.EllipsisNode)):
dst_temp = dst.result()
else:
code.putln("%s __pyx_temp_slice = %s;" % (slice_decl, dst.result()))
dst_temp = "__pyx_temp_slice"
# with slice_iter(dst.type, dst_temp, dst.type.ndim, code) as p:
slice_iter_obj = slice_iter(dst.type, dst_temp, dst.type.ndim, code)
p = slice_iter_obj.start_loops()
if dtype.is_pyobject:
code.putln("Py_DECREF(*(PyObject **) %s);" % p)
code.putln("*((%s *) %s) = __pyx_temp_scalar;" % (type_decl, p))
if dtype.is_pyobject:
code.putln("Py_INCREF(__pyx_temp_scalar);")
slice_iter_obj.end_loops()
code.end_block()
def slice_iter(slice_type, slice_temp, ndim, code):
if slice_type.is_c_contig or slice_type.is_f_contig: if slice_type.is_c_contig or slice_type.is_f_contig:
return ContigSliceIter(slice_type, slice_temp, ndim, code) return ContigSliceIter(slice_type, slice_result, ndim, code)
else: else:
return StridedSliceIter(slice_type, slice_temp, ndim, code) return StridedSliceIter(slice_type, slice_result, ndim, code)
class SliceIter(object): class SliceIter(object):
def __init__(self, slice_type, slice_temp, ndim, code): def __init__(self, slice_type, slice_result, ndim, code):
self.slice_type = slice_type self.slice_type = slice_type
self.slice_temp = slice_temp self.slice_result = slice_result
self.code = code self.code = code
self.ndim = ndim self.ndim = ndim
class ContigSliceIter(SliceIter): class ContigSliceIter(SliceIter):
def start_loops(self): def start_loops(self):
code = self.code code = self.code
...@@ -519,12 +428,12 @@ class ContigSliceIter(SliceIter): ...@@ -519,12 +428,12 @@ class ContigSliceIter(SliceIter):
type_decl = self.slice_type.dtype.empty_declaration_code() type_decl = self.slice_type.dtype.empty_declaration_code()
total_size = ' * '.join("%s.shape[%d]" % (self.slice_temp, i) total_size = ' * '.join("%s.shape[%d]" % (self.slice_result, i)
for i in range(self.ndim)) for i in range(self.ndim))
code.putln("Py_ssize_t __pyx_temp_extent = %s;" % total_size) code.putln("Py_ssize_t __pyx_temp_extent = %s;" % total_size)
code.putln("Py_ssize_t __pyx_temp_idx;") code.putln("Py_ssize_t __pyx_temp_idx;")
code.putln("%s *__pyx_temp_pointer = (%s *) %s.data;" % ( code.putln("%s *__pyx_temp_pointer = (%s *) %s.data;" % (
type_decl, type_decl, self.slice_temp)) type_decl, type_decl, self.slice_result))
code.putln("for (__pyx_temp_idx = 0; " code.putln("for (__pyx_temp_idx = 0; "
"__pyx_temp_idx < __pyx_temp_extent; " "__pyx_temp_idx < __pyx_temp_extent; "
"__pyx_temp_idx++) {") "__pyx_temp_idx++) {")
...@@ -536,19 +445,20 @@ class ContigSliceIter(SliceIter): ...@@ -536,19 +445,20 @@ class ContigSliceIter(SliceIter):
self.code.putln("}") self.code.putln("}")
self.code.end_block() self.code.end_block()
class StridedSliceIter(SliceIter): class StridedSliceIter(SliceIter):
def start_loops(self): def start_loops(self):
code = self.code code = self.code
code.begin_block() code.begin_block()
for i in range(self.ndim): for i in range(self.ndim):
t = i, self.slice_temp, i t = i, self.slice_result, i
code.putln("Py_ssize_t __pyx_temp_extent_%d = %s.shape[%d];" % t) code.putln("Py_ssize_t __pyx_temp_extent_%d = %s.shape[%d];" % t)
code.putln("Py_ssize_t __pyx_temp_stride_%d = %s.strides[%d];" % t) code.putln("Py_ssize_t __pyx_temp_stride_%d = %s.strides[%d];" % t)
code.putln("char *__pyx_temp_pointer_%d;" % i) code.putln("char *__pyx_temp_pointer_%d;" % i)
code.putln("Py_ssize_t __pyx_temp_idx_%d;" % i) code.putln("Py_ssize_t __pyx_temp_idx_%d;" % i)
code.putln("__pyx_temp_pointer_0 = %s.data;" % self.slice_temp) code.putln("__pyx_temp_pointer_0 = %s.data;" % self.slice_result)
for i in range(self.ndim): for i in range(self.ndim):
if i > 0: if i > 0:
......
...@@ -1054,8 +1054,8 @@ class MemoryViewSliceTypeNode(CBaseTypeNode): ...@@ -1054,8 +1054,8 @@ class MemoryViewSliceTypeNode(CBaseTypeNode):
if not MemoryView.validate_axes(self.pos, axes_specs): if not MemoryView.validate_axes(self.pos, axes_specs):
self.type = error_type self.type = error_type
else: else:
MemoryView.validate_memslice_dtype(self.pos, base_type)
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs) self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs)
self.type.validate_memslice_dtype(self.pos)
self.use_memview_utilities(env) self.use_memview_utilities(env)
return self.type return self.type
...@@ -4896,26 +4896,14 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4896,26 +4896,14 @@ class SingleAssignmentNode(AssignmentNode):
if unrolled_assignment: if unrolled_assignment:
return unrolled_assignment return unrolled_assignment
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast: if isinstance(self.lhs, ExprNodes.MemoryViewIndexNode):
self.lhs.memslice_broadcast = True self.lhs.analyse_broadcast_operation(self.rhs)
self.rhs.memslice_broadcast = True self.lhs = self.lhs.analyse_as_memview_scalar_assignment(self.rhs)
if (self.lhs.is_subscript and not self.rhs.type.is_memoryviewslice and
(self.lhs.memslice_slice or self.lhs.is_memslice_copy) and
(self.lhs.type.dtype.assignable_from(self.rhs.type) or
self.rhs.type.is_pyobject)):
# scalar slice assignment
self.lhs.is_memslice_scalar_assignment = True
dtype = self.lhs.type.dtype
elif self.lhs.type.is_array: elif self.lhs.type.is_array:
if not isinstance(self.lhs, ExprNodes.SliceIndexNode): if not isinstance(self.lhs, ExprNodes.SliceIndexNode):
# cannot assign to C array, only to its full slice # cannot assign to C array, only to its full slice
self.lhs = ExprNodes.SliceIndexNode( self.lhs = ExprNodes.SliceIndexNode(self.lhs.pos, base=self.lhs, start=None, stop=None)
self.lhs.pos, base=self.lhs, start=None, stop=None)
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
dtype = self.lhs.type
else:
dtype = self.lhs.type
if self.lhs.type.is_cpp_class: if self.lhs.type.is_cpp_class:
op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type]) op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type])
...@@ -4923,9 +4911,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4923,9 +4911,10 @@ class SingleAssignmentNode(AssignmentNode):
rhs = self.rhs rhs = self.rhs
self.is_overloaded_assignment = True self.is_overloaded_assignment = True
else: else:
rhs = self.rhs.coerce_to(dtype, env) rhs = self.rhs.coerce_to(self.lhs.type, env)
else: else:
rhs = self.rhs.coerce_to(dtype, env) rhs = self.rhs.coerce_to(self.lhs.type, env)
if use_temp or rhs.is_attribute or ( if use_temp or rhs.is_attribute or (
not rhs.is_name and not rhs.is_literal and not rhs.is_name and not rhs.is_literal and
rhs.type.is_pyobject): rhs.type.is_pyobject):
...@@ -5035,12 +5024,12 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5035,12 +5024,12 @@ class SingleAssignmentNode(AssignmentNode):
assignments = [] assignments = []
for lhs, rhs in zip(lhs_list, rhs_list): for lhs, rhs in zip(lhs_list, rhs_list):
assignments.append(SingleAssignmentNode(self.pos, lhs=lhs, rhs=rhs, first=self.first)) assignments.append(SingleAssignmentNode(self.pos, lhs=lhs, rhs=rhs, first=self.first))
all = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env) node = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env)
if check_node: if check_node:
all = StatListNode(pos=self.pos, stats=[check_node, all]) node = StatListNode(pos=self.pos, stats=[check_node, node])
for ref in refs[::-1]: for ref in refs[::-1]:
all = UtilNodes.LetNode(ref, all) node = UtilNodes.LetNode(ref, node)
return all return node
def unroll_rhs(self, env): def unroll_rhs(self, env):
from . import ExprNodes from . import ExprNodes
...@@ -5059,7 +5048,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5059,7 +5048,7 @@ class SingleAssignmentNode(AssignmentNode):
if self.lhs.type.is_ctuple: if self.lhs.type.is_ctuple:
# Handled directly. # Handled directly.
return return
from . import ExprNodes, UtilNodes from . import ExprNodes
if not isinstance(self.rhs, ExprNodes.TupleNode): if not isinstance(self.rhs, ExprNodes.TupleNode):
return return
...@@ -5261,8 +5250,7 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -5261,8 +5250,7 @@ class InPlaceAssignmentNode(AssignmentNode):
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
# When assigning to a fully indexed buffer or memoryview, coerce the rhs # When assigning to a fully indexed buffer or memoryview, coerce the rhs
if (self.lhs.is_subscript and if self.lhs.is_memview_index or self.lhs.is_buffer_access:
(self.lhs.memslice_index or self.lhs.is_buffer_access)):
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.rhs.coerce_to(self.lhs.type, env)
elif self.lhs.type.is_string and self.operator in '+-': elif self.lhs.type.is_string and self.operator in '+-':
# use pointer arithmetic for char* LHS instead of string concat # use pointer arithmetic for char* LHS instead of string concat
...@@ -5271,28 +5259,30 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -5271,28 +5259,30 @@ class InPlaceAssignmentNode(AssignmentNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.rhs.generate_evaluation_code(code) lhs, rhs = self.lhs, self.rhs
self.lhs.generate_subexpr_evaluation_code(code) rhs.generate_evaluation_code(code)
lhs.generate_subexpr_evaluation_code(code)
c_op = self.operator c_op = self.operator
if c_op == "//": if c_op == "//":
c_op = "/" c_op = "/"
elif c_op == "**": elif c_op == "**":
error(self.pos, "No C inplace power operator") error(self.pos, "No C inplace power operator")
if self.lhs.is_subscript and self.lhs.is_buffer_access: if lhs.is_buffer_access or lhs.is_memview_index:
if self.lhs.type.is_pyobject: if lhs.type.is_pyobject:
error(self.pos, "In-place operators not allowed on object buffers in this release.") error(self.pos, "In-place operators not allowed on object buffers in this release.")
if (c_op in ('/', '%') and self.lhs.type.is_int if c_op in ('/', '%') and lhs.type.is_int and not code.globalstate.directives['cdivision']:
and not code.globalstate.directives['cdivision']):
error(self.pos, "In-place non-c divide operators not allowed on int buffers.") error(self.pos, "In-place non-c divide operators not allowed on int buffers.")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op) lhs.generate_buffer_setitem_code(rhs, code, c_op)
elif lhs.is_memview_slice:
error(self.pos, "Inplace operators not supported on memoryview slices")
else: else:
# C++ # C++
# TODO: make sure overload is declared # TODO: make sure overload is declared
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result())) code.putln("%s %s= %s;" % (lhs.result(), c_op, rhs.result()))
self.lhs.generate_subexpr_disposal_code(code) lhs.generate_subexpr_disposal_code(code)
self.lhs.free_subexpr_temps(code) lhs.free_subexpr_temps(code)
self.rhs.generate_disposal_code(code) rhs.generate_disposal_code(code)
self.rhs.free_temps(code) rhs.free_temps(code)
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
...@@ -6344,8 +6334,8 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -6344,8 +6334,8 @@ class ForFromStatNode(LoopNode, StatNode):
"for-from loop variable must be c numeric type or Python object") "for-from loop variable must be c numeric type or Python object")
if target_type.is_numeric: if target_type.is_numeric:
self.is_py_target = False self.is_py_target = False
if isinstance(self.target, ExprNodes.IndexNode) and self.target.is_buffer_access: if isinstance(self.target, ExprNodes.BufferIndexNode):
raise error(self.pos, "Buffer indexing not allowed as for loop target.") raise error(self.pos, "Buffer or memoryview slicing/indexing not allowed as for-loop target.")
self.loopvar_node = self.target self.loopvar_node = self.target
self.py_loopvar_node = None self.py_loopvar_node = None
else: else:
......
...@@ -132,7 +132,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -132,7 +132,7 @@ class IterationTransform(Visitor.EnvTransform):
pos = node.pos pos = node.pos
result_ref = UtilNodes.ResultRefNode(node) result_ref = UtilNodes.ResultRefNode(node)
if isinstance(node.operand2, ExprNodes.IndexNode): if node.operand2.is_subscript:
base_type = node.operand2.base.type.base_type base_type = node.operand2.base.type.base_type
else: else:
base_type = node.operand2.type.base_type base_type = node.operand2.type.base_type
...@@ -442,7 +442,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -442,7 +442,7 @@ class IterationTransform(Visitor.EnvTransform):
error(slice_node.pos, "C array iteration requires known end index") error(slice_node.pos, "C array iteration requires known end index")
return node return node
elif isinstance(slice_node, ExprNodes.IndexNode): elif slice_node.is_subscript:
assert isinstance(slice_node.index, ExprNodes.SliceNode) assert isinstance(slice_node.index, ExprNodes.SliceNode)
slice_base = slice_node.base slice_base = slice_node.base
index = slice_node.index index = slice_node.index
...@@ -564,7 +564,6 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -564,7 +564,6 @@ class IterationTransform(Visitor.EnvTransform):
constant_result=0, constant_result=0,
type=PyrexTypes.c_int_type), type=PyrexTypes.c_int_type),
base=counter_temp, base=counter_temp,
is_buffer_access=False,
type=ptr_type.base_type) type=ptr_type.base_type)
if target_value.type != node.target.type: if target_value.type != node.target.type:
...@@ -1334,20 +1333,20 @@ class DropRefcountingTransform(Visitor.VisitorTransform): ...@@ -1334,20 +1333,20 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
node = node.arg node = node.arg
name_path = [] name_path = []
obj_node = node obj_node = node
while isinstance(obj_node, ExprNodes.AttributeNode): while obj_node.is_attribute:
if obj_node.is_py_attr: if obj_node.is_py_attr:
return False return False
name_path.append(obj_node.member) name_path.append(obj_node.member)
obj_node = obj_node.obj obj_node = obj_node.obj
if isinstance(obj_node, ExprNodes.NameNode): if obj_node.is_name:
name_path.append(obj_node.name) name_path.append(obj_node.name)
names.append( ('.'.join(name_path[::-1]), node) ) names.append( ('.'.join(name_path[::-1]), node) )
elif isinstance(node, ExprNodes.IndexNode): elif node.is_subscript:
if node.base.type != Builtin.list_type: if node.base.type != Builtin.list_type:
return False return False
if not node.index.type.is_int: if not node.index.type.is_int:
return False return False
if not isinstance(node.base, ExprNodes.NameNode): if not node.base.is_name:
return False return False
indices.append(node) indices.append(node)
else: else:
...@@ -1979,7 +1978,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -1979,7 +1978,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
elif isinstance(arg, ExprNodes.SimpleCallNode): elif isinstance(arg, ExprNodes.SimpleCallNode):
if node.type.is_int or node.type.is_float: if node.type.is_int or node.type.is_float:
return self._optimise_numeric_cast_call(node, arg) return self._optimise_numeric_cast_call(node, arg)
elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access: elif arg.is_subscript:
index_node = arg.index index_node = arg.index
if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
index_node = index_node.arg index_node = index_node.arg
......
...@@ -17,7 +17,7 @@ from . import Builtin ...@@ -17,7 +17,7 @@ from . import Builtin
from .Visitor import VisitorTransform, TreeVisitor from .Visitor import VisitorTransform, TreeVisitor
from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from .UtilNodes import LetNode, LetRefNode, ResultRefNode from .UtilNodes import LetNode, LetRefNode
from .TreeFragment import TreeFragment from .TreeFragment import TreeFragment
from .StringEncoding import EncodedString, _unicode from .StringEncoding import EncodedString, _unicode
from .Errors import error, warning, CompileError, InternalError from .Errors import error, warning, CompileError, InternalError
...@@ -1931,13 +1931,8 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1931,13 +1931,8 @@ class AnalyseExpressionsTransform(CythonTransform):
re-analyse the types. re-analyse the types.
""" """
self.visit_Node(node) self.visit_Node(node)
if node.is_fused_index and not node.type.is_error: if node.is_fused_index and not node.type.is_error:
node = node.base node = node.base
elif node.memslice_ellipsis_noop:
# memoryviewslice[...] expression, drop the IndexNode
node = node.base
return node return node
...@@ -1971,26 +1966,26 @@ class ExpandInplaceOperators(EnvTransform): ...@@ -1971,26 +1966,26 @@ class ExpandInplaceOperators(EnvTransform):
if lhs.type.is_cpp_class: if lhs.type.is_cpp_class:
# No getting around this exact operator here. # No getting around this exact operator here.
return node return node
if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access: if isinstance(lhs, ExprNodes.BufferIndexNode):
# There is code to handle this case. # There is code to handle this case in InPlaceAssignmentNode
return node return node
env = self.current_env() env = self.current_env()
def side_effect_free_reference(node, setting=False): def side_effect_free_reference(node, setting=False):
if isinstance(node, ExprNodes.NameNode): if node.is_name:
return node, [] return node, []
elif node.type.is_pyobject and not setting: elif node.type.is_pyobject and not setting:
node = LetRefNode(node) node = LetRefNode(node)
return node, [node] return node, [node]
elif isinstance(node, ExprNodes.IndexNode): elif node.is_subscript:
if node.is_buffer_access:
raise ValueError("Buffer access")
base, temps = side_effect_free_reference(node.base) base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index) index = LetRefNode(node.index)
return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index] return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, ExprNodes.AttributeNode): elif node.is_attribute:
obj, temps = side_effect_free_reference(node.obj) obj, temps = side_effect_free_reference(node.obj)
return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
elif isinstance(node, ExprNodes.BufferIndexNode):
raise ValueError("Don't allow things like attributes of buffer indexing operations")
else: else:
node = LetRefNode(node) node = LetRefNode(node)
return node, [node] return node, [node]
......
...@@ -541,7 +541,7 @@ class MemoryViewSliceType(PyrexType): ...@@ -541,7 +541,7 @@ class MemoryViewSliceType(PyrexType):
the *first* axis' packing spec and 'follow' for all other packing the *first* axis' packing spec and 'follow' for all other packing
specs. specs.
""" """
from . import MemoryView from . import Buffer, MemoryView
self.dtype = base_dtype self.dtype = base_dtype
self.axes = axes self.axes = axes
...@@ -555,7 +555,7 @@ class MemoryViewSliceType(PyrexType): ...@@ -555,7 +555,7 @@ class MemoryViewSliceType(PyrexType):
self.writable_needed = False self.writable_needed = False
if not self.dtype.is_fused: if not self.dtype.is_fused:
self.dtype_name = MemoryView.mangle_dtype_name(self.dtype) self.dtype_name = Buffer.mangle_dtype_name(self.dtype)
def __hash__(self): def __hash__(self):
return hash(self.__class__) ^ hash(self.dtype) ^ hash(tuple(self.axes)) return hash(self.__class__) ^ hash(self.dtype) ^ hash(tuple(self.axes))
...@@ -638,25 +638,28 @@ class MemoryViewSliceType(PyrexType): ...@@ -638,25 +638,28 @@ class MemoryViewSliceType(PyrexType):
elif attribute in ("copy", "copy_fortran"): elif attribute in ("copy", "copy_fortran"):
ndim = len(self.axes) ndim = len(self.axes)
to_axes_c = [('direct', 'contig')] follow_dim = [('direct', 'follow')]
to_axes_f = [('direct', 'contig')] contig_dim = [('direct', 'contig')]
if ndim - 1: to_axes_c = follow_dim * (ndim - 1) + contig_dim
to_axes_c = [('direct', 'follow')]*(ndim-1) + to_axes_c to_axes_f = contig_dim + follow_dim * (ndim -1)
to_axes_f = to_axes_f + [('direct', 'follow')]*(ndim-1)
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c) to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f) to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f)
for to_memview, cython_name in [(to_memview_c, "copy"), for to_memview, cython_name in [(to_memview_c, "copy"),
(to_memview_f, "copy_fortran")]: (to_memview_f, "copy_fortran")]:
entry = scope.declare_cfunction(cython_name, copy_func_type = CFuncType(
CFuncType(self, [CFuncTypeArg("memviewslice", self, None)]), to_memview,
pos=pos, [CFuncTypeArg("memviewslice", self, None)])
defining=1, copy_cname = MemoryView.copy_c_or_fortran_cname(to_memview)
cname=MemoryView.copy_c_or_fortran_cname(to_memview))
entry = scope.declare_cfunction(
cython_name,
copy_func_type, pos=pos, defining=1,
cname=copy_cname)
#entry.utility_code_definition = \ utility = MemoryView.get_copy_new_utility(pos, self, to_memview)
env.use_utility_code(MemoryView.get_copy_new_utility(pos, self, to_memview)) env.use_utility_code(utility)
MemoryView.use_cython_array_utility_code(env) MemoryView.use_cython_array_utility_code(env)
...@@ -684,9 +687,102 @@ class MemoryViewSliceType(PyrexType): ...@@ -684,9 +687,102 @@ class MemoryViewSliceType(PyrexType):
return True return True
def get_entry(self, node, cname=None, type=None):
from . import MemoryView, Symtab
if cname is None:
assert node.is_simple() or node.is_temp or node.is_elemental
cname = node.result()
if type is None:
type = node.type
entry = Symtab.Entry(cname, cname, type, node.pos)
return MemoryView.MemoryViewSliceBufferEntry(entry)
def conforms_to(self, dst, broadcast=False, copying=False):
"""
Returns True if src conforms to dst, False otherwise.
If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.
Any packing/access spec is conformable to itself.
'direct' and 'ptr' are conformable to 'full'.
'contig' and 'follow' are conformable to 'strided'.
Any other combo is not conformable.
"""
from . import MemoryView
src = self
if src.dtype != dst.dtype:
return False
if src.ndim != dst.ndim:
if broadcast:
src, dst = MemoryView.broadcast_types(src, dst)
else:
return False
for src_spec, dst_spec in zip(src.axes, dst.axes):
src_access, src_packing = src_spec
dst_access, dst_packing = dst_spec
if src_access != dst_access and dst_access != 'full':
return False
if src_packing != dst_packing and dst_packing != 'strided' and not copying:
return False
return True
def valid_dtype(self, dtype, i=0):
"""
Return whether type dtype can be used as the base type of a
memoryview slice.
We support structs, numeric types and objects
"""
if dtype.is_complex and dtype.real_type.is_int:
return False
if dtype.is_struct and dtype.kind == 'struct':
for member in dtype.scope.var_entries:
if not self.valid_dtype(member.type):
return False
return True
return (
dtype.is_error or
# Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
(dtype.is_array and i < 8 and self.valid_dtype(dtype.base_type, i + 1)) or
dtype.is_numeric or
dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and self.valid_dtype(dtype.typedef_base_type))
)
def validate_memslice_dtype(self, pos):
if not self.valid_dtype(self.dtype):
error(pos, "Invalid base type for memoryview slice: %s" % self.dtype)
def assert_direct_dims(self, pos):
for access, packing in self.axes:
if access != 'direct':
error(pos, "All dimensions must be direct")
return False
return True
def transpose(self, pos):
if not self.assert_direct_dims(pos):
return error_type
return MemoryViewSliceType(self.dtype, self.axes[::-1])
def specialization_name(self): def specialization_name(self):
return super(MemoryViewSliceType,self).specialization_name() \ return '%s_%s' % (
+ '_' + self.specialization_suffix() super(MemoryViewSliceType,self).specialization_name(),
self.specialization_suffix())
def specialization_suffix(self): def specialization_suffix(self):
return "%s_%s" % (self.axes_to_name(), self.dtype_name) return "%s_%s" % (self.axes_to_name(), self.dtype_name)
...@@ -874,6 +970,11 @@ class BufferType(BaseType): ...@@ -874,6 +970,11 @@ class BufferType(BaseType):
self.negative_indices, self.cast) self.negative_indices, self.cast)
return self return self
def get_entry(self, node):
from . import Buffer
assert node.is_name
return Buffer.BufferEntry(node.entry)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.base, name) return getattr(self.base, name)
......
...@@ -79,7 +79,7 @@ cdef extern from *: ...@@ -79,7 +79,7 @@ cdef extern from *:
size_t sizeof_dtype, int contig_flag, size_t sizeof_dtype, int contig_flag,
bint dtype_is_object) nogil except * bint dtype_is_object) nogil except *
bint slice_is_contig "__pyx_memviewslice_is_contig" ( bint slice_is_contig "__pyx_memviewslice_is_contig" (
{{memviewslice_name}} *mvs, char order, int ndim) nogil {{memviewslice_name}} mvs, char order, int ndim) nogil
bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1, bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1,
{{memviewslice_name}} *slice2, {{memviewslice_name}} *slice2,
int ndim, size_t itemsize) nogil int ndim, size_t itemsize) nogil
...@@ -578,13 +578,13 @@ cdef class memoryview(object): ...@@ -578,13 +578,13 @@ cdef class memoryview(object):
cdef {{memviewslice_name}} *mslice cdef {{memviewslice_name}} *mslice
cdef {{memviewslice_name}} tmp cdef {{memviewslice_name}} tmp
mslice = get_slice_from_memview(self, &tmp) mslice = get_slice_from_memview(self, &tmp)
return slice_is_contig(mslice, 'C', self.view.ndim) return slice_is_contig(mslice[0], 'C', self.view.ndim)
def is_f_contig(self): def is_f_contig(self):
cdef {{memviewslice_name}} *mslice cdef {{memviewslice_name}} *mslice
cdef {{memviewslice_name}} tmp cdef {{memviewslice_name}} tmp
mslice = get_slice_from_memview(self, &tmp) mslice = get_slice_from_memview(self, &tmp)
return slice_is_contig(mslice, 'F', self.view.ndim) return slice_is_contig(mslice[0], 'F', self.view.ndim)
def copy(self): def copy(self):
cdef {{memviewslice_name}} mslice cdef {{memviewslice_name}} mslice
...@@ -1195,7 +1195,7 @@ cdef void *copy_data_to_temp({{memviewslice_name}} *src, ...@@ -1195,7 +1195,7 @@ cdef void *copy_data_to_temp({{memviewslice_name}} *src,
if tmpslice.shape[i] == 1: if tmpslice.shape[i] == 1:
tmpslice.strides[i] = 0 tmpslice.strides[i] = 0
if slice_is_contig(src, order, ndim): if slice_is_contig(src[0], order, ndim):
memcpy(result, src.data, size) memcpy(result, src.data, size)
else: else:
copy_strided_to_strided(src, tmpslice, ndim, itemsize) copy_strided_to_strided(src, tmpslice, ndim, itemsize)
...@@ -1258,7 +1258,7 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src, ...@@ -1258,7 +1258,7 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src,
if slices_overlap(&src, &dst, ndim, itemsize): if slices_overlap(&src, &dst, ndim, itemsize):
# slices overlap, copy to temp, copy temp to dst # slices overlap, copy to temp, copy temp to dst
if not slice_is_contig(&src, order, ndim): if not slice_is_contig(src, order, ndim):
order = get_best_order(&dst, ndim) order = get_best_order(&dst, ndim)
tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) tmpdata = copy_data_to_temp(&src, &tmp, order, ndim)
...@@ -1267,10 +1267,10 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src, ...@@ -1267,10 +1267,10 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src,
if not broadcasting: if not broadcasting:
# See if both slices have equal contiguity, in that case perform a # See if both slices have equal contiguity, in that case perform a
# direct copy. This only works when we are not broadcasting. # direct copy. This only works when we are not broadcasting.
if slice_is_contig(&src, 'C', ndim): if slice_is_contig(src, 'C', ndim):
direct_copy = slice_is_contig(&dst, 'C', ndim) direct_copy = slice_is_contig(dst, 'C', ndim)
elif slice_is_contig(&src, 'F', ndim): elif slice_is_contig(src, 'F', ndim):
direct_copy = slice_is_contig(&dst, 'F', ndim) direct_copy = slice_is_contig(dst, 'F', ndim)
if direct_copy: if direct_copy:
# Contiguous slices with same order # Contiguous slices with same order
......
...@@ -692,29 +692,29 @@ __pyx_slices_overlap({{memviewslice_name}} *slice1, ...@@ -692,29 +692,29 @@ __pyx_slices_overlap({{memviewslice_name}} *slice1,
////////// MemviewSliceIsCContig.proto ////////// ////////// MemviewSliceIsCContig.proto //////////
#define __pyx_memviewslice_is_c_contig{{ndim}}(slice) \ #define __pyx_memviewslice_is_c_contig{{ndim}}(slice) \
__pyx_memviewslice_is_contig(&slice, 'C', {{ndim}}) __pyx_memviewslice_is_contig(slice, 'C', {{ndim}})
////////// MemviewSliceIsFContig.proto ////////// ////////// MemviewSliceIsFContig.proto //////////
#define __pyx_memviewslice_is_f_contig{{ndim}}(slice) \ #define __pyx_memviewslice_is_f_contig{{ndim}}(slice) \
__pyx_memviewslice_is_contig(&slice, 'F', {{ndim}}) __pyx_memviewslice_is_contig(slice, 'F', {{ndim}})
////////// MemviewSliceIsContig.proto ////////// ////////// MemviewSliceIsContig.proto //////////
static int __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs, static int __pyx_memviewslice_is_contig(const {{memviewslice_name}} mvs,
char order, int ndim); char order, int ndim);
////////// MemviewSliceIsContig ////////// ////////// MemviewSliceIsContig //////////
static int static int
__pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs, __pyx_memviewslice_is_contig(const {{memviewslice_name}} mvs,
char order, int ndim) char order, int ndim)
{ {
int i, index, step, start; int i, index, step, start;
Py_ssize_t itemsize = mvs->memview->view.itemsize; Py_ssize_t itemsize = mvs.memview->view.itemsize;
if (order == 'F') { if (order == 'F') {
step = 1; step = 1;
...@@ -726,10 +726,10 @@ __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs, ...@@ -726,10 +726,10 @@ __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs,
for (i = 0; i < ndim; i++) { for (i = 0; i < ndim; i++) {
index = start + step * i; index = start + step * i;
if (mvs->suboffsets[index] >= 0 || mvs->strides[index] != itemsize) if (mvs.suboffsets[index] >= 0 || mvs.strides[index] != itemsize)
return 0; return 0;
itemsize *= mvs->shape[index]; itemsize *= mvs.shape[index];
} }
return 1; return 1;
......
...@@ -14,6 +14,7 @@ from cython.view cimport array ...@@ -14,6 +14,7 @@ from cython.view cimport array
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
@testcase @testcase
def test_shape_stride_suboffset(): def test_shape_stride_suboffset():
u''' u'''
...@@ -47,6 +48,7 @@ def test_shape_stride_suboffset(): ...@@ -47,6 +48,7 @@ def test_shape_stride_suboffset():
print c_contig.strides[0], c_contig.strides[1], c_contig.strides[2] print c_contig.strides[0], c_contig.strides[1], c_contig.strides[2]
print c_contig.suboffsets[0], c_contig.suboffsets[1], c_contig.suboffsets[2] print c_contig.suboffsets[0], c_contig.suboffsets[1], c_contig.suboffsets[2]
@testcase @testcase
def test_copy_to(): def test_copy_to():
u''' u'''
...@@ -57,15 +59,19 @@ def test_copy_to(): ...@@ -57,15 +59,19 @@ def test_copy_to():
''' '''
cdef int[:, :, :] from_mvs, to_mvs cdef int[:, :, :] from_mvs, to_mvs
from_mvs = np.arange(8, dtype=np.int32).reshape(2,2,2) from_mvs = np.arange(8, dtype=np.int32).reshape(2,2,2)
cdef int *from_data = <int *> from_mvs._data cdef int *from_data = <int *> from_mvs._data
print ' '.join(str(from_data[i]) for i in range(2*2*2)) print ' '.join(str(from_data[i]) for i in range(2*2*2))
to_mvs = array((2,2,2), sizeof(int), 'i') to_mvs = array((2,2,2), sizeof(int), 'i')
to_mvs[...] = from_mvs to_mvs[...] = from_mvs
# TODO Mark: remove this _data attribute
cdef int *to_data = <int*>to_mvs._data cdef int *to_data = <int*>to_mvs._data
print ' '.join(str(from_data[i]) for i in range(2*2*2)) print ' '.join(str(from_data[i]) for i in range(2*2*2))
print ' '.join(str(to_data[i]) for i in range(2*2*2)) print ' '.join(str(to_data[i]) for i in range(2*2*2))
@testcase @testcase
def test_overlapping_copy(): def test_overlapping_copy():
""" """
...@@ -81,6 +87,22 @@ def test_overlapping_copy(): ...@@ -81,6 +87,22 @@ def test_overlapping_copy():
for i in range(10): for i in range(10):
assert slice[i] == 10 - 1 - i assert slice[i] == 10 - 1 - i
@testcase
def test_copy_return_type():
"""
>>> test_copy_return_type()
60.0
60.0
"""
cdef double[:, :, :] a = np.arange(5 * 5 * 5, dtype=np.float64).reshape(5, 5, 5)
cdef double[:, ::1] c_contig = a[..., 0].copy()
cdef double[::1, :] f_contig = a[..., 0].copy_fortran()
print(c_contig[2, 2])
print(f_contig[2, 2])
@testcase @testcase
def test_partly_overlapping(): def test_partly_overlapping():
""" """
...@@ -170,30 +192,34 @@ def test_copy_mismatch(): ...@@ -170,30 +192,34 @@ def test_copy_mismatch():
mv1[...] = mv2 mv1[...] = mv2
@testcase @testcase
def test_is_contiguous(): def test_is_contiguous():
u''' u"""
>>> test_is_contiguous() >>> test_is_contiguous()
True True one sized is_c/f_contig True True
False True is_c/f_contig False True
True False f_contig.copy().is_c/f_contig True False
True False f_contig.copy_fortran().is_c/f_contig False True
<BLANKLINE> one sized strided contig True True
False True strided False
True False """
'''
cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran') cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig() , fort_contig.is_f_contig()
fort_contig = array((200,100,100), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
fort_contig = fort_contig.copy()
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
cdef int[:,:,:] strided = fort_contig cdef int[:,:,:] strided = fort_contig
print strided.is_c_contig(), strided.is_f_contig()
print print 'one sized is_c/f_contig', fort_contig.is_c_contig(), fort_contig.is_f_contig()
fort_contig = fort_contig.copy_fortran() fort_contig = array((2,2,2), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig(), fort_contig.is_f_contig() print 'is_c/f_contig', fort_contig.is_c_contig(), fort_contig.is_f_contig()
print strided.is_c_contig(), strided.is_f_contig()
print 'f_contig.copy().is_c/f_contig', fort_contig.copy().is_c_contig(), \
fort_contig.copy().is_f_contig()
print 'f_contig.copy_fortran().is_c/f_contig', \
fort_contig.copy_fortran().is_c_contig(), \
fort_contig.copy_fortran().is_f_contig()
print 'one sized strided contig', strided.is_c_contig(), strided.is_f_contig()
print 'strided', strided[::2].is_c_contig()
@testcase @testcase
...@@ -272,6 +298,7 @@ def two_dee(): ...@@ -272,6 +298,7 @@ def two_dee():
print (<long*>mv3._data)[0] , (<long*>mv3._data)[1] , (<long*>mv3._data)[2] , (<long*>mv3._data)[3] print (<long*>mv3._data)[0] , (<long*>mv3._data)[1] , (<long*>mv3._data)[2] , (<long*>mv3._data)[3]
@testcase @testcase
def fort_two_dee(): def fort_two_dee():
u''' u'''
...@@ -283,7 +310,8 @@ def fort_two_dee(): ...@@ -283,7 +310,8 @@ def fort_two_dee():
1 2 3 -4 1 2 3 -4
''' '''
cdef array arr = array((2,2), sizeof(long), 'l', mode='fortran') cdef array arr = array((2,2), sizeof(long), 'l', mode='fortran')
cdef long[::1,:] mv1, mv2, mv3 cdef long[::1,:] mv1, mv2, mv4
cdef long[:, ::1] mv3
cdef long *arr_data cdef long *arr_data
arr_data = <long*>arr.data arr_data = <long*>arr.data
...@@ -311,6 +339,6 @@ def fort_two_dee(): ...@@ -311,6 +339,6 @@ def fort_two_dee():
print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3] print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3]
mv3 = mv3.copy_fortran() mv4 = mv3.copy_fortran()
print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3] print (<long*>mv4._data)[0], (<long*>mv4._data)[1], (<long*>mv4._data)[2], (<long*>mv4._data)[3]
...@@ -163,6 +163,7 @@ def test_ellipsis_memoryview(array): ...@@ -163,6 +163,7 @@ def test_ellipsis_memoryview(array):
ae(e.shape[0], e_obj.shape[0]) ae(e.shape[0], e_obj.shape[0])
ae(e.strides[0], e_obj.strides[0]) ae(e.strides[0], e_obj.strides[0])
@testcase @testcase
def test_transpose(): def test_transpose():
""" """
...@@ -193,6 +194,20 @@ def test_transpose(): ...@@ -193,6 +194,20 @@ def test_transpose():
print a[3, 2], a.T[2, 3], a_obj[3, 2], a_obj.T[2, 3], numpy_obj[3, 2], numpy_obj.T[2, 3] print a[3, 2], a.T[2, 3], a_obj[3, 2], a_obj.T[2, 3], numpy_obj[3, 2], numpy_obj.T[2, 3]
@testcase
def test_transpose_type(a):
"""
>>> a = np.zeros((5, 10), dtype=np.float64)
>>> a[4, 6] = 9
>>> test_transpose_type(a)
9.0
"""
cdef double[:, ::1] m = a
cdef double[::1, :] m_transpose = a.T
print m_transpose[6, 4]
@testcase_numpy_1_5 @testcase_numpy_1_5
def test_numpy_like_attributes(cyarray): def test_numpy_like_attributes(cyarray):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment