Commit c61cab22 authored by Julien Muchembled's avatar Julien Muchembled

Implement automatic renewal of client certificate

parent e24eb3f5
#!/usr/bin/python #!/usr/bin/python
import argparse, atexit, errno, os, subprocess, sqlite3, sys import argparse, atexit, errno, os, subprocess, sqlite3, sys, time
from OpenSSL import crypto from OpenSSL import crypto
from re6st import registry, utils from re6st import registry, utils
...@@ -139,7 +139,12 @@ def main(): ...@@ -139,7 +139,12 @@ def main():
os.ftruncate(cert_fd, len(cert)) os.ftruncate(cert_fd, len(cert))
os.close(cert_fd) os.close(cert_fd)
print "Certificate setup complete." cert = loadCert(cert)
not_after = utils.notAfter(cert)
print("Setup complete. Certificate is valid until %s UTC"
" and will be automatically renewed after %s UTC" % (
time.asctime(time.gmtime(not_after)),
time.asctime(time.gmtime(not_after - registry.RENEW_PERIOD))))
if not os.path.lexists(conf_path): if not os.path.lexists(conf_path):
create(conf_path, """\ create(conf_path, """\
...@@ -160,7 +165,7 @@ dh %s ...@@ -160,7 +165,7 @@ dh %s
""" % (config.registry, ca_path, cert_path, key_path, dh_path)) """ % (config.registry, ca_path, cert_path, key_path, dh_path))
print "Sample configuration file created." print "Sample configuration file created."
cn = utils.subnetFromCert(loadCert(cert)) cn = utils.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn) subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \ print "Your subnet: %s/%u (CN=%s)" \
% (utils.ipFromBin(subnet), len(subnet), cn) % (utils.ipFromBin(subnet), len(subnet), cn)
......
...@@ -8,6 +8,7 @@ from urllib import splittype, splithost, splitport, urlencode ...@@ -8,6 +8,7 @@ from urllib import splittype, splithost, splitport, urlencode
from . import tunnel, utils from . import tunnel, utils
HMAC_HEADER = "Re6stHMAC" HMAC_HEADER = "Re6stHMAC"
RENEW_PERIOD = 30 * 86400
class getcallargs(type): class getcallargs(type):
...@@ -190,29 +191,39 @@ class RegistryServer(object): ...@@ -190,29 +191,39 @@ class RegistryServer(object):
(token,)).next() (token,)).next()
except StopIteration: except StopIteration:
return return
self.db.execute("DELETE FROM token WHERE token = ?", (token,)) self.db.execute("DELETE FROM token WHERE token = ?",
(token,))
# Get a new prefix
prefix = self._getPrefix(prefix_len) prefix = self._getPrefix(prefix_len)
self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
(email, prefix))
return self._createCertificate(prefix, req.get_subject(),
req.get_pubkey())
# Create certificate def _createCertificate(self, client_prefix, subject, pubkey):
cert = crypto.X509() cert = crypto.X509()
cert.set_serial_number(0) # required for libssl < 1.0 cert.set_serial_number(0) # required for libssl < 1.0
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(self.cert_duration) cert.gmtime_adj_notAfter(self.cert_duration)
cert.set_issuer(self.ca.get_subject()) cert.set_issuer(self.ca.get_subject())
subject = req.get_subject() subject.CN = "%u/%u" % (int(client_prefix, 2), len(client_prefix))
subject.CN = "%u/%u" % (int(prefix, 2), prefix_len)
cert.set_subject(subject) cert.set_subject(subject)
cert.set_pubkey(req.get_pubkey()) cert.set_pubkey(pubkey)
cert.sign(self.key, 'sha1') cert.sign(self.key, 'sha1')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?",
# Insert certificate into db (cert, client_prefix))
self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
return cert return cert
def renewCertificate(self, cn):
with self.lock:
with self.db:
pem = self._getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if utils.notAfter(cert) - RENEW_PERIOD < time.time():
pem = self._createCertificate(cn, cert.get_subject(),
cert.get_pubkey())
return pem
def getCa(self): def getCa(self):
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca) return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
......
import argparse, errno, logging, os, shlex, signal, socket import argparse, calendar, errno, logging, os, shlex, signal, socket
import struct, subprocess, textwrap, threading, time import struct, subprocess, textwrap, threading, time
logging_levels = logging.WARNING, logging.INFO, logging.DEBUG, 5 logging_levels = logging.WARNING, logging.INFO, logging.DEBUG, 5
...@@ -132,6 +132,9 @@ def networkFromCa(ca): ...@@ -132,6 +132,9 @@ def networkFromCa(ca):
def subnetFromCert(cert): def subnetFromCert(cert):
return cert.get_subject().CN return cert.get_subject().CN
def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
def dump_address(address): def dump_address(address):
return ';'.join(map(','.join, address)) return ';'.join(map(','.join, address))
......
...@@ -4,8 +4,10 @@ import sqlite3, subprocess, sys, time, traceback ...@@ -4,8 +4,10 @@ import sqlite3, subprocess, sys, time, traceback
from collections import deque from collections import deque
from OpenSSL import crypto from OpenSSL import crypto
from re6st import db, plib, tunnel, utils from re6st import db, plib, tunnel, utils
from re6st.registry import RegistryClient from re6st.registry import RegistryClient, RENEW_PERIOD
class ReexecException(Exception):
pass
def getConfig(): def getConfig():
parser = utils.ArgParser(fromfile_prefix_chars='@', parser = utils.ArgParser(fromfile_prefix_chars='@',
...@@ -112,6 +114,34 @@ def getConfig(): ...@@ -112,6 +114,34 @@ def getConfig():
return parser.parse_args() return parser.parse_args()
def maybe_renew(path, cert, info, renew):
while True:
next_renew = utils.notAfter(cert) - RENEW_PERIOD
if time.time() < next_renew:
return cert, next_renew
try:
pem = renew()
if not pem or pem == crypto.dump_certificate(
crypto.FILETYPE_PEM, cert):
exc_info = 0
break
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
except Exception:
exc_info = 1
break
new_path = path + '.new'
with open(new_path, 'w') as f:
f.write(pem)
os.rename(new_path, path)
logging.info("%s renewed until %s UTC",
info, time.asctime(time.gmtime(utils.notAfter(cert))))
logging.error("%s not renewed. Will retry tomorrow.",
info, exc_info=exc_info)
return cert, time.time() + 86400
def exit(status):
exit.status = status
os.kill(os.getpid(), signal.SIGTERM)
def main(): def main():
# Get arguments # Get arguments
...@@ -142,7 +172,15 @@ def main(): ...@@ -142,7 +172,15 @@ def main():
plib.ovpn_log = config.log plib.ovpn_log = config.log
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1)) signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
signal.signal(signal.SIGTERM, lambda *args: sys.exit()) signal.signal(signal.SIGTERM, lambda *args:
sys.exit(getattr(exit, 'status', None)))
registry = RegistryClient(config.registry, config.key, ca)
cert, next_renew = maybe_renew(config.cert, cert, "Certificate",
lambda: registry.renewCertificate(prefix))
ca, ca_renew = maybe_renew(config.ca, ca, "CA Certificate", registry.getCa)
if next_renew > ca_renew:
next_renew = ca_renew
if config.max_clients is None: if config.max_clients is None:
config.max_clients = config.client_count * 2 config.max_clients = config.client_count * 2
...@@ -232,7 +270,6 @@ def main(): ...@@ -232,7 +270,6 @@ def main():
# Create and open read_only pipe to get server events # Create and open read_only pipe to get server events
r_pipe, write_pipe = os.pipe() r_pipe, write_pipe = os.pipe()
read_pipe = os.fdopen(r_pipe) read_pipe = os.fdopen(r_pipe)
registry = RegistryClient(config.registry, config.key, ca)
peer_db = db.PeerDB(db_path, registry, config.key, prefix) peer_db = db.PeerDB(db_path, registry, config.key, prefix)
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db, tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db,
config.openvpn_args, timeout, config.tunnel_refresh, config.openvpn_args, timeout, config.tunnel_refresh,
...@@ -313,7 +350,12 @@ def main(): ...@@ -313,7 +350,12 @@ def main():
# main loop # main loop
if tunnel_manager is None: if tunnel_manager is None:
sys.exit(os.WEXITSTATUS(os.wait()[1])) t = threading.Thread(target=lambda:
exit(os.WEXITSTATUS(os.wait()[1])))
t.daemon = True
t.start()
time.sleep(max(0, next_renew - time.time()))
raise ReexecException("Restart to renew certificate")
cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll
while True: while True:
next = tunnel_manager.next_refresh next = tunnel_manager.next_refresh
...@@ -333,6 +375,8 @@ def main(): ...@@ -333,6 +375,8 @@ def main():
t = time.time() t = time.time()
if t >= tunnel_manager.next_refresh: if t >= tunnel_manager.next_refresh:
tunnel_manager.refresh() tunnel_manager.refresh()
if t >= next_renew:
raise ReexecException("Restart to renew certificate")
if forwarder and t >= forwarder.next_refresh: if forwarder and t >= forwarder.next_refresh:
forwarder.refresh() forwarder.refresh()
finally: finally:
...@@ -344,16 +388,18 @@ def main(): ...@@ -344,16 +388,18 @@ def main():
except sqlite3.Error: except sqlite3.Error:
logging.exception("Restarting with empty cache") logging.exception("Restarting with empty cache")
os.rename(db_path, db_path + '.bak') os.rename(db_path, db_path + '.bak')
try: except ReexecException, e:
sys.exitfunc() logging.info(e)
finally:
os.execvp(sys.argv[0], sys.argv)
except KeyboardInterrupt: except KeyboardInterrupt:
return 0 return 0
except Exception: except Exception:
f = traceback.format_exception(*sys.exc_info()) f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f)) logging.error('%s%s', f.pop(), ''.join(f))
sys.exit(1) sys.exit(1)
try:
sys.exitfunc()
finally:
os.execvp(sys.argv[0], sys.argv)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment