Commit 8ce08bf9 authored by Vincent Pelletier's avatar Vincent Pelletier Committed by Vincent Pelletier

all: More python3 adaptations.

What was not picked up by 2to3.
parent 7f9e56cf
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
Caucase - Certificate Authority for Users, Certificate Authority for SErvices Caucase - Certificate Authority for Users, Certificate Authority for SErvices
""" """
from __future__ import absolute_import from __future__ import absolute_import
from binascii import hexlify, unhexlify
import datetime import datetime
import json import json
import os import os
...@@ -55,7 +56,7 @@ _SUBJECT_OID_DICT = { ...@@ -55,7 +56,7 @@ _SUBJECT_OID_DICT = {
'GN': x509.oid.NameOID.GIVEN_NAME, 'GN': x509.oid.NameOID.GIVEN_NAME,
# pylint: enable=bad-whitespace # pylint: enable=bad-whitespace
} }
_BACKUP_MAGIC = 'caucase\0' _BACKUP_MAGIC = b'caucase\0'
_CONFIG_NAME_AUTO_SIGN_CSR_AMOUNT = 'auto_sign_csr_amount' _CONFIG_NAME_AUTO_SIGN_CSR_AMOUNT = 'auto_sign_csr_amount'
def Extension(value, critical): def Extension(value, critical):
...@@ -227,9 +228,9 @@ class CertificateAuthority(object): ...@@ -227,9 +228,9 @@ class CertificateAuthority(object):
# Note: requested_amount is None when a known CSR is re-submitted # Note: requested_amount is None when a known CSR is re-submitted
csr_id, requested_amount = self._storage.appendCertificateSigningRequest( csr_id, requested_amount = self._storage.appendCertificateSigningRequest(
csr_pem=csr_pem, csr_pem=csr_pem,
key_id=x509.SubjectKeyIdentifier.from_public_key( key_id=hexlify(x509.SubjectKeyIdentifier.from_public_key(
csr.public_key(), csr.public_key(),
).digest.encode('hex'), ).digest),
override_limits=override_limits, override_limits=override_limits,
) )
if requested_amount is not None and \ if requested_amount is not None and \
...@@ -632,8 +633,8 @@ class CertificateAuthority(object): ...@@ -632,8 +633,8 @@ class CertificateAuthority(object):
current_crt_pem = utils.dump_certificate(key_pair['crt']) current_crt_pem = utils.dump_certificate(key_pair['crt'])
result.append(utils.wrap( result.append(utils.wrap(
{ {
'old_pem': previous_crt_pem, 'old_pem': utils.toUnicode(previous_crt_pem),
'new_pem': current_crt_pem, 'new_pem': utils.toUnicode(current_crt_pem),
}, },
previous_key, previous_key,
self.digest_list[0], self.digest_list[0],
...@@ -799,31 +800,31 @@ class UserCertificateAuthority(CertificateAuthority): ...@@ -799,31 +800,31 @@ class UserCertificateAuthority(CertificateAuthority):
continue continue
public_key = crt.public_key() public_key = crt.public_key()
key_list.append({ key_list.append({
'id': x509.SubjectKeyIdentifier.from_public_key( 'id': utils.toUnicode(hexlify(
public_key, x509.SubjectKeyIdentifier.from_public_key(public_key).digest,
).digest.encode('hex'), )),
'cipher': { 'cipher': {
'name': 'rsa_oaep_sha1_mgf1_sha1', 'name': 'rsa_oaep_sha1_mgf1_sha1',
}, },
'key': public_key.encrypt( 'key': utils.toUnicode(hexlify(public_key.encrypt(
signing_key + symetric_key, signing_key + symetric_key,
OAEP( OAEP(
mgf=MGF1(algorithm=hashes.SHA1()), mgf=MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(), algorithm=hashes.SHA1(),
label=None, label=None,
), ),
).encode('hex'), ))),
}) })
if not key_list: if not key_list:
# No users yet, backup is meaningless # No users yet, backup is meaningless
return False return False
header = json.dumps({ header = utils.toBytes(json.dumps({
'cipher': { 'cipher': {
'name': 'aes256_cbc_pkcs7_hmac_10M_sha256', 'name': 'aes256_cbc_pkcs7_hmac_10M_sha256',
'parameter': iv.encode('hex'), 'parameter': utils.toUnicode(hexlify(iv)),
}, },
'key_list': key_list, 'key_list': key_list,
}) }))
padder = padding.PKCS7(128).padder() padder = padding.PKCS7(128).padder()
write(_BACKUP_MAGIC) write(_BACKUP_MAGIC)
write(struct.pack('<I', len(header))) write(struct.pack('<I', len(header)))
...@@ -877,11 +878,11 @@ class UserCertificateAuthority(CertificateAuthority): ...@@ -877,11 +878,11 @@ class UserCertificateAuthority(CertificateAuthority):
if header['cipher']['name'] != 'aes256_cbc_pkcs7_hmac_10M_sha256': if header['cipher']['name'] != 'aes256_cbc_pkcs7_hmac_10M_sha256':
raise ValueError('Unrecognised symetric cipher') raise ValueError('Unrecognised symetric cipher')
private_key = utils.load_privatekey(key_pem) private_key = utils.load_privatekey(key_pem)
key_id = x509.SubjectKeyIdentifier.from_public_key( key_id = hexlify(x509.SubjectKeyIdentifier.from_public_key(
private_key.public_key(), private_key.public_key(),
).digest.encode('hex') ).digest)
symetric_key_list = [ symetric_key_list = [
x for x in header['key_list'] if x['id'] == key_id x for x in header['key_list'] if utils.toBytes(x['id']) == key_id
] ]
if not symetric_key_list: if not symetric_key_list:
raise ValueError( raise ValueError(
...@@ -891,7 +892,7 @@ class UserCertificateAuthority(CertificateAuthority): ...@@ -891,7 +892,7 @@ class UserCertificateAuthority(CertificateAuthority):
if symetric_key_entry['cipher']['name'] != 'rsa_oaep_sha1_mgf1_sha1': if symetric_key_entry['cipher']['name'] != 'rsa_oaep_sha1_mgf1_sha1':
raise ValueError('Unrecognised asymetric cipher') raise ValueError('Unrecognised asymetric cipher')
both_keys = private_key.decrypt( both_keys = private_key.decrypt(
symetric_key_entry['key'].decode('hex'), unhexlify(symetric_key_entry['key']),
OAEP( OAEP(
mgf=MGF1(algorithm=hashes.SHA1()), mgf=MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(), algorithm=hashes.SHA1(),
...@@ -902,7 +903,7 @@ class UserCertificateAuthority(CertificateAuthority): ...@@ -902,7 +903,7 @@ class UserCertificateAuthority(CertificateAuthority):
raise ValueError('Invalid key length') raise ValueError('Invalid key length')
decryptor = Cipher( decryptor = Cipher(
algorithms.AES(both_keys[32:]), algorithms.AES(both_keys[32:]),
modes.CBC(header['cipher']['parameter'].decode('hex')), modes.CBC(unhexlify(header['cipher']['parameter'])),
backend=_cryptography_backend, backend=_cryptography_backend,
).decryptor() ).decryptor()
unpadder = padding.PKCS7(128).unpadder() unpadder = padding.PKCS7(128).unpadder()
......
...@@ -20,6 +20,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices ...@@ -20,6 +20,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
""" """
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import argparse import argparse
from binascii import hexlify
import datetime import datetime
import httplib import httplib
import json import json
...@@ -102,7 +103,7 @@ class CLICaucaseClient(object): ...@@ -102,7 +103,7 @@ class CLICaucaseClient(object):
""" """
for csr_id, csr_path in csr_id_path_list: for csr_id, csr_path in csr_id_path_list:
csr_pem = self._client.getCertificateSigningRequest(int(csr_id)) csr_pem = self._client.getCertificateSigningRequest(int(csr_id))
with open(csr_path, 'a') as csr_file: with open(csr_path, 'ab') as csr_file:
csr_file.write(csr_pem) csr_file.write(csr_pem)
def getCRT(self, warning, error, crt_id_path_list, ca_list): def getCRT(self, warning, error, crt_id_path_list, ca_list):
...@@ -157,7 +158,7 @@ class CLICaucaseClient(object): ...@@ -157,7 +158,7 @@ class CLICaucaseClient(object):
) )
error = True error = True
continue continue
with open(crt_path, 'a') as crt_file: with open(crt_path, 'ab') as crt_file:
crt_file.write(crt_pem) crt_file.write(crt_pem)
return warning, error return warning, error
...@@ -228,11 +229,17 @@ class CLICaucaseClient(object): ...@@ -228,11 +229,17 @@ class CLICaucaseClient(object):
key_len=key_len, key_len=key_len,
) )
if key_path is None: if key_path is None:
with open(crt_path, 'w') as crt_file: with open(crt_path, 'wb') as crt_file:
crt_file.write(new_key_pem) crt_file.write(new_key_pem)
crt_file.write(new_crt_pem) crt_file.write(new_crt_pem)
else: else:
with open(crt_path, 'w') as crt_file, open(key_path, 'w') as key_file: with open(
crt_path,
'wb',
) as crt_file, open(
key_path,
'wb',
) as key_file:
key_file.write(new_key_pem) key_file.write(new_key_pem)
crt_file.write(new_crt_pem) crt_file.write(new_crt_pem)
updated = True updated = True
...@@ -250,7 +257,7 @@ class CLICaucaseClient(object): ...@@ -250,7 +257,7 @@ class CLICaucaseClient(object):
), ),
) )
for entry in self._client.getPendingCertificateRequestList(): for entry in self._client.getPendingCertificateRequestList():
csr = utils.load_certificate_request(entry['csr']) csr = utils.load_certificate_request(utils.toBytes(entry['csr']))
print( print(
'%20s | %r' % ( '%20s | %r' % (
entry['id'], entry['id'],
...@@ -264,7 +271,7 @@ class CLICaucaseClient(object): ...@@ -264,7 +271,7 @@ class CLICaucaseClient(object):
--sign-csr --sign-csr
""" """
for csr_id in csr_id_list: for csr_id in csr_id_list:
self._client.createCertificate(int(csr_id)) self._client.createCertificate(int(utils.toUnicode(csr_id)))
def signCSRWith(self, csr_id_path_list): def signCSRWith(self, csr_id_path_list):
""" """
...@@ -272,7 +279,7 @@ class CLICaucaseClient(object): ...@@ -272,7 +279,7 @@ class CLICaucaseClient(object):
""" """
for csr_id, csr_path in csr_id_path_list: for csr_id, csr_path in csr_id_path_list:
self._client.createCertificate( self._client.createCertificate(
int(csr_id), int(utils.toUnicode(csr_id)),
template_csr=utils.getCertRequest(csr_path), template_csr=utils.getCertRequest(csr_path),
) )
...@@ -763,7 +770,7 @@ def updater(argv=None, until=utils.until): ...@@ -763,7 +770,7 @@ def updater(argv=None, until=utils.until):
# Still here ? Ok, wait a bit and try again. # Still here ? Ok, wait a bit and try again.
until(datetime.datetime.utcnow() + datetime.timedelta(0, 60)) until(datetime.datetime.utcnow() + datetime.timedelta(0, 60))
else: else:
with open(args.crt, 'a') as crt_file: with open(args.crt, 'ab') as crt_file:
crt_file.write(crt_pem) crt_file.write(crt_pem)
updated = True updated = True
break break
...@@ -797,9 +804,10 @@ def updater(argv=None, until=utils.until): ...@@ -797,9 +804,10 @@ def updater(argv=None, until=utils.until):
if RetryingCaucaseClient.updateCRLFile(ca_url, args.crl, ca_crt_list): if RetryingCaucaseClient.updateCRLFile(ca_url, args.crl, ca_crt_list):
print('Got new CRL') print('Got new CRL')
updated = True updated = True
with open(args.crl, 'rb') as crl_file:
next_deadline = min( next_deadline = min(
next_deadline, next_deadline,
utils.load_crl(open(args.crl).read(), ca_crt_list).next_update, utils.load_crl(crli_file.read(), ca_crt_list).next_update,
) )
if args.crt: if args.crt:
crt_pem, key_pem, key_path = utils.getKeyPair(args.crt, args.key) crt_pem, key_pem, key_path = utils.getKeyPair(args.crt, args.key)
...@@ -812,16 +820,16 @@ def updater(argv=None, until=utils.until): ...@@ -812,16 +820,16 @@ def updater(argv=None, until=utils.until):
key_len=args.key_len, key_len=args.key_len,
) )
if key_path is None: if key_path is None:
with open(args.crt, 'w') as crt_file: with open(args.crt, 'wb') as crt_file:
crt_file.write(new_key_pem) crt_file.write(new_key_pem)
crt_file.write(new_crt_pem) crt_file.write(new_crt_pem)
else: else:
with open( with open(
args.crt, args.crt,
'w', 'wb',
) as crt_file, open( ) as crt_file, open(
key_path, key_path,
'w', 'wb',
) as key_file: ) as key_file:
key_file.write(new_key_pem) key_file.write(new_key_pem)
crt_file.write(new_crt_pem) crt_file.write(new_crt_pem)
...@@ -894,11 +902,11 @@ def rerequest(argv=None): ...@@ -894,11 +902,11 @@ def rerequest(argv=None):
key_pem = utils.dump_privatekey(key) key_pem = utils.dump_privatekey(key)
orig_umask = os.umask(0o177) orig_umask = os.umask(0o177)
try: try:
with open(args.key, 'w') as key_file: with open(args.key, 'wb') as key_file:
key_file.write(key_pem) key_file.write(key_pem)
finally: finally:
os.umask(orig_umask) os.umask(orig_umask)
with open(args.csr, 'w') as csr_file: with open(args.csr, 'wb') as csr_file:
csr_file.write(csr_pem) csr_file.write(csr_pem)
def key_id(argv=None): def key_id(argv=None):
...@@ -926,17 +934,20 @@ def key_id(argv=None): ...@@ -926,17 +934,20 @@ def key_id(argv=None):
) )
args = parser.parse_args(argv) args = parser.parse_args(argv)
for key_path in args.private_key: for key_path in args.private_key:
with open(key_path, 'rb') as key_file:
print( print(
key_path, key_path,
utils.toUnicode(hexlify(
x509.SubjectKeyIdentifier.from_public_key( x509.SubjectKeyIdentifier.from_public_key(
utils.load_privatekey(open(key_path).read()).public_key(), utils.load_privatekey(key_file.read()).public_key(),
).digest.encode('hex'), ).digest,
)),
) )
for backup_path in args.backup: for backup_path in args.backup:
print(backup_path) print(backup_path)
with open(backup_path) as backup_file: with open(backup_path, 'rb') as backup_file:
magic = backup_file.read(8) magic = backup_file.read(8)
if magic != 'caucase\0': if magic != b'caucase\0':
raise ValueError('Invalid backup magic string') raise ValueError('Invalid backup magic string')
header_len, = struct.unpack( header_len, = struct.unpack(
'<I', '<I',
......
...@@ -69,7 +69,7 @@ class CaucaseClient(object): ...@@ -69,7 +69,7 @@ class CaucaseClient(object):
""" """
if not os.path.exists(ca_crt_path): if not os.path.exists(ca_crt_path):
ca_pem = cls(ca_url=url).getCACertificate() ca_pem = cls(ca_url=url).getCACertificate()
with open(ca_crt_path, 'w') as ca_crt_file: with open(ca_crt_path, 'wb') as ca_crt_file:
ca_crt_file.write(ca_pem) ca_crt_file.write(ca_pem)
updated = True updated = True
else: else:
...@@ -85,8 +85,8 @@ class CaucaseClient(object): ...@@ -85,8 +85,8 @@ class CaucaseClient(object):
cls(ca_url=url, ca_crt_pem_list=ca_pem_list).getCACertificateChain(), cls(ca_url=url, ca_crt_pem_list=ca_pem_list).getCACertificateChain(),
) )
if ca_pem_list != loaded_ca_pem_list: if ca_pem_list != loaded_ca_pem_list:
data = ''.join(ca_pem_list) data = b''.join(ca_pem_list)
with open(ca_crt_path, 'w') as ca_crt_file: with open(ca_crt_path, 'wb') as ca_crt_file:
ca_crt_file.write(data) ca_crt_file.write(data)
updated = True updated = True
return updated return updated
...@@ -107,13 +107,13 @@ class CaucaseClient(object): ...@@ -107,13 +107,13 @@ class CaucaseClient(object):
Return whether an update happened. Return whether an update happened.
""" """
if os.path.exists(crl_path): if os.path.exists(crl_path):
my_crl = utils.load_crl(open(crl_path).read(), ca_list) my_crl = utils.load_crl(open(crl_path, 'rb').read(), ca_list)
else: else:
my_crl = None my_crl = None
latest_crl_pem = cls(ca_url=url).getCertificateRevocationList() latest_crl_pem = cls(ca_url=url).getCertificateRevocationList()
latest_crl = utils.load_crl(latest_crl_pem, ca_list) latest_crl = utils.load_crl(latest_crl_pem, ca_list)
if my_crl is None or latest_crl.signature != my_crl.signature: if my_crl is None or latest_crl.signature != my_crl.signature:
with open(crl_path, 'w') as crl_file: with open(crl_path, 'wb') as crl_file:
crl_file.write(latest_crl_pem) crl_file.write(latest_crl_pem)
return True return True
return False return False
...@@ -138,7 +138,11 @@ class CaucaseClient(object): ...@@ -138,7 +138,11 @@ class CaucaseClient(object):
ssl_context = ssl.create_default_context( ssl_context = ssl.create_default_context(
# unicode object needed as we use PEM, otherwise create_default_context # unicode object needed as we use PEM, otherwise create_default_context
# expects DER. # expects DER.
cadata=''.join(http_ca_crt_pem_list).decode('ascii') if http_ca_crt_pem_list else None, cadata=(
utils.toUnicode(''.join(http_ca_crt_pem_list))
if http_ca_crt_pem_list
else None
),
) )
if not http_ca_crt_pem_list: if not http_ca_crt_pem_list:
ssl_context.check_hostname = False ssl_context.check_hostname = False
...@@ -191,13 +195,7 @@ class CaucaseClient(object): ...@@ -191,13 +195,7 @@ class CaucaseClient(object):
""" """
[AUTHENTICATED] Retrieve all pending CSRs. [AUTHENTICATED] Retrieve all pending CSRs.
""" """
return [ return json.loads(self._https('GET', '/csr'))
{
y.encode('ascii'): z.encode('ascii') if isinstance(z, unicode) else z
for y, z in x.iteritems()
}
for x in json.loads(self._https('GET', '/csr'))
]
def createCertificateSigningRequest(self, csr): def createCertificateSigningRequest(self, csr):
""" """
...@@ -254,14 +252,14 @@ class CaucaseClient(object): ...@@ -254,14 +252,14 @@ class CaucaseClient(object):
continue continue
if not found: if not found:
found = utils.load_ca_certificate( found = utils.load_ca_certificate(
payload['old_pem'].encode('ascii'), utils.toBytes(payload['old_pem']),
) == trust_anchor ) == trust_anchor
if found: if found:
if utils.load_ca_certificate( if utils.load_ca_certificate(
payload['old_pem'].encode('ascii'), utils.toBytes(payload['old_pem']),
) != previous_ca: ) != previous_ca:
raise ValueError('CA signature chain broken') raise ValueError('CA signature chain broken')
new_pem = payload['new_pem'].encode('ascii') new_pem = utils.toBytes(payload['new_pem'])
result.append(new_pem) result.append(new_pem)
previous_ca = utils.load_ca_certificate(new_pem) previous_ca = utils.load_ca_certificate(new_pem)
return result return result
...@@ -279,8 +277,8 @@ class CaucaseClient(object): ...@@ -279,8 +277,8 @@ class CaucaseClient(object):
json.dumps( json.dumps(
utils.wrap( utils.wrap(
{ {
'crt_pem': utils.dump_certificate(old_crt), 'crt_pem': utils.toUnicode(utils.dump_certificate(old_crt)),
'renew_csr_pem': utils.dump_certificate_request( 'renew_csr_pem': utils.toUnicode(utils.dump_certificate_request(
x509.CertificateSigningRequestBuilder( x509.CertificateSigningRequestBuilder(
).subject_name( ).subject_name(
# Note: caucase server ignores this, but cryptography # Note: caucase server ignores this, but cryptography
...@@ -291,7 +289,7 @@ class CaucaseClient(object): ...@@ -291,7 +289,7 @@ class CaucaseClient(object):
algorithm=utils.DEFAULT_DIGEST_CLASS(), algorithm=utils.DEFAULT_DIGEST_CLASS(),
backend=_cryptography_backend, backend=_cryptography_backend,
), ),
), )),
}, },
old_key, old_key,
utils.DEFAULT_DIGEST, utils.DEFAULT_DIGEST,
...@@ -307,6 +305,7 @@ class CaucaseClient(object): ...@@ -307,6 +305,7 @@ class CaucaseClient(object):
[ANONYMOUS] if key is provided. [ANONYMOUS] if key is provided.
[AUTHENTICATED] if key is missing. [AUTHENTICATED] if key is missing.
""" """
crt = utils.toUnicode(crt)
if key: if key:
method = self._http method = self._http
data = utils.wrap( data = utils.wrap(
......
...@@ -70,7 +70,7 @@ def _createKey(path): ...@@ -70,7 +70,7 @@ def _createKey(path):
""" """
return os.fdopen( return os.fdopen(
os.open(path, os.O_WRONLY | os.O_CREAT, 0o600), os.open(path, os.O_WRONLY | os.O_CREAT, 0o600),
'w', 'wb',
) )
class ThreadingWSGIServer(ThreadingMixIn, WSGIServer): class ThreadingWSGIServer(ThreadingMixIn, WSGIServer):
...@@ -236,7 +236,7 @@ def getSSLContext( ...@@ -236,7 +236,7 @@ def getSSLContext(
# implementation cross-check would have been nice. # implementation cross-check would have been nice.
#ssl_context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF #ssl_context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
ssl_context.load_verify_locations( ssl_context.load_verify_locations(
cadata=cau.getCACertificate().decode('ascii'), cadata=utils.toUnicode(cau.getCACertificate()),
) )
http_cas_certificate_list = http_cas.getCACertificateList() http_cas_certificate_list = http_cas.getCACertificateList()
threshold_delta = datetime.timedelta(threshold, 0) threshold_delta = datetime.timedelta(threshold, 0)
...@@ -500,12 +500,12 @@ def main(argv=None, until=utils.until): ...@@ -500,12 +500,12 @@ def main(argv=None, until=utils.until):
) )
args = parser.parse_args(argv) args = parser.parse_args(argv)
base_url = u'http://' + args.netloc.decode('ascii') base_url = u'http://' + utils.toUnicode(args.netloc)
parsed_base_url = urlparse(base_url) parsed_base_url = urlparse(base_url)
hostname = parsed_base_url.hostname hostname = parsed_base_url.hostname
name_constraints_permited = [] name_constraints_permited = []
name_constraints_excluded = [] name_constraints_excluded = []
hostname_dnsname = hostname.decode('ascii') hostname_dnsname = utils.toUnicode(hostname)
try: try:
hostname_ip_address = ipaddress.ip_address(hostname_dnsname) hostname_ip_address = ipaddress.ip_address(hostname_dnsname)
except ValueError: except ValueError:
...@@ -615,7 +615,7 @@ def main(argv=None, until=utils.until): ...@@ -615,7 +615,7 @@ def main(argv=None, until=utils.until):
crt_life_time=args.service_crt_validity, crt_life_time=args.service_crt_validity,
) )
if os.path.exists(args.cors_key_store): if os.path.exists(args.cors_key_store):
with open(args.cors_key_store) as cors_key_file: with open(args.cors_key_store, 'rb') as cors_key_file:
cors_secret_list = json.load(cors_key_file) cors_secret_list = json.load(cors_key_file)
else: else:
cors_secret_list = [] cors_secret_list = []
...@@ -761,7 +761,7 @@ def main(argv=None, until=utils.until): ...@@ -761,7 +761,7 @@ def main(argv=None, until=utils.until):
tmp_backup_fd, tmp_backup_path = tempfile.mkstemp( tmp_backup_fd, tmp_backup_path = tempfile.mkstemp(
prefix='caucase_backup_', prefix='caucase_backup_',
) )
with os.fdopen(tmp_backup_fd, 'w') as backup_file: with os.fdopen(tmp_backup_fd, 'wb') as backup_file:
result = cau.doBackup(backup_file.write) result = cau.doBackup(backup_file.write)
if result: if result:
backup_path = os.path.join( backup_path = os.path.join(
...@@ -782,6 +782,7 @@ def main(argv=None, until=utils.until): ...@@ -782,6 +782,7 @@ def main(argv=None, until=utils.until):
finally: finally:
sys.stderr.write('Exiting\n') sys.stderr.write('Exiting\n')
for server in itertools.chain(http_list, https_list): for server in itertools.chain(http_list, https_list):
server.server_close()
server.shutdown() server.shutdown()
def manage(argv=None): def manage(argv=None):
...@@ -820,7 +821,7 @@ def manage(argv=None): ...@@ -820,7 +821,7 @@ def manage(argv=None):
default=[], default=[],
metavar='PEM_FILE', metavar='PEM_FILE',
action='append', action='append',
type=argparse.FileType('r'), type=argparse.FileType('rb'),
help='Import key pairs as initial service CA certificate. ' help='Import key pairs as initial service CA certificate. '
'May be provided multiple times to import multiple key pairs. ' 'May be provided multiple times to import multiple key pairs. '
'Keys and certificates may be in separate files. ' 'Keys and certificates may be in separate files. '
...@@ -846,7 +847,7 @@ def manage(argv=None): ...@@ -846,7 +847,7 @@ def manage(argv=None):
default=[], default=[],
metavar='PEM_FILE', metavar='PEM_FILE',
action='append', action='append',
type=argparse.FileType('r'), type=argparse.FileType('rb'),
help='Import service revocation list. Corresponding CA certificate must ' help='Import service revocation list. Corresponding CA certificate must '
'be already present in the database (including added in the same run ' 'be already present in the database (including added in the same run '
'using --import-ca).', 'using --import-ca).',
...@@ -854,7 +855,7 @@ def manage(argv=None): ...@@ -854,7 +855,7 @@ def manage(argv=None):
parser.add_argument( parser.add_argument(
'--export-ca', '--export-ca',
metavar='PEM_FILE', metavar='PEM_FILE',
type=argparse.FileType('w'), type=argparse.FileType('wb'),
help='Export all CA certificates in a PEM file. Passphrase will be ' help='Export all CA certificates in a PEM file. Passphrase will be '
'prompted to protect all keys.', 'prompted to protect all keys.',
) )
...@@ -873,8 +874,13 @@ def manage(argv=None): ...@@ -873,8 +874,13 @@ def manage(argv=None):
# maybe user extracted their private key ? # maybe user extracted their private key ?
key_pem = utils.getKey(backup_key_path) key_pem = utils.getKey(backup_key_path)
cau_crt_life_time = args.user_crt_validity cau_crt_life_time = args.user_crt_validity
with open(backup_path) as backup_file: with open(
with open(backup_crt_path, 'a') as new_crt_file: backup_path,
'rb',
) as backup_file, open(
backup_crt_path,
'ab',
) as new_crt_file:
new_crt_file.write( new_crt_file.write(
UserCertificateAuthority.restoreBackup( UserCertificateAuthority.restoreBackup(
db_class=SQLite3Storage, db_class=SQLite3Storage,
......
...@@ -22,6 +22,7 @@ Separate from .http because of different-licensed code in the middle. ...@@ -22,6 +22,7 @@ Separate from .http because of different-licensed code in the middle.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from wsgiref.simple_server import ServerHandler from wsgiref.simple_server import ServerHandler
from .utils import toBytes
class ProxyFile(object): class ProxyFile(object):
""" """
...@@ -48,7 +49,7 @@ class ChunkedFile(ProxyFile): ...@@ -48,7 +49,7 @@ class ChunkedFile(ProxyFile):
""" """
Read chunked data. Read chunked data.
""" """
result = '' result = b''
if not self._at_eof: if not self._at_eof:
readline = self.readline readline = self.readline
read = self.__getattr__('read') read = self.__getattr__('read')
...@@ -61,7 +62,7 @@ class ChunkedFile(ProxyFile): ...@@ -61,7 +62,7 @@ class ChunkedFile(ProxyFile):
if len(chunk_header) > MAX_CHUNKED_HEADER_LENGTH: if len(chunk_header) > MAX_CHUNKED_HEADER_LENGTH:
raise ValueError('Chunked encoding header too long') raise ValueError('Chunked encoding header too long')
try: try:
chunk_length = int(chunk_header.split(';', 1)[0], 16) chunk_length = int(chunk_header.split(b';', 1)[0], 16)
except ValueError: except ValueError:
raise ValueError('Invalid chunked encoding header') raise ValueError('Invalid chunked encoding header')
if not chunk_length: if not chunk_length:
...@@ -78,7 +79,7 @@ class ChunkedFile(ProxyFile): ...@@ -78,7 +79,7 @@ class ChunkedFile(ProxyFile):
if to_read != chunk_length: if to_read != chunk_length:
self._chunk_remaining_length = chunk_length - to_read self._chunk_remaining_length = chunk_length - to_read
break break
if read(2) != '\r\n': if read(2) != b'\r\n':
raise ValueError('Invalid chunked encoding separator') raise ValueError('Invalid chunked encoding separator')
return result return result
...@@ -131,7 +132,7 @@ class CleanServerHandler(ServerHandler): ...@@ -131,7 +132,7 @@ class CleanServerHandler(ServerHandler):
""" """
Emit "100 Continue" intermediate response. Emit "100 Continue" intermediate response.
""" """
self._write('HTTP/%s 100 Continue\r\n\r\n' % ( self._write(b'HTTP/%s 100 Continue\r\n\r\n' % (
self.http_version, toBytes(self.http_version),
)) ))
self._flush() self._flush()
...@@ -25,6 +25,7 @@ import sqlite3 ...@@ -25,6 +25,7 @@ import sqlite3
from threading import local from threading import local
from time import time from time import time
from .exceptions import NoStorage, NotFound, Found from .exceptions import NoStorage, NotFound, Found
from .utils import toBytes, toUnicode
__all__ = ('SQLite3Storage', ) __all__ = ('SQLite3Storage', )
...@@ -207,8 +208,8 @@ class SQLite3Storage(local): ...@@ -207,8 +208,8 @@ class SQLite3Storage(local):
) )
return [ return [
{ {
'crt_pem': x['crt'].encode('ascii'), 'crt_pem': toBytes(x['crt']),
'key_pem': x['key'].encode('ascii'), 'key_pem': toBytes(x['key']),
} }
for x in db.cursor().execute( for x in db.cursor().execute(
'SELECT key, crt FROM %sca ORDER BY expiration_date ASC' % ( 'SELECT key, crt FROM %sca ORDER BY expiration_date ASC' % (
...@@ -326,7 +327,7 @@ class SQLite3Storage(local): ...@@ -326,7 +327,7 @@ class SQLite3Storage(local):
) )
if result is None: if result is None:
raise NotFound raise NotFound
return result['csr'].encode('ascii') return toBytes(result['csr'])
def getCertificateSigningRequestList(self): def getCertificateSigningRequestList(self):
""" """
...@@ -338,7 +339,11 @@ class SQLite3Storage(local): ...@@ -338,7 +339,11 @@ class SQLite3Storage(local):
return [ return [
{ {
'id': str(x['id']), 'id': str(x['id']),
'csr': x['csr'].encode('ascii'), # XXX: because only call chain will end up serialising this value in
# json, and for some reason python3 json module refuses bytes.
# So rather than byte-ify (consistently with all PEM-encoded values)
# to then have to unicode-ify, just unicode-ify here.
'csr': toUnicode(x['csr']),
} }
for x in db.cursor().execute( for x in db.cursor().execute(
'SELECT id, csr FROM %scrt WHERE crt IS NULL' % ( 'SELECT id, csr FROM %scrt WHERE crt IS NULL' % (
...@@ -401,7 +406,7 @@ class SQLite3Storage(local): ...@@ -401,7 +406,7 @@ class SQLite3Storage(local):
crt_id, crt_id,
) )
) )
return row['crt'].encode('ascii') return toBytes(row['crt'])
def getCertificateByKeyIdentifier(self, key_id): def getCertificateByKeyIdentifier(self, key_id):
""" """
...@@ -419,7 +424,7 @@ class SQLite3Storage(local): ...@@ -419,7 +424,7 @@ class SQLite3Storage(local):
) )
if row is None: if row is None:
raise NotFound raise NotFound
return row['crt'].encode('ascii') return toBytes(row['crt'])
def iterCertificates(self): def iterCertificates(self):
""" """
...@@ -434,7 +439,7 @@ class SQLite3Storage(local): ...@@ -434,7 +439,7 @@ class SQLite3Storage(local):
row = c.fetchone() row = c.fetchone()
if row is None: if row is None:
break break
yield row['crt'].encode('ascii') yield toBytes(row['crt'])
def revoke(self, serial, expiration_date): def revoke(self, serial, expiration_date):
""" """
...@@ -483,7 +488,7 @@ class SQLite3Storage(local): ...@@ -483,7 +488,7 @@ class SQLite3Storage(local):
(time(), ) (time(), )
) )
if row is not None: if row is not None:
return row['crl'].encode('ascii') return toBytes(row['crl'])
return None return None
def getNextCertificateRevocationListNumber(self): def getNextCertificateRevocationListNumber(self):
...@@ -547,7 +552,7 @@ class SQLite3Storage(local): ...@@ -547,7 +552,7 @@ class SQLite3Storage(local):
class (so not limited to table_prefix). class (so not limited to table_prefix).
""" """
for statement in self._db.iterdump(): for statement in self._db.iterdump():
yield statement.encode('utf-8') + '\0' yield toBytes(statement, 'utf-8') + b'\0'
@staticmethod @staticmethod
def restore(db_path, restorator): def restore(db_path, restorator):
...@@ -563,14 +568,14 @@ class SQLite3Storage(local): ...@@ -563,14 +568,14 @@ class SQLite3Storage(local):
Produces chunks which correspond (in content, not necessarily in size) Produces chunks which correspond (in content, not necessarily in size)
to what dumpIterator produces. to what dumpIterator produces.
""" """
buf = '' buf = b''
if os.path.exists(db_path): if os.path.exists(db_path):
raise ValueError('%r exists, not restoring.' % (db_path, )) raise ValueError('%r exists, not restoring.' % (db_path, ))
c = sqlite3.connect(db_path, isolation_level=None).cursor() c = sqlite3.connect(db_path, isolation_level=None).cursor()
for chunk in restorator: for chunk in restorator:
statement_list = (buf + chunk).split('\0') statement_list = (buf + chunk).split(b'\0')
buf = statement_list.pop() buf = statement_list.pop()
for statement in statement_list: for statement in statement_list:
c.execute((statement).decode('utf-8')) c.execute(toUnicode(statement, 'utf-8'))
if buf: if buf:
raise ValueError('Short read, backup truncated ?') raise ValueError('Short read, backup truncated ?')
...@@ -22,12 +22,12 @@ Test suite ...@@ -22,12 +22,12 @@ Test suite
""" """
from __future__ import absolute_import from __future__ import absolute_import
from Cookie import SimpleCookie from Cookie import SimpleCookie
from cStringIO import StringIO
import datetime import datetime
import errno import errno
import glob import glob
import HTMLParser import HTMLParser
import httplib import httplib
from io import BytesIO, StringIO
import ipaddress import ipaddress
import json import json
import os import os
...@@ -48,7 +48,9 @@ from cryptography import x509 ...@@ -48,7 +48,9 @@ from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from caucase import cli from caucase import cli
from caucase.client import CaucaseError, CaucaseClient from caucase.client import CaucaseError, CaucaseClient
from caucase import http # Do not import caucase.http into this namespace: 2to3 will import standard
# http module, which will then be masqued by caucase's http submodule.
import caucase.http
from caucase import utils from caucase import utils
from caucase import exceptions from caucase import exceptions
from caucase import wsgi from caucase import wsgi
...@@ -106,11 +108,13 @@ def canConnect(address): # pragma: no cover ...@@ -106,11 +108,13 @@ def canConnect(address): # pragma: no cover
otherwise. otherwise.
""" """
try: try:
socket.create_connection(address) sock = socket.create_connection(address)
except socket.error as e: except socket.error as e:
if e.errno == errno.ECONNREFUSED: if e.errno == errno.ECONNREFUSED:
return False return False
raise raise
else:
sock.close()
return True return True
def retry(callback, try_count=200, try_delay=0.1): # pragma: no cover def retry(callback, try_count=200, try_delay=0.1): # pragma: no cover
...@@ -129,7 +133,7 @@ def retry(callback, try_count=200, try_delay=0.1): # pragma: no cover ...@@ -129,7 +133,7 @@ def retry(callback, try_count=200, try_delay=0.1): # pragma: no cover
class FakeStreamRequest(object): class FakeStreamRequest(object):
""" """
For testing StreamRequestHandler subclasses For testing StreamRequestHandler subclasses
(like http.CaucaseWSGIRequestHandler). (like caucase.http.CaucaseWSGIRequestHandler).
""" """
def __init__(self, rfile, wfile): def __init__(self, rfile, wfile):
""" """
...@@ -144,6 +148,9 @@ class FakeStreamRequest(object): ...@@ -144,6 +148,9 @@ class FakeStreamRequest(object):
""" """
return self._rfile if 'r' in mode else self._wfile return self._rfile if 'r' in mode else self._wfile
def sendall(self, data, flags=None): # pragma: no cover
self._wfile.write(data)
class NoCloseFileProxy(object): class NoCloseFileProxy(object):
""" """
Intercept .close() calls, for example to allow reading StringIO content Intercept .close() calls, for example to allow reading StringIO content
...@@ -324,7 +331,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -324,7 +331,7 @@ class CaucaseTest(unittest.TestCase):
Returns its exit status. Returns its exit status.
""" """
try: try:
http.manage( caucase.http.manage(
argv=( argv=(
'--db', self._server_db, '--db', self._server_db,
'--restore-backup', '--restore-backup',
...@@ -346,7 +353,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -346,7 +353,7 @@ class CaucaseTest(unittest.TestCase):
""" """
self._server_until = until = UntilEvent(self._server_event) self._server_until = until = UntilEvent(self._server_event)
self._server = server = threading.Thread( self._server = server = threading.Thread(
target=http.main, target=caucase.http.main,
kwargs={ kwargs={
'argv': ( 'argv': (
'--db', self._server_db, '--db', self._server_db,
...@@ -453,10 +460,10 @@ class CaucaseTest(unittest.TestCase): ...@@ -453,10 +460,10 @@ class CaucaseTest(unittest.TestCase):
row = c.fetchone() row = c.fetchone()
if row is None: # pragma: no cover if row is None: # pragma: no cover
raise Exception('CA with serial %r not found' % (serial, )) raise Exception('CA with serial %r not found' % (serial, ))
crt = utils.load_ca_certificate(row['crt'].encode('ascii')) crt = utils.load_ca_certificate(utils.toBytes(row['crt']))
if crt.serial_number == serial: if crt.serial_number == serial:
new_crt = self._setCertificateRemainingLifeTime( new_crt = self._setCertificateRemainingLifeTime(
key=utils.load_privatekey(row['key'].encode('ascii')), key=utils.load_privatekey(utils.toBytes(row['key'])),
crt=crt, crt=crt,
delta=delta, delta=delta,
) )
...@@ -489,7 +496,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -489,7 +496,7 @@ class CaucaseTest(unittest.TestCase):
""" """
name = basename + '.key.pem' name = basename + '.key.pem'
assert not os.path.exists(name) assert not os.path.exists(name)
with open(name, 'w') as key_file: with open(name, 'wb') as key_file:
key_file.write(utils.dump_privatekey( key_file.write(utils.dump_privatekey(
utils.generatePrivateKey(key_len=key_len), utils.generatePrivateKey(key_len=key_len),
)) ))
...@@ -516,7 +523,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -516,7 +523,7 @@ class CaucaseTest(unittest.TestCase):
""" """
name = basename + '.csr.pem' name = basename + '.csr.pem'
assert not os.path.exists(name) assert not os.path.exists(name)
with open(name, 'w') as csr_file: with open(name, 'wb') as csr_file:
csr_file.write( csr_file.write(
utils.dump_certificate_request( utils.dump_certificate_request(
csr_builder.sign( csr_builder.sign(
...@@ -604,7 +611,8 @@ class CaucaseTest(unittest.TestCase): ...@@ -604,7 +611,8 @@ class CaucaseTest(unittest.TestCase):
'--mode', mode, '--mode', mode,
'--get-csr', csr_id, csr2_path, '--get-csr', csr_id, csr2_path,
) )
self.assertEqual(open(csr_path).read(), open(csr2_path).read()) with open(csr_path, 'rb') as csr_file, open(csr2_path, 'rb') as csr2_file:
self.assertEqual(csr_file.read(), csr2_file.read())
# Sign using user cert # Sign using user cert
# Note: assuming user does not know the csr_id and keeps their own copy of # Note: assuming user does not know the csr_id and keeps their own copy of
# issued certificates. # issued certificates.
...@@ -1143,12 +1151,14 @@ class CaucaseTest(unittest.TestCase): ...@@ -1143,12 +1151,14 @@ class CaucaseTest(unittest.TestCase):
# Check renewed CRT filtering does not alter clean signed certificate # Check renewed CRT filtering does not alter clean signed certificate
# content (especially, caucase auto-signed flag must not appear). # content (especially, caucase auto-signed flag must not appear).
before_key = open(key_path).read() with open(key_path, 'rb') as key_file:
before_key = key_file.read()
self._runClient( self._runClient(
'--threshold', '100', '--threshold', '100',
'--renew-crt', key_path, '', '--renew-crt', key_path, '',
) )
after_key = open(key_path).read() with open(key_path, 'rb') as key_file:
after_key = key_file.read()
assert before_key != after_key assert before_key != after_key
checkCRT(key_path) checkCRT(key_path)
...@@ -1215,7 +1225,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -1215,7 +1225,7 @@ class CaucaseTest(unittest.TestCase):
) )
# As we will use this crt as trust anchor, we must make the client believe # As we will use this crt as trust anchor, we must make the client believe
# it knew it all along. # it knew it all along.
with open(self._client_user_ca_crt, 'w') as client_user_ca_crt_file: with open(self._client_user_ca_crt, 'wb') as client_user_ca_crt_file:
client_user_ca_crt_file.write(new_cau_crt_pem) client_user_ca_crt_file.write(new_cau_crt_pem)
self._startServer() self._startServer()
new_user_key = self._createAndApproveCertificate( new_user_key = self._createAndApproveCertificate(
...@@ -1302,11 +1312,11 @@ class CaucaseTest(unittest.TestCase): ...@@ -1302,11 +1312,11 @@ class CaucaseTest(unittest.TestCase):
self._server_key, self._server_key,
crl=None, crl=None,
) )
with open(self._server_key, 'w') as server_key_file: with open(self._server_key, 'wb') as server_key_file:
server_key_file.write(key_pem) server_key_file.write(key_pem)
server_key_file.write(utils.dump_certificate( server_key_file.write(utils.dump_certificate(
self._setCertificateRemainingLifeTime( self._setCertificateRemainingLifeTime(
key=utils.load_privatekey(http_cas_key.encode('ascii')), key=utils.load_privatekey(utils.toBytes(http_cas_key)),
crt=utils.load_certificate( crt=utils.load_certificate(
crt_pem, crt_pem,
[ [
...@@ -1318,10 +1328,13 @@ class CaucaseTest(unittest.TestCase): ...@@ -1318,10 +1328,13 @@ class CaucaseTest(unittest.TestCase):
) )
)) ))
server_key_file.write(ca_crt_pem) server_key_file.write(ca_crt_pem)
reference_server_key = open(self._server_key).read() def readServerKey():
with open(self._server_key, 'rb') as server_key_file:
return server_key_file.read()
reference_server_key = readServerKey()
self._startServer() self._startServer()
if not retry( if not retry(
lambda: open(self._server_key).read() != reference_server_key, lambda: readServerKey() != reference_server_key,
): # pragma: no cover ): # pragma: no cover
raise AssertionError('Server did not renew its key pair within 1 second') raise AssertionError('Server did not renew its key pair within 1 second')
# But user still trusts the server # But user still trusts the server
...@@ -1363,7 +1376,8 @@ class CaucaseTest(unittest.TestCase): ...@@ -1363,7 +1376,8 @@ class CaucaseTest(unittest.TestCase):
utils.load_ca_certificate(x) utils.load_ca_certificate(x)
for x in utils.getCertList(self._client_user_ca_crt) for x in utils.getCertList(self._client_user_ca_crt)
] ]
cau_crl = open(self._client_user_crl).read() with open(self._client_user_crl, 'rb') as client_user_crl_file:
cau_crl = client_user_crl_file.read()
class DummyCAU(object): class DummyCAU(object):
""" """
Mock CAU. Mock CAU.
...@@ -1382,7 +1396,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -1382,7 +1396,7 @@ class CaucaseTest(unittest.TestCase):
""" """
Return a dummy string as CA certificate Return a dummy string as CA certificate
""" """
return 'notreallyPEM' return b'notreallyPEM'
@staticmethod @staticmethod
def getCertificateRevocationList(): def getCertificateRevocationList():
...@@ -1441,7 +1455,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -1441,7 +1455,7 @@ class CaucaseTest(unittest.TestCase):
if key in header_dict: # pragma: no cover if key in header_dict: # pragma: no cover
value = header_dict[key] + ',' + value value = header_dict[key] + ',' + value
header_dict[key] = value header_dict[key] = value
return int(status), reason, header_dict, ''.join(body) return int(status), reason, header_dict, b''.join(body)
UNAUTHORISED_STATUS = 401 UNAUTHORISED_STATUS = 401
HATEOAS_HTTP_PREFIX = u"http://caucase.example.com:8000/base/path" HATEOAS_HTTP_PREFIX = u"http://caucase.example.com:8000/base/path"
...@@ -1841,7 +1855,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -1841,7 +1855,7 @@ class CaucaseTest(unittest.TestCase):
header_dict['Content-Security-Policy'], header_dict['Content-Security-Policy'],
"frame-ancestors 'none'", "frame-ancestors 'none'",
) )
assertHTMLNoScriptAlert(body) assertHTMLNoScriptAlert(utils.toUnicode(body))
# POST /cors sets cookie # POST /cors sets cookie
def getCORSPostEnvironment(kw=(), input_dict=( def getCORSPostEnvironment(kw=(), input_dict=(
('return_to', return_url), ('return_to', return_url),
...@@ -2042,9 +2056,9 @@ class CaucaseTest(unittest.TestCase): ...@@ -2042,9 +2056,9 @@ class CaucaseTest(unittest.TestCase):
table_prefix='cau', table_prefix='cau',
).dumpIterator()) ).dumpIterator())
CRL_INSERT = 'INSERT INTO "caucrl" ' CRL_INSERT = b'INSERT INTO "caucrl" '
CRT_INSERT = 'INSERT INTO "caucrt" ' CRT_INSERT = b'INSERT INTO "caucrt" '
REV_INSERT = 'INSERT INTO "caurevoked" ' REV_INSERT = b'INSERT INTO "caurevoked" '
def filterBackup(backup, expect_rev): def filterBackup(backup, expect_rev):
""" """
Remove all lines which are know to differ between original batabase and Remove all lines which are know to differ between original batabase and
...@@ -2145,7 +2159,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2145,7 +2159,7 @@ class CaucaseTest(unittest.TestCase):
user2_newnew_key_path, user2_newnew_key_path,
) )
user2_new_bare_key_path = user2_new_key_path + '.bare_key' user2_new_bare_key_path = user2_new_key_path + '.bare_key'
with open(user2_new_bare_key_path, 'w') as bare_key_file: with open(user2_new_bare_key_path, 'wb') as bare_key_file:
bare_key_file.write(utils.getKeyPair(user2_new_key_path)[1]) bare_key_file.write(utils.getKeyPair(user2_new_key_path)[1])
self.assertEqual( self.assertEqual(
self._restoreServer( self._restoreServer(
...@@ -2174,13 +2188,13 @@ class CaucaseTest(unittest.TestCase): ...@@ -2174,13 +2188,13 @@ class CaucaseTest(unittest.TestCase):
'--revoke-crt', service_key, service_key, '--revoke-crt', service_key, service_key,
) )
self._runClient() self._runClient()
getBytePass_orig = http.getBytePass getBytePass_orig = caucase.http.getBytePass
orig_stdout = sys.stdout orig_stdout = sys.stdout
try: try:
http.getBytePass = lambda x: 'test' caucase.http.getBytePass = lambda x: b'test'
sys.stdout = stdout = StringIO() sys.stdout = stdout = StringIO()
self.assertFalse(os.path.exists(exported_ca), exported_ca) self.assertFalse(os.path.exists(exported_ca), exported_ca)
http.manage( caucase.http.manage(
argv=( argv=(
'--db', self._server_db, '--db', self._server_db,
'--export-ca', exported_ca, '--export-ca', exported_ca,
...@@ -2189,7 +2203,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2189,7 +2203,7 @@ class CaucaseTest(unittest.TestCase):
self.assertTrue(os.path.exists(exported_ca), exported_ca) self.assertTrue(os.path.exists(exported_ca), exported_ca)
server_db2 = self._server_db + '2' server_db2 = self._server_db + '2'
self.assertFalse(os.path.exists(server_db2), server_db2) self.assertFalse(os.path.exists(server_db2), server_db2)
http.manage( caucase.http.manage(
argv=( argv=(
'--db', server_db2, '--db', server_db2,
'--import-ca', exported_ca, '--import-ca', exported_ca,
...@@ -2208,7 +2222,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2208,7 +2222,7 @@ class CaucaseTest(unittest.TestCase):
) )
finally: finally:
sys.stdout = orig_stdout sys.stdout = orig_stdout
http.getBytePass = getBytePass_orig caucase.http.getBytePass = getBytePass_orig
def testWSGIBase(self): def testWSGIBase(self):
""" """
...@@ -2220,10 +2234,10 @@ class CaucaseTest(unittest.TestCase): ...@@ -2220,10 +2234,10 @@ class CaucaseTest(unittest.TestCase):
""" """
Trigger execution of app, with given request. Trigger execution of app, with given request.
""" """
wfile = StringIO() wfile = BytesIO()
http.CaucaseWSGIRequestHandler( caucase.http.CaucaseWSGIRequestHandler(
FakeStreamRequest( FakeStreamRequest(
StringIO('\r\n'.join(request_line_list + [''])), BytesIO(b'\r\n'.join(request_line_list + [b''])),
NoCloseFileProxy(wfile), NoCloseFileProxy(wfile),
), ),
('0.0.0.0', 0), ('0.0.0.0', 0),
...@@ -2235,26 +2249,28 @@ class CaucaseTest(unittest.TestCase): ...@@ -2235,26 +2249,28 @@ class CaucaseTest(unittest.TestCase):
""" """
Naive extraction of http status out of an http response. Naive extraction of http status out of an http response.
""" """
_, code, _ = response_line_list[0].split(' ', 2) _, code, _ = response_line_list[0].split(b' ', 2)
return int(code) return int(code)
def getBody(response_line_list): def getBody(response_line_list):
""" """
Naive extraction of http response body. Naive extraction of http response body.
""" """
return '\r\n'.join(response_line_list[response_line_list.index('') + 1:]) return b'\r\n'.join(
response_line_list[response_line_list.index(b'') + 1:],
)
self.assertEqual( self.assertEqual(
getStatus(run(['GET /' + 'a' * 65537])), getStatus(run([b'GET /' + b'a' * 65537])),
414, 414,
) )
expect_continue_request = [ expect_continue_request = [
'PUT / HTTP/1.1', b'PUT / HTTP/1.1',
'Expect: 100-continue', b'Expect: 100-continue',
'Content-Length: 4', b'Content-Length: 4',
'Content-Type: text/plain', b'Content-Type: text/plain',
'', b'',
'Test', b'Test',
] ]
# No read: 200 OK # No read: 200 OK
self.assertEqual( self.assertEqual(
...@@ -2271,7 +2287,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2271,7 +2287,7 @@ class CaucaseTest(unittest.TestCase):
self.assertEqual( self.assertEqual(
getStatus(run( getStatus(run(
[ [
'PUT / HTTP/1.0', b'PUT / HTTP/1.0',
] + expect_continue_request[1:], ] + expect_continue_request[1:],
read_app, read_app,
)), )),
...@@ -2279,19 +2295,19 @@ class CaucaseTest(unittest.TestCase): ...@@ -2279,19 +2295,19 @@ class CaucaseTest(unittest.TestCase):
) )
chunked_request = [ chunked_request = [
'PUT / HTTP/1.1', b'PUT / HTTP/1.1',
'Transfer-Encoding: chunked', b'Transfer-Encoding: chunked',
'', b'',
'f;some=extension', b'f;some=extension',
'123456789abcd\r\n', b'123456789abcd\r\n',
'3', b'3',
'ef0', b'ef0',
'0', b'0',
'X-Chunked-Trailer: blah' b'X-Chunked-Trailer: blah'
] ]
self.assertEqual( self.assertEqual(
getBody(run(chunked_request, read_app)), getBody(run(chunked_request, read_app)),
'123456789abcd\r\nef0', b'123456789abcd\r\nef0',
) )
self.assertEqual( self.assertEqual(
getBody(run( getBody(run(
...@@ -2300,7 +2316,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2300,7 +2316,7 @@ class CaucaseTest(unittest.TestCase):
environ['wsgi.input'].read(), environ['wsgi.input'].read(),
environ['wsgi.input'].read(), environ['wsgi.input'].read(),
]))), ]))),
'123456789abcd\r\nef0', b'123456789abcd\r\nef0',
) )
self.assertEqual( self.assertEqual(
getBody(run( getBody(run(
...@@ -2309,7 +2325,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2309,7 +2325,7 @@ class CaucaseTest(unittest.TestCase):
environ['wsgi.input'].read(6), environ['wsgi.input'].read(6),
environ['wsgi.input'].read(), environ['wsgi.input'].read(),
]))), ]))),
'123456789abcd\r\nef0', b'123456789abcd\r\nef0',
) )
self.assertEqual( self.assertEqual(
getBody(run( getBody(run(
...@@ -2317,44 +2333,44 @@ class CaucaseTest(unittest.TestCase): ...@@ -2317,44 +2333,44 @@ class CaucaseTest(unittest.TestCase):
DummyApp(lambda environ: [ DummyApp(lambda environ: [
environ['wsgi.input'].read(32), environ['wsgi.input'].read(32),
]))), ]))),
'123456789abcd\r\nef0', b'123456789abcd\r\nef0',
) )
self.assertEqual( self.assertEqual(
getStatus(run([ getStatus(run([
'PUT / HTTP/1.1', b'PUT / HTTP/1.1',
'Transfer-Encoding: chunked', b'Transfer-Encoding: chunked',
'', b'',
'1', b'1',
'abc', # Chunk longer than advertised in header. b'abc', # Chunk longer than advertised in header.
], read_app)), ], read_app)),
500, 500,
) )
self.assertEqual( self.assertEqual(
getStatus(run([ getStatus(run([
'PUT / HTTP/1.1', b'PUT / HTTP/1.1',
'Transfer-Encoding: chunked', b'Transfer-Encoding: chunked',
'', b'',
'y', # Not a valid chunk header b'y', # Not a valid chunk header
], read_app)), ], read_app)),
500, 500,
) )
self.assertEqual( self.assertEqual(
getStatus(run([ getStatus(run([
'PUT / HTTP/1.1', b'PUT / HTTP/1.1',
'Transfer-Encoding: chunked', b'Transfer-Encoding: chunked',
'', b'',
'f;' + 'a' * 65537, # header too long b'f;' + b'a' * 65537, # header too long
], read_app)), ], read_app)),
500, 500,
) )
self.assertEqual( self.assertEqual(
getStatus(run([ getStatus(run([
'PUT / HTTP/1.1', b'PUT / HTTP/1.1',
'Transfer-Encoding: chunked', b'Transfer-Encoding: chunked',
'', b'',
'0', b'0',
'a' * 65537, # trailer too long b'a' * 65537, # trailer too long
], read_app)), ], read_app)),
500, 500,
) )
...@@ -2580,5 +2596,10 @@ class CaucaseTest(unittest.TestCase): ...@@ -2580,5 +2596,10 @@ class CaucaseTest(unittest.TestCase):
self.assertEqual(os.stat(self._server_db).st_mode & 0o777, 0o600) self.assertEqual(os.stat(self._server_db).st_mode & 0o777, 0o600)
self.assertEqual(os.stat(self._server_key).st_mode & 0o777, 0o600) self.assertEqual(os.stat(self._server_key).st_mode & 0o777, 0o600)
if getattr(CaucaseTest, 'assertItemsEqual', None) is None:
# Because python3 decided it should be named differently, and 2to3 cannot
# pick it up, and this code must remain python2-compatible... Yay !
CaucaseTest.assertItemsEqual = CaucaseTest.assertCountEqual
if __name__ == '__main__': # pragma: no cover if __name__ == '__main__': # pragma: no cover
unittest.main() unittest.main()
...@@ -21,6 +21,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices ...@@ -21,6 +21,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
Small-ish functions needed in many places. Small-ish functions needed in many places.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from binascii import a2b_base64, b2a_base64
from collections import defaultdict from collections import defaultdict
import datetime import datetime
import json import json
...@@ -273,19 +274,20 @@ def wrap(payload, key, digest): ...@@ -273,19 +274,20 @@ def wrap(payload, key, digest):
""" """
Sign payload (which gets json-serialised) with key, using given digest. Sign payload (which gets json-serialised) with key, using given digest.
""" """
payload = json.dumps(payload).encode('utf-8') payload = toBytes(json.dumps(payload), 'utf-8')
hash_class = getattr(hashes, digest.upper()) hash_class = getattr(hashes, digest.upper())
return { return {
'payload': payload, 'payload': toUnicode(payload),
'digest': digest, 'digest': digest,
'signature': key.sign( # For some reason, python3 thinks that a b2a method should return bytes.
payload + digest + ' ', 'signature': toUnicode(b2a_base64(key.sign(
payload + toBytes(digest) + b' ',
padding.PSS( padding.PSS(
mgf=padding.MGF1(hash_class()), mgf=padding.MGF1(hash_class()),
salt_length=padding.PSS.MAX_LENGTH, salt_length=padding.PSS.MAX_LENGTH,
), ),
hash_class(), hash_class(),
).encode('base64'), ))),
} }
def nullWrap(payload): def nullWrap(payload):
...@@ -308,10 +310,10 @@ def unwrap(wrapped, getCertificate, digest_list): ...@@ -308,10 +310,10 @@ def unwrap(wrapped, getCertificate, digest_list):
Note: does *not* verify received certificate itself (validity, issuer, ...). Note: does *not* verify received certificate itself (validity, issuer, ...).
""" """
# Check whether given digest is allowed # Check whether given digest is allowed
digest = wrapped['digest'].encode('ascii') digest = wrapped['digest']
if digest not in digest_list: if digest not in digest_list:
raise cryptography.exceptions.UnsupportedAlgorithm( raise cryptography.exceptions.UnsupportedAlgorithm(
'%r is not in allowed digest list', '%r is not in allowed digest list %r' % (digest, digest_list),
) )
hash_class = getattr(hashes, digest.upper()) hash_class = getattr(hashes, digest.upper())
try: try:
...@@ -319,11 +321,11 @@ def unwrap(wrapped, getCertificate, digest_list): ...@@ -319,11 +321,11 @@ def unwrap(wrapped, getCertificate, digest_list):
except ValueError: except ValueError:
raise NotJSON raise NotJSON
x509.load_pem_x509_certificate( x509.load_pem_x509_certificate(
getCertificate(payload).encode('ascii'), toBytes(getCertificate(payload)),
_cryptography_backend, _cryptography_backend,
).public_key().verify( ).public_key().verify(
wrapped['signature'].encode('ascii').decode('base64'), a2b_base64(toBytes(wrapped['signature'])),
wrapped['payload'].encode('utf-8') + digest + ' ', toBytes(wrapped['payload'], 'utf-8') + toBytes(digest) + b' ',
padding.PSS( padding.PSS(
mgf=padding.MGF1(hash_class()), mgf=padding.MGF1(hash_class()),
salt_length=padding.PSS.MAX_LENGTH, salt_length=padding.PSS.MAX_LENGTH,
...@@ -445,6 +447,18 @@ class SleepInterrupt(KeyboardInterrupt): ...@@ -445,6 +447,18 @@ class SleepInterrupt(KeyboardInterrupt):
""" """
pass pass
def toUnicode(value, encoding='ascii'):
"""
Convert value to unicode object, if it is not already.
"""
return value if isinstance(value, unicode) else value.decode(encoding)
def toBytes(value, encoding='ascii'):
"""
Convert valye to bytes object, if it is not already.
"""
return value if isinstance(value, bytes) else value.encode(encoding)
def interruptibleSleep(duration): # pragma: no cover def interruptibleSleep(duration): # pragma: no cover
""" """
Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt
......
...@@ -19,11 +19,11 @@ ...@@ -19,11 +19,11 @@
Caucase - Certificate Authority for Users, Certificate Authority for SErvices Caucase - Certificate Authority for Users, Certificate Authority for SErvices
""" """
from __future__ import absolute_import from __future__ import absolute_import
from cgi import escape
from Cookie import SimpleCookie, CookieError from Cookie import SimpleCookie, CookieError
import httplib import httplib
import json import json
import os import os
import sys
import threading import threading
import time import time
import traceback import traceback
...@@ -34,10 +34,15 @@ import jwt ...@@ -34,10 +34,15 @@ import jwt
from . import utils from . import utils
from . import exceptions from . import exceptions
if sys.version_info >= (3, ): # pragma: no cover
from html import escape
else: # pragma: no cover
from cgi import escape
__all__ = ('Application', 'CORSTokenManager') __all__ = ('Application', 'CORSTokenManager')
# TODO: l10n # TODO: l10n
CORS_FORM_TEMPLATE = '''\ CORS_FORM_TEMPLATE = b'''\
<html> <html>
<head> <head>
<title>Caucase CORS access</title> <title>Caucase CORS access</title>
...@@ -213,11 +218,11 @@ class CORSTokenManager(object): ...@@ -213,11 +218,11 @@ class CORSTokenManager(object):
key = os.urandom(32) key = os.urandom(32)
secret_list.append((now + self._secret_validity_period, key)) secret_list.append((now + self._secret_validity_period, key))
self._onNewKey(secret_list) self._onNewKey(secret_list)
return jwt.encode( return utils.toUnicode(jwt.encode(
payload={'p': payload}, payload={'p': payload},
key=key, key=key,
algorithm='HS256', algorithm='HS256',
) ))
def verify(self, token, default=None): def verify(self, token, default=None):
""" """
...@@ -571,7 +576,7 @@ class Application(object): ...@@ -571,7 +576,7 @@ class Application(object):
except exceptions.NoStorage: except exceptions.NoStorage:
raise InsufficientStorage raise InsufficientStorage
except exceptions.NotJSON: except exceptions.NotJSON:
raise BadRequest('Invalid json payload') raise BadRequest(b'Invalid json payload')
except exceptions.CertificateAuthorityException as e: except exceptions.CertificateAuthorityException as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
except Exception: except Exception:
...@@ -581,7 +586,7 @@ class Application(object): ...@@ -581,7 +586,7 @@ class Application(object):
except ApplicationError as e: except ApplicationError as e:
status = e.status status = e.status
header_list = e.response_headers header_list = e.response_headers
result = [str(x) for x in e.args] result = [utils.toBytes(str(x)) for x in e.args]
# Note: header_list and cors_header_list are expected to contain # Note: header_list and cors_header_list are expected to contain
# distinct header sets. This may not always stay true for "Vary". # distinct header sets. This may not always stay true for "Vary".
header_list.extend(cors_header_list) header_list.extend(cors_header_list)
...@@ -605,7 +610,7 @@ class Application(object): ...@@ -605,7 +610,7 @@ class Application(object):
try: try:
return int(crt_id) return int(crt_id)
except ValueError: except ValueError:
raise BadRequest('Invalid integer') raise BadRequest(b'Invalid integer')
@staticmethod @staticmethod
def _read(environ): def _read(environ):
...@@ -619,9 +624,9 @@ class Application(object): ...@@ -619,9 +624,9 @@ class Application(object):
try: try:
length = int(environ.get('CONTENT_LENGTH') or MAX_BODY_LENGTH) length = int(environ.get('CONTENT_LENGTH') or MAX_BODY_LENGTH)
except ValueError: except ValueError:
raise BadRequest('Invalid Content-Length') raise BadRequest(b'Invalid Content-Length')
if length > MAX_BODY_LENGTH: if length > MAX_BODY_LENGTH:
raise TooLarge('Content-Length limit exceeded') raise TooLarge(b'Content-Length limit exceeded')
return environ['wsgi.input'].read(length) return environ['wsgi.input'].read(length)
def _authenticate(self, environ, header_list): def _authenticate(self, environ, header_list):
...@@ -653,12 +658,12 @@ class Application(object): ...@@ -653,12 +658,12 @@ class Application(object):
json decoding fails. json decoding fails.
""" """
if environ.get('CONTENT_TYPE') != 'application/json': if environ.get('CONTENT_TYPE') != 'application/json':
raise BadRequest('Bad Content-Type') raise BadRequest(b'Bad Content-Type')
data = self._read(environ) data = self._read(environ)
try: try:
return json.loads(data) return json.loads(data)
except ValueError: except ValueError:
raise BadRequest('Invalid json') raise BadRequest(b'Invalid json')
def _createCORSCookie(self, environ, value): def _createCORSCookie(self, environ, value):
""" """
...@@ -859,7 +864,10 @@ class Application(object): ...@@ -859,7 +864,10 @@ class Application(object):
name = action['name'] name = action['name']
assert name not in hal_section_dict, name assert name not in hal_section_dict, name
hal_section_dict[name] = descriptor_dict hal_section_dict[name] = descriptor_dict
return self._returnFile(json.dumps(hal), 'application/hal+json') return self._returnFile(
utils.toBytes(json.dumps(hal)),
'application/hal+json',
)
def getCORSForm(self, context, environ): # pylint: disable=unused-argument def getCORSForm(self, context, environ): # pylint: disable=unused-argument
""" """
...@@ -881,9 +889,9 @@ class Application(object): ...@@ -881,9 +889,9 @@ class Application(object):
raise BadRequest raise BadRequest
return self._returnFile( return self._returnFile(
CORS_FORM_TEMPLATE % { CORS_FORM_TEMPLATE % {
'caucase': escape(self._http_url, quote=True), b'caucase': utils.toBytes(escape(self._http_url, quote=True)),
'return_to': escape(return_to, quote=True), b'return_to': utils.toBytes(escape(return_to, quote=True)),
'origin': escape(origin, quote=True), b'origin': utils.toBytes(escape(origin, quote=True)),
}, },
'text/html', 'text/html',
[ [
...@@ -902,7 +910,7 @@ class Application(object): ...@@ -902,7 +910,7 @@ class Application(object):
if environ['wsgi.url_scheme'] != 'https': if environ['wsgi.url_scheme'] != 'https':
raise NotFound raise NotFound
if environ.get('CONTENT_TYPE') != 'application/x-www-form-urlencoded': if environ.get('CONTENT_TYPE') != 'application/x-www-form-urlencoded':
raise BadRequest('Unhandled Content-Type') raise BadRequest(b'Unhandled Content-Type')
try: try:
form_dict = parse_qs(self._read(environ), strict_parsing=True) form_dict = parse_qs(self._read(environ), strict_parsing=True)
origin, = form_dict['origin'] origin, = form_dict['origin']
...@@ -961,7 +969,7 @@ class Application(object): ...@@ -961,7 +969,7 @@ class Application(object):
header_list = [] header_list = []
self._authenticate(environ, header_list) self._authenticate(environ, header_list)
return self._returnFile( return self._returnFile(
json.dumps(context.getCertificateRequestList()), utils.toBytes(json.dumps(context.getCertificateRequestList())),
'application/json', 'application/json',
header_list, header_list,
) )
...@@ -973,7 +981,7 @@ class Application(object): ...@@ -973,7 +981,7 @@ class Application(object):
try: try:
csr_id = context.appendCertificateSigningRequest(self._read(environ)) csr_id = context.appendCertificateSigningRequest(self._read(environ))
except exceptions.NotACertificateSigningRequest: except exceptions.NotACertificateSigningRequest:
raise BadRequest('Not a valid certificate signing request') raise BadRequest(b'Not a valid certificate signing request')
return (STATUS_CREATED, [('Location', str(csr_id))], []) return (STATUS_CREATED, [('Location', str(csr_id))], [])
def deletePendingCertificateRequest(self, context, environ, subpath): def deletePendingCertificateRequest(self, context, environ, subpath):
...@@ -1013,7 +1021,7 @@ class Application(object): ...@@ -1013,7 +1021,7 @@ class Application(object):
Handle GET /{context}/crt/ca.crt.json urls. Handle GET /{context}/crt/ca.crt.json urls.
""" """
return self._returnFile( return self._returnFile(
json.dumps(context.getValidCACertificateChain()), utils.toBytes(json.dumps(context.getValidCACertificateChain())),
'application/json', 'application/json',
) )
...@@ -1050,7 +1058,7 @@ class Application(object): ...@@ -1050,7 +1058,7 @@ class Application(object):
context.digest_list, context.digest_list,
) )
context.revoke( context.revoke(
crt_pem=payload['revoke_crt_pem'].encode('ascii'), crt_pem=utils.toBytes(payload['revoke_crt_pem']),
) )
return (STATUS_NO_CONTENT, header_list, []) return (STATUS_NO_CONTENT, header_list, [])
...@@ -1065,8 +1073,8 @@ class Application(object): ...@@ -1065,8 +1073,8 @@ class Application(object):
) )
return self._returnFile( return self._returnFile(
context.renew( context.renew(
crt_pem=payload['crt_pem'].encode('ascii'), crt_pem=utils.toBytes(payload['crt_pem']),
csr_pem=payload['renew_csr_pem'].encode('ascii'), csr_pem=utils.toBytes(payload['renew_csr_pem']),
), ),
'application/pkix-cert', 'application/pkix-cert',
) )
...@@ -1084,7 +1092,7 @@ class Application(object): ...@@ -1084,7 +1092,7 @@ class Application(object):
elif environ.get('CONTENT_TYPE') == 'application/pkcs10': elif environ.get('CONTENT_TYPE') == 'application/pkcs10':
template_csr = utils.load_certificate_request(body) template_csr = utils.load_certificate_request(body)
else: else:
raise BadRequest('Bad Content-Type') raise BadRequest(b'Bad Content-Type')
header_list = [] header_list = []
self._authenticate(environ, header_list) self._authenticate(environ, header_list)
context.createCertificate( context.createCertificate(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment