Commit f9c385e0 authored by Stefan Behnel's avatar Stefan Behnel

refactor analyse_types() and friends to work more like a transform by...

refactor analyse_types() and friends to work more like a transform by returning the node or a replacement
parent 927e1a4d
This diff is collapsed.
...@@ -673,22 +673,22 @@ class FusedCFuncDefNode(StatListNode): ...@@ -673,22 +673,22 @@ class FusedCFuncDefNode(StatListNode):
specialization_type.create_declaration_utility_code(env) specialization_type.create_declaration_utility_code(env)
if self.py_func: if self.py_func:
self.__signatures__.analyse_expressions(env) self.__signatures__ = self.__signatures__.analyse_expressions(env)
self.py_func.analyse_expressions(env) self.py_func = self.py_func.analyse_expressions(env)
self.resulting_fused_function.analyse_expressions(env) self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
self.fused_func_assignment.analyse_expressions(env) self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
self.defaults = defaults = [] self.defaults = defaults = []
for arg in self.node.args: for arg in self.node.args:
if arg.default: if arg.default:
arg.default.analyse_expressions(env) arg.default = arg.default.analyse_expressions(env)
defaults.append(ProxyNode(arg.default)) defaults.append(ProxyNode(arg.default))
else: else:
defaults.append(None) defaults.append(None)
for stat in self.stats: for i, stat in enumerate(self.stats):
stat.analyse_expressions(env) stat = self.stats[i] = stat.analyse_expressions(env)
if isinstance(stat, FuncDefNode): if isinstance(stat, FuncDefNode):
for arg, default in zip(stat.args, defaults): for arg, default in zip(stat.args, defaults):
if default is not None: if default is not None:
...@@ -697,7 +697,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -697,7 +697,7 @@ class FusedCFuncDefNode(StatListNode):
if self.py_func: if self.py_func:
args = [CloneNode(default) for default in defaults if default] args = [CloneNode(default) for default in defaults if default]
self.defaults_tuple = TupleNode(self.pos, args=args) self.defaults_tuple = TupleNode(self.pos, args=args)
self.defaults_tuple.analyse_types(env, skip_children=True) self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True)
self.defaults_tuple = ProxyNode(self.defaults_tuple) self.defaults_tuple = ProxyNode(self.defaults_tuple)
self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object) self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)
...@@ -705,10 +705,11 @@ class FusedCFuncDefNode(StatListNode): ...@@ -705,10 +705,11 @@ class FusedCFuncDefNode(StatListNode):
fused_func.defaults_tuple = CloneNode(self.defaults_tuple) fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
fused_func.code_object = CloneNode(self.code_object) fused_func.code_object = CloneNode(self.code_object)
for pycfunc in self.specialized_pycfuncs: for i, pycfunc in enumerate(self.specialized_pycfuncs):
pycfunc.code_object = CloneNode(self.code_object) pycfunc.code_object = CloneNode(self.code_object)
pycfunc.analyse_types(env) pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
pycfunc.defaults_tuple = CloneNode(self.defaults_tuple) pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
return self
def synthesize_defnodes(self): def synthesize_defnodes(self):
""" """
......
This diff is collapsed.
...@@ -100,7 +100,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -100,7 +100,7 @@ class IterationTransform(Visitor.EnvTransform):
iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
body=if_node, body=if_node,
else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
for_loop.analyse_expressions(self.current_env()) for_loop = for_loop.analyse_expressions(self.current_env())
for_loop = self.visit(for_loop) for_loop = self.visit(for_loop)
new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
...@@ -704,7 +704,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -704,7 +704,7 @@ class IterationTransform(Visitor.EnvTransform):
dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
key_target, value_target, tuple_target, key_target, value_target, tuple_target,
is_dict_temp) is_dict_temp)
iter_next_node.analyse_expressions(self.current_env()) iter_next_node = iter_next_node.analyse_expressions(self.current_env())
body.stats[0:0] = [iter_next_node] body.stats[0:0] = [iter_next_node]
if method: if method:
...@@ -1187,7 +1187,7 @@ class SimplifyCalls(Visitor.EnvTransform): ...@@ -1187,7 +1187,7 @@ class SimplifyCalls(Visitor.EnvTransform):
node.pos, node.pos,
function=node.function, function=node.function,
args=args) args=args)
call_node.analyse_types(self.current_env()) call_node = call_node.analyse_types(self.current_env())
if node.type != call_node.type: if node.type != call_node.type:
call_node = call_node.coerce_to( call_node = call_node.coerce_to(
node.type, self.current_env()) node.type, self.current_env())
......
...@@ -1819,20 +1819,20 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1819,20 +1819,20 @@ class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.scope.infer_types() node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body = node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
node.local_scope.infer_types() node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope) node.body = node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.has_local_scope: if node.has_local_scope:
node.expr_scope.infer_types() node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope) node = node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
......
...@@ -33,9 +33,11 @@ class TempRefNode(AtomicExprNode): ...@@ -33,9 +33,11 @@ class TempRefNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
assert self.type == self.handle.type assert self.type == self.handle.type
return self
def analyse_target_types(self, env): def analyse_target_types(self, env):
assert self.type == self.handle.type assert self.type == self.handle.type
return self
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
pass pass
...@@ -104,7 +106,8 @@ class TempsBlockNode(Node): ...@@ -104,7 +106,8 @@ class TempsBlockNode(Node):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code) self.body.generate_function_definitions(env, code)
...@@ -149,6 +152,7 @@ class ResultRefNode(AtomicExprNode): ...@@ -149,6 +152,7 @@ class ResultRefNode(AtomicExprNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.expression is not None: if self.expression is not None:
self.type = self.expression.type self.type = self.expression.type
return self
def infer_type(self, env): def infer_type(self, env):
if self.type is not None: if self.type is not None:
...@@ -263,9 +267,10 @@ class EvalWithTempExprNode(ExprNodes.ExprNode, LetNodeMixin): ...@@ -263,9 +267,10 @@ class EvalWithTempExprNode(ExprNodes.ExprNode, LetNodeMixin):
return self.subexpression.result() return self.subexpression.result()
def analyse_types(self, env): def analyse_types(self, env):
self.temp_expression.analyse_types(env) self.temp_expression = self.temp_expression.analyse_types(env)
self.subexpression.analyse_types(env) self.subexpression = self.subexpression.analyse_types(env)
self.type = self.subexpression.type self.type = self.subexpression.type
return self
def free_subexpr_temps(self, code): def free_subexpr_temps(self, code):
self.subexpression.free_temps(code) self.subexpression.free_temps(code)
...@@ -302,8 +307,9 @@ class LetNode(Nodes.StatNode, LetNodeMixin): ...@@ -302,8 +307,9 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.temp_expression.analyse_expressions(env) self.temp_expression = self.temp_expression.analyse_expressions(env)
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.setup_temp_expr(code) self.setup_temp_expr(code)
...@@ -335,7 +341,8 @@ class TempResultFromStatNode(ExprNodes.ExprNode): ...@@ -335,7 +341,8 @@ class TempResultFromStatNode(ExprNodes.ExprNode):
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_types(self, env): def analyse_types(self, env):
self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self
def generate_result_code(self, code): def generate_result_code(self, code):
self.result_ref.result_code = self.result() self.result_ref.result_code = self.result()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment