Commit e523e730 authored by Tom Niget's avatar Tom Niget

Handle free generic functions

parent d12286eb
......@@ -34,12 +34,15 @@ class Iterator(Generic[U]):
def __next__(self) -> U: ...
# type: TypeVar("U")
def next(it: Iterator[U], default: None) -> U:
...
# what happens with multiple functions
def identity(x: U) -> U:
...
assert identity(1)
assert identity("a")
def print(*args) -> None: ...
......@@ -47,5 +50,7 @@ def print(*args) -> None: ...
def range(*args) -> Iterator[int]: ...
def rangeb(*args) -> Iterator[bool]: ...
assert [6].__add__
assert [True].__add__
assert next(range(6), None)
assert next(rangeb(6), None)
\ No newline at end of file
import ast
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, List
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeType, TY_SELF
from transpiler.phases.typing.types import BaseType, TY_NONE, TypeType, TY_SELF, TypeVariable
from transpiler.phases.utils import NodeVisitorSeq
......@@ -11,12 +11,19 @@ from transpiler.phases.utils import NodeVisitorSeq
class TypeAnnotationVisitor(NodeVisitorSeq):
scope: Scope
cur_class: Optional[TypeType] = None
typevars: List[TypeVariable] = field(default_factory=list)
def visit_str(self, node: str) -> BaseType:
if node in ("Self", "self") and self.cur_class:
return TY_SELF
if existing := self.scope.get(node):
ty = existing.type
if isinstance(ty, TypeVariable):
if existing is not self.scope.vars.get(node, None):
# Type variable from outer scope, so we copy it
ty = TypeVariable(ty.name)
self.scope.declare_local(node, ty) # todo: unneeded?
self.typevars.append(ty)
if isinstance(ty, TypeType):
return ty.type_object
return ty
......
......@@ -85,6 +85,8 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Call(self, node: ast.Call) -> BaseType:
ftype = self.visit(node.func)
if ftype.typevars:
ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars})
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
actual = rtype
node.is_await = False
......
......@@ -64,6 +64,7 @@ class StdlibVisitor(NodeVisitorSeq):
arg_types = [arg_visitor.visit(arg.annotation or arg.arg) for arg in node.args.args]
ret_type = arg_visitor.visit(node.returns)
ty = FunctionType(arg_types, ret_type)
ty.typevars = arg_visitor.typevars
if node.args.vararg:
ty.variadic = True
if self.cur_class:
......
......@@ -15,6 +15,7 @@ class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=list, init=False)
typevars: List["TypeVariable"] = field(default_factory=list, init=False)
def get_parents(self) -> List["BaseType"]:
return self.parents
......@@ -40,7 +41,7 @@ class BaseType(ABC):
def contains_internal(self, other: "BaseType") -> bool:
pass
def gen_sub(self, this: "BaseType", typevars) -> "Self":
def gen_sub(self, this: "BaseType", typevars: Dict[str, "BaseType"]) -> "Self":
return self
def to_list(self) -> List["BaseType"]:
......@@ -181,11 +182,11 @@ class TypeOperator(BaseType, ABC):
return hash((self.name, tuple(self.args)))
def gen_sub(self, this: BaseType, typevars) -> "Self":
res = object.__new__(self.__class__)
res = object.__new__(self.__class__) # todo: ugly... should make a clone()
if isinstance(this, TypeOperator):
vardict = dict(zip(typevars.keys(), this.args))
else:
vardict = {}
vardict = typevars
res.args = [arg.resolve().gen_sub(this, vardict) for arg in self.args]
res.name = self.name
res.variadic = self.variadic
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment