Commit eee8a2f3 authored by Tom Niget's avatar Tom Niget

Add type checking to for-loops

parent b3e3696f
...@@ -36,6 +36,8 @@ class str: ...@@ -36,6 +36,8 @@ class str:
def format(self, *args) -> Self: ... def format(self, *args) -> Self: ...
def encode(self, encoding: Self) -> bytes: ... def encode(self, encoding: Self) -> bytes: ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def __add__(self, other: Self) -> Self: ...
def __mul__(self, other: int) -> Self: ...
class bytes: class bytes:
def decode(self, encoding: str) -> str: ... def decode(self, encoding: str) -> str: ...
......
...@@ -136,8 +136,18 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -136,8 +136,18 @@ class ScoperBlockVisitor(ScoperVisitor):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope node.inner_scope = scope
assert isinstance(node.target, ast.Name) assert isinstance(node.target, ast.Name)
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, TypeVariable()) var_var = TypeVariable()
self.expr().visit(node.iter) scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var)
seq_type = self.expr().visit(node.iter)
try:
iter_type = seq_type.methods["__iter__"].return_type
except:
raise IncompatibleTypesError(f"{seq_type} is not iterable")
try:
next_type = iter_type.methods["__next__"].return_type
except:
raise IncompatibleTypesError(f"iter({iter_type}) is not an iterator")
var_var.unify(next_type)
body_scope = scope.child(ScopeKind.FUNCTION_INNER) body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls) body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
body_visitor.visit_block(node.body) body_visitor.visit_block(node.body)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment