Commit 6a30fecf authored by will-ca's avatar will-ca Committed by GitHub

Make `Shadow.inline()` caching account for language version and compilation environment. (GH-3440)

Closes https://github.com/cython/cython/issues/3419
parent 7cc572f5
...@@ -141,6 +141,10 @@ def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None): ...@@ -141,6 +141,10 @@ def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
else: else:
print("Couldn't find %r" % symbol) print("Couldn't find %r" % symbol)
def _inline_key(orig_code, arg_sigs, language_level):
key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
def cython_inline(code, get_type=unsafe_type, def cython_inline(code, get_type=unsafe_type,
lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
cython_include_dirs=None, cython_compiler_directives=None, cython_include_dirs=None, cython_compiler_directives=None,
...@@ -150,13 +154,20 @@ def cython_inline(code, get_type=unsafe_type, ...@@ -150,13 +154,20 @@ def cython_inline(code, get_type=unsafe_type,
get_type = lambda x: 'object' get_type = lambda x: 'object'
ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
cython_compiler_directives = dict(cython_compiler_directives or {})
if language_level is None and 'language_level' not in cython_compiler_directives:
language_level = '3str'
if language_level is not None:
cython_compiler_directives['language_level'] = language_level
# Fast path if this has been called in this session. # Fast path if this has been called in this session.
_unbound_symbols = _cython_inline_cache.get(code) _unbound_symbols = _cython_inline_cache.get(code)
if _unbound_symbols is not None: if _unbound_symbols is not None:
_populate_unbound(kwds, _unbound_symbols, locals, globals) _populate_unbound(kwds, _unbound_symbols, locals, globals)
args = sorted(kwds.items()) args = sorted(kwds.items())
arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args]) arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
invoke = _cython_inline_cache.get((code, arg_sigs)) key_hash = _inline_key(code, arg_sigs, language_level)
invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
if invoke is not None: if invoke is not None:
arg_list = [arg[1] for arg in args] arg_list = [arg[1] for arg in args]
return invoke(*arg_list) return invoke(*arg_list)
...@@ -177,12 +188,6 @@ def cython_inline(code, get_type=unsafe_type, ...@@ -177,12 +188,6 @@ def cython_inline(code, get_type=unsafe_type,
# Parsing from strings not fully supported (e.g. cimports). # Parsing from strings not fully supported (e.g. cimports).
print("Could not parse code as a string (to extract unbound symbols).") print("Could not parse code as a string (to extract unbound symbols).")
cython_compiler_directives = dict(cython_compiler_directives or {})
if language_level is None and 'language_level' not in cython_compiler_directives:
language_level = '3str'
if language_level is not None:
cython_compiler_directives['language_level'] = language_level
cimports = [] cimports = []
for name, arg in list(kwds.items()): for name, arg in list(kwds.items()):
if arg is cython_module: if arg is cython_module:
...@@ -190,8 +195,8 @@ def cython_inline(code, get_type=unsafe_type, ...@@ -190,8 +195,8 @@ def cython_inline(code, get_type=unsafe_type,
del kwds[name] del kwds[name]
arg_names = sorted(kwds) arg_names = sorted(kwds)
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__ key_hash = _inline_key(orig_code, arg_sigs, language_level)
module_name = "_cython_inline_" + hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest() module_name = "_cython_inline_" + key_hash
if module_name in sys.modules: if module_name in sys.modules:
module = sys.modules[module_name] module = sys.modules[module_name]
...@@ -258,7 +263,7 @@ def __invoke(%(params)s): ...@@ -258,7 +263,7 @@ def __invoke(%(params)s):
module = load_dynamic(module_name, module_path) module = load_dynamic(module_name, module_path)
_cython_inline_cache[orig_code, arg_sigs] = module.__invoke _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
arg_list = [kwds[arg] for arg in arg_names] arg_list = [kwds[arg] for arg in arg_names]
return module.__invoke(*arg_list) return module.__invoke(*arg_list)
......
...@@ -74,6 +74,18 @@ class TestInline(CythonTest): ...@@ -74,6 +74,18 @@ class TestInline(CythonTest):
6 6
) )
def test_lang_version(self):
# GH-3419. Caching for inline code didn't always respect compiler directives.
inline_divcode = "def f(int a, int b): return a/b"
self.assertEqual(
inline(inline_divcode, language_level=2)['f'](5,2),
2
)
self.assertEqual(
inline(inline_divcode, language_level=3)['f'](5,2),
2.5
)
if has_numpy: if has_numpy:
def test_numpy(self): def test_numpy(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment