Commit fe0eeeb3 authored by da-woods's avatar da-woods Committed by GitHub

Implement PEP 572: Named/Assignment Expressions (GH-3691)

Closes https://github.com/cython/cython/issues/2636
parent cf88658e
...@@ -615,6 +615,9 @@ class ExprNode(Node): ...@@ -615,6 +615,9 @@ class ExprNode(Node):
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
error(self.pos, "Cannot assign to or delete this") error(self.pos, "Cannot assign to or delete this")
def analyse_assignment_expression_target_declaration(self, env):
error(self.pos, "Cannot use anything except a name in an assignment expression")
# ------------- Expression Analysis ---------------- # ------------- Expression Analysis ----------------
def analyse_const_expression(self, env): def analyse_const_expression(self, env):
...@@ -2083,9 +2086,18 @@ class NameNode(AtomicExprNode): ...@@ -2083,9 +2086,18 @@ class NameNode(AtomicExprNode):
return None return None
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
return self._analyse_target_declaration(env, is_assignment_expression=False)
def analyse_assignment_expression_target_declaration(self, env):
return self._analyse_target_declaration(env, is_assignment_expression=True)
def _analyse_target_declaration(self, env, is_assignment_expression):
self.is_target = True self.is_target = True
if not self.entry: if not self.entry:
self.entry = env.lookup_here(self.name) if is_assignment_expression:
self.entry = env.lookup_assignment_expression_target(self.name)
else:
self.entry = env.lookup_here(self.name)
if not self.entry and self.annotation is not None: if not self.entry and self.annotation is not None:
# name : type = ... # name : type = ...
self.declare_from_annotation(env, as_target=True) self.declare_from_annotation(env, as_target=True)
...@@ -2096,7 +2108,10 @@ class NameNode(AtomicExprNode): ...@@ -2096,7 +2108,10 @@ class NameNode(AtomicExprNode):
type = unspecified_type type = unspecified_type
else: else:
type = py_object_type type = py_object_type
self.entry = env.declare_var(self.name, type, self.pos) if is_assignment_expression:
self.entry = env.declare_assignment_expression_target(self.name, type, self.pos)
else:
self.entry = env.declare_var(self.name, type, self.pos)
if self.entry.is_declared_generic: if self.entry.is_declared_generic:
self.result_ctype = py_object_type self.result_ctype = py_object_type
if self.entry.as_module: if self.entry.as_module:
...@@ -13715,17 +13730,17 @@ class ProxyNode(CoercionNode): ...@@ -13715,17 +13730,17 @@ class ProxyNode(CoercionNode):
def __init__(self, arg): def __init__(self, arg):
super(ProxyNode, self).__init__(arg) super(ProxyNode, self).__init__(arg)
self.constant_result = arg.constant_result self.constant_result = arg.constant_result
self._proxy_type() self.update_type_and_entry()
def analyse_types(self, env): def analyse_types(self, env):
self.arg = self.arg.analyse_expressions(env) self.arg = self.arg.analyse_expressions(env)
self._proxy_type() self.update_type_and_entry()
return self return self
def infer_type(self, env): def infer_type(self, env):
return self.arg.infer_type(env) return self.arg.infer_type(env)
def _proxy_type(self): def update_type_and_entry(self):
type = getattr(self.arg, 'type', None) type = getattr(self.arg, 'type', None)
if type: if type:
self.type = type self.type = type
...@@ -13989,3 +14004,102 @@ class AnnotationNode(ExprNode): ...@@ -13989,3 +14004,102 @@ class AnnotationNode(ExprNode):
else: else:
warning(annotation.pos, "Unknown type declaration in annotation, ignoring") warning(annotation.pos, "Unknown type declaration in annotation, ignoring")
return base_type, arg_type return base_type, arg_type
class AssignmentExpressionNode(ExprNode):
"""
Also known as a named expression or the walrus operator
Arguments
lhs - NameNode - not stored directly as an attribute of the node
rhs - ExprNode
Attributes
rhs - ExprNode
assignment - SingleAssignmentNode
"""
# subexprs and child_attrs are intentionally different here, because the assignment is not an expression
subexprs = ["rhs"]
child_attrs = ["rhs", "assignment"] # This order is important for control-flow (i.e. xdecref) to be right
is_temp = False
assignment = None
clone_node = None
def __init__(self, pos, lhs, rhs, **kwds):
super(AssignmentExpressionNode, self).__init__(pos, **kwds)
self.rhs = ProxyNode(rhs)
assign_expr_rhs = CloneNode(self.rhs)
self.assignment = SingleAssignmentNode(
pos, lhs=lhs, rhs=assign_expr_rhs, is_assignment_expression=True)
@property
def type(self):
return self.rhs.type
@property
def target_name(self):
return self.assignment.lhs.name
def infer_type(self, env):
return self.rhs.infer_type(env)
def analyse_declarations(self, env):
self.assignment.analyse_declarations(env)
def analyse_types(self, env):
# we're trying to generate code that looks roughly like:
# __pyx_t_1 = rhs
# lhs = __pyx_t_1
# __pyx_t_1
# (plus any reference counting that's needed)
self.rhs = self.rhs.analyse_types(env)
if not self.rhs.arg.is_temp:
if not self.rhs.arg.is_literal:
# for anything but the simplest cases (where it can be used directly)
# we convert rhs to a temp, because CloneNode requires arg to be a temp
self.rhs.arg = self.rhs.arg.coerce_to_temp(env)
else:
# For literals we can optimize by just using the literal twice
#
# We aren't including `self.rhs.is_name` in this optimization
# because that goes wrong for assignment expressions run in
# parallel. e.g. `(a := b) + (b := a + c)`)
# This is a special case of https://github.com/cython/cython/issues/4146
# TODO - once that's fixed general revisit this code and possibly
# use coerce_to_simple
self.assignment.rhs = copy.copy(self.rhs)
# TODO - there's a missed optimization in the code generation stage
# for self.rhs.arg.is_temp: an incref/decref pair can be removed
# (but needs a general mechanism to do that)
self.assignment = self.assignment.analyse_types(env)
return self
def coerce_to(self, dst_type, env):
if dst_type == self.assignment.rhs.type:
# in this quite common case (for example, when both lhs, and self are being coerced to Python)
# we can optimize the coercion out by sharing it between
# this and the assignment
old_rhs_arg = self.rhs.arg
if isinstance(old_rhs_arg, CoerceToTempNode):
old_rhs_arg = old_rhs_arg.arg
rhs_arg = old_rhs_arg.coerce_to(dst_type, env)
if rhs_arg is not old_rhs_arg:
self.rhs.arg = rhs_arg
self.rhs.update_type_and_entry()
# clean up the old coercion node that the assignment has likely generated
if (isinstance(self.assignment.rhs, CoercionNode)
and not isinstance(self.assignment.rhs, CloneNode)):
self.assignment.rhs = self.assignment.rhs.arg
self.assignment.rhs.type = self.assignment.rhs.arg.type
return self
return super(AssignmentExpressionNode, self).coerce_to(dst_type, env)
def calculate_result_code(self):
return self.rhs.result()
def generate_result_code(self, code):
# we have to do this manually because it isn't a subexpression
self.assignment.generate_execution_code(code)
...@@ -590,7 +590,7 @@ def check_definitions(flow, compiler_directives): ...@@ -590,7 +590,7 @@ def check_definitions(flow, compiler_directives):
if (node.allow_null or entry.from_closure if (node.allow_null or entry.from_closure
or entry.is_pyclass_attr or entry.type.is_error): or entry.is_pyclass_attr or entry.type.is_error):
pass # Can be uninitialized here pass # Can be uninitialized here
elif node.cf_is_null: elif node.cf_is_null and not entry.in_closure:
if entry.error_on_uninitialized or ( if entry.error_on_uninitialized or (
Options.error_on_uninitialized and ( Options.error_on_uninitialized and (
entry.type.is_pyobject or entry.type.is_unspecified)): entry.type.is_pyobject or entry.type.is_unspecified)):
...@@ -604,10 +604,12 @@ def check_definitions(flow, compiler_directives): ...@@ -604,10 +604,12 @@ def check_definitions(flow, compiler_directives):
"local variable '%s' referenced before assignment" "local variable '%s' referenced before assignment"
% entry.name) % entry.name)
elif warn_maybe_uninitialized: elif warn_maybe_uninitialized:
msg = "local variable '%s' might be referenced before assignment" % entry.name
if entry.in_closure:
msg += " (maybe initialized inside a closure)"
messages.warning( messages.warning(
node.pos, node.pos,
"local variable '%s' might be referenced before assignment" msg)
% entry.name)
elif Unknown in node.cf_state: elif Unknown in node.cf_state:
# TODO: better cross-closure analysis to know when inner functions # TODO: better cross-closure analysis to know when inner functions
# are being called before a variable is being set, and when # are being called before a variable is being set, and when
......
...@@ -77,7 +77,7 @@ def make_lexicon(): ...@@ -77,7 +77,7 @@ def make_lexicon():
punct = Any(":,;+-*/|&<>=.%`~^?!@") punct = Any(":,;+-*/|&<>=.%`~^?!@")
diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//", diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//",
"+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=", "+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=",
"<<=", ">>=", "**=", "//=", "->", "@=", "&&", "||") "<<=", ">>=", "**=", "//=", "->", "@=", "&&", "||", ':=')
spaces = Rep1(Any(" \t\f")) spaces = Rep1(Any(" \t\f"))
escaped_newline = Str("\\\n") escaped_newline = Str("\\\n")
lineterm = Eol + Opt(Str("\n")) lineterm = Eol + Opt(Str("\n"))
......
...@@ -23,7 +23,7 @@ from . import PyrexTypes ...@@ -23,7 +23,7 @@ from . import PyrexTypes
from . import TypeSlots from . import TypeSlots
from .PyrexTypes import py_object_type, error_type from .PyrexTypes import py_object_type, error_type
from .Symtab import (ModuleScope, LocalScope, ClosureScope, PropertyScope, from .Symtab import (ModuleScope, LocalScope, ClosureScope, PropertyScope,
StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope, StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope, GeneratorExpressionScope,
CppScopedEnumScope, punycodify_name) CppScopedEnumScope, punycodify_name)
from .Code import UtilityCode from .Code import UtilityCode
from .StringEncoding import EncodedString from .StringEncoding import EncodedString
...@@ -1744,6 +1744,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1744,6 +1744,7 @@ class FuncDefNode(StatNode, BlockNode):
needs_outer_scope = False needs_outer_scope = False
pymethdef_required = False pymethdef_required = False
is_generator = False is_generator = False
is_generator_expression = False # this can be True alongside is_generator
is_coroutine = False is_coroutine = False
is_asyncgen = False is_asyncgen = False
is_generator_body = False is_generator_body = False
...@@ -1815,7 +1816,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1815,7 +1816,8 @@ class FuncDefNode(StatNode, BlockNode):
while genv.is_py_class_scope or genv.is_c_class_scope: while genv.is_py_class_scope or genv.is_c_class_scope:
genv = genv.outer_scope genv = genv.outer_scope
if self.needs_closure: if self.needs_closure:
lenv = ClosureScope(name=self.entry.name, cls = GeneratorExpressionScope if self.is_generator_expression else ClosureScope
lenv = cls(name=self.entry.name,
outer_scope=genv, outer_scope=genv,
parent_scope=env, parent_scope=env,
scope_name=self.entry.cname) scope_name=self.entry.cname)
...@@ -5748,12 +5750,14 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5748,12 +5750,14 @@ class SingleAssignmentNode(AssignmentNode):
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs? # first bool Is this guaranteed the first assignment to lhs?
# is_overloaded_assignment bool Is this assignment done via an overloaded operator= # is_overloaded_assignment bool Is this assignment done via an overloaded operator=
# is_assignment_expression bool Internally SingleAssignmentNode is used to implement assignment expressions
# exception_check # exception_check
# exception_value # exception_value
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False first = False
is_overloaded_assignment = False is_overloaded_assignment = False
is_assignment_expression = False
declaration_only = False declaration_only = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -5838,7 +5842,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5838,7 +5842,10 @@ class SingleAssignmentNode(AssignmentNode):
if self.declaration_only: if self.declaration_only:
return return
else: else:
self.lhs.analyse_target_declaration(env) if self.is_assignment_expression:
self.lhs.analyse_assignment_expression_target_declaration(env)
else:
self.lhs.analyse_target_declaration(env)
def analyse_types(self, env, use_temp=0): def analyse_types(self, env, use_temp=0):
from . import ExprNodes from . import ExprNodes
......
...@@ -183,6 +183,8 @@ class PostParse(ScopeTrackingTransform): ...@@ -183,6 +183,8 @@ class PostParse(ScopeTrackingTransform):
Note: Currently Parsing.py does a lot of interpretation and Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted. if a more pure Abstract Syntax Tree is wanted.
- Some invalid uses of := assignment expressions are detected
""" """
def __init__(self, context): def __init__(self, context):
super(PostParse, self).__init__(context) super(PostParse, self).__init__(context)
...@@ -215,7 +217,9 @@ class PostParse(ScopeTrackingTransform): ...@@ -215,7 +217,9 @@ class PostParse(ScopeTrackingTransform):
node.def_node = Nodes.DefNode( node.def_node = Nodes.DefNode(
node.pos, name=node.name, doc=None, node.pos, name=node.name, doc=None,
args=[], star_arg=None, starstar_arg=None, args=[], star_arg=None, starstar_arg=None,
body=node.loop, is_async_def=collector.has_await) body=node.loop, is_async_def=collector.has_await,
is_generator_expression=True)
_AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass"))
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -226,6 +230,7 @@ class PostParse(ScopeTrackingTransform): ...@@ -226,6 +230,7 @@ class PostParse(ScopeTrackingTransform):
collector.visitchildren(node.loop) collector.visitchildren(node.loop)
if collector.has_await: if collector.has_await:
node.has_local_scope = True node.has_local_scope = True
_AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass"))
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -378,6 +383,124 @@ class PostParse(ScopeTrackingTransform): ...@@ -378,6 +383,124 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class _AssignmentExpressionTargetNameFinder(TreeVisitor):
def __init__(self):
super(_AssignmentExpressionTargetNameFinder, self).__init__()
self.target_names = {}
def find_target_names(self, target):
if target.is_name:
return [target.name]
elif target.is_sequence_constructor:
names = []
for arg in target.args:
names.extend(self.find_target_names(arg))
return names
# other targets are possible, but it isn't necessary to investigate them here
return []
def visit_ForInStatNode(self, node):
self.target_names[node] = tuple(self.find_target_names(node.target))
self.visitchildren(node)
def visit_ComprehensionNode(self, node):
pass # don't recurse into nested comprehensions
def visit_LambdaNode(self, node):
pass # don't recurse into nested lambdas/generator expressions
def visit_Node(self, node):
self.visitchildren(node)
class _AssignmentExpressionChecker(TreeVisitor):
"""
Enforces rules on AssignmentExpressions within generator expressions and comprehensions
"""
def __init__(self, loop_node, scope_is_class):
super(_AssignmentExpressionChecker, self).__init__()
target_name_finder = _AssignmentExpressionTargetNameFinder()
target_name_finder.visit(loop_node)
self.target_names_dict = target_name_finder.target_names
self.in_iterator = False
self.in_nested_generator = False
self.scope_is_class = scope_is_class
self.current_target_names = ()
self.all_target_names = set()
for names in self.target_names_dict.values():
self.all_target_names.update(names)
def _reset_state(self):
old_state = (self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names)
# note: not resetting self.in_iterator here, see visit_LambdaNode() below
self.in_nested_generator = False
self.scope_is_class = False
self.current_target_names = ()
self.all_target_names = set()
return old_state
def _set_state(self, old_state):
self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names = old_state
@classmethod
def do_checks(cls, loop_node, scope_is_class):
checker = cls(loop_node, scope_is_class)
checker.visit(loop_node)
def visit_ForInStatNode(self, node):
if self.in_nested_generator:
self.visitchildren(node) # once nested, don't do anything special
return
current_target_names = self.current_target_names
target_name = self.target_names_dict.get(node, None)
if target_name:
self.current_target_names += target_name
self.in_iterator = True
self.visit(node.iterator)
self.in_iterator = False
self.visitchildren(node, exclude=("iterator",))
self.current_target_names = current_target_names
def visit_AssignmentExpressionNode(self, node):
if self.in_iterator:
error(node.pos, "assignment expression cannot be used in a comprehension iterable expression")
if self.scope_is_class:
error(node.pos, "assignment expression within a comprehension cannot be used in a class body")
if node.target_name in self.current_target_names:
error(node.pos, "assignment expression cannot rebind comprehension iteration variable '%s'" %
node.target_name)
elif node.target_name in self.all_target_names:
error(node.pos, "comprehension inner loop cannot rebind assignment expression target '%s'" %
node.target_name)
def visit_LambdaNode(self, node):
# Don't reset "in_iterator" - an assignment expression in a lambda in an
# iterator is explicitly tested by the Python testcases and banned.
old_state = self._reset_state()
# the lambda node's "def_node" is not set up at this point, so we need to recurse into it explicitly.
self.visit(node.result_expr)
self._set_state(old_state)
def visit_ComprehensionNode(self, node):
in_nested_generator = self.in_nested_generator
self.in_nested_generator = True
self.visitchildren(node)
self.in_nested_generator = in_nested_generator
def visit_GeneratorExpressionNode(self, node):
in_nested_generator = self.in_nested_generator
self.in_nested_generator = True
# def_node isn't set up yet, so we need to visit the loop directly.
self.visit(node.loop)
self.in_nested_generator = in_nested_generator
def visit_Node(self, node):
self.visitchildren(node)
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once. """Replace rhs items by LetRefNodes if they appear more than once.
...@@ -2269,6 +2392,11 @@ if VALUE is not None: ...@@ -2269,6 +2392,11 @@ if VALUE is not None:
property.doc = entry.doc property.doc = entry.doc
return property return property
def visit_AssignmentExpressionNode(self, node):
self.visitchildren(node)
node.analyse_declarations(self.current_env())
return node
class CalculateQualifiedNamesTransform(EnvTransform): class CalculateQualifiedNamesTransform(EnvTransform):
""" """
...@@ -2806,7 +2934,8 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2806,7 +2934,8 @@ class MarkClosureVisitor(CythonTransform):
star_arg=node.star_arg, starstar_arg=node.starstar_arg, star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators, doc=node.doc, decorators=node.decorators,
gbody=gbody, lambda_name=node.lambda_name, gbody=gbody, lambda_name=node.lambda_name,
return_type_annotation=node.return_type_annotation) return_type_annotation=node.return_type_annotation,
is_generator_expression=node.is_generator_expression)
return coroutine return coroutine
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
......
...@@ -23,14 +23,15 @@ cdef tuple p_binop_operator(PyrexScanner s) ...@@ -23,14 +23,15 @@ cdef tuple p_binop_operator(PyrexScanner s)
cdef p_binop_expr(PyrexScanner s, ops, p_sub_expr_func p_sub_expr) cdef p_binop_expr(PyrexScanner s, ops, p_sub_expr_func p_sub_expr)
cdef p_lambdef(PyrexScanner s, bint allow_conditional=*) cdef p_lambdef(PyrexScanner s, bint allow_conditional=*)
cdef p_lambdef_nocond(PyrexScanner s) cdef p_lambdef_nocond(PyrexScanner s)
cdef p_test(PyrexScanner s) cdef p_test(PyrexScanner s, bint allow_assignment_expression=*)
cdef p_test_nocond(PyrexScanner s) cdef p_test_nocond(PyrexScanner s, bint allow_assignment_expression=*)
cdef p_walrus_test(PyrexScanner s, bint allow_assignment_expression=*)
cdef p_or_test(PyrexScanner s) cdef p_or_test(PyrexScanner s)
cdef p_rassoc_binop_expr(PyrexScanner s, unicode op, p_sub_expr_func p_subexpr) cdef p_rassoc_binop_expr(PyrexScanner s, unicode op, p_sub_expr_func p_subexpr)
cdef p_and_test(PyrexScanner s) cdef p_and_test(PyrexScanner s)
cdef p_not_test(PyrexScanner s) cdef p_not_test(PyrexScanner s)
cdef p_comparison(PyrexScanner s) cdef p_comparison(PyrexScanner s)
cdef p_test_or_starred_expr(PyrexScanner s) cdef p_test_or_starred_expr(PyrexScanner s, bint is_expression=*)
cdef p_starred_expr(PyrexScanner s) cdef p_starred_expr(PyrexScanner s)
cdef p_cascaded_cmp(PyrexScanner s) cdef p_cascaded_cmp(PyrexScanner s)
cdef p_cmp_op(PyrexScanner s) cdef p_cmp_op(PyrexScanner s)
...@@ -86,7 +87,7 @@ cdef p_simple_expr_list(PyrexScanner s, expr=*) ...@@ -86,7 +87,7 @@ cdef p_simple_expr_list(PyrexScanner s, expr=*)
cdef p_test_or_starred_expr_list(PyrexScanner s, expr=*) cdef p_test_or_starred_expr_list(PyrexScanner s, expr=*)
cdef p_testlist(PyrexScanner s) cdef p_testlist(PyrexScanner s)
cdef p_testlist_star_expr(PyrexScanner s) cdef p_testlist_star_expr(PyrexScanner s)
cdef p_testlist_comp(PyrexScanner s) cdef p_testlist_comp(PyrexScanner s, bint is_expression=*)
cdef p_genexp(PyrexScanner s, expr) cdef p_genexp(PyrexScanner s, expr)
#------------------------------------------------------- #-------------------------------------------------------
......
...@@ -120,9 +120,9 @@ def p_lambdef(s, allow_conditional=True): ...@@ -120,9 +120,9 @@ def p_lambdef(s, allow_conditional=True):
s, terminator=':', annotated=False) s, terminator=':', annotated=False)
s.expect(':') s.expect(':')
if allow_conditional: if allow_conditional:
expr = p_test(s) expr = p_test(s, allow_assignment_expression=False)
else: else:
expr = p_test_nocond(s) expr = p_test_nocond(s, allow_assignment_expression=False)
return ExprNodes.LambdaNode( return ExprNodes.LambdaNode(
pos, args = args, pos, args = args,
star_arg = star_arg, starstar_arg = starstar_arg, star_arg = star_arg, starstar_arg = starstar_arg,
...@@ -135,14 +135,16 @@ def p_lambdef_nocond(s): ...@@ -135,14 +135,16 @@ def p_lambdef_nocond(s):
#test: or_test ['if' or_test 'else' test] | lambdef #test: or_test ['if' or_test 'else' test] | lambdef
def p_test(s): def p_test(s, allow_assignment_expression=True):
if s.sy == 'lambda': if s.sy == 'lambda':
return p_lambdef(s) return p_lambdef(s)
pos = s.position() pos = s.position()
expr = p_or_test(s) expr = p_walrus_test(s, allow_assignment_expression)
if s.sy == 'if': if s.sy == 'if':
s.next() s.next()
test = p_or_test(s) # Assignment expressions are always allowed here
# even if they wouldn't be allowed in the expression as a whole.
test = p_walrus_test(s)
s.expect('else') s.expect('else')
other = p_test(s) other = p_test(s)
return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other) return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other)
...@@ -151,11 +153,26 @@ def p_test(s): ...@@ -151,11 +153,26 @@ def p_test(s):
#test_nocond: or_test | lambdef_nocond #test_nocond: or_test | lambdef_nocond
def p_test_nocond(s): def p_test_nocond(s, allow_assignment_expression=True):
if s.sy == 'lambda': if s.sy == 'lambda':
return p_lambdef_nocond(s) return p_lambdef_nocond(s)
else: else:
return p_or_test(s) return p_walrus_test(s, allow_assignment_expression)
# walrurus_test: IDENT := test | or_test
def p_walrus_test(s, allow_assignment_expression=True):
lhs = p_or_test(s)
if s.sy == ':=':
position = s.position()
if not allow_assignment_expression:
s.error("invalid syntax: assignment expression not allowed in this context")
elif not lhs.is_name:
s.error("Left-hand side of assignment expression must be an identifier")
s.next()
rhs = p_test(s)
return ExprNodes.AssignmentExpressionNode(position, lhs=lhs, rhs=rhs)
return lhs
#or_test: and_test ('or' and_test)* #or_test: and_test ('or' and_test)*
...@@ -210,11 +227,11 @@ def p_comparison(s): ...@@ -210,11 +227,11 @@ def p_comparison(s):
n1.cascade = p_cascaded_cmp(s) n1.cascade = p_cascaded_cmp(s)
return n1 return n1
def p_test_or_starred_expr(s): def p_test_or_starred_expr(s, is_expression=False):
if s.sy == '*': if s.sy == '*':
return p_starred_expr(s) return p_starred_expr(s)
else: else:
return p_test(s) return p_test(s, allow_assignment_expression=is_expression)
def p_starred_expr(s): def p_starred_expr(s):
pos = s.position() pos = s.position()
...@@ -497,7 +514,7 @@ def p_call_parse_args(s, allow_genexp=True): ...@@ -497,7 +514,7 @@ def p_call_parse_args(s, allow_genexp=True):
encoded_name = s.context.intern_ustring(arg.name) encoded_name = s.context.intern_ustring(arg.name)
keyword = ExprNodes.IdentifierStringNode( keyword = ExprNodes.IdentifierStringNode(
arg.pos, value=encoded_name) arg.pos, value=encoded_name)
arg = p_test(s) arg = p_test(s, allow_assignment_expression=False)
keyword_args.append((keyword, arg)) keyword_args.append((keyword, arg))
else: else:
if keyword_args: if keyword_args:
...@@ -675,7 +692,7 @@ def p_atom(s): ...@@ -675,7 +692,7 @@ def p_atom(s):
elif s.sy == 'yield': elif s.sy == 'yield':
result = p_yield_expression(s) result = p_yield_expression(s)
else: else:
result = p_testlist_comp(s) result = p_testlist_comp(s, is_expression=True)
s.expect(')') s.expect(')')
return result return result
elif sy == '[': elif sy == '[':
...@@ -1259,7 +1276,7 @@ def p_list_maker(s): ...@@ -1259,7 +1276,7 @@ def p_list_maker(s):
s.expect(']') s.expect(']')
return ExprNodes.ListNode(pos, args=[]) return ExprNodes.ListNode(pos, args=[])
expr = p_test_or_starred_expr(s) expr = p_test_or_starred_expr(s, is_expression=True)
if s.sy in ('for', 'async'): if s.sy in ('for', 'async'):
if expr.is_starred: if expr.is_starred:
s.error("iterable unpacking cannot be used in comprehension") s.error("iterable unpacking cannot be used in comprehension")
...@@ -1459,7 +1476,7 @@ def p_simple_expr_list(s, expr=None): ...@@ -1459,7 +1476,7 @@ def p_simple_expr_list(s, expr=None):
def p_test_or_starred_expr_list(s, expr=None): def p_test_or_starred_expr_list(s, expr=None):
exprs = expr is not None and [expr] or [] exprs = expr is not None and [expr] or []
while s.sy not in expr_terminators: while s.sy not in expr_terminators:
exprs.append(p_test_or_starred_expr(s)) exprs.append(p_test_or_starred_expr(s, is_expression=(expr is not None)))
if s.sy != ',': if s.sy != ',':
break break
s.next() s.next()
...@@ -1492,9 +1509,9 @@ def p_testlist_star_expr(s): ...@@ -1492,9 +1509,9 @@ def p_testlist_star_expr(s):
# testlist_comp: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] ) # testlist_comp: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] )
def p_testlist_comp(s): def p_testlist_comp(s, is_expression=False):
pos = s.position() pos = s.position()
expr = p_test_or_starred_expr(s) expr = p_test_or_starred_expr(s, is_expression)
if s.sy == ',': if s.sy == ',':
s.next() s.next()
exprs = p_test_or_starred_expr_list(s, expr) exprs = p_test_or_starred_expr_list(s, expr)
...@@ -3073,11 +3090,11 @@ def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0, ...@@ -3073,11 +3090,11 @@ def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0,
default = ExprNodes.NoneNode(pos) default = ExprNodes.NoneNode(pos)
s.next() s.next()
elif 'inline' in ctx.modifiers: elif 'inline' in ctx.modifiers:
default = p_test(s) default = p_test(s, allow_assignment_expression=False)
else: else:
error(pos, "default values cannot be specified in pxd files, use ? or *") error(pos, "default values cannot be specified in pxd files, use ? or *")
else: else:
default = p_test(s) default = p_test(s, allow_assignment_expression=False)
return Nodes.CArgDeclNode(pos, return Nodes.CArgDeclNode(pos,
base_type = base_type, base_type = base_type,
declarator = declarator, declarator = declarator,
...@@ -3955,5 +3972,5 @@ def p_annotation(s): ...@@ -3955,5 +3972,5 @@ def p_annotation(s):
then it is not a bug. then it is not a bug.
""" """
pos = s.position() pos = s.position()
expr = p_test(s) expr = p_test(s, allow_assignment_expression=False)
return ExprNodes.AnnotationNode(pos, expr=expr) return ExprNodes.AnnotationNode(pos, expr=expr)
...@@ -331,6 +331,7 @@ class Scope(object): ...@@ -331,6 +331,7 @@ class Scope(object):
# is_py_class_scope boolean Is a Python class scope # is_py_class_scope boolean Is a Python class scope
# is_c_class_scope boolean Is an extension type scope # is_c_class_scope boolean Is an extension type scope
# is_closure_scope boolean Is a closure scope # is_closure_scope boolean Is a closure scope
# is_generator_expression_scope boolean A subset of closure scope used for generator expressions
# is_passthrough boolean Outer scope is passed directly # is_passthrough boolean Outer scope is passed directly
# is_cpp_class_scope boolean Is a C++ class scope # is_cpp_class_scope boolean Is a C++ class scope
# is_property_scope boolean Is a extension type property scope # is_property_scope boolean Is a extension type property scope
...@@ -347,6 +348,7 @@ class Scope(object): ...@@ -347,6 +348,7 @@ class Scope(object):
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_closure_scope = 0 is_closure_scope = 0
is_generator_expression_scope = 0
is_comprehension_scope = 0 is_comprehension_scope = 0
is_passthrough = 0 is_passthrough = 0
is_cpp_class_scope = 0 is_cpp_class_scope = 0
...@@ -748,6 +750,11 @@ class Scope(object): ...@@ -748,6 +750,11 @@ class Scope(object):
entry.used = 1 entry.used = 1
return entry return entry
def declare_assignment_expression_target(self, name, type, pos):
# In most cases declares the variable as normal.
# For generator expressions and comprehensions the variable is declared in their parent
return self.declare_var(name, type, pos)
def declare_builtin(self, name, pos): def declare_builtin(self, name, pos):
name = self.mangle_class_private_name(name) name = self.mangle_class_private_name(name)
return self.outer_scope.declare_builtin(name, pos) return self.outer_scope.declare_builtin(name, pos)
...@@ -974,6 +981,11 @@ class Scope(object): ...@@ -974,6 +981,11 @@ class Scope(object):
def lookup_here_unmangled(self, name): def lookup_here_unmangled(self, name):
return self.entries.get(name, None) return self.entries.get(name, None)
def lookup_assignment_expression_target(self, name):
# For most cases behaves like "lookup_here".
# However, it does look outwards for comprehension and generator expression scopes
return self.lookup_here(name)
def lookup_target(self, name): def lookup_target(self, name):
# Look up name in this scope only. Declare as Python # Look up name in this scope only. Declare as Python
# variable if not found. # variable if not found.
...@@ -1893,6 +1905,13 @@ class LocalScope(Scope): ...@@ -1893,6 +1905,13 @@ class LocalScope(Scope):
if entry is None or not entry.from_closure: if entry is None or not entry.from_closure:
error(pos, "no binding for nonlocal '%s' found" % name) error(pos, "no binding for nonlocal '%s' found" % name)
def _create_inner_entry_for_closure(self, name, entry):
entry.in_closure = True
inner_entry = InnerEntry(entry, self)
inner_entry.is_variable = True
self.entries[name] = inner_entry
return inner_entry
def lookup(self, name): def lookup(self, name):
# Look up name in this scope or an enclosing one. # Look up name in this scope or an enclosing one.
# Return None if not found. # Return None if not found.
...@@ -1907,11 +1926,7 @@ class LocalScope(Scope): ...@@ -1907,11 +1926,7 @@ class LocalScope(Scope):
raise InternalError("lookup() after scope class created.") raise InternalError("lookup() after scope class created.")
# The actual c fragment for the different scopes differs # The actual c fragment for the different scopes differs
# on the outside and inside, so we make a new entry # on the outside and inside, so we make a new entry
entry.in_closure = True return self._create_inner_entry_for_closure(name, entry)
inner_entry = InnerEntry(entry, self)
inner_entry.is_variable = True
self.entries[name] = inner_entry
return inner_entry
return entry return entry
def mangle_closure_cnames(self, outer_scope_cname): def mangle_closure_cnames(self, outer_scope_cname):
...@@ -1981,6 +1996,10 @@ class ComprehensionScope(Scope): ...@@ -1981,6 +1996,10 @@ class ComprehensionScope(Scope):
self.entries[name] = entry self.entries[name] = entry
return entry return entry
def declare_assignment_expression_target(self, name, type, pos):
# should be declared in the parent scope instead
return self.parent_scope.declare_var(name, type, pos)
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return self.outer_scope.declare_pyfunction( return self.outer_scope.declare_pyfunction(
name, pos, allow_redefine) name, pos, allow_redefine)
...@@ -1991,6 +2010,12 @@ class ComprehensionScope(Scope): ...@@ -1991,6 +2010,12 @@ class ComprehensionScope(Scope):
def add_lambda_def(self, def_node): def add_lambda_def(self, def_node):
return self.outer_scope.add_lambda_def(def_node) return self.outer_scope.add_lambda_def(def_node)
def lookup_assignment_expression_target(self, name):
entry = self.lookup_here(name)
if not entry:
entry = self.parent_scope.lookup_assignment_expression_target(name)
return entry
class ClosureScope(LocalScope): class ClosureScope(LocalScope):
...@@ -2012,6 +2037,25 @@ class ClosureScope(LocalScope): ...@@ -2012,6 +2037,25 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private') return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
def declare_assignment_expression_target(self, name, type, pos):
return self.declare_var(name, type, pos)
class GeneratorExpressionScope(ClosureScope):
is_generator_expression_scope = True
def declare_assignment_expression_target(self, name, type, pos):
entry = self.parent_scope.declare_var(name, type, pos)
return self._create_inner_entry_for_closure(name, entry)
def lookup_assignment_expression_target(self, name):
entry = self.lookup_here(name)
if not entry:
entry = self.parent_scope.lookup_assignment_expression_target(name)
if entry:
return self._create_inner_entry_for_closure(name, entry)
return entry
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
......
# mode: run
# tag: pure3.8
# These are extra tests for the assignment expression/walrus operator/named expression that cover things
# additional to the standard Python test-suite in tests/run/test_named_expressions.pyx
import cython
import sys
@cython.test_assert_path_exists("//PythonCapiCallNode")
def optimized(x):
"""
x*2 is optimized to a PythonCapiCallNode. The test fails unless the CloneNode is kept up-to-date
(in the event that the optimization changes and test_assert_path_exists fails, the thing to do
is to find another case that's similarly optimized - the test isn't specifically interested in
multiplication)
>>> optimized(5)
10
"""
return (x:=x*2)
# FIXME: currently broken; GH-4146
# Changing x in the assignment expression should not affect the value used on the right-hand side
#def order(x):
# """
# >>> order(5)
# 15
# """
# return x+(x:=x*2)
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals1():
"""
There's a small optimization for literals to avoid creating unnecessary temps
>>> optimize_literals1()
10
"""
x = 5
return (x := 10)
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals2():
"""
There's a small optimization for literals to avoid creating unnecessary temps
Test is in __doc__ (for Py2 string formatting reasons)
"""
x = 5
return (x := u"a string")
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals3():
"""
There's a small optimization for literals to avoid creating unnecessary temps
Test is in __doc__ (for Py2 string formatting reasons)
"""
x = 5
return (x := b"a bytes")
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals4():
"""
There's a small optimization for literals to avoid creating unnecessary temps
Test is in __doc__ (for Py2 string formatting reasons)
"""
x = 5
return (x := (u"tuple", 1, 1.0, b"stuff"))
if sys.version_info[0] != 2:
__doc__ = """
>>> optimize_literals2()
'a string'
>>> optimize_literals3()
b'a bytes'
>>> optimize_literals4()
('tuple', 1, 1.0, b'stuff')
"""
else:
__doc__ = """
>>> optimize_literals2()
u'a string'
>>> optimize_literals3()
'a bytes'
>>> optimize_literals4()
(u'tuple', 1, 1.0, 'stuff')
"""
@cython.test_fail_if_path_exists("//CoerceToPyTypeNode//AssignmentExpressionNode")
def avoid_extra_coercion(x : cython.double):
"""
The assignment expression and x are both coerced to PyObject - this should happen only once
rather than to both separately
>>> avoid_extra_coercion(5.)
5.0
"""
y : object = "I'm an object"
return (y := x)
async def async_func():
"""
DW doesn't understand async functions well enough to make it a runtime test, but it was causing
a compile-time failure at one point
"""
if variable := 1:
pass
y_global = 6
class InLambdaInClass:
"""
>>> InLambdaInClass.x1
12
>>> InLambdaInClass.x2
[12, 12]
"""
x1 = (lambda y_global: (y_global := y_global + 1) + y_global)(2) + y_global
x2 = [(lambda y_global: (y_global := y_global + 1) + y_global)(2) + y_global for _ in range(2) ]
def in_lambda_in_list_comprehension1():
"""
>>> in_lambda_in_list_comprehension1()
[[0, 2, 4, 6], [0, 2, 4, 6], [0, 2, 4, 6], [0, 2, 4, 6], [0, 2, 4, 6]]
"""
return [ (lambda x: [(x := y) + x for y in range(4)])(x) for x in range(5) ]
def in_lambda_in_list_comprehension2():
"""
>>> in_lambda_in_list_comprehension2()
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]]
"""
return [ (lambda z: [(x := y) + z for y in range(4)])(x) for x in range(5) ]
def in_lambda_in_generator_expression1():
"""
>>> in_lambda_in_generator_expression1()
[(0, 2, 4, 6), (0, 2, 4, 6), (0, 2, 4, 6), (0, 2, 4, 6), (0, 2, 4, 6)]
"""
return [ (lambda x: tuple((x := y) + x for y in range(4)))(x) for x in range(5) ]
def in_lambda_in_generator_expression2():
"""
>>> in_lambda_in_generator_expression2()
[(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5), (3, 4, 5, 6), (4, 5, 6, 7)]
"""
return [ (lambda z: tuple((x := y) + z for y in range(4)))(x) for x in range(5) ]
...@@ -1173,11 +1173,10 @@ non-important content ...@@ -1173,11 +1173,10 @@ non-important content
self.assertEqual(f'{0!=1}', 'True') self.assertEqual(f'{0!=1}', 'True')
self.assertEqual(f'{0<=1}', 'True') self.assertEqual(f'{0<=1}', 'True')
self.assertEqual(f'{0>=1}', 'False') self.assertEqual(f'{0>=1}', 'False')
# Walrus not implemented yet, skip self.assertEqual(f'{(x:="5")}', '5')
# self.assertEqual(f'{(x:="5")}', '5') self.assertEqual(x, '5')
# self.assertEqual(x, '5') self.assertEqual(f'{(x:=5)}', '5')
# self.assertEqual(f'{(x:=5)}', '5') self.assertEqual(x, 5)
# self.assertEqual(x, 5)
self.assertEqual(f'{"="}', '=') self.assertEqual(f'{"="}', '=')
x = 20 x = 20
...@@ -1239,13 +1238,9 @@ non-important content ...@@ -1239,13 +1238,9 @@ non-important content
# spec of '=10'. # spec of '=10'.
self.assertEqual(f'{x:=10}', ' 20') self.assertEqual(f'{x:=10}', ' 20')
# Note to anyone going to enable these: please have a look to the test
# above this one for more walrus cases to enable.
"""
# This is an assignment expression, which requires parens. # This is an assignment expression, which requires parens.
self.assertEqual(f'{(x:=10)}', '10') self.assertEqual(f'{(x:=10)}', '10')
self.assertEqual(x, 10) self.assertEqual(x, 10)
"""
def test_invalid_syntax_error_message(self): def test_invalid_syntax_error_message(self):
# with self.assertRaisesRegex(SyntaxError, "f-string: invalid syntax"): # with self.assertRaisesRegex(SyntaxError, "f-string: invalid syntax"):
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment