Commit 88d36895 authored by Jason Madden's avatar Jason Madden

Update 3.6pypy tests. All pass on OSX with pypy3.6-7.3

parent b09e74ed
...@@ -854,6 +854,10 @@ if PYPY and PY3: ...@@ -854,6 +854,10 @@ if PYPY and PY3:
'test_subprocess.POSIXProcessTestCase.test_pass_fds_inheritable', 'test_subprocess.POSIXProcessTestCase.test_pass_fds_inheritable',
'test_subprocess.POSIXProcessTestCase.test_pipe_cloexec', 'test_subprocess.POSIXProcessTestCase.test_pipe_cloexec',
# This passes various "invalid" strings and expects a ValueError. not sure why
# we don't see errors on Linux.
'test_subprocess.ProcessTestCase.test_invalid_env',
# The below are new with 5.10.1 # The below are new with 5.10.1
# These fail with 'OSError: received malformed or improperly truncated ancillary data' # These fail with 'OSError: received malformed or improperly truncated ancillary data'
'test_socket.RecvmsgSCMRightsStreamTest.testCmsgTruncLen0', 'test_socket.RecvmsgSCMRightsStreamTest.testCmsgTruncLen0',
...@@ -872,12 +876,16 @@ if PYPY and PY3: ...@@ -872,12 +876,16 @@ if PYPY and PY3:
'test_ssl.ThreadedTests.test_protocol_sslv3', 'test_ssl.ThreadedTests.test_protocol_sslv3',
'test_ssl.ThreadedTests.test_protocol_tlsv1', 'test_ssl.ThreadedTests.test_protocol_tlsv1',
'test_ssl.ThreadedTests.test_protocol_tlsv1_1', 'test_ssl.ThreadedTests.test_protocol_tlsv1_1',
# Similar, they fail without monkey-patching.
'test_ssl.TestPostHandshakeAuth.test_pha_no_pha_client',
'test_ssl.TestPostHandshakeAuth.test_pha_optional',
'test_ssl.TestPostHandshakeAuth.test_pha_required',
# This gets None instead of http1.1, even without gevent # This gets None instead of http1.1, even without gevent
'test_ssl.ThreadedTests.test_npn_protocols', 'test_ssl.ThreadedTests.test_npn_protocols',
# This fails to decode a filename even without gevent, # This fails to decode a filename even without gevent,
# at least on High Sierarr. # at least on High Sierra. Newer versions of the tests actually skip this.
'test_httpservers.SimpleHTTPServerTestCase.test_undecodable_filename', 'test_httpservers.SimpleHTTPServerTestCase.test_undecodable_filename',
] ]
......
DH Parameters: (3072 bit)
prime:
00:ff:ff:ff:ff:ff:ff:ff:ff:ad:f8:54:58:a2:bb:
4a:9a:af:dc:56:20:27:3d:3c:f1:d8:b9:c5:83:ce:
2d:36:95:a9:e1:36:41:14:64:33:fb:cc:93:9d:ce:
24:9b:3e:f9:7d:2f:e3:63:63:0c:75:d8:f6:81:b2:
02:ae:c4:61:7a:d3:df:1e:d5:d5:fd:65:61:24:33:
f5:1f:5f:06:6e:d0:85:63:65:55:3d:ed:1a:f3:b5:
57:13:5e:7f:57:c9:35:98:4f:0c:70:e0:e6:8b:77:
e2:a6:89:da:f3:ef:e8:72:1d:f1:58:a1:36:ad:e7:
35:30:ac:ca:4f:48:3a:79:7a:bc:0a:b1:82:b3:24:
fb:61:d1:08:a9:4b:b2:c8:e3:fb:b9:6a:da:b7:60:
d7:f4:68:1d:4f:42:a3:de:39:4d:f4:ae:56:ed:e7:
63:72:bb:19:0b:07:a7:c8:ee:0a:6d:70:9e:02:fc:
e1:cd:f7:e2:ec:c0:34:04:cd:28:34:2f:61:91:72:
fe:9c:e9:85:83:ff:8e:4f:12:32:ee:f2:81:83:c3:
fe:3b:1b:4c:6f:ad:73:3b:b5:fc:bc:2e:c2:20:05:
c5:8e:f1:83:7d:16:83:b2:c6:f3:4a:26:c1:b2:ef:
fa:88:6b:42:38:61:1f:cf:dc:de:35:5b:3b:65:19:
03:5b:bc:34:f4:de:f9:9c:02:38:61:b4:6f:c9:d6:
e6:c9:07:7a:d9:1d:26:91:f7:f7:ee:59:8c:b0:fa:
c1:86:d9:1c:ae:fe:13:09:85:13:92:70:b4:13:0c:
93:bc:43:79:44:f4:fd:44:52:e2:d7:4d:d3:64:f2:
e2:1e:71:f5:4b:ff:5c:ae:82:ab:9c:9d:f6:9e:e8:
6d:2b:c5:22:36:3a:0d:ab:c5:21:97:9b:0d:ea:da:
1d:bf:9a:42:d5:c4:48:4e:0a:bc:d0:6b:fa:53:dd:
ef:3c:1b:20:ee:3f:d5:9d:7c:25:e4:1d:2b:66:c6:
2e:37:ff:ff:ff:ff:ff:ff:ff:ff
generator: 2 (0x2)
recommended-private-length: 276 bits
-----BEGIN DH PARAMETERS-----
MIIBjAKCAYEA//////////+t+FRYortKmq/cViAnPTzx2LnFg84tNpWp4TZBFGQz
+8yTnc4kmz75fS/jY2MMddj2gbICrsRhetPfHtXV/WVhJDP1H18GbtCFY2VVPe0a
87VXE15/V8k1mE8McODmi3fipona8+/och3xWKE2rec1MKzKT0g6eXq8CrGCsyT7
YdEIqUuyyOP7uWrat2DX9GgdT0Kj3jlN9K5W7edjcrsZCwenyO4KbXCeAvzhzffi
7MA0BM0oNC9hkXL+nOmFg/+OTxIy7vKBg8P+OxtMb61zO7X8vC7CIAXFjvGDfRaD
ssbzSibBsu/6iGtCOGEfz9zeNVs7ZRkDW7w09N75nAI4YbRvydbmyQd62R0mkff3
7lmMsPrBhtkcrv4TCYUTknC0EwyTvEN5RPT9RFLi103TZPLiHnH1S/9croKrnJ32
nuhtK8UiNjoNq8Uhl5sN6todv5pC1cRITgq80Gv6U93vPBsg7j/VnXwl5B0rZsYu
N///////////AgECAgIBFA==
-----END DH PARAMETERS-----
...@@ -32,6 +32,9 @@ class Bunch(object): ...@@ -32,6 +32,9 @@ class Bunch(object):
self.started = [] self.started = []
self.finished = [] self.finished = []
self._can_exit = not wait_before_exit self._can_exit = not wait_before_exit
self.wait_thread = support.wait_threads_exit()
self.wait_thread.__enter__()
def task(): def task():
tid = threading.get_ident() tid = threading.get_ident()
self.started.append(tid) self.started.append(tid)
...@@ -41,6 +44,7 @@ class Bunch(object): ...@@ -41,6 +44,7 @@ class Bunch(object):
self.finished.append(tid) self.finished.append(tid)
while not self._can_exit: while not self._can_exit:
_wait() _wait()
try: try:
for i in range(n): for i in range(n):
start_new_thread(task, ()) start_new_thread(task, ())
...@@ -55,6 +59,8 @@ class Bunch(object): ...@@ -55,6 +59,8 @@ class Bunch(object):
def wait_for_finished(self): def wait_for_finished(self):
while len(self.finished) < self.n: while len(self.finished) < self.n:
_wait() _wait()
# Wait for threads exit
self.wait_thread.__exit__(None, None, None)
def do_finish(self): def do_finish(self):
self._can_exit = True self._can_exit = True
...@@ -222,20 +228,23 @@ class LockTests(BaseLockTests): ...@@ -222,20 +228,23 @@ class LockTests(BaseLockTests):
# Lock needs to be released before re-acquiring. # Lock needs to be released before re-acquiring.
lock = self.locktype() lock = self.locktype()
phase = [] phase = []
def f(): def f():
lock.acquire() lock.acquire()
phase.append(None) phase.append(None)
lock.acquire() lock.acquire()
phase.append(None) phase.append(None)
start_new_thread(f, ())
while len(phase) == 0: with support.wait_threads_exit():
_wait() start_new_thread(f, ())
_wait() while len(phase) == 0:
self.assertEqual(len(phase), 1) _wait()
lock.release()
while len(phase) == 1:
_wait() _wait()
self.assertEqual(len(phase), 2) self.assertEqual(len(phase), 1)
lock.release()
while len(phase) == 1:
_wait()
self.assertEqual(len(phase), 2)
def test_different_thread(self): def test_different_thread(self):
# Lock can be released from a different thread. # Lock can be released from a different thread.
...@@ -306,6 +315,7 @@ class RLockTests(BaseLockTests): ...@@ -306,6 +315,7 @@ class RLockTests(BaseLockTests):
self.assertRaises(RuntimeError, lock.release) self.assertRaises(RuntimeError, lock.release)
finally: finally:
b.do_finish() b.do_finish()
b.wait_for_finished()
def test__is_owned(self): def test__is_owned(self):
lock = self.locktype() lock = self.locktype()
...@@ -397,12 +407,13 @@ class EventTests(BaseTestCase): ...@@ -397,12 +407,13 @@ class EventTests(BaseTestCase):
# cleared before the waiting thread is woken up. # cleared before the waiting thread is woken up.
evt = self.eventtype() evt = self.eventtype()
results = [] results = []
timeout = 0.250
N = 5 N = 5
def f(): def f():
results.append(evt.wait(1)) results.append(evt.wait(timeout * 4))
b = Bunch(f, N) b = Bunch(f, N)
b.wait_for_started() b.wait_for_started()
time.sleep(0.5) time.sleep(timeout)
evt.set() evt.set()
evt.clear() evt.clear()
b.wait_for_finished() b.wait_for_finished()
...@@ -463,21 +474,28 @@ class ConditionTests(BaseTestCase): ...@@ -463,21 +474,28 @@ class ConditionTests(BaseTestCase):
# construct. In particular, it is possible that this can no longer # construct. In particular, it is possible that this can no longer
# be conveniently guaranteed should their implementation ever change. # be conveniently guaranteed should their implementation ever change.
N = 5 N = 5
ready = []
results1 = [] results1 = []
results2 = [] results2 = []
phase_num = 0 phase_num = 0
def f(): def f():
cond.acquire() cond.acquire()
ready.append(phase_num)
result = cond.wait() result = cond.wait()
cond.release() cond.release()
results1.append((result, phase_num)) results1.append((result, phase_num))
cond.acquire() cond.acquire()
ready.append(phase_num)
result = cond.wait() result = cond.wait()
cond.release() cond.release()
results2.append((result, phase_num)) results2.append((result, phase_num))
b = Bunch(f, N) b = Bunch(f, N)
b.wait_for_started() b.wait_for_started()
_wait() # first wait, to ensure all workers settle into cond.wait() before
# we continue. See issues #8799 and #30727.
while len(ready) < 5:
_wait()
ready.clear()
self.assertEqual(results1, []) self.assertEqual(results1, [])
# Notify 3 threads at first # Notify 3 threads at first
cond.acquire() cond.acquire()
...@@ -489,9 +507,9 @@ class ConditionTests(BaseTestCase): ...@@ -489,9 +507,9 @@ class ConditionTests(BaseTestCase):
_wait() _wait()
self.assertEqual(results1, [(True, 1)] * 3) self.assertEqual(results1, [(True, 1)] * 3)
self.assertEqual(results2, []) self.assertEqual(results2, [])
# first wait, to ensure all workers settle into cond.wait() before # make sure all awaken workers settle into cond.wait()
# we continue. See issue #8799 while len(ready) < 3:
_wait() _wait()
# Notify 5 threads: they might be in their first or second wait # Notify 5 threads: they might be in their first or second wait
cond.acquire() cond.acquire()
cond.notify(5) cond.notify(5)
...@@ -502,7 +520,9 @@ class ConditionTests(BaseTestCase): ...@@ -502,7 +520,9 @@ class ConditionTests(BaseTestCase):
_wait() _wait()
self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2) self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
self.assertEqual(results2, [(True, 2)] * 3) self.assertEqual(results2, [(True, 2)] * 3)
_wait() # make sure all workers settle into cond.wait() # make sure all workers settle into cond.wait()
while len(ready) < 5:
_wait()
# Notify all threads: they are all in their second wait # Notify all threads: they are all in their second wait
cond.acquire() cond.acquire()
cond.notify_all() cond.notify_all()
...@@ -612,13 +632,14 @@ class BaseSemaphoreTests(BaseTestCase): ...@@ -612,13 +632,14 @@ class BaseSemaphoreTests(BaseTestCase):
sem = self.semtype(7) sem = self.semtype(7)
sem.acquire() sem.acquire()
N = 10 N = 10
sem_results = []
results1 = [] results1 = []
results2 = [] results2 = []
phase_num = 0 phase_num = 0
def f(): def f():
sem.acquire() sem_results.append(sem.acquire())
results1.append(phase_num) results1.append(phase_num)
sem.acquire() sem_results.append(sem.acquire())
results2.append(phase_num) results2.append(phase_num)
b = Bunch(f, 10) b = Bunch(f, 10)
b.wait_for_started() b.wait_for_started()
...@@ -642,6 +663,7 @@ class BaseSemaphoreTests(BaseTestCase): ...@@ -642,6 +663,7 @@ class BaseSemaphoreTests(BaseTestCase):
# Final release, to let the last thread finish # Final release, to let the last thread finish
sem.release() sem.release()
b.wait_for_finished() b.wait_for_finished()
self.assertEqual(sem_results, [True] * (6 + 7 + 6 + 1))
def test_try_acquire(self): def test_try_acquire(self):
sem = self.semtype(2) sem = self.semtype(2)
......
-----BEGIN X509 CRL----- -----BEGIN X509 CRL-----
MIIBpjCBjwIBATANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJYWTEmMCQGA1UE MIICJjCBjwIBATANBgkqhkiG9w0BAQsFADBNMQswCQYDVQQGEwJYWTEmMCQGA1UE
CgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNVBAMMDW91ci1j CgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNVBAMMDW91ci1j
YS1zZXJ2ZXIXDTEzMTEyMTE3MDg0N1oXDTIzMDkzMDE3MDg0N1qgDjAMMAoGA1Ud YS1zZXJ2ZXIXDTE4MDgyOTE0MjMxNloXDTI4MDcwNzE0MjMxNlqgDjAMMAoGA1Ud
FAQDAgEAMA0GCSqGSIb3DQEBBQUAA4IBAQCNJXC2mVKauEeN3LlQ3ZtM5gkH3ExH FAQDAgEAMA0GCSqGSIb3DQEBCwUAA4IBgQCPhrtGSbuvxPAI3YWQFDB4iOWdBnVk
+i4bmJjtJn497WwvvoIeUdrmVXgJQR93RtV37hZwN0SXMLlNmUZPH4rHhihayw4m ugW1lsifmCsE86FfID0EwUut1SRHlksltMtcoULMEIdu8yMLWci++4ve22EEuMKT
unCzVj/OhCCY7/TPjKuJ1O/0XhaLBpBVjQN7R/1ujoRKbSia/CD3vcn7Fqxzw7LK HUc3T/wBIuQUhA7U4deFG8CZPAxRpNoK470y7dkD4OVf0Gxa6WYDl9z8mXKmWCB9
fSRCKRGTj1CZiuxrphtFchwALXSiFDy9mr2ZKhImcyq1PydfgEzU78APpOkMQsIC hvzqVfLWNSLTAVPsHtkD5PXdi5yRkQr6wYD7poWaIvkpsn7EKCY6Tw5V3rsbRuZq
UNJ/cf3c9emzf+dUtcMEcejQ3mynBo4eIGg1EW42bz4q4hSjzQlKcBV0muw5qXhc AGVCq5TH3mctcmwLloCJ4Xr/1q0DsRrYxeeLYxE+UpvvCbVBKgtjBK7zINS7AbcJ
HOxH2iTFhQ7SrvVuK/dM14rYM4B5mSX3nRC1kNmXpS9j3wJDhuwmjHed CYCYKUwGWv1fYKJ+KQQHf75mT3jQ9lWuzOj/YWK4k1EBnYmVGuKKt73lLFxC6h3y
MUnaBZc1KZSyJj0IxfHg/o6qx8NgKOl9XRIQ5g5B30cwpPOskGhEhodbTTY3bPtm
RQ36JvQZngzmkhyhr+MDEV5yUTOShfUiclzQOx26CmLmLHWxOZgXtFZob/oKrvbm
Gen/+7K7YTw6hfY52U7J2FuQRGOyzBXfBYQ=
-----END X509 CRL----- -----END X509 CRL-----
...@@ -21,25 +21,19 @@ class InterProcessSignalTests(unittest.TestCase): ...@@ -21,25 +21,19 @@ class InterProcessSignalTests(unittest.TestCase):
self.got_signals['SIGUSR1'] += 1 self.got_signals['SIGUSR1'] += 1
raise SIGUSR1Exception raise SIGUSR1Exception
def wait_signal(self, child, signame, exc_class=None): def wait_signal(self, child, signame):
try: if child is not None:
if child is not None: # This wait should be interrupted by exc_class
# This wait should be interrupted by exc_class # (if set)
# (if set) child.wait()
child.wait()
timeout = 10.0
timeout = 10.0 deadline = time.monotonic() + timeout
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
while time.monotonic() < deadline: if self.got_signals[signame]:
if self.got_signals[signame]:
return
signal.pause()
except BaseException as exc:
if exc_class is not None and isinstance(exc, exc_class):
# got the expected exception
return return
raise signal.pause()
self.fail('signal %s not received after %s seconds' self.fail('signal %s not received after %s seconds'
% (signame, timeout)) % (signame, timeout))
...@@ -65,8 +59,9 @@ class InterProcessSignalTests(unittest.TestCase): ...@@ -65,8 +59,9 @@ class InterProcessSignalTests(unittest.TestCase):
self.assertEqual(self.got_signals, {'SIGHUP': 1, 'SIGUSR1': 0, self.assertEqual(self.got_signals, {'SIGHUP': 1, 'SIGUSR1': 0,
'SIGALRM': 0}) 'SIGALRM': 0})
with self.subprocess_send_signal(pid, "SIGUSR1") as child: with self.assertRaises(SIGUSR1Exception):
self.wait_signal(child, 'SIGUSR1', SIGUSR1Exception) with self.subprocess_send_signal(pid, "SIGUSR1") as child:
self.wait_signal(child, 'SIGUSR1')
self.assertEqual(self.got_signals, {'SIGHUP': 1, 'SIGUSR1': 1, self.assertEqual(self.got_signals, {'SIGHUP': 1, 'SIGUSR1': 1,
'SIGALRM': 0}) 'SIGALRM': 0})
...@@ -74,10 +69,14 @@ class InterProcessSignalTests(unittest.TestCase): ...@@ -74,10 +69,14 @@ class InterProcessSignalTests(unittest.TestCase):
# Nothing should happen: SIGUSR2 is ignored # Nothing should happen: SIGUSR2 is ignored
child.wait() child.wait()
signal.alarm(1) try:
self.wait_signal(None, 'SIGALRM', KeyboardInterrupt) with self.assertRaises(KeyboardInterrupt):
self.assertEqual(self.got_signals, {'SIGHUP': 1, 'SIGUSR1': 1, signal.alarm(1)
'SIGALRM': 0}) self.wait_signal(None, 'SIGALRM')
self.assertEqual(self.got_signals, {'SIGHUP': 1, 'SIGUSR1': 1,
'SIGALRM': 0})
finally:
signal.alarm(0)
if __name__ == "__main__": if __name__ == "__main__":
......
-----BEGIN CERTIFICATE-----
MIIDqDCCApKgAwIBAgIBAjALBgkqhkiG9w0BAQswHzELMAkGA1UEBhMCVUsxEDAO
BgNVBAMTB2NvZHktY2EwHhcNMTgwNjE4MTgwMDU4WhcNMjgwNjE0MTgwMDU4WjA7
MQswCQYDVQQGEwJVSzEsMCoGA1UEAxMjY29kZW5vbWljb24tdm0tMi50ZXN0Lmxh
bC5jaXNjby5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC63fGB
J80A9Av1GB0bptslKRIUtJm8EeEu34HkDWbL6AJY0P8WfDtlXjlPaLqFa6sqH6ES
V48prSm1ZUbDSVL8R6BYVYpOlK8/48xk4pGTgRzv69gf5SGtQLwHy8UPBKgjSZoD
5a5k5wJXGswhKFFNqyyxqCvWmMnJWxXTt2XDCiWc4g4YAWi4O4+6SeeHVAV9rV7C
1wxqjzKovVe2uZOHjKEzJbbIU6JBPb6TRfMdRdYOw98n1VXDcKVgdX2DuuqjCzHP
WhU4Tw050M9NaK3eXp4Mh69VuiKoBGOLSOcS8reqHIU46Reg0hqeL8LIL6OhFHIF
j7HR6V1X6F+BfRS/AgMBAAGjgdYwgdMwCQYDVR0TBAIwADAdBgNVHQ4EFgQUOktp
HQjxDXXUg8prleY9jeLKeQ4wTwYDVR0jBEgwRoAUx6zgPygZ0ZErF9sPC4+5e2Io
UU+hI6QhMB8xCzAJBgNVBAYTAlVLMRAwDgYDVQQDEwdjb2R5LWNhggkA1QEAuwb7
2s0wCQYDVR0SBAIwADAuBgNVHREEJzAlgiNjb2Rlbm9taWNvbi12bS0yLnRlc3Qu
bGFsLmNpc2NvLmNvbTAOBgNVHQ8BAf8EBAMCBaAwCwYDVR0fBAQwAjAAMAsGCSqG
SIb3DQEBCwOCAQEAvqantx2yBlM11RoFiCfi+AfSblXPdrIrHvccepV4pYc/yO6p
t1f2dxHQb8rWH3i6cWag/EgIZx+HJQvo0rgPY1BFJsX1WnYf1/znZpkUBGbVmlJr
t/dW1gSkNS6sPsM0Q+7HPgEv8CPDNK5eo7vU2seE0iWOkxSyVUuiCEY9ZVGaLVit
p0C78nZ35Pdv4I+1cosmHl28+es1WI22rrnmdBpH8J1eY6WvUw2xuZHLeNVN0TzV
Q3qq53AaCWuLOD1AjESWuUCxMZTK9DPS4JKXTK8RLyDeqOvJGjsSWp3kL0y3GaQ+
10T1rfkKJub2+m9A9duin1fn6tHc2wSvB7m3DA==
-----END CERTIFICATE-----
...@@ -433,7 +433,10 @@ class FileWrapperTest(unittest.TestCase): ...@@ -433,7 +433,10 @@ class FileWrapperTest(unittest.TestCase):
f = asyncore.file_wrapper(fd) f = asyncore.file_wrapper(fd)
os.close(fd) os.close(fd)
f.close() os.close(f.fd) # file_wrapper dupped fd
with self.assertRaises(OSError):
f.close()
self.assertEqual(f.fd, -1) self.assertEqual(f.fd, -1)
# calling close twice should not fail # calling close twice should not fail
f.close() f.close()
...@@ -502,7 +505,7 @@ class BaseClient(BaseTestHandler): ...@@ -502,7 +505,7 @@ class BaseClient(BaseTestHandler):
class BaseTestAPI: class BaseTestAPI:
def tearDown(self): def tearDown(self):
asyncore.close_all() asyncore.close_all(ignore_all=True)
def loop_waiting_for_flag(self, instance, timeout=5): def loop_waiting_for_flag(self, instance, timeout=5):
timeout = float(timeout) / 100 timeout = float(timeout) / 100
...@@ -755,50 +758,50 @@ class BaseTestAPI: ...@@ -755,50 +758,50 @@ class BaseTestAPI:
def test_set_reuse_addr(self): def test_set_reuse_addr(self):
if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
self.skipTest("Not applicable to AF_UNIX sockets.") self.skipTest("Not applicable to AF_UNIX sockets.")
sock = socket.socket(self.family)
try: with socket.socket(self.family) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try:
except OSError: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
unittest.skip("SO_REUSEADDR not supported on this platform") except OSError:
else: unittest.skip("SO_REUSEADDR not supported on this platform")
# if SO_REUSEADDR succeeded for sock we expect asyncore else:
# to do the same # if SO_REUSEADDR succeeded for sock we expect asyncore
s = asyncore.dispatcher(socket.socket(self.family)) # to do the same
self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET, s = asyncore.dispatcher(socket.socket(self.family))
socket.SO_REUSEADDR)) self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET,
s.socket.close() socket.SO_REUSEADDR))
s.create_socket(self.family) s.socket.close()
s.set_reuse_addr() s.create_socket(self.family)
self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET, s.set_reuse_addr()
socket.SO_REUSEADDR)) self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET,
finally: socket.SO_REUSEADDR))
sock.close()
@unittest.skipUnless(threading, 'Threading required for this test.') @unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads @support.reap_threads
def test_quick_connect(self): def test_quick_connect(self):
# see: http://bugs.python.org/issue10340 # see: http://bugs.python.org/issue10340
if self.family in (socket.AF_INET, getattr(socket, "AF_INET6", object())): if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())):
server = BaseServer(self.family, self.addr) self.skipTest("test specific to AF_INET and AF_INET6")
t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1,
count=500)) server = BaseServer(self.family, self.addr)
t.start() # run the thread 500 ms: the socket should be connected in 200 ms
def cleanup(): t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1,
t.join(timeout=TIMEOUT) count=5))
if t.is_alive(): t.start()
self.fail("join() timed out") try:
self.addCleanup(cleanup) with socket.socket(self.family, socket.SOCK_STREAM) as s:
s.settimeout(.2)
s = socket.socket(self.family, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
s.settimeout(.2) struct.pack('ii', 1, 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0)) try:
try: s.connect(server.address)
s.connect(server.address) except OSError:
except OSError: pass
pass finally:
finally: t.join(timeout=TIMEOUT)
s.close() if t.is_alive():
self.fail("join() timed out")
class TestAPI_UseIPv4Sockets(BaseTestAPI): class TestAPI_UseIPv4Sockets(BaseTestAPI):
family = socket.AF_INET family = socket.AF_INET
......
...@@ -52,6 +52,7 @@ class TestServerThread(threading.Thread): ...@@ -52,6 +52,7 @@ class TestServerThread(threading.Thread):
def stop(self): def stop(self):
self.server.shutdown() self.server.shutdown()
self.join()
class BaseTestCase(unittest.TestCase): class BaseTestCase(unittest.TestCase):
...@@ -371,7 +372,8 @@ class SimpleHTTPServerTestCase(BaseTestCase): ...@@ -371,7 +372,8 @@ class SimpleHTTPServerTestCase(BaseTestCase):
reader.close() reader.close()
return body return body
@support.requires_mac_ver(10, 5) @unittest.skipIf(sys.platform == 'darwin',
'undecodable name cannot always be decoded on macOS')
@unittest.skipIf(sys.platform == 'win32', @unittest.skipIf(sys.platform == 'win32',
'undecodable name cannot be decoded on win32') 'undecodable name cannot be decoded on win32')
@unittest.skipUnless(support.TESTFN_UNDECODABLE, @unittest.skipUnless(support.TESTFN_UNDECODABLE,
......
...@@ -46,28 +46,27 @@ class _TriggerThread(threading.Thread): ...@@ -46,28 +46,27 @@ class _TriggerThread(threading.Thread):
class BlockingTestMixin: class BlockingTestMixin:
def tearDown(self):
self.t = None
def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args): def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args):
self.t = _TriggerThread(trigger_func, trigger_args) thread = _TriggerThread(trigger_func, trigger_args)
self.t.start() thread.start()
self.result = block_func(*block_args) try:
# If block_func returned before our thread made the call, we failed! self.result = block_func(*block_args)
if not self.t.startedEvent.is_set(): # If block_func returned before our thread made the call, we failed!
self.fail("blocking function '%r' appeared not to block" % if not thread.startedEvent.is_set():
block_func) self.fail("blocking function '%r' appeared not to block" %
self.t.join(10) # make sure the thread terminates block_func)
if self.t.is_alive(): return self.result
self.fail("trigger function '%r' appeared to not return" % finally:
trigger_func) thread.join(10) # make sure the thread terminates
return self.result if thread.is_alive():
self.fail("trigger function '%r' appeared to not return" %
trigger_func)
# Call this instead if block_func is supposed to raise an exception. # Call this instead if block_func is supposed to raise an exception.
def do_exceptional_blocking_test(self,block_func, block_args, trigger_func, def do_exceptional_blocking_test(self,block_func, block_args, trigger_func,
trigger_args, expected_exception_class): trigger_args, expected_exception_class):
self.t = _TriggerThread(trigger_func, trigger_args) thread = _TriggerThread(trigger_func, trigger_args)
self.t.start() thread.start()
try: try:
try: try:
block_func(*block_args) block_func(*block_args)
...@@ -77,11 +76,11 @@ class BlockingTestMixin: ...@@ -77,11 +76,11 @@ class BlockingTestMixin:
self.fail("expected exception of kind %r" % self.fail("expected exception of kind %r" %
expected_exception_class) expected_exception_class)
finally: finally:
self.t.join(10) # make sure the thread terminates thread.join(10) # make sure the thread terminates
if self.t.is_alive(): if thread.is_alive():
self.fail("trigger function '%r' appeared to not return" % self.fail("trigger function '%r' appeared to not return" %
trigger_func) trigger_func)
if not self.t.startedEvent.is_set(): if not thread.startedEvent.is_set():
self.fail("trigger thread ended but event never set") self.fail("trigger thread ended but event never set")
...@@ -159,8 +158,11 @@ class BaseQueueTestMixin(BlockingTestMixin): ...@@ -159,8 +158,11 @@ class BaseQueueTestMixin(BlockingTestMixin):
def queue_join_test(self, q): def queue_join_test(self, q):
self.cum = 0 self.cum = 0
threads = []
for i in (0,1): for i in (0,1):
threading.Thread(target=self.worker, args=(q,)).start() thread = threading.Thread(target=self.worker, args=(q,))
thread.start()
threads.append(thread)
for i in range(100): for i in range(100):
q.put(i) q.put(i)
q.join() q.join()
...@@ -169,6 +171,8 @@ class BaseQueueTestMixin(BlockingTestMixin): ...@@ -169,6 +171,8 @@ class BaseQueueTestMixin(BlockingTestMixin):
for i in (0,1): for i in (0,1):
q.put(-1) # instruct the threads to close q.put(-1) # instruct the threads to close
q.join() # verify that you can join twice q.join() # verify that you can join twice
for thread in threads:
thread.join()
def test_queue_task_done(self): def test_queue_task_done(self):
# Test to make sure a queue task completed successfully. # Test to make sure a queue task completed successfully.
......
...@@ -3,11 +3,13 @@ from test import support ...@@ -3,11 +3,13 @@ from test import support
from contextlib import closing from contextlib import closing
import enum import enum
import gc import gc
import os
import pickle import pickle
import random
import select import select
import signal import signal
import socket import socket
import struct import statistics
import subprocess import subprocess
import traceback import traceback
import sys, os, time, errno import sys, os, time, errno
...@@ -370,7 +372,6 @@ class WakeupSocketSignalTests(unittest.TestCase): ...@@ -370,7 +372,6 @@ class WakeupSocketSignalTests(unittest.TestCase):
signal.signal(signum, handler) signal.signal(signum, handler)
read, write = socket.socketpair() read, write = socket.socketpair()
read.setblocking(False)
write.setblocking(False) write.setblocking(False)
signal.set_wakeup_fd(write.fileno()) signal.set_wakeup_fd(write.fileno())
...@@ -615,6 +616,15 @@ class ItimerTest(unittest.TestCase): ...@@ -615,6 +616,15 @@ class ItimerTest(unittest.TestCase):
# and the handler should have been called # and the handler should have been called
self.assertEqual(self.hndl_called, True) self.assertEqual(self.hndl_called, True)
def test_setitimer_tiny(self):
# bpo-30807: C setitimer() takes a microsecond-resolution interval.
# Check that float -> timeval conversion doesn't round
# the interval down to zero, which would disable the timer.
self.itimer = signal.ITIMER_REAL
signal.setitimer(self.itimer, 1e-6)
time.sleep(1)
self.assertEqual(self.hndl_called, True)
class PendingSignalsTests(unittest.TestCase): class PendingSignalsTests(unittest.TestCase):
""" """
...@@ -950,6 +960,135 @@ class PendingSignalsTests(unittest.TestCase): ...@@ -950,6 +960,135 @@ class PendingSignalsTests(unittest.TestCase):
(exitcode, stdout)) (exitcode, stdout))
class StressTest(unittest.TestCase):
"""
Stress signal delivery, especially when a signal arrives in
the middle of recomputing the signal state or executing
previously tripped signal handlers.
"""
def setsig(self, signum, handler):
old_handler = signal.signal(signum, handler)
self.addCleanup(signal.signal, signum, old_handler)
def measure_itimer_resolution(self):
N = 20
times = []
def handler(signum=None, frame=None):
if len(times) < N:
times.append(time.perf_counter())
# 1 µs is the smallest possible timer interval,
# we want to measure what the concrete duration
# will be on this platform
signal.setitimer(signal.ITIMER_REAL, 1e-6)
self.addCleanup(signal.setitimer, signal.ITIMER_REAL, 0)
self.setsig(signal.SIGALRM, handler)
handler()
while len(times) < N:
time.sleep(1e-3)
durations = [times[i+1] - times[i] for i in range(len(times) - 1)]
med = statistics.median(durations)
if support.verbose:
print("detected median itimer() resolution: %.6f s." % (med,))
return med
def decide_itimer_count(self):
# Some systems have poor setitimer() resolution (for example
# measured around 20 ms. on FreeBSD 9), so decide on a reasonable
# number of sequential timers based on that.
reso = self.measure_itimer_resolution()
if reso <= 1e-4:
return 10000
elif reso <= 1e-2:
return 100
else:
self.skipTest("detected itimer resolution (%.3f s.) too high "
"(> 10 ms.) on this platform (or system too busy)"
% (reso,))
@unittest.skipUnless(hasattr(signal, "setitimer"),
"test needs setitimer()")
def test_stress_delivery_dependent(self):
"""
This test uses dependent signal handlers.
"""
N = self.decide_itimer_count()
sigs = []
def first_handler(signum, frame):
# 1e-6 is the minimum non-zero value for `setitimer()`.
# Choose a random delay so as to improve chances of
# triggering a race condition. Ideally the signal is received
# when inside critical signal-handling routines such as
# Py_MakePendingCalls().
signal.setitimer(signal.ITIMER_REAL, 1e-6 + random.random() * 1e-5)
def second_handler(signum=None, frame=None):
sigs.append(signum)
# Here on Linux, SIGPROF > SIGALRM > SIGUSR1. By using both
# ascending and descending sequences (SIGUSR1 then SIGALRM,
# SIGPROF then SIGALRM), we maximize chances of hitting a bug.
self.setsig(signal.SIGPROF, first_handler)
self.setsig(signal.SIGUSR1, first_handler)
self.setsig(signal.SIGALRM, second_handler) # for ITIMER_REAL
expected_sigs = 0
deadline = time.time() + 15.0
while expected_sigs < N:
os.kill(os.getpid(), signal.SIGPROF)
expected_sigs += 1
# Wait for handlers to run to avoid signal coalescing
while len(sigs) < expected_sigs and time.time() < deadline:
time.sleep(1e-5)
os.kill(os.getpid(), signal.SIGUSR1)
expected_sigs += 1
while len(sigs) < expected_sigs and time.time() < deadline:
time.sleep(1e-5)
# All ITIMER_REAL signals should have been delivered to the
# Python handler
self.assertEqual(len(sigs), N, "Some signals were lost")
@unittest.skipUnless(hasattr(signal, "setitimer"),
"test needs setitimer()")
def test_stress_delivery_simultaneous(self):
"""
This test uses simultaneous signal handlers.
"""
N = self.decide_itimer_count()
sigs = []
def handler(signum, frame):
sigs.append(signum)
self.setsig(signal.SIGUSR1, handler)
self.setsig(signal.SIGALRM, handler) # for ITIMER_REAL
expected_sigs = 0
deadline = time.time() + 15.0
while expected_sigs < N:
# Hopefully the SIGALRM will be received somewhere during
# initial processing of SIGUSR1.
signal.setitimer(signal.ITIMER_REAL, 1e-6 + random.random() * 1e-5)
os.kill(os.getpid(), signal.SIGUSR1)
expected_sigs += 2
# Wait for handlers to run to avoid signal coalescing
while len(sigs) < expected_sigs and time.time() < deadline:
time.sleep(1e-5)
# All ITIMER_REAL signals should have been delivered to the
# Python handler
self.assertEqual(len(sigs), N, "Some signals were lost")
def tearDownModule(): def tearDownModule():
support.reap_children() support.reap_children()
......
...@@ -18,6 +18,11 @@ import textwrap ...@@ -18,6 +18,11 @@ import textwrap
import unittest import unittest
from test import support, mock_socket from test import support, mock_socket
from unittest.mock import Mock
HOST = "localhost"
HOSTv4 = "127.0.0.1"
HOSTv6 = "::1"
try: try:
import threading import threading
...@@ -569,6 +574,33 @@ class NonConnectingTests(unittest.TestCase): ...@@ -569,6 +574,33 @@ class NonConnectingTests(unittest.TestCase):
"localhost:bogus") "localhost:bogus")
class DefaultArgumentsTests(unittest.TestCase):
def setUp(self):
self.msg = EmailMessage()
self.msg['From'] = 'Páolo <főo@bar.com>'
self.smtp = smtplib.SMTP()
self.smtp.ehlo = Mock(return_value=(200, 'OK'))
self.smtp.has_extn, self.smtp.sendmail = Mock(), Mock()
def testSendMessage(self):
expected_mail_options = ('SMTPUTF8', 'BODY=8BITMIME')
self.smtp.send_message(self.msg)
self.smtp.send_message(self.msg)
self.assertEqual(self.smtp.sendmail.call_args_list[0][0][3],
expected_mail_options)
self.assertEqual(self.smtp.sendmail.call_args_list[1][0][3],
expected_mail_options)
def testSendMessageWithMailOptions(self):
mail_options = ['STARTTLS']
expected_mail_options = ('STARTTLS', 'SMTPUTF8', 'BODY=8BITMIME')
self.smtp.send_message(self.msg, None, None, mail_options)
self.assertEqual(mail_options, ['STARTTLS'])
self.assertEqual(self.smtp.sendmail.call_args_list[0][0][3],
expected_mail_options)
# test response of client to a non-successful HELO message # test response of client to a non-successful HELO message
@unittest.skipUnless(threading, 'Threading required for this test.') @unittest.skipUnless(threading, 'Threading required for this test.')
class BadHELOServerTests(unittest.TestCase): class BadHELOServerTests(unittest.TestCase):
...@@ -604,7 +636,9 @@ class TooLongLineTests(unittest.TestCase): ...@@ -604,7 +636,9 @@ class TooLongLineTests(unittest.TestCase):
self.sock.settimeout(15) self.sock.settimeout(15)
self.port = support.bind_port(self.sock) self.port = support.bind_port(self.sock)
servargs = (self.evt, self.respdata, self.sock) servargs = (self.evt, self.respdata, self.sock)
threading.Thread(target=server, args=servargs).start() thread = threading.Thread(target=server, args=servargs)
thread.start()
self.addCleanup(thread.join)
self.evt.wait() self.evt.wait()
self.evt.clear() self.evt.clear()
...@@ -733,7 +767,7 @@ class SimSMTPChannel(smtpd.SMTPChannel): ...@@ -733,7 +767,7 @@ class SimSMTPChannel(smtpd.SMTPChannel):
try: try:
user, hashed_pass = logpass.split() user, hashed_pass = logpass.split()
except ValueError as e: except ValueError as e:
self.push('535 Splitting response {!r} into user and password' self.push('535 Splitting response {!r} into user and password '
'failed: {}'.format(logpass, e)) 'failed: {}'.format(logpass, e))
return False return False
valid_hashed_pass = hmac.HMAC( valid_hashed_pass = hmac.HMAC(
...@@ -816,6 +850,7 @@ class SimSMTPServer(smtpd.SMTPServer): ...@@ -816,6 +850,7 @@ class SimSMTPServer(smtpd.SMTPServer):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
self._extra_features = [] self._extra_features = []
self._addresses = {}
smtpd.SMTPServer.__init__(self, *args, **kw) smtpd.SMTPServer.__init__(self, *args, **kw)
def handle_accepted(self, conn, addr): def handle_accepted(self, conn, addr):
...@@ -824,7 +859,8 @@ class SimSMTPServer(smtpd.SMTPServer): ...@@ -824,7 +859,8 @@ class SimSMTPServer(smtpd.SMTPServer):
decode_data=self._decode_data) decode_data=self._decode_data)
def process_message(self, peer, mailfrom, rcpttos, data): def process_message(self, peer, mailfrom, rcpttos, data):
pass self._addresses['from'] = mailfrom
self._addresses['tos'] = rcpttos
def add_feature(self, feature): def add_feature(self, feature):
self._extra_features.append(feature) self._extra_features.append(feature)
...@@ -1064,6 +1100,34 @@ class SMTPSimTests(unittest.TestCase): ...@@ -1064,6 +1100,34 @@ class SMTPSimTests(unittest.TestCase):
self.assertRaises(UnicodeEncodeError, smtp.sendmail, 'Alice', 'Böb', '') self.assertRaises(UnicodeEncodeError, smtp.sendmail, 'Alice', 'Böb', '')
self.assertRaises(UnicodeEncodeError, smtp.mail, 'Älice') self.assertRaises(UnicodeEncodeError, smtp.mail, 'Älice')
def test_send_message_error_on_non_ascii_addrs_if_no_smtputf8(self):
# This test is located here and not in the SMTPUTF8SimTests
# class because it needs a "regular" SMTP server to work
msg = EmailMessage()
msg['From'] = "Páolo <főo@bar.com>"
msg['To'] = 'Dinsdale'
msg['Subject'] = 'Nudge nudge, wink, wink \u1F609'
smtp = smtplib.SMTP(
HOST, self.port, local_hostname='localhost', timeout=3)
self.addCleanup(smtp.close)
with self.assertRaises(smtplib.SMTPNotSupportedError):
smtp.send_message(msg)
def test_name_field_not_included_in_envelop_addresses(self):
smtp = smtplib.SMTP(
HOST, self.port, local_hostname='localhost', timeout=3
)
self.addCleanup(smtp.close)
message = EmailMessage()
message['From'] = email.utils.formataddr(('Michaël', 'michael@example.com'))
message['To'] = email.utils.formataddr(('René', 'rene@example.com'))
self.assertDictEqual(smtp.send_message(message), {})
self.assertEqual(self.serv._addresses['from'], 'michael@example.com')
self.assertEqual(self.serv._addresses['tos'], ['rene@example.com'])
class SimSMTPUTF8Server(SimSMTPServer): class SimSMTPUTF8Server(SimSMTPServer):
...@@ -1194,17 +1258,6 @@ class SMTPUTF8SimTests(unittest.TestCase): ...@@ -1194,17 +1258,6 @@ class SMTPUTF8SimTests(unittest.TestCase):
self.assertIn('SMTPUTF8', self.serv.last_mail_options) self.assertIn('SMTPUTF8', self.serv.last_mail_options)
self.assertEqual(self.serv.last_rcpt_options, []) self.assertEqual(self.serv.last_rcpt_options, [])
def test_send_message_error_on_non_ascii_addrs_if_no_smtputf8(self):
msg = EmailMessage()
msg['From'] = "Páolo <főo@bar.com>"
msg['To'] = 'Dinsdale'
msg['Subject'] = 'Nudge nudge, wink, wink \u1F609'
smtp = smtplib.SMTP(
HOST, self.port, local_hostname='localhost', timeout=3)
self.addCleanup(smtp.close)
self.assertRaises(smtplib.SMTPNotSupportedError,
smtp.send_message(msg))
EXPECTED_RESPONSE = encode_base64(b'\0psu\0doesnotexist', eol='') EXPECTED_RESPONSE = encode_base64(b'\0psu\0doesnotexist', eol='')
...@@ -1273,18 +1326,5 @@ class SMTPAUTHInitialResponseSimTests(unittest.TestCase): ...@@ -1273,18 +1326,5 @@ class SMTPAUTHInitialResponseSimTests(unittest.TestCase):
self.assertEqual(code, 235) self.assertEqual(code, 235)
@support.reap_threads
def test_main(verbose=None):
support.run_unittest(
BadHELOServerTests,
DebuggingServerTests,
GeneralTests,
NonConnectingTests,
SMTPAUTHInitialResponseSimTests,
SMTPSimTests,
TooLongLineTests,
)
if __name__ == '__main__': if __name__ == '__main__':
test_main() unittest.main()
...@@ -32,6 +32,7 @@ except ImportError: ...@@ -32,6 +32,7 @@ except ImportError:
HOST = support.HOST HOST = support.HOST
MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return
MAIN_TIMEOUT = 60.0
try: try:
import _thread as thread import _thread as thread
...@@ -245,6 +246,9 @@ class ThreadableTest: ...@@ -245,6 +246,9 @@ class ThreadableTest:
self.server_ready.set() self.server_ready.set()
def _setUp(self): def _setUp(self):
self.wait_threads = support.wait_threads_exit()
self.wait_threads.__enter__()
self.server_ready = threading.Event() self.server_ready = threading.Event()
self.client_ready = threading.Event() self.client_ready = threading.Event()
self.done = threading.Event() self.done = threading.Event()
...@@ -271,6 +275,7 @@ class ThreadableTest: ...@@ -271,6 +275,7 @@ class ThreadableTest:
def _tearDown(self): def _tearDown(self):
self.__tearDown() self.__tearDown()
self.done.wait() self.done.wait()
self.wait_threads.__exit__(None, None, None)
if self.queue.qsize(): if self.queue.qsize():
exc = self.queue.get() exc = self.queue.get()
...@@ -811,11 +816,6 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -811,11 +816,6 @@ class GeneralModuleTests(unittest.TestCase):
self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names))) self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
def test_host_resolution(self): def test_host_resolution(self):
for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2',
'1:1:1:1:1:1:1:1:1']:
self.assertRaises(OSError, socket.gethostbyname, addr)
self.assertRaises(OSError, socket.gethostbyaddr, addr)
for addr in [support.HOST, '10.0.0.1', '255.255.255.255']: for addr in [support.HOST, '10.0.0.1', '255.255.255.255']:
self.assertEqual(socket.gethostbyname(addr), addr) self.assertEqual(socket.gethostbyname(addr), addr)
...@@ -824,6 +824,21 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -824,6 +824,21 @@ class GeneralModuleTests(unittest.TestCase):
for host in [support.HOST]: for host in [support.HOST]:
self.assertIn(host, socket.gethostbyaddr(host)[2]) self.assertIn(host, socket.gethostbyaddr(host)[2])
def test_host_resolution_bad_address(self):
# These are all malformed IP addresses and expected not to resolve to
# any result. But some ISPs, e.g. AWS, may successfully resolve these
# IPs.
explanation = (
"resolving an invalid IP address did not raise OSError; "
"can be caused by a broken DNS server"
)
for addr in ['0.1.1.~1', '1+.1.1.1', '::1q', '::1::2',
'1:1:1:1:1:1:1:1:1']:
with self.assertRaises(OSError):
socket.gethostbyname(addr)
with self.assertRaises(OSError, msg=explanation):
socket.gethostbyaddr(addr)
@unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()") @unittest.skipUnless(hasattr(socket, 'sethostname'), "test needs socket.sethostname()")
@unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()") @unittest.skipUnless(hasattr(socket, 'gethostname'), "test needs socket.gethostname()")
def test_sethostname(self): def test_sethostname(self):
...@@ -904,6 +919,7 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -904,6 +919,7 @@ class GeneralModuleTests(unittest.TestCase):
self.assertEqual(swapped & mask, mask) self.assertEqual(swapped & mask, mask)
self.assertRaises(OverflowError, func, 1<<34) self.assertRaises(OverflowError, func, 1<<34)
@support.cpython_only
def testNtoHErrors(self): def testNtoHErrors(self):
good_values = [ 1, 2, 3, 1, 2, 3 ] good_values = [ 1, 2, 3, 1, 2, 3 ]
bad_values = [ -1, -2, -3, -1, -2, -3 ] bad_values = [ -1, -2, -3, -1, -2, -3 ]
...@@ -1354,7 +1370,7 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1354,7 +1370,7 @@ class GeneralModuleTests(unittest.TestCase):
socket.gethostbyname(domain) socket.gethostbyname(domain)
socket.gethostbyname_ex(domain) socket.gethostbyname_ex(domain)
socket.getaddrinfo(domain,0,socket.AF_UNSPEC,socket.SOCK_STREAM) socket.getaddrinfo(domain,0,socket.AF_UNSPEC,socket.SOCK_STREAM)
# this may not work if the forward lookup choses the IPv6 address, as that doesn't # this may not work if the forward lookup chooses the IPv6 address, as that doesn't
# have a reverse entry yet # have a reverse entry yet
# socket.gethostbyaddr('испытание.python.org') # socket.gethostbyaddr('испытание.python.org')
...@@ -1766,33 +1782,6 @@ class RDSTest(ThreadedRDSSocketTest): ...@@ -1766,33 +1782,6 @@ class RDSTest(ThreadedRDSSocketTest):
self.data = b'select' self.data = b'select'
self.cli.sendto(self.data, 0, (HOST, self.port)) self.cli.sendto(self.data, 0, (HOST, self.port))
def testCongestion(self):
# wait until the sender is done
self.evt.wait()
def _testCongestion(self):
# test the behavior in case of congestion
self.data = b'fill'
self.cli.setblocking(False)
try:
# try to lower the receiver's socket buffer size
self.cli.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 16384)
except OSError:
pass
with self.assertRaises(OSError) as cm:
try:
# fill the receiver's socket buffer
while True:
self.cli.sendto(self.data, 0, (HOST, self.port))
finally:
# signal the receiver we're done
self.evt.set()
# sendto() should have failed with ENOBUFS
self.assertEqual(cm.exception.errno, errno.ENOBUFS)
# and we should have received a congestion notification through poll
r, w, x = select.select([self.serv], [], [], 3.0)
self.assertIn(self.serv, r)
@unittest.skipUnless(thread, 'Threading required for this test.') @unittest.skipUnless(thread, 'Threading required for this test.')
class BasicTCPTest(SocketConnectedTest): class BasicTCPTest(SocketConnectedTest):
...@@ -2307,9 +2296,18 @@ class SendmsgStreamTests(SendmsgTests): ...@@ -2307,9 +2296,18 @@ class SendmsgStreamTests(SendmsgTests):
def _testSendmsgTimeout(self): def _testSendmsgTimeout(self):
try: try:
self.cli_sock.settimeout(0.03) self.cli_sock.settimeout(0.03)
with self.assertRaises(socket.timeout): try:
while True: while True:
self.sendmsgToServer([b"a"*512]) self.sendmsgToServer([b"a"*512])
except socket.timeout:
pass
except OSError as exc:
if exc.errno != errno.ENOMEM:
raise
# bpo-33937 the test randomly fails on Travis CI with
# "OSError: [Errno 12] Cannot allocate memory"
else:
self.fail("socket.timeout not raised")
finally: finally:
self.misc_event.set() self.misc_event.set()
...@@ -2332,8 +2330,10 @@ class SendmsgStreamTests(SendmsgTests): ...@@ -2332,8 +2330,10 @@ class SendmsgStreamTests(SendmsgTests):
with self.assertRaises(OSError) as cm: with self.assertRaises(OSError) as cm:
while True: while True:
self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT) self.sendmsgToServer([b"a"*512], [], socket.MSG_DONTWAIT)
# bpo-33937: catch also ENOMEM, the test randomly fails on Travis CI
# with "OSError: [Errno 12] Cannot allocate memory"
self.assertIn(cm.exception.errno, self.assertIn(cm.exception.errno,
(errno.EAGAIN, errno.EWOULDBLOCK)) (errno.EAGAIN, errno.EWOULDBLOCK, errno.ENOMEM))
finally: finally:
self.misc_event.set() self.misc_event.set()
...@@ -2867,10 +2867,11 @@ class SCMRightsTest(SendrecvmsgServerTimeoutBase): ...@@ -2867,10 +2867,11 @@ class SCMRightsTest(SendrecvmsgServerTimeoutBase):
def testFDPassSeparateMinSpace(self): def testFDPassSeparateMinSpace(self):
# Pass two FDs in two separate arrays, receiving them into the # Pass two FDs in two separate arrays, receiving them into the
# minimum space for two arrays. # minimum space for two arrays.
self.checkRecvmsgFDs(2, num_fds = 2
self.checkRecvmsgFDs(num_fds,
self.doRecvmsg(self.serv_sock, len(MSG), self.doRecvmsg(self.serv_sock, len(MSG),
socket.CMSG_SPACE(SIZEOF_INT) + socket.CMSG_SPACE(SIZEOF_INT) +
socket.CMSG_LEN(SIZEOF_INT)), socket.CMSG_LEN(SIZEOF_INT * num_fds)),
maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC) maxcmsgs=2, ignoreflags=socket.MSG_CTRUNC)
@testFDPassSeparateMinSpace.client_skip @testFDPassSeparateMinSpace.client_skip
...@@ -3694,7 +3695,6 @@ class InterruptedTimeoutBase(unittest.TestCase): ...@@ -3694,7 +3695,6 @@ class InterruptedTimeoutBase(unittest.TestCase):
orig_alrm_handler = signal.signal(signal.SIGALRM, orig_alrm_handler = signal.signal(signal.SIGALRM,
lambda signum, frame: 1 / 0) lambda signum, frame: 1 / 0)
self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)
self.addCleanup(self.setAlarm, 0)
# Timeout for socket operations # Timeout for socket operations
timeout = 4.0 timeout = 4.0
...@@ -3731,9 +3731,12 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): ...@@ -3731,9 +3731,12 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
def checkInterruptedRecv(self, func, *args, **kwargs): def checkInterruptedRecv(self, func, *args, **kwargs):
# Check that func(*args, **kwargs) raises # Check that func(*args, **kwargs) raises
# errno of EINTR when interrupted by a signal. # errno of EINTR when interrupted by a signal.
self.setAlarm(self.alarm_time) try:
with self.assertRaises(ZeroDivisionError) as cm: self.setAlarm(self.alarm_time)
func(*args, **kwargs) with self.assertRaises(ZeroDivisionError) as cm:
func(*args, **kwargs)
finally:
self.setAlarm(0)
def testInterruptedRecvTimeout(self): def testInterruptedRecvTimeout(self):
self.checkInterruptedRecv(self.serv.recv, 1024) self.checkInterruptedRecv(self.serv.recv, 1024)
...@@ -3789,10 +3792,13 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase, ...@@ -3789,10 +3792,13 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
# Check that func(*args, **kwargs), run in a loop, raises # Check that func(*args, **kwargs), run in a loop, raises
# OSError with an errno of EINTR when interrupted by a # OSError with an errno of EINTR when interrupted by a
# signal. # signal.
with self.assertRaises(ZeroDivisionError) as cm: try:
while True: with self.assertRaises(ZeroDivisionError) as cm:
self.setAlarm(self.alarm_time) while True:
func(*args, **kwargs) self.setAlarm(self.alarm_time)
func(*args, **kwargs)
finally:
self.setAlarm(0)
# Issue #12958: The following tests have problems on OS X prior to 10.7 # Issue #12958: The following tests have problems on OS X prior to 10.7
@support.requires_mac_ver(10, 7) @support.requires_mac_ver(10, 7)
...@@ -3873,6 +3879,7 @@ class BasicSocketPairTest(SocketPairTest): ...@@ -3873,6 +3879,7 @@ class BasicSocketPairTest(SocketPairTest):
class NonBlockingTCPTests(ThreadedTCPSocketTest): class NonBlockingTCPTests(ThreadedTCPSocketTest):
def __init__(self, methodName='runTest'): def __init__(self, methodName='runTest'):
self.event = threading.Event()
ThreadedTCPSocketTest.__init__(self, methodName=methodName) ThreadedTCPSocketTest.__init__(self, methodName=methodName)
def testSetBlocking(self): def testSetBlocking(self):
...@@ -3947,22 +3954,27 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): ...@@ -3947,22 +3954,27 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
def testAccept(self): def testAccept(self):
# Testing non-blocking accept # Testing non-blocking accept
self.serv.setblocking(0) self.serv.setblocking(0)
try:
conn, addr = self.serv.accept() # connect() didn't start: non-blocking accept() fails
except OSError: with self.assertRaises(BlockingIOError):
pass
else:
self.fail("Error trying to do non-blocking accept.")
read, write, err = select.select([self.serv], [], [])
if self.serv in read:
conn, addr = self.serv.accept() conn, addr = self.serv.accept()
self.assertIsNone(conn.gettimeout())
conn.close() self.event.set()
else:
read, write, err = select.select([self.serv], [], [], MAIN_TIMEOUT)
if self.serv not in read:
self.fail("Error trying to do accept after select.") self.fail("Error trying to do accept after select.")
# connect() completed: non-blocking accept() doesn't block
conn, addr = self.serv.accept()
self.addCleanup(conn.close)
self.assertIsNone(conn.gettimeout())
def _testAccept(self): def _testAccept(self):
time.sleep(0.1) # don't connect before event is set to check
# that non-blocking accept() raises BlockingIOError
self.event.wait()
self.cli.connect((HOST, self.port)) self.cli.connect((HOST, self.port))
def testConnect(self): def testConnect(self):
...@@ -3977,25 +3989,32 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): ...@@ -3977,25 +3989,32 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
def testRecv(self): def testRecv(self):
# Testing non-blocking recv # Testing non-blocking recv
conn, addr = self.serv.accept() conn, addr = self.serv.accept()
self.addCleanup(conn.close)
conn.setblocking(0) conn.setblocking(0)
try:
msg = conn.recv(len(MSG)) # the server didn't send data yet: non-blocking recv() fails
except OSError: with self.assertRaises(BlockingIOError):
pass
else:
self.fail("Error trying to do non-blocking recv.")
read, write, err = select.select([conn], [], [])
if conn in read:
msg = conn.recv(len(MSG)) msg = conn.recv(len(MSG))
conn.close()
self.assertEqual(msg, MSG) self.event.set()
else:
read, write, err = select.select([conn], [], [], MAIN_TIMEOUT)
if conn not in read:
self.fail("Error during select call to non-blocking socket.") self.fail("Error during select call to non-blocking socket.")
# the server sent data yet: non-blocking recv() doesn't block
msg = conn.recv(len(MSG))
self.assertEqual(msg, MSG)
def _testRecv(self): def _testRecv(self):
self.cli.connect((HOST, self.port)) self.cli.connect((HOST, self.port))
time.sleep(0.1)
self.cli.send(MSG) # don't send anything before event is set to check
# that non-blocking recv() raises BlockingIOError
self.event.wait()
# send data: recv() will no longer block
self.cli.sendall(MSG)
@unittest.skipUnless(thread, 'Threading required for this test.') @unittest.skipUnless(thread, 'Threading required for this test.')
class FileObjectClassTestCase(SocketConnectedTest): class FileObjectClassTestCase(SocketConnectedTest):
...@@ -4199,12 +4218,12 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): ...@@ -4199,12 +4218,12 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
self.write_file.write(self.write_msg) self.write_file.write(self.write_msg)
self.write_file.flush() self.write_file.flush()
@support.refcount_test
def testMakefileCloseSocketDestroy(self): def testMakefileCloseSocketDestroy(self):
if hasattr(sys, "getrefcount"): refcount_before = sys.getrefcount(self.cli_conn)
refcount_before = sys.getrefcount(self.cli_conn) self.read_file.close()
self.read_file.close() refcount_after = sys.getrefcount(self.cli_conn)
refcount_after = sys.getrefcount(self.cli_conn) self.assertEqual(refcount_before - 1, refcount_after)
self.assertEqual(refcount_before - 1, refcount_after)
def _testMakefileCloseSocketDestroy(self): def _testMakefileCloseSocketDestroy(self):
pass pass
...@@ -4237,7 +4256,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): ...@@ -4237,7 +4256,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
self.write_file.write(self.write_msg) self.write_file.write(self.write_msg)
self.write_file.flush() self.write_file.flush()
self.evt2.set() self.evt2.set()
# Avoid cloding the socket before the server test has finished, # Avoid closing the socket before the server test has finished,
# otherwise system recv() will return 0 instead of EWOULDBLOCK. # otherwise system recv() will return 0 instead of EWOULDBLOCK.
self.serv_finished.wait(5.0) self.serv_finished.wait(5.0)
...@@ -4371,6 +4390,10 @@ class NetworkConnectionNoServer(unittest.TestCase): ...@@ -4371,6 +4390,10 @@ class NetworkConnectionNoServer(unittest.TestCase):
expected_errnos = [ errno.ECONNREFUSED, ] expected_errnos = [ errno.ECONNREFUSED, ]
if hasattr(errno, 'ENETUNREACH'): if hasattr(errno, 'ENETUNREACH'):
expected_errnos.append(errno.ENETUNREACH) expected_errnos.append(errno.ENETUNREACH)
if hasattr(errno, 'EADDRNOTAVAIL'):
# bpo-31910: socket.create_connection() fails randomly
# with EADDRNOTAVAIL on Travis CI
expected_errnos.append(errno.EADDRNOTAVAIL)
self.assertIn(cm.exception.errno, expected_errnos) self.assertIn(cm.exception.errno, expected_errnos)
...@@ -4519,8 +4542,8 @@ class TCPTimeoutTest(SocketTCPTest): ...@@ -4519,8 +4542,8 @@ class TCPTimeoutTest(SocketTCPTest):
raise Alarm raise Alarm
old_alarm = signal.signal(signal.SIGALRM, alarm_handler) old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
try: try:
signal.alarm(2) # POSIX allows alarm to be up to 1 second early
try: try:
signal.alarm(2) # POSIX allows alarm to be up to 1 second early
foo = self.serv.accept() foo = self.serv.accept()
except socket.timeout: except socket.timeout:
self.fail("caught timeout instead of Alarm") self.fail("caught timeout instead of Alarm")
...@@ -4658,6 +4681,10 @@ class TestUnixDomain(unittest.TestCase): ...@@ -4658,6 +4681,10 @@ class TestUnixDomain(unittest.TestCase):
else: else:
raise raise
def testUnbound(self):
# Issue #30205
self.assertIn(self.sock.getsockname(), ('', None))
def testStrAddr(self): def testStrAddr(self):
# Test binding to and retrieving a normal string pathname. # Test binding to and retrieving a normal string pathname.
path = os.path.abspath(support.TESTFN) path = os.path.abspath(support.TESTFN)
...@@ -5345,11 +5372,10 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest): ...@@ -5345,11 +5372,10 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest):
def _testWithTimeoutTriggeredSend(self): def _testWithTimeoutTriggeredSend(self):
address = self.serv.getsockname() address = self.serv.getsockname()
file = open(support.TESTFN, 'rb') with open(support.TESTFN, 'rb') as file:
with socket.create_connection(address, timeout=0.01) as sock, \ with socket.create_connection(address, timeout=0.01) as sock:
file as file: meth = self.meth_from_sock(sock)
meth = self.meth_from_sock(sock) self.assertRaises(socket.timeout, meth, file)
self.assertRaises(socket.timeout, meth, file)
def testWithTimeoutTriggeredSend(self): def testWithTimeoutTriggeredSend(self):
conn = self.accept_conn() conn = self.accept_conn()
...@@ -5409,6 +5435,9 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5409,6 +5435,9 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
else: else:
return sock return sock
# bpo-31705: On kernel older than 4.5, sendto() failed with ENOKEY,
# at least on ppc64le architecture
@support.requires_linux_version(4, 5)
def test_sha256(self): def test_sha256(self):
expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396" expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396"
"177a9cb410ff61f20015ad") "177a9cb410ff61f20015ad")
...@@ -5468,7 +5497,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5468,7 +5497,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
op=socket.ALG_OP_ENCRYPT, iv=iv) op=socket.ALG_OP_ENCRYPT, iv=iv)
enc = op.recv(msglen * multiplier) enc = op.recv(msglen * multiplier)
self.assertEqual(len(enc), msglen * multiplier) self.assertEqual(len(enc), msglen * multiplier)
self.assertTrue(enc[:msglen], ciphertext) self.assertEqual(enc[:msglen], ciphertext)
op, _ = algo.accept() op, _ = algo.accept()
with op: with op:
...@@ -5478,7 +5507,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5478,7 +5507,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
self.assertEqual(len(dec), msglen * multiplier) self.assertEqual(len(dec), msglen * multiplier)
self.assertEqual(dec, msg * multiplier) self.assertEqual(dec, msg * multiplier)
@support.requires_linux_version(4, 3) # see test_aes_cbc @support.requires_linux_version(4, 9) # see issue29324
def test_aead_aes_gcm(self): def test_aead_aes_gcm(self):
key = bytes.fromhex('c939cc13397c1d37de6ae0e1cb7c423c') key = bytes.fromhex('c939cc13397c1d37de6ae0e1cb7c423c')
iv = bytes.fromhex('b3d8cc017cbb89b39e0f67e2') iv = bytes.fromhex('b3d8cc017cbb89b39e0f67e2')
...@@ -5501,8 +5530,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5501,8 +5530,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv, op.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, iv=iv,
assoclen=assoclen, flags=socket.MSG_MORE) assoclen=assoclen, flags=socket.MSG_MORE)
op.sendall(assoc, socket.MSG_MORE) op.sendall(assoc, socket.MSG_MORE)
op.sendall(plain, socket.MSG_MORE) op.sendall(plain)
op.sendall(b'\x00' * taglen)
res = op.recv(assoclen + len(plain) + taglen) res = op.recv(assoclen + len(plain) + taglen)
self.assertEqual(expected_ct, res[assoclen:-taglen]) self.assertEqual(expected_ct, res[assoclen:-taglen])
self.assertEqual(expected_tag, res[-taglen:]) self.assertEqual(expected_tag, res[-taglen:])
...@@ -5510,7 +5538,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5510,7 +5538,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
# now with msg # now with msg
op, _ = algo.accept() op, _ = algo.accept()
with op: with op:
msg = assoc + plain + b'\x00' * taglen msg = assoc + plain
op.sendmsg_afalg([msg], op=socket.ALG_OP_ENCRYPT, iv=iv, op.sendmsg_afalg([msg], op=socket.ALG_OP_ENCRYPT, iv=iv,
assoclen=assoclen) assoclen=assoclen)
res = op.recv(assoclen + len(plain) + taglen) res = op.recv(assoclen + len(plain) + taglen)
...@@ -5521,7 +5549,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5521,7 +5549,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
pack_uint32 = struct.Struct('I').pack pack_uint32 = struct.Struct('I').pack
op, _ = algo.accept() op, _ = algo.accept()
with op: with op:
msg = assoc + plain + b'\x00' * taglen msg = assoc + plain
op.sendmsg( op.sendmsg(
[msg], [msg],
([socket.SOL_ALG, socket.ALG_SET_OP, pack_uint32(socket.ALG_OP_ENCRYPT)], ([socket.SOL_ALG, socket.ALG_SET_OP, pack_uint32(socket.ALG_OP_ENCRYPT)],
...@@ -5529,7 +5557,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5529,7 +5557,7 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
[socket.SOL_ALG, socket.ALG_SET_AEAD_ASSOCLEN, pack_uint32(assoclen)], [socket.SOL_ALG, socket.ALG_SET_AEAD_ASSOCLEN, pack_uint32(assoclen)],
) )
) )
res = op.recv(len(msg)) res = op.recv(len(msg) + taglen)
self.assertEqual(expected_ct, res[assoclen:-taglen]) self.assertEqual(expected_ct, res[assoclen:-taglen])
self.assertEqual(expected_tag, res[-taglen:]) self.assertEqual(expected_tag, res[-taglen:])
...@@ -5539,8 +5567,8 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5539,8 +5567,8 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
msg = assoc + expected_ct + expected_tag msg = assoc + expected_ct + expected_tag
op.sendmsg_afalg([msg], op=socket.ALG_OP_DECRYPT, iv=iv, op.sendmsg_afalg([msg], op=socket.ALG_OP_DECRYPT, iv=iv,
assoclen=assoclen) assoclen=assoclen)
res = op.recv(len(msg)) res = op.recv(len(msg) - taglen)
self.assertEqual(plain, res[assoclen:-taglen]) self.assertEqual(plain, res[assoclen:])
@support.requires_linux_version(4, 3) # see test_aes_cbc @support.requires_linux_version(4, 3) # see test_aes_cbc
def test_drbg_pr_sha256(self): def test_drbg_pr_sha256(self):
...@@ -5571,6 +5599,42 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5571,6 +5599,42 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=-1) sock.sendmsg_afalg(op=socket.ALG_OP_ENCRYPT, assoclen=-1)
def test_length_restriction(self):
# bpo-35050, off-by-one error in length check
sock = socket.socket(socket.AF_ALG, socket.SOCK_SEQPACKET, 0)
self.addCleanup(sock.close)
# salg_type[14]
with self.assertRaises(FileNotFoundError):
sock.bind(("t" * 13, "name"))
with self.assertRaisesRegex(ValueError, "type too long"):
sock.bind(("t" * 14, "name"))
# salg_name[64]
with self.assertRaises(FileNotFoundError):
sock.bind(("type", "n" * 63))
with self.assertRaisesRegex(ValueError, "name too long"):
sock.bind(("type", "n" * 64))
@unittest.skipUnless(sys.platform.startswith("win"), "requires Windows")
class TestMSWindowsTCPFlags(unittest.TestCase):
knownTCPFlags = {
# available since long time ago
'TCP_MAXSEG',
'TCP_NODELAY',
# available starting with Windows 10 1607
'TCP_FASTOPEN',
# available starting with Windows 10 1703
'TCP_KEEPCNT',
}
def test_new_tcp_flags(self):
provided = [s for s in dir(socket) if s.startswith('TCP')]
unknown = [s for s in provided if s not in self.knownTCPFlags]
self.assertEqual([], unknown,
"New TCP flags were discovered. See bpo-32394 for more information")
def test_main(): def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
...@@ -5627,6 +5691,7 @@ def test_main(): ...@@ -5627,6 +5691,7 @@ def test_main():
SendfileUsingSendTest, SendfileUsingSendTest,
SendfileUsingSendfileTest, SendfileUsingSendfileTest,
]) ])
tests.append(TestMSWindowsTCPFlags)
thread_info = support.threading_setup() thread_info = support.threading_setup()
support.run_unittest(*tests) support.run_unittest(*tests)
......
...@@ -48,11 +48,11 @@ def receive(sock, n, timeout=20): ...@@ -48,11 +48,11 @@ def receive(sock, n, timeout=20):
if HAVE_UNIX_SOCKETS and HAVE_FORKING: if HAVE_UNIX_SOCKETS and HAVE_FORKING:
class ForkingUnixStreamServer(socketserver.ForkingMixIn, class ForkingUnixStreamServer(socketserver.ForkingMixIn,
socketserver.UnixStreamServer): socketserver.UnixStreamServer):
pass _block_on_close = True
class ForkingUnixDatagramServer(socketserver.ForkingMixIn, class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
socketserver.UnixDatagramServer): socketserver.UnixDatagramServer):
pass _block_on_close = True
@contextlib.contextmanager @contextlib.contextmanager
...@@ -62,10 +62,14 @@ def simple_subprocess(testcase): ...@@ -62,10 +62,14 @@ def simple_subprocess(testcase):
if pid == 0: if pid == 0:
# Don't raise an exception; it would be caught by the test harness. # Don't raise an exception; it would be caught by the test harness.
os._exit(72) os._exit(72)
yield None try:
pid2, status = os.waitpid(pid, 0) yield None
testcase.assertEqual(pid2, pid) except:
testcase.assertEqual(72 << 8, status) raise
finally:
pid2, status = os.waitpid(pid, 0)
testcase.assertEqual(pid2, pid)
testcase.assertEqual(72 << 8, status)
@unittest.skipUnless(threading, 'Threading required for this test.') @unittest.skipUnless(threading, 'Threading required for this test.')
...@@ -101,6 +105,8 @@ class SocketServerTest(unittest.TestCase): ...@@ -101,6 +105,8 @@ class SocketServerTest(unittest.TestCase):
def make_server(self, addr, svrcls, hdlrbase): def make_server(self, addr, svrcls, hdlrbase):
class MyServer(svrcls): class MyServer(svrcls):
_block_on_close = True
def handle_error(self, request, client_address): def handle_error(self, request, client_address):
self.close_request(request) self.close_request(request)
raise raise
...@@ -144,6 +150,10 @@ class SocketServerTest(unittest.TestCase): ...@@ -144,6 +150,10 @@ class SocketServerTest(unittest.TestCase):
t.join() t.join()
server.server_close() server.server_close()
self.assertEqual(-1, server.socket.fileno()) self.assertEqual(-1, server.socket.fileno())
if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
# bpo-31151: Check that ForkingMixIn.server_close() waits until
# all children completed
self.assertFalse(server.active_children)
if verbose: print("done") if verbose: print("done")
def stream_examine(self, proto, addr): def stream_examine(self, proto, addr):
...@@ -292,6 +302,7 @@ class ErrorHandlerTest(unittest.TestCase): ...@@ -292,6 +302,7 @@ class ErrorHandlerTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
test.support.unlink(test.support.TESTFN) test.support.unlink(test.support.TESTFN)
reap_children()
def test_sync_handled(self): def test_sync_handled(self):
BaseErrorTestServer(ValueError) BaseErrorTestServer(ValueError)
...@@ -329,6 +340,8 @@ class ErrorHandlerTest(unittest.TestCase): ...@@ -329,6 +340,8 @@ class ErrorHandlerTest(unittest.TestCase):
class BaseErrorTestServer(socketserver.TCPServer): class BaseErrorTestServer(socketserver.TCPServer):
_block_on_close = True
def __init__(self, exception): def __init__(self, exception):
self.exception = exception self.exception = exception
super().__init__((HOST, 0), BadHandler) super().__init__((HOST, 0), BadHandler)
...@@ -371,10 +384,7 @@ class ThreadingErrorTestServer(socketserver.ThreadingMixIn, ...@@ -371,10 +384,7 @@ class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
if HAVE_FORKING: if HAVE_FORKING:
class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
def wait_done(self): _block_on_close = True
[child] = self.active_children
os.waitpid(child, 0)
self.active_children.clear()
class SocketWriterTest(unittest.TestCase): class SocketWriterTest(unittest.TestCase):
......
...@@ -17,7 +17,12 @@ import traceback ...@@ -17,7 +17,12 @@ import traceback
import asyncore import asyncore
import weakref import weakref
import platform import platform
import re
import functools import functools
try:
import ctypes
except ImportError:
ctypes = None
ssl = support.import_module("ssl") ssl = support.import_module("ssl")
...@@ -56,7 +61,6 @@ BYTES_CAPATH = os.fsencode(CAPATH) ...@@ -56,7 +61,6 @@ BYTES_CAPATH = os.fsencode(CAPATH)
CAFILE_NEURONIO = data_file("capath", "4e1295a3.0") CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
CAFILE_CACERT = data_file("capath", "5ed36f99.0") CAFILE_CACERT = data_file("capath", "5ed36f99.0")
# empty CRL # empty CRL
CRLFILE = data_file("revocation.crl") CRLFILE = data_file("revocation.crl")
...@@ -78,8 +82,9 @@ NONEXISTINGCERT = data_file("XXXnonexisting.pem") ...@@ -78,8 +82,9 @@ NONEXISTINGCERT = data_file("XXXnonexisting.pem")
BADKEY = data_file("badkey.pem") BADKEY = data_file("badkey.pem")
NOKIACERT = data_file("nokia.pem") NOKIACERT = data_file("nokia.pem")
NULLBYTECERT = data_file("nullbytecert.pem") NULLBYTECERT = data_file("nullbytecert.pem")
TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
DHFILE = data_file("dh1024.pem") DHFILE = data_file("ffdh3072.pem")
BYTES_DHFILE = os.fsencode(DHFILE) BYTES_DHFILE = os.fsencode(DHFILE)
# Not defined in all versions of OpenSSL # Not defined in all versions of OpenSSL
...@@ -87,6 +92,7 @@ OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0) ...@@ -87,6 +92,7 @@ OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0) OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0) OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0) OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
def handle_error(prefix): def handle_error(prefix):
...@@ -142,6 +148,38 @@ def skip_if_broken_ubuntu_ssl(func): ...@@ -142,6 +148,38 @@ def skip_if_broken_ubuntu_ssl(func):
else: else:
return func return func
def skip_if_openssl_cnf_minprotocol_gt_tls1(func):
"""Skip a test if the OpenSSL config MinProtocol is > TLSv1.
OS distros with an /etc/ssl/openssl.cnf and MinProtocol set often do so to
require TLSv1.2 or higher (Debian Buster). Some of our tests for older
protocol versions will fail under such a config.
Alternative workaround: Run this test in a process with
OPENSSL_CONF=/dev/null in the environment.
"""
@functools.wraps(func)
def f(*args, **kwargs):
openssl_cnf = os.environ.get("OPENSSL_CONF", "/etc/ssl/openssl.cnf")
try:
with open(openssl_cnf, "r") as config:
for line in config:
match = re.match(r"MinProtocol\s*=\s*(TLSv\d+\S*)", line)
if match:
tls_ver = match.group(1)
if tls_ver > "TLSv1":
raise unittest.SkipTest(
"%s has MinProtocol = %s which is > TLSv1." %
(openssl_cnf, tls_ver))
except (EnvironmentError, UnicodeDecodeError) as err:
# no config file found, etc.
if support.verbose:
sys.stdout.write("\n Could not scan %s for MinProtocol: %s\n"
% (openssl_cnf, err))
return func(*args, **kwargs)
return f
needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test") needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
...@@ -174,6 +212,13 @@ class BasicSocketTests(unittest.TestCase): ...@@ -174,6 +212,13 @@ class BasicSocketTests(unittest.TestCase):
ssl.OP_NO_COMPRESSION ssl.OP_NO_COMPRESSION
self.assertIn(ssl.HAS_SNI, {True, False}) self.assertIn(ssl.HAS_SNI, {True, False})
self.assertIn(ssl.HAS_ECDH, {True, False}) self.assertIn(ssl.HAS_ECDH, {True, False})
ssl.OP_NO_SSLv2
ssl.OP_NO_SSLv3
ssl.OP_NO_TLSv1
ssl.OP_NO_TLSv1_3
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1):
ssl.OP_NO_TLSv1_1
ssl.OP_NO_TLSv1_2
def test_str_for_enums(self): def test_str_for_enums(self):
# Make sure that the PROTOCOL_* constants have enum-like string # Make sure that the PROTOCOL_* constants have enum-like string
...@@ -242,6 +287,8 @@ class BasicSocketTests(unittest.TestCase): ...@@ -242,6 +287,8 @@ class BasicSocketTests(unittest.TestCase):
self.assertNotEqual(child_random, parent_random) self.assertNotEqual(child_random, parent_random)
maxDiff = None
def test_parse_cert(self): def test_parse_cert(self):
# note that this uses an 'unofficial' function in _ssl.c, # note that this uses an 'unofficial' function in _ssl.c,
# provided solely for this test, to exercise the certificate # provided solely for this test, to exercise the certificate
...@@ -256,9 +303,9 @@ class BasicSocketTests(unittest.TestCase): ...@@ -256,9 +303,9 @@ class BasicSocketTests(unittest.TestCase):
(('commonName', 'localhost'),)) (('commonName', 'localhost'),))
) )
# Note the next three asserts will fail if the keys are regenerated # Note the next three asserts will fail if the keys are regenerated
self.assertEqual(p['notAfter'], asn1time('Oct 5 23:01:56 2020 GMT')) self.assertEqual(p['notAfter'], asn1time('Aug 26 14:23:15 2028 GMT'))
self.assertEqual(p['notBefore'], asn1time('Oct 8 23:01:56 2010 GMT')) self.assertEqual(p['notBefore'], asn1time('Aug 29 14:23:15 2018 GMT'))
self.assertEqual(p['serialNumber'], 'D7C7381919AFC24E') self.assertEqual(p['serialNumber'], '98A7CF88C74A32ED')
self.assertEqual(p['subject'], self.assertEqual(p['subject'],
((('countryName', 'XY'),), ((('countryName', 'XY'),),
(('localityName', 'Castle Anthrax'),), (('localityName', 'Castle Anthrax'),),
...@@ -282,6 +329,27 @@ class BasicSocketTests(unittest.TestCase): ...@@ -282,6 +329,27 @@ class BasicSocketTests(unittest.TestCase):
self.assertEqual(p['crlDistributionPoints'], self.assertEqual(p['crlDistributionPoints'],
('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',)) ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
def test_parse_cert_CVE_2019_5010(self):
p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
if support.verbose:
sys.stdout.write("\n" + pprint.pformat(p) + "\n")
self.assertEqual(
p,
{
'issuer': (
(('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
'notAfter': 'Jun 14 18:00:58 2028 GMT',
'notBefore': 'Jun 18 18:00:58 2018 GMT',
'serialNumber': '02',
'subject': ((('countryName', 'UK'),),
(('commonName',
'codenomicon-vm-2.test.lal.cisco.com'),)),
'subjectAltName': (
('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
'version': 3
}
)
def test_parse_cert_CVE_2013_4238(self): def test_parse_cert_CVE_2013_4238(self):
p = ssl._ssl._test_decode_cert(NULLBYTECERT) p = ssl._ssl._test_decode_cert(NULLBYTECERT)
if support.verbose: if support.verbose:
...@@ -397,6 +465,12 @@ class BasicSocketTests(unittest.TestCase): ...@@ -397,6 +465,12 @@ class BasicSocketTests(unittest.TestCase):
self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1) self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
self.assertRaises(OSError, ss.send, b'x') self.assertRaises(OSError, ss.send, b'x')
self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0)) self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
self.assertRaises(NotImplementedError, ss.dup)
self.assertRaises(NotImplementedError, ss.sendmsg,
[b'x'], (), 0, ('0.0.0.0', 0))
self.assertRaises(NotImplementedError, ss.recvmsg, 100)
self.assertRaises(NotImplementedError, ss.recvmsg_into,
[bytearray(100)])
def test_timeout(self): def test_timeout(self):
# Issue #8524: when creating an SSL socket, the timeout of the # Issue #8524: when creating an SSL socket, the timeout of the
...@@ -820,7 +894,7 @@ class BasicSocketTests(unittest.TestCase): ...@@ -820,7 +894,7 @@ class BasicSocketTests(unittest.TestCase):
self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901) self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901)
self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds
# no special treatement for the special value: # no special treatment for the special value:
# 99991231235959Z (rfc 5280) # 99991231235959Z (rfc 5280)
self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0) self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
...@@ -890,12 +964,13 @@ class ContextTests(unittest.TestCase): ...@@ -890,12 +964,13 @@ class ContextTests(unittest.TestCase):
@skip_if_broken_ubuntu_ssl @skip_if_broken_ubuntu_ssl
def test_options(self): def test_options(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
# SSLContext also enables these by default # SSLContext also enables these by default
default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE | default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE) OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
OP_ENABLE_MIDDLEBOX_COMPAT)
self.assertEqual(default, ctx.options) self.assertEqual(default, ctx.options)
ctx.options |= ssl.OP_NO_TLSv1 ctx.options |= ssl.OP_NO_TLSv1
self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options) self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
...@@ -1678,9 +1753,10 @@ class SimpleBackgroundTests(unittest.TestCase): ...@@ -1678,9 +1753,10 @@ class SimpleBackgroundTests(unittest.TestCase):
self.assertEqual(len(ctx.get_ca_certs()), 1) self.assertEqual(len(ctx.get_ca_certs()), 1)
@needs_sni @needs_sni
@unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), "needs TLS 1.2")
def test_context_setget(self): def test_context_setget(self):
# Check that the context of a connected socket can be replaced. # Check that the context of a connected socket can be replaced.
ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
s = socket.socket(socket.AF_INET) s = socket.socket(socket.AF_INET)
with ctx1.wrap_socket(s) as ss: with ctx1.wrap_socket(s) as ss:
...@@ -1738,6 +1814,8 @@ class SimpleBackgroundTests(unittest.TestCase): ...@@ -1738,6 +1814,8 @@ class SimpleBackgroundTests(unittest.TestCase):
sslobj = ctx.wrap_bio(incoming, outgoing, False, 'localhost') sslobj = ctx.wrap_bio(incoming, outgoing, False, 'localhost')
self.assertIs(sslobj._sslobj.owner, sslobj) self.assertIs(sslobj._sslobj.owner, sslobj)
self.assertIsNone(sslobj.cipher()) self.assertIsNone(sslobj.cipher())
# cypthon implementation detail
# self.assertIsNone(sslobj.version())
self.assertIsNotNone(sslobj.shared_ciphers()) self.assertIsNotNone(sslobj.shared_ciphers())
self.assertRaises(ValueError, sslobj.getpeercert) self.assertRaises(ValueError, sslobj.getpeercert)
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
...@@ -1745,6 +1823,7 @@ class SimpleBackgroundTests(unittest.TestCase): ...@@ -1745,6 +1823,7 @@ class SimpleBackgroundTests(unittest.TestCase):
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
self.assertTrue(sslobj.cipher()) self.assertTrue(sslobj.cipher())
self.assertIsNotNone(sslobj.shared_ciphers()) self.assertIsNotNone(sslobj.shared_ciphers())
self.assertIsNotNone(sslobj.version())
self.assertTrue(sslobj.getpeercert()) self.assertTrue(sslobj.getpeercert())
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertTrue(sslobj.get_channel_binding('tls-unique')) self.assertTrue(sslobj.get_channel_binding('tls-unique'))
...@@ -1795,34 +1874,6 @@ class NetworkedTests(unittest.TestCase): ...@@ -1795,34 +1874,6 @@ class NetworkedTests(unittest.TestCase):
_test_get_server_certificate(self, 'ipv6.google.com', 443) _test_get_server_certificate(self, 'ipv6.google.com', 443)
_test_get_server_certificate_fail(self, 'ipv6.google.com', 443) _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
def test_algorithms(self):
# Issue #8484: all algorithms should be available when verifying a
# certificate.
# SHA256 was added in OpenSSL 0.9.8
if ssl.OPENSSL_VERSION_INFO < (0, 9, 8, 0, 15):
self.skipTest("SHA256 not available on %r" % ssl.OPENSSL_VERSION)
# sha256.tbs-internet.com needs SNI to use the correct certificate
if not ssl.HAS_SNI:
self.skipTest("SNI needed for this test")
# https://sha2.hboeck.de/ was used until 2011-01-08 (no route to host)
remote = ("sha256.tbs-internet.com", 443)
sha256_cert = os.path.join(os.path.dirname(__file__), "sha256.pem")
with support.transient_internet("sha256.tbs-internet.com"):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(sha256_cert)
s = ctx.wrap_socket(socket.socket(socket.AF_INET),
server_hostname="sha256.tbs-internet.com")
try:
s.connect(remote)
if support.verbose:
sys.stdout.write("\nCipher with %r is %r\n" %
(remote, s.cipher()))
sys.stdout.write("Certificate is:\n%s\n" %
pprint.pformat(s.getpeercert()))
finally:
s.close()
def _test_get_server_certificate(test, host, port, cert=None): def _test_get_server_certificate(test, host, port, cert=None):
pem = ssl.get_server_certificate((host, port)) pem = ssl.get_server_certificate((host, port))
...@@ -1873,15 +1924,34 @@ if _have_threads: ...@@ -1873,15 +1924,34 @@ if _have_threads:
self.sock, server_side=True) self.sock, server_side=True)
self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
except (ssl.SSLError, ConnectionResetError) as e: except (ConnectionResetError, BrokenPipeError) as e:
# We treat ConnectionResetError as though it were an # We treat ConnectionResetError as though it were an
# SSLError - OpenSSL on Ubuntu abruptly closes the # SSLError - OpenSSL on Ubuntu abruptly closes the
# connection when asked to use an unsupported protocol. # connection when asked to use an unsupported protocol.
# #
# BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
# tries to send session tickets after handshake.
# https://github.com/openssl/openssl/issues/6342
self.server.conn_errors.append(str(e))
if self.server.chatty:
handle_error(
"\n server: bad connection attempt from " + repr(
self.addr) + ":\n")
self.running = False
self.close()
return False
except (ssl.SSLError, OSError) as e:
# OSError may occur with wrong protocols, e.g. both
# sides use PROTOCOL_TLS_SERVER.
#
# XXX Various errors can have happened here, for example # XXX Various errors can have happened here, for example
# a mismatching protocol version, an invalid certificate, # a mismatching protocol version, an invalid certificate,
# or a low-level bug. This should be made more discriminating. # or a low-level bug. This should be made more discriminating.
self.server.conn_errors.append(e) #
# bpo-31323: Store the exception as string to prevent
# a reference leak: server -> conn_errors -> exception
# -> traceback -> self (ConnectionHandler) -> server
self.server.conn_errors.append(str(e))
if self.server.chatty: if self.server.chatty:
handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n")
self.running = False self.running = False
...@@ -1970,6 +2040,24 @@ if _have_threads: ...@@ -1970,6 +2040,24 @@ if _have_threads:
sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
data = self.sslconn.get_channel_binding("tls-unique") data = self.sslconn.get_channel_binding("tls-unique")
self.write(repr(data).encode("us-ascii") + b"\n") self.write(repr(data).encode("us-ascii") + b"\n")
elif stripped == b'PHA':
if support.verbose and self.server.connectionchatty:
sys.stdout.write(
" server: initiating post handshake auth\n")
try:
self.sslconn.verify_client_post_handshake()
except ssl.SSLError as e:
self.write(repr(e).encode("us-ascii") + b"\n")
else:
self.write(b"OK\n")
elif stripped == b'HASCERT':
if self.sslconn.getpeercert() is not None:
self.write(b'TRUE\n')
else:
self.write(b'FALSE\n')
elif stripped == b'GETCERT':
cert = self.sslconn.getpeercert()
self.write(repr(cert).encode("us-ascii") + b"\n")
else: else:
if (support.verbose and if (support.verbose and
self.server.connectionchatty): self.server.connectionchatty):
...@@ -1977,6 +2065,16 @@ if _have_threads: ...@@ -1977,6 +2065,16 @@ if _have_threads:
sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
% (msg, ctype, msg.lower(), ctype)) % (msg, ctype, msg.lower(), ctype))
self.write(msg.lower()) self.write(msg.lower())
except ConnectionResetError:
# XXX: OpenSSL 1.1.1 sometimes raises ConnectionResetError
# when connection is not shut down gracefully.
if self.server.chatty and support.verbose:
sys.stdout.write(
" Connection reset by peer: {}\n".format(
self.addr)
)
self.close()
self.running = False
except OSError: except OSError:
if self.server.chatty: if self.server.chatty:
handle_error("Test server failure:\n") handle_error("Test server failure:\n")
...@@ -1996,7 +2094,7 @@ if _have_threads: ...@@ -1996,7 +2094,7 @@ if _have_threads:
else: else:
self.context = ssl.SSLContext(ssl_version self.context = ssl.SSLContext(ssl_version
if ssl_version is not None if ssl_version is not None
else ssl.PROTOCOL_TLSv1) else ssl.PROTOCOL_TLS)
self.context.verify_mode = (certreqs if certreqs is not None self.context.verify_mode = (certreqs if certreqs is not None
else ssl.CERT_NONE) else ssl.CERT_NONE)
if cacerts: if cacerts:
...@@ -2056,6 +2154,11 @@ if _have_threads: ...@@ -2056,6 +2154,11 @@ if _have_threads:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
self.stop() self.stop()
except BaseException as e:
if support.verbose and self.chatty:
sys.stdout.write(
' connection handling failed: ' + repr(e) + '\n')
self.sock.close() self.sock.close()
def stop(self): def stop(self):
...@@ -2067,7 +2170,7 @@ if _have_threads: ...@@ -2067,7 +2170,7 @@ if _have_threads:
class EchoServer (asyncore.dispatcher): class EchoServer (asyncore.dispatcher):
class ConnectionHandler (asyncore.dispatcher_with_send): class ConnectionHandler(asyncore.dispatcher_with_send):
def __init__(self, conn, certfile): def __init__(self, conn, certfile):
self.socket = test_wrap_socket(conn, server_side=True, self.socket = test_wrap_socket(conn, server_side=True,
...@@ -2158,6 +2261,8 @@ if _have_threads: ...@@ -2158,6 +2261,8 @@ if _have_threads:
self.join() self.join()
if support.verbose: if support.verbose:
sys.stdout.write(" cleanup: successfully joined.\n") sys.stdout.write(" cleanup: successfully joined.\n")
# make sure that ConnectionHandler is removed from socket_map
asyncore.close_all(ignore_all=True)
def start (self, flag=None): def start (self, flag=None):
self.flag = flag self.flag = flag
...@@ -2468,10 +2573,10 @@ if _have_threads: ...@@ -2468,10 +2573,10 @@ if _have_threads:
connect to it with a wrong client certificate fails. connect to it with a wrong client certificate fails.
""" """
certfile = os.path.join(os.path.dirname(__file__) or os.curdir, certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
"wrongcert.pem") "keycert.pem")
server = ThreadedEchoServer(CERTFILE, server = ThreadedEchoServer(SIGNED_CERTFILE,
certreqs=ssl.CERT_REQUIRED, certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False, cacerts=SIGNING_CA, chatty=False,
connectionchatty=False) connectionchatty=False)
with server, \ with server, \
socket.socket() as sock, \ socket.socket() as sock, \
...@@ -2560,6 +2665,7 @@ if _have_threads: ...@@ -2560,6 +2665,7 @@ if _have_threads:
client_options=ssl.OP_NO_TLSv1) client_options=ssl.OP_NO_TLSv1)
@skip_if_broken_ubuntu_ssl @skip_if_broken_ubuntu_ssl
@skip_if_openssl_cnf_minprotocol_gt_tls1
def test_protocol_sslv23(self): def test_protocol_sslv23(self):
"""Connecting to an SSLv23 server with various client options""" """Connecting to an SSLv23 server with various client options"""
if support.verbose: if support.verbose:
...@@ -2637,6 +2743,7 @@ if _have_threads: ...@@ -2637,6 +2743,7 @@ if _have_threads:
@skip_if_broken_ubuntu_ssl @skip_if_broken_ubuntu_ssl
@unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
"TLS version 1.1 not supported.") "TLS version 1.1 not supported.")
@skip_if_openssl_cnf_minprotocol_gt_tls1
def test_protocol_tlsv1_1(self): def test_protocol_tlsv1_1(self):
"""Connecting to a TLSv1.1 server with various client options. """Connecting to a TLSv1.1 server with various client options.
Testing against older TLS versions.""" Testing against older TLS versions."""
...@@ -2905,16 +3012,25 @@ if _have_threads: ...@@ -2905,16 +3012,25 @@ if _have_threads:
s.send(data) s.send(data)
buffer = bytearray(len(data)) buffer = bytearray(len(data))
self.assertEqual(s.read(-1, buffer), len(data)) self.assertEqual(s.read(-1, buffer), len(data))
self.assertEqual(buffer, data) self.assertEqual(buffer, data) # sendall accepts bytes-like objects
try:
if ctypes is not None:
ubyte = ctypes.c_ubyte * len(data)
byteslike = ubyte.from_buffer_copy(data)
s.sendall(byteslike)
self.assertEqual(s.read(), data)
except:
s.close()
raise
# Make sure sendmsg et al are disallowed to avoid # Make sure sendmsg et al are disallowed to avoid
# inadvertent disclosure of data and/or corruption # inadvertent disclosure of data and/or corruption
# of the encrypted data stream # of the encrypted data stream
self.assertRaises(NotImplementedError, s.dup)
self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
self.assertRaises(NotImplementedError, s.recvmsg, 100) self.assertRaises(NotImplementedError, s.recvmsg, 100)
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
s.recvmsg_into, bytearray(100)) s.recvmsg_into, [bytearray(100)])
s.write(b"over\n") s.write(b"over\n")
self.assertRaises(ValueError, s.recv, -1) self.assertRaises(ValueError, s.recv, -1)
...@@ -3043,7 +3159,7 @@ if _have_threads: ...@@ -3043,7 +3159,7 @@ if _have_threads:
# Block on the accept and wait on the connection to close. # Block on the accept and wait on the connection to close.
evt.set() evt.set()
remote, peer = server.accept() remote, peer = server.accept()
remote.recv(1) remote.send(remote.recv(4))
t = threading.Thread(target=serve) t = threading.Thread(target=serve)
t.start() t.start()
...@@ -3051,6 +3167,8 @@ if _have_threads: ...@@ -3051,6 +3167,8 @@ if _have_threads:
evt.wait() evt.wait()
client = context.wrap_socket(socket.socket()) client = context.wrap_socket(socket.socket())
client.connect((host, port)) client.connect((host, port))
client.send(b'data')
client.recv()
client_addr = client.getsockname() client_addr = client.getsockname()
client.close() client.close()
t.join() t.join()
...@@ -3074,20 +3192,25 @@ if _have_threads: ...@@ -3074,20 +3192,25 @@ if _have_threads:
sock.do_handshake() sock.do_handshake()
self.assertEqual(cm.exception.errno, errno.ENOTCONN) self.assertEqual(cm.exception.errno, errno.ENOTCONN)
def test_default_ciphers(self): def test_no_shared_ciphers(self):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
try: server_context.load_cert_chain(SIGNED_CERTFILE)
# Force a set of weak ciphers on our client context client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.set_ciphers("DES") client_context.verify_mode = ssl.CERT_REQUIRED
except ssl.SSLError: client_context.check_hostname = True
self.skipTest("no DES cipher available")
with ThreadedEchoServer(CERTFILE, # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
ssl_version=ssl.PROTOCOL_SSLv23, client_context.options |= ssl.OP_NO_TLSv1_3
chatty=False) as server: # Force different suites on client and master
with context.wrap_socket(socket.socket()) as s: client_context.set_ciphers("AES128")
server_context.set_ciphers("AES256")
with ThreadedEchoServer(context=server_context) as server:
with client_context.wrap_socket(
socket.socket(),
server_hostname="localhost") as s:
with self.assertRaises(OSError): with self.assertRaises(OSError):
s.connect((HOST, server.port)) s.connect((HOST, server.port))
self.assertEqual("NO_SHARED_CIPHER", server.conn_errors[0].reason) self.assertIn("no shared cipher", server.conn_errors[0])
def test_version_basic(self): def test_version_basic(self):
""" """
...@@ -3104,12 +3227,33 @@ if _have_threads: ...@@ -3104,12 +3227,33 @@ if _have_threads:
self.assertEqual(s.version(), 'TLSv1') self.assertEqual(s.version(), 'TLSv1')
self.assertIs(s.version(), None) self.assertIs(s.version(), None)
@unittest.skipUnless(ssl.HAS_TLSv1_3,
"test requires TLSv1.3 enabled OpenSSL")
def test_tls1_3(self):
context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.load_cert_chain(CERTFILE)
# disable all but TLS 1.3
context.options |= (
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
)
with ThreadedEchoServer(context=context) as server:
with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
self.assertIn(s.cipher()[0], [
'TLS_AES_256_GCM_SHA384',
'TLS_CHACHA20_POLY1305_SHA256',
'TLS_AES_128_GCM_SHA256',
])
@unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
def test_default_ecdh_curve(self): def test_default_ecdh_curve(self):
# Issue #21015: elliptic curve-based Diffie Hellman key exchange # Issue #21015: elliptic curve-based Diffie Hellman key exchange
# should be enabled by default on SSL contexts. # should be enabled by default on SSL contexts.
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.load_cert_chain(CERTFILE) context.load_cert_chain(CERTFILE)
# TLSv1.3 defaults to PFS key agreement and no longer has KEA in
# cipher name.
context.options |= ssl.OP_NO_TLSv1_3
# Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
# explicitly using the 'ECCdraft' cipher alias. Otherwise, # explicitly using the 'ECCdraft' cipher alias. Otherwise,
# our default cipher list should prefer ECDH-based ciphers # our default cipher list should prefer ECDH-based ciphers
...@@ -3258,8 +3402,9 @@ if _have_threads: ...@@ -3258,8 +3402,9 @@ if _have_threads:
except ssl.SSLError as e: except ssl.SSLError as e:
stats = e stats = e
if expected is None and IS_OPENSSL_1_1: if (expected is None and IS_OPENSSL_1_1
# OpenSSL 1.1.0 raises handshake error and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
# OpenSSL 1.1.0 to 1.1.0e raises handshake error
self.assertIsInstance(stats, ssl.SSLError) self.assertIsInstance(stats, ssl.SSLError)
else: else:
msg = "failed trying %s (s) and %s (c).\n" \ msg = "failed trying %s (s) and %s (c).\n" \
...@@ -3423,19 +3568,25 @@ if _have_threads: ...@@ -3423,19 +3568,25 @@ if _have_threads:
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
client_context.set_ciphers("AES128:AES256") client_context.set_ciphers("AES128:AES256")
server_context.set_ciphers("AES256") server_context.set_ciphers("AES256")
alg1 = "AES256" expected_algs = [
alg2 = "AES-256" "AES256", "AES-256"
]
else: else:
client_context.set_ciphers("AES:3DES") client_context.set_ciphers("AES:3DES")
server_context.set_ciphers("3DES") server_context.set_ciphers("3DES")
alg1 = "3DES" expected_algs = [
alg2 = "DES-CBC3" "3DES", "DES-CBC3"
]
if ssl.HAS_TLSv1_3:
# TLS 1.3 ciphers are always enabled
expected_algs.extend(["TLS_CHACHA20", "TLS_AES"])
stats = server_params_test(client_context, server_context) stats = server_params_test(client_context, server_context)
ciphers = stats['server_shared_ciphers'][0] ciphers = stats['server_shared_ciphers'][0]
self.assertGreater(len(ciphers), 0) self.assertGreater(len(ciphers), 0)
for name, tls_version, bits in ciphers: for name, tls_version, bits in ciphers:
if not alg1 in name.split("-") and alg2 not in name: if not any(alg in name for alg in expected_algs):
self.fail(name) self.fail(name)
def test_read_write_after_close_raises_valuerror(self): def test_read_write_after_close_raises_valuerror(self):
...@@ -3537,6 +3688,10 @@ if _have_threads: ...@@ -3537,6 +3688,10 @@ if _have_threads:
context2.load_verify_locations(CERTFILE) context2.load_verify_locations(CERTFILE)
context2.load_cert_chain(CERTFILE) context2.load_cert_chain(CERTFILE)
# TODO: session reuse does not work with TLS 1.3
context.options |= ssl.OP_NO_TLSv1_3
context2.options |= ssl.OP_NO_TLSv1_3
server = ThreadedEchoServer(context=context, chatty=False) server = ThreadedEchoServer(context=context, chatty=False)
with server: with server:
with context.wrap_socket(socket.socket()) as s: with context.wrap_socket(socket.socket()) as s:
...@@ -3576,6 +3731,194 @@ if _have_threads: ...@@ -3576,6 +3731,194 @@ if _have_threads:
'Session refers to a different SSLContext.') 'Session refers to a different SSLContext.')
def testing_context():
"""Create context
client_context, server_context, hostname = testing_context()
"""
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.load_verify_locations(SIGNING_CA)
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(SIGNED_CERTFILE)
server_context.load_verify_locations(SIGNING_CA)
return client_context, server_context, 'localhost'
@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
class TestPostHandshakeAuth(unittest.TestCase):
def test_pha_setter(self):
protocols = [
ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
]
for protocol in protocols:
ctx = ssl.SSLContext(protocol)
self.assertEqual(ctx.post_handshake_auth, False)
ctx.post_handshake_auth = True
self.assertEqual(ctx.post_handshake_auth, True)
ctx.verify_mode = ssl.CERT_REQUIRED
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
self.assertEqual(ctx.post_handshake_auth, True)
ctx.post_handshake_auth = False
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
self.assertEqual(ctx.post_handshake_auth, False)
ctx.verify_mode = ssl.CERT_OPTIONAL
ctx.post_handshake_auth = True
self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
self.assertEqual(ctx.post_handshake_auth, True)
def test_pha_required(self):
client_context, server_context, hostname = testing_context()
server_context.post_handshake_auth = True
server_context.verify_mode = ssl.CERT_REQUIRED
client_context.post_handshake_auth = True
client_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'FALSE\n')
s.write(b'PHA')
self.assertEqual(s.recv(1024), b'OK\n')
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'TRUE\n')
# PHA method just returns true when cert is already available
s.write(b'PHA')
self.assertEqual(s.recv(1024), b'OK\n')
s.write(b'GETCERT')
cert_text = s.recv(4096).decode('us-ascii')
self.assertIn('Python Software Foundation CA', cert_text)
def test_pha_required_nocert(self):
client_context, server_context, hostname = testing_context()
server_context.post_handshake_auth = True
server_context.verify_mode = ssl.CERT_REQUIRED
client_context.post_handshake_auth = True
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
s.write(b'PHA')
# receive CertificateRequest
self.assertEqual(s.recv(1024), b'OK\n')
# send empty Certificate + Finish
s.write(b'HASCERT')
# receive alert
with self.assertRaisesRegex(
ssl.SSLError,
'tlsv13 alert certificate required'):
s.recv(1024)
def test_pha_optional(self):
if support.verbose:
sys.stdout.write("\n")
client_context, server_context, hostname = testing_context()
server_context.post_handshake_auth = True
server_context.verify_mode = ssl.CERT_REQUIRED
client_context.post_handshake_auth = True
client_context.load_cert_chain(SIGNED_CERTFILE)
# check CERT_OPTIONAL
server_context.verify_mode = ssl.CERT_OPTIONAL
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'FALSE\n')
s.write(b'PHA')
self.assertEqual(s.recv(1024), b'OK\n')
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'TRUE\n')
def test_pha_optional_nocert(self):
if support.verbose:
sys.stdout.write("\n")
client_context, server_context, hostname = testing_context()
server_context.post_handshake_auth = True
server_context.verify_mode = ssl.CERT_OPTIONAL
client_context.post_handshake_auth = True
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'FALSE\n')
s.write(b'PHA')
self.assertEqual(s.recv(1024), b'OK\n')
# optional doens't fail when client does not have a cert
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'FALSE\n')
def test_pha_no_pha_client(self):
client_context, server_context, hostname = testing_context()
server_context.post_handshake_auth = True
server_context.verify_mode = ssl.CERT_REQUIRED
client_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
with self.assertRaisesRegex(ssl.SSLError, 'not server'):
s.verify_client_post_handshake()
s.write(b'PHA')
self.assertIn(b'extension not received', s.recv(1024))
def test_pha_no_pha_server(self):
# server doesn't have PHA enabled, cert is requested in handshake
client_context, server_context, hostname = testing_context()
server_context.verify_mode = ssl.CERT_REQUIRED
client_context.post_handshake_auth = True
client_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'TRUE\n')
# PHA doesn't fail if there is already a cert
s.write(b'PHA')
self.assertEqual(s.recv(1024), b'OK\n')
s.write(b'HASCERT')
self.assertEqual(s.recv(1024), b'TRUE\n')
def test_pha_not_tls13(self):
# TLS 1.2
client_context, server_context, hostname = testing_context()
server_context.verify_mode = ssl.CERT_REQUIRED
client_context.options |= ssl.OP_NO_TLSv1_3
client_context.post_handshake_auth = True
client_context.load_cert_chain(SIGNED_CERTFILE)
server = ThreadedEchoServer(context=server_context, chatty=False)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
# PHA fails for TLS != 1.3
s.write(b'PHA')
self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
def test_main(verbose=False): def test_main(verbose=False):
if support.verbose: if support.verbose:
import warnings import warnings
...@@ -3628,6 +3971,7 @@ def test_main(verbose=False): ...@@ -3628,6 +3971,7 @@ def test_main(verbose=False):
thread_info = support.threading_setup() thread_info = support.threading_setup()
if thread_info: if thread_info:
tests.append(ThreadedTests) tests.append(ThreadedTests)
tests.append(TestPostHandshakeAuth)
try: try:
support.run_unittest(*tests) support.run_unittest(*tests)
......
...@@ -6,6 +6,7 @@ import sys ...@@ -6,6 +6,7 @@ import sys
import platform import platform
import signal import signal
import io import io
import itertools
import os import os
import errno import errno
import tempfile import tempfile
...@@ -16,17 +17,25 @@ import select ...@@ -16,17 +17,25 @@ import select
import shutil import shutil
import gc import gc
import textwrap import textwrap
from test.support import FakePath
try: try:
import ctypes import ctypes
except ImportError: except ImportError:
ctypes = None ctypes = None
else:
import ctypes.util
try: try:
import threading import threading
except ImportError: except ImportError:
threading = None threading = None
try:
import _testcapi
except ImportError:
_testcapi = None
if support.PGO: if support.PGO:
raise unittest.SkipTest("test is not helpful for PGO") raise unittest.SkipTest("test is not helpful for PGO")
...@@ -42,6 +51,10 @@ if mswindows: ...@@ -42,6 +51,10 @@ if mswindows:
else: else:
SETBINARY = '' SETBINARY = ''
NONEXISTING_CMD = ('nonexisting_i_hope',)
# Ignore errors that indicate the command was not found
NONEXISTING_ERRORS = (FileNotFoundError, NotADirectoryError, PermissionError)
class BaseTestCase(unittest.TestCase): class BaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -54,6 +67,8 @@ class BaseTestCase(unittest.TestCase): ...@@ -54,6 +67,8 @@ class BaseTestCase(unittest.TestCase):
inst.wait() inst.wait()
subprocess._cleanup() subprocess._cleanup()
self.assertFalse(subprocess._active, "subprocess._active not empty") self.assertFalse(subprocess._active, "subprocess._active not empty")
self.doCleanups()
support.reap_children()
def assertStderrEqual(self, stderr, expected, msg=None): def assertStderrEqual(self, stderr, expected, msg=None):
# In a debug build, stuff like "[6580 refs]" is printed to stderr at # In a debug build, stuff like "[6580 refs]" is printed to stderr at
...@@ -299,9 +314,9 @@ class ProcessTestCase(BaseTestCase): ...@@ -299,9 +314,9 @@ class ProcessTestCase(BaseTestCase):
# Verify first that the call succeeds without the executable arg. # Verify first that the call succeeds without the executable arg.
pre_args = [sys.executable, "-c"] pre_args = [sys.executable, "-c"]
self._assert_python(pre_args) self._assert_python(pre_args)
self.assertRaises((FileNotFoundError, PermissionError), self.assertRaises(NONEXISTING_ERRORS,
self._assert_python, pre_args, self._assert_python, pre_args,
executable="doesnotexist") executable=NONEXISTING_CMD[0])
@unittest.skipIf(mswindows, "executable argument replaces shell") @unittest.skipIf(mswindows, "executable argument replaces shell")
def test_executable_replaces_shell(self): def test_executable_replaces_shell(self):
...@@ -350,12 +365,7 @@ class ProcessTestCase(BaseTestCase): ...@@ -350,12 +365,7 @@ class ProcessTestCase(BaseTestCase):
def test_cwd_with_pathlike(self): def test_cwd_with_pathlike(self):
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
temp_dir = self._normalize_cwd(temp_dir) temp_dir = self._normalize_cwd(temp_dir)
self._assert_cwd(temp_dir, sys.executable, cwd=FakePath(temp_dir))
class _PathLikeObj:
def __fspath__(self):
return temp_dir
self._assert_cwd(temp_dir, sys.executable, cwd=_PathLikeObj())
@unittest.skipIf(mswindows, "pending resolution of issue #15533") @unittest.skipIf(mswindows, "pending resolution of issue #15533")
def test_cwd_with_relative_arg(self): def test_cwd_with_relative_arg(self):
...@@ -644,6 +654,46 @@ class ProcessTestCase(BaseTestCase): ...@@ -644,6 +654,46 @@ class ProcessTestCase(BaseTestCase):
# environment # environment
b"['__CF_USER_TEXT_ENCODING']")) b"['__CF_USER_TEXT_ENCODING']"))
def test_invalid_cmd(self):
# null character in the command name
cmd = sys.executable + '\0'
with self.assertRaises(ValueError):
subprocess.Popen([cmd, "-c", "pass"])
# null character in the command argument
with self.assertRaises(ValueError):
subprocess.Popen([sys.executable, "-c", "pass#\0"])
def test_invalid_env(self):
# null character in the enviroment variable name
newenv = os.environ.copy()
newenv["FRUIT\0VEGETABLE"] = "cabbage"
with self.assertRaises(ValueError):
subprocess.Popen([sys.executable, "-c", "pass"], env=newenv)
# null character in the enviroment variable value
newenv = os.environ.copy()
newenv["FRUIT"] = "orange\0VEGETABLE=cabbage"
with self.assertRaises(ValueError):
subprocess.Popen([sys.executable, "-c", "pass"], env=newenv)
# equal character in the enviroment variable name
newenv = os.environ.copy()
newenv["FRUIT=ORANGE"] = "lemon"
with self.assertRaises(ValueError):
subprocess.Popen([sys.executable, "-c", "pass"], env=newenv)
# equal character in the enviroment variable value
newenv = os.environ.copy()
newenv["FRUIT"] = "orange=lemon"
with subprocess.Popen([sys.executable, "-c",
'import sys, os;'
'sys.stdout.write(os.getenv("FRUIT"))'],
stdout=subprocess.PIPE,
env=newenv) as p:
stdout, stderr = p.communicate()
self.assertEqual(stdout, b"orange=lemon")
def test_communicate_stdin(self): def test_communicate_stdin(self):
p = subprocess.Popen([sys.executable, "-c", p = subprocess.Popen([sys.executable, "-c",
'import sys;' 'import sys;'
...@@ -1071,10 +1121,11 @@ class ProcessTestCase(BaseTestCase): ...@@ -1071,10 +1121,11 @@ class ProcessTestCase(BaseTestCase):
p.stdin.write(line) # expect that it flushes the line in text mode p.stdin.write(line) # expect that it flushes the line in text mode
os.close(p.stdin.fileno()) # close it without flushing the buffer os.close(p.stdin.fileno()) # close it without flushing the buffer
read_line = p.stdout.readline() read_line = p.stdout.readline()
try: with support.SuppressCrashReport():
p.stdin.close() try:
except OSError: p.stdin.close()
pass except OSError:
pass
p.stdin = None p.stdin = None
self.assertEqual(p.returncode, 0) self.assertEqual(p.returncode, 0)
self.assertEqual(read_line, expected) self.assertEqual(read_line, expected)
...@@ -1098,13 +1149,51 @@ class ProcessTestCase(BaseTestCase): ...@@ -1098,13 +1149,51 @@ class ProcessTestCase(BaseTestCase):
# value for that limit, but Windows has 2048, so we loop # value for that limit, but Windows has 2048, so we loop
# 1024 times (each call leaked two fds). # 1024 times (each call leaked two fds).
for i in range(1024): for i in range(1024):
with self.assertRaises(OSError) as c: with self.assertRaises(NONEXISTING_ERRORS):
subprocess.Popen(['nonexisting_i_hope'], subprocess.Popen(NONEXISTING_CMD,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE)
# ignore errors that indicate the command was not found
if c.exception.errno not in (errno.ENOENT, errno.EACCES): def test_nonexisting_with_pipes(self):
raise c.exception # bpo-30121: Popen with pipes must close properly pipes on error.
# Previously, os.close() was called with a Windows handle which is not
# a valid file descriptor.
#
# Run the test in a subprocess to control how the CRT reports errors
# and to get stderr content.
try:
import msvcrt
msvcrt.CrtSetReportMode
except (AttributeError, ImportError):
self.skipTest("need msvcrt.CrtSetReportMode")
code = textwrap.dedent(f"""
import msvcrt
import subprocess
cmd = {NONEXISTING_CMD!r}
for report_type in [msvcrt.CRT_WARN,
msvcrt.CRT_ERROR,
msvcrt.CRT_ASSERT]:
msvcrt.CrtSetReportMode(report_type, msvcrt.CRTDBG_MODE_FILE)
msvcrt.CrtSetReportFile(report_type, msvcrt.CRTDBG_FILE_STDERR)
try:
subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
except OSError:
pass
""")
cmd = [sys.executable, "-c", code]
proc = subprocess.Popen(cmd,
stderr=subprocess.PIPE,
universal_newlines=True)
with proc:
stderr = proc.communicate()[1]
self.assertEqual(stderr, "")
self.assertEqual(proc.returncode, 0)
@unittest.skipIf(threading is None, "threading required") @unittest.skipIf(threading is None, "threading required")
def test_double_close_on_error(self): def test_double_close_on_error(self):
...@@ -1118,7 +1207,7 @@ class ProcessTestCase(BaseTestCase): ...@@ -1118,7 +1207,7 @@ class ProcessTestCase(BaseTestCase):
t.start() t.start()
try: try:
with self.assertRaises(EnvironmentError): with self.assertRaises(EnvironmentError):
subprocess.Popen(['nonexisting_i_hope'], subprocess.Popen(NONEXISTING_CMD,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE)
...@@ -1282,6 +1371,18 @@ class ProcessTestCase(BaseTestCase): ...@@ -1282,6 +1371,18 @@ class ProcessTestCase(BaseTestCase):
fds_after_exception = os.listdir(fd_directory) fds_after_exception = os.listdir(fd_directory)
self.assertEqual(fds_before_popen, fds_after_exception) self.assertEqual(fds_before_popen, fds_after_exception)
@unittest.skipIf(mswindows, "behavior currently not supported on Windows")
def test_file_not_found_includes_filename(self):
with self.assertRaises(FileNotFoundError) as c:
subprocess.call(['/opt/nonexistent_binary', 'with', 'some', 'args'])
self.assertEqual(c.exception.filename, '/opt/nonexistent_binary')
@unittest.skipIf(mswindows, "behavior currently not supported on Windows")
def test_file_not_found_with_bad_cwd(self):
with self.assertRaises(FileNotFoundError) as c:
subprocess.Popen(['exit', '0'], cwd='/some/nonexistent/directory')
self.assertEqual(c.exception.filename, '/some/nonexistent/directory')
class RunFuncTestCase(BaseTestCase): class RunFuncTestCase(BaseTestCase):
def run_python(self, code, **kwargs): def run_python(self, code, **kwargs):
...@@ -1440,6 +1541,57 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -1440,6 +1541,57 @@ class POSIXProcessTestCase(BaseTestCase):
else: else:
self.fail("Expected OSError: %s" % desired_exception) self.fail("Expected OSError: %s" % desired_exception)
# We mock the __del__ method for Popen in the next two tests
# because it does cleanup based on the pid returned by fork_exec
# along with issuing a resource warning if it still exists. Since
# we don't actually spawn a process in these tests we can forego
# the destructor. An alternative would be to set _child_created to
# False before the destructor is called but there is no easy way
# to do that
class PopenNoDestructor(subprocess.Popen):
def __del__(self):
pass
@mock.patch("subprocess._posixsubprocess.fork_exec")
def test_exception_errpipe_normal(self, fork_exec):
"""Test error passing done through errpipe_write in the good case"""
def proper_error(*args):
errpipe_write = args[13]
# Write the hex for the error code EISDIR: 'is a directory'
err_code = '{:x}'.format(errno.EISDIR).encode()
os.write(errpipe_write, b"OSError:" + err_code + b":")
return 0
fork_exec.side_effect = proper_error
with mock.patch("subprocess.os.waitpid",
side_effect=ChildProcessError):
with self.assertRaises(IsADirectoryError):
self.PopenNoDestructor(["non_existent_command"])
@mock.patch("subprocess._posixsubprocess.fork_exec")
def test_exception_errpipe_bad_data(self, fork_exec):
"""Test error passing done through errpipe_write where its not
in the expected format"""
error_data = b"\xFF\x00\xDE\xAD"
def bad_error(*args):
errpipe_write = args[13]
# Anything can be in the pipe, no assumptions should
# be made about its encoding, so we'll write some
# arbitrary hex bytes to test it out
os.write(errpipe_write, error_data)
return 0
fork_exec.side_effect = bad_error
with mock.patch("subprocess.os.waitpid",
side_effect=ChildProcessError):
with self.assertRaises(subprocess.SubprocessError) as e:
self.PopenNoDestructor(["non_existent_command"])
self.assertIn(repr(error_data), str(e.exception))
def test_restore_signals(self): def test_restore_signals(self):
# Code coverage for both values of restore_signals to make sure it # Code coverage for both values of restore_signals to make sure it
# at least does not blow up. # at least does not blow up.
...@@ -1958,6 +2110,55 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -1958,6 +2110,55 @@ class POSIXProcessTestCase(BaseTestCase):
self.check_swap_fds(2, 0, 1) self.check_swap_fds(2, 0, 1)
self.check_swap_fds(2, 1, 0) self.check_swap_fds(2, 1, 0)
def _check_swap_std_fds_with_one_closed(self, from_fds, to_fds):
saved_fds = self._save_fds(range(3))
try:
for from_fd in from_fds:
with tempfile.TemporaryFile() as f:
os.dup2(f.fileno(), from_fd)
fd_to_close = (set(range(3)) - set(from_fds)).pop()
os.close(fd_to_close)
arg_names = ['stdin', 'stdout', 'stderr']
kwargs = {}
for from_fd, to_fd in zip(from_fds, to_fds):
kwargs[arg_names[to_fd]] = from_fd
code = textwrap.dedent(r'''
import os, sys
skipped_fd = int(sys.argv[1])
for fd in range(3):
if fd != skipped_fd:
os.write(fd, str(fd).encode('ascii'))
''')
skipped_fd = (set(range(3)) - set(to_fds)).pop()
rc = subprocess.call([sys.executable, '-c', code, str(skipped_fd)],
**kwargs)
self.assertEqual(rc, 0)
for from_fd, to_fd in zip(from_fds, to_fds):
os.lseek(from_fd, 0, os.SEEK_SET)
read_bytes = os.read(from_fd, 1024)
read_fds = list(map(int, read_bytes.decode('ascii')))
msg = textwrap.dedent(f"""
When testing {from_fds} to {to_fds} redirection,
parent descriptor {from_fd} got redirected
to descriptor(s) {read_fds} instead of descriptor {to_fd}.
""")
self.assertEqual([to_fd], read_fds, msg)
finally:
self._restore_fds(saved_fds)
# Check that subprocess can remap std fds correctly even
# if one of them is closed (#32844).
def test_swap_std_fds_with_one_closed(self):
for from_fds in itertools.combinations(range(3), 2):
for to_fds in itertools.permutations(range(3), 2):
self._check_swap_std_fds_with_one_closed(from_fds, to_fds)
def test_surrogates_error_message(self): def test_surrogates_error_message(self):
def prepare(): def prepare():
raise ValueError("surrogate:\uDCff") raise ValueError("surrogate:\uDCff")
...@@ -2293,6 +2494,36 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2293,6 +2494,36 @@ class POSIXProcessTestCase(BaseTestCase):
self.assertEqual(os.get_inheritable(inheritable), True) self.assertEqual(os.get_inheritable(inheritable), True)
self.assertEqual(os.get_inheritable(non_inheritable), False) self.assertEqual(os.get_inheritable(non_inheritable), False)
# bpo-32270: Ensure that descriptors specified in pass_fds
# are inherited even if they are used in redirections.
# Contributed by @izbyshev.
def test_pass_fds_redirected(self):
"""Regression test for https://bugs.python.org/issue32270."""
fd_status = support.findfile("fd_status.py", subdir="subprocessdata")
pass_fds = []
for _ in range(2):
fd = os.open(os.devnull, os.O_RDWR)
self.addCleanup(os.close, fd)
pass_fds.append(fd)
stdout_r, stdout_w = os.pipe()
self.addCleanup(os.close, stdout_r)
self.addCleanup(os.close, stdout_w)
pass_fds.insert(1, stdout_w)
with subprocess.Popen([sys.executable, fd_status],
stdin=pass_fds[0],
stdout=pass_fds[1],
stderr=pass_fds[2],
close_fds=True,
pass_fds=pass_fds):
output = os.read(stdout_r, 1024)
fds = {int(num) for num in output.split(b',')}
self.assertEqual(fds, {0, 1, 2} | frozenset(pass_fds), f"output={output!a}")
def test_stdout_stdin_are_single_inout_fd(self): def test_stdout_stdin_are_single_inout_fd(self):
with io.open(os.devnull, "r+") as inout: with io.open(os.devnull, "r+") as inout:
p = subprocess.Popen([sys.executable, "-c", "import sys; sys.exit(0)"], p = subprocess.Popen([sys.executable, "-c", "import sys; sys.exit(0)"],
...@@ -2386,8 +2617,8 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2386,8 +2617,8 @@ class POSIXProcessTestCase(BaseTestCase):
# let some time for the process to exit, and create a new Popen: this # let some time for the process to exit, and create a new Popen: this
# should trigger the wait() of p # should trigger the wait() of p
time.sleep(0.2) time.sleep(0.2)
with self.assertRaises(OSError) as c: with self.assertRaises(OSError):
with subprocess.Popen(['nonexisting_i_hope'], with subprocess.Popen(NONEXISTING_CMD,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as proc: stderr=subprocess.PIPE) as proc:
pass pass
...@@ -2434,7 +2665,7 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2434,7 +2665,7 @@ class POSIXProcessTestCase(BaseTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
_posixsubprocess.fork_exec( _posixsubprocess.fork_exec(
args, exe_list, args, exe_list,
True, [], cwd, env_list, True, (), cwd, env_list,
-1, -1, -1, -1, -1, -1, -1, -1,
1, 2, 3, 4, 1, 2, 3, 4,
True, True, func) True, True, func)
...@@ -2446,6 +2677,16 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2446,6 +2677,16 @@ class POSIXProcessTestCase(BaseTestCase):
def test_fork_exec_sorted_fd_sanity_check(self): def test_fork_exec_sorted_fd_sanity_check(self):
# Issue #23564: sanity check the fork_exec() fds_to_keep sanity check. # Issue #23564: sanity check the fork_exec() fds_to_keep sanity check.
import _posixsubprocess import _posixsubprocess
class BadInt:
first = True
def __init__(self, value):
self.value = value
def __int__(self):
if self.first:
self.first = False
return self.value
raise ValueError
gc_enabled = gc.isenabled() gc_enabled = gc.isenabled()
try: try:
gc.enable() gc.enable()
...@@ -2456,6 +2697,7 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2456,6 +2697,7 @@ class POSIXProcessTestCase(BaseTestCase):
(18, 23, 42, 2**63), # Out of range. (18, 23, 42, 2**63), # Out of range.
(5, 4), # Not sorted. (5, 4), # Not sorted.
(6, 7, 7, 8), # Duplicate. (6, 7, 7, 8), # Duplicate.
(BadInt(1), BadInt(2)),
): ):
with self.assertRaises( with self.assertRaises(
ValueError, ValueError,
...@@ -2517,45 +2759,24 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2517,45 +2759,24 @@ class POSIXProcessTestCase(BaseTestCase):
proc.communicate(timeout=999) proc.communicate(timeout=999)
mock_proc_stdin.close.assert_called_once_with() mock_proc_stdin.close.assert_called_once_with()
_libc_file_extensions = { @unittest.skipUnless(_testcapi is not None
'Linux': 'so.6', and hasattr(_testcapi, 'W_STOPCODE'),
'Darwin': 'dylib', 'need _testcapi.W_STOPCODE')
} def test_stopped(self):
@unittest.skipIf(not ctypes, 'ctypes module required.')
@unittest.skipIf(platform.uname()[0] not in _libc_file_extensions,
'Test requires a libc this code can load with ctypes.')
@unittest.skipIf(not sys.executable, 'Test requires sys.executable.')
def test_child_terminated_in_stopped_state(self):
"""Test wait() behavior when waitpid returns WIFSTOPPED; issue29335.""" """Test wait() behavior when waitpid returns WIFSTOPPED; issue29335."""
PTRACE_TRACEME = 0 # From glibc and MacOS (PT_TRACE_ME). args = [sys.executable, '-c', 'pass']
libc_name = 'libc.' + self._libc_file_extensions[platform.uname()[0]] proc = subprocess.Popen(args)
libc = ctypes.CDLL(libc_name)
if not hasattr(libc, 'ptrace'): # Wait until the real process completes to avoid zombie process
raise unittest.SkipTest('ptrace() required.') pid = proc.pid
test_ptrace = subprocess.Popen( pid, status = os.waitpid(pid, 0)
[sys.executable, '-c', """if True: self.assertEqual(status, 0)
import ctypes
libc = ctypes.CDLL({libc_name!r}) status = _testcapi.W_STOPCODE(3)
libc.ptrace({PTRACE_TRACEME}, 0, 0) with mock.patch('subprocess.os.waitpid', return_value=(pid, status)):
""".format(libc_name=libc_name, PTRACE_TRACEME=PTRACE_TRACEME) returncode = proc.wait()
])
if test_ptrace.wait() != 0: self.assertEqual(returncode, -3)
raise unittest.SkipTest('ptrace() failed - unable to test.')
child = subprocess.Popen(
[sys.executable, '-c', """if True:
import ctypes
libc = ctypes.CDLL({libc_name!r})
libc.ptrace({PTRACE_TRACEME}, 0, 0)
libc.printf(ctypes.c_char_p(0xdeadbeef)) # Crash the process.
""".format(libc_name=libc_name, PTRACE_TRACEME=PTRACE_TRACEME)
])
try:
returncode = child.wait()
except Exception as e:
child.kill() # Clean up the hung stopped process.
raise e
self.assertNotEqual(0, returncode)
self.assertLess(returncode, 0) # signal death, likely SIGSEGV.
@unittest.skipUnless(mswindows, "Windows specific tests") @unittest.skipUnless(mswindows, "Windows specific tests")
...@@ -2596,6 +2817,15 @@ class Win32ProcessTestCase(BaseTestCase): ...@@ -2596,6 +2817,15 @@ class Win32ProcessTestCase(BaseTestCase):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
close_fds=True) close_fds=True)
@support.cpython_only
def test_issue31471(self):
# There shouldn't be an assertion failure in Popen() in case the env
# argument has a bad keys() method.
class BadEnv(dict):
keys = None
with self.assertRaises(TypeError):
subprocess.Popen([sys.executable, "-c", "pass"], env=BadEnv())
def test_close_fds(self): def test_close_fds(self):
# close file descriptors # close file descriptors
rc = subprocess.call([sys.executable, "-c", rc = subprocess.call([sys.executable, "-c",
...@@ -2826,8 +3056,8 @@ class ContextManagerTests(BaseTestCase): ...@@ -2826,8 +3056,8 @@ class ContextManagerTests(BaseTestCase):
self.assertEqual(proc.returncode, 1) self.assertEqual(proc.returncode, 1)
def test_invalid_args(self): def test_invalid_args(self):
with self.assertRaises((FileNotFoundError, PermissionError)) as c: with self.assertRaises(NONEXISTING_ERRORS):
with subprocess.Popen(['nonexisting_i_hope'], with subprocess.Popen(NONEXISTING_CMD,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as proc: stderr=subprocess.PIPE) as proc:
pass pass
......
...@@ -398,5 +398,4 @@ class ExpectTests(ExpectAndReadTestCase): ...@@ -398,5 +398,4 @@ class ExpectTests(ExpectAndReadTestCase):
if __name__ == '__main__': if __name__ == '__main__':
import unittest
unittest.main() unittest.main()
...@@ -11,6 +11,7 @@ from test import lock_tests ...@@ -11,6 +11,7 @@ from test import lock_tests
NUMTASKS = 10 NUMTASKS = 10
NUMTRIPS = 3 NUMTRIPS = 3
POLL_SLEEP = 0.010 # seconds = 10 ms
_print_mutex = thread.allocate_lock() _print_mutex = thread.allocate_lock()
...@@ -20,6 +21,7 @@ def verbose_print(arg): ...@@ -20,6 +21,7 @@ def verbose_print(arg):
with _print_mutex: with _print_mutex:
print(arg) print(arg)
class BasicThreadTest(unittest.TestCase): class BasicThreadTest(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -31,6 +33,9 @@ class BasicThreadTest(unittest.TestCase): ...@@ -31,6 +33,9 @@ class BasicThreadTest(unittest.TestCase):
self.running = 0 self.running = 0
self.next_ident = 0 self.next_ident = 0
key = support.threading_setup()
self.addCleanup(support.threading_cleanup, *key)
class ThreadRunningTests(BasicThreadTest): class ThreadRunningTests(BasicThreadTest):
...@@ -54,12 +59,13 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -54,12 +59,13 @@ class ThreadRunningTests(BasicThreadTest):
self.done_mutex.release() self.done_mutex.release()
def test_starting_threads(self): def test_starting_threads(self):
# Basic test for thread creation. with support.wait_threads_exit():
for i in range(NUMTASKS): # Basic test for thread creation.
self.newtask() for i in range(NUMTASKS):
verbose_print("waiting for tasks to complete...") self.newtask()
self.done_mutex.acquire() verbose_print("waiting for tasks to complete...")
verbose_print("all tasks done") self.done_mutex.acquire()
verbose_print("all tasks done")
def test_stack_size(self): def test_stack_size(self):
# Various stack size tests. # Various stack size tests.
...@@ -89,12 +95,13 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -89,12 +95,13 @@ class ThreadRunningTests(BasicThreadTest):
verbose_print("trying stack_size = (%d)" % tss) verbose_print("trying stack_size = (%d)" % tss)
self.next_ident = 0 self.next_ident = 0
self.created = 0 self.created = 0
for i in range(NUMTASKS): with support.wait_threads_exit():
self.newtask() for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for all tasks to complete") verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire() self.done_mutex.acquire()
verbose_print("all tasks done") verbose_print("all tasks done")
thread.stack_size(0) thread.stack_size(0)
...@@ -104,26 +111,29 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -104,26 +111,29 @@ class ThreadRunningTests(BasicThreadTest):
mut = thread.allocate_lock() mut = thread.allocate_lock()
mut.acquire() mut.acquire()
started = [] started = []
def task(): def task():
started.append(None) started.append(None)
mut.acquire() mut.acquire()
mut.release() mut.release()
thread.start_new_thread(task, ())
while not started: with support.wait_threads_exit():
time.sleep(0.01) thread.start_new_thread(task, ())
self.assertEqual(thread._count(), orig + 1) while not started:
# Allow the task to finish. time.sleep(POLL_SLEEP)
mut.release() self.assertEqual(thread._count(), orig + 1)
# The only reliable way to be sure that the thread ended from the # Allow the task to finish.
# interpreter's point of view is to wait for the function object to be mut.release()
# destroyed. # The only reliable way to be sure that the thread ended from the
done = [] # interpreter's point of view is to wait for the function object to be
wr = weakref.ref(task, lambda _: done.append(None)) # destroyed.
del task done = []
while not done: wr = weakref.ref(task, lambda _: done.append(None))
time.sleep(0.01) del task
support.gc_collect() while not done:
self.assertEqual(thread._count(), orig) time.sleep(POLL_SLEEP)
support.gc_collect()
self.assertEqual(thread._count(), orig)
def test_save_exception_state_on_error(self): def test_save_exception_state_on_error(self):
# See issue #14474 # See issue #14474
...@@ -136,16 +146,14 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -136,16 +146,14 @@ class ThreadRunningTests(BasicThreadTest):
except ValueError: except ValueError:
pass pass
real_write(self, *args) real_write(self, *args)
c = thread._count()
started = thread.allocate_lock() started = thread.allocate_lock()
with support.captured_output("stderr") as stderr: with support.captured_output("stderr") as stderr:
real_write = stderr.write real_write = stderr.write
stderr.write = mywrite stderr.write = mywrite
started.acquire() started.acquire()
thread.start_new_thread(task, ()) with support.wait_threads_exit():
started.acquire() thread.start_new_thread(task, ())
while thread._count() > c: started.acquire()
time.sleep(0.01)
self.assertIn("Traceback", stderr.getvalue()) self.assertIn("Traceback", stderr.getvalue())
...@@ -177,13 +185,14 @@ class Barrier: ...@@ -177,13 +185,14 @@ class Barrier:
class BarrierTest(BasicThreadTest): class BarrierTest(BasicThreadTest):
def test_barrier(self): def test_barrier(self):
self.bar = Barrier(NUMTASKS) with support.wait_threads_exit():
self.running = NUMTASKS self.bar = Barrier(NUMTASKS)
for i in range(NUMTASKS): self.running = NUMTASKS
thread.start_new_thread(self.task2, (i,)) for i in range(NUMTASKS):
verbose_print("waiting for tasks to end") thread.start_new_thread(self.task2, (i,))
self.done_mutex.acquire() verbose_print("waiting for tasks to end")
verbose_print("tasks done") self.done_mutex.acquire()
verbose_print("tasks done")
def task2(self, ident): def task2(self, ident):
for i in range(NUMTRIPS): for i in range(NUMTRIPS):
...@@ -225,28 +234,33 @@ class TestForkInThread(unittest.TestCase): ...@@ -225,28 +234,33 @@ class TestForkInThread(unittest.TestCase):
def setUp(self): def setUp(self):
self.read_fd, self.write_fd = os.pipe() self.read_fd, self.write_fd = os.pipe()
@unittest.skipIf(sys.platform.startswith('win'), @unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork')
"This test is only appropriate for POSIX-like systems.")
@support.reap_threads @support.reap_threads
def test_forkinthread(self): def test_forkinthread(self):
status = "not set"
def thread1(): def thread1():
try: nonlocal status
pid = os.fork() # fork in a thread
except RuntimeError:
os._exit(1) # exit the child
if pid == 0: # child # fork in a thread
pid = os.fork()
if pid == 0:
# child
try: try:
os.close(self.read_fd) os.close(self.read_fd)
os.write(self.write_fd, b"OK") os.write(self.write_fd, b"OK")
finally: finally:
os._exit(0) os._exit(0)
else: # parent else:
# parent
os.close(self.write_fd) os.close(self.write_fd)
pid, status = os.waitpid(pid, 0)
thread.start_new_thread(thread1, ()) with support.wait_threads_exit():
self.assertEqual(os.read(self.read_fd, 2), b"OK", thread.start_new_thread(thread1, ())
"Unable to fork() in thread") self.assertEqual(os.read(self.read_fd, 2), b"OK",
"Unable to fork() in thread")
self.assertEqual(status, 0)
def tearDown(self): def tearDown(self):
try: try:
......
...@@ -125,9 +125,10 @@ class ThreadTests(BaseTestCase): ...@@ -125,9 +125,10 @@ class ThreadTests(BaseTestCase):
done.set() done.set()
done = threading.Event() done = threading.Event()
ident = [] ident = []
_thread.start_new_thread(f, ()) with support.wait_threads_exit():
done.wait() tid = _thread.start_new_thread(f, ())
self.assertIsNotNone(ident[0]) done.wait()
self.assertEqual(ident[0], tid)
# Kill the "immortal" _DummyThread # Kill the "immortal" _DummyThread
del threading._active[ident[0]] del threading._active[ident[0]]
...@@ -165,9 +166,10 @@ class ThreadTests(BaseTestCase): ...@@ -165,9 +166,10 @@ class ThreadTests(BaseTestCase):
mutex = threading.Lock() mutex = threading.Lock()
mutex.acquire() mutex.acquire()
tid = _thread.start_new_thread(f, (mutex,)) with support.wait_threads_exit():
# Wait for the thread to finish. tid = _thread.start_new_thread(f, (mutex,))
mutex.acquire() # Wait for the thread to finish.
mutex.acquire()
self.assertIn(tid, threading._active) self.assertIn(tid, threading._active)
self.assertIsInstance(threading._active[tid], threading._DummyThread) self.assertIsInstance(threading._active[tid], threading._DummyThread)
#Issue 29376 #Issue 29376
...@@ -483,13 +485,15 @@ class ThreadTests(BaseTestCase): ...@@ -483,13 +485,15 @@ class ThreadTests(BaseTestCase):
for i in range(20): for i in range(20):
t = threading.Thread(target=lambda: None) t = threading.Thread(target=lambda: None)
t.start() t.start()
self.addCleanup(t.join)
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
os._exit(1 if t.is_alive() else 0) os._exit(11 if t.is_alive() else 10)
else: else:
t.join()
pid, status = os.waitpid(pid, 0) pid, status = os.waitpid(pid, 0)
self.assertEqual(0, status) self.assertTrue(os.WIFEXITED(status))
self.assertEqual(10, os.WEXITSTATUS(status))
def test_main_thread(self): def test_main_thread(self):
main = threading.main_thread() main = threading.main_thread()
...@@ -553,6 +557,37 @@ class ThreadTests(BaseTestCase): ...@@ -553,6 +557,37 @@ class ThreadTests(BaseTestCase):
self.assertEqual(err, b"") self.assertEqual(err, b"")
self.assertEqual(data, "Thread-1\nTrue\nTrue\n") self.assertEqual(data, "Thread-1\nTrue\nTrue\n")
@test.support.cpython_only
@requires_type_collecting
def test_main_thread_during_shutdown(self):
# bpo-31516: current_thread() should still point to the main thread
# at shutdown
code = """if 1:
import gc, threading
main_thread = threading.current_thread()
assert main_thread is threading.main_thread() # sanity check
class RefCycle:
def __init__(self):
self.cycle = self
def __del__(self):
print("GC:",
threading.current_thread() is main_thread,
threading.main_thread() is main_thread,
threading.enumerate() == [main_thread])
RefCycle()
gc.collect() # sanity check
x = RefCycle()
"""
_, out, err = assert_python_ok("-c", code)
data = out.decode()
self.assertEqual(err, b"")
self.assertEqual(data.splitlines(),
["GC: True True True"] * 2)
def test_tstate_lock(self): def test_tstate_lock(self):
# Test an implementation detail of Thread objects. # Test an implementation detail of Thread objects.
started = _thread.allocate_lock() started = _thread.allocate_lock()
...@@ -586,6 +621,7 @@ class ThreadTests(BaseTestCase): ...@@ -586,6 +621,7 @@ class ThreadTests(BaseTestCase):
self.assertFalse(t.is_alive()) self.assertFalse(t.is_alive())
# And verify the thread disposed of _tstate_lock. # And verify the thread disposed of _tstate_lock.
self.assertIsNone(t._tstate_lock) self.assertIsNone(t._tstate_lock)
t.join()
def test_repr_stopped(self): def test_repr_stopped(self):
# Verify that "stopped" shows up in repr(Thread) appropriately. # Verify that "stopped" shows up in repr(Thread) appropriately.
...@@ -612,6 +648,7 @@ class ThreadTests(BaseTestCase): ...@@ -612,6 +648,7 @@ class ThreadTests(BaseTestCase):
break break
time.sleep(0.01) time.sleep(0.01)
self.assertIn(LOOKING_FOR, repr(t)) # we waited at least 5 seconds self.assertIn(LOOKING_FOR, repr(t)) # we waited at least 5 seconds
t.join()
def test_BoundedSemaphore_limit(self): def test_BoundedSemaphore_limit(self):
# BoundedSemaphore should raise ValueError if released too often. # BoundedSemaphore should raise ValueError if released too often.
...@@ -928,6 +965,7 @@ class ThreadingExceptionTests(BaseTestCase): ...@@ -928,6 +965,7 @@ class ThreadingExceptionTests(BaseTestCase):
thread = threading.Thread() thread = threading.Thread()
thread.start() thread.start()
self.assertRaises(RuntimeError, thread.start) self.assertRaises(RuntimeError, thread.start)
thread.join()
def test_joining_current_thread(self): def test_joining_current_thread(self):
current_thread = threading.current_thread() current_thread = threading.current_thread()
...@@ -941,6 +979,7 @@ class ThreadingExceptionTests(BaseTestCase): ...@@ -941,6 +979,7 @@ class ThreadingExceptionTests(BaseTestCase):
thread = threading.Thread() thread = threading.Thread()
thread.start() thread.start()
self.assertRaises(RuntimeError, setattr, thread, "daemon", True) self.assertRaises(RuntimeError, setattr, thread, "daemon", True)
thread.join()
def test_releasing_unacquired_lock(self): def test_releasing_unacquired_lock(self):
lock = threading.Lock() lock = threading.Lock()
...@@ -1079,6 +1118,8 @@ class ThreadingExceptionTests(BaseTestCase): ...@@ -1079,6 +1118,8 @@ class ThreadingExceptionTests(BaseTestCase):
thread.join() thread.join()
self.assertIsNotNone(thread.exc) self.assertIsNotNone(thread.exc)
self.assertIsInstance(thread.exc, RuntimeError) self.assertIsInstance(thread.exc, RuntimeError)
# explicitly break the reference cycle to not leak a dangling thread
thread.exc = None
class TimerTests(BaseTestCase): class TimerTests(BaseTestCase):
...@@ -1101,6 +1142,8 @@ class TimerTests(BaseTestCase): ...@@ -1101,6 +1142,8 @@ class TimerTests(BaseTestCase):
self.callback_event.wait() self.callback_event.wait()
self.assertEqual(len(self.callback_args), 2) self.assertEqual(len(self.callback_args), 2)
self.assertEqual(self.callback_args, [((), {}), ((), {})]) self.assertEqual(self.callback_args, [((), {}), ((), {})])
timer1.join()
timer2.join()
def _callback_spy(self, *args, **kwargs): def _callback_spy(self, *args, **kwargs):
self.callback_args.append((args[:], kwargs.copy())) self.callback_args.append((args[:], kwargs.copy()))
...@@ -1127,10 +1170,6 @@ class CRLockTests(lock_tests.RLockTests): ...@@ -1127,10 +1170,6 @@ class CRLockTests(lock_tests.RLockTests):
class EventTests(lock_tests.EventTests): class EventTests(lock_tests.EventTests):
eventtype = staticmethod(threading.Event) eventtype = staticmethod(threading.Event)
@unittest.skip("not on gevent")
def test_reset_internal_locks(self):
pass
class ConditionAsRLockTests(lock_tests.RLockTests): class ConditionAsRLockTests(lock_tests.RLockTests):
# Condition uses an RLock by default and exports its API. # Condition uses an RLock by default and exports its API.
locktype = staticmethod(threading.Condition) locktype = staticmethod(threading.Condition)
......
...@@ -16,6 +16,7 @@ except ImportError: ...@@ -16,6 +16,7 @@ except ImportError:
ssl = None ssl = None
import sys import sys
import tempfile import tempfile
import warnings
from nturl2path import url2pathname, pathname2url from nturl2path import url2pathname, pathname2url
from base64 import b64encode from base64 import b64encode
...@@ -206,6 +207,7 @@ class urlopen_FileTests(unittest.TestCase): ...@@ -206,6 +207,7 @@ class urlopen_FileTests(unittest.TestCase):
def test_relativelocalfile(self): def test_relativelocalfile(self):
self.assertRaises(ValueError,urllib.request.urlopen,'./' + self.pathname) self.assertRaises(ValueError,urllib.request.urlopen,'./' + self.pathname)
class ProxyTests(unittest.TestCase): class ProxyTests(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -259,6 +261,7 @@ class ProxyTests(unittest.TestCase): ...@@ -259,6 +261,7 @@ class ProxyTests(unittest.TestCase):
self.assertFalse(bypass('newdomain.com')) # no port self.assertFalse(bypass('newdomain.com')) # no port
self.assertFalse(bypass('newdomain.com:1235')) # wrong port self.assertFalse(bypass('newdomain.com:1235')) # wrong port
class ProxyTests_withOrderedEnv(unittest.TestCase): class ProxyTests_withOrderedEnv(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -294,6 +297,7 @@ class ProxyTests_withOrderedEnv(unittest.TestCase): ...@@ -294,6 +297,7 @@ class ProxyTests_withOrderedEnv(unittest.TestCase):
proxies = urllib.request.getproxies_environment() proxies = urllib.request.getproxies_environment()
self.assertEqual('http://somewhere:3128', proxies['http']) self.assertEqual('http://somewhere:3128', proxies['http'])
class urlopen_HttpTests(unittest.TestCase, FakeHTTPMixin, FakeFTPMixin): class urlopen_HttpTests(unittest.TestCase, FakeHTTPMixin, FakeFTPMixin):
"""Test urlopen() opening a fake http connection.""" """Test urlopen() opening a fake http connection."""
...@@ -326,6 +330,59 @@ class urlopen_HttpTests(unittest.TestCase, FakeHTTPMixin, FakeFTPMixin): ...@@ -326,6 +330,59 @@ class urlopen_HttpTests(unittest.TestCase, FakeHTTPMixin, FakeFTPMixin):
finally: finally:
self.unfakehttp() self.unfakehttp()
@unittest.skipUnless(ssl, "ssl module required")
def test_url_with_control_char_rejected(self):
for char_no in list(range(0, 0x21)) + [0x7f]:
char = chr(char_no)
schemeless_url = f"//localhost:7777/test{char}/"
self.fakehttp(b"HTTP/1.1 200 OK\r\n\r\nHello.")
try:
# We explicitly test urllib.request.urlopen() instead of the top
# level 'def urlopen()' function defined in this... (quite ugly)
# test suite. They use different url opening codepaths. Plain
# urlopen uses FancyURLOpener which goes via a codepath that
# calls urllib.parse.quote() on the URL which makes all of the
# above attempts at injection within the url _path_ safe.
escaped_char_repr = repr(char).replace('\\', r'\\')
InvalidURL = http.client.InvalidURL
with self.assertRaisesRegex(
InvalidURL, f"contain control.*{escaped_char_repr}"):
urllib.request.urlopen(f"http:{schemeless_url}")
with self.assertRaisesRegex(
InvalidURL, f"contain control.*{escaped_char_repr}"):
urllib.request.urlopen(f"https:{schemeless_url}")
# This code path quotes the URL so there is no injection.
resp = urlopen(f"http:{schemeless_url}")
self.assertNotIn(char, resp.geturl())
finally:
self.unfakehttp()
@unittest.skipUnless(ssl, "ssl module required")
def test_url_with_newline_header_injection_rejected(self):
self.fakehttp(b"HTTP/1.1 200 OK\r\n\r\nHello.")
host = "localhost:7777?a=1 HTTP/1.1\r\nX-injected: header\r\nTEST: 123"
schemeless_url = "//" + host + ":8080/test/?test=a"
try:
# We explicitly test urllib.request.urlopen() instead of the top
# level 'def urlopen()' function defined in this... (quite ugly)
# test suite. They use different url opening codepaths. Plain
# urlopen uses FancyURLOpener which goes via a codepath that
# calls urllib.parse.quote() on the URL which makes all of the
# above attempts at injection within the url _path_ safe.
InvalidURL = http.client.InvalidURL
with self.assertRaisesRegex(
InvalidURL, r"contain control.*\\r.*(found at least . .)"):
urllib.request.urlopen(f"http:{schemeless_url}")
with self.assertRaisesRegex(InvalidURL, r"contain control.*\\n"):
urllib.request.urlopen(f"https:{schemeless_url}")
# This code path quotes the URL so there is no injection.
resp = urlopen(f"http:{schemeless_url}")
self.assertNotIn(' ', resp.geturl())
self.assertNotIn('\r', resp.geturl())
self.assertNotIn('\n', resp.geturl())
finally:
self.unfakehttp()
def test_read_0_9(self): def test_read_0_9(self):
# "0.9" response accepted (but not "simple responses" without # "0.9" response accepted (but not "simple responses" without
# a status line) # a status line)
...@@ -432,7 +489,6 @@ Connection: close ...@@ -432,7 +489,6 @@ Connection: close
finally: finally:
self.unfakeftp() self.unfakeftp()
def test_userpass_inurl(self): def test_userpass_inurl(self):
self.fakehttp(b"HTTP/1.0 200 OK\r\n\r\nHello!") self.fakehttp(b"HTTP/1.0 200 OK\r\n\r\nHello!")
try: try:
...@@ -476,6 +532,7 @@ Connection: close ...@@ -476,6 +532,7 @@ Connection: close
"https://localhost", cafile="/nonexistent/path", context=context "https://localhost", cafile="/nonexistent/path", context=context
) )
class urlopen_DataTests(unittest.TestCase): class urlopen_DataTests(unittest.TestCase):
"""Test urlopen() opening a data URL.""" """Test urlopen() opening a data URL."""
...@@ -549,6 +606,7 @@ class urlopen_DataTests(unittest.TestCase): ...@@ -549,6 +606,7 @@ class urlopen_DataTests(unittest.TestCase):
# missing padding character # missing padding character
self.assertRaises(ValueError,urllib.request.urlopen,'data:;base64,Cg=') self.assertRaises(ValueError,urllib.request.urlopen,'data:;base64,Cg=')
class urlretrieve_FileTests(unittest.TestCase): class urlretrieve_FileTests(unittest.TestCase):
"""Test urllib.urlretrieve() on local files""" """Test urllib.urlretrieve() on local files"""
...@@ -1406,6 +1464,23 @@ class URLopener_Tests(unittest.TestCase): ...@@ -1406,6 +1464,23 @@ class URLopener_Tests(unittest.TestCase):
"spam://c:|windows%/:=&?~#+!$,;'@()*[]|/path/"), "spam://c:|windows%/:=&?~#+!$,;'@()*[]|/path/"),
"//c:|windows%/:=&?~#+!$,;'@()*[]|/path/") "//c:|windows%/:=&?~#+!$,;'@()*[]|/path/")
def test_local_file_open(self):
# bpo-35907, CVE-2019-9948: urllib must reject local_file:// scheme
class DummyURLopener(urllib.request.URLopener):
def open_local_file(self, url):
return url
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", DeprecationWarning)
for url in ('local_file://example', 'local-file://example'):
self.assertRaises(OSError, urllib.request.urlopen, url)
self.assertRaises(OSError, urllib.request.URLopener().open, url)
self.assertRaises(OSError, urllib.request.URLopener().retrieve, url)
self.assertRaises(OSError, DummyURLopener().open, url)
self.assertRaises(OSError, DummyURLopener().retrieve, url)
# Just commented them out. # Just commented them out.
# Can't really tell why keep failing in windows and sparc. # Can't really tell why keep failing in windows and sparc.
# Everywhere else they work ok, but on those machines, sometimes # Everywhere else they work ok, but on those machines, sometimes
......
...@@ -141,44 +141,55 @@ class RequestHdrsTests(unittest.TestCase): ...@@ -141,44 +141,55 @@ class RequestHdrsTests(unittest.TestCase):
mgr = urllib.request.HTTPPasswordMgr() mgr = urllib.request.HTTPPasswordMgr()
add = mgr.add_password add = mgr.add_password
find_user_pass = mgr.find_user_password find_user_pass = mgr.find_user_password
add("Some Realm", "http://example.com/", "joe", "password") add("Some Realm", "http://example.com/", "joe", "password")
add("Some Realm", "http://example.com/ni", "ni", "ni") add("Some Realm", "http://example.com/ni", "ni", "ni")
add("c", "http://example.com/foo", "foo", "ni")
add("c", "http://example.com/bar", "bar", "nini")
add("b", "http://example.com/", "first", "blah")
add("b", "http://example.com/", "second", "spam")
add("a", "http://example.com", "1", "a")
add("Some Realm", "http://c.example.com:3128", "3", "c") add("Some Realm", "http://c.example.com:3128", "3", "c")
add("Some Realm", "d.example.com", "4", "d") add("Some Realm", "d.example.com", "4", "d")
add("Some Realm", "e.example.com:3128", "5", "e") add("Some Realm", "e.example.com:3128", "5", "e")
# For the same realm, password set the highest path is the winner.
self.assertEqual(find_user_pass("Some Realm", "example.com"), self.assertEqual(find_user_pass("Some Realm", "example.com"),
('joe', 'password')) ('joe', 'password'))
self.assertEqual(find_user_pass("Some Realm", "http://example.com/ni"),
#self.assertEqual(find_user_pass("Some Realm", "http://example.com/ni"), ('joe', 'password'))
# ('ni', 'ni'))
self.assertEqual(find_user_pass("Some Realm", "http://example.com"), self.assertEqual(find_user_pass("Some Realm", "http://example.com"),
('joe', 'password')) ('joe', 'password'))
self.assertEqual(find_user_pass("Some Realm", "http://example.com/"), self.assertEqual(find_user_pass("Some Realm", "http://example.com/"),
('joe', 'password')) ('joe', 'password'))
self.assertEqual( self.assertEqual(find_user_pass("Some Realm",
find_user_pass("Some Realm", "http://example.com/spam"), "http://example.com/spam"),
('joe', 'password')) ('joe', 'password'))
self.assertEqual(
find_user_pass("Some Realm", "http://example.com/spam/spam"), self.assertEqual(find_user_pass("Some Realm",
('joe', 'password')) "http://example.com/spam/spam"),
('joe', 'password'))
# You can have different passwords for different paths.
add("c", "http://example.com/foo", "foo", "ni")
add("c", "http://example.com/bar", "bar", "nini")
self.assertEqual(find_user_pass("c", "http://example.com/foo"), self.assertEqual(find_user_pass("c", "http://example.com/foo"),
('foo', 'ni')) ('foo', 'ni'))
self.assertEqual(find_user_pass("c", "http://example.com/bar"), self.assertEqual(find_user_pass("c", "http://example.com/bar"),
('bar', 'nini')) ('bar', 'nini'))
# For the same path, newer password should be considered.
add("b", "http://example.com/", "first", "blah")
add("b", "http://example.com/", "second", "spam")
self.assertEqual(find_user_pass("b", "http://example.com/"), self.assertEqual(find_user_pass("b", "http://example.com/"),
('second', 'spam')) ('second', 'spam'))
# No special relationship between a.example.com and example.com: # No special relationship between a.example.com and example.com:
add("a", "http://example.com", "1", "a")
self.assertEqual(find_user_pass("a", "http://example.com/"), self.assertEqual(find_user_pass("a", "http://example.com/"),
('1', 'a')) ('1', 'a'))
self.assertEqual(find_user_pass("a", "http://a.example.com/"), self.assertEqual(find_user_pass("a", "http://a.example.com/"),
(None, None)) (None, None))
...@@ -830,7 +841,6 @@ class HandlerTests(unittest.TestCase): ...@@ -830,7 +841,6 @@ class HandlerTests(unittest.TestCase):
for url, ftp in [ for url, ftp in [
("file://ftp.example.com//foo.txt", False), ("file://ftp.example.com//foo.txt", False),
("file://ftp.example.com///foo.txt", False), ("file://ftp.example.com///foo.txt", False),
# XXXX bug: fails with OSError, should be URLError
("file://ftp.example.com/foo.txt", False), ("file://ftp.example.com/foo.txt", False),
("file://somehost//foo/something.txt", False), ("file://somehost//foo/something.txt", False),
("file://localhost//foo/something.txt", False), ("file://localhost//foo/something.txt", False),
...@@ -838,8 +848,7 @@ class HandlerTests(unittest.TestCase): ...@@ -838,8 +848,7 @@ class HandlerTests(unittest.TestCase):
req = Request(url) req = Request(url)
try: try:
h.file_open(req) h.file_open(req)
# XXXX remove OSError when bug fixed except urllib.error.URLError:
except (urllib.error.URLError, OSError):
self.assertFalse(ftp) self.assertFalse(ftp)
else: else:
self.assertIs(o.req, req) self.assertIs(o.req, req)
...@@ -1414,7 +1423,6 @@ class HandlerTests(unittest.TestCase): ...@@ -1414,7 +1423,6 @@ class HandlerTests(unittest.TestCase):
self.assertEqual(req.host, "proxy.example.com:3128") self.assertEqual(req.host, "proxy.example.com:3128")
self.assertEqual(req.get_header("Proxy-authorization"), "FooBar") self.assertEqual(req.get_header("Proxy-authorization"), "FooBar")
# TODO: This should be only for OSX
@unittest.skipUnless(sys.platform == 'darwin', "only relevant for OSX") @unittest.skipUnless(sys.platform == 'darwin', "only relevant for OSX")
def test_osx_proxy_bypass(self): def test_osx_proxy_bypass(self):
bypass = { bypass = {
...@@ -1690,7 +1698,6 @@ class HandlerTests(unittest.TestCase): ...@@ -1690,7 +1698,6 @@ class HandlerTests(unittest.TestCase):
self.assertTrue(conn.fakesock.closed, "Connection not closed") self.assertTrue(conn.fakesock.closed, "Connection not closed")
class MiscTests(unittest.TestCase): class MiscTests(unittest.TestCase):
def opener_has_handler(self, opener, handler_class): def opener_has_handler(self, opener, handler_class):
......
...@@ -289,11 +289,15 @@ class BasicAuthTests(unittest.TestCase): ...@@ -289,11 +289,15 @@ class BasicAuthTests(unittest.TestCase):
def http_server_with_basic_auth_handler(*args, **kwargs): def http_server_with_basic_auth_handler(*args, **kwargs):
return BasicAuthHandler(*args, **kwargs) return BasicAuthHandler(*args, **kwargs)
self.server = LoopbackHttpServerThread(http_server_with_basic_auth_handler) self.server = LoopbackHttpServerThread(http_server_with_basic_auth_handler)
self.addCleanup(self.server.stop) self.addCleanup(self.stop_server)
self.server_url = 'http://127.0.0.1:%s' % self.server.port self.server_url = 'http://127.0.0.1:%s' % self.server.port
self.server.start() self.server.start()
self.server.ready.wait() self.server.ready.wait()
def stop_server(self):
self.server.stop()
self.server = None
def tearDown(self): def tearDown(self):
super(BasicAuthTests, self).tearDown() super(BasicAuthTests, self).tearDown()
...@@ -304,7 +308,7 @@ class BasicAuthTests(unittest.TestCase): ...@@ -304,7 +308,7 @@ class BasicAuthTests(unittest.TestCase):
try: try:
self.assertTrue(urllib.request.urlopen(self.server_url)) self.assertTrue(urllib.request.urlopen(self.server_url))
except urllib.error.HTTPError: except urllib.error.HTTPError:
self.fail("Basic auth failed for the url: %s", self.server_url) self.fail("Basic auth failed for the url: %s" % self.server_url)
def test_basic_auth_httperror(self): def test_basic_auth_httperror(self):
ah = urllib.request.HTTPBasicAuthHandler() ah = urllib.request.HTTPBasicAuthHandler()
...@@ -339,6 +343,7 @@ class ProxyAuthTests(unittest.TestCase): ...@@ -339,6 +343,7 @@ class ProxyAuthTests(unittest.TestCase):
return FakeProxyHandler(self.digest_auth_handler, *args, **kwargs) return FakeProxyHandler(self.digest_auth_handler, *args, **kwargs)
self.server = LoopbackHttpServerThread(create_fake_proxy_handler) self.server = LoopbackHttpServerThread(create_fake_proxy_handler)
self.addCleanup(self.stop_server)
self.server.start() self.server.start()
self.server.ready.wait() self.server.ready.wait()
proxy_url = "http://127.0.0.1:%d" % self.server.port proxy_url = "http://127.0.0.1:%d" % self.server.port
...@@ -347,9 +352,9 @@ class ProxyAuthTests(unittest.TestCase): ...@@ -347,9 +352,9 @@ class ProxyAuthTests(unittest.TestCase):
self.opener = urllib.request.build_opener( self.opener = urllib.request.build_opener(
handler, self.proxy_digest_handler) handler, self.proxy_digest_handler)
def tearDown(self): def stop_server(self):
self.server.stop() self.server.stop()
super(ProxyAuthTests, self).tearDown() self.server = None
def test_proxy_with_bad_password_raises_httperror(self): def test_proxy_with_bad_password_raises_httperror(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL, self.proxy_digest_handler.add_password(self.REALM, self.URL,
...@@ -468,13 +473,17 @@ class TestUrlopen(unittest.TestCase): ...@@ -468,13 +473,17 @@ class TestUrlopen(unittest.TestCase):
f.close() f.close()
return b"".join(l) return b"".join(l)
def stop_server(self):
self.server.stop()
self.server = None
def start_server(self, responses=None): def start_server(self, responses=None):
if responses is None: if responses is None:
responses = [(200, [], b"we don't care")] responses = [(200, [], b"we don't care")]
handler = GetRequestHandler(responses) handler = GetRequestHandler(responses)
self.server = LoopbackHttpServerThread(handler) self.server = LoopbackHttpServerThread(handler)
self.addCleanup(self.server.stop) self.addCleanup(self.stop_server)
self.server.start() self.server.start()
self.server.ready.wait() self.server.ready.wait()
port = self.server.port port = self.server.port
...@@ -589,7 +598,7 @@ class TestUrlopen(unittest.TestCase): ...@@ -589,7 +598,7 @@ class TestUrlopen(unittest.TestCase):
def cb_sni(ssl_sock, server_name, initial_context): def cb_sni(ssl_sock, server_name, initial_context):
nonlocal sni_name nonlocal sni_name
sni_name = server_name sni_name = server_name
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) context = ssl.SSLContext(ssl.PROTOCOL_TLS)
context.set_servername_callback(cb_sni) context.set_servername_callback(cb_sni)
handler = self.start_https_server(context=context, certfile=CERT_localhost) handler = self.start_https_server(context=context, certfile=CERT_localhost)
context = ssl.create_default_context(cafile=CERT_localhost) context = ssl.create_default_context(cafile=CERT_localhost)
...@@ -664,7 +673,7 @@ def setUpModule(): ...@@ -664,7 +673,7 @@ def setUpModule():
def tearDownModule(): def tearDownModule():
if threads_key: if threads_key:
support.threading_cleanup(threads_key) support.threading_cleanup(*threads_key)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -27,6 +27,13 @@ def _wrap_with_retry_thrice(func, exc): ...@@ -27,6 +27,13 @@ def _wrap_with_retry_thrice(func, exc):
return _retry_thrice(func, exc, *args, **kwargs) return _retry_thrice(func, exc, *args, **kwargs)
return wrapped return wrapped
# bpo-35411: FTP tests of test_urllib2net randomly fail
# with "425 Security: Bad IP connecting" on Travis CI
skip_ftp_test_on_travis = unittest.skipIf('TRAVIS' in os.environ,
'bpo-35411: skip FTP test '
'on Travis CI')
# Connecting to remote hosts is flaky. Make it more robust by retrying # Connecting to remote hosts is flaky. Make it more robust by retrying
# the connection several times. # the connection several times.
_urlopen_with_retry = _wrap_with_retry_thrice(urllib.request.urlopen, _urlopen_with_retry = _wrap_with_retry_thrice(urllib.request.urlopen,
...@@ -95,10 +102,11 @@ class OtherNetworkTests(unittest.TestCase): ...@@ -95,10 +102,11 @@ class OtherNetworkTests(unittest.TestCase):
# XXX The rest of these tests aren't very good -- they don't check much. # XXX The rest of these tests aren't very good -- they don't check much.
# They do sometimes catch some major disasters, though. # They do sometimes catch some major disasters, though.
@skip_ftp_test_on_travis
def test_ftp(self): def test_ftp(self):
urls = [ urls = [
'ftp://ftp.debian.org/debian/README', 'ftp://www.pythontest.net/README',
('ftp://ftp.debian.org/debian/non-existent-file', ('ftp://www.pythontest.net/non-existent-file',
None, urllib.error.URLError), None, urllib.error.URLError),
] ]
self._test_urls(urls, self._extra_handlers()) self._test_urls(urls, self._extra_handlers())
...@@ -177,6 +185,7 @@ class OtherNetworkTests(unittest.TestCase): ...@@ -177,6 +185,7 @@ class OtherNetworkTests(unittest.TestCase):
opener.open(request) opener.open(request)
self.assertEqual(request.get_header('User-agent'),'Test-Agent') self.assertEqual(request.get_header('User-agent'),'Test-Agent')
@unittest.skip('XXX: http://www.imdb.com is gone')
def test_sites_no_connection_close(self): def test_sites_no_connection_close(self):
# Some sites do not send Connection: close header. # Some sites do not send Connection: close header.
# Verify that those work properly. (#issue12576) # Verify that those work properly. (#issue12576)
...@@ -287,8 +296,9 @@ class TimeoutTest(unittest.TestCase): ...@@ -287,8 +296,9 @@ class TimeoutTest(unittest.TestCase):
self.addCleanup(u.close) self.addCleanup(u.close)
self.assertEqual(u.fp.raw._sock.gettimeout(), 120) self.assertEqual(u.fp.raw._sock.gettimeout(), 120)
FTP_HOST = 'ftp://ftp.debian.org/debian/' FTP_HOST = 'ftp://www.pythontest.net/'
@skip_ftp_test_on_travis
def test_ftp_basic(self): def test_ftp_basic(self):
self.assertIsNone(socket.getdefaulttimeout()) self.assertIsNone(socket.getdefaulttimeout())
with support.transient_internet(self.FTP_HOST, timeout=None): with support.transient_internet(self.FTP_HOST, timeout=None):
...@@ -296,6 +306,7 @@ class TimeoutTest(unittest.TestCase): ...@@ -296,6 +306,7 @@ class TimeoutTest(unittest.TestCase):
self.addCleanup(u.close) self.addCleanup(u.close)
self.assertIsNone(u.fp.fp.raw._sock.gettimeout()) self.assertIsNone(u.fp.fp.raw._sock.gettimeout())
@skip_ftp_test_on_travis
def test_ftp_default_timeout(self): def test_ftp_default_timeout(self):
self.assertIsNone(socket.getdefaulttimeout()) self.assertIsNone(socket.getdefaulttimeout())
with support.transient_internet(self.FTP_HOST): with support.transient_internet(self.FTP_HOST):
...@@ -307,6 +318,7 @@ class TimeoutTest(unittest.TestCase): ...@@ -307,6 +318,7 @@ class TimeoutTest(unittest.TestCase):
socket.setdefaulttimeout(None) socket.setdefaulttimeout(None)
self.assertEqual(u.fp.fp.raw._sock.gettimeout(), 60) self.assertEqual(u.fp.fp.raw._sock.gettimeout(), 60)
@skip_ftp_test_on_travis
def test_ftp_no_timeout(self): def test_ftp_no_timeout(self):
self.assertIsNone(socket.getdefaulttimeout()) self.assertIsNone(socket.getdefaulttimeout())
with support.transient_internet(self.FTP_HOST): with support.transient_internet(self.FTP_HOST):
...@@ -318,6 +330,7 @@ class TimeoutTest(unittest.TestCase): ...@@ -318,6 +330,7 @@ class TimeoutTest(unittest.TestCase):
socket.setdefaulttimeout(None) socket.setdefaulttimeout(None)
self.assertIsNone(u.fp.fp.raw._sock.gettimeout()) self.assertIsNone(u.fp.fp.raw._sock.gettimeout())
@skip_ftp_test_on_travis
def test_ftp_timeout(self): def test_ftp_timeout(self):
with support.transient_internet(self.FTP_HOST): with support.transient_internet(self.FTP_HOST):
u = _urlopen_with_retry(self.FTP_HOST, timeout=60) u = _urlopen_with_retry(self.FTP_HOST, timeout=60)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment