Commit 7d7e4eb1 authored by Stefan Behnel's avatar Stefan Behnel

implement PEP 448 for set/dict literals

parent be598992
......@@ -16,7 +16,7 @@ Features added
* Tracing is supported in ``nogil`` functions/sections and module init code.
* PEP 448 (Additional Unpacking Generalizations) was partially implemented
only for function calls.
for function calls, set and dict literals.
* When generators are used in a Cython module and the module imports the
modules "inspect" and/or "asyncio", Cython enables interoperability by
......
......@@ -297,6 +297,7 @@ class ExprNode(Node):
is_sequence_constructor = False
is_dict_literal = False
is_set_literal = False
is_string_literal = False
is_attribute = False
is_subscript = False
......@@ -5554,10 +5555,6 @@ class MergedDictNode(ExprNode):
code.mark_pos(self.pos)
self.allocate_temp_result(code)
if self.reject_duplicates and len(self.keyword_args) > 1:
code.globalstate.use_utility_code(
UtilityCode.load_cached("RaiseDoubleKeywords", "FunctionArguments.c"))
args = iter(self.keyword_args)
item = next(args)
item.generate_evaluation_code(code)
......@@ -5581,14 +5578,16 @@ class MergedDictNode(ExprNode):
code.putln('}')
item.free_temps(code)
helpers = set()
for item in args:
if item.is_dict_literal:
for arg in item.keyword_args:
for arg in item.key_value_pairs:
arg.generate_evaluation_code(code)
if self.reject_duplicates:
code.putln("if (unlikely(PyDict_Contains(%s, %s))) {" % (
self.result(),
arg.key.py_result()))
helpers.add("RaiseDoubleKeywords")
# FIXME: find out function name at runtime!
code.putln('__Pyx_RaiseDoubleKeywordsError("function", %s); %s' % (
arg.key.py_result(),
......@@ -5605,22 +5604,22 @@ class MergedDictNode(ExprNode):
item.generate_evaluation_code(code)
if self.reject_duplicates:
# merge mapping into kwdict one by one as we need to check for duplicates
code.globalstate.use_utility_code(
UtilityCode.load_cached("MergeKeywords", "FunctionArguments.c"))
helpers.add("MergeKeywords")
code.put_error_if_neg(item.pos, "__Pyx_MergeKeywords(%s, %s)" % (self.result(), item.py_result()))
else:
# simple case, just add all entries
code.globalstate.use_utility_code(
UtilityCode.load_cached("RaiseMappingExpected", "FunctionArguments.c"))
helpers.add("RaiseMappingExpected")
code.putln("if (unlikely(PyDict_Update(%s, %s) < 0)) {" % (self.result(), item.py_result()))
code.putln("if (PyErr_ExceptionMatches(PyExc_AttributeError)) __Pyx_RaiseMappingExpected(%s);" % (
code.putln("if (PyErr_ExceptionMatches(PyExc_AttributeError)) __Pyx_RaiseMappingExpectedError(%s);" % (
item.py_result()))
code.putln(code.error_goto(item.pos))
code.putln("}")
code.putln("}")
item.generate_disposal_code(code)
item.free_temps(code)
for helper in helpers:
code.globalstate.use_utility_code(UtilityCode.load_cached(helper, "FunctionArguments.c"))
def annotate(self, code):
for item in self.keyword_args:
item.annotate(code)
......@@ -7228,13 +7227,122 @@ class InlinedGeneratorExpressionNode(ScopedExprNode):
self.loop.generate_execution_code(code)
class SetNode(ExprNode):
# Set constructor.
class MergedSetNode(ExprNode):
"""
Merge a sequence of iterables into a set.
args [SetNode or other ExprNode]
"""
subexprs = ['args']
type = set_type
is_temp = True
gil_message = "Constructing Python set"
subexprs = ['args']
def calculate_constant_result(self):
result = set()
for item in self.args:
if item.is_sequence_constructor and item.mult_factor:
if item.mult_factor.constant_result <= 0:
continue
if item.is_set_literal or item.is_sequence_constructor:
# process items in order
items = (arg.constant_result for arg in item.args)
else:
items = item.constant_result
result.update(items)
self.constant_result = result
def compile_time_value(self, denv):
result = set()
for item in self.args:
if item.is_sequence_constructor and item.mult_factor:
if item.mult_factor.compile_time_value(denv) <= 0:
continue
if item.is_set_literal or item.is_sequence_constructor:
# process items in order
items = (arg.compile_time_value(denv) for arg in item.args)
else:
items = item.compile_time_value(denv)
try:
result.update(items)
except Exception as e:
self.compile_time_value_error(e)
return result
def type_dependencies(self, env):
return ()
def infer_type(self, env):
return set_type
def analyse_types(self, env):
args = [
arg.analyse_types(env).coerce_to_pyobject(env).as_none_safe_node(
# FIXME: CPython's error message starts with the runtime function name
'argument after * must be an iterable, not NoneType')
for arg in self.args
]
if len(args) == 1 and args[0].type is set_type:
# strip this intermediate node and use the bare set
return args[0]
self.args = args
return self
def may_be_none(self):
return False
def generate_evaluation_code(self, code):
code.mark_pos(self.pos)
self.allocate_temp_result(code)
args = iter(self.args)
item = next(args)
item.generate_evaluation_code(code)
if item.is_set_literal:
code.putln("%s = %s;" % (self.result(), item.py_result()))
item.generate_post_assignment_code(code)
else:
code.putln("%s = PySet_New(%s); %s" % (
self.result(),
item.py_result(),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
item.generate_disposal_code(code)
item.free_temps(code)
for item in args:
if item.is_set_literal or (item.is_sequence_constructor and not item.mult_factor):
for arg in item.args:
arg.generate_evaluation_code(code)
code.put_error_if_neg(arg.pos, "PySet_Add(%s, %s)" % (
self.result(),
arg.py_result()))
arg.generate_disposal_code(code)
arg.free_temps(code)
continue
item.generate_evaluation_code(code)
code.globalstate.use_utility_code(UtilityCode.load_cached("PySet_Update", "Builtins.c"))
code.put_error_if_neg(item.pos, "__Pyx_PySet_Update(%s, %s)" % (
self.result(),
item.py_result()))
item.generate_disposal_code(code)
item.free_temps(code)
def annotate(self, code):
for item in self.args:
item.annotate(code)
class SetNode(ExprNode):
"""
Set constructor.
"""
subexprs = ['args']
type = set_type
is_set_literal = True
gil_message = "Constructing Python set"
def analyse_types(self, env):
......@@ -7388,12 +7496,10 @@ class DictNode(ExprNode):
self.result(),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
if self.reject_duplicates and len(self.key_value_pairs) > 1:
code.globalstate.use_utility_code(
UtilityCode.load_cached("RaiseDoubleKeywords", "FunctionArguments.c"))
keys_seen = set()
key_type = None
needs_error_helper = False
for item in self.key_value_pairs:
item.generate_evaluation_code(code)
......@@ -7423,6 +7529,7 @@ class DictNode(ExprNode):
code.putln('if (unlikely(PyDict_Contains(%s, %s))) {' % (
self.result(), key.py_result()))
# currently only used in function calls
needs_error_helper = True
code.putln('__Pyx_RaiseDoubleKeywordsError("function", %s); %s' % (
key.py_result(),
code.error_goto(item.pos)))
......@@ -7444,6 +7551,10 @@ class DictNode(ExprNode):
item.generate_disposal_code(code)
item.free_temps(code)
if needs_error_helper:
code.globalstate.use_utility_code(
UtilityCode.load_cached("RaiseDoubleKeywords", "FunctionArguments.c"))
def annotate(self, code):
for item in self.key_value_pairs:
item.annotate(code)
......
......@@ -973,7 +973,12 @@ def p_comp_if(s, body):
body = p_comp_iter(s, body))],
else_clause = None )
#dictmaker: test ':' test (',' test ':' test)* [',']
# since PEP 448:
#dictorsetmaker: ( ((test ':' test | '**' expr)
# (comp_for | (',' (test ':' test | '**' expr))* [','])) |
# ((test | star_expr)
# (comp_for | (',' (test | star_expr))* [','])) )
def p_dict_or_set_maker(s):
# s.sy == '{'
......@@ -981,57 +986,112 @@ def p_dict_or_set_maker(s):
s.next()
if s.sy == '}':
s.next()
return ExprNodes.DictNode(pos, key_value_pairs = [])
item = p_test(s)
if s.sy == ',' or s.sy == '}':
# set literal
values = [item]
while s.sy == ',':
return ExprNodes.DictNode(pos, key_value_pairs=[])
parts = []
target_type = 0
last_was_simple_item = False
while True:
if s.sy in ('*', '**'):
# merged set/dict literal
if target_type == 0:
target_type = 1 if s.sy == '*' else 2 # 'stars'
elif target_type != len(s.sy):
s.error("unexpected %sitem found in %s literal" % (
s.sy, 'set' if target_type == 1 else 'dict'))
s.next()
item = p_test(s)
parts.append(item)
last_was_simple_item = False
else:
item = p_test(s)
if target_type == 0:
target_type = 2 if s.sy == ':' else 1 # dict vs. set
if target_type == 2:
# dict literal
s.expect(':')
key = item
value = p_test(s)
item = ExprNodes.DictItemNode(key.pos, key=key, value=value)
if last_was_simple_item:
parts[-1].append(item)
else:
parts.append([item])
last_was_simple_item = True
if s.sy == ',':
s.next()
if s.sy == '}':
break
values.append( p_test(s) )
s.expect('}')
return ExprNodes.SetNode(pos, args=values)
elif s.sy == 'for':
# set comprehension
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item)
loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type=Builtin.set_type)
elif s.sy == ':':
# dict literal or comprehension
key = item
s.next()
value = p_test(s)
if s.sy == 'for':
# dict comprehension
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value)
else:
break
if s.sy == 'for':
# dict/set comprehension
if len(parts) == 1 and isinstance(parts[0], list) and len(parts[0]) == 1:
item = parts[0][0]
if target_type == 2:
assert isinstance(item, ExprNodes.DictItemNode), type(item)
comprehension_type = Builtin.dict_type
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=item.key, value_expr=item.value)
else:
comprehension_type = Builtin.set_type
append = ExprNodes.ComprehensionAppendNode(item.pos, expr=item)
loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, type=Builtin.dict_type)
return ExprNodes.ComprehensionNode(pos, loop=loop, append=append, type=comprehension_type)
else:
# dict literal
items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
while s.sy == ',':
s.next()
if s.sy == '}':
break
key = p_test(s)
s.expect(':')
value = p_test(s)
items.append(
ExprNodes.DictItemNode(key.pos, key=key, value=value))
s.expect('}')
return ExprNodes.DictNode(pos, key_value_pairs=items)
else:
# raise an error
s.expect('}')
return ExprNodes.DictNode(pos, key_value_pairs = [])
# syntax error, try to find a good error message
if len(parts) == 1 and not isinstance(parts[0], list):
s.error("iterable unpacking cannot be used in comprehension")
else:
# e.g. "{1,2,3 for ..."
s.expect('}')
return ExprNodes.DictNode(pos, key_value_pairs=[])
s.expect('}')
if target_type == 1:
# (merged) set literal
items = []
set_items = []
for part in parts:
if isinstance(part, list):
set_items.extend(part)
elif part.is_set_literal or part.is_sequence_constructor:
# unpack *{1,2,3} and *[1,2,3] in place
set_items.extend(part.args)
else:
if set_items:
items.append(ExprNodes.SetNode(set_items[0].pos, args=set_items))
set_items = []
items.append(part)
if set_items:
items.append(ExprNodes.SetNode(set_items[0].pos, args=set_items))
if len(items) == 1 and items[0].is_set_literal:
return items[0]
return ExprNodes.MergedSetNode(pos, args=items)
else:
# (merged) dict literal
items = []
dict_items = []
for part in parts:
if isinstance(part, list):
dict_items.extend(part)
elif part.is_dict_literal:
# unpack **{...} in place
dict_items.extend(part.key_value_pairs)
else:
if dict_items:
items.append(ExprNodes.DictNode(dict_items[0].pos, key_value_pairs=dict_items))
dict_items = []
items.append(part)
if dict_items:
items.append(ExprNodes.DictNode(dict_items[0].pos, key_value_pairs=dict_items))
if len(items) == 1 and items[0].is_dict_literal:
return items[0]
return ExprNodes.MergedDictNode(pos, keyword_args=items, reject_duplicates=False)
# NOTE: no longer in Py3 :)
def p_backquote_expr(s):
......
......@@ -441,3 +441,33 @@ static CYTHON_INLINE PyObject* __Pyx_PyFrozenSet_New(PyObject* it) {
return PyObject_Call((PyObject*)&PyFrozenSet_Type, $empty_tuple, NULL);
#endif
}
//////////////////// PySet_Update.proto ////////////////////
static CYTHON_INLINE int __Pyx_PySet_Update(PyObject* set, PyObject* it); /*proto*/
//////////////////// PySet_Update ////////////////////
//@requires: ObjectHandling.c::PyObjectCallMethod1
static CYTHON_INLINE int __Pyx_PySet_Update(PyObject* set, PyObject* it) {
PyObject *retval;
#if CYTHON_COMPILING_IN_CPYTHON
if (PyAnySet_Check(it)) {
// fast and safe case: CPython will update our result set and return it
retval = PySet_Type.tp_as_number->nb_inplace_or(set, it);
if (likely(retval == set)) {
Py_DECREF(retval);
return 0;
}
if (unlikely(!retval))
return -1;
// unusual result, fall through to set.update() call below
Py_DECREF(retval);
}
#endif
retval = __Pyx_PyObject_CallMethod1(set, PYIDENT("update"), it);
if (unlikely(!retval)) return -1;
Py_DECREF(retval);
return 0;
}
cimport cython
class Iter(object):
def __init__(self, it=()):
self.it = iter(it)
def __iter__(self):
return self
def __next__(self):
return next(self.it)
next = __next__
class Map(object):
def __init__(self, mapping={}):
self.mapping = mapping
def __iter__(self):
return iter(self.mapping)
def keys(self):
return self.mapping.keys()
def __getitem__(self, key):
return self.mapping[key]
@cython.test_fail_if_path_exists(
"//SetNode//SetNode",
"//MergedSetNode//SetNode",
"//MergedSetNode//MergedSetNode",
)
def unpack_set_literal():
"""
>>> s = unpack_set_literal()
>>> s == set([1, 2, 4, 5]) or s
True
"""
return {*{1, 2, *{4, 5}}}
def unpack_set_simple(it):
"""
>>> s = unpack_set_simple([])
>>> s == set([]) or s
True
>>> s = unpack_set_simple(set())
>>> s == set([]) or s
True
>>> s = unpack_set_simple(Iter())
>>> s == set([]) or s
True
>>> s = unpack_set_simple([1])
>>> s == set([1]) or s
True
>>> s = unpack_set_simple([2, 1])
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_simple((2, 1))
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_simple(set([2, 1]))
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_simple(Iter([2, 1]))
>>> s == set([1, 2]) or s
True
"""
return {*it}
def unpack_set_from_iterable(it):
"""
>>> s = unpack_set_from_iterable([1, 2, 3])
>>> s == set([1, 2, 3]) or s
True
>>> s = unpack_set_from_iterable([1, 2])
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_from_iterable(set([1, 2]))
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_from_iterable(Iter([1, 2]))
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_from_iterable([3])
>>> s == set([1, 2, 3]) or s
True
>>> s = unpack_set_from_iterable(set([3]))
>>> s == set([1, 2, 3]) or s
True
>>> s = unpack_set_from_iterable(Iter([3]))
>>> s == set([1, 2, 3]) or s
True
>>> s = unpack_set_from_iterable([])
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_from_iterable(set([]))
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_from_iterable([])
>>> s == set([1, 2]) or s
True
>>> s = unpack_set_from_iterable((1, 2, 3))
>>> s == set([1, 2, 3]) or s
True
>>> s = unpack_set_from_iterable(set([1, 2, 3]))
>>> s == set([1, 2, 3]) or s
True
>>> s = unpack_set_from_iterable(Iter([1, 2, 3]))
>>> s == set([1, 2, 3]) or s
True
"""
return {1, 2, *it, 1, *{*it, *it}, *it, 2, 1, *it, *it}
@cython.test_fail_if_path_exists(
"//DictNode//DictNode",
"//MergedDictNode//DictNode",
"//MergedDictNode//MergedDictNode",
)
def unpack_dict_literal():
"""
>>> d = unpack_dict_literal()
>>> d == dict(a=1, b=2, c=4, d=5) or d
True
"""
return {**{'a': 1, 'b': 2, **{'c': 4, 'd': 5}}}
def unpack_dict_simple(it):
"""
>>> d = unpack_dict_simple({})
>>> d == {} or d
True
>>> d = unpack_dict_simple([])
>>> d == {} or d
True
>>> d = unpack_dict_simple(set())
>>> d == {} or d
True
>>> d = unpack_dict_simple(Iter())
>>> d == {} or d
True
>>> d = unpack_dict_simple(Map())
>>> d == {} or d
True
>>> d = unpack_dict_simple(dict(a=1))
>>> d == dict(a=1) or d
True
>>> d = unpack_dict_simple(dict(a=1, b=2))
>>> d == dict(a=1, b=2) or d
True
>>> d = unpack_dict_simple(Map(dict(a=1, b=2)))
>>> d == dict(a=1, b=2) or d
True
"""
return {**it}
def unpack_dict_from_iterable(it):
"""
>>> d = unpack_dict_from_iterable(dict(a=1, b=2, c=3))
>>> d == dict(a=1, b=2, c=3) or d
True
>>> d = unpack_dict_from_iterable(dict(a=1, b=2))
>>> d == dict(a=1, b=2) or d
True
>>> d = unpack_dict_from_iterable(Map(dict(a=1, b=2)))
>>> d == dict(a=1, b=2) or d
True
>>> d = unpack_dict_from_iterable(dict(a=3))
>>> d == dict(a=3, b=5) or d
True
>>> d = unpack_dict_from_iterable(Map(dict(a=3)))
>>> d == dict(a=3, b=5) or d
True
>>> d = unpack_dict_from_iterable({})
>>> d == dict(a=4, b=5) or d
True
>>> d = unpack_dict_from_iterable(Map())
>>> d == dict(a=4, b=5) or d
True
>>> d = unpack_dict_from_iterable(Iter())
Traceback (most recent call last):
TypeError: 'Iter' object is not a mapping
>>> d = unpack_dict_from_iterable([])
Traceback (most recent call last):
TypeError: 'list' object is not a mapping
>>> d = unpack_dict_from_iterable(dict(b=2, c=3))
>>> d == dict(a=4, b=2, c=3) or d
True
>>> d = unpack_dict_from_iterable(Map(dict(b=2, c=3)))
>>> d == dict(a=4, b=2, c=3) or d
True
>>> d = unpack_dict_from_iterable(dict(a=2, c=3))
>>> d == dict(a=2, b=5, c=3) or d
True
>>> d = unpack_dict_from_iterable(Map(dict(a=2, c=3)))
>>> d == dict(a=2, b=5, c=3) or d
True
"""
return {'a': 2, 'b': 3, **it, 'a': 1, **{**it, **it}, **it, 'a': 4, 'b': 5, **it, **it}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment