Commit 418bcde3 authored by Stefan Behnel's avatar Stefan Behnel

statically unpack string values in sequence assignments

parent 4cfcb3cb
...@@ -200,6 +200,7 @@ class ExprNode(Node): ...@@ -200,6 +200,7 @@ class ExprNode(Node):
# #
is_sequence_constructor = 0 is_sequence_constructor = 0
is_string_literal = 0
is_attribute = 0 is_attribute = 0
saved_subexpr_nodes = None saved_subexpr_nodes = None
...@@ -960,6 +961,7 @@ class BytesNode(ConstNode): ...@@ -960,6 +961,7 @@ class BytesNode(ConstNode):
# #
# value BytesLiteral # value BytesLiteral
is_string_literal = True
# start off as Python 'bytes' to support len() in O(1) # start off as Python 'bytes' to support len() in O(1)
type = bytes_type type = bytes_type
...@@ -1040,6 +1042,7 @@ class UnicodeNode(PyConstNode): ...@@ -1040,6 +1042,7 @@ class UnicodeNode(PyConstNode):
# value EncodedString # value EncodedString
# bytes_value BytesLiteral the literal parsed as bytes string ('-3' unicode literals only) # bytes_value BytesLiteral the literal parsed as bytes string ('-3' unicode literals only)
is_string_literal = True
bytes_value = None bytes_value = None
type = unicode_type type = unicode_type
...@@ -1104,6 +1107,7 @@ class StringNode(PyConstNode): ...@@ -1104,6 +1107,7 @@ class StringNode(PyConstNode):
# is_identifier boolean # is_identifier boolean
type = str_type type = str_type
is_string_literal = True
is_identifier = None is_identifier = None
unicode_value = None unicode_value = None
...@@ -3680,7 +3684,7 @@ class AttributeNode(ExprNode): ...@@ -3680,7 +3684,7 @@ class AttributeNode(ExprNode):
module_scope = self.obj.analyse_as_module(env) module_scope = self.obj.analyse_as_module(env)
if module_scope: if module_scope:
return module_scope.lookup_type(self.attribute) return module_scope.lookup_type(self.attribute)
if not isinstance(self.obj, (UnicodeNode, StringNode, BytesNode)): if not self.obj.is_string_literal:
base_type = self.obj.analyse_as_type(env) base_type = self.obj.analyse_as_type(env)
if base_type and hasattr(base_type, 'scope') and base_type.scope is not None: if base_type and hasattr(base_type, 'scope') and base_type.scope is not None:
return base_type.scope.lookup_type(self.attribute) return base_type.scope.lookup_type(self.attribute)
...@@ -4811,7 +4815,7 @@ class DictNode(ExprNode): ...@@ -4811,7 +4815,7 @@ class DictNode(ExprNode):
for item in self.key_value_pairs: for item in self.key_value_pairs:
if isinstance(item.key, CoerceToPyTypeNode): if isinstance(item.key, CoerceToPyTypeNode):
item.key = item.key.arg item.key = item.key.arg
if not isinstance(item.key, (UnicodeNode, StringNode, BytesNode)): if not item.key.is_string_literal:
error(item.key.pos, "Invalid struct field identifier") error(item.key.pos, "Invalid struct field identifier")
item.key = StringNode(item.key.pos, value="<error>") item.key = StringNode(item.key.pos, value="<error>")
else: else:
...@@ -6695,11 +6699,9 @@ class CmpNode(object): ...@@ -6695,11 +6699,9 @@ class CmpNode(object):
type1_can_be_int = False type1_can_be_int = False
type2_can_be_int = False type2_can_be_int = False
if isinstance(operand1, (StringNode, BytesNode, UnicodeNode)) \ if operand1.is_string_literal and operand1.can_coerce_to_char_literal():
and operand1.can_coerce_to_char_literal():
type1_can_be_int = True type1_can_be_int = True
if isinstance(operand2, (StringNode, BytesNode, UnicodeNode)) \ if operand2.is_string_literal and operand2.can_coerce_to_char_literal():
and operand2.can_coerce_to_char_literal():
type2_can_be_int = True type2_can_be_int = True
if type1.is_int: if type1.is_int:
......
...@@ -284,7 +284,8 @@ class PostParse(ScopeTrackingTransform): ...@@ -284,7 +284,8 @@ class PostParse(ScopeTrackingTransform):
"""Flatten parallel assignments into separate single """Flatten parallel assignments into separate single
assignments or cascaded assignments. assignments or cascaded assignments.
""" """
if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) < 2: if sum([ 1 for expr in expr_list
if expr.is_sequence_constructor or expr.is_string_literal ]) < 2:
# no parallel assignments => nothing to do # no parallel assignments => nothing to do
return node return node
...@@ -412,6 +413,17 @@ def sort_common_subsequences(items): ...@@ -412,6 +413,17 @@ def sort_common_subsequences(items):
items[i] = items[i-1] items[i] = items[i-1]
items[new_pos] = item items[new_pos] = item
def unpack_string_to_character_literals(literal):
chars = []
pos = literal.pos
stype = literal.__class__
sval = literal.value
sval_type = sval.__class__
for char in sval:
cval = sval_type(char)
chars.append(stype(pos, value=cval, constant_result=cval))
return chars
def flatten_parallel_assignments(input, output): def flatten_parallel_assignments(input, output):
# The input is a list of expression nodes, representing the LHSs # The input is a list of expression nodes, representing the LHSs
# and RHS of one (possibly cascaded) assignment statement. For # and RHS of one (possibly cascaded) assignment statement. For
...@@ -420,13 +432,21 @@ def flatten_parallel_assignments(input, output): ...@@ -420,13 +432,21 @@ def flatten_parallel_assignments(input, output):
# individual elements. This transformation is applied # individual elements. This transformation is applied
# recursively, so that nested structures get matched as well. # recursively, so that nested structures get matched as well.
rhs = input[-1] rhs = input[-1]
if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]): if (not (rhs.is_sequence_constructor or
(rhs.is_string_literal and not (rhs.type.is_string or
rhs.type is Builtin.bytes_type)))
or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])):
output.append(input) output.append(input)
return return
complete_assignments = [] complete_assignments = []
rhs_size = len(rhs.args) if rhs.is_sequence_constructor:
rhs_args = rhs.args
elif rhs.is_string_literal:
rhs_args = unpack_string_to_character_literals(rhs)
rhs_size = len(rhs_args)
lhs_targets = [ [] for _ in xrange(rhs_size) ] lhs_targets = [ [] for _ in xrange(rhs_size) ]
starred_assignments = [] starred_assignments = []
for lhs in input[:-1]: for lhs in input[:-1]:
...@@ -448,7 +468,7 @@ def flatten_parallel_assignments(input, output): ...@@ -448,7 +468,7 @@ def flatten_parallel_assignments(input, output):
continue continue
elif starred_targets: elif starred_targets:
map_starred_assignment(lhs_targets, starred_assignments, map_starred_assignment(lhs_targets, starred_assignments,
lhs.args, rhs.args) lhs.args, rhs_args)
elif lhs_size < rhs_size: elif lhs_size < rhs_size:
error(lhs.pos, "too many values to unpack (expected %d, got %d)" error(lhs.pos, "too many values to unpack (expected %d, got %d)"
% (lhs_size, rhs_size)) % (lhs_size, rhs_size))
...@@ -463,7 +483,7 @@ def flatten_parallel_assignments(input, output): ...@@ -463,7 +483,7 @@ def flatten_parallel_assignments(input, output):
output.append(complete_assignments) output.append(complete_assignments)
# recursively flatten partial assignments # recursively flatten partial assignments
for cascade, rhs in zip(lhs_targets, rhs.args): for cascade, rhs in zip(lhs_targets, rhs_args):
if cascade: if cascade:
cascade.append(rhs) cascade.append(rhs)
flatten_parallel_assignments(cascade, output) flatten_parallel_assignments(cascade, output)
......
# mode: run
# tag: string, unicode, bytes, sequence unpacking, starexpr
def unpack_single_str():
"""
>>> print(unpack_single_str())
a
"""
a, = 'a'
return a
def unpack_str():
"""
>>> a,b = unpack_str()
>>> print(a)
a
>>> print(b)
b
"""
a,b = 'ab'
return a,b
def star_unpack_str():
"""
>>> a,b,c = star_unpack_str()
>>> print(a)
a
>>> type(b) is list
True
>>> print(''.join(b))
bbb
>>> print(c)
c
"""
a,*b,c = 'abbbc'
return a,b,c
def unpack_single_unicode():
"""
>>> print(unpack_single_unicode())
a
"""
a, = u'a'
return a
def unpack_unicode():
"""
>>> a,b = unpack_unicode()
>>> print(a)
a
>>> print(b)
b
"""
a,b = u'ab'
return a,b
def star_unpack_unicode():
"""
>>> a,b,c = star_unpack_unicode()
>>> print(a)
a
>>> type(b) is list
True
>>> print(''.join(b))
bbb
>>> print(c)
c
"""
a,*b,c = u'abbbc'
return a,b,c
# the following is not supported due to Py2/Py3 bytes differences
## def unpack_single_bytes():
## """
## >>> print(unpack_single_bytes().decode('ASCII'))
## a
## """
## a, = b'a'
## return a
## def unpack_bytes():
## """
## >>> a,b = unpack_bytes()
## >>> print(a.decode('ASCII'))
## a
## >>> print(b.decode('ASCII'))
## b
## """
## a,b = b'ab'
## return a,b
## def star_unpack_bytes():
## """
## >>> a,b,c = star_unpack_bytes()
## >>> print(a.decode('ASCII'))
## a
## >>> type(b) is list
## True
## >>> print(''.join([ch.decode('ASCII') for ch in b]))
## bbb
## >>> print(c.decode('ASCII'))
## c
## """
## a,*b,c = b'abbbc'
## return a,b,c
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment