Commit 017b73ae authored by Mark Florisson's avatar Mark Florisson

slice assignment + broadcasting leading newaxis dimensions

parent f7e5e0b1
...@@ -220,6 +220,9 @@ class ExprNode(Node): ...@@ -220,6 +220,9 @@ class ExprNode(Node):
constant_result = constant_value_not_set constant_result = constant_value_not_set
# whether this node with a memoryview type should be broadcast
memslice_broadcast = False
try: try:
_get_child_attrs = operator.attrgetter('subexprs') _get_child_attrs = operator.attrgetter('subexprs')
except AttributeError: except AttributeError:
...@@ -643,7 +646,8 @@ class ExprNode(Node): ...@@ -643,7 +646,8 @@ class ExprNode(Node):
error(self.pos, error(self.pos,
"Cannot convert '%s' to memoryviewslice" % "Cannot convert '%s' to memoryviewslice" %
(src_type,)) (src_type,))
elif not MemoryView.src_conforms_to_dst(src.type, dst_type): elif not MemoryView.src_conforms_to_dst(
src.type, dst_type, broadcast=self.memslice_broadcast):
if src.type.dtype.same_as(dst_type.dtype): if src.type.dtype.same_as(dst_type.dtype):
msg = "Memoryview '%s' not conformable to memoryview '%s'." msg = "Memoryview '%s' not conformable to memoryview '%s'."
tup = src.type, dst_type tup = src.type, dst_type
...@@ -2339,6 +2343,8 @@ class IndexNode(ExprNode): ...@@ -2339,6 +2343,8 @@ class IndexNode(ExprNode):
# Whether we are indexing or slicing a memoryviewslice # Whether we are indexing or slicing a memoryviewslice
memslice_index = False memslice_index = False
memslice_slice = False memslice_slice = False
is_memslice_copy = False
memslice_ellipsis_noop = False
warned_untyped_idx = False warned_untyped_idx = False
def __init__(self, pos, index, *args, **kw): def __init__(self, pos, index, *args, **kw):
...@@ -2610,7 +2616,10 @@ class IndexNode(ExprNode): ...@@ -2610,7 +2616,10 @@ class IndexNode(ExprNode):
self.type = self.base.type self.type = self.base.type
self.is_memoryviewslice_access = True self.is_memoryviewslice_access = True
if getting: if getting:
error(self.pos, "memoryviews currently support setting only.") self.memslice_ellipsis_noop = True
else:
self.is_memslice_copy = True
self.memslice_broadcast = True
elif self.memslice_slice: elif self.memslice_slice:
self.index = None self.index = None
...@@ -2618,6 +2627,8 @@ class IndexNode(ExprNode): ...@@ -2618,6 +2627,8 @@ class IndexNode(ExprNode):
self.use_managed_ref = True self.use_managed_ref = True
self.type = PyrexTypes.MemoryViewSliceType( self.type = PyrexTypes.MemoryViewSliceType(
self.base.type.dtype, axes) self.base.type.dtype, axes)
if setting:
self.memslice_broadcast = True
else: else:
base_type = self.base.type base_type = self.base.type
...@@ -2816,6 +2827,8 @@ class IndexNode(ExprNode): ...@@ -2816,6 +2827,8 @@ class IndexNode(ExprNode):
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access: if self.is_buffer_access:
return "(*%s)" % self.buffer_ptr_code return "(*%s)" % self.buffer_ptr_code
elif self.is_memslice_copy:
return self.base.result()
elif self.base.type is list_type: elif self.base.type is list_type:
return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
elif self.base.type is tuple_type: elif self.base.type is tuple_type:
...@@ -2840,7 +2853,7 @@ class IndexNode(ExprNode): ...@@ -2840,7 +2853,7 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code) self.base.generate_evaluation_code(code)
if not self.indices: if self.indices is None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else: else:
for i in self.indices: for i in self.indices:
...@@ -2848,7 +2861,7 @@ class IndexNode(ExprNode): ...@@ -2848,7 +2861,7 @@ class IndexNode(ExprNode):
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
if not self.indices: if self.indices is None:
self.index.generate_disposal_code(code) self.index.generate_disposal_code(code)
else: else:
for i in self.indices: for i in self.indices:
...@@ -2856,7 +2869,7 @@ class IndexNode(ExprNode): ...@@ -2856,7 +2869,7 @@ class IndexNode(ExprNode):
def free_subexpr_temps(self, code): def free_subexpr_temps(self, code):
self.base.free_temps(code) self.base.free_temps(code)
if not self.indices: if self.indices is None:
self.index.free_temps(code) self.index.free_temps(code)
else: else:
for i in self.indices: for i in self.indices:
...@@ -2945,15 +2958,6 @@ class IndexNode(ExprNode): ...@@ -2945,15 +2958,6 @@ class IndexNode(ExprNode):
self.extra_index_params(), self.extra_index_params(),
code.error_goto(self.pos))) code.error_goto(self.pos)))
def generate_memoryviewslice_copy_code(self, rhs, code):
import MemoryView
code.putln(
code.error_goto_if_neg(
"%s(&%s, &%s, %d)" % (MemoryView.copy_src_to_dst_cname(),
rhs.result(), self.base.result(),
self.type.ndim),
self.pos))
def generate_buffer_setitem_code(self, rhs, code, op=""): def generate_buffer_setitem_code(self, rhs, code, op=""):
# Used from generate_assignment_code and InPlaceAssignmentNode # Used from generate_assignment_code and InPlaceAssignmentNode
if code.globalstate.directives['nonecheck'] and not self.memslice_index: if code.globalstate.directives['nonecheck'] and not self.memslice_index:
...@@ -2984,8 +2988,7 @@ class IndexNode(ExprNode): ...@@ -2984,8 +2988,7 @@ class IndexNode(ExprNode):
if self.is_buffer_access or self.memslice_index: if self.is_buffer_access or self.memslice_index:
self.generate_buffer_setitem_code(rhs, code) self.generate_buffer_setitem_code(rhs, code)
elif self.memslice_slice: elif self.memslice_slice:
error(rhs.pos, "Slice assignment not supported yet") self.generate_memoryviewslice_setslice_code(rhs, code)
#self.generate_memoryviewslice_setslice_code(rhs, code)
elif self.is_memoryviewslice_access: elif self.is_memoryviewslice_access:
self.generate_memoryviewslice_copy_code(rhs, code) self.generate_memoryviewslice_copy_code(rhs, code)
elif self.type.is_pyobject: elif self.type.is_pyobject:
...@@ -3040,6 +3043,7 @@ class IndexNode(ExprNode): ...@@ -3040,6 +3043,7 @@ class IndexNode(ExprNode):
return buffer_entry return buffer_entry
def buffer_lookup_code(self, code): def buffer_lookup_code(self, code):
"ndarray[1, 2, 3] and memslice[1, 2, 3]"
# Assign indices to temps # Assign indices to temps
index_temps = [code.funcstate.allocate_temp(i.type, manage_ref=False) index_temps = [code.funcstate.allocate_temp(i.type, manage_ref=False)
for i in self.indices] for i in self.indices]
...@@ -3066,6 +3070,7 @@ class IndexNode(ExprNode): ...@@ -3066,6 +3070,7 @@ class IndexNode(ExprNode):
negative_indices=negative_indices) negative_indices=negative_indices)
def put_memoryviewslice_slice_code(self, code): def put_memoryviewslice_slice_code(self, code):
"memslice[:]"
buffer_entry = self.buffer_entry() buffer_entry = self.buffer_entry()
have_gil = not self.in_nogil_context have_gil = not self.in_nogil_context
buffer_entry.generate_buffer_slice_code(code, buffer_entry.generate_buffer_slice_code(code,
...@@ -3073,6 +3078,17 @@ class IndexNode(ExprNode): ...@@ -3073,6 +3078,17 @@ class IndexNode(ExprNode):
self.result(), self.result(),
have_gil=have_gil) have_gil=have_gil)
def generate_memoryviewslice_setslice_code(self, rhs, code):
"memslice1[:] = memslice2"
import MemoryView
self.generate_evaluation_code(code)
MemoryView.copy_broadcast_memview_src_to_dst(rhs, self, code)
def generate_memoryviewslice_copy_code(self, rhs, code):
"memslice1[...] = memslice2"
import MemoryView
MemoryView.copy_broadcast_memview_src_to_dst(rhs, self, code)
def put_nonecheck(self, code): def put_nonecheck(self, code):
code.globalstate.use_utility_code(raise_noneindex_error_utility_code) code.globalstate.use_utility_code(raise_noneindex_error_utility_code)
code.putln("if (%s) {" % code.unlikely("%s == Py_None") % self.base.result_as(PyrexTypes.py_object_type)) code.putln("if (%s) {" % code.unlikely("%s == Py_None") % self.base.result_as(PyrexTypes.py_object_type))
......
...@@ -128,8 +128,19 @@ def get_buf_flags(specs): ...@@ -128,8 +128,19 @@ def get_buf_flags(specs):
else: else:
return memview_strided_access return memview_strided_access
def insert_newaxes(memoryviewtype, n):
axes = [('direct', 'strided')] * n
axes.extend(memoryviewtype.axes)
return PyrexTypes.MemoryViewSliceType(memoryviewtype.dtype, axes)
def broadcast_types(src, dst):
n = abs(src.ndim - dst.ndim)
if src.ndim < dst.ndim:
return insert_newaxes(src, n), dst
else:
return src, insert_newaxes(dst, n)
def src_conforms_to_dst(src, dst): def src_conforms_to_dst(src, dst, broadcast=False):
''' '''
returns True if src conforms to dst, False otherwise. returns True if src conforms to dst, False otherwise.
...@@ -144,7 +155,11 @@ def src_conforms_to_dst(src, dst): ...@@ -144,7 +155,11 @@ def src_conforms_to_dst(src, dst):
if src.dtype != dst.dtype: if src.dtype != dst.dtype:
return False return False
if len(src.axes) != len(dst.axes):
if src.ndim != dst.ndim:
if broadcast:
src, dst = broadcast_types(src, dst)
else:
return False return False
for src_spec, dst_spec in zip(src.axes, dst.axes): for src_spec, dst_spec in zip(src.axes, dst.axes):
...@@ -412,6 +427,62 @@ def get_is_contig_utility(c_contig, ndim): ...@@ -412,6 +427,62 @@ def get_is_contig_utility(c_contig, ndim):
def copy_src_to_dst_cname(): def copy_src_to_dst_cname():
return "__pyx_memoryview_copy_contents" return "__pyx_memoryview_copy_contents"
def verify_direct_dimensions(node):
for access, packing in node.type.axes:
if access != 'direct':
error(self.pos, "All dimensions must be direct")
return False
return True
def broadcast(src, dst, src_temp, dst_temp, code):
"Perform an in-place broadcast of slices src and dst"
if src.type.ndim != dst.type.ndim:
code.putln("__pyx_memoryview_broadcast_inplace(&%s, &%s, %d, %d);" % (
src_temp, dst_temp, src.type.ndim, dst.type.ndim))
return max(src.type.ndim, dst.type.ndim)
return src.type.ndim
def copy_broadcast_memview_src_to_dst_inplace(src, dst, src_temp, dst_temp, code):
"""
It is hard to check for overlapping memory with indirect slices,
so we currently don't support them.
"""
if not verify_direct_dimensions(src): return
if not verify_direct_dimensions(dst): return
ndim = broadcast(src, dst, src_temp, dst_temp, code)
call = "%s(&%s, &%s, %d)" % (copy_src_to_dst_cname(),
src_temp, dst_temp, ndim)
code.putln(code.error_goto_if_neg(call, dst.pos))
def copy_broadcast_memview_src_to_dst(src, dst, code):
# Note: do not use code.funcstate.allocate_temp to allocate temps, as
# temps will be acquisition counted (so we would need new
# references, as any sudden exception would cause a jump leading to
# a decref before we can nullify our slice)
src_tmp = None
dst_tmp = None
code.begin_block()
if src.type.ndim < dst.type.ndim and not src.result_in_temp():
src_tmp = '__pyx_slice_tmp1'
code.putln("%s %s = %s;" % (memviewslice_cname, src_tmp, src.result()))
if dst.type.ndim < src.type.ndim and not dst.result_in_temp():
dst_tmp = '__pyx+_slice_tmp2'
code.putln("%s %s = %s;" % (memviewslice_cname, dst_tmp, dst.result()))
copy_broadcast_memview_src_to_dst_inplace(src, dst,
src_tmp or src.result(),
dst_tmp or dst.result(),
code)
code.end_block()
def copy_c_or_fortran_cname(memview): def copy_c_or_fortran_cname(memview):
if memview.is_c_contig: if memview.is_c_contig:
c_or_f = 'c' c_or_f = 'c'
......
...@@ -4727,6 +4727,11 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4727,6 +4727,11 @@ class SingleAssignmentNode(AssignmentNode):
self.rhs.analyse_types(env) self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast:
self.lhs.memslice_broadcast = True
self.rhs.memslice_broadcast = True
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.rhs.coerce_to(self.lhs.type, env)
if use_temp: if use_temp:
self.rhs = self.rhs.coerce_to_temp(env) self.rhs = self.rhs.coerce_to_temp(env)
......
...@@ -1779,6 +1779,10 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1779,6 +1779,10 @@ class AnalyseExpressionsTransform(CythonTransform):
if node.is_fused_index and node.type is not PyrexTypes.error_type: if node.is_fused_index and node.type is not PyrexTypes.error_type:
node = node.base node = node.base
if node.memslice_ellipsis_noop:
# memoryviewslice[...] expression, drop the IndexNode
node = node.base
return node return node
......
...@@ -66,9 +66,9 @@ cdef class array: ...@@ -66,9 +66,9 @@ cdef class array:
cdef int idx cdef int idx
# cdef Py_ssize_t dim, stride # cdef Py_ssize_t dim, stride
idx = 0 idx = 0
for dim in shape: for idx, dim in enumerate(shape):
if dim <= 0: if dim <= 0:
raise ValueError("Invalid shape.") raise ValueError("Invalid shape in axis %d: %d." % (idx, dim))
self._shape[idx] = dim self._shape[idx] = dim
idx += 1 idx += 1
...@@ -243,7 +243,7 @@ cdef extern from *: ...@@ -243,7 +243,7 @@ cdef extern from *:
{{memviewslice_name}} slice_copy_contig "__pyx_memoryview_copy_new_contig"( {{memviewslice_name}} slice_copy_contig "__pyx_memoryview_copy_new_contig"(
__Pyx_memviewslice *from_mvs, __Pyx_memviewslice *from_mvs,
char *mode, int ndim, char *mode, int ndim,
size_t sizeof_dtype, int contig_flag) nogil size_t sizeof_dtype, int contig_flag) nogil except *
bint slice_is_contig "__pyx_memviewslice_is_contig" ( bint slice_is_contig "__pyx_memviewslice_is_contig" (
{{memviewslice_name}} *mvs, char order, int ndim) nogil {{memviewslice_name}} *mvs, char order, int ndim) nogil
bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1, bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1,
...@@ -347,9 +347,43 @@ cdef class memoryview(object): ...@@ -347,9 +347,43 @@ cdef class memoryview(object):
@cname('__pyx_memoryview_setitem') @cname('__pyx_memoryview_setitem')
def __setitem__(memoryview self, object index, object value): def __setitem__(memoryview self, object index, object value):
have_slices, index = _unellipsify(index, self.view.ndim) have_slices, index = _unellipsify(index, self.view.ndim)
if have_slices: if have_slices:
raise NotImplementedError("Slice assignment not supported yet") obj = self.is_slice(value)
if obj:
self.setitem_slice_assignment(index, obj)
else:
self.setitem_slice_assign_scalar(index, value)
else:
self.setitem_indexed(index, value)
cdef is_slice(self, obj):
if not isinstance(obj, memoryview):
try:
obj = memoryview(obj, self.flags|PyBUF_ANY_CONTIGUOUS)
except TypeError:
return None
return obj
cdef setitem_slice_assignment(self, index, src):
cdef {{memviewslice_name}} dst_slice
cdef {{memviewslice_name}} src_slice
dst = self[index]
get_slice_from_memview(dst, &dst_slice)
slice_copy(src, &src_slice)
if dst.ndim != src.ndim:
broadcast_inplace(&src_slice, &dst_slice, src.ndim, dst.ndim)
memoryview_copy_contents(&src_slice, &dst_slice,
max(src.ndim, dst.ndim))
cdef setitem_slice_assign_scalar(self, index, value):
raise ValueError("Scalar assignment currently unsupported")
cdef setitem_indexed(self, index, value):
cdef char *itemp = self.get_item_pointer(index) cdef char *itemp = self.get_item_pointer(index)
self.assign_item_from_object(itemp, value) self.assign_item_from_object(itemp, value)
...@@ -1106,8 +1140,8 @@ cdef int memoryview_copy_contents({{memviewslice_name}} *src, ...@@ -1106,8 +1140,8 @@ cdef int memoryview_copy_contents({{memviewslice_name}} *src,
if src.shape[i] != dst.shape[i]: if src.shape[i] != dst.shape[i]:
with gil: with gil:
raise ValueError( raise ValueError(
"memoryview shapes are not the same in dimension %d, " "memoryview shapes are not the same in dimension %d "
"got %d and %d" % (i, src.shape[i], dst.shape[i])) "(got %d and %d)" % (i, dst.shape[i], src.shape[i]))
if slices_overlap(src, dst, ndim, itemsize): if slices_overlap(src, dst, ndim, itemsize):
# slices overlap, copy to temp, copy temp to dst # slices overlap, copy to temp, copy temp to dst
...@@ -1151,6 +1185,38 @@ cdef int memoryview_copy_contents({{memviewslice_name}} *src, ...@@ -1151,6 +1185,38 @@ cdef int memoryview_copy_contents({{memviewslice_name}} *src,
copy_strided_to_strided(src, dst, ndim, itemsize) copy_strided_to_strided(src, dst, ndim, itemsize)
return 0 return 0
@cname('__pyx_memoryview_broadcast_inplace')
cdef void broadcast_inplace({{memviewslice_name}} *slice1,
{{memviewslice_name}} *slice2,
int ndim1,
int ndim2) nogil:
"""
Broadcast the slice with the least dimensions to prepend empty
dimensions.
"""
cdef int i
cdef int offset = ndim1 - ndim2
cdef int ndim
cdef {{memviewslice_name}} *slice
if offset < 0:
slice = slice1
offset = -offset
ndim = ndim1
else:
slice = slice2
ndim = ndim2
for i in range(ndim - 1, -1, -1):
slice.shape[i + offset] = slice.shape[i]
slice.strides[i + offset] = slice.strides[i]
slice.suboffsets[i + offset] = slice.suboffsets[i]
for i in range(offset):
slice.shape[i] = 1
slice.strides[i] = slice.strides[0]
slice.suboffsets[i] = -1
############### BufferFormatFromTypeInfo ############### ############### BufferFormatFromTypeInfo ###############
cdef extern from *: cdef extern from *:
ctypedef struct __Pyx_StructField ctypedef struct __Pyx_StructField
......
...@@ -483,8 +483,16 @@ __pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, ...@@ -483,8 +483,16 @@ __pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs,
__Pyx_RefNannySetupContext("__pyx_memoryview_copy_new_contig"); __Pyx_RefNannySetupContext("__pyx_memoryview_copy_new_contig");
for (i = 0; i < ndim; i++) {
if (from_mvs->suboffsets[i] >= 0) {
PyErr_Format(PyExc_ValueError, "Cannot copy memoryview slice with "
"indirect dimensions (axis %d)", i);
goto fail;
}
}
shape_tuple = PyTuple_New(ndim); shape_tuple = PyTuple_New(ndim);
if(unlikely(!shape_tuple)) { if (unlikely(!shape_tuple)) {
goto fail; goto fail;
} }
__Pyx_GOTREF(shape_tuple); __Pyx_GOTREF(shape_tuple);
......
...@@ -171,7 +171,7 @@ def test_copy_mismatch(): ...@@ -171,7 +171,7 @@ def test_copy_mismatch():
>>> test_copy_mismatch() >>> test_copy_mismatch()
Traceback (most recent call last): Traceback (most recent call last):
... ...
ValueError: memoryview shapes are not the same in dimension 0, got 1 and 2 ValueError: memoryview shapes are not the same in dimension 0 (got 2 and 1)
''' '''
cdef int[:,:,::1] mv1 = array((2,2,3), sizeof(int), 'i') cdef int[:,:,::1] mv1 = array((2,2,3), sizeof(int), 'i')
cdef int[:,:,::1] mv2 = array((1,2,3), sizeof(int), 'i') cdef int[:,:,::1] mv2 = array((1,2,3), sizeof(int), 'i')
......
...@@ -1724,3 +1724,60 @@ def test_object_indices(): ...@@ -1724,3 +1724,60 @@ def test_object_indices():
for j in range(3): for j in range(3):
print myslice[j] print myslice[j]
@testcase
def test_ellipsis_expr():
"""
>>> test_ellipsis_expr()
8
"""
cdef int[10] a
cdef int[:] m = a
m[4] = 8
m[...] = m[...]
print m[4]
@testcase
def test_slice_assignment():
"""
>>> test_slice_assignment()
"""
cdef int carray[10][100]
cdef int i, j
for i in range(10):
for j in range(100):
carray[i][j] = i * 10 + j
cdef int[:, :] m = carray
cdef int[:, :] copy = m[-6:-1, 60:65].copy()
m[...] = m[::-1, ::-1]
m[:, :] = m[::-1, ::-1]
m[-5:, -5:] = m[-6:-1, 60:65]
for i in range(5):
for j in range(5):
assert copy[i, j] == m[-5 + i, -5 + j]
@testcase
def test_slice_assignment_broadcast_leading_dimensions():
"""
>>> test_slice_assignment_broadcast_leading_dimensions()
"""
cdef int array1[1][10]
cdef int array2[10]
cdef int i
for i in range(10):
array1[0][i] = i
cdef int[:, :] a = array1
cdef int[:] b = array2
b[:] = a[:, :]
b = b[::-1]
a[:, :] = b[:]
for i in range(10):
assert a[0, i] == b[i] == 10 - 1 - i
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment