Commit 10764aa4 authored by Jason Madden's avatar Jason Madden

Rework the way the in-place patcher works.

It now handles multiple imports better and doesn't pollute the namespace of an original module object if it happened to be pre-imported.

Fixes #1526
parent 8911b794
......@@ -74,6 +74,10 @@ Other
glibc Linux, it has a hardcoded limitation of only working with file
descriptors < 1024). See :issue:`1466` reported by Sam Wong.
- Make the dnspython resolver work if dns python had been imported
before the gevent resolver was initialized. Reported in
:issue:`1526` by Chris Utz and Josh Zuech.
1.5a3 (2020-01-01)
==================
......
......@@ -18,7 +18,7 @@ from gevent._compat import imp_acquire_lock
from gevent._compat import imp_release_lock
from gevent.builtins import __import__ as _import
from gevent.builtins import __import__ as g_import
MAPPING = {
......@@ -41,8 +41,9 @@ _PATCH_PREFIX = '__g_patched_module_'
class _SysModulesPatcher(object):
def __init__(self, importing):
def __init__(self, importing, extra_all=lambda mod_name: ()):
self._saved = {}
self.extra_all = extra_all
self.importing = importing
self.green_modules = {
stdlib_name: importlib.import_module(gevent_name)
......@@ -52,12 +53,14 @@ class _SysModulesPatcher(object):
self.orig_imported = frozenset(sys.modules)
def _save(self):
self.orig_imported = frozenset(sys.modules)
for modname in self.green_modules:
self._saved[modname] = sys.modules.get(modname, None)
self._saved[self.importing] = sys.modules.get(self.importing, None)
# Anything we've already patched regains its original name during this
# process
# process; anything imported in the original namespace is temporarily withdrawn.
for mod_name, mod in iteritems(sys.modules):
if mod_name.startswith(_PATCH_PREFIX):
orig_mod_name = mod_name[len(_PATCH_PREFIX):]
......@@ -101,28 +104,112 @@ class _SysModulesPatcher(object):
imp_acquire_lock()
self._save()
self._replace()
return self
module = None
def __call__(self):
if self.module is None:
self.module = self.import_one(self.importing)
return self
def import_one(self, module_name):
patched_name = _PATCH_PREFIX + module_name
if patched_name in sys.modules:
return sys.modules[patched_name]
assert module_name.startswith(self.importing)
sys.modules.pop(module_name, None)
module = g_import(module_name, {}, {}, module_name.split('.')[:-1])
self.module = module
sys.modules[patched_name] = module
# On Python 3, we could probably do something much nicer with the
# import machinery? Set the __loader__ or __finder__ or something like that?
self._import_all([module])
return module
def _import_all(self, queue):
# Called while monitoring for patch changes.
while queue:
module = queue.pop(0)
for attr_name in tuple(getattr(module, '__all__', ())) + self.extra_all(module.__name__):
try:
getattr(module, attr_name)
except AttributeError:
module_name = module.__name__ + '.' + attr_name
sys.modules.pop(module_name, None)
new_module = g_import(module_name, {}, {}, attr_name)
setattr(module, attr_name, new_module)
queue.append(new_module)
def import_patched(module_name):
def import_patched(module_name, extra_all=lambda mod_name: ()):
"""
Import *module_name* with gevent monkey-patches active,
and return the greened module.
and return an object holding the greened module as *module*.
Any sub-modules that were imported by the package are also
saved.
"""
patched_name = _PATCH_PREFIX + module_name
if patched_name in sys.modules:
return sys.modules[patched_name]
.. versionchanged:: 1.5a4
If the module defines ``__all__``, then each of those
attributes/modules is also imported as part of the same transaction,
recursively. The order of ``__all__`` is respected. Anything passed in
*extra_all* (which must be in the same namespace tree) is also imported.
"""
with cached_platform_architecture():
# Save the current module state, and restore on exit,
# capturing desirable changes in the modules package.
with _SysModulesPatcher(module_name):
sys.modules.pop(module_name, None)
with _SysModulesPatcher(module_name, extra_all) as patcher:
patcher()
return patcher
module = _import(module_name, {}, {}, module_name.split('.')[:-1])
sys.modules[patched_name] = module
return module
class cached_platform_architecture(object):
"""
Context manager that caches ``platform.architecture``.
Some things that load shared libraries (like Cryptodome, via
dnspython) invoke ``platform.architecture()`` for each one. That
in turn wants to fork and run commands , which in turn wants to
call ``threading._after_fork`` if the GIL has been initialized.
All of that means that certain imports done early may wind up
wanting to have the hub initialized potentially much earlier than
before.
Part of the fix is to observe when that happens and delay
initializing parts of gevent until as late as possible (e.g., we
delay importing and creating the resolver until the hub needs it,
unless explicitly configured).
The rest of the fix is to avoid the ``_after_fork`` issues by
first caching the results of platform.architecture before doing
patched imports.
(See events.py for similar issues with platform, and
test__threading_2.py for notes about threading._after_fork if the
GIL has been initialized)
"""
_arch_result = None
_orig_arch = None
_platform = None
def __enter__(self):
import platform
self._platform = platform
self._arch_result = platform.architecture()
self._orig_arch = platform.architecture
def arch(*args, **kwargs):
if not args and not kwargs:
return self._arch_result
return self._orig_arch(*args, **kwargs)
platform.architecture = arch
return self
def __exit__(self, *_args):
self._platform.architecture = self._orig_arch
self._platform = None
......@@ -78,9 +78,10 @@ from gevent.resolver import hostname_types
from gevent.resolver._hostsfile import HostsFile
from gevent.resolver._addresses import is_ipv6_addr
from gevent.builtins import __import__ as g_import
from gevent._compat import string_types
from gevent._compat import iteritems
from gevent._patcher import import_patched
from gevent._config import config
__all__ = [
......@@ -88,62 +89,49 @@ __all__ = [
]
# Import the DNS packages to use the gevent modules,
# even if the system is not monkey-patched.
# Beginning in dnspython 0.16, note that this imports dns.dnssec,
# which imports Cryptodome, which wants to load a lot of shared
# libraries, and to do so it invokes platform.architecture() for each
# one; that wants to fork and run commands (see events.py for similar
# issues with platform, and test__threading_2.py for notes about
# threading._after_fork if the GIL has been initialized), which in
# turn want to call threading._after_fork if the GIL has been
# initialized; all of that means that this now wants to have the hub
# initialized potentially much earlier than before, although we delay
# importing and creating the resolver until the hub is asked for it.
# We avoid the _after_fork issues by first caching the results of
# platform.architecture before doing the patched imports.
def _patch_dns():
import platform
result = platform.architecture()
orig_arch = platform.architecture
def arch(*args, **kwargs):
if not args and not kwargs:
return result
return orig_arch(*args, **kwargs)
platform.architecture = arch
try:
top = import_patched('dns')
for pkg in ('dns',
'dns.rdtypes',
'dns.rdtypes.IN',
'dns.rdtypes.ANY'):
mod = import_patched(pkg)
for name in mod.__all__:
setattr(mod, name, import_patched(pkg + '.' + name))
finally:
platform.architecture = orig_arch
return top
dns = _patch_dns()
# even if the system is not monkey-patched. If it *is* already
# patched, this imports a second copy under a different name,
# which is probably not strictly necessary, but matches
# what we've historically done, and allows configuring the resolvers
# differently.
def _dns_import_patched(name):
def _patch_dns():
from gevent._patcher import import_patched as importer
# The dns package itself is empty but defines __all__
# we make sure to import all of those things now under the
# patch. Note this triggers two DeprecationWarnings,
# one of which we could avoid.
extras = {
'dns': ('rdata', 'resolver', 'rdtypes'),
'dns.rdtypes': ('IN', 'ANY', ),
'dns.rdtypes.IN': ('A', 'AAAA',),
'dns.rdtypes.ANY': ('SOA',),
}
def extra_all(mod_name):
return extras.get(mod_name, ())
patcher = importer('dns', extra_all)
top = patcher.module
def _dns_import_patched(name):
assert name.startswith('dns')
import_patched(name)
with patcher():
patcher.import_one(name)
return dns
# This module tries to dynamically import classes
# using __import__, and it's important that they match
# the ones we just created, otherwise exceptions won't be caught
# as expected. It uses a one-arg __import__ statement and then
# tries to walk down the sub-modules using getattr, so we can't
# directly use import_patched as-is.
dns.rdata.__import__ = _dns_import_patched
# This module tries to dynamically import classes
# using __import__, and it's important that they match
# the ones we just created, otherwise exceptions won't be caught
# as expected. It uses a one-arg __import__ statement and then
# tries to walk down the sub-modules using getattr, so we can't
# directly use import_patched as-is.
top.rdata.__import__ = _dns_import_patched
return top
dns = _patch_dns()
resolver = dns.resolver
dTimeout = dns.resolver.Timeout
_exc_clear = getattr(sys, 'exc_clear', lambda: None)
# This is a wrapper for dns.resolver._getaddrinfo with two crucial changes.
# First, it backports https://github.com/rthalley/dnspython/issues/316
# from version 2.0. This can be dropped when we support only dnspython 2
......@@ -156,7 +144,9 @@ _exc_clear = getattr(sys, 'exc_clear', lambda: None)
# lookups that are not super fast. But it does have a habit of leaving
# exceptions around which can complicate our memleak checks.)
def _getaddrinfo(host=None, service=None, family=AF_UNSPEC, socktype=0,
proto=0, flags=0, _orig_gai=resolver._getaddrinfo):
proto=0, flags=0,
_orig_gai=resolver._getaddrinfo,
_exc_clear=getattr(sys, 'exc_clear', lambda: None)):
if flags & (socket.AI_ADDRCONFIG | socket.AI_V4MAPPED) != 0:
# Not implemented. We raise a gaierror as opposed to a
# NotImplementedError as it helps callers handle errors more
......@@ -172,7 +162,6 @@ resolver._getaddrinfo = _getaddrinfo
HOSTS_TTL = 300.0
class _HostsAnswer(dns.resolver.Answer):
# Answer class for HostsResolver object
......
......@@ -2,6 +2,10 @@
"""
Make a package.
This file has no other functionality. Individual modules in this package
are used for testing, often being run with 'python -m ...' in individual
test cases (functions).
"""
from __future__ import absolute_import
from __future__ import division
......
# -*- coding: utf-8 -*-
"""
Test for issue #1526:
- dnspython is imported first;
- no monkey-patching is done.
"""
from __future__ import print_function
from __future__ import absolute_import
import dns
assert dns
import gevent.socket as socket
socket.getfqdn() # create the resolver
from gevent.resolver.dnspython import dns as gdns
import dns.rdtypes
assert dns is not gdns, (dns, gdns)
assert dns.rdtypes is not gdns.rdtypes
import sys
print(sorted(sys.modules))
# -*- coding: utf-8 -*-
"""
Test for issue #1526:
- dnspython is imported first;
- monkey-patching happens early
"""
from __future__ import print_function, absolute_import
from gevent import monkey
monkey.patch_all()
import dns
assert dns
import socket
import sys
socket.getfqdn()
import gevent.resolver.dnspython
from gevent.resolver.dnspython import dns as gdns
from dns import rdtypes # NOT import dns.rdtypes
assert hasattr(dns, 'rdtypes')
assert gevent.resolver.dnspython.dns is gdns
assert gdns is not dns, (gdns, dns, "id dns", id(dns))
assert gdns.rdtypes is not rdtypes, (gdns.rdtypes, rdtypes)
print(sorted(sys.modules))
# -*- coding: utf-8 -*-
"""
Tests explicitly using the DNS python resolver.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import unittest
import subprocess
import os
from gevent import testing as greentest
class TestDnsPython(unittest.TestCase):
def _run_one(self, mod_name):
cmd = [
sys.executable,
'-m',
'gevent.tests.monkey_package.' + mod_name
]
env = dict(os.environ)
env['GEVENT_RESOLVER'] = 'dnspython'
output = subprocess.check_output(cmd, env=env)
self.assertIn(b'_g_patched_module_dns', output)
self.assertNotIn(b'_g_patched_module_dns.rdtypes', output)
return output
def test_import_dns_no_monkey_patch(self):
self._run_one('issue1526_no_monkey')
def test_import_dns_with_monkey_patch(self):
self._run_one('issue1526_with_monkey')
if __name__ == '__main__':
greentest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment