Commit 784e3789 authored by Jason Madden's avatar Jason Madden

Refactor test__server[_pywsgi].py for finer control. Adjust the timeouts for pypy/libuv.

parent ea983355
...@@ -26,7 +26,6 @@ import sys ...@@ -26,7 +26,6 @@ import sys
import types import types
import unittest import unittest
from unittest import TestCase as BaseTestCase from unittest import TestCase as BaseTestCase
from unittest.util import safe_repr
import time import time
import os import os
from os.path import basename, splitext from os.path import basename, splitext
...@@ -153,7 +152,7 @@ else: ...@@ -153,7 +152,7 @@ else:
skipIf = unittest.skipIf skipIf = unittest.skipIf
EXPECT_POOR_TIMER_RESOLUTION = PYPY3 or RUNNING_ON_APPVEYOR EXPECT_POOR_TIMER_RESOLUTION = PYPY3 or RUNNING_ON_APPVEYOR or (LIBUV and PYPY)
skipOnLibuv = _do_not_skip skipOnLibuv = _do_not_skip
skipOnLibuvOnCI = _do_not_skip skipOnLibuvOnCI = _do_not_skip
...@@ -167,6 +166,18 @@ if LIBUV: ...@@ -167,6 +166,18 @@ if LIBUV:
if PYPY: if PYPY:
skipOnLibuvOnCIOnPyPy = unittest.skip skipOnLibuvOnCIOnPyPy = unittest.skip
CONN_ABORTED_ERRORS = []
try:
from errno import WSAECONNABORTED
CONN_ABORTED_ERRORS.append(WSAECONNABORTED)
except ImportError:
pass
from errno import ECONNRESET
CONN_ABORTED_ERRORS.append(ECONNRESET)
CONN_ABORTED_ERRORS = frozenset(CONN_ABORTED_ERRORS)
class ExpectedException(Exception): class ExpectedException(Exception):
"""An exception whose traceback should be ignored by the hub""" """An exception whose traceback should be ignored by the hub"""
...@@ -622,7 +633,7 @@ class TestCase(TestCaseMetaClass("NewBase", (BaseTestCase,), {})): ...@@ -622,7 +633,7 @@ class TestCase(TestCaseMetaClass("NewBase", (BaseTestCase,), {})):
def assertMonkeyPatchedFuncSignatures(self, mod_name, func_names=(), exclude=()): def assertMonkeyPatchedFuncSignatures(self, mod_name, func_names=(), exclude=()):
# We use inspect.getargspec because it's the only thing available # We use inspect.getargspec because it's the only thing available
# in Python 2.7, but it is deprecated # in Python 2.7, but it is deprecated
# pylint:disable=deprecated-method # pylint:disable=deprecated-method,too-many-locals
import inspect import inspect
import warnings import warnings
from gevent.monkey import get_original from gevent.monkey import get_original
...@@ -738,8 +749,7 @@ class _DelayWaitMixin(object): ...@@ -738,8 +749,7 @@ class _DelayWaitMixin(object):
timeout = gevent.Timeout.start_new(0.001, ref=False) timeout = gevent.Timeout.start_new(0.001, ref=False)
try: try:
with self.assertRaises(gevent.Timeout) as exc: with self.assertRaises(gevent.Timeout) as exc:
result = self.wait(timeout=1) self.wait(timeout=1)
self.assertIs(exc.exception, timeout) self.assertIs(exc.exception, timeout)
finally: finally:
timeout.cancel() timeout.cancel()
...@@ -760,7 +770,7 @@ class GenericWaitTestCase(_DelayWaitMixin, TestCase): ...@@ -760,7 +770,7 @@ class GenericWaitTestCase(_DelayWaitMixin, TestCase):
def test_returns_none_after_timeout(self): def test_returns_none_after_timeout(self):
result = self._wait_and_check() result = self._wait_and_check()
# join and wait simply return after timeout expires # join and wait simply return after timeout expires
assert result is None, repr(result) self.assertIsNone(result)
class GenericGetTestCase(_DelayWaitMixin, TestCase): class GenericGetTestCase(_DelayWaitMixin, TestCase):
...@@ -787,10 +797,10 @@ class GenericGetTestCase(_DelayWaitMixin, TestCase): ...@@ -787,10 +797,10 @@ class GenericGetTestCase(_DelayWaitMixin, TestCase):
def test_raises_timeout_Timeout_exc_customized(self): def test_raises_timeout_Timeout_exc_customized(self):
error = RuntimeError('expected error') error = RuntimeError('expected error')
timeout = gevent.Timeout(self._default_wait_timeout, exception=error) timeout = gevent.Timeout(self._default_wait_timeout, exception=error)
try: with self.assertRaises(RuntimeError) as exc:
self._wait_and_check(timeout=timeout) self._wait_and_check(timeout=timeout)
except RuntimeError as ex:
self.assertIs(ex, error) self.assertIs(exc.exception, error)
self.cleanup() self.cleanup()
......
...@@ -52,21 +52,12 @@ from gevent.pywsgi import Input ...@@ -52,21 +52,12 @@ from gevent.pywsgi import Input
CONTENT_LENGTH = 'Content-Length' CONTENT_LENGTH = 'Content-Length'
CONN_ABORTED_ERRORS = [] CONN_ABORTED_ERRORS = greentest.CONN_ABORTED_ERRORS
server_implements_chunked = True server_implements_chunked = True
server_implements_pipeline = True server_implements_pipeline = True
server_implements_100continue = True server_implements_100continue = True
DEBUG = '-v' in sys.argv DEBUG = '-v' in sys.argv
try:
from errno import WSAECONNABORTED
CONN_ABORTED_ERRORS.append(WSAECONNABORTED)
except ImportError:
pass
from errno import ECONNRESET
CONN_ABORTED_ERRORS.append(ECONNRESET)
REASONS = {200: 'OK', REASONS = {200: 'OK',
500: 'Internal Server Error'} 500: 'Internal Server Error'}
......
from __future__ import print_function from __future__ import print_function
import unittest
import errno
import os
import greentest import greentest
from greentest import PY3 from greentest import PY3
from gevent import socket from gevent import socket
import gevent import gevent
from gevent.server import StreamServer from gevent.server import StreamServer
import errno
import os
# Timeouts very flaky on appveyor and PyPy3 # Timeouts very flaky on appveyor and PyPy3
_DEFAULT_SOCKET_TIMEOUT = 0.1 if not greentest.EXPECT_POOR_TIMER_RESOLUTION else 1.0 _DEFAULT_SOCKET_TIMEOUT = 0.1 if not greentest.EXPECT_POOR_TIMER_RESOLUTION else 1.0
...@@ -13,14 +16,14 @@ _DEFAULT_SOCKET_TIMEOUT = 0.1 if not greentest.EXPECT_POOR_TIMER_RESOLUTION else ...@@ -13,14 +16,14 @@ _DEFAULT_SOCKET_TIMEOUT = 0.1 if not greentest.EXPECT_POOR_TIMER_RESOLUTION else
class SimpleStreamServer(StreamServer): class SimpleStreamServer(StreamServer):
def handle(self, client_socket, address): def handle(self, client_socket, _address):
fd = client_socket.makefile() fd = client_socket.makefile()
try: try:
request_line = fd.readline() request_line = fd.readline()
if not request_line: if not request_line:
return return
try: try:
method, path, rest = request_line.split(' ', 3) _method, path, _rest = request_line.split(' ', 3)
except Exception: except Exception:
print('Failed to parse request line: %r' % (request_line, )) print('Failed to parse request line: %r' % (request_line, ))
raise raise
...@@ -38,39 +41,42 @@ class SimpleStreamServer(StreamServer): ...@@ -38,39 +41,42 @@ class SimpleStreamServer(StreamServer):
fd.close() fd.close()
class Settings: class _Settings(object):
ServerClass = StreamServer ServerClass = StreamServer
ServerSubClass = SimpleStreamServer ServerSubClass = SimpleStreamServer
restartable = True restartable = True
close_socket_detected = True close_socket_detected = True
@staticmethod @staticmethod
def assertAcceptedConnectionError(self): def assertAcceptedConnectionError(inst):
conn = self.makefile() conn = inst.makefile()
result = conn.read() result = conn.read()
assert not result, repr(result) inst.assertFalse(result)
assert500 = assertAcceptedConnectionError assert500 = assertAcceptedConnectionError
@staticmethod @staticmethod
def assert503(self): def assert503(inst):
# regular reads timeout # regular reads timeout
self.assert500() inst.assert500()
# attempt to send anything reset the connection # attempt to send anything reset the connection
try: try:
self.send_request() inst.send_request()
except socket.error as ex: except socket.error as ex:
if ex.args[0] != errno.ECONNRESET: if ex.args[0] not in greentest.CONN_ABORTED_ERRORS:
raise raise
@staticmethod @staticmethod
def assertPoolFull(self): def assertPoolFull(inst):
self.assertRaises(socket.timeout, self.assertRequestSucceeded, timeout=0.01) with inst.assertRaises(socket.timeout):
inst.assertRequestSucceeded(timeout=0.01)
class TestCase(greentest.TestCase): class TestCase(greentest.TestCase):
__timeout__ = greentest.LARGE_TIMEOUT __timeout__ = greentest.LARGE_TIMEOUT
Settings = _Settings
server = None
def cleanup(self): def cleanup(self):
if getattr(self, 'server', None) is not None: if getattr(self, 'server', None) is not None:
...@@ -130,27 +136,24 @@ class TestCase(greentest.TestCase): ...@@ -130,27 +136,24 @@ class TestCase(greentest.TestCase):
return conn return conn
def assertConnectionRefused(self): def assertConnectionRefused(self):
try: with self.assertRaises(socket.error) as exc:
conn = self.makefile() conn = self.makefile()
try:
raise AssertionError('Connection was not refused: %r' % (conn._sock, ))
finally:
conn.close() conn.close()
except socket.error as ex:
if ex.args[0] not in (errno.ECONNREFUSED, errno.EADDRNOTAVAIL): ex = exc.exception
raise self.assertIn(ex.args[0], (errno.ECONNREFUSED, errno.EADDRNOTAVAIL))
def assert500(self): def assert500(self):
Settings.assert500(self) self.Settings.assert500(self)
def assert503(self): def assert503(self):
Settings.assert503(self) self.Settings.assert503(self)
def assertAcceptedConnectionError(self): def assertAcceptedConnectionError(self):
Settings.assertAcceptedConnectionError(self) self.Settings.assertAcceptedConnectionError(self)
def assertPoolFull(self): def assertPoolFull(self):
Settings.assertPoolFull(self) self.Settings.assertPoolFull(self)
def assertNotAccepted(self): def assertNotAccepted(self):
conn = self.makefile() conn = self.makefile()
...@@ -185,11 +188,10 @@ class TestCase(greentest.TestCase): ...@@ -185,11 +188,10 @@ class TestCase(greentest.TestCase):
self.server.stop() self.server.stop()
self.assertConnectionRefused() self.assertConnectionRefused()
def report_netstat(self, msg): def report_netstat(self, _msg):
if 0: # At one point this would call 'sudo netstat -anp | grep PID'
print(msg) # with os.system. We can probably do better with psutil.
os.system('sudo netstat -anp | grep %s' % os.getpid()) return
print('^^^^^')
def _create_server(self): def _create_server(self):
return self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0)) return self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0))
...@@ -221,12 +223,14 @@ class TestCase(greentest.TestCase): ...@@ -221,12 +223,14 @@ class TestCase(greentest.TestCase):
def ServerClass(self, *args, **kwargs): def ServerClass(self, *args, **kwargs):
kwargs.setdefault('spawn', self.get_spawn()) kwargs.setdefault('spawn', self.get_spawn())
return Settings.ServerClass(*args, **kwargs) return self.Settings.ServerClass(*args, **kwargs)
def ServerSubClass(self, *args, **kwargs): def ServerSubClass(self, *args, **kwargs):
kwargs.setdefault('spawn', self.get_spawn()) kwargs.setdefault('spawn', self.get_spawn())
return Settings.ServerSubClass(*args, **kwargs) return self.Settings.ServerSubClass(*args, **kwargs)
def get_spawn(self):
return None
class TestDefaultSpawn(TestCase): class TestDefaultSpawn(TestCase):
...@@ -237,7 +241,7 @@ class TestDefaultSpawn(TestCase): ...@@ -237,7 +241,7 @@ class TestDefaultSpawn(TestCase):
self.report_netstat('before start') self.report_netstat('before start')
self.start_server() self.start_server()
self.report_netstat('after start') self.report_netstat('after start')
if restartable and Settings.restartable: if restartable and self.Settings.restartable:
self.server.stop_accepting() self.server.stop_accepting()
self.report_netstat('after stop_accepting') self.report_netstat('after stop_accepting')
self.assertNotAccepted() self.assertNotAccepted()
...@@ -341,13 +345,6 @@ class TestDefaultSpawn(TestCase): ...@@ -341,13 +345,6 @@ class TestDefaultSpawn(TestCase):
self.expect_one_error() self.expect_one_error()
self.assertAcceptedConnectionError() self.assertAcceptedConnectionError()
self.assert_error(ExpectedError, error) self.assert_error(ExpectedError, error)
return
if Settings.restartable:
assert not self.server.started
else:
assert self.server.started
gevent.sleep(0.1)
assert self.server.started
def test_server_repr_when_handle_is_instancemethod(self): def test_server_repr_when_handle_is_instancemethod(self):
# PR 501 # PR 501
...@@ -424,9 +421,9 @@ class TestNoneSpawn(TestCase): ...@@ -424,9 +421,9 @@ class TestNoneSpawn(TestCase):
self._test_invalid_callback() self._test_invalid_callback()
def test_assertion_in_blocking_func(self): def test_assertion_in_blocking_func(self):
def sleep(*args): def sleep(*_args):
gevent.sleep(0) gevent.sleep(0)
self.server = Settings.ServerClass((greentest.DEFAULT_BIND_ADDR, 0), sleep, spawn=None) self.server = self.Settings.ServerClass((greentest.DEFAULT_BIND_ADDR, 0), sleep, spawn=None)
self.server.start() self.server.start()
self.expect_one_error() self.expect_one_error()
self.assert500() self.assert500()
...@@ -437,32 +434,28 @@ class ExpectedError(Exception): ...@@ -437,32 +434,28 @@ class ExpectedError(Exception):
pass pass
if hasattr(socket, 'ssl'):
class TestSSLSocketNotAllowed(TestCase): class TestSSLSocketNotAllowed(TestCase):
switch_expected = False switch_expected = False
def get_spawn(self): def get_spawn(self):
return gevent.spawn return gevent.spawn
@unittest.skipUnless(hasattr(socket, 'ssl'), "Uses socket.ssl")
def test(self): def test(self):
from gevent.socket import ssl, socket from gevent.socket import ssl
listener = socket() from gevent.socket import socket as gsocket
listener = gsocket()
listener.bind(('0.0.0.0', 0)) listener.bind(('0.0.0.0', 0))
listener.listen(5) listener.listen(5)
listener = ssl(listener) listener = ssl(listener)
self.assertRaises(TypeError, self.ServerSubClass, listener) self.assertRaises(TypeError, self.ServerSubClass, listener)
try: def _file(name, here=os.path.dirname(__file__)):
__import__('ssl')
except ImportError:
pass
else:
def _file(name, here=os.path.dirname(__file__)):
return os.path.abspath(os.path.join(here, name)) return os.path.abspath(os.path.join(here, name))
class TestSSLGetCertificate(TestCase): class TestSSLGetCertificate(TestCase):
def _create_server(self): def _create_server(self):
return self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0), return self.ServerSubClass((greentest.DEFAULT_BIND_ADDR, 0),
......
import socket
import greentest import greentest
import gevent import gevent
from gevent import pywsgi from gevent import pywsgi
import test__server import test__server
from test__server import *
from test__server import Settings as server_Settings
def application(self, environ, start_response): def application(environ, start_response):
if environ['PATH_INFO'] == '/': if environ['PATH_INFO'] == '/':
start_response("200 OK", []) start_response("200 OK", [])
return [b"PONG"] return [b"PONG"]
if environ['PATH_INFO'] == '/ping': if environ['PATH_INFO'] == '/ping':
start_response("200 OK", []) start_response("200 OK", [])
return [b"PONG"] return [b"PONG"]
elif environ['PATH_INFO'] == '/short': if environ['PATH_INFO'] == '/short':
gevent.sleep(0.5) gevent.sleep(0.5)
start_response("200 OK", []) start_response("200 OK", [])
return [] return []
elif environ['PATH_INFO'] == '/long': if environ['PATH_INFO'] == '/long':
gevent.sleep(10) gevent.sleep(10)
start_response("200 OK", []) start_response("200 OK", [])
return [] return []
else:
start_response("404 pywsgi WTF?", []) start_response("404 pywsgi WTF?", [])
return [] return []
class SimpleWSGIServer(pywsgi.WSGIServer): class SimpleWSGIServer(pywsgi.WSGIServer):
application = application application = staticmethod(application)
internal_error_start = b'HTTP/1.1 500 Internal Server Error\n'.replace(b'\n', b'\r\n') internal_error_start = b'HTTP/1.1 500 Internal Server Error\n'.replace(b'\n', b'\r\n')
...@@ -41,7 +42,7 @@ Content-length: 31 ...@@ -41,7 +42,7 @@ Content-length: 31
Service Temporarily Unavailable'''.replace(b'\n', b'\r\n') Service Temporarily Unavailable'''.replace(b'\n', b'\r\n')
class Settings: class Settings(object):
ServerClass = pywsgi.WSGIServer ServerClass = pywsgi.WSGIServer
ServerSubClass = SimpleWSGIServer ServerSubClass = SimpleWSGIServer
close_socket_detected = True close_socket_detected = True
...@@ -49,36 +50,51 @@ class Settings: ...@@ -49,36 +50,51 @@ class Settings:
close_socket_detected = False close_socket_detected = False
@staticmethod @staticmethod
def assert500(self): def assert500(inst):
conn = self.makefile() conn = inst.makefile()
conn.write(b'GET / HTTP/1.0\r\n\r\n') conn.write(b'GET / HTTP/1.0\r\n\r\n')
result = conn.read() result = conn.read()
assert result.startswith(internal_error_start), (result, internal_error_start) inst.assertTrue(result.startswith(internal_error_start),
assert result.endswith(internal_error_end), (result, internal_error_end) (result, internal_error_start))
inst.assertTrue(result.endswith(internal_error_end),
assertAcceptedConnectionError = assert500 (result, internal_error_end))
@staticmethod @staticmethod
def assert503(self): def assert503(inst):
conn = self.makefile() conn = inst.makefile()
conn.write(b'GET / HTTP/1.0\r\n\r\n') conn.write(b'GET / HTTP/1.0\r\n\r\n')
result = conn.read() result = conn.read()
assert result == internal_error503, (result, internal_error503) inst.assertEqual(result, internal_error503)
@staticmethod @staticmethod
def assertPoolFull(self): def assertPoolFull(inst):
self.assertRaises(socket.timeout, self.assertRequestSucceeded) with inst.assertRaises(socket.timeout):
inst.assertRequestSucceeded()
@staticmethod @staticmethod
def assertAcceptedConnectionError(self): def assertAcceptedConnectionError(inst):
conn = self.makefile() conn = inst.makefile()
result = conn.read() result = conn.read()
assert not result, repr(result) inst.assertFalse(result)
class TestCase(test__server.TestCase):
Settings = Settings
class TestDefaultSpawn(test__server.TestDefaultSpawn):
Settings = Settings
class TestSSLSocketNotAllowed(test__server.TestSSLSocketNotAllowed):
Settings = Settings
class TestRawSpawn(test__server.TestRawSpawn):
Settings = Settings
test__server.Settings = Settings class TestSSLGetCertificate(test__server.TestSSLGetCertificate):
Settings = Settings
del TestNoneSpawn class TestPoolSpawn(test__server.TestPoolSpawn):
Settings = Settings
if __name__ == '__main__': if __name__ == '__main__':
greentest.main() greentest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment