Commit 8b15b5d4 authored by Stefan Behnel's avatar Stefan Behnel

merge

parents aee166ac af9fb6bc
...@@ -1219,7 +1219,7 @@ class CCodeWriter(object): ...@@ -1219,7 +1219,7 @@ class CCodeWriter(object):
def put_var_decref(self, entry): def put_var_decref(self, entry):
if entry.type.is_pyobject: if entry.type.is_pyobject:
if entry.init_to_none is False: if entry.init_to_none is False: # FIXME: 0 and False are treated differently???
self.putln("__Pyx_XDECREF(%s);" % self.entry_as_pyobject(entry)) self.putln("__Pyx_XDECREF(%s);" % self.entry_as_pyobject(entry))
else: else:
self.putln("__Pyx_DECREF(%s);" % self.entry_as_pyobject(entry)) self.putln("__Pyx_DECREF(%s);" % self.entry_as_pyobject(entry))
......
...@@ -1770,6 +1770,9 @@ class IteratorNode(ExprNode): ...@@ -1770,6 +1770,9 @@ class IteratorNode(ExprNode):
self.type = self.sequence.type self.type = self.sequence.type
else: else:
self.sequence = self.sequence.coerce_to_pyobject(env) self.sequence = self.sequence.coerce_to_pyobject(env)
if self.sequence.type is list_type or \
self.sequence.type is tuple_type:
self.sequence = self.sequence.as_none_safe_node("'NoneType' object is not iterable")
self.is_temp = 1 self.is_temp = 1
gil_message = "Iterating over Python object" gil_message = "Iterating over Python object"
...@@ -1786,28 +1789,22 @@ class IteratorNode(ExprNode): ...@@ -1786,28 +1789,22 @@ class IteratorNode(ExprNode):
raise InternalError("for in carray slice not transformed") raise InternalError("for in carray slice not transformed")
is_builtin_sequence = self.sequence.type is list_type or \ is_builtin_sequence = self.sequence.type is list_type or \
self.sequence.type is tuple_type self.sequence.type is tuple_type
may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type may_be_a_sequence = not self.sequence.type.is_builtin_type
if is_builtin_sequence: if may_be_a_sequence:
code.putln(
"if (likely(%s != Py_None)) {" % self.sequence.py_result())
elif may_be_a_sequence:
code.putln( code.putln(
"if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % ( "if (PyList_CheckExact(%s) || PyTuple_CheckExact(%s)) {" % (
self.sequence.py_result(), self.sequence.py_result(),
self.sequence.py_result())) self.sequence.py_result()))
if may_be_a_sequence: if is_builtin_sequence or may_be_a_sequence:
code.putln( code.putln(
"%s = 0; %s = %s; __Pyx_INCREF(%s);" % ( "%s = 0; %s = %s; __Pyx_INCREF(%s);" % (
self.counter_cname, self.counter_cname,
self.result(), self.result(),
self.sequence.py_result(), self.sequence.py_result(),
self.result())) self.result()))
if not is_builtin_sequence:
if may_be_a_sequence:
code.putln("} else {") code.putln("} else {")
if is_builtin_sequence:
code.putln(
'PyErr_SetString(PyExc_TypeError, "\'NoneType\' object is not iterable"); %s' %
code.error_goto(self.pos))
else:
code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % ( code.putln("%s = -1; %s = PyObject_GetIter(%s); %s" % (
self.counter_cname, self.counter_cname,
self.result(), self.result(),
...@@ -4135,46 +4132,49 @@ class ScopedExprNode(ExprNode): ...@@ -4135,46 +4132,49 @@ class ScopedExprNode(ExprNode):
subexprs = [] subexprs = []
expr_scope = None expr_scope = None
def analyse_types(self, env): # does this node really have a local scope, e.g. does it leak loop
# nothing to do here, the children will be analysed separately # variables or not? non-leaking Py3 behaviour is default, except
# for list comprehensions where the behaviour differs in Py2 and
# Py3 (set in Parsing.py based on parser context)
has_local_scope = True
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_declarations(self, env):
self.init_scope(env)
def analyse_scoped_declarations(self, env):
# this is called with the expr_scope as env
pass pass
def analyse_expressions(self, env): def analyse_types(self, env):
# nothing to do here, the children will be analysed separately # no recursion here, the children will be analysed separately below
pass pass
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
# this is called with the expr_scope as env # this is called with the expr_scope as env
pass pass
def init_scope(self, outer_scope, expr_scope=None):
self.expr_scope = expr_scope
class ComprehensionNode(ScopedExprNode): class ComprehensionNode(ScopedExprNode):
subexprs = ["target"] subexprs = ["target"]
child_attrs = ["loop", "append"] child_attrs = ["loop", "append"]
# leak loop variables or not? non-leaking Py3 behaviour is
# default, except for list comprehensions where the behaviour
# differs in Py2 and Py3 (see Parsing.py)
has_local_scope = True
def infer_type(self, env): def infer_type(self, env):
return self.target.infer_type(env) return self.target.infer_type(env)
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop self.append.target = self # this is used in the PyList_Append of the inner loop
self.init_scope(env) self.init_scope(env)
self.loop.analyse_declarations(self.expr_scope or env)
def init_scope(self, outer_scope, expr_scope=None): def analyse_scoped_declarations(self, env):
if expr_scope is not None: self.loop.analyse_declarations(env)
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_types(self, env): def analyse_types(self, env):
self.target.analyse_expressions(env) self.target.analyse_expressions(env)
...@@ -4182,9 +4182,6 @@ class ComprehensionNode(ScopedExprNode): ...@@ -4182,9 +4182,6 @@ class ComprehensionNode(ScopedExprNode):
if not self.has_local_scope: if not self.has_local_scope:
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
def analyse_expressions(self, env):
self.analyse_types(env)
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
if self.has_local_scope: if self.has_local_scope:
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
...@@ -4286,20 +4283,16 @@ class GeneratorExpressionNode(ScopedExprNode): ...@@ -4286,20 +4283,16 @@ class GeneratorExpressionNode(ScopedExprNode):
type = py_object_type type = py_object_type
def analyse_declarations(self, env): def analyse_scoped_declarations(self, env):
self.init_scope(env) self.loop.analyse_declarations(env)
self.loop.analyse_declarations(self.expr_scope)
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
else:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
def analyse_types(self, env): def analyse_types(self, env):
if not self.has_local_scope:
self.loop.analyse_expressions(env)
self.is_temp = True self.is_temp = True
def analyse_scoped_expressions(self, env): def analyse_scoped_expressions(self, env):
if self.has_local_scope:
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
def may_be_none(self): def may_be_none(self):
...@@ -4320,14 +4313,29 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode): ...@@ -4320,14 +4313,29 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
# orig_func String the name of the builtin function this node replaces # orig_func String the name of the builtin function this node replaces
child_attrs = ["loop"] child_attrs = ["loop"]
loop_analysed = False
def infer_type(self, env):
return self.result_node.infer_type(env)
def analyse_types(self, env): def analyse_types(self, env):
if not self.has_local_scope:
self.loop_analysed = True
self.loop.analyse_expressions(env)
self.type = self.result_node.type self.type = self.result_node.type
self.is_temp = True self.is_temp = True
def analyse_scoped_expressions(self, env):
self.loop_analysed = True
GeneratorExpressionNode.analyse_scoped_expressions(self, env)
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if self.orig_func == 'sum' and dst_type.is_numeric: if self.orig_func == 'sum' and dst_type.is_numeric and not self.loop_analysed:
# we can optimise by dropping the aggregation variable into C # We can optimise by dropping the aggregation variable and
# the add operations into C. This can only be done safely
# before analysing the loop body, after that, the result
# reference type will have infected expressions and
# assignments.
self.result_node.type = self.type = dst_type self.result_node.type = self.type = dst_type
return self return self
return GeneratorExpressionNode.coerce_to(self, dst_type, env) return GeneratorExpressionNode.coerce_to(self, dst_type, env)
......
...@@ -1339,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1339,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
""" """
if len(pos_args) not in (1,2): if len(pos_args) not in (1,2):
return node return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
ExprNodes.ComprehensionNode)):
return node return node
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
else: # ComprehensionNode
yield_stat_node = gen_expr_node.append
yield_expression = yield_stat_node.expr
try:
if not yield_expression.is_literal or not yield_expression.type.is_int:
return node
except AttributeError:
return node # in case we don't have a type yet
# special case: old Py2 backwards compatible "sum([int_const for ...])"
# can safely be unpacked into a genexpr
if len(pos_args) == 1: if len(pos_args) == 1:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
...@@ -1375,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1375,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.InlinedGeneratorExpressionNode( return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref, gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum') expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
has_local_scope = gen_expr_node.has_local_scope)
def _handle_simple_function_min(self, node, pos_args): def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<') return self._optimise_min_max(node, pos_args, '<')
......
cimport cython
from Cython.Compiler.Visitor cimport ( from Cython.Compiler.Visitor cimport (
CythonTransform, VisitorTransform, CythonTransform, VisitorTransform, TreeVisitor,
ScopeTrackingTransform, EnvTransform) ScopeTrackingTransform, EnvTransform)
#class NameNodeCollector(TreeVisitor): cdef class NameNodeCollector(TreeVisitor):
# cdef list name_nodes cdef list name_nodes
cdef class SkipDeclarations: # (object): cdef class SkipDeclarations: # (object):
pass pass
...@@ -12,6 +14,7 @@ cdef class SkipDeclarations: # (object): ...@@ -12,6 +14,7 @@ cdef class SkipDeclarations: # (object):
cdef class NormalizeTree(CythonTransform): cdef class NormalizeTree(CythonTransform):
cdef bint is_in_statlist cdef bint is_in_statlist
cdef bint is_in_expr cdef bint is_in_expr
cpdef visit_StatNode(self, node, is_listcontainer=*)
cdef class PostParse(ScopeTrackingTransform): cdef class PostParse(ScopeTrackingTransform):
cdef dict specialattribute_handlers cdef dict specialattribute_handlers
...@@ -21,6 +24,7 @@ cdef class PostParse(ScopeTrackingTransform): ...@@ -21,6 +24,7 @@ cdef class PostParse(ScopeTrackingTransform):
#def eliminate_rhs_duplicates(list expr_list_list, list ref_node_sequence) #def eliminate_rhs_duplicates(list expr_list_list, list ref_node_sequence)
#def sort_common_subsequences(list items) #def sort_common_subsequences(list items)
@cython.locals(starred_targets=Py_ssize_t, lhs_size=Py_ssize_t, rhs_size=Py_ssize_t)
cdef flatten_parallel_assignments(list input, list output) cdef flatten_parallel_assignments(list input, list output)
cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs_args, list rhs_args) cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs_args, list rhs_args)
......
import cython import cython
from cython import set
cython.declare(copy=object, ModuleNode=object, TreeFragment=object, TemplateTransform=object, cython.declare(copy=object, ModuleNode=object, TreeFragment=object, TemplateTransform=object,
EncodedString=object, error=object, warning=object, PyrexTypes=object, Naming=object) EncodedString=object, error=object, warning=object, PyrexTypes=object, Naming=object)
...@@ -26,11 +25,12 @@ class NameNodeCollector(TreeVisitor): ...@@ -26,11 +25,12 @@ class NameNodeCollector(TreeVisitor):
super(NameNodeCollector, self).__init__() super(NameNodeCollector, self).__init__()
self.name_nodes = [] self.name_nodes = []
visit_Node = TreeVisitor.visitchildren
def visit_NameNode(self, node): def visit_NameNode(self, node):
self.name_nodes.append(node) self.name_nodes.append(node)
def visit_Node(self, node):
self._visitchildren(node, None)
class SkipDeclarations(object): class SkipDeclarations(object):
""" """
...@@ -300,7 +300,7 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): ...@@ -300,7 +300,7 @@ def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
and appends them to ref_node_sequence. The input list is modified and appends them to ref_node_sequence. The input list is modified
in-place. in-place.
""" """
seen_nodes = set() seen_nodes = cython.set()
ref_nodes = {} ref_nodes = {}
def find_duplicates(node): def find_duplicates(node):
if node.is_literal or node.is_name: if node.is_literal or node.is_name:
...@@ -571,16 +571,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -571,16 +571,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma' : c_binop_constructor(','), 'operator.comma' : c_binop_constructor(','),
} }
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL'] 'cast', 'pointer', 'compiled', 'NULL'])
+ list(unop_method_nodes.keys())) special_methods.update(unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
self.compilation_directive_defaults = {} self.compilation_directive_defaults = {}
for key, value in compilation_directive_defaults.items(): for key, value in compilation_directive_defaults.items():
self.compilation_directive_defaults[unicode(key)] = value self.compilation_directive_defaults[unicode(key)] = value
self.cython_module_names = set() self.cython_module_names = cython.set()
self.directive_names = {} self.directive_names = {}
def check_directive_scope(self, pos, directive, scope): def check_directive_scope(self, pos, directive, scope):
...@@ -1022,7 +1022,7 @@ property NAME: ...@@ -1022,7 +1022,7 @@ property NAME:
return node return node
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(cython.set())
node.analyse_declarations(self.env_stack[-1]) node.analyse_declarations(self.env_stack[-1])
self.visitchildren(node) self.visitchildren(node)
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
...@@ -1054,7 +1054,7 @@ property NAME: ...@@ -1054,7 +1054,7 @@ property NAME:
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.seen_vars_stack.append(set()) self.seen_vars_stack.append(cython.set())
lenv = node.local_scope lenv = node.local_scope
node.body.analyse_control_flow(lenv) # this will be totally refactored node.body.analyse_control_flow(lenv) # this will be totally refactored
node.declare_arguments(lenv) node.declare_arguments(lenv)
...@@ -1073,15 +1073,18 @@ property NAME: ...@@ -1073,15 +1073,18 @@ property NAME:
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
node.analyse_declarations(self.env_stack[-1]) env = self.env_stack[-1]
node.analyse_declarations(env)
# the node may or may not have a local scope # the node may or may not have a local scope
if node.expr_scope: if node.has_local_scope:
self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) self.seen_vars_stack.append(cython.set(self.seen_vars_stack[-1]))
self.env_stack.append(node.expr_scope) self.env_stack.append(node.expr_scope)
node.analyse_scoped_declarations(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.env_stack.pop()
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
else: else:
node.analyse_scoped_declarations(env)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -1177,7 +1180,7 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1177,7 +1180,7 @@ class AnalyseExpressionsTransform(CythonTransform):
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.expr_scope is not None: if node.has_local_scope:
node.expr_scope.infer_types() node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope) node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
......
...@@ -141,6 +141,9 @@ class ResultRefNode(AtomicExprNode): ...@@ -141,6 +141,9 @@ class ResultRefNode(AtomicExprNode):
def infer_type(self, env): def infer_type(self, env):
if self.expression is not None: if self.expression is not None:
return self.expression.infer_type(env) return self.expression.infer_type(env)
if self.type is not None:
return self.type
assert False, "cannot infer type of ResultRefNode"
def may_be_none(self): def may_be_none(self):
if not self.type.is_pyobject: if not self.type.is_pyobject:
......
...@@ -149,6 +149,46 @@ def return_typed_sum_squares_start(seq, int start): ...@@ -149,6 +149,46 @@ def return_typed_sum_squares_start(seq, int start):
return <int>sum((i*i for i in seq), start) return <int>sum((i*i for i in seq), start)
@cython.test_assert_path_exists('//ForInStatNode',
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode')
def return_sum_of_listcomp_consts_start(seq, int start):
"""
>>> sum([1 for i in range(10) if i > 3], -1)
5
>>> return_sum_of_listcomp_consts_start(range(10), -1)
5
>>> print(sum([1 for i in range(10000) if i > 3], 9))
10005
>>> print(return_sum_of_listcomp_consts_start(range(10000), 9))
10005
"""
return sum([1 for i in seq if i > 3], start)
@cython.test_assert_path_exists('//ForInStatNode',
"//InlinedGeneratorExpressionNode",
# the next test is for a deficiency
# (see InlinedGeneratorExpressionNode.coerce_to()),
# hope this breaks one day
"//CoerceFromPyTypeNode//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode')
def return_typed_sum_of_listcomp_consts_start(seq, int start):
"""
>>> sum([1 for i in range(10) if i > 3], -1)
5
>>> return_typed_sum_of_listcomp_consts_start(range(10), -1)
5
>>> print(sum([1 for i in range(10000) if i > 3], 9))
10005
>>> print(return_typed_sum_of_listcomp_consts_start(range(10000), 9))
10005
"""
return <int>sum([1 for i in seq if i > 3], start)
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
'//ForInStatNode', '//ForInStatNode',
"//InlinedGeneratorExpressionNode") "//InlinedGeneratorExpressionNode")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment