Commit 61399d52 authored by Jason Madden's avatar Jason Madden

Fix #1044 by always closing opened sockets before raising

Also enable ResoureWarnings by default in the test suite and fix a
bunch that showed up.
parent be540604
......@@ -90,6 +90,12 @@
- ``socket.send()`` now catches ``EPROTYPE`` on macOS to handle a race
condition during shutdown. Fixed in :pr:`1035` by Jay Oster.
- :func:`gevent.socket.create_connection` now properly cleans up open
sockets if connecting or binding raises a :exc:`BaseException` like
:exc:`KeyboardInterrupt`, :exc:`greenlet.GreenletExit` or
:exc:`gevent.timeout.Timeout`. Reported in :issue:`1044` by
kochelmonster.
- Update c-ares to 1.13.0. See :issue:`990`.
1.2.2 (2017-06-05)
......
......@@ -716,8 +716,13 @@ def main():
def _get_script_help():
from inspect import getargspec
patch_all_args = getargspec(patch_all)[0] # pylint:disable=deprecated-method
# pylint:disable=deprecated-method
import inspect
try:
getter = inspect.getfullargspec # deprecated in 3.5, un-deprecated in 3.6
except AttributeError:
getter = inspect.getargspec
patch_all_args = getter(patch_all)[0]
modules = [x for x in patch_all_args if 'patch_' + x in globals()]
script_help = """gevent.monkey - monkey patch the standard modules to use gevent.
......
......@@ -74,7 +74,13 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N
host, port = address
err = None
for res in getaddrinfo(host, port, 0 if has_ipv6 else AF_INET, SOCK_STREAM):
# getaddrinfo is documented as returning a list, but our interface
# is pluggable, so be sure it does.
addrs = list(getaddrinfo(host, port, 0 if has_ipv6 else AF_INET, SOCK_STREAM))
if not addrs:
raise error("getaddrinfo returns an empty list")
for res in addrs:
af, socktype, proto, _, sa = res
sock = None
try:
......@@ -84,24 +90,34 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except error as ex:
# without exc_clear(), if connect() fails once, the socket is referenced by the frame in exc_info
# and the next bind() fails (see test__socket.TestCreateConnection)
# that does not happen with regular sockets though, because _socket.socket.connect() is a built-in.
# this is similar to "getnameinfo loses a reference" failure in test_socket.py
if not PY3:
sys.exc_clear() # pylint:disable=no-member,useless-suppression
except error:
if sock is not None:
sock.close()
err = ex
if err is not None:
if res is addrs[-1]:
raise
# without exc_clear(), if connect() fails once, the socket
# is referenced by the frame in exc_info and the next
# bind() fails (see test__socket.TestCreateConnection)
# that does not happen with regular sockets though,
# because _socket.socket.connect() is a built-in. this is
# similar to "getnameinfo loses a reference" failure in
# test_socket.py
try:
raise err # pylint:disable=raising-bad-type
finally:
err = None
c = sys.exc_clear
except AttributeError:
pass # Python 3 doesn't have this
else:
raise error("getaddrinfo returns an empty list")
c()
except BaseException:
# Things like GreenletExit, Timeout and KeyboardInterrupt.
# These get raised immediately, being sure to
# close the socket
if sock is not None:
sock.close()
raise
else:
return sock
# This is promised to be in the __all__ of the _source, but, for circularity reasons,
# we implement it in this module. Mostly for documentation purposes, put it
......
......@@ -430,6 +430,10 @@ class TestCase(TestCaseMetaClass("NewBase", (BaseTestCase,), {})):
super(TestCase, cls).tearDownClass()
def _close_on_teardown(self, resource):
"""
*resource* either has a ``close`` method, or is a
callable.
"""
if 'close_on_teardown' not in self.__dict__:
self.close_on_teardown = []
self.close_on_teardown.append(resource)
......
......@@ -47,17 +47,18 @@ def TESTRUNNER(tests=None):
if tests:
atexit.register(os.system, 'rm -f */@test*')
basic_args = [sys.executable, '-u', '-W', 'ignore', '-m' 'monkey_test']
for filename in tests:
if filename in version_tests:
util.log("Overriding %s from %s with file from %s", filename, directory, full_directory)
continue
yield [sys.executable, '-u', '-m', 'monkey_test', filename], options.copy()
yield [sys.executable, '-u', '-m', 'monkey_test', '--Event', filename], options.copy()
yield basic_args + [filename], options.copy()
yield basic_args + ['--Event', filename], options.copy()
options['cwd'] = full_directory
for filename in version_tests:
yield [sys.executable, '-u', '-m', 'monkey_test', filename], options.copy()
yield [sys.executable, '-u', '-m', 'monkey_test', '--Event', filename], options.copy()
yield basic_args + [filename], options.copy()
yield basic_args + ['--Event', filename], options.copy()
def main():
......
......@@ -17,7 +17,8 @@ class Test_udp_client(TestCase):
server = DatagramServer('127.0.0.1:9000', handle)
server.start()
try:
run([sys.executable, '-u', 'udp_client.py', 'Test_udp_client'], timeout=10, cwd='../../examples/')
run([sys.executable, '-W', 'ignore' '-u', 'udp_client.py', 'Test_udp_client'],
timeout=10, cwd='../../examples/')
finally:
server.close()
self.assertEqual(log, [b'Test_udp_client'])
......
......@@ -18,6 +18,7 @@
# THE SOFTWARE.
from greentest import TestCase, main, tcp_listener
from greentest import skipOnPyPy
import gevent
from gevent import socket
import sys
......@@ -87,17 +88,19 @@ class TestGreenIo(TestCase):
did_it_work(server)
server_greenlet.kill()
@skipOnPyPy("GC is different")
def test_del_closes_socket(self):
if PYPY:
return
timer = gevent.Timeout.start_new(0.5)
def accept_once(listener):
# delete/overwrite the original conn
# object, only keeping the file object around
# closing the file object should close everything
# XXX: This is not exactly true on Python 3.
# This produces a ResourceWarning.
try:
conn, addr = listener.accept()
conn, _ = listener.accept()
conn = conn.makefile(mode='wb')
conn.write(b'hello\n')
conn.close()
......
......@@ -32,14 +32,13 @@ DELAY = 0.1
class TestCloseSocketWhilePolling(greentest.TestCase):
def test(self):
try:
with self.assertRaises(Exception):
sock = socket.socket()
self._close_on_teardown(sock)
get_hub().loop.timer(0, sock.close)
sock.connect(('python.org', 81))
except Exception:
gevent.sleep(0)
else:
assert False, 'expected an error here'
class TestExceptionInMainloop(greentest.TestCase):
......
from __future__ import print_function
import os
from gevent import monkey; monkey.patch_all()
import re
import socket
import ssl
import threading
import unittest
import errno
from greentest import TestCase
dirname = os.path.dirname(os.path.abspath(__file__))
certfile = os.path.join(dirname, '2.7/keycert.pem')
pid = os.getpid()
......@@ -27,13 +28,13 @@ except ImportError:
psutil = None
class Test(unittest.TestCase):
class Test(TestCase):
extra_allowed_open_states = ()
def tearDown(self):
self.extra_allowed_open_states = ()
unittest.TestCase.tearDown(self)
super(Test, self).tearDown()
def assert_raises_EBADF(self, func):
try:
......@@ -156,6 +157,7 @@ class TestSocket(Test):
listener.listen(1)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
......@@ -180,6 +182,7 @@ class TestSocket(Test):
listener.listen(1)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
......@@ -213,6 +216,7 @@ class TestSocket(Test):
listener.listen(1)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
......@@ -282,10 +286,12 @@ class TestSSL(Test):
listener.listen(1)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector)
x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect)
t.start()
......@@ -303,15 +309,18 @@ class TestSSL(Test):
def test_server_makefile1(self):
listener = socket.socket()
self._close_on_teardown(listener)
listener.bind(('127.0.0.1', 0))
port = listener.getsockname()[1]
listener.listen(1)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector)
x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect)
t.start()
......@@ -338,10 +347,12 @@ class TestSSL(Test):
listener.listen(1)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector)
x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect)
t.start()
......@@ -372,10 +383,12 @@ class TestSSL(Test):
listener = ssl.wrap_socket(listener, keyfile=certfile, certfile=certfile)
connector = socket.socket()
self._close_on_teardown(connector)
def connect():
connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector)
x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect)
t.start()
......
......@@ -70,7 +70,7 @@ def init_server():
def handle_request(s, raise_on_timeout):
try:
conn, address = s.accept()
conn, _ = s.accept()
except socket.timeout:
if raise_on_timeout:
raise
......@@ -83,7 +83,7 @@ def handle_request(s, raise_on_timeout):
res = conn.send(b'bye')
#print('handle_request - sent %r' % res)
#print('handle_request - conn refcount: %s' % sys.getrefcount(conn))
#conn.close()
conn.close()
def make_request(port):
......@@ -96,7 +96,7 @@ def make_request(port):
res = s.recv(100)
assert res == b'bye', repr(res)
#print('make_request - recvd %r' % res)
#s.close()
s.close()
def run_interaction(run_client):
......
......@@ -9,8 +9,11 @@ try:
assert weakref.ref(Dummy())() is None
from gevent import socket
assert weakref.ref(socket.socket())() is None
s = socket.socket()
r = weakref.ref(s)
s.close()
del s
assert r() is None
except AssertionError:
import sys
if hasattr(sys, 'pypy_version_info'):
......
......@@ -322,19 +322,59 @@ def get_port():
class TestCreateConnection(greentest.TestCase):
__timeout__ = 5
__timeout__ = 5000
def test(self):
try:
def test_refuses(self):
with self.assertRaises(socket.error) as cm:
socket.create_connection((greentest.DEFAULT_BIND_ADDR, get_port()),
timeout=30,
source_address=('', get_port()))
except socket.error as ex:
if 'refused' not in str(ex).lower():
raise
else:
raise AssertionError('create_connection did not raise socket.error as expected')
ex = cm.exception
self.assertIn('refused', str(ex).lower())
def test_base_exception(self):
# such as a GreenletExit or a gevent.timeout.Timeout
class E(BaseException):
pass
class MockSocket(object):
created = ()
closed = False
def __init__(self, *_):
MockSocket.created += (self,)
def connect(self, _):
raise E()
def close(self):
self.closed = True
def mockgetaddrinfo(*_):
return [(1, 2, 3, 3, 5),]
import gevent.socket as gsocket
# Make sure we're monkey patched
self.assertEqual(gsocket.create_connection, socket.create_connection)
orig_socket = gsocket.socket
orig_getaddrinfo = gsocket.getaddrinfo
try:
gsocket.socket = MockSocket
gsocket.getaddrinfo = mockgetaddrinfo
with self.assertRaises(E):
socket.create_connection(('host', 'port'))
self.assertEqual(1, len(MockSocket.created))
self.assertTrue(MockSocket.created[0].closed)
finally:
MockSocket.created = ()
gsocket.socket = orig_socket
gsocket.getaddrinfo = orig_getaddrinfo
class TestFunctions(greentest.TestCase):
......
......@@ -34,6 +34,7 @@ class TestSocketErrors(greentest.TestCase):
def test_connection_refused(self):
s = socket()
self._close_on_teardown(s)
try:
s.connect(('127.0.0.1', 81))
except error as ex:
......
......@@ -12,12 +12,13 @@ def _send(socket):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('127.0.0.1', 12345))
getattr(sock, meth)(anStructure)
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('127.0.0.1', 12345))
sock.settimeout(1.0)
getattr(sock, meth)(anStructure)
sock.close()
def TestSendBuiltinSocket():
import socket
......
......@@ -5,36 +5,42 @@ import greentest
class Test(greentest.TestCase):
def start(self):
server = None
acceptor = None
server_port = None
def _accept(self):
conn, _ = self.server.accept()
self._close_on_teardown(conn)
def setUp(self):
super(Test, self).setUp()
self.server = socket.socket()
self._close_on_teardown(self.server)
self.server.bind(('127.0.0.1', 0))
self.server.listen(1)
self.server_port = self.server.getsockname()[1]
self.acceptor = gevent.spawn(self.server.accept)
self.acceptor = gevent.spawn(self._accept)
def stop(self):
self.server.close()
def tearDown(self):
self.acceptor.kill()
self.server.close()
del self.acceptor
del self.server
super(Test, self).tearDown()
def test(self):
self.start()
try:
sock = socket.socket()
self._close_on_teardown(sock)
sock.connect(('127.0.0.1', self.server_port))
try:
sock.settimeout(0.1)
try:
result = sock.recv(1024)
raise AssertionError('Expected timeout to be raised, instead recv() returned %r' % (result, ))
except socket.error as ex:
with self.assertRaises(socket.error) as cm:
sock.recv(1024)
ex = cm.exception
self.assertEqual(ex.args, ('timed out',))
self.assertEqual(str(ex), 'timed out')
finally:
sock.close()
finally:
self.stop()
if __name__ == '__main__':
......
......@@ -71,7 +71,8 @@ class Test(greentest.TestCase):
def test_communicate(self):
p = subprocess.Popen([sys.executable, "-c",
p = subprocess.Popen([sys.executable, "-W", "ignore",
"-c",
'import sys,os;'
'sys.stderr.write("pineapple");'
'sys.stdout.write(sys.stdin.read())'],
......@@ -91,7 +92,9 @@ class Test(greentest.TestCase):
# Native string all the things. See https://github.com/gevent/gevent/issues/1039
p = subprocess.Popen(
[
sys.executable, "-c",
sys.executable,
"-W", "ignore",
"-c",
'import sys,os;'
'sys.stderr.write("pineapple\\r\\n\\xff\\xff\\xf2\\xf9\\r\\n");'
'sys.stdout.write(sys.stdin.read())'
......
......@@ -5,11 +5,7 @@ from gevent import queue as Queue
import threading
import time
import unittest
try:
from test import support as test_support
except ImportError:
from test import test_support
from _six import xrange
QUEUE_SIZE = 5
......@@ -48,7 +44,7 @@ class _TriggerThread(threading.Thread):
# is supposed to raise an exception, call do_exceptional_blocking_test()
# instead.
class BlockingTestMixin:
class BlockingTestMixin(object):
def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args):
self.t = _TriggerThread(trigger_func, trigger_args)
......@@ -65,18 +61,13 @@ class BlockingTestMixin:
return self.result
# Call this instead if block_func is supposed to raise an exception.
def do_exceptional_blocking_test(self,block_func, block_args, trigger_func,
def do_exceptional_blocking_test(self, block_func, block_args, trigger_func,
trigger_args, expected_exception_class):
self.t = _TriggerThread(trigger_func, trigger_args)
self.t.start()
try:
try:
with self.assertRaises(expected_exception_class):
block_func(*block_args)
except expected_exception_class:
raise
else:
self.fail("expected exception of kind %r" %
expected_exception_class)
finally:
self.t.join(10) # make sure the thread terminates
if self.t.isAlive():
......@@ -87,6 +78,8 @@ class BlockingTestMixin:
class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
type2test = Queue.Queue
def setUp(self):
self.cum = 0
self.cumlock = threading.Lock()
......@@ -100,26 +93,26 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
q.put(222)
q.put(444)
target_first_items = dict(
Queue = 111,
LifoQueue = 444,
PriorityQueue = 111)
Queue=111,
LifoQueue=444,
PriorityQueue=111)
actual_first_item = (q.peek(), q.get())
self.assertEquals(actual_first_item,
self.assertEqual(actual_first_item,
(target_first_items[q.__class__.__name__],
target_first_items[q.__class__.__name__]),
"q.peek() and q.get() are not equal!")
target_order = dict(Queue = [333, 222, 444],
LifoQueue = [222, 333, 111],
PriorityQueue = [222, 333, 444])
target_order = dict(Queue=[333, 222, 444],
LifoQueue=[222, 333, 111],
PriorityQueue=[222, 333, 444])
actual_order = [q.get(), q.get(), q.get()]
self.assertEquals(actual_order, target_order[q.__class__.__name__],
self.assertEqual(actual_order, target_order[q.__class__.__name__],
"Didn't seem to queue the correct data!")
for i in range(QUEUE_SIZE-1):
q.put(i)
self.assert_(not q.empty(), "Queue should not be empty")
self.assert_(not q.full(), "Queue should not be full")
self.assertFalse(q.empty(), "Queue should not be empty")
self.assertFalse(q.full(), "Queue should not be full")
q.put(999)
self.assert_(q.full(), "Queue should be full")
self.assertTrue(q.full(), "Queue should be full")
try:
q.put(888, block=0)
self.fail("Didn't appear to block with a full queue")
......@@ -130,14 +123,14 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
self.fail("Didn't appear to time-out with a full queue")
except Queue.Full:
pass
self.assertEquals(q.qsize(), QUEUE_SIZE)
self.assertEqual(q.qsize(), QUEUE_SIZE)
# Test a blocking put
self.do_blocking_test(q.put, (888,), q.get, ())
self.do_blocking_test(q.put, (888, True, 10), q.get, ())
# Empty it
for i in range(QUEUE_SIZE):
q.get()
self.assert_(q.empty(), "Queue should be empty")
self.assertTrue(q.empty(), "Queue should be empty")
try:
q.get(block=0)
self.fail("Didn't appear to block with an empty queue")
......@@ -164,14 +157,14 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
def queue_join_test(self, q):
self.cum = 0
for i in (0,1):
for i in (0, 1):
threading.Thread(target=self.worker, args=(q,)).start()
for i in xrange(100):
for i in range(100):
q.put(i)
q.join()
self.assertEquals(self.cum, sum(range(100)),
self.assertEqual(self.cum, sum(range(100)),
"q.join() did not block until all tasks were done")
for i in (0,1):
for i in (0, 1):
q.put(None) # instruct the threads to close
q.join() # verify that you can join twice
......@@ -227,10 +220,6 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
self.simple_queue_test(q)
self.simple_queue_test(q)
class QueueTest(BaseQueueTest):
type2test = Queue.Queue
class LifoQueueTest(BaseQueueTest):
type2test = Queue.LifoQueue
......@@ -274,79 +263,59 @@ class FailingQueueTest(unittest.TestCase, BlockingTestMixin):
q.put(i)
# Test a failing non-blocking put.
q.fail_next_put = True
try:
with self.assertRaises(FailingQueueException):
q.put("oops", block=0)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
q.fail_next_put = True
try:
with self.assertRaises(FailingQueueException):
q.put("oops", timeout=0.1)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
q.put(999)
self.assert_(q.full(), "Queue should be full")
self.assertTrue(q.full(), "Queue should be full")
# Test a failing blocking put
q.fail_next_put = True
try:
with self.assertRaises(FailingQueueException):
self.do_blocking_test(q.put, (888,), q.get, ())
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# Check the Queue isn't damaged.
# put failed, but get succeeded - re-add
q.put(999)
# Test a failing timeout put
q.fail_next_put = True
try:
self.do_exceptional_blocking_test(q.put, (888, True, 10), q.get, (),
FailingQueueException)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# Check the Queue isn't damaged.
# put failed, but get succeeded - re-add
q.put(999)
self.assert_(q.full(), "Queue should be full")
self.assertTrue(q.full(), "Queue should be full")
q.get()
self.assert_(not q.full(), "Queue should not be full")
self.assertFalse(q.full(), "Queue should not be full")
q.put(999)
self.assert_(q.full(), "Queue should be full")
self.assertTrue(q.full(), "Queue should be full")
# Test a blocking put
self.do_blocking_test(q.put, (888,), q.get, ())
# Empty it
for i in range(QUEUE_SIZE):
q.get()
self.assert_(q.empty(), "Queue should be empty")
self.assertTrue(q.empty(), "Queue should be empty")
q.put("first")
q.fail_next_get = True
try:
with self.assertRaises(FailingQueueException):
q.get()
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
self.assert_(not q.empty(), "Queue should not be empty")
self.assertFalse(q.empty(), "Queue should not be empty")
q.fail_next_get = True
try:
with self.assertRaises(FailingQueueException):
q.get(timeout=0.1)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
self.assert_(not q.empty(), "Queue should not be empty")
self.assertFalse(q.empty(), "Queue should not be empty")
q.get()
self.assert_(q.empty(), "Queue should be empty")
self.assertTrue(q.empty(), "Queue should be empty")
q.fail_next_get = True
try:
self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',),
FailingQueueException)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# put succeeded, but get failed.
self.assert_(not q.empty(), "Queue should not be empty")
self.assertFalse(q.empty(), "Queue should not be empty")
q.get()
self.assert_(q.empty(), "Queue should be empty")
self.assertTrue(q.empty(), "Queue should be empty")
def test_failing_queue(self):
# Test to make sure a queue is functioning correctly.
......@@ -356,10 +325,5 @@ class FailingQueueTest(unittest.TestCase, BlockingTestMixin):
self.failing_queue_test(q)
def test_main():
test_support.run_unittest(QueueTest, LifoQueueTest, PriorityQueueTest,
FailingQueueTest)
if __name__ == "__main__":
test_main()
unittest.main()
......@@ -297,6 +297,16 @@ def main():
config_data = f.read()
six.exec_(config_data, config)
FAILING_TESTS = config['FAILING_TESTS']
if 'PYTHONWARNINGS' not in os.environ and not sys.warnoptions:
# Enable default warnings such as ResourceWarning.
# On Python 3[.6], the system site.py module has
# "open(fullname, 'rU')" which produces the warning that
# 'U' is deprecated, so ignore warnings from site.py
os.environ['PYTHONWARNINGS'] = 'default,ignore:::site:'
if 'PYTHONFAULTHANDLER' not in os.environ:
os.environ['PYTHONFAULTHANDLER'] = 'true'
tests = discover(options.tests, options.ignore, coverage)
if options.discover:
for cmd, options in tests:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment