Commit fbbd9a3c authored by Denis Bilenko's avatar Denis Bilenko

include tests from stdlib/2.7.3

that way test runner does not fail on ubuntu where this tests are not available
a new test__monkey_patching.py is added which runs all stdlib tests
parent 626f8bd8
-----BEGIN RSA PRIVATE KEY-----
MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L
opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH
fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB
AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU
D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA
IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM
oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0
ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/
loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j
oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA
z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq
ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV
q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
Just bad cert data
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L
opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH
fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB
AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU
D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA
IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM
oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0
ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/
loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j
oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA
z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq
ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV
q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
Just bad cert data
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
Bad Key, though the cert should be OK
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD
VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x
IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT
U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1
NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl
bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj
aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh
m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8
M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn
fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC
AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb
08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx
CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/
iHkC6gGdBJhogs4=
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
Bad Key, though the cert should be OK
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD
VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x
IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT
U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1
NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl
bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj
aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh
m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8
M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn
fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC
AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb
08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx
CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/
iHkC6gGdBJhogs4=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIHPTCCBSWgAwIBAgIBADANBgkqhkiG9w0BAQQFADB5MRAwDgYDVQQKEwdSb290
IENBMR4wHAYDVQQLExVodHRwOi8vd3d3LmNhY2VydC5vcmcxIjAgBgNVBAMTGUNB
IENlcnQgU2lnbmluZyBBdXRob3JpdHkxITAfBgkqhkiG9w0BCQEWEnN1cHBvcnRA
Y2FjZXJ0Lm9yZzAeFw0wMzAzMzAxMjI5NDlaFw0zMzAzMjkxMjI5NDlaMHkxEDAO
BgNVBAoTB1Jvb3QgQ0ExHjAcBgNVBAsTFWh0dHA6Ly93d3cuY2FjZXJ0Lm9yZzEi
MCAGA1UEAxMZQ0EgQ2VydCBTaWduaW5nIEF1dGhvcml0eTEhMB8GCSqGSIb3DQEJ
ARYSc3VwcG9ydEBjYWNlcnQub3JnMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIIC
CgKCAgEAziLA4kZ97DYoB1CW8qAzQIxL8TtmPzHlawI229Z89vGIj053NgVBlfkJ
8BLPRoZzYLdufujAWGSuzbCtRRcMY/pnCujW0r8+55jE8Ez64AO7NV1sId6eINm6
zWYyN3L69wj1x81YyY7nDl7qPv4coRQKFWyGhFtkZip6qUtTefWIonvuLwphK42y
fk1WpRPs6tqSnqxEQR5YYGUFZvjARL3LlPdCfgv3ZWiYUQXw8wWRBB0bF4LsyFe7
w2t6iPGwcswlWyCR7BYCEo8y6RcYSNDHBS4CMEK4JZwFaz+qOqfrU0j36NK2B5jc
G8Y0f3/JHIJ6BVgrCFvzOKKrF11myZjXnhCLotLddJr3cQxyYN/Nb5gznZY0dj4k
epKwDpUeb+agRThHqtdB7Uq3EvbXG4OKDy7YCbZZ16oE/9KTfWgu3YtLq1i6L43q
laegw1SJpfvbi1EinbLDvhG+LJGGi5Z4rSDTii8aP8bQUWWHIbEZAWV/RRyH9XzQ
QUxPKZgh/TMfdQwEUfoZd9vUFBzugcMd9Zi3aQaRIt0AUMyBMawSB3s42mhb5ivU
fslfrejrckzzAeVLIL+aplfKkQABi6F1ITe1Yw1nPkZPcCBnzsXWWdsC4PDSy826
YreQQejdIOQpvGQpQsgi3Hia/0PsmBsJUUtaWsJx8cTLc6nloQsCAwEAAaOCAc4w
ggHKMB0GA1UdDgQWBBQWtTIb1Mfz4OaO873SsDrusjkY0TCBowYDVR0jBIGbMIGY
gBQWtTIb1Mfz4OaO873SsDrusjkY0aF9pHsweTEQMA4GA1UEChMHUm9vdCBDQTEe
MBwGA1UECxMVaHR0cDovL3d3dy5jYWNlcnQub3JnMSIwIAYDVQQDExlDQSBDZXJ0
IFNpZ25pbmcgQXV0aG9yaXR5MSEwHwYJKoZIhvcNAQkBFhJzdXBwb3J0QGNhY2Vy
dC5vcmeCAQAwDwYDVR0TAQH/BAUwAwEB/zAyBgNVHR8EKzApMCegJaAjhiFodHRw
czovL3d3dy5jYWNlcnQub3JnL3Jldm9rZS5jcmwwMAYJYIZIAYb4QgEEBCMWIWh0
dHBzOi8vd3d3LmNhY2VydC5vcmcvcmV2b2tlLmNybDA0BglghkgBhvhCAQgEJxYl
aHR0cDovL3d3dy5jYWNlcnQub3JnL2luZGV4LnBocD9pZD0xMDBWBglghkgBhvhC
AQ0ESRZHVG8gZ2V0IHlvdXIgb3duIGNlcnRpZmljYXRlIGZvciBGUkVFIGhlYWQg
b3ZlciB0byBodHRwOi8vd3d3LmNhY2VydC5vcmcwDQYJKoZIhvcNAQEEBQADggIB
ACjH7pyCArpcgBLKNQodgW+JapnM8mgPf6fhjViVPr3yBsOQWqy1YPaZQwGjiHCc
nWKdpIevZ1gNMDY75q1I08t0AoZxPuIrA2jxNGJARjtT6ij0rPtmlVOKTV39O9lg
18p5aTuxZZKmxoGCXJzN600BiqXfEVWqFcofN8CCmHBh22p8lqOOLlQ+TyGpkO/c
gr/c6EWtTZBzCDyUZbAEmXZ/4rzCahWqlwQ3JNgelE5tDlG+1sSPypZt90Pf6DBl
Jzt7u0NDY8RD97LsaMzhGY4i+5jhe1o+ATc7iwiwovOVThrLm82asduycPAtStvY
sONvRUgzEv/+PDIqVPfE94rwiCPCR/5kenHA0R6mY7AHfqQv0wGP3J8rtsYIqQ+T
SCX8Ev2fQtzzxD72V7DX3WnRBnc0CkvSyqD/HMaMyRa+xMwyN2hzXwj7UfdJUzYF
CpUCTPJ5GhD22Dp1nPMd8aINcGeGG7MW9S/lpOt5hvk9C8JzC6WZrG/8Z7jlLwum
GCSNe9FINSkYQKyTYOGWhlC0elnYjyELn8+CkcY7v2vcB5G5l1YjqrZslMZIBjzk
zk6q5PYvCdxTby78dOs6Y5nCpqyJvKeyRKANihDjbPIky/qbn3BHLt4Ui9SyIAmW
omTxJBzcoTWcFbLUvFUufQb1nA5V9FrWk9p2rSVzTMVD
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIICXwIBAAKBgQC8ddrhm+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9L
opdJhTvbGfEj0DQs1IE8M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVH
fhi/VwovESJlaBOp+WMnfhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQAB
AoGBAK0FZpaKj6WnJZN0RqhhK+ggtBWwBnc0U/ozgKz2j1s3fsShYeiGtW6CK5nU
D1dZ5wzhbGThI7LiOXDvRucc9n7vUgi0alqPQ/PFodPxAN/eEYkmXQ7W2k7zwsDA
IUK0KUhktQbLu8qF/m8qM86ba9y9/9YkXuQbZ3COl5ahTZrhAkEA301P08RKv3KM
oXnGU2UHTuJ1MAD2hOrPxjD4/wxA/39EWG9bZczbJyggB4RHu0I3NOSFjAm3HQm0
ANOu5QK9owJBANgOeLfNNcF4pp+UikRFqxk5hULqRAWzVxVrWe85FlPm0VVmHbb/
loif7mqjU8o1jTd/LM7RD9f2usZyE2psaw8CQQCNLhkpX3KO5kKJmS9N7JMZSc4j
oog58yeYO8BBqKKzpug0LXuQultYv2K4veaIO04iL9VLe5z9S/Q1jaCHBBuXAkEA
z8gjGoi1AOp6PBBLZNsncCvcV/0aC+1se4HxTNo2+duKSDnbq+ljqOM+E7odU+Nq
ewvIWOG//e8fssd0mq3HywJBAJ8l/c8GVmrpFTx8r/nZ2Pyyjt3dH1widooDXYSV
q6Gbf41Llo5sYAtmxdndTLASuHKecacTgZVhy0FryZpLKrU=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICpzCCAhCgAwIBAgIJAP+qStv1cIGNMA0GCSqGSIb3DQEBBQUAMIGJMQswCQYD
VQQGEwJVUzERMA8GA1UECBMIRGVsYXdhcmUxEzARBgNVBAcTCldpbG1pbmd0b24x
IzAhBgNVBAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMQwwCgYDVQQLEwNT
U0wxHzAdBgNVBAMTFnNvbWVtYWNoaW5lLnB5dGhvbi5vcmcwHhcNMDcwODI3MTY1
NDUwWhcNMTMwMjE2MTY1NDUwWjCBiTELMAkGA1UEBhMCVVMxETAPBgNVBAgTCERl
bGF3YXJlMRMwEQYDVQQHEwpXaWxtaW5ndG9uMSMwIQYDVQQKExpQeXRob24gU29m
dHdhcmUgRm91bmRhdGlvbjEMMAoGA1UECxMDU1NMMR8wHQYDVQQDExZzb21lbWFj
aGluZS5weXRob24ub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC8ddrh
m+LutBvjYcQlnH21PPIseJ1JVG2HMmN2CmZk2YukO+9LopdJhTvbGfEj0DQs1IE8
M+kTUyOmuKfVrFMKwtVeCJphrAnhoz7TYOuLBSqt7lVHfhi/VwovESJlaBOp+WMn
fhcduPEYHYx/6cnVapIkZnLt30zu2um+DzA9jQIDAQABoxUwEzARBglghkgBhvhC
AQEEBAMCBkAwDQYJKoZIhvcNAQEFBQADgYEAF4Q5BVqmCOLv1n8je/Jw9K669VXb
08hyGzQhkemEBYQd6fzQ9A/1ZzHkJKb1P6yreOLSEh4KcxYPyrLRC1ll8nr5OlCx
CMhKkTnR6qBsdNV0XtdU2+N25hqW+Ma4ZeqsN/iiJVCGNOZGnvQuvCAGWF8+J/f/
iHkC6gGdBJhogs4=
-----END CERTIFICATE-----
"""
Various tests for synchronization primitives.
"""
import sys
import time
from thread import start_new_thread, get_ident
import threading
import unittest
from test import test_support as support
def _wait():
# A crude wait/yield function not relying on synchronization primitives.
time.sleep(0.01)
class Bunch(object):
"""
A bunch of threads.
"""
def __init__(self, f, n, wait_before_exit=False):
"""
Construct a bunch of `n` threads running the same function `f`.
If `wait_before_exit` is True, the threads won't terminate until
do_finish() is called.
"""
self.f = f
self.n = n
self.started = []
self.finished = []
self._can_exit = not wait_before_exit
def task():
tid = get_ident()
self.started.append(tid)
try:
f()
finally:
self.finished.append(tid)
while not self._can_exit:
_wait()
for i in range(n):
start_new_thread(task, ())
def wait_for_started(self):
while len(self.started) < self.n:
_wait()
def wait_for_finished(self):
while len(self.finished) < self.n:
_wait()
def do_finish(self):
self._can_exit = True
class BaseTestCase(unittest.TestCase):
def setUp(self):
self._threads = support.threading_setup()
def tearDown(self):
support.threading_cleanup(*self._threads)
support.reap_children()
class BaseLockTests(BaseTestCase):
"""
Tests for both recursive and non-recursive locks.
"""
def test_constructor(self):
lock = self.locktype()
del lock
def test_acquire_destroy(self):
lock = self.locktype()
lock.acquire()
del lock
def test_acquire_release(self):
lock = self.locktype()
lock.acquire()
lock.release()
del lock
def test_try_acquire(self):
lock = self.locktype()
self.assertTrue(lock.acquire(False))
lock.release()
def test_try_acquire_contended(self):
lock = self.locktype()
lock.acquire()
result = []
def f():
result.append(lock.acquire(False))
Bunch(f, 1).wait_for_finished()
self.assertFalse(result[0])
lock.release()
def test_acquire_contended(self):
lock = self.locktype()
lock.acquire()
N = 5
def f():
lock.acquire()
lock.release()
b = Bunch(f, N)
b.wait_for_started()
_wait()
self.assertEqual(len(b.finished), 0)
lock.release()
b.wait_for_finished()
self.assertEqual(len(b.finished), N)
def test_with(self):
lock = self.locktype()
def f():
lock.acquire()
lock.release()
def _with(err=None):
with lock:
if err is not None:
raise err
_with()
# Check the lock is unacquired
Bunch(f, 1).wait_for_finished()
self.assertRaises(TypeError, _with, TypeError)
# Check the lock is unacquired
Bunch(f, 1).wait_for_finished()
def test_thread_leak(self):
# The lock shouldn't leak a Thread instance when used from a foreign
# (non-threading) thread.
lock = self.locktype()
def f():
lock.acquire()
lock.release()
n = len(threading.enumerate())
# We run many threads in the hope that existing threads ids won't
# be recycled.
Bunch(f, 15).wait_for_finished()
self.assertEqual(n, len(threading.enumerate()))
class LockTests(BaseLockTests):
"""
Tests for non-recursive, weak locks
(which can be acquired and released from different threads).
"""
def test_reacquire(self):
# Lock needs to be released before re-acquiring.
lock = self.locktype()
phase = []
def f():
lock.acquire()
phase.append(None)
lock.acquire()
phase.append(None)
start_new_thread(f, ())
while len(phase) == 0:
_wait()
_wait()
self.assertEqual(len(phase), 1)
lock.release()
while len(phase) == 1:
_wait()
self.assertEqual(len(phase), 2)
def test_different_thread(self):
# Lock can be released from a different thread.
lock = self.locktype()
lock.acquire()
def f():
lock.release()
b = Bunch(f, 1)
b.wait_for_finished()
lock.acquire()
lock.release()
class RLockTests(BaseLockTests):
"""
Tests for recursive locks.
"""
def test_reacquire(self):
lock = self.locktype()
lock.acquire()
lock.acquire()
lock.release()
lock.acquire()
lock.release()
lock.release()
def test_release_unacquired(self):
# Cannot release an unacquired lock
lock = self.locktype()
self.assertRaises(RuntimeError, lock.release)
lock.acquire()
lock.acquire()
lock.release()
lock.acquire()
lock.release()
lock.release()
self.assertRaises(RuntimeError, lock.release)
def test_different_thread(self):
# Cannot release from a different thread
lock = self.locktype()
def f():
lock.acquire()
b = Bunch(f, 1, True)
try:
self.assertRaises(RuntimeError, lock.release)
finally:
b.do_finish()
def test__is_owned(self):
lock = self.locktype()
self.assertFalse(lock._is_owned())
lock.acquire()
self.assertTrue(lock._is_owned())
lock.acquire()
self.assertTrue(lock._is_owned())
result = []
def f():
result.append(lock._is_owned())
Bunch(f, 1).wait_for_finished()
self.assertFalse(result[0])
lock.release()
self.assertTrue(lock._is_owned())
lock.release()
self.assertFalse(lock._is_owned())
class EventTests(BaseTestCase):
"""
Tests for Event objects.
"""
def test_is_set(self):
evt = self.eventtype()
self.assertFalse(evt.is_set())
evt.set()
self.assertTrue(evt.is_set())
evt.set()
self.assertTrue(evt.is_set())
evt.clear()
self.assertFalse(evt.is_set())
evt.clear()
self.assertFalse(evt.is_set())
def _check_notify(self, evt):
# All threads get notified
N = 5
results1 = []
results2 = []
def f():
results1.append(evt.wait())
results2.append(evt.wait())
b = Bunch(f, N)
b.wait_for_started()
_wait()
self.assertEqual(len(results1), 0)
evt.set()
b.wait_for_finished()
self.assertEqual(results1, [True] * N)
self.assertEqual(results2, [True] * N)
def test_notify(self):
evt = self.eventtype()
self._check_notify(evt)
# Another time, after an explicit clear()
evt.set()
evt.clear()
self._check_notify(evt)
def test_timeout(self):
evt = self.eventtype()
results1 = []
results2 = []
N = 5
def f():
results1.append(evt.wait(0.0))
t1 = time.time()
r = evt.wait(0.2)
t2 = time.time()
results2.append((r, t2 - t1))
Bunch(f, N).wait_for_finished()
self.assertEqual(results1, [False] * N)
for r, dt in results2:
self.assertFalse(r)
self.assertTrue(dt >= 0.2, dt)
# The event is set
results1 = []
results2 = []
evt.set()
Bunch(f, N).wait_for_finished()
self.assertEqual(results1, [True] * N)
for r, dt in results2:
self.assertTrue(r)
class ConditionTests(BaseTestCase):
"""
Tests for condition variables.
"""
def test_acquire(self):
cond = self.condtype()
# Be default we have an RLock: the condition can be acquired multiple
# times.
cond.acquire()
cond.acquire()
cond.release()
cond.release()
lock = threading.Lock()
cond = self.condtype(lock)
cond.acquire()
self.assertFalse(lock.acquire(False))
cond.release()
self.assertTrue(lock.acquire(False))
self.assertFalse(cond.acquire(False))
lock.release()
with cond:
self.assertFalse(lock.acquire(False))
def test_unacquired_wait(self):
cond = self.condtype()
self.assertRaises(RuntimeError, cond.wait)
def test_unacquired_notify(self):
cond = self.condtype()
self.assertRaises(RuntimeError, cond.notify)
def _check_notify(self, cond):
N = 5
results1 = []
results2 = []
phase_num = 0
def f():
cond.acquire()
cond.wait()
cond.release()
results1.append(phase_num)
cond.acquire()
cond.wait()
cond.release()
results2.append(phase_num)
b = Bunch(f, N)
b.wait_for_started()
_wait()
self.assertEqual(results1, [])
# Notify 3 threads at first
cond.acquire()
cond.notify(3)
_wait()
phase_num = 1
cond.release()
while len(results1) < 3:
_wait()
self.assertEqual(results1, [1] * 3)
self.assertEqual(results2, [])
# Notify 5 threads: they might be in their first or second wait
cond.acquire()
cond.notify(5)
_wait()
phase_num = 2
cond.release()
while len(results1) + len(results2) < 8:
_wait()
self.assertEqual(results1, [1] * 3 + [2] * 2)
self.assertEqual(results2, [2] * 3)
# Notify all threads: they are all in their second wait
cond.acquire()
cond.notify_all()
_wait()
phase_num = 3
cond.release()
while len(results2) < 5:
_wait()
self.assertEqual(results1, [1] * 3 + [2] * 2)
self.assertEqual(results2, [2] * 3 + [3] * 2)
b.wait_for_finished()
def test_notify(self):
cond = self.condtype()
self._check_notify(cond)
# A second time, to check internal state is still ok.
self._check_notify(cond)
def test_timeout(self):
cond = self.condtype()
results = []
N = 5
def f():
cond.acquire()
t1 = time.time()
cond.wait(0.2)
t2 = time.time()
cond.release()
results.append(t2 - t1)
Bunch(f, N).wait_for_finished()
self.assertEqual(len(results), 5)
for dt in results:
self.assertTrue(dt >= 0.2, dt)
class BaseSemaphoreTests(BaseTestCase):
"""
Common tests for {bounded, unbounded} semaphore objects.
"""
def test_constructor(self):
self.assertRaises(ValueError, self.semtype, value = -1)
self.assertRaises(ValueError, self.semtype, value = -sys.maxint)
def test_acquire(self):
sem = self.semtype(1)
sem.acquire()
sem.release()
sem = self.semtype(2)
sem.acquire()
sem.acquire()
sem.release()
sem.release()
def test_acquire_destroy(self):
sem = self.semtype()
sem.acquire()
del sem
def test_acquire_contended(self):
sem = self.semtype(7)
sem.acquire()
N = 10
results1 = []
results2 = []
phase_num = 0
def f():
sem.acquire()
results1.append(phase_num)
sem.acquire()
results2.append(phase_num)
b = Bunch(f, 10)
b.wait_for_started()
while len(results1) + len(results2) < 6:
_wait()
self.assertEqual(results1 + results2, [0] * 6)
phase_num = 1
for i in range(7):
sem.release()
while len(results1) + len(results2) < 13:
_wait()
self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
phase_num = 2
for i in range(6):
sem.release()
while len(results1) + len(results2) < 19:
_wait()
self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
# The semaphore is still locked
self.assertFalse(sem.acquire(False))
# Final release, to let the last thread finish
sem.release()
b.wait_for_finished()
def test_try_acquire(self):
sem = self.semtype(2)
self.assertTrue(sem.acquire(False))
self.assertTrue(sem.acquire(False))
self.assertFalse(sem.acquire(False))
sem.release()
self.assertTrue(sem.acquire(False))
def test_try_acquire_contended(self):
sem = self.semtype(4)
sem.acquire()
results = []
def f():
results.append(sem.acquire(False))
results.append(sem.acquire(False))
Bunch(f, 5).wait_for_finished()
# There can be a thread switch between acquiring the semaphore and
# appending the result, therefore results will not necessarily be
# ordered.
self.assertEqual(sorted(results), [False] * 7 + [True] * 3 )
def test_default_value(self):
# The default initial value is 1.
sem = self.semtype()
sem.acquire()
def f():
sem.acquire()
sem.release()
b = Bunch(f, 1)
b.wait_for_started()
_wait()
self.assertFalse(b.finished)
sem.release()
b.wait_for_finished()
def test_with(self):
sem = self.semtype(2)
def _with(err=None):
with sem:
self.assertTrue(sem.acquire(False))
sem.release()
with sem:
self.assertFalse(sem.acquire(False))
if err:
raise err
_with()
self.assertTrue(sem.acquire(False))
sem.release()
self.assertRaises(TypeError, _with, TypeError)
self.assertTrue(sem.acquire(False))
sem.release()
class SemaphoreTests(BaseSemaphoreTests):
"""
Tests for unbounded semaphores.
"""
def test_release_unacquired(self):
# Unbounded releases are allowed and increment the semaphore's value
sem = self.semtype(1)
sem.release()
sem.acquire()
sem.acquire()
sem.release()
class BoundedSemaphoreTests(BaseSemaphoreTests):
"""
Tests for bounded semaphores.
"""
def test_release_unacquired(self):
# Cannot go past the initial value
sem = self.semtype()
self.assertRaises(ValueError, sem.release)
sem.acquire()
sem.release()
self.assertRaises(ValueError, sem.release)
# Certificate for projects.developer.nokia.com:443 (see issue 13034)
-----BEGIN CERTIFICATE-----
MIIFLDCCBBSgAwIBAgIQLubqdkCgdc7lAF9NfHlUmjANBgkqhkiG9w0BAQUFADCB
vDELMAkGA1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMR8wHQYDVQQL
ExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMTswOQYDVQQLEzJUZXJtcyBvZiB1c2Ug
YXQgaHR0cHM6Ly93d3cudmVyaXNpZ24uY29tL3JwYSAoYykxMDE2MDQGA1UEAxMt
VmVyaVNpZ24gQ2xhc3MgMyBJbnRlcm5hdGlvbmFsIFNlcnZlciBDQSAtIEczMB4X
DTExMDkyMTAwMDAwMFoXDTEyMDkyMDIzNTk1OVowcTELMAkGA1UEBhMCRkkxDjAM
BgNVBAgTBUVzcG9vMQ4wDAYDVQQHFAVFc3BvbzEOMAwGA1UEChQFTm9raWExCzAJ
BgNVBAsUAkJJMSUwIwYDVQQDFBxwcm9qZWN0cy5kZXZlbG9wZXIubm9raWEuY29t
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCr92w1bpHYSYxUEx8N/8Iddda2
lYi+aXNtQfV/l2Fw9Ykv3Ipw4nLeGTj18FFlAZgMdPRlgrzF/NNXGw/9l3/qKdow
CypkQf8lLaxb9Ze1E/KKmkRJa48QTOqvo6GqKuTI6HCeGlG1RxDb8YSKcQWLiytn
yj3Wp4MgRQO266xmMQIDAQABo4IB9jCCAfIwQQYDVR0RBDowOIIccHJvamVjdHMu
ZGV2ZWxvcGVyLm5va2lhLmNvbYIYcHJvamVjdHMuZm9ydW0ubm9raWEuY29tMAkG
A1UdEwQCMAAwCwYDVR0PBAQDAgWgMEEGA1UdHwQ6MDgwNqA0oDKGMGh0dHA6Ly9T
VlJJbnRsLUczLWNybC52ZXJpc2lnbi5jb20vU1ZSSW50bEczLmNybDBEBgNVHSAE
PTA7MDkGC2CGSAGG+EUBBxcDMCowKAYIKwYBBQUHAgEWHGh0dHBzOi8vd3d3LnZl
cmlzaWduLmNvbS9ycGEwKAYDVR0lBCEwHwYJYIZIAYb4QgQBBggrBgEFBQcDAQYI
KwYBBQUHAwIwcgYIKwYBBQUHAQEEZjBkMCQGCCsGAQUFBzABhhhodHRwOi8vb2Nz
cC52ZXJpc2lnbi5jb20wPAYIKwYBBQUHMAKGMGh0dHA6Ly9TVlJJbnRsLUczLWFp
YS52ZXJpc2lnbi5jb20vU1ZSSW50bEczLmNlcjBuBggrBgEFBQcBDARiMGChXqBc
MFowWDBWFglpbWFnZS9naWYwITAfMAcGBSsOAwIaBBRLa7kolgYMu9BSOJsprEsH
iyEFGDAmFiRodHRwOi8vbG9nby52ZXJpc2lnbi5jb20vdnNsb2dvMS5naWYwDQYJ
KoZIhvcNAQEFBQADggEBACQuPyIJqXwUyFRWw9x5yDXgMW4zYFopQYOw/ItRY522
O5BsySTh56BWS6mQB07XVfxmYUGAvRQDA5QHpmY8jIlNwSmN3s8RKo+fAtiNRlcL
x/mWSfuMs3D/S6ev3D6+dpEMZtjrhOdctsarMKp8n/hPbwhAbg5hVjpkW5n8vz2y
0KxvvkA1AxpLwpVv7OlK17ttzIHw8bp9HTlHBU5s8bKz4a565V/a5HI0CSEv/+0y
ko4/ghTnZc1CkmUngKKeFMSah/mT/xAh8XnE2l1AazFa8UKuYki1e+ArHaGZc4ix
UYOtiRphwfuYQhRZ7qX9q2MMkCMI65XNK/SaFrAbbG0=
-----END CERTIFICATE-----
# Certificate chain for https://sha256.tbs-internet.com
0 s:/C=FR/postalCode=14000/ST=Calvados/L=CAEN/street=22 rue de Bretagne/O=TBS INTERNET/OU=0002 440443810/OU=sha-256 production/CN=sha256.tbs-internet.com
i:/C=FR/ST=Calvados/L=Caen/O=TBS INTERNET/OU=Terms and Conditions: http://www.tbs-internet.com/CA/repository/OU=TBS INTERNET CA/CN=TBS X509 CA SGC
-----BEGIN CERTIFICATE-----
MIIGXTCCBUWgAwIBAgIRAMmag+ygSAdxZsbyzYjhuW0wDQYJKoZIhvcNAQELBQAw
gcQxCzAJBgNVBAYTAkZSMREwDwYDVQQIEwhDYWx2YWRvczENMAsGA1UEBxMEQ2Fl
bjEVMBMGA1UEChMMVEJTIElOVEVSTkVUMUgwRgYDVQQLEz9UZXJtcyBhbmQgQ29u
ZGl0aW9uczogaHR0cDovL3d3dy50YnMtaW50ZXJuZXQuY29tL0NBL3JlcG9zaXRv
cnkxGDAWBgNVBAsTD1RCUyBJTlRFUk5FVCBDQTEYMBYGA1UEAxMPVEJTIFg1MDkg
Q0EgU0dDMB4XDTEwMDIxODAwMDAwMFoXDTEyMDIxOTIzNTk1OVowgcsxCzAJBgNV
BAYTAkZSMQ4wDAYDVQQREwUxNDAwMDERMA8GA1UECBMIQ2FsdmFkb3MxDTALBgNV
BAcTBENBRU4xGzAZBgNVBAkTEjIyIHJ1ZSBkZSBCcmV0YWduZTEVMBMGA1UEChMM
VEJTIElOVEVSTkVUMRcwFQYDVQQLEw4wMDAyIDQ0MDQ0MzgxMDEbMBkGA1UECxMS
c2hhLTI1NiBwcm9kdWN0aW9uMSAwHgYDVQQDExdzaGEyNTYudGJzLWludGVybmV0
LmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKbuM8VT7f0nntwu
N3F7v9KIBlhKNAxqCrziOXU5iqUt8HrQB3DtHbdmII+CpVUlwlmepsx6G+srEZ9a
MIGAy0nxi5aLb7watkyIdPjJTMvTUBQ/+RPWzt5JtYbbY9BlJ+yci0dctP74f4NU
ISLtlrEjUbf2gTohLrcE01TfmOF6PDEbB5PKDi38cB3NzKfizWfrOaJW6Q1C1qOJ
y4/4jkUREX1UFUIxzx7v62VfjXSGlcjGpBX1fvtABQOSLeE0a6gciDZs1REqroFf
5eXtqYphpTa14Z83ITXMfgg5Nze1VtMnzI9Qx4blYBw4dgQVEuIsYr7FDBOITDzc
VEVXZx0CAwEAAaOCAj8wggI7MB8GA1UdIwQYMBaAFAdEdoWTKLx/bXjSCuv6TEvf
2YIfMB0GA1UdDgQWBBSJKI/AYVI9RQNY0QPIqc8ej2QivTAOBgNVHQ8BAf8EBAMC
BaAwDAYDVR0TAQH/BAIwADA0BgNVHSUELTArBggrBgEFBQcDAQYIKwYBBQUHAwIG
CisGAQQBgjcKAwMGCWCGSAGG+EIEATBMBgNVHSAERTBDMEEGCysGAQQBgOU3AgQB
MDIwMAYIKwYBBQUHAgEWJGh0dHBzOi8vd3d3LnRicy1pbnRlcm5ldC5jb20vQ0Ev
Q1BTNDBtBgNVHR8EZjBkMDKgMKAuhixodHRwOi8vY3JsLnRicy1pbnRlcm5ldC5j
b20vVEJTWDUwOUNBU0dDLmNybDAuoCygKoYoaHR0cDovL2NybC50YnMteDUwOS5j
b20vVEJTWDUwOUNBU0dDLmNybDCBpgYIKwYBBQUHAQEEgZkwgZYwOAYIKwYBBQUH
MAKGLGh0dHA6Ly9jcnQudGJzLWludGVybmV0LmNvbS9UQlNYNTA5Q0FTR0MuY3J0
MDQGCCsGAQUFBzAChihodHRwOi8vY3J0LnRicy14NTA5LmNvbS9UQlNYNTA5Q0FT
R0MuY3J0MCQGCCsGAQUFBzABhhhodHRwOi8vb2NzcC50YnMteDUwOS5jb20wPwYD
VR0RBDgwNoIXc2hhMjU2LnRicy1pbnRlcm5ldC5jb22CG3d3dy5zaGEyNTYudGJz
LWludGVybmV0LmNvbTANBgkqhkiG9w0BAQsFAAOCAQEAA5NL0D4QSqhErhlkdPmz
XtiMvdGL+ZehM4coTRIpasM/Agt36Rc0NzCvnQwKE+wkngg1Gy2qe7Q0E/ziqBtB
fZYzdVgu1zdiL4kTaf+wFKYAFGsFbyeEmXysy+CMwaNoF2vpSjCU1UD56bEnTX/W
fxVZYxtBQUpnu2wOsm8cDZuZRv9XrYgAhGj9Tt6F0aVHSDGn59uwShG1+BVF/uju
SCyPTTjL1oc7YElJUzR/x4mQJYvtQI8gDIDAGEOs7v3R/gKa5EMfbUQUI4C84UbI
Yz09Jdnws/MkC/Hm1BZEqk89u7Hvfv+oHqEb0XaUo0TDfsxE0M1sMdnLb91QNQBm
UQ==
-----END CERTIFICATE-----
1 s:/C=FR/ST=Calvados/L=Caen/O=TBS INTERNET/OU=Terms and Conditions: http://www.tbs-internet.com/CA/repository/OU=TBS INTERNET CA/CN=TBS X509 CA SGC
i:/C=SE/O=AddTrust AB/OU=AddTrust External TTP Network/CN=AddTrust External CA Root
-----BEGIN CERTIFICATE-----
MIIFVjCCBD6gAwIBAgIQXpDZ0ETJMV02WTx3GTnhhTANBgkqhkiG9w0BAQUFADBv
MQswCQYDVQQGEwJTRTEUMBIGA1UEChMLQWRkVHJ1c3QgQUIxJjAkBgNVBAsTHUFk
ZFRydXN0IEV4dGVybmFsIFRUUCBOZXR3b3JrMSIwIAYDVQQDExlBZGRUcnVzdCBF
eHRlcm5hbCBDQSBSb290MB4XDTA1MTIwMTAwMDAwMFoXDTE5MDYyNDE5MDYzMFow
gcQxCzAJBgNVBAYTAkZSMREwDwYDVQQIEwhDYWx2YWRvczENMAsGA1UEBxMEQ2Fl
bjEVMBMGA1UEChMMVEJTIElOVEVSTkVUMUgwRgYDVQQLEz9UZXJtcyBhbmQgQ29u
ZGl0aW9uczogaHR0cDovL3d3dy50YnMtaW50ZXJuZXQuY29tL0NBL3JlcG9zaXRv
cnkxGDAWBgNVBAsTD1RCUyBJTlRFUk5FVCBDQTEYMBYGA1UEAxMPVEJTIFg1MDkg
Q0EgU0dDMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAsgOkO3f7wzN6
rOjg45tR5vjBfzK7qmV9IBxb/QW9EEXxG+E7FNhZqQLtwGBKoSsHTnQqV75wWMk0
9tinWvftBkSpj5sTi/8cbzJfUvTSVYh3Qxv6AVVjMMH/ruLjE6y+4PoaPs8WoYAQ
ts5R4Z1g8c/WnTepLst2x0/Wv7GmuoQi+gXvHU6YrBiu7XkeYhzc95QdviWSJRDk
owhb5K43qhcvjRmBfO/paGlCliDGZp8mHwrI21mwobWpVjTxZRwYO3bd4+TGcI4G
Ie5wmHwE8F7SK1tgSqbBacKjDa93j7txKkfz/Yd2n7TGqOXiHPsJpG655vrKtnXk
9vs1zoDeJQIDAQABo4IBljCCAZIwHQYDVR0OBBYEFAdEdoWTKLx/bXjSCuv6TEvf
2YIfMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMCAGA1UdJQQZ
MBcGCisGAQQBgjcKAwMGCWCGSAGG+EIEATAYBgNVHSAEETAPMA0GCysGAQQBgOU3
AgQBMHsGA1UdHwR0MHIwOKA2oDSGMmh0dHA6Ly9jcmwuY29tb2RvY2EuY29tL0Fk
ZFRydXN0RXh0ZXJuYWxDQVJvb3QuY3JsMDagNKAyhjBodHRwOi8vY3JsLmNvbW9k
by5uZXQvQWRkVHJ1c3RFeHRlcm5hbENBUm9vdC5jcmwwgYAGCCsGAQUFBwEBBHQw
cjA4BggrBgEFBQcwAoYsaHR0cDovL2NydC5jb21vZG9jYS5jb20vQWRkVHJ1c3RV
VE5TR0NDQS5jcnQwNgYIKwYBBQUHMAKGKmh0dHA6Ly9jcnQuY29tb2RvLm5ldC9B
ZGRUcnVzdFVUTlNHQ0NBLmNydDARBglghkgBhvhCAQEEBAMCAgQwDQYJKoZIhvcN
AQEFBQADggEBAK2zEzs+jcIrVK9oDkdDZNvhuBYTdCfpxfFs+OAujW0bIfJAy232
euVsnJm6u/+OrqKudD2tad2BbejLLXhMZViaCmK7D9nrXHx4te5EP8rL19SUVqLY
1pTnv5dhNgEgvA7n5lIzDSYs7yRLsr7HJsYPr6SeYSuZizyX1SNz7ooJ32/F3X98
RB0Mlc/E0OyOrkQ9/y5IrnpnaSora8CnUrV5XNOg+kyCz9edCyx4D5wXYcwZPVWz
8aDqquESrezPyjtfi4WRO4s/VD3HLZvOxzMrWAVYCDG9FxaOhF0QGuuG1F7F3GKV
v6prNyCl016kRl2j1UT+a7gLd8fA25A4C9E=
-----END CERTIFICATE-----
2 s:/C=SE/O=AddTrust AB/OU=AddTrust External TTP Network/CN=AddTrust External CA Root
i:/C=US/ST=UT/L=Salt Lake City/O=The USERTRUST Network/OU=http://www.usertrust.com/CN=UTN - DATACorp SGC
-----BEGIN CERTIFICATE-----
MIIEZjCCA06gAwIBAgIQUSYKkxzif5zDpV954HKugjANBgkqhkiG9w0BAQUFADCB
kzELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAlVUMRcwFQYDVQQHEw5TYWx0IExha2Ug
Q2l0eTEeMBwGA1UEChMVVGhlIFVTRVJUUlVTVCBOZXR3b3JrMSEwHwYDVQQLExho
dHRwOi8vd3d3LnVzZXJ0cnVzdC5jb20xGzAZBgNVBAMTElVUTiAtIERBVEFDb3Jw
IFNHQzAeFw0wNTA2MDcwODA5MTBaFw0xOTA2MjQxOTA2MzBaMG8xCzAJBgNVBAYT
AlNFMRQwEgYDVQQKEwtBZGRUcnVzdCBBQjEmMCQGA1UECxMdQWRkVHJ1c3QgRXh0
ZXJuYWwgVFRQIE5ldHdvcmsxIjAgBgNVBAMTGUFkZFRydXN0IEV4dGVybmFsIENB
IFJvb3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC39xoz5vIABC05
4E5b7R+8bA/Ntfojts7emxEzl6QpTH2Tn71KvJPtAxrjj8/lbVBa1pcplFqAsEl6
2y6V/bjKvzc4LR4+kUGtcFbH8E8/6DKedMrIkFTpxl8PeJ2aQDwOrGGqXhSPnoeh
alDc15pOrwWzpnGUnHGzUGAKxxOdOAeGAqjpqGkmGJCrTLBPI6s6T4TY386f4Wlv
u9dC12tE5Met7m1BX3JacQg3s3llpFmglDf3AC8NwpJy2tA4ctsUqEXEXSp9t7TW
xO6szRNEt8kr3UMAJfphuWlqWCMRt6czj1Z1WfXNKddGtworZbbTQm8Vsrh7++/p
XVPVNFonAgMBAAGjgdgwgdUwHwYDVR0jBBgwFoAUUzLRs89/+uDxoF2FTpLSnkUd
tE8wHQYDVR0OBBYEFK29mHo0tCb3+sQmVO8DveAky1QaMA4GA1UdDwEB/wQEAwIB
BjAPBgNVHRMBAf8EBTADAQH/MBEGCWCGSAGG+EIBAQQEAwIBAjAgBgNVHSUEGTAX
BgorBgEEAYI3CgMDBglghkgBhvhCBAEwPQYDVR0fBDYwNDAyoDCgLoYsaHR0cDov
L2NybC51c2VydHJ1c3QuY29tL1VUTi1EQVRBQ29ycFNHQy5jcmwwDQYJKoZIhvcN
AQEFBQADggEBAMbuUxdoFLJRIh6QWA2U/b3xcOWGLcM2MY9USEbnLQg3vGwKYOEO
rVE04BKT6b64q7gmtOmWPSiPrmQH/uAB7MXjkesYoPF1ftsK5p+R26+udd8jkWjd
FwBaS/9kbHDrARrQkNnHptZt9hPk/7XJ0h4qy7ElQyZ42TCbTg0evmnv3+r+LbPM
+bDdtRTKkdSytaX7ARmjR3mfnYyVhzT4HziS2jamEfpr62vp3EV4FTkG101B5CHI
3C+H0be/SGB1pWLLJN47YaApIKa+xWycxOkKaSLvkTr6Jq/RW0GnOuL4OAdCq8Fb
+M5tug8EPzI0rNwEKNdwMBQmBsTkm5jVz3g=
-----END CERTIFICATE-----
3 s:/C=US/ST=UT/L=Salt Lake City/O=The USERTRUST Network/OU=http://www.usertrust.com/CN=UTN - DATACorp SGC
i:/C=US/ST=UT/L=Salt Lake City/O=The USERTRUST Network/OU=http://www.usertrust.com/CN=UTN - DATACorp SGC
-----BEGIN CERTIFICATE-----
MIIEXjCCA0agAwIBAgIQRL4Mi1AAIbQR0ypoBqmtaTANBgkqhkiG9w0BAQUFADCB
kzELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAlVUMRcwFQYDVQQHEw5TYWx0IExha2Ug
Q2l0eTEeMBwGA1UEChMVVGhlIFVTRVJUUlVTVCBOZXR3b3JrMSEwHwYDVQQLExho
dHRwOi8vd3d3LnVzZXJ0cnVzdC5jb20xGzAZBgNVBAMTElVUTiAtIERBVEFDb3Jw
IFNHQzAeFw05OTA2MjQxODU3MjFaFw0xOTA2MjQxOTA2MzBaMIGTMQswCQYDVQQG
EwJVUzELMAkGA1UECBMCVVQxFzAVBgNVBAcTDlNhbHQgTGFrZSBDaXR5MR4wHAYD
VQQKExVUaGUgVVNFUlRSVVNUIE5ldHdvcmsxITAfBgNVBAsTGGh0dHA6Ly93d3cu
dXNlcnRydXN0LmNvbTEbMBkGA1UEAxMSVVROIC0gREFUQUNvcnAgU0dDMIIBIjAN
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3+5YEKIrblXEjr8uRgnn4AgPLit6
E5Qbvfa2gI5lBZMAHryv4g+OGQ0SR+ysraP6LnD43m77VkIVni5c7yPeIbkFdicZ
D0/Ww5y0vpQZY/KmEQrrU0icvvIpOxboGqBMpsn0GFlowHDyUwDAXlCCpVZvNvlK
4ESGoE1O1kduSUrLZ9emxAW5jh70/P/N5zbgnAVssjMiFdC04MwXwLLA9P4yPykq
lXvY8qdOD1R8oQ2AswkDwf9c3V6aPryuvEeKaq5xyh+xKrhfQgUL7EYw0XILyulW
bfXv33i+Ybqypa4ETLyorGkVl73v67SMvzX41MPRKA5cOp9wGDMgd8SirwIDAQAB
o4GrMIGoMAsGA1UdDwQEAwIBxjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBRT
MtGzz3/64PGgXYVOktKeRR20TzA9BgNVHR8ENjA0MDKgMKAuhixodHRwOi8vY3Js
LnVzZXJ0cnVzdC5jb20vVVROLURBVEFDb3JwU0dDLmNybDAqBgNVHSUEIzAhBggr
BgEFBQcDAQYKKwYBBAGCNwoDAwYJYIZIAYb4QgQBMA0GCSqGSIb3DQEBBQUAA4IB
AQAnNZcAiosovcYzMB4p/OL31ZjUQLtgyr+rFywJNn9Q+kHcrpY6CiM+iVnJowft
Gzet/Hy+UUla3joKVAgWRcKZsYfNjGjgaQPpxE6YsjuMFrMOoAyYUJuTqXAJyCyj
j98C5OBxOvG0I3KgqgHf35g+FFCgMSa9KOlaMCZ1+XtgHI3zzVAmbQQnmt/VDUVH
KWss5nbZqSl9Mt3JNjy9rjXxEZ4du5A/EkdOjtd+D2JzHVImOBwYSf0wdJrE5SIv
2MCN7ZF6TACPcn9d2t0bi0Vr591pl6jFVkwPDPafepE39peC4N1xaf92P2BNPM/3
mfnGV/TJVTl4uix5yaaIK/QI
-----END CERTIFICATE-----
import signal, subprocess, sys
# On Linux this causes os.waitpid to fail with OSError as the OS has already
# reaped our child process. The wait() passing the OSError on to the caller
# and causing us to exit with an error is what we are testing against.
signal.signal(signal.SIGCHLD, signal.SIG_IGN)
subprocess.Popen([sys.executable, '-c', 'print("albatross")']).wait()
import asyncore
import unittest
import select
import os
import socket
import sys
import time
import warnings
import errno
from test import test_support
from test.test_support import TESTFN, run_unittest, unlink
from StringIO import StringIO
try:
import threading
except ImportError:
threading = None
HOST = test_support.HOST
class dummysocket:
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def fileno(self):
return 42
class dummychannel:
def __init__(self):
self.socket = dummysocket()
def close(self):
self.socket.close()
class exitingdummy:
def __init__(self):
pass
def handle_read_event(self):
raise asyncore.ExitNow()
handle_write_event = handle_read_event
handle_close = handle_read_event
handle_expt_event = handle_read_event
class crashingdummy:
def __init__(self):
self.error_handled = False
def handle_read_event(self):
raise Exception()
handle_write_event = handle_read_event
handle_close = handle_read_event
handle_expt_event = handle_read_event
def handle_error(self):
self.error_handled = True
# used when testing senders; just collects what it gets until newline is sent
def capture_server(evt, buf, serv):
try:
serv.listen(5)
conn, addr = serv.accept()
except socket.timeout:
pass
else:
n = 200
while n > 0:
r, w, e = select.select([conn], [], [])
if r:
data = conn.recv(10)
# keep everything except for the newline terminator
buf.write(data.replace('\n', ''))
if '\n' in data:
break
n -= 1
time.sleep(0.01)
conn.close()
finally:
serv.close()
evt.set()
class HelperFunctionTests(unittest.TestCase):
def test_readwriteexc(self):
# Check exception handling behavior of read, write and _exception
# check that ExitNow exceptions in the object handler method
# bubbles all the way up through asyncore read/write/_exception calls
tr1 = exitingdummy()
self.assertRaises(asyncore.ExitNow, asyncore.read, tr1)
self.assertRaises(asyncore.ExitNow, asyncore.write, tr1)
self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1)
# check that an exception other than ExitNow in the object handler
# method causes the handle_error method to get called
tr2 = crashingdummy()
asyncore.read(tr2)
self.assertEqual(tr2.error_handled, True)
tr2 = crashingdummy()
asyncore.write(tr2)
self.assertEqual(tr2.error_handled, True)
tr2 = crashingdummy()
asyncore._exception(tr2)
self.assertEqual(tr2.error_handled, True)
# asyncore.readwrite uses constants in the select module that
# are not present in Windows systems (see this thread:
# http://mail.python.org/pipermail/python-list/2001-October/109973.html)
# These constants should be present as long as poll is available
@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
def test_readwrite(self):
# Check that correct methods are called by readwrite()
attributes = ('read', 'expt', 'write', 'closed', 'error_handled')
expected = (
(select.POLLIN, 'read'),
(select.POLLPRI, 'expt'),
(select.POLLOUT, 'write'),
(select.POLLERR, 'closed'),
(select.POLLHUP, 'closed'),
(select.POLLNVAL, 'closed'),
)
class testobj:
def __init__(self):
self.read = False
self.write = False
self.closed = False
self.expt = False
self.error_handled = False
def handle_read_event(self):
self.read = True
def handle_write_event(self):
self.write = True
def handle_close(self):
self.closed = True
def handle_expt_event(self):
self.expt = True
def handle_error(self):
self.error_handled = True
for flag, expectedattr in expected:
tobj = testobj()
self.assertEqual(getattr(tobj, expectedattr), False)
asyncore.readwrite(tobj, flag)
# Only the attribute modified by the routine we expect to be
# called should be True.
for attr in attributes:
self.assertEqual(getattr(tobj, attr), attr==expectedattr)
# check that ExitNow exceptions in the object handler method
# bubbles all the way up through asyncore readwrite call
tr1 = exitingdummy()
self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag)
# check that an exception other than ExitNow in the object handler
# method causes the handle_error method to get called
tr2 = crashingdummy()
self.assertEqual(tr2.error_handled, False)
asyncore.readwrite(tr2, flag)
self.assertEqual(tr2.error_handled, True)
def test_closeall(self):
self.closeall_check(False)
def test_closeall_default(self):
self.closeall_check(True)
def closeall_check(self, usedefault):
# Check that close_all() closes everything in a given map
l = []
testmap = {}
for i in range(10):
c = dummychannel()
l.append(c)
self.assertEqual(c.socket.closed, False)
testmap[i] = c
if usedefault:
socketmap = asyncore.socket_map
try:
asyncore.socket_map = testmap
asyncore.close_all()
finally:
testmap, asyncore.socket_map = asyncore.socket_map, socketmap
else:
asyncore.close_all(testmap)
self.assertEqual(len(testmap), 0)
for c in l:
self.assertEqual(c.socket.closed, True)
def test_compact_traceback(self):
try:
raise Exception("I don't like spam!")
except:
real_t, real_v, real_tb = sys.exc_info()
r = asyncore.compact_traceback()
else:
self.fail("Expected exception")
(f, function, line), t, v, info = r
self.assertEqual(os.path.split(f)[-1], 'test_asyncore.py')
self.assertEqual(function, 'test_compact_traceback')
self.assertEqual(t, real_t)
self.assertEqual(v, real_v)
self.assertEqual(info, '[%s|%s|%s]' % (f, function, line))
class DispatcherTests(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
asyncore.close_all()
def test_basic(self):
d = asyncore.dispatcher()
self.assertEqual(d.readable(), True)
self.assertEqual(d.writable(), True)
def test_repr(self):
d = asyncore.dispatcher()
self.assertEqual(repr(d), '<asyncore.dispatcher at %#x>' % id(d))
def test_log(self):
d = asyncore.dispatcher()
# capture output of dispatcher.log() (to stderr)
fp = StringIO()
stderr = sys.stderr
l1 = "Lovely spam! Wonderful spam!"
l2 = "I don't like spam!"
try:
sys.stderr = fp
d.log(l1)
d.log(l2)
finally:
sys.stderr = stderr
lines = fp.getvalue().splitlines()
self.assertEqual(lines, ['log: %s' % l1, 'log: %s' % l2])
def test_log_info(self):
d = asyncore.dispatcher()
# capture output of dispatcher.log_info() (to stdout via print)
fp = StringIO()
stdout = sys.stdout
l1 = "Have you got anything without spam?"
l2 = "Why can't she have egg bacon spam and sausage?"
l3 = "THAT'S got spam in it!"
try:
sys.stdout = fp
d.log_info(l1, 'EGGS')
d.log_info(l2)
d.log_info(l3, 'SPAM')
finally:
sys.stdout = stdout
lines = fp.getvalue().splitlines()
expected = ['EGGS: %s' % l1, 'info: %s' % l2, 'SPAM: %s' % l3]
self.assertEqual(lines, expected)
def test_unhandled(self):
d = asyncore.dispatcher()
d.ignore_log_types = ()
# capture output of dispatcher.log_info() (to stdout via print)
fp = StringIO()
stdout = sys.stdout
try:
sys.stdout = fp
d.handle_expt()
d.handle_read()
d.handle_write()
d.handle_connect()
d.handle_accept()
finally:
sys.stdout = stdout
lines = fp.getvalue().splitlines()
expected = ['warning: unhandled incoming priority event',
'warning: unhandled read event',
'warning: unhandled write event',
'warning: unhandled connect event',
'warning: unhandled accept event']
self.assertEqual(lines, expected)
def test_issue_8594(self):
# XXX - this test is supposed to be removed in next major Python
# version
d = asyncore.dispatcher(socket.socket())
# make sure the error message no longer refers to the socket
# object but the dispatcher instance instead
self.assertRaisesRegexp(AttributeError, 'dispatcher instance',
getattr, d, 'foo')
# cheap inheritance with the underlying socket is supposed
# to still work but a DeprecationWarning is expected
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
family = d.family
self.assertEqual(family, socket.AF_INET)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[0].category, DeprecationWarning))
def test_strerror(self):
# refers to bug #8573
err = asyncore._strerror(errno.EPERM)
if hasattr(os, 'strerror'):
self.assertEqual(err, os.strerror(errno.EPERM))
err = asyncore._strerror(-1)
self.assertTrue(err != "")
class dispatcherwithsend_noread(asyncore.dispatcher_with_send):
def readable(self):
return False
def handle_connect(self):
pass
class DispatcherWithSendTests(unittest.TestCase):
usepoll = False
def setUp(self):
pass
def tearDown(self):
asyncore.close_all()
@unittest.skipUnless(threading, 'Threading required for this test.')
@test_support.reap_threads
def test_send(self):
evt = threading.Event()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(3)
port = test_support.bind_port(sock)
cap = StringIO()
args = (evt, cap, sock)
t = threading.Thread(target=capture_server, args=args)
t.start()
try:
# wait a little longer for the server to initialize (it sometimes
# refuses connections on slow machines without this wait)
time.sleep(0.2)
data = "Suppose there isn't a 16-ton weight?"
d = dispatcherwithsend_noread()
d.create_socket(socket.AF_INET, socket.SOCK_STREAM)
d.connect((HOST, port))
# give time for socket to connect
time.sleep(0.1)
d.send(data)
d.send(data)
d.send('\n')
n = 1000
while d.out_buffer and n > 0:
asyncore.poll()
n -= 1
evt.wait()
self.assertEqual(cap.getvalue(), data*2)
finally:
t.join()
class DispatcherWithSendTests_UsePoll(DispatcherWithSendTests):
usepoll = True
@unittest.skipUnless(hasattr(asyncore, 'file_wrapper'),
'asyncore.file_wrapper required')
class FileWrapperTest(unittest.TestCase):
def setUp(self):
self.d = "It's not dead, it's sleeping!"
with file(TESTFN, 'w') as h:
h.write(self.d)
def tearDown(self):
unlink(TESTFN)
def test_recv(self):
fd = os.open(TESTFN, os.O_RDONLY)
w = asyncore.file_wrapper(fd)
os.close(fd)
self.assertNotEqual(w.fd, fd)
self.assertNotEqual(w.fileno(), fd)
self.assertEqual(w.recv(13), "It's not dead")
self.assertEqual(w.read(6), ", it's")
w.close()
self.assertRaises(OSError, w.read, 1)
def test_send(self):
d1 = "Come again?"
d2 = "I want to buy some cheese."
fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND)
w = asyncore.file_wrapper(fd)
os.close(fd)
w.write(d1)
w.send(d2)
w.close()
self.assertEqual(file(TESTFN).read(), self.d + d1 + d2)
@unittest.skipUnless(hasattr(asyncore, 'file_dispatcher'),
'asyncore.file_dispatcher required')
def test_dispatcher(self):
fd = os.open(TESTFN, os.O_RDONLY)
data = []
class FileDispatcher(asyncore.file_dispatcher):
def handle_read(self):
data.append(self.recv(29))
s = FileDispatcher(fd)
os.close(fd)
asyncore.loop(timeout=0.01, use_poll=True, count=2)
self.assertEqual(b"".join(data), self.d)
class BaseTestHandler(asyncore.dispatcher):
def __init__(self, sock=None):
asyncore.dispatcher.__init__(self, sock)
self.flag = False
def handle_accept(self):
raise Exception("handle_accept not supposed to be called")
def handle_connect(self):
raise Exception("handle_connect not supposed to be called")
def handle_expt(self):
raise Exception("handle_expt not supposed to be called")
def handle_close(self):
raise Exception("handle_close not supposed to be called")
def handle_error(self):
raise
class TCPServer(asyncore.dispatcher):
"""A server which listens on an address and dispatches the
connection to a handler.
"""
def __init__(self, handler=BaseTestHandler, host=HOST, port=0):
asyncore.dispatcher.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.set_reuse_addr()
self.bind((host, port))
self.listen(5)
self.handler = handler
@property
def address(self):
return self.socket.getsockname()[:2]
def handle_accept(self):
sock, addr = self.accept()
self.handler(sock)
def handle_error(self):
raise
class BaseClient(BaseTestHandler):
def __init__(self, address):
BaseTestHandler.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.connect(address)
def handle_connect(self):
pass
class BaseTestAPI(unittest.TestCase):
def tearDown(self):
asyncore.close_all()
def loop_waiting_for_flag(self, instance, timeout=5):
timeout = float(timeout) / 100
count = 100
while asyncore.socket_map and count > 0:
asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll)
if instance.flag:
return
count -= 1
time.sleep(timeout)
self.fail("flag not set")
def test_handle_connect(self):
# make sure handle_connect is called on connect()
class TestClient(BaseClient):
def handle_connect(self):
self.flag = True
server = TCPServer()
client = TestClient(server.address)
self.loop_waiting_for_flag(client)
def test_handle_accept(self):
# make sure handle_accept() is called when a client connects
class TestListener(BaseTestHandler):
def __init__(self):
BaseTestHandler.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.bind((HOST, 0))
self.listen(5)
self.address = self.socket.getsockname()[:2]
def handle_accept(self):
self.flag = True
server = TestListener()
client = BaseClient(server.address)
self.loop_waiting_for_flag(server)
def test_handle_read(self):
# make sure handle_read is called on data received
class TestClient(BaseClient):
def handle_read(self):
self.flag = True
class TestHandler(BaseTestHandler):
def __init__(self, conn):
BaseTestHandler.__init__(self, conn)
self.send('x' * 1024)
server = TCPServer(TestHandler)
client = TestClient(server.address)
self.loop_waiting_for_flag(client)
def test_handle_write(self):
# make sure handle_write is called
class TestClient(BaseClient):
def handle_write(self):
self.flag = True
server = TCPServer()
client = TestClient(server.address)
self.loop_waiting_for_flag(client)
def test_handle_close(self):
# make sure handle_close is called when the other end closes
# the connection
class TestClient(BaseClient):
def handle_read(self):
# in order to make handle_close be called we are supposed
# to make at least one recv() call
self.recv(1024)
def handle_close(self):
self.flag = True
self.close()
class TestHandler(BaseTestHandler):
def __init__(self, conn):
BaseTestHandler.__init__(self, conn)
self.close()
server = TCPServer(TestHandler)
client = TestClient(server.address)
self.loop_waiting_for_flag(client)
@unittest.skipIf(sys.platform.startswith("sunos"),
"OOB support is broken on Solaris")
def test_handle_expt(self):
# Make sure handle_expt is called on OOB data received.
# Note: this might fail on some platforms as OOB data is
# tenuously supported and rarely used.
class TestClient(BaseClient):
def handle_expt(self):
self.flag = True
class TestHandler(BaseTestHandler):
def __init__(self, conn):
BaseTestHandler.__init__(self, conn)
self.socket.send(chr(244), socket.MSG_OOB)
server = TCPServer(TestHandler)
client = TestClient(server.address)
self.loop_waiting_for_flag(client)
def test_handle_error(self):
class TestClient(BaseClient):
def handle_write(self):
1.0 / 0
def handle_error(self):
self.flag = True
try:
raise
except ZeroDivisionError:
pass
else:
raise Exception("exception not raised")
server = TCPServer()
client = TestClient(server.address)
self.loop_waiting_for_flag(client)
def test_connection_attributes(self):
server = TCPServer()
client = BaseClient(server.address)
# we start disconnected
self.assertFalse(server.connected)
self.assertTrue(server.accepting)
# this can't be taken for granted across all platforms
#self.assertFalse(client.connected)
self.assertFalse(client.accepting)
# execute some loops so that client connects to server
asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100)
self.assertFalse(server.connected)
self.assertTrue(server.accepting)
self.assertTrue(client.connected)
self.assertFalse(client.accepting)
# disconnect the client
client.close()
self.assertFalse(server.connected)
self.assertTrue(server.accepting)
self.assertFalse(client.connected)
self.assertFalse(client.accepting)
# stop serving
server.close()
self.assertFalse(server.connected)
self.assertFalse(server.accepting)
def test_create_socket(self):
s = asyncore.dispatcher()
s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.assertEqual(s.socket.family, socket.AF_INET)
self.assertEqual(s.socket.type, socket.SOCK_STREAM)
def test_bind(self):
s1 = asyncore.dispatcher()
s1.create_socket(socket.AF_INET, socket.SOCK_STREAM)
s1.bind((HOST, 0))
s1.listen(5)
port = s1.socket.getsockname()[1]
s2 = asyncore.dispatcher()
s2.create_socket(socket.AF_INET, socket.SOCK_STREAM)
# EADDRINUSE indicates the socket was correctly bound
self.assertRaises(socket.error, s2.bind, (HOST, port))
def test_set_reuse_addr(self):
sock = socket.socket()
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error:
unittest.skip("SO_REUSEADDR not supported on this platform")
else:
# if SO_REUSEADDR succeeded for sock we expect asyncore
# to do the same
s = asyncore.dispatcher(socket.socket())
self.assertFalse(s.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR))
s.create_socket(socket.AF_INET, socket.SOCK_STREAM)
s.set_reuse_addr()
self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR))
finally:
sock.close()
class TestAPI_UseSelect(BaseTestAPI):
use_poll = False
@unittest.skipUnless(hasattr(select, 'poll'), 'select.poll required')
class TestAPI_UsePoll(BaseTestAPI):
use_poll = True
def test_main():
tests = [HelperFunctionTests, DispatcherTests, DispatcherWithSendTests,
DispatcherWithSendTests_UsePoll, TestAPI_UseSelect,
TestAPI_UsePoll, FileWrapperTest]
run_unittest(*tests)
if __name__ == "__main__":
test_main()
# -*- coding: latin-1 -*-
"""Tests for cookielib.py."""
import cookielib
import os
import re
import time
from unittest import TestCase
from test import test_support
class DateTimeTests(TestCase):
def test_time2isoz(self):
from cookielib import time2isoz
base = 1019227000
day = 24*3600
self.assertEqual(time2isoz(base), "2002-04-19 14:36:40Z")
self.assertEqual(time2isoz(base+day), "2002-04-20 14:36:40Z")
self.assertEqual(time2isoz(base+2*day), "2002-04-21 14:36:40Z")
self.assertEqual(time2isoz(base+3*day), "2002-04-22 14:36:40Z")
az = time2isoz()
bz = time2isoz(500000)
for text in (az, bz):
self.assertTrue(re.search(r"^\d{4}-\d\d-\d\d \d\d:\d\d:\d\dZ$", text),
"bad time2isoz format: %s %s" % (az, bz))
def test_http2time(self):
from cookielib import http2time
def parse_date(text):
return time.gmtime(http2time(text))[:6]
self.assertEqual(parse_date("01 Jan 2001"), (2001, 1, 1, 0, 0, 0.0))
# this test will break around year 2070
self.assertEqual(parse_date("03-Feb-20"), (2020, 2, 3, 0, 0, 0.0))
# this test will break around year 2048
self.assertEqual(parse_date("03-Feb-98"), (1998, 2, 3, 0, 0, 0.0))
def test_http2time_formats(self):
from cookielib import http2time, time2isoz
# test http2time for supported dates. Test cases with 2 digit year
# will probably break in year 2044.
tests = [
'Thu, 03 Feb 1994 00:00:00 GMT', # proposed new HTTP format
'Thursday, 03-Feb-94 00:00:00 GMT', # old rfc850 HTTP format
'Thursday, 03-Feb-1994 00:00:00 GMT', # broken rfc850 HTTP format
'03 Feb 1994 00:00:00 GMT', # HTTP format (no weekday)
'03-Feb-94 00:00:00 GMT', # old rfc850 (no weekday)
'03-Feb-1994 00:00:00 GMT', # broken rfc850 (no weekday)
'03-Feb-1994 00:00 GMT', # broken rfc850 (no weekday, no seconds)
'03-Feb-1994 00:00', # broken rfc850 (no weekday, no seconds, no tz)
'03-Feb-94', # old rfc850 HTTP format (no weekday, no time)
'03-Feb-1994', # broken rfc850 HTTP format (no weekday, no time)
'03 Feb 1994', # proposed new HTTP format (no weekday, no time)
# A few tests with extra space at various places
' 03 Feb 1994 0:00 ',
' 03-Feb-1994 ',
]
test_t = 760233600 # assume broken POSIX counting of seconds
result = time2isoz(test_t)
expected = "1994-02-03 00:00:00Z"
self.assertEqual(result, expected,
"%s => '%s' (%s)" % (test_t, result, expected))
for s in tests:
t = http2time(s)
t2 = http2time(s.lower())
t3 = http2time(s.upper())
self.assertTrue(t == t2 == t3 == test_t,
"'%s' => %s, %s, %s (%s)" % (s, t, t2, t3, test_t))
def test_http2time_garbage(self):
from cookielib import http2time
for test in [
'',
'Garbage',
'Mandag 16. September 1996',
'01-00-1980',
'01-13-1980',
'00-01-1980',
'32-01-1980',
'01-01-1980 25:00:00',
'01-01-1980 00:61:00',
'01-01-1980 00:00:62',
]:
self.assertTrue(http2time(test) is None,
"http2time(%s) is not None\n"
"http2time(test) %s" % (test, http2time(test))
)
class HeaderTests(TestCase):
def test_parse_ns_headers_expires(self):
from cookielib import parse_ns_headers
# quotes should be stripped
expected = [[('foo', 'bar'), ('expires', 2209069412L), ('version', '0')]]
for hdr in [
'foo=bar; expires=01 Jan 2040 22:23:32 GMT',
'foo=bar; expires="01 Jan 2040 22:23:32 GMT"',
]:
self.assertEqual(parse_ns_headers([hdr]), expected)
def test_parse_ns_headers_version(self):
from cookielib import parse_ns_headers
# quotes should be stripped
expected = [[('foo', 'bar'), ('version', '1')]]
for hdr in [
'foo=bar; version="1"',
'foo=bar; Version="1"',
]:
self.assertEqual(parse_ns_headers([hdr]), expected)
def test_parse_ns_headers_special_names(self):
# names such as 'expires' are not special in first name=value pair
# of Set-Cookie: header
from cookielib import parse_ns_headers
# Cookie with name 'expires'
hdr = 'expires=01 Jan 2040 22:23:32 GMT'
expected = [[("expires", "01 Jan 2040 22:23:32 GMT"), ("version", "0")]]
self.assertEqual(parse_ns_headers([hdr]), expected)
def test_join_header_words(self):
from cookielib import join_header_words
joined = join_header_words([[("foo", None), ("bar", "baz")]])
self.assertEqual(joined, "foo; bar=baz")
self.assertEqual(join_header_words([[]]), "")
def test_split_header_words(self):
from cookielib import split_header_words
tests = [
("foo", [[("foo", None)]]),
("foo=bar", [[("foo", "bar")]]),
(" foo ", [[("foo", None)]]),
(" foo= ", [[("foo", "")]]),
(" foo=", [[("foo", "")]]),
(" foo= ; ", [[("foo", "")]]),
(" foo= ; bar= baz ", [[("foo", ""), ("bar", "baz")]]),
("foo=bar bar=baz", [[("foo", "bar"), ("bar", "baz")]]),
# doesn't really matter if this next fails, but it works ATM
("foo= bar=baz", [[("foo", "bar=baz")]]),
("foo=bar;bar=baz", [[("foo", "bar"), ("bar", "baz")]]),
('foo bar baz', [[("foo", None), ("bar", None), ("baz", None)]]),
("a, b, c", [[("a", None)], [("b", None)], [("c", None)]]),
(r'foo; bar=baz, spam=, foo="\,\;\"", bar= ',
[[("foo", None), ("bar", "baz")],
[("spam", "")], [("foo", ',;"')], [("bar", "")]]),
]
for arg, expect in tests:
try:
result = split_header_words([arg])
except:
import traceback, StringIO
f = StringIO.StringIO()
traceback.print_exc(None, f)
result = "(error -- traceback follows)\n\n%s" % f.getvalue()
self.assertEqual(result, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
""" % (arg, expect, result))
def test_roundtrip(self):
from cookielib import split_header_words, join_header_words
tests = [
("foo", "foo"),
("foo=bar", "foo=bar"),
(" foo ", "foo"),
("foo=", 'foo=""'),
("foo=bar bar=baz", "foo=bar; bar=baz"),
("foo=bar;bar=baz", "foo=bar; bar=baz"),
('foo bar baz', "foo; bar; baz"),
(r'foo="\"" bar="\\"', r'foo="\""; bar="\\"'),
('foo,,,bar', 'foo, bar'),
('foo=bar,bar=baz', 'foo=bar, bar=baz'),
('text/html; charset=iso-8859-1',
'text/html; charset="iso-8859-1"'),
('foo="bar"; port="80,81"; discard, bar=baz',
'foo=bar; port="80,81"; discard, bar=baz'),
(r'Basic realm="\"foo\\\\bar\""',
r'Basic; realm="\"foo\\\\bar\""')
]
for arg, expect in tests:
input = split_header_words([arg])
res = join_header_words(input)
self.assertEqual(res, expect, """
When parsing: '%s'
Expected: '%s'
Got: '%s'
Input was: '%s'
""" % (arg, expect, res, input))
class FakeResponse:
def __init__(self, headers=[], url=None):
"""
headers: list of RFC822-style 'Key: value' strings
"""
import mimetools, StringIO
f = StringIO.StringIO("\n".join(headers))
self._headers = mimetools.Message(f)
self._url = url
def info(self): return self._headers
def interact_2965(cookiejar, url, *set_cookie_hdrs):
return _interact(cookiejar, url, set_cookie_hdrs, "Set-Cookie2")
def interact_netscape(cookiejar, url, *set_cookie_hdrs):
return _interact(cookiejar, url, set_cookie_hdrs, "Set-Cookie")
def _interact(cookiejar, url, set_cookie_hdrs, hdr_name):
"""Perform a single request / response cycle, returning Cookie: header."""
from urllib2 import Request
req = Request(url)
cookiejar.add_cookie_header(req)
cookie_hdr = req.get_header("Cookie", "")
headers = []
for hdr in set_cookie_hdrs:
headers.append("%s: %s" % (hdr_name, hdr))
res = FakeResponse(headers, url)
cookiejar.extract_cookies(res, req)
return cookie_hdr
class FileCookieJarTests(TestCase):
def test_lwp_valueless_cookie(self):
# cookies with no value should be saved and loaded consistently
from cookielib import LWPCookieJar
filename = test_support.TESTFN
c = LWPCookieJar()
interact_netscape(c, "http://www.acme.com/", 'boo')
self.assertEqual(c._cookies["www.acme.com"]["/"]["boo"].value, None)
try:
c.save(filename, ignore_discard=True)
c = LWPCookieJar()
c.load(filename, ignore_discard=True)
finally:
try: os.unlink(filename)
except OSError: pass
self.assertEqual(c._cookies["www.acme.com"]["/"]["boo"].value, None)
def test_bad_magic(self):
from cookielib import LWPCookieJar, MozillaCookieJar, LoadError
# IOErrors (eg. file doesn't exist) are allowed to propagate
filename = test_support.TESTFN
for cookiejar_class in LWPCookieJar, MozillaCookieJar:
c = cookiejar_class()
try:
c.load(filename="for this test to work, a file with this "
"filename should not exist")
except IOError, exc:
# exactly IOError, not LoadError
self.assertEqual(exc.__class__, IOError)
else:
self.fail("expected IOError for invalid filename")
# Invalid contents of cookies file (eg. bad magic string)
# causes a LoadError.
try:
f = open(filename, "w")
f.write("oops\n")
for cookiejar_class in LWPCookieJar, MozillaCookieJar:
c = cookiejar_class()
self.assertRaises(LoadError, c.load, filename)
finally:
try: os.unlink(filename)
except OSError: pass
class CookieTests(TestCase):
# XXX
# Get rid of string comparisons where not actually testing str / repr.
# .clear() etc.
# IP addresses like 50 (single number, no dot) and domain-matching
# functions (and is_HDN)? See draft RFC 2965 errata.
# Strictness switches
# is_third_party()
# unverifiability / third-party blocking
# Netscape cookies work the same as RFC 2965 with regard to port.
# Set-Cookie with negative max age.
# If turn RFC 2965 handling off, Set-Cookie2 cookies should not clobber
# Set-Cookie cookies.
# Cookie2 should be sent if *any* cookies are not V1 (ie. V0 OR V2 etc.).
# Cookies (V1 and V0) with no expiry date should be set to be discarded.
# RFC 2965 Quoting:
# Should accept unquoted cookie-attribute values? check errata draft.
# Which are required on the way in and out?
# Should always return quoted cookie-attribute values?
# Proper testing of when RFC 2965 clobbers Netscape (waiting for errata).
# Path-match on return (same for V0 and V1).
# RFC 2965 acceptance and returning rules
# Set-Cookie2 without version attribute is rejected.
# Netscape peculiarities list from Ronald Tschalar.
# The first two still need tests, the rest are covered.
## - Quoting: only quotes around the expires value are recognized as such
## (and yes, some folks quote the expires value); quotes around any other
## value are treated as part of the value.
## - White space: white space around names and values is ignored
## - Default path: if no path parameter is given, the path defaults to the
## path in the request-uri up to, but not including, the last '/'. Note
## that this is entirely different from what the spec says.
## - Commas and other delimiters: Netscape just parses until the next ';'.
## This means it will allow commas etc inside values (and yes, both
## commas and equals are commonly appear in the cookie value). This also
## means that if you fold multiple Set-Cookie header fields into one,
## comma-separated list, it'll be a headache to parse (at least my head
## starts hurting everytime I think of that code).
## - Expires: You'll get all sorts of date formats in the expires,
## including emtpy expires attributes ("expires="). Be as flexible as you
## can, and certainly don't expect the weekday to be there; if you can't
## parse it, just ignore it and pretend it's a session cookie.
## - Domain-matching: Netscape uses the 2-dot rule for _all_ domains, not
## just the 7 special TLD's listed in their spec. And folks rely on
## that...
def test_domain_return_ok(self):
# test optimization: .domain_return_ok() should filter out most
# domains in the CookieJar before we try to access them (because that
# may require disk access -- in particular, with MSIECookieJar)
# This is only a rough check for performance reasons, so it's not too
# critical as long as it's sufficiently liberal.
import cookielib, urllib2
pol = cookielib.DefaultCookiePolicy()
for url, domain, ok in [
("http://foo.bar.com/", "blah.com", False),
("http://foo.bar.com/", "rhubarb.blah.com", False),
("http://foo.bar.com/", "rhubarb.foo.bar.com", False),
("http://foo.bar.com/", ".foo.bar.com", True),
("http://foo.bar.com/", "foo.bar.com", True),
("http://foo.bar.com/", ".bar.com", True),
("http://foo.bar.com/", "com", True),
("http://foo.com/", "rhubarb.foo.com", False),
("http://foo.com/", ".foo.com", True),
("http://foo.com/", "foo.com", True),
("http://foo.com/", "com", True),
("http://foo/", "rhubarb.foo", False),
("http://foo/", ".foo", True),
("http://foo/", "foo", True),
("http://foo/", "foo.local", True),
("http://foo/", ".local", True),
]:
request = urllib2.Request(url)
r = pol.domain_return_ok(domain, request)
if ok: self.assertTrue(r)
else: self.assertTrue(not r)
def test_missing_value(self):
from cookielib import MozillaCookieJar, lwp_cookie_str
# missing = sign in Cookie: header is regarded by Mozilla as a missing
# name, and by cookielib as a missing value
filename = test_support.TESTFN
c = MozillaCookieJar(filename)
interact_netscape(c, "http://www.acme.com/", 'eggs')
interact_netscape(c, "http://www.acme.com/", '"spam"; path=/foo/')
cookie = c._cookies["www.acme.com"]["/"]["eggs"]
self.assertTrue(cookie.value is None)
self.assertEqual(cookie.name, "eggs")
cookie = c._cookies["www.acme.com"]['/foo/']['"spam"']
self.assertTrue(cookie.value is None)
self.assertEqual(cookie.name, '"spam"')
self.assertEqual(lwp_cookie_str(cookie), (
r'"spam"; path="/foo/"; domain="www.acme.com"; '
'path_spec; discard; version=0'))
old_str = repr(c)
c.save(ignore_expires=True, ignore_discard=True)
try:
c = MozillaCookieJar(filename)
c.revert(ignore_expires=True, ignore_discard=True)
finally:
os.unlink(c.filename)
# cookies unchanged apart from lost info re. whether path was specified
self.assertEqual(
repr(c),
re.sub("path_specified=%s" % True, "path_specified=%s" % False,
old_str)
)
self.assertEqual(interact_netscape(c, "http://www.acme.com/foo/"),
'"spam"; eggs')
def test_rfc2109_handling(self):
# RFC 2109 cookies are handled as RFC 2965 or Netscape cookies,
# dependent on policy settings
from cookielib import CookieJar, DefaultCookiePolicy
for rfc2109_as_netscape, rfc2965, version in [
# default according to rfc2965 if not explicitly specified
(None, False, 0),
(None, True, 1),
# explicit rfc2109_as_netscape
(False, False, None), # version None here means no cookie stored
(False, True, 1),
(True, False, 0),
(True, True, 0),
]:
policy = DefaultCookiePolicy(
rfc2109_as_netscape=rfc2109_as_netscape,
rfc2965=rfc2965)
c = CookieJar(policy)
interact_netscape(c, "http://www.example.com/", "ni=ni; Version=1")
try:
cookie = c._cookies["www.example.com"]["/"]["ni"]
except KeyError:
self.assertTrue(version is None) # didn't expect a stored cookie
else:
self.assertEqual(cookie.version, version)
# 2965 cookies are unaffected
interact_2965(c, "http://www.example.com/",
"foo=bar; Version=1")
if rfc2965:
cookie2965 = c._cookies["www.example.com"]["/"]["foo"]
self.assertEqual(cookie2965.version, 1)
def test_ns_parser(self):
from cookielib import CookieJar, DEFAULT_HTTP_PORT
c = CookieJar()
interact_netscape(c, "http://www.acme.com/",
'spam=eggs; DoMain=.acme.com; port; blArgh="feep"')
interact_netscape(c, "http://www.acme.com/", 'ni=ni; port=80,8080')
interact_netscape(c, "http://www.acme.com:80/", 'nini=ni')
interact_netscape(c, "http://www.acme.com:80/", 'foo=bar; expires=')
interact_netscape(c, "http://www.acme.com:80/", 'spam=eggs; '
'expires="Foo Bar 25 33:22:11 3022"')
cookie = c._cookies[".acme.com"]["/"]["spam"]
self.assertEqual(cookie.domain, ".acme.com")
self.assertTrue(cookie.domain_specified)
self.assertEqual(cookie.port, DEFAULT_HTTP_PORT)
self.assertTrue(not cookie.port_specified)
# case is preserved
self.assertTrue(cookie.has_nonstandard_attr("blArgh") and
not cookie.has_nonstandard_attr("blargh"))
cookie = c._cookies["www.acme.com"]["/"]["ni"]
self.assertEqual(cookie.domain, "www.acme.com")
self.assertTrue(not cookie.domain_specified)
self.assertEqual(cookie.port, "80,8080")
self.assertTrue(cookie.port_specified)
cookie = c._cookies["www.acme.com"]["/"]["nini"]
self.assertTrue(cookie.port is None)
self.assertTrue(not cookie.port_specified)
# invalid expires should not cause cookie to be dropped
foo = c._cookies["www.acme.com"]["/"]["foo"]
spam = c._cookies["www.acme.com"]["/"]["foo"]
self.assertTrue(foo.expires is None)
self.assertTrue(spam.expires is None)
def test_ns_parser_special_names(self):
# names such as 'expires' are not special in first name=value pair
# of Set-Cookie: header
from cookielib import CookieJar
c = CookieJar()
interact_netscape(c, "http://www.acme.com/", 'expires=eggs')
interact_netscape(c, "http://www.acme.com/", 'version=eggs; spam=eggs')
cookies = c._cookies["www.acme.com"]["/"]
self.assertTrue('expires' in cookies)
self.assertTrue('version' in cookies)
def test_expires(self):
from cookielib import time2netscape, CookieJar
# if expires is in future, keep cookie...
c = CookieJar()
future = time2netscape(time.time()+3600)
interact_netscape(c, "http://www.acme.com/", 'spam="bar"; expires=%s' %
future)
self.assertEqual(len(c), 1)
now = time2netscape(time.time()-1)
# ... and if in past or present, discard it
interact_netscape(c, "http://www.acme.com/", 'foo="eggs"; expires=%s' %
now)
h = interact_netscape(c, "http://www.acme.com/")
self.assertEqual(len(c), 1)
self.assertTrue('spam="bar"' in h and "foo" not in h)
# max-age takes precedence over expires, and zero max-age is request to
# delete both new cookie and any old matching cookie
interact_netscape(c, "http://www.acme.com/", 'eggs="bar"; expires=%s' %
future)
interact_netscape(c, "http://www.acme.com/", 'bar="bar"; expires=%s' %
future)
self.assertEqual(len(c), 3)
interact_netscape(c, "http://www.acme.com/", 'eggs="bar"; '
'expires=%s; max-age=0' % future)
interact_netscape(c, "http://www.acme.com/", 'bar="bar"; '
'max-age=0; expires=%s' % future)
h = interact_netscape(c, "http://www.acme.com/")
self.assertEqual(len(c), 1)
# test expiry at end of session for cookies with no expires attribute
interact_netscape(c, "http://www.rhubarb.net/", 'whum="fizz"')
self.assertEqual(len(c), 2)
c.clear_session_cookies()
self.assertEqual(len(c), 1)
self.assertIn('spam="bar"', h)
# XXX RFC 2965 expiry rules (some apply to V0 too)
def test_default_path(self):
from cookielib import CookieJar, DefaultCookiePolicy
# RFC 2965
pol = DefaultCookiePolicy(rfc2965=True)
c = CookieJar(pol)
interact_2965(c, "http://www.acme.com/", 'spam="bar"; Version="1"')
self.assertIn("/", c._cookies["www.acme.com"])
c = CookieJar(pol)
interact_2965(c, "http://www.acme.com/blah", 'eggs="bar"; Version="1"')
self.assertIn("/", c._cookies["www.acme.com"])
c = CookieJar(pol)
interact_2965(c, "http://www.acme.com/blah/rhubarb",
'eggs="bar"; Version="1"')
self.assertIn("/blah/", c._cookies["www.acme.com"])
c = CookieJar(pol)
interact_2965(c, "http://www.acme.com/blah/rhubarb/",
'eggs="bar"; Version="1"')
self.assertIn("/blah/rhubarb/", c._cookies["www.acme.com"])
# Netscape
c = CookieJar()
interact_netscape(c, "http://www.acme.com/", 'spam="bar"')
self.assertIn("/", c._cookies["www.acme.com"])
c = CookieJar()
interact_netscape(c, "http://www.acme.com/blah", 'eggs="bar"')
self.assertIn("/", c._cookies["www.acme.com"])
c = CookieJar()
interact_netscape(c, "http://www.acme.com/blah/rhubarb", 'eggs="bar"')
self.assertIn("/blah", c._cookies["www.acme.com"])
c = CookieJar()
interact_netscape(c, "http://www.acme.com/blah/rhubarb/", 'eggs="bar"')
self.assertIn("/blah/rhubarb", c._cookies["www.acme.com"])
def test_default_path_with_query(self):
cj = cookielib.CookieJar()
uri = "http://example.com/?spam/eggs"
value = 'eggs="bar"'
interact_netscape(cj, uri, value)
# default path does not include query, so is "/", not "/?spam"
self.assertIn("/", cj._cookies["example.com"])
# cookie is sent back to the same URI
self.assertEqual(interact_netscape(cj, uri), value)
def test_escape_path(self):
from cookielib import escape_path
cases = [
# quoted safe
("/foo%2f/bar", "/foo%2F/bar"),
("/foo%2F/bar", "/foo%2F/bar"),
# quoted %
("/foo%%/bar", "/foo%%/bar"),
# quoted unsafe
("/fo%19o/bar", "/fo%19o/bar"),
("/fo%7do/bar", "/fo%7Do/bar"),
# unquoted safe
("/foo/bar&", "/foo/bar&"),
("/foo//bar", "/foo//bar"),
("\176/foo/bar", "\176/foo/bar"),
# unquoted unsafe
("/foo\031/bar", "/foo%19/bar"),
("/\175foo/bar", "/%7Dfoo/bar"),
# unicode
(u"/foo/bar\uabcd", "/foo/bar%EA%AF%8D"), # UTF-8 encoded
]
for arg, result in cases:
self.assertEqual(escape_path(arg), result)
def test_request_path(self):
from urllib2 import Request
from cookielib import request_path
# with parameters
req = Request("http://www.example.com/rheum/rhaponticum;"
"foo=bar;sing=song?apples=pears&spam=eggs#ni")
self.assertEqual(request_path(req),
"/rheum/rhaponticum;foo=bar;sing=song")
# without parameters
req = Request("http://www.example.com/rheum/rhaponticum?"
"apples=pears&spam=eggs#ni")
self.assertEqual(request_path(req), "/rheum/rhaponticum")
# missing final slash
req = Request("http://www.example.com")
self.assertEqual(request_path(req), "/")
def test_request_port(self):
from urllib2 import Request
from cookielib import request_port, DEFAULT_HTTP_PORT
req = Request("http://www.acme.com:1234/",
headers={"Host": "www.acme.com:4321"})
self.assertEqual(request_port(req), "1234")
req = Request("http://www.acme.com/",
headers={"Host": "www.acme.com:4321"})
self.assertEqual(request_port(req), DEFAULT_HTTP_PORT)
def test_request_host(self):
from urllib2 import Request
from cookielib import request_host
# this request is illegal (RFC2616, 14.2.3)
req = Request("http://1.1.1.1/",
headers={"Host": "www.acme.com:80"})
# libwww-perl wants this response, but that seems wrong (RFC 2616,
# section 5.2, point 1., and RFC 2965 section 1, paragraph 3)
#self.assertEqual(request_host(req), "www.acme.com")
self.assertEqual(request_host(req), "1.1.1.1")
req = Request("http://www.acme.com/",
headers={"Host": "irrelevant.com"})
self.assertEqual(request_host(req), "www.acme.com")
# not actually sure this one is valid Request object, so maybe should
# remove test for no host in url in request_host function?
req = Request("/resource.html",
headers={"Host": "www.acme.com"})
self.assertEqual(request_host(req), "www.acme.com")
# port shouldn't be in request-host
req = Request("http://www.acme.com:2345/resource.html",
headers={"Host": "www.acme.com:5432"})
self.assertEqual(request_host(req), "www.acme.com")
def test_is_HDN(self):
from cookielib import is_HDN
self.assertTrue(is_HDN("foo.bar.com"))
self.assertTrue(is_HDN("1foo2.3bar4.5com"))
self.assertTrue(not is_HDN("192.168.1.1"))
self.assertTrue(not is_HDN(""))
self.assertTrue(not is_HDN("."))
self.assertTrue(not is_HDN(".foo.bar.com"))
self.assertTrue(not is_HDN("..foo"))
self.assertTrue(not is_HDN("foo."))
def test_reach(self):
from cookielib import reach
self.assertEqual(reach("www.acme.com"), ".acme.com")
self.assertEqual(reach("acme.com"), "acme.com")
self.assertEqual(reach("acme.local"), ".local")
self.assertEqual(reach(".local"), ".local")
self.assertEqual(reach(".com"), ".com")
self.assertEqual(reach("."), ".")
self.assertEqual(reach(""), "")
self.assertEqual(reach("192.168.0.1"), "192.168.0.1")
def test_domain_match(self):
from cookielib import domain_match, user_domain_match
self.assertTrue(domain_match("192.168.1.1", "192.168.1.1"))
self.assertTrue(not domain_match("192.168.1.1", ".168.1.1"))
self.assertTrue(domain_match("x.y.com", "x.Y.com"))
self.assertTrue(domain_match("x.y.com", ".Y.com"))
self.assertTrue(not domain_match("x.y.com", "Y.com"))
self.assertTrue(domain_match("a.b.c.com", ".c.com"))
self.assertTrue(not domain_match(".c.com", "a.b.c.com"))
self.assertTrue(domain_match("example.local", ".local"))
self.assertTrue(not domain_match("blah.blah", ""))
self.assertTrue(not domain_match("", ".rhubarb.rhubarb"))
self.assertTrue(domain_match("", ""))
self.assertTrue(user_domain_match("acme.com", "acme.com"))
self.assertTrue(not user_domain_match("acme.com", ".acme.com"))
self.assertTrue(user_domain_match("rhubarb.acme.com", ".acme.com"))
self.assertTrue(user_domain_match("www.rhubarb.acme.com", ".acme.com"))
self.assertTrue(user_domain_match("x.y.com", "x.Y.com"))
self.assertTrue(user_domain_match("x.y.com", ".Y.com"))
self.assertTrue(not user_domain_match("x.y.com", "Y.com"))
self.assertTrue(user_domain_match("y.com", "Y.com"))
self.assertTrue(not user_domain_match(".y.com", "Y.com"))
self.assertTrue(user_domain_match(".y.com", ".Y.com"))
self.assertTrue(user_domain_match("x.y.com", ".com"))
self.assertTrue(not user_domain_match("x.y.com", "com"))
self.assertTrue(not user_domain_match("x.y.com", "m"))
self.assertTrue(not user_domain_match("x.y.com", ".m"))
self.assertTrue(not user_domain_match("x.y.com", ""))
self.assertTrue(not user_domain_match("x.y.com", "."))
self.assertTrue(user_domain_match("192.168.1.1", "192.168.1.1"))
# not both HDNs, so must string-compare equal to match
self.assertTrue(not user_domain_match("192.168.1.1", ".168.1.1"))
self.assertTrue(not user_domain_match("192.168.1.1", "."))
# empty string is a special case
self.assertTrue(not user_domain_match("192.168.1.1", ""))
def test_wrong_domain(self):
# Cookies whose effective request-host name does not domain-match the
# domain are rejected.
# XXX far from complete
from cookielib import CookieJar
c = CookieJar()
interact_2965(c, "http://www.nasty.com/",
'foo=bar; domain=friendly.org; Version="1"')
self.assertEqual(len(c), 0)
def test_strict_domain(self):
# Cookies whose domain is a country-code tld like .co.uk should
# not be set if CookiePolicy.strict_domain is true.
from cookielib import CookieJar, DefaultCookiePolicy
cp = DefaultCookiePolicy(strict_domain=True)
cj = CookieJar(policy=cp)
interact_netscape(cj, "http://example.co.uk/", 'no=problemo')
interact_netscape(cj, "http://example.co.uk/",
'okey=dokey; Domain=.example.co.uk')
self.assertEqual(len(cj), 2)
for pseudo_tld in [".co.uk", ".org.za", ".tx.us", ".name.us"]:
interact_netscape(cj, "http://example.%s/" % pseudo_tld,
'spam=eggs; Domain=.co.uk')
self.assertEqual(len(cj), 2)
def test_two_component_domain_ns(self):
# Netscape: .www.bar.com, www.bar.com, .bar.com, bar.com, no domain
# should all get accepted, as should .acme.com, acme.com and no domain
# for 2-component domains like acme.com.
from cookielib import CookieJar, DefaultCookiePolicy
c = CookieJar()
# two-component V0 domain is OK
interact_netscape(c, "http://foo.net/", 'ns=bar')
self.assertEqual(len(c), 1)
self.assertEqual(c._cookies["foo.net"]["/"]["ns"].value, "bar")
self.assertEqual(interact_netscape(c, "http://foo.net/"), "ns=bar")
# *will* be returned to any other domain (unlike RFC 2965)...
self.assertEqual(interact_netscape(c, "http://www.foo.net/"),
"ns=bar")
# ...unless requested otherwise
pol = DefaultCookiePolicy(
strict_ns_domain=DefaultCookiePolicy.DomainStrictNonDomain)
c.set_policy(pol)
self.assertEqual(interact_netscape(c, "http://www.foo.net/"), "")
# unlike RFC 2965, even explicit two-component domain is OK,
# because .foo.net matches foo.net
interact_netscape(c, "http://foo.net/foo/",
'spam1=eggs; domain=foo.net')
# even if starts with a dot -- in NS rules, .foo.net matches foo.net!
interact_netscape(c, "http://foo.net/foo/bar/",
'spam2=eggs; domain=.foo.net')
self.assertEqual(len(c), 3)
self.assertEqual(c._cookies[".foo.net"]["/foo"]["spam1"].value,
"eggs")
self.assertEqual(c._cookies[".foo.net"]["/foo/bar"]["spam2"].value,
"eggs")
self.assertEqual(interact_netscape(c, "http://foo.net/foo/bar/"),
"spam2=eggs; spam1=eggs; ns=bar")
# top-level domain is too general
interact_netscape(c, "http://foo.net/", 'nini="ni"; domain=.net')
self.assertEqual(len(c), 3)
## # Netscape protocol doesn't allow non-special top level domains (such
## # as co.uk) in the domain attribute unless there are at least three
## # dots in it.
# Oh yes it does! Real implementations don't check this, and real
# cookies (of course) rely on that behaviour.
interact_netscape(c, "http://foo.co.uk", 'nasty=trick; domain=.co.uk')
## self.assertEqual(len(c), 2)
self.assertEqual(len(c), 4)
def test_two_component_domain_rfc2965(self):
from cookielib import CookieJar, DefaultCookiePolicy
pol = DefaultCookiePolicy(rfc2965=True)
c = CookieJar(pol)
# two-component V1 domain is OK
interact_2965(c, "http://foo.net/", 'foo=bar; Version="1"')
self.assertEqual(len(c), 1)
self.assertEqual(c._cookies["foo.net"]["/"]["foo"].value, "bar")
self.assertEqual(interact_2965(c, "http://foo.net/"),
"$Version=1; foo=bar")
# won't be returned to any other domain (because domain was implied)
self.assertEqual(interact_2965(c, "http://www.foo.net/"), "")
# unless domain is given explicitly, because then it must be
# rewritten to start with a dot: foo.net --> .foo.net, which does
# not domain-match foo.net
interact_2965(c, "http://foo.net/foo",
'spam=eggs; domain=foo.net; path=/foo; Version="1"')
self.assertEqual(len(c), 1)
self.assertEqual(interact_2965(c, "http://foo.net/foo"),
"$Version=1; foo=bar")
# explicit foo.net from three-component domain www.foo.net *does* get
# set, because .foo.net domain-matches .foo.net
interact_2965(c, "http://www.foo.net/foo/",
'spam=eggs; domain=foo.net; Version="1"')
self.assertEqual(c._cookies[".foo.net"]["/foo/"]["spam"].value,
"eggs")
self.assertEqual(len(c), 2)
self.assertEqual(interact_2965(c, "http://foo.net/foo/"),
"$Version=1; foo=bar")
self.assertEqual(interact_2965(c, "http://www.foo.net/foo/"),
'$Version=1; spam=eggs; $Domain="foo.net"')
# top-level domain is too general
interact_2965(c, "http://foo.net/",
'ni="ni"; domain=".net"; Version="1"')
self.assertEqual(len(c), 2)
# RFC 2965 doesn't require blocking this
interact_2965(c, "http://foo.co.uk/",
'nasty=trick; domain=.co.uk; Version="1"')
self.assertEqual(len(c), 3)
def test_domain_allow(self):
from cookielib import CookieJar, DefaultCookiePolicy
from urllib2 import Request
c = CookieJar(policy=DefaultCookiePolicy(
blocked_domains=["acme.com"],
allowed_domains=["www.acme.com"]))
req = Request("http://acme.com/")
headers = ["Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/"]
res = FakeResponse(headers, "http://acme.com/")
c.extract_cookies(res, req)
self.assertEqual(len(c), 0)
req = Request("http://www.acme.com/")
res = FakeResponse(headers, "http://www.acme.com/")
c.extract_cookies(res, req)
self.assertEqual(len(c), 1)
req = Request("http://www.coyote.com/")
res = FakeResponse(headers, "http://www.coyote.com/")
c.extract_cookies(res, req)
self.assertEqual(len(c), 1)
# set a cookie with non-allowed domain...
req = Request("http://www.coyote.com/")
res = FakeResponse(headers, "http://www.coyote.com/")
cookies = c.make_cookies(res, req)
c.set_cookie(cookies[0])
self.assertEqual(len(c), 2)
# ... and check is doesn't get returned
c.add_cookie_header(req)
self.assertTrue(not req.has_header("Cookie"))
def test_domain_block(self):
from cookielib import CookieJar, DefaultCookiePolicy
from urllib2 import Request
pol = DefaultCookiePolicy(
rfc2965=True, blocked_domains=[".acme.com"])
c = CookieJar(policy=pol)
headers = ["Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/"]
req = Request("http://www.acme.com/")
res = FakeResponse(headers, "http://www.acme.com/")
c.extract_cookies(res, req)
self.assertEqual(len(c), 0)
p = pol.set_blocked_domains(["acme.com"])
c.extract_cookies(res, req)
self.assertEqual(len(c), 1)
c.clear()
req = Request("http://www.roadrunner.net/")
res = FakeResponse(headers, "http://www.roadrunner.net/")
c.extract_cookies(res, req)
self.assertEqual(len(c), 1)
req = Request("http://www.roadrunner.net/")
c.add_cookie_header(req)
self.assertTrue((req.has_header("Cookie") and
req.has_header("Cookie2")))
c.clear()
pol.set_blocked_domains([".acme.com"])
c.extract_cookies(res, req)
self.assertEqual(len(c), 1)
# set a cookie with blocked domain...
req = Request("http://www.acme.com/")
res = FakeResponse(headers, "http://www.acme.com/")
cookies = c.make_cookies(res, req)
c.set_cookie(cookies[0])
self.assertEqual(len(c), 2)
# ... and check is doesn't get returned
c.add_cookie_header(req)
self.assertTrue(not req.has_header("Cookie"))
def test_secure(self):
from cookielib import CookieJar, DefaultCookiePolicy
for ns in True, False:
for whitespace in " ", "":
c = CookieJar()
if ns:
pol = DefaultCookiePolicy(rfc2965=False)
int = interact_netscape
vs = ""
else:
pol = DefaultCookiePolicy(rfc2965=True)
int = interact_2965
vs = "; Version=1"
c.set_policy(pol)
url = "http://www.acme.com/"
int(c, url, "foo1=bar%s%s" % (vs, whitespace))
int(c, url, "foo2=bar%s; secure%s" % (vs, whitespace))
self.assertTrue(
not c._cookies["www.acme.com"]["/"]["foo1"].secure,
"non-secure cookie registered secure")
self.assertTrue(
c._cookies["www.acme.com"]["/"]["foo2"].secure,
"secure cookie registered non-secure")
def test_quote_cookie_value(self):
from cookielib import CookieJar, DefaultCookiePolicy
c = CookieJar(policy=DefaultCookiePolicy(rfc2965=True))
interact_2965(c, "http://www.acme.com/", r'foo=\b"a"r; Version=1')
h = interact_2965(c, "http://www.acme.com/")
self.assertEqual(h, r'$Version=1; foo=\\b\"a\"r')
def test_missing_final_slash(self):
# Missing slash from request URL's abs_path should be assumed present.
from cookielib import CookieJar, DefaultCookiePolicy
from urllib2 import Request
url = "http://www.acme.com"
c = CookieJar(DefaultCookiePolicy(rfc2965=True))
interact_2965(c, url, "foo=bar; Version=1")
req = Request(url)
self.assertEqual(len(c), 1)
c.add_cookie_header(req)
self.assertTrue(req.has_header("Cookie"))
def test_domain_mirror(self):
from cookielib import CookieJar, DefaultCookiePolicy
pol = DefaultCookiePolicy(rfc2965=True)
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, "spam=eggs; Version=1")
h = interact_2965(c, url)
self.assertNotIn("Domain", h,
"absent domain returned with domain present")
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, 'spam=eggs; Version=1; Domain=.bar.com')
h = interact_2965(c, url)
self.assertIn('$Domain=".bar.com"', h, "domain not returned")
c = CookieJar(pol)
url = "http://foo.bar.com/"
# note missing initial dot in Domain
interact_2965(c, url, 'spam=eggs; Version=1; Domain=bar.com')
h = interact_2965(c, url)
self.assertIn('$Domain="bar.com"', h, "domain not returned")
def test_path_mirror(self):
from cookielib import CookieJar, DefaultCookiePolicy
pol = DefaultCookiePolicy(rfc2965=True)
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, "spam=eggs; Version=1")
h = interact_2965(c, url)
self.assertNotIn("Path", h, "absent path returned with path present")
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, 'spam=eggs; Version=1; Path=/')
h = interact_2965(c, url)
self.assertIn('$Path="/"', h, "path not returned")
def test_port_mirror(self):
from cookielib import CookieJar, DefaultCookiePolicy
pol = DefaultCookiePolicy(rfc2965=True)
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, "spam=eggs; Version=1")
h = interact_2965(c, url)
self.assertNotIn("Port", h, "absent port returned with port present")
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, "spam=eggs; Version=1; Port")
h = interact_2965(c, url)
self.assertTrue(re.search("\$Port([^=]|$)", h),
"port with no value not returned with no value")
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, 'spam=eggs; Version=1; Port="80"')
h = interact_2965(c, url)
self.assertIn('$Port="80"', h,
"port with single value not returned with single value")
c = CookieJar(pol)
url = "http://foo.bar.com/"
interact_2965(c, url, 'spam=eggs; Version=1; Port="80,8080"')
h = interact_2965(c, url)
self.assertIn('$Port="80,8080"', h,
"port with multiple values not returned with multiple "
"values")
def test_no_return_comment(self):
from cookielib import CookieJar, DefaultCookiePolicy
c = CookieJar(DefaultCookiePolicy(rfc2965=True))
url = "http://foo.bar.com/"
interact_2965(c, url, 'spam=eggs; Version=1; '
'Comment="does anybody read these?"; '
'CommentURL="http://foo.bar.net/comment.html"')
h = interact_2965(c, url)
self.assertTrue(
"Comment" not in h,
"Comment or CommentURL cookie-attributes returned to server")
def test_Cookie_iterator(self):
from cookielib import CookieJar, Cookie, DefaultCookiePolicy
cs = CookieJar(DefaultCookiePolicy(rfc2965=True))
# add some random cookies
interact_2965(cs, "http://blah.spam.org/", 'foo=eggs; Version=1; '
'Comment="does anybody read these?"; '
'CommentURL="http://foo.bar.net/comment.html"')
interact_netscape(cs, "http://www.acme.com/blah/", "spam=bar; secure")
interact_2965(cs, "http://www.acme.com/blah/",
"foo=bar; secure; Version=1")
interact_2965(cs, "http://www.acme.com/blah/",
"foo=bar; path=/; Version=1")
interact_2965(cs, "http://www.sol.no",
r'bang=wallop; version=1; domain=".sol.no"; '
r'port="90,100, 80,8080"; '
r'max-age=100; Comment = "Just kidding! (\"|\\\\) "')
versions = [1, 1, 1, 0, 1]
names = ["bang", "foo", "foo", "spam", "foo"]
domains = [".sol.no", "blah.spam.org", "www.acme.com",
"www.acme.com", "www.acme.com"]
paths = ["/", "/", "/", "/blah", "/blah/"]
for i in range(4):
i = 0
for c in cs:
self.assertIsInstance(c, Cookie)
self.assertEqual(c.version, versions[i])
self.assertEqual(c.name, names[i])
self.assertEqual(c.domain, domains[i])
self.assertEqual(c.path, paths[i])
i = i + 1
def test_parse_ns_headers(self):
from cookielib import parse_ns_headers
# missing domain value (invalid cookie)
self.assertEqual(
parse_ns_headers(["foo=bar; path=/; domain"]),
[[("foo", "bar"),
("path", "/"), ("domain", None), ("version", "0")]]
)
# invalid expires value
self.assertEqual(
parse_ns_headers(["foo=bar; expires=Foo Bar 12 33:22:11 2000"]),
[[("foo", "bar"), ("expires", None), ("version", "0")]]
)
# missing cookie value (valid cookie)
self.assertEqual(
parse_ns_headers(["foo"]),
[[("foo", None), ("version", "0")]]
)
# shouldn't add version if header is empty
self.assertEqual(parse_ns_headers([""]), [])
def test_bad_cookie_header(self):
def cookiejar_from_cookie_headers(headers):
from cookielib import CookieJar
from urllib2 import Request
c = CookieJar()
req = Request("http://www.example.com/")
r = FakeResponse(headers, "http://www.example.com/")
c.extract_cookies(r, req)
return c
# none of these bad headers should cause an exception to be raised
for headers in [
["Set-Cookie: "], # actually, nothing wrong with this
["Set-Cookie2: "], # ditto
# missing domain value
["Set-Cookie2: a=foo; path=/; Version=1; domain"],
# bad max-age
["Set-Cookie: b=foo; max-age=oops"],
# bad version
["Set-Cookie: b=foo; version=spam"],
]:
c = cookiejar_from_cookie_headers(headers)
# these bad cookies shouldn't be set
self.assertEqual(len(c), 0)
# cookie with invalid expires is treated as session cookie
headers = ["Set-Cookie: c=foo; expires=Foo Bar 12 33:22:11 2000"]
c = cookiejar_from_cookie_headers(headers)
cookie = c._cookies["www.example.com"]["/"]["c"]
self.assertTrue(cookie.expires is None)
class LWPCookieTests(TestCase):
# Tests taken from libwww-perl, with a few modifications and additions.
def test_netscape_example_1(self):
from cookielib import CookieJar, DefaultCookiePolicy
from urllib2 import Request
#-------------------------------------------------------------------
# First we check that it works for the original example at
# http://www.netscape.com/newsref/std/cookie_spec.html
# Client requests a document, and receives in the response:
#
# Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/; expires=Wednesday, 09-Nov-99 23:12:40 GMT
#
# When client requests a URL in path "/" on this server, it sends:
#
# Cookie: CUSTOMER=WILE_E_COYOTE
#
# Client requests a document, and receives in the response:
#
# Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/
#
# When client requests a URL in path "/" on this server, it sends:
#
# Cookie: CUSTOMER=WILE_E_COYOTE; PART_NUMBER=ROCKET_LAUNCHER_0001
#
# Client receives:
#
# Set-Cookie: SHIPPING=FEDEX; path=/fo
#
# When client requests a URL in path "/" on this server, it sends:
#
# Cookie: CUSTOMER=WILE_E_COYOTE; PART_NUMBER=ROCKET_LAUNCHER_0001
#
# When client requests a URL in path "/foo" on this server, it sends:
#
# Cookie: CUSTOMER=WILE_E_COYOTE; PART_NUMBER=ROCKET_LAUNCHER_0001; SHIPPING=FEDEX
#
# The last Cookie is buggy, because both specifications say that the
# most specific cookie must be sent first. SHIPPING=FEDEX is the
# most specific and should thus be first.
year_plus_one = time.localtime()[0] + 1
headers = []
c = CookieJar(DefaultCookiePolicy(rfc2965 = True))
#req = Request("http://1.1.1.1/",
# headers={"Host": "www.acme.com:80"})
req = Request("http://www.acme.com:80/",
headers={"Host": "www.acme.com:80"})
headers.append(
"Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/ ; "
"expires=Wednesday, 09-Nov-%d 23:12:40 GMT" % year_plus_one)
res = FakeResponse(headers, "http://www.acme.com/")
c.extract_cookies(res, req)
req = Request("http://www.acme.com/")
c.add_cookie_header(req)
self.assertEqual(req.get_header("Cookie"), "CUSTOMER=WILE_E_COYOTE")
self.assertEqual(req.get_header("Cookie2"), '$Version="1"')
headers.append("Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/")
res = FakeResponse(headers, "http://www.acme.com/")
c.extract_cookies(res, req)
req = Request("http://www.acme.com/foo/bar")
c.add_cookie_header(req)
h = req.get_header("Cookie")
self.assertIn("PART_NUMBER=ROCKET_LAUNCHER_0001", h)
self.assertIn("CUSTOMER=WILE_E_COYOTE", h)
headers.append('Set-Cookie: SHIPPING=FEDEX; path=/foo')
res = FakeResponse(headers, "http://www.acme.com")
c.extract_cookies(res, req)
req = Request("http://www.acme.com/")
c.add_cookie_header(req)
h = req.get_header("Cookie")
self.assertIn("PART_NUMBER=ROCKET_LAUNCHER_0001", h)
self.assertIn("CUSTOMER=WILE_E_COYOTE", h)
self.assertNotIn("SHIPPING=FEDEX", h)
req = Request("http://www.acme.com/foo/")
c.add_cookie_header(req)
h = req.get_header("Cookie")
self.assertIn("PART_NUMBER=ROCKET_LAUNCHER_0001", h)
self.assertIn("CUSTOMER=WILE_E_COYOTE", h)
self.assertTrue(h.startswith("SHIPPING=FEDEX;"))
def test_netscape_example_2(self):
from cookielib import CookieJar
from urllib2 import Request
# Second Example transaction sequence:
#
# Assume all mappings from above have been cleared.
#
# Client receives:
#
# Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/
#
# When client requests a URL in path "/" on this server, it sends:
#
# Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001
#
# Client receives:
#
# Set-Cookie: PART_NUMBER=RIDING_ROCKET_0023; path=/ammo
#
# When client requests a URL in path "/ammo" on this server, it sends:
#
# Cookie: PART_NUMBER=RIDING_ROCKET_0023; PART_NUMBER=ROCKET_LAUNCHER_0001
#
# NOTE: There are two name/value pairs named "PART_NUMBER" due to
# the inheritance of the "/" mapping in addition to the "/ammo" mapping.
c = CookieJar()
headers = []
req = Request("http://www.acme.com/")
headers.append("Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/")
res = FakeResponse(headers, "http://www.acme.com/")
c.extract_cookies(res, req)
req = Request("http://www.acme.com/")
c.add_cookie_header(req)
self.assertEqual(req.get_header("Cookie"),
"PART_NUMBER=ROCKET_LAUNCHER_0001")
headers.append(
"Set-Cookie: PART_NUMBER=RIDING_ROCKET_0023; path=/ammo")
res = FakeResponse(headers, "http://www.acme.com/")
c.extract_cookies(res, req)
req = Request("http://www.acme.com/ammo")
c.add_cookie_header(req)
self.assertTrue(re.search(r"PART_NUMBER=RIDING_ROCKET_0023;\s*"
"PART_NUMBER=ROCKET_LAUNCHER_0001",
req.get_header("Cookie")))
def test_ietf_example_1(self):
from cookielib import CookieJar, DefaultCookiePolicy
#-------------------------------------------------------------------
# Then we test with the examples from draft-ietf-http-state-man-mec-03.txt
#
# 5. EXAMPLES
c = CookieJar(DefaultCookiePolicy(rfc2965=True))
#
# 5.1 Example 1
#
# Most detail of request and response headers has been omitted. Assume
# the user agent has no stored cookies.
#
# 1. User Agent -> Server
#
# POST /acme/login HTTP/1.1
# [form data]
#
# User identifies self via a form.
#
# 2. Server -> User Agent
#
# HTTP/1.1 200 OK
# Set-Cookie2: Customer="WILE_E_COYOTE"; Version="1"; Path="/acme"
#
# Cookie reflects user's identity.
cookie = interact_2965(
c, 'http://www.acme.com/acme/login',
'Customer="WILE_E_COYOTE"; Version="1"; Path="/acme"')
self.assertTrue(not cookie)
#
# 3. User Agent -> Server
#
# POST /acme/pickitem HTTP/1.1
# Cookie: $Version="1"; Customer="WILE_E_COYOTE"; $Path="/acme"
# [form data]
#
# User selects an item for ``shopping basket.''
#
# 4. Server -> User Agent
#
# HTTP/1.1 200 OK
# Set-Cookie2: Part_Number="Rocket_Launcher_0001"; Version="1";
# Path="/acme"
#
# Shopping basket contains an item.
cookie = interact_2965(c, 'http://www.acme.com/acme/pickitem',
'Part_Number="Rocket_Launcher_0001"; '
'Version="1"; Path="/acme"');
self.assertTrue(re.search(
r'^\$Version="?1"?; Customer="?WILE_E_COYOTE"?; \$Path="/acme"$',
cookie))
#
# 5. User Agent -> Server
#
# POST /acme/shipping HTTP/1.1
# Cookie: $Version="1";
# Customer="WILE_E_COYOTE"; $Path="/acme";
# Part_Number="Rocket_Launcher_0001"; $Path="/acme"
# [form data]
#
# User selects shipping method from form.
#
# 6. Server -> User Agent
#
# HTTP/1.1 200 OK
# Set-Cookie2: Shipping="FedEx"; Version="1"; Path="/acme"
#
# New cookie reflects shipping method.
cookie = interact_2965(c, "http://www.acme.com/acme/shipping",
'Shipping="FedEx"; Version="1"; Path="/acme"')
self.assertTrue(re.search(r'^\$Version="?1"?;', cookie))
self.assertTrue(re.search(r'Part_Number="?Rocket_Launcher_0001"?;'
'\s*\$Path="\/acme"', cookie))
self.assertTrue(re.search(r'Customer="?WILE_E_COYOTE"?;\s*\$Path="\/acme"',
cookie))
#
# 7. User Agent -> Server
#
# POST /acme/process HTTP/1.1
# Cookie: $Version="1";
# Customer="WILE_E_COYOTE"; $Path="/acme";
# Part_Number="Rocket_Launcher_0001"; $Path="/acme";
# Shipping="FedEx"; $Path="/acme"
# [form data]
#
# User chooses to process order.
#
# 8. Server -> User Agent
#
# HTTP/1.1 200 OK
#
# Transaction is complete.
cookie = interact_2965(c, "http://www.acme.com/acme/process")
self.assertTrue(
re.search(r'Shipping="?FedEx"?;\s*\$Path="\/acme"', cookie) and
"WILE_E_COYOTE" in cookie)
#
# The user agent makes a series of requests on the origin server, after
# each of which it receives a new cookie. All the cookies have the same
# Path attribute and (default) domain. Because the request URLs all have
# /acme as a prefix, and that matches the Path attribute, each request
# contains all the cookies received so far.
def test_ietf_example_2(self):
from cookielib import CookieJar, DefaultCookiePolicy
# 5.2 Example 2
#
# This example illustrates the effect of the Path attribute. All detail
# of request and response headers has been omitted. Assume the user agent
# has no stored cookies.
c = CookieJar(DefaultCookiePolicy(rfc2965=True))
# Imagine the user agent has received, in response to earlier requests,
# the response headers
#
# Set-Cookie2: Part_Number="Rocket_Launcher_0001"; Version="1";
# Path="/acme"
#
# and
#
# Set-Cookie2: Part_Number="Riding_Rocket_0023"; Version="1";
# Path="/acme/ammo"
interact_2965(
c, "http://www.acme.com/acme/ammo/specific",
'Part_Number="Rocket_Launcher_0001"; Version="1"; Path="/acme"',
'Part_Number="Riding_Rocket_0023"; Version="1"; Path="/acme/ammo"')
# A subsequent request by the user agent to the (same) server for URLs of
# the form /acme/ammo/... would include the following request header:
#
# Cookie: $Version="1";
# Part_Number="Riding_Rocket_0023"; $Path="/acme/ammo";
# Part_Number="Rocket_Launcher_0001"; $Path="/acme"
#
# Note that the NAME=VALUE pair for the cookie with the more specific Path
# attribute, /acme/ammo, comes before the one with the less specific Path
# attribute, /acme. Further note that the same cookie name appears more
# than once.
cookie = interact_2965(c, "http://www.acme.com/acme/ammo/...")
self.assertTrue(
re.search(r"Riding_Rocket_0023.*Rocket_Launcher_0001", cookie))
# A subsequent request by the user agent to the (same) server for a URL of
# the form /acme/parts/ would include the following request header:
#
# Cookie: $Version="1"; Part_Number="Rocket_Launcher_0001"; $Path="/acme"
#
# Here, the second cookie's Path attribute /acme/ammo is not a prefix of
# the request URL, /acme/parts/, so the cookie does not get forwarded to
# the server.
cookie = interact_2965(c, "http://www.acme.com/acme/parts/")
self.assertIn("Rocket_Launcher_0001", cookie)
self.assertNotIn("Riding_Rocket_0023", cookie)
def test_rejection(self):
# Test rejection of Set-Cookie2 responses based on domain, path, port.
from cookielib import DefaultCookiePolicy, LWPCookieJar
pol = DefaultCookiePolicy(rfc2965=True)
c = LWPCookieJar(policy=pol)
max_age = "max-age=3600"
# illegal domain (no embedded dots)
cookie = interact_2965(c, "http://www.acme.com",
'foo=bar; domain=".com"; version=1')
self.assertTrue(not c)
# legal domain
cookie = interact_2965(c, "http://www.acme.com",
'ping=pong; domain="acme.com"; version=1')
self.assertEqual(len(c), 1)
# illegal domain (host prefix "www.a" contains a dot)
cookie = interact_2965(c, "http://www.a.acme.com",
'whiz=bang; domain="acme.com"; version=1')
self.assertEqual(len(c), 1)
# legal domain
cookie = interact_2965(c, "http://www.a.acme.com",
'wow=flutter; domain=".a.acme.com"; version=1')
self.assertEqual(len(c), 2)
# can't partially match an IP-address
cookie = interact_2965(c, "http://125.125.125.125",
'zzzz=ping; domain="125.125.125"; version=1')
self.assertEqual(len(c), 2)
# illegal path (must be prefix of request path)
cookie = interact_2965(c, "http://www.sol.no",
'blah=rhubarb; domain=".sol.no"; path="/foo"; '
'version=1')
self.assertEqual(len(c), 2)
# legal path
cookie = interact_2965(c, "http://www.sol.no/foo/bar",
'bing=bong; domain=".sol.no"; path="/foo"; '
'version=1')
self.assertEqual(len(c), 3)
# illegal port (request-port not in list)
cookie = interact_2965(c, "http://www.sol.no",
'whiz=ffft; domain=".sol.no"; port="90,100"; '
'version=1')
self.assertEqual(len(c), 3)
# legal port
cookie = interact_2965(
c, "http://www.sol.no",
r'bang=wallop; version=1; domain=".sol.no"; '
r'port="90,100, 80,8080"; '
r'max-age=100; Comment = "Just kidding! (\"|\\\\) "')
self.assertEqual(len(c), 4)
# port attribute without any value (current port)
cookie = interact_2965(c, "http://www.sol.no",
'foo9=bar; version=1; domain=".sol.no"; port; '
'max-age=100;')
self.assertEqual(len(c), 5)
# encoded path
# LWP has this test, but unescaping allowed path characters seems
# like a bad idea, so I think this should fail:
## cookie = interact_2965(c, "http://www.sol.no/foo/",
## r'foo8=bar; version=1; path="/%66oo"')
# but this is OK, because '<' is not an allowed HTTP URL path
# character:
cookie = interact_2965(c, "http://www.sol.no/<oo/",
r'foo8=bar; version=1; path="/%3coo"')
self.assertEqual(len(c), 6)
# save and restore
filename = test_support.TESTFN
try:
c.save(filename, ignore_discard=True)
old = repr(c)
c = LWPCookieJar(policy=pol)
c.load(filename, ignore_discard=True)
finally:
try: os.unlink(filename)
except OSError: pass
self.assertEqual(old, repr(c))
def test_url_encoding(self):
# Try some URL encodings of the PATHs.
# (the behaviour here has changed from libwww-perl)
from cookielib import CookieJar, DefaultCookiePolicy
c = CookieJar(DefaultCookiePolicy(rfc2965=True))
interact_2965(c, "http://www.acme.com/foo%2f%25/%3c%3c%0Anew%E5/%E5",
"foo = bar; version = 1")
cookie = interact_2965(
c, "http://www.acme.com/foo%2f%25/<<%0anew/",
'bar=baz; path="/foo/"; version=1');
version_re = re.compile(r'^\$version=\"?1\"?', re.I)
self.assertTrue("foo=bar" in cookie and version_re.search(cookie))
cookie = interact_2965(
c, "http://www.acme.com/foo/%25/<<%0anew/")
self.assertTrue(not cookie)
# unicode URL doesn't raise exception
cookie = interact_2965(c, u"http://www.acme.com/\xfc")
def test_mozilla(self):
# Save / load Mozilla/Netscape cookie file format.
from cookielib import MozillaCookieJar, DefaultCookiePolicy
year_plus_one = time.localtime()[0] + 1
filename = test_support.TESTFN
c = MozillaCookieJar(filename,
policy=DefaultCookiePolicy(rfc2965=True))
interact_2965(c, "http://www.acme.com/",
"foo1=bar; max-age=100; Version=1")
interact_2965(c, "http://www.acme.com/",
'foo2=bar; port="80"; max-age=100; Discard; Version=1')
interact_2965(c, "http://www.acme.com/", "foo3=bar; secure; Version=1")
expires = "expires=09-Nov-%d 23:12:40 GMT" % (year_plus_one,)
interact_netscape(c, "http://www.foo.com/",
"fooa=bar; %s" % expires)
interact_netscape(c, "http://www.foo.com/",
"foob=bar; Domain=.foo.com; %s" % expires)
interact_netscape(c, "http://www.foo.com/",
"fooc=bar; Domain=www.foo.com; %s" % expires)
def save_and_restore(cj, ignore_discard):
try:
cj.save(ignore_discard=ignore_discard)
new_c = MozillaCookieJar(filename,
DefaultCookiePolicy(rfc2965=True))
new_c.load(ignore_discard=ignore_discard)
finally:
try: os.unlink(filename)
except OSError: pass
return new_c
new_c = save_and_restore(c, True)
self.assertEqual(len(new_c), 6) # none discarded
self.assertIn("name='foo1', value='bar'", repr(new_c))
new_c = save_and_restore(c, False)
self.assertEqual(len(new_c), 4) # 2 of them discarded on save
self.assertIn("name='foo1', value='bar'", repr(new_c))
def test_netscape_misc(self):
# Some additional Netscape cookies tests.
from cookielib import CookieJar
from urllib2 import Request
c = CookieJar()
headers = []
req = Request("http://foo.bar.acme.com/foo")
# Netscape allows a host part that contains dots
headers.append("Set-Cookie: Customer=WILE_E_COYOTE; domain=.acme.com")
res = FakeResponse(headers, "http://www.acme.com/foo")
c.extract_cookies(res, req)
# and that the domain is the same as the host without adding a leading
# dot to the domain. Should not quote even if strange chars are used
# in the cookie value.
headers.append("Set-Cookie: PART_NUMBER=3,4; domain=foo.bar.acme.com")
res = FakeResponse(headers, "http://www.acme.com/foo")
c.extract_cookies(res, req)
req = Request("http://foo.bar.acme.com/foo")
c.add_cookie_header(req)
self.assertTrue(
"PART_NUMBER=3,4" in req.get_header("Cookie") and
"Customer=WILE_E_COYOTE" in req.get_header("Cookie"))
def test_intranet_domains_2965(self):
# Test handling of local intranet hostnames without a dot.
from cookielib import CookieJar, DefaultCookiePolicy
c = CookieJar(DefaultCookiePolicy(rfc2965=True))
interact_2965(c, "http://example/",
"foo1=bar; PORT; Discard; Version=1;")
cookie = interact_2965(c, "http://example/",
'foo2=bar; domain=".local"; Version=1')
self.assertIn("foo1=bar", cookie)
interact_2965(c, "http://example/", 'foo3=bar; Version=1')
cookie = interact_2965(c, "http://example/")
self.assertIn("foo2=bar", cookie)
self.assertEqual(len(c), 3)
def test_intranet_domains_ns(self):
from cookielib import CookieJar, DefaultCookiePolicy
c = CookieJar(DefaultCookiePolicy(rfc2965 = False))
interact_netscape(c, "http://example/", "foo1=bar")
cookie = interact_netscape(c, "http://example/",
'foo2=bar; domain=.local')
self.assertEqual(len(c), 2)
self.assertIn("foo1=bar", cookie)
cookie = interact_netscape(c, "http://example/")
self.assertIn("foo2=bar", cookie)
self.assertEqual(len(c), 2)
def test_empty_path(self):
from cookielib import CookieJar, DefaultCookiePolicy
from urllib2 import Request
# Test for empty path
# Broken web-server ORION/1.3.38 returns to the client response like
#
# Set-Cookie: JSESSIONID=ABCDERANDOM123; Path=
#
# ie. with Path set to nothing.
# In this case, extract_cookies() must set cookie to / (root)
c = CookieJar(DefaultCookiePolicy(rfc2965 = True))
headers = []
req = Request("http://www.ants.com/")
headers.append("Set-Cookie: JSESSIONID=ABCDERANDOM123; Path=")
res = FakeResponse(headers, "http://www.ants.com/")
c.extract_cookies(res, req)
req = Request("http://www.ants.com/")
c.add_cookie_header(req)
self.assertEqual(req.get_header("Cookie"),
"JSESSIONID=ABCDERANDOM123")
self.assertEqual(req.get_header("Cookie2"), '$Version="1"')
# missing path in the request URI
req = Request("http://www.ants.com:8080")
c.add_cookie_header(req)
self.assertEqual(req.get_header("Cookie"),
"JSESSIONID=ABCDERANDOM123")
self.assertEqual(req.get_header("Cookie2"), '$Version="1"')
def test_session_cookies(self):
from cookielib import CookieJar
from urllib2 import Request
year_plus_one = time.localtime()[0] + 1
# Check session cookies are deleted properly by
# CookieJar.clear_session_cookies method
req = Request('http://www.perlmeister.com/scripts')
headers = []
headers.append("Set-Cookie: s1=session;Path=/scripts")
headers.append("Set-Cookie: p1=perm; Domain=.perlmeister.com;"
"Path=/;expires=Fri, 02-Feb-%d 23:24:20 GMT" %
year_plus_one)
headers.append("Set-Cookie: p2=perm;Path=/;expires=Fri, "
"02-Feb-%d 23:24:20 GMT" % year_plus_one)
headers.append("Set-Cookie: s2=session;Path=/scripts;"
"Domain=.perlmeister.com")
headers.append('Set-Cookie2: s3=session;Version=1;Discard;Path="/"')
res = FakeResponse(headers, 'http://www.perlmeister.com/scripts')
c = CookieJar()
c.extract_cookies(res, req)
# How many session/permanent cookies do we have?
counter = {"session_after": 0,
"perm_after": 0,
"session_before": 0,
"perm_before": 0}
for cookie in c:
key = "%s_before" % cookie.value
counter[key] = counter[key] + 1
c.clear_session_cookies()
# How many now?
for cookie in c:
key = "%s_after" % cookie.value
counter[key] = counter[key] + 1
self.assertTrue(not (
# a permanent cookie got lost accidentally
counter["perm_after"] != counter["perm_before"] or
# a session cookie hasn't been cleared
counter["session_after"] != 0 or
# we didn't have session cookies in the first place
counter["session_before"] == 0))
def test_main(verbose=None):
test_support.run_unittest(
DateTimeTests,
HeaderTests,
CookieTests,
FileCookieJarTests,
LWPCookieTests,
)
if __name__ == "__main__":
test_main(verbose=True)
"""Test script for ftplib module."""
# Modified by Giampaolo Rodola' to test FTP class, IPv6 and TLS
# environment
import ftplib
import asyncore
import asynchat
import socket
import StringIO
import errno
import os
try:
import ssl
except ImportError:
ssl = None
from unittest import TestCase
from test import test_support
from test.test_support import HOST
threading = test_support.import_module('threading')
# the dummy data returned by server over the data channel when
# RETR, LIST and NLST commands are issued
RETR_DATA = 'abcde12345\r\n' * 1000
LIST_DATA = 'foo\r\nbar\r\n'
NLST_DATA = 'foo\r\nbar\r\n'
class DummyDTPHandler(asynchat.async_chat):
dtp_conn_closed = False
def __init__(self, conn, baseclass):
asynchat.async_chat.__init__(self, conn)
self.baseclass = baseclass
self.baseclass.last_received_data = ''
def handle_read(self):
self.baseclass.last_received_data += self.recv(1024)
def handle_close(self):
# XXX: this method can be called many times in a row for a single
# connection, including in clear-text (non-TLS) mode.
# (behaviour witnessed with test_data_connection)
if not self.dtp_conn_closed:
self.baseclass.push('226 transfer complete')
self.close()
self.dtp_conn_closed = True
def handle_error(self):
raise
class DummyFTPHandler(asynchat.async_chat):
dtp_handler = DummyDTPHandler
def __init__(self, conn):
asynchat.async_chat.__init__(self, conn)
self.set_terminator("\r\n")
self.in_buffer = []
self.dtp = None
self.last_received_cmd = None
self.last_received_data = ''
self.next_response = ''
self.rest = None
self.push('220 welcome')
def collect_incoming_data(self, data):
self.in_buffer.append(data)
def found_terminator(self):
line = ''.join(self.in_buffer)
self.in_buffer = []
if self.next_response:
self.push(self.next_response)
self.next_response = ''
cmd = line.split(' ')[0].lower()
self.last_received_cmd = cmd
space = line.find(' ')
if space != -1:
arg = line[space + 1:]
else:
arg = ""
if hasattr(self, 'cmd_' + cmd):
method = getattr(self, 'cmd_' + cmd)
method(arg)
else:
self.push('550 command "%s" not understood.' %cmd)
def handle_error(self):
raise
def push(self, data):
asynchat.async_chat.push(self, data + '\r\n')
def cmd_port(self, arg):
addr = map(int, arg.split(','))
ip = '%d.%d.%d.%d' %tuple(addr[:4])
port = (addr[4] * 256) + addr[5]
s = socket.create_connection((ip, port), timeout=10)
self.dtp = self.dtp_handler(s, baseclass=self)
self.push('200 active data connection established')
def cmd_pasv(self, arg):
sock = socket.socket()
sock.bind((self.socket.getsockname()[0], 0))
sock.listen(5)
sock.settimeout(10)
ip, port = sock.getsockname()[:2]
ip = ip.replace('.', ',')
p1, p2 = divmod(port, 256)
self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2))
conn, addr = sock.accept()
self.dtp = self.dtp_handler(conn, baseclass=self)
def cmd_eprt(self, arg):
af, ip, port = arg.split(arg[0])[1:-1]
port = int(port)
s = socket.create_connection((ip, port), timeout=10)
self.dtp = self.dtp_handler(s, baseclass=self)
self.push('200 active data connection established')
def cmd_epsv(self, arg):
sock = socket.socket(socket.AF_INET6)
sock.bind((self.socket.getsockname()[0], 0))
sock.listen(5)
sock.settimeout(10)
port = sock.getsockname()[1]
self.push('229 entering extended passive mode (|||%d|)' %port)
conn, addr = sock.accept()
self.dtp = self.dtp_handler(conn, baseclass=self)
def cmd_echo(self, arg):
# sends back the received string (used by the test suite)
self.push(arg)
def cmd_user(self, arg):
self.push('331 username ok')
def cmd_pass(self, arg):
self.push('230 password ok')
def cmd_acct(self, arg):
self.push('230 acct ok')
def cmd_rnfr(self, arg):
self.push('350 rnfr ok')
def cmd_rnto(self, arg):
self.push('250 rnto ok')
def cmd_dele(self, arg):
self.push('250 dele ok')
def cmd_cwd(self, arg):
self.push('250 cwd ok')
def cmd_size(self, arg):
self.push('250 1000')
def cmd_mkd(self, arg):
self.push('257 "%s"' %arg)
def cmd_rmd(self, arg):
self.push('250 rmd ok')
def cmd_pwd(self, arg):
self.push('257 "pwd ok"')
def cmd_type(self, arg):
self.push('200 type ok')
def cmd_quit(self, arg):
self.push('221 quit ok')
self.close()
def cmd_stor(self, arg):
self.push('125 stor ok')
def cmd_rest(self, arg):
self.rest = arg
self.push('350 rest ok')
def cmd_retr(self, arg):
self.push('125 retr ok')
if self.rest is not None:
offset = int(self.rest)
else:
offset = 0
self.dtp.push(RETR_DATA[offset:])
self.dtp.close_when_done()
self.rest = None
def cmd_list(self, arg):
self.push('125 list ok')
self.dtp.push(LIST_DATA)
self.dtp.close_when_done()
def cmd_nlst(self, arg):
self.push('125 nlst ok')
self.dtp.push(NLST_DATA)
self.dtp.close_when_done()
class DummyFTPServer(asyncore.dispatcher, threading.Thread):
handler = DummyFTPHandler
def __init__(self, address, af=socket.AF_INET):
threading.Thread.__init__(self)
asyncore.dispatcher.__init__(self)
self.create_socket(af, socket.SOCK_STREAM)
self.bind(address)
self.listen(5)
self.active = False
self.active_lock = threading.Lock()
self.host, self.port = self.socket.getsockname()[:2]
def start(self):
assert not self.active
self.__flag = threading.Event()
threading.Thread.start(self)
self.__flag.wait()
def run(self):
self.active = True
self.__flag.set()
while self.active and asyncore.socket_map:
self.active_lock.acquire()
asyncore.loop(timeout=0.1, count=1)
self.active_lock.release()
asyncore.close_all(ignore_all=True)
def stop(self):
assert self.active
self.active = False
self.join()
def handle_accept(self):
conn, addr = self.accept()
self.handler = self.handler(conn)
self.close()
def handle_connect(self):
self.close()
handle_read = handle_connect
def writable(self):
return 0
def handle_error(self):
raise
if ssl is not None:
CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem")
class SSLConnection(object, asyncore.dispatcher):
"""An asyncore.dispatcher subclass supporting TLS/SSL."""
_ssl_accepting = False
_ssl_closing = False
def secure_connection(self):
self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
certfile=CERTFILE, server_side=True,
do_handshake_on_connect=False,
ssl_version=ssl.PROTOCOL_SSLv23)
self._ssl_accepting = True
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
except ssl.SSLError, err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close()
raise
except socket.error, err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
self._ssl_accepting = False
def _do_ssl_shutdown(self):
self._ssl_closing = True
try:
self.socket = self.socket.unwrap()
except ssl.SSLError, err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
except socket.error, err:
# Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
# from OpenSSL's SSL_shutdown(), corresponding to a
# closed socket condition. See also:
# http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
pass
self._ssl_closing = False
super(SSLConnection, self).close()
def handle_read_event(self):
if self._ssl_accepting:
self._do_ssl_handshake()
elif self._ssl_closing:
self._do_ssl_shutdown()
else:
super(SSLConnection, self).handle_read_event()
def handle_write_event(self):
if self._ssl_accepting:
self._do_ssl_handshake()
elif self._ssl_closing:
self._do_ssl_shutdown()
else:
super(SSLConnection, self).handle_write_event()
def send(self, data):
try:
return super(SSLConnection, self).send(data)
except ssl.SSLError, err:
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN,
ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return 0
raise
def recv(self, buffer_size):
try:
return super(SSLConnection, self).recv(buffer_size)
except ssl.SSLError, err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return ''
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
self.handle_close()
return ''
raise
def handle_error(self):
raise
def close(self):
if (isinstance(self.socket, ssl.SSLSocket) and
self.socket._sslobj is not None):
self._do_ssl_shutdown()
class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
"""A DummyDTPHandler subclass supporting TLS/SSL."""
def __init__(self, conn, baseclass):
DummyDTPHandler.__init__(self, conn, baseclass)
if self.baseclass.secure_data_channel:
self.secure_connection()
class DummyTLS_FTPHandler(SSLConnection, DummyFTPHandler):
"""A DummyFTPHandler subclass supporting TLS/SSL."""
dtp_handler = DummyTLS_DTPHandler
def __init__(self, conn):
DummyFTPHandler.__init__(self, conn)
self.secure_data_channel = False
def cmd_auth(self, line):
"""Set up secure control channel."""
self.push('234 AUTH TLS successful')
self.secure_connection()
def cmd_pbsz(self, line):
"""Negotiate size of buffer for secure data transfer.
For TLS/SSL the only valid value for the parameter is '0'.
Any other value is accepted but ignored.
"""
self.push('200 PBSZ=0 successful.')
def cmd_prot(self, line):
"""Setup un/secure data channel."""
arg = line.upper()
if arg == 'C':
self.push('200 Protection set to Clear')
self.secure_data_channel = False
elif arg == 'P':
self.push('200 Protection set to Private')
self.secure_data_channel = True
else:
self.push("502 Unrecognized PROT type (use C or P).")
class DummyTLS_FTPServer(DummyFTPServer):
handler = DummyTLS_FTPHandler
class TestFTPClass(TestCase):
def setUp(self):
self.server = DummyFTPServer((HOST, 0))
self.server.start()
self.client = ftplib.FTP(timeout=10)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
self.client.close()
self.server.stop()
def test_getwelcome(self):
self.assertEqual(self.client.getwelcome(), '220 welcome')
def test_sanitize(self):
self.assertEqual(self.client.sanitize('foo'), repr('foo'))
self.assertEqual(self.client.sanitize('pass 12345'), repr('pass *****'))
self.assertEqual(self.client.sanitize('PASS 12345'), repr('PASS *****'))
def test_exceptions(self):
self.assertRaises(ftplib.error_temp, self.client.sendcmd, 'echo 400')
self.assertRaises(ftplib.error_temp, self.client.sendcmd, 'echo 499')
self.assertRaises(ftplib.error_perm, self.client.sendcmd, 'echo 500')
self.assertRaises(ftplib.error_perm, self.client.sendcmd, 'echo 599')
self.assertRaises(ftplib.error_proto, self.client.sendcmd, 'echo 999')
def test_all_errors(self):
exceptions = (ftplib.error_reply, ftplib.error_temp, ftplib.error_perm,
ftplib.error_proto, ftplib.Error, IOError, EOFError)
for x in exceptions:
try:
raise x('exception not included in all_errors set')
except ftplib.all_errors:
pass
def test_set_pasv(self):
# passive mode is supposed to be enabled by default
self.assertTrue(self.client.passiveserver)
self.client.set_pasv(True)
self.assertTrue(self.client.passiveserver)
self.client.set_pasv(False)
self.assertFalse(self.client.passiveserver)
def test_voidcmd(self):
self.client.voidcmd('echo 200')
self.client.voidcmd('echo 299')
self.assertRaises(ftplib.error_reply, self.client.voidcmd, 'echo 199')
self.assertRaises(ftplib.error_reply, self.client.voidcmd, 'echo 300')
def test_login(self):
self.client.login()
def test_acct(self):
self.client.acct('passwd')
def test_rename(self):
self.client.rename('a', 'b')
self.server.handler.next_response = '200'
self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b')
def test_delete(self):
self.client.delete('foo')
self.server.handler.next_response = '199'
self.assertRaises(ftplib.error_reply, self.client.delete, 'foo')
def test_size(self):
self.client.size('foo')
def test_mkd(self):
dir = self.client.mkd('/foo')
self.assertEqual(dir, '/foo')
def test_rmd(self):
self.client.rmd('foo')
def test_pwd(self):
dir = self.client.pwd()
self.assertEqual(dir, 'pwd ok')
def test_quit(self):
self.assertEqual(self.client.quit(), '221 quit ok')
# Ensure the connection gets closed; sock attribute should be None
self.assertEqual(self.client.sock, None)
def test_retrbinary(self):
received = []
self.client.retrbinary('retr', received.append)
self.assertEqual(''.join(received), RETR_DATA)
def test_retrbinary_rest(self):
for rest in (0, 10, 20):
received = []
self.client.retrbinary('retr', received.append, rest=rest)
self.assertEqual(''.join(received), RETR_DATA[rest:],
msg='rest test case %d %d %d' % (rest,
len(''.join(received)),
len(RETR_DATA[rest:])))
def test_retrlines(self):
received = []
self.client.retrlines('retr', received.append)
self.assertEqual(''.join(received), RETR_DATA.replace('\r\n', ''))
def test_storbinary(self):
f = StringIO.StringIO(RETR_DATA)
self.client.storbinary('stor', f)
self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
self.client.storbinary('stor', f, callback=lambda x: flag.append(None))
self.assertTrue(flag)
def test_storbinary_rest(self):
f = StringIO.StringIO(RETR_DATA)
for r in (30, '30'):
f.seek(0)
self.client.storbinary('stor', f, rest=r)
self.assertEqual(self.server.handler.rest, str(r))
def test_storlines(self):
f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n'))
self.client.storlines('stor', f)
self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
self.client.storlines('stor foo', f, callback=lambda x: flag.append(None))
self.assertTrue(flag)
def test_nlst(self):
self.client.nlst()
self.assertEqual(self.client.nlst(), NLST_DATA.split('\r\n')[:-1])
def test_dir(self):
l = []
self.client.dir(lambda x: l.append(x))
self.assertEqual(''.join(l), LIST_DATA.replace('\r\n', ''))
def test_makeport(self):
self.client.makeport()
# IPv4 is in use, just make sure send_eprt has not been used
self.assertEqual(self.server.handler.last_received_cmd, 'port')
def test_makepasv(self):
host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10)
conn.close()
# IPv4 is in use, just make sure send_epsv has not been used
self.assertEqual(self.server.handler.last_received_cmd, 'pasv')
class TestIPv6Environment(TestCase):
def setUp(self):
self.server = DummyFTPServer((HOST, 0), af=socket.AF_INET6)
self.server.start()
self.client = ftplib.FTP()
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
self.client.close()
self.server.stop()
def test_af(self):
self.assertEqual(self.client.af, socket.AF_INET6)
def test_makeport(self):
self.client.makeport()
self.assertEqual(self.server.handler.last_received_cmd, 'eprt')
def test_makepasv(self):
host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10)
conn.close()
self.assertEqual(self.server.handler.last_received_cmd, 'epsv')
def test_transfer(self):
def retr():
received = []
self.client.retrbinary('retr', received.append)
self.assertEqual(''.join(received), RETR_DATA)
self.client.set_pasv(True)
retr()
self.client.set_pasv(False)
retr()
class TestTLS_FTPClassMixin(TestFTPClass):
"""Repeat TestFTPClass tests starting the TLS layer for both control
and data connections first.
"""
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
self.client = ftplib.FTP_TLS(timeout=10)
self.client.connect(self.server.host, self.server.port)
# enable TLS
self.client.auth()
self.client.prot_p()
class TestTLS_FTPClass(TestCase):
"""Specific TLS_FTP class tests."""
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
self.client = ftplib.FTP_TLS(timeout=10)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
self.client.close()
self.server.stop()
def test_control_connection(self):
self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
self.client.auth()
self.assertIsInstance(self.client.sock, ssl.SSLSocket)
def test_data_connection(self):
# clear text
sock = self.client.transfercmd('list')
self.assertNotIsInstance(sock, ssl.SSLSocket)
sock.close()
self.assertEqual(self.client.voidresp(), "226 transfer complete")
# secured, after PROT P
self.client.prot_p()
sock = self.client.transfercmd('list')
self.assertIsInstance(sock, ssl.SSLSocket)
sock.close()
self.assertEqual(self.client.voidresp(), "226 transfer complete")
# PROT C is issued, the connection must be in cleartext again
self.client.prot_c()
sock = self.client.transfercmd('list')
self.assertNotIsInstance(sock, ssl.SSLSocket)
sock.close()
self.assertEqual(self.client.voidresp(), "226 transfer complete")
def test_login(self):
# login() is supposed to implicitly secure the control connection
self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
self.client.login()
self.assertIsInstance(self.client.sock, ssl.SSLSocket)
# make sure that AUTH TLS doesn't get issued again
self.client.login()
def test_auth_issued_twice(self):
self.client.auth()
self.assertRaises(ValueError, self.client.auth)
def test_auth_ssl(self):
try:
self.client.ssl_version = ssl.PROTOCOL_SSLv3
self.client.auth()
self.assertRaises(ValueError, self.client.auth)
finally:
self.client.ssl_version = ssl.PROTOCOL_TLSv1
class TestTimeouts(TestCase):
def setUp(self):
self.evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(10)
self.port = test_support.bind_port(self.sock)
threading.Thread(target=self.server, args=(self.evt,self.sock)).start()
# Wait for the server to be ready.
self.evt.wait()
self.evt.clear()
ftplib.FTP.port = self.port
def tearDown(self):
self.evt.wait()
def server(self, evt, serv):
# This method sets the evt 3 times:
# 1) when the connection is ready to be accepted.
# 2) when it is safe for the caller to close the connection
# 3) when we have closed the socket
serv.listen(5)
# (1) Signal the caller that we are ready to accept the connection.
evt.set()
try:
conn, addr = serv.accept()
except socket.timeout:
pass
else:
conn.send("1 Hola mundo\n")
# (2) Signal the caller that it is safe to close the socket.
evt.set()
conn.close()
finally:
serv.close()
# (3) Signal the caller that we are done.
evt.set()
def testTimeoutDefault(self):
# default -- use global socket timeout
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
ftp = ftplib.FTP("localhost")
finally:
socket.setdefaulttimeout(None)
self.assertEqual(ftp.sock.gettimeout(), 30)
self.evt.wait()
ftp.close()
def testTimeoutNone(self):
# no timeout -- do not use global socket timeout
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
ftp = ftplib.FTP("localhost", timeout=None)
finally:
socket.setdefaulttimeout(None)
self.assertTrue(ftp.sock.gettimeout() is None)
self.evt.wait()
ftp.close()
def testTimeoutValue(self):
# a value
ftp = ftplib.FTP(HOST, timeout=30)
self.assertEqual(ftp.sock.gettimeout(), 30)
self.evt.wait()
ftp.close()
def testTimeoutConnect(self):
ftp = ftplib.FTP()
ftp.connect(HOST, timeout=30)
self.assertEqual(ftp.sock.gettimeout(), 30)
self.evt.wait()
ftp.close()
def testTimeoutDifferentOrder(self):
ftp = ftplib.FTP(timeout=30)
ftp.connect(HOST)
self.assertEqual(ftp.sock.gettimeout(), 30)
self.evt.wait()
ftp.close()
def testTimeoutDirectAccess(self):
ftp = ftplib.FTP()
ftp.timeout = 30
ftp.connect(HOST)
self.assertEqual(ftp.sock.gettimeout(), 30)
self.evt.wait()
ftp.close()
def test_main():
tests = [TestFTPClass, TestTimeouts]
if socket.has_ipv6:
try:
DummyFTPServer((HOST, 0), af=socket.AF_INET6)
except socket.error:
pass
else:
tests.append(TestIPv6Environment)
if ssl is not None:
tests.extend([TestTLS_FTPClassMixin, TestTLS_FTPClass])
thread_info = test_support.threading_setup()
try:
test_support.run_unittest(*tests)
finally:
test_support.threading_cleanup(*thread_info)
if __name__ == '__main__':
test_main()
import httplib
import array
import httplib
import StringIO
import socket
import errno
import unittest
TestCase = unittest.TestCase
from test import test_support
HOST = test_support.HOST
class FakeSocket:
def __init__(self, text, fileclass=StringIO.StringIO):
self.text = text
self.fileclass = fileclass
self.data = ''
def sendall(self, data):
self.data += ''.join(data)
def makefile(self, mode, bufsize=None):
if mode != 'r' and mode != 'rb':
raise httplib.UnimplementedFileMode()
return self.fileclass(self.text)
class EPipeSocket(FakeSocket):
def __init__(self, text, pipe_trigger):
# When sendall() is called with pipe_trigger, raise EPIPE.
FakeSocket.__init__(self, text)
self.pipe_trigger = pipe_trigger
def sendall(self, data):
if self.pipe_trigger in data:
raise socket.error(errno.EPIPE, "gotcha")
self.data += data
def close(self):
pass
class NoEOFStringIO(StringIO.StringIO):
"""Like StringIO, but raises AssertionError on EOF.
This is used below to test that httplib doesn't try to read
more from the underlying file than it should.
"""
def read(self, n=-1):
data = StringIO.StringIO.read(self, n)
if data == '':
raise AssertionError('caller tried to read past EOF')
return data
def readline(self, length=None):
data = StringIO.StringIO.readline(self, length)
if data == '':
raise AssertionError('caller tried to read past EOF')
return data
class HeaderTests(TestCase):
def test_auto_headers(self):
# Some headers are added automatically, but should not be added by
# .request() if they are explicitly set.
class HeaderCountingBuffer(list):
def __init__(self):
self.count = {}
def append(self, item):
kv = item.split(':')
if len(kv) > 1:
# item is a 'Key: Value' header string
lcKey = kv[0].lower()
self.count.setdefault(lcKey, 0)
self.count[lcKey] += 1
list.append(self, item)
for explicit_header in True, False:
for header in 'Content-length', 'Host', 'Accept-encoding':
conn = httplib.HTTPConnection('example.com')
conn.sock = FakeSocket('blahblahblah')
conn._buffer = HeaderCountingBuffer()
body = 'spamspamspam'
headers = {}
if explicit_header:
headers[header] = str(len(body))
conn.request('POST', '/', body, headers)
self.assertEqual(conn._buffer.count[header.lower()], 1)
def test_putheader(self):
conn = httplib.HTTPConnection('example.com')
conn.sock = FakeSocket(None)
conn.putrequest('GET','/')
conn.putheader('Content-length',42)
self.assertTrue('Content-length: 42' in conn._buffer)
def test_ipv6host_header(self):
# Default host header on IPv6 transaction should wrapped by [] if
# its actual IPv6 address
expected = 'GET /foo HTTP/1.1\r\nHost: [2001::]:81\r\n' \
'Accept-Encoding: identity\r\n\r\n'
conn = httplib.HTTPConnection('[2001::]:81')
sock = FakeSocket('')
conn.sock = sock
conn.request('GET', '/foo')
self.assertTrue(sock.data.startswith(expected))
expected = 'GET /foo HTTP/1.1\r\nHost: [2001:102A::]\r\n' \
'Accept-Encoding: identity\r\n\r\n'
conn = httplib.HTTPConnection('[2001:102A::]')
sock = FakeSocket('')
conn.sock = sock
conn.request('GET', '/foo')
self.assertTrue(sock.data.startswith(expected))
class BasicTest(TestCase):
def test_status_lines(self):
# Test HTTP status lines
body = "HTTP/1.1 200 Ok\r\n\r\nText"
sock = FakeSocket(body)
resp = httplib.HTTPResponse(sock)
resp.begin()
self.assertEqual(resp.read(), 'Text')
self.assertTrue(resp.isclosed())
body = "HTTP/1.1 400.100 Not Ok\r\n\r\nText"
sock = FakeSocket(body)
resp = httplib.HTTPResponse(sock)
self.assertRaises(httplib.BadStatusLine, resp.begin)
def test_bad_status_repr(self):
exc = httplib.BadStatusLine('')
self.assertEqual(repr(exc), '''BadStatusLine("\'\'",)''')
def test_partial_reads(self):
# if we have a lenght, the system knows when to close itself
# same behaviour than when we read the whole thing with read()
body = "HTTP/1.1 200 Ok\r\nContent-Length: 4\r\n\r\nText"
sock = FakeSocket(body)
resp = httplib.HTTPResponse(sock)
resp.begin()
self.assertEqual(resp.read(2), 'Te')
self.assertFalse(resp.isclosed())
self.assertEqual(resp.read(2), 'xt')
self.assertTrue(resp.isclosed())
def test_host_port(self):
# Check invalid host_port
# Note that httplib does not accept user:password@ in the host-port.
for hp in ("www.python.org:abc", "user:password@www.python.org"):
self.assertRaises(httplib.InvalidURL, httplib.HTTP, hp)
for hp, h, p in (("[fe80::207:e9ff:fe9b]:8000", "fe80::207:e9ff:fe9b",
8000),
("www.python.org:80", "www.python.org", 80),
("www.python.org", "www.python.org", 80),
("www.python.org:", "www.python.org", 80),
("[fe80::207:e9ff:fe9b]", "fe80::207:e9ff:fe9b", 80)):
http = httplib.HTTP(hp)
c = http._conn
if h != c.host:
self.fail("Host incorrectly parsed: %s != %s" % (h, c.host))
if p != c.port:
self.fail("Port incorrectly parsed: %s != %s" % (p, c.host))
def test_response_headers(self):
# test response with multiple message headers with the same field name.
text = ('HTTP/1.1 200 OK\r\n'
'Set-Cookie: Customer="WILE_E_COYOTE";'
' Version="1"; Path="/acme"\r\n'
'Set-Cookie: Part_Number="Rocket_Launcher_0001"; Version="1";'
' Path="/acme"\r\n'
'\r\n'
'No body\r\n')
hdr = ('Customer="WILE_E_COYOTE"; Version="1"; Path="/acme"'
', '
'Part_Number="Rocket_Launcher_0001"; Version="1"; Path="/acme"')
s = FakeSocket(text)
r = httplib.HTTPResponse(s)
r.begin()
cookies = r.getheader("Set-Cookie")
if cookies != hdr:
self.fail("multiple headers not combined properly")
def test_read_head(self):
# Test that the library doesn't attempt to read any data
# from a HEAD request. (Tickles SF bug #622042.)
sock = FakeSocket(
'HTTP/1.1 200 OK\r\n'
'Content-Length: 14432\r\n'
'\r\n',
NoEOFStringIO)
resp = httplib.HTTPResponse(sock, method="HEAD")
resp.begin()
if resp.read() != "":
self.fail("Did not expect response from HEAD request")
def test_send_file(self):
expected = 'GET /foo HTTP/1.1\r\nHost: example.com\r\n' \
'Accept-Encoding: identity\r\nContent-Length:'
body = open(__file__, 'rb')
conn = httplib.HTTPConnection('example.com')
sock = FakeSocket(body)
conn.sock = sock
conn.request('GET', '/foo', body)
self.assertTrue(sock.data.startswith(expected))
def test_send(self):
expected = 'this is a test this is only a test'
conn = httplib.HTTPConnection('example.com')
sock = FakeSocket(None)
conn.sock = sock
conn.send(expected)
self.assertEqual(expected, sock.data)
sock.data = ''
conn.send(array.array('c', expected))
self.assertEqual(expected, sock.data)
sock.data = ''
conn.send(StringIO.StringIO(expected))
self.assertEqual(expected, sock.data)
def test_chunked(self):
chunked_start = (
'HTTP/1.1 200 OK\r\n'
'Transfer-Encoding: chunked\r\n\r\n'
'a\r\n'
'hello worl\r\n'
'1\r\n'
'd\r\n'
)
sock = FakeSocket(chunked_start + '0\r\n')
resp = httplib.HTTPResponse(sock, method="GET")
resp.begin()
self.assertEqual(resp.read(), 'hello world')
resp.close()
for x in ('', 'foo\r\n'):
sock = FakeSocket(chunked_start + x)
resp = httplib.HTTPResponse(sock, method="GET")
resp.begin()
try:
resp.read()
except httplib.IncompleteRead, i:
self.assertEqual(i.partial, 'hello world')
self.assertEqual(repr(i),'IncompleteRead(11 bytes read)')
self.assertEqual(str(i),'IncompleteRead(11 bytes read)')
else:
self.fail('IncompleteRead expected')
finally:
resp.close()
def test_chunked_head(self):
chunked_start = (
'HTTP/1.1 200 OK\r\n'
'Transfer-Encoding: chunked\r\n\r\n'
'a\r\n'
'hello world\r\n'
'1\r\n'
'd\r\n'
)
sock = FakeSocket(chunked_start + '0\r\n')
resp = httplib.HTTPResponse(sock, method="HEAD")
resp.begin()
self.assertEqual(resp.read(), '')
self.assertEqual(resp.status, 200)
self.assertEqual(resp.reason, 'OK')
self.assertTrue(resp.isclosed())
def test_negative_content_length(self):
sock = FakeSocket('HTTP/1.1 200 OK\r\n'
'Content-Length: -1\r\n\r\nHello\r\n')
resp = httplib.HTTPResponse(sock, method="GET")
resp.begin()
self.assertEqual(resp.read(), 'Hello\r\n')
resp.close()
def test_incomplete_read(self):
sock = FakeSocket('HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nHello\r\n')
resp = httplib.HTTPResponse(sock, method="GET")
resp.begin()
try:
resp.read()
except httplib.IncompleteRead as i:
self.assertEqual(i.partial, 'Hello\r\n')
self.assertEqual(repr(i),
"IncompleteRead(7 bytes read, 3 more expected)")
self.assertEqual(str(i),
"IncompleteRead(7 bytes read, 3 more expected)")
else:
self.fail('IncompleteRead expected')
finally:
resp.close()
def test_epipe(self):
sock = EPipeSocket(
"HTTP/1.0 401 Authorization Required\r\n"
"Content-type: text/html\r\n"
"WWW-Authenticate: Basic realm=\"example\"\r\n",
b"Content-Length")
conn = httplib.HTTPConnection("example.com")
conn.sock = sock
self.assertRaises(socket.error,
lambda: conn.request("PUT", "/url", "body"))
resp = conn.getresponse()
self.assertEqual(401, resp.status)
self.assertEqual("Basic realm=\"example\"",
resp.getheader("www-authenticate"))
def test_filenoattr(self):
# Just test the fileno attribute in the HTTPResponse Object.
body = "HTTP/1.1 200 Ok\r\n\r\nText"
sock = FakeSocket(body)
resp = httplib.HTTPResponse(sock)
self.assertTrue(hasattr(resp,'fileno'),
'HTTPResponse should expose a fileno attribute')
# Test lines overflowing the max line size (_MAXLINE in http.client)
def test_overflowing_status_line(self):
self.skipTest("disabled for HTTP 0.9 support")
body = "HTTP/1.1 200 Ok" + "k" * 65536 + "\r\n"
resp = httplib.HTTPResponse(FakeSocket(body))
self.assertRaises((httplib.LineTooLong, httplib.BadStatusLine), resp.begin)
def test_overflowing_header_line(self):
body = (
'HTTP/1.1 200 OK\r\n'
'X-Foo: bar' + 'r' * 65536 + '\r\n\r\n'
)
resp = httplib.HTTPResponse(FakeSocket(body))
self.assertRaises(httplib.LineTooLong, resp.begin)
def test_overflowing_chunked_line(self):
body = (
'HTTP/1.1 200 OK\r\n'
'Transfer-Encoding: chunked\r\n\r\n'
+ '0' * 65536 + 'a\r\n'
'hello world\r\n'
'0\r\n'
)
resp = httplib.HTTPResponse(FakeSocket(body))
resp.begin()
self.assertRaises(httplib.LineTooLong, resp.read)
class OfflineTest(TestCase):
def test_responses(self):
self.assertEqual(httplib.responses[httplib.NOT_FOUND], "Not Found")
class SourceAddressTest(TestCase):
def setUp(self):
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = test_support.bind_port(self.serv)
self.source_port = test_support.find_unused_port()
self.serv.listen(5)
self.conn = None
def tearDown(self):
if self.conn:
self.conn.close()
self.conn = None
self.serv.close()
self.serv = None
def testHTTPConnectionSourceAddress(self):
self.conn = httplib.HTTPConnection(HOST, self.port,
source_address=('', self.source_port))
self.conn.connect()
self.assertEqual(self.conn.sock.getsockname()[1], self.source_port)
@unittest.skipIf(not hasattr(httplib, 'HTTPSConnection'),
'httplib.HTTPSConnection not defined')
def testHTTPSConnectionSourceAddress(self):
self.conn = httplib.HTTPSConnection(HOST, self.port,
source_address=('', self.source_port))
# We don't test anything here other the constructor not barfing as
# this code doesn't deal with setting up an active running SSL server
# for an ssl_wrapped connect() to actually return from.
class TimeoutTest(TestCase):
PORT = None
def setUp(self):
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
TimeoutTest.PORT = test_support.bind_port(self.serv)
self.serv.listen(5)
def tearDown(self):
self.serv.close()
self.serv = None
def testTimeoutAttribute(self):
'''This will prove that the timeout gets through
HTTPConnection and into the socket.
'''
# default -- use global socket timeout
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
httpConn = httplib.HTTPConnection(HOST, TimeoutTest.PORT)
httpConn.connect()
finally:
socket.setdefaulttimeout(None)
self.assertEqual(httpConn.sock.gettimeout(), 30)
httpConn.close()
# no timeout -- do not use global socket default
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
httpConn = httplib.HTTPConnection(HOST, TimeoutTest.PORT,
timeout=None)
httpConn.connect()
finally:
socket.setdefaulttimeout(None)
self.assertEqual(httpConn.sock.gettimeout(), None)
httpConn.close()
# a value
httpConn = httplib.HTTPConnection(HOST, TimeoutTest.PORT, timeout=30)
httpConn.connect()
self.assertEqual(httpConn.sock.gettimeout(), 30)
httpConn.close()
class HTTPSTimeoutTest(TestCase):
# XXX Here should be tests for HTTPS, there isn't any right now!
def test_attributes(self):
# simple test to check it's storing it
if hasattr(httplib, 'HTTPSConnection'):
h = httplib.HTTPSConnection(HOST, TimeoutTest.PORT, timeout=30)
self.assertEqual(h.timeout, 30)
@unittest.skipIf(not hasattr(httplib, 'HTTPS'), 'httplib.HTTPS not available')
def test_host_port(self):
# Check invalid host_port
# Note that httplib does not accept user:password@ in the host-port.
for hp in ("www.python.org:abc", "user:password@www.python.org"):
self.assertRaises(httplib.InvalidURL, httplib.HTTP, hp)
for hp, h, p in (("[fe80::207:e9ff:fe9b]:8000", "fe80::207:e9ff:fe9b",
8000),
("pypi.python.org:443", "pypi.python.org", 443),
("pypi.python.org", "pypi.python.org", 443),
("pypi.python.org:", "pypi.python.org", 443),
("[fe80::207:e9ff:fe9b]", "fe80::207:e9ff:fe9b", 443)):
http = httplib.HTTPS(hp)
c = http._conn
if h != c.host:
self.fail("Host incorrectly parsed: %s != %s" % (h, c.host))
if p != c.port:
self.fail("Port incorrectly parsed: %s != %s" % (p, c.host))
def test_main(verbose=None):
test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
HTTPSTimeoutTest, SourceAddressTest)
if __name__ == '__main__':
test_main()
"""Unittests for the various HTTPServer modules.
Written by Cody A.W. Somerville <cody-somerville@ubuntu.com>,
Josip Dzolonga, and Michael Otteneder for the 2007/08 GHOP contest.
"""
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
from CGIHTTPServer import CGIHTTPRequestHandler
import CGIHTTPServer
import os
import sys
import re
import base64
import shutil
import urllib
import httplib
import tempfile
import unittest
from StringIO import StringIO
from test import test_support
threading = test_support.import_module('threading')
class NoLogRequestHandler:
def log_message(self, *args):
# don't write log messages to stderr
pass
class SocketlessRequestHandler(SimpleHTTPRequestHandler):
def __init__(self):
self.get_called = False
self.protocol_version = "HTTP/1.1"
def do_GET(self):
self.get_called = True
self.send_response(200)
self.send_header('Content-Type', 'text/html')
self.end_headers()
self.wfile.write(b'<html><body>Data</body></html>\r\n')
def log_message(self, format, *args):
pass
class TestServerThread(threading.Thread):
def __init__(self, test_object, request_handler):
threading.Thread.__init__(self)
self.request_handler = request_handler
self.test_object = test_object
def run(self):
self.server = HTTPServer(('', 0), self.request_handler)
self.test_object.PORT = self.server.socket.getsockname()[1]
self.test_object.server_started.set()
self.test_object = None
try:
self.server.serve_forever(0.05)
finally:
self.server.server_close()
def stop(self):
self.server.shutdown()
class BaseTestCase(unittest.TestCase):
def setUp(self):
self._threads = test_support.threading_setup()
os.environ = test_support.EnvironmentVarGuard()
self.server_started = threading.Event()
self.thread = TestServerThread(self, self.request_handler)
self.thread.start()
self.server_started.wait()
def tearDown(self):
self.thread.stop()
os.environ.__exit__()
test_support.threading_cleanup(*self._threads)
def request(self, uri, method='GET', body=None, headers={}):
self.connection = httplib.HTTPConnection('localhost', self.PORT)
self.connection.request(method, uri, body, headers)
return self.connection.getresponse()
class BaseHTTPRequestHandlerTestCase(unittest.TestCase):
"""Test the functionality of the BaseHTTPServer focussing on
BaseHTTPRequestHandler.
"""
HTTPResponseMatch = re.compile('HTTP/1.[0-9]+ 200 OK')
def setUp (self):
self.handler = SocketlessRequestHandler()
def send_typical_request(self, message):
input = StringIO(message)
output = StringIO()
self.handler.rfile = input
self.handler.wfile = output
self.handler.handle_one_request()
output.seek(0)
return output.readlines()
def verify_get_called(self):
self.assertTrue(self.handler.get_called)
def verify_expected_headers(self, headers):
for fieldName in 'Server: ', 'Date: ', 'Content-Type: ':
self.assertEqual(sum(h.startswith(fieldName) for h in headers), 1)
def verify_http_server_response(self, response):
match = self.HTTPResponseMatch.search(response)
self.assertTrue(match is not None)
def test_http_1_1(self):
result = self.send_typical_request('GET / HTTP/1.1\r\n\r\n')
self.verify_http_server_response(result[0])
self.verify_expected_headers(result[1:-1])
self.verify_get_called()
self.assertEqual(result[-1], '<html><body>Data</body></html>\r\n')
def test_http_1_0(self):
result = self.send_typical_request('GET / HTTP/1.0\r\n\r\n')
self.verify_http_server_response(result[0])
self.verify_expected_headers(result[1:-1])
self.verify_get_called()
self.assertEqual(result[-1], '<html><body>Data</body></html>\r\n')
def test_http_0_9(self):
result = self.send_typical_request('GET / HTTP/0.9\r\n\r\n')
self.assertEqual(len(result), 1)
self.assertEqual(result[0], '<html><body>Data</body></html>\r\n')
self.verify_get_called()
def test_with_continue_1_0(self):
result = self.send_typical_request('GET / HTTP/1.0\r\nExpect: 100-continue\r\n\r\n')
self.verify_http_server_response(result[0])
self.verify_expected_headers(result[1:-1])
self.verify_get_called()
self.assertEqual(result[-1], '<html><body>Data</body></html>\r\n')
def test_request_length(self):
# Issue #10714: huge request lines are discarded, to avoid Denial
# of Service attacks.
result = self.send_typical_request(b'GET ' + b'x' * 65537)
self.assertEqual(result[0], b'HTTP/1.1 414 Request-URI Too Long\r\n')
self.assertFalse(self.handler.get_called)
class BaseHTTPServerTestCase(BaseTestCase):
class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
default_request_version = 'HTTP/1.1'
def do_TEST(self):
self.send_response(204)
self.send_header('Content-Type', 'text/html')
self.send_header('Connection', 'close')
self.end_headers()
def do_KEEP(self):
self.send_response(204)
self.send_header('Content-Type', 'text/html')
self.send_header('Connection', 'keep-alive')
self.end_headers()
def do_KEYERROR(self):
self.send_error(999)
def do_CUSTOM(self):
self.send_response(999)
self.send_header('Content-Type', 'text/html')
self.send_header('Connection', 'close')
self.end_headers()
def setUp(self):
BaseTestCase.setUp(self)
self.con = httplib.HTTPConnection('localhost', self.PORT)
self.con.connect()
def test_command(self):
self.con.request('GET', '/')
res = self.con.getresponse()
self.assertEqual(res.status, 501)
def test_request_line_trimming(self):
self.con._http_vsn_str = 'HTTP/1.1\n'
self.con.putrequest('GET', '/')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 501)
def test_version_bogus(self):
self.con._http_vsn_str = 'FUBAR'
self.con.putrequest('GET', '/')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 400)
def test_version_digits(self):
self.con._http_vsn_str = 'HTTP/9.9.9'
self.con.putrequest('GET', '/')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 400)
def test_version_none_get(self):
self.con._http_vsn_str = ''
self.con.putrequest('GET', '/')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 501)
def test_version_none(self):
self.con._http_vsn_str = ''
self.con.putrequest('PUT', '/')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 400)
def test_version_invalid(self):
self.con._http_vsn = 99
self.con._http_vsn_str = 'HTTP/9.9'
self.con.putrequest('GET', '/')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 505)
def test_send_blank(self):
self.con._http_vsn_str = ''
self.con.putrequest('', '')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 400)
def test_header_close(self):
self.con.putrequest('GET', '/')
self.con.putheader('Connection', 'close')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 501)
def test_head_keep_alive(self):
self.con._http_vsn_str = 'HTTP/1.1'
self.con.putrequest('GET', '/')
self.con.putheader('Connection', 'keep-alive')
self.con.endheaders()
res = self.con.getresponse()
self.assertEqual(res.status, 501)
def test_handler(self):
self.con.request('TEST', '/')
res = self.con.getresponse()
self.assertEqual(res.status, 204)
def test_return_header_keep_alive(self):
self.con.request('KEEP', '/')
res = self.con.getresponse()
self.assertEqual(res.getheader('Connection'), 'keep-alive')
self.con.request('TEST', '/')
self.addCleanup(self.con.close)
def test_internal_key_error(self):
self.con.request('KEYERROR', '/')
res = self.con.getresponse()
self.assertEqual(res.status, 999)
def test_return_custom_status(self):
self.con.request('CUSTOM', '/')
res = self.con.getresponse()
self.assertEqual(res.status, 999)
class SimpleHTTPServerTestCase(BaseTestCase):
class request_handler(NoLogRequestHandler, SimpleHTTPRequestHandler):
pass
def setUp(self):
BaseTestCase.setUp(self)
self.cwd = os.getcwd()
basetempdir = tempfile.gettempdir()
os.chdir(basetempdir)
self.data = 'We are the knights who say Ni!'
self.tempdir = tempfile.mkdtemp(dir=basetempdir)
self.tempdir_name = os.path.basename(self.tempdir)
temp = open(os.path.join(self.tempdir, 'test'), 'wb')
temp.write(self.data)
temp.close()
def tearDown(self):
try:
os.chdir(self.cwd)
try:
shutil.rmtree(self.tempdir)
except:
pass
finally:
BaseTestCase.tearDown(self)
def check_status_and_reason(self, response, status, data=None):
body = response.read()
self.assertTrue(response)
self.assertEqual(response.status, status)
self.assertIsNotNone(response.reason)
if data:
self.assertEqual(data, body)
def test_get(self):
#constructs the path relative to the root directory of the HTTPServer
response = self.request(self.tempdir_name + '/test')
self.check_status_and_reason(response, 200, data=self.data)
response = self.request(self.tempdir_name + '/')
self.check_status_and_reason(response, 200)
response = self.request(self.tempdir_name)
self.check_status_and_reason(response, 301)
response = self.request('/ThisDoesNotExist')
self.check_status_and_reason(response, 404)
response = self.request('/' + 'ThisDoesNotExist' + '/')
self.check_status_and_reason(response, 404)
f = open(os.path.join(self.tempdir_name, 'index.html'), 'w')
response = self.request('/' + self.tempdir_name + '/')
self.check_status_and_reason(response, 200)
# chmod() doesn't work as expected on Windows, and filesystem
# permissions are ignored by root on Unix.
if os.name == 'posix' and os.geteuid() != 0:
os.chmod(self.tempdir, 0)
response = self.request(self.tempdir_name + '/')
self.check_status_and_reason(response, 404)
os.chmod(self.tempdir, 0755)
def test_head(self):
response = self.request(
self.tempdir_name + '/test', method='HEAD')
self.check_status_and_reason(response, 200)
self.assertEqual(response.getheader('content-length'),
str(len(self.data)))
self.assertEqual(response.getheader('content-type'),
'application/octet-stream')
def test_invalid_requests(self):
response = self.request('/', method='FOO')
self.check_status_and_reason(response, 501)
# requests must be case sensitive,so this should fail too
response = self.request('/', method='get')
self.check_status_and_reason(response, 501)
response = self.request('/', method='GETs')
self.check_status_and_reason(response, 501)
cgi_file1 = """\
#!%s
print "Content-type: text/html"
print
print "Hello World"
"""
cgi_file2 = """\
#!%s
import cgi
print "Content-type: text/html"
print
form = cgi.FieldStorage()
print "%%s, %%s, %%s" %% (form.getfirst("spam"), form.getfirst("eggs"),
form.getfirst("bacon"))
"""
@unittest.skipIf(hasattr(os, 'geteuid') and os.geteuid() == 0,
"This test can't be run reliably as root (issue #13308).")
class CGIHTTPServerTestCase(BaseTestCase):
class request_handler(NoLogRequestHandler, CGIHTTPRequestHandler):
pass
def setUp(self):
BaseTestCase.setUp(self)
self.parent_dir = tempfile.mkdtemp()
self.cgi_dir = os.path.join(self.parent_dir, 'cgi-bin')
os.mkdir(self.cgi_dir)
# The shebang line should be pure ASCII: use symlink if possible.
# See issue #7668.
if hasattr(os, 'symlink'):
self.pythonexe = os.path.join(self.parent_dir, 'python')
os.symlink(sys.executable, self.pythonexe)
else:
self.pythonexe = sys.executable
self.file1_path = os.path.join(self.cgi_dir, 'file1.py')
with open(self.file1_path, 'w') as file1:
file1.write(cgi_file1 % self.pythonexe)
os.chmod(self.file1_path, 0777)
self.file2_path = os.path.join(self.cgi_dir, 'file2.py')
with open(self.file2_path, 'w') as file2:
file2.write(cgi_file2 % self.pythonexe)
os.chmod(self.file2_path, 0777)
self.cwd = os.getcwd()
os.chdir(self.parent_dir)
def tearDown(self):
try:
os.chdir(self.cwd)
if self.pythonexe != sys.executable:
os.remove(self.pythonexe)
os.remove(self.file1_path)
os.remove(self.file2_path)
os.rmdir(self.cgi_dir)
os.rmdir(self.parent_dir)
finally:
BaseTestCase.tearDown(self)
def test_url_collapse_path_split(self):
test_vectors = {
'': ('/', ''),
'..': IndexError,
'/.//..': IndexError,
'/': ('/', ''),
'//': ('/', ''),
'/\\': ('/', '\\'),
'/.//': ('/', ''),
'cgi-bin/file1.py': ('/cgi-bin', 'file1.py'),
'/cgi-bin/file1.py': ('/cgi-bin', 'file1.py'),
'a': ('/', 'a'),
'/a': ('/', 'a'),
'//a': ('/', 'a'),
'./a': ('/', 'a'),
'./C:/': ('/C:', ''),
'/a/b': ('/a', 'b'),
'/a/b/': ('/a/b', ''),
'/a/b/c/..': ('/a/b', ''),
'/a/b/c/../d': ('/a/b', 'd'),
'/a/b/c/../d/e/../f': ('/a/b/d', 'f'),
'/a/b/c/../d/e/../../f': ('/a/b', 'f'),
'/a/b/c/../d/e/.././././..//f': ('/a/b', 'f'),
'../a/b/c/../d/e/.././././..//f': IndexError,
'/a/b/c/../d/e/../../../f': ('/a', 'f'),
'/a/b/c/../d/e/../../../../f': ('/', 'f'),
'/a/b/c/../d/e/../../../../../f': IndexError,
'/a/b/c/../d/e/../../../../f/..': ('/', ''),
}
for path, expected in test_vectors.iteritems():
if isinstance(expected, type) and issubclass(expected, Exception):
self.assertRaises(expected,
CGIHTTPServer._url_collapse_path_split, path)
else:
actual = CGIHTTPServer._url_collapse_path_split(path)
self.assertEqual(expected, actual,
msg='path = %r\nGot: %r\nWanted: %r' %
(path, actual, expected))
def test_headers_and_content(self):
res = self.request('/cgi-bin/file1.py')
self.assertEqual(('Hello World\n', 'text/html', 200),
(res.read(), res.getheader('Content-type'), res.status))
def test_post(self):
params = urllib.urlencode({'spam' : 1, 'eggs' : 'python', 'bacon' : 123456})
headers = {'Content-type' : 'application/x-www-form-urlencoded'}
res = self.request('/cgi-bin/file2.py', 'POST', params, headers)
self.assertEqual(res.read(), '1, python, 123456\n')
def test_invaliduri(self):
res = self.request('/cgi-bin/invalid')
res.read()
self.assertEqual(res.status, 404)
def test_authorization(self):
headers = {'Authorization' : 'Basic %s' %
base64.b64encode('username:pass')}
res = self.request('/cgi-bin/file1.py', 'GET', headers=headers)
self.assertEqual(('Hello World\n', 'text/html', 200),
(res.read(), res.getheader('Content-type'), res.status))
def test_no_leading_slash(self):
# http://bugs.python.org/issue2254
res = self.request('cgi-bin/file1.py')
self.assertEqual(('Hello World\n', 'text/html', 200),
(res.read(), res.getheader('Content-type'), res.status))
def test_os_environ_is_not_altered(self):
signature = "Test CGI Server"
os.environ['SERVER_SOFTWARE'] = signature
res = self.request('/cgi-bin/file1.py')
self.assertEqual((b'Hello World\n', 'text/html', 200),
(res.read(), res.getheader('Content-type'), res.status))
self.assertEqual(os.environ['SERVER_SOFTWARE'], signature)
class SimpleHTTPRequestHandlerTestCase(unittest.TestCase):
""" Test url parsing """
def setUp(self):
self.translated = os.getcwd()
self.translated = os.path.join(self.translated, 'filename')
self.handler = SocketlessRequestHandler()
def test_query_arguments(self):
path = self.handler.translate_path('/filename')
self.assertEqual(path, self.translated)
path = self.handler.translate_path('/filename?foo=bar')
self.assertEqual(path, self.translated)
path = self.handler.translate_path('/filename?a=b&spam=eggs#zot')
self.assertEqual(path, self.translated)
def test_start_with_double_slash(self):
path = self.handler.translate_path('//filename')
self.assertEqual(path, self.translated)
path = self.handler.translate_path('//filename?foo=bar')
self.assertEqual(path, self.translated)
def test_main(verbose=None):
try:
cwd = os.getcwd()
test_support.run_unittest(BaseHTTPRequestHandlerTestCase,
SimpleHTTPRequestHandlerTestCase,
BaseHTTPServerTestCase,
SimpleHTTPServerTestCase,
CGIHTTPServerTestCase
)
finally:
os.chdir(cwd)
if __name__ == '__main__':
test_main()
# Some simple queue module tests, plus some failure conditions
# to ensure the Queue locks remain stable.
import Queue
import time
import unittest
from test import test_support
threading = test_support.import_module('threading')
QUEUE_SIZE = 5
# A thread to run a function that unclogs a blocked Queue.
class _TriggerThread(threading.Thread):
def __init__(self, fn, args):
self.fn = fn
self.args = args
self.startedEvent = threading.Event()
threading.Thread.__init__(self)
def run(self):
# The sleep isn't necessary, but is intended to give the blocking
# function in the main thread a chance at actually blocking before
# we unclog it. But if the sleep is longer than the timeout-based
# tests wait in their blocking functions, those tests will fail.
# So we give them much longer timeout values compared to the
# sleep here (I aimed at 10 seconds for blocking functions --
# they should never actually wait that long - they should make
# progress as soon as we call self.fn()).
time.sleep(0.1)
self.startedEvent.set()
self.fn(*self.args)
# Execute a function that blocks, and in a separate thread, a function that
# triggers the release. Returns the result of the blocking function. Caution:
# block_func must guarantee to block until trigger_func is called, and
# trigger_func must guarantee to change queue state so that block_func can make
# enough progress to return. In particular, a block_func that just raises an
# exception regardless of whether trigger_func is called will lead to
# timing-dependent sporadic failures, and one of those went rarely seen but
# undiagnosed for years. Now block_func must be unexceptional. If block_func
# is supposed to raise an exception, call do_exceptional_blocking_test()
# instead.
class BlockingTestMixin:
def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args):
self.t = _TriggerThread(trigger_func, trigger_args)
self.t.start()
self.result = block_func(*block_args)
# If block_func returned before our thread made the call, we failed!
if not self.t.startedEvent.is_set():
self.fail("blocking function '%r' appeared not to block" %
block_func)
self.t.join(10) # make sure the thread terminates
if self.t.is_alive():
self.fail("trigger function '%r' appeared to not return" %
trigger_func)
return self.result
# Call this instead if block_func is supposed to raise an exception.
def do_exceptional_blocking_test(self,block_func, block_args, trigger_func,
trigger_args, expected_exception_class):
self.t = _TriggerThread(trigger_func, trigger_args)
self.t.start()
try:
try:
block_func(*block_args)
except expected_exception_class:
raise
else:
self.fail("expected exception of kind %r" %
expected_exception_class)
finally:
self.t.join(10) # make sure the thread terminates
if self.t.is_alive():
self.fail("trigger function '%r' appeared to not return" %
trigger_func)
if not self.t.startedEvent.is_set():
self.fail("trigger thread ended but event never set")
class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
def setUp(self):
self.cum = 0
self.cumlock = threading.Lock()
def simple_queue_test(self, q):
if not q.empty():
raise RuntimeError, "Call this function with an empty queue"
# I guess we better check things actually queue correctly a little :)
q.put(111)
q.put(333)
q.put(222)
target_order = dict(Queue = [111, 333, 222],
LifoQueue = [222, 333, 111],
PriorityQueue = [111, 222, 333])
actual_order = [q.get(), q.get(), q.get()]
self.assertEqual(actual_order, target_order[q.__class__.__name__],
"Didn't seem to queue the correct data!")
for i in range(QUEUE_SIZE-1):
q.put(i)
self.assertTrue(not q.empty(), "Queue should not be empty")
self.assertTrue(not q.full(), "Queue should not be full")
last = 2 * QUEUE_SIZE
full = 3 * 2 * QUEUE_SIZE
q.put(last)
self.assertTrue(q.full(), "Queue should be full")
try:
q.put(full, block=0)
self.fail("Didn't appear to block with a full queue")
except Queue.Full:
pass
try:
q.put(full, timeout=0.01)
self.fail("Didn't appear to time-out with a full queue")
except Queue.Full:
pass
# Test a blocking put
self.do_blocking_test(q.put, (full,), q.get, ())
self.do_blocking_test(q.put, (full, True, 10), q.get, ())
# Empty it
for i in range(QUEUE_SIZE):
q.get()
self.assertTrue(q.empty(), "Queue should be empty")
try:
q.get(block=0)
self.fail("Didn't appear to block with an empty queue")
except Queue.Empty:
pass
try:
q.get(timeout=0.01)
self.fail("Didn't appear to time-out with an empty queue")
except Queue.Empty:
pass
# Test a blocking get
self.do_blocking_test(q.get, (), q.put, ('empty',))
self.do_blocking_test(q.get, (True, 10), q.put, ('empty',))
def worker(self, q):
while True:
x = q.get()
if x is None:
q.task_done()
return
with self.cumlock:
self.cum += x
q.task_done()
def queue_join_test(self, q):
self.cum = 0
for i in (0,1):
threading.Thread(target=self.worker, args=(q,)).start()
for i in xrange(100):
q.put(i)
q.join()
self.assertEqual(self.cum, sum(range(100)),
"q.join() did not block until all tasks were done")
for i in (0,1):
q.put(None) # instruct the threads to close
q.join() # verify that you can join twice
def test_queue_task_done(self):
# Test to make sure a queue task completed successfully.
q = self.type2test()
try:
q.task_done()
except ValueError:
pass
else:
self.fail("Did not detect task count going negative")
def test_queue_join(self):
# Test that a queue join()s successfully, and before anything else
# (done twice for insurance).
q = self.type2test()
self.queue_join_test(q)
self.queue_join_test(q)
try:
q.task_done()
except ValueError:
pass
else:
self.fail("Did not detect task count going negative")
def test_simple_queue(self):
# Do it a couple of times on the same queue.
# Done twice to make sure works with same instance reused.
q = self.type2test(QUEUE_SIZE)
self.simple_queue_test(q)
self.simple_queue_test(q)
class QueueTest(BaseQueueTest):
type2test = Queue.Queue
class LifoQueueTest(BaseQueueTest):
type2test = Queue.LifoQueue
class PriorityQueueTest(BaseQueueTest):
type2test = Queue.PriorityQueue
# A Queue subclass that can provoke failure at a moment's notice :)
class FailingQueueException(Exception):
pass
class FailingQueue(Queue.Queue):
def __init__(self, *args):
self.fail_next_put = False
self.fail_next_get = False
Queue.Queue.__init__(self, *args)
def _put(self, item):
if self.fail_next_put:
self.fail_next_put = False
raise FailingQueueException, "You Lose"
return Queue.Queue._put(self, item)
def _get(self):
if self.fail_next_get:
self.fail_next_get = False
raise FailingQueueException, "You Lose"
return Queue.Queue._get(self)
class FailingQueueTest(unittest.TestCase, BlockingTestMixin):
def failing_queue_test(self, q):
if not q.empty():
raise RuntimeError, "Call this function with an empty queue"
for i in range(QUEUE_SIZE-1):
q.put(i)
# Test a failing non-blocking put.
q.fail_next_put = True
try:
q.put("oops", block=0)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
q.fail_next_put = True
try:
q.put("oops", timeout=0.1)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
q.put("last")
self.assertTrue(q.full(), "Queue should be full")
# Test a failing blocking put
q.fail_next_put = True
try:
self.do_blocking_test(q.put, ("full",), q.get, ())
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# Check the Queue isn't damaged.
# put failed, but get succeeded - re-add
q.put("last")
# Test a failing timeout put
q.fail_next_put = True
try:
self.do_exceptional_blocking_test(q.put, ("full", True, 10), q.get, (),
FailingQueueException)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# Check the Queue isn't damaged.
# put failed, but get succeeded - re-add
q.put("last")
self.assertTrue(q.full(), "Queue should be full")
q.get()
self.assertTrue(not q.full(), "Queue should not be full")
q.put("last")
self.assertTrue(q.full(), "Queue should be full")
# Test a blocking put
self.do_blocking_test(q.put, ("full",), q.get, ())
# Empty it
for i in range(QUEUE_SIZE):
q.get()
self.assertTrue(q.empty(), "Queue should be empty")
q.put("first")
q.fail_next_get = True
try:
q.get()
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
self.assertTrue(not q.empty(), "Queue should not be empty")
q.fail_next_get = True
try:
q.get(timeout=0.1)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
self.assertTrue(not q.empty(), "Queue should not be empty")
q.get()
self.assertTrue(q.empty(), "Queue should be empty")
q.fail_next_get = True
try:
self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',),
FailingQueueException)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# put succeeded, but get failed.
self.assertTrue(not q.empty(), "Queue should not be empty")
q.get()
self.assertTrue(q.empty(), "Queue should be empty")
def test_failing_queue(self):
# Test to make sure a queue is functioning correctly.
# Done twice to the same instance.
q = FailingQueue(QUEUE_SIZE)
self.failing_queue_test(q)
self.failing_queue_test(q)
def test_main():
test_support.run_unittest(QueueTest, LifoQueueTest, PriorityQueueTest,
FailingQueueTest)
if __name__ == "__main__":
test_main()
from test import test_support
import unittest
import select
import os
import sys
@unittest.skipIf(sys.platform[:3] in ('win', 'os2', 'riscos'),
"can't easily test on this system")
class SelectTestCase(unittest.TestCase):
class Nope:
pass
class Almost:
def fileno(self):
return 'fileno'
def test_error_conditions(self):
self.assertRaises(TypeError, select.select, 1, 2, 3)
self.assertRaises(TypeError, select.select, [self.Nope()], [], [])
self.assertRaises(TypeError, select.select, [self.Almost()], [], [])
self.assertRaises(TypeError, select.select, [], [], [], "not a number")
def test_returned_list_identity(self):
# See issue #8329
r, w, x = select.select([], [], [], 1)
self.assertIsNot(r, w)
self.assertIsNot(r, x)
self.assertIsNot(w, x)
def test_select(self):
cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
p = os.popen(cmd, 'r')
for tout in (0, 1, 2, 4, 8, 16) + (None,)*10:
if test_support.verbose:
print 'timeout =', tout
rfd, wfd, xfd = select.select([p], [], [], tout)
if (rfd, wfd, xfd) == ([], [], []):
continue
if (rfd, wfd, xfd) == ([p], [], []):
line = p.readline()
if test_support.verbose:
print repr(line)
if not line:
if test_support.verbose:
print 'EOF'
break
continue
self.fail('Unexpected return values from select():', rfd, wfd, xfd)
p.close()
def test_main():
test_support.run_unittest(SelectTestCase)
test_support.reap_children()
if __name__ == "__main__":
test_main()
import unittest
from test import test_support
from contextlib import closing
import gc
import pickle
import select
import signal
import subprocess
import traceback
import sys, os, time, errno
if sys.platform in ('os2', 'riscos'):
raise unittest.SkipTest("Can't test signal on %s" % sys.platform)
class HandlerBCalled(Exception):
pass
def exit_subprocess():
"""Use os._exit(0) to exit the current subprocess.
Otherwise, the test catches the SystemExit and continues executing
in parallel with the original test, so you wind up with an
exponential number of tests running concurrently.
"""
os._exit(0)
def ignoring_eintr(__func, *args, **kwargs):
try:
return __func(*args, **kwargs)
except EnvironmentError as e:
if e.errno != errno.EINTR:
raise
return None
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class InterProcessSignalTests(unittest.TestCase):
MAX_DURATION = 20 # Entire test should last at most 20 sec.
def setUp(self):
self.using_gc = gc.isenabled()
gc.disable()
def tearDown(self):
if self.using_gc:
gc.enable()
def format_frame(self, frame, limit=None):
return ''.join(traceback.format_stack(frame, limit=limit))
def handlerA(self, signum, frame):
self.a_called = True
if test_support.verbose:
print "handlerA invoked from signal %s at:\n%s" % (
signum, self.format_frame(frame, limit=1))
def handlerB(self, signum, frame):
self.b_called = True
if test_support.verbose:
print "handlerB invoked from signal %s at:\n%s" % (
signum, self.format_frame(frame, limit=1))
raise HandlerBCalled(signum, self.format_frame(frame))
def wait(self, child):
"""Wait for child to finish, ignoring EINTR."""
while True:
try:
child.wait()
return
except OSError as e:
if e.errno != errno.EINTR:
raise
def run_test(self):
# Install handlers. This function runs in a sub-process, so we
# don't worry about re-setting the default handlers.
signal.signal(signal.SIGHUP, self.handlerA)
signal.signal(signal.SIGUSR1, self.handlerB)
signal.signal(signal.SIGUSR2, signal.SIG_IGN)
signal.signal(signal.SIGALRM, signal.default_int_handler)
# Variables the signals will modify:
self.a_called = False
self.b_called = False
# Let the sub-processes know who to send signals to.
pid = os.getpid()
if test_support.verbose:
print "test runner's pid is", pid
child = ignoring_eintr(subprocess.Popen, ['kill', '-HUP', str(pid)])
if child:
self.wait(child)
if not self.a_called:
time.sleep(1) # Give the signal time to be delivered.
self.assertTrue(self.a_called)
self.assertFalse(self.b_called)
self.a_called = False
# Make sure the signal isn't delivered while the previous
# Popen object is being destroyed, because __del__ swallows
# exceptions.
del child
try:
child = subprocess.Popen(['kill', '-USR1', str(pid)])
# This wait should be interrupted by the signal's exception.
self.wait(child)
time.sleep(1) # Give the signal time to be delivered.
self.fail('HandlerBCalled exception not thrown')
except HandlerBCalled:
self.assertTrue(self.b_called)
self.assertFalse(self.a_called)
if test_support.verbose:
print "HandlerBCalled exception caught"
child = ignoring_eintr(subprocess.Popen, ['kill', '-USR2', str(pid)])
if child:
self.wait(child) # Nothing should happen.
try:
signal.alarm(1)
# The race condition in pause doesn't matter in this case,
# since alarm is going to raise a KeyboardException, which
# will skip the call.
signal.pause()
# But if another signal arrives before the alarm, pause
# may return early.
time.sleep(1)
except KeyboardInterrupt:
if test_support.verbose:
print "KeyboardInterrupt (the alarm() went off)"
except:
self.fail("Some other exception woke us from pause: %s" %
traceback.format_exc())
else:
self.fail("pause returned of its own accord, and the signal"
" didn't arrive after another second.")
# Issue 3864. Unknown if this affects earlier versions of freebsd also.
@unittest.skipIf(sys.platform=='freebsd6',
'inter process signals not reliable (do not mix well with threading) '
'on freebsd6')
def test_main(self):
# This function spawns a child process to insulate the main
# test-running process from all the signals. It then
# communicates with that child process over a pipe and
# re-raises information about any exceptions the child
# throws. The real work happens in self.run_test().
os_done_r, os_done_w = os.pipe()
with closing(os.fdopen(os_done_r)) as done_r, \
closing(os.fdopen(os_done_w, 'w')) as done_w:
child = os.fork()
if child == 0:
# In the child process; run the test and report results
# through the pipe.
try:
done_r.close()
# Have to close done_w again here because
# exit_subprocess() will skip the enclosing with block.
with closing(done_w):
try:
self.run_test()
except:
pickle.dump(traceback.format_exc(), done_w)
else:
pickle.dump(None, done_w)
except:
print 'Uh oh, raised from pickle.'
traceback.print_exc()
finally:
exit_subprocess()
done_w.close()
# Block for up to MAX_DURATION seconds for the test to finish.
r, w, x = select.select([done_r], [], [], self.MAX_DURATION)
if done_r in r:
tb = pickle.load(done_r)
if tb:
self.fail(tb)
else:
os.kill(child, signal.SIGKILL)
self.fail('Test deadlocked after %d seconds.' %
self.MAX_DURATION)
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class BasicSignalTests(unittest.TestCase):
def trivial_signal_handler(self, *args):
pass
def test_out_of_range_signal_number_raises_error(self):
self.assertRaises(ValueError, signal.getsignal, 4242)
self.assertRaises(ValueError, signal.signal, 4242,
self.trivial_signal_handler)
def test_setting_signal_handler_to_none_raises_error(self):
self.assertRaises(TypeError, signal.signal,
signal.SIGUSR1, None)
def test_getsignal(self):
hup = signal.signal(signal.SIGHUP, self.trivial_signal_handler)
self.assertEqual(signal.getsignal(signal.SIGHUP),
self.trivial_signal_handler)
signal.signal(signal.SIGHUP, hup)
self.assertEqual(signal.getsignal(signal.SIGHUP), hup)
@unittest.skipUnless(sys.platform == "win32", "Windows specific")
class WindowsSignalTests(unittest.TestCase):
def test_issue9324(self):
# Updated for issue #10003, adding SIGBREAK
handler = lambda x, y: None
for sig in (signal.SIGABRT, signal.SIGBREAK, signal.SIGFPE,
signal.SIGILL, signal.SIGINT, signal.SIGSEGV,
signal.SIGTERM):
# Set and then reset a handler for signals that work on windows
signal.signal(sig, signal.signal(sig, handler))
with self.assertRaises(ValueError):
signal.signal(-1, handler)
with self.assertRaises(ValueError):
signal.signal(7, handler)
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class WakeupSignalTests(unittest.TestCase):
TIMEOUT_FULL = 10
TIMEOUT_HALF = 5
def test_wakeup_fd_early(self):
import select
signal.alarm(1)
before_time = time.time()
# We attempt to get a signal during the sleep,
# before select is called
time.sleep(self.TIMEOUT_FULL)
mid_time = time.time()
self.assertTrue(mid_time - before_time < self.TIMEOUT_HALF)
select.select([self.read], [], [], self.TIMEOUT_FULL)
after_time = time.time()
self.assertTrue(after_time - mid_time < self.TIMEOUT_HALF)
def test_wakeup_fd_during(self):
import select
signal.alarm(1)
before_time = time.time()
# We attempt to get a signal during the select call
self.assertRaises(select.error, select.select,
[self.read], [], [], self.TIMEOUT_FULL)
after_time = time.time()
self.assertTrue(after_time - before_time < self.TIMEOUT_HALF)
def setUp(self):
import fcntl
self.alrm = signal.signal(signal.SIGALRM, lambda x,y:None)
self.read, self.write = os.pipe()
flags = fcntl.fcntl(self.write, fcntl.F_GETFL, 0)
flags = flags | os.O_NONBLOCK
fcntl.fcntl(self.write, fcntl.F_SETFL, flags)
self.old_wakeup = signal.set_wakeup_fd(self.write)
def tearDown(self):
signal.set_wakeup_fd(self.old_wakeup)
os.close(self.read)
os.close(self.write)
signal.signal(signal.SIGALRM, self.alrm)
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class SiginterruptTest(unittest.TestCase):
def setUp(self):
"""Install a no-op signal handler that can be set to allow
interrupts or not, and arrange for the original signal handler to be
re-installed when the test is finished.
"""
self.signum = signal.SIGUSR1
oldhandler = signal.signal(self.signum, lambda x,y: None)
self.addCleanup(signal.signal, self.signum, oldhandler)
def readpipe_interrupted(self):
"""Perform a read during which a signal will arrive. Return True if the
read is interrupted by the signal and raises an exception. Return False
if it returns normally.
"""
# Create a pipe that can be used for the read. Also clean it up
# when the test is over, since nothing else will (but see below for
# the write end).
r, w = os.pipe()
self.addCleanup(os.close, r)
# Create another process which can send a signal to this one to try
# to interrupt the read.
ppid = os.getpid()
pid = os.fork()
if pid == 0:
# Child code: sleep to give the parent enough time to enter the
# read() call (there's a race here, but it's really tricky to
# eliminate it); then signal the parent process. Also, sleep
# again to make it likely that the signal is delivered to the
# parent process before the child exits. If the child exits
# first, the write end of the pipe will be closed and the test
# is invalid.
try:
time.sleep(0.2)
os.kill(ppid, self.signum)
time.sleep(0.2)
finally:
# No matter what, just exit as fast as possible now.
exit_subprocess()
else:
# Parent code.
# Make sure the child is eventually reaped, else it'll be a
# zombie for the rest of the test suite run.
self.addCleanup(os.waitpid, pid, 0)
# Close the write end of the pipe. The child has a copy, so
# it's not really closed until the child exits. We need it to
# close when the child exits so that in the non-interrupt case
# the read eventually completes, otherwise we could just close
# it *after* the test.
os.close(w)
# Try the read and report whether it is interrupted or not to
# the caller.
try:
d = os.read(r, 1)
return False
except OSError, err:
if err.errno != errno.EINTR:
raise
return True
def test_without_siginterrupt(self):
"""If a signal handler is installed and siginterrupt is not called
at all, when that signal arrives, it interrupts a syscall that's in
progress.
"""
i = self.readpipe_interrupted()
self.assertTrue(i)
# Arrival of the signal shouldn't have changed anything.
i = self.readpipe_interrupted()
self.assertTrue(i)
def test_siginterrupt_on(self):
"""If a signal handler is installed and siginterrupt is called with
a true value for the second argument, when that signal arrives, it
interrupts a syscall that's in progress.
"""
signal.siginterrupt(self.signum, 1)
i = self.readpipe_interrupted()
self.assertTrue(i)
# Arrival of the signal shouldn't have changed anything.
i = self.readpipe_interrupted()
self.assertTrue(i)
def test_siginterrupt_off(self):
"""If a signal handler is installed and siginterrupt is called with
a false value for the second argument, when that signal arrives, it
does not interrupt a syscall that's in progress.
"""
signal.siginterrupt(self.signum, 0)
i = self.readpipe_interrupted()
self.assertFalse(i)
# Arrival of the signal shouldn't have changed anything.
i = self.readpipe_interrupted()
self.assertFalse(i)
@unittest.skipIf(sys.platform == "win32", "Not valid on Windows")
class ItimerTest(unittest.TestCase):
def setUp(self):
self.hndl_called = False
self.hndl_count = 0
self.itimer = None
self.old_alarm = signal.signal(signal.SIGALRM, self.sig_alrm)
def tearDown(self):
signal.signal(signal.SIGALRM, self.old_alarm)
if self.itimer is not None: # test_itimer_exc doesn't change this attr
# just ensure that itimer is stopped
signal.setitimer(self.itimer, 0)
def sig_alrm(self, *args):
self.hndl_called = True
if test_support.verbose:
print("SIGALRM handler invoked", args)
def sig_vtalrm(self, *args):
self.hndl_called = True
if self.hndl_count > 3:
# it shouldn't be here, because it should have been disabled.
raise signal.ItimerError("setitimer didn't disable ITIMER_VIRTUAL "
"timer.")
elif self.hndl_count == 3:
# disable ITIMER_VIRTUAL, this function shouldn't be called anymore
signal.setitimer(signal.ITIMER_VIRTUAL, 0)
if test_support.verbose:
print("last SIGVTALRM handler call")
self.hndl_count += 1
if test_support.verbose:
print("SIGVTALRM handler invoked", args)
def sig_prof(self, *args):
self.hndl_called = True
signal.setitimer(signal.ITIMER_PROF, 0)
if test_support.verbose:
print("SIGPROF handler invoked", args)
def test_itimer_exc(self):
# XXX I'm assuming -1 is an invalid itimer, but maybe some platform
# defines it ?
self.assertRaises(signal.ItimerError, signal.setitimer, -1, 0)
# Negative times are treated as zero on some platforms.
if 0:
self.assertRaises(signal.ItimerError,
signal.setitimer, signal.ITIMER_REAL, -1)
def test_itimer_real(self):
self.itimer = signal.ITIMER_REAL
signal.setitimer(self.itimer, 1.0)
if test_support.verbose:
print("\ncall pause()...")
signal.pause()
self.assertEqual(self.hndl_called, True)
# Issue 3864. Unknown if this affects earlier versions of freebsd also.
@unittest.skipIf(sys.platform in ('freebsd6', 'netbsd5'),
'itimer not reliable (does not mix well with threading) on some BSDs.')
def test_itimer_virtual(self):
self.itimer = signal.ITIMER_VIRTUAL
signal.signal(signal.SIGVTALRM, self.sig_vtalrm)
signal.setitimer(self.itimer, 0.3, 0.2)
start_time = time.time()
while time.time() - start_time < 60.0:
# use up some virtual time by doing real work
_ = pow(12345, 67890, 10000019)
if signal.getitimer(self.itimer) == (0.0, 0.0):
break # sig_vtalrm handler stopped this itimer
else: # Issue 8424
self.skipTest("timeout: likely cause: machine too slow or load too "
"high")
# virtual itimer should be (0.0, 0.0) now
self.assertEqual(signal.getitimer(self.itimer), (0.0, 0.0))
# and the handler should have been called
self.assertEqual(self.hndl_called, True)
# Issue 3864. Unknown if this affects earlier versions of freebsd also.
@unittest.skipIf(sys.platform=='freebsd6',
'itimer not reliable (does not mix well with threading) on freebsd6')
def test_itimer_prof(self):
self.itimer = signal.ITIMER_PROF
signal.signal(signal.SIGPROF, self.sig_prof)
signal.setitimer(self.itimer, 0.2, 0.2)
start_time = time.time()
while time.time() - start_time < 60.0:
# do some work
_ = pow(12345, 67890, 10000019)
if signal.getitimer(self.itimer) == (0.0, 0.0):
break # sig_prof handler stopped this itimer
else: # Issue 8424
self.skipTest("timeout: likely cause: machine too slow or load too "
"high")
# profiling itimer should be (0.0, 0.0) now
self.assertEqual(signal.getitimer(self.itimer), (0.0, 0.0))
# and the handler should have been called
self.assertEqual(self.hndl_called, True)
def test_main():
test_support.run_unittest(BasicSignalTests, InterProcessSignalTests,
WakeupSignalTests, SiginterruptTest,
ItimerTest, WindowsSignalTests)
if __name__ == "__main__":
test_main()
import asyncore
import email.utils
import socket
import smtpd
import smtplib
import StringIO
import sys
import time
import select
import unittest
from test import test_support
try:
import threading
except ImportError:
threading = None
HOST = test_support.HOST
def server(evt, buf, serv):
serv.listen(5)
evt.set()
try:
conn, addr = serv.accept()
except socket.timeout:
pass
else:
n = 500
while buf and n > 0:
r, w, e = select.select([], [conn], [])
if w:
sent = conn.send(buf)
buf = buf[sent:]
n -= 1
conn.close()
finally:
serv.close()
evt.set()
@unittest.skipUnless(threading, 'Threading required for this test.')
class GeneralTests(unittest.TestCase):
def setUp(self):
self._threads = test_support.threading_setup()
self.evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(15)
self.port = test_support.bind_port(self.sock)
servargs = (self.evt, "220 Hola mundo\n", self.sock)
self.thread = threading.Thread(target=server, args=servargs)
self.thread.start()
self.evt.wait()
self.evt.clear()
def tearDown(self):
self.evt.wait()
self.thread.join()
test_support.threading_cleanup(*self._threads)
def testBasic1(self):
# connects
smtp = smtplib.SMTP(HOST, self.port)
smtp.close()
def testBasic2(self):
# connects, include port in host name
smtp = smtplib.SMTP("%s:%s" % (HOST, self.port))
smtp.close()
def testLocalHostName(self):
# check that supplied local_hostname is used
smtp = smtplib.SMTP(HOST, self.port, local_hostname="testhost")
self.assertEqual(smtp.local_hostname, "testhost")
smtp.close()
def testTimeoutDefault(self):
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
smtp = smtplib.SMTP(HOST, self.port)
finally:
socket.setdefaulttimeout(None)
self.assertEqual(smtp.sock.gettimeout(), 30)
smtp.close()
def testTimeoutNone(self):
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
smtp = smtplib.SMTP(HOST, self.port, timeout=None)
finally:
socket.setdefaulttimeout(None)
self.assertTrue(smtp.sock.gettimeout() is None)
smtp.close()
def testTimeoutValue(self):
smtp = smtplib.SMTP(HOST, self.port, timeout=30)
self.assertEqual(smtp.sock.gettimeout(), 30)
smtp.close()
# Test server thread using the specified SMTP server class
def debugging_server(serv, serv_evt, client_evt):
serv_evt.set()
try:
if hasattr(select, 'poll'):
poll_fun = asyncore.poll2
else:
poll_fun = asyncore.poll
n = 1000
while asyncore.socket_map and n > 0:
poll_fun(0.01, asyncore.socket_map)
# when the client conversation is finished, it will
# set client_evt, and it's then ok to kill the server
if client_evt.is_set():
serv.close()
break
n -= 1
except socket.timeout:
pass
finally:
if not client_evt.is_set():
# allow some time for the client to read the result
time.sleep(0.5)
serv.close()
asyncore.close_all()
serv_evt.set()
MSG_BEGIN = '---------- MESSAGE FOLLOWS ----------\n'
MSG_END = '------------ END MESSAGE ------------\n'
# NOTE: Some SMTP objects in the tests below are created with a non-default
# local_hostname argument to the constructor, since (on some systems) the FQDN
# lookup caused by the default local_hostname sometimes takes so long that the
# test server times out, causing the test to fail.
# Test behavior of smtpd.DebuggingServer
@unittest.skipUnless(threading, 'Threading required for this test.')
class DebuggingServerTests(unittest.TestCase):
def setUp(self):
# temporarily replace sys.stdout to capture DebuggingServer output
self.old_stdout = sys.stdout
self.output = StringIO.StringIO()
sys.stdout = self.output
self._threads = test_support.threading_setup()
self.serv_evt = threading.Event()
self.client_evt = threading.Event()
# Pick a random unused port by passing 0 for the port number
self.serv = smtpd.DebuggingServer((HOST, 0), ('nowhere', -1))
# Keep a note of what port was assigned
self.port = self.serv.socket.getsockname()[1]
serv_args = (self.serv, self.serv_evt, self.client_evt)
self.thread = threading.Thread(target=debugging_server, args=serv_args)
self.thread.start()
# wait until server thread has assigned a port number
self.serv_evt.wait()
self.serv_evt.clear()
def tearDown(self):
# indicate that the client is finished
self.client_evt.set()
# wait for the server thread to terminate
self.serv_evt.wait()
self.thread.join()
test_support.threading_cleanup(*self._threads)
# restore sys.stdout
sys.stdout = self.old_stdout
def testBasic(self):
# connect
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
smtp.quit()
def testNOOP(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
expected = (250, 'Ok')
self.assertEqual(smtp.noop(), expected)
smtp.quit()
def testRSET(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
expected = (250, 'Ok')
self.assertEqual(smtp.rset(), expected)
smtp.quit()
def testNotImplemented(self):
# EHLO isn't implemented in DebuggingServer
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
expected = (502, 'Error: command "EHLO" not implemented')
self.assertEqual(smtp.ehlo(), expected)
smtp.quit()
def testVRFY(self):
# VRFY isn't implemented in DebuggingServer
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
expected = (502, 'Error: command "VRFY" not implemented')
self.assertEqual(smtp.vrfy('nobody@nowhere.com'), expected)
self.assertEqual(smtp.verify('nobody@nowhere.com'), expected)
smtp.quit()
def testSecondHELO(self):
# check that a second HELO returns a message that it's a duplicate
# (this behavior is specific to smtpd.SMTPChannel)
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
smtp.helo()
expected = (503, 'Duplicate HELO/EHLO')
self.assertEqual(smtp.helo(), expected)
smtp.quit()
def testHELP(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
self.assertEqual(smtp.help(), 'Error: command "HELP" not implemented')
smtp.quit()
def testSend(self):
# connect and send mail
m = 'A test message'
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=3)
smtp.sendmail('John', 'Sally', m)
# XXX(nnorwitz): this test is flaky and dies with a bad file descriptor
# in asyncore. This sleep might help, but should really be fixed
# properly by using an Event variable.
time.sleep(0.01)
smtp.quit()
self.client_evt.set()
self.serv_evt.wait()
self.output.flush()
mexpect = '%s%s\n%s' % (MSG_BEGIN, m, MSG_END)
self.assertEqual(self.output.getvalue(), mexpect)
class NonConnectingTests(unittest.TestCase):
def testNotConnected(self):
# Test various operations on an unconnected SMTP object that
# should raise exceptions (at present the attempt in SMTP.send
# to reference the nonexistent 'sock' attribute of the SMTP object
# causes an AttributeError)
smtp = smtplib.SMTP()
self.assertRaises(smtplib.SMTPServerDisconnected, smtp.ehlo)
self.assertRaises(smtplib.SMTPServerDisconnected,
smtp.send, 'test msg')
def testNonnumericPort(self):
# check that non-numeric port raises socket.error
self.assertRaises(socket.error, smtplib.SMTP,
"localhost", "bogus")
self.assertRaises(socket.error, smtplib.SMTP,
"localhost:bogus")
# test response of client to a non-successful HELO message
@unittest.skipUnless(threading, 'Threading required for this test.')
class BadHELOServerTests(unittest.TestCase):
def setUp(self):
self.old_stdout = sys.stdout
self.output = StringIO.StringIO()
sys.stdout = self.output
self._threads = test_support.threading_setup()
self.evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(15)
self.port = test_support.bind_port(self.sock)
servargs = (self.evt, "199 no hello for you!\n", self.sock)
self.thread = threading.Thread(target=server, args=servargs)
self.thread.start()
self.evt.wait()
self.evt.clear()
def tearDown(self):
self.evt.wait()
self.thread.join()
test_support.threading_cleanup(*self._threads)
sys.stdout = self.old_stdout
def testFailingHELO(self):
self.assertRaises(smtplib.SMTPConnectError, smtplib.SMTP,
HOST, self.port, 'localhost', 3)
sim_users = {'Mr.A@somewhere.com':'John A',
'Ms.B@somewhere.com':'Sally B',
'Mrs.C@somewhereesle.com':'Ruth C',
}
sim_auth = ('Mr.A@somewhere.com', 'somepassword')
sim_cram_md5_challenge = ('PENCeUxFREJoU0NnbmhNWitOMjNGNn'
'dAZWx3b29kLmlubm9zb2Z0LmNvbT4=')
sim_auth_credentials = {
'login': 'TXIuQUBzb21ld2hlcmUuY29t',
'plain': 'AE1yLkFAc29tZXdoZXJlLmNvbQBzb21lcGFzc3dvcmQ=',
'cram-md5': ('TXIUQUBZB21LD2HLCMUUY29TIDG4OWQ0MJ'
'KWZGQ4ODNMNDA4NTGXMDRLZWMYZJDMODG1'),
}
sim_auth_login_password = 'C29TZXBHC3N3B3JK'
sim_lists = {'list-1':['Mr.A@somewhere.com','Mrs.C@somewhereesle.com'],
'list-2':['Ms.B@somewhere.com',],
}
# Simulated SMTP channel & server
class SimSMTPChannel(smtpd.SMTPChannel):
def __init__(self, extra_features, *args, **kw):
self._extrafeatures = ''.join(
[ "250-{0}\r\n".format(x) for x in extra_features ])
smtpd.SMTPChannel.__init__(self, *args, **kw)
def smtp_EHLO(self, arg):
resp = ('250-testhost\r\n'
'250-EXPN\r\n'
'250-SIZE 20000000\r\n'
'250-STARTTLS\r\n'
'250-DELIVERBY\r\n')
resp = resp + self._extrafeatures + '250 HELP'
self.push(resp)
def smtp_VRFY(self, arg):
# For max compatibility smtplib should be sending the raw address.
if arg in sim_users:
self.push('250 %s %s' % (sim_users[arg], smtplib.quoteaddr(arg)))
else:
self.push('550 No such user: %s' % arg)
def smtp_EXPN(self, arg):
list_name = arg.lower()
if list_name in sim_lists:
user_list = sim_lists[list_name]
for n, user_email in enumerate(user_list):
quoted_addr = smtplib.quoteaddr(user_email)
if n < len(user_list) - 1:
self.push('250-%s %s' % (sim_users[user_email], quoted_addr))
else:
self.push('250 %s %s' % (sim_users[user_email], quoted_addr))
else:
self.push('550 No access for you!')
def smtp_AUTH(self, arg):
if arg.strip().lower()=='cram-md5':
self.push('334 {0}'.format(sim_cram_md5_challenge))
return
mech, auth = arg.split()
mech = mech.lower()
if mech not in sim_auth_credentials:
self.push('504 auth type unimplemented')
return
if mech == 'plain' and auth==sim_auth_credentials['plain']:
self.push('235 plain auth ok')
elif mech=='login' and auth==sim_auth_credentials['login']:
self.push('334 Password:')
else:
self.push('550 No access for you!')
def handle_error(self):
raise
class SimSMTPServer(smtpd.SMTPServer):
def __init__(self, *args, **kw):
self._extra_features = []
smtpd.SMTPServer.__init__(self, *args, **kw)
def handle_accept(self):
conn, addr = self.accept()
self._SMTPchannel = SimSMTPChannel(self._extra_features,
self, conn, addr)
def process_message(self, peer, mailfrom, rcpttos, data):
pass
def add_feature(self, feature):
self._extra_features.append(feature)
def handle_error(self):
raise
# Test various SMTP & ESMTP commands/behaviors that require a simulated server
# (i.e., something with more features than DebuggingServer)
@unittest.skipUnless(threading, 'Threading required for this test.')
class SMTPSimTests(unittest.TestCase):
def setUp(self):
self._threads = test_support.threading_setup()
self.serv_evt = threading.Event()
self.client_evt = threading.Event()
# Pick a random unused port by passing 0 for the port number
self.serv = SimSMTPServer((HOST, 0), ('nowhere', -1))
# Keep a note of what port was assigned
self.port = self.serv.socket.getsockname()[1]
serv_args = (self.serv, self.serv_evt, self.client_evt)
self.thread = threading.Thread(target=debugging_server, args=serv_args)
self.thread.start()
# wait until server thread has assigned a port number
self.serv_evt.wait()
self.serv_evt.clear()
def tearDown(self):
# indicate that the client is finished
self.client_evt.set()
# wait for the server thread to terminate
self.serv_evt.wait()
self.thread.join()
test_support.threading_cleanup(*self._threads)
def testBasic(self):
# smoke test
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
smtp.quit()
def testEHLO(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
# no features should be present before the EHLO
self.assertEqual(smtp.esmtp_features, {})
# features expected from the test server
expected_features = {'expn':'',
'size': '20000000',
'starttls': '',
'deliverby': '',
'help': '',
}
smtp.ehlo()
self.assertEqual(smtp.esmtp_features, expected_features)
for k in expected_features:
self.assertTrue(smtp.has_extn(k))
self.assertFalse(smtp.has_extn('unsupported-feature'))
smtp.quit()
def testVRFY(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
for email, name in sim_users.items():
expected_known = (250, '%s %s' % (name, smtplib.quoteaddr(email)))
self.assertEqual(smtp.vrfy(email), expected_known)
u = 'nobody@nowhere.com'
expected_unknown = (550, 'No such user: %s' % u)
self.assertEqual(smtp.vrfy(u), expected_unknown)
smtp.quit()
def testEXPN(self):
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
for listname, members in sim_lists.items():
users = []
for m in members:
users.append('%s %s' % (sim_users[m], smtplib.quoteaddr(m)))
expected_known = (250, '\n'.join(users))
self.assertEqual(smtp.expn(listname), expected_known)
u = 'PSU-Members-List'
expected_unknown = (550, 'No access for you!')
self.assertEqual(smtp.expn(u), expected_unknown)
smtp.quit()
def testAUTH_PLAIN(self):
self.serv.add_feature("AUTH PLAIN")
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
expected_auth_ok = (235, b'plain auth ok')
self.assertEqual(smtp.login(sim_auth[0], sim_auth[1]), expected_auth_ok)
# SimSMTPChannel doesn't fully support LOGIN or CRAM-MD5 auth because they
# require a synchronous read to obtain the credentials...so instead smtpd
# sees the credential sent by smtplib's login method as an unknown command,
# which results in smtplib raising an auth error. Fortunately the error
# message contains the encoded credential, so we can partially check that it
# was generated correctly (partially, because the 'word' is uppercased in
# the error message).
def testAUTH_LOGIN(self):
self.serv.add_feature("AUTH LOGIN")
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
try: smtp.login(sim_auth[0], sim_auth[1])
except smtplib.SMTPAuthenticationError as err:
if sim_auth_login_password not in str(err):
raise "expected encoded password not found in error message"
def testAUTH_CRAM_MD5(self):
self.serv.add_feature("AUTH CRAM-MD5")
smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15)
try: smtp.login(sim_auth[0], sim_auth[1])
except smtplib.SMTPAuthenticationError as err:
if sim_auth_credentials['cram-md5'] not in str(err):
raise "expected encoded credentials not found in error message"
#TODO: add tests for correct AUTH method fallback now that the
#test infrastructure can support it.
def test_main(verbose=None):
test_support.run_unittest(GeneralTests, DebuggingServerTests,
NonConnectingTests,
BadHELOServerTests, SMTPSimTests)
if __name__ == '__main__':
test_main()
#!/usr/bin/env python
import unittest
from test import test_support
import errno
import socket
import select
import time
import traceback
import Queue
import sys
import os
import array
import contextlib
from weakref import proxy
import signal
import math
def try_address(host, port=0, family=socket.AF_INET):
"""Try to bind a socket on the given host:port and return True
if that has been possible."""
try:
sock = socket.socket(family, socket.SOCK_STREAM)
sock.bind((host, port))
except (socket.error, socket.gaierror):
return False
else:
sock.close()
return True
HOST = test_support.HOST
MSG = b'Michael Gilfix was here\n'
SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6)
try:
import thread
import threading
except ImportError:
thread = None
threading = None
HOST = test_support.HOST
MSG = 'Michael Gilfix was here\n'
class SocketTCPTest(unittest.TestCase):
def setUp(self):
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = test_support.bind_port(self.serv)
self.serv.listen(1)
def tearDown(self):
self.serv.close()
self.serv = None
class SocketUDPTest(unittest.TestCase):
def setUp(self):
self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.port = test_support.bind_port(self.serv)
def tearDown(self):
self.serv.close()
self.serv = None
class ThreadableTest:
"""Threadable Test class
The ThreadableTest class makes it easy to create a threaded
client/server pair from an existing unit test. To create a
new threaded class from an existing unit test, use multiple
inheritance:
class NewClass (OldClass, ThreadableTest):
pass
This class defines two new fixture functions with obvious
purposes for overriding:
clientSetUp ()
clientTearDown ()
Any new test functions within the class must then define
tests in pairs, where the test name is preceeded with a
'_' to indicate the client portion of the test. Ex:
def testFoo(self):
# Server portion
def _testFoo(self):
# Client portion
Any exceptions raised by the clients during their tests
are caught and transferred to the main thread to alert
the testing framework.
Note, the server setup function cannot call any blocking
functions that rely on the client thread during setup,
unless serverExplicitReady() is called just before
the blocking call (such as in setting up a client/server
connection and performing the accept() in setUp().
"""
def __init__(self):
# Swap the true setup function
self.__setUp = self.setUp
self.__tearDown = self.tearDown
self.setUp = self._setUp
self.tearDown = self._tearDown
def serverExplicitReady(self):
"""This method allows the server to explicitly indicate that
it wants the client thread to proceed. This is useful if the
server is about to execute a blocking routine that is
dependent upon the client thread during its setup routine."""
self.server_ready.set()
def _setUp(self):
self.server_ready = threading.Event()
self.client_ready = threading.Event()
self.done = threading.Event()
self.queue = Queue.Queue(1)
# Do some munging to start the client test.
methodname = self.id()
i = methodname.rfind('.')
methodname = methodname[i+1:]
test_method = getattr(self, '_' + methodname)
self.client_thread = thread.start_new_thread(
self.clientRun, (test_method,))
self.__setUp()
if not self.server_ready.is_set():
self.server_ready.set()
self.client_ready.wait()
def _tearDown(self):
self.__tearDown()
self.done.wait()
if not self.queue.empty():
msg = self.queue.get()
self.fail(msg)
def clientRun(self, test_func):
self.server_ready.wait()
self.clientSetUp()
self.client_ready.set()
if not callable(test_func):
raise TypeError("test_func must be a callable function.")
try:
test_func()
except Exception, strerror:
self.queue.put(strerror)
self.clientTearDown()
def clientSetUp(self):
raise NotImplementedError("clientSetUp must be implemented.")
def clientTearDown(self):
self.done.set()
thread.exit()
class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest):
def __init__(self, methodName='runTest'):
SocketTCPTest.__init__(self, methodName=methodName)
ThreadableTest.__init__(self)
def clientSetUp(self):
self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def clientTearDown(self):
self.cli.close()
self.cli = None
ThreadableTest.clientTearDown(self)
class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
def __init__(self, methodName='runTest'):
SocketUDPTest.__init__(self, methodName=methodName)
ThreadableTest.__init__(self)
def clientSetUp(self):
self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
def clientTearDown(self):
self.cli.close()
self.cli = None
ThreadableTest.clientTearDown(self)
class SocketConnectedTest(ThreadedTCPSocketTest):
def __init__(self, methodName='runTest'):
ThreadedTCPSocketTest.__init__(self, methodName=methodName)
def setUp(self):
ThreadedTCPSocketTest.setUp(self)
# Indicate explicitly we're ready for the client thread to
# proceed and then perform the blocking call to accept
self.serverExplicitReady()
conn, addr = self.serv.accept()
self.cli_conn = conn
def tearDown(self):
self.cli_conn.close()
self.cli_conn = None
ThreadedTCPSocketTest.tearDown(self)
def clientSetUp(self):
ThreadedTCPSocketTest.clientSetUp(self)
self.cli.connect((HOST, self.port))
self.serv_conn = self.cli
def clientTearDown(self):
self.serv_conn.close()
self.serv_conn = None
ThreadedTCPSocketTest.clientTearDown(self)
class SocketPairTest(unittest.TestCase, ThreadableTest):
def __init__(self, methodName='runTest'):
unittest.TestCase.__init__(self, methodName=methodName)
ThreadableTest.__init__(self)
def setUp(self):
self.serv, self.cli = socket.socketpair()
def tearDown(self):
self.serv.close()
self.serv = None
def clientSetUp(self):
pass
def clientTearDown(self):
self.cli.close()
self.cli = None
ThreadableTest.clientTearDown(self)
#######################################################################
## Begin Tests
class GeneralModuleTests(unittest.TestCase):
def test_weakref(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
p = proxy(s)
self.assertEqual(p.fileno(), s.fileno())
s.close()
s = None
try:
p.fileno()
except ReferenceError:
pass
else:
self.fail('Socket proxy still exists')
def testSocketError(self):
# Testing socket module exceptions
def raise_error(*args, **kwargs):
raise socket.error
def raise_herror(*args, **kwargs):
raise socket.herror
def raise_gaierror(*args, **kwargs):
raise socket.gaierror
self.assertRaises(socket.error, raise_error,
"Error raising socket exception.")
self.assertRaises(socket.error, raise_herror,
"Error raising socket exception.")
self.assertRaises(socket.error, raise_gaierror,
"Error raising socket exception.")
def testSendtoErrors(self):
# Testing that sendto doens't masks failures. See #10169.
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.addCleanup(s.close)
s.bind(('', 0))
sockname = s.getsockname()
# 2 args
with self.assertRaises(UnicodeEncodeError):
s.sendto(u'\u2620', sockname)
with self.assertRaises(TypeError) as cm:
s.sendto(5j, sockname)
self.assertIn('not complex', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto('foo', None)
self.assertIn('not NoneType', str(cm.exception))
# 3 args
with self.assertRaises(UnicodeEncodeError):
s.sendto(u'\u2620', 0, sockname)
with self.assertRaises(TypeError) as cm:
s.sendto(5j, 0, sockname)
self.assertIn('not complex', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto('foo', 0, None)
self.assertIn('not NoneType', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto('foo', 'bar', sockname)
self.assertIn('an integer is required', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto('foo', None, None)
self.assertIn('an integer is required', str(cm.exception))
# wrong number of args
with self.assertRaises(TypeError) as cm:
s.sendto('foo')
self.assertIn('(1 given)', str(cm.exception))
with self.assertRaises(TypeError) as cm:
s.sendto('foo', 0, sockname, 4)
self.assertIn('(4 given)', str(cm.exception))
def testCrucialConstants(self):
# Testing for mission critical constants
socket.AF_INET
socket.SOCK_STREAM
socket.SOCK_DGRAM
socket.SOCK_RAW
socket.SOCK_RDM
socket.SOCK_SEQPACKET
socket.SOL_SOCKET
socket.SO_REUSEADDR
def testHostnameRes(self):
# Testing hostname resolution mechanisms
hostname = socket.gethostname()
try:
ip = socket.gethostbyname(hostname)
except socket.error:
# Probably name lookup wasn't set up right; skip this test
return
self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.")
try:
hname, aliases, ipaddrs = socket.gethostbyaddr(ip)
except socket.error:
# Probably a similar problem as above; skip this test
return
all_host_names = [hostname, hname] + aliases
fqhn = socket.getfqdn(ip)
if not fqhn in all_host_names:
self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names)))
def testRefCountGetNameInfo(self):
# Testing reference count for getnameinfo
if hasattr(sys, "getrefcount"):
try:
# On some versions, this loses a reference
orig = sys.getrefcount(__name__)
socket.getnameinfo(__name__,0)
except TypeError:
self.assertEqual(sys.getrefcount(__name__), orig,
"socket.getnameinfo loses a reference")
def testInterpreterCrash(self):
# Making sure getnameinfo doesn't crash the interpreter
try:
# On some versions, this crashes the interpreter.
socket.getnameinfo(('x', 0, 0, 0), 0)
except socket.error:
pass
def testNtoH(self):
# This just checks that htons etc. are their own inverse,
# when looking at the lower 16 or 32 bits.
sizes = {socket.htonl: 32, socket.ntohl: 32,
socket.htons: 16, socket.ntohs: 16}
for func, size in sizes.items():
mask = (1L<<size) - 1
for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210):
self.assertEqual(i & mask, func(func(i&mask)) & mask)
swapped = func(mask)
self.assertEqual(swapped & mask, mask)
self.assertRaises(OverflowError, func, 1L<<34)
def testNtoHErrors(self):
good_values = [ 1, 2, 3, 1L, 2L, 3L ]
bad_values = [ -1, -2, -3, -1L, -2L, -3L ]
for k in good_values:
socket.ntohl(k)
socket.ntohs(k)
socket.htonl(k)
socket.htons(k)
for k in bad_values:
self.assertRaises(OverflowError, socket.ntohl, k)
self.assertRaises(OverflowError, socket.ntohs, k)
self.assertRaises(OverflowError, socket.htonl, k)
self.assertRaises(OverflowError, socket.htons, k)
def testGetServBy(self):
eq = self.assertEqual
# Find one service that exists, then check all the related interfaces.
# I've ordered this by protocols that have both a tcp and udp
# protocol, at least for modern Linuxes.
if (sys.platform.startswith('linux') or
sys.platform.startswith('freebsd') or
sys.platform.startswith('netbsd') or
sys.platform == 'darwin'):
# avoid the 'echo' service on this platform, as there is an
# assumption breaking non-standard port/protocol entry
services = ('daytime', 'qotd', 'domain')
else:
services = ('echo', 'daytime', 'domain')
for service in services:
try:
port = socket.getservbyname(service, 'tcp')
break
except socket.error:
pass
else:
raise socket.error
# Try same call with optional protocol omitted
port2 = socket.getservbyname(service)
eq(port, port2)
# Try udp, but don't barf it it doesn't exist
try:
udpport = socket.getservbyname(service, 'udp')
except socket.error:
udpport = None
else:
eq(udpport, port)
# Now make sure the lookup by port returns the same service name
eq(socket.getservbyport(port2), service)
eq(socket.getservbyport(port, 'tcp'), service)
if udpport is not None:
eq(socket.getservbyport(udpport, 'udp'), service)
# Make sure getservbyport does not accept out of range ports.
self.assertRaises(OverflowError, socket.getservbyport, -1)
self.assertRaises(OverflowError, socket.getservbyport, 65536)
def testDefaultTimeout(self):
# Testing default timeout
# The default timeout should initially be None
self.assertEqual(socket.getdefaulttimeout(), None)
s = socket.socket()
self.assertEqual(s.gettimeout(), None)
s.close()
# Set the default timeout to 10, and see if it propagates
socket.setdefaulttimeout(10)
self.assertEqual(socket.getdefaulttimeout(), 10)
s = socket.socket()
self.assertEqual(s.gettimeout(), 10)
s.close()
# Reset the default timeout to None, and see if it propagates
socket.setdefaulttimeout(None)
self.assertEqual(socket.getdefaulttimeout(), None)
s = socket.socket()
self.assertEqual(s.gettimeout(), None)
s.close()
# Check that setting it to an invalid value raises ValueError
self.assertRaises(ValueError, socket.setdefaulttimeout, -1)
# Check that setting it to an invalid type raises TypeError
self.assertRaises(TypeError, socket.setdefaulttimeout, "spam")
def testIPv4_inet_aton_fourbytes(self):
if not hasattr(socket, 'inet_aton'):
return # No inet_aton, nothing to check
# Test that issue1008086 and issue767150 are fixed.
# It must return 4 bytes.
self.assertEqual('\x00'*4, socket.inet_aton('0.0.0.0'))
self.assertEqual('\xff'*4, socket.inet_aton('255.255.255.255'))
def testIPv4toString(self):
if not hasattr(socket, 'inet_pton'):
return # No inet_pton() on this platform
from socket import inet_aton as f, inet_pton, AF_INET
g = lambda a: inet_pton(AF_INET, a)
self.assertEqual('\x00\x00\x00\x00', f('0.0.0.0'))
self.assertEqual('\xff\x00\xff\x00', f('255.0.255.0'))
self.assertEqual('\xaa\xaa\xaa\xaa', f('170.170.170.170'))
self.assertEqual('\x01\x02\x03\x04', f('1.2.3.4'))
self.assertEqual('\xff\xff\xff\xff', f('255.255.255.255'))
self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0'))
self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0'))
self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170'))
self.assertEqual('\xff\xff\xff\xff', g('255.255.255.255'))
def testIPv6toString(self):
if not hasattr(socket, 'inet_pton'):
return # No inet_pton() on this platform
try:
from socket import inet_pton, AF_INET6, has_ipv6
if not has_ipv6:
return
except ImportError:
return
f = lambda a: inet_pton(AF_INET6, a)
self.assertEqual('\x00' * 16, f('::'))
self.assertEqual('\x00' * 16, f('0::0'))
self.assertEqual('\x00\x01' + '\x00' * 14, f('1::'))
self.assertEqual(
'\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae',
f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae')
)
def testStringToIPv4(self):
if not hasattr(socket, 'inet_ntop'):
return # No inet_ntop() on this platform
from socket import inet_ntoa as f, inet_ntop, AF_INET
g = lambda a: inet_ntop(AF_INET, a)
self.assertEqual('1.0.1.0', f('\x01\x00\x01\x00'))
self.assertEqual('170.85.170.85', f('\xaa\x55\xaa\x55'))
self.assertEqual('255.255.255.255', f('\xff\xff\xff\xff'))
self.assertEqual('1.2.3.4', f('\x01\x02\x03\x04'))
self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00'))
self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55'))
self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff'))
def testStringToIPv6(self):
if not hasattr(socket, 'inet_ntop'):
return # No inet_ntop() on this platform
try:
from socket import inet_ntop, AF_INET6, has_ipv6
if not has_ipv6:
return
except ImportError:
return
f = lambda a: inet_ntop(AF_INET6, a)
self.assertEqual('::', f('\x00' * 16))
self.assertEqual('::1', f('\x00' * 15 + '\x01'))
self.assertEqual(
'aef:b01:506:1001:ffff:9997:55:170',
f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70')
)
# XXX The following don't test module-level functionality...
def _get_unused_port(self, bind_address='0.0.0.0'):
"""Use a temporary socket to elicit an unused ephemeral port.
Args:
bind_address: Hostname or IP address to search for a port on.
Returns: A most likely to be unused port.
"""
tempsock = socket.socket()
tempsock.bind((bind_address, 0))
host, port = tempsock.getsockname()
tempsock.close()
return port
def testSockName(self):
# Testing getsockname()
port = self._get_unused_port()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.addCleanup(sock.close)
sock.bind(("0.0.0.0", port))
name = sock.getsockname()
# XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
# it reasonable to get the host's addr in addition to 0.0.0.0.
# At least for eCos. This is required for the S/390 to pass.
try:
my_ip_addr = socket.gethostbyname(socket.gethostname())
except socket.error:
# Probably name lookup wasn't set up right; skip this test
return
self.assertIn(name[0], ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
self.assertEqual(name[1], port)
def testGetSockOpt(self):
# Testing getsockopt()
# We know a socket should start without reuse==0
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.addCleanup(sock.close)
reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
self.assertFalse(reuse != 0, "initial mode is reuse")
def testSetSockOpt(self):
# Testing setsockopt()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.addCleanup(sock.close)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)
self.assertFalse(reuse == 0, "failed to set reuse mode")
def testSendAfterClose(self):
# testing send() after close() with timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(1)
sock.close()
self.assertRaises(socket.error, sock.send, "spam")
def testNewAttributes(self):
# testing .family, .type and .protocol
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.assertEqual(sock.family, socket.AF_INET)
self.assertEqual(sock.type, socket.SOCK_STREAM)
self.assertEqual(sock.proto, 0)
sock.close()
def test_getsockaddrarg(self):
host = '0.0.0.0'
port = self._get_unused_port(bind_address=host)
big_port = port + 65536
neg_port = port - 65536
sock = socket.socket()
try:
self.assertRaises(OverflowError, sock.bind, (host, big_port))
self.assertRaises(OverflowError, sock.bind, (host, neg_port))
sock.bind((host, port))
finally:
sock.close()
@unittest.skipUnless(os.name == "nt", "Windows specific")
def test_sock_ioctl(self):
self.assertTrue(hasattr(socket.socket, 'ioctl'))
self.assertTrue(hasattr(socket, 'SIO_RCVALL'))
self.assertTrue(hasattr(socket, 'RCVALL_ON'))
self.assertTrue(hasattr(socket, 'RCVALL_OFF'))
self.assertTrue(hasattr(socket, 'SIO_KEEPALIVE_VALS'))
s = socket.socket()
self.addCleanup(s.close)
self.assertRaises(ValueError, s.ioctl, -1, None)
s.ioctl(socket.SIO_KEEPALIVE_VALS, (1, 100, 100))
def testGetaddrinfo(self):
try:
socket.getaddrinfo('localhost', 80)
except socket.gaierror as err:
if err.errno == socket.EAI_SERVICE:
# see http://bugs.python.org/issue1282647
self.skipTest("buggy libc version")
raise
# len of every sequence is supposed to be == 5
for info in socket.getaddrinfo(HOST, None):
self.assertEqual(len(info), 5)
# host can be a domain name, a string representation of an
# IPv4/v6 address or None
socket.getaddrinfo('localhost', 80)
socket.getaddrinfo('127.0.0.1', 80)
socket.getaddrinfo(None, 80)
if SUPPORTS_IPV6:
socket.getaddrinfo('::1', 80)
# port can be a string service name such as "http", a numeric
# port number or None
socket.getaddrinfo(HOST, "http")
socket.getaddrinfo(HOST, 80)
socket.getaddrinfo(HOST, None)
# test family and socktype filters
infos = socket.getaddrinfo(HOST, None, socket.AF_INET)
for family, _, _, _, _ in infos:
self.assertEqual(family, socket.AF_INET)
infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM)
for _, socktype, _, _, _ in infos:
self.assertEqual(socktype, socket.SOCK_STREAM)
# test proto and flags arguments
socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP)
socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE)
# a server willing to support both IPv4 and IPv6 will
# usually do this
socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
socket.AI_PASSIVE)
def check_sendall_interrupted(self, with_timeout):
# socketpair() is not stricly required, but it makes things easier.
if not hasattr(signal, 'alarm') or not hasattr(socket, 'socketpair'):
self.skipTest("signal.alarm and socket.socketpair required for this test")
# Our signal handlers clobber the C errno by calling a math function
# with an invalid domain value.
def ok_handler(*args):
self.assertRaises(ValueError, math.acosh, 0)
def raising_handler(*args):
self.assertRaises(ValueError, math.acosh, 0)
1 // 0
c, s = socket.socketpair()
old_alarm = signal.signal(signal.SIGALRM, raising_handler)
try:
if with_timeout:
# Just above the one second minimum for signal.alarm
c.settimeout(1.5)
with self.assertRaises(ZeroDivisionError):
signal.alarm(1)
c.sendall(b"x" * (1024**2))
if with_timeout:
signal.signal(signal.SIGALRM, ok_handler)
signal.alarm(1)
self.assertRaises(socket.timeout, c.sendall, b"x" * (1024**2))
finally:
signal.signal(signal.SIGALRM, old_alarm)
c.close()
s.close()
def test_sendall_interrupted(self):
self.check_sendall_interrupted(False)
def test_sendall_interrupted_with_timeout(self):
self.check_sendall_interrupted(True)
def testListenBacklog0(self):
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv.bind((HOST, 0))
# backlog = 0
srv.listen(0)
srv.close()
@unittest.skipUnless(SUPPORTS_IPV6, 'IPv6 required for this test.')
def test_flowinfo(self):
self.assertRaises(OverflowError, socket.getnameinfo,
('::1',0, 0xffffffff), 0)
s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
try:
self.assertRaises(OverflowError, s.bind, ('::1', 0, -10))
finally:
s.close()
@unittest.skipUnless(thread, 'Threading required for this test.')
class BasicTCPTest(SocketConnectedTest):
def __init__(self, methodName='runTest'):
SocketConnectedTest.__init__(self, methodName=methodName)
def testRecv(self):
# Testing large receive over TCP
msg = self.cli_conn.recv(1024)
self.assertEqual(msg, MSG)
def _testRecv(self):
self.serv_conn.send(MSG)
def testOverFlowRecv(self):
# Testing receive in chunks over TCP
seg1 = self.cli_conn.recv(len(MSG) - 3)
seg2 = self.cli_conn.recv(1024)
msg = seg1 + seg2
self.assertEqual(msg, MSG)
def _testOverFlowRecv(self):
self.serv_conn.send(MSG)
def testRecvFrom(self):
# Testing large recvfrom() over TCP
msg, addr = self.cli_conn.recvfrom(1024)
self.assertEqual(msg, MSG)
def _testRecvFrom(self):
self.serv_conn.send(MSG)
def testOverFlowRecvFrom(self):
# Testing recvfrom() in chunks over TCP
seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
seg2, addr = self.cli_conn.recvfrom(1024)
msg = seg1 + seg2
self.assertEqual(msg, MSG)
def _testOverFlowRecvFrom(self):
self.serv_conn.send(MSG)
def testSendAll(self):
# Testing sendall() with a 2048 byte string over TCP
msg = ''
while 1:
read = self.cli_conn.recv(1024)
if not read:
break
msg += read
self.assertEqual(msg, 'f' * 2048)
def _testSendAll(self):
big_chunk = 'f' * 2048
self.serv_conn.sendall(big_chunk)
def testFromFd(self):
# Testing fromfd()
if not hasattr(socket, "fromfd"):
return # On Windows, this doesn't exist
fd = self.cli_conn.fileno()
sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
self.addCleanup(sock.close)
msg = sock.recv(1024)
self.assertEqual(msg, MSG)
def _testFromFd(self):
self.serv_conn.send(MSG)
def testDup(self):
# Testing dup()
sock = self.cli_conn.dup()
self.addCleanup(sock.close)
msg = sock.recv(1024)
self.assertEqual(msg, MSG)
def _testDup(self):
self.serv_conn.send(MSG)
def testShutdown(self):
# Testing shutdown()
msg = self.cli_conn.recv(1024)
self.assertEqual(msg, MSG)
# wait for _testShutdown to finish: on OS X, when the server
# closes the connection the client also becomes disconnected,
# and the client's shutdown call will fail. (Issue #4397.)
self.done.wait()
def _testShutdown(self):
self.serv_conn.send(MSG)
self.serv_conn.shutdown(2)
@unittest.skipUnless(thread, 'Threading required for this test.')
class BasicUDPTest(ThreadedUDPSocketTest):
def __init__(self, methodName='runTest'):
ThreadedUDPSocketTest.__init__(self, methodName=methodName)
def testSendtoAndRecv(self):
# Testing sendto() and Recv() over UDP
msg = self.serv.recv(len(MSG))
self.assertEqual(msg, MSG)
def _testSendtoAndRecv(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
def testRecvFrom(self):
# Testing recvfrom() over UDP
msg, addr = self.serv.recvfrom(len(MSG))
self.assertEqual(msg, MSG)
def _testRecvFrom(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
def testRecvFromNegative(self):
# Negative lengths passed to recvfrom should give ValueError.
self.assertRaises(ValueError, self.serv.recvfrom, -1)
def _testRecvFromNegative(self):
self.cli.sendto(MSG, 0, (HOST, self.port))
@unittest.skipUnless(thread, 'Threading required for this test.')
class TCPCloserTest(ThreadedTCPSocketTest):
def testClose(self):
conn, addr = self.serv.accept()
conn.close()
sd = self.cli
read, write, err = select.select([sd], [], [], 1.0)
self.assertEqual(read, [sd])
self.assertEqual(sd.recv(1), '')
def _testClose(self):
self.cli.connect((HOST, self.port))
time.sleep(1.0)
@unittest.skipUnless(thread, 'Threading required for this test.')
class BasicSocketPairTest(SocketPairTest):
def __init__(self, methodName='runTest'):
SocketPairTest.__init__(self, methodName=methodName)
def testRecv(self):
msg = self.serv.recv(1024)
self.assertEqual(msg, MSG)
def _testRecv(self):
self.cli.send(MSG)
def testSend(self):
self.serv.send(MSG)
def _testSend(self):
msg = self.cli.recv(1024)
self.assertEqual(msg, MSG)
@unittest.skipUnless(thread, 'Threading required for this test.')
class NonBlockingTCPTests(ThreadedTCPSocketTest):
def __init__(self, methodName='runTest'):
ThreadedTCPSocketTest.__init__(self, methodName=methodName)
def testSetBlocking(self):
# Testing whether set blocking works
self.serv.setblocking(0)
start = time.time()
try:
self.serv.accept()
except socket.error:
pass
end = time.time()
self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.")
def _testSetBlocking(self):
pass
def testAccept(self):
# Testing non-blocking accept
self.serv.setblocking(0)
try:
conn, addr = self.serv.accept()
except socket.error:
pass
else:
self.fail("Error trying to do non-blocking accept.")
read, write, err = select.select([self.serv], [], [])
if self.serv in read:
conn, addr = self.serv.accept()
conn.close()
else:
self.fail("Error trying to do accept after select.")
def _testAccept(self):
time.sleep(0.1)
self.cli.connect((HOST, self.port))
def testConnect(self):
# Testing non-blocking connect
conn, addr = self.serv.accept()
conn.close()
def _testConnect(self):
self.cli.settimeout(10)
self.cli.connect((HOST, self.port))
def testRecv(self):
# Testing non-blocking recv
conn, addr = self.serv.accept()
conn.setblocking(0)
try:
msg = conn.recv(len(MSG))
except socket.error:
pass
else:
self.fail("Error trying to do non-blocking recv.")
read, write, err = select.select([conn], [], [])
if conn in read:
msg = conn.recv(len(MSG))
conn.close()
self.assertEqual(msg, MSG)
else:
self.fail("Error during select call to non-blocking socket.")
def _testRecv(self):
self.cli.connect((HOST, self.port))
time.sleep(0.1)
self.cli.send(MSG)
@unittest.skipUnless(thread, 'Threading required for this test.')
class FileObjectClassTestCase(SocketConnectedTest):
bufsize = -1 # Use default buffer size
def __init__(self, methodName='runTest'):
SocketConnectedTest.__init__(self, methodName=methodName)
def setUp(self):
SocketConnectedTest.setUp(self)
self.serv_file = self.cli_conn.makefile('rb', self.bufsize)
def tearDown(self):
self.serv_file.close()
self.assertTrue(self.serv_file.closed)
self.serv_file = None
SocketConnectedTest.tearDown(self)
def clientSetUp(self):
SocketConnectedTest.clientSetUp(self)
self.cli_file = self.serv_conn.makefile('wb')
def clientTearDown(self):
self.cli_file.close()
self.assertTrue(self.cli_file.closed)
self.cli_file = None
SocketConnectedTest.clientTearDown(self)
def testSmallRead(self):
# Performing small file read test
first_seg = self.serv_file.read(len(MSG)-3)
second_seg = self.serv_file.read(3)
msg = first_seg + second_seg
self.assertEqual(msg, MSG)
def _testSmallRead(self):
self.cli_file.write(MSG)
self.cli_file.flush()
def testFullRead(self):
# read until EOF
msg = self.serv_file.read()
self.assertEqual(msg, MSG)
def _testFullRead(self):
self.cli_file.write(MSG)
self.cli_file.close()
def testUnbufferedRead(self):
# Performing unbuffered file read test
buf = ''
while 1:
char = self.serv_file.read(1)
if not char:
break
buf += char
self.assertEqual(buf, MSG)
def _testUnbufferedRead(self):
self.cli_file.write(MSG)
self.cli_file.flush()
def testReadline(self):
# Performing file readline test
line = self.serv_file.readline()
self.assertEqual(line, MSG)
def _testReadline(self):
self.cli_file.write(MSG)
self.cli_file.flush()
def testReadlineAfterRead(self):
a_baloo_is = self.serv_file.read(len("A baloo is"))
self.assertEqual("A baloo is", a_baloo_is)
_a_bear = self.serv_file.read(len(" a bear"))
self.assertEqual(" a bear", _a_bear)
line = self.serv_file.readline()
self.assertEqual("\n", line)
line = self.serv_file.readline()
self.assertEqual("A BALOO IS A BEAR.\n", line)
line = self.serv_file.readline()
self.assertEqual(MSG, line)
def _testReadlineAfterRead(self):
self.cli_file.write("A baloo is a bear\n")
self.cli_file.write("A BALOO IS A BEAR.\n")
self.cli_file.write(MSG)
self.cli_file.flush()
def testReadlineAfterReadNoNewline(self):
end_of_ = self.serv_file.read(len("End Of "))
self.assertEqual("End Of ", end_of_)
line = self.serv_file.readline()
self.assertEqual("Line", line)
def _testReadlineAfterReadNoNewline(self):
self.cli_file.write("End Of Line")
def testClosedAttr(self):
self.assertTrue(not self.serv_file.closed)
def _testClosedAttr(self):
self.assertTrue(not self.cli_file.closed)
class FileObjectInterruptedTestCase(unittest.TestCase):
"""Test that the file object correctly handles EINTR internally."""
class MockSocket(object):
def __init__(self, recv_funcs=()):
# A generator that returns callables that we'll call for each
# call to recv().
self._recv_step = iter(recv_funcs)
def recv(self, size):
return self._recv_step.next()()
@staticmethod
def _raise_eintr():
raise socket.error(errno.EINTR)
def _test_readline(self, size=-1, **kwargs):
mock_sock = self.MockSocket(recv_funcs=[
lambda : "This is the first line\nAnd the sec",
self._raise_eintr,
lambda : "ond line is here\n",
lambda : "",
])
fo = socket._fileobject(mock_sock, **kwargs)
self.assertEqual(fo.readline(size), "This is the first line\n")
self.assertEqual(fo.readline(size), "And the second line is here\n")
def _test_read(self, size=-1, **kwargs):
mock_sock = self.MockSocket(recv_funcs=[
lambda : "This is the first line\nAnd the sec",
self._raise_eintr,
lambda : "ond line is here\n",
lambda : "",
])
fo = socket._fileobject(mock_sock, **kwargs)
self.assertEqual(fo.read(size), "This is the first line\n"
"And the second line is here\n")
def test_default(self):
self._test_readline()
self._test_readline(size=100)
self._test_read()
self._test_read(size=100)
def test_with_1k_buffer(self):
self._test_readline(bufsize=1024)
self._test_readline(size=100, bufsize=1024)
self._test_read(bufsize=1024)
self._test_read(size=100, bufsize=1024)
def _test_readline_no_buffer(self, size=-1):
mock_sock = self.MockSocket(recv_funcs=[
lambda : "aa",
lambda : "\n",
lambda : "BB",
self._raise_eintr,
lambda : "bb",
lambda : "",
])
fo = socket._fileobject(mock_sock, bufsize=0)
self.assertEqual(fo.readline(size), "aa\n")
self.assertEqual(fo.readline(size), "BBbb")
def test_no_buffer(self):
self._test_readline_no_buffer()
self._test_readline_no_buffer(size=4)
self._test_read(bufsize=0)
self._test_read(size=100, bufsize=0)
class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
"""Repeat the tests from FileObjectClassTestCase with bufsize==0.
In this case (and in this case only), it should be possible to
create a file object, read a line from it, create another file
object, read another line from it, without loss of data in the
first file object's buffer. Note that httplib relies on this
when reading multiple requests from the same socket."""
bufsize = 0 # Use unbuffered mode
def testUnbufferedReadline(self):
# Read a line, create a new file object, read another line with it
line = self.serv_file.readline() # first line
self.assertEqual(line, "A. " + MSG) # first line
self.serv_file = self.cli_conn.makefile('rb', 0)
line = self.serv_file.readline() # second line
self.assertEqual(line, "B. " + MSG) # second line
def _testUnbufferedReadline(self):
self.cli_file.write("A. " + MSG)
self.cli_file.write("B. " + MSG)
self.cli_file.flush()
class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase):
bufsize = 1 # Default-buffered for reading; line-buffered for writing
class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
bufsize = 2 # Exercise the buffering code
class NetworkConnectionTest(object):
"""Prove network connection."""
def clientSetUp(self):
# We're inherited below by BasicTCPTest2, which also inherits
# BasicTCPTest, which defines self.port referenced below.
self.cli = socket.create_connection((HOST, self.port))
self.serv_conn = self.cli
class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
"""Tests that NetworkConnection does not break existing TCP functionality.
"""
class NetworkConnectionNoServer(unittest.TestCase):
class MockSocket(socket.socket):
def connect(self, *args):
raise socket.timeout('timed out')
@contextlib.contextmanager
def mocked_socket_module(self):
"""Return a socket which times out on connect"""
old_socket = socket.socket
socket.socket = self.MockSocket
try:
yield
finally:
socket.socket = old_socket
def test_connect(self):
port = test_support.find_unused_port()
cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.addCleanup(cli.close)
with self.assertRaises(socket.error) as cm:
cli.connect((HOST, port))
self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
def test_create_connection(self):
# Issue #9792: errors raised by create_connection() should have
# a proper errno attribute.
port = test_support.find_unused_port()
with self.assertRaises(socket.error) as cm:
socket.create_connection((HOST, port))
self.assertEqual(cm.exception.errno, errno.ECONNREFUSED)
def test_create_connection_timeout(self):
# Issue #9792: create_connection() should not recast timeout errors
# as generic socket errors.
with self.mocked_socket_module():
with self.assertRaises(socket.timeout):
socket.create_connection((HOST, 1234))
@unittest.skipUnless(thread, 'Threading required for this test.')
class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
def __init__(self, methodName='runTest'):
SocketTCPTest.__init__(self, methodName=methodName)
ThreadableTest.__init__(self)
def clientSetUp(self):
self.source_port = test_support.find_unused_port()
def clientTearDown(self):
self.cli.close()
self.cli = None
ThreadableTest.clientTearDown(self)
def _justAccept(self):
conn, addr = self.serv.accept()
conn.close()
testFamily = _justAccept
def _testFamily(self):
self.cli = socket.create_connection((HOST, self.port), timeout=30)
self.addCleanup(self.cli.close)
self.assertEqual(self.cli.family, 2)
testSourceAddress = _justAccept
def _testSourceAddress(self):
self.cli = socket.create_connection((HOST, self.port), timeout=30,
source_address=('', self.source_port))
self.addCleanup(self.cli.close)
self.assertEqual(self.cli.getsockname()[1], self.source_port)
# The port number being used is sufficient to show that the bind()
# call happened.
testTimeoutDefault = _justAccept
def _testTimeoutDefault(self):
# passing no explicit timeout uses socket's global default
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(42)
try:
self.cli = socket.create_connection((HOST, self.port))
self.addCleanup(self.cli.close)
finally:
socket.setdefaulttimeout(None)
self.assertEqual(self.cli.gettimeout(), 42)
testTimeoutNone = _justAccept
def _testTimeoutNone(self):
# None timeout means the same as sock.settimeout(None)
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
self.cli = socket.create_connection((HOST, self.port), timeout=None)
self.addCleanup(self.cli.close)
finally:
socket.setdefaulttimeout(None)
self.assertEqual(self.cli.gettimeout(), None)
testTimeoutValueNamed = _justAccept
def _testTimeoutValueNamed(self):
self.cli = socket.create_connection((HOST, self.port), timeout=30)
self.assertEqual(self.cli.gettimeout(), 30)
testTimeoutValueNonamed = _justAccept
def _testTimeoutValueNonamed(self):
self.cli = socket.create_connection((HOST, self.port), 30)
self.addCleanup(self.cli.close)
self.assertEqual(self.cli.gettimeout(), 30)
@unittest.skipUnless(thread, 'Threading required for this test.')
class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
def __init__(self, methodName='runTest'):
SocketTCPTest.__init__(self, methodName=methodName)
ThreadableTest.__init__(self)
def clientSetUp(self):
pass
def clientTearDown(self):
self.cli.close()
self.cli = None
ThreadableTest.clientTearDown(self)
def testInsideTimeout(self):
conn, addr = self.serv.accept()
self.addCleanup(conn.close)
time.sleep(3)
conn.send("done!")
testOutsideTimeout = testInsideTimeout
def _testInsideTimeout(self):
self.cli = sock = socket.create_connection((HOST, self.port))
data = sock.recv(5)
self.assertEqual(data, "done!")
def _testOutsideTimeout(self):
self.cli = sock = socket.create_connection((HOST, self.port), timeout=1)
self.assertRaises(socket.timeout, lambda: sock.recv(5))
class Urllib2FileobjectTest(unittest.TestCase):
# urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
# it close the socket if the close c'tor argument is true
def testClose(self):
class MockSocket:
closed = False
def flush(self): pass
def close(self): self.closed = True
# must not close unless we request it: the original use of _fileobject
# by module socket requires that the underlying socket not be closed until
# the _socketobject that created the _fileobject is closed
s = MockSocket()
f = socket._fileobject(s)
f.close()
self.assertTrue(not s.closed)
s = MockSocket()
f = socket._fileobject(s, close=True)
f.close()
self.assertTrue(s.closed)
class TCPTimeoutTest(SocketTCPTest):
def testTCPTimeout(self):
def raise_timeout(*args, **kwargs):
self.serv.settimeout(1.0)
self.serv.accept()
self.assertRaises(socket.timeout, raise_timeout,
"Error generating a timeout exception (TCP)")
def testTimeoutZero(self):
ok = False
try:
self.serv.settimeout(0.0)
foo = self.serv.accept()
except socket.timeout:
self.fail("caught timeout instead of error (TCP)")
except socket.error:
ok = True
except:
self.fail("caught unexpected exception (TCP)")
if not ok:
self.fail("accept() returned success when we did not expect it")
def testInterruptedTimeout(self):
# XXX I don't know how to do this test on MSWindows or any other
# plaform that doesn't support signal.alarm() or os.kill(), though
# the bug should have existed on all platforms.
if not hasattr(signal, "alarm"):
return # can only test on *nix
self.serv.settimeout(5.0) # must be longer than alarm
class Alarm(Exception):
pass
def alarm_handler(signal, frame):
raise Alarm
old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
try:
signal.alarm(2) # POSIX allows alarm to be up to 1 second early
try:
foo = self.serv.accept()
except socket.timeout:
self.fail("caught timeout instead of Alarm")
except Alarm:
pass
except:
self.fail("caught other exception instead of Alarm:"
" %s(%s):\n%s" %
(sys.exc_info()[:2] + (traceback.format_exc(),)))
else:
self.fail("nothing caught")
finally:
signal.alarm(0) # shut off alarm
except Alarm:
self.fail("got Alarm in wrong place")
finally:
# no alarm can be pending. Safe to restore old handler.
signal.signal(signal.SIGALRM, old_alarm)
class UDPTimeoutTest(SocketUDPTest):
def testUDPTimeout(self):
def raise_timeout(*args, **kwargs):
self.serv.settimeout(1.0)
self.serv.recv(1024)
self.assertRaises(socket.timeout, raise_timeout,
"Error generating a timeout exception (UDP)")
def testTimeoutZero(self):
ok = False
try:
self.serv.settimeout(0.0)
foo = self.serv.recv(1024)
except socket.timeout:
self.fail("caught timeout instead of error (UDP)")
except socket.error:
ok = True
except:
self.fail("caught unexpected exception (UDP)")
if not ok:
self.fail("recv() returned success when we did not expect it")
class TestExceptions(unittest.TestCase):
def testExceptionTree(self):
self.assertTrue(issubclass(socket.error, Exception))
self.assertTrue(issubclass(socket.herror, socket.error))
self.assertTrue(issubclass(socket.gaierror, socket.error))
self.assertTrue(issubclass(socket.timeout, socket.error))
class TestLinuxAbstractNamespace(unittest.TestCase):
UNIX_PATH_MAX = 108
def testLinuxAbstractNamespace(self):
address = "\x00python-test-hello\x00\xff"
s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s1.bind(address)
s1.listen(1)
s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s2.connect(s1.getsockname())
s1.accept()
self.assertEqual(s1.getsockname(), address)
self.assertEqual(s2.getpeername(), address)
def testMaxName(self):
address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1)
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.bind(address)
self.assertEqual(s.getsockname(), address)
def testNameOverflow(self):
address = "\x00" + "h" * self.UNIX_PATH_MAX
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.assertRaises(socket.error, s.bind, address)
@unittest.skipUnless(thread, 'Threading required for this test.')
class BufferIOTest(SocketConnectedTest):
"""
Test the buffer versions of socket.recv() and socket.send().
"""
def __init__(self, methodName='runTest'):
SocketConnectedTest.__init__(self, methodName=methodName)
def testRecvIntoArray(self):
buf = array.array('c', ' '*1024)
nbytes = self.cli_conn.recv_into(buf)
self.assertEqual(nbytes, len(MSG))
msg = buf.tostring()[:len(MSG)]
self.assertEqual(msg, MSG)
def _testRecvIntoArray(self):
with test_support.check_py3k_warnings():
buf = buffer(MSG)
self.serv_conn.send(buf)
def testRecvIntoBytearray(self):
buf = bytearray(1024)
nbytes = self.cli_conn.recv_into(buf)
self.assertEqual(nbytes, len(MSG))
msg = buf[:len(MSG)]
self.assertEqual(msg, MSG)
_testRecvIntoBytearray = _testRecvIntoArray
def testRecvIntoMemoryview(self):
buf = bytearray(1024)
nbytes = self.cli_conn.recv_into(memoryview(buf))
self.assertEqual(nbytes, len(MSG))
msg = buf[:len(MSG)]
self.assertEqual(msg, MSG)
_testRecvIntoMemoryview = _testRecvIntoArray
def testRecvFromIntoArray(self):
buf = array.array('c', ' '*1024)
nbytes, addr = self.cli_conn.recvfrom_into(buf)
self.assertEqual(nbytes, len(MSG))
msg = buf.tostring()[:len(MSG)]
self.assertEqual(msg, MSG)
def _testRecvFromIntoArray(self):
with test_support.check_py3k_warnings():
buf = buffer(MSG)
self.serv_conn.send(buf)
def testRecvFromIntoBytearray(self):
buf = bytearray(1024)
nbytes, addr = self.cli_conn.recvfrom_into(buf)
self.assertEqual(nbytes, len(MSG))
msg = buf[:len(MSG)]
self.assertEqual(msg, MSG)
_testRecvFromIntoBytearray = _testRecvFromIntoArray
def testRecvFromIntoMemoryview(self):
buf = bytearray(1024)
nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf))
self.assertEqual(nbytes, len(MSG))
msg = buf[:len(MSG)]
self.assertEqual(msg, MSG)
_testRecvFromIntoMemoryview = _testRecvFromIntoArray
TIPC_STYPE = 2000
TIPC_LOWER = 200
TIPC_UPPER = 210
def isTipcAvailable():
"""Check if the TIPC module is loaded
The TIPC module is not loaded automatically on Ubuntu and probably
other Linux distros.
"""
if not hasattr(socket, "AF_TIPC"):
return False
if not os.path.isfile("/proc/modules"):
return False
with open("/proc/modules") as f:
for line in f:
if line.startswith("tipc "):
return True
if test_support.verbose:
print "TIPC module is not loaded, please 'sudo modprobe tipc'"
return False
class TIPCTest (unittest.TestCase):
def testRDM(self):
srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM)
srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
TIPC_LOWER, TIPC_UPPER)
srv.bind(srvaddr)
sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
cli.sendto(MSG, sendaddr)
msg, recvaddr = srv.recvfrom(1024)
self.assertEqual(cli.getsockname(), recvaddr)
self.assertEqual(msg, MSG)
class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
def __init__(self, methodName = 'runTest'):
unittest.TestCase.__init__(self, methodName = methodName)
ThreadableTest.__init__(self)
def setUp(self):
self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,
TIPC_LOWER, TIPC_UPPER)
self.srv.bind(srvaddr)
self.srv.listen(5)
self.serverExplicitReady()
self.conn, self.connaddr = self.srv.accept()
def clientSetUp(self):
# The is a hittable race between serverExplicitReady() and the
# accept() call; sleep a little while to avoid it, otherwise
# we could get an exception
time.sleep(0.1)
self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM)
addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE,
TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0)
self.cli.connect(addr)
self.cliaddr = self.cli.getsockname()
def testStream(self):
msg = self.conn.recv(1024)
self.assertEqual(msg, MSG)
self.assertEqual(self.cliaddr, self.connaddr)
def _testStream(self):
self.cli.send(MSG)
self.cli.close()
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
UDPTimeoutTest ]
tests.extend([
NonBlockingTCPTests,
FileObjectClassTestCase,
FileObjectInterruptedTestCase,
UnbufferedFileObjectClassTestCase,
LineBufferedFileObjectClassTestCase,
SmallBufferedFileObjectClassTestCase,
Urllib2FileobjectTest,
NetworkConnectionNoServer,
NetworkConnectionAttributesTest,
NetworkConnectionBehaviourTest,
])
if hasattr(socket, "socketpair"):
tests.append(BasicSocketPairTest)
if sys.platform == 'linux2':
tests.append(TestLinuxAbstractNamespace)
if isTipcAvailable():
tests.append(TIPCTest)
tests.append(TIPCThreadableTest)
thread_info = test_support.threading_setup()
test_support.run_unittest(*tests)
test_support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()
# Test the support for SSL and sockets
import sys
import unittest
from test import test_support
import asyncore
import socket
import select
import time
import gc
import os
import errno
import pprint
import urllib, urlparse
import traceback
import weakref
import functools
import platform
from BaseHTTPServer import HTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
ssl = test_support.import_module("ssl")
HOST = test_support.HOST
CERTFILE = None
SVN_PYTHON_ORG_ROOT_CERT = None
def handle_error(prefix):
exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
if test_support.verbose:
sys.stdout.write(prefix + exc_format)
class BasicTests(unittest.TestCase):
def test_sslwrap_simple(self):
# A crude test for the legacy API
try:
ssl.sslwrap_simple(socket.socket(socket.AF_INET))
except IOError, e:
if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
pass
else:
raise
try:
ssl.sslwrap_simple(socket.socket(socket.AF_INET)._sock)
except IOError, e:
if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
pass
else:
raise
# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2
def skip_if_broken_ubuntu_ssl(func):
if hasattr(ssl, 'PROTOCOL_SSLv2'):
# We need to access the lower-level wrapper in order to create an
# implicit SSL context without trying to connect or listen.
try:
import _ssl
except ImportError:
# The returned function won't get executed, just ignore the error
pass
@functools.wraps(func)
def f(*args, **kwargs):
try:
s = socket.socket(socket.AF_INET)
_ssl.sslwrap(s._sock, 0, None, None,
ssl.CERT_NONE, ssl.PROTOCOL_SSLv2, None, None)
except ssl.SSLError as e:
if (ssl.OPENSSL_VERSION_INFO == (0, 9, 8, 15, 15) and
platform.linux_distribution() == ('debian', 'squeeze/sid', '')
and 'Invalid SSL protocol variant specified' in str(e)):
raise unittest.SkipTest("Patched Ubuntu OpenSSL breaks behaviour")
return func(*args, **kwargs)
return f
else:
return func
class BasicSocketTests(unittest.TestCase):
def test_constants(self):
#ssl.PROTOCOL_SSLv2
ssl.PROTOCOL_SSLv23
ssl.PROTOCOL_SSLv3
ssl.PROTOCOL_TLSv1
ssl.CERT_NONE
ssl.CERT_OPTIONAL
ssl.CERT_REQUIRED
def test_random(self):
v = ssl.RAND_status()
if test_support.verbose:
sys.stdout.write("\n RAND_status is %d (%s)\n"
% (v, (v and "sufficient randomness") or
"insufficient randomness"))
try:
ssl.RAND_egd(1)
except TypeError:
pass
else:
print "didn't raise TypeError"
ssl.RAND_add("this is a random string", 75.0)
def test_parse_cert(self):
# note that this uses an 'unofficial' function in _ssl.c,
# provided solely for this test, to exercise the certificate
# parsing code
p = ssl._ssl._test_decode_cert(CERTFILE, False)
if test_support.verbose:
sys.stdout.write("\n" + pprint.pformat(p) + "\n")
self.assertEqual(p['subject'],
((('countryName', u'US'),),
(('stateOrProvinceName', u'Delaware'),),
(('localityName', u'Wilmington'),),
(('organizationName', u'Python Software Foundation'),),
(('organizationalUnitName', u'SSL'),),
(('commonName', u'somemachine.python.org'),)),
)
# Issue #13034: the subjectAltName in some certificates
# (notably projects.developer.nokia.com:443) wasn't parsed
p = ssl._ssl._test_decode_cert(NOKIACERT)
if test_support.verbose:
sys.stdout.write("\n" + pprint.pformat(p) + "\n")
self.assertEqual(p['subjectAltName'],
(('DNS', 'projects.developer.nokia.com'),
('DNS', 'projects.forum.nokia.com'))
)
def test_DER_to_PEM(self):
with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f:
pem = f.read()
d1 = ssl.PEM_cert_to_DER_cert(pem)
p2 = ssl.DER_cert_to_PEM_cert(d1)
d2 = ssl.PEM_cert_to_DER_cert(p2)
self.assertEqual(d1, d2)
if not p2.startswith(ssl.PEM_HEADER + '\n'):
self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
def test_openssl_version(self):
n = ssl.OPENSSL_VERSION_NUMBER
t = ssl.OPENSSL_VERSION_INFO
s = ssl.OPENSSL_VERSION
self.assertIsInstance(n, (int, long))
self.assertIsInstance(t, tuple)
self.assertIsInstance(s, str)
# Some sanity checks follow
# >= 0.9
self.assertGreaterEqual(n, 0x900000)
# < 2.0
self.assertLess(n, 0x20000000)
major, minor, fix, patch, status = t
self.assertGreaterEqual(major, 0)
self.assertLess(major, 2)
self.assertGreaterEqual(minor, 0)
self.assertLess(minor, 256)
self.assertGreaterEqual(fix, 0)
self.assertLess(fix, 256)
self.assertGreaterEqual(patch, 0)
self.assertLessEqual(patch, 26)
self.assertGreaterEqual(status, 0)
self.assertLessEqual(status, 15)
# Version string as returned by OpenSSL, the format might change
self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
(s, t))
def test_ciphers(self):
if not test_support.is_resource_enabled('network'):
return
remote = ("svn.python.org", 443)
with test_support.transient_internet(remote[0]):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE, ciphers="ALL")
s.connect(remote)
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")
s.connect(remote)
# Error checking occurs when connecting, because the SSL context
# isn't created before.
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"):
s.connect(remote)
@test_support.cpython_only
def test_refcycle(self):
# Issue #7943: an SSL object doesn't create reference cycles with
# itself.
s = socket.socket(socket.AF_INET)
ss = ssl.wrap_socket(s)
wr = weakref.ref(ss)
del ss
self.assertEqual(wr(), None)
def test_wrapped_unconnected(self):
# The _delegate_methods in socket.py are correctly delegated to by an
# unconnected SSLSocket, so they will raise a socket.error rather than
# something unexpected like TypeError.
s = socket.socket(socket.AF_INET)
ss = ssl.wrap_socket(s)
self.assertRaises(socket.error, ss.recv, 1)
self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
self.assertRaises(socket.error, ss.recvfrom, 1)
self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
self.assertRaises(socket.error, ss.send, b'x')
self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
class NetworkedTests(unittest.TestCase):
def test_connect(self):
with test_support.transient_internet("svn.python.org"):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_NONE)
s.connect(("svn.python.org", 443))
c = s.getpeercert()
if c:
self.fail("Peer cert %s shouldn't be here!")
s.close()
# this should fail because we have no verification certs
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED)
try:
s.connect(("svn.python.org", 443))
except ssl.SSLError:
pass
finally:
s.close()
# this should succeed because we specify the root cert
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
try:
s.connect(("svn.python.org", 443))
finally:
s.close()
def test_connect_ex(self):
# Issue #11326: check connect_ex() implementation
with test_support.transient_internet("svn.python.org"):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
try:
self.assertEqual(0, s.connect_ex(("svn.python.org", 443)))
self.assertTrue(s.getpeercert())
finally:
s.close()
def test_non_blocking_connect_ex(self):
# Issue #11326: non-blocking connect_ex() should allow handshake
# to proceed after the socket gets ready.
with test_support.transient_internet("svn.python.org"):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=SVN_PYTHON_ORG_ROOT_CERT,
do_handshake_on_connect=False)
try:
s.setblocking(False)
rc = s.connect_ex(('svn.python.org', 443))
# EWOULDBLOCK under Windows, EINPROGRESS elsewhere
self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
# Wait for connect to finish
select.select([], [s], [], 5.0)
# Non-blocking handshake
while True:
try:
s.do_handshake()
break
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
select.select([s], [], [], 5.0)
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
select.select([], [s], [], 5.0)
else:
raise
# SSL established
self.assertTrue(s.getpeercert())
finally:
s.close()
@unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
def test_makefile_close(self):
# Issue #5238: creating a file-like object with makefile() shouldn't
# delay closing the underlying "real socket" (here tested with its
# file descriptor, hence skipping the test under Windows).
with test_support.transient_internet("svn.python.org"):
ss = ssl.wrap_socket(socket.socket(socket.AF_INET))
ss.connect(("svn.python.org", 443))
fd = ss.fileno()
f = ss.makefile()
f.close()
# The fd is still open
os.read(fd, 0)
# Closing the SSL socket should close the fd too
ss.close()
gc.collect()
with self.assertRaises(OSError) as e:
os.read(fd, 0)
self.assertEqual(e.exception.errno, errno.EBADF)
def test_non_blocking_handshake(self):
with test_support.transient_internet("svn.python.org"):
s = socket.socket(socket.AF_INET)
s.connect(("svn.python.org", 443))
s.setblocking(False)
s = ssl.wrap_socket(s,
cert_reqs=ssl.CERT_NONE,
do_handshake_on_connect=False)
count = 0
while True:
try:
count += 1
s.do_handshake()
break
except ssl.SSLError, err:
if err.args[0] == ssl.SSL_ERROR_WANT_READ:
select.select([s], [], [])
elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
select.select([], [s], [])
else:
raise
s.close()
if test_support.verbose:
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
def test_get_server_certificate(self):
with test_support.transient_internet("svn.python.org"):
pem = ssl.get_server_certificate(("svn.python.org", 443))
if not pem:
self.fail("No server certificate on svn.python.org:443!")
try:
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
except ssl.SSLError:
#should fail
pass
else:
self.fail("Got server certificate %s for svn.python.org!" % pem)
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
if not pem:
self.fail("No server certificate on svn.python.org:443!")
if test_support.verbose:
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
def test_algorithms(self):
# Issue #8484: all algorithms should be available when verifying a
# certificate.
# SHA256 was added in OpenSSL 0.9.8
if ssl.OPENSSL_VERSION_INFO < (0, 9, 8, 0, 15):
self.skipTest("SHA256 not available on %r" % ssl.OPENSSL_VERSION)
# NOTE: https://sha256.tbs-internet.com is another possible test host
remote = ("sha256.tbs-internet.com", 443)
sha256_cert = os.path.join(os.path.dirname(__file__), "sha256.pem")
with test_support.transient_internet("sha256.tbs-internet.com"):
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
cert_reqs=ssl.CERT_REQUIRED,
ca_certs=sha256_cert,)
try:
s.connect(remote)
if test_support.verbose:
sys.stdout.write("\nCipher with %r is %r\n" %
(remote, s.cipher()))
sys.stdout.write("Certificate is:\n%s\n" %
pprint.pformat(s.getpeercert()))
finally:
s.close()
try:
import threading
except ImportError:
_have_threads = False
else:
_have_threads = True
class ThreadedEchoServer(threading.Thread):
class ConnectionHandler(threading.Thread):
"""A mildly complicated class, because we want it to work both
with and without the SSL wrapper around the socket connection, so
that we can test the STARTTLS functionality."""
def __init__(self, server, connsock):
self.server = server
self.running = False
self.sock = connsock
self.sock.setblocking(1)
self.sslconn = None
threading.Thread.__init__(self)
self.daemon = True
def show_conn_details(self):
if self.server.certreqs == ssl.CERT_REQUIRED:
cert = self.sslconn.getpeercert()
if test_support.verbose and self.server.chatty:
sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
cert_binary = self.sslconn.getpeercert(True)
if test_support.verbose and self.server.chatty:
sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
cipher = self.sslconn.cipher()
if test_support.verbose and self.server.chatty:
sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
def wrap_conn(self):
try:
self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.server.certificate,
ssl_version=self.server.protocol,
ca_certs=self.server.cacerts,
cert_reqs=self.server.certreqs,
ciphers=self.server.ciphers)
except ssl.SSLError as e:
# XXX Various errors can have happened here, for example
# a mismatching protocol version, an invalid certificate,
# or a low-level bug. This should be made more discriminating.
self.server.conn_errors.append(e)
if self.server.chatty:
handle_error("\n server: bad connection attempt from " +
str(self.sock.getpeername()) + ":\n")
self.close()
self.running = False
self.server.stop()
return False
else:
return True
def read(self):
if self.sslconn:
return self.sslconn.read()
else:
return self.sock.recv(1024)
def write(self, bytes):
if self.sslconn:
return self.sslconn.write(bytes)
else:
return self.sock.send(bytes)
def close(self):
if self.sslconn:
self.sslconn.close()
else:
self.sock._sock.close()
def run(self):
self.running = True
if not self.server.starttls_server:
if isinstance(self.sock, ssl.SSLSocket):
self.sslconn = self.sock
elif not self.wrap_conn():
return
self.show_conn_details()
while self.running:
try:
msg = self.read()
if not msg:
# eof, so quit this handler
self.running = False
self.close()
elif msg.strip() == 'over':
if test_support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: client closed connection\n")
self.close()
return
elif self.server.starttls_server and msg.strip() == 'STARTTLS':
if test_support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
self.write("OK\n")
if not self.wrap_conn():
return
elif self.server.starttls_server and self.sslconn and msg.strip() == 'ENDTLS':
if test_support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
self.write("OK\n")
self.sslconn.unwrap()
self.sslconn = None
if test_support.verbose and self.server.connectionchatty:
sys.stdout.write(" server: connection is now unencrypted...\n")
else:
if (test_support.verbose and
self.server.connectionchatty):
ctype = (self.sslconn and "encrypted") or "unencrypted"
sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n"
% (repr(msg), ctype, repr(msg.lower()), ctype))
self.write(msg.lower())
except ssl.SSLError:
if self.server.chatty:
handle_error("Test server failure:\n")
self.close()
self.running = False
# normally, we'd just stop here, but for the test
# harness, we want to stop the server
self.server.stop()
def __init__(self, certificate, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
wrap_accepting_socket=False, ciphers=None):
if ssl_version is None:
ssl_version = ssl.PROTOCOL_TLSv1
if certreqs is None:
certreqs = ssl.CERT_NONE
self.certificate = certificate
self.protocol = ssl_version
self.certreqs = certreqs
self.cacerts = cacerts
self.ciphers = ciphers
self.chatty = chatty
self.connectionchatty = connectionchatty
self.starttls_server = starttls_server
self.sock = socket.socket()
self.flag = None
if wrap_accepting_socket:
self.sock = ssl.wrap_socket(self.sock, server_side=True,
certfile=self.certificate,
cert_reqs = self.certreqs,
ca_certs = self.cacerts,
ssl_version = self.protocol,
ciphers = self.ciphers)
if test_support.verbose and self.chatty:
sys.stdout.write(' server: wrapped server socket as %s\n' % str(self.sock))
self.port = test_support.bind_port(self.sock)
self.active = False
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
def __enter__(self):
self.start(threading.Event())
self.flag.wait()
return self
def __exit__(self, *args):
self.stop()
self.join()
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run(self):
self.sock.settimeout(0.05)
self.sock.listen(5)
self.active = True
if self.flag:
# signal an event
self.flag.set()
while self.active:
try:
newconn, connaddr = self.sock.accept()
if test_support.verbose and self.chatty:
sys.stdout.write(' server: new connection from '
+ str(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn)
handler.start()
handler.join()
except socket.timeout:
pass
except KeyboardInterrupt:
self.stop()
self.sock.close()
def stop(self):
self.active = False
class AsyncoreEchoServer(threading.Thread):
class EchoServer(asyncore.dispatcher):
class ConnectionHandler(asyncore.dispatcher_with_send):
def __init__(self, conn, certfile):
asyncore.dispatcher_with_send.__init__(self, conn)
self.socket = ssl.wrap_socket(conn, server_side=True,
certfile=certfile,
do_handshake_on_connect=False)
self._ssl_accepting = True
def readable(self):
if isinstance(self.socket, ssl.SSLSocket):
while self.socket.pending() > 0:
self.handle_read_event()
return True
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
except ssl.SSLError, err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close()
raise
except socket.error, err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
self._ssl_accepting = False
def handle_read(self):
if self._ssl_accepting:
self._do_ssl_handshake()
else:
data = self.recv(1024)
if data and data.strip() != 'over':
self.send(data.lower())
def handle_close(self):
self.close()
if test_support.verbose:
sys.stdout.write(" server: closed connection %s\n" % self.socket)
def handle_error(self):
raise
def __init__(self, certfile):
self.certfile = certfile
asyncore.dispatcher.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self.port = test_support.bind_port(self.socket)
self.listen(5)
def handle_accept(self):
sock_obj, addr = self.accept()
if test_support.verbose:
sys.stdout.write(" server: new connection from %s:%s\n" %addr)
self.ConnectionHandler(sock_obj, self.certfile)
def handle_error(self):
raise
def __init__(self, certfile):
self.flag = None
self.active = False
self.server = self.EchoServer(certfile)
self.port = self.server.port
threading.Thread.__init__(self)
self.daemon = True
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
def __enter__(self):
self.start(threading.Event())
self.flag.wait()
return self
def __exit__(self, *args):
if test_support.verbose:
sys.stdout.write(" cleanup: stopping server.\n")
self.stop()
if test_support.verbose:
sys.stdout.write(" cleanup: joining server thread.\n")
self.join()
if test_support.verbose:
sys.stdout.write(" cleanup: successfully joined.\n")
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run(self):
self.active = True
if self.flag:
self.flag.set()
while self.active:
asyncore.loop(0.05)
def stop(self):
self.active = False
self.server.close()
class SocketServerHTTPSServer(threading.Thread):
class HTTPSServer(HTTPServer):
def __init__(self, server_address, RequestHandlerClass, certfile):
HTTPServer.__init__(self, server_address, RequestHandlerClass)
# we assume the certfile contains both private key and certificate
self.certfile = certfile
self.allow_reuse_address = True
def __str__(self):
return ('<%s %s:%s>' %
(self.__class__.__name__,
self.server_name,
self.server_port))
def get_request(self):
# override this to wrap socket with SSL
sock, addr = self.socket.accept()
sslconn = ssl.wrap_socket(sock, server_side=True,
certfile=self.certfile)
return sslconn, addr
class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
# need to override translate_path to get a known root,
# instead of using os.curdir, since the test could be
# run from anywhere
server_version = "TestHTTPS/1.0"
root = None
def translate_path(self, path):
"""Translate a /-separated PATH to the local filename syntax.
Components that mean special things to the local file system
(e.g. drive or directory names) are ignored. (XXX They should
probably be diagnosed.)
"""
# abandon query parameters
path = urlparse.urlparse(path)[2]
path = os.path.normpath(urllib.unquote(path))
words = path.split('/')
words = filter(None, words)
path = self.root
for word in words:
drive, word = os.path.splitdrive(word)
head, word = os.path.split(word)
if word in self.root: continue
path = os.path.join(path, word)
return path
def log_message(self, format, *args):
# we override this to suppress logging unless "verbose"
if test_support.verbose:
sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
(self.server.server_address,
self.server.server_port,
self.request.cipher(),
self.log_date_time_string(),
format%args))
def __init__(self, certfile):
self.flag = None
self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
self.server = self.HTTPSServer(
(HOST, 0), self.RootedHTTPRequestHandler, certfile)
self.port = self.server.server_port
threading.Thread.__init__(self)
self.daemon = True
def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.server)
def start(self, flag=None):
self.flag = flag
threading.Thread.start(self)
def run(self):
if self.flag:
self.flag.set()
self.server.serve_forever(0.05)
def stop(self):
self.server.shutdown()
def bad_cert_test(certfile):
"""
Launch a server with CERT_REQUIRED, and check that trying to
connect to it with the given client certificate fails.
"""
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_REQUIRED,
cacerts=CERTFILE, chatty=False)
with server:
try:
s = ssl.wrap_socket(socket.socket(),
certfile=certfile,
ssl_version=ssl.PROTOCOL_TLSv1)
s.connect((HOST, server.port))
except ssl.SSLError, x:
if test_support.verbose:
sys.stdout.write("\nSSLError is %s\n" % x[1])
except socket.error, x:
if test_support.verbose:
sys.stdout.write("\nsocket.error is %s\n" % x[1])
else:
raise AssertionError("Use of invalid cert should have failed!")
def server_params_test(certfile, protocol, certreqs, cacertsfile,
client_certfile, client_protocol=None, indata="FOO\n",
ciphers=None, chatty=True, connectionchatty=False,
wrap_accepting_socket=False):
"""
Launch a server, connect a client to it and try various reads
and writes.
"""
server = ThreadedEchoServer(certfile,
certreqs=certreqs,
ssl_version=protocol,
cacerts=cacertsfile,
ciphers=ciphers,
chatty=chatty,
connectionchatty=connectionchatty,
wrap_accepting_socket=wrap_accepting_socket)
with server:
# try to connect
if client_protocol is None:
client_protocol = protocol
s = ssl.wrap_socket(socket.socket(),
certfile=client_certfile,
ca_certs=cacertsfile,
ciphers=ciphers,
cert_reqs=certreqs,
ssl_version=client_protocol)
s.connect((HOST, server.port))
for arg in [indata, bytearray(indata), memoryview(indata)]:
if connectionchatty:
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % (repr(arg)))
s.write(arg)
outdata = s.read()
if connectionchatty:
if test_support.verbose:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
raise AssertionError(
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata),20)], len(outdata),
indata[:min(len(indata),20)].lower(), len(indata)))
s.write("over\n")
if connectionchatty:
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
s.close()
def try_protocol_combo(server_protocol,
client_protocol,
expect_success,
certsreqs=None):
if certsreqs is None:
certsreqs = ssl.CERT_NONE
certtype = {
ssl.CERT_NONE: "CERT_NONE",
ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
ssl.CERT_REQUIRED: "CERT_REQUIRED",
}[certsreqs]
if test_support.verbose:
formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
sys.stdout.write(formatstr %
(ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol),
certtype))
try:
# NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
# will send an SSLv3 hello (rather than SSLv2) starting from
# OpenSSL 1.0.0 (see issue #8322).
server_params_test(CERTFILE, server_protocol, certsreqs,
CERTFILE, CERTFILE, client_protocol,
ciphers="ALL", chatty=False)
# Protocol mismatch can result in either an SSLError, or a
# "Connection reset by peer" error.
except ssl.SSLError:
if expect_success:
raise
except socket.error as e:
if expect_success or e.errno != errno.ECONNRESET:
raise
else:
if not expect_success:
raise AssertionError(
"Client protocol %s succeeded with server protocol %s!"
% (ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol)))
class ThreadedTests(unittest.TestCase):
def test_rude_shutdown(self):
"""A brutal shutdown of an SSL server should raise an IOError
in the client when attempting handshake.
"""
listener_ready = threading.Event()
listener_gone = threading.Event()
s = socket.socket()
port = test_support.bind_port(s, HOST)
# `listener` runs in a thread. It sits in an accept() until
# the main thread connects. Then it rudely closes the socket,
# and sets Event `listener_gone` to let the main thread know
# the socket is gone.
def listener():
s.listen(5)
listener_ready.set()
s.accept()
s.close()
listener_gone.set()
def connector():
listener_ready.wait()
c = socket.socket()
c.connect((HOST, port))
listener_gone.wait()
try:
ssl_sock = ssl.wrap_socket(c)
except IOError:
pass
else:
self.fail('connecting to closed SSL socket should have failed')
t = threading.Thread(target=listener)
t.start()
try:
connector()
finally:
t.join()
@skip_if_broken_ubuntu_ssl
def test_echo(self):
"""Basic test of an SSL client connecting to a server"""
if test_support.verbose:
sys.stdout.write("\n")
server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1,
chatty=True, connectionchatty=True)
def test_getpeercert(self):
if test_support.verbose:
sys.stdout.write("\n")
s2 = socket.socket()
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_SSLv23,
cacerts=CERTFILE,
chatty=False)
with server:
s = ssl.wrap_socket(socket.socket(),
certfile=CERTFILE,
ca_certs=CERTFILE,
cert_reqs=ssl.CERT_REQUIRED,
ssl_version=ssl.PROTOCOL_SSLv23)
s.connect((HOST, server.port))
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
cipher = s.cipher()
if test_support.verbose:
sys.stdout.write(pprint.pformat(cert) + '\n')
sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
if 'subject' not in cert:
self.fail("No subject field in certificate: %s." %
pprint.pformat(cert))
if ((('organizationName', 'Python Software Foundation'),)
not in cert['subject']):
self.fail(
"Missing or invalid 'organizationName' field in certificate subject; "
"should be 'Python Software Foundation'.")
s.close()
def test_empty_cert(self):
"""Connecting with an empty cert file"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"nullcert.pem"))
def test_malformed_cert(self):
"""Connecting with a badly formatted certificate (syntax error)"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"badcert.pem"))
def test_nonexisting_cert(self):
"""Connecting with a non-existing cert file"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"wrongcert.pem"))
def test_malformed_key(self):
"""Connecting with a badly formatted key (syntax error)"""
bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
"badkey.pem"))
@skip_if_broken_ubuntu_ssl
def test_protocol_sslv2(self):
"""Connecting to an SSLv2 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
if not hasattr(ssl, 'PROTOCOL_SSLv2'):
self.skipTest("PROTOCOL_SSLv2 needed")
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
@skip_if_broken_ubuntu_ssl
def test_protocol_sslv23(self):
"""Connecting to an SSLv23 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
@skip_if_broken_ubuntu_ssl
def test_protocol_sslv3(self):
"""Connecting to an SSLv3 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
@skip_if_broken_ubuntu_ssl
def test_protocol_tlsv1(self):
"""Connecting to a TLSv1 server with various client options"""
if test_support.verbose:
sys.stdout.write("\n")
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
if hasattr(ssl, 'PROTOCOL_SSLv2'):
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
server = ThreadedEchoServer(CERTFILE,
ssl_version=ssl.PROTOCOL_TLSv1,
starttls_server=True,
chatty=True,
connectionchatty=True)
wrapped = False
with server:
s = socket.socket()
s.setblocking(1)
s.connect((HOST, server.port))
if test_support.verbose:
sys.stdout.write("\n")
for indata in msgs:
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % repr(indata))
if wrapped:
conn.write(indata)
outdata = conn.read()
else:
s.send(indata)
outdata = s.recv(1024)
if (indata == "STARTTLS" and
outdata.strip().lower().startswith("ok")):
# STARTTLS ok, switch to secure mode
if test_support.verbose:
sys.stdout.write(
" client: read %s from server, starting TLS...\n"
% repr(outdata))
conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
wrapped = True
elif (indata == "ENDTLS" and
outdata.strip().lower().startswith("ok")):
# ENDTLS ok, switch back to clear text
if test_support.verbose:
sys.stdout.write(
" client: read %s from server, ending TLS...\n"
% repr(outdata))
s = conn.unwrap()
wrapped = False
else:
if test_support.verbose:
sys.stdout.write(
" client: read %s from server\n" % repr(outdata))
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
if wrapped:
conn.write("over\n")
else:
s.send("over\n")
s.close()
def test_socketserver(self):
"""Using a SocketServer to create and manage SSL connections."""
server = SocketServerHTTPSServer(CERTFILE)
flag = threading.Event()
server.start(flag)
# wait for it to start
flag.wait()
# try to connect
try:
if test_support.verbose:
sys.stdout.write('\n')
with open(CERTFILE, 'rb') as f:
d1 = f.read()
d2 = ''
# now fetch the same data from the HTTPS server
url = 'https://127.0.0.1:%d/%s' % (
server.port, os.path.split(CERTFILE)[1])
with test_support.check_py3k_warnings():
f = urllib.urlopen(url)
dlen = f.info().getheader("content-length")
if dlen and (int(dlen) > 0):
d2 = f.read(int(dlen))
if test_support.verbose:
sys.stdout.write(
" client: read %d bytes from remote server '%s'\n"
% (len(d2), server))
f.close()
self.assertEqual(d1, d2)
finally:
server.stop()
server.join()
def test_wrapped_accept(self):
"""Check the accept() method on SSL sockets."""
if test_support.verbose:
sys.stdout.write("\n")
server_params_test(CERTFILE, ssl.PROTOCOL_SSLv23, ssl.CERT_REQUIRED,
CERTFILE, CERTFILE, ssl.PROTOCOL_SSLv23,
chatty=True, connectionchatty=True,
wrap_accepting_socket=True)
def test_asyncore_server(self):
"""Check the example asyncore integration."""
indata = "TEST MESSAGE of mixed case\n"
if test_support.verbose:
sys.stdout.write("\n")
server = AsyncoreEchoServer(CERTFILE)
with server:
s = ssl.wrap_socket(socket.socket())
s.connect(('127.0.0.1', server.port))
if test_support.verbose:
sys.stdout.write(
" client: sending %s...\n" % (repr(indata)))
s.write(indata)
outdata = s.read()
if test_support.verbose:
sys.stdout.write(" client: read %s\n" % repr(outdata))
if outdata != indata.lower():
self.fail(
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
% (outdata[:min(len(outdata),20)], len(outdata),
indata[:min(len(indata),20)].lower(), len(indata)))
s.write("over\n")
if test_support.verbose:
sys.stdout.write(" client: closing connection.\n")
s.close()
def test_recv_send(self):
"""Test recv(), send() and friends."""
if test_support.verbose:
sys.stdout.write("\n")
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_TLSv1,
cacerts=CERTFILE,
chatty=True,
connectionchatty=False)
with server:
s = ssl.wrap_socket(socket.socket(),
server_side=False,
certfile=CERTFILE,
ca_certs=CERTFILE,
cert_reqs=ssl.CERT_NONE,
ssl_version=ssl.PROTOCOL_TLSv1)
s.connect((HOST, server.port))
# helper methods for standardising recv* method signatures
def _recv_into():
b = bytearray("\0"*100)
count = s.recv_into(b)
return b[:count]
def _recvfrom_into():
b = bytearray("\0"*100)
count, addr = s.recvfrom_into(b)
return b[:count]
# (name, method, whether to expect success, *args)
send_methods = [
('send', s.send, True, []),
('sendto', s.sendto, False, ["some.address"]),
('sendall', s.sendall, True, []),
]
recv_methods = [
('recv', s.recv, True, []),
('recvfrom', s.recvfrom, False, ["some.address"]),
('recv_into', _recv_into, True, []),
('recvfrom_into', _recvfrom_into, False, []),
]
data_prefix = u"PREFIX_"
for meth_name, send_meth, expect_success, args in send_methods:
indata = data_prefix + meth_name
try:
send_meth(indata.encode('ASCII', 'strict'), *args)
outdata = s.read()
outdata = outdata.decode('ASCII', 'strict')
if outdata != indata.lower():
self.fail(
"While sending with <<%s>> bad data "
"<<%r>> (%d) received; "
"expected <<%r>> (%d)\n" % (
meth_name, outdata[:20], len(outdata),
indata[:20], len(indata)
)
)
except ValueError as e:
if expect_success:
self.fail(
"Failed to send with method <<%s>>; "
"expected to succeed.\n" % (meth_name,)
)
if not str(e).startswith(meth_name):
self.fail(
"Method <<%s>> failed with unexpected "
"exception message: %s\n" % (
meth_name, e
)
)
for meth_name, recv_meth, expect_success, args in recv_methods:
indata = data_prefix + meth_name
try:
s.send(indata.encode('ASCII', 'strict'))
outdata = recv_meth(*args)
outdata = outdata.decode('ASCII', 'strict')
if outdata != indata.lower():
self.fail(
"While receiving with <<%s>> bad data "
"<<%r>> (%d) received; "
"expected <<%r>> (%d)\n" % (
meth_name, outdata[:20], len(outdata),
indata[:20], len(indata)
)
)
except ValueError as e:
if expect_success:
self.fail(
"Failed to receive with method <<%s>>; "
"expected to succeed.\n" % (meth_name,)
)
if not str(e).startswith(meth_name):
self.fail(
"Method <<%s>> failed with unexpected "
"exception message: %s\n" % (
meth_name, e
)
)
# consume data
s.read()
s.write("over\n".encode("ASCII", "strict"))
s.close()
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout
server = socket.socket(socket.AF_INET)
host = "127.0.0.1"
port = test_support.bind_port(server)
started = threading.Event()
finish = False
def serve():
server.listen(5)
started.set()
conns = []
while not finish:
r, w, e = select.select([server], [], [], 0.1)
if server in r:
# Let the socket hang around rather than having
# it closed by garbage collection.
conns.append(server.accept()[0])
t = threading.Thread(target=serve)
t.start()
started.wait()
try:
try:
c = socket.socket(socket.AF_INET)
c.settimeout(0.2)
c.connect((host, port))
# Will attempt handshake and time out
self.assertRaisesRegexp(ssl.SSLError, "timed out",
ssl.wrap_socket, c)
finally:
c.close()
try:
c = socket.socket(socket.AF_INET)
c.settimeout(0.2)
c = ssl.wrap_socket(c)
# Will attempt handshake and time out
self.assertRaisesRegexp(ssl.SSLError, "timed out",
c.connect, (host, port))
finally:
c.close()
finally:
finish = True
t.join()
server.close()
def test_default_ciphers(self):
with ThreadedEchoServer(CERTFILE,
ssl_version=ssl.PROTOCOL_SSLv23,
chatty=False) as server:
sock = socket.socket()
try:
# Force a set of weak ciphers on our client socket
try:
s = ssl.wrap_socket(sock,
ssl_version=ssl.PROTOCOL_SSLv23,
ciphers="DES")
except ssl.SSLError:
self.skipTest("no DES cipher available")
with self.assertRaises((OSError, ssl.SSLError)):
s.connect((HOST, server.port))
finally:
sock.close()
self.assertIn("no shared cipher", str(server.conn_errors[0]))
def test_main(verbose=False):
global CERTFILE, SVN_PYTHON_ORG_ROOT_CERT, NOKIACERT
CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir,
"keycert.pem")
SVN_PYTHON_ORG_ROOT_CERT = os.path.join(
os.path.dirname(__file__) or os.curdir,
"https_svn_python_org_root.pem")
NOKIACERT = os.path.join(os.path.dirname(__file__) or os.curdir,
"nokia.pem")
if (not os.path.exists(CERTFILE) or
not os.path.exists(SVN_PYTHON_ORG_ROOT_CERT) or
not os.path.exists(NOKIACERT)):
raise test_support.TestFailed("Can't read certificate files!")
tests = [BasicTests, BasicSocketTests]
if test_support.is_resource_enabled('network'):
tests.append(NetworkedTests)
if _have_threads:
thread_info = test_support.threading_setup()
if thread_info and test_support.is_resource_enabled('network'):
tests.append(ThreadedTests)
try:
test_support.run_unittest(*tests)
finally:
if _have_threads:
test_support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()
import unittest
from test import test_support
import subprocess
import sys
import signal
import os
import errno
import tempfile
import time
import re
import sysconfig
try:
import resource
except ImportError:
resource = None
mswindows = (sys.platform == "win32")
#
# Depends on the following external programs: Python
#
if mswindows:
SETBINARY = ('import msvcrt; msvcrt.setmode(sys.stdout.fileno(), '
'os.O_BINARY);')
else:
SETBINARY = ''
try:
mkstemp = tempfile.mkstemp
except AttributeError:
# tempfile.mkstemp is not available
def mkstemp():
"""Replacement for mkstemp, calling mktemp."""
fname = tempfile.mktemp()
return os.open(fname, os.O_RDWR|os.O_CREAT), fname
class BaseTestCase(unittest.TestCase):
def setUp(self):
# Try to minimize the number of children we have so this test
# doesn't crash on some buildbots (Alphas in particular).
test_support.reap_children()
def tearDown(self):
for inst in subprocess._active:
inst.wait()
subprocess._cleanup()
self.assertFalse(subprocess._active, "subprocess._active not empty")
def assertStderrEqual(self, stderr, expected, msg=None):
# In a debug build, stuff like "[6580 refs]" is printed to stderr at
# shutdown time. That frustrates tests trying to check stderr produced
# from a spawned Python process.
actual = re.sub(r"\[\d+ refs\]\r?\n?$", "", stderr)
self.assertEqual(actual, expected, msg)
class ProcessTestCase(BaseTestCase):
def test_call_seq(self):
# call() function with sequence argument
rc = subprocess.call([sys.executable, "-c",
"import sys; sys.exit(47)"])
self.assertEqual(rc, 47)
def test_check_call_zero(self):
# check_call() function with zero return code
rc = subprocess.check_call([sys.executable, "-c",
"import sys; sys.exit(0)"])
self.assertEqual(rc, 0)
def test_check_call_nonzero(self):
# check_call() function with non-zero return code
with self.assertRaises(subprocess.CalledProcessError) as c:
subprocess.check_call([sys.executable, "-c",
"import sys; sys.exit(47)"])
self.assertEqual(c.exception.returncode, 47)
def test_check_output(self):
# check_output() function with zero return code
output = subprocess.check_output(
[sys.executable, "-c", "print 'BDFL'"])
self.assertIn('BDFL', output)
def test_check_output_nonzero(self):
# check_call() function with non-zero return code
with self.assertRaises(subprocess.CalledProcessError) as c:
subprocess.check_output(
[sys.executable, "-c", "import sys; sys.exit(5)"])
self.assertEqual(c.exception.returncode, 5)
def test_check_output_stderr(self):
# check_output() function stderr redirected to stdout
output = subprocess.check_output(
[sys.executable, "-c", "import sys; sys.stderr.write('BDFL')"],
stderr=subprocess.STDOUT)
self.assertIn('BDFL', output)
def test_check_output_stdout_arg(self):
# check_output() function stderr redirected to stdout
with self.assertRaises(ValueError) as c:
output = subprocess.check_output(
[sys.executable, "-c", "print 'will not be run'"],
stdout=sys.stdout)
self.fail("Expected ValueError when stdout arg supplied.")
self.assertIn('stdout', c.exception.args[0])
def test_call_kwargs(self):
# call() function with keyword args
newenv = os.environ.copy()
newenv["FRUIT"] = "banana"
rc = subprocess.call([sys.executable, "-c",
'import sys, os;'
'sys.exit(os.getenv("FRUIT")=="banana")'],
env=newenv)
self.assertEqual(rc, 1)
def test_invalid_args(self):
# Popen() called with invalid arguments should raise TypeError
# but Popen.__del__ should not complain (issue #12085)
with test_support.captured_stderr() as s:
self.assertRaises(TypeError, subprocess.Popen, invalid_arg_name=1)
argcount = subprocess.Popen.__init__.__code__.co_argcount
too_many_args = [0] * (argcount + 1)
self.assertRaises(TypeError, subprocess.Popen, *too_many_args)
self.assertEqual(s.getvalue(), '')
def test_stdin_none(self):
# .stdin is None when not redirected
p = subprocess.Popen([sys.executable, "-c", 'print "banana"'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
p.wait()
self.assertEqual(p.stdin, None)
def test_stdout_none(self):
# .stdout is None when not redirected
p = subprocess.Popen([sys.executable, "-c",
'print " this bit of output is from a '
'test of stdout in a different '
'process ..."'],
stdin=subprocess.PIPE, stderr=subprocess.PIPE)
self.addCleanup(p.stdin.close)
self.addCleanup(p.stderr.close)
p.wait()
self.assertEqual(p.stdout, None)
def test_stderr_none(self):
# .stderr is None when not redirected
p = subprocess.Popen([sys.executable, "-c", 'print "banana"'],
stdin=subprocess.PIPE, stdout=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stdin.close)
p.wait()
self.assertEqual(p.stderr, None)
def test_executable_with_cwd(self):
python_dir = os.path.dirname(os.path.realpath(sys.executable))
p = subprocess.Popen(["somethingyoudonthave", "-c",
"import sys; sys.exit(47)"],
executable=sys.executable, cwd=python_dir)
p.wait()
self.assertEqual(p.returncode, 47)
@unittest.skipIf(sysconfig.is_python_build(),
"need an installed Python. See #7774")
def test_executable_without_cwd(self):
# For a normal installation, it should work without 'cwd'
# argument. For test runs in the build directory, see #7774.
p = subprocess.Popen(["somethingyoudonthave", "-c",
"import sys; sys.exit(47)"],
executable=sys.executable)
p.wait()
self.assertEqual(p.returncode, 47)
def test_stdin_pipe(self):
# stdin redirection
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.exit(sys.stdin.read() == "pear")'],
stdin=subprocess.PIPE)
p.stdin.write("pear")
p.stdin.close()
p.wait()
self.assertEqual(p.returncode, 1)
def test_stdin_filedes(self):
# stdin is set to open file descriptor
tf = tempfile.TemporaryFile()
d = tf.fileno()
os.write(d, "pear")
os.lseek(d, 0, 0)
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.exit(sys.stdin.read() == "pear")'],
stdin=d)
p.wait()
self.assertEqual(p.returncode, 1)
def test_stdin_fileobj(self):
# stdin is set to open file object
tf = tempfile.TemporaryFile()
tf.write("pear")
tf.seek(0)
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.exit(sys.stdin.read() == "pear")'],
stdin=tf)
p.wait()
self.assertEqual(p.returncode, 1)
def test_stdout_pipe(self):
# stdout redirection
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stdout.write("orange")'],
stdout=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.assertEqual(p.stdout.read(), "orange")
def test_stdout_filedes(self):
# stdout is set to open file descriptor
tf = tempfile.TemporaryFile()
d = tf.fileno()
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stdout.write("orange")'],
stdout=d)
p.wait()
os.lseek(d, 0, 0)
self.assertEqual(os.read(d, 1024), "orange")
def test_stdout_fileobj(self):
# stdout is set to open file object
tf = tempfile.TemporaryFile()
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stdout.write("orange")'],
stdout=tf)
p.wait()
tf.seek(0)
self.assertEqual(tf.read(), "orange")
def test_stderr_pipe(self):
# stderr redirection
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stderr.write("strawberry")'],
stderr=subprocess.PIPE)
self.addCleanup(p.stderr.close)
self.assertStderrEqual(p.stderr.read(), "strawberry")
def test_stderr_filedes(self):
# stderr is set to open file descriptor
tf = tempfile.TemporaryFile()
d = tf.fileno()
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stderr.write("strawberry")'],
stderr=d)
p.wait()
os.lseek(d, 0, 0)
self.assertStderrEqual(os.read(d, 1024), "strawberry")
def test_stderr_fileobj(self):
# stderr is set to open file object
tf = tempfile.TemporaryFile()
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stderr.write("strawberry")'],
stderr=tf)
p.wait()
tf.seek(0)
self.assertStderrEqual(tf.read(), "strawberry")
def test_stdout_stderr_pipe(self):
# capture stdout and stderr to the same pipe
p = subprocess.Popen([sys.executable, "-c",
'import sys;'
'sys.stdout.write("apple");'
'sys.stdout.flush();'
'sys.stderr.write("orange")'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
self.addCleanup(p.stdout.close)
self.assertStderrEqual(p.stdout.read(), "appleorange")
def test_stdout_stderr_file(self):
# capture stdout and stderr to the same open file
tf = tempfile.TemporaryFile()
p = subprocess.Popen([sys.executable, "-c",
'import sys;'
'sys.stdout.write("apple");'
'sys.stdout.flush();'
'sys.stderr.write("orange")'],
stdout=tf,
stderr=tf)
p.wait()
tf.seek(0)
self.assertStderrEqual(tf.read(), "appleorange")
def test_stdout_filedes_of_stdout(self):
# stdout is set to 1 (#1531862).
cmd = r"import sys, os; sys.exit(os.write(sys.stdout.fileno(), '.\n'))"
rc = subprocess.call([sys.executable, "-c", cmd], stdout=1)
self.assertEqual(rc, 2)
def test_cwd(self):
tmpdir = tempfile.gettempdir()
# We cannot use os.path.realpath to canonicalize the path,
# since it doesn't expand Tru64 {memb} strings. See bug 1063571.
cwd = os.getcwd()
os.chdir(tmpdir)
tmpdir = os.getcwd()
os.chdir(cwd)
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;'
'sys.stdout.write(os.getcwd())'],
stdout=subprocess.PIPE,
cwd=tmpdir)
self.addCleanup(p.stdout.close)
normcase = os.path.normcase
self.assertEqual(normcase(p.stdout.read()), normcase(tmpdir))
def test_env(self):
newenv = os.environ.copy()
newenv["FRUIT"] = "orange"
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;'
'sys.stdout.write(os.getenv("FRUIT"))'],
stdout=subprocess.PIPE,
env=newenv)
self.addCleanup(p.stdout.close)
self.assertEqual(p.stdout.read(), "orange")
def test_communicate_stdin(self):
p = subprocess.Popen([sys.executable, "-c",
'import sys;'
'sys.exit(sys.stdin.read() == "pear")'],
stdin=subprocess.PIPE)
p.communicate("pear")
self.assertEqual(p.returncode, 1)
def test_communicate_stdout(self):
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stdout.write("pineapple")'],
stdout=subprocess.PIPE)
(stdout, stderr) = p.communicate()
self.assertEqual(stdout, "pineapple")
self.assertEqual(stderr, None)
def test_communicate_stderr(self):
p = subprocess.Popen([sys.executable, "-c",
'import sys; sys.stderr.write("pineapple")'],
stderr=subprocess.PIPE)
(stdout, stderr) = p.communicate()
self.assertEqual(stdout, None)
self.assertStderrEqual(stderr, "pineapple")
def test_communicate(self):
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;'
'sys.stderr.write("pineapple");'
'sys.stdout.write(sys.stdin.read())'],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
self.addCleanup(p.stdin.close)
(stdout, stderr) = p.communicate("banana")
self.assertEqual(stdout, "banana")
self.assertStderrEqual(stderr, "pineapple")
# This test is Linux specific for simplicity to at least have
# some coverage. It is not a platform specific bug.
@unittest.skipUnless(os.path.isdir('/proc/%d/fd' % os.getpid()),
"Linux specific")
# Test for the fd leak reported in http://bugs.python.org/issue2791.
def test_communicate_pipe_fd_leak(self):
fd_directory = '/proc/%d/fd' % os.getpid()
num_fds_before_popen = len(os.listdir(fd_directory))
p = subprocess.Popen([sys.executable, "-c", "print()"],
stdout=subprocess.PIPE)
p.communicate()
num_fds_after_communicate = len(os.listdir(fd_directory))
del p
num_fds_after_destruction = len(os.listdir(fd_directory))
self.assertEqual(num_fds_before_popen, num_fds_after_destruction)
self.assertEqual(num_fds_before_popen, num_fds_after_communicate)
def test_communicate_returns(self):
# communicate() should return None if no redirection is active
p = subprocess.Popen([sys.executable, "-c",
"import sys; sys.exit(47)"])
(stdout, stderr) = p.communicate()
self.assertEqual(stdout, None)
self.assertEqual(stderr, None)
def test_communicate_pipe_buf(self):
# communicate() with writes larger than pipe_buf
# This test will probably deadlock rather than fail, if
# communicate() does not work properly.
x, y = os.pipe()
if mswindows:
pipe_buf = 512
else:
pipe_buf = os.fpathconf(x, "PC_PIPE_BUF")
os.close(x)
os.close(y)
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;'
'sys.stdout.write(sys.stdin.read(47));'
'sys.stderr.write("xyz"*%d);'
'sys.stdout.write(sys.stdin.read())' % pipe_buf],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
self.addCleanup(p.stdin.close)
string_to_write = "abc"*pipe_buf
(stdout, stderr) = p.communicate(string_to_write)
self.assertEqual(stdout, string_to_write)
def test_writes_before_communicate(self):
# stdin.write before communicate()
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;'
'sys.stdout.write(sys.stdin.read())'],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
self.addCleanup(p.stdin.close)
p.stdin.write("banana")
(stdout, stderr) = p.communicate("split")
self.assertEqual(stdout, "bananasplit")
self.assertStderrEqual(stderr, "")
def test_universal_newlines(self):
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;' + SETBINARY +
'sys.stdout.write("line1\\n");'
'sys.stdout.flush();'
'sys.stdout.write("line2\\r");'
'sys.stdout.flush();'
'sys.stdout.write("line3\\r\\n");'
'sys.stdout.flush();'
'sys.stdout.write("line4\\r");'
'sys.stdout.flush();'
'sys.stdout.write("\\nline5");'
'sys.stdout.flush();'
'sys.stdout.write("\\nline6");'],
stdout=subprocess.PIPE,
universal_newlines=1)
self.addCleanup(p.stdout.close)
stdout = p.stdout.read()
if hasattr(file, 'newlines'):
# Interpreter with universal newline support
self.assertEqual(stdout,
"line1\nline2\nline3\nline4\nline5\nline6")
else:
# Interpreter without universal newline support
self.assertEqual(stdout,
"line1\nline2\rline3\r\nline4\r\nline5\nline6")
def test_universal_newlines_communicate(self):
# universal newlines through communicate()
p = subprocess.Popen([sys.executable, "-c",
'import sys,os;' + SETBINARY +
'sys.stdout.write("line1\\n");'
'sys.stdout.flush();'
'sys.stdout.write("line2\\r");'
'sys.stdout.flush();'
'sys.stdout.write("line3\\r\\n");'
'sys.stdout.flush();'
'sys.stdout.write("line4\\r");'
'sys.stdout.flush();'
'sys.stdout.write("\\nline5");'
'sys.stdout.flush();'
'sys.stdout.write("\\nline6");'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=1)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
(stdout, stderr) = p.communicate()
if hasattr(file, 'newlines'):
# Interpreter with universal newline support
self.assertEqual(stdout,
"line1\nline2\nline3\nline4\nline5\nline6")
else:
# Interpreter without universal newline support
self.assertEqual(stdout,
"line1\nline2\rline3\r\nline4\r\nline5\nline6")
def test_no_leaking(self):
# Make sure we leak no resources
if not mswindows:
max_handles = 1026 # too much for most UNIX systems
else:
max_handles = 2050 # too much for (at least some) Windows setups
handles = []
try:
for i in range(max_handles):
try:
handles.append(os.open(test_support.TESTFN,
os.O_WRONLY | os.O_CREAT))
except OSError as e:
if e.errno != errno.EMFILE:
raise
break
else:
self.skipTest("failed to reach the file descriptor limit "
"(tried %d)" % max_handles)
# Close a couple of them (should be enough for a subprocess)
for i in range(10):
os.close(handles.pop())
# Loop creating some subprocesses. If one of them leaks some fds,
# the next loop iteration will fail by reaching the max fd limit.
for i in range(15):
p = subprocess.Popen([sys.executable, "-c",
"import sys;"
"sys.stdout.write(sys.stdin.read())"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
data = p.communicate(b"lime")[0]
self.assertEqual(data, b"lime")
finally:
for h in handles:
os.close(h)
def test_list2cmdline(self):
self.assertEqual(subprocess.list2cmdline(['a b c', 'd', 'e']),
'"a b c" d e')
self.assertEqual(subprocess.list2cmdline(['ab"c', '\\', 'd']),
'ab\\"c \\ d')
self.assertEqual(subprocess.list2cmdline(['ab"c', ' \\', 'd']),
'ab\\"c " \\\\" d')
self.assertEqual(subprocess.list2cmdline(['a\\\\\\b', 'de fg', 'h']),
'a\\\\\\b "de fg" h')
self.assertEqual(subprocess.list2cmdline(['a\\"b', 'c', 'd']),
'a\\\\\\"b c d')
self.assertEqual(subprocess.list2cmdline(['a\\\\b c', 'd', 'e']),
'"a\\\\b c" d e')
self.assertEqual(subprocess.list2cmdline(['a\\\\b\\ c', 'd', 'e']),
'"a\\\\b\\ c" d e')
self.assertEqual(subprocess.list2cmdline(['ab', '']),
'ab ""')
def test_poll(self):
p = subprocess.Popen([sys.executable,
"-c", "import time; time.sleep(1)"])
count = 0
while p.poll() is None:
time.sleep(0.1)
count += 1
# We expect that the poll loop probably went around about 10 times,
# but, based on system scheduling we can't control, it's possible
# poll() never returned None. It "should be" very rare that it
# didn't go around at least twice.
self.assertGreaterEqual(count, 2)
# Subsequent invocations should just return the returncode
self.assertEqual(p.poll(), 0)
def test_wait(self):
p = subprocess.Popen([sys.executable,
"-c", "import time; time.sleep(2)"])
self.assertEqual(p.wait(), 0)
# Subsequent invocations should just return the returncode
self.assertEqual(p.wait(), 0)
def test_invalid_bufsize(self):
# an invalid type of the bufsize argument should raise
# TypeError.
with self.assertRaises(TypeError):
subprocess.Popen([sys.executable, "-c", "pass"], "orange")
def test_leaking_fds_on_error(self):
# see bug #5179: Popen leaks file descriptors to PIPEs if
# the child fails to execute; this will eventually exhaust
# the maximum number of open fds. 1024 seems a very common
# value for that limit, but Windows has 2048, so we loop
# 1024 times (each call leaked two fds).
for i in range(1024):
# Windows raises IOError. Others raise OSError.
with self.assertRaises(EnvironmentError) as c:
subprocess.Popen(['nonexisting_i_hope'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
# ignore errors that indicate the command was not found
if c.exception.errno not in (errno.ENOENT, errno.EACCES):
raise c.exception
def test_handles_closed_on_exception(self):
# If CreateProcess exits with an error, ensure the
# duplicate output handles are released
ifhandle, ifname = mkstemp()
ofhandle, ofname = mkstemp()
efhandle, efname = mkstemp()
try:
subprocess.Popen (["*"], stdin=ifhandle, stdout=ofhandle,
stderr=efhandle)
except OSError:
os.close(ifhandle)
os.remove(ifname)
os.close(ofhandle)
os.remove(ofname)
os.close(efhandle)
os.remove(efname)
self.assertFalse(os.path.exists(ifname))
self.assertFalse(os.path.exists(ofname))
self.assertFalse(os.path.exists(efname))
def test_communicate_epipe(self):
# Issue 10963: communicate() should hide EPIPE
p = subprocess.Popen([sys.executable, "-c", 'pass'],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
self.addCleanup(p.stdin.close)
p.communicate("x" * 2**20)
def test_communicate_epipe_only_stdin(self):
# Issue 10963: communicate() should hide EPIPE
p = subprocess.Popen([sys.executable, "-c", 'pass'],
stdin=subprocess.PIPE)
self.addCleanup(p.stdin.close)
time.sleep(2)
p.communicate("x" * 2**20)
# context manager
class _SuppressCoreFiles(object):
"""Try to prevent core files from being created."""
old_limit = None
def __enter__(self):
"""Try to save previous ulimit, then set it to (0, 0)."""
if resource is not None:
try:
self.old_limit = resource.getrlimit(resource.RLIMIT_CORE)
resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
except (ValueError, resource.error):
pass
if sys.platform == 'darwin':
# Check if the 'Crash Reporter' on OSX was configured
# in 'Developer' mode and warn that it will get triggered
# when it is.
#
# This assumes that this context manager is used in tests
# that might trigger the next manager.
value = subprocess.Popen(['/usr/bin/defaults', 'read',
'com.apple.CrashReporter', 'DialogType'],
stdout=subprocess.PIPE).communicate()[0]
if value.strip() == b'developer':
print "this tests triggers the Crash Reporter, that is intentional"
sys.stdout.flush()
def __exit__(self, *args):
"""Return core file behavior to default."""
if self.old_limit is None:
return
if resource is not None:
try:
resource.setrlimit(resource.RLIMIT_CORE, self.old_limit)
except (ValueError, resource.error):
pass
@unittest.skipUnless(hasattr(signal, 'SIGALRM'),
"Requires signal.SIGALRM")
def test_communicate_eintr(self):
# Issue #12493: communicate() should handle EINTR
def handler(signum, frame):
pass
old_handler = signal.signal(signal.SIGALRM, handler)
self.addCleanup(signal.signal, signal.SIGALRM, old_handler)
# the process is running for 2 seconds
args = [sys.executable, "-c", 'import time; time.sleep(2)']
for stream in ('stdout', 'stderr'):
kw = {stream: subprocess.PIPE}
with subprocess.Popen(args, **kw) as process:
signal.alarm(1)
# communicate() will be interrupted by SIGALRM
process.communicate()
@unittest.skipIf(mswindows, "POSIX specific tests")
class POSIXProcessTestCase(BaseTestCase):
def test_exceptions(self):
# caught & re-raised exceptions
with self.assertRaises(OSError) as c:
p = subprocess.Popen([sys.executable, "-c", ""],
cwd="/this/path/does/not/exist")
# The attribute child_traceback should contain "os.chdir" somewhere.
self.assertIn("os.chdir", c.exception.child_traceback)
def test_run_abort(self):
# returncode handles signal termination
with _SuppressCoreFiles():
p = subprocess.Popen([sys.executable, "-c",
"import os; os.abort()"])
p.wait()
self.assertEqual(-p.returncode, signal.SIGABRT)
def test_preexec(self):
# preexec function
p = subprocess.Popen([sys.executable, "-c",
"import sys, os;"
"sys.stdout.write(os.getenv('FRUIT'))"],
stdout=subprocess.PIPE,
preexec_fn=lambda: os.putenv("FRUIT", "apple"))
self.addCleanup(p.stdout.close)
self.assertEqual(p.stdout.read(), "apple")
def test_args_string(self):
# args is a string
f, fname = mkstemp()
os.write(f, "#!/bin/sh\n")
os.write(f, "exec '%s' -c 'import sys; sys.exit(47)'\n" %
sys.executable)
os.close(f)
os.chmod(fname, 0o700)
p = subprocess.Popen(fname)
p.wait()
os.remove(fname)
self.assertEqual(p.returncode, 47)
def test_invalid_args(self):
# invalid arguments should raise ValueError
self.assertRaises(ValueError, subprocess.call,
[sys.executable, "-c",
"import sys; sys.exit(47)"],
startupinfo=47)
self.assertRaises(ValueError, subprocess.call,
[sys.executable, "-c",
"import sys; sys.exit(47)"],
creationflags=47)
def test_shell_sequence(self):
# Run command through the shell (sequence)
newenv = os.environ.copy()
newenv["FRUIT"] = "apple"
p = subprocess.Popen(["echo $FRUIT"], shell=1,
stdout=subprocess.PIPE,
env=newenv)
self.addCleanup(p.stdout.close)
self.assertEqual(p.stdout.read().strip(), "apple")
def test_shell_string(self):
# Run command through the shell (string)
newenv = os.environ.copy()
newenv["FRUIT"] = "apple"
p = subprocess.Popen("echo $FRUIT", shell=1,
stdout=subprocess.PIPE,
env=newenv)
self.addCleanup(p.stdout.close)
self.assertEqual(p.stdout.read().strip(), "apple")
def test_call_string(self):
# call() function with string argument on UNIX
f, fname = mkstemp()
os.write(f, "#!/bin/sh\n")
os.write(f, "exec '%s' -c 'import sys; sys.exit(47)'\n" %
sys.executable)
os.close(f)
os.chmod(fname, 0700)
rc = subprocess.call(fname)
os.remove(fname)
self.assertEqual(rc, 47)
def test_specific_shell(self):
# Issue #9265: Incorrect name passed as arg[0].
shells = []
for prefix in ['/bin', '/usr/bin/', '/usr/local/bin']:
for name in ['bash', 'ksh']:
sh = os.path.join(prefix, name)
if os.path.isfile(sh):
shells.append(sh)
if not shells: # Will probably work for any shell but csh.
self.skipTest("bash or ksh required for this test")
sh = '/bin/sh'
if os.path.isfile(sh) and not os.path.islink(sh):
# Test will fail if /bin/sh is a symlink to csh.
shells.append(sh)
for sh in shells:
p = subprocess.Popen("echo $0", executable=sh, shell=True,
stdout=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.assertEqual(p.stdout.read().strip(), sh)
def _kill_process(self, method, *args):
# Do not inherit file handles from the parent.
# It should fix failures on some platforms.
p = subprocess.Popen([sys.executable, "-c", """if 1:
import sys, time
sys.stdout.write('x\\n')
sys.stdout.flush()
time.sleep(30)
"""],
close_fds=True,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
# Wait for the interpreter to be completely initialized before
# sending any signal.
p.stdout.read(1)
getattr(p, method)(*args)
return p
def test_send_signal(self):
p = self._kill_process('send_signal', signal.SIGINT)
_, stderr = p.communicate()
self.assertIn('KeyboardInterrupt', stderr)
self.assertNotEqual(p.wait(), 0)
def test_kill(self):
p = self._kill_process('kill')
_, stderr = p.communicate()
self.assertStderrEqual(stderr, '')
self.assertEqual(p.wait(), -signal.SIGKILL)
def test_terminate(self):
p = self._kill_process('terminate')
_, stderr = p.communicate()
self.assertStderrEqual(stderr, '')
self.assertEqual(p.wait(), -signal.SIGTERM)
def check_close_std_fds(self, fds):
# Issue #9905: test that subprocess pipes still work properly with
# some standard fds closed
stdin = 0
newfds = []
for a in fds:
b = os.dup(a)
newfds.append(b)
if a == 0:
stdin = b
try:
for fd in fds:
os.close(fd)
out, err = subprocess.Popen([sys.executable, "-c",
'import sys;'
'sys.stdout.write("apple");'
'sys.stdout.flush();'
'sys.stderr.write("orange")'],
stdin=stdin,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE).communicate()
err = test_support.strip_python_stderr(err)
self.assertEqual((out, err), (b'apple', b'orange'))
finally:
for b, a in zip(newfds, fds):
os.dup2(b, a)
for b in newfds:
os.close(b)
def test_close_fd_0(self):
self.check_close_std_fds([0])
def test_close_fd_1(self):
self.check_close_std_fds([1])
def test_close_fd_2(self):
self.check_close_std_fds([2])
def test_close_fds_0_1(self):
self.check_close_std_fds([0, 1])
def test_close_fds_0_2(self):
self.check_close_std_fds([0, 2])
def test_close_fds_1_2(self):
self.check_close_std_fds([1, 2])
def test_close_fds_0_1_2(self):
# Issue #10806: test that subprocess pipes still work properly with
# all standard fds closed.
self.check_close_std_fds([0, 1, 2])
def check_swap_fds(self, stdin_no, stdout_no, stderr_no):
# open up some temporary files
temps = [mkstemp() for i in range(3)]
temp_fds = [fd for fd, fname in temps]
try:
# unlink the files -- we won't need to reopen them
for fd, fname in temps:
os.unlink(fname)
# save a copy of the standard file descriptors
saved_fds = [os.dup(fd) for fd in range(3)]
try:
# duplicate the temp files over the standard fd's 0, 1, 2
for fd, temp_fd in enumerate(temp_fds):
os.dup2(temp_fd, fd)
# write some data to what will become stdin, and rewind
os.write(stdin_no, b"STDIN")
os.lseek(stdin_no, 0, 0)
# now use those files in the given order, so that subprocess
# has to rearrange them in the child
p = subprocess.Popen([sys.executable, "-c",
'import sys; got = sys.stdin.read();'
'sys.stdout.write("got %s"%got); sys.stderr.write("err")'],
stdin=stdin_no,
stdout=stdout_no,
stderr=stderr_no)
p.wait()
for fd in temp_fds:
os.lseek(fd, 0, 0)
out = os.read(stdout_no, 1024)
err = test_support.strip_python_stderr(os.read(stderr_no, 1024))
finally:
for std, saved in enumerate(saved_fds):
os.dup2(saved, std)
os.close(saved)
self.assertEqual(out, b"got STDIN")
self.assertEqual(err, b"err")
finally:
for fd in temp_fds:
os.close(fd)
# When duping fds, if there arises a situation where one of the fds is
# either 0, 1 or 2, it is possible that it is overwritten (#12607).
# This tests all combinations of this.
def test_swap_fds(self):
self.check_swap_fds(0, 1, 2)
self.check_swap_fds(0, 2, 1)
self.check_swap_fds(1, 0, 2)
self.check_swap_fds(1, 2, 0)
self.check_swap_fds(2, 0, 1)
self.check_swap_fds(2, 1, 0)
def test_wait_when_sigchild_ignored(self):
# NOTE: sigchild_ignore.py may not be an effective test on all OSes.
sigchild_ignore = test_support.findfile("sigchild_ignore.py",
subdir="subprocessdata")
p = subprocess.Popen([sys.executable, sigchild_ignore],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = p.communicate()
self.assertEqual(0, p.returncode, "sigchild_ignore.py exited"
" non-zero with this error:\n%s" % stderr)
def test_zombie_fast_process_del(self):
# Issue #12650: on Unix, if Popen.__del__() was called before the
# process exited, it wouldn't be added to subprocess._active, and would
# remain a zombie.
# spawn a Popen, and delete its reference before it exits
p = subprocess.Popen([sys.executable, "-c",
'import sys, time;'
'time.sleep(0.2)'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
ident = id(p)
pid = p.pid
del p
# check that p is in the active processes list
self.assertIn(ident, [id(o) for o in subprocess._active])
def test_leak_fast_process_del_killed(self):
# Issue #12650: on Unix, if Popen.__del__() was called before the
# process exited, and the process got killed by a signal, it would never
# be removed from subprocess._active, which triggered a FD and memory
# leak.
# spawn a Popen, delete its reference and kill it
p = subprocess.Popen([sys.executable, "-c",
'import time;'
'time.sleep(3)'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
ident = id(p)
pid = p.pid
del p
os.kill(pid, signal.SIGKILL)
# check that p is in the active processes list
self.assertIn(ident, [id(o) for o in subprocess._active])
# let some time for the process to exit, and create a new Popen: this
# should trigger the wait() of p
time.sleep(0.2)
with self.assertRaises(EnvironmentError) as c:
with subprocess.Popen(['nonexisting_i_hope'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE) as proc:
pass
# p should have been wait()ed on, and removed from the _active list
self.assertRaises(OSError, os.waitpid, pid, 0)
self.assertNotIn(ident, [id(o) for o in subprocess._active])
def test_pipe_cloexec(self):
# Issue 12786: check that the communication pipes' FDs are set CLOEXEC,
# and are not inherited by another child process.
p1 = subprocess.Popen([sys.executable, "-c",
'import os;'
'os.read(0, 1)'
],
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
p2 = subprocess.Popen([sys.executable, "-c", """if True:
import os, errno, sys
for fd in %r:
try:
os.close(fd)
except OSError as e:
if e.errno != errno.EBADF:
raise
else:
sys.exit(1)
sys.exit(0)
""" % [f.fileno() for f in (p1.stdin, p1.stdout,
p1.stderr)]
],
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, close_fds=False)
p1.communicate('foo')
_, stderr = p2.communicate()
self.assertEqual(p2.returncode, 0, "Unexpected error: " + repr(stderr))
@unittest.skipUnless(mswindows, "Windows specific tests")
class Win32ProcessTestCase(BaseTestCase):
def test_startupinfo(self):
# startupinfo argument
# We uses hardcoded constants, because we do not want to
# depend on win32all.
STARTF_USESHOWWINDOW = 1
SW_MAXIMIZE = 3
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags = STARTF_USESHOWWINDOW
startupinfo.wShowWindow = SW_MAXIMIZE
# Since Python is a console process, it won't be affected
# by wShowWindow, but the argument should be silently
# ignored
subprocess.call([sys.executable, "-c", "import sys; sys.exit(0)"],
startupinfo=startupinfo)
def test_creationflags(self):
# creationflags argument
CREATE_NEW_CONSOLE = 16
sys.stderr.write(" a DOS box should flash briefly ...\n")
subprocess.call(sys.executable +
' -c "import time; time.sleep(0.25)"',
creationflags=CREATE_NEW_CONSOLE)
def test_invalid_args(self):
# invalid arguments should raise ValueError
self.assertRaises(ValueError, subprocess.call,
[sys.executable, "-c",
"import sys; sys.exit(47)"],
preexec_fn=lambda: 1)
self.assertRaises(ValueError, subprocess.call,
[sys.executable, "-c",
"import sys; sys.exit(47)"],
stdout=subprocess.PIPE,
close_fds=True)
def test_close_fds(self):
# close file descriptors
rc = subprocess.call([sys.executable, "-c",
"import sys; sys.exit(47)"],
close_fds=True)
self.assertEqual(rc, 47)
def test_shell_sequence(self):
# Run command through the shell (sequence)
newenv = os.environ.copy()
newenv["FRUIT"] = "physalis"
p = subprocess.Popen(["set"], shell=1,
stdout=subprocess.PIPE,
env=newenv)
self.addCleanup(p.stdout.close)
self.assertIn("physalis", p.stdout.read())
def test_shell_string(self):
# Run command through the shell (string)
newenv = os.environ.copy()
newenv["FRUIT"] = "physalis"
p = subprocess.Popen("set", shell=1,
stdout=subprocess.PIPE,
env=newenv)
self.addCleanup(p.stdout.close)
self.assertIn("physalis", p.stdout.read())
def test_call_string(self):
# call() function with string argument on Windows
rc = subprocess.call(sys.executable +
' -c "import sys; sys.exit(47)"')
self.assertEqual(rc, 47)
def _kill_process(self, method, *args):
# Some win32 buildbot raises EOFError if stdin is inherited
p = subprocess.Popen([sys.executable, "-c", """if 1:
import sys, time
sys.stdout.write('x\\n')
sys.stdout.flush()
time.sleep(30)
"""],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
self.addCleanup(p.stdin.close)
# Wait for the interpreter to be completely initialized before
# sending any signal.
p.stdout.read(1)
getattr(p, method)(*args)
_, stderr = p.communicate()
self.assertStderrEqual(stderr, '')
returncode = p.wait()
self.assertNotEqual(returncode, 0)
def test_send_signal(self):
self._kill_process('send_signal', signal.SIGTERM)
def test_kill(self):
self._kill_process('kill')
def test_terminate(self):
self._kill_process('terminate')
@unittest.skipUnless(getattr(subprocess, '_has_poll', False),
"poll system call not supported")
class ProcessTestCaseNoPoll(ProcessTestCase):
def setUp(self):
subprocess._has_poll = False
ProcessTestCase.setUp(self)
def tearDown(self):
subprocess._has_poll = True
ProcessTestCase.tearDown(self)
class HelperFunctionTests(unittest.TestCase):
@unittest.skipIf(mswindows, "errno and EINTR make no sense on windows")
def test_eintr_retry_call(self):
record_calls = []
def fake_os_func(*args):
record_calls.append(args)
if len(record_calls) == 2:
raise OSError(errno.EINTR, "fake interrupted system call")
return tuple(reversed(args))
self.assertEqual((999, 256),
subprocess._eintr_retry_call(fake_os_func, 256, 999))
self.assertEqual([(256, 999)], record_calls)
# This time there will be an EINTR so it will loop once.
self.assertEqual((666,),
subprocess._eintr_retry_call(fake_os_func, 666))
self.assertEqual([(256, 999), (666,), (666,)], record_calls)
@unittest.skipUnless(mswindows, "mswindows only")
class CommandsWithSpaces (BaseTestCase):
def setUp(self):
super(CommandsWithSpaces, self).setUp()
f, fname = mkstemp(".py", "te st")
self.fname = fname.lower ()
os.write(f, b"import sys;"
b"sys.stdout.write('%d %s' % (len(sys.argv), [a.lower () for a in sys.argv]))"
)
os.close(f)
def tearDown(self):
os.remove(self.fname)
super(CommandsWithSpaces, self).tearDown()
def with_spaces(self, *args, **kwargs):
kwargs['stdout'] = subprocess.PIPE
p = subprocess.Popen(*args, **kwargs)
self.addCleanup(p.stdout.close)
self.assertEqual(
p.stdout.read ().decode("mbcs"),
"2 [%r, 'ab cd']" % self.fname
)
def test_shell_string_with_spaces(self):
# call() function with string argument with spaces on Windows
self.with_spaces('"%s" "%s" "%s"' % (sys.executable, self.fname,
"ab cd"), shell=1)
def test_shell_sequence_with_spaces(self):
# call() function with sequence argument with spaces on Windows
self.with_spaces([sys.executable, self.fname, "ab cd"], shell=1)
def test_noshell_string_with_spaces(self):
# call() function with string argument with spaces on Windows
self.with_spaces('"%s" "%s" "%s"' % (sys.executable, self.fname,
"ab cd"))
def test_noshell_sequence_with_spaces(self):
# call() function with sequence argument with spaces on Windows
self.with_spaces([sys.executable, self.fname, "ab cd"])
def test_main():
unit_tests = (ProcessTestCase,
POSIXProcessTestCase,
Win32ProcessTestCase,
ProcessTestCaseNoPoll,
HelperFunctionTests,
CommandsWithSpaces)
test_support.run_unittest(*unit_tests)
test_support.reap_children()
if __name__ == "__main__":
test_main()
import socket
import telnetlib
import time
import Queue
from unittest import TestCase
from test import test_support
threading = test_support.import_module('threading')
HOST = test_support.HOST
EOF_sigil = object()
def server(evt, serv, dataq=None):
""" Open a tcp server in three steps
1) set evt to true to let the parent know we are ready
2) [optional] if is not False, write the list of data from dataq.get()
to the socket.
"""
serv.listen(5)
evt.set()
try:
conn, addr = serv.accept()
if dataq:
data = ''
new_data = dataq.get(True, 0.5)
dataq.task_done()
for item in new_data:
if item == EOF_sigil:
break
if type(item) in [int, float]:
time.sleep(item)
else:
data += item
written = conn.send(data)
data = data[written:]
conn.close()
except socket.timeout:
pass
finally:
serv.close()
class GeneralTests(TestCase):
def setUp(self):
self.evt = threading.Event()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(60) # Safety net. Look issue 11812
self.port = test_support.bind_port(self.sock)
self.thread = threading.Thread(target=server, args=(self.evt,self.sock))
self.thread.setDaemon(True)
self.thread.start()
self.evt.wait()
def tearDown(self):
self.thread.join()
def testBasic(self):
# connects
telnet = telnetlib.Telnet(HOST, self.port)
telnet.sock.close()
def testTimeoutDefault(self):
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
telnet = telnetlib.Telnet(HOST, self.port)
finally:
socket.setdefaulttimeout(None)
self.assertEqual(telnet.sock.gettimeout(), 30)
telnet.sock.close()
def testTimeoutNone(self):
# None, having other default
self.assertTrue(socket.getdefaulttimeout() is None)
socket.setdefaulttimeout(30)
try:
telnet = telnetlib.Telnet(HOST, self.port, timeout=None)
finally:
socket.setdefaulttimeout(None)
self.assertTrue(telnet.sock.gettimeout() is None)
telnet.sock.close()
def testTimeoutValue(self):
telnet = telnetlib.Telnet(HOST, self.port, timeout=30)
self.assertEqual(telnet.sock.gettimeout(), 30)
telnet.sock.close()
def testTimeoutOpen(self):
telnet = telnetlib.Telnet()
telnet.open(HOST, self.port, timeout=30)
self.assertEqual(telnet.sock.gettimeout(), 30)
telnet.sock.close()
def _read_setUp(self):
self.evt = threading.Event()
self.dataq = Queue.Queue()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(10)
self.port = test_support.bind_port(self.sock)
self.thread = threading.Thread(target=server, args=(self.evt,self.sock, self.dataq))
self.thread.start()
self.evt.wait()
def _read_tearDown(self):
self.thread.join()
class ReadTests(TestCase):
setUp = _read_setUp
tearDown = _read_tearDown
# use a similar approach to testing timeouts as test_timeout.py
# these will never pass 100% but make the fuzz big enough that it is rare
block_long = 0.6
block_short = 0.3
def test_read_until_A(self):
"""
read_until(expected, [timeout])
Read until the expected string has been seen, or a timeout is
hit (default is no timeout); may block.
"""
want = ['x' * 10, 'match', 'y' * 10, EOF_sigil]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
data = telnet.read_until('match')
self.assertEqual(data, ''.join(want[:-2]))
def test_read_until_B(self):
# test the timeout - it does NOT raise socket.timeout
want = ['hello', self.block_long, 'not seen', EOF_sigil]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
data = telnet.read_until('not seen', self.block_short)
self.assertEqual(data, want[0])
self.assertEqual(telnet.read_all(), 'not seen')
def test_read_all_A(self):
"""
read_all()
Read all data until EOF; may block.
"""
want = ['x' * 500, 'y' * 500, 'z' * 500, EOF_sigil]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
data = telnet.read_all()
self.assertEqual(data, ''.join(want[:-1]))
return
def _test_blocking(self, func):
self.dataq.put([self.block_long, EOF_sigil])
self.dataq.join()
start = time.time()
data = func()
self.assertTrue(self.block_short <= time.time() - start)
def test_read_all_B(self):
self._test_blocking(telnetlib.Telnet(HOST, self.port).read_all)
def test_read_all_C(self):
self.dataq.put([EOF_sigil])
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
telnet.read_all()
telnet.read_all() # shouldn't raise
def test_read_some_A(self):
"""
read_some()
Read at least one byte or EOF; may block.
"""
# test 'at least one byte'
want = ['x' * 500, EOF_sigil]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
data = telnet.read_all()
self.assertTrue(len(data) >= 1)
def test_read_some_B(self):
# test EOF
self.dataq.put([EOF_sigil])
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
self.assertEqual('', telnet.read_some())
def test_read_some_C(self):
self._test_blocking(telnetlib.Telnet(HOST, self.port).read_some)
def _test_read_any_eager_A(self, func_name):
"""
read_very_eager()
Read all data available already queued or on the socket,
without blocking.
"""
want = [self.block_long, 'x' * 100, 'y' * 100, EOF_sigil]
expects = want[1] + want[2]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
func = getattr(telnet, func_name)
data = ''
while True:
try:
data += func()
self.assertTrue(expects.startswith(data))
except EOFError:
break
self.assertEqual(expects, data)
def _test_read_any_eager_B(self, func_name):
# test EOF
self.dataq.put([EOF_sigil])
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
time.sleep(self.block_short)
func = getattr(telnet, func_name)
self.assertRaises(EOFError, func)
# read_eager and read_very_eager make the same gaurantees
# (they behave differently but we only test the gaurantees)
def test_read_very_eager_A(self):
self._test_read_any_eager_A('read_very_eager')
def test_read_very_eager_B(self):
self._test_read_any_eager_B('read_very_eager')
def test_read_eager_A(self):
self._test_read_any_eager_A('read_eager')
def test_read_eager_B(self):
self._test_read_any_eager_B('read_eager')
# NB -- we need to test the IAC block which is mentioned in the docstring
# but not in the module docs
def _test_read_any_lazy_B(self, func_name):
self.dataq.put([EOF_sigil])
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
func = getattr(telnet, func_name)
telnet.fill_rawq()
self.assertRaises(EOFError, func)
def test_read_lazy_A(self):
want = ['x' * 100, EOF_sigil]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
time.sleep(self.block_short)
self.assertEqual('', telnet.read_lazy())
data = ''
while True:
try:
read_data = telnet.read_lazy()
data += read_data
if not read_data:
telnet.fill_rawq()
except EOFError:
break
self.assertTrue(want[0].startswith(data))
self.assertEqual(data, want[0])
def test_read_lazy_B(self):
self._test_read_any_lazy_B('read_lazy')
def test_read_very_lazy_A(self):
want = ['x' * 100, EOF_sigil]
self.dataq.put(want)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
time.sleep(self.block_short)
self.assertEqual('', telnet.read_very_lazy())
data = ''
while True:
try:
read_data = telnet.read_very_lazy()
except EOFError:
break
data += read_data
if not read_data:
telnet.fill_rawq()
self.assertEqual('', telnet.cookedq)
telnet.process_rawq()
self.assertTrue(want[0].startswith(data))
self.assertEqual(data, want[0])
def test_read_very_lazy_B(self):
self._test_read_any_lazy_B('read_very_lazy')
class nego_collector(object):
def __init__(self, sb_getter=None):
self.seen = ''
self.sb_getter = sb_getter
self.sb_seen = ''
def do_nego(self, sock, cmd, opt):
self.seen += cmd + opt
if cmd == tl.SE and self.sb_getter:
sb_data = self.sb_getter()
self.sb_seen += sb_data
tl = telnetlib
class OptionTests(TestCase):
setUp = _read_setUp
tearDown = _read_tearDown
# RFC 854 commands
cmds = [tl.AO, tl.AYT, tl.BRK, tl.EC, tl.EL, tl.GA, tl.IP, tl.NOP]
def _test_command(self, data):
""" helper for testing IAC + cmd """
self.setUp()
self.dataq.put(data)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
nego = nego_collector()
telnet.set_option_negotiation_callback(nego.do_nego)
txt = telnet.read_all()
cmd = nego.seen
self.assertTrue(len(cmd) > 0) # we expect at least one command
self.assertIn(cmd[0], self.cmds)
self.assertEqual(cmd[1], tl.NOOPT)
self.assertEqual(len(''.join(data[:-1])), len(txt + cmd))
nego.sb_getter = None # break the nego => telnet cycle
self.tearDown()
def test_IAC_commands(self):
# reset our setup
self.dataq.put([EOF_sigil])
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
self.tearDown()
for cmd in self.cmds:
self._test_command(['x' * 100, tl.IAC + cmd, 'y'*100, EOF_sigil])
self._test_command(['x' * 10, tl.IAC + cmd, 'y'*10, EOF_sigil])
self._test_command([tl.IAC + cmd, EOF_sigil])
# all at once
self._test_command([tl.IAC + cmd for (cmd) in self.cmds] + [EOF_sigil])
self.assertEqual('', telnet.read_sb_data())
def test_SB_commands(self):
# RFC 855, subnegotiations portion
send = [tl.IAC + tl.SB + tl.IAC + tl.SE,
tl.IAC + tl.SB + tl.IAC + tl.IAC + tl.IAC + tl.SE,
tl.IAC + tl.SB + tl.IAC + tl.IAC + 'aa' + tl.IAC + tl.SE,
tl.IAC + tl.SB + 'bb' + tl.IAC + tl.IAC + tl.IAC + tl.SE,
tl.IAC + tl.SB + 'cc' + tl.IAC + tl.IAC + 'dd' + tl.IAC + tl.SE,
EOF_sigil,
]
self.dataq.put(send)
telnet = telnetlib.Telnet(HOST, self.port)
self.dataq.join()
nego = nego_collector(telnet.read_sb_data)
telnet.set_option_negotiation_callback(nego.do_nego)
txt = telnet.read_all()
self.assertEqual(txt, '')
want_sb_data = tl.IAC + tl.IAC + 'aabb' + tl.IAC + 'cc' + tl.IAC + 'dd'
self.assertEqual(nego.sb_seen, want_sb_data)
self.assertEqual('', telnet.read_sb_data())
nego.sb_getter = None # break the nego => telnet cycle
def test_main(verbose=None):
test_support.run_unittest(GeneralTests, ReadTests, OptionTests)
if __name__ == '__main__':
test_main()
# Very rudimentary test of threading module
import test.test_support
from test.test_support import verbose
import random
import re
import sys
thread = test.test_support.import_module('thread')
threading = test.test_support.import_module('threading')
import time
import unittest
import weakref
import os
import subprocess
from test import lock_tests
# A trivial mutable counter.
class Counter(object):
def __init__(self):
self.value = 0
def inc(self):
self.value += 1
def dec(self):
self.value -= 1
def get(self):
return self.value
class TestThread(threading.Thread):
def __init__(self, name, testcase, sema, mutex, nrunning):
threading.Thread.__init__(self, name=name)
self.testcase = testcase
self.sema = sema
self.mutex = mutex
self.nrunning = nrunning
def run(self):
delay = random.random() / 10000.0
if verbose:
print 'task %s will run for %.1f usec' % (
self.name, delay * 1e6)
with self.sema:
with self.mutex:
self.nrunning.inc()
if verbose:
print self.nrunning.get(), 'tasks are running'
self.testcase.assertTrue(self.nrunning.get() <= 3)
time.sleep(delay)
if verbose:
print 'task', self.name, 'done'
with self.mutex:
self.nrunning.dec()
self.testcase.assertTrue(self.nrunning.get() >= 0)
if verbose:
print '%s is finished. %d tasks are running' % (
self.name, self.nrunning.get())
class BaseTestCase(unittest.TestCase):
def setUp(self):
self._threads = test.test_support.threading_setup()
def tearDown(self):
test.test_support.threading_cleanup(*self._threads)
test.test_support.reap_children()
class ThreadTests(BaseTestCase):
# Create a bunch of threads, let each do some work, wait until all are
# done.
def test_various_ops(self):
# This takes about n/3 seconds to run (about n/3 clumps of tasks,
# times about 1 second per clump).
NUMTASKS = 10
# no more than 3 of the 10 can run at once
sema = threading.BoundedSemaphore(value=3)
mutex = threading.RLock()
numrunning = Counter()
threads = []
for i in range(NUMTASKS):
t = TestThread("<thread %d>"%i, self, sema, mutex, numrunning)
threads.append(t)
self.assertEqual(t.ident, None)
self.assertTrue(re.match('<TestThread\(.*, initial\)>', repr(t)))
t.start()
if verbose:
print 'waiting for all tasks to complete'
for t in threads:
t.join(NUMTASKS)
self.assertTrue(not t.is_alive())
self.assertNotEqual(t.ident, 0)
self.assertFalse(t.ident is None)
self.assertTrue(re.match('<TestThread\(.*, \w+ -?\d+\)>', repr(t)))
if verbose:
print 'all tasks done'
self.assertEqual(numrunning.get(), 0)
def test_ident_of_no_threading_threads(self):
# The ident still must work for the main thread and dummy threads.
self.assertFalse(threading.currentThread().ident is None)
def f():
ident.append(threading.currentThread().ident)
done.set()
done = threading.Event()
ident = []
thread.start_new_thread(f, ())
done.wait()
self.assertFalse(ident[0] is None)
# Kill the "immortal" _DummyThread
del threading._active[ident[0]]
# run with a small(ish) thread stack size (256kB)
def test_various_ops_small_stack(self):
if verbose:
print 'with 256kB thread stack size...'
try:
threading.stack_size(262144)
except thread.error:
if verbose:
print 'platform does not support changing thread stack size'
return
self.test_various_ops()
threading.stack_size(0)
# run with a large thread stack size (1MB)
def test_various_ops_large_stack(self):
if verbose:
print 'with 1MB thread stack size...'
try:
threading.stack_size(0x100000)
except thread.error:
if verbose:
print 'platform does not support changing thread stack size'
return
self.test_various_ops()
threading.stack_size(0)
def test_foreign_thread(self):
# Check that a "foreign" thread can use the threading module.
def f(mutex):
# Calling current_thread() forces an entry for the foreign
# thread to get made in the threading._active map.
threading.current_thread()
mutex.release()
mutex = threading.Lock()
mutex.acquire()
tid = thread.start_new_thread(f, (mutex,))
# Wait for the thread to finish.
mutex.acquire()
self.assertIn(tid, threading._active)
self.assertIsInstance(threading._active[tid], threading._DummyThread)
del threading._active[tid]
# PyThreadState_SetAsyncExc() is a CPython-only gimmick, not (currently)
# exposed at the Python level. This test relies on ctypes to get at it.
def test_PyThreadState_SetAsyncExc(self):
try:
import ctypes
except ImportError:
if verbose:
print "test_PyThreadState_SetAsyncExc can't import ctypes"
return # can't do anything
set_async_exc = ctypes.pythonapi.PyThreadState_SetAsyncExc
class AsyncExc(Exception):
pass
exception = ctypes.py_object(AsyncExc)
# First check it works when setting the exception from the same thread.
tid = thread.get_ident()
try:
result = set_async_exc(ctypes.c_long(tid), exception)
# The exception is async, so we might have to keep the VM busy until
# it notices.
while True:
pass
except AsyncExc:
pass
else:
# This code is unreachable but it reflects the intent. If we wanted
# to be smarter the above loop wouldn't be infinite.
self.fail("AsyncExc not raised")
try:
self.assertEqual(result, 1) # one thread state modified
except UnboundLocalError:
# The exception was raised too quickly for us to get the result.
pass
# `worker_started` is set by the thread when it's inside a try/except
# block waiting to catch the asynchronously set AsyncExc exception.
# `worker_saw_exception` is set by the thread upon catching that
# exception.
worker_started = threading.Event()
worker_saw_exception = threading.Event()
class Worker(threading.Thread):
def run(self):
self.id = thread.get_ident()
self.finished = False
try:
while True:
worker_started.set()
time.sleep(0.1)
except AsyncExc:
self.finished = True
worker_saw_exception.set()
t = Worker()
t.daemon = True # so if this fails, we don't hang Python at shutdown
t.start()
if verbose:
print " started worker thread"
# Try a thread id that doesn't make sense.
if verbose:
print " trying nonsensical thread id"
result = set_async_exc(ctypes.c_long(-1), exception)
self.assertEqual(result, 0) # no thread states modified
# Now raise an exception in the worker thread.
if verbose:
print " waiting for worker thread to get started"
ret = worker_started.wait()
self.assertTrue(ret)
if verbose:
print " verifying worker hasn't exited"
self.assertTrue(not t.finished)
if verbose:
print " attempting to raise asynch exception in worker"
result = set_async_exc(ctypes.c_long(t.id), exception)
self.assertEqual(result, 1) # one thread state modified
if verbose:
print " waiting for worker to say it caught the exception"
worker_saw_exception.wait(timeout=10)
self.assertTrue(t.finished)
if verbose:
print " all OK -- joining worker"
if t.finished:
t.join()
# else the thread is still running, and we have no way to kill it
def test_limbo_cleanup(self):
# Issue 7481: Failure to start thread should cleanup the limbo map.
def fail_new_thread(*args):
raise thread.error()
_start_new_thread = threading._start_new_thread
threading._start_new_thread = fail_new_thread
try:
t = threading.Thread(target=lambda: None)
self.assertRaises(thread.error, t.start)
self.assertFalse(
t in threading._limbo,
"Failed to cleanup _limbo map on failure of Thread.start().")
finally:
threading._start_new_thread = _start_new_thread
def test_finalize_runnning_thread(self):
# Issue 1402: the PyGILState_Ensure / _Release functions may be called
# very late on python exit: on deallocation of a running thread for
# example.
try:
import ctypes
except ImportError:
if verbose:
print("test_finalize_with_runnning_thread can't import ctypes")
return # can't do anything
rc = subprocess.call([sys.executable, "-c", """if 1:
import ctypes, sys, time, thread
# This lock is used as a simple event variable.
ready = thread.allocate_lock()
ready.acquire()
# Module globals are cleared before __del__ is run
# So we save the functions in class dict
class C:
ensure = ctypes.pythonapi.PyGILState_Ensure
release = ctypes.pythonapi.PyGILState_Release
def __del__(self):
state = self.ensure()
self.release(state)
def waitingThread():
x = C()
ready.release()
time.sleep(100)
thread.start_new_thread(waitingThread, ())
ready.acquire() # Be sure the other thread is waiting.
sys.exit(42)
"""])
self.assertEqual(rc, 42)
def test_finalize_with_trace(self):
# Issue1733757
# Avoid a deadlock when sys.settrace steps into threading._shutdown
p = subprocess.Popen([sys.executable, "-c", """if 1:
import sys, threading
# A deadlock-killer, to prevent the
# testsuite to hang forever
def killer():
import os, time
time.sleep(2)
print 'program blocked; aborting'
os._exit(2)
t = threading.Thread(target=killer)
t.daemon = True
t.start()
# This is the trace function
def func(frame, event, arg):
threading.current_thread()
return func
sys.settrace(func)
"""],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
stdout, stderr = p.communicate()
rc = p.returncode
self.assertFalse(rc == 2, "interpreted was blocked")
self.assertTrue(rc == 0,
"Unexpected error: " + repr(stderr))
def test_join_nondaemon_on_shutdown(self):
# Issue 1722344
# Raising SystemExit skipped threading._shutdown
p = subprocess.Popen([sys.executable, "-c", """if 1:
import threading
from time import sleep
def child():
sleep(1)
# As a non-daemon thread we SHOULD wake up and nothing
# should be torn down yet
print "Woke up, sleep function is:", sleep
threading.Thread(target=child).start()
raise SystemExit
"""],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
self.addCleanup(p.stdout.close)
self.addCleanup(p.stderr.close)
stdout, stderr = p.communicate()
self.assertEqual(stdout.strip(),
"Woke up, sleep function is: <built-in function sleep>")
stderr = re.sub(r"^\[\d+ refs\]", "", stderr, re.MULTILINE).strip()
self.assertEqual(stderr, "")
def test_enumerate_after_join(self):
# Try hard to trigger #1703448: a thread is still returned in
# threading.enumerate() after it has been join()ed.
enum = threading.enumerate
old_interval = sys.getcheckinterval()
try:
for i in xrange(1, 100):
# Try a couple times at each thread-switching interval
# to get more interleavings.
sys.setcheckinterval(i // 5)
t = threading.Thread(target=lambda: None)
t.start()
t.join()
l = enum()
self.assertNotIn(t, l,
"#1703448 triggered after %d trials: %s" % (i, l))
finally:
sys.setcheckinterval(old_interval)
def test_no_refcycle_through_target(self):
class RunSelfFunction(object):
def __init__(self, should_raise):
# The links in this refcycle from Thread back to self
# should be cleaned up when the thread completes.
self.should_raise = should_raise
self.thread = threading.Thread(target=self._run,
args=(self,),
kwargs={'yet_another':self})
self.thread.start()
def _run(self, other_ref, yet_another):
if self.should_raise:
raise SystemExit
cyclic_object = RunSelfFunction(should_raise=False)
weak_cyclic_object = weakref.ref(cyclic_object)
cyclic_object.thread.join()
del cyclic_object
self.assertEqual(None, weak_cyclic_object(),
msg=('%d references still around' %
sys.getrefcount(weak_cyclic_object())))
raising_cyclic_object = RunSelfFunction(should_raise=True)
weak_raising_cyclic_object = weakref.ref(raising_cyclic_object)
raising_cyclic_object.thread.join()
del raising_cyclic_object
self.assertEqual(None, weak_raising_cyclic_object(),
msg=('%d references still around' %
sys.getrefcount(weak_raising_cyclic_object())))
class ThreadJoinOnShutdown(BaseTestCase):
# Between fork() and exec(), only async-safe functions are allowed (issues
# #12316 and #11870), and fork() from a worker thread is known to trigger
# problems with some operating systems (issue #3863): skip problematic tests
# on platforms known to behave badly.
platforms_to_skip = ('freebsd4', 'freebsd5', 'freebsd6', 'netbsd5',
'os2emx')
def _run_and_join(self, script):
script = """if 1:
import sys, os, time, threading
# a thread, which waits for the main program to terminate
def joiningfunc(mainthread):
mainthread.join()
print 'end of thread'
\n""" + script
p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE)
rc = p.wait()
data = p.stdout.read().replace('\r', '')
p.stdout.close()
self.assertEqual(data, "end of main\nend of thread\n")
self.assertFalse(rc == 2, "interpreter was blocked")
self.assertTrue(rc == 0, "Unexpected error")
def test_1_join_on_shutdown(self):
# The usual case: on exit, wait for a non-daemon thread
script = """if 1:
import os
t = threading.Thread(target=joiningfunc,
args=(threading.current_thread(),))
t.start()
time.sleep(0.1)
print 'end of main'
"""
self._run_and_join(script)
@unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
def test_2_join_in_forked_process(self):
# Like the test above, but from a forked interpreter
script = """if 1:
childpid = os.fork()
if childpid != 0:
os.waitpid(childpid, 0)
sys.exit(0)
t = threading.Thread(target=joiningfunc,
args=(threading.current_thread(),))
t.start()
print 'end of main'
"""
self._run_and_join(script)
@unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
def test_3_join_in_forked_from_thread(self):
# Like the test above, but fork() was called from a worker thread
# In the forked process, the main Thread object must be marked as stopped.
script = """if 1:
main_thread = threading.current_thread()
def worker():
childpid = os.fork()
if childpid != 0:
os.waitpid(childpid, 0)
sys.exit(0)
t = threading.Thread(target=joiningfunc,
args=(main_thread,))
print 'end of main'
t.start()
t.join() # Should not block: main_thread is already stopped
w = threading.Thread(target=worker)
w.start()
"""
self._run_and_join(script)
def assertScriptHasOutput(self, script, expected_output):
p = subprocess.Popen([sys.executable, "-c", script],
stdout=subprocess.PIPE)
rc = p.wait()
data = p.stdout.read().decode().replace('\r', '')
self.assertEqual(rc, 0, "Unexpected error")
self.assertEqual(data, expected_output)
@unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
def test_4_joining_across_fork_in_worker_thread(self):
# There used to be a possible deadlock when forking from a child
# thread. See http://bugs.python.org/issue6643.
# The script takes the following steps:
# - The main thread in the parent process starts a new thread and then
# tries to join it.
# - The join operation acquires the Lock inside the thread's _block
# Condition. (See threading.py:Thread.join().)
# - We stub out the acquire method on the condition to force it to wait
# until the child thread forks. (See LOCK ACQUIRED HERE)
# - The child thread forks. (See LOCK HELD and WORKER THREAD FORKS
# HERE)
# - The main thread of the parent process enters Condition.wait(),
# which releases the lock on the child thread.
# - The child process returns. Without the necessary fix, when the
# main thread of the child process (which used to be the child thread
# in the parent process) attempts to exit, it will try to acquire the
# lock in the Thread._block Condition object and hang, because the
# lock was held across the fork.
script = """if 1:
import os, time, threading
finish_join = False
start_fork = False
def worker():
# Wait until this thread's lock is acquired before forking to
# create the deadlock.
global finish_join
while not start_fork:
time.sleep(0.01)
# LOCK HELD: Main thread holds lock across this call.
childpid = os.fork()
finish_join = True
if childpid != 0:
# Parent process just waits for child.
os.waitpid(childpid, 0)
# Child process should just return.
w = threading.Thread(target=worker)
# Stub out the private condition variable's lock acquire method.
# This acquires the lock and then waits until the child has forked
# before returning, which will release the lock soon after. If
# someone else tries to fix this test case by acquiring this lock
# before forking instead of resetting it, the test case will
# deadlock when it shouldn't.
condition = w._block
orig_acquire = condition.acquire
call_count_lock = threading.Lock()
call_count = 0
def my_acquire():
global call_count
global start_fork
orig_acquire() # LOCK ACQUIRED HERE
start_fork = True
if call_count == 0:
while not finish_join:
time.sleep(0.01) # WORKER THREAD FORKS HERE
with call_count_lock:
call_count += 1
condition.acquire = my_acquire
w.start()
w.join()
print('end of main')
"""
self.assertScriptHasOutput(script, "end of main\n")
@unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
def test_5_clear_waiter_locks_to_avoid_crash(self):
# Check that a spawned thread that forks doesn't segfault on certain
# platforms, namely OS X. This used to happen if there was a waiter
# lock in the thread's condition variable's waiters list. Even though
# we know the lock will be held across the fork, it is not safe to
# release locks held across forks on all platforms, so releasing the
# waiter lock caused a segfault on OS X. Furthermore, since locks on
# OS X are (as of this writing) implemented with a mutex + condition
# variable instead of a semaphore, while we know that the Python-level
# lock will be acquired, we can't know if the internal mutex will be
# acquired at the time of the fork.
script = """if True:
import os, time, threading
start_fork = False
def worker():
# Wait until the main thread has attempted to join this thread
# before continuing.
while not start_fork:
time.sleep(0.01)
childpid = os.fork()
if childpid != 0:
# Parent process just waits for child.
(cpid, rc) = os.waitpid(childpid, 0)
assert cpid == childpid
assert rc == 0
print('end of worker thread')
else:
# Child process should just return.
pass
w = threading.Thread(target=worker)
# Stub out the private condition variable's _release_save method.
# This releases the condition's lock and flips the global that
# causes the worker to fork. At this point, the problematic waiter
# lock has been acquired once by the waiter and has been put onto
# the waiters list.
condition = w._block
orig_release_save = condition._release_save
def my_release_save():
global start_fork
orig_release_save()
# Waiter lock held here, condition lock released.
start_fork = True
condition._release_save = my_release_save
w.start()
w.join()
print('end of main thread')
"""
output = "end of worker thread\nend of main thread\n"
self.assertScriptHasOutput(script, output)
@unittest.skipUnless(hasattr(os, 'fork'), "needs os.fork()")
@unittest.skipIf(sys.platform in platforms_to_skip, "due to known OS bug")
def test_reinit_tls_after_fork(self):
# Issue #13817: fork() would deadlock in a multithreaded program with
# the ad-hoc TLS implementation.
def do_fork_and_wait():
# just fork a child process and wait it
pid = os.fork()
if pid > 0:
os.waitpid(pid, 0)
else:
os._exit(0)
# start a bunch of threads that will fork() child processes
threads = []
for i in range(16):
t = threading.Thread(target=do_fork_and_wait)
threads.append(t)
t.start()
for t in threads:
t.join()
class ThreadingExceptionTests(BaseTestCase):
# A RuntimeError should be raised if Thread.start() is called
# multiple times.
def test_start_thread_again(self):
thread = threading.Thread()
thread.start()
self.assertRaises(RuntimeError, thread.start)
def test_joining_current_thread(self):
current_thread = threading.current_thread()
self.assertRaises(RuntimeError, current_thread.join);
def test_joining_inactive_thread(self):
thread = threading.Thread()
self.assertRaises(RuntimeError, thread.join)
def test_daemonize_active_thread(self):
thread = threading.Thread()
thread.start()
self.assertRaises(RuntimeError, setattr, thread, "daemon", True)
class LockTests(lock_tests.LockTests):
locktype = staticmethod(threading.Lock)
class RLockTests(lock_tests.RLockTests):
locktype = staticmethod(threading.RLock)
class EventTests(lock_tests.EventTests):
eventtype = staticmethod(threading.Event)
class ConditionAsRLockTests(lock_tests.RLockTests):
# An Condition uses an RLock by default and exports its API.
locktype = staticmethod(threading.Condition)
class ConditionTests(lock_tests.ConditionTests):
condtype = staticmethod(threading.Condition)
class SemaphoreTests(lock_tests.SemaphoreTests):
semtype = staticmethod(threading.Semaphore)
class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests):
semtype = staticmethod(threading.BoundedSemaphore)
@unittest.skipUnless(sys.platform == 'darwin', 'test macosx problem')
def test_recursion_limit(self):
# Issue 9670
# test that excessive recursion within a non-main thread causes
# an exception rather than crashing the interpreter on platforms
# like Mac OS X or FreeBSD which have small default stack sizes
# for threads
script = """if True:
import threading
def recurse():
return recurse()
def outer():
try:
recurse()
except RuntimeError:
pass
w = threading.Thread(target=outer)
w.start()
w.join()
print('end of main thread')
"""
expected_output = "end of main thread\n"
p = subprocess.Popen([sys.executable, "-c", script],
stdout=subprocess.PIPE)
stdout, stderr = p.communicate()
data = stdout.decode().replace('\r', '')
self.assertEqual(p.returncode, 0, "Unexpected error")
self.assertEqual(data, expected_output)
def test_main():
test.test_support.run_unittest(LockTests, RLockTests, EventTests,
ConditionAsRLockTests, ConditionTests,
SemaphoreTests, BoundedSemaphoreTests,
ThreadTests,
ThreadJoinOnShutdown,
ThreadingExceptionTests,
)
if __name__ == "__main__":
test_main()
import unittest
from doctest import DocTestSuite
from test import test_support
import weakref
import gc
# Modules under test
_thread = test_support.import_module('thread')
threading = test_support.import_module('threading')
import _threading_local
class Weak(object):
pass
def target(local, weaklist):
weak = Weak()
local.weak = weak
weaklist.append(weakref.ref(weak))
class BaseLocalTest:
def test_local_refs(self):
self._local_refs(20)
self._local_refs(50)
self._local_refs(100)
def _local_refs(self, n):
local = self._local()
weaklist = []
for i in range(n):
t = threading.Thread(target=target, args=(local, weaklist))
t.start()
t.join()
del t
gc.collect()
self.assertEqual(len(weaklist), n)
# XXX _threading_local keeps the local of the last stopped thread alive.
deadlist = [weak for weak in weaklist if weak() is None]
self.assertIn(len(deadlist), (n-1, n))
# Assignment to the same thread local frees it sometimes (!)
local.someothervar = None
gc.collect()
deadlist = [weak for weak in weaklist if weak() is None]
self.assertIn(len(deadlist), (n-1, n), (n, len(deadlist)))
def test_derived(self):
# Issue 3088: if there is a threads switch inside the __init__
# of a threading.local derived class, the per-thread dictionary
# is created but not correctly set on the object.
# The first member set may be bogus.
import time
class Local(self._local):
def __init__(self):
time.sleep(0.01)
local = Local()
def f(i):
local.x = i
# Simply check that the variable is correctly set
self.assertEqual(local.x, i)
threads= []
for i in range(10):
t = threading.Thread(target=f, args=(i,))
t.start()
threads.append(t)
for t in threads:
t.join()
def test_derived_cycle_dealloc(self):
# http://bugs.python.org/issue6990
class Local(self._local):
pass
locals = None
passed = [False]
e1 = threading.Event()
e2 = threading.Event()
def f():
# 1) Involve Local in a cycle
cycle = [Local()]
cycle.append(cycle)
cycle[0].foo = 'bar'
# 2) GC the cycle (triggers threadmodule.c::local_clear
# before local_dealloc)
del cycle
gc.collect()
e1.set()
e2.wait()
# 4) New Locals should be empty
passed[0] = all(not hasattr(local, 'foo') for local in locals)
t = threading.Thread(target=f)
t.start()
e1.wait()
# 3) New Locals should recycle the original's address. Creating
# them in the thread overwrites the thread state and avoids the
# bug
locals = [Local() for i in range(10)]
e2.set()
t.join()
self.assertTrue(passed[0])
def test_arguments(self):
# Issue 1522237
from thread import _local as local
from _threading_local import local as py_local
for cls in (local, py_local):
class MyLocal(cls):
def __init__(self, *args, **kwargs):
pass
MyLocal(a=1)
MyLocal(1)
self.assertRaises(TypeError, cls, a=1)
self.assertRaises(TypeError, cls, 1)
def _test_one_class(self, c):
self._failed = "No error message set or cleared."
obj = c()
e1 = threading.Event()
e2 = threading.Event()
def f1():
obj.x = 'foo'
obj.y = 'bar'
del obj.y
e1.set()
e2.wait()
def f2():
try:
foo = obj.x
except AttributeError:
# This is expected -- we haven't set obj.x in this thread yet!
self._failed = "" # passed
else:
self._failed = ('Incorrectly got value %r from class %r\n' %
(foo, c))
sys.stderr.write(self._failed)
t1 = threading.Thread(target=f1)
t1.start()
e1.wait()
t2 = threading.Thread(target=f2)
t2.start()
t2.join()
# The test is done; just let t1 know it can exit, and wait for it.
e2.set()
t1.join()
self.assertFalse(self._failed, self._failed)
def test_threading_local(self):
self._test_one_class(self._local)
def test_threading_local_subclass(self):
class LocalSubclass(self._local):
"""To test that subclasses behave properly."""
self._test_one_class(LocalSubclass)
def _test_dict_attribute(self, cls):
obj = cls()
obj.x = 5
self.assertEqual(obj.__dict__, {'x': 5})
with self.assertRaises(AttributeError):
obj.__dict__ = {}
with self.assertRaises(AttributeError):
del obj.__dict__
def test_dict_attribute(self):
self._test_dict_attribute(self._local)
def test_dict_attribute_subclass(self):
class LocalSubclass(self._local):
"""To test that subclasses behave properly."""
self._test_dict_attribute(LocalSubclass)
class ThreadLocalTest(unittest.TestCase, BaseLocalTest):
_local = _thread._local
# Fails for the pure Python implementation
def test_cycle_collection(self):
class X:
pass
x = X()
x.local = self._local()
x.local.x = x
wr = weakref.ref(x)
del x
gc.collect()
self.assertIs(wr(), None)
class PyThreadingLocalTest(unittest.TestCase, BaseLocalTest):
_local = _threading_local.local
def test_main():
suite = unittest.TestSuite()
suite.addTest(DocTestSuite('_threading_local'))
suite.addTest(unittest.makeSuite(ThreadLocalTest))
suite.addTest(unittest.makeSuite(PyThreadingLocalTest))
try:
from thread import _local
except ImportError:
pass
else:
import _threading_local
local_orig = _threading_local.local
def setUp(test):
_threading_local.local = _local
def tearDown(test):
_threading_local.local = local_orig
suite.addTest(DocTestSuite('_threading_local',
setUp=setUp, tearDown=tearDown)
)
test_support.run_unittest(suite)
if __name__ == '__main__':
test_main()
"""Unit tests for socket timeout feature."""
import unittest
from test import test_support
# This requires the 'network' resource as given on the regrtest command line.
skip_expected = not test_support.is_resource_enabled('network')
import time
import socket
class CreationTestCase(unittest.TestCase):
"""Test case for socket.gettimeout() and socket.settimeout()"""
def setUp(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def tearDown(self):
self.sock.close()
def testObjectCreation(self):
# Test Socket creation
self.assertEqual(self.sock.gettimeout(), None,
"timeout not disabled by default")
def testFloatReturnValue(self):
# Test return value of gettimeout()
self.sock.settimeout(7.345)
self.assertEqual(self.sock.gettimeout(), 7.345)
self.sock.settimeout(3)
self.assertEqual(self.sock.gettimeout(), 3)
self.sock.settimeout(None)
self.assertEqual(self.sock.gettimeout(), None)
def testReturnType(self):
# Test return type of gettimeout()
self.sock.settimeout(1)
self.assertEqual(type(self.sock.gettimeout()), type(1.0))
self.sock.settimeout(3.9)
self.assertEqual(type(self.sock.gettimeout()), type(1.0))
def testTypeCheck(self):
# Test type checking by settimeout()
self.sock.settimeout(0)
self.sock.settimeout(0L)
self.sock.settimeout(0.0)
self.sock.settimeout(None)
self.assertRaises(TypeError, self.sock.settimeout, "")
self.assertRaises(TypeError, self.sock.settimeout, u"")
self.assertRaises(TypeError, self.sock.settimeout, ())
self.assertRaises(TypeError, self.sock.settimeout, [])
self.assertRaises(TypeError, self.sock.settimeout, {})
self.assertRaises(TypeError, self.sock.settimeout, 0j)
def testRangeCheck(self):
# Test range checking by settimeout()
self.assertRaises(ValueError, self.sock.settimeout, -1)
self.assertRaises(ValueError, self.sock.settimeout, -1L)
self.assertRaises(ValueError, self.sock.settimeout, -1.0)
def testTimeoutThenBlocking(self):
# Test settimeout() followed by setblocking()
self.sock.settimeout(10)
self.sock.setblocking(1)
self.assertEqual(self.sock.gettimeout(), None)
self.sock.setblocking(0)
self.assertEqual(self.sock.gettimeout(), 0.0)
self.sock.settimeout(10)
self.sock.setblocking(0)
self.assertEqual(self.sock.gettimeout(), 0.0)
self.sock.setblocking(1)
self.assertEqual(self.sock.gettimeout(), None)
def testBlockingThenTimeout(self):
# Test setblocking() followed by settimeout()
self.sock.setblocking(0)
self.sock.settimeout(1)
self.assertEqual(self.sock.gettimeout(), 1)
self.sock.setblocking(1)
self.sock.settimeout(1)
self.assertEqual(self.sock.gettimeout(), 1)
class TimeoutTestCase(unittest.TestCase):
"""Test case for socket.socket() timeout functions"""
# There are a number of tests here trying to make sure that an operation
# doesn't take too much longer than expected. But competing machine
# activity makes it inevitable that such tests will fail at times.
# When fuzz was at 1.0, I (tim) routinely saw bogus failures on Win2K
# and Win98SE. Boosting it to 2.0 helped a lot, but isn't a real
# solution.
fuzz = 2.0
def setUp(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.addr_remote = ('www.python.org.', 80)
self.localhost = '127.0.0.1'
def tearDown(self):
self.sock.close()
def testConnectTimeout(self):
# Choose a private address that is unlikely to exist to prevent
# failures due to the connect succeeding before the timeout.
# Use a dotted IP address to avoid including the DNS lookup time
# with the connect time. This avoids failing the assertion that
# the timeout occurred fast enough.
addr = ('10.0.0.0', 12345)
# Test connect() timeout
_timeout = 0.001
self.sock.settimeout(_timeout)
_t1 = time.time()
self.assertRaises(socket.error, self.sock.connect, addr)
_t2 = time.time()
_delta = abs(_t1 - _t2)
self.assertTrue(_delta < _timeout + self.fuzz,
"timeout (%g) is more than %g seconds more than expected (%g)"
%(_delta, self.fuzz, _timeout))
def testRecvTimeout(self):
# Test recv() timeout
_timeout = 0.02
with test_support.transient_internet(self.addr_remote[0]):
self.sock.connect(self.addr_remote)
self.sock.settimeout(_timeout)
_t1 = time.time()
self.assertRaises(socket.timeout, self.sock.recv, 1024)
_t2 = time.time()
_delta = abs(_t1 - _t2)
self.assertTrue(_delta < _timeout + self.fuzz,
"timeout (%g) is %g seconds more than expected (%g)"
%(_delta, self.fuzz, _timeout))
def testAcceptTimeout(self):
# Test accept() timeout
_timeout = 2
self.sock.settimeout(_timeout)
# Prevent "Address already in use" socket exceptions
test_support.bind_port(self.sock, self.localhost)
self.sock.listen(5)
_t1 = time.time()
self.assertRaises(socket.error, self.sock.accept)
_t2 = time.time()
_delta = abs(_t1 - _t2)
self.assertTrue(_delta < _timeout + self.fuzz,
"timeout (%g) is %g seconds more than expected (%g)"
%(_delta, self.fuzz, _timeout))
def testRecvfromTimeout(self):
# Test recvfrom() timeout
_timeout = 2
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(_timeout)
# Prevent "Address already in use" socket exceptions
test_support.bind_port(self.sock, self.localhost)
_t1 = time.time()
self.assertRaises(socket.error, self.sock.recvfrom, 8192)
_t2 = time.time()
_delta = abs(_t1 - _t2)
self.assertTrue(_delta < _timeout + self.fuzz,
"timeout (%g) is %g seconds more than expected (%g)"
%(_delta, self.fuzz, _timeout))
def testSend(self):
# Test send() timeout
# couldn't figure out how to test it
pass
def testSendto(self):
# Test sendto() timeout
# couldn't figure out how to test it
pass
def testSendall(self):
# Test sendall() timeout
# couldn't figure out how to test it
pass
def test_main():
test_support.requires('network')
test_support.run_unittest(CreationTestCase, TimeoutTestCase)
if __name__ == "__main__":
test_main()
import unittest
from test import test_support
import os
import socket
import StringIO
import urllib2
from urllib2 import Request, OpenerDirector
# XXX
# Request
# CacheFTPHandler (hard to write)
# parse_keqv_list, parse_http_list, HTTPDigestAuthHandler
class TrivialTests(unittest.TestCase):
def test_trivial(self):
# A couple trivial tests
self.assertRaises(ValueError, urllib2.urlopen, 'bogus url')
# XXX Name hacking to get this to work on Windows.
fname = os.path.abspath(urllib2.__file__).replace('\\', '/')
# And more hacking to get it to work on MacOS. This assumes
# urllib.pathname2url works, unfortunately...
if os.name == 'riscos':
import string
fname = os.expand(fname)
fname = fname.translate(string.maketrans("/.", "./"))
if os.name == 'nt':
file_url = "file:///%s" % fname
else:
file_url = "file://%s" % fname
f = urllib2.urlopen(file_url)
buf = f.read()
f.close()
def test_parse_http_list(self):
tests = [('a,b,c', ['a', 'b', 'c']),
('path"o,l"og"i"cal, example', ['path"o,l"og"i"cal', 'example']),
('a, b, "c", "d", "e,f", g, h', ['a', 'b', '"c"', '"d"', '"e,f"', 'g', 'h']),
('a="b\\"c", d="e\\,f", g="h\\\\i"', ['a="b"c"', 'd="e,f"', 'g="h\\i"'])]
for string, list in tests:
self.assertEqual(urllib2.parse_http_list(string), list)
def test_request_headers_dict():
"""
The Request.headers dictionary is not a documented interface. It should
stay that way, because the complete set of headers are only accessible
through the .get_header(), .has_header(), .header_items() interface.
However, .headers pre-dates those methods, and so real code will be using
the dictionary.
The introduction in 2.4 of those methods was a mistake for the same reason:
code that previously saw all (urllib2 user)-provided headers in .headers
now sees only a subset (and the function interface is ugly and incomplete).
A better change would have been to replace .headers dict with a dict
subclass (or UserDict.DictMixin instance?) that preserved the .headers
interface and also provided access to the "unredirected" headers. It's
probably too late to fix that, though.
Check .capitalize() case normalization:
>>> url = "http://example.com"
>>> Request(url, headers={"Spam-eggs": "blah"}).headers["Spam-eggs"]
'blah'
>>> Request(url, headers={"spam-EggS": "blah"}).headers["Spam-eggs"]
'blah'
Currently, Request(url, "Spam-eggs").headers["Spam-Eggs"] raises KeyError,
but that could be changed in future.
"""
def test_request_headers_methods():
"""
Note the case normalization of header names here, to .capitalize()-case.
This should be preserved for backwards-compatibility. (In the HTTP case,
normalization to .title()-case is done by urllib2 before sending headers to
httplib).
>>> url = "http://example.com"
>>> r = Request(url, headers={"Spam-eggs": "blah"})
>>> r.has_header("Spam-eggs")
True
>>> r.header_items()
[('Spam-eggs', 'blah')]
>>> r.add_header("Foo-Bar", "baz")
>>> items = r.header_items()
>>> items.sort()
>>> items
[('Foo-bar', 'baz'), ('Spam-eggs', 'blah')]
Note that e.g. r.has_header("spam-EggS") is currently False, and
r.get_header("spam-EggS") returns None, but that could be changed in
future.
>>> r.has_header("Not-there")
False
>>> print r.get_header("Not-there")
None
>>> r.get_header("Not-there", "default")
'default'
"""
def test_password_manager(self):
"""
>>> mgr = urllib2.HTTPPasswordMgr()
>>> add = mgr.add_password
>>> add("Some Realm", "http://example.com/", "joe", "password")
>>> add("Some Realm", "http://example.com/ni", "ni", "ni")
>>> add("c", "http://example.com/foo", "foo", "ni")
>>> add("c", "http://example.com/bar", "bar", "nini")
>>> add("b", "http://example.com/", "first", "blah")
>>> add("b", "http://example.com/", "second", "spam")
>>> add("a", "http://example.com", "1", "a")
>>> add("Some Realm", "http://c.example.com:3128", "3", "c")
>>> add("Some Realm", "d.example.com", "4", "d")
>>> add("Some Realm", "e.example.com:3128", "5", "e")
>>> mgr.find_user_password("Some Realm", "example.com")
('joe', 'password')
>>> mgr.find_user_password("Some Realm", "http://example.com")
('joe', 'password')
>>> mgr.find_user_password("Some Realm", "http://example.com/")
('joe', 'password')
>>> mgr.find_user_password("Some Realm", "http://example.com/spam")
('joe', 'password')
>>> mgr.find_user_password("Some Realm", "http://example.com/spam/spam")
('joe', 'password')
>>> mgr.find_user_password("c", "http://example.com/foo")
('foo', 'ni')
>>> mgr.find_user_password("c", "http://example.com/bar")
('bar', 'nini')
Actually, this is really undefined ATM
## Currently, we use the highest-level path where more than one match:
## >>> mgr.find_user_password("Some Realm", "http://example.com/ni")
## ('joe', 'password')
Use latest add_password() in case of conflict:
>>> mgr.find_user_password("b", "http://example.com/")
('second', 'spam')
No special relationship between a.example.com and example.com:
>>> mgr.find_user_password("a", "http://example.com/")
('1', 'a')
>>> mgr.find_user_password("a", "http://a.example.com/")
(None, None)
Ports:
>>> mgr.find_user_password("Some Realm", "c.example.com")
(None, None)
>>> mgr.find_user_password("Some Realm", "c.example.com:3128")
('3', 'c')
>>> mgr.find_user_password("Some Realm", "http://c.example.com:3128")
('3', 'c')
>>> mgr.find_user_password("Some Realm", "d.example.com")
('4', 'd')
>>> mgr.find_user_password("Some Realm", "e.example.com:3128")
('5', 'e')
"""
pass
def test_password_manager_default_port(self):
"""
>>> mgr = urllib2.HTTPPasswordMgr()
>>> add = mgr.add_password
The point to note here is that we can't guess the default port if there's
no scheme. This applies to both add_password and find_user_password.
>>> add("f", "http://g.example.com:80", "10", "j")
>>> add("g", "http://h.example.com", "11", "k")
>>> add("h", "i.example.com:80", "12", "l")
>>> add("i", "j.example.com", "13", "m")
>>> mgr.find_user_password("f", "g.example.com:100")
(None, None)
>>> mgr.find_user_password("f", "g.example.com:80")
('10', 'j')
>>> mgr.find_user_password("f", "g.example.com")
(None, None)
>>> mgr.find_user_password("f", "http://g.example.com:100")
(None, None)
>>> mgr.find_user_password("f", "http://g.example.com:80")
('10', 'j')
>>> mgr.find_user_password("f", "http://g.example.com")
('10', 'j')
>>> mgr.find_user_password("g", "h.example.com")
('11', 'k')
>>> mgr.find_user_password("g", "h.example.com:80")
('11', 'k')
>>> mgr.find_user_password("g", "http://h.example.com:80")
('11', 'k')
>>> mgr.find_user_password("h", "i.example.com")
(None, None)
>>> mgr.find_user_password("h", "i.example.com:80")
('12', 'l')
>>> mgr.find_user_password("h", "http://i.example.com:80")
('12', 'l')
>>> mgr.find_user_password("i", "j.example.com")
('13', 'm')
>>> mgr.find_user_password("i", "j.example.com:80")
(None, None)
>>> mgr.find_user_password("i", "http://j.example.com")
('13', 'm')
>>> mgr.find_user_password("i", "http://j.example.com:80")
(None, None)
"""
class MockOpener:
addheaders = []
def open(self, req, data=None,timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
self.req, self.data, self.timeout = req, data, timeout
def error(self, proto, *args):
self.proto, self.args = proto, args
class MockFile:
def read(self, count=None): pass
def readline(self, count=None): pass
def close(self): pass
class MockHeaders(dict):
def getheaders(self, name):
return self.values()
class MockResponse(StringIO.StringIO):
def __init__(self, code, msg, headers, data, url=None):
StringIO.StringIO.__init__(self, data)
self.code, self.msg, self.headers, self.url = code, msg, headers, url
def info(self):
return self.headers
def geturl(self):
return self.url
class MockCookieJar:
def add_cookie_header(self, request):
self.ach_req = request
def extract_cookies(self, response, request):
self.ec_req, self.ec_r = request, response
class FakeMethod:
def __init__(self, meth_name, action, handle):
self.meth_name = meth_name
self.handle = handle
self.action = action
def __call__(self, *args):
return self.handle(self.meth_name, self.action, *args)
class MockHTTPResponse:
def __init__(self, fp, msg, status, reason):
self.fp = fp
self.msg = msg
self.status = status
self.reason = reason
def read(self):
return ''
class MockHTTPClass:
def __init__(self):
self.req_headers = []
self.data = None
self.raise_on_endheaders = False
self._tunnel_headers = {}
def __call__(self, host, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
self.host = host
self.timeout = timeout
return self
def set_debuglevel(self, level):
self.level = level
def set_tunnel(self, host, port=None, headers=None):
self._tunnel_host = host
self._tunnel_port = port
if headers:
self._tunnel_headers = headers
else:
self._tunnel_headers.clear()
def request(self, method, url, body=None, headers=None):
self.method = method
self.selector = url
if headers is not None:
self.req_headers += headers.items()
self.req_headers.sort()
if body:
self.data = body
if self.raise_on_endheaders:
import socket
raise socket.error()
def getresponse(self):
return MockHTTPResponse(MockFile(), {}, 200, "OK")
def close(self):
pass
class MockHandler:
# useful for testing handler machinery
# see add_ordered_mock_handlers() docstring
handler_order = 500
def __init__(self, methods):
self._define_methods(methods)
def _define_methods(self, methods):
for spec in methods:
if len(spec) == 2: name, action = spec
else: name, action = spec, None
meth = FakeMethod(name, action, self.handle)
setattr(self.__class__, name, meth)
def handle(self, fn_name, action, *args, **kwds):
self.parent.calls.append((self, fn_name, args, kwds))
if action is None:
return None
elif action == "return self":
return self
elif action == "return response":
res = MockResponse(200, "OK", {}, "")
return res
elif action == "return request":
return Request("http://blah/")
elif action.startswith("error"):
code = action[action.rfind(" ")+1:]
try:
code = int(code)
except ValueError:
pass
res = MockResponse(200, "OK", {}, "")
return self.parent.error("http", args[0], res, code, "", {})
elif action == "raise":
raise urllib2.URLError("blah")
assert False
def close(self): pass
def add_parent(self, parent):
self.parent = parent
self.parent.calls = []
def __lt__(self, other):
if not hasattr(other, "handler_order"):
# No handler_order, leave in original order. Yuck.
return True
return self.handler_order < other.handler_order
def add_ordered_mock_handlers(opener, meth_spec):
"""Create MockHandlers and add them to an OpenerDirector.
meth_spec: list of lists of tuples and strings defining methods to define
on handlers. eg:
[["http_error", "ftp_open"], ["http_open"]]
defines methods .http_error() and .ftp_open() on one handler, and
.http_open() on another. These methods just record their arguments and
return None. Using a tuple instead of a string causes the method to
perform some action (see MockHandler.handle()), eg:
[["http_error"], [("http_open", "return request")]]
defines .http_error() on one handler (which simply returns None), and
.http_open() on another handler, which returns a Request object.
"""
handlers = []
count = 0
for meths in meth_spec:
class MockHandlerSubclass(MockHandler): pass
h = MockHandlerSubclass(meths)
h.handler_order += count
h.add_parent(opener)
count = count + 1
handlers.append(h)
opener.add_handler(h)
return handlers
def build_test_opener(*handler_instances):
opener = OpenerDirector()
for h in handler_instances:
opener.add_handler(h)
return opener
class MockHTTPHandler(urllib2.BaseHandler):
# useful for testing redirections and auth
# sends supplied headers and code as first response
# sends 200 OK as second response
def __init__(self, code, headers):
self.code = code
self.headers = headers
self.reset()
def reset(self):
self._count = 0
self.requests = []
def http_open(self, req):
import mimetools, httplib, copy
from StringIO import StringIO
self.requests.append(copy.deepcopy(req))
if self._count == 0:
self._count = self._count + 1
name = httplib.responses[self.code]
msg = mimetools.Message(StringIO(self.headers))
return self.parent.error(
"http", req, MockFile(), self.code, name, msg)
else:
self.req = req
msg = mimetools.Message(StringIO("\r\n\r\n"))
return MockResponse(200, "OK", msg, "", req.get_full_url())
class MockHTTPSHandler(urllib2.AbstractHTTPHandler):
# Useful for testing the Proxy-Authorization request by verifying the
# properties of httpcon
def __init__(self):
urllib2.AbstractHTTPHandler.__init__(self)
self.httpconn = MockHTTPClass()
def https_open(self, req):
return self.do_open(self.httpconn, req)
class MockPasswordManager:
def add_password(self, realm, uri, user, password):
self.realm = realm
self.url = uri
self.user = user
self.password = password
def find_user_password(self, realm, authuri):
self.target_realm = realm
self.target_url = authuri
return self.user, self.password
class OpenerDirectorTests(unittest.TestCase):
def test_add_non_handler(self):
class NonHandler(object):
pass
self.assertRaises(TypeError,
OpenerDirector().add_handler, NonHandler())
def test_badly_named_methods(self):
# test work-around for three methods that accidentally follow the
# naming conventions for handler methods
# (*_open() / *_request() / *_response())
# These used to call the accidentally-named methods, causing a
# TypeError in real code; here, returning self from these mock
# methods would either cause no exception, or AttributeError.
from urllib2 import URLError
o = OpenerDirector()
meth_spec = [
[("do_open", "return self"), ("proxy_open", "return self")],
[("redirect_request", "return self")],
]
handlers = add_ordered_mock_handlers(o, meth_spec)
o.add_handler(urllib2.UnknownHandler())
for scheme in "do", "proxy", "redirect":
self.assertRaises(URLError, o.open, scheme+"://example.com/")
def test_handled(self):
# handler returning non-None means no more handlers will be called
o = OpenerDirector()
meth_spec = [
["http_open", "ftp_open", "http_error_302"],
["ftp_open"],
[("http_open", "return self")],
[("http_open", "return self")],
]
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("http://example.com/")
r = o.open(req)
# Second .http_open() gets called, third doesn't, since second returned
# non-None. Handlers without .http_open() never get any methods called
# on them.
# In fact, second mock handler defining .http_open() returns self
# (instead of response), which becomes the OpenerDirector's return
# value.
self.assertEqual(r, handlers[2])
calls = [(handlers[0], "http_open"), (handlers[2], "http_open")]
for expected, got in zip(calls, o.calls):
handler, name, args, kwds = got
self.assertEqual((handler, name), expected)
self.assertEqual(args, (req,))
def test_handler_order(self):
o = OpenerDirector()
handlers = []
for meths, handler_order in [
([("http_open", "return self")], 500),
(["http_open"], 0),
]:
class MockHandlerSubclass(MockHandler): pass
h = MockHandlerSubclass(meths)
h.handler_order = handler_order
handlers.append(h)
o.add_handler(h)
r = o.open("http://example.com/")
# handlers called in reverse order, thanks to their sort order
self.assertEqual(o.calls[0][0], handlers[1])
self.assertEqual(o.calls[1][0], handlers[0])
def test_raise(self):
# raising URLError stops processing of request
o = OpenerDirector()
meth_spec = [
[("http_open", "raise")],
[("http_open", "return self")],
]
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("http://example.com/")
self.assertRaises(urllib2.URLError, o.open, req)
self.assertEqual(o.calls, [(handlers[0], "http_open", (req,), {})])
## def test_error(self):
## # XXX this doesn't actually seem to be used in standard library,
## # but should really be tested anyway...
def test_http_error(self):
# XXX http_error_default
# http errors are a special case
o = OpenerDirector()
meth_spec = [
[("http_open", "error 302")],
[("http_error_400", "raise"), "http_open"],
[("http_error_302", "return response"), "http_error_303",
"http_error"],
[("http_error_302")],
]
handlers = add_ordered_mock_handlers(o, meth_spec)
class Unknown:
def __eq__(self, other): return True
req = Request("http://example.com/")
r = o.open(req)
assert len(o.calls) == 2
calls = [(handlers[0], "http_open", (req,)),
(handlers[2], "http_error_302",
(req, Unknown(), 302, "", {}))]
for expected, got in zip(calls, o.calls):
handler, method_name, args = expected
self.assertEqual((handler, method_name), got[:2])
self.assertEqual(args, got[2])
def test_processors(self):
# *_request / *_response methods get called appropriately
o = OpenerDirector()
meth_spec = [
[("http_request", "return request"),
("http_response", "return response")],
[("http_request", "return request"),
("http_response", "return response")],
]
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("http://example.com/")
r = o.open(req)
# processor methods are called on *all* handlers that define them,
# not just the first handler that handles the request
calls = [
(handlers[0], "http_request"), (handlers[1], "http_request"),
(handlers[0], "http_response"), (handlers[1], "http_response")]
for i, (handler, name, args, kwds) in enumerate(o.calls):
if i < 2:
# *_request
self.assertEqual((handler, name), calls[i])
self.assertEqual(len(args), 1)
self.assertIsInstance(args[0], Request)
else:
# *_response
self.assertEqual((handler, name), calls[i])
self.assertEqual(len(args), 2)
self.assertIsInstance(args[0], Request)
# response from opener.open is None, because there's no
# handler that defines http_open to handle it
self.assertTrue(args[1] is None or
isinstance(args[1], MockResponse))
def sanepathname2url(path):
import urllib
urlpath = urllib.pathname2url(path)
if os.name == "nt" and urlpath.startswith("///"):
urlpath = urlpath[2:]
# XXX don't ask me about the mac...
return urlpath
class HandlerTests(unittest.TestCase):
def test_ftp(self):
class MockFTPWrapper:
def __init__(self, data): self.data = data
def retrfile(self, filename, filetype):
self.filename, self.filetype = filename, filetype
return StringIO.StringIO(self.data), len(self.data)
def close(self): pass
class NullFTPHandler(urllib2.FTPHandler):
def __init__(self, data): self.data = data
def connect_ftp(self, user, passwd, host, port, dirs,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
self.user, self.passwd = user, passwd
self.host, self.port = host, port
self.dirs = dirs
self.ftpwrapper = MockFTPWrapper(self.data)
return self.ftpwrapper
import ftplib
data = "rheum rhaponicum"
h = NullFTPHandler(data)
o = h.parent = MockOpener()
for url, host, port, user, passwd, type_, dirs, filename, mimetype in [
("ftp://localhost/foo/bar/baz.html",
"localhost", ftplib.FTP_PORT, "", "", "I",
["foo", "bar"], "baz.html", "text/html"),
("ftp://parrot@localhost/foo/bar/baz.html",
"localhost", ftplib.FTP_PORT, "parrot", "", "I",
["foo", "bar"], "baz.html", "text/html"),
("ftp://%25parrot@localhost/foo/bar/baz.html",
"localhost", ftplib.FTP_PORT, "%parrot", "", "I",
["foo", "bar"], "baz.html", "text/html"),
("ftp://%2542parrot@localhost/foo/bar/baz.html",
"localhost", ftplib.FTP_PORT, "%42parrot", "", "I",
["foo", "bar"], "baz.html", "text/html"),
("ftp://localhost:80/foo/bar/",
"localhost", 80, "", "", "D",
["foo", "bar"], "", None),
("ftp://localhost/baz.gif;type=a",
"localhost", ftplib.FTP_PORT, "", "", "A",
[], "baz.gif", None), # XXX really this should guess image/gif
]:
req = Request(url)
req.timeout = None
r = h.ftp_open(req)
# ftp authentication not yet implemented by FTPHandler
self.assertEqual(h.user, user)
self.assertEqual(h.passwd, passwd)
self.assertEqual(h.host, socket.gethostbyname(host))
self.assertEqual(h.port, port)
self.assertEqual(h.dirs, dirs)
self.assertEqual(h.ftpwrapper.filename, filename)
self.assertEqual(h.ftpwrapper.filetype, type_)
headers = r.info()
self.assertEqual(headers.get("Content-type"), mimetype)
self.assertEqual(int(headers["Content-length"]), len(data))
def test_file(self):
import rfc822, socket
h = urllib2.FileHandler()
o = h.parent = MockOpener()
TESTFN = test_support.TESTFN
urlpath = sanepathname2url(os.path.abspath(TESTFN))
towrite = "hello, world\n"
urls = [
"file://localhost%s" % urlpath,
"file://%s" % urlpath,
"file://%s%s" % (socket.gethostbyname('localhost'), urlpath),
]
try:
localaddr = socket.gethostbyname(socket.gethostname())
except socket.gaierror:
localaddr = ''
if localaddr:
urls.append("file://%s%s" % (localaddr, urlpath))
for url in urls:
f = open(TESTFN, "wb")
try:
try:
f.write(towrite)
finally:
f.close()
r = h.file_open(Request(url))
try:
data = r.read()
headers = r.info()
respurl = r.geturl()
finally:
r.close()
stats = os.stat(TESTFN)
modified = rfc822.formatdate(stats.st_mtime)
finally:
os.remove(TESTFN)
self.assertEqual(data, towrite)
self.assertEqual(headers["Content-type"], "text/plain")
self.assertEqual(headers["Content-length"], "13")
self.assertEqual(headers["Last-modified"], modified)
self.assertEqual(respurl, url)
for url in [
"file://localhost:80%s" % urlpath,
"file:///file_does_not_exist.txt",
"file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
os.getcwd(), TESTFN),
"file://somerandomhost.ontheinternet.com%s/%s" %
(os.getcwd(), TESTFN),
]:
try:
f = open(TESTFN, "wb")
try:
f.write(towrite)
finally:
f.close()
self.assertRaises(urllib2.URLError,
h.file_open, Request(url))
finally:
os.remove(TESTFN)
h = urllib2.FileHandler()
o = h.parent = MockOpener()
# XXXX why does // mean ftp (and /// mean not ftp!), and where
# is file: scheme specified? I think this is really a bug, and
# what was intended was to distinguish between URLs like:
# file:/blah.txt (a file)
# file://localhost/blah.txt (a file)
# file:///blah.txt (a file)
# file://ftp.example.com/blah.txt (an ftp URL)
for url, ftp in [
("file://ftp.example.com//foo.txt", True),
("file://ftp.example.com///foo.txt", False),
# XXXX bug: fails with OSError, should be URLError
("file://ftp.example.com/foo.txt", False),
("file://somehost//foo/something.txt", True),
("file://localhost//foo/something.txt", False),
]:
req = Request(url)
try:
h.file_open(req)
# XXXX remove OSError when bug fixed
except (urllib2.URLError, OSError):
self.assertTrue(not ftp)
else:
self.assertTrue(o.req is req)
self.assertEqual(req.type, "ftp")
self.assertEqual(req.type == "ftp", ftp)
def test_http(self):
h = urllib2.AbstractHTTPHandler()
o = h.parent = MockOpener()
url = "http://example.com/"
for method, data in [("GET", None), ("POST", "blah")]:
req = Request(url, data, {"Foo": "bar"})
req.timeout = None
req.add_unredirected_header("Spam", "eggs")
http = MockHTTPClass()
r = h.do_open(http, req)
# result attributes
r.read; r.readline # wrapped MockFile methods
r.info; r.geturl # addinfourl methods
r.code, r.msg == 200, "OK" # added from MockHTTPClass.getreply()
hdrs = r.info()
hdrs.get; hdrs.has_key # r.info() gives dict from .getreply()
self.assertEqual(r.geturl(), url)
self.assertEqual(http.host, "example.com")
self.assertEqual(http.level, 0)
self.assertEqual(http.method, method)
self.assertEqual(http.selector, "/")
self.assertEqual(http.req_headers,
[("Connection", "close"),
("Foo", "bar"), ("Spam", "eggs")])
self.assertEqual(http.data, data)
# check socket.error converted to URLError
http.raise_on_endheaders = True
self.assertRaises(urllib2.URLError, h.do_open, http, req)
# check adding of standard headers
o.addheaders = [("Spam", "eggs")]
for data in "", None: # POST, GET
req = Request("http://example.com/", data)
r = MockResponse(200, "OK", {}, "")
newreq = h.do_request_(req)
if data is None: # GET
self.assertNotIn("Content-length", req.unredirected_hdrs)
self.assertNotIn("Content-type", req.unredirected_hdrs)
else: # POST
self.assertEqual(req.unredirected_hdrs["Content-length"], "0")
self.assertEqual(req.unredirected_hdrs["Content-type"],
"application/x-www-form-urlencoded")
# XXX the details of Host could be better tested
self.assertEqual(req.unredirected_hdrs["Host"], "example.com")
self.assertEqual(req.unredirected_hdrs["Spam"], "eggs")
# don't clobber existing headers
req.add_unredirected_header("Content-length", "foo")
req.add_unredirected_header("Content-type", "bar")
req.add_unredirected_header("Host", "baz")
req.add_unredirected_header("Spam", "foo")
newreq = h.do_request_(req)
self.assertEqual(req.unredirected_hdrs["Content-length"], "foo")
self.assertEqual(req.unredirected_hdrs["Content-type"], "bar")
self.assertEqual(req.unredirected_hdrs["Host"], "baz")
self.assertEqual(req.unredirected_hdrs["Spam"], "foo")
def test_http_doubleslash(self):
# Checks that the presence of an unnecessary double slash in a url doesn't break anything
# Previously, a double slash directly after the host could cause incorrect parsing of the url
h = urllib2.AbstractHTTPHandler()
o = h.parent = MockOpener()
data = ""
ds_urls = [
"http://example.com/foo/bar/baz.html",
"http://example.com//foo/bar/baz.html",
"http://example.com/foo//bar/baz.html",
"http://example.com/foo/bar//baz.html",
]
for ds_url in ds_urls:
ds_req = Request(ds_url, data)
# Check whether host is determined correctly if there is no proxy
np_ds_req = h.do_request_(ds_req)
self.assertEqual(np_ds_req.unredirected_hdrs["Host"],"example.com")
# Check whether host is determined correctly if there is a proxy
ds_req.set_proxy("someproxy:3128",None)
p_ds_req = h.do_request_(ds_req)
self.assertEqual(p_ds_req.unredirected_hdrs["Host"],"example.com")
def test_fixpath_in_weirdurls(self):
# Issue4493: urllib2 to supply '/' when to urls where path does not
# start with'/'
h = urllib2.AbstractHTTPHandler()
o = h.parent = MockOpener()
weird_url = 'http://www.python.org?getspam'
req = Request(weird_url)
newreq = h.do_request_(req)
self.assertEqual(newreq.get_host(),'www.python.org')
self.assertEqual(newreq.get_selector(),'/?getspam')
url_without_path = 'http://www.python.org'
req = Request(url_without_path)
newreq = h.do_request_(req)
self.assertEqual(newreq.get_host(),'www.python.org')
self.assertEqual(newreq.get_selector(),'')
def test_errors(self):
h = urllib2.HTTPErrorProcessor()
o = h.parent = MockOpener()
url = "http://example.com/"
req = Request(url)
# all 2xx are passed through
r = MockResponse(200, "OK", {}, "", url)
newr = h.http_response(req, r)
self.assertTrue(r is newr)
self.assertTrue(not hasattr(o, "proto")) # o.error not called
r = MockResponse(202, "Accepted", {}, "", url)
newr = h.http_response(req, r)
self.assertTrue(r is newr)
self.assertTrue(not hasattr(o, "proto")) # o.error not called
r = MockResponse(206, "Partial content", {}, "", url)
newr = h.http_response(req, r)
self.assertTrue(r is newr)
self.assertTrue(not hasattr(o, "proto")) # o.error not called
# anything else calls o.error (and MockOpener returns None, here)
r = MockResponse(502, "Bad gateway", {}, "", url)
self.assertTrue(h.http_response(req, r) is None)
self.assertEqual(o.proto, "http") # o.error called
self.assertEqual(o.args, (req, r, 502, "Bad gateway", {}))
def test_cookies(self):
cj = MockCookieJar()
h = urllib2.HTTPCookieProcessor(cj)
o = h.parent = MockOpener()
req = Request("http://example.com/")
r = MockResponse(200, "OK", {}, "")
newreq = h.http_request(req)
self.assertTrue(cj.ach_req is req is newreq)
self.assertEqual(req.get_origin_req_host(), "example.com")
self.assertTrue(not req.is_unverifiable())
newr = h.http_response(req, r)
self.assertTrue(cj.ec_req is req)
self.assertTrue(cj.ec_r is r is newr)
def test_redirect(self):
from_url = "http://example.com/a.html"
to_url = "http://example.com/b.html"
h = urllib2.HTTPRedirectHandler()
o = h.parent = MockOpener()
# ordinary redirect behaviour
for code in 301, 302, 303, 307:
for data in None, "blah\nblah\n":
method = getattr(h, "http_error_%s" % code)
req = Request(from_url, data)
req.add_header("Nonsense", "viking=withhold")
req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
if data is not None:
req.add_header("Content-Length", str(len(data)))
req.add_unredirected_header("Spam", "spam")
try:
method(req, MockFile(), code, "Blah",
MockHeaders({"location": to_url}))
except urllib2.HTTPError:
# 307 in response to POST requires user OK
self.assertTrue(code == 307 and data is not None)
self.assertEqual(o.req.get_full_url(), to_url)
try:
self.assertEqual(o.req.get_method(), "GET")
except AttributeError:
self.assertTrue(not o.req.has_data())
# now it's a GET, there should not be headers regarding content
# (possibly dragged from before being a POST)
headers = [x.lower() for x in o.req.headers]
self.assertNotIn("content-length", headers)
self.assertNotIn("content-type", headers)
self.assertEqual(o.req.headers["Nonsense"],
"viking=withhold")
self.assertNotIn("Spam", o.req.headers)
self.assertNotIn("Spam", o.req.unredirected_hdrs)
# loop detection
req = Request(from_url)
req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
def redirect(h, req, url=to_url):
h.http_error_302(req, MockFile(), 302, "Blah",
MockHeaders({"location": url}))
# Note that the *original* request shares the same record of
# redirections with the sub-requests caused by the redirections.
# detect infinite loop redirect of a URL to itself
req = Request(from_url, origin_req_host="example.com")
count = 0
req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
try:
while 1:
redirect(h, req, "http://example.com/")
count = count + 1
except urllib2.HTTPError:
# don't stop until max_repeats, because cookies may introduce state
self.assertEqual(count, urllib2.HTTPRedirectHandler.max_repeats)
# detect endless non-repeating chain of redirects
req = Request(from_url, origin_req_host="example.com")
count = 0
req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
try:
while 1:
redirect(h, req, "http://example.com/%d" % count)
count = count + 1
except urllib2.HTTPError:
self.assertEqual(count,
urllib2.HTTPRedirectHandler.max_redirections)
def test_invalid_redirect(self):
from_url = "http://example.com/a.html"
valid_schemes = ['http', 'https', 'ftp']
invalid_schemes = ['file', 'imap', 'ldap']
schemeless_url = "example.com/b.html"
h = urllib2.HTTPRedirectHandler()
o = h.parent = MockOpener()
req = Request(from_url)
req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
for scheme in invalid_schemes:
invalid_url = scheme + '://' + schemeless_url
self.assertRaises(urllib2.HTTPError, h.http_error_302,
req, MockFile(), 302, "Security Loophole",
MockHeaders({"location": invalid_url}))
for scheme in valid_schemes:
valid_url = scheme + '://' + schemeless_url
h.http_error_302(req, MockFile(), 302, "That's fine",
MockHeaders({"location": valid_url}))
self.assertEqual(o.req.get_full_url(), valid_url)
def test_cookie_redirect(self):
# cookies shouldn't leak into redirected requests
from cookielib import CookieJar
from test.test_cookielib import interact_netscape
cj = CookieJar()
interact_netscape(cj, "http://www.example.com/", "spam=eggs")
hh = MockHTTPHandler(302, "Location: http://www.cracker.com/\r\n\r\n")
hdeh = urllib2.HTTPDefaultErrorHandler()
hrh = urllib2.HTTPRedirectHandler()
cp = urllib2.HTTPCookieProcessor(cj)
o = build_test_opener(hh, hdeh, hrh, cp)
o.open("http://www.example.com/")
self.assertTrue(not hh.req.has_header("Cookie"))
def test_redirect_fragment(self):
redirected_url = 'http://www.example.com/index.html#OK\r\n\r\n'
hh = MockHTTPHandler(302, 'Location: ' + redirected_url)
hdeh = urllib2.HTTPDefaultErrorHandler()
hrh = urllib2.HTTPRedirectHandler()
o = build_test_opener(hh, hdeh, hrh)
fp = o.open('http://www.example.com')
self.assertEqual(fp.geturl(), redirected_url.strip())
def test_proxy(self):
o = OpenerDirector()
ph = urllib2.ProxyHandler(dict(http="proxy.example.com:3128"))
o.add_handler(ph)
meth_spec = [
[("http_open", "return response")]
]
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("http://acme.example.com/")
self.assertEqual(req.get_host(), "acme.example.com")
r = o.open(req)
self.assertEqual(req.get_host(), "proxy.example.com:3128")
self.assertEqual([(handlers[0], "http_open")],
[tup[0:2] for tup in o.calls])
def test_proxy_no_proxy(self):
os.environ['no_proxy'] = 'python.org'
o = OpenerDirector()
ph = urllib2.ProxyHandler(dict(http="proxy.example.com"))
o.add_handler(ph)
req = Request("http://www.perl.org/")
self.assertEqual(req.get_host(), "www.perl.org")
r = o.open(req)
self.assertEqual(req.get_host(), "proxy.example.com")
req = Request("http://www.python.org")
self.assertEqual(req.get_host(), "www.python.org")
r = o.open(req)
self.assertEqual(req.get_host(), "www.python.org")
del os.environ['no_proxy']
def test_proxy_https(self):
o = OpenerDirector()
ph = urllib2.ProxyHandler(dict(https='proxy.example.com:3128'))
o.add_handler(ph)
meth_spec = [
[("https_open","return response")]
]
handlers = add_ordered_mock_handlers(o, meth_spec)
req = Request("https://www.example.com/")
self.assertEqual(req.get_host(), "www.example.com")
r = o.open(req)
self.assertEqual(req.get_host(), "proxy.example.com:3128")
self.assertEqual([(handlers[0], "https_open")],
[tup[0:2] for tup in o.calls])
def test_proxy_https_proxy_authorization(self):
o = OpenerDirector()
ph = urllib2.ProxyHandler(dict(https='proxy.example.com:3128'))
o.add_handler(ph)
https_handler = MockHTTPSHandler()
o.add_handler(https_handler)
req = Request("https://www.example.com/")
req.add_header("Proxy-Authorization","FooBar")
req.add_header("User-Agent","Grail")
self.assertEqual(req.get_host(), "www.example.com")
self.assertIsNone(req._tunnel_host)
r = o.open(req)
# Verify Proxy-Authorization gets tunneled to request.
# httpsconn req_headers do not have the Proxy-Authorization header but
# the req will have.
self.assertNotIn(("Proxy-Authorization","FooBar"),
https_handler.httpconn.req_headers)
self.assertIn(("User-Agent","Grail"),
https_handler.httpconn.req_headers)
self.assertIsNotNone(req._tunnel_host)
self.assertEqual(req.get_host(), "proxy.example.com:3128")
self.assertEqual(req.get_header("Proxy-authorization"),"FooBar")
def test_basic_auth(self, quote_char='"'):
opener = OpenerDirector()
password_manager = MockPasswordManager()
auth_handler = urllib2.HTTPBasicAuthHandler(password_manager)
realm = "ACME Widget Store"
http_handler = MockHTTPHandler(
401, 'WWW-Authenticate: Basic realm=%s%s%s\r\n\r\n' %
(quote_char, realm, quote_char) )
opener.add_handler(auth_handler)
opener.add_handler(http_handler)
self._test_basic_auth(opener, auth_handler, "Authorization",
realm, http_handler, password_manager,
"http://acme.example.com/protected",
"http://acme.example.com/protected",
)
def test_basic_auth_with_single_quoted_realm(self):
self.test_basic_auth(quote_char="'")
def test_proxy_basic_auth(self):
opener = OpenerDirector()
ph = urllib2.ProxyHandler(dict(http="proxy.example.com:3128"))
opener.add_handler(ph)
password_manager = MockPasswordManager()
auth_handler = urllib2.ProxyBasicAuthHandler(password_manager)
realm = "ACME Networks"
http_handler = MockHTTPHandler(
407, 'Proxy-Authenticate: Basic realm="%s"\r\n\r\n' % realm)
opener.add_handler(auth_handler)
opener.add_handler(http_handler)
self._test_basic_auth(opener, auth_handler, "Proxy-authorization",
realm, http_handler, password_manager,
"http://acme.example.com:3128/protected",
"proxy.example.com:3128",
)
def test_basic_and_digest_auth_handlers(self):
# HTTPDigestAuthHandler threw an exception if it couldn't handle a 40*
# response (http://python.org/sf/1479302), where it should instead
# return None to allow another handler (especially
# HTTPBasicAuthHandler) to handle the response.
# Also (http://python.org/sf/14797027, RFC 2617 section 1.2), we must
# try digest first (since it's the strongest auth scheme), so we record
# order of calls here to check digest comes first:
class RecordingOpenerDirector(OpenerDirector):
def __init__(self):
OpenerDirector.__init__(self)
self.recorded = []
def record(self, info):
self.recorded.append(info)
class TestDigestAuthHandler(urllib2.HTTPDigestAuthHandler):
def http_error_401(self, *args, **kwds):
self.parent.record("digest")
urllib2.HTTPDigestAuthHandler.http_error_401(self,
*args, **kwds)
class TestBasicAuthHandler(urllib2.HTTPBasicAuthHandler):
def http_error_401(self, *args, **kwds):
self.parent.record("basic")
urllib2.HTTPBasicAuthHandler.http_error_401(self,
*args, **kwds)
opener = RecordingOpenerDirector()
password_manager = MockPasswordManager()
digest_handler = TestDigestAuthHandler(password_manager)
basic_handler = TestBasicAuthHandler(password_manager)
realm = "ACME Networks"
http_handler = MockHTTPHandler(
401, 'WWW-Authenticate: Basic realm="%s"\r\n\r\n' % realm)
opener.add_handler(basic_handler)
opener.add_handler(digest_handler)
opener.add_handler(http_handler)
# check basic auth isn't blocked by digest handler failing
self._test_basic_auth(opener, basic_handler, "Authorization",
realm, http_handler, password_manager,
"http://acme.example.com/protected",
"http://acme.example.com/protected",
)
# check digest was tried before basic (twice, because
# _test_basic_auth called .open() twice)
self.assertEqual(opener.recorded, ["digest", "basic"]*2)
def _test_basic_auth(self, opener, auth_handler, auth_header,
realm, http_handler, password_manager,
request_url, protected_url):
import base64
user, password = "wile", "coyote"
# .add_password() fed through to password manager
auth_handler.add_password(realm, request_url, user, password)
self.assertEqual(realm, password_manager.realm)
self.assertEqual(request_url, password_manager.url)
self.assertEqual(user, password_manager.user)
self.assertEqual(password, password_manager.password)
r = opener.open(request_url)
# should have asked the password manager for the username/password
self.assertEqual(password_manager.target_realm, realm)
self.assertEqual(password_manager.target_url, protected_url)
# expect one request without authorization, then one with
self.assertEqual(len(http_handler.requests), 2)
self.assertFalse(http_handler.requests[0].has_header(auth_header))
userpass = '%s:%s' % (user, password)
auth_hdr_value = 'Basic '+base64.encodestring(userpass).strip()
self.assertEqual(http_handler.requests[1].get_header(auth_header),
auth_hdr_value)
self.assertEqual(http_handler.requests[1].unredirected_hdrs[auth_header],
auth_hdr_value)
# if the password manager can't find a password, the handler won't
# handle the HTTP auth error
password_manager.user = password_manager.password = None
http_handler.reset()
r = opener.open(request_url)
self.assertEqual(len(http_handler.requests), 1)
self.assertFalse(http_handler.requests[0].has_header(auth_header))
class MiscTests(unittest.TestCase):
def test_build_opener(self):
class MyHTTPHandler(urllib2.HTTPHandler): pass
class FooHandler(urllib2.BaseHandler):
def foo_open(self): pass
class BarHandler(urllib2.BaseHandler):
def bar_open(self): pass
build_opener = urllib2.build_opener
o = build_opener(FooHandler, BarHandler)
self.opener_has_handler(o, FooHandler)
self.opener_has_handler(o, BarHandler)
# can take a mix of classes and instances
o = build_opener(FooHandler, BarHandler())
self.opener_has_handler(o, FooHandler)
self.opener_has_handler(o, BarHandler)
# subclasses of default handlers override default handlers
o = build_opener(MyHTTPHandler)
self.opener_has_handler(o, MyHTTPHandler)
# a particular case of overriding: default handlers can be passed
# in explicitly
o = build_opener()
self.opener_has_handler(o, urllib2.HTTPHandler)
o = build_opener(urllib2.HTTPHandler)
self.opener_has_handler(o, urllib2.HTTPHandler)
o = build_opener(urllib2.HTTPHandler())
self.opener_has_handler(o, urllib2.HTTPHandler)
# Issue2670: multiple handlers sharing the same base class
class MyOtherHTTPHandler(urllib2.HTTPHandler): pass
o = build_opener(MyHTTPHandler, MyOtherHTTPHandler)
self.opener_has_handler(o, MyHTTPHandler)
self.opener_has_handler(o, MyOtherHTTPHandler)
def opener_has_handler(self, opener, handler_class):
for h in opener.handlers:
if h.__class__ == handler_class:
break
else:
self.assertTrue(False)
class RequestTests(unittest.TestCase):
def setUp(self):
self.get = urllib2.Request("http://www.python.org/~jeremy/")
self.post = urllib2.Request("http://www.python.org/~jeremy/",
"data",
headers={"X-Test": "test"})
def test_method(self):
self.assertEqual("POST", self.post.get_method())
self.assertEqual("GET", self.get.get_method())
def test_add_data(self):
self.assertTrue(not self.get.has_data())
self.assertEqual("GET", self.get.get_method())
self.get.add_data("spam")
self.assertTrue(self.get.has_data())
self.assertEqual("POST", self.get.get_method())
def test_get_full_url(self):
self.assertEqual("http://www.python.org/~jeremy/",
self.get.get_full_url())
def test_selector(self):
self.assertEqual("/~jeremy/", self.get.get_selector())
req = urllib2.Request("http://www.python.org/")
self.assertEqual("/", req.get_selector())
def test_get_type(self):
self.assertEqual("http", self.get.get_type())
def test_get_host(self):
self.assertEqual("www.python.org", self.get.get_host())
def test_get_host_unquote(self):
req = urllib2.Request("http://www.%70ython.org/")
self.assertEqual("www.python.org", req.get_host())
def test_proxy(self):
self.assertTrue(not self.get.has_proxy())
self.get.set_proxy("www.perl.org", "http")
self.assertTrue(self.get.has_proxy())
self.assertEqual("www.python.org", self.get.get_origin_req_host())
self.assertEqual("www.perl.org", self.get.get_host())
def test_wrapped_url(self):
req = Request("<URL:http://www.python.org>")
self.assertEqual("www.python.org", req.get_host())
def test_url_fragment(self):
req = Request("http://www.python.org/?qs=query#fragment=true")
self.assertEqual("/?qs=query", req.get_selector())
req = Request("http://www.python.org/#fun=true")
self.assertEqual("/", req.get_selector())
# Issue 11703: geturl() omits fragment in the original URL.
url = 'http://docs.python.org/library/urllib2.html#OK'
req = Request(url)
self.assertEqual(req.get_full_url(), url)
def test_HTTPError_interface():
"""
Issue 13211 reveals that HTTPError didn't implement the URLError
interface even though HTTPError is a subclass of URLError.
>>> err = urllib2.HTTPError(msg='something bad happened', url=None, code=None, hdrs=None, fp=None)
>>> assert hasattr(err, 'reason')
>>> err.reason
'something bad happened'
"""
def test_main(verbose=None):
from test import test_urllib2
test_support.run_doctest(test_urllib2, verbose)
test_support.run_doctest(urllib2, verbose)
tests = (TrivialTests,
OpenerDirectorTests,
HandlerTests,
MiscTests,
RequestTests)
test_support.run_unittest(*tests)
if __name__ == "__main__":
test_main(verbose=True)
#!/usr/bin/env python
import urlparse
import urllib2
import BaseHTTPServer
import unittest
import hashlib
from test import test_support
mimetools = test_support.import_module('mimetools', deprecated=True)
threading = test_support.import_module('threading')
# Loopback http server infrastructure
class LoopbackHttpServer(BaseHTTPServer.HTTPServer):
"""HTTP server w/ a few modifications that make it useful for
loopback testing purposes.
"""
def __init__(self, server_address, RequestHandlerClass):
BaseHTTPServer.HTTPServer.__init__(self,
server_address,
RequestHandlerClass)
# Set the timeout of our listening socket really low so
# that we can stop the server easily.
self.socket.settimeout(1.0)
def get_request(self):
"""BaseHTTPServer method, overridden."""
request, client_address = self.socket.accept()
# It's a loopback connection, so setting the timeout
# really low shouldn't affect anything, but should make
# deadlocks less likely to occur.
request.settimeout(10.0)
return (request, client_address)
class LoopbackHttpServerThread(threading.Thread):
"""Stoppable thread that runs a loopback http server."""
def __init__(self, request_handler):
threading.Thread.__init__(self)
self._stop = False
self.ready = threading.Event()
request_handler.protocol_version = "HTTP/1.0"
self.httpd = LoopbackHttpServer(('127.0.0.1', 0),
request_handler)
#print "Serving HTTP on %s port %s" % (self.httpd.server_name,
# self.httpd.server_port)
self.port = self.httpd.server_port
def stop(self):
"""Stops the webserver if it's currently running."""
# Set the stop flag.
self._stop = True
self.join()
def run(self):
self.ready.set()
while not self._stop:
self.httpd.handle_request()
# Authentication infrastructure
class DigestAuthHandler:
"""Handler for performing digest authentication."""
def __init__(self):
self._request_num = 0
self._nonces = []
self._users = {}
self._realm_name = "Test Realm"
self._qop = "auth"
def set_qop(self, qop):
self._qop = qop
def set_users(self, users):
assert isinstance(users, dict)
self._users = users
def set_realm(self, realm):
self._realm_name = realm
def _generate_nonce(self):
self._request_num += 1
nonce = hashlib.md5(str(self._request_num)).hexdigest()
self._nonces.append(nonce)
return nonce
def _create_auth_dict(self, auth_str):
first_space_index = auth_str.find(" ")
auth_str = auth_str[first_space_index+1:]
parts = auth_str.split(",")
auth_dict = {}
for part in parts:
name, value = part.split("=")
name = name.strip()
if value[0] == '"' and value[-1] == '"':
value = value[1:-1]
else:
value = value.strip()
auth_dict[name] = value
return auth_dict
def _validate_auth(self, auth_dict, password, method, uri):
final_dict = {}
final_dict.update(auth_dict)
final_dict["password"] = password
final_dict["method"] = method
final_dict["uri"] = uri
HA1_str = "%(username)s:%(realm)s:%(password)s" % final_dict
HA1 = hashlib.md5(HA1_str).hexdigest()
HA2_str = "%(method)s:%(uri)s" % final_dict
HA2 = hashlib.md5(HA2_str).hexdigest()
final_dict["HA1"] = HA1
final_dict["HA2"] = HA2
response_str = "%(HA1)s:%(nonce)s:%(nc)s:" \
"%(cnonce)s:%(qop)s:%(HA2)s" % final_dict
response = hashlib.md5(response_str).hexdigest()
return response == auth_dict["response"]
def _return_auth_challenge(self, request_handler):
request_handler.send_response(407, "Proxy Authentication Required")
request_handler.send_header("Content-Type", "text/html")
request_handler.send_header(
'Proxy-Authenticate', 'Digest realm="%s", '
'qop="%s",'
'nonce="%s", ' % \
(self._realm_name, self._qop, self._generate_nonce()))
# XXX: Not sure if we're supposed to add this next header or
# not.
#request_handler.send_header('Connection', 'close')
request_handler.end_headers()
request_handler.wfile.write("Proxy Authentication Required.")
return False
def handle_request(self, request_handler):
"""Performs digest authentication on the given HTTP request
handler. Returns True if authentication was successful, False
otherwise.
If no users have been set, then digest auth is effectively
disabled and this method will always return True.
"""
if len(self._users) == 0:
return True
if 'Proxy-Authorization' not in request_handler.headers:
return self._return_auth_challenge(request_handler)
else:
auth_dict = self._create_auth_dict(
request_handler.headers['Proxy-Authorization']
)
if auth_dict["username"] in self._users:
password = self._users[ auth_dict["username"] ]
else:
return self._return_auth_challenge(request_handler)
if not auth_dict.get("nonce") in self._nonces:
return self._return_auth_challenge(request_handler)
else:
self._nonces.remove(auth_dict["nonce"])
auth_validated = False
# MSIE uses short_path in its validation, but Python's
# urllib2 uses the full path, so we're going to see if
# either of them works here.
for path in [request_handler.path, request_handler.short_path]:
if self._validate_auth(auth_dict,
password,
request_handler.command,
path):
auth_validated = True
if not auth_validated:
return self._return_auth_challenge(request_handler)
return True
# Proxy test infrastructure
class FakeProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""This is a 'fake proxy' that makes it look like the entire
internet has gone down due to a sudden zombie invasion. It main
utility is in providing us with authentication support for
testing.
"""
def __init__(self, digest_auth_handler, *args, **kwargs):
# This has to be set before calling our parent's __init__(), which will
# try to call do_GET().
self.digest_auth_handler = digest_auth_handler
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args, **kwargs)
def log_message(self, format, *args):
# Uncomment the next line for debugging.
#sys.stderr.write(format % args)
pass
def do_GET(self):
(scm, netloc, path, params, query, fragment) = urlparse.urlparse(
self.path, 'http')
self.short_path = path
if self.digest_auth_handler.handle_request(self):
self.send_response(200, "OK")
self.send_header("Content-Type", "text/html")
self.end_headers()
self.wfile.write("You've reached %s!<BR>" % self.path)
self.wfile.write("Our apologies, but our server is down due to "
"a sudden zombie invasion.")
# Test cases
class BaseTestCase(unittest.TestCase):
def setUp(self):
self._threads = test_support.threading_setup()
def tearDown(self):
test_support.threading_cleanup(*self._threads)
class ProxyAuthTests(BaseTestCase):
URL = "http://localhost"
USER = "tester"
PASSWD = "test123"
REALM = "TestRealm"
def setUp(self):
super(ProxyAuthTests, self).setUp()
self.digest_auth_handler = DigestAuthHandler()
self.digest_auth_handler.set_users({self.USER: self.PASSWD})
self.digest_auth_handler.set_realm(self.REALM)
def create_fake_proxy_handler(*args, **kwargs):
return FakeProxyHandler(self.digest_auth_handler, *args, **kwargs)
self.server = LoopbackHttpServerThread(create_fake_proxy_handler)
self.server.start()
self.server.ready.wait()
proxy_url = "http://127.0.0.1:%d" % self.server.port
handler = urllib2.ProxyHandler({"http" : proxy_url})
self.proxy_digest_handler = urllib2.ProxyDigestAuthHandler()
self.opener = urllib2.build_opener(handler, self.proxy_digest_handler)
def tearDown(self):
self.server.stop()
super(ProxyAuthTests, self).tearDown()
def test_proxy_with_bad_password_raises_httperror(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD+"bad")
self.digest_auth_handler.set_qop("auth")
self.assertRaises(urllib2.HTTPError,
self.opener.open,
self.URL)
def test_proxy_with_no_password_raises_httperror(self):
self.digest_auth_handler.set_qop("auth")
self.assertRaises(urllib2.HTTPError,
self.opener.open,
self.URL)
def test_proxy_qop_auth_works(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD)
self.digest_auth_handler.set_qop("auth")
result = self.opener.open(self.URL)
while result.read():
pass
result.close()
def test_proxy_qop_auth_int_works_or_throws_urlerror(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD)
self.digest_auth_handler.set_qop("auth-int")
try:
result = self.opener.open(self.URL)
except urllib2.URLError:
# It's okay if we don't support auth-int, but we certainly
# shouldn't receive any kind of exception here other than
# a URLError.
result = None
if result:
while result.read():
pass
result.close()
def GetRequestHandler(responses):
class FakeHTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
server_version = "TestHTTP/"
requests = []
headers_received = []
port = 80
def do_GET(self):
body = self.send_head()
if body:
self.wfile.write(body)
def do_POST(self):
content_length = self.headers['Content-Length']
post_data = self.rfile.read(int(content_length))
self.do_GET()
self.requests.append(post_data)
def send_head(self):
FakeHTTPRequestHandler.headers_received = self.headers
self.requests.append(self.path)
response_code, headers, body = responses.pop(0)
self.send_response(response_code)
for (header, value) in headers:
self.send_header(header, value % self.port)
if body:
self.send_header('Content-type', 'text/plain')
self.end_headers()
return body
self.end_headers()
def log_message(self, *args):
pass
return FakeHTTPRequestHandler
class TestUrlopen(BaseTestCase):
"""Tests urllib2.urlopen using the network.
These tests are not exhaustive. Assuming that testing using files does a
good job overall of some of the basic interface features. There are no
tests exercising the optional 'data' and 'proxies' arguments. No tests
for transparent redirection have been written.
"""
def start_server(self, responses):
handler = GetRequestHandler(responses)
self.server = LoopbackHttpServerThread(handler)
self.server.start()
self.server.ready.wait()
port = self.server.port
handler.port = port
return handler
def test_redirection(self):
expected_response = 'We got here...'
responses = [
(302, [('Location', 'http://localhost:%s/somewhere_else')], ''),
(200, [], expected_response)
]
handler = self.start_server(responses)
try:
f = urllib2.urlopen('http://localhost:%s/' % handler.port)
data = f.read()
f.close()
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ['/', '/somewhere_else'])
finally:
self.server.stop()
def test_404(self):
expected_response = 'Bad bad bad...'
handler = self.start_server([(404, [], expected_response)])
try:
try:
urllib2.urlopen('http://localhost:%s/weeble' % handler.port)
except urllib2.URLError, f:
pass
else:
self.fail('404 should raise URLError')
data = f.read()
f.close()
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ['/weeble'])
finally:
self.server.stop()
def test_200(self):
expected_response = 'pycon 2008...'
handler = self.start_server([(200, [], expected_response)])
try:
f = urllib2.urlopen('http://localhost:%s/bizarre' % handler.port)
data = f.read()
f.close()
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ['/bizarre'])
finally:
self.server.stop()
def test_200_with_parameters(self):
expected_response = 'pycon 2008...'
handler = self.start_server([(200, [], expected_response)])
try:
f = urllib2.urlopen('http://localhost:%s/bizarre' % handler.port, 'get=with_feeling')
data = f.read()
f.close()
self.assertEqual(data, expected_response)
self.assertEqual(handler.requests, ['/bizarre', 'get=with_feeling'])
finally:
self.server.stop()
def test_sending_headers(self):
handler = self.start_server([(200, [], "we don't care")])
try:
req = urllib2.Request("http://localhost:%s/" % handler.port,
headers={'Range': 'bytes=20-39'})
urllib2.urlopen(req)
self.assertEqual(handler.headers_received['Range'], 'bytes=20-39')
finally:
self.server.stop()
def test_basic(self):
handler = self.start_server([(200, [], "we don't care")])
try:
open_url = urllib2.urlopen("http://localhost:%s" % handler.port)
for attr in ("read", "close", "info", "geturl"):
self.assertTrue(hasattr(open_url, attr), "object returned from "
"urlopen lacks the %s attribute" % attr)
try:
self.assertTrue(open_url.read(), "calling 'read' failed")
finally:
open_url.close()
finally:
self.server.stop()
def test_info(self):
handler = self.start_server([(200, [], "we don't care")])
try:
open_url = urllib2.urlopen("http://localhost:%s" % handler.port)
info_obj = open_url.info()
self.assertIsInstance(info_obj, mimetools.Message,
"object returned by 'info' is not an "
"instance of mimetools.Message")
self.assertEqual(info_obj.getsubtype(), "plain")
finally:
self.server.stop()
def test_geturl(self):
# Make sure same URL as opened is returned by geturl.
handler = self.start_server([(200, [], "we don't care")])
try:
open_url = urllib2.urlopen("http://localhost:%s" % handler.port)
url = open_url.geturl()
self.assertEqual(url, "http://localhost:%s" % handler.port)
finally:
self.server.stop()
def test_bad_address(self):
# Make sure proper exception is raised when connecting to a bogus
# address.
self.assertRaises(IOError,
# Given that both VeriSign and various ISPs have in
# the past or are presently hijacking various invalid
# domain name requests in an attempt to boost traffic
# to their own sites, finding a domain name to use
# for this test is difficult. RFC2606 leads one to
# believe that '.invalid' should work, but experience
# seemed to indicate otherwise. Single character
# TLDs are likely to remain invalid, so this seems to
# be the best choice. The trailing '.' prevents a
# related problem: The normal DNS resolver appends
# the domain names from the search path if there is
# no '.' the end and, and if one of those domains
# implements a '*' rule a result is returned.
# However, none of this will prevent the test from
# failing if the ISP hijacks all invalid domain
# requests. The real solution would be to be able to
# parameterize the framework with a mock resolver.
urllib2.urlopen, "http://sadflkjsasf.i.nvali.d./")
def test_iteration(self):
expected_response = "pycon 2008..."
handler = self.start_server([(200, [], expected_response)])
try:
data = urllib2.urlopen("http://localhost:%s" % handler.port)
for line in data:
self.assertEqual(line, expected_response)
finally:
self.server.stop()
def ztest_line_iteration(self):
lines = ["We\n", "got\n", "here\n", "verylong " * 8192 + "\n"]
expected_response = "".join(lines)
handler = self.start_server([(200, [], expected_response)])
try:
data = urllib2.urlopen("http://localhost:%s" % handler.port)
for index, line in enumerate(data):
self.assertEqual(line, lines[index],
"Fetched line number %s doesn't match expected:\n"
" Expected length was %s, got %s" %
(index, len(lines[index]), len(line)))
finally:
self.server.stop()
self.assertEqual(index + 1, len(lines))
def test_main():
# We will NOT depend on the network resource flag
# (Lib/test/regrtest.py -u network) since all tests here are only
# localhost. However, if this is a bad rationale, then uncomment
# the next line.
#test_support.requires("network")
test_support.run_unittest(ProxyAuthTests, TestUrlopen)
if __name__ == "__main__":
test_main()
#!/usr/bin/env python
import unittest
from test import test_support
from test.test_urllib2 import sanepathname2url
import socket
import urllib2
import os
import sys
TIMEOUT = 60 # seconds
def _retry_thrice(func, exc, *args, **kwargs):
for i in range(3):
try:
return func(*args, **kwargs)
except exc, last_exc:
continue
except:
raise
raise last_exc
def _wrap_with_retry_thrice(func, exc):
def wrapped(*args, **kwargs):
return _retry_thrice(func, exc, *args, **kwargs)
return wrapped
# Connecting to remote hosts is flaky. Make it more robust by retrying
# the connection several times.
_urlopen_with_retry = _wrap_with_retry_thrice(urllib2.urlopen, urllib2.URLError)
class AuthTests(unittest.TestCase):
"""Tests urllib2 authentication features."""
## Disabled at the moment since there is no page under python.org which
## could be used to HTTP authentication.
#
# def test_basic_auth(self):
# import httplib
#
# test_url = "http://www.python.org/test/test_urllib2/basic_auth"
# test_hostport = "www.python.org"
# test_realm = 'Test Realm'
# test_user = 'test.test_urllib2net'
# test_password = 'blah'
#
# # failure
# try:
# _urlopen_with_retry(test_url)
# except urllib2.HTTPError, exc:
# self.assertEqual(exc.code, 401)
# else:
# self.fail("urlopen() should have failed with 401")
#
# # success
# auth_handler = urllib2.HTTPBasicAuthHandler()
# auth_handler.add_password(test_realm, test_hostport,
# test_user, test_password)
# opener = urllib2.build_opener(auth_handler)
# f = opener.open('http://localhost/')
# response = _urlopen_with_retry("http://www.python.org/")
#
# # The 'userinfo' URL component is deprecated by RFC 3986 for security
# # reasons, let's not implement it! (it's already implemented for proxy
# # specification strings (that is, URLs or authorities specifying a
# # proxy), so we must keep that)
# self.assertRaises(httplib.InvalidURL,
# urllib2.urlopen, "http://evil:thing@example.com")
class CloseSocketTest(unittest.TestCase):
def test_close(self):
import httplib
# calling .close() on urllib2's response objects should close the
# underlying socket
# delve deep into response to fetch socket._socketobject
response = _urlopen_with_retry("http://www.python.org/")
abused_fileobject = response.fp
self.assertTrue(abused_fileobject.__class__ is socket._fileobject)
httpresponse = abused_fileobject._sock
self.assertTrue(httpresponse.__class__ is httplib.HTTPResponse)
fileobject = httpresponse.fp
self.assertTrue(fileobject.__class__ is socket._fileobject)
self.assertTrue(not fileobject.closed)
response.close()
self.assertTrue(fileobject.closed)
class OtherNetworkTests(unittest.TestCase):
def setUp(self):
if 0: # for debugging
import logging
logger = logging.getLogger("test_urllib2net")
logger.addHandler(logging.StreamHandler())
# XXX The rest of these tests aren't very good -- they don't check much.
# They do sometimes catch some major disasters, though.
def test_ftp(self):
urls = [
'ftp://ftp.kernel.org/pub/linux/kernel/README',
'ftp://ftp.kernel.org/pub/linux/kernel/non-existent-file',
#'ftp://ftp.kernel.org/pub/leenox/kernel/test',
'ftp://gatekeeper.research.compaq.com/pub/DEC/SRC'
'/research-reports/00README-Legal-Rules-Regs',
]
self._test_urls(urls, self._extra_handlers())
def test_file(self):
TESTFN = test_support.TESTFN
f = open(TESTFN, 'w')
try:
f.write('hi there\n')
f.close()
urls = [
'file:'+sanepathname2url(os.path.abspath(TESTFN)),
('file:///nonsensename/etc/passwd', None, urllib2.URLError),
]
self._test_urls(urls, self._extra_handlers(), retry=True)
finally:
os.remove(TESTFN)
self.assertRaises(ValueError, urllib2.urlopen,'./relative_path/to/file')
# XXX Following test depends on machine configurations that are internal
# to CNRI. Need to set up a public server with the right authentication
# configuration for test purposes.
## def test_cnri(self):
## if socket.gethostname() == 'bitdiddle':
## localhost = 'bitdiddle.cnri.reston.va.us'
## elif socket.gethostname() == 'bitdiddle.concentric.net':
## localhost = 'localhost'
## else:
## localhost = None
## if localhost is not None:
## urls = [
## 'file://%s/etc/passwd' % localhost,
## 'http://%s/simple/' % localhost,
## 'http://%s/digest/' % localhost,
## 'http://%s/not/found.h' % localhost,
## ]
## bauth = HTTPBasicAuthHandler()
## bauth.add_password('basic_test_realm', localhost, 'jhylton',
## 'password')
## dauth = HTTPDigestAuthHandler()
## dauth.add_password('digest_test_realm', localhost, 'jhylton',
## 'password')
## self._test_urls(urls, self._extra_handlers()+[bauth, dauth])
def test_urlwithfrag(self):
urlwith_frag = "http://docs.python.org/glossary.html#glossary"
with test_support.transient_internet(urlwith_frag):
req = urllib2.Request(urlwith_frag)
res = urllib2.urlopen(req)
self.assertEqual(res.geturl(),
"http://docs.python.org/glossary.html#glossary")
def test_fileno(self):
req = urllib2.Request("http://www.python.org")
opener = urllib2.build_opener()
res = opener.open(req)
try:
res.fileno()
except AttributeError:
self.fail("HTTPResponse object should return a valid fileno")
finally:
res.close()
def test_custom_headers(self):
url = "http://www.example.com"
with test_support.transient_internet(url):
opener = urllib2.build_opener()
request = urllib2.Request(url)
self.assertFalse(request.header_items())
opener.open(request)
self.assertTrue(request.header_items())
self.assertTrue(request.has_header('User-agent'))
request.add_header('User-Agent','Test-Agent')
opener.open(request)
self.assertEqual(request.get_header('User-agent'),'Test-Agent')
def test_sites_no_connection_close(self):
# Some sites do not send Connection: close header.
# Verify that those work properly. (#issue12576)
URL = 'http://www.imdb.com' # No Connection:close
with test_support.transient_internet(URL):
req = urllib2.urlopen(URL)
res = req.read()
self.assertTrue(res)
def _test_urls(self, urls, handlers, retry=True):
import time
import logging
debug = logging.getLogger("test_urllib2").debug
urlopen = urllib2.build_opener(*handlers).open
if retry:
urlopen = _wrap_with_retry_thrice(urlopen, urllib2.URLError)
for url in urls:
if isinstance(url, tuple):
url, req, expected_err = url
else:
req = expected_err = None
with test_support.transient_internet(url):
debug(url)
try:
f = urlopen(url, req, TIMEOUT)
except EnvironmentError as err:
debug(err)
if expected_err:
msg = ("Didn't get expected error(s) %s for %s %s, got %s: %s" %
(expected_err, url, req, type(err), err))
self.assertIsInstance(err, expected_err, msg)
except urllib2.URLError as err:
if isinstance(err[0], socket.timeout):
print >>sys.stderr, "<timeout: %s>" % url
continue
else:
raise
else:
try:
with test_support.transient_internet(url):
buf = f.read()
debug("read %d bytes" % len(buf))
except socket.timeout:
print >>sys.stderr, "<timeout: %s>" % url
f.close()
debug("******** next url coming up...")
time.sleep(0.1)
def _extra_handlers(self):
handlers = []
cfh = urllib2.CacheFTPHandler()
self.addCleanup(cfh.clear_cache)
cfh.setTimeout(1)
handlers.append(cfh)
return handlers
class TimeoutTest(unittest.TestCase):
def test_http_basic(self):
self.assertTrue(socket.getdefaulttimeout() is None)
url = "http://www.python.org"
with test_support.transient_internet(url, timeout=None):
u = _urlopen_with_retry(url)
self.assertTrue(u.fp._sock.fp._sock.gettimeout() is None)
def test_http_default_timeout(self):
self.assertTrue(socket.getdefaulttimeout() is None)
url = "http://www.python.org"
with test_support.transient_internet(url):
socket.setdefaulttimeout(60)
try:
u = _urlopen_with_retry(url)
finally:
socket.setdefaulttimeout(None)
self.assertEqual(u.fp._sock.fp._sock.gettimeout(), 60)
def test_http_no_timeout(self):
self.assertTrue(socket.getdefaulttimeout() is None)
url = "http://www.python.org"
with test_support.transient_internet(url):
socket.setdefaulttimeout(60)
try:
u = _urlopen_with_retry(url, timeout=None)
finally:
socket.setdefaulttimeout(None)
self.assertTrue(u.fp._sock.fp._sock.gettimeout() is None)
def test_http_timeout(self):
url = "http://www.python.org"
with test_support.transient_internet(url):
u = _urlopen_with_retry(url, timeout=120)
self.assertEqual(u.fp._sock.fp._sock.gettimeout(), 120)
FTP_HOST = "ftp://ftp.mirror.nl/pub/gnu/"
def test_ftp_basic(self):
self.assertTrue(socket.getdefaulttimeout() is None)
with test_support.transient_internet(self.FTP_HOST, timeout=None):
u = _urlopen_with_retry(self.FTP_HOST)
self.assertTrue(u.fp.fp._sock.gettimeout() is None)
def test_ftp_default_timeout(self):
self.assertTrue(socket.getdefaulttimeout() is None)
with test_support.transient_internet(self.FTP_HOST):
socket.setdefaulttimeout(60)
try:
u = _urlopen_with_retry(self.FTP_HOST)
finally:
socket.setdefaulttimeout(None)
self.assertEqual(u.fp.fp._sock.gettimeout(), 60)
def test_ftp_no_timeout(self):
self.assertTrue(socket.getdefaulttimeout() is None)
with test_support.transient_internet(self.FTP_HOST):
socket.setdefaulttimeout(60)
try:
u = _urlopen_with_retry(self.FTP_HOST, timeout=None)
finally:
socket.setdefaulttimeout(None)
self.assertTrue(u.fp.fp._sock.gettimeout() is None)
def test_ftp_timeout(self):
with test_support.transient_internet(self.FTP_HOST):
u = _urlopen_with_retry(self.FTP_HOST, timeout=60)
self.assertEqual(u.fp.fp._sock.gettimeout(), 60)
def test_main():
test_support.requires("network")
test_support.run_unittest(AuthTests,
OtherNetworkTests,
CloseSocketTest,
TimeoutTest,
)
if __name__ == "__main__":
test_main()
from __future__ import nested_scopes # Backward compat for 2.1
from unittest import TestCase
from wsgiref.util import setup_testing_defaults
from wsgiref.headers import Headers
from wsgiref.handlers import BaseHandler, BaseCGIHandler
from wsgiref import util
from wsgiref.validate import validator
from wsgiref.simple_server import WSGIServer, WSGIRequestHandler, demo_app
from wsgiref.simple_server import make_server
from StringIO import StringIO
from SocketServer import BaseServer
import os
import re
import sys
from test import test_support
class MockServer(WSGIServer):
"""Non-socket HTTP server"""
def __init__(self, server_address, RequestHandlerClass):
BaseServer.__init__(self, server_address, RequestHandlerClass)
self.server_bind()
def server_bind(self):
host, port = self.server_address
self.server_name = host
self.server_port = port
self.setup_environ()
class MockHandler(WSGIRequestHandler):
"""Non-socket HTTP handler"""
def setup(self):
self.connection = self.request
self.rfile, self.wfile = self.connection
def finish(self):
pass
def hello_app(environ,start_response):
start_response("200 OK", [
('Content-Type','text/plain'),
('Date','Mon, 05 Jun 2006 18:49:54 GMT')
])
return ["Hello, world!"]
def run_amock(app=hello_app, data="GET / HTTP/1.0\n\n"):
server = make_server("", 80, app, MockServer, MockHandler)
inp, out, err, olderr = StringIO(data), StringIO(), StringIO(), sys.stderr
sys.stderr = err
try:
server.finish_request((inp,out), ("127.0.0.1",8888))
finally:
sys.stderr = olderr
return out.getvalue(), err.getvalue()
def compare_generic_iter(make_it,match):
"""Utility to compare a generic 2.1/2.2+ iterator with an iterable
If running under Python 2.2+, this tests the iterator using iter()/next(),
as well as __getitem__. 'make_it' must be a function returning a fresh
iterator to be tested (since this may test the iterator twice)."""
it = make_it()
n = 0
for item in match:
if not it[n]==item: raise AssertionError
n+=1
try:
it[n]
except IndexError:
pass
else:
raise AssertionError("Too many items from __getitem__",it)
try:
iter, StopIteration
except NameError:
pass
else:
# Only test iter mode under 2.2+
it = make_it()
if not iter(it) is it: raise AssertionError
for item in match:
if not it.next()==item: raise AssertionError
try:
it.next()
except StopIteration:
pass
else:
raise AssertionError("Too many items from .next()",it)
class IntegrationTests(TestCase):
def check_hello(self, out, has_length=True):
self.assertEqual(out,
"HTTP/1.0 200 OK\r\n"
"Server: WSGIServer/0.1 Python/"+sys.version.split()[0]+"\r\n"
"Content-Type: text/plain\r\n"
"Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" +
(has_length and "Content-Length: 13\r\n" or "") +
"\r\n"
"Hello, world!"
)
def test_plain_hello(self):
out, err = run_amock()
self.check_hello(out)
def test_validated_hello(self):
out, err = run_amock(validator(hello_app))
# the middleware doesn't support len(), so content-length isn't there
self.check_hello(out, has_length=False)
def test_simple_validation_error(self):
def bad_app(environ,start_response):
start_response("200 OK", ('Content-Type','text/plain'))
return ["Hello, world!"]
out, err = run_amock(validator(bad_app))
self.assertTrue(out.endswith(
"A server error occurred. Please contact the administrator."
))
self.assertEqual(
err.splitlines()[-2],
"AssertionError: Headers (('Content-Type', 'text/plain')) must"
" be of type list: <type 'tuple'>"
)
class UtilityTests(TestCase):
def checkShift(self,sn_in,pi_in,part,sn_out,pi_out):
env = {'SCRIPT_NAME':sn_in,'PATH_INFO':pi_in}
util.setup_testing_defaults(env)
self.assertEqual(util.shift_path_info(env),part)
self.assertEqual(env['PATH_INFO'],pi_out)
self.assertEqual(env['SCRIPT_NAME'],sn_out)
return env
def checkDefault(self, key, value, alt=None):
# Check defaulting when empty
env = {}
util.setup_testing_defaults(env)
if isinstance(value, StringIO):
self.assertIsInstance(env[key], StringIO)
else:
self.assertEqual(env[key], value)
# Check existing value
env = {key:alt}
util.setup_testing_defaults(env)
self.assertTrue(env[key] is alt)
def checkCrossDefault(self,key,value,**kw):
util.setup_testing_defaults(kw)
self.assertEqual(kw[key],value)
def checkAppURI(self,uri,**kw):
util.setup_testing_defaults(kw)
self.assertEqual(util.application_uri(kw),uri)
def checkReqURI(self,uri,query=1,**kw):
util.setup_testing_defaults(kw)
self.assertEqual(util.request_uri(kw,query),uri)
def checkFW(self,text,size,match):
def make_it(text=text,size=size):
return util.FileWrapper(StringIO(text),size)
compare_generic_iter(make_it,match)
it = make_it()
self.assertFalse(it.filelike.closed)
for item in it:
pass
self.assertFalse(it.filelike.closed)
it.close()
self.assertTrue(it.filelike.closed)
def testSimpleShifts(self):
self.checkShift('','/', '', '/', '')
self.checkShift('','/x', 'x', '/x', '')
self.checkShift('/','', None, '/', '')
self.checkShift('/a','/x/y', 'x', '/a/x', '/y')
self.checkShift('/a','/x/', 'x', '/a/x', '/')
def testNormalizedShifts(self):
self.checkShift('/a/b', '/../y', '..', '/a', '/y')
self.checkShift('', '/../y', '..', '', '/y')
self.checkShift('/a/b', '//y', 'y', '/a/b/y', '')
self.checkShift('/a/b', '//y/', 'y', '/a/b/y', '/')
self.checkShift('/a/b', '/./y', 'y', '/a/b/y', '')
self.checkShift('/a/b', '/./y/', 'y', '/a/b/y', '/')
self.checkShift('/a/b', '///./..//y/.//', '..', '/a', '/y/')
self.checkShift('/a/b', '///', '', '/a/b/', '')
self.checkShift('/a/b', '/.//', '', '/a/b/', '')
self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/')
self.checkShift('/a/b', '/.', None, '/a/b', '')
def testDefaults(self):
for key, value in [
('SERVER_NAME','127.0.0.1'),
('SERVER_PORT', '80'),
('SERVER_PROTOCOL','HTTP/1.0'),
('HTTP_HOST','127.0.0.1'),
('REQUEST_METHOD','GET'),
('SCRIPT_NAME',''),
('PATH_INFO','/'),
('wsgi.version', (1,0)),
('wsgi.run_once', 0),
('wsgi.multithread', 0),
('wsgi.multiprocess', 0),
('wsgi.input', StringIO("")),
('wsgi.errors', StringIO()),
('wsgi.url_scheme','http'),
]:
self.checkDefault(key,value)
def testCrossDefaults(self):
self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar")
self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on")
self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="1")
self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="yes")
self.checkCrossDefault('wsgi.url_scheme',"http",HTTPS="foo")
self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo")
self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on")
def testGuessScheme(self):
self.assertEqual(util.guess_scheme({}), "http")
self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http")
self.assertEqual(util.guess_scheme({'HTTPS':"on"}), "https")
self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https")
self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https")
def testAppURIs(self):
self.checkAppURI("http://127.0.0.1/")
self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
self.checkAppURI("http://spam.example.com:2071/",
HTTP_HOST="spam.example.com:2071", SERVER_PORT="2071")
self.checkAppURI("http://spam.example.com/",
SERVER_NAME="spam.example.com")
self.checkAppURI("http://127.0.0.1/",
HTTP_HOST="127.0.0.1", SERVER_NAME="spam.example.com")
self.checkAppURI("https://127.0.0.1/", HTTPS="on")
self.checkAppURI("http://127.0.0.1:8000/", SERVER_PORT="8000",
HTTP_HOST=None)
def testReqURIs(self):
self.checkReqURI("http://127.0.0.1/")
self.checkReqURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
self.checkReqURI("http://127.0.0.1/spammity/spam",
SCRIPT_NAME="/spammity", PATH_INFO="/spam")
self.checkReqURI("http://127.0.0.1/spammity/spam;ham",
SCRIPT_NAME="/spammity", PATH_INFO="/spam;ham")
self.checkReqURI("http://127.0.0.1/spammity/spam;cookie=1234,5678",
SCRIPT_NAME="/spammity", PATH_INFO="/spam;cookie=1234,5678")
self.checkReqURI("http://127.0.0.1/spammity/spam?say=ni",
SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
self.checkReqURI("http://127.0.0.1/spammity/spam", 0,
SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
def testFileWrapper(self):
self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10])
def testHopByHop(self):
for hop in (
"Connection Keep-Alive Proxy-Authenticate Proxy-Authorization "
"TE Trailers Transfer-Encoding Upgrade"
).split():
for alt in hop, hop.title(), hop.upper(), hop.lower():
self.assertTrue(util.is_hop_by_hop(alt))
# Not comprehensive, just a few random header names
for hop in (
"Accept Cache-Control Date Pragma Trailer Via Warning"
).split():
for alt in hop, hop.title(), hop.upper(), hop.lower():
self.assertFalse(util.is_hop_by_hop(alt))
class HeaderTests(TestCase):
def testMappingInterface(self):
test = [('x','y')]
self.assertEqual(len(Headers([])),0)
self.assertEqual(len(Headers(test[:])),1)
self.assertEqual(Headers(test[:]).keys(), ['x'])
self.assertEqual(Headers(test[:]).values(), ['y'])
self.assertEqual(Headers(test[:]).items(), test)
self.assertFalse(Headers(test).items() is test) # must be copy!
h=Headers([])
del h['foo'] # should not raise an error
h['Foo'] = 'bar'
for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__:
self.assertTrue(m('foo'))
self.assertTrue(m('Foo'))
self.assertTrue(m('FOO'))
self.assertFalse(m('bar'))
self.assertEqual(h['foo'],'bar')
h['foo'] = 'baz'
self.assertEqual(h['FOO'],'baz')
self.assertEqual(h.get_all('foo'),['baz'])
self.assertEqual(h.get("foo","whee"), "baz")
self.assertEqual(h.get("zoo","whee"), "whee")
self.assertEqual(h.setdefault("foo","whee"), "baz")
self.assertEqual(h.setdefault("zoo","whee"), "whee")
self.assertEqual(h["foo"],"baz")
self.assertEqual(h["zoo"],"whee")
def testRequireList(self):
self.assertRaises(TypeError, Headers, "foo")
def testExtras(self):
h = Headers([])
self.assertEqual(str(h),'\r\n')
h.add_header('foo','bar',baz="spam")
self.assertEqual(h['foo'], 'bar; baz="spam"')
self.assertEqual(str(h),'foo: bar; baz="spam"\r\n\r\n')
h.add_header('Foo','bar',cheese=None)
self.assertEqual(h.get_all('foo'),
['bar; baz="spam"', 'bar; cheese'])
self.assertEqual(str(h),
'foo: bar; baz="spam"\r\n'
'Foo: bar; cheese\r\n'
'\r\n'
)
class ErrorHandler(BaseCGIHandler):
"""Simple handler subclass for testing BaseHandler"""
# BaseHandler records the OS environment at import time, but envvars
# might have been changed later by other tests, which trips up
# HandlerTests.testEnviron().
os_environ = dict(os.environ.items())
def __init__(self,**kw):
setup_testing_defaults(kw)
BaseCGIHandler.__init__(
self, StringIO(''), StringIO(), StringIO(), kw,
multithread=True, multiprocess=True
)
class TestHandler(ErrorHandler):
"""Simple handler subclass for testing BaseHandler, w/error passthru"""
def handle_error(self):
raise # for testing, we want to see what's happening
class HandlerTests(TestCase):
def checkEnvironAttrs(self, handler):
env = handler.environ
for attr in [
'version','multithread','multiprocess','run_once','file_wrapper'
]:
if attr=='file_wrapper' and handler.wsgi_file_wrapper is None:
continue
self.assertEqual(getattr(handler,'wsgi_'+attr),env['wsgi.'+attr])
def checkOSEnviron(self,handler):
empty = {}; setup_testing_defaults(empty)
env = handler.environ
from os import environ
for k,v in environ.items():
if k not in empty:
self.assertEqual(env[k],v)
for k,v in empty.items():
self.assertIn(k, env)
def testEnviron(self):
h = TestHandler(X="Y")
h.setup_environ()
self.checkEnvironAttrs(h)
self.checkOSEnviron(h)
self.assertEqual(h.environ["X"],"Y")
def testCGIEnviron(self):
h = BaseCGIHandler(None,None,None,{})
h.setup_environ()
for key in 'wsgi.url_scheme', 'wsgi.input', 'wsgi.errors':
self.assertIn(key, h.environ)
def testScheme(self):
h=TestHandler(HTTPS="on"); h.setup_environ()
self.assertEqual(h.environ['wsgi.url_scheme'],'https')
h=TestHandler(); h.setup_environ()
self.assertEqual(h.environ['wsgi.url_scheme'],'http')
def testAbstractMethods(self):
h = BaseHandler()
for name in [
'_flush','get_stdin','get_stderr','add_cgi_vars'
]:
self.assertRaises(NotImplementedError, getattr(h,name))
self.assertRaises(NotImplementedError, h._write, "test")
def testContentLength(self):
# Demo one reason iteration is better than write()... ;)
def trivial_app1(e,s):
s('200 OK',[])
return [e['wsgi.url_scheme']]
def trivial_app2(e,s):
s('200 OK',[])(e['wsgi.url_scheme'])
return []
def trivial_app4(e,s):
# Simulate a response to a HEAD request
s('200 OK',[('Content-Length', '12345')])
return []
h = TestHandler()
h.run(trivial_app1)
self.assertEqual(h.stdout.getvalue(),
"Status: 200 OK\r\n"
"Content-Length: 4\r\n"
"\r\n"
"http")
h = TestHandler()
h.run(trivial_app2)
self.assertEqual(h.stdout.getvalue(),
"Status: 200 OK\r\n"
"\r\n"
"http")
h = TestHandler()
h.run(trivial_app4)
self.assertEqual(h.stdout.getvalue(),
b'Status: 200 OK\r\n'
b'Content-Length: 12345\r\n'
b'\r\n')
def testBasicErrorOutput(self):
def non_error_app(e,s):
s('200 OK',[])
return []
def error_app(e,s):
raise AssertionError("This should be caught by handler")
h = ErrorHandler()
h.run(non_error_app)
self.assertEqual(h.stdout.getvalue(),
"Status: 200 OK\r\n"
"Content-Length: 0\r\n"
"\r\n")
self.assertEqual(h.stderr.getvalue(),"")
h = ErrorHandler()
h.run(error_app)
self.assertEqual(h.stdout.getvalue(),
"Status: %s\r\n"
"Content-Type: text/plain\r\n"
"Content-Length: %d\r\n"
"\r\n%s" % (h.error_status,len(h.error_body),h.error_body))
self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1)
def testErrorAfterOutput(self):
MSG = "Some output has been sent"
def error_app(e,s):
s("200 OK",[])(MSG)
raise AssertionError("This should be caught by handler")
h = ErrorHandler()
h.run(error_app)
self.assertEqual(h.stdout.getvalue(),
"Status: 200 OK\r\n"
"\r\n"+MSG)
self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1)
def testHeaderFormats(self):
def non_error_app(e,s):
s('200 OK',[])
return []
stdpat = (
r"HTTP/%s 200 OK\r\n"
r"Date: \w{3}, [ 0123]\d \w{3} \d{4} \d\d:\d\d:\d\d GMT\r\n"
r"%s" r"Content-Length: 0\r\n" r"\r\n"
)
shortpat = (
"Status: 200 OK\r\n" "Content-Length: 0\r\n" "\r\n"
)
for ssw in "FooBar/1.0", None:
sw = ssw and "Server: %s\r\n" % ssw or ""
for version in "1.0", "1.1":
for proto in "HTTP/0.9", "HTTP/1.0", "HTTP/1.1":
h = TestHandler(SERVER_PROTOCOL=proto)
h.origin_server = False
h.http_version = version
h.server_software = ssw
h.run(non_error_app)
self.assertEqual(shortpat,h.stdout.getvalue())
h = TestHandler(SERVER_PROTOCOL=proto)
h.origin_server = True
h.http_version = version
h.server_software = ssw
h.run(non_error_app)
if proto=="HTTP/0.9":
self.assertEqual(h.stdout.getvalue(),"")
else:
self.assertTrue(
re.match(stdpat%(version,sw), h.stdout.getvalue()),
(stdpat%(version,sw), h.stdout.getvalue())
)
# This epilogue is needed for compatibility with the Python 2.5 regrtest module
def test_main():
test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
# the above lines intentionally left blank
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnH
FlbsVUg2Xtk6+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6T
f9lnNTwpSoeK24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQAB
AoGAQFko4uyCgzfxr4Ezb4Mp5pN3Npqny5+Jey3r8EjSAX9Ogn+CNYgoBcdtFgbq
1yif/0sK7ohGBJU9FUCAwrqNBI9ZHB6rcy7dx+gULOmRBGckln1o5S1+smVdmOsW
7zUVLBVByKuNWqTYFlzfVd6s4iiXtAE2iHn3GCyYdlICwrECQQDhMQVxHd3EFbzg
SFmJBTARlZ2GKA3c1g/h9/XbkEPQ9/RwI3vnjJ2RaSnjlfoLl8TOcf0uOGbOEyFe
19RvCLXjAkEA1s+UE5ziF+YVkW3WolDCQ2kQ5WG9+ccfNebfh6b67B7Ln5iG0Sbg
ky9cjsO3jbMJQtlzAQnH1850oRD5Gi51dQJAIbHCDLDZU9Ok1TI+I2BhVuA6F666
lEZ7TeZaJSYq34OaUYUdrwG9OdqwZ9sy9LUav4ESzu2lhEQchCJrKMn23QJAReqs
ZLHUeTjfXkVk7dHhWPWSlUZ6AhmIlA/AQ7Payg2/8wM/JkZEJEPvGVykms9iPUrv
frADRr+hAGe43IewnQJBAJWKZllPgKuEBPwoEldHNS8nRu61D7HzxEzQ2xnfj+Nk
2fgf1MAzzTRsikfGENhVsVWeqOcijWb6g5gsyCmlRpc=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIICsDCCAhmgAwIBAgIJAOqYOYFJfEEoMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQwHhcNMDgwNjI2MTgxNTUyWhcNMDkwNjI2MTgxNTUyWjBF
MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50
ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
gQC89ZNxjTgWgq7Z1g0tJ65w+k7lNAj5IgjLb155UkUrz0XsHDnHFlbsVUg2Xtk6
+bo2UEYIzN7cIm5ImpmyW/2z0J1IDVDlvR2xJ659xrE0v5c2cB6Tf9lnNTwpSoeK
24Nd7Jwq4j9vk95fLrdqsBq0/KVlsCXeixS/CaqqduXfvwIDAQABo4GnMIGkMB0G
A1UdDgQWBBTctMtI3EO9OjLI0x9Zo2ifkwIiNjB1BgNVHSMEbjBsgBTctMtI3EO9
OjLI0x9Zo2ifkwIiNqFJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt
U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJAOqYOYFJ
fEEoMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEAQwa7jya/DfhaDn7E
usPkpgIX8WCL2B1SqnRTXEZfBPPVq/cUmFGyEVRVATySRuMwi8PXbVcOhXXuocA+
43W+iIsD9pXapCZhhOerCq18TC1dWK98vLUsoK8PMjB6e5H/O8bqojv0EeC+fyCw
eSHj5jpC8iZKjCHBn+mAi4cQ514=
-----END CERTIFICATE-----
from gevent import monkey; monkey.patch_all()
import sys
import os
from patched_tests_setup import disable_tests_in_source
import test.test_support
test.test_support.is_resource_enabled = lambda *args: True
del test.test_support.use_resources
test_filename = sys.argv[1]
del sys.argv[1]
__file__ = os.path.join(os.getcwd(), test_filename)
test_name = os.path.splitext(test_filename)[0]
module_source = open(test_filename).read()
module_source = disable_tests_in_source(module_source, test_name)
module_code = compile(module_source, test_filename, 'exec')
if test_name.startswith('test_urllib2'):
import test
import test_cookielib
import test_urllib2
test.test_urllib2 = test_urllib2
sys.modules['test.test_urllib2'] = test_urllib2
sys.modules['test.test_cookielib'] = test_cookielib
elif test_name == 'test_threading':
import test
import lock_tests
test.lock_tests = lock_tests
exec module_code in globals()
import sys
import os
import glob
import subprocess
import time
def wait(popen, timeout=60):
endtime = time.time() + timeout
try:
while True:
if popen.poll() is not None:
return popen.poll()
time.sleep(0.5)
if time.time() > endtime:
break
finally:
if popen.poll() is None:
sys.stderr.write('\nKilling %s (timed out)\n' % popen.name)
try:
popen.kill()
except OSError:
pass
sys.stderr.write('\n')
return 'TIMEOUT'
version = '%s.%s.%s' % sys.version_info[:3]
if not os.path.exists(version):
sys.exit('Directory %s not found in %s' % (version, os.getcwd()))
os.chdir(version)
class ContainsAll(object):
def __contains__(self, item):
return True
import test_support
test_support.use_resources = ContainsAll()
total = 0
failed = []
tests = set(glob.glob('test_*.py')) - set(['test_support.py'])
tests = sorted(tests)
for test in tests:
total += 1
sys.stderr.write('\nRunning %s\n' % test)
popen = subprocess.Popen([sys.executable, '-u', '-m', 'monkey_test', test])
popen.name = test
if wait(popen):
failed.append(test)
sys.stderr.write('%s/%s tests failed: %s\n' % (len(failed), total, failed))
if failed:
sys.exit(1)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment