Commit 766bfd93 authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

correct reference handling for memoryviewslices.

parent 985d50f3
...@@ -22,7 +22,7 @@ import Nodes ...@@ -22,7 +22,7 @@ import Nodes
from Nodes import Node from Nodes import Node
import PyrexTypes import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \ from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \
unspecified_type unspecified_type, cython_memoryview_ptr_type
import TypeSlots import TypeSlots
from Builtin import list_type, tuple_type, set_type, dict_type, \ from Builtin import list_type, tuple_type, set_type, dict_type, \
unicode_type, str_type, bytes_type, type_type unicode_type, str_type, bytes_type, type_type
...@@ -1700,11 +1700,28 @@ class NameNode(AtomicExprNode): ...@@ -1700,11 +1700,28 @@ class NameNode(AtomicExprNode):
rhs.free_temps(code) rhs.free_temps(code)
def generate_acquire_memoryviewslice(self, rhs, code): def generate_acquire_memoryviewslice(self, rhs, code):
# to explicitly manange the memviewslice.memview object correctly.
import MemoryView import MemoryView
assert rhs.type.is_memoryviewslice
if not rhs.result_in_temp():
code.put_incref("%s.memview" % rhs.result(), cython_memoryview_ptr_type)
if self.entry.is_cglobal:
code.put_gotref("%s.memview" % self.result())
if not self.lhs_of_first_assignment:
if self.entry.is_local and not Options.init_local_none:
code.put_xdecref("%s.memview" % self.result(), cython_memoryview_ptr_type)
else:
code.put_decref("%s.memview" % self.result(), cython_memoryview_ptr_type)
if self.entry.is_cglobal:
code.put_giveref("%s.memview" % rhs.result())
MemoryView.put_assign_to_memviewslice(self.result(), rhs.result(), self.type, MemoryView.put_assign_to_memviewslice(self.result(), rhs.result(), self.type,
pos=self.pos, code=code) pos=self.pos, code=code)
if rhs.is_temp: if rhs.result_in_temp():
code.put_xdecref_clear("%s.memview" % rhs.result(), py_object_type) code.putln("%s.memview = 0;" % rhs.result())
def generate_acquire_buffer(self, rhs, code): def generate_acquire_buffer(self, rhs, code):
# rhstmp is only used in case the rhs is a complicated expression leading to # rhstmp is only used in case the rhs is a complicated expression leading to
...@@ -3949,7 +3966,7 @@ class AttributeNode(ExprNode): ...@@ -3949,7 +3966,7 @@ class AttributeNode(ExprNode):
MemoryView.put_assign_to_memviewslice(select_code, rhs.result(), self.type, MemoryView.put_assign_to_memviewslice(select_code, rhs.result(), self.type,
pos=self.pos, code=code) pos=self.pos, code=code)
if rhs.is_temp: if rhs.is_temp:
code.put_xdecref_clear("%s.memview" % rhs.result(), py_object_type) code.put_xdecref_clear("%s.memview" % rhs.result(), cython_memoryview_ptr_type)
if not self.type.is_memoryviewslice: if not self.type.is_memoryviewslice:
code.putln( code.putln(
"%s = %s;" % ( "%s = %s;" % (
...@@ -7620,7 +7637,7 @@ class CoerceToMemViewSliceNode(CoercionNode): ...@@ -7620,7 +7637,7 @@ class CoerceToMemViewSliceNode(CoercionNode):
code.putln("__pyx_viewaxis_init_memviewslice_from_memview" code.putln("__pyx_viewaxis_init_memviewslice_from_memview"
"((struct __pyx_obj_memoryview *)%s, %s, %d, sizeof(%s), \"%s\", &%s);" %\ "((struct __pyx_obj_memoryview *)%s, %s, %d, sizeof(%s), \"%s\", &%s);" %\
(memviewobj, spec_int_arr, ndim, itemsize, format, self.result())) (memviewobj, spec_int_arr, ndim, itemsize, format, self.result()))
code.put_gotref("%s.memview" % self.result()) code.put_gotref(code.as_pyobject("%s.memview" % self.result(), cython_memoryview_ptr_type))
code.funcstate.release_temp(memviewobj) code.funcstate.release_temp(memviewobj)
code.funcstate.release_temp(spec_int_arr) code.funcstate.release_temp(spec_int_arr)
......
...@@ -5,7 +5,7 @@ import Options ...@@ -5,7 +5,7 @@ import Options
import CythonScope import CythonScope
from Code import UtilityCode from Code import UtilityCode
from UtilityCode import CythonUtilityCode from UtilityCode import CythonUtilityCode
from PyrexTypes import py_object_type, cython_memoryview_type from PyrexTypes import py_object_type, cython_memoryview_ptr_type
START_ERR = "there must be nothing or the value 0 (zero) in the start slot." START_ERR = "there must be nothing or the value 0 (zero) in the start slot."
STOP_ERR = "Axis specification only allowed in the 'stop' slot." STOP_ERR = "Axis specification only allowed in the 'stop' slot."
...@@ -67,17 +67,11 @@ def format_from_type(base_type): ...@@ -67,17 +67,11 @@ def format_from_type(base_type):
def put_init_entry(mv_cname, code): def put_init_entry(mv_cname, code):
code.putln("%s.data = NULL;" % mv_cname) code.putln("%s.data = NULL;" % mv_cname)
code.put_init_to_py_none("%s.memview" % mv_cname, cython_memoryview_type) code.put_init_to_py_none("%s.memview" % mv_cname, cython_memoryview_ptr_type)
code.put_giveref("%s.memview" % mv_cname) code.put_giveref(code.as_pyobject("%s.memview" % mv_cname, cython_memoryview_ptr_type))
def put_assign_to_memviewslice(lhs_cname, rhs_cname, memviewslicetype, pos, code): def put_assign_to_memviewslice(lhs_cname, rhs_cname, memviewslicetype, pos, code):
# XXX: add error checks!
code.put_giveref("%s.memview" % (rhs_cname))
code.put_incref("%s.memview" % (rhs_cname), py_object_type)
code.put_gotref("%s.memview" % (lhs_cname))
code.put_xdecref("%s.memview" % (lhs_cname), py_object_type)
code.putln("%s.memview = %s.memview;" % (lhs_cname, rhs_cname)) code.putln("%s.memview = %s.memview;" % (lhs_cname, rhs_cname))
code.putln("%s.data = %s.data;" % (lhs_cname, rhs_cname)) code.putln("%s.data = %s.data;" % (lhs_cname, rhs_cname))
ndim = len(memviewslicetype.axes) ndim = len(memviewslicetype.axes)
......
...@@ -1237,7 +1237,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1237,7 +1237,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
for entry in memviewslice_attrs: for entry in memviewslice_attrs:
code.putln("p->%s.data = NULL;" % entry.cname) code.putln("p->%s.data = NULL;" % entry.cname)
code.put_init_to_py_none("p->%s.memview" % entry.cname, code.put_init_to_py_none("p->%s.memview" % entry.cname,
PyrexTypes.cython_memoryview_type, nanny=False) PyrexTypes.cython_memoryview_ptr_type, nanny=False)
entry = scope.lookup_here("__new__") entry = scope.lookup_here("__new__")
if entry and entry.is_special: if entry and entry.is_special:
if entry.trivial_signature: if entry.trivial_signature:
......
...@@ -17,7 +17,7 @@ from Errors import error, warning, InternalError, CompileError ...@@ -17,7 +17,7 @@ from Errors import error, warning, InternalError, CompileError
import Naming import Naming
import PyrexTypes import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type, CFuncType from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType, cython_memoryview_ptr_type
from Symtab import ModuleScope, LocalScope, ClosureScope, \ from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CClassScope, CppClassScope StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
...@@ -1424,12 +1424,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1424,12 +1424,14 @@ class FuncDefNode(StatNode, BlockNode):
if entry.type.is_pyobject: if entry.type.is_pyobject:
if (acquire_gil or entry.assignments) and not entry.in_closure: if (acquire_gil or entry.assignments) and not entry.in_closure:
code.put_var_incref(entry) code.put_var_incref(entry)
if entry.type.is_memoryviewslice:
code.put_incref("%s.memview" % entry.cname, cython_memoryview_ptr_type)
# ----- Initialise local buffer auxiliary variables # ----- Initialise local buffer auxiliary variables
for entry in lenv.var_entries + lenv.arg_entries: for entry in lenv.var_entries + lenv.arg_entries:
if entry.type.is_buffer and entry.buffer_aux.buflocal_nd_var.used: if entry.type.is_buffer and entry.buffer_aux.buflocal_nd_var.used:
Buffer.put_init_vars(entry, code) Buffer.put_init_vars(entry, code)
# ----- Initialise local memoryview slices # ----- Initialise local memoryviewslices
for entry in lenv.var_entries + lenv.arg_entries: for entry in lenv.var_entries:
if entry.type.is_memoryviewslice: if entry.type.is_memoryviewslice:
MemoryView.put_init_entry(entry.cname, code) MemoryView.put_init_entry(entry.cname, code)
# ----- Check and convert arguments # ----- Check and convert arguments
...@@ -1533,14 +1535,20 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1533,14 +1535,20 @@ class FuncDefNode(StatNode, BlockNode):
code.put_label(code.return_from_error_cleanup_label) code.put_label(code.return_from_error_cleanup_label)
for entry in lenv.var_entries: for entry in lenv.var_entries:
if not entry.used or entry.in_closure:
continue
if entry.type.is_memoryviewslice:
code.put_xdecref("%s.memview" % entry.cname, cython_memoryview_ptr_type)
if entry.type.is_pyobject: if entry.type.is_pyobject:
if entry.used and not entry.in_closure:
code.put_var_decref(entry) code.put_var_decref(entry)
# Decref any increfed args # Decref any increfed args
for entry in lenv.arg_entries: for entry in lenv.arg_entries:
if entry.type.is_pyobject: if entry.type.is_pyobject:
if (acquire_gil or entry.assignments) and not entry.in_closure: if (acquire_gil or entry.assignments) and not entry.in_closure:
code.put_var_decref(entry) code.put_var_decref(entry)
if entry.type.is_memoryviewslice:
code.put_decref("%s.memview" % entry.cname, cython_memoryview_ptr_type)
if self.needs_closure: if self.needs_closure:
code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type) code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
...@@ -1600,7 +1608,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1600,7 +1608,7 @@ class FuncDefNode(StatNode, BlockNode):
def declare_argument(self, env, arg): def declare_argument(self, env, arg):
if arg.type.is_void: if arg.type.is_void:
error(arg.pos, "Invalid use of 'void'") error(arg.pos, "Invalid use of 'void'")
elif not arg.type.is_complete() and not arg.type.is_array: elif not arg.type.is_complete() and not (arg.type.is_array or arg.type.is_memoryviewslice):
error(arg.pos, error(arg.pos,
"Argument type '%s' is incomplete" % arg.type) "Argument type '%s' is incomplete" % arg.type)
return env.declare_arg(arg.name, arg.type, arg.pos) return env.declare_arg(arg.name, arg.type, arg.pos)
......
...@@ -172,7 +172,7 @@ class PyrexType(BaseType): ...@@ -172,7 +172,7 @@ class PyrexType(BaseType):
def global_init_code(self, entry, code): def global_init_code(self, entry, code):
# abstract # abstract
raise NotImplementedError() pass
def public_decl(base_code, dll_linkage): def public_decl(base_code, dll_linkage):
...@@ -374,7 +374,7 @@ class MemoryViewSliceType(PyrexType): ...@@ -374,7 +374,7 @@ class MemoryViewSliceType(PyrexType):
def global_init_code(self, entry, code): def global_init_code(self, entry, code):
code.putln("%s.data = NULL;" % entry.cname) code.putln("%s.data = NULL;" % entry.cname)
code.put_init_to_py_none("%s.memview" % entry.cname, cython_memoryview_type, nanny=False) code.put_init_to_py_none("%s.memview" % entry.cname, cython_memoryview_ptr_type, nanny=False)
class BufferType(BaseType): class BufferType(BaseType):
# #
...@@ -2562,9 +2562,10 @@ c_pyx_buffer_ptr_type = CPtrType(c_pyx_buffer_type) ...@@ -2562,9 +2562,10 @@ c_pyx_buffer_ptr_type = CPtrType(c_pyx_buffer_type)
c_pyx_buffer_nd_type = CStructOrUnionType("__Pyx_LocalBuf_ND", "struct", c_pyx_buffer_nd_type = CStructOrUnionType("__Pyx_LocalBuf_ND", "struct",
None, 1, "__Pyx_LocalBuf_ND") None, 1, "__Pyx_LocalBuf_ND")
cython_memoryview_type = CPtrType(CStructOrUnionType("__pyx_obj_memoryview", "struct", cython_memoryview_type = CStructOrUnionType("__pyx_obj_memoryview", "struct",
None, 0, "__pyx_obj_memoryview")) None, 0, "__pyx_obj_memoryview")
cython_memoryview_ptr_type = CPtrType(cython_memoryview_type)
error_type = ErrorType() error_type = ErrorType()
unspecified_type = UnspecifiedType() unspecified_type = UnspecifiedType()
......
...@@ -7,12 +7,20 @@ u''' ...@@ -7,12 +7,20 @@ u'''
from cython.view cimport memoryview from cython.view cimport memoryview
from cython cimport array, PyBUF_C_CONTIGUOUS from cython cimport array, PyBUF_C_CONTIGUOUS
def init_obj():
return 3
cdef passmvs(float[:,::1] mvs, object foo):
mvs = array((10,10), itemsize=sizeof(float), format='f')
foo = init_obj()
def f(): def f():
cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i') cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i')
cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS) cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS)
def g(): def g():
cdef object obj = init_obj()
cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i') cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i')
obj = init_obj()
mview = array((10,), itemsize=sizeof(int), format='i') mview = array((10,), itemsize=sizeof(int), format='i')
cdef class Foo: cdef class Foo:
...@@ -33,9 +41,11 @@ cdef cdg(): ...@@ -33,9 +41,11 @@ cdef cdg():
cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f') cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f')
global_mv = array((10,10), itemsize=sizeof(float), format='f') global_mv = array((10,10), itemsize=sizeof(float), format='f')
cdef object global_obj
def call(): def call():
global global_mv global global_mv
passmvs(global_mv, global_obj)
global_mv = array((3,3), itemsize=sizeof(float), format='f') global_mv = array((3,3), itemsize=sizeof(float), format='f')
cdg() cdg()
f = Foo() f = Foo()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment