Commit 4fc031e7 authored by Vitja Makarov's avatar Vitja Makarov

Assignmment based type inference

parent f6b07dae
...@@ -1466,6 +1466,7 @@ class NameNode(AtomicExprNode): ...@@ -1466,6 +1466,7 @@ class NameNode(AtomicExprNode):
cf_is_null = False cf_is_null = False
allow_null = False allow_null = False
nogil = False nogil = False
inferred_type = None
def as_cython_attribute(self): def as_cython_attribute(self):
return self.cython_attribute return self.cython_attribute
...@@ -1474,7 +1475,7 @@ class NameNode(AtomicExprNode): ...@@ -1474,7 +1475,7 @@ class NameNode(AtomicExprNode):
if self.entry is None: if self.entry is None:
self.entry = env.lookup(self.name) self.entry = env.lookup(self.name)
if self.entry is not None and self.entry.type.is_unspecified: if self.entry is not None and self.entry.type.is_unspecified:
return (self.entry,) return (self,)
else: else:
return () return ()
...@@ -1482,6 +1483,8 @@ class NameNode(AtomicExprNode): ...@@ -1482,6 +1483,8 @@ class NameNode(AtomicExprNode):
if self.entry is None: if self.entry is None:
self.entry = env.lookup(self.name) self.entry = env.lookup(self.name)
if self.entry is None or self.entry.type is unspecified_type: if self.entry is None or self.entry.type is unspecified_type:
if self.inferred_type is not None:
return self.inferred_type
return py_object_type return py_object_type
elif (self.entry.type.is_extension_type or self.entry.type.is_builtin_type) and \ elif (self.entry.type.is_extension_type or self.entry.type.is_builtin_type) and \
self.name == self.entry.type.name: self.name == self.entry.type.name:
...@@ -1496,6 +1499,12 @@ class NameNode(AtomicExprNode): ...@@ -1496,6 +1499,12 @@ class NameNode(AtomicExprNode):
# special case: referring to a C function must return its pointer # special case: referring to a C function must return its pointer
return PyrexTypes.CPtrType(self.entry.type) return PyrexTypes.CPtrType(self.entry.type)
else: else:
# If entry is inferred as pyobject it's safe to use local
# NameNode's inferred_type.
if self.entry.type.is_pyobject and self.inferred_type:
# Overflow may happen if integer
if not (self.inferred_type.is_int and self.entry.might_overflow):
return self.inferred_type
return self.entry.type return self.entry.type
def compile_time_value(self, denv): def compile_time_value(self, denv):
......
...@@ -35,6 +35,7 @@ cdef class NameAssignment: ...@@ -35,6 +35,7 @@ cdef class NameAssignment:
cdef public object pos cdef public object pos
cdef public set refs cdef public set refs
cdef public object bit cdef public object bit
cdef public object inferred_type
cdef class AssignmentList: cdef class AssignmentList:
cdef public object bit cdef public object bit
......
...@@ -318,15 +318,23 @@ class NameAssignment(object): ...@@ -318,15 +318,23 @@ class NameAssignment(object):
self.refs = set() self.refs = set()
self.is_arg = False self.is_arg = False
self.is_deletion = False self.is_deletion = False
self.inferred_type = None
def __repr__(self): def __repr__(self):
return '%s(entry=%r)' % (self.__class__.__name__, self.entry) return '%s(entry=%r)' % (self.__class__.__name__, self.entry)
def infer_type(self, scope): def infer_type(self):
return self.rhs.infer_type(scope) self.inferred_type = self.rhs.infer_type(self.entry.scope)
return self.inferred_type
def type_dependencies(self, scope): def type_dependencies(self):
return self.rhs.type_dependencies(scope) return self.rhs.type_dependencies(self.entry.scope)
@property
def type(self):
if not self.entry.type.is_unspecified:
return self.entry.type
return self.inferred_type
class StaticAssignment(NameAssignment): class StaticAssignment(NameAssignment):
...@@ -340,11 +348,11 @@ class StaticAssignment(NameAssignment): ...@@ -340,11 +348,11 @@ class StaticAssignment(NameAssignment):
entry.type, may_be_none=may_be_none, pos=entry.pos) entry.type, may_be_none=may_be_none, pos=entry.pos)
super(StaticAssignment, self).__init__(lhs, lhs, entry) super(StaticAssignment, self).__init__(lhs, lhs, entry)
def infer_type(self, scope): def infer_type(self):
return self.entry.type return self.entry.type
def type_dependencies(self, scope): def type_dependencies(self):
return [] return ()
class Argument(NameAssignment): class Argument(NameAssignment):
...@@ -358,11 +366,12 @@ class NameDeletion(NameAssignment): ...@@ -358,11 +366,12 @@ class NameDeletion(NameAssignment):
NameAssignment.__init__(self, lhs, lhs, entry) NameAssignment.__init__(self, lhs, lhs, entry)
self.is_deletion = True self.is_deletion = True
def infer_type(self, scope): def infer_type(self):
inferred_type = self.rhs.infer_type(scope) inferred_type = self.rhs.infer_type(self.entry.scope)
if (not inferred_type.is_pyobject and if (not inferred_type.is_pyobject and
inferred_type.can_coerce_to_pyobject(scope)): inferred_type.can_coerce_to_pyobject(self.entry.scope)):
return py_object_type return py_object_type
self.inferred_type = inferred_type
return inferred_type return inferred_type
...@@ -409,7 +418,9 @@ class ControlFlowState(list): ...@@ -409,7 +418,9 @@ class ControlFlowState(list):
else: else:
if len(state) == 1: if len(state) == 1:
self.is_single = True self.is_single = True
super(ControlFlowState, self).__init__(state) # XXX: Remove fake_rhs_expr
super(ControlFlowState, self).__init__(
[i for i in state if i.rhs is not fake_rhs_expr])
def one(self): def one(self):
return self[0] return self[0]
......
...@@ -339,8 +339,11 @@ class SimpleAssignmentTypeInferer(object): ...@@ -339,8 +339,11 @@ class SimpleAssignmentTypeInferer(object):
Note: in order to support cross-closure type inference, this must be Note: in order to support cross-closure type inference, this must be
applies to nested scopes in top-down order. applies to nested scopes in top-down order.
""" """
# TODO: Implement a real type inference algorithm. def set_entry_type(self, entry, entry_type):
# (Something more powerful than just extending this one...) entry.type = entry_type
for e in entry.all_entries():
e.type = entry_type
def infer_types(self, scope): def infer_types(self, scope):
enabled = scope.directives['infer_types'] enabled = scope.directives['infer_types']
verbose = scope.directives['infer_types.verbose'] verbose = scope.directives['infer_types.verbose']
...@@ -352,85 +355,126 @@ class SimpleAssignmentTypeInferer(object): ...@@ -352,85 +355,126 @@ class SimpleAssignmentTypeInferer(object):
else: else:
for entry in scope.entries.values(): for entry in scope.entries.values():
if entry.type is unspecified_type: if entry.type is unspecified_type:
entry.type = py_object_type self.set_entry_type(entry, py_object_type)
return return
dependancies_by_entry = {} # entry -> dependancies # Set of assignemnts
entries_by_dependancy = {} # dependancy -> entries assignments = set([])
ready_to_infer = [] assmts_resolved = set([])
dependencies = {}
assmt_to_names = {}
for name, entry in scope.entries.items(): for name, entry in scope.entries.items():
if entry.type is unspecified_type:
all = set()
for assmt in entry.cf_assignments: for assmt in entry.cf_assignments:
all.update(assmt.type_dependencies(entry.scope)) names = assmt.type_dependencies()
if all: assmt_to_names[assmt] = names
dependancies_by_entry[entry] = all assmts = set()
for dep in all: for node in names:
if dep not in entries_by_dependancy: assmts.update(node.cf_state)
entries_by_dependancy[dep] = set([entry]) dependencies[assmt] = assmts
if entry.type is unspecified_type:
assignments.update(entry.cf_assignments)
else: else:
entries_by_dependancy[dep].add(entry) assmts_resolved.update(entry.cf_assignments)
def infer_name_node_type(node):
types = [assmt.inferred_type for assmt in node.cf_state]
if not types:
node_type = py_object_type
else: else:
ready_to_infer.append(entry) node_type = spanning_type(
types, entry.might_overflow, entry.pos)
def resolve_dependancy(dep): node.inferred_type = node_type
if dep in entries_by_dependancy:
for entry in entries_by_dependancy[dep]: def infer_name_node_type_partial(node):
entry_deps = dependancies_by_entry[entry] types = [assmt.inferred_type for assmt in node.cf_state
entry_deps.remove(dep) if assmt.inferred_type is not None]
if not entry_deps and entry != dep: if not types:
del dependancies_by_entry[entry] return
ready_to_infer.append(entry) return spanning_type(types, entry.might_overflow, entry.pos)
# Try to infer things in order... def resolve_assignments(assignments):
resolved = set()
for assmt in assignments:
deps = dependencies[assmt]
# All assignments are resolved
if assmts_resolved.issuperset(deps):
for node in assmt_to_names[assmt]:
infer_name_node_type(node)
# Resolve assmt
inferred_type = assmt.infer_type()
done = False
assmts_resolved.add(assmt)
resolved.add(assmt)
assignments -= resolved
return resolved
def partial_infer(assmt):
partial_types = []
for node in assmt_to_names[assmt]:
partial_type = infer_name_node_type_partial(node)
if partial_type is None:
return False
partial_types.append((node, partial_type))
for node, partial_type in partial_types:
node.inferred_type = partial_type
assmt.infer_type()
return True
partial_assmts = set()
def resolve_partial(assignments):
# try to handle circular references
partials = set()
for assmt in assignments:
partial_types = []
if assmt in partial_assmts:
continue
for node in assmt_to_names[assmt]:
if partial_infer(assmt):
partials.add(assmt)
assmts_resolved.add(assmt)
partial_assmts.update(partials)
return partials
# Infer assignments
while True: while True:
while ready_to_infer: if not resolve_assignments(assignments):
entry = ready_to_infer.pop() if not resolve_partial(assignments):
types = [ break
assmt.rhs.infer_type(scope) inferred = set()
for assmt in entry.cf_assignments # First pass
] for entry in scope.entries.values():
if types and Utils.all(types): if entry.type is not unspecified_type:
entry_type = spanning_type(types, entry.might_overflow, entry.pos) continue
else:
# FIXME: raise a warning?
# print "No assignments", entry.pos, entry
entry_type = py_object_type entry_type = py_object_type
# propagate entry type to all nested scopes if assmts_resolved.issuperset(entry.cf_assignments):
for e in entry.all_entries(): types = [assmt.inferred_type for assmt in entry.cf_assignments]
if e.type is unspecified_type: if types and Utils.all(types):
e.type = entry_type entry_type = spanning_type(
else: types, entry.might_overflow, entry.pos)
# FIXME: can this actually happen? inferred.add(entry)
assert e.type == entry_type, ( self.set_entry_type(entry, entry_type)
'unexpected type mismatch between closures for inferred type %s: %s vs. %s' %
entry_type, e, entry) def reinfer():
if verbose: dirty = False
message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type)) for entry in inferred:
resolve_dependancy(entry) types = [assmt.infer_type()
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [assmt.infer_type(scope)
for assmt in entry.cf_assignments
if assmt.type_dependencies(scope) == ()]
if types:
entry.type = spanning_type(types, entry.might_overflow, entry.pos)
types = [assmt.infer_type(scope)
for assmt in entry.cf_assignments] for assmt in entry.cf_assignments]
entry.type = spanning_type(types, entry.might_overflow, entry.pos) # might be wider... new_type = spanning_type(types, entry.might_overflow, entry.pos)
resolve_dependancy(entry) if new_type != entry.type:
del dependancies_by_entry[entry] self.set_entry_type(entry, new_type)
if ready_to_infer: dirty = True
break return dirty
if not ready_to_infer:
break # types propagation
while reinfer():
pass
# We can't figure out the rest with this algorithm, let them be objects.
for entry in dependancies_by_entry:
entry.type = py_object_type
if verbose: if verbose:
message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type)) for entry in inferred:
message(entry.pos, "inferred '%s' to be of type '%s'" % (
entry.name, entry.type))
def find_spanning_type(type1, type2): def find_spanning_type(type1, type2):
if type1 is type2: if type1 is type2:
......
cimport cython
from cython cimport typeof, infer_types
def test_swap():
"""
>>> test_swap()
"""
a = 0
b = 1
tmp = a
a = b
b = tmp
assert typeof(a) == "long", typeof(a)
assert typeof(b) == "long", typeof(b)
assert typeof(tmp) == "long", typeof(tmp)
def test_object_assmt():
"""
>>> test_object_assmt()
"""
a = 1
b = a
a = "str"
assert typeof(a) == "Python object", typeof(a)
assert typeof(b) == "long", typeof(b)
def test_long_vs_double(cond):
"""
>>> test_long_vs_double(0)
"""
assert typeof(a) == "double", typeof(a)
assert typeof(b) == "double", typeof(b)
assert typeof(c) == "double", typeof(c)
assert typeof(d) == "double", typeof(d)
if cond:
a = 1
b = 2
c = (a + b) / 2
else:
a = 1.0
b = 2.0
d = (a + b) / 2
def test_double_vs_pyobject():
"""
>>> test_double_vs_pyobject()
"""
assert typeof(a) == "Python object", typeof(a)
assert typeof(b) == "Python object", typeof(b)
assert typeof(d) == "double", typeof(d)
a = []
b = []
a = 1.0
b = 2.0
d = (a + b) / 2
def test_python_objects(cond):
"""
>>> test_python_objects(0)
"""
if cond == 1:
a = [1, 2, 3]
o_list = a
elif cond == 2:
a = set([1, 2, 3])
o_set = a
else:
a = {1:1, 2:2, 3:3}
o_dict = a
assert typeof(a) == "Python object", typeof(a)
assert typeof(o_list) == "list object", typeof(o_list)
assert typeof(o_dict) == "dict object", typeof(o_dict)
assert typeof(o_set) == "set object", typeof(o_set)
# CF loops
def test_cf_loop():
"""
>>> test_cf_loop()
"""
cdef int i
a = 0.0
for i in range(3):
a += 1
assert typeof(a) == "double", typeof(a)
def test_cf_loop_intermediate():
"""
>>> test_cf_loop()
"""
cdef int i
a = 0
for i in range(3):
b = a
a = b + 1
assert typeof(a) == "long", typeof(a)
assert typeof(b) == "long", typeof(b)
# Integer overflow
def test_integer_overflow():
"""
>>> test_integer_overflow()
"""
a = 1
b = 2
c = a + b
assert typeof(a) == "Python object", typeof(a)
assert typeof(b) == "Python object", typeof(b)
assert typeof(c) == "Python object", typeof(c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment