From b1ddc670c5d0ec17169305724ba6863832f8fb4a Mon Sep 17 00:00:00 2001
From: Tom Niget <tom.niget@etu.univ-cotedazur.fr>
Date: Mon, 29 May 2023 13:17:53 +0200
Subject: [PATCH] Fix unification checks and type instantiation

---
 trans/stdlib/__init__.py                   | 25 +++++++++++++---
 trans/transpiler/phases/typing/__init__.py |  5 ++--
 trans/transpiler/phases/typing/expr.py     | 11 ++++++--
 trans/transpiler/phases/typing/types.py    | 33 ++++++++++++++++++----
 4 files changed, 60 insertions(+), 14 deletions(-)

diff --git a/trans/stdlib/__init__.py b/trans/stdlib/__init__.py
index 669b6e4..fff21f9 100644
--- a/trans/stdlib/__init__.py
+++ b/trans/stdlib/__init__.py
@@ -1,4 +1,4 @@
-from typing import Self, TypeVar, Generic
+from typing import Self, TypeVar, Generic, Tuple
 
 
 class int:
@@ -13,6 +13,7 @@ class int:
 
 assert int.__add__
 U = TypeVar("U")
+V = TypeVar("V")
 
 
 class list(Generic[U]):
@@ -44,13 +45,29 @@ def identity(x: U) -> U:
 assert identity(1)
 assert identity("a")
 
+def identity_2(x: U, y: V) -> Tuple[U, V]:
+    ...
+
+assert list.__add__
+assert list.__add__([5], [[6][0]])
+assert list[U].__add__
+assert list[U].__add__([1], [2])
+assert list[U].__add__
+assert list[int].__add__
+assert identity_2(1, "a")
+assert lambda x, y: identity_2(x, y)
+assert lambda x: identity_2(x, x)
+
 def print(*args) -> None: ...
 
 
 def range(*args) -> Iterator[int]: ...
-def rangeb(*args) -> Iterator[bool]: ...
 
+assert [].__add__
 assert [6].__add__
 assert [True].__add__
-assert next(range(6), None)
-assert next(rangeb(6), None)
\ No newline at end of file
+assert lambda x: [x].__add__
+
+
+
+assert next(range(6), None)
\ No newline at end of file
diff --git a/trans/transpiler/phases/typing/__init__.py b/trans/transpiler/phases/typing/__init__.py
index f06186c..b41707d 100644
--- a/trans/transpiler/phases/typing/__init__.py
+++ b/trans/transpiler/phases/typing/__init__.py
@@ -4,7 +4,7 @@ from pathlib import Path
 from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind
 from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
 from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
-    TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future, PyIterator
+    TypeVariable, TY_MODULE, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType
 
 PRELUDE.vars.update({
     # "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
@@ -28,7 +28,8 @@ PRELUDE.vars.update({
     "Forked": VarDecl(VarKind.LOCAL, TypeType(Forked)),
     "Task": VarDecl(VarKind.LOCAL, TypeType(Task)),
     "Future": VarDecl(VarKind.LOCAL, TypeType(Future)),
-    "Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator))
+    "Iterator": VarDecl(VarKind.LOCAL, TypeType(PyIterator)),
+    "Tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
 })
 
 typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
diff --git a/trans/transpiler/phases/typing/expr.py b/trans/transpiler/phases/typing/expr.py
index fba02e7..1ba8f62 100644
--- a/trans/transpiler/phases/typing/expr.py
+++ b/trans/transpiler/phases/typing/expr.py
@@ -1,5 +1,6 @@
 import abc
 import ast
+import inspect
 from typing import List
 
 from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
@@ -108,7 +109,7 @@ class ScoperExprVisitor(ScoperVisitor):
         try:
             ftype.unify(equivalent)
         except IncompatibleTypesError as e:
-            raise IncompatibleTypesError(f"Cannot call {ftype} with {equivalent}: {e}")
+            raise IncompatibleTypesError(f"Cannot call {ftype} with ({(', '.join(map(str, arguments)))}): {e}")
         return ftype.return_type
 
     def visit_Lambda(self, node: ast.Lambda) -> BaseType:
@@ -147,6 +148,12 @@ class ScoperExprVisitor(ScoperVisitor):
         if isinstance(ltype, TypeType):
             ltype = ltype.type_object
             bound = False
+        if isinstance(ltype, abc.ABCMeta):
+            ctor = ltype.__init__
+            args = list(inspect.signature(ctor).parameters.values())[1:]
+            if not all(arg.annotation == BaseType for arg in args):
+                raise NotImplementedError("I don't know how to handle this type")
+            ltype = ltype(*(TypeVariable() for _ in args))
         if attr := ltype.members.get(name):
             return attr
         if meth := ltype.methods.get(name):
@@ -187,7 +194,7 @@ class ScoperExprVisitor(ScoperVisitor):
         args = [self.visit(e) for e in args]
         if isinstance(left, TypeType) and isinstance(left.type_object, abc.ABCMeta):
             # generic
-            return TypeType(left.type_object(*[arg.type_object for arg in args]))
+            return TypeType(left.type_object(*[arg.type_object if isinstance(arg, TypeType) else arg for arg in args]))
             pass
         return self.make_dunder([left, *args], "getitem")
 
diff --git a/trans/transpiler/phases/typing/types.py b/trans/transpiler/phases/typing/types.py
index b47e11a..efb2cc4 100644
--- a/trans/transpiler/phases/typing/types.py
+++ b/trans/transpiler/phases/typing/types.py
@@ -1,3 +1,4 @@
+import dataclasses
 import typing
 from abc import ABC, abstractmethod
 from dataclasses import dataclass, field
@@ -41,6 +42,10 @@ class BaseType(ABC):
     def contains_internal(self, other: "BaseType") -> bool:
         pass
 
+    # @abstractmethod
+    # def clone(self) -> "BaseType":
+    #     pass
+
     def gen_sub(self, this: "BaseType", typevars: Dict[str, "BaseType"]) -> "Self":
         return self
 
@@ -49,6 +54,8 @@ class BaseType(ABC):
 
 
 T = typing.TypeVar("T")
+
+
 class MagicType(BaseType, typing.Generic[T]):
     val: T
 
@@ -66,6 +73,9 @@ class MagicType(BaseType, typing.Generic[T]):
     def __str__(self):
         return str(self.val)
 
+    def clone(self) -> "BaseType":
+        return MagicType(self.val)
+
 
 cur_var = 0
 
@@ -161,6 +171,9 @@ class TypeOperator(BaseType, ABC):
 
             if not (self.variadic or other.variadic):
                 raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
+        if len(self.args) == 0:
+            if self.name != other.name:
+                raise IncompatibleTypesError(f"Cannot unify {self} and {other}")
         for i, (a, b) in enumerate(zip_longest(self.args, other.args)):
             if a is None and self.variadic or b is None and other.variadic:
                 continue
@@ -171,7 +184,6 @@ class TypeOperator(BaseType, ABC):
                 if a != b:
                     raise IncompatibleTypesError(f"Cannot unify {a} and {b}")
 
-
     def contains_internal(self, other: "BaseType") -> bool:
         return any(arg.contains(other) for arg in self.args)
 
@@ -182,20 +194,23 @@ class TypeOperator(BaseType, ABC):
         return hash((self.name, tuple(self.args)))
 
     def gen_sub(self, this: BaseType, typevars) -> "Self":
+        if len(self.args) == 0:
+            return self
         res = object.__new__(self.__class__)  # todo: ugly... should make a clone()
         if isinstance(this, TypeOperator):
             vardict = dict(zip(typevars.keys(), this.args))
         else:
             vardict = typevars
+        for k in dataclasses.fields(self):
+            setattr(res, k.name, getattr(self, k.name))
         res.args = [arg.resolve().gen_sub(this, vardict) for arg in self.args]
-        res.name = self.name
-        res.variadic = self.variadic
         return res
 
     def to_list(self) -> List["BaseType"]:
         return [self, *self.args]
 
 
+
 class FunctionType(TypeOperator):
     def __init__(self, args: List[BaseType], ret: BaseType):
         super().__init__([ret, *args])
@@ -287,6 +302,7 @@ class PyDict(TypeOperator):
     def value_type(self):
         return self.args[1]
 
+
 class PyIterator(TypeOperator):
     def __init__(self, arg: BaseType):
         super().__init__([arg], "iter")
@@ -296,9 +312,8 @@ class PyIterator(TypeOperator):
         return self.args[0]
 
 
-
 class TupleType(TypeOperator):
-    def __init__(self, args: List[BaseType]):
+    def __init__(self, *args: List[BaseType]):
         super().__init__(args, "tuple")
 
 
@@ -334,17 +349,23 @@ class Promise(TypeOperator, ABC):
             return [PyIterator(self.return_type), *super().get_parents()]
         return super().get_parents()
 
+
 class Forked(Promise):
     """Only use this for type specs"""
+
     def __init__(self, ret: BaseType):
         super().__init__(ret, PromiseKind.FORKED)
 
+
 class Task(Promise):
     """Only use this for type specs"""
+
     def __init__(self, ret: BaseType):
         super().__init__(ret, PromiseKind.TASK)
 
+
 class Future(Promise):
     """Only use this for type specs"""
+
     def __init__(self, ret: BaseType):
-        super().__init__(ret, PromiseKind.FUTURE)
\ No newline at end of file
+        super().__init__(ret, PromiseKind.FUTURE)
-- 
2.30.9