Commit ba776430 authored by Thomas Gambier's avatar Thomas Gambier

software/slapos_master: take improvements of test_balancer from erp5

Especially, we apply commit 341d42e3.
parent ec91edf6
...@@ -10,25 +10,24 @@ import tempfile ...@@ -10,25 +10,24 @@ import tempfile
import time import time
import urllib.parse import urllib.parse
from http.server import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from typing import Dict
from unittest import mock from unittest import mock
import OpenSSL.SSL
import pexpect
import psutil
import requests
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID from cryptography.x509.oid import NameOID
from slapos.testing.testcase import ManagedResource import OpenSSL.SSL
from slapos.testing.utils import (CrontabMixin, ManagedHTTPServer, import pexpect
findFreeTCPPort) import psutil
import requests
from slapos.testing.caucase import CaucaseCertificate, CaucaseService
from slapos.testing.utils import CrontabMixin, ManagedHTTPServer
from . import ERP5InstanceTestCase, setUpModule from . import ERP5InstanceTestCase, setUpModule
setUpModule # pyflakes _ = setUpModule
class EchoHTTPServer(ManagedHTTPServer): class EchoHTTPServer(ManagedHTTPServer):
...@@ -36,8 +35,7 @@ class EchoHTTPServer(ManagedHTTPServer): ...@@ -36,8 +35,7 @@ class EchoHTTPServer(ManagedHTTPServer):
encoded in json. encoded in json.
""" """
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self) -> None:
# type: () -> None
self.send_response(200) self.send_response(200)
self.send_header("Content-Type", "application/json") self.send_header("Content-Type", "application/json")
response = json.dumps( response = json.dumps(
...@@ -59,8 +57,7 @@ class EchoHTTP11Server(ManagedHTTPServer): ...@@ -59,8 +57,7 @@ class EchoHTTP11Server(ManagedHTTPServer):
""" """
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1' protocol_version = 'HTTP/1.1'
def do_GET(self): def do_GET(self) -> None:
# type: () -> None
self.send_response(200) self.send_response(200)
self.send_header("Content-Type", "application/json") self.send_header("Content-Type", "application/json")
response = json.dumps( response = json.dumps(
...@@ -77,61 +74,6 @@ class EchoHTTP11Server(ManagedHTTPServer): ...@@ -77,61 +74,6 @@ class EchoHTTP11Server(ManagedHTTPServer):
log_message = logging.getLogger(__name__ + '.EchoHTTP11Server').info log_message = logging.getLogger(__name__ + '.EchoHTTP11Server').info
class CaucaseService(ManagedResource):
"""A caucase service.
"""
url = None # type: str
directory = None # type: str
_caucased_process = None # type: subprocess.Popen
def open(self):
# type: () -> None
# start a caucased and server certificate.
software_release_root_path = os.path.join(
self._cls.slap._software_root,
hashlib.md5(self._cls.getSoftwareURL().encode()).hexdigest(),
)
caucased_path = os.path.join(software_release_root_path, 'bin', 'caucased')
self.directory = tempfile.mkdtemp()
caucased_dir = os.path.join(self.directory, 'caucased')
os.mkdir(caucased_dir)
os.mkdir(os.path.join(caucased_dir, 'user'))
os.mkdir(os.path.join(caucased_dir, 'service'))
backend_caucased_netloc = f'{self._cls._ipv4_address}:{findFreeTCPPort(self._cls._ipv4_address)}'
self.url = 'http://' + backend_caucased_netloc
self._caucased_process = subprocess.Popen(
[
caucased_path,
'--db', os.path.join(caucased_dir, 'caucase.sqlite'),
'--server-key', os.path.join(caucased_dir, 'server.key.pem'),
'--netloc', backend_caucased_netloc,
'--service-auto-approve-count', '1',
],
# capture subprocess output not to pollute test's own stdout
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
for _ in range(30):
try:
if requests.get(self.url).status_code == 200:
break
except Exception:
pass
time.sleep(1)
else:
raise RuntimeError('caucased failed to start.')
def close(self):
# type: () -> None
self._caucased_process.terminate()
self._caucased_process.wait()
self._caucased_process.stdout.close()
shutil.rmtree(self.directory)
class BalancerTestCase(ERP5InstanceTestCase): class BalancerTestCase(ERP5InstanceTestCase):
@classmethod @classmethod
...@@ -139,8 +81,7 @@ class BalancerTestCase(ERP5InstanceTestCase): ...@@ -139,8 +81,7 @@ class BalancerTestCase(ERP5InstanceTestCase):
return 'balancer' return 'balancer'
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
return { return {
'shared-certificate-authority-path': os.path.join( 'shared-certificate-authority-path': os.path.join(
'~', 'srv', 'ssl'), '~', 'srv', 'ssl'),
...@@ -177,11 +118,10 @@ class BalancerTestCase(ERP5InstanceTestCase): ...@@ -177,11 +118,10 @@ class BalancerTestCase(ERP5InstanceTestCase):
} }
@classmethod @classmethod
def getInstanceParameterDict(cls): def getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
return {'_': json.dumps(cls._getInstanceParameterDict())} return {'_': json.dumps(cls._getInstanceParameterDict())}
def setUp(self): def setUp(self) -> None:
self.default_balancer_url = json.loads( self.default_balancer_url = json.loads(
self.computer_partition.getConnectionParameterDict()['_'])['default'] self.computer_partition.getConnectionParameterDict()['_'])['default']
...@@ -192,15 +132,16 @@ class SlowHTTPServer(ManagedHTTPServer): ...@@ -192,15 +132,16 @@ class SlowHTTPServer(ManagedHTTPServer):
Timeout is 2 seconds by default, and can be specified in the path of the URL Timeout is 2 seconds by default, and can be specified in the path of the URL
""" """
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self) -> None:
# type: () -> None
self.send_response(200)
self.send_header("Content-Type", "text/plain")
timeout = 2 timeout = 2
if self.path == '/': # for health checks
timeout = 0
try: try:
timeout = int(self.path[1:]) timeout = int(self.path.split('/')[5])
except ValueError: except (ValueError, IndexError):
pass pass
self.send_response(200)
self.send_header("Content-Type", "text/plain")
time.sleep(timeout) time.sleep(timeout)
self.end_headers() self.end_headers()
self.wfile.write(b"OK\n") self.wfile.write(b"OK\n")
...@@ -208,6 +149,28 @@ class SlowHTTPServer(ManagedHTTPServer): ...@@ -208,6 +149,28 @@ class SlowHTTPServer(ManagedHTTPServer):
log_message = logging.getLogger(__name__ + '.SlowHTTPServer').info log_message = logging.getLogger(__name__ + '.SlowHTTPServer').info
class TestTimeout(BalancerTestCase, CrontabMixin):
__partition_reference__ = 't'
@classmethod
def _getInstanceParameterDict(cls) -> dict:
parameter_dict = super()._getInstanceParameterDict()
# use a slow server instead
parameter_dict['dummy_http_server'] = [[cls.getManagedResource("slow_web_server", SlowHTTPServer).netloc, 1, False]]
# and set timeout of 1 second
parameter_dict['timeout-dict'] = {'default': 1}
return parameter_dict
def test_timeout(self) -> None:
self.assertEqual(
requests.get(
urllib.parse.urljoin(self.default_balancer_zope_url, '/1'),
verify=False).status_code,
requests.codes.ok)
self.assertEqual(
requests.get(
urllib.parse.urljoin(self.default_balancer_zope_url, '/5'),
verify=False).status_code,
requests.codes.gateway_timeout)
class TestLog(BalancerTestCase, CrontabMixin): class TestLog(BalancerTestCase, CrontabMixin):
...@@ -215,15 +178,13 @@ class TestLog(BalancerTestCase, CrontabMixin): ...@@ -215,15 +178,13 @@ class TestLog(BalancerTestCase, CrontabMixin):
""" """
__partition_reference__ = 'l' __partition_reference__ = 'l'
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
parameter_dict = super()._getInstanceParameterDict() parameter_dict = super()._getInstanceParameterDict()
# use a slow server instead # use a slow server instead, so that we can test logs with slow requests
parameter_dict['dummy_http_server'] = [[cls.getManagedResource("slow_web_server", SlowHTTPServer).netloc, 1, False]] parameter_dict['dummy_http_server'] = [[cls.getManagedResource("slow_web_server", SlowHTTPServer).netloc, 1, False]]
return parameter_dict return parameter_dict
def test_access_log_format(self): def test_access_log_format(self) -> None:
# type: () -> None
requests.get( requests.get(
urllib.parse.urljoin(self.default_balancer_url, '/url_path'), urllib.parse.urljoin(self.default_balancer_url, '/url_path'),
verify=False, verify=False,
...@@ -248,31 +209,25 @@ class TestLog(BalancerTestCase, CrontabMixin): ...@@ -248,31 +209,25 @@ class TestLog(BalancerTestCase, CrontabMixin):
self.assertGreater(request_time, 2 * 1000000) self.assertGreater(request_time, 2 * 1000000)
self.assertLess(request_time, 20 * 1000000) self.assertLess(request_time, 20 * 1000000)
def test_access_log_apachedex_report(self): def test_access_log_apachedex_report(self) -> None:
# type: () -> None
# make a request so that we have something in the logs # make a request so that we have something in the logs
requests.get(self.default_balancer_url, verify=False) requests.get(self.default_balancer_url, verify=False)
# crontab for apachedex is executed # crontab for apachedex is executed
self._executeCrontabAtDate('generate-apachedex-report', '23:59') self._executeCrontabAtDate('generate-apachedex-report', '23:59')
# it creates a report for the day # it creates a report for the day
apachedex_report, = glob.glob( apachedex_report, = (
os.path.join( self.computer_partition_root_path
self.computer_partition_root_path, / 'srv'
'srv', / 'monitor'
'monitor', / 'private'
'private', / 'apachedex').glob('ApacheDex-*.html')
'apachedex', report_text = apachedex_report.read_text()
'ApacheDex-*.html',
))
with open(apachedex_report) as f:
report_text = f.read()
self.assertIn('APacheDEX', report_text) self.assertIn('APacheDEX', report_text)
# having this table means that apachedex could parse some lines. # having this table means that apachedex could parse some lines.
self.assertIn('<h2>Hits per status code</h2>', report_text) self.assertIn('<h2>Hits per status code</h2>', report_text)
def test_access_log_rotation(self): def test_access_log_rotation(self) -> None:
# type: () -> None
# run logrotate a first time so that it create state files # run logrotate a first time so that it create state files
self._executeCrontabAtDate('logrotate', '2000-01-01') self._executeCrontabAtDate('logrotate', '2000-01-01')
...@@ -298,7 +253,7 @@ class TestLog(BalancerTestCase, CrontabMixin): ...@@ -298,7 +253,7 @@ class TestLog(BalancerTestCase, CrontabMixin):
self.assertTrue(os.path.exists(rotated_log_file + '.xz')) self.assertTrue(os.path.exists(rotated_log_file + '.xz'))
self.assertFalse(os.path.exists(rotated_log_file)) self.assertFalse(os.path.exists(rotated_log_file))
def test_error_log(self): def test_error_log(self) -> None:
# stop backend server # stop backend server
backend_server = self.getManagedResource("slow_web_server", SlowHTTPServer) backend_server = self.getManagedResource("slow_web_server", SlowHTTPServer)
self.addCleanup(backend_server.open) self.addCleanup(backend_server.open)
...@@ -308,8 +263,8 @@ class TestLog(BalancerTestCase, CrontabMixin): ...@@ -308,8 +263,8 @@ class TestLog(BalancerTestCase, CrontabMixin):
self.assertEqual( self.assertEqual(
requests.get(self.default_balancer_url, verify=False).status_code, requests.get(self.default_balancer_url, verify=False).status_code,
requests.codes.service_unavailable) requests.codes.service_unavailable)
with open(os.path.join(self.computer_partition_root_path, 'var', 'log', 'apache-error.log')) as error_log_file: error_log_file = self.computer_partition_root_path / 'var' / 'log' / 'apache-error.log'
error_line = error_log_file.read().splitlines()[-1] error_line = error_log_file.read_text().splitlines()[-1]
self.assertIn('apache.conf -D FOREGROUND', error_line) self.assertIn('apache.conf -D FOREGROUND', error_line)
# this log also include a timestamp # this log also include a timestamp
# This regex is for haproxy mostly, so keep it commented for now, until we can # This regex is for haproxy mostly, so keep it commented for now, until we can
...@@ -320,7 +275,9 @@ class TestLog(BalancerTestCase, CrontabMixin): ...@@ -320,7 +275,9 @@ class TestLog(BalancerTestCase, CrontabMixin):
class BalancerCookieHTTPServer(ManagedHTTPServer): class BalancerCookieHTTPServer(ManagedHTTPServer):
"""An HTTP Server which can set balancer cookie. """An HTTP Server which can set balancer cookie.
This server set cookie when requested /set-cookie path. This server set cookie when requested /set-cookie path (actually
/VirtualHostBase/https/{host}/VirtualHostRoot/set-cookie , which is
added by balancer proxy)
The reply body is the name used when registering this resource The reply body is the name used when registering this resource
using getManagedResource. This way we can assert which using getManagedResource. This way we can assert which
...@@ -331,8 +288,7 @@ class BalancerCookieHTTPServer(ManagedHTTPServer): ...@@ -331,8 +288,7 @@ class BalancerCookieHTTPServer(ManagedHTTPServer):
def RequestHandler(self): def RequestHandler(self):
server = self server = self
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self) -> None:
# type: () -> None
self.send_response(200) self.send_response(200)
self.send_header("Content-Type", "text/plain") self.send_header("Content-Type", "text/plain")
if self.path == '/set_cookie': if self.path == '/set_cookie':
...@@ -353,8 +309,7 @@ class TestBalancer(BalancerTestCase): ...@@ -353,8 +309,7 @@ class TestBalancer(BalancerTestCase):
""" """
__partition_reference__ = 'b' __partition_reference__ = 'b'
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
parameter_dict = super()._getInstanceParameterDict() parameter_dict = super()._getInstanceParameterDict()
# use two backend servers # use two backend servers
...@@ -364,14 +319,14 @@ class TestBalancer(BalancerTestCase): ...@@ -364,14 +319,14 @@ class TestBalancer(BalancerTestCase):
] ]
return parameter_dict return parameter_dict
def test_balancer_round_robin(self): def test_balancer_round_robin(self) -> None:
# requests are by default balanced to both servers # requests are by default balanced to both servers
self.assertEqual( self.assertEqual(
{requests.get(self.default_balancer_url, verify=False).text for _ in range(10)}, {requests.get(self.default_balancer_url, verify=False).text for _ in range(10)},
{'backend_web_server1', 'backend_web_server2'} {'backend_web_server1', 'backend_web_server2'}
) )
def test_balancer_server_down(self): def test_balancer_server_down(self) -> None:
# if one backend is down, it is excluded from balancer # if one backend is down, it is excluded from balancer
self.getManagedResource("backend_web_server2", BalancerCookieHTTPServer).close() self.getManagedResource("backend_web_server2", BalancerCookieHTTPServer).close()
self.addCleanup(self.getManagedResource("backend_web_server2", BalancerCookieHTTPServer).open) self.addCleanup(self.getManagedResource("backend_web_server2", BalancerCookieHTTPServer).open)
...@@ -380,7 +335,7 @@ class TestBalancer(BalancerTestCase): ...@@ -380,7 +335,7 @@ class TestBalancer(BalancerTestCase):
{'backend_web_server1',} {'backend_web_server1',}
) )
def test_balancer_set_cookie(self): def test_balancer_set_cookie(self) -> None:
# if backend provides a "SERVERID" cookie, balancer will overwrite it with the # if backend provides a "SERVERID" cookie, balancer will overwrite it with the
# backend selected by balancing algorithm # backend selected by balancing algorithm
self.assertIn( self.assertIn(
...@@ -388,7 +343,7 @@ class TestBalancer(BalancerTestCase): ...@@ -388,7 +343,7 @@ class TestBalancer(BalancerTestCase):
('default-0', 'default-1'), ('default-0', 'default-1'),
) )
def test_balancer_respects_sticky_cookie(self): def test_balancer_respects_sticky_cookie(self) -> None:
# if request is made with the sticky cookie, the client stick on one balancer # if request is made with the sticky cookie, the client stick on one balancer
cookies = dict(SERVERID='default-1') cookies = dict(SERVERID='default-1')
self.assertEqual( self.assertEqual(
...@@ -409,8 +364,7 @@ class TestTestRunnerEntryPoints(BalancerTestCase): ...@@ -409,8 +364,7 @@ class TestTestRunnerEntryPoints(BalancerTestCase):
""" """
__partition_reference__ = 't' __partition_reference__ = 't'
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
parameter_dict = super()._getInstanceParameterDict() parameter_dict = super()._getInstanceParameterDict()
parameter_dict['dummy_http_server-test-runner-address-list'] = [ parameter_dict['dummy_http_server-test-runner-address-list'] = [
...@@ -429,7 +383,7 @@ class TestTestRunnerEntryPoints(BalancerTestCase): ...@@ -429,7 +383,7 @@ class TestTestRunnerEntryPoints(BalancerTestCase):
] ]
return parameter_dict return parameter_dict
def test_use_proper_backend(self): def test_use_proper_backend(self) -> None:
# requests are directed to proper backend based on URL path # requests are directed to proper backend based on URL path
test_runner_url_list = self.getRootPartitionConnectionParameterDict( test_runner_url_list = self.getRootPartitionConnectionParameterDict(
)['default-test-runner-url-list'] )['default-test-runner-url-list']
...@@ -482,8 +436,7 @@ class TestHTTP(BalancerTestCase): ...@@ -482,8 +436,7 @@ class TestHTTP(BalancerTestCase):
"""Check HTTP protocol with a HTTP/1.1 backend """Check HTTP protocol with a HTTP/1.1 backend
""" """
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
parameter_dict = super()._getInstanceParameterDict() parameter_dict = super()._getInstanceParameterDict()
# use a HTTP/1.1 server instead # use a HTTP/1.1 server instead
parameter_dict['dummy_http_server'] = [[cls.getManagedResource("HTTP/1.1 Server", EchoHTTP11Server).netloc, 1, False]] parameter_dict['dummy_http_server'] = [[cls.getManagedResource("HTTP/1.1 Server", EchoHTTP11Server).netloc, 1, False]]
...@@ -491,8 +444,7 @@ class TestHTTP(BalancerTestCase): ...@@ -491,8 +444,7 @@ class TestHTTP(BalancerTestCase):
__partition_reference__ = 'h' __partition_reference__ = 'h'
def test_http_version(self): def test_http_version(self) -> None:
# type: () -> None
self.assertEqual( self.assertEqual(
subprocess.check_output([ subprocess.check_output([
'curl', 'curl',
...@@ -508,8 +460,7 @@ class TestHTTP(BalancerTestCase): ...@@ -508,8 +460,7 @@ class TestHTTP(BalancerTestCase):
b'1.1', b'1.1',
) )
def test_keep_alive(self): def test_keep_alive(self) -> None:
# type: () -> None
# when doing two requests, connection is established only once # when doing two requests, connection is established only once
with requests.Session() as session: with requests.Session() as session:
session.verify = False session.verify = False
...@@ -539,13 +490,14 @@ class ContentTypeHTTPServer(ManagedHTTPServer): ...@@ -539,13 +490,14 @@ class ContentTypeHTTPServer(ManagedHTTPServer):
For example when requested http://host/text/plain it will reply For example when requested http://host/text/plain it will reply
with Content-Type: text/plain header. with Content-Type: text/plain header.
This actually uses a URL like this to support zope style virtual host:
GET /VirtualHostBase/https/{host}/VirtualHostRoot/text/plain
The body is always "OK" The body is always "OK"
""" """
class RequestHandler(BaseHTTPRequestHandler): class RequestHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1' protocol_version = 'HTTP/1.1'
def do_GET(self): def do_GET(self) -> None:
# type: () -> None
self.send_response(200) self.send_response(200)
if self.path == '/': if self.path == '/':
self.send_header("Content-Length", '0') self.send_header("Content-Length", '0')
...@@ -565,8 +517,7 @@ class TestContentEncoding(BalancerTestCase): ...@@ -565,8 +517,7 @@ class TestContentEncoding(BalancerTestCase):
""" """
__partition_reference__ = 'ce' __partition_reference__ = 'ce'
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
parameter_dict = super()._getInstanceParameterDict() parameter_dict = super()._getInstanceParameterDict()
parameter_dict['dummy_http_server'] = [ parameter_dict['dummy_http_server'] = [
[cls.getManagedResource("content_type_server", ContentTypeHTTPServer).netloc, 1, False], [cls.getManagedResource("content_type_server", ContentTypeHTTPServer).netloc, 1, False],
...@@ -575,8 +526,7 @@ class TestContentEncoding(BalancerTestCase): ...@@ -575,8 +526,7 @@ class TestContentEncoding(BalancerTestCase):
# Disabled test until we can rework on it for apache, or drop # Disabled test until we can rework on it for apache, or drop
# apache on the backend. # apache on the backend.
def disabled_test_gzip_encoding(self): def disabled_test_gzip_encoding(self) -> None:
# type: () -> None
for content_type in ( for content_type in (
'text/cache-manifest', 'text/cache-manifest',
'text/html', 'text/html',
...@@ -596,10 +546,7 @@ class TestContentEncoding(BalancerTestCase): ...@@ -596,10 +546,7 @@ class TestContentEncoding(BalancerTestCase):
'application/font-woff2', 'application/font-woff2',
'application/x-font-opentype', 'application/x-font-opentype',
'application/wasm',): 'application/wasm',):
resp = requests.get( resp = requests.get(urllib.parse.urljoin(self.default_balancer_url, content_type), verify=False)
urllib.parse.urljoin(self.default_balancer_url, content_type),
verify=False,
headers={"Accept-Encoding": "gzip, deflate",})
self.assertEqual(resp.headers['Content-Type'], content_type) self.assertEqual(resp.headers['Content-Type'], content_type)
self.assertEqual( self.assertEqual(
resp.headers.get('Content-Encoding'), resp.headers.get('Content-Encoding'),
...@@ -607,121 +554,12 @@ class TestContentEncoding(BalancerTestCase): ...@@ -607,121 +554,12 @@ class TestContentEncoding(BalancerTestCase):
'{} uses wrong encoding: {}'.format(content_type, resp.headers.get('Content-Encoding'))) '{} uses wrong encoding: {}'.format(content_type, resp.headers.get('Content-Encoding')))
self.assertEqual(resp.text, 'OK') self.assertEqual(resp.text, 'OK')
def test_no_gzip_encoding(self): def test_no_gzip_encoding(self) -> None:
# type: () -> None resp = requests.get(urllib.parse.urljoin(self.default_balancer_zope_url, '/image/png'), verify=False)
resp = requests.get(urllib.parse.urljoin(self.default_balancer_url, '/image/png'), verify=False)
self.assertNotIn('Content-Encoding', resp.headers) self.assertNotIn('Content-Encoding', resp.headers)
self.assertEqual(resp.text, 'OK') self.assertEqual(resp.text, 'OK')
class CaucaseCertificate(ManagedResource):
"""A certificate signed by a caucase service.
"""
ca_crt_file = None # type: str
crl_file = None # type: str
csr_file = None # type: str
cert_file = None # type: str
key_file = None # type: str
def open(self):
# type: () -> None
self.tmpdir = tempfile.mkdtemp()
self.ca_crt_file = os.path.join(self.tmpdir, 'ca-crt.pem')
self.crl_file = os.path.join(self.tmpdir, 'ca-crl.pem')
self.csr_file = os.path.join(self.tmpdir, 'csr.pem')
self.cert_file = os.path.join(self.tmpdir, 'crt.pem')
self.key_file = os.path.join(self.tmpdir, 'key.pem')
def close(self):
# type: () -> None
shutil.rmtree(self.tmpdir)
@property
def _caucase_path(self):
# type: () -> str
"""path of caucase executable.
"""
software_release_root_path = os.path.join(
self._cls.slap._software_root,
hashlib.md5(self._cls.getSoftwareURL().encode()).hexdigest(),
)
return os.path.join(software_release_root_path, 'bin', 'caucase')
def request(self, common_name, caucase):
# type: (str, CaucaseService) -> None
"""Generate certificate and request signature to the caucase service.
This overwrite any previously requested certificate for this instance.
"""
cas_args = [
self._caucase_path,
'--ca-url', caucase.url,
'--ca-crt', self.ca_crt_file,
'--crl', self.crl_file,
]
key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)
with open(self.key_file, 'wb') as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
))
csr = x509.CertificateSigningRequestBuilder().subject_name(
x509.Name([
x509.NameAttribute(
NameOID.COMMON_NAME,
common_name,
),
])).sign(
key,
hashes.SHA256(),
default_backend(),
)
with open(self.csr_file, 'wb') as f:
f.write(csr.public_bytes(serialization.Encoding.PEM))
csr_id = subprocess.check_output(
cas_args + [
'--send-csr', self.csr_file,
],
).split()[0].decode()
assert csr_id
for _ in range(30):
if not subprocess.call(
cas_args + [
'--get-crt', csr_id, self.cert_file,
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) == 0:
break
else:
time.sleep(1)
else:
raise RuntimeError('getting service certificate failed.')
with open(self.cert_file) as cert_file:
assert 'BEGIN CERTIFICATE' in cert_file.read()
def revoke(self, caucase):
# type: (str, CaucaseService) -> None
"""Revoke the client certificate on this caucase instance.
"""
subprocess.check_call([
self._caucase_path,
'--ca-url', caucase.url,
'--ca-crt', self.ca_crt_file,
'--crl', self.crl_file,
'--revoke-crt', self.cert_file, self.key_file,
])
class TestServerTLSProvidedCertificate(BalancerTestCase): class TestServerTLSProvidedCertificate(BalancerTestCase):
"""Check that certificate and key can be provided as instance parameters. """Check that certificate and key can be provided as instance parameters.
...@@ -729,8 +567,7 @@ class TestServerTLSProvidedCertificate(BalancerTestCase): ...@@ -729,8 +567,7 @@ class TestServerTLSProvidedCertificate(BalancerTestCase):
__partition_reference__ = 's' __partition_reference__ = 's'
@classmethod @classmethod
def _getInstanceParameterDict(cls): def _getInstanceParameterDict(cls) -> dict:
# type: () -> Dict
server_caucase = cls.getManagedResource('server_caucase', CaucaseService) server_caucase = cls.getManagedResource('server_caucase', CaucaseService)
server_certificate = cls.getManagedResource('server_certificate', CaucaseCertificate) server_certificate = cls.getManagedResource('server_certificate', CaucaseCertificate)
server_certificate.request(cls._ipv4_address, server_caucase) server_certificate.request(cls._ipv4_address, server_caucase)
...@@ -741,8 +578,7 @@ class TestServerTLSProvidedCertificate(BalancerTestCase): ...@@ -741,8 +578,7 @@ class TestServerTLSProvidedCertificate(BalancerTestCase):
parameter_dict['ssl']['key'] = f.read() parameter_dict['ssl']['key'] = f.read()
return parameter_dict return parameter_dict
def test_certificate_validates_with_provided_ca(self): def test_certificate_validates_with_provided_ca(self) -> None:
# type: () -> None
server_certificate = self.getManagedResource("server_certificate", CaucaseCertificate) server_certificate = self.getManagedResource("server_certificate", CaucaseCertificate)
requests.get(self.default_balancer_url, verify=server_certificate.ca_crt_file) requests.get(self.default_balancer_url, verify=server_certificate.ca_crt_file)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment