Commit 64ed1511 authored by Julien Muchembled's avatar Julien Muchembled Committed by Xavier Thompson

Rewrite 'urlretrieve' helper to fix various download-related issues

- Py3: stop using legacy API of urllib.request and
       fix download of http(s) URLs containing user:passwd@
- Py2: avoid OOM when downloading huge files

This is implemented as a method in case we want to make it configurable
via [buildout].
parent 04cd4e83
...@@ -20,3 +20,16 @@ class UserError(Exception): ...@@ -20,3 +20,16 @@ class UserError(Exception):
def __str__(self): def __str__(self):
return " ".join(map(str, self.args)) return " ".join(map(str, self.args))
# Used for Python 2-3 compatibility
if str is bytes:
bytes2str = str2bytes = lambda s: s
def unicode2str(s):
return s.encode('utf-8')
else:
def bytes2str(s):
return s.decode()
def str2bytes(s):
return s.encode()
def unicode2str(s):
return s
...@@ -20,35 +20,17 @@ except ImportError: ...@@ -20,35 +20,17 @@ except ImportError:
try: try:
# Python 3 # Python 3
from urllib.request import urlretrieve from urllib.request import Request, urlopen
from urllib.parse import urlparse from urllib.parse import urlparse, urlunparse
except ImportError: except ImportError:
# Python 2 # Python 2
import base64
from urlparse import urlparse from urlparse import urlparse
from urlparse import urlunparse from urlparse import urlunparse
import urllib2 from urllib2 import Request, urlopen
def urlretrieve(url, tmp_path):
"""Work around Python issue 24599 includig basic auth support
"""
scheme, netloc, path, params, query, frag = urlparse(url)
auth, host = urllib2.splituser(netloc)
if auth:
url = urlunparse((scheme, host, path, params, query, frag))
req = urllib2.Request(url)
base64string = base64.encodestring(auth)[:-1]
basic = "Basic " + base64string
req.add_header("Authorization", basic)
else:
req = urllib2.Request(url)
url_obj = urllib2.urlopen(req)
with open(tmp_path, 'wb') as fp:
fp.write(url_obj.read())
return tmp_path, url_obj.info()
from zc.buildout.easy_install import realpath from zc.buildout.easy_install import realpath
from base64 import b64encode
from contextlib import closing
import logging import logging
import os import os
import os.path import os.path
...@@ -56,6 +38,7 @@ import re ...@@ -56,6 +38,7 @@ import re
import shutil import shutil
import tempfile import tempfile
import zc.buildout import zc.buildout
from . import bytes2str, str2bytes
from .rmtree import rmtree from .rmtree import rmtree
...@@ -223,13 +206,14 @@ class Download(object): ...@@ -223,13 +206,14 @@ class Download(object):
nc.get('signature-certificate-list'), md5sum): nc.get('signature-certificate-list'), md5sum):
# Download from original url if not cached or md5sum doesn't match. # Download from original url if not cached or md5sum doesn't match.
try: try:
tmp_path, headers = urlretrieve(url, tmp_path) tmp_path, headers = self.urlretrieve(url, tmp_path)
except HTTPError: except HTTPError:
if not alternate_url: if not alternate_url:
raise raise
self.logger.info('using alternate URL: %s', alternate_url) self.logger.info('using alternate URL: %s', alternate_url)
download_url = alternate_url download_url = alternate_url
tmp_path, headers = urlretrieve(alternate_url, tmp_path) tmp_path, headers = self.urlretrieve(
alternate_url, tmp_path)
if not check_md5sum(tmp_path, md5sum): if not check_md5sum(tmp_path, md5sum):
raise ChecksumError( raise ChecksumError(
'MD5 checksum mismatch downloading %r' % download_url) 'MD5 checksum mismatch downloading %r' % download_url)
...@@ -282,6 +266,22 @@ class Download(object): ...@@ -282,6 +266,22 @@ class Download(object):
url_host, url_port = parsed[-2:] url_host, url_port = parsed[-2:]
return '%s:%s' % (url_host, url_port) return '%s:%s' % (url_host, url_port)
def urlretrieve(self, url, tmp_path):
parsed_url = urlparse(url)
req = url
if parsed_url.scheme in ('http', 'https'):
auth_host = parsed_url.netloc.rsplit('@', 1)
if len(auth_host) > 1:
auth = auth_host[0]
url = parsed_url._replace(netloc=auth_host[1]).geturl()
req = Request(url)
req.add_header("Authorization",
"Basic " + bytes2str(b64encode(str2bytes(auth))))
with closing(urlopen(req)) as src:
with open(tmp_path, 'wb') as dst:
shutil.copyfileobj(src, dst)
return tmp_path, src.info()
def check_md5sum(path, md5sum): def check_md5sum(path, md5sum):
"""Tell whether the MD5 checksum of the file at path matches. """Tell whether the MD5 checksum of the file at path matches.
......
...@@ -152,6 +152,19 @@ This is a foo text. ...@@ -152,6 +152,19 @@ This is a foo text.
>>> remove(path) >>> remove(path)
HTTP basic authentication:
>>> download = Download()
>>> user_url = server_url.replace('/localhost:', '/%s@localhost:') + 'private/'
>>> path, is_temp = download(user_url % 'foo:' + 'foo:')
>>> is_temp; remove(path)
True
>>> path, is_temp = download(user_url % 'foo:bar' + 'foo:bar')
>>> is_temp; remove(path)
True
>>> download(user_url % 'bar:' + 'foo:')
Traceback (most recent call last):
UserError: Error downloading ...: HTTP Error 403: Forbidden
Downloading using the download cache Downloading using the download cache
------------------------------------ ------------------------------------
......
...@@ -23,6 +23,7 @@ except ImportError: ...@@ -23,6 +23,7 @@ except ImportError:
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from urllib2 import urlopen from urllib2 import urlopen
import base64
import errno import errno
import logging import logging
import os import os
...@@ -364,6 +365,23 @@ class Handler(BaseHTTPRequestHandler): ...@@ -364,6 +365,23 @@ class Handler(BaseHTTPRequestHandler):
self.__server.__log = False self.__server.__log = False
return k() return k()
if self.path.startswith('/private/'):
auth = self.headers.get('Authorization')
if auth and auth.startswith('Basic ') and \
self.path[9:].encode() == base64.b64decode(
self.headers.get('Authorization')[6:]):
return k()
# But not returning 401+WWW-Authenticate, we check that the client
# skips auth challenge, which is not free (in terms of performance)
# and useless for what we support.
self.send_response(403, 'Forbidden')
out = '<html><body>Forbidden</body></html>'.encode()
self.send_header('Content-Length', str(len(out)))
self.send_header('Content-Type', 'text/html')
self.end_headers()
self.wfile.write(out)
return
path = os.path.abspath(os.path.join(self.tree, *self.path.split('/'))) path = os.path.abspath(os.path.join(self.tree, *self.path.split('/')))
if not ( if not (
((path == self.tree) or path.startswith(self.tree+os.path.sep)) ((path == self.tree) or path.startswith(self.tree+os.path.sep))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment