Commit 9aed0252 authored by Guillaume Hervier's avatar Guillaume Hervier

Implement threading for diffs checking on restore

parent b0371fcc
...@@ -149,6 +149,7 @@ def init_connection(remote_cmd): ...@@ -149,6 +149,7 @@ def init_connection(remote_cmd):
stdin, stdout = os.popen2(remote_cmd) stdin, stdout = os.popen2(remote_cmd)
conn_number = len(Globals.connections) conn_number = len(Globals.connections)
conn = connection.PipeConnection(stdout, stdin, conn_number) conn = connection.PipeConnection(stdout, stdin, conn_number)
conn.Client()
check_connection_version(conn, remote_cmd) check_connection_version(conn, remote_cmd)
Log("Registering connection %d" % conn_number, 7) Log("Registering connection %d" % conn_number, 7)
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
from __future__ import generators from __future__ import generators
import types, os, tempfile, cPickle, shutil, traceback, \ import types, os, tempfile, cPickle, shutil, traceback, \
socket, sys, gzip socket, sys, gzip, threading
from pool import Pool
# The following EA and ACL modules may be used if available # The following EA and ACL modules may be used if available
try: import xattr try: import xattr
except ImportError: pass except ImportError: pass
...@@ -115,6 +116,8 @@ class LowLevelPipeConnection(Connection): ...@@ -115,6 +116,8 @@ class LowLevelPipeConnection(Connection):
"""inpipe is a file-type open for reading, outpipe for writing""" """inpipe is a file-type open for reading, outpipe for writing"""
self.inpipe = inpipe self.inpipe = inpipe
self.outpipe = outpipe self.outpipe = outpipe
self.write_lock = threading.RLock()
self.read_lock = threading.RLock()
def __str__(self): def __str__(self):
"""Return string version """Return string version
...@@ -128,6 +131,8 @@ class LowLevelPipeConnection(Connection): ...@@ -128,6 +131,8 @@ class LowLevelPipeConnection(Connection):
def _put(self, obj, req_num): def _put(self, obj, req_num):
"""Put an object into the pipe (will send raw if string)""" """Put an object into the pipe (will send raw if string)"""
self.write_lock.acquire()
log.Log.conn("sending", obj, req_num) log.Log.conn("sending", obj, req_num)
if type(obj) is types.StringType: self._putbuf(obj, req_num) if type(obj) is types.StringType: self._putbuf(obj, req_num)
elif isinstance(obj, connection.Connection):self._putconn(obj, req_num) elif isinstance(obj, connection.Connection):self._putconn(obj, req_num)
...@@ -140,6 +145,8 @@ class LowLevelPipeConnection(Connection): ...@@ -140,6 +145,8 @@ class LowLevelPipeConnection(Connection):
elif hasattr(obj, "next"): self._putiter(obj, req_num) elif hasattr(obj, "next"): self._putiter(obj, req_num)
else: self._putobj(obj, req_num) else: self._putobj(obj, req_num)
self.write_lock.release()
def _putobj(self, obj, req_num): def _putobj(self, obj, req_num):
"""Send a generic python obj down the outpipe""" """Send a generic python obj down the outpipe"""
self._write("o", cPickle.dumps(obj, 1), req_num) self._write("o", cPickle.dumps(obj, 1), req_num)
...@@ -229,7 +236,11 @@ class LowLevelPipeConnection(Connection): ...@@ -229,7 +236,11 @@ class LowLevelPipeConnection(Connection):
def _get(self): def _get(self):
"""Read an object from the pipe and return (req_num, value)""" """Read an object from the pipe and return (req_num, value)"""
self.read_lock.acquire()
header_string = self.inpipe.read(9) header_string = self.inpipe.read(9)
if len(header_string) == 0:
raise ConnectionQuit('EOF')
if not len(header_string) == 9: if not len(header_string) == 9:
raise ConnectionReadError("Truncated header string (problem " raise ConnectionReadError("Truncated header string (problem "
"probably originated remotely)") "probably originated remotely)")
...@@ -251,6 +262,8 @@ class LowLevelPipeConnection(Connection): ...@@ -251,6 +262,8 @@ class LowLevelPipeConnection(Connection):
assert format_string == "c", header_string assert format_string == "c", header_string
result = Globals.connection_dict[int(data)] result = Globals.connection_dict[int(data)]
log.Log.conn("received", result, req_num) log.Log.conn("received", result, req_num)
self.read_lock.release()
return (req_num, result) return (req_num, result)
def _getrorpath(self, raw_rorpath_buf): def _getrorpath(self, raw_rorpath_buf):
...@@ -276,6 +289,53 @@ class LowLevelPipeConnection(Connection): ...@@ -276,6 +289,53 @@ class LowLevelPipeConnection(Connection):
self.inpipe.close() self.inpipe.close()
class RequestNumberRegistry(object):
def __init__(self):
self._lock = threading.RLock()
self._next = 0
self._entries = {}
def get(self):
with self._lock:
if self._next >= 256:
return None
req_num = self._next
self.insert(req_num)
return req_num
def insert(self, req_num):
with self._lock:
if req_num in self._entries:
# Vacant slot
self._next = self._entries[req_num]
else:
self._next += 1
def remove(self, req_num):
with self._lock:
self._entries[req_num] = self._next
self._next = req_num
class AsyncRequest(object):
def __init__(self, req_num):
self.req_num = req_num
self.value = None
self.completed = threading.Event()
def set(self, value):
self.value = value
self.completed.set()
def get(self):
while not self.completed.is_set():
self.completed.wait()
return self.value
class PipeConnection(LowLevelPipeConnection): class PipeConnection(LowLevelPipeConnection):
"""Provide server and client functions for a Pipe Connection """Provide server and client functions for a Pipe Connection
...@@ -287,6 +347,17 @@ class PipeConnection(LowLevelPipeConnection): ...@@ -287,6 +347,17 @@ class PipeConnection(LowLevelPipeConnection):
client makes the first request, and the server listens first. client makes the first request, and the server listens first.
""" """
DISCARDED_RESULTS_FUNCTIONS = [
'log.Log.log_to_file',
'log.Log.close_logfile_allconn',
'rpath.setdata_local',
'Globals.set',
]
RUN_ON_MAIN_THREAD = [
'robust.install_signal_handlers',
]
def __init__(self, inpipe, outpipe, conn_number = 0): def __init__(self, inpipe, outpipe, conn_number = 0):
"""Init PipeConnection """Init PipeConnection
...@@ -298,45 +369,46 @@ class PipeConnection(LowLevelPipeConnection): ...@@ -298,45 +369,46 @@ class PipeConnection(LowLevelPipeConnection):
""" """
LowLevelPipeConnection.__init__(self, inpipe, outpipe) LowLevelPipeConnection.__init__(self, inpipe, outpipe)
self.conn_number = conn_number self.conn_number = conn_number
self.unused_request_numbers = {} self.request_numbers = RequestNumberRegistry()
for i in range(256): self.unused_request_numbers[i] = None self.requests = {}
self.pool = Pool(processes=4,
max_taskqueue_size=16)
self._read_thread = None
def __str__(self): return "PipeConnection %d" % self.conn_number def __str__(self): return "PipeConnection %d" % self.conn_number
def get_response(self, desired_req_num): def read_messages(self):
"""Read from pipe, responding to requests until req_num. while True:
try:
Sometimes after a request is sent, the other side will make req_num, obj = self._get()
another request before responding to the original one. In
that case, respond to the request. But return once the right
response is given.
"""
while 1:
try: req_num, object = self._get()
except ConnectionQuit: except ConnectionQuit:
self._put("quitting", self.get_new_req_num())
self._close() self._close()
return break
if req_num == desired_req_num: return object
else: if isinstance(obj, ConnectionRequest):
assert isinstance(object, ConnectionRequest) args = []
self.answer_request(object, req_num) for _ in range(obj.num_args):
arg_req_num, arg = self._get()
def answer_request(self, request, req_num): assert arg_req_num == req_num
"""Put the object requested by request down the pipe""" args.append(arg)
del self.unused_request_numbers[req_num]
argument_list = [] if Globals.server and obj.function_string in self.RUN_ON_MAIN_THREAD:
for i in range(request.num_args): self.answer_request(obj, args, req_num)
arg_req_num, arg = self._get() else:
assert arg_req_num == req_num self.pool.apply(self.answer_request, obj, args, req_num)
argument_list.append(arg) elif req_num in self.requests:
req = self.requests.pop(req_num)
req.set(obj)
self.request_numbers.remove(req_num)
def answer_request(self, request, args, req_num):
try: try:
Security.vet_request(request, argument_list) Security.vet_request(request, args)
result = apply(eval(request.function_string), argument_list) result = apply(eval(request.function_string), args)
except: result = self.extract_exception() except: result = self.extract_exception()
self._put(result, req_num)
self.unused_request_numbers[req_num] = None if request.function_string not in self.DISCARDED_RESULTS_FUNCTIONS:
self._put(result, req_num)
def extract_exception(self): def extract_exception(self):
"""Return active exception""" """Return active exception"""
...@@ -348,12 +420,24 @@ class PipeConnection(LowLevelPipeConnection): ...@@ -348,12 +420,24 @@ class PipeConnection(LowLevelPipeConnection):
"".join(traceback.format_tb(sys.exc_info()[2]))), 5) "".join(traceback.format_tb(sys.exc_info()[2]))), 5)
return sys.exc_info()[1] return sys.exc_info()[1]
def Client(self):
self._read_thread = read_thread = threading.Thread(target=self.read_messages)
read_thread.daemon = True
read_thread.start()
def Server(self): def Server(self):
"""Start server's read eval return loop""" """Start server's read eval return loop"""
Globals.server = 1 Globals.server = 1
Globals.connections.append(self) Globals.connections.append(self)
log.Log("Starting server", 6) log.Log("Starting server", 6)
self.get_response(-1) # self.get_response(-1)
self.read_messages()
def new_request(self):
req_num = self.get_new_req_num()
req = AsyncRequest(req_num)
self.requests[req_num] = req
return req
def reval(self, function_string, *args): def reval(self, function_string, *args):
"""Execute command on remote side """Execute command on remote side
...@@ -363,11 +447,20 @@ class PipeConnection(LowLevelPipeConnection): ...@@ -363,11 +447,20 @@ class PipeConnection(LowLevelPipeConnection):
function. function.
""" """
req_num = self.get_new_req_num() req = self.new_request()
self._put(ConnectionRequest(function_string, len(args)), req_num)
for arg in args: self._put(arg, req_num) self.write_lock.acquire()
result = self.get_response(req_num) self._put(ConnectionRequest(function_string, len(args)), req.req_num)
self.unused_request_numbers[req_num] = None for arg in args: self._put(arg, req.req_num)
self.write_lock.release()
if function_string in self.DISCARDED_RESULTS_FUNCTIONS:
result = None
del self.requests[req.req_num]
self.request_numbers.remove(req.req_num)
else:
result = req.get()
if isinstance(result, Exception): raise result if isinstance(result, Exception): raise result
elif isinstance(result, SystemExit): raise result elif isinstance(result, SystemExit): raise result
elif isinstance(result, KeyboardInterrupt): raise result elif isinstance(result, KeyboardInterrupt): raise result
...@@ -375,18 +468,19 @@ class PipeConnection(LowLevelPipeConnection): ...@@ -375,18 +468,19 @@ class PipeConnection(LowLevelPipeConnection):
def get_new_req_num(self): def get_new_req_num(self):
"""Allot a new request number and return it""" """Allot a new request number and return it"""
if not self.unused_request_numbers: req_num = self.request_numbers.get()
if req_num is None:
raise ConnectionError("Exhaused possible connection numbers") raise ConnectionError("Exhaused possible connection numbers")
req_num = self.unused_request_numbers.keys()[0]
del self.unused_request_numbers[req_num]
return req_num return req_num
def quit(self): def quit(self):
"""Close the associated pipes and tell server side to quit""" """Close the associated pipes and tell server side to quit"""
assert not Globals.server assert not Globals.server
self._putquit() self._putquit()
self._get() if self._read_thread is not None:
self._close() self._read_thread.join()
self.pool.stop()
self.pool.join()
def __getattr__(self, name): def __getattr__(self, name):
"""Intercept attributes to allow for . invocation""" """Intercept attributes to allow for . invocation"""
......
...@@ -127,7 +127,7 @@ class IterVirtualFile(UnwrapFile): ...@@ -127,7 +127,7 @@ class IterVirtualFile(UnwrapFile):
return_val = self.buffer[:real_len] return_val = self.buffer[:real_len]
self.buffer = self.buffer[real_len:] self.buffer = self.buffer[real_len:]
return return_val return return_val
def addtobuffer(self): def addtobuffer(self):
"""Read a chunk from the file and add it to the buffer""" """Read a chunk from the file and add it to the buffer"""
assert self.iwf.currently_in_file assert self.iwf.currently_in_file
...@@ -335,6 +335,8 @@ class MiscIterToFile(FileWrappingIter): ...@@ -335,6 +335,8 @@ class MiscIterToFile(FileWrappingIter):
elif currentobj is iterfile.MiscIterFlushRepeat: elif currentobj is iterfile.MiscIterFlushRepeat:
self.add_misc(currentobj) self.add_misc(currentobj)
return None return None
elif isinstance(currentobj, rpath.RPath):
self.addrpath(currentobj)
elif isinstance(currentobj, rpath.RORPath): elif isinstance(currentobj, rpath.RORPath):
self.addrorp(currentobj) self.addrorp(currentobj)
else: self.add_misc(currentobj) else: self.add_misc(currentobj)
...@@ -358,7 +360,19 @@ class MiscIterToFile(FileWrappingIter): ...@@ -358,7 +360,19 @@ class MiscIterToFile(FileWrappingIter):
self.array_buf.fromstring("r") self.array_buf.fromstring("r")
self.array_buf.fromstring(C.long2str(long(len(pickle)))) self.array_buf.fromstring(C.long2str(long(len(pickle))))
self.array_buf.fromstring(pickle) self.array_buf.fromstring(pickle)
def addrpath(self, rp):
if rp.file:
data = (rp.conn.conn_number, rp.base, rp.index, rp.data, 1)
self.next_in_line = rp.file
else:
data = (rp.conn.conn_number, rp.base, rp.index, rp.data, 0)
self.rorps_in_buffer += 1
pickle = cPickle.dumps(data, 1)
self.array_buf.fromstring("R")
self.array_buf.fromstring(C.long2str(long(len(pickle))))
self.array_buf.fromstring(pickle)
def addfinal(self): def addfinal(self):
"""Signal the end of the iterator to the other end""" """Signal the end of the iterator to the other end"""
self.array_buf.fromstring("z") self.array_buf.fromstring("z")
...@@ -383,9 +397,19 @@ class FileToMiscIter(IterWrappingFile): ...@@ -383,9 +397,19 @@ class FileToMiscIter(IterWrappingFile):
while not type: type, data = self._get() while not type: type, data = self._get()
if type == "z": raise StopIteration if type == "z": raise StopIteration
elif type == "r": return self.get_rorp(data) elif type == "r": return self.get_rorp(data)
elif type == "R": return self.get_rp(data)
elif type == "o": return data elif type == "o": return data
else: raise IterFileException("Bad file type %s" % (type,)) else: raise IterFileException("Bad file type %s" % (type,))
def get_rp(self, pickled_tuple):
conn_number, base, index, data_dict, num_files = pickled_tuple
rp = rpath.RPath(Globals.connection_dict[conn_number],
base, index, data_dict)
if num_files:
assert num_files == 1, "Only one file accepted right now"
rp.setfile(self.get_file())
return rp
def get_rorp(self, pickled_tuple): def get_rorp(self, pickled_tuple):
"""Return rorp that data represents""" """Return rorp that data represents"""
index, data_dict, num_files = pickled_tuple index, data_dict, num_files = pickled_tuple
...@@ -419,7 +443,7 @@ class FileToMiscIter(IterWrappingFile): ...@@ -419,7 +443,7 @@ class FileToMiscIter(IterWrappingFile):
type, length = self.buf[0], C.str2long(self.buf[1:8]) type, length = self.buf[0], C.str2long(self.buf[1:8])
data = self.buf[8:8+length] data = self.buf[8:8+length]
self.buf = self.buf[8+length:] self.buf = self.buf[8+length:]
if type in "oerh": return type, cPickle.loads(data) if type in "oerRh": return type, cPickle.loads(data)
else: return type, data else: return type, data
......
...@@ -135,6 +135,7 @@ class Logger: ...@@ -135,6 +135,7 @@ class Logger:
if verbosity <= 2 or Globals.server: termfp = sys.stderr if verbosity <= 2 or Globals.server: termfp = sys.stderr
else: termfp = sys.stdout else: termfp = sys.stdout
termfp.write(self.format(message, self.term_verbosity)) termfp.write(self.format(message, self.term_verbosity))
termfp.flush()
def conn(self, direction, result, req_num): def conn(self, direction, result, req_num):
"""Log some data on the connection """Log some data on the connection
......
# vim: set nolist noet ts=4:
# Copyright 2002, 2003, 2004, 2005 Ben Escoto
#
# This file is part of rdiff-backup.
#
# rdiff-backup is free software; you can redistribute it and/or modify
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# rdiff-backup is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with rdiff-backup; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
# USA
from collections import namedtuple
from multiprocessing import cpu_count
import Queue
import itertools
import threading
Job = namedtuple('Job', ['func', 'iterable', 'outqueue'])
Task = namedtuple('Task', ['func', 'args', 'index', 'outqueue'])
Result = namedtuple('Result', ['index', 'value'])
RUNNING = 0
STOPPED = 1
def worker(taskqueue):
while True:
task = taskqueue.get(True)
if task is None:
taskqueue.task_done()
break
if task.func is None:
# It means this task was the last of an iterable job
result = None
else:
value = task.func(*task.args)
result = Result(task.index, value)
task.outqueue.put(result, block=True)
taskqueue.task_done()
class Pool(object):
def __init__(self, processes=None,
max_taskqueue_size=0, max_jobqueue_size=0):
if processes is None:
processes = cpu_count()
self.processes = processes
self.state = STOPPED
# Init queues
self.taskqueue = Queue.Queue(maxsize=max_taskqueue_size)
self.jobqueue = Queue.Queue(maxsize=max_jobqueue_size)
# Init workers
self.workers = []
self.start_workers()
# Init task handler thread
self._job_handler_thread = self._start_handler_thread(self._job_handler,
self.jobqueue,
self.taskqueue)
def start_workers(self):
while len(self.workers) < self.processes:
w = self._start_handler_thread(worker, self.taskqueue)
self.workers.append(w)
for w in self.workers:
if not w.is_alive():
w.start()
def _start_handler_thread(self, func, *args):
thread = threading.Thread(target=func, args=args)
thread.daemon = True
thread.start()
return thread
def create_job(self, func, iterable, max_outqueue_size=0):
outqueue = Queue.Queue(maxsize=max_outqueue_size)
job = Job(func, iterable, outqueue)
self.jobqueue.put(job)
return job
def imap(self, func, iterable, max_outqueue_size=0):
iterable = itertools.imap(None, iterable)
job = self.create_job(func, iterable,
max_outqueue_size=max_outqueue_size)
return IMapIterator(job.outqueue)
def apply(self, func, *args):
job = self.create_job(func, [args])
return AsyncResult(job.outqueue)
def stop(self):
self.jobqueue.put(None, block=True)
def join(self, timeout=None):
self.stop()
self._job_handler_thread.join(timeout=timeout)
for w in self.workers:
w.join(timeout=timeout)
def _job_handler(self, jobqueue, taskqueue):
while True:
job = jobqueue.get(True)
if job is None:
for w in self.workers:
taskqueue.put(None)
break
for (index, args) in enumerate(job.iterable):
task = Task(job.func, args, index, job.outqueue)
taskqueue.put(task, block=True)
taskqueue.put(Task(None, None, None, job.outqueue), block=True)
jobqueue.task_done()
class IMapIterator(object):
def __init__(self, outqueue):
self.outqueue = outqueue
self.results = {}
self.index = 0
def __iter__(self):
return self
def next(self):
while True:
if self.index in self.results:
result = self.results.pop(self.index)
else:
result = self.outqueue.get(True)
if result is None:
raise StopIteration()
if result.index != self.index:
self.results[result.index] = result
continue
self.index += 1
return result.value
class AsyncResult(object):
def __init__(self, outqueue):
self.outqueue = outqueue
self.completed = False
self.value = None
def wait(self):
if self.completed:
return
self.value = self.outqueue.get(True)
self.completed = True
def get(self):
self.wait()
return self.value
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
"""Read increment files and restore to original""" """Read increment files and restore to original"""
from __future__ import generators from __future__ import generators
import tempfile, os, cStringIO from pool import Pool as ThreadPool
import tempfile, os, cStringIO, itertools
import static, rorpiter, FilenameMapping, connection import static, rorpiter, FilenameMapping, connection
class RestoreError(Exception): pass class RestoreError(Exception): pass
...@@ -31,15 +32,43 @@ def Restore(mirror_rp, inc_rpath, target, restore_to_time): ...@@ -31,15 +32,43 @@ def Restore(mirror_rp, inc_rpath, target, restore_to_time):
MirrorS = mirror_rp.conn.restore.MirrorStruct MirrorS = mirror_rp.conn.restore.MirrorStruct
TargetS = target.conn.restore.TargetStruct TargetS = target.conn.restore.TargetStruct
pool = ThreadPool(max_taskqueue_size=8)
MirrorS.set_mirror_and_rest_times(restore_to_time) MirrorS.set_mirror_and_rest_times(restore_to_time)
MirrorS.initialize_rf_cache(mirror_rp, inc_rpath) MirrorS.initialize_rf_cache(mirror_rp, inc_rpath)
# we run this locally to retrieve RPath instead of RORPath objects # we run this locally to retrieve RPath instead of RORPath objects
# target_iter = TargetS.get_initial_iter(target) target_iter = TargetS.get_initial_iter(target)
target_iter = selection.Select(target).set_iter() # target_iter = selection.Select(target).set_iter()
diff_iter = MirrorS.get_diffs(target_iter) mir_iter = MirrorS.subtract_indicies(MirrorS.mirror_base.index,
MirrorS.get_mirror_rorp_iter())
collated = rorpiter.Collate2Iters(mir_iter, target_iter)
diff_iter = pool.imap(get_diff, collated, max_outqueue_size=8)
diff_iter = itertools.ifilter(lambda diff: diff is not None, diff_iter)
TargetS.patch(target, diff_iter) TargetS.patch(target, diff_iter)
pool.stop()
pool.join()
MirrorS.close_rf_cache() MirrorS.close_rf_cache()
def get_diff(args):
mir_rorp, target_rorp = args
if Globals.preserve_hardlinks and mir_rorp:
Hardlink.add_rorp(mir_rorp, target_rorp)
diff = None
if not (target_rorp and mir_rorp and target_rorp == mir_rorp and
(not Globals.preserve_hardlinks or
Hardlink.rorp_eq(mir_rorp, target_rorp))):
diff = MirrorStruct.get_diff(mir_rorp, target_rorp)
if Globals.preserve_hardlinks and mir_rorp:
Hardlink.del_rorp(mir_rorp)
return diff
def get_inclist(inc_rpath): def get_inclist(inc_rpath):
"""Returns increments with given base""" """Returns increments with given base"""
dirname, basename = inc_rpath.dirsplit() dirname, basename = inc_rpath.dirsplit()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment