Commit 1871431a authored by Tom Niget's avatar Tom Niget

Fix generic substitution for methods

parent 44123167
...@@ -46,7 +46,7 @@ class BaseType(ABC): ...@@ -46,7 +46,7 @@ class BaseType(ABC):
# def clone(self) -> "BaseType": # def clone(self) -> "BaseType":
# pass # pass
def gen_sub(self, this: "BaseType", typevars: Dict[str, "BaseType"]) -> "Self": def gen_sub(self, this: "BaseType", typevars: Dict[str, "BaseType"], cache=None) -> "Self":
return self return self
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
...@@ -104,7 +104,7 @@ class TypeVariable(BaseType): ...@@ -104,7 +104,7 @@ class TypeVariable(BaseType):
def contains_internal(self, other: BaseType) -> bool: def contains_internal(self, other: BaseType) -> bool:
return self.resolve() is other.resolve() return self.resolve() is other.resolve()
def gen_sub(self, this: "BaseType", typevars) -> "Self": def gen_sub(self, this: "BaseType", typevars, cache=None) -> "Self":
if match := typevars.get(self.name): if match := typevars.get(self.name):
return match return match
return self return self
...@@ -193,17 +193,22 @@ class TypeOperator(BaseType, ABC): ...@@ -193,17 +193,22 @@ class TypeOperator(BaseType, ABC):
def __hash__(self): def __hash__(self):
return hash((self.name, tuple(self.args))) return hash((self.name, tuple(self.args)))
def gen_sub(self, this: BaseType, typevars) -> "Self": def gen_sub(self, this: BaseType, typevars, cache=None) -> "Self":
cache = cache or {}
if me := cache.get(self):
return me
if len(self.args) == 0: if len(self.args) == 0:
return self return self
res = object.__new__(self.__class__) # todo: ugly... should make a clone() res = object.__new__(self.__class__) # todo: ugly... should make a clone()
cache[self] = res
if isinstance(this, TypeOperator): if isinstance(this, TypeOperator):
vardict = dict(zip(typevars.keys(), this.args)) vardict = dict(zip(typevars.keys(), this.args))
else: else:
vardict = typevars vardict = typevars
for k in dataclasses.fields(self): for k in dataclasses.fields(self):
setattr(res, k.name, getattr(self, k.name)) setattr(res, k.name, getattr(self, k.name))
res.args = [arg.resolve().gen_sub(this, vardict) for arg in self.args] res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()}
return res return res
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
...@@ -269,7 +274,7 @@ TY_NONE = TypeOperator([], "NoneType") ...@@ -269,7 +274,7 @@ TY_NONE = TypeOperator([], "NoneType")
TY_MODULE = TypeOperator([], "module") TY_MODULE = TypeOperator([], "module")
TY_VARARG = TypeOperator([], "vararg") TY_VARARG = TypeOperator([], "vararg")
TY_SELF = TypeOperator([], "Self") TY_SELF = TypeOperator([], "Self")
TY_SELF.gen_sub = lambda this, typevars: this TY_SELF.gen_sub = lambda this, typevars, _: this
class PyList(TypeOperator): class PyList(TypeOperator):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment