Commit f2e83334 authored by Stefan Behnel's avatar Stefan Behnel

Add Cython.Utils to the list of compiled modules and include a faster...

Add Cython.Utils to the list of compiled modules and include a faster @contextmanager for try-finally cases.
parent 47c1d85d
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
from Cython.Utils import ( from Cython.Utils import (
_CACHE_NAME_PATTERN, _build_cache_name, _find_cache_attributes, _CACHE_NAME_PATTERN, _build_cache_name, _find_cache_attributes,
build_hex_version, cached_method, clear_method_caches) build_hex_version, cached_method, clear_method_caches, try_finally_contextmanager)
METHOD_NAME = "cached_next" METHOD_NAME = "cached_next"
CACHE_NAME = _build_cache_name(METHOD_NAME) CACHE_NAME = _build_cache_name(METHOD_NAME)
...@@ -94,3 +94,35 @@ class TestCythonUtils(unittest.TestCase): ...@@ -94,3 +94,35 @@ class TestCythonUtils(unittest.TestCase):
clear_method_caches(obj) clear_method_caches(obj)
self.set_of_names_equal(obj, {names}) self.set_of_names_equal(obj, {names})
def test_try_finally_contextmanager(self):
states = []
@try_finally_contextmanager
def gen(*args, **kwargs):
states.append("enter")
yield (args, kwargs)
states.append("exit")
with gen(1, 2, 3, x=4) as call_args:
assert states == ["enter"]
self.assertEqual(call_args, ((1, 2, 3), {'x': 4}))
assert states == ["enter", "exit"]
class MyException(RuntimeError):
pass
del states[:]
with self.assertRaises(MyException):
with gen(1, 2, y=4) as call_args:
assert states == ["enter"]
self.assertEqual(call_args, ((1, 2), {'y': 4}))
raise MyException("FAIL INSIDE")
assert states == ["enter", "exit"]
del states[:]
with self.assertRaises(StopIteration):
with gen(1, 2, y=4) as call_args:
assert states == ["enter"]
self.assertEqual(call_args, ((1, 2), {'y': 4}))
raise StopIteration("STOP")
assert states == ["enter", "exit"]
cdef class _TryFinallyGeneratorContextManager:
cdef object _gen
""" """
Cython -- Things that don't belong Cython -- Things that don't belong anywhere else in particular
anywhere else in particular
""" """
from __future__ import absolute_import from __future__ import absolute_import
import cython
cython.declare(
basestring=object,
os=object, sys=object, re=object, io=object, codecs=object, glob=object, shutil=object, tempfile=object,
cython_version=object,
_function_caches=list, _parse_file_version=object, _match_file_encoding=object,
)
try: try:
from __builtin__ import basestring from __builtin__ import basestring
except ImportError: except ImportError:
...@@ -23,7 +31,7 @@ import codecs ...@@ -23,7 +31,7 @@ import codecs
import glob import glob
import shutil import shutil
import tempfile import tempfile
from contextlib import contextmanager from functools import wraps
from . import __version__ as cython_version from . import __version__ as cython_version
...@@ -34,6 +42,31 @@ _CACHE_NAME_PATTERN = re.compile(r"^__(.+)_cache$") ...@@ -34,6 +42,31 @@ _CACHE_NAME_PATTERN = re.compile(r"^__(.+)_cache$")
modification_time = os.path.getmtime modification_time = os.path.getmtime
class _TryFinallyGeneratorContextManager(object):
"""
Fast, bare minimum @contextmanager, only for try-finally, not for exception handling.
"""
def __init__(self, gen):
self._gen = gen
def __enter__(self):
return next(self._gen)
def __exit__(self, exc_type, exc_val, exc_tb):
try:
next(self._gen)
except (StopIteration, GeneratorExit):
pass
def try_finally_contextmanager(gen_func):
@wraps(gen_func)
def make_gen(*args, **kwargs):
return _TryFinallyGeneratorContextManager(gen_func(*args, **kwargs))
return make_gen
_function_caches = [] _function_caches = []
...@@ -47,6 +80,7 @@ def cached_function(f): ...@@ -47,6 +80,7 @@ def cached_function(f):
_function_caches.append(cache) _function_caches.append(cache)
uncomputed = object() uncomputed = object()
@wraps(f)
def wrapper(*args): def wrapper(*args):
res = cache.get(args, uncomputed) res = cache.get(args, uncomputed)
if res is uncomputed: if res is uncomputed:
...@@ -443,7 +477,7 @@ def get_cython_cache_dir(): ...@@ -443,7 +477,7 @@ def get_cython_cache_dir():
return os.path.expanduser(os.path.join('~', '.cython')) return os.path.expanduser(os.path.join('~', '.cython'))
@contextmanager @try_finally_contextmanager
def captured_fd(stream=2, encoding=None): def captured_fd(stream=2, encoding=None):
orig_stream = os.dup(stream) # keep copy of original stream orig_stream = os.dup(stream) # keep copy of original stream
try: try:
...@@ -455,15 +489,14 @@ def captured_fd(stream=2, encoding=None): ...@@ -455,15 +489,14 @@ def captured_fd(stream=2, encoding=None):
return _output[0] return _output[0]
os.dup2(temp_file.fileno(), stream) # replace stream by copy of pipe os.dup2(temp_file.fileno(), stream) # replace stream by copy of pipe
try: def get_output():
def get_output(): result = read_output()
result = read_output() return result.decode(encoding) if encoding else result
return result.decode(encoding) if encoding else result
yield get_output
yield get_output # note: @contextlib.contextmanager requires try-finally here
finally: os.dup2(orig_stream, stream) # restore original stream
os.dup2(orig_stream, stream) # restore original stream read_output() # keep the output in case it's used after closing the context manager
read_output() # keep the output in case it's used after closing the context manager
finally: finally:
os.close(orig_stream) os.close(orig_stream)
...@@ -514,23 +547,6 @@ def print_bytes(s, header_text=None, end=b'\n', file=sys.stdout, flush=True): ...@@ -514,23 +547,6 @@ def print_bytes(s, header_text=None, end=b'\n', file=sys.stdout, flush=True):
out.flush() out.flush()
class LazyStr:
def __init__(self, callback):
self.callback = callback
def __str__(self):
return self.callback()
def __repr__(self):
return self.callback()
def __add__(self, right):
return self.callback() + right
def __radd__(self, left):
return left + self.callback()
class OrderedSet(object): class OrderedSet(object):
def __init__(self, elements=()): def __init__(self, elements=()):
self._list = [] self._list = []
......
...@@ -94,6 +94,7 @@ def compile_cython_modules(profile=False, coverage=False, compile_more=False, cy ...@@ -94,6 +94,7 @@ def compile_cython_modules(profile=False, coverage=False, compile_more=False, cy
"Cython.Compiler.FusedNode", "Cython.Compiler.FusedNode",
"Cython.Tempita._tempita", "Cython.Tempita._tempita",
"Cython.StringIOTree", "Cython.StringIOTree",
"Cython.Utils",
] ]
if compile_more: if compile_more:
compiled_modules.extend([ compiled_modules.extend([
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment