Commit 8c497751 authored by Jason Madden's avatar Jason Madden

Make gevent.pywsgi stop dealing with chunks when the connection is being upgraded.

Let the application have full control over input and output.

Fixes #1712.
parent f54fa619
Make `gevent.pywsgi` trying to enforce the rules for reading chunked input or
``Content-Length`` terminated input when the connection is being
upgraded, for example to a websocket connection. Likewise, if the
protocol was switched by returning a ``101`` status, stop trying to
automatically chunk the responses.
Reported by Kavindu Santhusa.
......@@ -415,6 +415,9 @@ class WSGIHandler(object):
time_finish = 0 # time.time() when done handling request
headers_sent = False # Have we already sent headers?
response_use_chunked = False # Write with transfer-encoding chunked
# Was the connection upgraded? We shouldn't try to chunk writes in that
# case.
connection_upgraded = False
environ = None # Dict from self.get_environ
application = None # application callable from self.server.application
requestline = None # native str 'GET / HTTP/1.1'
......@@ -486,6 +489,7 @@ class WSGIHandler(object):
pass
self.__dict__.pop('socket', None)
self.__dict__.pop('rfile', None)
self.__dict__.pop('wsgi_input', None)
def _check_http_version(self):
version_str = self.request_version
......@@ -697,10 +701,19 @@ class WSGIHandler(object):
return True # read more requests
def _connection_upgrade_requested(self):
if self.headers.get('Connection', '').lower() == 'upgrade':
return True
if self.headers.get('Upgrade', '').lower() == 'websocket':
return True
return False
def finalize_headers(self):
if self.provided_date is None:
self.response_headers.append((b'Date', format_date_time(time.time())))
self.connection_upgraded = self.code == 101
if self.code not in (304, 204):
# the reply will include message-body; make sure we have either Content-Length or chunked
if self.provided_content_length is None:
......@@ -711,8 +724,11 @@ class WSGIHandler(object):
total_len_str = total_len_str.encode("latin-1")
self.response_headers.append((b'Content-Length', total_len_str))
else:
if self.request_version != 'HTTP/1.0':
self.response_use_chunked = True
self.response_use_chunked = (
not self.connection_upgraded
and self.request_version != 'HTTP/1.0'
)
if self.response_use_chunked:
self.response_headers.append((b'Transfer-Encoding', b'chunked'))
def _sendall(self, data):
......@@ -975,6 +991,7 @@ class WSGIHandler(object):
self.result = None
self.response_use_chunked = False
self.connection_upgraded = False
self.response_length = 0
try:
......@@ -1103,10 +1120,7 @@ class WSGIHandler(object):
# See https://github.com/gevent/gevent/issues/1667 for discussion.
env['SCRIPT_NAME'] = ''
if '?' in self.path:
path, query = self.path.split('?', 1)
else:
path, query = self.path, ''
path, query = self.path.split('?', 1) if '?' in self.path else (self.path, '')
# Note that self.path contains the original str object; if it contains
# encoded escapes, it will NOT match PATH_INFO.
env['PATH_INFO'] = unquote_latin1(path)
......@@ -1134,18 +1148,20 @@ class WSGIHandler(object):
else:
env[key] = value
if env.get('HTTP_EXPECT') == '100-continue':
sock = self.socket
else:
sock = None
sock = self.socket if env.get('HTTP_EXPECT') == '100-continue' else None
chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked'
# Input refuses to read if the data isn't chunked, and there is no content_length
# provided. For 'Upgrade: Websocket' requests, neither of those things is true.
handling_reads = not self._connection_upgrade_requested()
self.wsgi_input = Input(self.rfile, self.content_length, socket=sock, chunked_input=chunked)
env['wsgi.input'] = self.wsgi_input
env['wsgi.input'] = self.wsgi_input if handling_reads else self.rfile
# This is a non-standard flag indicating that our input stream is
# self-terminated (returns EOF when consumed).
# See https://github.com/gevent/gevent/issues/1308
env['wsgi.input_terminated'] = True
env['wsgi.input_terminated'] = handling_reads
return env
......
......@@ -432,18 +432,35 @@ class TestNoChunks(CommonTestMixin, TestCase):
# when returning a list of strings a shortcut is employed by the server:
# it calculates the content-length and joins all the chunks before sending
validator = None
last_environ = None
def _check_environ(self, input_terminated=True):
if input_terminated:
self.assertTrue(self.last_environ.get('wsgi.input_terminated'))
else:
self.assertFalse(self.last_environ['wsgi.input_terminated'])
def application(self, env, start_response):
self.assertTrue(env.get('wsgi.input_terminated'))
self.last_environ = env
path = env['PATH_INFO']
if path == '/':
start_response('200 OK', [('Content-Type', 'text/plain')])
return [b'hello ', b'world']
if path == '/websocket':
write = start_response('101 Switching Protocols',
[('Content-Type', 'text/plain'),
# Con:close is to make our simple client
# happy; otherwise it wants to read data from the
# body thot's being kept open.
('Connection', 'close')])
write(b'') # Trigger finalizing the headers now.
return [b'upgrading to', b'websocket']
start_response('404 Not Found', [('Content-Type', 'text/plain')])
return [b'not ', b'found']
def test_basic(self):
response, dne_response = super(TestNoChunks, self).test_basic()
self._check_environ()
self.assertFalse(response.chunks)
response.assertHeader('Content-Length', '11')
if dne_response is not None:
......@@ -455,8 +472,28 @@ class TestNoChunks(CommonTestMixin, TestCase):
fd.write(self.format_request(path='/notexist'))
response = read_http(fd, code=404, reason='Not Found', body='not found')
self.assertFalse(response.chunks)
self._check_environ()
response.assertHeader('Content-Length', '9')
class TestConnectionUpgrades(TestNoChunks):
def test_connection_upgrade(self):
with self.makefile() as fd:
fd.write(self.format_request(path='/websocket', Connection='upgrade'))
response = read_http(fd, code=101)
self._check_environ(input_terminated=False)
self.assertFalse(response.chunks)
def test_upgrade_websocket(self):
with self.makefile() as fd:
fd.write(self.format_request(path='/websocket', Upgrade='websocket'))
response = read_http(fd, code=101)
self._check_environ(input_terminated=False)
self.assertFalse(response.chunks)
class TestNoChunks10(TestNoChunks):
HTTP_CLIENT_VERSION = '1.0'
PIPELINE_NOT_SUPPORTED_EXS = (ConnectionClosed,)
......@@ -475,6 +512,7 @@ class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancesto
# server - it caculates the content-length
def application(self, env, start_response):
self.last_environ = env
self.assertTrue(env.get('wsgi.input_terminated'))
path = env['PATH_INFO']
if path == '/':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment