Commit 32a6dcfe authored by Tom Niget's avatar Tom Niget

Merge .members and .methods; fix unification for hierarchy lookup

parent e2134ee5
# coding: utf-8
from enum import Enum
class TokenType(Enum):
NUMBER = 1
PARENTHESIS = 2
OPERATION = 3
if __name__ == "__main__":
x = TokenType.NUMBER
\ No newline at end of file
...@@ -124,11 +124,7 @@ class BlockVisitor(NodeVisitor): ...@@ -124,11 +124,7 @@ class BlockVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef): def visit_ClassDef(self, node: ast.ClassDef):
yield from () yield from ()
def check(self, f): has_return = ReturnVisitor().match(node.body)
for b in node.body:
yield from self.match(node)
has_return = next(ReturnVisitor().check(node), False)
yield from self.visit_func_decls(node.body, inner_scope) yield from self.visit_func_decls(node.body, inner_scope)
......
...@@ -29,12 +29,12 @@ class ClassVisitor(NodeVisitor): ...@@ -29,12 +29,12 @@ class ClassVisitor(NodeVisitor):
yield "int value;" yield "int value;"
yield "operator int() const { return value; }" yield "operator int() const { return value; }"
yield "void py_repr(std::ostream &s) const {" yield "void py_repr(std::ostream &s) const {"
yield f's << "{node.name}." << value;' yield f's << "{node.name}.";'
yield "}" yield "}"
else: else:
yield "void py_repr(std::ostream &s) const {" yield "void py_repr(std::ostream &s) const {"
yield "s << '{';" yield "s << '{';"
for i, (name, memb) in enumerate(node.type.members.items()): for i, (name, memb) in enumerate(node.type.fields.items()):
if i != 0: if i != 0:
yield 's << ", ";' yield 's << ", ";'
yield f's << "\\"{name}\\": ";' yield f's << "\\"{name}\\": ";'
...@@ -63,8 +63,8 @@ class ClassInnerVisitor(NodeVisitor): ...@@ -63,8 +63,8 @@ class ClassInnerVisitor(NodeVisitor):
scope: Scope scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]: def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.members[node.target.id] member = self.scope.obj_type.fields[node.target.id]
yield from self.visit(member) yield from self.visit(member.type)
yield node.target.id yield node.target.id
yield ";" yield ";"
......
...@@ -26,9 +26,9 @@ class ModuleVisitor(BlockVisitor): ...@@ -26,9 +26,9 @@ class ModuleVisitor(BlockVisitor):
yield f"namespace py_{concrete} {{" yield f"namespace py_{concrete} {{"
yield f"struct {concrete}_t {{" yield f"struct {concrete}_t {{"
for name, obj in alias.module_obj.members.items(): for name, obj in alias.module_obj.fields.items():
if obj.python_func_used: if obj.type.python_func_used:
yield from self.emit_python_func(alias.name, name, name, obj) yield from self.emit_python_func(alias.name, name, name, obj.type)
yield "} all;" yield "} all;"
yield f"auto& get_all() {{ return all; }}" yield f"auto& get_all() {{ return all; }}"
......
...@@ -15,4 +15,6 @@ class SearchVisitor(ast.NodeVisitor): ...@@ -15,4 +15,6 @@ class SearchVisitor(ast.NodeVisitor):
yield from self.visit(value) yield from self.visit(value)
def match(self, node) -> bool: def match(self, node) -> bool:
if type(node) == list:
return any(self.match(n) for n in node)
return next(self.visit(node), False) return next(self.visit(node), False)
...@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope ...@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \ from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \ TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature, UnionType ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature, UnionType, MemberDef
PRELUDE.vars.update({ PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT), # "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
...@@ -46,7 +46,7 @@ typon_std = Path(__file__).parent.parent.parent.parent / "stdlib" ...@@ -46,7 +46,7 @@ typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def make_module(name: str, scope: Scope) -> BaseType: def make_module(name: str, scope: Scope) -> BaseType:
ty = ModuleType([], f"{name}") ty = ModuleType([], f"{name}")
for n, v in scope.vars.items(): for n, v in scope.vars.items():
ty.members[n] = v.type ty.fields[n] = MemberDef(v.type, v.val, False)
return ty return ty
......
...@@ -57,7 +57,7 @@ class TypeAnnotationVisitor(NodeVisitorSeq): ...@@ -57,7 +57,7 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
def visit_Attribute(self, node: ast.Attribute) -> BaseType: def visit_Attribute(self, node: ast.Attribute) -> BaseType:
left = self.visit(node.value) left = self.visit(node.value)
res = left.members[node.attr] res = left.fields[node.attr].type
assert isinstance(res, TypeType) assert isinstance(res, TypeType)
return res.type_object return res.type_object
......
...@@ -11,7 +11,8 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER ...@@ -11,7 +11,8 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \
RuntimeValue
from transpiler.phases.utils import PlainBlock, AnnotationName from transpiler.phases.utils import PlainBlock, AnnotationName
...@@ -167,7 +168,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -167,7 +168,7 @@ class ScoperBlockVisitor(ScoperVisitor):
init_method = ast.FunctionDef( init_method = ast.FunctionDef(
name="__init__", name="__init__",
args=ast.arguments( args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.members]], args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]],
defaults=[], defaults=[],
kw_defaults=[], kw_defaults=[],
kwarg=None, kwarg=None,
...@@ -179,7 +180,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -179,7 +180,7 @@ class ScoperBlockVisitor(ScoperVisitor):
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)], targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n), value=ast.Name(id=n),
**lnd **lnd
) for n in ctype.members ) for n in ctype.get_members()
], ],
decorator_list=[], decorator_list=[],
returns=None, returns=None,
...@@ -195,9 +196,11 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -195,9 +196,11 @@ class ScoperBlockVisitor(ScoperVisitor):
base = self.expr().visit(base) base = self.expr().visit(base)
if is_builtin(base, "Enum"): if is_builtin(base, "Enum"):
ctype.parents.append(TY_INT) ctype.parents.append(TY_INT)
for k in ctype.members: for k, m in ctype.fields.items():
ctype.members[k] = ctype m.type = ctype
ctype.members["value"] = TY_INT m.val = ast.literal_eval(m.val)
assert type(m.val) == int
ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node) lnd = linenodata(node)
init_method = ast.FunctionDef( init_method = ast.FunctionDef(
name="__init__", name="__init__",
......
...@@ -4,7 +4,7 @@ from dataclasses import dataclass, field ...@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef
@dataclass @dataclass
...@@ -15,15 +15,15 @@ class ScoperClassVisitor(ScoperVisitor): ...@@ -15,15 +15,15 @@ class ScoperClassVisitor(ScoperVisitor):
assert node.value is None, "Class field should not have a value" assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)" assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name) assert isinstance(node.target, ast.Name)
self.scope.obj_type.members[node.target.id] = self.visit_annotation(node.annotation) self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation))
def visit_Assign(self, node: ast.Assign): def visit_Assign(self, node: ast.Assign):
assert len(node.targets) == 1, "Class field should be assigned to only once" assert len(node.targets) == 1, "Can't use destructuring in class static member"
assert isinstance(node.targets[0], ast.Name) assert isinstance(node.targets[0], ast.Name)
node.is_declare = True node.is_declare = True
valtype = self.expr().visit(node.value) valtype = self.expr().visit(node.value)
node.targets[0].type = valtype node.targets[0].type = valtype
self.scope.obj_type.members[node.targets[0].id] = valtype self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value)
def visit_FunctionDef(self, node: ast.FunctionDef): def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node) ftype = self.parse_function(node)
...@@ -32,5 +32,5 @@ class ScoperClassVisitor(ScoperVisitor): ...@@ -32,5 +32,5 @@ class ScoperClassVisitor(ScoperVisitor):
if node.name != "__init__": if node.name != "__init__":
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK) ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype.is_method = True ftype.is_method = True
self.scope.obj_type.methods[node.name] = ftype self.scope.obj_type.fields[node.name] = MemberDef(ftype, node)
return (node, inner) return (node, inner)
...@@ -108,7 +108,7 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -108,7 +108,7 @@ class ScoperVisitor(NodeVisitorSeq):
def get_iter(seq_type): def get_iter(seq_type):
try: try:
iter_type = seq_type.methods["__iter__"].return_type iter_type = seq_type.fields["__iter__"].type.return_type
except: except:
from transpiler.phases.typing.exceptions import NotIterableError from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type) raise NotIterableError(seq_type)
...@@ -116,7 +116,7 @@ def get_iter(seq_type): ...@@ -116,7 +116,7 @@ def get_iter(seq_type):
def get_next(iter_type): def get_next(iter_type):
try: try:
next_type = iter_type.methods["__next__"].return_type next_type = iter_type.fields["__next__"].type.return_type
except: except:
from transpiler.phases.typing.exceptions import NotIteratorError from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type) raise NotIteratorError(iter_type)
......
...@@ -174,6 +174,11 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -174,6 +174,11 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_getattr(self, ltype: BaseType, name: str) -> BaseType: def visit_getattr(self, ltype: BaseType, name: str) -> BaseType:
bound = True bound = True
if isinstance(ltype, TypeType): if isinstance(ltype, TypeType):
# if mdecl := ltype.static_members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
ltype = ltype.type_object ltype = ltype.type_object
bound = False bound = False
if isinstance(ltype, abc.ABCMeta): if isinstance(ltype, abc.ABCMeta):
...@@ -182,16 +187,28 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -182,16 +187,28 @@ class ScoperExprVisitor(ScoperVisitor):
if not all(arg.annotation == BaseType for arg in args): if not all(arg.annotation == BaseType for arg in args):
raise NotImplementedError("I don't know how to handle this type") raise NotImplementedError("I don't know how to handle this type")
ltype = ltype(*(TypeVariable() for _ in args)) ltype = ltype(*(TypeVariable() for _ in args))
if attr := ltype.members.get(name): # if mdecl := ltype.members.get(name):
if getattr(attr, "is_python_func", False): # attr = mdecl.type
attr.python_func_used = True # if getattr(attr, "is_python_func", False):
return attr # attr.python_func_used = True
if meth := ltype.methods.get(name): # return attr
meth = meth.gen_sub(ltype, {}) # if meth := ltype.methods.get(name):
if bound: # meth = meth.gen_sub(ltype, {})
return meth.remove_self() # if bound:
else: # return meth.remove_self()
return meth # else:
# return meth
if field := ltype.fields.get(name):
ty = field.type
if getattr(ty, "is_python_func", False):
ty.python_func_used = True
if isinstance(ty, FunctionType):
ty = ty.gen_sub(ltype, {})
if bound and field.in_class_def:
return ty.remove_self()
return ty
from transpiler.phases.typing.exceptions import MissingAttributeError from transpiler.phases.typing.exceptions import MissingAttributeError
parents = ltype.iter_hierarchy_recursive() parents = ltype.iter_hierarchy_recursive()
next(parents) next(parents)
......
...@@ -3,7 +3,7 @@ from dataclasses import field, dataclass ...@@ -3,7 +3,7 @@ from dataclasses import field, dataclass
from enum import Enum from enum import Enum
from typing import Optional, Dict, List, Any from typing import Optional, Dict, List, Any
from transpiler.phases.typing.types import BaseType from transpiler.phases.typing.types import BaseType, RuntimeValue
class VarKind(Enum): class VarKind(Enum):
...@@ -23,10 +23,6 @@ class VarType: ...@@ -23,10 +23,6 @@ class VarType:
pass pass
class RuntimeValue:
pass
@dataclass @dataclass
class VarDecl: class VarDecl:
kind: VarKind kind: VarKind
......
...@@ -8,7 +8,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor ...@@ -8,7 +8,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE from transpiler.phases.typing.common import PRELUDE
from transpiler.phases.typing.expr import ScoperExprVisitor from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \
MemberDef
from transpiler.phases.utils import NodeVisitorSeq from transpiler.phases.utils import NodeVisitorSeq
...@@ -36,7 +37,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -36,7 +37,7 @@ class StdlibVisitor(NodeVisitorSeq):
if isinstance(self.cur_class.type_object, ABCMeta): if isinstance(self.cur_class.type_object, ABCMeta):
raise NotImplementedError raise NotImplementedError
else: else:
self.cur_class.type_object.members[node.target.id] = ty.gen_sub(self.cur_class.type_object, self.typevars) self.cur_class.type_object.fields[node.target.id] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, ty) self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, ty)
def visit_ImportFrom(self, node: ast.ImportFrom): def visit_ImportFrom(self, node: ast.ImportFrom):
...@@ -110,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -110,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq):
if isinstance(self.cur_class.type_object, ABCMeta): if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars) self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else: else:
self.cur_class.type_object.methods[node.name] = ty.gen_sub(self.cur_class.type_object, self.typevars) self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert): def visit_Assert(self, node: ast.Assert):
......
...@@ -13,12 +13,36 @@ def get_default_parents(): ...@@ -13,12 +13,36 @@ def get_default_parents():
return [obj] return [obj]
return [] return []
class RuntimeValue:
pass
@dataclass
class MemberDef:
type: "BaseType"
val: typing.Any = RuntimeValue()
in_class_def: bool = True
@dataclass
class UnifyMode:
search_hierarchy: bool = True
match_protocol: bool = True
UnifyMode.NORMAL = UnifyMode()
UnifyMode.EXACT = UnifyMode(False, False)
@dataclass(eq=False) @dataclass(eq=False)
class BaseType(ABC): class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False) #members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False) #methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=get_default_parents, init=False) parents: List["BaseType"] = field(default_factory=get_default_parents, init=False)
typevars: List["TypeVariable"] = field(default_factory=list, init=False) typevars: List["TypeVariable"] = field(default_factory=list, init=False)
#static_members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
def get_members(self):
return {n: m for n, m in self.fields.items() if type(m.val) is RuntimeValue}
def get_parents(self) -> List["BaseType"]: def get_parents(self) -> List["BaseType"]:
...@@ -41,21 +65,29 @@ class BaseType(ABC): ...@@ -41,21 +65,29 @@ class BaseType(ABC):
queue.put(p) queue.put(p)
def inherits_from(self, other: "BaseType") -> bool: def inherits_from(self, other: "BaseType") -> bool:
return other in self.iter_hierarchy_recursive() from transpiler.exceptions import CompileError
for parent in self.iter_hierarchy_recursive():
try:
parent.unify(other, UnifyMode.EXACT)
except CompileError:
pass
else:
return True
return False
def resolve(self) -> "BaseType": def resolve(self) -> "BaseType":
return self return self
@abstractmethod @abstractmethod
def unify_internal(self, other: "BaseType"): def unify_internal(self, other: "BaseType", mode: UnifyMode):
pass pass
def unify(self, other: "BaseType"): def unify(self, other: "BaseType", mode = UnifyMode.NORMAL):
a, b = self.resolve(), other.resolve() a, b = self.resolve(), other.resolve()
TB = f"unifying {highlight(a)} and {highlight(b)}" TB = f"unifying {highlight(a)} and {highlight(b)}"
if isinstance(b, TypeVariable): if isinstance(b, TypeVariable):
a, b = b, a a, b = b, a
a.unify_internal(b) a.unify_internal(b, mode)
def contains(self, other: "BaseType") -> bool: def contains(self, other: "BaseType") -> bool:
needle, haystack = other.resolve(), self.resolve() needle, haystack = other.resolve(), self.resolve()
...@@ -86,7 +118,7 @@ class MagicType(BaseType, typing.Generic[T]): ...@@ -86,7 +118,7 @@ class MagicType(BaseType, typing.Generic[T]):
super().__init__() super().__init__()
self.val = val self.val = val
def unify_internal(self, other: "BaseType"): def unify_internal(self, other: "BaseType", mode: UnifyMode):
if type(self) != type(other) or self.val != other.val: if type(self) != type(other) or self.val != other.val:
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
...@@ -128,7 +160,7 @@ class TypeVariable(BaseType): ...@@ -128,7 +160,7 @@ class TypeVariable(BaseType):
return self return self
return self.resolved.resolve() return self.resolved.resolve()
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType, mode: UnifyMode):
if self is not other: if self is not other:
if other.contains(self): if other.contains(self):
from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError
...@@ -178,19 +210,19 @@ class TypeOperator(BaseType, ABC): ...@@ -178,19 +210,19 @@ class TypeOperator(BaseType, ABC):
if self.name is None: if self.name is None:
self.name = self.__class__.__name__ self.name = self.__class__.__name__
for name, factory in self.gen_methods.items(): for name, factory in self.gen_methods.items():
self.methods[name] = factory(self) self.fields[name] = MemberDef(factory(self))
for gp in self.gen_parents: for gp in self.gen_parents:
if not isinstance(gp, BaseType): if not isinstance(gp, BaseType):
gp = gp(self.args) gp = gp(self.args)
self.parents.append(gp) self.parents.append(gp)
self.methods = {**gp.methods, **self.methods} self.fields = {**gp.fields, **self.fields}
self.is_protocol = self.is_protocol or self.is_protocol_gen self.is_protocol = self.is_protocol or self.is_protocol_gen
self._add_default_eq() self._add_default_eq()
def _add_default_eq(self): def _add_default_eq(self):
if "__eq__" not in self.methods: if "__eq__" not in self.fields:
if "DEFAULT_EQ" in globals(): if "DEFAULT_EQ" in globals():
self.methods["__eq__"] = DEFAULT_EQ self.fields["__eq__"] = MemberDef(DEFAULT_EQ)
def matches_protocol(self, protocol: "TypeOperator"): def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache: if hash(protocol) in self.match_cache:
...@@ -199,33 +231,35 @@ class TypeOperator(BaseType, ABC): ...@@ -199,33 +231,35 @@ class TypeOperator(BaseType, ABC):
try: try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args}) dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol)) self.match_cache.add(hash(protocol))
for name, ty in dupl.methods.items(): for name, ty in dupl.fields.items():
if name == "__eq__": if name == "__eq__":
continue continue
if name not in self.methods: if name not in self.fields:
raise ProtocolMismatchError(self, protocol, f"missing method {name}") raise ProtocolMismatchError(self, protocol, f"missing method {name}")
corresp = self.methods[name] corresp = self.fields[name].type
corresp.remove_self().unify(ty.remove_self()) corresp.remove_self().unify(ty.type.remove_self())
except TypeMismatchError as e: except TypeMismatchError as e:
if hash(protocol) in self.match_cache: if hash(protocol) in self.match_cache:
self.match_cache.remove(hash(protocol)) self.match_cache.remove(hash(protocol))
raise ProtocolMismatchError(self, protocol, e) raise ProtocolMismatchError(self, protocol, e)
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType, mode: UnifyMode):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
# TODO(zdimension): this is really broken... but it would be nice # TODO(zdimension): this is really broken... but it would be nice
# if from_node := next(filter(None, (getattr(x, "from_node", None) for x in (other, self))), None): # if from_node := next(filter(None, (getattr(x, "from_node", None) for x in (other, self))), None):
# TB_NODE = from_node # TB_NODE = from_node
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if other.is_protocol and not self.is_protocol: if mode.match_protocol:
return other.unify_internal(self) if other.is_protocol and not self.is_protocol:
if self.is_protocol and not other.is_protocol: return other.unify_internal(self, mode)
return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message
assert self.is_protocol == other.is_protocol assert self.is_protocol == other.is_protocol
if type(self) != type(other): # and ((TY_NONE not in {self, other}) or isinstance(({self, other} - {TY_NONE}).pop(), UnionType)): if type(self) != type(other): # and ((TY_NONE not in {self, other}) or isinstance(({self, other} - {TY_NONE}).pop(), UnionType)):
if self.inherits_from(other) or other.inherits_from(self): if mode.search_hierarchy:
return if self.inherits_from(other) or other.inherits_from(self):
return
# for parent in other.get_parents(): # for parent in other.get_parents():
# try: # try:
# self.unify(parent) # self.unify(parent)
...@@ -242,8 +276,8 @@ class TypeOperator(BaseType, ABC): ...@@ -242,8 +276,8 @@ class TypeOperator(BaseType, ABC):
# return # return
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if len(self.args) < len(other.args): if len(self.args) < len(other.args):
return other.unify_internal(self) return other.unify_internal(self, mode)
if len(self.args) == 0: if True or len(self.args) == 0: # todo: why check len?
if self.name != other.name: if self.name != other.name:
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
for i, (a, b) in enumerate(zip_longest(self.args, other.args)): for i, (a, b) in enumerate(zip_longest(self.args, other.args)):
...@@ -292,7 +326,7 @@ class TypeOperator(BaseType, ABC): ...@@ -292,7 +326,7 @@ class TypeOperator(BaseType, ABC):
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
setattr(res, k, v) setattr(res, k, v)
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args] res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()} res.fields = {k: dataclasses.replace(v, type=v.type.gen_sub(this, vardict, cache)) for k, v in self.fields.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents] res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
#res.is_protocol = self.is_protocol #res.is_protocol = self.is_protocol
return res return res
...@@ -466,10 +500,10 @@ class Promise(TypeOperator, ABC): ...@@ -466,10 +500,10 @@ class Promise(TypeOperator, ABC):
if value == PromiseKind.GENERATOR: if value == PromiseKind.GENERATOR:
f_iter = FunctionType([], self) f_iter = FunctionType([], self)
f_iter.is_method = True f_iter.is_method = True
self.methods["__iter__"] = f_iter self.fields["__iter__"] = MemberDef(f_iter, ())
f_next = FunctionType([], self.return_type) f_next = FunctionType([], self.return_type)
f_next.is_method = True f_next.is_method = True
self.methods["__next__"] = f_next self.fields["__next__"] = MemberDef(f_next, ())
self.args[1].val = value self.args[1].val = value
def __str__(self): def __str__(self):
...@@ -506,7 +540,7 @@ class UserType(TypeOperator): ...@@ -506,7 +540,7 @@ class UserType(TypeOperator):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__([], name=name, is_reference=True) super().__init__([], name=name, is_reference=True)
def unify_internal(self, other: "BaseType"): def unify_internal(self, other: "BaseType", mode: UnifyMode):
if type(self) != type(other): if type(self) != type(other):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE) raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment