Commit b3e3696f authored by Tom Niget's avatar Tom Niget

Add initial (WIP) support for protocols

parent 5c8b8830
from typing import Self, TypeVar, Generic from typing import Self, TypeVar, Generic, Protocol
class int: class int:
def __add__(self, other: Self) -> Self: ... def __add__(self, other: Self) -> Self: ...
...@@ -17,29 +17,31 @@ V = TypeVar("V") ...@@ -17,29 +17,31 @@ V = TypeVar("V")
# TODO: really, these should work as interfaces, on a duck-typing basis. it's gonna be a hell of a ride to implement # TODO: really, these should work as interfaces, on a duck-typing basis. it's gonna be a hell of a ride to implement
# unification for this # unification for this
class HasLen: class HasLen(Protocol):
def __len__(self) -> int: ... def __len__(self) -> int: ...
def len(x: HasLen) -> int: def len(x: HasLen) -> int:
... ...
class Iterator(Generic[U]): class Iterator(Protocol[U]):
def __iter__(self) -> Self: ... def __iter__(self) -> Self: ...
def __next__(self) -> U: ... def __next__(self) -> U: ...
class Iterable(Generic[U]): class Iterable(Protocol[U]):
def __iter__(self) -> Iterator[U]: ... def __iter__(self) -> Iterator[U]: ...
class str(HasLen): class str:
def find(self, sub: Self) -> int: ... def find(self, sub: Self) -> int: ...
def format(self, *args) -> Self: ... def format(self, *args) -> Self: ...
def encode(self, encoding: Self) -> bytes: ... def encode(self, encoding: Self) -> bytes: ...
def __len__(self) -> int: ...
class bytes(HasLen): class bytes:
def decode(self, encoding: str) -> str: ... def decode(self, encoding: str) -> str: ...
def __len__(self) -> int: ...
class list(Generic[U], HasLen, Iterable[U]): class list(Generic[U]):
def __add__(self, other: Self) -> Self: ... def __add__(self, other: Self) -> Self: ...
...@@ -48,6 +50,8 @@ class list(Generic[U], HasLen, Iterable[U]): ...@@ -48,6 +50,8 @@ class list(Generic[U], HasLen, Iterable[U]):
def __getitem__(self, index: int) -> U: ... def __getitem__(self, index: int) -> U: ...
def pop(self, index: int = -1) -> U: ... def pop(self, index: int = -1) -> U: ...
def __iter__(self) -> Iterator[U]: ...
def __len__(self) -> int: ...
assert [1, 2].__iter__() assert [1, 2].__iter__()
assert list[int].__iter__ assert list[int].__iter__
......
...@@ -56,6 +56,9 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -56,6 +56,9 @@ class StdlibVisitor(NodeVisitorSeq):
sliceval = [n.id for n in b.slice.value.elts] sliceval = [n.id for n in b.slice.value.elts]
if isinstance(b.value, ast.Name) and b.value.id == "Generic": if isinstance(b.value, ast.Name) and b.value.id == "Generic":
typevars = sliceval typevars = sliceval
elif isinstance(b.value, ast.Name) and b.value.id == "Protocol":
typevars = sliceval
ty.type_object.is_protocol_gen = True
else: else:
idxs = [typevars.index(v) for v in sliceval] idxs = [typevars.index(v) for v in sliceval]
parent = self.visit(b.value) parent = self.visit(b.value)
...@@ -63,12 +66,15 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -63,12 +66,15 @@ class StdlibVisitor(NodeVisitorSeq):
assert isinstance(ty.type_object, ABCMeta) assert isinstance(ty.type_object, ABCMeta)
ty.type_object.gen_parents.append(lambda selfvars: parent.type_object(*[selfvars[i] for i in idxs])) ty.type_object.gen_parents.append(lambda selfvars: parent.type_object(*[selfvars[i] for i in idxs]))
else: else:
parent = self.visit(b) if isinstance(b, ast.Name) and b.id == "Protocol":
assert isinstance(parent, TypeType) ty.type_object.is_protocol_gen = True
if isinstance(ty.type_object, ABCMeta):
ty.type_object.gen_parents.append(parent.type_object)
else: else:
ty.type_object.parents.append(parent.type_object) parent = self.visit(b)
assert isinstance(parent, TypeType)
if isinstance(ty.type_object, ABCMeta):
ty.type_object.gen_parents.append(parent.type_object)
else:
ty.type_object.parents.append(parent.type_object)
if not typevars and not existing: if not typevars and not existing:
ty.type_object = ty.type_object() ty.type_object = ty.type_object()
cl_scope = self.scope.child(ScopeKind.CLASS) cl_scope = self.scope.child(ScopeKind.CLASS)
......
...@@ -121,6 +121,9 @@ class TypeOperator(BaseType, ABC): ...@@ -121,6 +121,9 @@ class TypeOperator(BaseType, ABC):
optional_at: Optional[int] = None optional_at: Optional[int] = None
gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {} gen_methods: ClassVar[Dict[str, GenMethodFactory]] = {}
gen_parents: ClassVar[List[BaseType]] = [] gen_parents: ClassVar[List[BaseType]] = []
is_protocol: bool = False
is_protocol_gen: ClassVar[bool] = False
match_cache: set["TypeOperator"] = field(default_factory=set, init=False)
@staticmethod @staticmethod
def make_type(name: str): def make_type(name: str):
...@@ -144,12 +147,31 @@ class TypeOperator(BaseType, ABC): ...@@ -144,12 +147,31 @@ class TypeOperator(BaseType, ABC):
gp = gp(self.args) gp = gp(self.args)
self.parents.append(gp) self.parents.append(gp)
self.methods = {**gp.methods, **self.methods} self.methods = {**gp.methods, **self.methods}
self.is_protocol = self.is_protocol or self.is_protocol_gen
def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache:
return
try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol))
for name, ty in dupl.methods.items():
corresp = self.methods[name]
corresp.remove_self().unify(ty.remove_self())
except Exception as e:
self.match_cache.remove(hash(protocol))
raise IncompatibleTypesError(f"Type {self} doesn't implement protocol {protocol}: {e}")
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
raise IncompatibleTypesError() raise IncompatibleTypesError()
if other.is_protocol and not self.is_protocol:
return other.unify_internal(self)
if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self)
if len(self.args) < len(other.args): if len(self.args) < len(other.args):
return other.unify_internal(self) return other.unify_internal(self)
assert self.is_protocol == other.is_protocol
if type(self) != type(other): if type(self) != type(other):
for parent in other.get_parents(): for parent in other.get_parents():
try: try:
...@@ -211,6 +233,7 @@ class TypeOperator(BaseType, ABC): ...@@ -211,6 +233,7 @@ class TypeOperator(BaseType, ABC):
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args] res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()} res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents] res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
res.is_protocol = self.is_protocol
return res return res
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
...@@ -360,6 +383,9 @@ class Promise(TypeOperator, ABC): ...@@ -360,6 +383,9 @@ class Promise(TypeOperator, ABC):
@kind.setter @kind.setter
def kind(self, value: PromiseKind): def kind(self, value: PromiseKind):
if value == PromiseKind.GENERATOR:
self.methods["__iter__"] = FunctionType([], self)
self.methods["__next__"] = FunctionType([], self.return_type)
self.args[1].val = value self.args[1].val = value
def __str__(self): def __str__(self):
...@@ -367,7 +393,7 @@ class Promise(TypeOperator, ABC): ...@@ -367,7 +393,7 @@ class Promise(TypeOperator, ABC):
def get_parents(self) -> List["BaseType"]: def get_parents(self) -> List["BaseType"]:
if self.kind == PromiseKind.GENERATOR: if self.kind == PromiseKind.GENERATOR:
return [PyIterator(self.return_type), *super().get_parents()] return [*super().get_parents()]
return super().get_parents() return super().get_parents()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment