Commit 9f0d1621 authored by Tom Niget's avatar Tom Niget

Analyze functions in two passes (signatures, then bodies) to allow for forward use

parent 3a5d7a43
......@@ -38,8 +38,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
def visit_Module(self, node: ast.Module):
for stmt in node.body:
self.visit(stmt)
self.visit_block(node.body)
def get_type(self, node: ast.expr) -> BaseType:
if type := getattr(node, "type", None):
......@@ -98,13 +97,7 @@ class ScoperBlockVisitor(ScoperVisitor):
ftype.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.return_type.unify(TY_NONE)
self.fdecls.append((node, rtype.return_type))
def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name)
......@@ -115,8 +108,7 @@ class ScoperBlockVisitor(ScoperVisitor):
node.inner_scope = scope
node.type = ctype
visitor = ScoperClassVisitor(scope)
for b in node.body:
visitor.visit(b)
visitor.visit_block(node.body)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......@@ -124,13 +116,11 @@ class ScoperBlockVisitor(ScoperVisitor):
self.expr().visit(node.test)
then_scope = scope.child(ScopeKind.FUNCTION_INNER)
then_visitor = ScoperBlockVisitor(then_scope, self.root_decls)
for b in node.body:
then_visitor.visit(b)
then_visitor.visit_block(node.body)
if node.orelse:
else_scope = scope.child(ScopeKind.FUNCTION_INNER)
else_visitor = ScoperBlockVisitor(else_scope, self.root_decls)
for b in node.orelse:
else_visitor.visit(b)
else_visitor.visit_block(node.orelse.body)
def visit_While(self, node: ast.While):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......@@ -138,8 +128,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.expr().visit(node.test)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
for b in node.body:
body_visitor.visit(b)
body_visitor.visit_block(node.body)
if node.orelse:
raise NotImplementedError(node.orelse)
......@@ -151,8 +140,7 @@ class ScoperBlockVisitor(ScoperVisitor):
self.expr().visit(node.iter)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
for b in node.body:
body_visitor.visit(b)
body_visitor.visit_block(node.body)
if node.orelse:
raise NotImplementedError(node.orelse)
......@@ -183,6 +171,8 @@ class ScoperBlockVisitor(ScoperVisitor):
if isinstance(node, ast.AST):
super().visit(node)
node.scope = self.scope
else:
raise NotImplementedError(node)
def visit_Break(self, node: ast.Break):
pass # TODO: check in loop
# coding: utf-8
import ast
from dataclasses import dataclass
from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType
@dataclass
class ScoperClassVisitor(ScoperVisitor):
fdecls: list[(ast.FunctionDef, BaseType)] = field(default_factory=list)
def visit_AnnAssign(self, node: ast.AnnAssign):
assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
......@@ -20,6 +23,9 @@ class ScoperClassVisitor(ScoperVisitor):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
argtypes[0].unify(self.scope.obj_type) # self parameter
rtype = self.visit_annotation(node.returns)
inner_rtype = rtype
if node.name != "__init__":
rtype = Promise(rtype, PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype)
self.scope.obj_type.methods[node.name] = ftype
scope = self.scope.child(ScopeKind.FUNCTION)
......@@ -29,10 +35,4 @@ class ScoperClassVisitor(ScoperVisitor):
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.unify(TY_NONE)
self.fdecls.append((node, inner_rtype))
......@@ -4,7 +4,7 @@ from typing import Dict, Optional
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE
from transpiler.phases.utils import NodeVisitorSeq
PRELUDE = Scope.make_global()
......@@ -18,4 +18,18 @@ class ScoperVisitor(NodeVisitorSeq):
return TypeAnnotationVisitor(self.scope)
def visit_annotation(self, expr: Optional[ast.expr]) -> BaseType:
return self.anno().visit(expr) if expr else TypeVariable()
\ No newline at end of file
return self.anno().visit(expr) if expr else TypeVariable()
def visit_block(self, block: list[ast.AST]):
from transpiler.phases.typing.block import ScoperBlockVisitor
self.fdecls = []
for b in block:
self.visit(b)
for node, rtype in self.fdecls:
for b in node.body:
decls = {}
visitor = ScoperBlockVisitor(node.inner_scope, decls)
visitor.visit(b)
b.decls = decls
if not node.inner_scope.has_return:
rtype.unify(TY_NONE)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment