Commit ef5401a4 authored by Julien Muchembled's avatar Julien Muchembled

Add a way to define network parameters in the registry and propagate them efficiently

parent aba0e94d
...@@ -3,7 +3,7 @@ import httplib, logging, socket ...@@ -3,7 +3,7 @@ import httplib, logging, socket
from BaseHTTPServer import BaseHTTPRequestHandler from BaseHTTPServer import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer from SocketServer import ThreadingTCPServer
from urlparse import parse_qsl from urlparse import parse_qsl
from re6st import ctl, registry, utils from re6st import ctl, registry, utils, version
# To generate server ca and key with serial for 2001:db8:42::/48 # To generate server ca and key with serial for 2001:db8:42::/48
# openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 3650 -out ca.crt # openssl req -nodes -new -x509 -key ca.key -set_serial 0x120010db80042 -days 3650 -out ca.crt
...@@ -89,8 +89,15 @@ def main(): ...@@ -89,8 +89,15 @@ def main():
_('-v', '--verbose', default=1, type=int, _('-v', '--verbose', default=1, type=int,
help="Log level. 0 disables logging." help="Log level. 0 disables logging."
" Use SIGUSR1 to reopen log.") " Use SIGUSR1 to reopen log.")
_('--min-protocol', default=version.min_protocol, type=int,
help="Reject nodes that are too old. Current is %s." % version.protocol)
config = parser.parse_args() config = parser.parse_args()
if not version.min_protocol <= config.min_protocol <= version.protocol:
parser.error("--min-protocol: value must between %s and %s (included)"
% (version.min_protocol, version.protocol))
utils.setupLog(config.verbose, config.logfile) utils.setupLog(config.verbose, config.logfile)
server = registry.RegistryServer(config) server = registry.RegistryServer(config)
......
import logging, os, sqlite3, socket, subprocess, time import json, logging, os, sqlite3, socket, subprocess, time, zlib
from re6st.registry import RegistryClient from re6st.registry import RegistryClient
from . import utils from . import utils, version, x509
class Cache(object): class Cache(object):
# internal ip = temp arg/attribute
def __init__(self, db_path, registry, cert, db_size=200): def __init__(self, db_path, registry, cert, db_size=200):
self._prefix = cert.prefix self._prefix = cert.prefix
self._db_size = db_size self._db_size = db_size
...@@ -25,26 +24,28 @@ class Cache(object): ...@@ -25,26 +24,28 @@ class Cache(object):
try INTEGER NOT NULL DEFAULT 0)""") try INTEGER NOT NULL DEFAULT 0)""")
q("CREATE INDEX volatile.stat_try ON stat(try)") q("CREATE INDEX volatile.stat_try ON stat(try)")
q("INSERT INTO volatile.stat (peer) SELECT prefix FROM peer") q("INSERT INTO volatile.stat (peer) SELECT prefix FROM peer")
self._db.commit()
self._loadConfig(q("SELECT * FROM config"))
try: try:
a = q("SELECT value FROM config WHERE name='registry'").next()[0] cert.verifyVersion(self.version)
int(a, 2) except (AttributeError, x509.VerifyError):
except (StopIteration, ValueError):
logging.info("Asking registry its private IP...")
retry = 1 retry = 1
while True: while not self.updateConfig():
try:
a = self._registry.getPrefix(self._prefix)
int(a, 2)
break
except socket.error, e:
logging.warning(e)
time.sleep(retry) time.sleep(retry)
retry = min(60, retry * 2) retry = min(60, retry * 2)
q("INSERT OR REPLACE INTO config VALUES ('registry',?)", (a,)) else:
self._db.commit() if (# re6stnet upgraded after being unused for a long time.
self.registry_prefix = a self.protocol < version.protocol
logging.info("Cache initialized. Prefix of registry node is %s/%u", # Always query the registry at startup in case we were down
int(a, 2), len(a)) # when it tried to send us new parameters.
or self._prefix == self.registry_prefix):
self.updateConfig()
if version.protocol < self.min_protocol:
logging.critical("Your version of re6stnet is too old."
" Please update.")
sys.exit(1)
self.warnProtocol()
logging.info("Cache initialized.")
def _open(self, path): def _open(self, path):
db = sqlite3.connect(path, isolation_level=None) db = sqlite3.connect(path, isolation_level=None)
...@@ -59,6 +60,58 @@ class Cache(object): ...@@ -59,6 +60,58 @@ class Cache(object):
"value") "value")
return db return db
def _loadConfig(self, config):
cls = self.__class__
logging.debug("Loading network parameters:")
for k, v in config:
hasattr(cls, k) or setattr(self, k, v)
logging.debug("- %s: %r", k, v)
def updateConfig(self):
logging.info("Getting new network parameters from registry...")
try:
# TODO: When possible, the registry should be queried via the re6st.
config = json.loads(zlib.decompress(
self._registry.getNetworkConfig(self._prefix)))
base64 = config.pop('', ())
config = dict((str(k), v.decode('base64') if k in base64 else
str(v) if type(v) is unicode else v)
for k, v in config.iteritems())
except socket.error, e:
logging.warning(e)
return
except Exception:
# Even if the response is authenticated, a mistake on the registry
# should not kill the whole network in a few seconds.
logging.exception("buggy registry ?")
return
# XXX: check version ?
self.delay_restart = config.pop("delay_restart", 0)
old = {}
with self._db as db:
remove = []
for k, v in db.execute("SELECT * FROM config"):
if k in config:
old[k] = v
continue
try:
delattr(self, k)
except AttributeError:
pass
remove.append(k)
db.execute("DELETE FROM config WHERE name in ('%s')"
% "','".join(remove))
db.executemany("INSERT OR REPLACE INTO config VALUES(?,?)",
config.iteritems())
self._loadConfig(config.iteritems())
return [k for k, v in config.iteritems()
if k not in old or old[k] != v]
def warnProtocol(self):
if version.protocol < self.protocol:
logging.warning("There's a new version of re6stnet:"
" you should update.")
def log(self): def log(self):
if logging.getLogger().isEnabledFor(5): if logging.getLogger().isEnabledFor(5):
logging.trace("Cache:") logging.trace("Cache:")
......
...@@ -20,7 +20,7 @@ Authenticated communication: ...@@ -20,7 +20,7 @@ Authenticated communication:
""" """
import base64, hmac, hashlib, httplib, inspect, json, logging import base64, hmac, hashlib, httplib, inspect, json, logging
import mailbox, os, random, select, smtplib, socket, sqlite3 import mailbox, os, random, select, smtplib, socket, sqlite3
import string, sys, threading, time, weakref import string, struct, sys, threading, time, weakref, zlib
from collections import defaultdict, deque from collections import defaultdict, deque
from datetime import datetime from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
...@@ -55,10 +55,12 @@ class RegistryServer(object): ...@@ -55,10 +55,12 @@ class RegistryServer(object):
utils.makedirs(os.path.dirname(self.config.db)) utils.makedirs(os.path.dirname(self.config.db))
self.db = sqlite3.connect(self.config.db, isolation_level=None, self.db = sqlite3.connect(self.config.db, isolation_level=None,
check_same_thread=False) check_same_thread=False)
self.db.text_factory = str
utils.sqliteCreateTable(self.db, "config", utils.sqliteCreateTable(self.db, "config",
"name TEXT PRIMARY KEY NOT NULL", "name TEXT PRIMARY KEY NOT NULL",
"value") "value")
self.prefix = self.getConfig("prefix", None) self.prefix = self.getConfig("prefix", None)
self.version = self.getConfig("version", "\0")
utils.sqliteCreateTable(self.db, "token", utils.sqliteCreateTable(self.db, "token",
"token TEXT PRIMARY KEY NOT NULL", "token TEXT PRIMARY KEY NOT NULL",
"email TEXT NOT NULL", "email TEXT NOT NULL",
...@@ -82,6 +84,9 @@ class RegistryServer(object): ...@@ -82,6 +84,9 @@ class RegistryServer(object):
weakref.proxy(self), self.network) weakref.proxy(self), self.network)
self.onTimeout() self.onTimeout()
if self.prefix:
with self.db:
self.updateNetworkConfig()
def getConfig(self, name, *default): def getConfig(self, name, *default):
r, = next(self.db.execute( r, = next(self.db.execute(
...@@ -92,6 +97,41 @@ class RegistryServer(object): ...@@ -92,6 +97,41 @@ class RegistryServer(object):
self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)",
name_value) name_value)
def updateNetworkConfig(self):
kw = {
'protocol': version.protocol,
'registry_prefix': self.prefix,
}
for x in 'min_protocol',:
kw[x] = getattr(self.config, x)
config = json.dumps(kw, sort_keys=True)
if config != self.getConfig('last_config', None):
self.version = self.encodeVersion(
1 + self.decodeVersion(self.version))
self.setConfig('version', self.version)
self.setConfig('last_config', config)
self.sendto(self.prefix, 0)
kw[''] = 'version',
# Example to avoid all nodes to restart at the same time:
# kw['delay_restart'] = 600 * random.random()
kw['version'] = self.version.encode('base64')
self.network_config = zlib.compress(json.dumps(kw))
# The 3 first bits code the number of bytes.
def encodeVersion(self, version):
for n in xrange(8):
x = 32 << 8 * n
if version < x:
x = struct.pack("!Q", version + n * x)[7-n:]
return x + self.cert.sign(x)
version -= x
def decodeVersion(self, version):
n = ord(version[0]) >> 5
version, = struct.unpack("!Q", '\0' * (7 - n) + version[:n+1])
return sum((32 << 8 * n for n in xrange(n)),
version - (n * 32 << 8 * n))
def sendto(self, prefix, code): def sendto(self, prefix, code):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT)) self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT))
...@@ -307,6 +347,7 @@ class RegistryServer(object): ...@@ -307,6 +347,7 @@ class RegistryServer(object):
if self.prefix is None: if self.prefix is None:
self.prefix = prefix self.prefix = prefix
self.setConfig('prefix', prefix) self.setConfig('prefix', prefix)
self.updateNetworkConfig()
subject = req.get_subject() subject = req.get_subject()
subject.serialNumber = str(self.getSubjectSerial()) subject.serialNumber = str(self.getSubjectSerial())
return self.createCertificate(prefix, subject, req.get_pubkey()) return self.createCertificate(prefix, subject, req.get_pubkey())
...@@ -337,8 +378,8 @@ class RegistryServer(object): ...@@ -337,8 +378,8 @@ class RegistryServer(object):
cert.set_pubkey(pubkey) cert.set_pubkey(pubkey)
# Certificate serial, for revocation support. Contrary to # Certificate serial, for revocation support. Contrary to
# subject serial, it does not need to be as small as possible. # subject serial, it does not need to be as small as possible.
serial = 1 + self.getConfig('_serial', 0) serial = 1 + self.getConfig('serial', 0)
self.setConfig('_serial', serial) self.setConfig('serial', serial)
cert.set_serial_number(serial) cert.set_serial_number(serial)
cert.sign(self.cert.key, 'sha1') cert.sign(self.cert.key, 'sha1')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
...@@ -367,8 +408,8 @@ class RegistryServer(object): ...@@ -367,8 +408,8 @@ class RegistryServer(object):
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca) return crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert.ca)
@rpc @rpc
def getPrefix(self, cn): def getNetworkConfig(self, cn):
return self.prefix return self.network_config
@rpc @rpc
def getBootstrapPeer(self, cn): def getBootstrapPeer(self, cn):
......
...@@ -162,6 +162,11 @@ class TunnelKiller(object): ...@@ -162,6 +162,11 @@ class TunnelKiller(object):
class BaseTunnelManager(object): class BaseTunnelManager(object):
# TODO: To minimize downtime when network parameters change, we should do
# our best to not restart any process. Ideally, this list should be
# empty and the affected subprocesses reloaded.
NEED_RESTART = frozenset()
_forward = None _forward = None
def __init__(self, cache, cert, cert_renew, address=()): def __init__(self, cache, cert, cert_renew, address=()):
...@@ -172,6 +177,7 @@ class BaseTunnelManager(object): ...@@ -172,6 +177,7 @@ class BaseTunnelManager(object):
self._connecting = set() self._connecting = set()
self._connection_dict = {} self._connection_dict = {}
self._served = set() self._served = set()
self._version = cache.version
address_dict = defaultdict(list) address_dict = defaultdict(list)
for family, address in address: for family, address in address:
...@@ -186,12 +192,27 @@ class BaseTunnelManager(object): ...@@ -186,12 +192,27 @@ class BaseTunnelManager(object):
self.sock.bind(('::', PORT)) self.sock.bind(('::', PORT))
p = x509.Peer(self._prefix) p = x509.Peer(self._prefix)
self._next_invalidated = p.stop_date = cert_renew p.stop_date = cert_renew
self._peers = [p] self._peers = [p]
self._timeouts = [(cert_renew, self.invalidatePeers)]
def select(self, r, w, t): def select(self, r, w, t):
r[self.sock] = self.handlePeerEvent r[self.sock] = self.handlePeerEvent
t.append((self._next_invalidated, self.invalidatePeers)) t += self._timeouts
def selectTimeout(self, next, callback, force=True):
t = self._timeouts
for i, x in enumerate(t):
if x[1] == callback:
if not next:
logging.debug("timeout: removing %r (%s)", callback.__name__, next)
del t[i]
elif force or next < x[0]:
logging.debug("timeout: updating %r (%s)", callback.__name__, next)
t[i] = next, callback
return
logging.debug("timeout: adding %r (%s)", callback.__name__, next)
t.append((next, callback))
def invalidatePeers(self): def invalidatePeers(self):
next = float('inf') next = float('inf')
...@@ -206,11 +227,14 @@ class BaseTunnelManager(object): ...@@ -206,11 +227,14 @@ class BaseTunnelManager(object):
next = peer.stop_date next = peer.stop_date
for i in reversed(remove): for i in reversed(remove):
del self._peers[i] del self._peers[i]
self._next_invalidated = next self.selectTimeout(next, self.invalidatePeers)
def _getPeer(self, prefix):
return self._peers[bisect(self._peers, prefix) - 1]
def sendto(self, prefix, msg): def sendto(self, prefix, msg):
to = utils.ipFromBin(self._network + prefix), PORT to = utils.ipFromBin(self._network + prefix), PORT
peer = self._peers[bisect(self._peers, prefix) - 1] peer = self._getPeer(prefix)
if peer.prefix != prefix: if peer.prefix != prefix:
peer = x509.Peer(prefix) peer = x509.Peer(prefix)
insort(self._peers, peer) insort(self._peers, peer)
...@@ -259,7 +283,7 @@ class BaseTunnelManager(object): ...@@ -259,7 +283,7 @@ class BaseTunnelManager(object):
if len(msg) <= 4 or not sender.startswith(self._network): if len(msg) <= 4 or not sender.startswith(self._network):
return return
prefix = sender[len(self._network):] prefix = sender[len(self._network):]
peer = self._peers[bisect(self._peers, prefix) - 1] peer = self._getPeer(prefix)
msg = peer.decode(msg) msg = peer.decode(msg)
if type(msg) is tuple: if type(msg) is tuple:
seqno, msg = msg seqno, msg = msg
...@@ -274,7 +298,8 @@ class BaseTunnelManager(object): ...@@ -274,7 +298,8 @@ class BaseTunnelManager(object):
logging.debug('ignored new session key from %r', logging.debug('ignored new session key from %r',
address, exc_info=1) address, exc_info=1)
return return
self._sendto(to, "", peer) # ack peer.version = self._version \
if self._sendto(to, '\0' + self._version, peer) else ''
return return
if seqno: if seqno:
h = x509.fingerprint(self.cert.cert).digest() h = x509.fingerprint(self.cert.cert).digest()
...@@ -298,8 +323,7 @@ class BaseTunnelManager(object): ...@@ -298,8 +323,7 @@ class BaseTunnelManager(object):
insort(self._peers, peer) insort(self._peers, peer)
peer.cert = cert peer.cert = cert
peer.stop_date = stop_date peer.stop_date = stop_date
if stop_date < self._next_invalidated: self.selectTimeout(stop_date, self.invalidatePeers, False)
self._next_invalidated = stop_date
if seqno: if seqno:
self._sendto(to, peer.hello(self.cert)) self._sendto(to, peer.hello(self.cert))
else: else:
...@@ -330,7 +354,25 @@ class BaseTunnelManager(object): ...@@ -330,7 +354,25 @@ class BaseTunnelManager(object):
self._makeTunnel(peer, msg) self._makeTunnel(peer, msg)
else: else:
return ';'.join(self._address.itervalues()) return ';'.join(self._address.itervalues())
elif 2 <= code <= 3: # kill elif not code: # ver
if peer:
try:
if msg == self._version:
return
self.cert.verifyVersion(msg)
except x509.VerifyError:
pass
else:
if msg < self._version:
return self._version
self._version = msg
self.selectTimeout(time.time() + 1, self.newVersion)
finally:
if peer:
self._getPeer(peer).version = self._version
else:
self.selectTimeout(time.time() + 1, self.newVersion)
elif code <= 3: # kill
if peer: if peer:
try: try:
tunnel_killer = self._killing[peer] tunnel_killer = self._killing[peer]
...@@ -353,6 +395,33 @@ class BaseTunnelManager(object): ...@@ -353,6 +395,33 @@ class BaseTunnelManager(object):
for x in (self._connection_dict, self._served) for x in (self._connection_dict, self._served)
for x in x) for x in x)
@staticmethod
def _restart():
raise utils.ReexecException(
"Restart with new network parameters")
def broadcastVersion(self):
pass
def newVersion(self):
changed = self.cache.updateConfig()
if changed is None:
logging.info("will retry to update network parameters in 5 minutes")
self.selectTimeout(time.time() + 300, self.newVersion)
return
logging.info("changed: %r", changed)
self.selectTimeout(None, self.newVersion)
self._version = self.cache.version
self.broadcastVersion()
self.cache.warnProtocol()
if not self.NEED_RESTART.isdisjoint(changed) or \
version.protocol < self.cache.min_protocol:
# Wait at least 1 second to broadcast new version to neighbours.
# If re6stnet is too old, don't abort now, because a new version
# may have been installed without restart.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart)
class TunnelManager(BaseTunnelManager): class TunnelManager(BaseTunnelManager):
...@@ -693,3 +762,13 @@ class TunnelManager(BaseTunnelManager): ...@@ -693,3 +762,13 @@ class TunnelManager(BaseTunnelManager):
family, address = self._ip_changed(ip) family, address = self._ip_changed(ip)
if address: if address:
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
def broadcastVersion(self):
for prefix in self.ctl.neighbours:
if prefix:
peer = self._getPeer(prefix)
if peer.prefix != prefix:
self.sendto(prefix, None)
elif (peer.version < self._version and
self.sendto(prefix, '\0' + self._version)):
peer.version = self._version
...@@ -27,5 +27,13 @@ version = "0-%s.g%s" % (revision, short) ...@@ -27,5 +27,13 @@ version = "0-%s.g%s" % (revision, short)
if dirty: if dirty:
version += ".dirty" version += ".dirty"
# Because the software could be forked or have local changes/commits, above
# properties can't be used to decide whether a peer runs an appropriate version:
# they are intended to the network admin.
# Only 'protocol' is important and it must be increased whenever they would be
# a wish to force an update of nodes.
protocol = 1
min_protocol = 1
if __name__ == "__main__": if __name__ == "__main__":
print version print version
...@@ -149,6 +149,13 @@ class Cert(object): ...@@ -149,6 +149,13 @@ class Cert(object):
raise subprocess.CalledProcessError(p.returncode, 'openssl', err) raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out return out
def verifyVersion(self, version):
try:
n = 1 + (ord(version[0]) >> 5)
self.verify(version[n:], version[:n])
except (IndexError, crypto.Error):
raise VerifyError(None, None, 'invalid network version')
class Peer(object): class Peer(object):
""" """
...@@ -175,6 +182,7 @@ class Peer(object): ...@@ -175,6 +182,7 @@ class Peer(object):
_hello = _last = 0 _hello = _last = 0
_key = newHmacSecret() _key = newHmacSecret()
stop_date = float('inf') stop_date = float('inf')
version = ''
def __init__(self, prefix): def __init__(self, prefix):
assert len(prefix) == 16 or prefix == ('0' * 14 + '1' + '0' * 65), prefix assert len(prefix) == 16 or prefix == ('0' * 14 + '1' + '0' * 65), prefix
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment