Commit 0c0fffa2 authored by Denis Bilenko's avatar Denis Bilenko

test_ssl.py: copy all new tests and updates from Python 2.7 (but make it pass...

test_ssl.py: copy all new tests and updates from Python 2.7 (but make it pass on older versions too)
parent 733aa47c
# Test the support for SSL and sockets
from gevent import monkey; monkey.patch_all(aggressive=True)
from gevent import monkey; monkey.patch_all()
import sys
import unittest
......@@ -7,15 +7,27 @@ import test_support
import asyncore
import socket
import select
import time
import gc
import os
import errno
import pprint
import urllib, urlparse
import traceback
import weakref
from BaseHTTPServer import HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
try:
bytearray
except NameError:
bytearray = None
try:
memoryview
except NameError:
memoryview = None
# Optionally test SSL support, if we have it in the tested platform
skip_expected = False
try:
......@@ -24,55 +36,36 @@ except ImportError:
skip_expected = True
HOST = 'localhost'
CERTFILE = None
SVN_PYTHON_ORG_ROOT_CERT = None
CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "keycert.pem")
SVN_PYTHON_ORG_ROOT_CERT = os.path.join(os.path.dirname(__file__) or os.curdir, "https_svn_python_org_root.pem")
def handle_error(prefix):
exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
if test_support.verbose:
sys.stdout.write(prefix + exc_format)
def testSimpleSSLwrap(self):
class BasicTests(unittest.TestCase):
def test_sslwrap_simple(self):
# A crude test for the legacy API
try:
ssl.sslwrap_simple(socket.socket(socket.AF_INET))
except IOError, e:
if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
if e[0] == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
pass
else:
raise
try:
ssl.sslwrap_simple(socket.socket(socket.AF_INET)._sock)
except IOError, e:
if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
if e[0] == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
pass
else:
raise
class BasicTests(unittest.TestCase):
def testSSLconnect(self):
if not test_support.is_resource_enabled('network'):
return
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE)
s.connect(("svn.python.org", 443))
c = s.getpeercert()
if c:
raise test_support.TestFailed("Peer cert %s shouldn't be here!")
s.close()
# this should fail because we have no verification certs
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED)
try:
try:
s.connect(("svn.python.org", 443))
except ssl.SSLError:
pass
finally:
s.close()
def testCrucialConstants(self):
def test_constants(self):
ssl.PROTOCOL_SSLv2
ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_SSLv3
......@@ -81,7 +74,7 @@ class BasicTests(unittest.TestCase):
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED
def testRAND(self):
def test_random(self):
v = ssl.RAND_status()
if test_support.verbose:
sys.stdout.write("\n RAND_status is %d (%s)\n"
......@@ -95,7 +88,7 @@ class BasicTests(unittest.TestCase):
print "didn't raise TypeError"
ssl.RAND_add("this is a random string", 75.0)
def testParseCert(self):
def test_parse_cert(self):
# note that this uses an 'unofficial' function in _ssl.c,
# provided solely for this test, to exercise the certificate
# parsing code
......@@ -103,8 +96,7 @@ class BasicTests(unittest.TestCase):
if test_support.verbose:
sys.stdout.write("\n" + pprint.pformat(p) + "\n")
def testDERtoPEM(self):
def test_DER_to_PEM(self):
pem = open(SVN_PYTHON_ORG_ROOT_CERT, 'r').read()
d1 = ssl.PEM_cert_to_DER_cert(pem)
p2 = ssl.DER_cert_to_PEM_cert(d1)
......@@ -112,9 +104,80 @@ class BasicTests(unittest.TestCase):
if (d1 != d2):
raise test_support.TestFailed("PEM-to-DER or DER-to-PEM translation failed")
def _test_openssl_version(self):
n = ssl.OPENSSL_VERSION_NUMBER
t = ssl.OPENSSL_VERSION_INFO
s = ssl.OPENSSL_VERSION
self.assertIsInstance(n, (int, long))
self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str)
# Some sanity checks follow
# >= 0.9
self.assertGreaterEqual(n, 0x900000)
# < 2.0
self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t
self.assertGreaterEqual(major, 0)
self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 0)
self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26)
self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15)
# Version string as returned by OpenSSL, the format might change
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
(s, t))
def test_openssl_version(self):
try:
self._test_openssl_version()
except AttributeError:
if sys.version_info[:2] >= (2, 7):
raise
def _test_ciphers(self):
remote = ("svn.python.org", 443)
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE, ciphers="ALL")
s.connect(remote)
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
s.connect(remote)
# Error checking occurs when connecting, because the SSL context
# isn't created before.
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
try:
s.connect(remote)
except ssl.SSLError, ex:
if "No cipher can be selected" not in str(ex):
raise
def test_ciphers(self):
try:
self._test_ciphers()
except TypeError, ex:
if 'sslwrap() takes at most 7 arguments (8 given)' in str(ex) and sys.version_info[:2] <= (2, 6):
pass
else:
raise
def test_refcycle(self):
# Issue #7943: an SSL object doesn't create reference cycles with
# itself.
s = socket.socket(socket.AF_INET)
ss = ssl.wrap_socket(s)
wr = weakref.ref(ss)
del ss
self.assertEqual(wr(), None)
class NetworkedTests(unittest.TestCase):
def testConnect(self):
def test_connect(self):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE)
s.connect(("svn.python.org", 443))
......@@ -138,16 +201,33 @@ class NetworkedTests(unittest.TestCase):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
try:
try:
s.connect(("svn.python.org", 443))
except ssl.SSLError, x:
raise test_support.TestFailed("Unexpected exception %s" % x)
finally:
s.close()
#@unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
def test_makefile_close(self):
# Issue #5238: creating a file-like object with makefile() shouldn't
# delay closing the underlying "real socket" (here tested with its
# file descriptor, hence skipping the test under Windows).
ss = ssl.wrap_socket(socket.socket(socket.AF_INET))
ss.connect(("svn.python.org", 443))
fd = ss.fileno()
f = ss.makefile()
f.close()
# The fd is still open
os.read(fd, 0)
# Closing the SSL socket should close the fd too
ss.close()
gc.collect()
try:
os.read(fd, 0)
except OSError, ex:
if ex[0] != errno.EBADF:
raise
def testNonBlockingHandshake(self):
def test_non_blocking_handshake(self):
s = socket.socket(socket.AF_INET)
s.connect(("svn.python.org", 443))
s.setblocking(False)
......@@ -171,8 +251,7 @@ class NetworkedTests(unittest.TestCase):
if test_support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
def testFetchServerCert(self):
def test_get_server_certificate(self):
pem = ssl.get_server_certificate(("svn.python.org", 443))
if not pem:
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
......@@ -191,12 +270,38 @@ class NetworkedTests(unittest.TestCase):
if test_support.verbose:
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
def test_algorithms(self):
# Issue #8484: all algorithms should be available when verifying a
# certificate.
# SHA256 was added in OpenSSL 0.9.8
if not hasattr(ssl, 'OPENSSL_VERSION_INFO'):
return
if ssl.OPENSSL_VERSION_INFO < (0, 9, 8, 0, 15):
self.skipTest("SHA256 not available on %r" % ssl.OPENSSL_VERSION)
# NOTE: https://sha256.tbs-internet.com is another possible test host
remote = ("sha2.hboeck.de", 443)
sha256_cert = os.path.join(os.path.dirname(__file__), "sha256.pem")
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=sha256_cert,)
try:
s.connect(remote)
if test_support.verbose:
sys.stdout.write("\nCipher with %r is %r\n" %
(remote, s.cipher()))
sys.stdout.write("Certificate is:\n%s\n" %
pprint.pformat(s.getpeercert()))
finally:
s.close()
try:
import threading
except ImportError:
_have_threads = False
else:
if not hasattr(threading.Thread, 'daemon'):
threading.Thread.daemon = property(threading.Thread.isDaemon, threading.Thread.setDaemon)
_have_threads = True
......@@ -215,7 +320,7 @@ else:
self.sock.setblocking(1)
self.sslconn = None
threading.Thread.__init__(self)
self.setDaemon(True)
self.daemon = True
def show_conn_details(self):
if self.server.certreqs == ssl.CERT_REQUIRED:
......@@ -229,27 +334,25 @@ else:
if test_support.verbose and self.server.chatty:
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
def wrap_conn (self):
def wrap_conn(self):
try:
self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.server.certificate,
ssl_version=self.server.protocol,
ca_certs=self.server.cacerts,
cert_reqs=self.server.certreqs)
except:
cert_reqs=self.server.certreqs,
ciphers=self.server.ciphers)
except ssl.SSLError:
# XXX Various errors can have happened here, for example
# a mismatching protocol version, an invalid certificate,
# or a low-level bug. This should be made more discriminating.
if self.server.chatty:
handle_error("\n server: bad connection attempt from " +
str(self.sock.getpeername()) + ":\n")
self.close()
if not self.server.expect_bad_connects:
# here, we want to stop the server, because this shouldn't
# happen in the context of our test case
self.running = False
# normally, we'd just stop here, but for the test
# harness, we want to stop the server
self.server.stop()
return False
else:
return True
......@@ -271,7 +374,7 @@ else:
else:
self.sock._sock.close()
def run (self):
def run(self):
self.running = True
if not self.server.starttls_server:
if isinstance(self.sock, ssl.SSLSocket):
......@@ -320,13 +423,11 @@ else:
# normally, we'd just stop here, but for the test
# harness, we want to stop the server
self.server.stop()
except:
handle_error('')
def __init__(self, certificate, ssl_version=None,
certreqs=None, cacerts=None, expect_bad_connects=False,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
wrap_accepting_socket=False):
wrap_accepting_socket=False, ciphers=None):
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLSv1
......@@ -336,7 +437,7 @@ else:
self.protocol = ssl_version
self.certreqs = certreqs
self.cacerts = cacerts
self.expect_bad_connects = expect_bad_connects
self.ciphers = ciphers
self.chatty = chatty
self.connectionchatty = connectionchatty
self.starttls_server = starttls_server
......@@ -347,20 +448,21 @@ else:
certfile=self.certificate,
cert_reqs = self.certreqs,
ca_certs = self.cacerts,
ssl_version = self.protocol)
ssl_version = self.protocol,
ciphers = self.ciphers)
if test_support.verbose and self.chatty:
sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
self.port = test_support.bind_port(self.sock)
self.active = False
threading.Thread.__init__(self)
self.setDaemon(False)
self.daemon = False
def start (self, flag=None):
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run (self):
self.sock.settimeout(0.5)
def run(self):
self.sock.settimeout(0.05)
self.sock.listen(5)
self.active = True
if self.flag:
......@@ -378,25 +480,23 @@ else:
pass
except KeyboardInterrupt:
self.stop()
except:
if self.chatty:
handle_error("Test server failure:\n")
self.sock.close()
def stop (self):
def stop(self):
self.active = False
class AsyncoreEchoServer(threading.Thread):
class EchoServer (asyncore.dispatcher):
class EchoServer(asyncore.dispatcher):
class ConnectionHandler (asyncore.dispatcher_with_send):
class ConnectionHandler(asyncore.dispatcher_with_send):
def __init__(self, conn, certfile):
asyncore.dispatcher_with_send.__init__(self, conn)
self.socket = ssl.wrap_socket(conn, server_side=True,
certfile=certfile,
do_handshake_on_connect=True)
do_handshake_on_connect=False)
self._ssl_accepting = True
def readable(self):
if isinstance(self.socket, ssl.SSLSocket):
......@@ -404,8 +504,28 @@ else:
self.handle_read_event()
return True
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
except ssl.SSLError, err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close()
raise
except socket.error, err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
self._ssl_accepting = False
def handle_read(self):
if self._ssl_accepting:
self._do_ssl_handshake()
else:
data = self.recv(1024)
if data and data.strip() != 'over':
self.send(data.lower())
def handle_close(self):
......@@ -438,26 +558,23 @@ else:
self.server = self.EchoServer(certfile)
self.port = self.server.port
threading.Thread.__init__(self)
self.setDaemon(True)
self.daemon = True
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
def start (self, flag=None):
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run (self):
def run(self):
self.active = True
if self.flag:
self.flag.set()
while self.active:
try:
asyncore.loop(1)
except:
pass
asyncore.loop(0.05)
def stop (self):
def stop(self):
self.active = False
self.server.close()
......@@ -466,12 +583,9 @@ else:
class HTTPSServer(HTTPServer):
def __init__(self, server_address, RequestHandlerClass, certfile):
HTTPServer.__init__(self, server_address, RequestHandlerClass)
# we assume the certfile contains both private key and certificate
self.certfile = certfile
self.active = False
self.active_lock = threading.Lock()
self.allow_reuse_address = True
def __str__(self):
......@@ -480,64 +594,27 @@ else:
self.server_name,
self.server_port))
def get_request (self):
def get_request(self):
# override this to wrap socket with SSL
sock, addr = self.socket.accept()
sslconn = ssl.wrap_socket(sock, server_side=True,
certfile=self.certfile)
return sslconn, addr
# The methods overridden below this are mainly so that we
# can run it in a thread and be able to stop it from another
# You probably wouldn't need them in other uses.
def server_activate(self):
# We want to run this in a thread for testing purposes,
# so we override this to set timeout, so that we get
# a chance to stop the server
self.socket.settimeout(0.5)
HTTPServer.server_activate(self)
# for Python 2.5 and older:
def serve_forever(self):
# We want this to run in a thread, so we use a slightly
# modified version of "forever".
self.active = True
while 1:
def serve_forever(self, poll_interval):
try:
# We need to lock while handling the request.
# Another thread can close the socket after self.active
# has been checked and before the request is handled.
# This causes an exception when using the closed socket.
self.active_lock.acquire()
try:
if not self.active:
break
return HTTPServer.serve_forever(self, poll_interval)
except TypeError:
for _ in xrange(100):
self.handle_request()
finally:
self.active_lock.release()
except socket.timeout:
pass
except KeyboardInterrupt:
def shutdown(self):
self.server_close()
return
except:
sys.stdout.write(''.join(traceback.format_exception(*sys.exc_info())))
break
time.sleep(0.1)
def server_close(self):
# Again, we want this to run in a thread, so we need to override
# close to clear the "active" flag, so that serve_forever() will
# terminate.
self.active_lock.acquire()
try:
HTTPServer.server_close(self)
self.active = False
finally:
self.active_lock.release()
class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
# need to override translate_path to get a known root,
# instead of using os.curdir, since the test could be
# run from anywhere
......@@ -582,34 +659,34 @@ else:
def __init__(self, certfile):
self.flag = None
self.active = False
self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
self.port = test_support.find_unused_port()
self.server = self.HTTPSServer(
(HOST, self.port), self.RootedHTTPRequestHandler, certfile)
(HOST, 0), self.RootedHTTPRequestHandler, certfile)
self.port = self.server.server_port
threading.Thread.__init__(self)
self.setDaemon(True)
self.daemon = True
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
def start (self, flag=None):
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run (self):
self.active = True
def run(self):
if self.flag:
self.flag.set()
self.server.serve_forever()
self.active = False
self.server.serve_forever(0.05)
def stop (self):
self.active = False
self.server.server_close()
def stop(self):
self.server.shutdown()
def badCertTest (certfile):
def bad_cert_test(certfile):
"""
Launch a server with CERT_REQUIRED, and check that trying to
connect to it with the given client certificate fails.
"""
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False)
......@@ -631,21 +708,24 @@ else:
if test_support.verbose:
sys.stdout.write("\nsocket.error is %s\n" % x[1])
else:
raise test_support.TestFailed(
"Use of invalid cert should have failed!")
raise AssertionError("Use of invalid cert should have failed!")
finally:
server.stop()
server.join()
def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
def server_params_test(certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n",
chatty=True, connectionchatty=False,
ciphers=None, chatty=True, connectionchatty=False,
wrap_accepting_socket=False):
"""
Launch a server, connect a client to it and try various reads
and writes.
"""
server = ThreadedEchoServer(certfile,
certreqs=certreqs,
ssl_version=protocol,
cacerts=cacertsfile,
ciphers=ciphers,
chatty=chatty,
connectionchatty=connectionchatty,
wrap_accepting_socket=wrap_accepting_socket)
......@@ -656,30 +736,32 @@ else:
# try to connect
if client_protocol is None:
client_protocol = protocol
try:
try:
s = ssl.wrap_socket(socket.socket(),
certfile=client_certfile,
ca_certs=cacertsfile,
ciphers=ciphers,
cert_reqs=certreqs,
ssl_version=client_protocol)
s.connect((HOST, server.port))
except ssl.SSLError, x:
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
except Exception, x:
raise test_support.TestFailed("Unexpected exception: " + str(x))
else:
args = [indata]
if sys.version_info[:2] >= (2, 7):
# bytearray fails on Python2.6
args.append(bytearray(indata))
if memoryview is not None:
args.append(memoryview(indata))
for arg in args:
if connectionchatty:
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % (repr(indata)))
s.write(indata)
" client: sending %s...\n" % (repr(arg)))
s.write(arg)
outdata = s.read()
if connectionchatty:
if test_support.verbose:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise test_support.TestFailed(
raise AssertionError(
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata),20)], len(outdata),
indata[:min(len(indata),20)].lower(), len(indata)))
......@@ -692,35 +774,45 @@ else:
server.stop()
server.join()
def tryProtocolCombo (server_protocol,
def try_protocol_combo(server_protocol,
client_protocol,
expectedToWork,
expect_success,
certsreqs=None):
if certsreqs is None:
certsreqs = ssl.CERT_NONE
if certsreqs == ssl.CERT_NONE:
certtype = "CERT_NONE"
elif certsreqs == ssl.CERT_OPTIONAL:
certtype = "CERT_OPTIONAL"
elif certsreqs == ssl.CERT_REQUIRED:
certtype = "CERT_REQUIRED"
certtype = {
ssl.CERT_NONE: "CERT_NONE",
ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
ssl.CERT_REQUIRED: "CERT_REQUIRED",
}[certsreqs]
if test_support.verbose:
formatstr = (expectedToWork and " %s->%s %s\n") or " {%s->%s} %s\n"
formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
sys.stdout.write(formatstr %
(ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol),
certtype))
try:
serverParamsTest(CERTFILE, server_protocol, certsreqs,
CERTFILE, CERTFILE, client_protocol, chatty=False)
except test_support.TestFailed:
if expectedToWork:
# NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
# will send an SSLv3 hello (rather than SSLv2) starting from
# OpenSSL 1.0.0 (see issue #8322).
if sys.version_info >= (2, 7):
ciphers = 'ALL'
else:
ciphers = None
server_params_test(CERTFILE, server_protocol, certsreqs,
CERTFILE, CERTFILE, client_protocol,
ciphers=ciphers, chatty=False)
# Protocol mismatch can result in either an SSLError, or a
# "Connection reset by peer" error.
except ssl.SSLError:
if expect_success:
raise
except socket.error, e:
if expect_success or e[0] != errno.ECONNRESET:
raise
else:
if not expectedToWork:
raise test_support.TestFailed(
if not expect_success:
raise AssertionError(
"Client protocol %s succeeded with server protocol %s!"
% (ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol)))
......@@ -728,53 +820,61 @@ else:
class ThreadedTests(unittest.TestCase):
def testRudeShutdown(self):
def test_rude_shutdown(self):
"""A brutal shutdown of an SSL server should raise an IOError
in the client when attempting handshake.
"""
listener_ready = threading.Event()
listener_gone = threading.Event()
port = test_support.find_unused_port()
# `listener` runs in a thread. It opens a socket listening on
# PORT, and sits in an accept() until the main thread connects.
# Then it rudely closes the socket, and sets Event `listener_gone`
# to let the main thread know the socket is gone.
def listener():
s = socket.socket()
s.bind((HOST, port))
port = test_support.bind_port(s, HOST)
# `listener` runs in a thread. It sits in an accept() until
# the main thread connects. Then it rudely closes the socket,
# and sets Event `listener_gone` to let the main thread know
# the socket is gone.
def listener():
s.listen(5)
listener_ready.set()
s.accept()
s = None # reclaim the socket object, which also closes it
s.close()
listener_gone.set()
def connector():
listener_ready.wait()
s = socket.socket()
s.connect((HOST, port))
c = socket.socket()
c.connect((HOST, port))
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(s)
except ssl.SSLError:
ssl_sock = ssl.wrap_socket(c)
except IOError:
pass
except ssl.SSLError:
# in pypi/ssl package (used on Python 2.5 and 2.4), ssl.SSLError is not a subclass of IOError
# so we accepting it here as well
if sys.version_info >= (2, 6):
raise
else:
raise test_support.TestFailed(
'connecting to closed SSL socket should have failed')
t = threading.Thread(target=listener)
t.start()
try:
connector()
finally:
t.join()
def testEcho (self):
def test_echo(self):
"""Basic test of an SSL client connecting to a server"""
if test_support.verbose:
sys.stdout.write("\n")
serverParamsTest(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1,
chatty=True, connectionchatty=True)
def testReadCert(self):
def test_getpeercert(self):
if test_support.verbose:
sys.stdout.write("\n")
s2 = socket.socket()
......@@ -788,7 +888,6 @@ else:
# wait for it to start
flag.wait()
# try to connect
try:
try:
s = ssl.wrap_socket(socket.socket(),
certfile=CERTFILE,
......@@ -796,25 +895,13 @@ else:
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl.PROTOCOL_SSLv23)
s.connect((HOST, server.port))
except ssl.SSLError, x:
raise test_support.TestFailed(
"Unexpected SSL error: " + str(x))
except Exception, x:
raise test_support.TestFailed(
"Unexpected exception: " + str(x))
else:
if not s:
raise test_support.TestFailed(
"Can't SSL-handshake with test server")
cert = s.getpeercert()
if not cert:
raise test_support.TestFailed(
"Can't get peer certificate.")
self.assertTrue(cert, "Can't get peer certificate.")
cipher = s.cipher()
if test_support.verbose:
sys.stdout.write(pprint.pformat(cert) + '\n')
sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
if not cert.has_key('subject'):
if 'subject' not in cert:
raise test_support.TestFailed(
"No subject field in certificate: %s." %
pprint.pformat(cert))
......@@ -828,74 +915,82 @@ else:
server.stop()
server.join()
def testNULLcert(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
def test_empty_cert(self):
"""Connecting with an empty cert file"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"nullcert.pem"))
def testMalformedCert(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
def test_malformed_cert(self):
"""Connecting with a badly formatted certificate (syntax error)"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"badcert.pem"))
def testWrongCert(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
def test_nonexisting_cert(self):
"""Connecting with a non-existing cert file"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"wrongcert.pem"))
def testMalformedKey(self):
badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
def test_malformed_key(self):
"""Connecting with a badly formatted key (syntax error)"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"badkey.pem"))
def testProtocolSSL2(self):
def test_protocol_sslv2(self):
"""Connecting to an SSLv2 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
def testProtocolSSL23(self):
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
def test_protocol_sslv23(self):
"""Connecting to an SSLv23 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
try:
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
except test_support.TestFailed, x:
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
except (ssl.SSLError, socket.error), x:
# this fails on some older versions of OpenSSL (0.9.7l, for instance)
if test_support.verbose:
sys.stdout.write(
" SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
% str(x))
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
def testProtocolSSL3(self):
def test_protocol_sslv3(self):
"""Connecting to an SSLv3 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False)
tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
def testProtocolTLS1(self):
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
def test_protocol_tlsv1(self):
"""Connecting to a TLSv1 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False)
def testSTARTTLS (self):
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
server = ThreadedEchoServer(CERTFILE,
......@@ -909,14 +1004,10 @@ else:
flag.wait()
# try to connect
wrapped = False
try:
try:
s = socket.socket()
s.setblocking(1)
s.connect((HOST, server.port))
except Exception, x:
raise test_support.TestFailed("Unexpected exception: " + str(x))
else:
if test_support.verbose:
sys.stdout.write("\n")
for indata in msgs:
......@@ -931,6 +1022,7 @@ else:
outdata = s.recv(1024)
if (indata == "STARTTLS" and
outdata.strip().lower().startswith("ok")):
# STARTTLS ok, switch to secure mode
if test_support.verbose:
sys.stdout.write(
" client: read %s from server, starting TLS...\n"
......@@ -939,6 +1031,7 @@ else:
wrapped = True
elif (indata == "ENDTLS" and
outdata.strip().lower().startswith("ok")):
# ENDTLS ok, switch back to clear text
if test_support.verbose:
sys.stdout.write(
" client: read %s from server, ending TLS...\n"
......@@ -960,15 +1053,14 @@ else:
server.stop()
server.join()
def testSocketServer(self):
def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections."""
server = SocketServerHTTPSServer(CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
try:
if test_support.verbose:
sys.stdout.write('\n')
......@@ -986,31 +1078,22 @@ else:
" client: read %d bytes from remote server '%s'\n"
% (len(d2), server))
f.close()
except:
msg = ''.join(traceback.format_exception(*sys.exc_info()))
if test_support.verbose:
sys.stdout.write('\n' + msg)
raise test_support.TestFailed(msg)
else:
if not (d1 == d2):
raise test_support.TestFailed(
"Couldn't fetch data from HTTPS server")
self.assertEqual(d1, d2)
finally:
server.stop()
server.join()
def testWrappedAccept (self):
def test_wrapped_accept(self):
"""Check the accept() method on SSL sockets."""
if test_support.verbose:
sys.stdout.write("\n")
serverParamsTest(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
server_params_test(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23,
chatty=True, connectionchatty=True,
wrap_accepting_socket=True)
def testAsyncoreServer (self):
def test_asyncore_server(self):
"""Check the example asyncore integration."""
indata = "TEST MESSAGE of mixed case\n"
if test_support.verbose:
......@@ -1021,15 +1104,9 @@ else:
# wait for it to start
flag.wait()
# try to connect
try:
try:
s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port))
except ssl.SSLError, x:
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
except Exception, x:
raise test_support.TestFailed("Unexpected exception: " + str(x))
else:
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % (repr(indata)))
......@@ -1051,9 +1128,8 @@ else:
# wait for server thread to end
server.join()
def testAllRecvAndSendMethods(self):
def test_recv_send(self):
"""Test recv(), send() and friends."""
if test_support.verbose:
sys.stdout.write("\n")
......@@ -1068,8 +1144,6 @@ else:
# wait for it to start
flag.wait()
# try to connect
try:
try:
s = ssl.wrap_socket(socket.socket(),
server_side=False,
certfile=CERTFILE,
......@@ -1077,15 +1151,7 @@ else:
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_TLSv1)
s.connect((HOST, server.port))
except ssl.SSLError, x:
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
except Exception, x:
raise test_support.TestFailed("Unexpected exception: " + str(x))
else:
try:
bytearray
except NameError:
bytearray = None
# helper methods for standardising recv* method signatures
def _recv_into():
b = bytearray("\0"*100)
......@@ -1179,40 +1245,59 @@ else:
server.stop()
server.join()
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout
server = socket.socket(socket.AF_INET)
host = "127.0.0.1"
port = test_support.bind_port(server)
started = threading.Event()
finish = False
def serve():
server.listen(5)
started.set()
conns = []
while not finish:
r, w, e = select.select([server], [], [], 0.1)
if server in r:
# Let the socket hang around rather than having
# it closed by garbage collection.
conns.append(server.accept()[0])
t = threading.Thread(target=serve)
t.start()
started.wait()
def test_main(verbose=False):
if skip_expected:
raise test_support.TestSkipped("No SSL support")
global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT
CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
"keycert.pem")
SVN_PYTHON_ORG_ROOT_CERT = os.path.join(
os.path.dirname(__file__) or os.curdir,
"https_svn_python_org_root.pem")
if (not os.path.exists(CERTFILE) or
not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT)):
raise test_support.TestFailed("Can't read certificate files!")
TESTPORT = test_support.find_unused_port()
if not TESTPORT:
raise test_support.TestFailed("Can't find open port to test servers on!")
tests = [BasicTests]
#if test_support.is_resource_enabled('network'):
tests.append(NetworkedTests)
#if _have_threads:
# thread_info = test_support.threading_setup()
# if thread_info and test_support.is_resource_enabled('network'):
tests.append(ThreadedTests)
test_support.run_unittest(*tests)
try:
try:
c = socket.socket(socket.AF_INET)
c.settimeout(0.2)
c.connect((host, port))
# Will attempt handshake and time out
try:
ssl.wrap_socket(c)
except ssl.SSLError, ex:
if 'timed out' not in str(ex):
raise
finally:
c.close()
try:
c = socket.socket(socket.AF_INET)
c.settimeout(0.2)
c = ssl.wrap_socket(c)
# Will attempt handshake and time out
try:
c.connect((host, port))
except ssl.SSLError, ex:
if 'timed out' not in str(ex):
raise
finally:
c.close()
finally:
finish = True
t.join()
server.close()
#if _have_threads:
# test_support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment