Commit a8d3d9a0 authored by Vitja Makarov's avatar Vitja Makarov

Closure optimization

parent d5ceb489
...@@ -4838,10 +4838,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -4838,10 +4838,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
class InnerFunctionNode(PyCFunctionNode): class InnerFunctionNode(PyCFunctionNode):
# Special PyCFunctionNode that depends on a closure class # Special PyCFunctionNode that depends on a closure class
# #
binding = True binding = True
needs_self_code = True
def self_result_code(self): def self_result_code(self):
return "((PyObject*)%s)" % Naming.cur_scope_cname if self.needs_self_code:
return "((PyObject*)%s)" % (Naming.cur_scope_cname)
return "NULL"
class LambdaNode(InnerFunctionNode): class LambdaNode(InnerFunctionNode):
# Lambda expression node (only used as a function reference) # Lambda expression node (only used as a function reference)
...@@ -4859,7 +4863,6 @@ class LambdaNode(InnerFunctionNode): ...@@ -4859,7 +4863,6 @@ class LambdaNode(InnerFunctionNode):
name = StringEncoding.EncodedString('<lambda>') name = StringEncoding.EncodedString('<lambda>')
def analyse_declarations(self, env): def analyse_declarations(self, env):
#self.def_node.needs_closure = self.needs_closure
self.def_node.analyse_declarations(env) self.def_node.analyse_declarations(env)
self.pymethdef_cname = self.def_node.entry.pymethdef_cname self.pymethdef_cname = self.def_node.entry.pymethdef_cname
env.add_lambda_def(self.def_node) env.add_lambda_def(self.def_node)
......
...@@ -134,7 +134,6 @@ class Context(object): ...@@ -134,7 +134,6 @@ class Context(object):
WithTransform(self), WithTransform(self),
DecoratorTransform(self), DecoratorTransform(self),
AnalyseDeclarationsTransform(self), AnalyseDeclarationsTransform(self),
CreateClosureClasses(self),
AutoTestDictTransform(self), AutoTestDictTransform(self),
EmbedSignature(self), EmbedSignature(self),
EarlyReplaceBuiltinCalls(self), ## Necessary? EarlyReplaceBuiltinCalls(self), ## Necessary?
...@@ -144,6 +143,7 @@ class Context(object): ...@@ -144,6 +143,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
CreateClosureClasses(self), ## After all lookups and type inference
ExpandInplaceOperators(self), ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary? OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(), IterationTransform(),
......
...@@ -1146,11 +1146,13 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1146,11 +1146,13 @@ class FuncDefNode(StatNode, BlockNode):
# #filename string C name of filename string const # #filename string C name of filename string const
# entry Symtab.Entry # entry Symtab.Entry
# needs_closure boolean Whether or not this function has inner functions/classes/yield # needs_closure boolean Whether or not this function has inner functions/classes/yield
# needs_outer_scope boolean Whether or not this function requires outer scope
# directive_locals { string : NameNode } locals defined by cython.locals(...) # directive_locals { string : NameNode } locals defined by cython.locals(...)
py_func = None py_func = None
assmt = None assmt = None
needs_closure = False needs_closure = False
needs_outer_scope = False
modifiers = [] modifiers = []
def analyse_default_values(self, env): def analyse_default_values(self, env):
...@@ -1198,7 +1200,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1198,7 +1200,7 @@ class FuncDefNode(StatNode, BlockNode):
import Buffer import Buffer
lenv = self.local_scope lenv = self.local_scope
if lenv.is_closure_scope: if lenv.is_closure_scope and not lenv.is_passthrough:
outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname, outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname,
Naming.outer_scope_cname) Naming.outer_scope_cname)
else: else:
...@@ -1259,10 +1261,13 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1259,10 +1261,13 @@ class FuncDefNode(StatNode, BlockNode):
cenv = env cenv = env
while cenv.is_py_class_scope or cenv.is_c_class_scope: while cenv.is_py_class_scope or cenv.is_c_class_scope:
cenv = cenv.outer_scope cenv = cenv.outer_scope
if lenv.is_closure_scope: if self.needs_closure:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname)) code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";") code.putln(";")
elif cenv.is_closure_scope: elif self.needs_outer_scope:
if lenv.is_passthrough:
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname)) code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
code.putln(";") code.putln(";")
self.generate_argument_declarations(lenv, code) self.generate_argument_declarations(lenv, code)
...@@ -1314,12 +1319,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1314,12 +1319,14 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}") code.putln("}")
code.put_gotref(Naming.cur_scope_cname) code.put_gotref(Naming.cur_scope_cname)
# Note that it is unsafe to decref the scope at this point. # Note that it is unsafe to decref the scope at this point.
if cenv.is_closure_scope: if self.needs_outer_scope:
code.putln("%s = (%s)%s;" % ( code.putln("%s = (%s)%s;" % (
outer_scope_cname, outer_scope_cname,
cenv.scope_class.type.declaration_code(''), cenv.scope_class.type.declaration_code(''),
Naming.self_cname)) Naming.self_cname))
if self.needs_closure: if lenv.is_passthrough:
code.putln("%s = %s;" % (Naming.cur_scope_cname, outer_scope_cname));
elif self.needs_closure:
# inner closures own a reference to their outer parent # inner closures own a reference to their outer parent
code.put_incref(outer_scope_cname, cenv.scope_class.type) code.put_incref(outer_scope_cname, cenv.scope_class.type)
code.put_giveref(outer_scope_cname) code.put_giveref(outer_scope_cname)
......
...@@ -1317,16 +1317,58 @@ class MarkClosureVisitor(CythonTransform): ...@@ -1317,16 +1317,58 @@ class MarkClosureVisitor(CythonTransform):
class CreateClosureClasses(CythonTransform): class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions # Output closure classes in module scope for all functions
# that need it. # that really need it.
def __init__(self, context):
super(CreateClosureClasses, self).__init__(context)
self.path = []
self.in_lambda = False
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.module_scope = node.scope self.module_scope = node.scope
self.visitchildren(node) self.visitchildren(node)
return node return node
def create_class_from_scope(self, node, target_module_scope): def get_scope_use(self, node):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname) from_closure = []
in_closure = []
for name, entry in node.local_scope.entries.items():
if entry.from_closure:
from_closure.append((name, entry))
elif entry.in_closure and not entry.from_closure:
in_closure.append((name, entry))
return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None):
from_closure, in_closure = self.get_scope_use(node)
in_closure.sort()
# Now from the begining
node.needs_closure = False
node.needs_outer_scope = False
func_scope = node.local_scope func_scope = node.local_scope
cscope = node.entry.scope
while cscope.is_py_class_scope or cscope.is_c_class_scope:
cscope = cscope.outer_scope
if not from_closure and self.path:
if not inner_node:
if not node.assmt:
raise InternalError, "DefNode does not have assignment node"
inner_node = node.assmt.rhs
inner_node.needs_self_code = False
node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure:
return
elif not in_closure:
func_scope.is_passthrough = True
func_scope.scope_class = cscope.scope_class
node.needs_outer_scope = True
return
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name, entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True) pos = node.pos, defining = True, implementing = True)
...@@ -1335,34 +1377,41 @@ class CreateClosureClasses(CythonTransform): ...@@ -1335,34 +1377,41 @@ class CreateClosureClasses(CythonTransform):
class_scope.is_internal = True class_scope.is_internal = True
class_scope.directives = {'final': True} class_scope.directives = {'final': True}
cscope = node.entry.scope if from_closure:
while cscope.is_py_class_scope or cscope.is_c_class_scope: assert cscope.is_closure_scope
cscope = cscope.outer_scope
if cscope.is_closure_scope:
class_scope.declare_var(pos=node.pos, class_scope.declare_var(pos=node.pos,
name=Naming.outer_scope_cname, # this could conflict? name=Naming.outer_scope_cname,
cname=Naming.outer_scope_cname, cname=Naming.outer_scope_cname,
type=cscope.scope_class.type, type=cscope.scope_class.type,
is_cdef=True) is_cdef=True)
entries = func_scope.entries.items() node.needs_outer_scope = True
entries.sort() for name, entry in in_closure:
for name, entry in entries:
# This is wasteful--we should do this later when we know
# which vars are actually being used inside...
#
# Also, this happens before type inference and type
# analysis, so the entries created here may end up having
# incorrect or at least unspecified types.
class_scope.declare_var(pos=entry.pos, class_scope.declare_var(pos=entry.pos,
name=entry.name, name=entry.name,
cname=entry.cname, cname=entry.cname,
type=entry.type, type=entry.type,
is_cdef=True) is_cdef=True)
node.needs_closure = True
# Do it here because other classes are already checked
target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node):
was_in_lambda = self.in_lambda
self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node)
self.visitchildren(node)
self.in_lambda = was_in_lambda
return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
if node.needs_closure: if self.in_lambda:
self.visitchildren(node)
return node
if node.needs_closure or self.path:
self.create_class_from_scope(node, self.module_scope) self.create_class_from_scope(node, self.module_scope)
self.path.append(node)
self.visitchildren(node) self.visitchildren(node)
self.path.pop()
return node return node
......
...@@ -211,7 +211,8 @@ class Scope(object): ...@@ -211,7 +211,8 @@ class Scope(object):
# return_type PyrexType or None Return type of function owning scope # return_type PyrexType or None Return type of function owning scope
# is_py_class_scope boolean Is a Python class scope # is_py_class_scope boolean Is a Python class scope
# is_c_class_scope boolean Is an extension type scope # is_c_class_scope boolean Is an extension type scope
# is_closure_scope boolean # is_closure_scope boolean Is a closure scope
# is_passthrough boolean Outer scope is passed directly
# is_cpp_class_scope boolean Is a C++ class scope # is_cpp_class_scope boolean Is a C++ class scope
# is_property_scope boolean Is a extension type property scope # is_property_scope boolean Is a extension type property scope
# scope_prefix string Disambiguator for C names # scope_prefix string Disambiguator for C names
...@@ -228,6 +229,7 @@ class Scope(object): ...@@ -228,6 +229,7 @@ class Scope(object):
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_closure_scope = 0 is_closure_scope = 0
is_passthrough = 0
is_cpp_class_scope = 0 is_cpp_class_scope = 0
is_property_scope = 0 is_property_scope = 0
is_module_scope = 0 is_module_scope = 0
...@@ -1121,7 +1123,30 @@ class ModuleScope(Scope): ...@@ -1121,7 +1123,30 @@ class ModuleScope(Scope):
# Check defined # Check defined
if not entry.type.scope: if not entry.type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % entry.name) error(entry.pos, "C class '%s' is declared but not defined" % entry.name)
def check_c_class(self, entry):
type = entry.type
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
def check_c_classes(self): def check_c_classes(self):
# Performs post-analysis checking and finishing up of extension types # Performs post-analysis checking and finishing up of extension types
# being implemented in this module. This is called only for the main # being implemented in this module. This is called only for the main
...@@ -1144,28 +1169,8 @@ class ModuleScope(Scope): ...@@ -1144,28 +1169,8 @@ class ModuleScope(Scope):
print("...entry %s %s" % (entry.name, entry)) print("...entry %s %s" % (entry.name, entry))
print("......type = ", entry.type) print("......type = ", entry.type)
print("......visibility = ", entry.visibility) print("......visibility = ", entry.visibility)
type = entry.type self.check_c_class(entry)
name = entry.name
visibility = entry.visibility
# Check defined
if not type.scope:
error(entry.pos, "C class '%s' is declared but not defined" % name)
# Generate typeobj_cname
if visibility != 'extern' and not type.typeobj_cname:
type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
## Generate typeptr_cname
#type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
# Check C methods defined
if type.scope:
for method_entry in type.scope.cfunc_entries:
if not method_entry.is_inherited and not method_entry.func_cname:
error(method_entry.pos, "C method '%s' is declared but not defined" %
method_entry.name)
# Allocate vtable name if necessary
if type.vtabslot_cname:
#print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
def check_c_functions(self): def check_c_functions(self):
# Performs post-analysis checking making sure all # Performs post-analysis checking making sure all
# defined c functions are actually implemented. # defined c functions are actually implemented.
...@@ -1253,6 +1258,8 @@ class LocalScope(Scope): ...@@ -1253,6 +1258,8 @@ class LocalScope(Scope):
entry = Scope.lookup(self, name) entry = Scope.lookup(self, name)
if entry is not None: if entry is not None:
if entry.scope is not self and entry.scope.is_closure_scope: if entry.scope is not self and entry.scope.is_closure_scope:
if hasattr(entry.scope, "scope_class"):
raise InternalError, "lookup() after scope class created."
# The actual c fragment for the different scopes differs # The actual c fragment for the different scopes differs
# on the outside and inside, so we make a new entry # on the outside and inside, so we make a new entry
entry.in_closure = True entry.in_closure = True
...@@ -1270,14 +1277,16 @@ class LocalScope(Scope): ...@@ -1270,14 +1277,16 @@ class LocalScope(Scope):
for entry in self.entries.values(): for entry in self.entries.values():
if entry.from_closure: if entry.from_closure:
cname = entry.outer_entry.cname cname = entry.outer_entry.cname
if cname.startswith(Naming.cur_scope_cname): if self.is_passthrough:
cname = cname[len(Naming.cur_scope_cname)+2:] entry.cname = cname
entry.cname = "%s->%s" % (outer_scope_cname, cname) else:
if cname.startswith(Naming.cur_scope_cname):
cname = cname[len(Naming.cur_scope_cname)+2:]
entry.cname = "%s->%s" % (outer_scope_cname, cname)
elif entry.in_closure: elif entry.in_closure:
entry.original_cname = entry.cname entry.original_cname = entry.cname
entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname) entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
class GeneratorExpressionScope(LocalScope): class GeneratorExpressionScope(LocalScope):
"""Scope for generator expressions and comprehensions. As opposed """Scope for generator expressions and comprehensions. As opposed
to generators, these can be easily inlined in some cases, so all to generators, these can be easily inlined in some cases, so all
......
...@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object): ...@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object):
for entry in scope.entries.values(): for entry in scope.entries.values():
if entry.type is unspecified_type: if entry.type is unspecified_type:
entry.type = py_object_type entry.type = py_object_type
if scope.is_closure_scope:
fix_closure_entries(scope)
return return
dependancies_by_entry = {} # entry -> dependancies dependancies_by_entry = {} # entry -> dependancies
...@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object): ...@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object):
entry.type = py_object_type entry.type = py_object_type
if verbose: if verbose:
message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type)) message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
#if scope.is_closure_scope:
# fix_closure_entries(scope)
def fix_closure_entries(scope):
"""Temporary work-around to fix field types in the closure class
that were unknown at the time of creation and only determined
during type inference.
"""
closure_entries = scope.scope_class.type.scope.entries
for name, entry in scope.entries.iteritems():
if name in closure_entries:
closure_entry = closure_entries[name]
closure_entry.type = entry.type
def find_spanning_type(type1, type2): def find_spanning_type(type1, type2):
if type1 is type2: if type1 is type2:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment