Commit c37bf1b0 authored by Xavier Thompson's avatar Xavier Thompson

Enforce const-correctness when locally aliasing const cypclass references

parent 238f9c24
...@@ -492,6 +492,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -492,6 +492,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
def __call__(self, root): def __call__(self, root):
self.rlocked = defaultdict(int) self.rlocked = defaultdict(int)
self.wlocked = defaultdict(int) self.wlocked = defaultdict(int)
self.const = defaultdict(int)
self.reading = False self.reading = False
self.writing = False self.writing = False
self.deleting = False self.deleting = False
...@@ -542,8 +543,8 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -542,8 +543,8 @@ class CypclassLockTransform(Visitor.EnvTransform):
return written_node return written_node
ref_id = self.reference_identifier(written_node) ref_id = self.reference_identifier(written_node)
if ref_id: if ref_id:
if ref_id.type.is_const: if self.const[ref_id] > 0:
error(written_node.pos, "Reference '%s' is const but requires a write lock" % self.id_to_name(ref_id)) error(written_node.pos, "Local reference '%s' is const but requires a write lock" % self.id_to_name(ref_id))
return written_node return written_node
if not self.wlocked[ref_id] > 0: if not self.wlocked[ref_id] > 0:
if lock_mode == "checklock": if lock_mode == "checklock":
...@@ -580,14 +581,28 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -580,14 +581,28 @@ class CypclassLockTransform(Visitor.EnvTransform):
# else: should have caused a previous error # else: should have caused a previous error
return rhs return rhs
def mark_const_local_alias(self, lhs, rhs):
lhs_ref_id = self.reference_identifier(lhs)
rhs_ref_id = self.reference_identifier(rhs)
if not lhs_ref_id or not rhs_ref_id:
return
if self.const[rhs_ref_id] and lhs_ref_id.is_local:
self.const[lhs_ref_id] = 1
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
cyp_class_args = (e for e in node.local_scope.arg_entries if e.type.is_cyp_class) cyp_class_args = (e for e in node.local_scope.arg_entries if e.type.is_cyp_class)
arg_locks = [] arg_locks = []
old_const = self.const.copy()
for arg in cyp_class_args: for arg in cyp_class_args:
# Mark each cypclass arguments as locked within the function body # Mark each cypclass arguments as locked within the function body
arg_locks.append(self.stacklock(arg, "rlocked" if arg.type.is_const else "wlocked")) if arg.type.is_const:
arg_locks.append(self.stacklock(arg, "rlocked"))
self.const[arg] = 1
else:
arg_locks.append(self.stacklock(arg, "wlocked"))
with_body = lambda: self.visit(node.body) with_body = lambda: self.visit(node.body)
self.with_nested_stacklocks(iter(arg_locks), with_body) self.with_nested_stacklocks(iter(arg_locks), with_body)
self.const = old_const
return node return node
def visit_LockCypclassNode(self, node): def visit_LockCypclassNode(self, node):
...@@ -627,6 +642,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -627,6 +642,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(node.lhs.pos, "Assigning to a locked cypclass reference") error(node.lhs.pos, "Assigning to a locked cypclass reference")
return node return node
self.mark_const_local_alias(node.lhs, node.rhs)
node.rhs = self.lockcheck_if_subscript_rhs(node.lhs, node.rhs) node.rhs = self.lockcheck_if_subscript_rhs(node.lhs, node.rhs)
with self.accesscontext(writing=True): with self.accesscontext(writing=True):
self.visit(node.lhs) self.visit(node.lhs)
...@@ -641,6 +657,8 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -641,6 +657,8 @@ class CypclassLockTransform(Visitor.EnvTransform):
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(lhs.pos, "Assigning to a locked cypclass reference") error(lhs.pos, "Assigning to a locked cypclass reference")
return node return node
for lhs in node.lhs_list:
self.mark_const_local_alias(lhs, node.rhs)
for lhs in node.lhs_list: for lhs in node.lhs_list:
node.rhs = self.lockcheck_if_subscript_rhs(lhs, node.rhs) node.rhs = self.lockcheck_if_subscript_rhs(lhs, node.rhs)
with self.accesscontext(writing=True): with self.accesscontext(writing=True):
......
...@@ -39,6 +39,13 @@ cdef take_non_const(A a): ...@@ -39,6 +39,13 @@ cdef take_non_const(A a):
cdef take_const(const A a): cdef take_const(const A a):
take_non_const(a) take_non_const(a)
cdef propagate_local_const(const A a):
cdef A b = a
c = a
take_non_const(b)
take_non_const(c)
_ERRORS = u""" _ERRORS = u"""
20:4: Reference 'obj' is not correctly locked in this expression (write lock required) 20:4: Reference 'obj' is not correctly locked in this expression (write lock required)
21:4: Reference 'obj' is not correctly locked in this expression (read lock required) 21:4: Reference 'obj' is not correctly locked in this expression (read lock required)
...@@ -47,5 +54,7 @@ _ERRORS = u""" ...@@ -47,5 +54,7 @@ _ERRORS = u"""
25:4: Reference 'obj' is not correctly locked in this expression (read lock required) 25:4: Reference 'obj' is not correctly locked in this expression (read lock required)
26:21: Reference 'obj' is not correctly locked in this expression (read lock required) 26:21: Reference 'obj' is not correctly locked in this expression (read lock required)
32:17: Can only lock local variables or arguments 32:17: Can only lock local variables or arguments
40:19: Reference 'a' is const but requires a write lock 40:19: Local reference 'a' is const but requires a write lock
46:19: Local reference 'b' is const but requires a write lock
47:19: Local reference 'c' is const but requires a write lock
""" """
...@@ -51,3 +51,43 @@ def test_lock_traversal(n): ...@@ -51,3 +51,43 @@ def test_lock_traversal(n):
with wlocked contained: with wlocked contained:
argument_recursivity(contained, n) argument_recursivity(contained, n)
print contained.getter() print contained.getter()
cdef Container global_container
cdef int non_const_aliasing(const A a):
global global_container
global_container = Container()
global_container.object = a
b = global_container.object
with wlocked b:
b.setter(42)
return b.getter()
def test_non_const_aliasing():
"""
>>> test_non_const_aliasing()
42
"""
a = A()
with rlocked a:
return non_const_aliasing(a)
cdef A global_a
cdef int non_const_global_aliasing(const A a):
global global_a
global_a = a
b = global_a
with wlocked b:
b.setter(42)
return b.getter()
def test_non_const_global_aliasing():
"""
>>> test_non_const_global_aliasing()
42
"""
global global_a
a = A()
with rlocked a:
return non_const_global_aliasing(a)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment