Commit cbd03841 authored by Xavier Thompson's avatar Xavier Thompson

Fix 'consume' operations with typecast operands

parent a09e8c79
...@@ -11378,6 +11378,7 @@ class ConsumeNode(ExprNode): ...@@ -11378,6 +11378,7 @@ class ConsumeNode(ExprNode):
# generate_runtime_check boolean used internally # generate_runtime_check boolean used internally
# check_refcount_only boolean used internally # check_refcount_only boolean used internally
# operand_is_named boolean used internally # operand_is_named boolean used internally
# solid_operand ExprNode used internally
subexprs = ['operand'] subexprs = ['operand']
...@@ -11409,10 +11410,16 @@ class ConsumeNode(ExprNode): ...@@ -11409,10 +11410,16 @@ class ConsumeNode(ExprNode):
self.generate_runtime_check = True self.generate_runtime_check = True
self.check_refcount_only = False self.check_refcount_only = False
self.type = PyrexTypes.cyp_class_qualified_type(operand_type, 'iso~') self.type = PyrexTypes.cyp_class_qualified_type(operand_type, 'iso~')
self.operand_is_named = self.operand.is_name or self.operand.is_attribute solid_operand = self.operand
self.is_temp = self.operand_is_named or (self.generate_runtime_check and not self.operand.is_temp) while isinstance(solid_operand, TypecastNode) and not solid_operand.is_temp:
if not solid_operand.operand.type.is_cyp_class:
break
solid_operand = solid_operand.operand
self.operand_is_named = solid_operand.is_name or solid_operand.is_attribute
self.is_temp = self.operand_is_named or (self.generate_runtime_check and not solid_operand.is_temp)
self.solid_operand = solid_operand
if self.operand_is_named: if self.operand_is_named:
self.operand.entry.is_consumed = True solid_operand.entry.is_consumed = True
return self return self
def may_be_none(self): def may_be_none(self):
...@@ -11453,7 +11460,7 @@ class ConsumeNode(ExprNode): ...@@ -11453,7 +11460,7 @@ class ConsumeNode(ExprNode):
code.error_goto(self.pos)) code.error_goto(self.pos))
code.putln("}") code.putln("}")
if self.operand_is_named: if self.operand_is_named:
code.putln("%s = NULL;" % self.operand.result()) code.putln("%s = NULL;" % self.solid_operand.result())
def generate_post_assignment_code(self, code): def generate_post_assignment_code(self, code):
if self.is_temp: if self.is_temp:
......
...@@ -82,6 +82,37 @@ def test_nogil_consume_aliased_leaf(): ...@@ -82,6 +82,37 @@ def test_nogil_consume_aliased_leaf():
return 0 return 0
cdef cypclass Convertible:
Leaf __Leaf__(self):
return Leaf()
def test_consume_isolated_cast_named_leaf():
"""
>>> test_consume_isolated_cast_named_leaf()
0
"""
leaf = Leaf()
try:
l = consume <Leaf> leaf
if leaf is not NULL:
return -1
return 0
except TypeError as e:
print(e)
return -2
def test_consume_isolated_cast_converted_leaf():
"""
>>> test_consume_isolated_cast_converted_leaf()
0
"""
try:
l = consume <Leaf> Convertible()
return 0
except TypeError as e:
print(e)
return -2
cdef cypclass Field: cdef cypclass Field:
Field foo(self, Field other): Field foo(self, Field other):
return other return other
......
...@@ -144,3 +144,114 @@ def test_consume_and_drop_field(): ...@@ -144,3 +144,114 @@ def test_consume_and_drop_field():
print("consumed") print("consumed")
return 0 return 0
def test_consume_cast_name():
"""
>>> test_consume_cast_name()
Refcounted destroyed
0
"""
r0 = Refcounted()
if Cy_GETREF(r0) != 2:
return -1
cdef Refcounted r1 = consume <Refcounted> r0
if r0 is not NULL:
return -2
if Cy_GETREF(r1) != 2:
return -3
return 0
def test_consume_cast_constructed():
"""
>>> test_consume_cast_constructed()
Refcounted destroyed
0
"""
cdef Refcounted r = consume <Refcounted> Refcounted()
if Cy_GETREF(r) != 2:
return -1
return 0
def test_consume_cast_field():
"""
>>> test_consume_cast_field()
Refcounted destroyed
0
"""
cdef Refcounted r = consume <Refcounted> Origin().field
if Cy_GETREF(r) != 2:
return -1
return 0
cdef cypclass Convertible:
Refcounted __Refcounted__(self):
return Refcounted()
__dealloc__(self) with gil:
print("Convertible destroyed")
def test_consume_converted_name():
"""
>>> test_consume_converted_name()
Convertible destroyed
Refcounted destroyed
0
"""
c = Convertible()
if Cy_GETREF(c) != 2:
return -1
cdef Refcounted r = consume <Refcounted> c
if c is NULL:
return -2
if Cy_GETREF(c) != 2:
return -3
if Cy_GETREF(r) != 2:
return -4
del c
return 0
def test_consume_converted_constructed():
"""
>>> test_consume_converted_constructed()
Convertible destroyed
Refcounted destroyed
0
"""
cdef Refcounted r = consume <Refcounted> Convertible()
if Cy_GETREF(r) != 2:
return -1
return 0
cdef cypclass OriginConvertible:
Convertible field
__init__(self):
self.field = Convertible()
def test_consume_converted_field():
"""
>>> test_consume_converted_field()
Convertible destroyed
Refcounted destroyed
0
"""
o = OriginConvertible()
if Cy_GETREF(o.field) != 2:
return -1
cdef Refcounted r = consume <Refcounted> o.field
if o.field is NULL:
return -2
if Cy_GETREF(o.field) != 2:
return -3
if Cy_GETREF(r) != 2:
return -4
return 0
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment