Commit 49ef9c91 authored by Jason Madden's avatar Jason Madden

Another pass for corner cases.

parent 10764aa4
......@@ -42,39 +42,55 @@ _PATCH_PREFIX = '__g_patched_module_'
class _SysModulesPatcher(object):
def __init__(self, importing, extra_all=lambda mod_name: ()):
self._saved = {}
# Permanent state.
self.extra_all = extra_all
self.importing = importing
self.green_modules = {
# green modules, replacing regularly imported modules.
# This begins as the gevent list of modules, and
# then gets extended with green things from the tree we import.
self._green_modules = {
stdlib_name: importlib.import_module(gevent_name)
for gevent_name, stdlib_name
in iteritems(MAPPING)
}
self.orig_imported = frozenset(sys.modules)
## Transient, reset each time we're called.
# The set of things imported before we began.
self._t_modules_to_restore = {}
def _save(self):
self.orig_imported = frozenset(sys.modules)
self._t_modules_to_restore = {}
for modname in self.green_modules:
self._saved[modname] = sys.modules.get(modname, None)
# Copy all the things we know we are going to overwrite.
for modname in self._green_modules:
self._t_modules_to_restore[modname] = sys.modules.get(modname, None)
self._saved[self.importing] = sys.modules.get(self.importing, None)
# Anything we've already patched regains its original name during this
# process; anything imported in the original namespace is temporarily withdrawn.
for mod_name, mod in iteritems(sys.modules):
if mod_name.startswith(_PATCH_PREFIX):
orig_mod_name = mod_name[len(_PATCH_PREFIX):]
self._saved[mod_name] = sys.modules.get(orig_mod_name, None)
self.green_modules[orig_mod_name] = mod
# Copy anything else in the import tree.
for modname, mod in list(iteritems(sys.modules)):
if modname.startswith(self.importing):
self._t_modules_to_restore[modname] = mod
# And remove it. If it had been imported green, it will
# be put right back. Otherwise, it was imported "manually"
# outside this process and isn't green.
del sys.modules[modname]
def _replace(self):
# Cover the target modules so that when you import the module it
# sees only the patched versions
for name, mod in iteritems(self.green_modules):
for name, mod in iteritems(self._green_modules):
sys.modules[name] = mod
def _restore(self):
for modname, mod in iteritems(self._saved):
# Anything from the same package tree we imported this time
# needs to be saved so we can restore it later, and so it doesn't
# leak into the namespace.
for modname, mod in list(iteritems(sys.modules)):
if modname.startswith(self.importing):
self._green_modules[modname] = mod
del sys.modules[modname]
# Now, what we saved at the beginning needs to be restored.
for modname, mod in iteritems(self._t_modules_to_restore):
if mod is not None:
sys.modules[modname] = mod
else:
......@@ -82,38 +98,35 @@ class _SysModulesPatcher(object):
del sys.modules[modname]
except KeyError:
pass
# Anything from the same package tree we imported this time
# needs to be saved so we can restore it later, and so it doesn't
# leak into the namespace.
pkg_prefix = self.importing.split('.', 1)[0]
for modname, mod in list(iteritems(sys.modules)):
if (modname not in self.orig_imported
and modname != self.importing
and not modname.startswith(_PATCH_PREFIX)
and modname.startswith(pkg_prefix)):
sys.modules[_PATCH_PREFIX + modname] = mod
del sys.modules[modname]
def __exit__(self, t, v, tb):
try:
self._restore()
finally:
imp_release_lock()
self._t_modules_to_restore = None
def __enter__(self):
imp_acquire_lock()
self._save()
self._replace()
return self
module = None
def __call__(self):
def __call__(self, after_import_hook):
if self.module is None:
self.module = self.import_one(self.importing)
with self:
self.module = self.import_one(self.importing, after_import_hook)
# Circular reference. Someone must keep a reference to this module alive
# for it to be visible. We record it in sys.modules to be that someone, and
# to aid debugging. In the past, we worked with multiple completely separate
# invocations of `import_patched`, but we no longer do.
self.module.__gevent_patcher__ = self
sys.modules[_PATCH_PREFIX + self.importing] = self.module
return self
def import_one(self, module_name):
def import_one(self, module_name, after_import_hook):
patched_name = _PATCH_PREFIX + module_name
if patched_name in sys.modules:
return sys.modules[patched_name]
......@@ -123,28 +136,31 @@ class _SysModulesPatcher(object):
module = g_import(module_name, {}, {}, module_name.split('.')[:-1])
self.module = module
sys.modules[patched_name] = module
# On Python 3, we could probably do something much nicer with the
# import machinery? Set the __loader__ or __finder__ or something like that?
self._import_all([module])
after_import_hook(module)
return module
def _import_all(self, queue):
# Called while monitoring for patch changes.
while queue:
module = queue.pop(0)
for attr_name in tuple(getattr(module, '__all__', ())) + self.extra_all(module.__name__):
name = module.__name__
mod_all = tuple(getattr(module, '__all__', ())) + self.extra_all(name)
for attr_name in mod_all:
try:
getattr(module, attr_name)
except AttributeError:
module_name = module.__name__ + '.' + attr_name
sys.modules.pop(module_name, None)
new_module = g_import(module_name, {}, {}, attr_name)
setattr(module, attr_name, new_module)
queue.append(new_module)
def import_patched(module_name, extra_all=lambda mod_name: ()):
def import_patched(module_name,
extra_all=lambda mod_name: (),
after_import_hook=lambda module: None):
"""
Import *module_name* with gevent monkey-patches active,
and return an object holding the greened module as *module*.
......@@ -158,13 +174,17 @@ def import_patched(module_name, extra_all=lambda mod_name: ()):
recursively. The order of ``__all__`` is respected. Anything passed in
*extra_all* (which must be in the same namespace tree) is also imported.
.. versionchanged:: 1.5a4
You must now do all patching for a given module tree
with one call to this method, or at least by using the returned
object.
"""
with cached_platform_architecture():
# Save the current module state, and restore on exit,
# capturing desirable changes in the modules package.
with _SysModulesPatcher(module_name, extra_all) as patcher:
patcher()
patcher = _SysModulesPatcher(module_name, extra_all)
patcher(after_import_hook)
return patcher
......
......@@ -105,25 +105,36 @@ def _patch_dns():
'dns': ('rdata', 'resolver', 'rdtypes'),
'dns.rdtypes': ('IN', 'ANY', ),
'dns.rdtypes.IN': ('A', 'AAAA',),
'dns.rdtypes.ANY': ('SOA',),
'dns.rdtypes.ANY': ('SOA', 'PTR'),
}
def extra_all(mod_name):
return extras.get(mod_name, ())
patcher = importer('dns', extra_all)
def after_import_hook(mod):
# Runs while still in the original patching scope.
# The dns.rdata:get_rdata_class() function tries to
# dynamically import modules using __import__ and then walk
# through the attribute tree to find classes in `dns.rdtypes`.
# It is critical that this all matches up, otherwise we can
# get different exception classes that don't get caught.
# We could patch __import__ to do things at runtime, but it's
# easier to enumerate the world and populate the cache now
# before we then switch the names back.
rdata = mod.rdata
get_rdata_class = rdata.get_rdata_class
for rdclass in mod.rdataclass._by_value:
for rdtype in mod.rdatatype._by_value:
get_rdata_class(rdclass, rdtype)
patcher = importer('dns', extra_all, after_import_hook)
top = patcher.module
def _dns_import_patched(name):
assert name.startswith('dns')
with patcher():
patcher.import_one(name)
return dns
# This module tries to dynamically import classes
# using __import__, and it's important that they match
# the ones we just created, otherwise exceptions won't be caught
# as expected. It uses a one-arg __import__ statement and then
# tries to walk down the sub-modules using getattr, so we can't
# directly use import_patched as-is.
top.rdata.__import__ = _dns_import_patched
# Now disable the dynamic imports
def _no_dynamic_imports(name):
raise ValueError(name)
top.rdata.__import__ = _no_dynamic_imports
return top
......
......@@ -20,8 +20,9 @@ socket.getfqdn()
import gevent.resolver.dnspython
from gevent.resolver.dnspython import dns as gdns
from dns import rdtypes # NOT import dns.rdtypes
assert hasattr(dns, 'rdtypes')
assert gevent.resolver.dnspython.dns is gdns
assert gdns is not dns, (gdns, dns, "id dns", id(dns))
assert gdns.rdtypes is not rdtypes, (gdns.rdtypes, rdtypes)
assert hasattr(dns, 'rdtypes')
print(sorted(sys.modules))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment