Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
caucase
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Labels
Merge Requests
2
Merge Requests
2
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Jobs
Commits
Open sidebar
nexedi
caucase
Commits
8ce08bf9
Commit
8ce08bf9
authored
Jul 13, 2018
by
Vincent Pelletier
Committed by
Vincent Pelletier
Sep 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
all: More python3 adaptations.
What was not picked up by 2to3.
parent
7f9e56cf
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
286 additions
and
220 deletions
+286
-220
caucase/ca.py
caucase/ca.py
+19
-18
caucase/cli.py
caucase/cli.py
+36
-25
caucase/client.py
caucase/client.py
+18
-19
caucase/http.py
caucase/http.py
+39
-33
caucase/http_wsgibase.py
caucase/http_wsgibase.py
+6
-5
caucase/storage.py
caucase/storage.py
+17
-12
caucase/test.py
caucase/test.py
+96
-75
caucase/utils.py
caucase/utils.py
+24
-10
caucase/wsgi.py
caucase/wsgi.py
+31
-23
No files found.
caucase/ca.py
View file @
8ce08bf9
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
Caucase - Certificate Authority for Users, Certificate Authority for SErvices
Caucase - Certificate Authority for Users, Certificate Authority for SErvices
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
binascii
import
hexlify
,
unhexlify
import
datetime
import
datetime
import
json
import
json
import
os
import
os
...
@@ -55,7 +56,7 @@ _SUBJECT_OID_DICT = {
...
@@ -55,7 +56,7 @@ _SUBJECT_OID_DICT = {
'GN'
:
x509
.
oid
.
NameOID
.
GIVEN_NAME
,
'GN'
:
x509
.
oid
.
NameOID
.
GIVEN_NAME
,
# pylint: enable=bad-whitespace
# pylint: enable=bad-whitespace
}
}
_BACKUP_MAGIC
=
'caucase
\
0
'
_BACKUP_MAGIC
=
b
'caucase
\
0
'
_CONFIG_NAME_AUTO_SIGN_CSR_AMOUNT
=
'auto_sign_csr_amount'
_CONFIG_NAME_AUTO_SIGN_CSR_AMOUNT
=
'auto_sign_csr_amount'
def
Extension
(
value
,
critical
):
def
Extension
(
value
,
critical
):
...
@@ -227,9 +228,9 @@ class CertificateAuthority(object):
...
@@ -227,9 +228,9 @@ class CertificateAuthority(object):
# Note: requested_amount is None when a known CSR is re-submitted
# Note: requested_amount is None when a known CSR is re-submitted
csr_id
,
requested_amount
=
self
.
_storage
.
appendCertificateSigningRequest
(
csr_id
,
requested_amount
=
self
.
_storage
.
appendCertificateSigningRequest
(
csr_pem
=
csr_pem
,
csr_pem
=
csr_pem
,
key_id
=
x509
.
SubjectKeyIdentifier
.
from_public_key
(
key_id
=
hexlify
(
x509
.
SubjectKeyIdentifier
.
from_public_key
(
csr
.
public_key
(),
csr
.
public_key
(),
).
digest
.
encode
(
'hex'
),
).
digest
),
override_limits
=
override_limits
,
override_limits
=
override_limits
,
)
)
if
requested_amount
is
not
None
and
\
if
requested_amount
is
not
None
and
\
...
@@ -632,8 +633,8 @@ class CertificateAuthority(object):
...
@@ -632,8 +633,8 @@ class CertificateAuthority(object):
current_crt_pem
=
utils
.
dump_certificate
(
key_pair
[
'crt'
])
current_crt_pem
=
utils
.
dump_certificate
(
key_pair
[
'crt'
])
result
.
append
(
utils
.
wrap
(
result
.
append
(
utils
.
wrap
(
{
{
'old_pem'
:
previous_crt_pem
,
'old_pem'
:
utils
.
toUnicode
(
previous_crt_pem
)
,
'new_pem'
:
current_crt_pem
,
'new_pem'
:
utils
.
toUnicode
(
current_crt_pem
)
,
},
},
previous_key
,
previous_key
,
self
.
digest_list
[
0
],
self
.
digest_list
[
0
],
...
@@ -799,31 +800,31 @@ class UserCertificateAuthority(CertificateAuthority):
...
@@ -799,31 +800,31 @@ class UserCertificateAuthority(CertificateAuthority):
continue
continue
public_key
=
crt
.
public_key
()
public_key
=
crt
.
public_key
()
key_list
.
append
({
key_list
.
append
({
'id'
:
x509
.
SubjectKeyIdentifier
.
from_public_ke
y
(
'id'
:
utils
.
toUnicode
(
hexlif
y
(
public_key
,
x509
.
SubjectKeyIdentifier
.
from_public_key
(
public_key
).
digest
,
)
.
digest
.
encode
(
'hex'
),
)),
'cipher'
:
{
'cipher'
:
{
'name'
:
'rsa_oaep_sha1_mgf1_sha1'
,
'name'
:
'rsa_oaep_sha1_mgf1_sha1'
,
},
},
'key'
:
public_key
.
encrypt
(
'key'
:
utils
.
toUnicode
(
hexlify
(
public_key
.
encrypt
(
signing_key
+
symetric_key
,
signing_key
+
symetric_key
,
OAEP
(
OAEP
(
mgf
=
MGF1
(
algorithm
=
hashes
.
SHA1
()),
mgf
=
MGF1
(
algorithm
=
hashes
.
SHA1
()),
algorithm
=
hashes
.
SHA1
(),
algorithm
=
hashes
.
SHA1
(),
label
=
None
,
label
=
None
,
),
),
)
.
encode
(
'hex'
),
)
)
),
})
})
if
not
key_list
:
if
not
key_list
:
# No users yet, backup is meaningless
# No users yet, backup is meaningless
return
False
return
False
header
=
json
.
dumps
({
header
=
utils
.
toBytes
(
json
.
dumps
({
'cipher'
:
{
'cipher'
:
{
'name'
:
'aes256_cbc_pkcs7_hmac_10M_sha256'
,
'name'
:
'aes256_cbc_pkcs7_hmac_10M_sha256'
,
'parameter'
:
iv
.
encode
(
'hex'
),
'parameter'
:
utils
.
toUnicode
(
hexlify
(
iv
)
),
},
},
'key_list'
:
key_list
,
'key_list'
:
key_list
,
})
})
)
padder
=
padding
.
PKCS7
(
128
).
padder
()
padder
=
padding
.
PKCS7
(
128
).
padder
()
write
(
_BACKUP_MAGIC
)
write
(
_BACKUP_MAGIC
)
write
(
struct
.
pack
(
'<I'
,
len
(
header
)))
write
(
struct
.
pack
(
'<I'
,
len
(
header
)))
...
@@ -877,11 +878,11 @@ class UserCertificateAuthority(CertificateAuthority):
...
@@ -877,11 +878,11 @@ class UserCertificateAuthority(CertificateAuthority):
if
header
[
'cipher'
][
'name'
]
!=
'aes256_cbc_pkcs7_hmac_10M_sha256'
:
if
header
[
'cipher'
][
'name'
]
!=
'aes256_cbc_pkcs7_hmac_10M_sha256'
:
raise
ValueError
(
'Unrecognised symetric cipher'
)
raise
ValueError
(
'Unrecognised symetric cipher'
)
private_key
=
utils
.
load_privatekey
(
key_pem
)
private_key
=
utils
.
load_privatekey
(
key_pem
)
key_id
=
x509
.
SubjectKeyIdentifier
.
from_public_key
(
key_id
=
hexlify
(
x509
.
SubjectKeyIdentifier
.
from_public_key
(
private_key
.
public_key
(),
private_key
.
public_key
(),
).
digest
.
encode
(
'hex'
)
).
digest
)
symetric_key_list
=
[
symetric_key_list
=
[
x
for
x
in
header
[
'key_list'
]
if
x
[
'id'
]
==
key_id
x
for
x
in
header
[
'key_list'
]
if
utils
.
toBytes
(
x
[
'id'
])
==
key_id
]
]
if
not
symetric_key_list
:
if
not
symetric_key_list
:
raise
ValueError
(
raise
ValueError
(
...
@@ -891,7 +892,7 @@ class UserCertificateAuthority(CertificateAuthority):
...
@@ -891,7 +892,7 @@ class UserCertificateAuthority(CertificateAuthority):
if
symetric_key_entry
[
'cipher'
][
'name'
]
!=
'rsa_oaep_sha1_mgf1_sha1'
:
if
symetric_key_entry
[
'cipher'
][
'name'
]
!=
'rsa_oaep_sha1_mgf1_sha1'
:
raise
ValueError
(
'Unrecognised asymetric cipher'
)
raise
ValueError
(
'Unrecognised asymetric cipher'
)
both_keys
=
private_key
.
decrypt
(
both_keys
=
private_key
.
decrypt
(
symetric_key_entry
[
'key'
].
decode
(
'hex'
),
unhexlify
(
symetric_key_entry
[
'key'
]
),
OAEP
(
OAEP
(
mgf
=
MGF1
(
algorithm
=
hashes
.
SHA1
()),
mgf
=
MGF1
(
algorithm
=
hashes
.
SHA1
()),
algorithm
=
hashes
.
SHA1
(),
algorithm
=
hashes
.
SHA1
(),
...
@@ -902,7 +903,7 @@ class UserCertificateAuthority(CertificateAuthority):
...
@@ -902,7 +903,7 @@ class UserCertificateAuthority(CertificateAuthority):
raise
ValueError
(
'Invalid key length'
)
raise
ValueError
(
'Invalid key length'
)
decryptor
=
Cipher
(
decryptor
=
Cipher
(
algorithms
.
AES
(
both_keys
[
32
:]),
algorithms
.
AES
(
both_keys
[
32
:]),
modes
.
CBC
(
header
[
'cipher'
][
'parameter'
].
decode
(
'hex'
)),
modes
.
CBC
(
unhexlify
(
header
[
'cipher'
][
'parameter'
]
)),
backend
=
_cryptography_backend
,
backend
=
_cryptography_backend
,
).
decryptor
()
).
decryptor
()
unpadder
=
padding
.
PKCS7
(
128
).
unpadder
()
unpadder
=
padding
.
PKCS7
(
128
).
unpadder
()
...
...
caucase/cli.py
View file @
8ce08bf9
...
@@ -20,6 +20,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
...
@@ -20,6 +20,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
"""
"""
from
__future__
import
absolute_import
,
print_function
from
__future__
import
absolute_import
,
print_function
import
argparse
import
argparse
from
binascii
import
hexlify
import
datetime
import
datetime
import
httplib
import
httplib
import
json
import
json
...
@@ -102,7 +103,7 @@ class CLICaucaseClient(object):
...
@@ -102,7 +103,7 @@ class CLICaucaseClient(object):
"""
"""
for
csr_id
,
csr_path
in
csr_id_path_list
:
for
csr_id
,
csr_path
in
csr_id_path_list
:
csr_pem
=
self
.
_client
.
getCertificateSigningRequest
(
int
(
csr_id
))
csr_pem
=
self
.
_client
.
getCertificateSigningRequest
(
int
(
csr_id
))
with
open
(
csr_path
,
'a'
)
as
csr_file
:
with
open
(
csr_path
,
'a
b
'
)
as
csr_file
:
csr_file
.
write
(
csr_pem
)
csr_file
.
write
(
csr_pem
)
def
getCRT
(
self
,
warning
,
error
,
crt_id_path_list
,
ca_list
):
def
getCRT
(
self
,
warning
,
error
,
crt_id_path_list
,
ca_list
):
...
@@ -157,7 +158,7 @@ class CLICaucaseClient(object):
...
@@ -157,7 +158,7 @@ class CLICaucaseClient(object):
)
)
error
=
True
error
=
True
continue
continue
with
open
(
crt_path
,
'a'
)
as
crt_file
:
with
open
(
crt_path
,
'a
b
'
)
as
crt_file
:
crt_file
.
write
(
crt_pem
)
crt_file
.
write
(
crt_pem
)
return
warning
,
error
return
warning
,
error
...
@@ -228,11 +229,17 @@ class CLICaucaseClient(object):
...
@@ -228,11 +229,17 @@ class CLICaucaseClient(object):
key_len
=
key_len
,
key_len
=
key_len
,
)
)
if
key_path
is
None
:
if
key_path
is
None
:
with
open
(
crt_path
,
'w'
)
as
crt_file
:
with
open
(
crt_path
,
'w
b
'
)
as
crt_file
:
crt_file
.
write
(
new_key_pem
)
crt_file
.
write
(
new_key_pem
)
crt_file
.
write
(
new_crt_pem
)
crt_file
.
write
(
new_crt_pem
)
else
:
else
:
with
open
(
crt_path
,
'w'
)
as
crt_file
,
open
(
key_path
,
'w'
)
as
key_file
:
with
open
(
crt_path
,
'wb'
,
)
as
crt_file
,
open
(
key_path
,
'wb'
,
)
as
key_file
:
key_file
.
write
(
new_key_pem
)
key_file
.
write
(
new_key_pem
)
crt_file
.
write
(
new_crt_pem
)
crt_file
.
write
(
new_crt_pem
)
updated
=
True
updated
=
True
...
@@ -250,7 +257,7 @@ class CLICaucaseClient(object):
...
@@ -250,7 +257,7 @@ class CLICaucaseClient(object):
),
),
)
)
for
entry
in
self
.
_client
.
getPendingCertificateRequestList
():
for
entry
in
self
.
_client
.
getPendingCertificateRequestList
():
csr
=
utils
.
load_certificate_request
(
entry
[
'csr'
]
)
csr
=
utils
.
load_certificate_request
(
utils
.
toBytes
(
entry
[
'csr'
])
)
print
(
print
(
'%20s | %r'
%
(
'%20s | %r'
%
(
entry
[
'id'
],
entry
[
'id'
],
...
@@ -264,7 +271,7 @@ class CLICaucaseClient(object):
...
@@ -264,7 +271,7 @@ class CLICaucaseClient(object):
--sign-csr
--sign-csr
"""
"""
for
csr_id
in
csr_id_list
:
for
csr_id
in
csr_id_list
:
self
.
_client
.
createCertificate
(
int
(
csr_id
))
self
.
_client
.
createCertificate
(
int
(
utils
.
toUnicode
(
csr_id
)
))
def
signCSRWith
(
self
,
csr_id_path_list
):
def
signCSRWith
(
self
,
csr_id_path_list
):
"""
"""
...
@@ -272,7 +279,7 @@ class CLICaucaseClient(object):
...
@@ -272,7 +279,7 @@ class CLICaucaseClient(object):
"""
"""
for
csr_id
,
csr_path
in
csr_id_path_list
:
for
csr_id
,
csr_path
in
csr_id_path_list
:
self
.
_client
.
createCertificate
(
self
.
_client
.
createCertificate
(
int
(
csr_id
),
int
(
utils
.
toUnicode
(
csr_id
)
),
template_csr
=
utils
.
getCertRequest
(
csr_path
),
template_csr
=
utils
.
getCertRequest
(
csr_path
),
)
)
...
@@ -763,7 +770,7 @@ def updater(argv=None, until=utils.until):
...
@@ -763,7 +770,7 @@ def updater(argv=None, until=utils.until):
# Still here ? Ok, wait a bit and try again.
# Still here ? Ok, wait a bit and try again.
until
(
datetime
.
datetime
.
utcnow
()
+
datetime
.
timedelta
(
0
,
60
))
until
(
datetime
.
datetime
.
utcnow
()
+
datetime
.
timedelta
(
0
,
60
))
else
:
else
:
with
open
(
args
.
crt
,
'a'
)
as
crt_file
:
with
open
(
args
.
crt
,
'a
b
'
)
as
crt_file
:
crt_file
.
write
(
crt_pem
)
crt_file
.
write
(
crt_pem
)
updated
=
True
updated
=
True
break
break
...
@@ -797,10 +804,11 @@ def updater(argv=None, until=utils.until):
...
@@ -797,10 +804,11 @@ def updater(argv=None, until=utils.until):
if
RetryingCaucaseClient
.
updateCRLFile
(
ca_url
,
args
.
crl
,
ca_crt_list
):
if
RetryingCaucaseClient
.
updateCRLFile
(
ca_url
,
args
.
crl
,
ca_crt_list
):
print
(
'Got new CRL'
)
print
(
'Got new CRL'
)
updated
=
True
updated
=
True
next_deadline
=
min
(
with
open
(
args
.
crl
,
'rb'
)
as
crl_file
:
next_deadline
,
next_deadline
=
min
(
utils
.
load_crl
(
open
(
args
.
crl
).
read
(),
ca_crt_list
).
next_update
,
next_deadline
,
)
utils
.
load_crl
(
crli_file
.
read
(),
ca_crt_list
).
next_update
,
)
if
args
.
crt
:
if
args
.
crt
:
crt_pem
,
key_pem
,
key_path
=
utils
.
getKeyPair
(
args
.
crt
,
args
.
key
)
crt_pem
,
key_pem
,
key_path
=
utils
.
getKeyPair
(
args
.
crt
,
args
.
key
)
crt
=
utils
.
load_certificate
(
crt_pem
,
ca_crt_list
,
None
)
crt
=
utils
.
load_certificate
(
crt_pem
,
ca_crt_list
,
None
)
...
@@ -812,16 +820,16 @@ def updater(argv=None, until=utils.until):
...
@@ -812,16 +820,16 @@ def updater(argv=None, until=utils.until):
key_len
=
args
.
key_len
,
key_len
=
args
.
key_len
,
)
)
if
key_path
is
None
:
if
key_path
is
None
:
with
open
(
args
.
crt
,
'w'
)
as
crt_file
:
with
open
(
args
.
crt
,
'w
b
'
)
as
crt_file
:
crt_file
.
write
(
new_key_pem
)
crt_file
.
write
(
new_key_pem
)
crt_file
.
write
(
new_crt_pem
)
crt_file
.
write
(
new_crt_pem
)
else
:
else
:
with
open
(
with
open
(
args
.
crt
,
args
.
crt
,
'w'
,
'w
b
'
,
)
as
crt_file
,
open
(
)
as
crt_file
,
open
(
key_path
,
key_path
,
'w'
,
'w
b
'
,
)
as
key_file
:
)
as
key_file
:
key_file
.
write
(
new_key_pem
)
key_file
.
write
(
new_key_pem
)
crt_file
.
write
(
new_crt_pem
)
crt_file
.
write
(
new_crt_pem
)
...
@@ -894,11 +902,11 @@ def rerequest(argv=None):
...
@@ -894,11 +902,11 @@ def rerequest(argv=None):
key_pem
=
utils
.
dump_privatekey
(
key
)
key_pem
=
utils
.
dump_privatekey
(
key
)
orig_umask
=
os
.
umask
(
0o177
)
orig_umask
=
os
.
umask
(
0o177
)
try
:
try
:
with
open
(
args
.
key
,
'w'
)
as
key_file
:
with
open
(
args
.
key
,
'w
b
'
)
as
key_file
:
key_file
.
write
(
key_pem
)
key_file
.
write
(
key_pem
)
finally
:
finally
:
os
.
umask
(
orig_umask
)
os
.
umask
(
orig_umask
)
with
open
(
args
.
csr
,
'w'
)
as
csr_file
:
with
open
(
args
.
csr
,
'w
b
'
)
as
csr_file
:
csr_file
.
write
(
csr_pem
)
csr_file
.
write
(
csr_pem
)
def
key_id
(
argv
=
None
):
def
key_id
(
argv
=
None
):
...
@@ -926,17 +934,20 @@ def key_id(argv=None):
...
@@ -926,17 +934,20 @@ def key_id(argv=None):
)
)
args
=
parser
.
parse_args
(
argv
)
args
=
parser
.
parse_args
(
argv
)
for
key_path
in
args
.
private_key
:
for
key_path
in
args
.
private_key
:
print
(
with
open
(
key_path
,
'rb'
)
as
key_file
:
key_path
,
print
(
x509
.
SubjectKeyIdentifier
.
from_public_key
(
key_path
,
utils
.
load_privatekey
(
open
(
key_path
).
read
()).
public_key
(),
utils
.
toUnicode
(
hexlify
(
).
digest
.
encode
(
'hex'
),
x509
.
SubjectKeyIdentifier
.
from_public_key
(
)
utils
.
load_privatekey
(
key_file
.
read
()).
public_key
(),
).
digest
,
)),
)
for
backup_path
in
args
.
backup
:
for
backup_path
in
args
.
backup
:
print
(
backup_path
)
print
(
backup_path
)
with
open
(
backup_path
)
as
backup_file
:
with
open
(
backup_path
,
'rb'
)
as
backup_file
:
magic
=
backup_file
.
read
(
8
)
magic
=
backup_file
.
read
(
8
)
if
magic
!=
'caucase
\
0
'
:
if
magic
!=
b
'caucase
\
0
'
:
raise
ValueError
(
'Invalid backup magic string'
)
raise
ValueError
(
'Invalid backup magic string'
)
header_len
,
=
struct
.
unpack
(
header_len
,
=
struct
.
unpack
(
'<I'
,
'<I'
,
...
...
caucase/client.py
View file @
8ce08bf9
...
@@ -69,7 +69,7 @@ class CaucaseClient(object):
...
@@ -69,7 +69,7 @@ class CaucaseClient(object):
"""
"""
if
not
os
.
path
.
exists
(
ca_crt_path
):
if
not
os
.
path
.
exists
(
ca_crt_path
):
ca_pem
=
cls
(
ca_url
=
url
).
getCACertificate
()
ca_pem
=
cls
(
ca_url
=
url
).
getCACertificate
()
with
open
(
ca_crt_path
,
'w'
)
as
ca_crt_file
:
with
open
(
ca_crt_path
,
'w
b
'
)
as
ca_crt_file
:
ca_crt_file
.
write
(
ca_pem
)
ca_crt_file
.
write
(
ca_pem
)
updated
=
True
updated
=
True
else
:
else
:
...
@@ -85,8 +85,8 @@ class CaucaseClient(object):
...
@@ -85,8 +85,8 @@ class CaucaseClient(object):
cls
(
ca_url
=
url
,
ca_crt_pem_list
=
ca_pem_list
).
getCACertificateChain
(),
cls
(
ca_url
=
url
,
ca_crt_pem_list
=
ca_pem_list
).
getCACertificateChain
(),
)
)
if
ca_pem_list
!=
loaded_ca_pem_list
:
if
ca_pem_list
!=
loaded_ca_pem_list
:
data
=
''
.
join
(
ca_pem_list
)
data
=
b
''
.
join
(
ca_pem_list
)
with
open
(
ca_crt_path
,
'w'
)
as
ca_crt_file
:
with
open
(
ca_crt_path
,
'w
b
'
)
as
ca_crt_file
:
ca_crt_file
.
write
(
data
)
ca_crt_file
.
write
(
data
)
updated
=
True
updated
=
True
return
updated
return
updated
...
@@ -107,13 +107,13 @@ class CaucaseClient(object):
...
@@ -107,13 +107,13 @@ class CaucaseClient(object):
Return whether an update happened.
Return whether an update happened.
"""
"""
if
os
.
path
.
exists
(
crl_path
):
if
os
.
path
.
exists
(
crl_path
):
my_crl
=
utils
.
load_crl
(
open
(
crl_path
).
read
(),
ca_list
)
my_crl
=
utils
.
load_crl
(
open
(
crl_path
,
'rb'
).
read
(),
ca_list
)
else
:
else
:
my_crl
=
None
my_crl
=
None
latest_crl_pem
=
cls
(
ca_url
=
url
).
getCertificateRevocationList
()
latest_crl_pem
=
cls
(
ca_url
=
url
).
getCertificateRevocationList
()
latest_crl
=
utils
.
load_crl
(
latest_crl_pem
,
ca_list
)
latest_crl
=
utils
.
load_crl
(
latest_crl_pem
,
ca_list
)
if
my_crl
is
None
or
latest_crl
.
signature
!=
my_crl
.
signature
:
if
my_crl
is
None
or
latest_crl
.
signature
!=
my_crl
.
signature
:
with
open
(
crl_path
,
'w'
)
as
crl_file
:
with
open
(
crl_path
,
'w
b
'
)
as
crl_file
:
crl_file
.
write
(
latest_crl_pem
)
crl_file
.
write
(
latest_crl_pem
)
return
True
return
True
return
False
return
False
...
@@ -138,7 +138,11 @@ class CaucaseClient(object):
...
@@ -138,7 +138,11 @@ class CaucaseClient(object):
ssl_context
=
ssl
.
create_default_context
(
ssl_context
=
ssl
.
create_default_context
(
# unicode object needed as we use PEM, otherwise create_default_context
# unicode object needed as we use PEM, otherwise create_default_context
# expects DER.
# expects DER.
cadata
=
''
.
join
(
http_ca_crt_pem_list
).
decode
(
'ascii'
)
if
http_ca_crt_pem_list
else
None
,
cadata
=
(
utils
.
toUnicode
(
''
.
join
(
http_ca_crt_pem_list
))
if
http_ca_crt_pem_list
else
None
),
)
)
if
not
http_ca_crt_pem_list
:
if
not
http_ca_crt_pem_list
:
ssl_context
.
check_hostname
=
False
ssl_context
.
check_hostname
=
False
...
@@ -191,13 +195,7 @@ class CaucaseClient(object):
...
@@ -191,13 +195,7 @@ class CaucaseClient(object):
"""
"""
[AUTHENTICATED] Retrieve all pending CSRs.
[AUTHENTICATED] Retrieve all pending CSRs.
"""
"""
return
[
return
json
.
loads
(
self
.
_https
(
'GET'
,
'/csr'
))
{
y
.
encode
(
'ascii'
):
z
.
encode
(
'ascii'
)
if
isinstance
(
z
,
unicode
)
else
z
for
y
,
z
in
x
.
iteritems
()
}
for
x
in
json
.
loads
(
self
.
_https
(
'GET'
,
'/csr'
))
]
def
createCertificateSigningRequest
(
self
,
csr
):
def
createCertificateSigningRequest
(
self
,
csr
):
"""
"""
...
@@ -254,14 +252,14 @@ class CaucaseClient(object):
...
@@ -254,14 +252,14 @@ class CaucaseClient(object):
continue
continue
if
not
found
:
if
not
found
:
found
=
utils
.
load_ca_certificate
(
found
=
utils
.
load_ca_certificate
(
payload
[
'old_pem'
].
encode
(
'ascii'
),
utils
.
toBytes
(
payload
[
'old_pem'
]
),
)
==
trust_anchor
)
==
trust_anchor
if
found
:
if
found
:
if
utils
.
load_ca_certificate
(
if
utils
.
load_ca_certificate
(
payload
[
'old_pem'
].
encode
(
'ascii'
),
utils
.
toBytes
(
payload
[
'old_pem'
]
),
)
!=
previous_ca
:
)
!=
previous_ca
:
raise
ValueError
(
'CA signature chain broken'
)
raise
ValueError
(
'CA signature chain broken'
)
new_pem
=
payload
[
'new_pem'
].
encode
(
'ascii'
)
new_pem
=
utils
.
toBytes
(
payload
[
'new_pem'
]
)
result
.
append
(
new_pem
)
result
.
append
(
new_pem
)
previous_ca
=
utils
.
load_ca_certificate
(
new_pem
)
previous_ca
=
utils
.
load_ca_certificate
(
new_pem
)
return
result
return
result
...
@@ -279,8 +277,8 @@ class CaucaseClient(object):
...
@@ -279,8 +277,8 @@ class CaucaseClient(object):
json
.
dumps
(
json
.
dumps
(
utils
.
wrap
(
utils
.
wrap
(
{
{
'crt_pem'
:
utils
.
dump_certificate
(
old_crt
),
'crt_pem'
:
utils
.
toUnicode
(
utils
.
dump_certificate
(
old_crt
)
),
'renew_csr_pem'
:
utils
.
dump_certificate_request
(
'renew_csr_pem'
:
utils
.
toUnicode
(
utils
.
dump_certificate_request
(
x509
.
CertificateSigningRequestBuilder
(
x509
.
CertificateSigningRequestBuilder
(
).
subject_name
(
).
subject_name
(
# Note: caucase server ignores this, but cryptography
# Note: caucase server ignores this, but cryptography
...
@@ -291,7 +289,7 @@ class CaucaseClient(object):
...
@@ -291,7 +289,7 @@ class CaucaseClient(object):
algorithm
=
utils
.
DEFAULT_DIGEST_CLASS
(),
algorithm
=
utils
.
DEFAULT_DIGEST_CLASS
(),
backend
=
_cryptography_backend
,
backend
=
_cryptography_backend
,
),
),
),
)
)
,
},
},
old_key
,
old_key
,
utils
.
DEFAULT_DIGEST
,
utils
.
DEFAULT_DIGEST
,
...
@@ -307,6 +305,7 @@ class CaucaseClient(object):
...
@@ -307,6 +305,7 @@ class CaucaseClient(object):
[ANONYMOUS] if key is provided.
[ANONYMOUS] if key is provided.
[AUTHENTICATED] if key is missing.
[AUTHENTICATED] if key is missing.
"""
"""
crt
=
utils
.
toUnicode
(
crt
)
if
key
:
if
key
:
method
=
self
.
_http
method
=
self
.
_http
data
=
utils
.
wrap
(
data
=
utils
.
wrap
(
...
...
caucase/http.py
View file @
8ce08bf9
...
@@ -70,7 +70,7 @@ def _createKey(path):
...
@@ -70,7 +70,7 @@ def _createKey(path):
"""
"""
return
os
.
fdopen
(
return
os
.
fdopen
(
os
.
open
(
path
,
os
.
O_WRONLY
|
os
.
O_CREAT
,
0o600
),
os
.
open
(
path
,
os
.
O_WRONLY
|
os
.
O_CREAT
,
0o600
),
'w'
,
'w
b
'
,
)
)
class
ThreadingWSGIServer
(
ThreadingMixIn
,
WSGIServer
):
class
ThreadingWSGIServer
(
ThreadingMixIn
,
WSGIServer
):
...
@@ -236,7 +236,7 @@ def getSSLContext(
...
@@ -236,7 +236,7 @@ def getSSLContext(
# implementation cross-check would have been nice.
# implementation cross-check would have been nice.
#ssl_context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
#ssl_context.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
ssl_context
.
load_verify_locations
(
ssl_context
.
load_verify_locations
(
cadata
=
cau
.
getCACertificate
().
decode
(
'ascii'
),
cadata
=
utils
.
toUnicode
(
cau
.
getCACertificate
()
),
)
)
http_cas_certificate_list
=
http_cas
.
getCACertificateList
()
http_cas_certificate_list
=
http_cas
.
getCACertificateList
()
threshold_delta
=
datetime
.
timedelta
(
threshold
,
0
)
threshold_delta
=
datetime
.
timedelta
(
threshold
,
0
)
...
@@ -500,12 +500,12 @@ def main(argv=None, until=utils.until):
...
@@ -500,12 +500,12 @@ def main(argv=None, until=utils.until):
)
)
args
=
parser
.
parse_args
(
argv
)
args
=
parser
.
parse_args
(
argv
)
base_url
=
u'http://'
+
args
.
netloc
.
decode
(
'ascii'
)
base_url
=
u'http://'
+
utils
.
toUnicode
(
args
.
netloc
)
parsed_base_url
=
urlparse
(
base_url
)
parsed_base_url
=
urlparse
(
base_url
)
hostname
=
parsed_base_url
.
hostname
hostname
=
parsed_base_url
.
hostname
name_constraints_permited
=
[]
name_constraints_permited
=
[]
name_constraints_excluded
=
[]
name_constraints_excluded
=
[]
hostname_dnsname
=
hostname
.
decode
(
'ascii'
)
hostname_dnsname
=
utils
.
toUnicode
(
hostname
)
try
:
try
:
hostname_ip_address
=
ipaddress
.
ip_address
(
hostname_dnsname
)
hostname_ip_address
=
ipaddress
.
ip_address
(
hostname_dnsname
)
except
ValueError
:
except
ValueError
:
...
@@ -615,7 +615,7 @@ def main(argv=None, until=utils.until):
...
@@ -615,7 +615,7 @@ def main(argv=None, until=utils.until):
crt_life_time
=
args
.
service_crt_validity
,
crt_life_time
=
args
.
service_crt_validity
,
)
)
if
os
.
path
.
exists
(
args
.
cors_key_store
):
if
os
.
path
.
exists
(
args
.
cors_key_store
):
with
open
(
args
.
cors_key_store
)
as
cors_key_file
:
with
open
(
args
.
cors_key_store
,
'rb'
)
as
cors_key_file
:
cors_secret_list
=
json
.
load
(
cors_key_file
)
cors_secret_list
=
json
.
load
(
cors_key_file
)
else
:
else
:
cors_secret_list
=
[]
cors_secret_list
=
[]
...
@@ -761,7 +761,7 @@ def main(argv=None, until=utils.until):
...
@@ -761,7 +761,7 @@ def main(argv=None, until=utils.until):
tmp_backup_fd
,
tmp_backup_path
=
tempfile
.
mkstemp
(
tmp_backup_fd
,
tmp_backup_path
=
tempfile
.
mkstemp
(
prefix
=
'caucase_backup_'
,
prefix
=
'caucase_backup_'
,
)
)
with
os
.
fdopen
(
tmp_backup_fd
,
'w'
)
as
backup_file
:
with
os
.
fdopen
(
tmp_backup_fd
,
'w
b
'
)
as
backup_file
:
result
=
cau
.
doBackup
(
backup_file
.
write
)
result
=
cau
.
doBackup
(
backup_file
.
write
)
if
result
:
if
result
:
backup_path
=
os
.
path
.
join
(
backup_path
=
os
.
path
.
join
(
...
@@ -782,6 +782,7 @@ def main(argv=None, until=utils.until):
...
@@ -782,6 +782,7 @@ def main(argv=None, until=utils.until):
finally
:
finally
:
sys
.
stderr
.
write
(
'Exiting
\
n
'
)
sys
.
stderr
.
write
(
'Exiting
\
n
'
)
for
server
in
itertools
.
chain
(
http_list
,
https_list
):
for
server
in
itertools
.
chain
(
http_list
,
https_list
):
server
.
server_close
()
server
.
shutdown
()
server
.
shutdown
()
def
manage
(
argv
=
None
):
def
manage
(
argv
=
None
):
...
@@ -820,7 +821,7 @@ def manage(argv=None):
...
@@ -820,7 +821,7 @@ def manage(argv=None):
default
=
[],
default
=
[],
metavar
=
'PEM_FILE'
,
metavar
=
'PEM_FILE'
,
action
=
'append'
,
action
=
'append'
,
type
=
argparse
.
FileType
(
'r'
),
type
=
argparse
.
FileType
(
'r
b
'
),
help
=
'Import key pairs as initial service CA certificate. '
help
=
'Import key pairs as initial service CA certificate. '
'May be provided multiple times to import multiple key pairs. '
'May be provided multiple times to import multiple key pairs. '
'Keys and certificates may be in separate files. '
'Keys and certificates may be in separate files. '
...
@@ -846,7 +847,7 @@ def manage(argv=None):
...
@@ -846,7 +847,7 @@ def manage(argv=None):
default
=
[],
default
=
[],
metavar
=
'PEM_FILE'
,
metavar
=
'PEM_FILE'
,
action
=
'append'
,
action
=
'append'
,
type
=
argparse
.
FileType
(
'r'
),
type
=
argparse
.
FileType
(
'r
b
'
),
help
=
'Import service revocation list. Corresponding CA certificate must '
help
=
'Import service revocation list. Corresponding CA certificate must '
'be already present in the database (including added in the same run '
'be already present in the database (including added in the same run '
'using --import-ca).'
,
'using --import-ca).'
,
...
@@ -854,7 +855,7 @@ def manage(argv=None):
...
@@ -854,7 +855,7 @@ def manage(argv=None):
parser
.
add_argument
(
parser
.
add_argument
(
'--export-ca'
,
'--export-ca'
,
metavar
=
'PEM_FILE'
,
metavar
=
'PEM_FILE'
,
type
=
argparse
.
FileType
(
'w'
),
type
=
argparse
.
FileType
(
'w
b
'
),
help
=
'Export all CA certificates in a PEM file. Passphrase will be '
help
=
'Export all CA certificates in a PEM file. Passphrase will be '
'prompted to protect all keys.'
,
'prompted to protect all keys.'
,
)
)
...
@@ -873,30 +874,35 @@ def manage(argv=None):
...
@@ -873,30 +874,35 @@ def manage(argv=None):
# maybe user extracted their private key ?
# maybe user extracted their private key ?
key_pem
=
utils
.
getKey
(
backup_key_path
)
key_pem
=
utils
.
getKey
(
backup_key_path
)
cau_crt_life_time
=
args
.
user_crt_validity
cau_crt_life_time
=
args
.
user_crt_validity
with
open
(
backup_path
)
as
backup_file
:
with
open
(
with
open
(
backup_crt_path
,
'a'
)
as
new_crt_file
:
backup_path
,
new_crt_file
.
write
(
'rb'
,
UserCertificateAuthority
.
restoreBackup
(
)
as
backup_file
,
open
(
db_class
=
SQLite3Storage
,
backup_crt_path
,
db_path
=
db_path
,
'ab'
,
read
=
backup_file
.
read
,
)
as
new_crt_file
:
key_pem
=
key_pem
,
new_crt_file
.
write
(
csr_pem
=
utils
.
getCertRequest
(
backup_csr_path
),
UserCertificateAuthority
.
restoreBackup
(
db_kw
=
{
db_class
=
SQLite3Storage
,
'table_prefix'
:
'cau'
,
db_path
=
db_path
,
# max_csr_amount: not needed, renewal ignores quota
read
=
backup_file
.
read
,
# Effectively disables certificate expiration
key_pem
=
key_pem
,
'crt_keep_time'
:
cau_crt_life_time
,
csr_pem
=
utils
.
getCertRequest
(
backup_csr_path
),
'crt_read_keep_time'
:
cau_crt_life_time
,
db_kw
=
{
'enforce_unique_key_id'
:
True
,
'table_prefix'
:
'cau'
,
},
# max_csr_amount: not needed, renewal ignores quota
kw
=
{
# Effectively disables certificate expiration
# Disable CA cert renewal
'crt_keep_time'
:
cau_crt_life_time
,
'ca_key_size'
:
None
,
'crt_read_keep_time'
:
cau_crt_life_time
,
'crt_life_time'
:
cau_crt_life_time
,
'enforce_unique_key_id'
:
True
,
},
},
),
kw
=
{
)
# Disable CA cert renewal
'ca_key_size'
:
None
,
'crt_life_time'
:
cau_crt_life_time
,
},
),
)
if
args
.
import_ca
:
if
args
.
import_ca
:
import_ca_dict
=
defaultdict
(
import_ca_dict
=
defaultdict
(
(
lambda
:
{
'crt'
:
None
,
'key'
:
None
,
'from'
:
[]}),
(
lambda
:
{
'crt'
:
None
,
'key'
:
None
,
'from'
:
[]}),
...
...
caucase/http_wsgibase.py
View file @
8ce08bf9
...
@@ -22,6 +22,7 @@ Separate from .http because of different-licensed code in the middle.
...
@@ -22,6 +22,7 @@ Separate from .http because of different-licensed code in the middle.
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
wsgiref.simple_server
import
ServerHandler
from
wsgiref.simple_server
import
ServerHandler
from
.utils
import
toBytes
class
ProxyFile
(
object
):
class
ProxyFile
(
object
):
"""
"""
...
@@ -48,7 +49,7 @@ class ChunkedFile(ProxyFile):
...
@@ -48,7 +49,7 @@ class ChunkedFile(ProxyFile):
"""
"""
Read chunked data.
Read chunked data.
"""
"""
result
=
''
result
=
b
''
if
not
self
.
_at_eof
:
if
not
self
.
_at_eof
:
readline
=
self
.
readline
readline
=
self
.
readline
read
=
self
.
__getattr__
(
'read'
)
read
=
self
.
__getattr__
(
'read'
)
...
@@ -61,7 +62,7 @@ class ChunkedFile(ProxyFile):
...
@@ -61,7 +62,7 @@ class ChunkedFile(ProxyFile):
if
len
(
chunk_header
)
>
MAX_CHUNKED_HEADER_LENGTH
:
if
len
(
chunk_header
)
>
MAX_CHUNKED_HEADER_LENGTH
:
raise
ValueError
(
'Chunked encoding header too long'
)
raise
ValueError
(
'Chunked encoding header too long'
)
try
:
try
:
chunk_length
=
int
(
chunk_header
.
split
(
';'
,
1
)[
0
],
16
)
chunk_length
=
int
(
chunk_header
.
split
(
b
';'
,
1
)[
0
],
16
)
except
ValueError
:
except
ValueError
:
raise
ValueError
(
'Invalid chunked encoding header'
)
raise
ValueError
(
'Invalid chunked encoding header'
)
if
not
chunk_length
:
if
not
chunk_length
:
...
@@ -78,7 +79,7 @@ class ChunkedFile(ProxyFile):
...
@@ -78,7 +79,7 @@ class ChunkedFile(ProxyFile):
if
to_read
!=
chunk_length
:
if
to_read
!=
chunk_length
:
self
.
_chunk_remaining_length
=
chunk_length
-
to_read
self
.
_chunk_remaining_length
=
chunk_length
-
to_read
break
break
if
read
(
2
)
!=
'
\
r
\
n
'
:
if
read
(
2
)
!=
b
'
\
r
\
n
'
:
raise
ValueError
(
'Invalid chunked encoding separator'
)
raise
ValueError
(
'Invalid chunked encoding separator'
)
return
result
return
result
...
@@ -131,7 +132,7 @@ class CleanServerHandler(ServerHandler):
...
@@ -131,7 +132,7 @@ class CleanServerHandler(ServerHandler):
"""
"""
Emit "100 Continue" intermediate response.
Emit "100 Continue" intermediate response.
"""
"""
self
.
_write
(
'HTTP/%s 100 Continue
\
r
\
n
\
r
\
n
'
%
(
self
.
_write
(
b
'HTTP/%s 100 Continue
\
r
\
n
\
r
\
n
'
%
(
self
.
http_version
,
toBytes
(
self
.
http_version
)
,
))
))
self
.
_flush
()
self
.
_flush
()
caucase/storage.py
View file @
8ce08bf9
...
@@ -25,6 +25,7 @@ import sqlite3
...
@@ -25,6 +25,7 @@ import sqlite3
from
threading
import
local
from
threading
import
local
from
time
import
time
from
time
import
time
from
.exceptions
import
NoStorage
,
NotFound
,
Found
from
.exceptions
import
NoStorage
,
NotFound
,
Found
from
.utils
import
toBytes
,
toUnicode
__all__
=
(
'SQLite3Storage'
,
)
__all__
=
(
'SQLite3Storage'
,
)
...
@@ -207,8 +208,8 @@ class SQLite3Storage(local):
...
@@ -207,8 +208,8 @@ class SQLite3Storage(local):
)
)
return
[
return
[
{
{
'crt_pem'
:
x
[
'crt'
].
encode
(
'ascii'
),
'crt_pem'
:
toBytes
(
x
[
'crt'
]
),
'key_pem'
:
x
[
'key'
].
encode
(
'ascii'
),
'key_pem'
:
toBytes
(
x
[
'key'
]
),
}
}
for
x
in
db
.
cursor
().
execute
(
for
x
in
db
.
cursor
().
execute
(
'SELECT key, crt FROM %sca ORDER BY expiration_date ASC'
%
(
'SELECT key, crt FROM %sca ORDER BY expiration_date ASC'
%
(
...
@@ -326,7 +327,7 @@ class SQLite3Storage(local):
...
@@ -326,7 +327,7 @@ class SQLite3Storage(local):
)
)
if
result
is
None
:
if
result
is
None
:
raise
NotFound
raise
NotFound
return
result
[
'csr'
].
encode
(
'ascii'
)
return
toBytes
(
result
[
'csr'
]
)
def
getCertificateSigningRequestList
(
self
):
def
getCertificateSigningRequestList
(
self
):
"""
"""
...
@@ -338,7 +339,11 @@ class SQLite3Storage(local):
...
@@ -338,7 +339,11 @@ class SQLite3Storage(local):
return
[
return
[
{
{
'id'
:
str
(
x
[
'id'
]),
'id'
:
str
(
x
[
'id'
]),
'csr'
:
x
[
'csr'
].
encode
(
'ascii'
),
# XXX: because only call chain will end up serialising this value in
# json, and for some reason python3 json module refuses bytes.
# So rather than byte-ify (consistently with all PEM-encoded values)
# to then have to unicode-ify, just unicode-ify here.
'csr'
:
toUnicode
(
x
[
'csr'
]),
}
}
for
x
in
db
.
cursor
().
execute
(
for
x
in
db
.
cursor
().
execute
(
'SELECT id, csr FROM %scrt WHERE crt IS NULL'
%
(
'SELECT id, csr FROM %scrt WHERE crt IS NULL'
%
(
...
@@ -401,7 +406,7 @@ class SQLite3Storage(local):
...
@@ -401,7 +406,7 @@ class SQLite3Storage(local):
crt_id
,
crt_id
,
)
)
)
)
return
row
[
'crt'
].
encode
(
'ascii'
)
return
toBytes
(
row
[
'crt'
]
)
def
getCertificateByKeyIdentifier
(
self
,
key_id
):
def
getCertificateByKeyIdentifier
(
self
,
key_id
):
"""
"""
...
@@ -419,7 +424,7 @@ class SQLite3Storage(local):
...
@@ -419,7 +424,7 @@ class SQLite3Storage(local):
)
)
if
row
is
None
:
if
row
is
None
:
raise
NotFound
raise
NotFound
return
row
[
'crt'
].
encode
(
'ascii'
)
return
toBytes
(
row
[
'crt'
]
)
def
iterCertificates
(
self
):
def
iterCertificates
(
self
):
"""
"""
...
@@ -434,7 +439,7 @@ class SQLite3Storage(local):
...
@@ -434,7 +439,7 @@ class SQLite3Storage(local):
row
=
c
.
fetchone
()
row
=
c
.
fetchone
()
if
row
is
None
:
if
row
is
None
:
break
break
yield
row
[
'crt'
].
encode
(
'ascii'
)
yield
toBytes
(
row
[
'crt'
]
)
def
revoke
(
self
,
serial
,
expiration_date
):
def
revoke
(
self
,
serial
,
expiration_date
):
"""
"""
...
@@ -483,7 +488,7 @@ class SQLite3Storage(local):
...
@@ -483,7 +488,7 @@ class SQLite3Storage(local):
(
time
(),
)
(
time
(),
)
)
)
if
row
is
not
None
:
if
row
is
not
None
:
return
row
[
'crl'
].
encode
(
'ascii'
)
return
toBytes
(
row
[
'crl'
]
)
return
None
return
None
def
getNextCertificateRevocationListNumber
(
self
):
def
getNextCertificateRevocationListNumber
(
self
):
...
@@ -547,7 +552,7 @@ class SQLite3Storage(local):
...
@@ -547,7 +552,7 @@ class SQLite3Storage(local):
class (so not limited to table_prefix).
class (so not limited to table_prefix).
"""
"""
for
statement
in
self
.
_db
.
iterdump
():
for
statement
in
self
.
_db
.
iterdump
():
yield
statement
.
encode
(
'utf-8'
)
+
'
\
0
'
yield
toBytes
(
statement
,
'utf-8'
)
+
b
'
\
0
'
@
staticmethod
@
staticmethod
def
restore
(
db_path
,
restorator
):
def
restore
(
db_path
,
restorator
):
...
@@ -563,14 +568,14 @@ class SQLite3Storage(local):
...
@@ -563,14 +568,14 @@ class SQLite3Storage(local):
Produces chunks which correspond (in content, not necessarily in size)
Produces chunks which correspond (in content, not necessarily in size)
to what dumpIterator produces.
to what dumpIterator produces.
"""
"""
buf
=
''
buf
=
b
''
if
os
.
path
.
exists
(
db_path
):
if
os
.
path
.
exists
(
db_path
):
raise
ValueError
(
'%r exists, not restoring.'
%
(
db_path
,
))
raise
ValueError
(
'%r exists, not restoring.'
%
(
db_path
,
))
c
=
sqlite3
.
connect
(
db_path
,
isolation_level
=
None
).
cursor
()
c
=
sqlite3
.
connect
(
db_path
,
isolation_level
=
None
).
cursor
()
for
chunk
in
restorator
:
for
chunk
in
restorator
:
statement_list
=
(
buf
+
chunk
).
split
(
'
\
0
'
)
statement_list
=
(
buf
+
chunk
).
split
(
b
'
\
0
'
)
buf
=
statement_list
.
pop
()
buf
=
statement_list
.
pop
()
for
statement
in
statement_list
:
for
statement
in
statement_list
:
c
.
execute
(
(
statement
).
decode
(
'utf-8'
))
c
.
execute
(
toUnicode
(
statement
,
'utf-8'
))
if
buf
:
if
buf
:
raise
ValueError
(
'Short read, backup truncated ?'
)
raise
ValueError
(
'Short read, backup truncated ?'
)
caucase/test.py
View file @
8ce08bf9
...
@@ -22,12 +22,12 @@ Test suite
...
@@ -22,12 +22,12 @@ Test suite
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
Cookie
import
SimpleCookie
from
Cookie
import
SimpleCookie
from
cStringIO
import
StringIO
import
datetime
import
datetime
import
errno
import
errno
import
glob
import
glob
import
HTMLParser
import
HTMLParser
import
httplib
import
httplib
from
io
import
BytesIO
,
StringIO
import
ipaddress
import
ipaddress
import
json
import
json
import
os
import
os
...
@@ -48,7 +48,9 @@ from cryptography import x509
...
@@ -48,7 +48,9 @@ from cryptography import x509
from
cryptography.hazmat.backends
import
default_backend
from
cryptography.hazmat.backends
import
default_backend
from
caucase
import
cli
from
caucase
import
cli
from
caucase.client
import
CaucaseError
,
CaucaseClient
from
caucase.client
import
CaucaseError
,
CaucaseClient
from
caucase
import
http
# Do not import caucase.http into this namespace: 2to3 will import standard
# http module, which will then be masqued by caucase's http submodule.
import
caucase.http
from
caucase
import
utils
from
caucase
import
utils
from
caucase
import
exceptions
from
caucase
import
exceptions
from
caucase
import
wsgi
from
caucase
import
wsgi
...
@@ -106,11 +108,13 @@ def canConnect(address): # pragma: no cover
...
@@ -106,11 +108,13 @@ def canConnect(address): # pragma: no cover
otherwise.
otherwise.
"""
"""
try
:
try
:
socket
.
create_connection
(
address
)
sock
=
sock
et
.
create_connection
(
address
)
except
socket
.
error
as
e
:
except
socket
.
error
as
e
:
if
e
.
errno
==
errno
.
ECONNREFUSED
:
if
e
.
errno
==
errno
.
ECONNREFUSED
:
return
False
return
False
raise
raise
else
:
sock
.
close
()
return
True
return
True
def
retry
(
callback
,
try_count
=
200
,
try_delay
=
0.1
):
# pragma: no cover
def
retry
(
callback
,
try_count
=
200
,
try_delay
=
0.1
):
# pragma: no cover
...
@@ -129,7 +133,7 @@ def retry(callback, try_count=200, try_delay=0.1): # pragma: no cover
...
@@ -129,7 +133,7 @@ def retry(callback, try_count=200, try_delay=0.1): # pragma: no cover
class
FakeStreamRequest
(
object
):
class
FakeStreamRequest
(
object
):
"""
"""
For testing StreamRequestHandler subclasses
For testing StreamRequestHandler subclasses
(like http.CaucaseWSGIRequestHandler).
(like
caucase.
http.CaucaseWSGIRequestHandler).
"""
"""
def
__init__
(
self
,
rfile
,
wfile
):
def
__init__
(
self
,
rfile
,
wfile
):
"""
"""
...
@@ -144,6 +148,9 @@ class FakeStreamRequest(object):
...
@@ -144,6 +148,9 @@ class FakeStreamRequest(object):
"""
"""
return
self
.
_rfile
if
'r'
in
mode
else
self
.
_wfile
return
self
.
_rfile
if
'r'
in
mode
else
self
.
_wfile
def
sendall
(
self
,
data
,
flags
=
None
):
# pragma: no cover
self
.
_wfile
.
write
(
data
)
class
NoCloseFileProxy
(
object
):
class
NoCloseFileProxy
(
object
):
"""
"""
Intercept .close() calls, for example to allow reading StringIO content
Intercept .close() calls, for example to allow reading StringIO content
...
@@ -324,7 +331,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -324,7 +331,7 @@ class CaucaseTest(unittest.TestCase):
Returns its exit status.
Returns its exit status.
"""
"""
try
:
try
:
http
.
manage
(
caucase
.
http
.
manage
(
argv
=
(
argv
=
(
'--db'
,
self
.
_server_db
,
'--db'
,
self
.
_server_db
,
'--restore-backup'
,
'--restore-backup'
,
...
@@ -346,7 +353,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -346,7 +353,7 @@ class CaucaseTest(unittest.TestCase):
"""
"""
self
.
_server_until
=
until
=
UntilEvent
(
self
.
_server_event
)
self
.
_server_until
=
until
=
UntilEvent
(
self
.
_server_event
)
self
.
_server
=
server
=
threading
.
Thread
(
self
.
_server
=
server
=
threading
.
Thread
(
target
=
http
.
main
,
target
=
caucase
.
http
.
main
,
kwargs
=
{
kwargs
=
{
'argv'
:
(
'argv'
:
(
'--db'
,
self
.
_server_db
,
'--db'
,
self
.
_server_db
,
...
@@ -453,10 +460,10 @@ class CaucaseTest(unittest.TestCase):
...
@@ -453,10 +460,10 @@ class CaucaseTest(unittest.TestCase):
row
=
c
.
fetchone
()
row
=
c
.
fetchone
()
if
row
is
None
:
# pragma: no cover
if
row
is
None
:
# pragma: no cover
raise
Exception
(
'CA with serial %r not found'
%
(
serial
,
))
raise
Exception
(
'CA with serial %r not found'
%
(
serial
,
))
crt
=
utils
.
load_ca_certificate
(
row
[
'crt'
].
encode
(
'ascii'
))
crt
=
utils
.
load_ca_certificate
(
utils
.
toBytes
(
row
[
'crt'
]
))
if
crt
.
serial_number
==
serial
:
if
crt
.
serial_number
==
serial
:
new_crt
=
self
.
_setCertificateRemainingLifeTime
(
new_crt
=
self
.
_setCertificateRemainingLifeTime
(
key
=
utils
.
load_privatekey
(
row
[
'key'
].
encode
(
'ascii'
)),
key
=
utils
.
load_privatekey
(
utils
.
toBytes
(
row
[
'key'
]
)),
crt
=
crt
,
crt
=
crt
,
delta
=
delta
,
delta
=
delta
,
)
)
...
@@ -489,7 +496,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -489,7 +496,7 @@ class CaucaseTest(unittest.TestCase):
"""
"""
name
=
basename
+
'.key.pem'
name
=
basename
+
'.key.pem'
assert
not
os
.
path
.
exists
(
name
)
assert
not
os
.
path
.
exists
(
name
)
with
open
(
name
,
'w'
)
as
key_file
:
with
open
(
name
,
'w
b
'
)
as
key_file
:
key_file
.
write
(
utils
.
dump_privatekey
(
key_file
.
write
(
utils
.
dump_privatekey
(
utils
.
generatePrivateKey
(
key_len
=
key_len
),
utils
.
generatePrivateKey
(
key_len
=
key_len
),
))
))
...
@@ -516,7 +523,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -516,7 +523,7 @@ class CaucaseTest(unittest.TestCase):
"""
"""
name
=
basename
+
'.csr.pem'
name
=
basename
+
'.csr.pem'
assert
not
os
.
path
.
exists
(
name
)
assert
not
os
.
path
.
exists
(
name
)
with
open
(
name
,
'w'
)
as
csr_file
:
with
open
(
name
,
'w
b
'
)
as
csr_file
:
csr_file
.
write
(
csr_file
.
write
(
utils
.
dump_certificate_request
(
utils
.
dump_certificate_request
(
csr_builder
.
sign
(
csr_builder
.
sign
(
...
@@ -604,7 +611,8 @@ class CaucaseTest(unittest.TestCase):
...
@@ -604,7 +611,8 @@ class CaucaseTest(unittest.TestCase):
'--mode'
,
mode
,
'--mode'
,
mode
,
'--get-csr'
,
csr_id
,
csr2_path
,
'--get-csr'
,
csr_id
,
csr2_path
,
)
)
self
.
assertEqual
(
open
(
csr_path
).
read
(),
open
(
csr2_path
).
read
())
with
open
(
csr_path
,
'rb'
)
as
csr_file
,
open
(
csr2_path
,
'rb'
)
as
csr2_file
:
self
.
assertEqual
(
csr_file
.
read
(),
csr2_file
.
read
())
# Sign using user cert
# Sign using user cert
# Note: assuming user does not know the csr_id and keeps their own copy of
# Note: assuming user does not know the csr_id and keeps their own copy of
# issued certificates.
# issued certificates.
...
@@ -1143,12 +1151,14 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1143,12 +1151,14 @@ class CaucaseTest(unittest.TestCase):
# Check renewed CRT filtering does not alter clean signed certificate
# Check renewed CRT filtering does not alter clean signed certificate
# content (especially, caucase auto-signed flag must not appear).
# content (especially, caucase auto-signed flag must not appear).
before_key
=
open
(
key_path
).
read
()
with
open
(
key_path
,
'rb'
)
as
key_file
:
before_key
=
key_file
.
read
()
self
.
_runClient
(
self
.
_runClient
(
'--threshold'
,
'100'
,
'--threshold'
,
'100'
,
'--renew-crt'
,
key_path
,
''
,
'--renew-crt'
,
key_path
,
''
,
)
)
after_key
=
open
(
key_path
).
read
()
with
open
(
key_path
,
'rb'
)
as
key_file
:
after_key
=
key_file
.
read
()
assert
before_key
!=
after_key
assert
before_key
!=
after_key
checkCRT
(
key_path
)
checkCRT
(
key_path
)
...
@@ -1215,7 +1225,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1215,7 +1225,7 @@ class CaucaseTest(unittest.TestCase):
)
)
# As we will use this crt as trust anchor, we must make the client believe
# As we will use this crt as trust anchor, we must make the client believe
# it knew it all along.
# it knew it all along.
with
open
(
self
.
_client_user_ca_crt
,
'w'
)
as
client_user_ca_crt_file
:
with
open
(
self
.
_client_user_ca_crt
,
'w
b
'
)
as
client_user_ca_crt_file
:
client_user_ca_crt_file
.
write
(
new_cau_crt_pem
)
client_user_ca_crt_file
.
write
(
new_cau_crt_pem
)
self
.
_startServer
()
self
.
_startServer
()
new_user_key
=
self
.
_createAndApproveCertificate
(
new_user_key
=
self
.
_createAndApproveCertificate
(
...
@@ -1302,11 +1312,11 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1302,11 +1312,11 @@ class CaucaseTest(unittest.TestCase):
self
.
_server_key
,
self
.
_server_key
,
crl
=
None
,
crl
=
None
,
)
)
with
open
(
self
.
_server_key
,
'w'
)
as
server_key_file
:
with
open
(
self
.
_server_key
,
'w
b
'
)
as
server_key_file
:
server_key_file
.
write
(
key_pem
)
server_key_file
.
write
(
key_pem
)
server_key_file
.
write
(
utils
.
dump_certificate
(
server_key_file
.
write
(
utils
.
dump_certificate
(
self
.
_setCertificateRemainingLifeTime
(
self
.
_setCertificateRemainingLifeTime
(
key
=
utils
.
load_privatekey
(
http_cas_key
.
encode
(
'ascii'
)),
key
=
utils
.
load_privatekey
(
utils
.
toBytes
(
http_cas_key
)),
crt
=
utils
.
load_certificate
(
crt
=
utils
.
load_certificate
(
crt_pem
,
crt_pem
,
[
[
...
@@ -1318,10 +1328,13 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1318,10 +1328,13 @@ class CaucaseTest(unittest.TestCase):
)
)
))
))
server_key_file
.
write
(
ca_crt_pem
)
server_key_file
.
write
(
ca_crt_pem
)
reference_server_key
=
open
(
self
.
_server_key
).
read
()
def
readServerKey
():
with
open
(
self
.
_server_key
,
'rb'
)
as
server_key_file
:
return
server_key_file
.
read
()
reference_server_key
=
readServerKey
()
self
.
_startServer
()
self
.
_startServer
()
if
not
retry
(
if
not
retry
(
lambda
:
open
(
self
.
_server_key
).
read
()
!=
reference_server_key
,
lambda
:
readServerKey
()
!=
reference_server_key
,
):
# pragma: no cover
):
# pragma: no cover
raise
AssertionError
(
'Server did not renew its key pair within 1 second'
)
raise
AssertionError
(
'Server did not renew its key pair within 1 second'
)
# But user still trusts the server
# But user still trusts the server
...
@@ -1363,7 +1376,8 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1363,7 +1376,8 @@ class CaucaseTest(unittest.TestCase):
utils
.
load_ca_certificate
(
x
)
utils
.
load_ca_certificate
(
x
)
for
x
in
utils
.
getCertList
(
self
.
_client_user_ca_crt
)
for
x
in
utils
.
getCertList
(
self
.
_client_user_ca_crt
)
]
]
cau_crl
=
open
(
self
.
_client_user_crl
).
read
()
with
open
(
self
.
_client_user_crl
,
'rb'
)
as
client_user_crl_file
:
cau_crl
=
client_user_crl_file
.
read
()
class
DummyCAU
(
object
):
class
DummyCAU
(
object
):
"""
"""
Mock CAU.
Mock CAU.
...
@@ -1382,7 +1396,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1382,7 +1396,7 @@ class CaucaseTest(unittest.TestCase):
"""
"""
Return a dummy string as CA certificate
Return a dummy string as CA certificate
"""
"""
return
'notreallyPEM'
return
b
'notreallyPEM'
@
staticmethod
@
staticmethod
def
getCertificateRevocationList
():
def
getCertificateRevocationList
():
...
@@ -1441,7 +1455,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1441,7 +1455,7 @@ class CaucaseTest(unittest.TestCase):
if
key
in
header_dict
:
# pragma: no cover
if
key
in
header_dict
:
# pragma: no cover
value
=
header_dict
[
key
]
+
','
+
value
value
=
header_dict
[
key
]
+
','
+
value
header_dict
[
key
]
=
value
header_dict
[
key
]
=
value
return
int
(
status
),
reason
,
header_dict
,
''
.
join
(
body
)
return
int
(
status
),
reason
,
header_dict
,
b
''
.
join
(
body
)
UNAUTHORISED_STATUS
=
401
UNAUTHORISED_STATUS
=
401
HATEOAS_HTTP_PREFIX
=
u"http://caucase.example.com:8000/base/path"
HATEOAS_HTTP_PREFIX
=
u"http://caucase.example.com:8000/base/path"
...
@@ -1841,7 +1855,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -1841,7 +1855,7 @@ class CaucaseTest(unittest.TestCase):
header_dict
[
'Content-Security-Policy'
],
header_dict
[
'Content-Security-Policy'
],
"frame-ancestors 'none'"
,
"frame-ancestors 'none'"
,
)
)
assertHTMLNoScriptAlert
(
body
)
assertHTMLNoScriptAlert
(
utils
.
toUnicode
(
body
)
)
# POST /cors sets cookie
# POST /cors sets cookie
def
getCORSPostEnvironment
(
kw
=
(),
input_dict
=
(
def
getCORSPostEnvironment
(
kw
=
(),
input_dict
=
(
(
'return_to'
,
return_url
),
(
'return_to'
,
return_url
),
...
@@ -2042,9 +2056,9 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2042,9 +2056,9 @@ class CaucaseTest(unittest.TestCase):
table_prefix
=
'cau'
,
table_prefix
=
'cau'
,
).
dumpIterator
())
).
dumpIterator
())
CRL_INSERT
=
'INSERT INTO "caucrl" '
CRL_INSERT
=
b
'INSERT INTO "caucrl" '
CRT_INSERT
=
'INSERT INTO "caucrt" '
CRT_INSERT
=
b
'INSERT INTO "caucrt" '
REV_INSERT
=
'INSERT INTO "caurevoked" '
REV_INSERT
=
b
'INSERT INTO "caurevoked" '
def
filterBackup
(
backup
,
expect_rev
):
def
filterBackup
(
backup
,
expect_rev
):
"""
"""
Remove all lines which are know to differ between original batabase and
Remove all lines which are know to differ between original batabase and
...
@@ -2145,7 +2159,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2145,7 +2159,7 @@ class CaucaseTest(unittest.TestCase):
user2_newnew_key_path
,
user2_newnew_key_path
,
)
)
user2_new_bare_key_path
=
user2_new_key_path
+
'.bare_key'
user2_new_bare_key_path
=
user2_new_key_path
+
'.bare_key'
with
open
(
user2_new_bare_key_path
,
'w'
)
as
bare_key_file
:
with
open
(
user2_new_bare_key_path
,
'w
b
'
)
as
bare_key_file
:
bare_key_file
.
write
(
utils
.
getKeyPair
(
user2_new_key_path
)[
1
])
bare_key_file
.
write
(
utils
.
getKeyPair
(
user2_new_key_path
)[
1
])
self
.
assertEqual
(
self
.
assertEqual
(
self
.
_restoreServer
(
self
.
_restoreServer
(
...
@@ -2174,13 +2188,13 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2174,13 +2188,13 @@ class CaucaseTest(unittest.TestCase):
'--revoke-crt'
,
service_key
,
service_key
,
'--revoke-crt'
,
service_key
,
service_key
,
)
)
self
.
_runClient
()
self
.
_runClient
()
getBytePass_orig
=
http
.
getBytePass
getBytePass_orig
=
caucase
.
http
.
getBytePass
orig_stdout
=
sys
.
stdout
orig_stdout
=
sys
.
stdout
try
:
try
:
http
.
getBytePass
=
lambda
x
:
'test'
caucase
.
http
.
getBytePass
=
lambda
x
:
b
'test'
sys
.
stdout
=
stdout
=
StringIO
()
sys
.
stdout
=
stdout
=
StringIO
()
self
.
assertFalse
(
os
.
path
.
exists
(
exported_ca
),
exported_ca
)
self
.
assertFalse
(
os
.
path
.
exists
(
exported_ca
),
exported_ca
)
http
.
manage
(
caucase
.
http
.
manage
(
argv
=
(
argv
=
(
'--db'
,
self
.
_server_db
,
'--db'
,
self
.
_server_db
,
'--export-ca'
,
exported_ca
,
'--export-ca'
,
exported_ca
,
...
@@ -2189,7 +2203,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2189,7 +2203,7 @@ class CaucaseTest(unittest.TestCase):
self
.
assertTrue
(
os
.
path
.
exists
(
exported_ca
),
exported_ca
)
self
.
assertTrue
(
os
.
path
.
exists
(
exported_ca
),
exported_ca
)
server_db2
=
self
.
_server_db
+
'2'
server_db2
=
self
.
_server_db
+
'2'
self
.
assertFalse
(
os
.
path
.
exists
(
server_db2
),
server_db2
)
self
.
assertFalse
(
os
.
path
.
exists
(
server_db2
),
server_db2
)
http
.
manage
(
caucase
.
http
.
manage
(
argv
=
(
argv
=
(
'--db'
,
server_db2
,
'--db'
,
server_db2
,
'--import-ca'
,
exported_ca
,
'--import-ca'
,
exported_ca
,
...
@@ -2208,7 +2222,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2208,7 +2222,7 @@ class CaucaseTest(unittest.TestCase):
)
)
finally
:
finally
:
sys
.
stdout
=
orig_stdout
sys
.
stdout
=
orig_stdout
http
.
getBytePass
=
getBytePass_orig
caucase
.
http
.
getBytePass
=
getBytePass_orig
def
testWSGIBase
(
self
):
def
testWSGIBase
(
self
):
"""
"""
...
@@ -2220,10 +2234,10 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2220,10 +2234,10 @@ class CaucaseTest(unittest.TestCase):
"""
"""
Trigger execution of app, with given request.
Trigger execution of app, with given request.
"""
"""
wfile
=
String
IO
()
wfile
=
Bytes
IO
()
http
.
CaucaseWSGIRequestHandler
(
caucase
.
http
.
CaucaseWSGIRequestHandler
(
FakeStreamRequest
(
FakeStreamRequest
(
StringIO
(
'
\
r
\
n
'
.
join
(
request_line_list
+
[
''
])),
BytesIO
(
b'
\
r
\
n
'
.
join
(
request_line_list
+
[
b
''
])),
NoCloseFileProxy
(
wfile
),
NoCloseFileProxy
(
wfile
),
),
),
(
'0.0.0.0'
,
0
),
(
'0.0.0.0'
,
0
),
...
@@ -2235,26 +2249,28 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2235,26 +2249,28 @@ class CaucaseTest(unittest.TestCase):
"""
"""
Naive extraction of http status out of an http response.
Naive extraction of http status out of an http response.
"""
"""
_
,
code
,
_
=
response_line_list
[
0
].
split
(
' '
,
2
)
_
,
code
,
_
=
response_line_list
[
0
].
split
(
b
' '
,
2
)
return
int
(
code
)
return
int
(
code
)
def
getBody
(
response_line_list
):
def
getBody
(
response_line_list
):
"""
"""
Naive extraction of http response body.
Naive extraction of http response body.
"""
"""
return
'
\
r
\
n
'
.
join
(
response_line_list
[
response_line_list
.
index
(
''
)
+
1
:])
return
b'
\
r
\
n
'
.
join
(
response_line_list
[
response_line_list
.
index
(
b''
)
+
1
:],
)
self
.
assertEqual
(
self
.
assertEqual
(
getStatus
(
run
([
'GET /'
+
'a'
*
65537
])),
getStatus
(
run
([
b'GET /'
+
b
'a'
*
65537
])),
414
,
414
,
)
)
expect_continue_request
=
[
expect_continue_request
=
[
'PUT / HTTP/1.1'
,
b
'PUT / HTTP/1.1'
,
'Expect: 100-continue'
,
b
'Expect: 100-continue'
,
'Content-Length: 4'
,
b
'Content-Length: 4'
,
'Content-Type: text/plain'
,
b
'Content-Type: text/plain'
,
''
,
b
''
,
'Test'
,
b
'Test'
,
]
]
# No read: 200 OK
# No read: 200 OK
self
.
assertEqual
(
self
.
assertEqual
(
...
@@ -2271,7 +2287,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2271,7 +2287,7 @@ class CaucaseTest(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
getStatus
(
run
(
getStatus
(
run
(
[
[
'PUT / HTTP/1.0'
,
b
'PUT / HTTP/1.0'
,
]
+
expect_continue_request
[
1
:],
]
+
expect_continue_request
[
1
:],
read_app
,
read_app
,
)),
)),
...
@@ -2279,19 +2295,19 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2279,19 +2295,19 @@ class CaucaseTest(unittest.TestCase):
)
)
chunked_request
=
[
chunked_request
=
[
'PUT / HTTP/1.1'
,
b
'PUT / HTTP/1.1'
,
'Transfer-Encoding: chunked'
,
b
'Transfer-Encoding: chunked'
,
''
,
b
''
,
'f;some=extension'
,
b
'f;some=extension'
,
'123456789abcd
\
r
\
n
'
,
b
'123456789abcd
\
r
\
n
'
,
'3'
,
b
'3'
,
'ef0'
,
b
'ef0'
,
'0'
,
b
'0'
,
'X-Chunked-Trailer: blah'
b
'X-Chunked-Trailer: blah'
]
]
self
.
assertEqual
(
self
.
assertEqual
(
getBody
(
run
(
chunked_request
,
read_app
)),
getBody
(
run
(
chunked_request
,
read_app
)),
'123456789abcd
\
r
\
n
ef0'
,
b
'123456789abcd
\
r
\
n
ef0'
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getBody
(
run
(
getBody
(
run
(
...
@@ -2300,7 +2316,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2300,7 +2316,7 @@ class CaucaseTest(unittest.TestCase):
environ
[
'wsgi.input'
].
read
(),
environ
[
'wsgi.input'
].
read
(),
environ
[
'wsgi.input'
].
read
(),
environ
[
'wsgi.input'
].
read
(),
]))),
]))),
'123456789abcd
\
r
\
n
ef0'
,
b
'123456789abcd
\
r
\
n
ef0'
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getBody
(
run
(
getBody
(
run
(
...
@@ -2309,7 +2325,7 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2309,7 +2325,7 @@ class CaucaseTest(unittest.TestCase):
environ
[
'wsgi.input'
].
read
(
6
),
environ
[
'wsgi.input'
].
read
(
6
),
environ
[
'wsgi.input'
].
read
(),
environ
[
'wsgi.input'
].
read
(),
]))),
]))),
'123456789abcd
\
r
\
n
ef0'
,
b
'123456789abcd
\
r
\
n
ef0'
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getBody
(
run
(
getBody
(
run
(
...
@@ -2317,44 +2333,44 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2317,44 +2333,44 @@ class CaucaseTest(unittest.TestCase):
DummyApp
(
lambda
environ
:
[
DummyApp
(
lambda
environ
:
[
environ
[
'wsgi.input'
].
read
(
32
),
environ
[
'wsgi.input'
].
read
(
32
),
]))),
]))),
'123456789abcd
\
r
\
n
ef0'
,
b
'123456789abcd
\
r
\
n
ef0'
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getStatus
(
run
([
getStatus
(
run
([
'PUT / HTTP/1.1'
,
b
'PUT / HTTP/1.1'
,
'Transfer-Encoding: chunked'
,
b
'Transfer-Encoding: chunked'
,
''
,
b
''
,
'1'
,
b
'1'
,
'abc'
,
# Chunk longer than advertised in header.
b
'abc'
,
# Chunk longer than advertised in header.
],
read_app
)),
],
read_app
)),
500
,
500
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getStatus
(
run
([
getStatus
(
run
([
'PUT / HTTP/1.1'
,
b
'PUT / HTTP/1.1'
,
'Transfer-Encoding: chunked'
,
b
'Transfer-Encoding: chunked'
,
''
,
b
''
,
'y'
,
# Not a valid chunk header
b
'y'
,
# Not a valid chunk header
],
read_app
)),
],
read_app
)),
500
,
500
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getStatus
(
run
([
getStatus
(
run
([
'PUT / HTTP/1.1'
,
b
'PUT / HTTP/1.1'
,
'Transfer-Encoding: chunked'
,
b
'Transfer-Encoding: chunked'
,
''
,
b
''
,
'f;'
+
'a'
*
65537
,
# header too long
b'f;'
+
b
'a'
*
65537
,
# header too long
],
read_app
)),
],
read_app
)),
500
,
500
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
getStatus
(
run
([
getStatus
(
run
([
'PUT / HTTP/1.1'
,
b
'PUT / HTTP/1.1'
,
'Transfer-Encoding: chunked'
,
b
'Transfer-Encoding: chunked'
,
''
,
b
''
,
'0'
,
b
'0'
,
'a'
*
65537
,
# trailer too long
b
'a'
*
65537
,
# trailer too long
],
read_app
)),
],
read_app
)),
500
,
500
,
)
)
...
@@ -2580,5 +2596,10 @@ class CaucaseTest(unittest.TestCase):
...
@@ -2580,5 +2596,10 @@ class CaucaseTest(unittest.TestCase):
self
.
assertEqual
(
os
.
stat
(
self
.
_server_db
).
st_mode
&
0o777
,
0o600
)
self
.
assertEqual
(
os
.
stat
(
self
.
_server_db
).
st_mode
&
0o777
,
0o600
)
self
.
assertEqual
(
os
.
stat
(
self
.
_server_key
).
st_mode
&
0o777
,
0o600
)
self
.
assertEqual
(
os
.
stat
(
self
.
_server_key
).
st_mode
&
0o777
,
0o600
)
if
getattr
(
CaucaseTest
,
'assertItemsEqual'
,
None
)
is
None
:
# Because python3 decided it should be named differently, and 2to3 cannot
# pick it up, and this code must remain python2-compatible... Yay !
CaucaseTest
.
assertItemsEqual
=
CaucaseTest
.
assertCountEqual
if
__name__
==
'__main__'
:
# pragma: no cover
if
__name__
==
'__main__'
:
# pragma: no cover
unittest
.
main
()
unittest
.
main
()
caucase/utils.py
View file @
8ce08bf9
...
@@ -21,6 +21,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
...
@@ -21,6 +21,7 @@ Caucase - Certificate Authority for Users, Certificate Authority for SErvices
Small-ish functions needed in many places.
Small-ish functions needed in many places.
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
binascii
import
a2b_base64
,
b2a_base64
from
collections
import
defaultdict
from
collections
import
defaultdict
import
datetime
import
datetime
import
json
import
json
...
@@ -273,19 +274,20 @@ def wrap(payload, key, digest):
...
@@ -273,19 +274,20 @@ def wrap(payload, key, digest):
"""
"""
Sign payload (which gets json-serialised) with key, using given digest.
Sign payload (which gets json-serialised) with key, using given digest.
"""
"""
payload
=
json
.
dumps
(
payload
).
encode
(
'utf-8'
)
payload
=
toBytes
(
json
.
dumps
(
payload
),
'utf-8'
)
hash_class
=
getattr
(
hashes
,
digest
.
upper
())
hash_class
=
getattr
(
hashes
,
digest
.
upper
())
return
{
return
{
'payload'
:
payload
,
'payload'
:
toUnicode
(
payload
)
,
'digest'
:
digest
,
'digest'
:
digest
,
'signature'
:
key
.
sign
(
# For some reason, python3 thinks that a b2a method should return bytes.
payload
+
digest
+
' '
,
'signature'
:
toUnicode
(
b2a_base64
(
key
.
sign
(
payload
+
toBytes
(
digest
)
+
b' '
,
padding
.
PSS
(
padding
.
PSS
(
mgf
=
padding
.
MGF1
(
hash_class
()),
mgf
=
padding
.
MGF1
(
hash_class
()),
salt_length
=
padding
.
PSS
.
MAX_LENGTH
,
salt_length
=
padding
.
PSS
.
MAX_LENGTH
,
),
),
hash_class
(),
hash_class
(),
)
.
encode
(
'base64'
),
)
)
),
}
}
def
nullWrap
(
payload
):
def
nullWrap
(
payload
):
...
@@ -308,10 +310,10 @@ def unwrap(wrapped, getCertificate, digest_list):
...
@@ -308,10 +310,10 @@ def unwrap(wrapped, getCertificate, digest_list):
Note: does *not* verify received certificate itself (validity, issuer, ...).
Note: does *not* verify received certificate itself (validity, issuer, ...).
"""
"""
# Check whether given digest is allowed
# Check whether given digest is allowed
digest
=
wrapped
[
'digest'
]
.
encode
(
'ascii'
)
digest
=
wrapped
[
'digest'
]
if
digest
not
in
digest_list
:
if
digest
not
in
digest_list
:
raise
cryptography
.
exceptions
.
UnsupportedAlgorithm
(
raise
cryptography
.
exceptions
.
UnsupportedAlgorithm
(
'%r is not in allowed digest list
'
,
'%r is not in allowed digest list
%r'
%
(
digest
,
digest_list
)
,
)
)
hash_class
=
getattr
(
hashes
,
digest
.
upper
())
hash_class
=
getattr
(
hashes
,
digest
.
upper
())
try
:
try
:
...
@@ -319,11 +321,11 @@ def unwrap(wrapped, getCertificate, digest_list):
...
@@ -319,11 +321,11 @@ def unwrap(wrapped, getCertificate, digest_list):
except
ValueError
:
except
ValueError
:
raise
NotJSON
raise
NotJSON
x509
.
load_pem_x509_certificate
(
x509
.
load_pem_x509_certificate
(
getCertificate
(
payload
).
encode
(
'ascii'
),
toBytes
(
getCertificate
(
payload
)
),
_cryptography_backend
,
_cryptography_backend
,
).
public_key
().
verify
(
).
public_key
().
verify
(
wrapped
[
'signature'
].
encode
(
'ascii'
).
decode
(
'base64'
),
a2b_base64
(
toBytes
(
wrapped
[
'signature'
])
),
wrapped
[
'payload'
].
encode
(
'utf-8'
)
+
digest
+
' '
,
toBytes
(
wrapped
[
'payload'
],
'utf-8'
)
+
toBytes
(
digest
)
+
b
' '
,
padding
.
PSS
(
padding
.
PSS
(
mgf
=
padding
.
MGF1
(
hash_class
()),
mgf
=
padding
.
MGF1
(
hash_class
()),
salt_length
=
padding
.
PSS
.
MAX_LENGTH
,
salt_length
=
padding
.
PSS
.
MAX_LENGTH
,
...
@@ -445,6 +447,18 @@ class SleepInterrupt(KeyboardInterrupt):
...
@@ -445,6 +447,18 @@ class SleepInterrupt(KeyboardInterrupt):
"""
"""
pass
pass
def
toUnicode
(
value
,
encoding
=
'ascii'
):
"""
Convert value to unicode object, if it is not already.
"""
return
value
if
isinstance
(
value
,
unicode
)
else
value
.
decode
(
encoding
)
def
toBytes
(
value
,
encoding
=
'ascii'
):
"""
Convert valye to bytes object, if it is not already.
"""
return
value
if
isinstance
(
value
,
bytes
)
else
value
.
encode
(
encoding
)
def
interruptibleSleep
(
duration
):
# pragma: no cover
def
interruptibleSleep
(
duration
):
# pragma: no cover
"""
"""
Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt
Like sleep, but raises SleepInterrupt when interrupted by KeyboardInterrupt
...
...
caucase/wsgi.py
View file @
8ce08bf9
...
@@ -19,11 +19,11 @@
...
@@ -19,11 +19,11 @@
Caucase - Certificate Authority for Users, Certificate Authority for SErvices
Caucase - Certificate Authority for Users, Certificate Authority for SErvices
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
cgi
import
escape
from
Cookie
import
SimpleCookie
,
CookieError
from
Cookie
import
SimpleCookie
,
CookieError
import
httplib
import
httplib
import
json
import
json
import
os
import
os
import
sys
import
threading
import
threading
import
time
import
time
import
traceback
import
traceback
...
@@ -34,10 +34,15 @@ import jwt
...
@@ -34,10 +34,15 @@ import jwt
from
.
import
utils
from
.
import
utils
from
.
import
exceptions
from
.
import
exceptions
if
sys
.
version_info
>=
(
3
,
):
# pragma: no cover
from
html
import
escape
else
:
# pragma: no cover
from
cgi
import
escape
__all__
=
(
'Application'
,
'CORSTokenManager'
)
__all__
=
(
'Application'
,
'CORSTokenManager'
)
# TODO: l10n
# TODO: l10n
CORS_FORM_TEMPLATE
=
'''
\
CORS_FORM_TEMPLATE
=
b
'''
\
<html>
<html>
<head>
<head>
<title>Caucase CORS access</title>
<title>Caucase CORS access</title>
...
@@ -213,11 +218,11 @@ class CORSTokenManager(object):
...
@@ -213,11 +218,11 @@ class CORSTokenManager(object):
key
=
os
.
urandom
(
32
)
key
=
os
.
urandom
(
32
)
secret_list
.
append
((
now
+
self
.
_secret_validity_period
,
key
))
secret_list
.
append
((
now
+
self
.
_secret_validity_period
,
key
))
self
.
_onNewKey
(
secret_list
)
self
.
_onNewKey
(
secret_list
)
return
jwt
.
encode
(
return
utils
.
toUnicode
(
jwt
.
encode
(
payload
=
{
'p'
:
payload
},
payload
=
{
'p'
:
payload
},
key
=
key
,
key
=
key
,
algorithm
=
'HS256'
,
algorithm
=
'HS256'
,
)
)
)
def
verify
(
self
,
token
,
default
=
None
):
def
verify
(
self
,
token
,
default
=
None
):
"""
"""
...
@@ -571,7 +576,7 @@ class Application(object):
...
@@ -571,7 +576,7 @@ class Application(object):
except
exceptions
.
NoStorage
:
except
exceptions
.
NoStorage
:
raise
InsufficientStorage
raise
InsufficientStorage
except
exceptions
.
NotJSON
:
except
exceptions
.
NotJSON
:
raise
BadRequest
(
'Invalid json payload'
)
raise
BadRequest
(
b
'Invalid json payload'
)
except
exceptions
.
CertificateAuthorityException
as
e
:
except
exceptions
.
CertificateAuthorityException
as
e
:
raise
BadRequest
(
str
(
e
))
raise
BadRequest
(
str
(
e
))
except
Exception
:
except
Exception
:
...
@@ -581,7 +586,7 @@ class Application(object):
...
@@ -581,7 +586,7 @@ class Application(object):
except
ApplicationError
as
e
:
except
ApplicationError
as
e
:
status
=
e
.
status
status
=
e
.
status
header_list
=
e
.
response_headers
header_list
=
e
.
response_headers
result
=
[
str
(
x
)
for
x
in
e
.
args
]
result
=
[
utils
.
toBytes
(
str
(
x
)
)
for
x
in
e
.
args
]
# Note: header_list and cors_header_list are expected to contain
# Note: header_list and cors_header_list are expected to contain
# distinct header sets. This may not always stay true for "Vary".
# distinct header sets. This may not always stay true for "Vary".
header_list
.
extend
(
cors_header_list
)
header_list
.
extend
(
cors_header_list
)
...
@@ -605,7 +610,7 @@ class Application(object):
...
@@ -605,7 +610,7 @@ class Application(object):
try
:
try
:
return
int
(
crt_id
)
return
int
(
crt_id
)
except
ValueError
:
except
ValueError
:
raise
BadRequest
(
'Invalid integer'
)
raise
BadRequest
(
b
'Invalid integer'
)
@
staticmethod
@
staticmethod
def
_read
(
environ
):
def
_read
(
environ
):
...
@@ -619,9 +624,9 @@ class Application(object):
...
@@ -619,9 +624,9 @@ class Application(object):
try
:
try
:
length
=
int
(
environ
.
get
(
'CONTENT_LENGTH'
)
or
MAX_BODY_LENGTH
)
length
=
int
(
environ
.
get
(
'CONTENT_LENGTH'
)
or
MAX_BODY_LENGTH
)
except
ValueError
:
except
ValueError
:
raise
BadRequest
(
'Invalid Content-Length'
)
raise
BadRequest
(
b
'Invalid Content-Length'
)
if
length
>
MAX_BODY_LENGTH
:
if
length
>
MAX_BODY_LENGTH
:
raise
TooLarge
(
'Content-Length limit exceeded'
)
raise
TooLarge
(
b
'Content-Length limit exceeded'
)
return
environ
[
'wsgi.input'
].
read
(
length
)
return
environ
[
'wsgi.input'
].
read
(
length
)
def
_authenticate
(
self
,
environ
,
header_list
):
def
_authenticate
(
self
,
environ
,
header_list
):
...
@@ -653,12 +658,12 @@ class Application(object):
...
@@ -653,12 +658,12 @@ class Application(object):
json decoding fails.
json decoding fails.
"""
"""
if
environ
.
get
(
'CONTENT_TYPE'
)
!=
'application/json'
:
if
environ
.
get
(
'CONTENT_TYPE'
)
!=
'application/json'
:
raise
BadRequest
(
'Bad Content-Type'
)
raise
BadRequest
(
b
'Bad Content-Type'
)
data
=
self
.
_read
(
environ
)
data
=
self
.
_read
(
environ
)
try
:
try
:
return
json
.
loads
(
data
)
return
json
.
loads
(
data
)
except
ValueError
:
except
ValueError
:
raise
BadRequest
(
'Invalid json'
)
raise
BadRequest
(
b
'Invalid json'
)
def
_createCORSCookie
(
self
,
environ
,
value
):
def
_createCORSCookie
(
self
,
environ
,
value
):
"""
"""
...
@@ -859,7 +864,10 @@ class Application(object):
...
@@ -859,7 +864,10 @@ class Application(object):
name
=
action
[
'name'
]
name
=
action
[
'name'
]
assert
name
not
in
hal_section_dict
,
name
assert
name
not
in
hal_section_dict
,
name
hal_section_dict
[
name
]
=
descriptor_dict
hal_section_dict
[
name
]
=
descriptor_dict
return
self
.
_returnFile
(
json
.
dumps
(
hal
),
'application/hal+json'
)
return
self
.
_returnFile
(
utils
.
toBytes
(
json
.
dumps
(
hal
)),
'application/hal+json'
,
)
def
getCORSForm
(
self
,
context
,
environ
):
# pylint: disable=unused-argument
def
getCORSForm
(
self
,
context
,
environ
):
# pylint: disable=unused-argument
"""
"""
...
@@ -881,9 +889,9 @@ class Application(object):
...
@@ -881,9 +889,9 @@ class Application(object):
raise
BadRequest
raise
BadRequest
return
self
.
_returnFile
(
return
self
.
_returnFile
(
CORS_FORM_TEMPLATE
%
{
CORS_FORM_TEMPLATE
%
{
'caucase'
:
escape
(
self
.
_http_url
,
quote
=
True
),
b'caucase'
:
utils
.
toBytes
(
escape
(
self
.
_http_url
,
quote
=
True
)
),
'return_to'
:
escape
(
return_to
,
quote
=
True
),
b'return_to'
:
utils
.
toBytes
(
escape
(
return_to
,
quote
=
True
)
),
'origin'
:
escape
(
origin
,
quote
=
True
),
b'origin'
:
utils
.
toBytes
(
escape
(
origin
,
quote
=
True
)
),
},
},
'text/html'
,
'text/html'
,
[
[
...
@@ -902,7 +910,7 @@ class Application(object):
...
@@ -902,7 +910,7 @@ class Application(object):
if
environ
[
'wsgi.url_scheme'
]
!=
'https'
:
if
environ
[
'wsgi.url_scheme'
]
!=
'https'
:
raise
NotFound
raise
NotFound
if
environ
.
get
(
'CONTENT_TYPE'
)
!=
'application/x-www-form-urlencoded'
:
if
environ
.
get
(
'CONTENT_TYPE'
)
!=
'application/x-www-form-urlencoded'
:
raise
BadRequest
(
'Unhandled Content-Type'
)
raise
BadRequest
(
b
'Unhandled Content-Type'
)
try
:
try
:
form_dict
=
parse_qs
(
self
.
_read
(
environ
),
strict_parsing
=
True
)
form_dict
=
parse_qs
(
self
.
_read
(
environ
),
strict_parsing
=
True
)
origin
,
=
form_dict
[
'origin'
]
origin
,
=
form_dict
[
'origin'
]
...
@@ -961,7 +969,7 @@ class Application(object):
...
@@ -961,7 +969,7 @@ class Application(object):
header_list
=
[]
header_list
=
[]
self
.
_authenticate
(
environ
,
header_list
)
self
.
_authenticate
(
environ
,
header_list
)
return
self
.
_returnFile
(
return
self
.
_returnFile
(
json
.
dumps
(
context
.
getCertificateRequestList
(
)),
utils
.
toBytes
(
json
.
dumps
(
context
.
getCertificateRequestList
()
)),
'application/json'
,
'application/json'
,
header_list
,
header_list
,
)
)
...
@@ -973,7 +981,7 @@ class Application(object):
...
@@ -973,7 +981,7 @@ class Application(object):
try
:
try
:
csr_id
=
context
.
appendCertificateSigningRequest
(
self
.
_read
(
environ
))
csr_id
=
context
.
appendCertificateSigningRequest
(
self
.
_read
(
environ
))
except
exceptions
.
NotACertificateSigningRequest
:
except
exceptions
.
NotACertificateSigningRequest
:
raise
BadRequest
(
'Not a valid certificate signing request'
)
raise
BadRequest
(
b
'Not a valid certificate signing request'
)
return
(
STATUS_CREATED
,
[(
'Location'
,
str
(
csr_id
))],
[])
return
(
STATUS_CREATED
,
[(
'Location'
,
str
(
csr_id
))],
[])
def
deletePendingCertificateRequest
(
self
,
context
,
environ
,
subpath
):
def
deletePendingCertificateRequest
(
self
,
context
,
environ
,
subpath
):
...
@@ -1013,7 +1021,7 @@ class Application(object):
...
@@ -1013,7 +1021,7 @@ class Application(object):
Handle GET /{context}/crt/ca.crt.json urls.
Handle GET /{context}/crt/ca.crt.json urls.
"""
"""
return
self
.
_returnFile
(
return
self
.
_returnFile
(
json
.
dumps
(
context
.
getValidCACertificateChain
(
)),
utils
.
toBytes
(
json
.
dumps
(
context
.
getValidCACertificateChain
()
)),
'application/json'
,
'application/json'
,
)
)
...
@@ -1050,7 +1058,7 @@ class Application(object):
...
@@ -1050,7 +1058,7 @@ class Application(object):
context
.
digest_list
,
context
.
digest_list
,
)
)
context
.
revoke
(
context
.
revoke
(
crt_pem
=
payload
[
'revoke_crt_pem'
].
encode
(
'ascii'
),
crt_pem
=
utils
.
toBytes
(
payload
[
'revoke_crt_pem'
]
),
)
)
return
(
STATUS_NO_CONTENT
,
header_list
,
[])
return
(
STATUS_NO_CONTENT
,
header_list
,
[])
...
@@ -1065,8 +1073,8 @@ class Application(object):
...
@@ -1065,8 +1073,8 @@ class Application(object):
)
)
return
self
.
_returnFile
(
return
self
.
_returnFile
(
context
.
renew
(
context
.
renew
(
crt_pem
=
payload
[
'crt_pem'
].
encode
(
'ascii'
),
crt_pem
=
utils
.
toBytes
(
payload
[
'crt_pem'
]
),
csr_pem
=
payload
[
'renew_csr_pem'
].
encode
(
'ascii'
),
csr_pem
=
utils
.
toBytes
(
payload
[
'renew_csr_pem'
]
),
),
),
'application/pkix-cert'
,
'application/pkix-cert'
,
)
)
...
@@ -1084,7 +1092,7 @@ class Application(object):
...
@@ -1084,7 +1092,7 @@ class Application(object):
elif
environ
.
get
(
'CONTENT_TYPE'
)
==
'application/pkcs10'
:
elif
environ
.
get
(
'CONTENT_TYPE'
)
==
'application/pkcs10'
:
template_csr
=
utils
.
load_certificate_request
(
body
)
template_csr
=
utils
.
load_certificate_request
(
body
)
else
:
else
:
raise
BadRequest
(
'Bad Content-Type'
)
raise
BadRequest
(
b
'Bad Content-Type'
)
header_list
=
[]
header_list
=
[]
self
.
_authenticate
(
environ
,
header_list
)
self
.
_authenticate
(
environ
,
header_list
)
context
.
createCertificate
(
context
.
createCertificate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment