Commit 3b3d5c13 authored by Tom Niget's avatar Tom Niget

Fix support for generic parent classes

parent ba47aacb
...@@ -22,6 +22,14 @@ class HasLen: ...@@ -22,6 +22,14 @@ class HasLen:
def len(x: HasLen) -> int: def len(x: HasLen) -> int:
... ...
class Iterator(Generic[U]):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
class Iterable(Generic[U]):
def __iter__(self) -> Iterator[U]: ...
class str(HasLen): class str(HasLen):
def find(self, sub: Self) -> int: ... def find(self, sub: Self) -> int: ...
def format(self, *args) -> Self: ... def format(self, *args) -> Self: ...
...@@ -30,7 +38,7 @@ class str(HasLen): ...@@ -30,7 +38,7 @@ class str(HasLen):
class bytes(HasLen): class bytes(HasLen):
def decode(self, encoding: str) -> str: ... def decode(self, encoding: str) -> str: ...
class list(Generic[U], HasLen): class list(Generic[U], HasLen, Iterable[U]):
def __add__(self, other: Self) -> Self: ... def __add__(self, other: Self) -> Self: ...
...@@ -40,16 +48,16 @@ class list(Generic[U], HasLen): ...@@ -40,16 +48,16 @@ class list(Generic[U], HasLen):
def pop(self, index: int = -1) -> U: ... def pop(self, index: int = -1) -> U: ...
assert [1, 2].__iter__()
assert list[int].__iter__
assert(len(["a"])) assert(len(["a"]))
assert [1, 2, 3][1] assert [1, 2, 3][1]
class Iterator(Generic[U]):
def __iter__(self) -> Self: ...
def __next__(self) -> U: ...
def next(it: Iterator[U], default: None) -> U: def next(it: Iterator[U], default: None) -> U:
... ...
......
...@@ -40,7 +40,10 @@ if __name__ == "__main__": ...@@ -40,7 +40,10 @@ if __name__ == "__main__":
print(is_cpp) print(is_cpp)
# TODO: doesn't compile under G++ 12.2, fixed in trunk on March 15 # TODO: doesn't compile under G++ 12.2, fixed in trunk on March 15
# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98056 # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98056
sum = 0
for i in range(15):
sum += i
print("C++ " if is_cpp() else "Python", print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5, "res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5,
3j) 3j, sum)
print() print()
...@@ -42,15 +42,26 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -42,15 +42,26 @@ class StdlibVisitor(NodeVisitorSeq):
if existing := self.scope.get(node.name): if existing := self.scope.get(node.name):
ty = existing.type ty = existing.type
else: else:
ty = TypeType(TypeOperator([], node.name)) class TheType(TypeOperator):
def __init__(self, *args):
super().__init__(args, node.name)
ty = TypeType(TheType)
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
typevars = [] typevars = []
for b in node.bases: for b in node.bases:
if isinstance(b, ast.Subscript) and isinstance(b.value, ast.Name) and b.value.id == "Generic": if isinstance(b, ast.Subscript):
if isinstance(b.slice, ast.Name): if isinstance(b.slice, ast.Name):
typevars = [b.slice.id] sliceval = [b.slice.id]
elif isinstance(b.slice, ast.Tuple): elif isinstance(b.slice, ast.Tuple):
typevars = [n.id for n in b.slice.value.elts] sliceval = [n.id for n in b.slice.value.elts]
if isinstance(b.value, ast.Name) and b.value.id == "Generic":
typevars = sliceval
else:
idxs = [typevars.index(v) for v in sliceval]
parent = self.visit(b.value)
assert isinstance(parent, TypeType)
assert isinstance(ty.type_object, ABCMeta)
ty.type_object.gen_parents.append(lambda selfvars: parent.type_object(*[selfvars[i] for i in idxs]))
else: else:
parent = self.visit(b) parent = self.visit(b)
assert isinstance(parent, TypeType) assert isinstance(parent, TypeType)
...@@ -58,6 +69,8 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -58,6 +69,8 @@ class StdlibVisitor(NodeVisitorSeq):
ty.type_object.gen_parents.append(parent.type_object) ty.type_object.gen_parents.append(parent.type_object)
else: else:
ty.type_object.parents.append(parent.type_object) ty.type_object.parents.append(parent.type_object)
if not typevars and not existing:
ty.type_object = ty.type_object()
cl_scope = self.scope.child(ScopeKind.CLASS) cl_scope = self.scope.child(ScopeKind.CLASS)
visitor = StdlibVisitor(cl_scope, ty) visitor = StdlibVisitor(cl_scope, ty)
for var in typevars: for var in typevars:
......
...@@ -139,7 +139,11 @@ class TypeOperator(BaseType, ABC): ...@@ -139,7 +139,11 @@ class TypeOperator(BaseType, ABC):
self.name = self.__class__.__name__ self.name = self.__class__.__name__
for name, factory in self.gen_methods.items(): for name, factory in self.gen_methods.items():
self.methods[name] = factory(self) self.methods[name] = factory(self)
self.parents.extend(self.gen_parents) for gp in self.gen_parents:
if not isinstance(gp, BaseType):
gp = gp(self.args)
self.parents.append(gp)
self.methods = {**gp.methods, **self.methods}
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
...@@ -300,6 +304,10 @@ class TypeType(TypeOperator): ...@@ -300,6 +304,10 @@ class TypeType(TypeOperator):
def type_object(self) -> BaseType: def type_object(self) -> BaseType:
return self.args[0] return self.args[0]
@type_object.setter
def type_object(self, value: BaseType):
self.args[0] = value
TY_TYPE = TypeOperator.make_type("type") TY_TYPE = TypeOperator.make_type("type")
TY_INT = TypeOperator.make_type("int") TY_INT = TypeOperator.make_type("int")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment