Commit 9c5dca1b authored by Stefan Behnel's avatar Stefan Behnel

re-implement Coroutine.__await__() using a dedicated wrapper type to avoid...

re-implement Coroutine.__await__() using a dedicated wrapper type to avoid having to make a Coroutine half an Iterator
parent ef2f3011
......@@ -7,25 +7,33 @@ static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject
static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) {
PyObject *source_gen, *retval;
source_gen = PyObject_GetIter(source);
if (unlikely(!source_gen)) {
#ifdef __Pyx_Coroutine_USED
#if CYTHON_COMPILING_IN_CPYTHON
// avoid exception instantiation if possible
if (PyErr_Occurred() == PyExc_TypeError
#else
if (PyErr_ExceptionMatches(PyExc_TypeError)
#endif
&& __Pyx_Coroutine_CheckExact(source)) {
PyErr_Clear();
#ifdef __Pyx_Coroutine_USED
if (__Pyx_Coroutine_CheckExact(source)) {
// TODO: this should only happen for types.coroutine()ed generators, but we can't determine that here
source_gen = __Pyx_Coroutine_await(source);
Py_INCREF(source);
source_gen = source;
retval = __Pyx_Generator_Next(source);
} else
#endif
#endif
{
#if CYTHON_COMPILING_IN_CPYTHON
if (likely(Py_TYPE(source)->tp_iter)) {
source_gen = Py_TYPE(source)->tp_iter(source);
if (unlikely(!source_gen))
return NULL;
if (unlikely(!PyIter_Check(source_gen))) {
PyErr_Format(PyExc_TypeError,
"iter() returned non-iterator of type '%.100s'",
Py_TYPE(source_gen)->tp_name);
Py_DECREF(source_gen);
return NULL;
}
} else
#endif
source_gen = PyObject_GetIter(source);
// source_gen is now the iterator, make the first next() call
retval = Py_TYPE(source_gen)->tp_iternext(source_gen);
}
if (likely(retval)) {
gen->yieldfrom = source_gen;
return retval;
......@@ -274,13 +282,14 @@ static int __Pyx_PyGen_FetchStopIterationValue(PyObject **pvalue); /*proto*/
#define __Pyx_Coroutine_USED
static PyTypeObject *__pyx_CoroutineType = 0;
static PyTypeObject *__pyx_CoroutineAwaitType = 0;
#define __Pyx_Coroutine_CheckExact(obj) (Py_TYPE(obj) == __pyx_CoroutineType)
#define __Pyx_Coroutine_New(body, closure, name, qualname) \
__Pyx__Coroutine_New(__pyx_CoroutineType, body, closure, name, qualname)
static int __pyx_Coroutine_init(void); /*proto*/
static PyObject *__Pyx_Coroutine_await(PyObject *coroutine); /*proto*/
static PyObject *__Pyx__Coroutine_await(PyObject *coroutine); /*proto*/
//////////////////// Generator.proto ////////////////////
......@@ -961,6 +970,140 @@ static __pyx_CoroutineObject *__Pyx__Coroutine_New(PyTypeObject* type, __pyx_cor
//@requires: CoroutineBase
//@requires: PatchGeneratorABC
typedef struct {
PyObject_HEAD
PyObject *coroutine;
} __pyx_CoroutineAwaitObject;
static void __Pyx_CoroutineAwait_dealloc(PyObject *self) {
#if CYTHON_COMPILING_IN_CPYTHON
_PyObject_GC_UNTRACK(self);
#else
PyObject_GC_UnTrack(self);
#endif
Py_CLEAR(((__pyx_CoroutineAwaitObject*)self)->coroutine);
PyObject_GC_Del(self);
}
static int __Pyx_CoroutineAwait_traverse(__pyx_CoroutineAwaitObject *self, visitproc visit, void *arg) {
Py_VISIT(self->coroutine);
return 0;
}
static int __Pyx_CoroutineAwait_clear(__pyx_CoroutineAwaitObject *self) {
Py_CLEAR(self->coroutine);
return 0;
}
static PyObject *__Pyx_CoroutineAwait_Next(__pyx_CoroutineAwaitObject *self) {
return __Pyx_Generator_Next(self->coroutine);
}
static PyObject *__Pyx_CoroutineAwait_Send(__pyx_CoroutineAwaitObject *self, PyObject *value) {
return __Pyx_Coroutine_Send(self->coroutine, value);
}
static PyObject *__Pyx_CoroutineAwait_Throw(__pyx_CoroutineAwaitObject *self, PyObject *args) {
return __Pyx_Coroutine_Throw(self->coroutine, args);
}
static PyObject *__Pyx_CoroutineAwait_Close(__pyx_CoroutineAwaitObject *self) {
return __Pyx_Coroutine_Close(self->coroutine);
}
static PyObject *__Pyx_CoroutineAwait_self(PyObject *self) {
Py_INCREF(self);
return self;
}
static PyObject *__Pyx_CoroutineAwait_no_new(CYTHON_UNUSED PyTypeObject *type, CYTHON_UNUSED PyObject *args, CYTHON_UNUSED PyObject *kwargs) {
PyErr_SetString(PyExc_TypeError, "cannot instantiate 'coroutine_await' type, call coroutine.__await__() instead");
return NULL;
}
static PyMethodDef __pyx_CoroutineAwait_methods[] = {
{"send", (PyCFunction) __Pyx_CoroutineAwait_Send, METH_O, 0},
{"throw", (PyCFunction) __Pyx_CoroutineAwait_Throw, METH_VARARGS, 0},
{"close", (PyCFunction) __Pyx_CoroutineAwait_Close, METH_NOARGS, 0},
{0, 0, 0, 0}
};
static PyTypeObject __pyx_CoroutineAwaitType_type = {
PyVarObject_HEAD_INIT(0, 0)
"coroutine_await", /*tp_name*/
sizeof(__pyx_CoroutineAwaitObject), /*tp_basicsize*/
0, /*tp_itemsize*/
(destructor) __Pyx_CoroutineAwait_dealloc,/*tp_dealloc*/
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
0, /*tp_as_async resp. tp_compare*/
0, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash*/
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /*tp_flags*/
PyDoc_STR("A wrapper object implementing __await__ for coroutines."), /*tp_doc*/
(traverseproc) __Pyx_CoroutineAwait_traverse, /*tp_traverse*/
(inquiry) __Pyx_CoroutineAwait_clear, /*tp_clear*/
0, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
__Pyx_CoroutineAwait_self, /*tp_iter*/
(iternextfunc) __Pyx_CoroutineAwait_Next, /*tp_iternext*/
__pyx_CoroutineAwait_methods, /*tp_methods*/
0 , /*tp_members*/
0 , /*tp_getset*/
0, /*tp_base*/
0, /*tp_dict*/
0, /*tp_descr_get*/
0, /*tp_descr_set*/
0, /*tp_dictoffset*/
0, /*tp_init*/
0, /*tp_alloc*/
__Pyx_CoroutineAwait_no_new, /*tp_new*/
0, /*tp_free*/
0, /*tp_is_gc*/
0, /*tp_bases*/
0, /*tp_mro*/
0, /*tp_cache*/
0, /*tp_subclasses*/
0, /*tp_weaklist*/
0, /*tp_del*/
0, /*tp_version_tag*/
#if PY_VERSION_HEX >= 0x030400a1
0, /*tp_finalize*/
#endif
};
static CYTHON_INLINE PyObject *__Pyx__Coroutine_await(PyObject *coroutine) {
#if CYTHON_COMPILING_IN_CPYTHON
__pyx_CoroutineAwaitObject *await = PyObject_GC_New(__pyx_CoroutineAwaitObject, __pyx_CoroutineAwaitType);
#else
__pyx_CoroutineAwaitObject *await = __pyx_CoroutineAwaitType->tp_alloc(__pyx_CoroutineAwaitType);
#endif
if (unlikely(!await)) return NULL;
Py_INCREF(coroutine);
await->coroutine = coroutine;
#if CYTHON_COMPILING_IN_CPYTHON
_PyObject_GC_TRACK(await);
#endif
return (PyObject*)await;
}
static PyObject *__Pyx_Coroutine_await(PyObject *coroutine) {
if (unlikely(!coroutine || !__Pyx_Coroutine_CheckExact(coroutine))) {
PyErr_SetString(PyExc_TypeError, "invalid input, expected coroutine");
return NULL;
}
return __Pyx__Coroutine_await(coroutine);
}
static void __Pyx_Coroutine_check_and_dealloc(PyObject *self) {
__pyx_CoroutineObject *gen = (__pyx_CoroutineObject *) self;
......@@ -1023,11 +1166,6 @@ static PyObject *__Pyx_Coroutine_compare(PyObject *obj, PyObject *other, int op)
}
#endif
static PyObject *__Pyx_Coroutine_await(PyObject *self) {
Py_INCREF(self);
return self;
}
static PyMethodDef __pyx_Coroutine_methods[] = {
{"send", (PyCFunction) __Pyx_Coroutine_Send, METH_O, 0},
{"throw", (PyCFunction) __Pyx_Coroutine_Throw, METH_VARARGS, 0},
......@@ -1083,7 +1221,7 @@ static PyTypeObject __pyx_CoroutineType_type = {
offsetof(__pyx_CoroutineObject, gi_weakreflist), /*tp_weaklistoffset*/
// no tp_iter() as iterator is only available through __await__()
0, /*tp_iter*/
(iternextfunc) __Pyx_Generator_Next, /*tp_iternext*/
0, /*tp_iternext*/
__pyx_Coroutine_methods, /*tp_methods*/
__pyx_Coroutine_memberlist, /*tp_members*/
__pyx_Coroutine_getsets, /*tp_getset*/
......@@ -1118,9 +1256,12 @@ static int __pyx_Coroutine_init(void) {
__pyx_CoroutineType_type.tp_getattro = PyObject_GenericGetAttr;
__pyx_CoroutineType = __Pyx_FetchCommonType(&__pyx_CoroutineType_type);
if (unlikely(!__pyx_CoroutineType)) {
if (unlikely(!__pyx_CoroutineType))
return -1;
__pyx_CoroutineAwaitType = __Pyx_FetchCommonType(&__pyx_CoroutineAwaitType_type);
if (unlikely(!__pyx_CoroutineAwaitType))
return -1;
}
return 0;
}
......@@ -1267,6 +1408,7 @@ static PyObject* __Pyx_Coroutine_patch_module(PyObject* module, const char* py_c
globals = PyDict_New(); if (unlikely(!globals)) goto ignore;
#ifdef __Pyx_Coroutine_USED
if (unlikely(PyDict_SetItemString(globals, "_cython_coroutine_type", (PyObject*)__pyx_CoroutineType) < 0)) goto ignore;
if (unlikely(PyDict_SetItemString(globals, "_cython_coroutine_await_type", (PyObject*)__pyx_CoroutineAwaitType) < 0)) goto ignore;
#endif
#ifdef __Pyx_Generator_USED
if (unlikely(PyDict_SetItemString(globals, "_cython_generator_type", (PyObject*)__pyx_GeneratorType) < 0)) goto ignore;
......
# mode: run
# tag: asyncio
# tag: asyncio, pep492
"""
PYTHON setup.py build_ext -i
PYTHON test_from_import.py
PYTHON test_import.py
PYTHON test_async_def.py
PYTHON test_async_def_future.py
PYTHON test_all.py
"""
......@@ -63,6 +64,25 @@ if ASYNCIO_SUPPORTS_COROUTINE:
runloop(async_def.wait3)
######## test_async_def_future.py ########
import sys
ASYNCIO_SUPPORTS_COROUTINE = sys.version_info[:2] >= (3, 5)
if ASYNCIO_SUPPORTS_COROUTINE:
from async_def_future import await_future
import asyncio
def runloop():
loop = asyncio.get_event_loop()
task, events, expected = await_future(loop)
result = loop.run_until_complete(task())
assert events == expected, events
runloop()
######## test_all.py ########
import sys
......@@ -172,3 +192,27 @@ async def wait3():
await asyncio.sleep(0.01)
counter += 1
return counter
######## async_def_future.pyx ########
import asyncio
def await_future(loop):
events = []
async def worker():
fut = asyncio.Future()
def setval():
events.append('setval')
fut.set_result(123)
events.append('setup')
loop.call_later(0.2, setval)
events.append(await fut)
async def test():
await worker()
expected = ['setup', 'setval', 123]
return test, events, expected
......@@ -73,7 +73,7 @@ class AsyncYield:
def run_async(coro):
#assert coro.__class__ is types.GeneratorType
assert coro.__class__.__name__ in ('coroutine', 'GeneratorWrapper')
assert coro.__class__.__name__ in ('coroutine', 'GeneratorWrapper'), coro.__class__.__name__
buffer = []
result = None
......@@ -226,9 +226,8 @@ class CoroutineTest(unittest.TestCase):
with check():
iter(foo())
# in Cython: not iterable, but an iterator ...
#with check():
# next(foo())
with check():
next(foo())
with silence_coro_gc(), check():
for i in foo():
......@@ -306,6 +305,47 @@ class CoroutineTest(unittest.TestCase):
foo()
gc.collect()
def test_func_10(self):
N = 0
@types_coroutine
def gen():
nonlocal N
try:
a = yield
yield (a ** 2)
except ZeroDivisionError:
N += 100
raise
finally:
N += 1
async def foo():
await gen()
coro = foo()
aw = coro.__await__()
self.assertIs(aw, iter(aw))
next(aw)
self.assertEqual(aw.send(10), 100)
with self.assertRaises(TypeError):
type(aw).send(None, None)
self.assertEqual(N, 0)
aw.close()
self.assertEqual(N, 1)
with self.assertRaises(TypeError):
type(aw).close(None)
coro = foo()
aw = coro.__await__()
next(aw)
with self.assertRaises(ZeroDivisionError):
aw.throw(ZeroDivisionError, None, None)
self.assertEqual(N, 102)
with self.assertRaises(TypeError):
type(aw).throw(None, None, None, None)
def test_await_1(self):
async def foo():
......@@ -460,6 +500,20 @@ class CoroutineTest(unittest.TestCase):
run_async(foo())
def test_await_iterator(self):
async def foo():
return 123
coro = foo()
it = coro.__await__()
self.assertEqual(type(it).__name__, 'coroutine_await')
with self.assertRaisesRegex(TypeError, "cannot instantiate 'coroutine_await' type"):
type(it)() # cannot instantiate
with self.assertRaisesRegex(StopIteration, "123"):
next(it)
def test_with_1(self):
class Manager:
def __init__(self, name):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment