Commit e9de51f0 authored by Vincent Pelletier's avatar Vincent Pelletier

all: Finalise python3 support.

Basically, wrap stdout and stderr whenever they do not have an encoding
with an ascii-encoding writer, and write unicode to stdout & stderr.
wsgi.errors is defined in the reference implementation as being a StringIO,
so follow that.
Stop using argparse.FileType to get rid of python3 "file not closed"
errors.
Also, fix setup access to CHANGES.txt .
Also, fix 2to3 involvement.
Also, replace test.captureStdout with extra tool arguments.
parent e9cd6586
...@@ -51,7 +51,7 @@ class RetryingCaucaseClient(CaucaseClient): ...@@ -51,7 +51,7 @@ class RetryingCaucaseClient(CaucaseClient):
Retries every 10 seconds. Retries every 10 seconds.
""" """
_until = staticmethod(utils.until) _until = staticmethod(utils.until)
_log_file = sys.stdout _log_file = utils.toUnicodeWritableStream(sys.stdout)
def _request(self, connection, method, url, body=None, headers=None): def _request(self, connection, method, url, body=None, headers=None):
while True: while True:
...@@ -79,10 +79,10 @@ class RetryingCaucaseClient(CaucaseClient): ...@@ -79,10 +79,10 @@ class RetryingCaucaseClient(CaucaseClient):
# letting non-printable characters through. # letting non-printable characters through.
next_try = datetime.datetime.utcnow() + datetime.timedelta(0, 10) next_try = datetime.datetime.utcnow() + datetime.timedelta(0, 10)
print( print(
'Got a network error, retrying at %s, %s: %r' % ( u'Got a network error, retrying at %s, %s: %r' % (
next_try.strftime(b'%Y-%m-%d %H:%M:%S +0000'), next_try.strftime(u'%Y-%m-%d %H:%M:%S +0000'),
exception.__class__.__name__, exception.__class__.__name__,
str(exception), unicode(exception),
), ),
file=self._log_file, file=self._log_file,
) )
...@@ -104,8 +104,14 @@ class CLICaucaseClient(object): ...@@ -104,8 +104,14 @@ class CLICaucaseClient(object):
# Note: this class it more to reduce local variable scopes (avoiding # Note: this class it more to reduce local variable scopes (avoiding
# accidental mixups) in each methods than about API declaration. # accidental mixups) in each methods than about API declaration.
def __init__(self, client): def __init__(self, client, stdout, stderr):
self._client = client self._client = client
self._stdout = stdout
self._stderr = stderr
def _print(self, *args, **kw):
kw.setdefault('file', self._stdout)
print(*args, **kw)
def putCSR(self, csr_path_list): def putCSR(self, csr_path_list):
""" """
...@@ -115,9 +121,9 @@ class CLICaucaseClient(object): ...@@ -115,9 +121,9 @@ class CLICaucaseClient(object):
csr_pem = utils.getCertRequest(csr_path) csr_pem = utils.getCertRequest(csr_path)
# Quick sanity check # Quick sanity check
utils.load_certificate_request(csr_pem) utils.load_certificate_request(csr_pem)
print( self._print(
self._client.createCertificateSigningRequest(csr_pem), self._client.createCertificateSigningRequest(csr_pem),
utils.toBytes(csr_path), csr_path,
) )
def getCSR(self, csr_id_path_list): def getCSR(self, csr_id_path_list):
...@@ -145,41 +151,41 @@ class CLICaucaseClient(object): ...@@ -145,41 +151,41 @@ class CLICaucaseClient(object):
except CaucaseError as e: except CaucaseError as e:
if e.args[0] != httplib.NOT_FOUND: if e.args[0] != httplib.NOT_FOUND:
raise raise
print(crt_id, b'not found - maybe CSR was rejected ?') self._print(crt_id, 'not found - maybe CSR was rejected ?')
error = True error = True
else: else:
print(crt_id, b'CSR still pending') self._print(crt_id, 'CSR still pending')
warning = True warning = True
else: else:
print(crt_id, end=' ') self._print(crt_id, end=' ')
if utils.isCertificateAutoSigned(utils.load_certificate( if utils.isCertificateAutoSigned(utils.load_certificate(
crt_pem, crt_pem,
ca_list, ca_list,
None, None,
)): )):
print(b'was (originally) automatically approved') self._print('was (originally) automatically approved')
else: else:
print(b'was (originally) manually approved') self._print('was (originally) manually approved')
if os.path.exists(crt_path): if os.path.exists(crt_path):
try: try:
key_pem = utils.getKey(crt_path) key_pem = utils.getKey(crt_path)
except ValueError: except ValueError:
print( self._print(
b'Expected to find exactly one privatekey key in %s, skipping' % ( 'Expected to find exactly one privatekey key in %s, skipping' % (
crt_path, crt_path,
), ),
file=sys.stderr, file=self._stderr,
) )
error = True error = True
continue continue
try: try:
utils.validateCertAndKey(crt_pem, key_pem) utils.validateCertAndKey(crt_pem, key_pem)
except ValueError: except ValueError:
print( self._print(
b'Key in %s does not match retrieved certificate, skipping' % ( 'Key in %s does not match retrieved certificate, skipping' % (
crt_path, crt_path,
), ),
file=sys.stderr, file=self._stderr,
) )
error = True error = True
continue continue
...@@ -195,11 +201,11 @@ class CLICaucaseClient(object): ...@@ -195,11 +201,11 @@ class CLICaucaseClient(object):
try: try:
crt, key, _ = utils.getKeyPair(crt_path, key_path) crt, key, _ = utils.getKeyPair(crt_path, key_path)
except ValueError: except ValueError:
print( self._print(
b'Could not find (exactly) one matching key pair in %s, skipping' % ( 'Could not find (exactly) one matching key pair in %s, skipping' % (
[x for x in set((crt_path, key_path)) if x], [x for x in set((crt_path, key_path)) if x],
), ),
file=sys.stderr, file=self._stderr,
) )
error = True error = True
continue continue
...@@ -225,11 +231,11 @@ class CLICaucaseClient(object): ...@@ -225,11 +231,11 @@ class CLICaucaseClient(object):
key_path, key_path,
) )
except ValueError: except ValueError:
print( self._print(
b'Could not find (exactly) one matching key pair in %s, skipping' % ( 'Could not find (exactly) one matching key pair in %s, skipping' % (
[x for x in set((crt_path, key_path)) if x], [x for x in set((crt_path, key_path)) if x],
), ),
file=sys.stderr, file=self._stderr,
) )
error = True error = True
continue continue
...@@ -240,13 +246,13 @@ class CLICaucaseClient(object): ...@@ -240,13 +246,13 @@ class CLICaucaseClient(object):
None, None,
) )
except exceptions.CertificateVerificationError: except exceptions.CertificateVerificationError:
print( self._print(
crt_path, crt_path,
b'was not signed by this CA, revoked or otherwise invalid, skipping', 'was not signed by this CA, revoked or otherwise invalid, skipping',
) )
continue continue
if renewal_deadline < old_crt.not_valid_after: if renewal_deadline < old_crt.not_valid_after:
print(crt_path, b'did not reach renew threshold, not renewing') self._print(crt_path, 'did not reach renew threshold, not renewing')
continue continue
new_key_pem, new_crt_pem = self._client.renewCertificate( new_key_pem, new_crt_pem = self._client.renewCertificate(
old_crt=old_crt, old_crt=old_crt,
...@@ -274,22 +280,22 @@ class CLICaucaseClient(object): ...@@ -274,22 +280,22 @@ class CLICaucaseClient(object):
""" """
--list-csr --list-csr
""" """
print(b'-- pending', mode, b'CSRs --') self._print('-- pending', mode, 'CSRs --')
print( self._print(
b'%20s | %s' % ( '%20s | %s' % (
b'csr_id', 'csr_id',
b'subject preview (fetch csr and check full content !)', 'subject preview (fetch csr and check full content !)',
), ),
) )
for entry in self._client.getPendingCertificateRequestList(): for entry in self._client.getPendingCertificateRequestList():
csr = utils.load_certificate_request(utils.toBytes(entry['csr'])) csr = utils.load_certificate_request(utils.toBytes(entry['csr']))
print( self._print(
b'%20s | %r' % ( '%20s | %r' % (
utils.toBytes(entry['id']), entry['id'],
utils.toBytes(repr(csr.subject)), repr(csr.subject),
), ),
) )
print(b'-- end of pending', mode, b'CSRs --') self._print('-- end of pending', mode, 'CSRs --')
def signCSR(self, csr_id_list): def signCSR(self, csr_id_list):
""" """
...@@ -332,11 +338,11 @@ class CLICaucaseClient(object): ...@@ -332,11 +338,11 @@ class CLICaucaseClient(object):
# authenticated revocations). # authenticated revocations).
crt_pem = utils.getCert(crt_path) crt_pem = utils.getCert(crt_path)
except ValueError: except ValueError:
print( self._print(
b'Could not load a single certificate in %s, skipping' % ( 'Could not load a single certificate in %s, skipping' % (
crt_path, crt_path,
), ),
file=sys.stderr, file=self._stderr,
) )
self._client.revokeCertificate(crt_pem) self._client.revokeCertificate(crt_pem)
return error return error
...@@ -348,7 +354,7 @@ class CLICaucaseClient(object): ...@@ -348,7 +354,7 @@ class CLICaucaseClient(object):
for serial in serial_list: for serial in serial_list:
self._client.revokeSerial(serial) self._client.revokeSerial(serial)
def main(argv=None): def main(argv=None, stdout=sys.stdout, stderr=sys.stderr):
""" """
Command line caucase client entry point. Command line caucase client entry point.
""" """
...@@ -547,6 +553,8 @@ def main(argv=None): ...@@ -547,6 +553,8 @@ def main(argv=None):
'Use --revoke and --revoke-other-crt whenever possible.', 'Use --revoke and --revoke-other-crt whenever possible.',
) )
args = parser.parse_args(argv) args = parser.parse_args(argv)
stdout = utils.toUnicodeWritableStream(stdout)
stderr = utils.toUnicodeWritableStream(stderr)
sign_csr_id_set = set(args.sign_csr) sign_csr_id_set = set(args.sign_csr)
sign_with_csr_id_set = {x for x, _ in args.sign_csr_with} sign_with_csr_id_set = {x for x, _ in args.sign_csr_with}
...@@ -556,9 +564,9 @@ def main(argv=None): ...@@ -556,9 +564,9 @@ def main(argv=None):
sign_csr_id_set.intersection(sign_with_csr_id_set) sign_csr_id_set.intersection(sign_with_csr_id_set)
): ):
print( print(
b'A given CSR_ID cannot be in more than one of --sign-csr, ' 'A given CSR_ID cannot be in more than one of --sign-csr, '
b'--sign-csr-with and --reject-csr', '--sign-csr-with and --reject-csr',
file=sys.stderr, file=stderr,
) )
raise SystemExit(STATUS_ERROR) raise SystemExit(STATUS_ERROR)
...@@ -585,6 +593,8 @@ def main(argv=None): ...@@ -585,6 +593,8 @@ def main(argv=None):
ca_crt_pem_list=utils.getCertList(args.ca_crt), ca_crt_pem_list=utils.getCertList(args.ca_crt),
user_key=args.user_key, user_key=args.user_key,
), ),
stdout=stdout,
stderr=stderr,
) )
ca_list = [ ca_list = [
utils.load_ca_certificate(x) utils.load_ca_certificate(x)
...@@ -802,12 +812,12 @@ def updater(argv=None, until=utils.until): ...@@ -802,12 +812,12 @@ def updater(argv=None, until=utils.until):
ca_crt_pem_list=utils.getCertList(args.cas_ca) ca_crt_pem_list=utils.getCertList(args.cas_ca)
) )
if args.crt and not utils.hasOneCert(args.crt): if args.crt and not utils.hasOneCert(args.crt):
print(b'Bootstraping...') print('Bootstraping...')
csr_pem = utils.getCertRequest(args.csr) csr_pem = utils.getCertRequest(args.csr)
# Quick sanity check before bothering server # Quick sanity check before bothering server
utils.load_certificate_request(csr_pem) utils.load_certificate_request(csr_pem)
csr_id = client.createCertificateSigningRequest(csr_pem) csr_id = client.createCertificateSigningRequest(csr_pem)
print(b'Waiting for signature of', csr_id) print('Waiting for signature of', csr_id)
while True: while True:
try: try:
crt_pem = client.getCertificate(csr_id) crt_pem = client.getCertificate(csr_id)
...@@ -825,12 +835,12 @@ def updater(argv=None, until=utils.until): ...@@ -825,12 +835,12 @@ def updater(argv=None, until=utils.until):
crt_file.write(crt_pem) crt_file.write(crt_pem)
updated = True updated = True
break break
print(b'Bootstrap done') print('Bootstrap done')
next_deadline = datetime.datetime.utcnow() next_deadline = datetime.datetime.utcnow()
while True: while True:
print( print(
b'Next wake-up at', 'Next wake-up at',
next_deadline.strftime(b'%Y-%m-%d %H:%M:%S +0000'), next_deadline.strftime('%Y-%m-%d %H:%M:%S +0000'),
) )
now = until(next_deadline) now = until(next_deadline)
next_deadline = now + max_sleep next_deadline = now + max_sleep
...@@ -843,7 +853,7 @@ def updater(argv=None, until=utils.until): ...@@ -843,7 +853,7 @@ def updater(argv=None, until=utils.until):
ca_crt_pem_list=utils.getCertList(args.cas_ca) ca_crt_pem_list=utils.getCertList(args.cas_ca)
) )
if RetryingCaucaseClient.updateCAFile(ca_url, args.ca): if RetryingCaucaseClient.updateCAFile(ca_url, args.ca):
print(b'Got new CA') print('Got new CA')
updated = True updated = True
# Note: CRL expiration should happen several time during CA renewal # Note: CRL expiration should happen several time during CA renewal
# period, so it should not be needed to keep track of CA expiration # period, so it should not be needed to keep track of CA expiration
...@@ -853,7 +863,7 @@ def updater(argv=None, until=utils.until): ...@@ -853,7 +863,7 @@ def updater(argv=None, until=utils.until):
for x in utils.getCertList(args.ca) for x in utils.getCertList(args.ca)
] ]
if RetryingCaucaseClient.updateCRLFile(ca_url, args.crl, ca_crt_list): if RetryingCaucaseClient.updateCRLFile(ca_url, args.crl, ca_crt_list):
print(b'Got new CRL') print('Got new CRL')
updated = True updated = True
with open(args.crl, 'rb') as crl_file: with open(args.crl, 'rb') as crl_file:
next_deadline = min( next_deadline = min(
...@@ -867,7 +877,7 @@ def updater(argv=None, until=utils.until): ...@@ -867,7 +877,7 @@ def updater(argv=None, until=utils.until):
crt_pem, key_pem, key_path = utils.getKeyPair(args.crt, args.key) crt_pem, key_pem, key_path = utils.getKeyPair(args.crt, args.key)
crt = utils.load_certificate(crt_pem, ca_crt_list, None) crt = utils.load_certificate(crt_pem, ca_crt_list, None)
if crt.not_valid_after - threshold <= now: if crt.not_valid_after - threshold <= now:
print(b'Renewing', args.crt) print('Renewing', args.crt)
new_key_pem, new_crt_pem = client.renewCertificate( new_key_pem, new_crt_pem = client.renewCertificate(
old_crt=crt, old_crt=crt,
old_key=utils.load_privatekey(key_pem), old_key=utils.load_privatekey(key_pem),
...@@ -901,7 +911,7 @@ def updater(argv=None, until=utils.until): ...@@ -901,7 +911,7 @@ def updater(argv=None, until=utils.until):
if args.on_renew is not None: if args.on_renew is not None:
status = os.system(args.on_renew) status = os.system(args.on_renew)
if status: if status:
print(b'Renewal hook exited with status:', status, file=sys.stderr) print('Renewal hook exited with status:', status, file=sys.stderr)
raise SystemExit(STATUS_ERROR) raise SystemExit(STATUS_ERROR)
updated = False updated = False
except (utils.SleepInterrupt, SystemExit): except (utils.SleepInterrupt, SystemExit):
...@@ -972,7 +982,7 @@ def rerequest(argv=None): ...@@ -972,7 +982,7 @@ def rerequest(argv=None):
with open(args.csr, 'wb') as csr_file: with open(args.csr, 'wb') as csr_file:
csr_file.write(csr_pem) csr_file.write(csr_pem)
def key_id(argv=None): def key_id(argv=None, stdout=sys.stdout):
""" """
Displays key identifier from private key, and the list of acceptable key Displays key identifier from private key, and the list of acceptable key
identifiers for a given backup file. identifiers for a given backup file.
...@@ -1001,6 +1011,7 @@ def key_id(argv=None): ...@@ -1001,6 +1011,7 @@ def key_id(argv=None):
'identifiers of.', 'identifiers of.',
) )
args = parser.parse_args(argv) args = parser.parse_args(argv)
stdout = utils.toUnicodeWritableStream(stdout)
for key_path in args.private_key: for key_path in args.private_key:
with open(key_path, 'rb') as key_file: with open(key_path, 'rb') as key_file:
print( print(
...@@ -1010,9 +1021,10 @@ def key_id(argv=None): ...@@ -1010,9 +1021,10 @@ def key_id(argv=None):
utils.load_privatekey(key_file.read()).public_key(), utils.load_privatekey(key_file.read()).public_key(),
).digest, ).digest,
), ),
file=stdout,
) )
for backup_path in args.backup: for backup_path in args.backup:
print(backup_path) print(backup_path, file=stdout)
with open(backup_path, 'rb') as backup_file: with open(backup_path, 'rb') as backup_file:
magic = backup_file.read(8) magic = backup_file.read(8)
if magic != b'caucase\0': if magic != b'caucase\0':
...@@ -1022,4 +1034,4 @@ def key_id(argv=None): ...@@ -1022,4 +1034,4 @@ def key_id(argv=None):
backup_file.read(struct.calcsize('<I')), backup_file.read(struct.calcsize('<I')),
) )
for key_entry in json.loads(backup_file.read(header_len))['key_list']: for key_entry in json.loads(backup_file.read(header_len))['key_list']:
print(b' ', key_entry['id'].encode('utf-8')) print(' ', key_entry['id'].encode('utf-8'), file=stdout)
...@@ -173,8 +173,12 @@ class CaucaseWSGIRequestHandler(WSGIRequestHandler): ...@@ -173,8 +173,12 @@ class CaucaseWSGIRequestHandler(WSGIRequestHandler):
remote_user_name = '-' remote_user_name = '-'
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
self._log_file = kw.pop('log_file', sys.stdout) self._log_file = utils.toUnicodeWritableStream(
self._error_file = kw.pop('error_file', sys.stderr) kw.pop('log_file', sys.stdout),
)
self._error_file = utils.toUnicodeWritableStream(
kw.pop('error_file', sys.stderr),
)
WSGIRequestHandler.__init__(self, *args, **kw) WSGIRequestHandler.__init__(self, *args, **kw)
def log_date_time_string(self): def log_date_time_string(self):
...@@ -580,8 +584,10 @@ def main( ...@@ -580,8 +584,10 @@ def main(
help='Number of days between backups. default: %(default)s', help='Number of days between backups. default: %(default)s',
) )
args = parser.parse_args(argv) args = parser.parse_args(argv)
log_file = utils.toUnicodeWritableStream(log_file)
error_file = utils.toUnicodeWritableStream(error_file)
base_url = u'http://' + utils.toUnicode(args.netloc) base_url = 'http://' + utils.toUnicode(args.netloc)
parsed_base_url = urlparse(base_url) parsed_base_url = urlparse(base_url)
hostname = parsed_base_url.hostname hostname = parsed_base_url.hostname
name_constraints_permited = [] name_constraints_permited = []
...@@ -881,7 +887,7 @@ def main( ...@@ -881,7 +887,7 @@ def main(
server.server_close() server.server_close()
server.shutdown() server.shutdown()
def manage(argv=None): def manage(argv=None, stdout=sys.stdout):
""" """
caucased database management tool. caucased database management tool.
""" """
...@@ -922,7 +928,6 @@ def manage(argv=None): ...@@ -922,7 +928,6 @@ def manage(argv=None):
default=[], default=[],
metavar='PEM_FILE', metavar='PEM_FILE',
action='append', action='append',
type=argparse.FileType('rb'),
help='Import key pairs as initial service CA certificate. ' help='Import key pairs as initial service CA certificate. '
'May be provided multiple times to import multiple key pairs. ' 'May be provided multiple times to import multiple key pairs. '
'Keys and certificates may be in separate files. ' 'Keys and certificates may be in separate files. '
...@@ -948,7 +953,6 @@ def manage(argv=None): ...@@ -948,7 +953,6 @@ def manage(argv=None):
default=[], default=[],
metavar='PEM_FILE', metavar='PEM_FILE',
action='append', action='append',
type=argparse.FileType('rb'),
help='Import service revocation list. Corresponding CA certificate must ' help='Import service revocation list. Corresponding CA certificate must '
'be already present in the database (including added in the same run ' 'be already present in the database (including added in the same run '
'using --import-ca).', 'using --import-ca).',
...@@ -956,11 +960,11 @@ def manage(argv=None): ...@@ -956,11 +960,11 @@ def manage(argv=None):
parser.add_argument( parser.add_argument(
'--export-ca', '--export-ca',
metavar='PEM_FILE', metavar='PEM_FILE',
type=argparse.FileType('wb'),
help='Export all CA certificates in a PEM file. Passphrase will be ' help='Export all CA certificates in a PEM file. Passphrase will be '
'prompted to protect all keys.', 'prompted to protect all keys.',
) )
args = parser.parse_args(argv) args = parser.parse_args(argv)
stdout = utils.toUnicodeWritableStream(stdout)
db_path = args.db db_path = args.db
if args.restore_backup: if args.restore_backup:
( (
...@@ -1008,9 +1012,11 @@ def manage(argv=None): ...@@ -1008,9 +1012,11 @@ def manage(argv=None):
import_ca_dict = defaultdict( import_ca_dict = defaultdict(
(lambda: {'crt': None, 'key': None, 'from': []}), (lambda: {'crt': None, 'key': None, 'from': []}),
) )
for ca_file in args.import_ca: for import_ca in args.import_ca:
for index, component in enumerate(pem.parse(ca_file.read())): with open(import_ca, 'rb') as ca_file:
name = '%r, block %i' % (ca_file.name, index) ca_data = ca_file.read()
for index, component in enumerate(pem.parse(ca_data)):
name = '%r, block %i' % (import_ca, index)
if isinstance(component, pem.Certificate): if isinstance(component, pem.Certificate):
component_name = 'crt' component_name = 'crt'
component_value = x509.load_pem_x509_certificate( component_value = x509.load_pem_x509_certificate(
...@@ -1053,11 +1059,16 @@ def manage(argv=None): ...@@ -1053,11 +1059,16 @@ def manage(argv=None):
found_from = ', '.join(ca_pair['from']) found_from = ', '.join(ca_pair['from'])
crt = ca_pair['crt'] crt = ca_pair['crt']
if crt is None: if crt is None:
print(b'No certificate correspond to', found_from, b'- skipping') print(
'No certificate correspond to',
found_from,
'- skipping',
file=stdout,
)
continue continue
expiration = utils.datetime2timestamp(crt.not_valid_after) expiration = utils.datetime2timestamp(crt.not_valid_after)
if expiration < now: if expiration < now:
print(b'Skipping expired certificate from', found_from) print('Skipping expired certificate from', found_from, file=stdout)
del import_ca_dict[identifier] del import_ca_dict[identifier]
continue continue
if not args.import_bad_ca: if not args.import_bad_ca:
...@@ -1076,11 +1087,16 @@ def manage(argv=None): ...@@ -1076,11 +1087,16 @@ def manage(argv=None):
or not key_usage.key_cert_sign or not key_usage.crl_sign or not key_usage.key_cert_sign or not key_usage.crl_sign
) )
if failed: if failed:
print(b'Skipping non-CA certificate from', found_from) print('Skipping non-CA certificate from', found_from, file=stdout)
continue continue
key = ca_pair['key'] key = ca_pair['key']
if key is None: if key is None:
print(b'No private key correspond to', found_from, b'- skipping') print(
'No private key correspond to',
found_from,
'- skipping',
file=stdout,
)
continue continue
imported += 1 imported += 1
cas_db.appendCAKeyPair( cas_db.appendCAKeyPair(
...@@ -1092,7 +1108,7 @@ def manage(argv=None): ...@@ -1092,7 +1108,7 @@ def manage(argv=None):
) )
if not imported: if not imported:
raise ValueError('No CA certificate imported') raise ValueError('No CA certificate imported')
print(b'Imported %i CA certificates' % imported) print('Imported %i CA certificates' % imported, file=stdout)
if args.import_crl: if args.import_crl:
db = SQLite3Storage(db_path, table_prefix='cas') db = SQLite3Storage(db_path, table_prefix='cas')
trusted_ca_crt_set = [ trusted_ca_crt_set = [
...@@ -1104,8 +1120,10 @@ def manage(argv=None): ...@@ -1104,8 +1120,10 @@ def manage(argv=None):
for x in trusted_ca_crt_set for x in trusted_ca_crt_set
) )
already_revoked_count = revoked_count = 0 already_revoked_count = revoked_count = 0
for crl_file in args.import_crl: for import_crl in args.import_crl:
for revoked in utils.load_crl(crl_file.read(), trusted_ca_crt_set): with open(import_crl, 'rb') as crl_file:
crl_data = crl_file.read()
for revoked in utils.load_crl(crl_data, trusted_ca_crt_set):
try: try:
db.revoke( db.revoke(
revoked.serial_number, revoked.serial_number,
...@@ -1115,28 +1133,31 @@ def manage(argv=None): ...@@ -1115,28 +1133,31 @@ def manage(argv=None):
already_revoked_count += 1 already_revoked_count += 1
else: else:
revoked_count += 1 revoked_count += 1
print(b'Revoked %i certificates (%i were already revoked)' % ( print(
revoked_count, 'Revoked %i certificates (%i were already revoked)' % (
already_revoked_count, revoked_count,
)) already_revoked_count,
),
file=stdout,
)
if args.export_ca is not None: if args.export_ca is not None:
encryption_algorithm = serialization.BestAvailableEncryption( encryption_algorithm = serialization.BestAvailableEncryption(
getBytePass('CA export passphrase: ') getBytePass('CA export passphrase: ')
) )
write = args.export_ca.write with open(args.export_ca, 'wb') as export_ca_file:
for key_pair in SQLite3Storage( write = export_ca_file.write
db_path, for key_pair in SQLite3Storage(
table_prefix='cas', db_path,
).getCAKeyPairList(): table_prefix='cas',
write( ).getCAKeyPairList():
key_pair['crt_pem'] + serialization.load_pem_private_key( write(
key_pair['key_pem'], key_pair['crt_pem'] + serialization.load_pem_private_key(
None, key_pair['key_pem'],
_cryptography_backend, None,
).private_bytes( _cryptography_backend,
encoding=serialization.Encoding.PEM, ).private_bytes(
format=serialization.PrivateFormat.TraditionalOpenSSL, encoding=serialization.Encoding.PEM,
encryption_algorithm=encryption_algorithm, format=serialization.PrivateFormat.TraditionalOpenSSL,
), encryption_algorithm=encryption_algorithm,
) ),
args.export_ca.close() )
...@@ -25,7 +25,6 @@ Test suite ...@@ -25,7 +25,6 @@ Test suite
""" """
# pylint: disable=too-many-lines, too-many-public-methods # pylint: disable=too-many-lines, too-many-public-methods
from __future__ import absolute_import from __future__ import absolute_import
import contextlib
from Cookie import SimpleCookie from Cookie import SimpleCookie
import datetime import datetime
import errno import errno
...@@ -305,25 +304,14 @@ def print_buffer_on_error(func): ...@@ -305,25 +304,14 @@ def print_buffer_on_error(func):
try: try:
return func(self, *args, **kw) return func(self, *args, **kw)
except Exception: # pragma: no cover except Exception: # pragma: no cover
sys.stdout.write(utils.toBytes(os.linesep)) stdout = utils.toUnicodeWritableStream(sys.stdout)
sys.stdout.write(self.caucase_test_output.getvalue()) stdout.write(os.linesep)
stdout.write(
self.caucase_test_output.getvalue().decode('ascii', 'replace'),
)
raise raise
return wrapper return wrapper
@contextlib.contextmanager
def captureStdout():
"""
Replace stdout with a BytesIO object for the duration of the context manager,
and provide it to caller.
"""
orig_stdout = sys.stdout
sys.stdout = stdout = BytesIO()
try:
yield stdout
finally:
sys.stdout = orig_stdout
@unittest.skipIf(sys.version_info >= (3, ), 'Caucase currently supports python 2 only')
class CaucaseTest(unittest.TestCase): class CaucaseTest(unittest.TestCase):
""" """
Test a complete caucase setup: spawn a caucase-http server on CAUCASE_NETLOC Test a complete caucase setup: spawn a caucase-http server on CAUCASE_NETLOC
...@@ -356,6 +344,9 @@ class CaucaseTest(unittest.TestCase): ...@@ -356,6 +344,9 @@ class CaucaseTest(unittest.TestCase):
self._server_backup_path = os.path.join(server_dir, 'backup') self._server_backup_path = os.path.join(server_dir, 'backup')
self._server_cors_store = os.path.join(server_dir, 'cors.key') self._server_cors_store = os.path.join(server_dir, 'cors.key')
# pylint: enable=bad-whitespace # pylint: enable=bad-whitespace
# Using a BytesIO for caucased output here, because stdout/stderr do not
# necessarily have a known encoding, for example when output is a pipe
# (to a file, ...). caucased must deal with this.
self.caucase_test_output = BytesIO() self.caucase_test_output = BytesIO()
os.mkdir(self._server_backup_path) os.mkdir(self._server_backup_path)
...@@ -497,7 +488,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -497,7 +488,7 @@ class CaucaseTest(unittest.TestCase):
) )
def _getClientCRL(self): def _getClientCRL(self):
with open(self._client_crl) as crl_pem_file: with open(self._client_crl, 'rb') as crl_pem_file:
return x509.load_pem_x509_crl( return x509.load_pem_x509_crl(
crl_pem_file.read(), crl_pem_file.read(),
_cryptography_backend _cryptography_backend
...@@ -612,20 +603,24 @@ class CaucaseTest(unittest.TestCase): ...@@ -612,20 +603,24 @@ class CaucaseTest(unittest.TestCase):
Returns stdout. Returns stdout.
""" """
with captureStdout() as stdout: # Using a BytesIO for caucased output here, because stdout/stderr do not
try: # necessarily have a known encoding, for example when output is a pipe
cli.main( # (to a file, ...). caucase must deal with this.
argv=( stdout = BytesIO()
'--ca-url', self._caucase_url, try:
'--ca-crt', self._client_ca_crt, cli.main(
'--user-ca-crt', self._client_user_ca_crt, argv=(
'--crl', self._client_crl, '--ca-url', self._caucase_url,
'--user-crl', self._client_user_crl, '--ca-crt', self._client_ca_crt,
) + argv, '--user-ca-crt', self._client_user_ca_crt,
) '--crl', self._client_crl,
except SystemExit: '--user-crl', self._client_user_crl,
pass ) + argv,
return stdout.getvalue() stdout=stdout,
)
except SystemExit:
pass
return stdout.getvalue().decode('ascii')
@staticmethod @staticmethod
def _setCertificateRemainingLifeTime(key, crt, delta): def _setCertificateRemainingLifeTime(key, crt, delta):
...@@ -1676,7 +1671,10 @@ class CaucaseTest(unittest.TestCase): ...@@ -1676,7 +1671,10 @@ class CaucaseTest(unittest.TestCase):
""" """
Non-standard shorthand for invoking the WSGI application. Non-standard shorthand for invoking the WSGI application.
""" """
environ.setdefault('wsgi.errors', self.caucase_test_output) environ.setdefault(
'wsgi.errors',
utils.toUnicodeWritableStream(self.caucase_test_output),
)
environ.setdefault('wsgi.url_scheme', 'http') environ.setdefault('wsgi.url_scheme', 'http')
environ.setdefault('SERVER_NAME', server_name) environ.setdefault('SERVER_NAME', server_name)
environ.setdefault('SERVER_PORT', str(server_http_port)) environ.setdefault('SERVER_PORT', str(server_http_port))
...@@ -2294,29 +2292,31 @@ class CaucaseTest(unittest.TestCase): ...@@ -2294,29 +2292,31 @@ class CaucaseTest(unittest.TestCase):
os.unlink(self._server_db) os.unlink(self._server_db)
os.unlink(self._server_key) os.unlink(self._server_key)
with captureStdout() as stdout: stdout = BytesIO()
cli.key_id([ cli.key_id(
'--private-key', user_key_path, user2_key_path, user2_new_key_path, ['--private-key', user_key_path, user2_key_path, user2_new_key_path],
]) stdout=stdout,
)
key_id_dict = dict( key_id_dict = dict(
line.rsplit(' ', 1) line.decode('ascii').rsplit(' ', 1)
for line in stdout.getvalue().splitlines() for line in stdout.getvalue().splitlines()
) )
key_id = key_id_dict.pop(user_key_path) key_id = key_id_dict.pop(user_key_path)
key2_id = key_id_dict.pop(user2_key_path) key2_id = key_id_dict.pop(user2_key_path)
new_key2_id = key_id_dict.pop(user2_new_key_path) new_key2_id = key_id_dict.pop(user2_new_key_path)
self.assertFalse(key_id_dict) self.assertFalse(key_id_dict)
with captureStdout() as stdout: stdout = BytesIO()
cli.key_id([ cli.key_id(
'--backup', backup_path, ['--backup', backup_path],
]) stdout=stdout,
)
self.assertItemsEqual( self.assertItemsEqual(
[ [
backup_path, backup_path,
' ' + key_id, ' ' + key_id,
' ' + key2_id, ' ' + key2_id,
], ],
stdout.getvalue().splitlines(), stdout.getvalue().decode('ascii').splitlines(),
) )
try: try:
...@@ -2410,17 +2410,18 @@ class CaucaseTest(unittest.TestCase): ...@@ -2410,17 +2410,18 @@ class CaucaseTest(unittest.TestCase):
if not backup_path_list: # pragma: no cover if not backup_path_list: # pragma: no cover
raise AssertionError('Backup file not created after 1 second') raise AssertionError('Backup file not created after 1 second')
backup_path, = glob.glob(backup_glob) backup_path, = glob.glob(backup_glob)
with captureStdout() as stdout: stdout = BytesIO()
cli.key_id([ cli.key_id(
'--backup', backup_path, ['--backup', backup_path],
]) stdout=stdout,
)
self.assertItemsEqual( self.assertItemsEqual(
[ [
backup_path, backup_path,
' ' + key_id, ' ' + key_id,
' ' + new_key2_id, ' ' + new_key2_id,
], ],
stdout.getvalue().splitlines(), stdout.getvalue().decode('ascii').splitlines(),
) )
# Now, push a lot of data to exercise chunked checksum in backup & # Now, push a lot of data to exercise chunked checksum in backup &
...@@ -2444,17 +2445,18 @@ class CaucaseTest(unittest.TestCase): ...@@ -2444,17 +2445,18 @@ class CaucaseTest(unittest.TestCase):
if not backup_path_list: # pragma: no cover if not backup_path_list: # pragma: no cover
raise AssertionError('Backup file took too long to be created') raise AssertionError('Backup file took too long to be created')
backup_path, = glob.glob(backup_glob) backup_path, = glob.glob(backup_glob)
with captureStdout() as stdout: stdout = BytesIO()
cli.key_id([ cli.key_id(
'--backup', backup_path, ['--backup', backup_path],
]) stdout=stdout,
)
self.assertItemsEqual( self.assertItemsEqual(
[ [
backup_path, backup_path,
' ' + key_id, ' ' + key_id,
' ' + new_key2_id, ' ' + new_key2_id,
], ],
stdout.getvalue().splitlines(), stdout.getvalue().decode('ascii').splitlines(),
) )
self._stopServer() self._stopServer()
os.unlink(self._server_db) os.unlink(self._server_db)
...@@ -2510,23 +2512,24 @@ class CaucaseTest(unittest.TestCase): ...@@ -2510,23 +2512,24 @@ class CaucaseTest(unittest.TestCase):
self.assertTrue(os.path.exists(exported_ca), exported_ca) self.assertTrue(os.path.exists(exported_ca), exported_ca)
server_db2 = self._server_db + '2' server_db2 = self._server_db + '2'
self.assertFalse(os.path.exists(server_db2), server_db2) self.assertFalse(os.path.exists(server_db2), server_db2)
with captureStdout() as stdout: stdout = BytesIO()
caucase.http.manage( caucase.http.manage(
argv=( argv=(
'--db', server_db2, '--db', server_db2,
'--import-ca', exported_ca, '--import-ca', exported_ca,
'--import-crl', self._client_crl, '--import-crl', self._client_crl,
# Twice, for code coverage... # Twice, for code coverage...
'--import-crl', self._client_crl, '--import-crl', self._client_crl,
), ),
) stdout=stdout,
)
self.assertTrue(os.path.exists(server_db2), server_db2) self.assertTrue(os.path.exists(server_db2), server_db2)
self.assertEqual( self.assertEqual(
[ [
'Imported 1 CA certificates', 'Imported 1 CA certificates',
'Revoked 1 certificates (1 were already revoked)', 'Revoked 1 certificates (1 were already revoked)',
], ],
stdout.getvalue().splitlines(), stdout.getvalue().decode('ascii').splitlines(),
) )
finally: finally:
caucase.http.getBytePass = getBytePass_orig caucase.http.getBytePass = getBytePass_orig
...@@ -2729,7 +2732,7 @@ class CaucaseTest(unittest.TestCase): ...@@ -2729,7 +2732,7 @@ class CaucaseTest(unittest.TestCase):
until_network_issue = UntilEvent(network_issue_event) until_network_issue = UntilEvent(network_issue_event)
# pylint: disable=protected-access # pylint: disable=protected-access
cli.RetryingCaucaseClient._until = until_network_issue cli.RetryingCaucaseClient._until = until_network_issue
cli.RetryingCaucaseClient._log_file = self.caucase_test_output cli.RetryingCaucaseClient._log_file = StringIO()
# pylint: enable=protected-access # pylint: enable=protected-access
until_network_issue.action = ON_EVENT_EXPIRE until_network_issue.action = ON_EVENT_EXPIRE
original_HTTPConnection = cli.RetryingCaucaseClient.HTTPConnection original_HTTPConnection = cli.RetryingCaucaseClient.HTTPConnection
......
...@@ -26,6 +26,7 @@ Small-ish functions needed in many places. ...@@ -26,6 +26,7 @@ Small-ish functions needed in many places.
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
from binascii import a2b_base64, b2a_base64 from binascii import a2b_base64, b2a_base64
import calendar import calendar
import codecs
from collections import defaultdict from collections import defaultdict
import datetime import datetime
import email import email
...@@ -499,6 +500,17 @@ def toBytes(value, encoding='ascii'): ...@@ -499,6 +500,17 @@ def toBytes(value, encoding='ascii'):
""" """
return value if isinstance(value, bytes) else value.encode(encoding) return value if isinstance(value, bytes) else value.encode(encoding)
def toUnicodeWritableStream(writable_stream, encoding='ascii'):
"""
Convert writable_stream into a writable stream accepting unicode.
If writable_stream already accepts unicode, returns it.
Otherwise, returns a writable stream accepting unicode, and sending it to
writable_stream encoded with given encoding.
"""
if getattr(writable_stream, 'encoding', None) is not None:
return writable_stream
return codecs.getwriter(encoding)(writable_stream)
def interruptibleSleep(duration): # pragma: no cover def interruptibleSleep(duration): # pragma: no cover
""" """
Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt
......
...@@ -20,15 +20,10 @@ ...@@ -20,15 +20,10 @@
# See https://www.nexedi.com/licensing for rationale and options. # See https://www.nexedi.com/licensing for rationale and options.
from setuptools import setup, find_packages from setuptools import setup, find_packages
import glob
import os
import sys
import versioneer import versioneer
long_description = open("README.rst").read() + "\n" with open("README.rst") as readme, open("CHANGES.txt") as changes:
for f in sorted(glob.glob(os.path.join('caucase', 'README.*.rst'))): long_description = readme.read() + "\n" + changes.read() + "\n"
long_description += '\n' + open(f).read() + '\n'
long_description += open("CHANGES.txt").read() + "\n"
setup( setup(
name='caucase', name='caucase',
...@@ -71,5 +66,5 @@ setup( ...@@ -71,5 +66,5 @@ setup(
] ]
}, },
test_suite='caucase.test', test_suite='caucase.test',
use_2to3=sys.version_info >= (3, ), use_2to3=True,
) )
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment