Commit 5292701b authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

memoryviewslices support in-place copying through to_arr[...] indexing

parent a5953c6c
...@@ -401,7 +401,7 @@ class ExprNode(Node): ...@@ -401,7 +401,7 @@ class ExprNode(Node):
# By default, any expression based on Python objects is # By default, any expression based on Python objects is
# prevented in nogil environments. Subtypes must override # prevented in nogil environments. Subtypes must override
# this if they can work without the GIL. # this if they can work without the GIL.
if self.type.is_pyobject: if self.type and self.type.is_pyobject:
self.gil_error() self.gil_error()
def gil_assignment_check(self, env): def gil_assignment_check(self, env):
...@@ -2320,6 +2320,7 @@ class IndexNode(ExprNode): ...@@ -2320,6 +2320,7 @@ class IndexNode(ExprNode):
# For buffers, self.index is packed out on the initial analysis, and # For buffers, self.index is packed out on the initial analysis, and
# when cloning self.indices is copied. # when cloning self.indices is copied.
self.is_buffer_access = False self.is_buffer_access = False
self.is_memoryviewslice_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
if self.base.type.is_error: if self.base.type.is_error:
...@@ -2340,6 +2341,7 @@ class IndexNode(ExprNode): ...@@ -2340,6 +2341,7 @@ class IndexNode(ExprNode):
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
memoryviewslice_access = False
if self.base.type.is_buffer: if self.base.type.is_buffer:
if self.indices: if self.indices:
indices = self.indices indices = self.indices
...@@ -2358,6 +2360,12 @@ class IndexNode(ExprNode): ...@@ -2358,6 +2360,12 @@ class IndexNode(ExprNode):
if buffer_access: if buffer_access:
assert hasattr(self.base, "entry") # Must be a NameNode-like node assert hasattr(self.base, "entry") # Must be a NameNode-like node
if self.base.type.is_memoryviewslice:
assert hasattr(self.base, "entry")
if self.indices or not isinstance(self.index, EllipsisNode):
error(self.pos, "Memoryviews currently support ellipsis indexing only.")
else: memoryviewslice_access = True
# On cloning, indices is cloned. Otherwise, unpack index into indices # On cloning, indices is cloned. Otherwise, unpack index into indices
assert not (buffer_access and isinstance(self.index, CloneNode)) assert not (buffer_access and isinstance(self.index, CloneNode))
...@@ -2375,6 +2383,13 @@ class IndexNode(ExprNode): ...@@ -2375,6 +2383,13 @@ class IndexNode(ExprNode):
error(self.pos, "Writing to readonly buffer") error(self.pos, "Writing to readonly buffer")
else: else:
self.base.entry.buffer_aux.writable_needed = True self.base.entry.buffer_aux.writable_needed = True
elif memoryviewslice_access:
self.type = self.base.type
self.is_memoryviewslice_access = True
if getting:
error(self.pos, "memoryviews currently support setting only.")
else: else:
base_type = self.base.type base_type = self.base.type
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
...@@ -2593,6 +2608,14 @@ class IndexNode(ExprNode): ...@@ -2593,6 +2608,14 @@ class IndexNode(ExprNode):
self.extra_index_params(), self.extra_index_params(),
code.error_goto(self.pos))) code.error_goto(self.pos)))
def generate_memoryviewslice_setitem_code(self, rhs, code, op=""):
assert isinstance(self.index, EllipsisNode)
import MemoryView
util_code = MemoryView.CopyContentsFuncUtilCode(rhs.type, self.type)
func_name = util_code.copy_contents_name
code.putln(code.error_goto_if_neg("%s(&%s, &%s)" % (func_name, rhs.result(), self.base.result()), self.pos))
code.globalstate.use_utility_code(util_code)
def generate_buffer_setitem_code(self, rhs, code, op=""): def generate_buffer_setitem_code(self, rhs, code, op=""):
# Used from generate_assignment_code and InPlaceAssignmentNode # Used from generate_assignment_code and InPlaceAssignmentNode
if code.globalstate.directives['nonecheck']: if code.globalstate.directives['nonecheck']:
...@@ -2619,6 +2642,8 @@ class IndexNode(ExprNode): ...@@ -2619,6 +2642,8 @@ class IndexNode(ExprNode):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.is_buffer_access: if self.is_buffer_access:
self.generate_buffer_setitem_code(rhs, code) self.generate_buffer_setitem_code(rhs, code)
elif self.is_memoryviewslice_access:
self.generate_memoryviewslice_setitem_code(rhs, code)
elif self.type.is_pyobject: elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
else: else:
......
...@@ -218,7 +218,7 @@ static int %s(const __Pyx_memviewslice mvs) { ...@@ -218,7 +218,7 @@ static int %s(const __Pyx_memviewslice mvs) {
int i, ndim = mvs.memview->view.ndim; int i, ndim = mvs.memview->view.ndim;
Py_ssize_t itemsize = mvs.memview->view.itemsize; Py_ssize_t itemsize = mvs.memview->view.itemsize;
unsigned long size = 0; long size = 0;
""" % func_name """ % func_name
if c_or_f == 'fortran': if c_or_f == 'fortran':
...@@ -358,7 +358,7 @@ def get_copy_contents_func(from_mvs, to_mvs, cfunc_name): ...@@ -358,7 +358,7 @@ def get_copy_contents_func(from_mvs, to_mvs, cfunc_name):
# XXX: we only support direct access for now. # XXX: we only support direct access for now.
for (access, packing) in from_mvs.axes: for (access, packing) in from_mvs.axes:
if access != 'direct': if access != 'direct':
raise NotImplementedError("only direct access supported currently.") raise NotImplementedError("currently only direct access is supported.")
code_decl = ("static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs," code_decl = ("static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs,"
"__Pyx_memviewslice *to_mvs); /* proto */" % {'cfunc_name' : cfunc_name}) "__Pyx_memviewslice *to_mvs); /* proto */" % {'cfunc_name' : cfunc_name})
...@@ -372,16 +372,30 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice ...@@ -372,16 +372,30 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice
struct __pyx_obj_memoryview *temp_memview = 0; struct __pyx_obj_memoryview *temp_memview = 0;
char *temp_data = 0; char *temp_data = 0;
''' % {'cfunc_name' : cfunc_name} int ndim_idx = 0;
for(ndim_idx=0; ndim_idx<%(ndim)d; ndim_idx++) {
if(from_mvs->diminfo[ndim_idx].shape != to_mvs->diminfo[ndim_idx].shape) {
PyErr_Format(PyExc_ValueError,
"memoryview shapes not the same in dimension %%d", ndim_idx);
return -1;
}
}
''' % {'cfunc_name' : cfunc_name, 'ndim' : ndim}
# raise NotImplementedError("put in shape checking code here!!!")
INDENT = " "
dtype_decl = from_mvs.dtype.declaration_code("")
last_idx = ndim-1
if to_mvs.is_c_contig or to_mvs.is_f_contig:
if to_mvs.is_c_contig: if to_mvs.is_c_contig:
start, stop, step = 0, ndim, 1 start, stop, step = 0, ndim, 1
elif to_mvs.is_f_contig: elif to_mvs.is_f_contig:
start, stop, step = ndim-1, -1, -1 start, stop, step = ndim-1, -1, -1
else:
assert False
INDENT = " "
for i, idx in enumerate(range(start, stop, step)): for i, idx in enumerate(range(start, stop, step)):
# the crazy indexing is to account for the fortran indexing. # the crazy indexing is to account for the fortran indexing.
...@@ -404,10 +418,27 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice ...@@ -404,10 +418,27 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice
code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k} code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k}
# the inner part of the loop. # the inner part of the loop.
dtype_decl = from_mvs.dtype.declaration_code("") code_impl += INDENT*(ndim+1)+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
last_idx = ndim-1 code_impl += INDENT*(ndim+1)+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()
code_impl += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
code_impl += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()
else:
code_impl += INDENT+"/* 'f' prefix is for the 'from' memview, 't' prefix is for the 'to' memview */\n"
for i in range(ndim):
code_impl += INDENT+"char *fi%d = 0, *ti%d = 0, *end%d = 0;\n" % (i,i,i)
code_impl += INDENT+"Py_ssize_t fstride%(i)d = from_mvs->diminfo[%(i)d].strides;\n" % {'i':i}
code_impl += INDENT+"Py_ssize_t fshape%(i)d = from_mvs->diminfo[%(i)d].shape;\n" % {'i':i}
code_impl += INDENT+"Py_ssize_t tstride%(i)d = to_mvs->diminfo[%(i)d].strides;\n" % {'i':i}
# code_impl += INDENT+"Py_ssize_t tshape%(i)d = to_mvs->diminfo[%(i)d].shape;\n" % {'i':i}
code_impl += INDENT+"end0 = fshape0 * fstride0 + from_mvs->data;\n"
code_impl += INDENT+"for(fi0=from_buf, ti0=to_buf; fi0 < end0; fi0 += fstride0, ti0 += tstride0) {\n"
for i in range(1, ndim):
code_impl += INDENT*(i+1)+"end%(i)d = fshape%(i)d * fstride%(i)d + fi%(im1)d;\n" % {'i' : i, 'im1' : i-1}
code_impl += INDENT*(i+1)+"for(fi%(i)d=fi%(im1)d, ti%(i)d=ti%(im1)d; fi%(i)d < end%(i)d; fi%(i)d += fstride%(i)d, ti%(i)d += tstride%(i)d) {\n" % {'i':i, 'im1':i-1}
code_impl += INDENT*(ndim+1)+"*(%(dtype_decl)s*)(ti%(last_idx)d) = *(%(dtype_decl)s*)(fi%(last_idx)d);\n" % locals()
# for-loop closing braces # for-loop closing braces
for k in range(ndim-1, -1, -1): for k in range(ndim-1, -1, -1):
......
u''' u'''
>>> test_copy_mismatch()
Traceback (most recent call last):
...
ValueError: memoryview shapes not the same in dimension 0
>>> test_copy_to()
0 1 2 3 4 5 6 7
0 1 2 3 4 5 6 7
0 1 2 3 4 5 6 7
>>> test_is_contiguous() >>> test_is_contiguous()
1 1 1 1
0 1 0 1
...@@ -50,6 +58,29 @@ AttributeError: 'NoneType' object has no attribute '_data' ...@@ -50,6 +58,29 @@ AttributeError: 'NoneType' object has no attribute '_data'
cimport cython cimport cython
from cython cimport array from cython cimport array
import numpy as np
cimport numpy as np
def test_copy_to():
cdef int[:,:,:] from_mvs, to_mvs
from_mvs = np.arange(8, dtype=np.int32).reshape(2,2,2)
cdef int *from_dta = <int*>from_mvs._data
for i in range(2*2*2):
print from_dta[i],
print
# for i in range(2*2*2):
# from_dta[i] = i
to_mvs = array((2,2,2), sizeof(int), 'i')
to_mvs[...] = from_mvs
cdef int *to_data = <int*>to_mvs._data
for i in range(2*2*2):
print from_dta[i],
print
for i in range(2*2*2):
print to_data[i],
print
@cython.nonecheck(True) @cython.nonecheck(True)
def test_nonecheck1(): def test_nonecheck1():
cdef int[:,:,:] uninitialized cdef int[:,:,:] uninitialized
...@@ -75,6 +106,12 @@ def test_nonecheck5(): ...@@ -75,6 +106,12 @@ def test_nonecheck5():
cdef int[:,:,:] uninitialized cdef int[:,:,:] uninitialized
uninitialized._data uninitialized._data
def test_copy_mismatch():
cdef int[:,:,::1] mv1 = array((2,2,3), sizeof(int), 'i')
cdef int[:,:,::1] mv2 = array((1,2,3), sizeof(int), 'i')
mv1[...] = mv2
def test_is_contiguous(): def test_is_contiguous():
cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran') cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig() , fort_contig.is_f_contig() print fort_contig.is_c_contig() , fort_contig.is_f_contig()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment