Commit 0f582e92 authored by Tom Niget's avatar Tom Niget

Add support for generators

parent 41659efd
...@@ -25,4 +25,13 @@ class list(Generic[U]): ...@@ -25,4 +25,13 @@ class list(Generic[U]):
assert list[int].first assert list[int].first
class Iterator(Generic[U]):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
def next(it: Iterator[U], default: None) -> U: ...
def print(*args) -> None: ... def print(*args) -> None: ...
def range(*args) -> Iterator[int]: ...
\ No newline at end of file
...@@ -65,6 +65,8 @@ class NodeVisitor(UniversalVisitor): ...@@ -65,6 +65,8 @@ class NodeVisitor(UniversalVisitor):
yield "Future" yield "Future"
elif node.kind == PromiseKind.FORKED: elif node.kind == PromiseKind.FORKED:
yield "Forked" yield "Forked"
elif node.kind == PromiseKind.GENERATOR:
yield "Generator"
else: else:
raise NotImplementedError(node) raise NotImplementedError(node)
yield "<" yield "<"
......
...@@ -210,11 +210,13 @@ class ExpressionVisitor(NodeVisitor): ...@@ -210,11 +210,13 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit(node.orelse) yield from self.visit(node.orelse)
def visit_Yield(self, node: ast.Yield) -> Iterable[str]: def visit_Yield(self, node: ast.Yield) -> Iterable[str]:
if CoroutineMode.GENERATOR in self.generator: #if CoroutineMode.GENERATOR in self.generator:
yield "co_yield" # yield "co_yield"
yield from self.prec("co_yield").visit(node.value) # yield from self.prec("co_yield").visit(node.value)
elif CoroutineMode.FAKE in self.generator: #elif CoroutineMode.FAKE in self.generator:
yield "return" # yield "return"
yield from self.visit(node.value) # yield from self.visit(node.value)
else: #else:
raise NotImplementedError(node) # raise NotImplementedError(node)
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
...@@ -4,7 +4,7 @@ from pathlib import Path ...@@ -4,7 +4,7 @@ from pathlib import Path
from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \ from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future, PyIterator
PRELUDE.vars.update({ PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT), # "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
...@@ -28,6 +28,7 @@ PRELUDE.vars.update({ ...@@ -28,6 +28,7 @@ PRELUDE.vars.update({
"Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)), "Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
"Task": VarDecl(VarKind.LOCAL, TypeType(Task)), "Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
"Future": VarDecl(VarKind.LOCAL, TypeType(Future)), "Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
"Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator))
}) })
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib" typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
......
...@@ -7,7 +7,7 @@ from transpiler.phases.typing.common import ScoperVisitor ...@@ -7,7 +7,7 @@ from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, IncompatibleTypesError, TY_MODULE, \
Promise, TY_NONE, PromiseKind Promise, TY_NONE, PromiseKind, TupleType
@dataclass @dataclass
...@@ -64,6 +64,10 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -64,6 +64,10 @@ class ScoperBlockVisitor(ScoperVisitor):
if self.scope.kind == ScopeKind.FUNCTION_INNER: if self.scope.kind == ScopeKind.FUNCTION_INNER:
self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val) self.root_decls[target.id] = VarDecl(VarKind.OUTER_DECL, decl_val)
return True return True
elif isinstance(target, ast.Tuple):
if not (isinstance(decl_val, TupleType) and len(target.elts) == len(decl_val.args)):
raise IncompatibleTypesError(f"Cannot unpack {decl_val} into {target}")
return any(self.visit_assign_target(t, ty) for t, ty in zip(target.elts, decl_val.args))
else: else:
raise NotImplementedError(target) raise NotImplementedError(target)
...@@ -118,6 +122,19 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -118,6 +122,19 @@ class ScoperBlockVisitor(ScoperVisitor):
if node.orelse: if node.orelse:
raise NotImplementedError(node.orelse) raise NotImplementedError(node.orelse)
def visit_For(self, node: ast.For):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
node.inner_scope = scope
assert isinstance(node.target, ast.Name)
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, TypeVariable())
self.expr().visit(node.iter)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
for b in node.body:
body_visitor.visit(b)
if node.orelse:
raise NotImplementedError(node.orelse)
def visit_Expr(self, node: ast.Expr): def visit_Expr(self, node: ast.Expr):
self.expr().visit(node.value) self.expr().visit(node.value)
......
...@@ -44,6 +44,17 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -44,6 +44,17 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Tuple(self, node: ast.Tuple) -> BaseType: def visit_Tuple(self, node: ast.Tuple) -> BaseType:
return TupleType([self.visit(e) for e in node.elts]) return TupleType([self.visit(e) for e in node.elts])
def visit_Yield(self, node: ast.Yield) -> BaseType:
ytype = self.visit(node.value)
ftype = self.scope.function.obj_type.return_type
assert isinstance(ftype, Promise)
assert ftype.kind == PromiseKind.TASK
ftype.kind = PromiseKind.GENERATOR
ftype.return_type.unify(ytype)
return TY_NONE
def visit_Constant(self, node: ast.Constant) -> BaseType: def visit_Constant(self, node: ast.Constant) -> BaseType:
if isinstance(node.value, str): if isinstance(node.value, str):
return TY_STR return TY_STR
...@@ -77,7 +88,7 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -77,7 +88,7 @@ class ScoperExprVisitor(ScoperVisitor):
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args]) rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype actual = rtype
node.is_await = False node.is_await = False
if isinstance(actual, Promise): if isinstance(actual, Promise) and actual.kind != PromiseKind.GENERATOR:
node.is_await = True node.is_await = True
actual = actual.return_type.resolve() actual = actual.return_type.resolve()
...@@ -202,6 +213,3 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -202,6 +213,3 @@ class ScoperExprVisitor(ScoperVisitor):
if then != else_: if then != else_:
raise NotImplementedError("IfExp with different types not handled yet") raise NotImplementedError("IfExp with different types not handled yet")
return then return then
def visit_Yield(self, node: ast.Yield) -> BaseType:
raise NotImplementedError(node)
...@@ -13,6 +13,10 @@ class IncompatibleTypesError(Exception): ...@@ -13,6 +13,10 @@ class IncompatibleTypesError(Exception):
class BaseType(ABC): class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False) members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False) methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=list, init=False)
def get_parents(self) -> List["BaseType"]:
return self.parents
def resolve(self) -> "BaseType": def resolve(self) -> "BaseType":
return self return self
...@@ -117,6 +121,22 @@ class TypeOperator(BaseType, ABC): ...@@ -117,6 +121,22 @@ class TypeOperator(BaseType, ABC):
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
raise IncompatibleTypesError() raise IncompatibleTypesError()
if type(self) != type(other):
for parent in other.get_parents():
try:
self.unify(parent)
except IncompatibleTypesError:
pass
else:
return
for parent in self.get_parents():
try:
parent.unify(other)
except IncompatibleTypesError:
pass
else:
return
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different type and no common parents")
if len(self.args) != len(other.args) and not (self.variadic or other.variadic): if len(self.args) != len(other.args) and not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments") raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
for a, b in zip(self.args, other.args): for a, b in zip(self.args, other.args):
...@@ -241,6 +261,15 @@ class PyDict(TypeOperator): ...@@ -241,6 +261,15 @@ class PyDict(TypeOperator):
def value_type(self): def value_type(self):
return self.args[1] return self.args[1]
class PyIterator(TypeOperator):
def __init__(self, arg: BaseType):
super().__init__([arg], "iter")
@property
def element_type(self):
return self.args[0]
class TupleType(TypeOperator): class TupleType(TypeOperator):
def __init__(self, args: List[BaseType]): def __init__(self, args: List[BaseType]):
...@@ -252,6 +281,7 @@ class PromiseKind(Enum): ...@@ -252,6 +281,7 @@ class PromiseKind(Enum):
JOIN = 1 JOIN = 1
FUTURE = 2 FUTURE = 2
FORKED = 3 FORKED = 3
GENERATOR = 4
class Promise(TypeOperator, ABC): class Promise(TypeOperator, ABC):
...@@ -273,6 +303,11 @@ class Promise(TypeOperator, ABC): ...@@ -273,6 +303,11 @@ class Promise(TypeOperator, ABC):
def __str__(self): def __str__(self):
return f"{self.kind.name.lower()}<{self.return_type}>" return f"{self.kind.name.lower()}<{self.return_type}>"
def get_parents(self) -> List["BaseType"]:
if self.kind == PromiseKind.GENERATOR:
return [PyIterator(self.return_type), *super().get_parents()]
return super().get_parents()
class Forked(Promise): class Forked(Promise):
"""Only use this for type specs""" """Only use this for type specs"""
def __init__(self, ret: BaseType): def __init__(self, ret: BaseType):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment