Commit 96f819bb authored by Jim Fulton's avatar Jim Fulton

Added a lock for calling cache methods to avoid a race condition

between calls from the storage and from the out-of-band
invalidation message handler.
parent ec9f5c8f
...@@ -144,16 +144,23 @@ file 0 and file 1. ...@@ -144,16 +144,23 @@ file 0 and file 1.
""" """
__version__ = "$Revision: 1.11 $"[11:-2] __version__ = "$Revision: 1.12 $"[11:-2]
import os, tempfile import os, tempfile
from struct import pack, unpack from struct import pack, unpack
from thread import allocate_lock
magic='ZEC0' magic='ZEC0'
class ClientCache: class ClientCache:
def __init__(self, storage='', size=20000000, client=None, var=None): def __init__(self, storage='', size=20000000, client=None, var=None):
# Allocate locks:
l=allocate_lock()
self._acquire=l.acquire
self._release=l.release
if client: if client:
# Create a persistent cache # Create a persistent cache
if var is None: var=os.path.join(INSTANCE_HOME,'var') if var is None: var=os.path.join(INSTANCE_HOME,'var')
...@@ -199,6 +206,8 @@ class ClientCache: ...@@ -199,6 +206,8 @@ class ClientCache:
self._current=current self._current=current
def open(self): def open(self):
self._acquire()
try:
self._index=index={} self._index=index={}
self._get=index.get self._get=index.get
serial={} serial={}
...@@ -209,8 +218,11 @@ class ClientCache: ...@@ -209,8 +218,11 @@ class ClientCache:
self._pos=read_index(index, serial, f[current], current) self._pos=read_index(index, serial, f[current], current)
return serial.items() return serial.items()
finally: self._release()
def invalidate(self, oid, version): def invalidate(self, oid, version):
self._acquire()
try:
p=self._get(oid, None) p=self._get(oid, None)
if p is None: return None if p is None: return None
f=self._f[p < 0] f=self._f[p < 0]
...@@ -224,8 +236,11 @@ class ClientCache: ...@@ -224,8 +236,11 @@ class ClientCache:
else: else:
del self._index[oid] del self._index[oid]
f.write('i') f.write('i')
finally: self._release()
def load(self, oid, version): def load(self, oid, version):
self._acquire()
try:
p=self._get(oid, None) p=self._get(oid, None)
if p is None: return None if p is None: return None
f=self._f[p < 0] f=self._f[p < 0]
...@@ -261,13 +276,16 @@ class ClientCache: ...@@ -261,13 +276,16 @@ class ClientCache:
dlen=unpack(">i", read(4))[0] dlen=unpack(">i", read(4))[0]
return read(dlen), read(8) return read(dlen), read(8)
finally: self._release()
def update(self, oid, serial, version, data): def update(self, oid, serial, version, data):
self._acquire()
try:
if version: if version:
# We need to find and include non-version data # We need to find and include non-version data
p=self._get(oid, None) p=self._get(oid, None)
if p is None: if p is None:
return self.store(oid, '', '', version, data, serial) return self._store(oid, '', '', version, data, serial)
f=self._f[p < 0] f=self._f[p < 0]
ap=abs(p) ap=abs(p)
seek=f.seek seek=f.seek
...@@ -277,23 +295,26 @@ class ClientCache: ...@@ -277,23 +295,26 @@ class ClientCache:
if len(h)==27 and h[8] in 'nv' and h[:8]==oid: if len(h)==27 and h[8] in 'nv' and h[:8]==oid:
tlen, vlen, dlen = unpack(">iHi", h[9:19]) tlen, vlen, dlen = unpack(">iHi", h[9:19])
else: else:
return self.store(oid, '', '', version, data, serial) return self._store(oid, '', '', version, data, serial)
if tlen <= 0 or vlen < 0 or dlen <= 0 or vlen+dlen > tlen: if tlen <= 0 or vlen < 0 or dlen <= 0 or vlen+dlen > tlen:
return self.store(oid, '', '', version, data, serial) return self._store(oid, '', '', version, data, serial)
if dlen: if dlen:
p=read(dlen) p=read(dlen)
s=h[19:] s=h[19:]
else: else:
return self.store(oid, '', '', version, data, serial) return self._store(oid, '', '', version, data, serial)
self.store(oid, p, s, version, data, serial) self._store(oid, p, s, version, data, serial)
else: else:
# Simple case, just store new data: # Simple case, just store new data:
self.store(oid, data, serial, '', None, None) self._store(oid, data, serial, '', None, None)
finally: self._release()
def modifiedInVersion(self, oid): def modifiedInVersion(self, oid):
self._acquire()
try:
p=self._get(oid, None) p=self._get(oid, None)
if p is None: return None if p is None: return None
f=self._f[p < 0] f=self._f[p < 0]
...@@ -314,8 +335,11 @@ class ClientCache: ...@@ -314,8 +335,11 @@ class ClientCache:
if not vlen: return '' if not vlen: return ''
seek(dlen, 1) seek(dlen, 1)
return read(vlen) return read(vlen)
finally: self._release()
def checkSize(self, size): def checkSize(self, size):
self._acquire()
try:
# Make sure we aren't going to exceed the target size. # Make sure we aren't going to exceed the target size.
# If we are, then flip the cache. # If we are, then flip the cache.
if self._pos+size > self._limit: if self._pos+size > self._limit:
...@@ -336,9 +360,15 @@ class ClientCache: ...@@ -336,9 +360,15 @@ class ClientCache:
self._f[current] = tempfile.TemporaryFile(suffix='.zec') self._f[current] = tempfile.TemporaryFile(suffix='.zec')
self._f[current].write(magic) self._f[current].write(magic)
self._pos=pos=4 self._pos=pos=4
finally: self._release()
def store(self, oid, p, s, version, pv, sv): def store(self, oid, p, s, version, pv, sv):
self._acquire()
try: self._store(oid, p, s, version, pv, sv)
finally: self._release()
def _store(self, oid, p, s, version, pv, sv):
if not s: if not s:
p='' p=''
s='\0\0\0\0\0\0\0\0' s='\0\0\0\0\0\0\0\0'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment