Commit 5237c20f authored by Guillaume Hervier's avatar Guillaume Hervier

Refactor pool result queue to limit size

parent cebfe1e5
...@@ -35,8 +35,7 @@ def get_signature_vf(rorp): ...@@ -35,8 +35,7 @@ def get_signature_vf(rorp):
size = os.path.getsize(rorp.path) size = os.path.getsize(rorp.path)
file2 = open(rorp.path, 'rb') file2 = open(rorp.path, 'rb')
signature_fp = librsync.SigFile(file2, find_blocksize(size)) signature_fp = librsync.SigFile(file2, find_blocksize(size))
vf = connection.VirtualFile.new(signature_fp) return signature_fp
return vf
def find_blocksize(file_len): def find_blocksize(file_len):
"""Return a reasonable block size to use on files of length file_len """Return a reasonable block size to use on files of length file_len
......
...@@ -22,9 +22,10 @@ from multiprocessing import cpu_count ...@@ -22,9 +22,10 @@ from multiprocessing import cpu_count
import Queue import Queue
import itertools import itertools
import threading import threading
import sync
Job = namedtuple('Job', ['func', 'iterable', 'outqueue', 'options']) Job = namedtuple('Job', ['func', 'iterable', 'outqueue', 'options'])
Task = namedtuple('Task', ['func', 'args', 'index', 'outqueue', 'options']) Task = namedtuple('Task', ['func', 'args', 'out', 'options'])
Result = namedtuple('Result', ['index', 'value']) Result = namedtuple('Result', ['index', 'value'])
...@@ -35,18 +36,19 @@ def worker(taskqueue): ...@@ -35,18 +36,19 @@ def worker(taskqueue):
taskqueue.task_done() taskqueue.task_done()
break break
if task.func is None: if callable(task.out):
# It means this task was the last of an iterable job args = task.options.get('out_args', ())
result = None out = task.out(*args)
else: else:
try: out = task.out
value = task.func(*task.args)
except Exception as e:
value = e
result = Result(task.index, value)
task.outqueue.put(result, block=True) try:
value = task.func(*task.args)
except Exception as e:
value = e
result = value
out.set(result)
taskqueue.task_done() taskqueue.task_done()
...@@ -66,9 +68,7 @@ class Pool(object): ...@@ -66,9 +68,7 @@ class Pool(object):
self.start_workers() self.start_workers()
# Init task handler thread # Init task handler thread
self._job_handler_thread = self._start_handler_thread(self._job_handler, self._job_handler_thread = self._start_handler_thread(self._job_handler)
self.jobqueue,
self.taskqueue)
def start_workers(self): def start_workers(self):
while len(self.workers) < self.processes: while len(self.workers) < self.processes:
...@@ -88,13 +88,21 @@ class Pool(object): ...@@ -88,13 +88,21 @@ class Pool(object):
def create_job(self, func, iterable, **options): def create_job(self, func, iterable, **options):
max_outqueue_size = options.pop('max_outqueue_size', 0) max_outqueue_size = options.pop('max_outqueue_size', 0)
outqueue = Queue.Queue(maxsize=max_outqueue_size) outqueue = JobResultQueue(maxsize=max_outqueue_size)
job = Job(func, iterable, outqueue, options) job = Job(func, iterable, outqueue, options)
self.jobqueue.put(job) self.jobqueue.put(job, block=True)
return job return job
def create_task(self, func, args, out=None, **options):
if out is None:
out = sync.AsyncValue()
task = Task(func, args, out, options)
self.taskqueue.put(task, block=True)
return task
def imap(self, func, iterable, **options): def imap(self, func, iterable, **options):
iterable = itertools.imap(None, iterable) iterable = itertools.imap(None, iterable)
...@@ -104,9 +112,9 @@ class Pool(object): ...@@ -104,9 +112,9 @@ class Pool(object):
return IMapIterator(job.outqueue) return IMapIterator(job.outqueue)
def apply(self, func, *args, **options): def apply(self, func, *args, **options):
job = self.create_job(func, [args], **options) task = self.create_task(func, args, **options)
return AsyncResult(job.outqueue) return task.out
def stop(self): def stop(self):
self.jobqueue.put(None, block=True) self.jobqueue.put(None, block=True)
...@@ -117,20 +125,22 @@ class Pool(object): ...@@ -117,20 +125,22 @@ class Pool(object):
for w in self.workers: for w in self.workers:
w.join(timeout=timeout) w.join(timeout=timeout)
def _job_handler(self, jobqueue, taskqueue): def _job_handler(self):
while True: while True:
job = jobqueue.get(True) job = self.jobqueue.get(True)
if job is None: if job is None:
for w in self.workers: for w in self.workers:
taskqueue.put(None) self.taskqueue.put(None)
break break
for (index, args) in enumerate(job.iterable): for (index, args) in enumerate(job.iterable):
task = Task(job.func, args, index, job.outqueue, job.options) out = job.outqueue.slot(index)
taskqueue.put(task, block=True) task = self.create_task(job.func, args, out, **job.options)
taskqueue.put(Task(None, None, None, job.outqueue, job.options), block=True) # task = self.create_task(job.func, args, job.outqueue.slot,
# out_args=(index,), **job.options)
job.outqueue.set_length(index + 1)
jobqueue.task_done() self.jobqueue.task_done()
class IMapIterator(object): class IMapIterator(object):
def __init__(self, outqueue): def __init__(self, outqueue):
...@@ -143,34 +153,55 @@ class IMapIterator(object): ...@@ -143,34 +153,55 @@ class IMapIterator(object):
def next(self): def next(self):
while True: while True:
if self.index in self.results: slot = self.outqueue.get(self.index)
result = self.results.pop(self.index) if slot is None:
else: raise StopIteration()
result = self.outqueue.get(True)
if result is None:
raise StopIteration()
if isinstance(result.value, Exception):
raise result.value
if result.index != self.index:
self.results[result.index] = result
continue
self.index += 1 result = slot.get()
return result.value if isinstance(result, Exception):
raise result
class AsyncResult(object): self.index += 1
def __init__(self, outqueue): return result
self.outqueue = outqueue
self.completed = False class JobResultQueue(object):
self.value = None def __init__(self, maxsize=None):
self.maxsize = maxsize
def wait(self):
if self.completed: self.slots = {}
return self.updated = threading.Condition(threading.Lock())
self.value = self.outqueue.get(True) self.length = None
self.completed = True
def slot(self, index):
def get(self): self.updated.acquire()
self.wait() try:
return self.value while len(self.slots) == self.maxsize:
self.updated.wait()
slot = self.slots[index] = sync.AsyncValue()
self.updated.notify()
return slot
finally:
self.updated.release()
def get(self, index):
self.updated.acquire()
try:
while index not in self.slots:
if self.length is not None and index >= self.length:
return None
self.updated.wait()
slot = self.slots.pop(index)
self.updated.notify()
return slot
finally:
self.updated.release()
def set_length(self, length):
self.updated.acquire()
try:
self.length = length
self.updated.notify()
finally:
self.updated.release()
...@@ -49,7 +49,6 @@ def Restore(mirror_rp, inc_rpath, target, restore_to_time): ...@@ -49,7 +49,6 @@ def Restore(mirror_rp, inc_rpath, target, restore_to_time):
TargetS.patch(target, diff_iter) TargetS.patch(target, diff_iter)
pool.stop()
pool.join() pool.join()
MirrorS.close_rf_cache() MirrorS.close_rf_cache()
...@@ -293,8 +292,7 @@ class MirrorStruct: ...@@ -293,8 +292,7 @@ class MirrorStruct:
file_fp = cls.rf_cache.get_fp(expanded_index, mir_rorp) file_fp = cls.rf_cache.get_fp(expanded_index, mir_rorp)
if (target_rorp): if (target_rorp):
# a file is already there, we can create a diff # a file is already there, we can create a diff
signature_vf = target_rorp.conn.Rdiff.get_signature_vf(target_rorp) target_signature = target_rorp.conn.Rdiff.get_signature_vf(target_rorp)
target_signature = connection.VirtualFile(target_rorp.conn, signature_vf)
file_fobj = opener.lazy_open(file_fp.name, 'rb') file_fobj = opener.lazy_open(file_fp.name, 'rb')
delta_fp = mir_rorp.get_delta(target_signature, file_fobj) delta_fp = mir_rorp.get_delta(target_signature, file_fobj)
mir_rorp.setfile(delta_fp) mir_rorp.setfile(delta_fp)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment