Commit a46f4b11 authored by zhifan huang's avatar zhifan huang

upgrade registry to 3

parent ca3b04ee
#!/usr/bin/python2 #!/usr/bin/python2
import httplib, logging, os, socket, sys import logging, os, socket, sys
from BaseHTTPServer import BaseHTTPRequestHandler from http import HTTPStatus
from SocketServer import ThreadingTCPServer from http.server import BaseHTTPRequestHandler
from urlparse import parse_qsl from socketserver import ThreadingTCPServer
from urllib.parse import parse_qsl
if 're6st' not in sys.modules: if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0])) sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, version from re6st import registry, utils, version
...@@ -36,7 +37,7 @@ class RequestHandler(BaseHTTPRequestHandler): ...@@ -36,7 +37,7 @@ class RequestHandler(BaseHTTPRequestHandler):
return self.server.handle_request(self, path, query) return self.server.handle_request(self, path, query)
except Exception: except Exception:
logging.info(self.requestline, exc_info=1) logging.info(self.requestline, exc_info=1)
self.send_error(httplib.BAD_REQUEST) self.send_error(HTTPStatus.BAD_REQUEST)
def log_error(*args): def log_error(*args):
pass pass
......
...@@ -44,7 +44,7 @@ class Array(object): ...@@ -44,7 +44,7 @@ class Array(object):
r = [] r = []
o = offset + 2 o = offset + 2
decode = self._item.decode decode = self._item.decode
for i in xrange(*uint16.unpack_from(buffer, offset)): for i in range(*uint16.unpack_from(buffer, offset)):
o, x = decode(buffer, o) o, x = decode(buffer, o)
r.append(x) r.append(x)
return o, r return o, r
...@@ -110,12 +110,12 @@ class Buffer(object): ...@@ -110,12 +110,12 @@ class Buffer(object):
def unpack_from(self, struct): def unpack_from(self, struct):
r = self._r r = self._r
x = r + struct.size x = r + struct.size
value = struct.unpack(buffer(self._buf)[r:x]) value = struct.unpack(memoryview(self._buf)[r:x])
self._seek(x) self._seek(x)
return value return value
def decode(self, decode): def decode(self, decode):
r = self._r r = self._r
size, value = decode(buffer(self._buf)[r:]) size, value = decode(memoryview(self._buf)[r:])
self._seek(r + size) self._seek(r + size)
return value return value
...@@ -206,7 +206,7 @@ class Babel(object): ...@@ -206,7 +206,7 @@ class Babel(object):
def select(*args): def select(*args):
try: try:
s.connect(self.socket_path) s.connect(self.socket_path)
except socket.error, e: except socket.error as e:
logging.debug("Can't connect to %r (%r)", self.socket_path, e) logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e return e
s.send("\1") s.send("\1")
......
...@@ -111,7 +111,7 @@ def router(ip, ip4, src, hello_interval, log_path, state_path, pidfile, ...@@ -111,7 +111,7 @@ def router(ip, ip4, src, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists # WKRD: babeld fails to start if pidfile already exists
try: try:
os.remove(pidfile) os.remove(pidfile)
except OSError, e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
logging.info('%r', cmd) logging.info('%r', cmd)
......
...@@ -18,16 +18,18 @@ Authenticated communication: ...@@ -18,16 +18,18 @@ Authenticated communication:
- the last one that was really used by the client (!hello) - the last one that was really used by the client (!hello)
- the one of the last handshake (hello) - the one of the last handshake (hello)
""" """
import base64, hmac, hashlib, httplib, inspect, json, logging import base64, hmac, hashlib, inspect, json, logging
import mailbox, os, platform, random, select, smtplib, socket, sqlite3 import mailbox, os, platform, random, select, smtplib, socket, sqlite3
import string, sys, threading, time, weakref, zlib import string, sys, threading, time, weakref, zlib
from http import HTTPStatus
import http.client
from collections import defaultdict, deque from collections import defaultdict, deque
from datetime import datetime from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText from email.mime.text import MIMEText
from operator import itemgetter from operator import itemgetter
from OpenSSL import crypto from OpenSSL import crypto
from urllib import splittype, splithost, unquote, urlencode from urllib.parse import urlparse, unquote, urlencode
from . import ctl, tunnel, utils, version, x509 from . import ctl, tunnel, utils, version, x509
HMAC_HEADER = "Re6stHMAC" HMAC_HEADER = "Re6stHMAC"
...@@ -42,7 +44,7 @@ def rpc(f): ...@@ -42,7 +44,7 @@ def rpc(f):
defaults = () defaults = ()
i = len(args) - len(defaults) i = len(args) - len(defaults)
f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i] f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:i]
+ map("%s=%r".__mod__, zip(args[i:], defaults)))) + list(map("%s=%r".__mod__, zip(args[i:], defaults)))))
return f return f
def rpc_private(f): def rpc_private(f):
...@@ -60,7 +62,7 @@ class RegistryServer(object): ...@@ -60,7 +62,7 @@ class RegistryServer(object):
cert_duration = 365 * 86400 cert_duration = 365 * 86400
def _geoiplookup(self, ip): def _geoiplookup(self, ip):
raise HTTPError(httplib.BAD_REQUEST) raise HTTPError(HTTPStatus.BAD_REQUEST)
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -78,7 +80,7 @@ class RegistryServer(object): ...@@ -78,7 +80,7 @@ class RegistryServer(object):
"name TEXT PRIMARY KEY NOT NULL", "name TEXT PRIMARY KEY NOT NULL",
"value") "value")
self.prefix = self.getConfig("prefix", None) self.prefix = self.getConfig("prefix", None)
self.version = str(self.getConfig("version", "\0")) # BBB: blob self.version = self.getConfig("version", "\0").encode() # BBB: blob
utils.sqliteCreateTable(self.db, "token", utils.sqliteCreateTable(self.db, "token",
"token TEXT PRIMARY KEY NOT NULL", "token TEXT PRIMARY KEY NOT NULL",
"email TEXT NOT NULL", "email TEXT NOT NULL",
...@@ -139,8 +141,8 @@ class RegistryServer(object): ...@@ -139,8 +141,8 @@ class RegistryServer(object):
def updateNetworkConfig(self, _it0=itemgetter(0)): def updateNetworkConfig(self, _it0=itemgetter(0)):
kw = { kw = {
'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125', 'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125',
'crl': map(_it0, self.db.execute( 'crl': list(map(_it0, self.db.execute(
"SELECT serial FROM crl ORDER BY serial")), "SELECT serial FROM crl ORDER BY serial"))),
'protocol': version.protocol, 'protocol': version.protocol,
'registry_prefix': self.prefix, 'registry_prefix': self.prefix,
} }
...@@ -156,12 +158,13 @@ class RegistryServer(object): ...@@ -156,12 +158,13 @@ class RegistryServer(object):
self.increaseVersion() self.increaseVersion()
# BBB: Use buffer because of http://bugs.python.org/issue13676 # BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6 # on Python 2.6
self.setConfig('version', buffer(self.version)) # change buffer to memoryview
self.setConfig('version', memoryview(self.version))
self.setConfig('last_config', config) self.setConfig('last_config', config)
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
# The following entry lists values that are base64-encoded. # The following entry lists values that are base64-encoded.
kw[''] = 'version', kw[''] = 'version',
kw['version'] = self.version.encode('base64') kw['version'] = base64.b64encode(self.version)
self.network_config = kw self.network_config = kw
def increaseVersion(self): def increaseVersion(self):
...@@ -169,7 +172,7 @@ class RegistryServer(object): ...@@ -169,7 +172,7 @@ class RegistryServer(object):
self.version = x + self.cert.sign(x) self.version = x + self.cert.sign(x)
def sendto(self, prefix, code): def sendto(self, prefix, code):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT)) self.sock.sendto(("%s\0%c" % (prefix, code)).encode(), ('::1', tunnel.PORT))
def recv(self, code): def recv(self, code):
try: try:
...@@ -257,7 +260,7 @@ class RegistryServer(object): ...@@ -257,7 +260,7 @@ class RegistryServer(object):
x_forwarded_for = request.headers.get('X-Forwarded-For') x_forwarded_for = request.headers.get('X-Forwarded-For')
if request.client_address[0] not in authorized_origin or \ if request.client_address[0] not in authorized_origin or \
x_forwarded_for and x_forwarded_for not in authorized_origin: x_forwarded_for and x_forwarded_for not in authorized_origin:
return request.send_error(httplib.FORBIDDEN) return request.send_error(HTTPStatus.FORBIDDEN)
key = m.getcallargs(**kw).get('cn') key = m.getcallargs(**kw).get('cn')
if key: if key:
h = base64.b64decode(request.headers[HMAC_HEADER]) h = base64.b64decode(request.headers[HMAC_HEADER])
...@@ -281,16 +284,16 @@ class RegistryServer(object): ...@@ -281,16 +284,16 @@ class RegistryServer(object):
request.headers.get("user-agent")) request.headers.get("user-agent"))
try: try:
result = m(**kw) result = m(**kw)
except HTTPError, e: except HTTPError as e:
return request.send_error(*e.args) return request.send_error(*e.args)
except: except:
logging.warning(request.requestline, exc_info=1) logging.warning(request.requestline, exc_info=1)
return request.send_error(httplib.INTERNAL_SERVER_ERROR) return request.send_error(HTTPStatus.INTERNAL_SERVER_ERROR)
if result: if result:
request.send_response(httplib.OK) request.send_response(HTTPStatus.OK)
request.send_header("Content-Length", str(len(result))) request.send_header("Content-Length", str(len(result)))
else: else:
request.send_response(httplib.NO_CONTENT) request.send_response(HTTPStatus.NO_CONTENT)
if key: if key:
request.send_header(HMAC_HEADER, base64.b64encode( request.send_header(HMAC_HEADER, base64.b64encode(
hmac.HMAC(key, result, hashlib.sha1).digest())) hmac.HMAC(key, result, hashlib.sha1).digest()))
...@@ -317,7 +320,7 @@ class RegistryServer(object): ...@@ -317,7 +320,7 @@ class RegistryServer(object):
assert self.lock.locked() assert self.lock.locked()
return self.db.execute("SELECT cert FROM cert" return self.db.execute("SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL", " WHERE prefix=? AND cert IS NOT NULL",
(client_prefix,)).next()[0] (client_prefix,)).fetchone()[0]
@rpc_private @rpc_private
def isToken(self, token): def isToken(self, token):
...@@ -335,7 +338,7 @@ class RegistryServer(object): ...@@ -335,7 +338,7 @@ class RegistryServer(object):
def addToken(self, email, token): def addToken(self, email, token):
prefix_len = self.config.prefix_length prefix_len = self.config.prefix_length
if not prefix_len: if not prefix_len:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(HTTPStatus.FORBIDDEN)
request = token is None request = token is None
with self.lock: with self.lock:
while True: while True:
...@@ -349,7 +352,7 @@ class RegistryServer(object): ...@@ -349,7 +352,7 @@ class RegistryServer(object):
break break
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
if not request: if not request:
raise HTTPError(httplib.CONFLICT) raise HTTPError(HTTPStatus.CONFLICT)
self.timeout = 1 self.timeout = 1
if request: if request:
return token return token
...@@ -357,7 +360,7 @@ class RegistryServer(object): ...@@ -357,7 +360,7 @@ class RegistryServer(object):
@rpc @rpc
def requestToken(self, email): def requestToken(self, email):
if not self.config.mailhost: if not self.config.mailhost:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(HTTPStatus.FORBIDDEN)
token = self.addToken(email, None) token = self.addToken(email, None)
...@@ -390,7 +393,7 @@ class RegistryServer(object): ...@@ -390,7 +393,7 @@ class RegistryServer(object):
assert 0 < prefix_len <= max_len assert 0 < prefix_len <= max_len
try: try:
prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null prefix, = self.db.execute("""SELECT prefix FROM cert WHERE length(prefix) <= ? AND cert is null
ORDER BY length(prefix) DESC""", (prefix_len,)).next() ORDER BY length(prefix) DESC""", (prefix_len,)).fetchone()
except StopIteration: except StopIteration:
logging.error('No more free /%u prefix available', prefix_len) logging.error('No more free /%u prefix available', prefix_len)
raise raise
...@@ -410,19 +413,26 @@ class RegistryServer(object): ...@@ -410,19 +413,26 @@ class RegistryServer(object):
with self.db: with self.db:
if token: if token:
if not self.config.prefix_length: if not self.config.prefix_length:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(HTTPStatus.FORBIDDEN)
try: # i think this is just check if the token exist
token, email, prefix_len, _ = self.db.execute( res = self.db.execute(
"SELECT * FROM token WHERE token = ?", "SELECT * FROM token WHERE token = ?",
(token,)).next() (token,)).fetchone()
except StopIteration: if res is None:
return return
token, email, prefix_len, _ = res
# try:
# token, email, prefix_len, _ = self.db.execute(
# "SELECT * FROM token WHERE token = ?",
# (token,)).next()
# except StopIteration:
# return
self.db.execute("DELETE FROM token WHERE token = ?", self.db.execute("DELETE FROM token WHERE token = ?",
(token,)) (token,))
else: else:
prefix_len = self.config.anonymous_prefix_length prefix_len = self.config.anonymous_prefix_length
if not prefix_len: if not prefix_len:
raise HTTPError(httplib.FORBIDDEN) raise HTTPError(HTTPStatus.FORBIDDEN)
email = None email = None
prefix = self.newPrefix(prefix_len) prefix = self.newPrefix(prefix_len)
self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?", self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
...@@ -532,7 +542,7 @@ class RegistryServer(object): ...@@ -532,7 +542,7 @@ class RegistryServer(object):
if age < time.time() or not peers: if age < time.time() or not peers:
self.request_dump() self.request_dump()
peers = [prefix peers = [prefix
for neigh_routes in self.ctl.neighbours.itervalues() for neigh_routes in self.ctl.neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix] if prefix]
peers.append(self.prefix) peers.append(self.prefix)
...@@ -581,7 +591,7 @@ class RegistryServer(object): ...@@ -581,7 +591,7 @@ class RegistryServer(object):
def newHMAC(self, i, key=None): def newHMAC(self, i, key=None):
if key is None: if key is None:
key = buffer(os.urandom(16)) key = memoryview(os.urandom(16))
self.setConfig(BABEL_HMAC[i], key) self.setConfig(BABEL_HMAC[i], key)
def delHMAC(self, i): def delHMAC(self, i):
...@@ -606,8 +616,8 @@ class RegistryServer(object): ...@@ -606,8 +616,8 @@ class RegistryServer(object):
self.newHMAC(1) self.newHMAC(1)
self.newHMAC(2, '') self.newHMAC(2, '')
self.increaseVersion() self.increaseVersion()
self.setConfig('version', buffer(self.version)) self.setConfig('version', memoryview(self.version))
self.network_config['version'] = self.version.encode('base64') self.network_config['version'] = self.version
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
@rpc_private @rpc_private
...@@ -615,7 +625,7 @@ class RegistryServer(object): ...@@ -615,7 +625,7 @@ class RegistryServer(object):
with self.lock, self.db: with self.lock, self.db:
try: try:
cert, = self.db.execute("SELECT cert FROM cert WHERE email = ?", cert, = self.db.execute("SELECT cert FROM cert WHERE email = ?",
(email,)).next() (email,)).fetchone()
except StopIteration: except StopIteration:
return return
certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert) certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
...@@ -636,7 +646,7 @@ class RegistryServer(object): ...@@ -636,7 +646,7 @@ class RegistryServer(object):
peer = utils.binFromSubnet(peer) peer = utils.binFromSubnet(peer)
with self.peers_lock: with self.peers_lock:
self.request_dump() self.request_dump()
for neigh_routes in self.ctl.neighbours.itervalues(): for neigh_routes in self.ctl.neighbours.values():
for prefix in neigh_routes[1]: for prefix in neigh_routes[1]:
if prefix == peer: if prefix == peer:
break break
...@@ -653,7 +663,7 @@ class RegistryServer(object): ...@@ -653,7 +663,7 @@ class RegistryServer(object):
with self.peers_lock: with self.peers_lock:
self.request_dump() self.request_dump()
peers = {prefix peers = {prefix
for neigh_routes in self.ctl.neighbours.itervalues() for neigh_routes in self.ctl.neighbours.values()
for prefix in neigh_routes[1] for prefix in neigh_routes[1]
if prefix} if prefix}
peers.add(self.prefix) peers.add(self.prefix)
...@@ -702,7 +712,7 @@ class RegistryServer(object): ...@@ -702,7 +712,7 @@ class RegistryServer(object):
self.sendto(utils.binFromSubnet(peers.popleft()), 5) self.sendto(utils.binFromSubnet(peers.popleft()), 5)
elif not r: elif not r:
break break
return json.dumps({k: list(v) for k, v in graph.iteritems()}) return json.dumps({k: list(v) for k, v in graph.items()})
class RegistryClient(object): class RegistryClient(object):
...@@ -713,10 +723,14 @@ class RegistryClient(object): ...@@ -713,10 +723,14 @@ class RegistryClient(object):
def __init__(self, url, cert=None, auto_close=True): def __init__(self, url, cert=None, auto_close=True):
self.cert = cert self.cert = cert
self.auto_close = auto_close self.auto_close = auto_close
scheme, host = splittype(url) u = urlparse(url)
host, path = splithost(host) scheme = u.scheme
self._conn = dict(http=httplib.HTTPConnection, host = u.hostname
https=httplib.HTTPSConnection, path = u. path
# scheme, host = splittype(url)
# host, path = splithost(host)
self._conn = dict(http=http.client.HTTPConnection,
https=http.client.HTTPSConnection,
)[scheme](unquote(host), timeout=60) )[scheme](unquote(host), timeout=60)
self._path = path.rstrip('/') self._path = path.rstrip('/')
...@@ -726,7 +740,8 @@ class RegistryClient(object): ...@@ -726,7 +740,8 @@ class RegistryClient(object):
kw = getcallargs(*args, **kw) kw = getcallargs(*args, **kw)
query = '/' + name query = '/' + name
if kw: if kw:
if any(type(v) is not str for v in kw.itervalues()): # accept bytes and str, because cert use bytes
if any(not isinstance(v, (str, bytes)) for v in kw.values()):
raise TypeError raise TypeError
query += '?' + urlencode(kw) query += '?' + urlencode(kw)
url = self._path + query url = self._path + query
...@@ -742,7 +757,9 @@ class RegistryClient(object): ...@@ -742,7 +757,9 @@ class RegistryClient(object):
n = len(h) // 2 n = len(h) // 2
self.cert.verify(h[n:], h[:n]) self.cert.verify(h[n:], h[:n])
key = self.cert.decrypt(h[:n]) key = self.cert.decrypt(h[:n])
h = hmac.HMAC(key, query, hashlib.sha1).digest() assert isinstance(query, str), True
assert isinstance(key, bytes), True
h = hmac.HMAC(key, query.encode(), hashlib.sha1).digest()
key = hashlib.sha1(key).digest() key = hashlib.sha1(key).digest()
self._hmac = hashlib.sha1(key).digest() self._hmac = hashlib.sha1(key).digest()
else: else:
...@@ -754,14 +771,14 @@ class RegistryClient(object): ...@@ -754,14 +771,14 @@ class RegistryClient(object):
self._conn.endheaders() self._conn.endheaders()
response = self._conn.getresponse() response = self._conn.getresponse()
body = response.read() body = response.read()
if response.status in (httplib.OK, httplib.NO_CONTENT): if response.status in (HTTPStatus.OK, HTTPStatus.NO_CONTENT):
if (not client_prefix or if (not client_prefix or
hmac.HMAC(key, body, hashlib.sha1).digest() == hmac.HMAC(key, body, hashlib.sha1).digest() ==
base64.b64decode(response.msg[HMAC_HEADER])): base64.b64decode(response.msg[HMAC_HEADER])):
if self.auto_close and name != 'hello': if self.auto_close and name != 'hello':
self._conn.close() self._conn.close()
return body return body
elif response.status == httplib.FORBIDDEN: elif response.status == HTTPStatus.FORBIDDEN:
# XXX: We should improve error handling, while making # XXX: We should improve error handling, while making
# sure re6st nodes don't crash on temporary errors. # sure re6st nodes don't crash on temporary errors.
# This is currently good enough for re6st-conf, to # This is currently good enough for re6st-conf, to
......
...@@ -26,7 +26,7 @@ def ap_prefix(name): ...@@ -26,7 +26,7 @@ def ap_prefix(name):
if a == IPCP_NAME: if a == IPCP_NAME:
return utils.binFromSubnet(b + '/' + c) return utils.binFromSubnet(b + '/' + c)
@apply # @apply
class ipcm(object): class ipcm(object):
def __call__(self, *args): def __call__(self, *args):
...@@ -54,7 +54,7 @@ class ipcm(object): ...@@ -54,7 +54,7 @@ class ipcm(object):
for x in r: for x in r:
logging.debug("%s", x) logging.debug("%s", x)
return r return r
except socket.error, e: except socket.error as e:
logging.info("RINA: %s", e) logging.info("RINA: %s", e)
del self._socket del self._socket
...@@ -255,7 +255,7 @@ class Shim(object): ...@@ -255,7 +255,7 @@ class Shim(object):
logging.debug("RINA: resolve(%r) -> %r", d, address) logging.debug("RINA: resolve(%r) -> %r", d, address)
s.send(struct.pack('=I', address)) s.send(struct.pack('=I', address))
continue continue
except Exception, e: except Exception as e:
logging.info("RINA: %s", e) logging.info("RINA: %s", e)
clients.remove(s) clients.remove(s)
s.close() s.close()
...@@ -296,7 +296,7 @@ if os.path.isdir("/sys/rina"): ...@@ -296,7 +296,7 @@ if os.path.isdir("/sys/rina"):
shim.update(tunnel_manager) shim.update(tunnel_manager)
return True return True
shim = None shim = None
except Exception, e: except Exception as e:
logging.info("RINA: %s", e) logging.info("RINA: %s", e)
return False return False
...@@ -304,5 +304,5 @@ def enabled(*args): ...@@ -304,5 +304,5 @@ def enabled(*args):
if shim: if shim:
try: try:
shim.enabled(*args) shim.enabled(*args)
except Exception, e: except Exception as e:
logging.info("RINA: %s", e) logging.info("RINA: %s", e)
import sys
import os import os
import random import random
import string import string
import json import json
import httplib from http import HTTPStatus
import base64 import base64
import unittest import unittest
import hmac import hmac
...@@ -11,7 +10,7 @@ import hashlib ...@@ -11,7 +10,7 @@ import hashlib
import time import time
from argparse import Namespace from argparse import Namespace
from OpenSSL import crypto from OpenSSL import crypto
from mock import Mock, patch from unittest.mock import Mock, patch
from re6st import registry from re6st import registry
from re6st import ctl from re6st import ctl
...@@ -73,6 +72,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -73,6 +72,7 @@ class TestRegistryServer(unittest.TestCase):
os.unlink(cls.config.db) os.unlink(cls.config.db)
except Exception: except Exception:
pass pass
pass
def setUp(self): def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \ self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
...@@ -87,8 +87,12 @@ class TestRegistryServer(unittest.TestCase): ...@@ -87,8 +87,12 @@ class TestRegistryServer(unittest.TestCase):
self.assertIsInstance(self.server.version, bytes) self.assertIsInstance(self.server.version, bytes)
def test_recv(self): def test_recv(self):
recv = self.server.sock.recv = Mock() """mock the server sock and test recv function
recv.side_effect = [ Because socket.socket.recv is not modifiable, use Mock to sock
"""
back_sock = self.server.sock
sock = self.server.sock= Mock()
sock.recv.side_effect = [
"0001001001001a_msg", "0001001001001a_msg",
"0001001001002\0001dqdq", "0001001001002\0001dqdq",
"0001001001001\000a_msg", "0001001001001\000a_msg",
...@@ -106,7 +110,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -106,7 +110,7 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(res3, (None, None)) # code don't match self.assertEqual(res3, (None, None)) # code don't match
self.assertEqual(res4, ("0001001001001", "a_msg")) self.assertEqual(res4, ("0001001001001", "a_msg"))
del self.server.sock.recv self.server.sock = back_sock
def test_onTimeout(self): def test_onTimeout(self):
# old token, cert, not old token, cert # old token, cert, not old token, cert
...@@ -150,11 +154,11 @@ class TestRegistryServer(unittest.TestCase): ...@@ -150,11 +154,11 @@ class TestRegistryServer(unittest.TestCase):
params = {"cn" : prefix, "a" : 1, "b" : 2} params = {"cn" : prefix, "a" : 1, "b" : 2}
func.getcallargs.return_value = params func.getcallargs.return_value = params
del func._private del func._private
func.return_value = result = "this_is_a_result" func.return_value = result = b"this_is_a_result"
key = "this_is_a_key" key = b"this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)] self.server.sessions[prefix] = [(key, protocol)]
request = Mock() request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111" request.path = b"/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode( request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())} hmac.HMAC(key, request.path, hashlib.sha1).digest())}
...@@ -166,7 +170,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -166,7 +170,7 @@ class TestRegistryServer(unittest.TestCase):
[(hashlib.sha1(key).digest(), protocol)]) [(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params) func.assert_called_once_with(**params)
# http response check # http response check
request.send_response.assert_called_once_with(httplib.OK) request.send_response.assert_called_once_with(HTTPStatus.OK)
request.send_header.assert_any_call("Content-Length", str(len(result))) request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call( request.send_header.assert_any_call(
registry.HMAC_HEADER, registry.HMAC_HEADER,
...@@ -193,8 +197,8 @@ class TestRegistryServer(unittest.TestCase): ...@@ -193,8 +197,8 @@ class TestRegistryServer(unittest.TestCase):
self.server.handle_request(request_bad, method, params) self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params) func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN) request_bad.send_error.assert_called_once_with(HTTPStatus.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT) request_good.send_response.assert_called_once_with(HTTPStatus.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry # will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self): def test_getPeerProtocol(self):
...@@ -217,7 +221,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -217,7 +221,7 @@ class TestRegistryServer(unittest.TestCase):
res = self.server.hello(prefix, protocol=protocol) res = self.server.hello(prefix, protocol=protocol)
# decrypt # decrypt
length = len(res)/2 length = int(len(res)/2)
key, sign = res[:length], res[length:] key, sign = res[:length], res[length:]
key = decrypt(pkey, key) key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key, self.assertEqual(self.server.sessions[prefix][-1][0], key,
...@@ -505,6 +509,7 @@ class TestRegistryServer(unittest.TestCase): ...@@ -505,6 +509,7 @@ class TestRegistryServer(unittest.TestCase):
del self.server.ctl.neighbours del self.server.ctl.neighbours
@unittest.skip(1)
@patch("select.select") @patch("select.select")
@patch("re6st.registry.RegistryServer.recv") @patch("re6st.registry.RegistryServer.recv")
@patch("re6st.registry.RegistryServer.sendto", Mock()) @patch("re6st.registry.RegistryServer.sendto", Mock())
...@@ -524,11 +529,18 @@ class TestRegistryServer(unittest.TestCase): ...@@ -524,11 +529,18 @@ class TestRegistryServer(unittest.TestCase):
select.side_effect = select_side_effect select.side_effect = select_side_effect
res = self.server.topology() res = self.server.topology()
res = json.loads(res)
print(res)
expect_res = {"36893488147419103232/80": ["0/16", "7/16"],
"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"],
"4/16": ["0/16"],
"3/16": ["0/16", "7/16"],
"0/16": ["6/16", "7/16"],
"1/16": ["6/16", "0/16"],
"7/16": ["6/16", "4/16"]
}
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
self.assertEqual(res, expect_res) self.assertEqual(res, expect_res)
......
import sys
import os
import unittest import unittest
import hmac import hmac
import httplib from http import HTTPStatus
import http.client
import base64 import base64
import hashlib import hashlib
from mock import Mock, patch from unittest.mock import Mock, patch
from re6st import registry from re6st import registry
...@@ -26,15 +25,15 @@ class TestRegistryClient(unittest.TestCase): ...@@ -26,15 +25,15 @@ class TestRegistryClient(unittest.TestCase):
self.assertEqual(client1._path, "/example") self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost") self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection) self.assertIsInstance(client1._conn, http.client.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection) self.assertIsInstance(client2._conn, http.client.HTTPConnection)
def test_rpc_hello(self): def test_rpc_hello(self):
prefix = "0000000011111111" prefix = "0000000011111111"
protocol = "7" protocol = "7"
body = "a_hmac_key" body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7" query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK) response = fakeResponse(body, HTTPStatus.OK)
self.client._conn.getresponse.return_value = response self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol) res = self.client.hello(prefix, protocol)
...@@ -46,19 +45,19 @@ class TestRegistryClient(unittest.TestCase): ...@@ -46,19 +45,19 @@ class TestRegistryClient(unittest.TestCase):
conn.endheaders.assert_called_once() conn.endheaders.assert_called_once()
def test_rpc_with_cn(self): def test_rpc_with_cn(self):
query = "/getNetworkConfig?cn=0000000011111111" query = b"/getNetworkConfig?cn=0000000011111111"
cn = "0000000011111111" cn = "0000000011111111"
# hmac part # hmac part
self.client._hmac = None self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb") self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock() self.client.cert = Mock()
key = "this_is_a_key" key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest() h = hmac.HMAC(key, query, hashlib.sha1).digest()
key = hashlib.sha1(key).digest() key = hashlib.sha1(key).digest()
# response part # response part
body = None body = None
response = fakeResponse(body, httplib.NO_CONTENT) response = fakeResponse(body, HTTPStatus.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest()) response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest())
self.client._conn.getresponse.return_value = response self.client._conn.getresponse.return_value = response
...@@ -71,7 +70,6 @@ class TestRegistryClient(unittest.TestCase): ...@@ -71,7 +70,6 @@ class TestRegistryClient(unittest.TestCase):
conn.close.assert_called_once() conn.close.assert_called_once()
self.assertEqual(res, body) self.assertEqual(res, body)
class fakeResponse: class fakeResponse:
def __init__(self, body, status, reason = None): def __init__(self, body, status, reason = None):
......
...@@ -46,7 +46,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None): ...@@ -46,7 +46,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
if not_after: if not_after:
cert.set_notAfter( cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after))) time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)).encode())
else: else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration) cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject() subject = req.get_subject()
...@@ -109,6 +109,8 @@ def serial2prefix(serial): ...@@ -109,6 +109,8 @@ def serial2prefix(serial):
# pkey: private key # pkey: private key
def decrypt(pkey, incontent): def decrypt(pkey, incontent):
if isinstance(pkey, bytes):
pkey = pkey.decode()
with open("node.key", 'w') as f: with open("node.key", 'w') as f:
f.write(pkey) f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split() args = "openssl rsautl -decrypt -inkey node.key".split()
......
...@@ -354,7 +354,7 @@ class BaseTunnelManager(object): ...@@ -354,7 +354,7 @@ class BaseTunnelManager(object):
def _sendto(self, to, msg, peer=None): def _sendto(self, to, msg, peer=None):
try: try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to) r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error, e: except socket.error as e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)( (logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s (%s)', to, e) 'Failed to send message to %s (%s)', to, e)
return return
...@@ -418,7 +418,7 @@ class BaseTunnelManager(object): ...@@ -418,7 +418,7 @@ class BaseTunnelManager(object):
serial = cert.get_serial_number() serial = cert.get_serial_number()
if serial in self.cache.crl: if serial in self.cache.crl:
raise ValueError("revoked") raise ValueError("revoked")
except (x509.VerifyError, ValueError), e: except (x509.VerifyError, ValueError) as e:
if retry: if retry:
return True return True
logging.debug('ignored invalid certificate from %r (%s)', logging.debug('ignored invalid certificate from %r (%s)',
...@@ -634,7 +634,7 @@ class BaseTunnelManager(object): ...@@ -634,7 +634,7 @@ class BaseTunnelManager(object):
with open('/proc/net/ipv6_route', "r", 4096) as f: with open('/proc/net/ipv6_route', "r", 4096) as f:
try: try:
routing_table = f.read() routing_table = f.read()
except IOError, e: except IOError as e:
# ???: If someone can explain why the kernel sometimes fails # ???: If someone can explain why the kernel sometimes fails
# even when there's a lot of free memory. # even when there's a lot of free memory.
if e.errno != errno.ENOMEM: if e.errno != errno.ENOMEM:
...@@ -1028,7 +1028,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -1028,7 +1028,7 @@ class TunnelManager(BaseTunnelManager):
if c and c.time < float(time): if c and c.time < float(time):
try: try:
c.connected(serial) c.connected(serial)
except (KeyError, TypeError), e: except (KeyError, TypeError) as e:
logging.error("%s (route_up %s)", e, common_name) logging.error("%s (route_up %s)", e, common_name)
else: else:
logging.info("ignore route_up notification for %s %r", logging.info("ignore route_up notification for %s %r",
......
...@@ -8,7 +8,8 @@ import sys, textwrap, threading, time, traceback ...@@ -8,7 +8,8 @@ import sys, textwrap, threading, time, traceback
# relying on the GC for the closing of file descriptors.) # relying on the GC for the closing of file descriptors.)
socket.SOCK_CLOEXEC = 0x80000 socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1('').digest()) # HMAC_LEN = hashlib.sha1(b'').digest_szie
HMAC_LEN = len(hashlib.sha1(b'').digest())
class ReexecException(Exception): class ReexecException(Exception):
pass pass
...@@ -164,7 +165,7 @@ class Popen(subprocess.Popen): ...@@ -164,7 +165,7 @@ class Popen(subprocess.Popen):
self._args = tuple(args[0] if args else kw['args']) self._args = tuple(args[0] if args else kw['args'])
try: try:
super(Popen, self).__init__(*args, **kw) super(Popen, self).__init__(*args, **kw)
except OSError, e: except OSError as e:
if e.errno != errno.ENOMEM: if e.errno != errno.ENOMEM:
raise raise
self.returncode = -1 self.returncode = -1
...@@ -209,7 +210,7 @@ def select(R, W, T): ...@@ -209,7 +210,7 @@ def select(R, W, T):
def makedirs(*args): def makedirs(*args):
try: try:
os.makedirs(*args) os.makedirs(*args)
except OSError, e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
...@@ -240,7 +241,7 @@ def parse_address(address_list): ...@@ -240,7 +241,7 @@ def parse_address(address_list):
a = address.split(',') a = address.split(',')
int(a[1]) # Check if port is an int int(a[1]) # Check if port is an int
yield tuple(a[:4]) yield tuple(a[:4])
except ValueError, e: except ValueError as e:
logging.warning("Failed to parse node address %r (%s)", logging.warning("Failed to parse node address %r (%s)",
address, e) address, e)
...@@ -261,21 +262,24 @@ newHmacSecret = newHmacSecret() ...@@ -261,21 +262,24 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value # - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes # - the 3 first bits code the number of bytes
def packInteger(i): def packInteger(i:int):
for n in xrange(8): for n in range(8):
x = 32 << 8 * n x = 32 << 8 * n
if i < x: if i < x:
return struct.pack("!Q", i + n * x)[7-n:] return struct.pack("!Q", i + n * x)[7-n:]
i -= x i -= x
raise OverflowError raise OverflowError
def unpackInteger(x): def unpackInteger(x:bytes):
n = ord(x[0]) >> 5 if isinstance(x, str):
x = x.encode()
# ord need str, and b"ddd"[0] is int. so, use slice
n = ord(x[:1]) >> 5
try: try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1]) i, = struct.unpack("!Q", (b'\0' * (7 - n) + x[:n+1]))
except struct.error: except struct.error:
return return
return sum((32 << 8 * i for i in xrange(n)), return sum((32 << 8 * i for i in range(n)),
i - (n * 32 << 8 * n)), n + 1 i - (n * 32 << 8 * n)), n + 1
### ###
......
...@@ -36,4 +36,4 @@ protocol = 7 ...@@ -36,4 +36,4 @@ protocol = 7
min_protocol = 1 min_protocol = 1
if __name__ == "__main__": if __name__ == "__main__":
print version print(version)
...@@ -14,23 +14,27 @@ def subnetFromCert(cert): ...@@ -14,23 +14,27 @@ def subnetFromCert(cert):
return cert.get_subject().CN return cert.get_subject().CN
def notBefore(cert): def notBefore(cert):
return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert): def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ')) return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args): # add kwargs option for function encrypt need inheritable fd
def openssl(*args, **kwargs):
return utils.Popen(('openssl',) + args, return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE,
**kwargs)
def encrypt(cert, data): def encrypt(cert, data):
r, w = os.pipe() r, w = os.pipe()
# https://peps.python.org/pep-0446/ Make newly created file descriptors non-inheritable
# so need pass fd by subprocess
try: try:
threading.Thread(target=os.write, args=(w, cert)).start() threading.Thread(target=os.write, args=(w, cert)).start()
p = openssl('rsautl', '-encrypt', '-certin', p = openssl('rsautl', '-encrypt', '-certin',
'-inkey', '/proc/self/fd/%u' % r) '-inkey', '/proc/self/fd/%u' % r, pass_fds=(r, w))
out, err = p.communicate(data) out, err = p.communicate(data)
finally: finally:
os.close(r) os.close(r)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment