Commit 68ea4de6 authored by Tom Niget's avatar Tom Niget

Continue work on async types handling

parent 8acf1377
......@@ -11,6 +11,7 @@ class int:
def __and__(self, other: Self) -> Self: ...
assert int.__add__
U = TypeVar("U")
......@@ -22,5 +23,6 @@ class list(Generic[U]):
def first(self) -> U: ...
assert list[int].first
def print(*args) -> None: ...
......@@ -2,20 +2,30 @@ from typing import Callable, TypeVar, Generic
T = TypeVar("T")
class Fork(Generic[T]):
class Forked(Generic[T]):
def get(self) -> T: ...
class Task(Generic[T]):
pass
class Future(Generic[T]):
def get(self) -> Task[T]: ...
assert Forked[int].get
def fork(f: Callable[[], T]) -> Fork[T]:
def fork(f: Callable[[], T]) -> Task[Forked[T]]:
# stub
class Res:
get = f
return Res
def future(f: Callable[[], T]) -> T:
def future(f: Callable[[], T]) -> Task[Future[T]]:
# stub
return f()
class Res:
get = f
return Res
def sync() -> None:
......
......@@ -24,7 +24,7 @@ def f(x: int):
return x + 1
def fct(param):
def fct(param: int):
loc = f(456)
global glob
loc = 789
......
......@@ -5,8 +5,8 @@ def fibo(n: int) -> int:
return n
a = future(lambda: fibo(n - 1))
b = future(lambda: fibo(n - 2))
return a + b
return a.get() + b.get()
if __name__ == "__main__":
print(fibo(30)) # should display 832040
\ No newline at end of file
print(fibo(20)) # should display 832040
\ No newline at end of file
......@@ -6,7 +6,7 @@ from typing import Iterable
from transpiler.phases.emit_cpp.consts import MAPPINGS
from transpiler.phases.typing import TypeVariable
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, ForkResult
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TY_NONE, Promise, PromiseKind
from transpiler.utils import UnsupportedNodeError
class UniversalVisitor:
......@@ -68,8 +68,21 @@ class NodeVisitor(UniversalVisitor):
yield "int"
elif node is TY_BOOL:
yield "bool"
elif isinstance(node, ForkResult):
yield "Forked<"
elif node is TY_NONE:
yield "void"
elif isinstance(node, Promise):
yield "typon::"
if node.kind == PromiseKind.TASK:
yield "Task"
elif node.kind == PromiseKind.JOIN:
yield "Join"
elif node.kind == PromiseKind.FUTURE:
yield "Future"
elif node.kind == PromiseKind.FORKED:
yield "Forked"
else:
raise NotImplementedError(node)
yield "<"
yield from self.visit(node.return_type)
yield ">"
elif isinstance(node, TypeVariable):
......
......@@ -42,6 +42,11 @@ class BlockVisitor(NodeVisitor):
yield "int main() { root().call(); }"
return
yield "struct {"
yield from self.visit_func_new(node)
yield f"}} {node.name};"
return
yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE)
......@@ -78,6 +83,34 @@ class BlockVisitor(NodeVisitor):
yield "}"
yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef) -> Iterable[str]:
yield from self.visit(node.type.return_type)
yield "operator()"
yield "("
for i, (arg, argty) in enumerate(zip(node.args.args, node.type.parameters)):
if i != 0:
yield ", "
yield from self.visit(argty)
yield arg.arg
yield ")"
inner_scope = node.inner_scope
yield "{"
for child in node.body:
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, CoroutineMode.ASYNC)
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
yield "}"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args, names = self.process_args(node.args)
if templ:
......
......@@ -4,7 +4,7 @@ from pathlib import Path
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, TY_MODULE, CppType, PyList, TypeType, ForkResult
TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
......@@ -24,8 +24,10 @@ PRELUDE.vars.update({
"Callable": VarDecl(VarKind.LOCAL, FunctionType),
"TypeVar": VarDecl(VarKind.LOCAL, TypeVariable),
"CppType": VarDecl(VarKind.LOCAL, CppType),
"list": VarDecl(VarKind.LOCAL, PyList),
"Fork": VarDecl(VarKind.LOCAL, ForkResult),
"list": VarDecl(VarKind.LOCAL, TypeType(PyList)),
"Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
"Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
"Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
})
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
......
......@@ -7,7 +7,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise
Promise, TY_NONE, PromiseKind
@dataclass
......@@ -75,13 +75,14 @@ class ScoperBlockVisitor(ScoperVisitor):
def visit_FunctionDef(self, node: ast.FunctionDef):
argtypes = [self.visit_annotation(arg.annotation) for arg in node.args.args]
rtype = Promise(self.visit_annotation(node.returns))
rtype = Promise(self.visit_annotation(node.returns), PromiseKind.TASK)
ftype = FunctionType(argtypes, rtype)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ftype)
scope = self.scope.child(ScopeKind.FUNCTION)
scope.obj_type = ftype
scope.function = scope
node.inner_scope = scope
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body:
......@@ -89,6 +90,8 @@ class ScoperBlockVisitor(ScoperVisitor):
visitor = ScoperBlockVisitor(scope, decls)
visitor.visit(b)
b.decls = decls
if not scope.has_return:
rtype.return_type.unify(TY_NONE)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......@@ -107,8 +110,9 @@ class ScoperBlockVisitor(ScoperVisitor):
raise IncompatibleTypesError("Return outside function")
ftype = fct.obj_type
assert isinstance(ftype, FunctionType)
vtype = self.expr().visit(node.value) if node.value else None
vtype = self.expr().visit(node.value) if node.value else TY_NONE
vtype.unify(ftype.return_type.return_type if isinstance(ftype.return_type, Promise) else ftype.return_type)
fct.has_return = True
def visit_Global(self, node: ast.Global):
for name in node.names:
......
import abc
import ast
from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import IncompatibleTypesError, BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind
DUNDER = {
ast.Eq: "eq",
......@@ -74,10 +75,28 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func)
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype
node.is_await = False
if isinstance(actual, Promise):
node.is_await = True
actual = actual.return_type.resolve()
if isinstance(actual, Promise) and actual.kind == PromiseKind.FORKED \
and isinstance(fty := self.scope.function.obj_type.return_type, Promise):
fty.kind = PromiseKind.JOIN
return actual
if isinstance(rtype, Promise):
node.is_await = True
return rtype.return_type
node.is_await = False
if rtype.kind == PromiseKind.FORKED \
and isinstance(fty := self.scope.function.obj_type.return_type, Promise):
fty.kind = PromiseKind.JOIN
else:
return rtype.return_type
else:
node.is_await = False
return rtype
def visit_function_call(self, ftype: BaseType, arguments: List[BaseType]):
......@@ -165,6 +184,12 @@ class ScoperExprVisitor(ScoperVisitor):
return PyDict(keys[0], values[0])
def visit_Subscript(self, node: ast.Subscript) -> BaseType:
left = self.visit(node.value)
args = node.slice if type(node.slice) == tuple else [node.slice]
if isinstance(left, TypeType) and isinstance(left.type_object, abc.ABCMeta):
# generic
return TypeType(left.type_object(*[self.visit(e).type_object for e in args]))
pass
raise NotImplementedError(node)
def visit_UnaryOp(self, node: ast.UnaryOp) -> BaseType:
......
......@@ -52,6 +52,7 @@ class Scope:
vars: Dict[str, VarDecl] = field(default_factory=dict)
children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None
has_return: bool = False
@staticmethod
def make_global():
......
import ast
import dataclasses
from abc import ABCMeta
from dataclasses import dataclass, field
from typing import Optional, List, Dict
......@@ -38,8 +39,8 @@ class StdlibVisitor(NodeVisitorSeq):
typevars = []
for b in node.bases:
if isinstance(b, ast.Subscript) and isinstance(b.value, ast.Name) and b.value.id == "Generic":
if isinstance(b.slice, ast.Index):
typevars = [b.slice.value.id]
if isinstance(b.slice, ast.Name):
typevars = [b.slice.id]
elif isinstance(b.slice, ast.Tuple):
typevars = [n.id for n in b.slice.value.elts]
if existing := self.scope.get(node.name):
......@@ -54,6 +55,9 @@ class StdlibVisitor(NodeVisitorSeq):
for stmt in node.body:
visitor.visit(stmt)
def visit_Pass(self, node: ast.Pass):
pass
def visit_FunctionDef(self, node: ast.FunctionDef):
arg_visitor = TypeAnnotationVisitor(self.scope.child(ScopeKind.FUNCTION), self.cur_class)
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
......@@ -63,12 +67,12 @@ class StdlibVisitor(NodeVisitorSeq):
ty.variadic = True
#arg_types.append(TY_VARARG)
if self.cur_class:
if isinstance(self.cur_class, TypeType):
if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else:
# ty_inst = FunctionType(arg_types[1:], ret_type)
# self.cur_class.args[0].add_inst_member(node.name, ty_inst)
self.cur_class.type_object.methods[node.name] = ty.gen_sub(self.cur_class.type_object, self.typevars)
else:
self.cur_class.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
......
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Optional, List, ClassVar, Callable
from enum import Enum
from typing import Dict, Optional, List, ClassVar, Callable, Any
class IncompatibleTypesError(Exception):
pass
@dataclass
@dataclass(eq=False)
class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
......@@ -36,27 +38,33 @@ class BaseType(ABC):
def gen_sub(self, this: "BaseType", typevars) -> "Self":
return self
def __repr__(self):
return str(self)
def to_list(self) -> List["BaseType"]:
return [self]
class MagicType(BaseType):
T = typing.TypeVar("T")
class MagicType(BaseType, typing.Generic[T]):
val: T
def __init__(self, val: T):
super().__init__()
self.val = val
def unify_internal(self, other: "BaseType"):
if type(self) is not type(other):
if type(self) != type(other) or self.val != other.val:
raise IncompatibleTypesError()
def contains_internal(self, other: "BaseType") -> bool:
return False
def __str__(self):
return str(self.val)
cur_var = 0
@dataclass
@dataclass(eq=False)
class TypeVariable(BaseType):
name: str = field(default_factory=lambda: chr(ord('a') + cur_var))
resolved: Optional[BaseType] = None
......@@ -85,8 +93,10 @@ class TypeVariable(BaseType):
return match
return self
GenMethodFactory = Callable[["BaseType"], "FunctionType"]
@dataclass
class TypeOperator(BaseType, ABC):
args: List[BaseType]
......@@ -94,6 +104,10 @@ class TypeOperator(BaseType, ABC):
variadic: bool = False
gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.gen_methods = {}
def __post_init__(self):
if self.name is None:
self.name = self.__class__.__name__
......@@ -106,7 +120,11 @@ class TypeOperator(BaseType, ABC):
if len(self.args) != len(other.args) and not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
for a, b in zip(self.args, other.args):
a.unify(b)
if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b)
else:
if a != b:
raise IncompatibleTypesError(f"Cannot unify {a} and {b}")
def contains_internal(self, other: "BaseType") -> bool:
return any(arg.contains(other) for arg in self.args)
......@@ -158,6 +176,7 @@ class FunctionType(TypeOperator):
args = "()"
return f"{args} -> {ret}"
class CppType(TypeOperator):
def __init__(self, name: str):
super().__init__([name], name)
......@@ -165,10 +184,12 @@ class CppType(TypeOperator):
def __str__(self):
return self.name
class Union(TypeOperator):
def __init__(self, left: BaseType, right: BaseType):
super().__init__([left, right], "Union")
class TypeType(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "Type")
......@@ -189,6 +210,7 @@ TY_VARARG = TypeOperator([], "vararg")
TY_SELF = TypeOperator([], "Self")
TY_SELF.gen_sub = lambda this, typevars: this
class PyList(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "list")
......@@ -197,6 +219,7 @@ class PyList(TypeOperator):
def element_type(self):
return self.args[0]
class PySet(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "set")
......@@ -205,6 +228,7 @@ class PySet(TypeOperator):
def element_type(self):
return self.args[0]
class PyDict(TypeOperator):
def __init__(self, key: BaseType, value: BaseType):
super().__init__([key, value], "dict")
......@@ -217,22 +241,49 @@ class PyDict(TypeOperator):
def value_type(self):
return self.args[1]
class TupleType(TypeOperator):
def __init__(self, args: List[BaseType]):
super().__init__(args, "tuple")
class ForkResult(TypeOperator):
def __init__(self, args: BaseType):
super().__init__([args], "ForkResult")
class PromiseKind(Enum):
TASK = 0
JOIN = 1
FUTURE = 2
FORKED = 3
class Promise(TypeOperator, ABC):
def __init__(self, ret: BaseType, kind: PromiseKind):
super().__init__([ret, MagicType(kind)])
@property
def return_type(self):
def return_type(self) -> BaseType:
return self.args[0]
class Promise(TypeOperator):
def __init__(self, args: BaseType):
super().__init__([args], "Promise")
@property
def return_type(self):
return self.args[0]
\ No newline at end of file
def kind(self) -> PromiseKind:
return self.args[1].val
@kind.setter
def kind(self, value: PromiseKind):
self.args[1].val = value
def __str__(self):
return f"{self.kind.name.lower()}<{self.return_type}>"
class Forked(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FORKED)
class Task(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.TASK)
class Future(Promise):
"""Only use this for type specs"""
def __init__(self, ret: BaseType):
super().__init__(ret, PromiseKind.FUTURE)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment