Commit 50a5d1d3 authored by Vincent Pelletier's avatar Vincent Pelletier

http: Separate bind from netloc.

netloc is the public access point to a caucase instance.
bind is the private access point to a caucase instance, which may be
different (ex: NAT). Allow overriding netloc address with --bind.
As a consequence, add support for multiple binds: a netloc may resolve to
multiple addresses (ex: one IPv4, one global IPv6 and one Unique Local
Address).
As a further consequence, systematically disable automatic IPv4 binding
when binding to an IPv6 address.
Also, allow overriding netloc port with --base-port. The same port pair
will be used on all bound hosts.
Share SSL context between multiple https sockets.
To increase binding visibility, print bindings, and print when exiting.
parent b1e05975
...@@ -24,6 +24,7 @@ from collections import defaultdict ...@@ -24,6 +24,7 @@ from collections import defaultdict
import datetime import datetime
from getpass import getpass from getpass import getpass
import glob import glob
import itertools
import os import os
import socket import socket
from SocketServer import ThreadingMixIn from SocketServer import ThreadingMixIn
...@@ -77,11 +78,14 @@ class ThreadingWSGIServer(ThreadingMixIn, WSGIServer): ...@@ -77,11 +78,14 @@ class ThreadingWSGIServer(ThreadingMixIn, WSGIServer):
def __init__(self, server_address, *args, **kw): def __init__(self, server_address, *args, **kw):
self.address_family, _, _, _, _ = socket.getaddrinfo(*server_address)[0] self.address_family, _, _, _, _ = socket.getaddrinfo(*server_address)[0]
assert self.address_family in (socket.AF_INET, socket.AF_INET6), (
self.address_family,
)
WSGIServer.__init__(self, server_address, *args, **kw) WSGIServer.__init__(self, server_address, *args, **kw)
def server_bind(self):
if self.address_family == socket.AF_INET6:
# Separate IPv6 and IPv4 port spaces
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
WSGIServer.server_bind(self)
class CaucaseWSGIRequestHandler(WSGIRequestHandler): class CaucaseWSGIRequestHandler(WSGIRequestHandler):
""" """
Make WSGIRequestHandler logging more apache-like. Make WSGIRequestHandler logging more apache-like.
...@@ -146,19 +150,17 @@ def startServerThread(server): ...@@ -146,19 +150,17 @@ def startServerThread(server):
server_thread.daemon = True server_thread.daemon = True
server_thread.start() server_thread.start()
def updateSSLContext( def getSSLContext(
https,
key_len, key_len,
threshold, threshold,
server_key_path, server_key_path,
hostname, hostname,
cau, cau,
cas, cas,
wrap=False,
): ):
""" """
Build a new SSLContext with updated CA certificates, CRL and server key pair, Build a new SSLContext with updated CA certificates, CRL and server key pair,
apply it to <https>.socket and return the datetime of next update. and return it along with the datetime of next update.
""" """
ssl_context = ssl.create_default_context( ssl_context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH, purpose=ssl.Purpose.CLIENT_AUTH,
...@@ -258,18 +260,14 @@ def updateSSLContext( ...@@ -258,18 +260,14 @@ def updateSSLContext(
crt_file.write(new_key_pem) crt_file.write(new_key_pem)
crt_file.write(new_crt_pem) crt_file.write(new_crt_pem)
ssl_context.load_cert_chain(server_key_path) ssl_context.load_cert_chain(server_key_path)
if wrap: return (
https.socket = ssl_context.wrap_socket( ssl_context,
sock=https.socket, utils.load_certificate(
server_side=True,
)
else:
https.socket.context = ssl_context
return utils.load_certificate(
utils.getCert(server_key_path), utils.getCert(server_key_path),
cas_certificate_list, cas_certificate_list,
None, None,
).not_valid_after - threshold_delta ).not_valid_after - threshold_delta,
)
def main(argv=None, until=utils.until): def main(argv=None, until=utils.until):
""" """
...@@ -291,16 +289,29 @@ def main(argv=None, until=utils.until): ...@@ -291,16 +289,29 @@ def main(argv=None, until=utils.until):
parser.add_argument( parser.add_argument(
'--netloc', '--netloc',
required=True, required=True,
help='<host>[:<port>] of HTTP socket. ' help='<host>[:<port>] at which certificate verificators may reach this '
'HTTPS socket netloc will be deduced following caucase rules: if port is ' 'service. This value is embedded in generated certificates (as CRL '
'80 or not provided, https port will be 443, else it will be port + 1. ' 'distribution point, as CA certificate common name, possibly more). '
'If not provided, http port will be picked among available ports and ' 'See --base-port for how https port is derived from this port. '
'https port will be the next port. Also, signed certificates will not '
'contain a CRL distribution point URL. If https port is not available, '
'this program will exit with an aerror status. '
'Note on encoding: only ascii is currently supported. Non-ascii may be ' 'Note on encoding: only ascii is currently supported. Non-ascii may be '
'provided idna-encoded.', 'provided idna-encoded.',
) )
parser.add_argument(
'--base-port',
type=int,
help='Port at which caucase locally binds to provide HTTP service. '
'If this port is 80, HTTPS service is provided on port 443, otherwise '
'it is provided on --base-port + 1. '
'If derived HTTPS port is not available, caucase will exit with an error '
'status. default: --netloc\'s port, or 80',
)
parser.add_argument(
'--bind',
default=[],
action='append',
help='Address on which caucase locally binds. '
'default: addresses --netloc\'s <host> resolves into.',
)
parser.add_argument( parser.add_argument(
'--threshold', '--threshold',
default=31, default=31,
...@@ -392,9 +403,10 @@ def main(argv=None, until=utils.until): ...@@ -392,9 +403,10 @@ def main(argv=None, until=utils.until):
backup_group.add_argument( backup_group.add_argument(
'--backup-directory', '--backup-directory',
help='Backup directory path. Backups will be periodically stored in ' help='Backup directory path. Backups will be periodically stored in '
'given directory, encrypted with all certificates which are valid at the ' 'given directory, encrypted with all user certificates which are valid '
'time of backup generation. Any one of the associated private keys can ' 'at backup generation time. Any one of the associated private keys can '
'decypher it. If not set, no backup will be created.', 'decypher it. If not set or no user certificate exists, no backup will '
'be created.',
) )
backup_group.add_argument( backup_group.add_argument(
'--backup-period', '--backup-period',
...@@ -407,7 +419,8 @@ def main(argv=None, until=utils.until): ...@@ -407,7 +419,8 @@ def main(argv=None, until=utils.until):
base_url = u'http://' + args.netloc.decode('ascii') base_url = u'http://' + args.netloc.decode('ascii')
parsed_base_url = urlparse(base_url) parsed_base_url = urlparse(base_url)
hostname = parsed_base_url.hostname hostname = parsed_base_url.hostname
http_port = parsed_base_url.port http_port = parsed_base_url.port if args.base_port is None else args.base_port
https_port = 443 if http_port == 80 else http_port + 1
cau_crt_life_time = args.user_crt_validity cau_crt_life_time = args.user_crt_validity
cau = UserCertificateAuthority( cau = UserCertificateAuthority(
storage=SQLite3Storage( storage=SQLite3Storage(
...@@ -447,29 +460,64 @@ def main(argv=None, until=utils.until): ...@@ -447,29 +460,64 @@ def main(argv=None, until=utils.until):
lock_auto_sign_csr_amount=args.lock_auto_approve_count, lock_auto_sign_csr_amount=args.lock_auto_approve_count,
) )
application = Application(cau=cau, cas=cas) application = Application(cau=cau, cas=cas)
http = make_server( http_list = []
host=hostname, https_list = []
known_host_set = set()
for bind in args.bind or [hostname]:
for family, _, _, _, sockaddr in socket.getaddrinfo(
bind,
0,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
):
if family == socket.AF_INET:
host, _ = sockaddr
elif family == socket.AF_INET6:
host, _, _, _ = sockaddr
else:
continue
if host in known_host_set:
continue
known_host_set.add(host)
sys.stderr.write(
'Listening on [%s]:%i-%i\n' % (
host,
http_port,
https_port,
),
)
http_list.append(
make_server(
host=host,
port=http_port, port=http_port,
app=application, app=application,
server_class=ThreadingWSGIServer, server_class=ThreadingWSGIServer,
handler_class=CaucaseWSGIRequestHandler, handler_class=CaucaseWSGIRequestHandler,
),
) )
https = make_server( https_list.append(
host=hostname, make_server(
port=443 if http_port == 80 else http_port + 1, host=host,
port=https_port,
app=application, app=application,
server_class=ThreadingWSGIServer, server_class=ThreadingWSGIServer,
handler_class=CaucaseSSLWSGIRequestHandler, handler_class=CaucaseSSLWSGIRequestHandler,
),
) )
next_deadline = next_ssl_update = updateSSLContext( ssl_context, next_ssl_update = getSSLContext(
https=https,
key_len=args.key_len, key_len=args.key_len,
threshold=args.threshold, threshold=args.threshold,
server_key_path=args.server_key, server_key_path=args.server_key,
hostname=hostname, hostname=hostname,
cau=cau, cau=cau,
cas=cas, cas=cas,
wrap=True, )
next_deadline = next_ssl_update
for https in https_list:
https.socket = ssl_context.wrap_socket(
sock=https.socket,
server_side=True,
) )
if args.backup_directory: if args.backup_directory:
backup_period = datetime.timedelta(args.backup_period, 0) backup_period = datetime.timedelta(args.backup_period, 0)
...@@ -488,14 +536,13 @@ def main(argv=None, until=utils.until): ...@@ -488,14 +536,13 @@ def main(argv=None, until=utils.until):
) )
else: else:
next_backup = None next_backup = None
startServerThread(http) for server in itertools.chain(http_list, https_list):
startServerThread(https) startServerThread(server)
try: try:
while True: while True:
now = until(next_deadline) now = until(next_deadline)
if now >= next_ssl_update: if now >= next_ssl_update:
next_ssl_update = updateSSLContext( ssl_context, next_ssl_update = getSSLContext(
https=https,
key_len=args.key_len, key_len=args.key_len,
threshold=args.threshold, threshold=args.threshold,
server_key_path=args.server_key, server_key_path=args.server_key,
...@@ -503,6 +550,8 @@ def main(argv=None, until=utils.until): ...@@ -503,6 +550,8 @@ def main(argv=None, until=utils.until):
cau=cau, cau=cau,
cas=cas, cas=cas,
) )
for https in https_list:
https.socket.context = ssl_context
if next_backup is None: if next_backup is None:
next_deadline = next_ssl_update next_deadline = next_ssl_update
else: else:
...@@ -529,8 +578,9 @@ def main(argv=None, until=utils.until): ...@@ -529,8 +578,9 @@ def main(argv=None, until=utils.until):
except utils.SleepInterrupt: except utils.SleepInterrupt:
pass pass
finally: finally:
https.shutdown() sys.stderr.write('Exiting\n')
http.shutdown() for server in itertools.chain(http_list, https_list):
server.shutdown()
def manage(argv=None): def manage(argv=None):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment