Commit d71ad15e authored by Stefan Behnel's avatar Stefan Behnel

allow yield in async def functions (which turns them into async generators)

parent dee2a0cb
......@@ -9411,6 +9411,7 @@ class YieldExprNode(ExprNode):
label_num = 0
is_yield_from = False
is_await = False
in_async_gen = False
expr_keyword = 'yield'
def analyse_types(self, env):
......@@ -9469,6 +9470,9 @@ class YieldExprNode(ExprNode):
code.putln("/* return from generator, yielding value */")
code.putln("%s->resume_label = %d;" % (
Naming.generator_cname, label_num))
if self.in_async_gen and not self.is_await:
code.putln("return __pyx__PyAsyncGenWrapValue(%s);" % Naming.retval_cname)
else:
code.putln("return %s;" % Naming.retval_cname)
code.put_label(label_name)
......
......@@ -2151,7 +2151,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("%s = PyUnicode_FromStringAndSize(\"\", 0); %s" % (
Naming.empty_unicode, code.error_goto_if_null(Naming.empty_unicode, self.pos)))
for ext_type in ('CyFunction', 'FusedFunction', 'Coroutine', 'Generator', 'StopAsyncIteration'):
for ext_type in ('CyFunction', 'FusedFunction', 'Coroutine', 'Generator', 'AsyncGen', 'StopAsyncIteration'):
code.putln("#ifdef __Pyx_%s_USED" % ext_type)
code.put_error_if_neg(self.pos, "__pyx_%s_init()" % ext_type)
code.putln("#endif")
......
......@@ -3969,6 +3969,8 @@ class GeneratorDefNode(DefNode):
is_generator = True
is_coroutine = False
is_asyncgen = False
gen_type_name = 'Generator'
needs_closure = True
child_attrs = DefNode.child_attrs + ["gbody"]
......@@ -3992,7 +3994,7 @@ class GeneratorDefNode(DefNode):
code.putln('{')
code.putln('__pyx_CoroutineObject *gen = __Pyx_%s_New('
'(__pyx_coroutine_body_t) %s, (PyObject *) %s, %s, %s, %s); %s' % (
'Coroutine' if self.is_coroutine else 'Generator',
self.gen_type_name,
body_cname, Naming.cur_scope_cname, name, qualname, module_name,
code.error_goto_if_null('gen', self.pos)))
code.put_decref(Naming.cur_scope_cname, py_object_type)
......@@ -4007,18 +4009,23 @@ class GeneratorDefNode(DefNode):
code.putln('}')
def generate_function_definitions(self, env, code):
env.use_utility_code(UtilityCode.load_cached(
'Coroutine' if self.is_coroutine else 'Generator', "Coroutine.c"))
env.use_utility_code(UtilityCode.load_cached(self.gen_type_name, "Coroutine.c"))
self.gbody.generate_function_header(code, proto=True)
super(GeneratorDefNode, self).generate_function_definitions(env, code)
self.gbody.generate_function_definitions(env, code)
class AsyncDefNode(GeneratorDefNode):
gen_type_name = 'Coroutine'
is_coroutine = True
class AsyncGenNode(AsyncDefNode):
gen_type_name = 'AsyncGen'
is_asyncgen = True
class GeneratorBodyDefNode(DefNode):
# Main code body of a generator implemented as a DefNode.
#
......
......@@ -52,6 +52,8 @@ cdef class YieldNodeCollector(TreeVisitor):
cdef public list yields
cdef public list returns
cdef public bint has_return_value
cdef public bint has_yield
cdef public bint has_await
cdef class MarkClosureVisitor(CythonTransform):
cdef bint needs_closure
......
......@@ -192,7 +192,7 @@ class PostParse(ScopeTrackingTransform):
# unpack a lambda expression into the corresponding DefNode
collector = YieldNodeCollector()
collector.visitchildren(node.result_expr)
if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode):
if collector.has_yield or collector.has_await or isinstance(node.result_expr, ExprNodes.YieldExprNode):
body = Nodes.ExprStatNode(
node.result_expr.pos, expr=node.result_expr)
else:
......@@ -2457,19 +2457,22 @@ class YieldNodeCollector(TreeVisitor):
def __init__(self):
super(YieldNodeCollector, self).__init__()
self.yields = []
self.awaits = []
self.returns = []
self.has_return_value = False
self.has_yield = False
self.has_await = False
def visit_Node(self, node):
self.visitchildren(node)
def visit_YieldExprNode(self, node):
self.yields.append(node)
self.has_yield = True
self.visitchildren(node)
def visit_AwaitExprNode(self, node):
self.awaits.append(node)
self.yields.append(node)
self.has_await = True
self.visitchildren(node)
def visit_ReturnStatNode(self, node):
......@@ -2513,24 +2516,27 @@ class MarkClosureVisitor(CythonTransform):
collector.visitchildren(node)
if node.is_async_def:
if collector.yields:
error(collector.yields[0].pos, "'yield' not allowed in async coroutines (use 'await')")
yields = collector.awaits
elif collector.yields:
if collector.awaits:
error(collector.yields[0].pos, "'await' not allowed in generators (use 'yield')")
yields = collector.yields
coroutine_type = Nodes.AsyncGenNode if collector.has_yield else Nodes.AsyncDefNode
if collector.has_yield:
for yield_expr in collector.yields:
yield_expr.in_async_gen = True
elif collector.has_await:
found = next(y for y in collector.yields if y.is_await)
error(found.pos, "'await' not allowed in generators (use 'yield')")
return node
elif collector.has_yield:
coroutine_type = Nodes.GeneratorDefNode
else:
return node
for i, yield_expr in enumerate(yields, 1):
for i, yield_expr in enumerate(collector.yields, 1):
yield_expr.label_num = i
for retnode in collector.returns:
retnode.in_generator = True
gbody = Nodes.GeneratorBodyDefNode(
pos=node.pos, name=node.name, body=node.body)
coroutine = (Nodes.AsyncDefNode if node.is_async_def else Nodes.GeneratorDefNode)(
coroutine = coroutine_type(
pos=node.pos, name=node.name, args=node.args,
star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators,
......
......@@ -21,6 +21,9 @@ static PyTypeObject *__pyx_AsyncGenType = 0;
static PyObject *__Pyx_AsyncGen_ANext(PyObject *o);
static PyObject *__pyx__PyAsyncGenWrapValue(PyObject *val);
static __pyx_CoroutineObject *__Pyx_AsyncGen_New(
__pyx_coroutine_body_t body, PyObject *closure,
PyObject *name, PyObject *qualname, PyObject *module_name) {
......@@ -679,7 +682,8 @@ static PyObject *
__pyx__PyAsyncGenWrapValue(PyObject *val)
{
__pyx__PyAsyncGenWrappedValue *o;
assert(val);
if (unlikely(!val))
return NULL;
if (__Pyx_ag_value_fl_free) {
__Pyx_ag_value_fl_free--;
......@@ -689,11 +693,12 @@ __pyx__PyAsyncGenWrapValue(PyObject *val)
} else {
o = PyObject_New(__pyx__PyAsyncGenWrappedValue, __pyx__PyAsyncGenWrappedValueType);
if (o == NULL) {
Py_DECREF(val);
return NULL;
}
}
o->val = val;
Py_INCREF(val);
// no Py_INCREF(val) - steals reference!
return (PyObject*)o;
}
......
# mode: error
# tag: pep492, async
async def foo():
yield
_ERRORS = """
5:4: 'yield' not allowed in async coroutines (use 'await')
5:4: 'yield' not supported here
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment