Commit 10764aa4 authored by Jason Madden's avatar Jason Madden

Rework the way the in-place patcher works.

It now handles multiple imports better and doesn't pollute the namespace of an original module object if it happened to be pre-imported.

Fixes #1526
parent 8911b794
...@@ -74,6 +74,10 @@ Other ...@@ -74,6 +74,10 @@ Other
glibc Linux, it has a hardcoded limitation of only working with file glibc Linux, it has a hardcoded limitation of only working with file
descriptors < 1024). See :issue:`1466` reported by Sam Wong. descriptors < 1024). See :issue:`1466` reported by Sam Wong.
- Make the dnspython resolver work if dns python had been imported
before the gevent resolver was initialized. Reported in
:issue:`1526` by Chris Utz and Josh Zuech.
1.5a3 (2020-01-01) 1.5a3 (2020-01-01)
================== ==================
......
...@@ -18,7 +18,7 @@ from gevent._compat import imp_acquire_lock ...@@ -18,7 +18,7 @@ from gevent._compat import imp_acquire_lock
from gevent._compat import imp_release_lock from gevent._compat import imp_release_lock
from gevent.builtins import __import__ as _import from gevent.builtins import __import__ as g_import
MAPPING = { MAPPING = {
...@@ -41,8 +41,9 @@ _PATCH_PREFIX = '__g_patched_module_' ...@@ -41,8 +41,9 @@ _PATCH_PREFIX = '__g_patched_module_'
class _SysModulesPatcher(object): class _SysModulesPatcher(object):
def __init__(self, importing): def __init__(self, importing, extra_all=lambda mod_name: ()):
self._saved = {} self._saved = {}
self.extra_all = extra_all
self.importing = importing self.importing = importing
self.green_modules = { self.green_modules = {
stdlib_name: importlib.import_module(gevent_name) stdlib_name: importlib.import_module(gevent_name)
...@@ -52,12 +53,14 @@ class _SysModulesPatcher(object): ...@@ -52,12 +53,14 @@ class _SysModulesPatcher(object):
self.orig_imported = frozenset(sys.modules) self.orig_imported = frozenset(sys.modules)
def _save(self): def _save(self):
self.orig_imported = frozenset(sys.modules)
for modname in self.green_modules: for modname in self.green_modules:
self._saved[modname] = sys.modules.get(modname, None) self._saved[modname] = sys.modules.get(modname, None)
self._saved[self.importing] = sys.modules.get(self.importing, None) self._saved[self.importing] = sys.modules.get(self.importing, None)
# Anything we've already patched regains its original name during this # Anything we've already patched regains its original name during this
# process # process; anything imported in the original namespace is temporarily withdrawn.
for mod_name, mod in iteritems(sys.modules): for mod_name, mod in iteritems(sys.modules):
if mod_name.startswith(_PATCH_PREFIX): if mod_name.startswith(_PATCH_PREFIX):
orig_mod_name = mod_name[len(_PATCH_PREFIX):] orig_mod_name = mod_name[len(_PATCH_PREFIX):]
...@@ -101,28 +104,112 @@ class _SysModulesPatcher(object): ...@@ -101,28 +104,112 @@ class _SysModulesPatcher(object):
imp_acquire_lock() imp_acquire_lock()
self._save() self._save()
self._replace() self._replace()
return self
module = None
def __call__(self):
if self.module is None:
self.module = self.import_one(self.importing)
return self
def import_one(self, module_name):
patched_name = _PATCH_PREFIX + module_name
if patched_name in sys.modules:
return sys.modules[patched_name]
assert module_name.startswith(self.importing)
sys.modules.pop(module_name, None)
module = g_import(module_name, {}, {}, module_name.split('.')[:-1])
self.module = module
sys.modules[patched_name] = module
# On Python 3, we could probably do something much nicer with the
# import machinery? Set the __loader__ or __finder__ or something like that?
self._import_all([module])
return module
def _import_all(self, queue):
# Called while monitoring for patch changes.
while queue:
module = queue.pop(0)
for attr_name in tuple(getattr(module, '__all__', ())) + self.extra_all(module.__name__):
try:
getattr(module, attr_name)
except AttributeError:
module_name = module.__name__ + '.' + attr_name
sys.modules.pop(module_name, None)
new_module = g_import(module_name, {}, {}, attr_name)
setattr(module, attr_name, new_module)
queue.append(new_module)
def import_patched(module_name):
def import_patched(module_name, extra_all=lambda mod_name: ()):
""" """
Import *module_name* with gevent monkey-patches active, Import *module_name* with gevent monkey-patches active,
and return the greened module. and return an object holding the greened module as *module*.
Any sub-modules that were imported by the package are also Any sub-modules that were imported by the package are also
saved. saved.
.. versionchanged:: 1.5a4
If the module defines ``__all__``, then each of those
attributes/modules is also imported as part of the same transaction,
recursively. The order of ``__all__`` is respected. Anything passed in
*extra_all* (which must be in the same namespace tree) is also imported.
""" """
patched_name = _PATCH_PREFIX + module_name
if patched_name in sys.modules:
return sys.modules[patched_name]
with cached_platform_architecture():
# Save the current module state, and restore on exit,
# capturing desirable changes in the modules package.
with _SysModulesPatcher(module_name, extra_all) as patcher:
patcher()
return patcher
# Save the current module state, and restore on exit,
# capturing desirable changes in the modules package.
with _SysModulesPatcher(module_name):
sys.modules.pop(module_name, None)
module = _import(module_name, {}, {}, module_name.split('.')[:-1]) class cached_platform_architecture(object):
sys.modules[patched_name] = module """
Context manager that caches ``platform.architecture``.
Some things that load shared libraries (like Cryptodome, via
dnspython) invoke ``platform.architecture()`` for each one. That
in turn wants to fork and run commands , which in turn wants to
call ``threading._after_fork`` if the GIL has been initialized.
All of that means that certain imports done early may wind up
wanting to have the hub initialized potentially much earlier than
before.
Part of the fix is to observe when that happens and delay
initializing parts of gevent until as late as possible (e.g., we
delay importing and creating the resolver until the hub needs it,
unless explicitly configured).
The rest of the fix is to avoid the ``_after_fork`` issues by
first caching the results of platform.architecture before doing
patched imports.
(See events.py for similar issues with platform, and
test__threading_2.py for notes about threading._after_fork if the
GIL has been initialized)
"""
return module _arch_result = None
_orig_arch = None
_platform = None
def __enter__(self):
import platform
self._platform = platform
self._arch_result = platform.architecture()
self._orig_arch = platform.architecture
def arch(*args, **kwargs):
if not args and not kwargs:
return self._arch_result
return self._orig_arch(*args, **kwargs)
platform.architecture = arch
return self
def __exit__(self, *_args):
self._platform.architecture = self._orig_arch
self._platform = None
...@@ -78,9 +78,10 @@ from gevent.resolver import hostname_types ...@@ -78,9 +78,10 @@ from gevent.resolver import hostname_types
from gevent.resolver._hostsfile import HostsFile from gevent.resolver._hostsfile import HostsFile
from gevent.resolver._addresses import is_ipv6_addr from gevent.resolver._addresses import is_ipv6_addr
from gevent.builtins import __import__ as g_import
from gevent._compat import string_types from gevent._compat import string_types
from gevent._compat import iteritems from gevent._compat import iteritems
from gevent._patcher import import_patched
from gevent._config import config from gevent._config import config
__all__ = [ __all__ = [
...@@ -88,62 +89,49 @@ __all__ = [ ...@@ -88,62 +89,49 @@ __all__ = [
] ]
# Import the DNS packages to use the gevent modules, # Import the DNS packages to use the gevent modules,
# even if the system is not monkey-patched. # even if the system is not monkey-patched. If it *is* already
# patched, this imports a second copy under a different name,
# Beginning in dnspython 0.16, note that this imports dns.dnssec, # which is probably not strictly necessary, but matches
# which imports Cryptodome, which wants to load a lot of shared # what we've historically done, and allows configuring the resolvers
# libraries, and to do so it invokes platform.architecture() for each # differently.
# one; that wants to fork and run commands (see events.py for similar
# issues with platform, and test__threading_2.py for notes about
# threading._after_fork if the GIL has been initialized), which in
# turn want to call threading._after_fork if the GIL has been
# initialized; all of that means that this now wants to have the hub
# initialized potentially much earlier than before, although we delay
# importing and creating the resolver until the hub is asked for it.
# We avoid the _after_fork issues by first caching the results of
# platform.architecture before doing the patched imports.
def _patch_dns(): def _patch_dns():
import platform from gevent._patcher import import_patched as importer
result = platform.architecture() # The dns package itself is empty but defines __all__
orig_arch = platform.architecture # we make sure to import all of those things now under the
def arch(*args, **kwargs): # patch. Note this triggers two DeprecationWarnings,
if not args and not kwargs: # one of which we could avoid.
return result extras = {
return orig_arch(*args, **kwargs) 'dns': ('rdata', 'resolver', 'rdtypes'),
platform.architecture = arch 'dns.rdtypes': ('IN', 'ANY', ),
try: 'dns.rdtypes.IN': ('A', 'AAAA',),
top = import_patched('dns') 'dns.rdtypes.ANY': ('SOA',),
for pkg in ('dns', }
'dns.rdtypes', def extra_all(mod_name):
'dns.rdtypes.IN', return extras.get(mod_name, ())
'dns.rdtypes.ANY'): patcher = importer('dns', extra_all)
mod = import_patched(pkg) top = patcher.module
for name in mod.__all__: def _dns_import_patched(name):
setattr(mod, name, import_patched(pkg + '.' + name)) assert name.startswith('dns')
finally: with patcher():
platform.architecture = orig_arch patcher.import_one(name)
return dns
# This module tries to dynamically import classes
# using __import__, and it's important that they match
# the ones we just created, otherwise exceptions won't be caught
# as expected. It uses a one-arg __import__ statement and then
# tries to walk down the sub-modules using getattr, so we can't
# directly use import_patched as-is.
top.rdata.__import__ = _dns_import_patched
return top return top
dns = _patch_dns() dns = _patch_dns()
def _dns_import_patched(name):
assert name.startswith('dns')
import_patched(name)
return dns
# This module tries to dynamically import classes
# using __import__, and it's important that they match
# the ones we just created, otherwise exceptions won't be caught
# as expected. It uses a one-arg __import__ statement and then
# tries to walk down the sub-modules using getattr, so we can't
# directly use import_patched as-is.
dns.rdata.__import__ = _dns_import_patched
resolver = dns.resolver resolver = dns.resolver
dTimeout = dns.resolver.Timeout dTimeout = dns.resolver.Timeout
_exc_clear = getattr(sys, 'exc_clear', lambda: None)
# This is a wrapper for dns.resolver._getaddrinfo with two crucial changes. # This is a wrapper for dns.resolver._getaddrinfo with two crucial changes.
# First, it backports https://github.com/rthalley/dnspython/issues/316 # First, it backports https://github.com/rthalley/dnspython/issues/316
# from version 2.0. This can be dropped when we support only dnspython 2 # from version 2.0. This can be dropped when we support only dnspython 2
...@@ -156,7 +144,9 @@ _exc_clear = getattr(sys, 'exc_clear', lambda: None) ...@@ -156,7 +144,9 @@ _exc_clear = getattr(sys, 'exc_clear', lambda: None)
# lookups that are not super fast. But it does have a habit of leaving # lookups that are not super fast. But it does have a habit of leaving
# exceptions around which can complicate our memleak checks.) # exceptions around which can complicate our memleak checks.)
def _getaddrinfo(host=None, service=None, family=AF_UNSPEC, socktype=0, def _getaddrinfo(host=None, service=None, family=AF_UNSPEC, socktype=0,
proto=0, flags=0, _orig_gai=resolver._getaddrinfo): proto=0, flags=0,
_orig_gai=resolver._getaddrinfo,
_exc_clear=getattr(sys, 'exc_clear', lambda: None)):
if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0: if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0:
# Not implemented. We raise a gaierror as opposed to a # Not implemented. We raise a gaierror as opposed to a
# NotImplementedError as it helps callers handle errors more # NotImplementedError as it helps callers handle errors more
...@@ -172,7 +162,6 @@ resolver._getaddrinfo = _getaddrinfo ...@@ -172,7 +162,6 @@ resolver._getaddrinfo = _getaddrinfo
HOSTS_TTL = 300.0 HOSTS_TTL = 300.0
class _HostsAnswer(dns.resolver.Answer): class _HostsAnswer(dns.resolver.Answer):
# Answer class for HostsResolver object # Answer class for HostsResolver object
......
...@@ -2,6 +2,10 @@ ...@@ -2,6 +2,10 @@
""" """
Make a package. Make a package.
This file has no other functionality. Individual modules in this package
are used for testing, often being run with 'python -m ...' in individual
test cases (functions).
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# -*- coding: utf-8 -*-
"""
Test for issue #1526:
- dnspython is imported first;
- no monkey-patching is done.
"""
from __future__ import print_function
from __future__ import absolute_import
import dns
assert dns
import gevent.socket as socket
socket.getfqdn() # create the resolver
from gevent.resolver.dnspython import dns as gdns
import dns.rdtypes
assert dns is not gdns, (dns, gdns)
assert dns.rdtypes is not gdns.rdtypes
import sys
print(sorted(sys.modules))
# -*- coding: utf-8 -*-
"""
Test for issue #1526:
- dnspython is imported first;
- monkey-patching happens early
"""
from __future__ import print_function, absolute_import
from gevent import monkey
monkey.patch_all()
import dns
assert dns
import socket
import sys
socket.getfqdn()
import gevent.resolver.dnspython
from gevent.resolver.dnspython import dns as gdns
from dns import rdtypes # NOT import dns.rdtypes
assert hasattr(dns, 'rdtypes')
assert gevent.resolver.dnspython.dns is gdns
assert gdns is not dns, (gdns, dns, "id dns", id(dns))
assert gdns.rdtypes is not rdtypes, (gdns.rdtypes, rdtypes)
print(sorted(sys.modules))
# -*- coding: utf-8 -*-
"""
Tests explicitly using the DNS python resolver.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import unittest
import subprocess
import os
from gevent import testing as greentest
class TestDnsPython(unittest.TestCase):
def _run_one(self, mod_name):
cmd = [
sys.executable,
'-m',
'gevent.tests.monkey_package.' + mod_name
]
env = dict(os.environ)
env['GEVENT_RESOLVER'] = 'dnspython'
output = subprocess.check_output(cmd, env=env)
self.assertIn(b'_g_patched_module_dns', output)
self.assertNotIn(b'_g_patched_module_dns.rdtypes', output)
return output
def test_import_dns_no_monkey_patch(self):
self._run_one('issue1526_no_monkey')
def test_import_dns_with_monkey_patch(self):
self._run_one('issue1526_with_monkey')
if __name__ == '__main__':
greentest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment