Commit 3da937ba authored by Tom Niget's avatar Tom Niget

Unify function and method parsing code

parent 014ef7ed
......@@ -11,7 +11,7 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT
from transpiler.phases.utils import PlainBlock, AnnotationName
......@@ -140,30 +140,12 @@ class ScoperBlockVisitor(ScoperVisitor):
else:
raise NotImplementedError(ast.unparse(target))
def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None:
res = TypeVariable()
arg.annotation = AnnotationName(res)
return res
else:
return self.visit_annotation(arg.annotation)
def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.annotate_arg(arg) for arg in node.args.args]
rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype)
ftype = self.parse_function(node)
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
ftype.optional_at = len(node.args.args) - len(node.args.defaults)
for ty, default in zip(argtypes[ftype.optional_at:], node.args.defaults):
self.expr().visit(default).unify(ty)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, rtype.return_type))
def visit_ClassDef(self, node: ast.ClassDef):
ctype = UserType(node.name)
......@@ -206,13 +188,43 @@ class ScoperBlockVisitor(ScoperVisitor):
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
else:
raise NotImplementedError(deco)
for base in node.bases:
base = self.expr().visit(base)
if is_builtin(base, "Enum"):
ctype.parents.append(TY_INT)
for k in ctype.members:
ctype.members[k] = ctype
ctype.members["value"] = TY_INT
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), ast.arg(arg="value")],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr="value")],
value=ast.Name(id="value"),
**lnd
)
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
node.body.append(init_method)
ctype.is_enum = True
else:
raise NotImplementedError(base)
......@@ -228,6 +240,8 @@ class ScoperBlockVisitor(ScoperVisitor):
else_visitor = ScoperBlockVisitor(else_scope, self.root_decls)
else_visitor.visit_block(node.orelse)
node.orelse_scope = else_scope
if then_scope.diverges and else_scope.diverges:
self.scope.diverges = True
def visit_While(self, node: ast.While):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......@@ -281,7 +295,8 @@ class ScoperBlockVisitor(ScoperVisitor):
assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else TY_NONE
vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
fct.has_return = True
self.scope.diverges = True
#fct.has_return = True
def visit_Global(self, node: ast.Global):
for name in node.names:
......@@ -348,6 +363,7 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node.finalbody)
def visit_Raise(self, node: ast.Raise):
self.scope.diverges = True
if node.exc:
self.expr().visit(node.exc)
if node.cause:
......
......@@ -26,23 +26,11 @@ class ScoperClassVisitor(ScoperVisitor):
self.scope.obj_type.members[node.targets[0].id] = valtype
def visit_FunctionDef(self, node: ast.FunctionDef):
from transpiler.phases.typing.block import ScoperBlockVisitor
# TODO: maybe merge this code with ScoperBlockVisitor.visit_FunctionDef
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
argtypes[0].unify(self.scope.obj_type) # self parameter
rtype = self.visit_annotation(node.returns)
inner_rtype = rtype
ftype = self.parse_function(node)
ftype.parameters[0].unify(self.scope.obj_type)
inner = ftype.return_type
if node.name != "__init__":
rtype = Promise(rtype, PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype)
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype.is_method = True
self.scope.obj_type.methods[node.name] = ftype
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
res = (node, inner_rtype)
self.fdecls.append(res)
return res
return (node, inner)
......@@ -4,9 +4,10 @@ from typing import Dict, Optional
from transpiler.utils import highlight
from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature
from transpiler.phases.utils import NodeVisitorSeq
from transpiler.phases.typing.scope import Scope, ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.types import BaseType, TypeVariable, TY_NONE, TypeType, BuiltinFeature, FunctionType, \
Promise, PromiseKind
from transpiler.phases.utils import NodeVisitorSeq, AnnotationName
PRELUDE = Scope.make_global()
......@@ -28,6 +29,31 @@ class ScoperVisitor(NodeVisitorSeq):
assert not isinstance(res, TypeType)
return res
def annotate_arg(self, arg: ast.arg) -> BaseType:
if arg.annotation is None:
res = TypeVariable()
arg.annotation = AnnotationName(res)
return res
else:
return self.visit_annotation(arg.annotation)
def parse_function(self, node: ast.FunctionDef):
argtypes = [self.annotate_arg(arg) for arg in node.args.args]
rtype = self.visit_annotation(node.returns)
ftype = FunctionType(argtypes, rtype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
ftype.optional_at = len(node.args.args) - len(node.args.defaults)
for ty, default in zip(argtypes[ftype.optional_at:], node.args.defaults):
self.expr().visit(default).unify(ty)
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, rtype))
return ftype
def visit_block(self, block: list[ast.AST]):
if not block:
return
......@@ -69,11 +95,16 @@ class ScoperVisitor(NodeVisitorSeq):
elif len(visitor.fdecls) == 1:
fnode, frtype = visitor.fdecls[0]
self.visit_function_definition(fnode, frtype)
del node.inner_scope.vars[fnode.name]
#del node.inner_scope.vars[fnode.name]
visitor.visit_assign_target(ast.Name(fnode.name), fnode.type)
b.decls = decls
if not node.inner_scope.has_return:
rtype.unify(TY_NONE) # todo: properly indicate missing return
if not node.inner_scope.diverges and not (isinstance(node.type.return_type, Promise) and node.type.return_type.kind == PromiseKind.GENERATOR):
from transpiler.phases.typing.exceptions import TypeMismatchError
try:
rtype.unify(TY_NONE)
except TypeMismatchError as e:
from transpiler.phases.typing.exceptions import MissingReturnError
raise MissingReturnError(node) from e
def get_iter(seq_type):
try:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment