Commit 32a6dcfe authored by Tom Niget's avatar Tom Niget

Merge .members and .methods; fix unification for hierarchy lookup

parent e2134ee5
# coding: utf-8
from enum import Enum
class TokenType(Enum):
NUMBER = 1
PARENTHESIS = 2
OPERATION = 3
if __name__ == "__main__":
x = TokenType.NUMBER
\ No newline at end of file
......@@ -124,11 +124,7 @@ class BlockVisitor(NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
def check(self, f):
for b in node.body:
yield from self.match(node)
has_return = next(ReturnVisitor().check(node), False)
has_return = ReturnVisitor().match(node.body)
yield from self.visit_func_decls(node.body, inner_scope)
......
......@@ -29,12 +29,12 @@ class ClassVisitor(NodeVisitor):
yield "int value;"
yield "operator int() const { return value; }"
yield "void py_repr(std::ostream &s) const {"
yield f's << "{node.name}." << value;'
yield f's << "{node.name}.";'
yield "}"
else:
yield "void py_repr(std::ostream &s) const {"
yield "s << '{';"
for i, (name, memb) in enumerate(node.type.members.items()):
for i, (name, memb) in enumerate(node.type.fields.items()):
if i != 0:
yield 's << ", ";'
yield f's << "\\"{name}\\": ";'
......@@ -63,8 +63,8 @@ class ClassInnerVisitor(NodeVisitor):
scope: Scope
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
member = self.scope.obj_type.members[node.target.id]
yield from self.visit(member)
member = self.scope.obj_type.fields[node.target.id]
yield from self.visit(member.type)
yield node.target.id
yield ";"
......
......@@ -26,9 +26,9 @@ class ModuleVisitor(BlockVisitor):
yield f"namespace py_{concrete} {{"
yield f"struct {concrete}_t {{"
for name, obj in alias.module_obj.members.items():
if obj.python_func_used:
yield from self.emit_python_func(alias.name, name, name, obj)
for name, obj in alias.module_obj.fields.items():
if obj.type.python_func_used:
yield from self.emit_python_func(alias.name, name, name, obj.type)
yield "} all;"
yield f"auto& get_all() {{ return all; }}"
......
......@@ -15,4 +15,6 @@ class SearchVisitor(ast.NodeVisitor):
yield from self.visit(value)
def match(self, node) -> bool:
if type(node) == list:
return any(self.match(n) for n in node)
return next(self.visit(node), False)
......@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature, UnionType
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature, UnionType, MemberDef
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
......@@ -46,7 +46,7 @@ typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def make_module(name: str, scope: Scope) -> BaseType:
ty = ModuleType([], f"{name}")
for n, v in scope.vars.items():
ty.members[n] = v.type
ty.fields[n] = MemberDef(v.type, v.val, False)
return ty
......
......@@ -57,7 +57,7 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
def visit_Attribute(self, node: ast.Attribute) -> BaseType:
left = self.visit(node.value)
res = left.members[node.attr]
res = left.fields[node.attr].type
assert isinstance(res, TypeType)
return res.type_object
......
......@@ -11,7 +11,8 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature, TY_INT, MemberDef, \
RuntimeValue
from transpiler.phases.utils import PlainBlock, AnnotationName
......@@ -167,7 +168,7 @@ class ScoperBlockVisitor(ScoperVisitor):
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.members]],
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.get_members()]],
defaults=[],
kw_defaults=[],
kwarg=None,
......@@ -179,7 +180,7 @@ class ScoperBlockVisitor(ScoperVisitor):
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n),
**lnd
) for n in ctype.members
) for n in ctype.get_members()
],
decorator_list=[],
returns=None,
......@@ -195,9 +196,11 @@ class ScoperBlockVisitor(ScoperVisitor):
base = self.expr().visit(base)
if is_builtin(base, "Enum"):
ctype.parents.append(TY_INT)
for k in ctype.members:
ctype.members[k] = ctype
ctype.members["value"] = TY_INT
for k, m in ctype.fields.items():
m.type = ctype
m.val = ast.literal_eval(m.val)
assert type(m.val) == int
ctype.fields["value"] = MemberDef(TY_INT)
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
......
......@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from transpiler.phases.typing import FunctionType, ScopeKind, VarDecl, VarKind, TY_NONE
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType
from transpiler.phases.typing.types import PromiseKind, Promise, BaseType, MemberDef
@dataclass
......@@ -15,15 +15,15 @@ class ScoperClassVisitor(ScoperVisitor):
assert node.value is None, "Class field should not have a value"
assert node.simple == 1, "Class field should be simple (identifier, not parenthesized)"
assert isinstance(node.target, ast.Name)
self.scope.obj_type.members[node.target.id] = self.visit_annotation(node.annotation)
self.scope.obj_type.fields[node.target.id] = MemberDef(self.visit_annotation(node.annotation))
def visit_Assign(self, node: ast.Assign):
assert len(node.targets) == 1, "Class field should be assigned to only once"
assert len(node.targets) == 1, "Can't use destructuring in class static member"
assert isinstance(node.targets[0], ast.Name)
node.is_declare = True
valtype = self.expr().visit(node.value)
node.targets[0].type = valtype
self.scope.obj_type.members[node.targets[0].id] = valtype
self.scope.obj_type.fields[node.targets[0].id] = MemberDef(valtype, node.value)
def visit_FunctionDef(self, node: ast.FunctionDef):
ftype = self.parse_function(node)
......@@ -32,5 +32,5 @@ class ScoperClassVisitor(ScoperVisitor):
if node.name != "__init__":
ftype.return_type = Promise(ftype.return_type, PromiseKind.TASK)
ftype.is_method = True
self.scope.obj_type.methods[node.name] = ftype
self.scope.obj_type.fields[node.name] = MemberDef(ftype, node)
return (node, inner)
......@@ -108,7 +108,7 @@ class ScoperVisitor(NodeVisitorSeq):
def get_iter(seq_type):
try:
iter_type = seq_type.methods["__iter__"].return_type
iter_type = seq_type.fields["__iter__"].type.return_type
except:
from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
......@@ -116,7 +116,7 @@ def get_iter(seq_type):
def get_next(iter_type):
try:
next_type = iter_type.methods["__next__"].return_type
next_type = iter_type.fields["__next__"].type.return_type
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
......
......@@ -174,6 +174,11 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_getattr(self, ltype: BaseType, name: str) -> BaseType:
bound = True
if isinstance(ltype, TypeType):
# if mdecl := ltype.static_members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
ltype = ltype.type_object
bound = False
if isinstance(ltype, abc.ABCMeta):
......@@ -182,16 +187,28 @@ class ScoperExprVisitor(ScoperVisitor):
if not all(arg.annotation == BaseType for arg in args):
raise NotImplementedError("I don't know how to handle this type")
ltype = ltype(*(TypeVariable() for _ in args))
if attr := ltype.members.get(name):
if getattr(attr, "is_python_func", False):
attr.python_func_used = True
return attr
if meth := ltype.methods.get(name):
meth = meth.gen_sub(ltype, {})
if bound:
return meth.remove_self()
else:
return meth
# if mdecl := ltype.members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
# if meth := ltype.methods.get(name):
# meth = meth.gen_sub(ltype, {})
# if bound:
# return meth.remove_self()
# else:
# return meth
if field := ltype.fields.get(name):
ty = field.type
if getattr(ty, "is_python_func", False):
ty.python_func_used = True
if isinstance(ty, FunctionType):
ty = ty.gen_sub(ltype, {})
if bound and field.in_class_def:
return ty.remove_self()
return ty
from transpiler.phases.typing.exceptions import MissingAttributeError
parents = ltype.iter_hierarchy_recursive()
next(parents)
......
......@@ -3,7 +3,7 @@ from dataclasses import field, dataclass
from enum import Enum
from typing import Optional, Dict, List, Any
from transpiler.phases.typing.types import BaseType
from transpiler.phases.typing.types import BaseType, RuntimeValue
class VarKind(Enum):
......@@ -23,10 +23,6 @@ class VarType:
pass
class RuntimeValue:
pass
@dataclass
class VarDecl:
kind: VarKind
......
......@@ -8,7 +8,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from transpiler.phases.typing.common import PRELUDE
from transpiler.phases.typing.expr import ScoperExprVisitor
from transpiler.phases.typing.scope import Scope, VarDecl, VarKind, ScopeKind
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable
from transpiler.phases.typing.types import BaseType, TypeOperator, FunctionType, TY_VARARG, TypeType, TypeVariable, \
MemberDef
from transpiler.phases.utils import NodeVisitorSeq
......@@ -36,7 +37,7 @@ class StdlibVisitor(NodeVisitorSeq):
if isinstance(self.cur_class.type_object, ABCMeta):
raise NotImplementedError
else:
self.cur_class.type_object.members[node.target.id] = ty.gen_sub(self.cur_class.type_object, self.typevars)
self.cur_class.type_object.fields[node.target.id] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, ty)
def visit_ImportFrom(self, node: ast.ImportFrom):
......@@ -110,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq):
if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else:
self.cur_class.type_object.methods[node.name] = ty.gen_sub(self.cur_class.type_object, self.typevars)
self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars))
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert):
......
......@@ -13,12 +13,36 @@ def get_default_parents():
return [obj]
return []
class RuntimeValue:
pass
@dataclass
class MemberDef:
type: "BaseType"
val: typing.Any = RuntimeValue()
in_class_def: bool = True
@dataclass
class UnifyMode:
search_hierarchy: bool = True
match_protocol: bool = True
UnifyMode.NORMAL = UnifyMode()
UnifyMode.EXACT = UnifyMode(False, False)
@dataclass(eq=False)
class BaseType(ABC):
members: Dict[str, "BaseType"] = field(default_factory=dict, init=False)
methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
#members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
#methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
fields: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
parents: List["BaseType"] = field(default_factory=get_default_parents, init=False)
typevars: List["TypeVariable"] = field(default_factory=list, init=False)
#static_members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
def get_members(self):
return {n: m for n, m in self.fields.items() if type(m.val) is RuntimeValue}
def get_parents(self) -> List["BaseType"]:
......@@ -41,21 +65,29 @@ class BaseType(ABC):
queue.put(p)
def inherits_from(self, other: "BaseType") -> bool:
return other in self.iter_hierarchy_recursive()
from transpiler.exceptions import CompileError
for parent in self.iter_hierarchy_recursive():
try:
parent.unify(other, UnifyMode.EXACT)
except CompileError:
pass
else:
return True
return False
def resolve(self) -> "BaseType":
return self
@abstractmethod
def unify_internal(self, other: "BaseType"):
def unify_internal(self, other: "BaseType", mode: UnifyMode):
pass
def unify(self, other: "BaseType"):
def unify(self, other: "BaseType", mode = UnifyMode.NORMAL):
a, b = self.resolve(), other.resolve()
TB = f"unifying {highlight(a)} and {highlight(b)}"
if isinstance(b, TypeVariable):
a, b = b, a
a.unify_internal(b)
a.unify_internal(b, mode)
def contains(self, other: "BaseType") -> bool:
needle, haystack = other.resolve(), self.resolve()
......@@ -86,7 +118,7 @@ class MagicType(BaseType, typing.Generic[T]):
super().__init__()
self.val = val
def unify_internal(self, other: "BaseType"):
def unify_internal(self, other: "BaseType", mode: UnifyMode):
if type(self) != type(other) or self.val != other.val:
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
......@@ -128,7 +160,7 @@ class TypeVariable(BaseType):
return self
return self.resolved.resolve()
def unify_internal(self, other: BaseType):
def unify_internal(self, other: BaseType, mode: UnifyMode):
if self is not other:
if other.contains(self):
from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError
......@@ -178,19 +210,19 @@ class TypeOperator(BaseType, ABC):
if self.name is None:
self.name = self.__class__.__name__
for name, factory in self.gen_methods.items():
self.methods[name] = factory(self)
self.fields[name] = MemberDef(factory(self))
for gp in self.gen_parents:
if not isinstance(gp, BaseType):
gp = gp(self.args)
self.parents.append(gp)
self.methods = {**gp.methods, **self.methods}
self.fields = {**gp.fields, **self.fields}
self.is_protocol = self.is_protocol or self.is_protocol_gen
self._add_default_eq()
def _add_default_eq(self):
if "__eq__" not in self.methods:
if "__eq__" not in self.fields:
if "DEFAULT_EQ" in globals():
self.methods["__eq__"] = DEFAULT_EQ
self.fields["__eq__"] = MemberDef(DEFAULT_EQ)
def matches_protocol(self, protocol: "TypeOperator"):
if hash(protocol) in self.match_cache:
......@@ -199,33 +231,35 @@ class TypeOperator(BaseType, ABC):
try:
dupl = protocol.gen_sub(self, {v.name: (TypeVariable(v.name) if isinstance(v.resolve(), TypeVariable) else v) for v in protocol.args})
self.match_cache.add(hash(protocol))
for name, ty in dupl.methods.items():
for name, ty in dupl.fields.items():
if name == "__eq__":
continue
if name not in self.methods:
if name not in self.fields:
raise ProtocolMismatchError(self, protocol, f"missing method {name}")
corresp = self.methods[name]
corresp.remove_self().unify(ty.remove_self())
corresp = self.fields[name].type
corresp.remove_self().unify(ty.type.remove_self())
except TypeMismatchError as e:
if hash(protocol) in self.match_cache:
self.match_cache.remove(hash(protocol))
raise ProtocolMismatchError(self, protocol, e)
def unify_internal(self, other: BaseType):
def unify_internal(self, other: BaseType, mode: UnifyMode):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
# TODO(zdimension): this is really broken... but it would be nice
# if from_node := next(filter(None, (getattr(x, "from_node", None) for x in (other, self))), None):
# TB_NODE = from_node
if not isinstance(other, TypeOperator):
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if other.is_protocol and not self.is_protocol:
return other.unify_internal(self)
if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message
if mode.match_protocol:
if other.is_protocol and not self.is_protocol:
return other.unify_internal(self, mode)
if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message
assert self.is_protocol == other.is_protocol
if type(self) != type(other): # and ((TY_NONE not in {self, other}) or isinstance(({self, other} - {TY_NONE}).pop(), UnionType)):
if self.inherits_from(other) or other.inherits_from(self):
return
if mode.search_hierarchy:
if self.inherits_from(other) or other.inherits_from(self):
return
# for parent in other.get_parents():
# try:
# self.unify(parent)
......@@ -242,8 +276,8 @@ class TypeOperator(BaseType, ABC):
# return
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if len(self.args) < len(other.args):
return other.unify_internal(self)
if len(self.args) == 0:
return other.unify_internal(self, mode)
if True or len(self.args) == 0: # todo: why check len?
if self.name != other.name:
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
for i, (a, b) in enumerate(zip_longest(self.args, other.args)):
......@@ -292,7 +326,7 @@ class TypeOperator(BaseType, ABC):
for k, v in self.__dict__.items():
setattr(res, k, v)
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()}
res.fields = {k: dataclasses.replace(v, type=v.type.gen_sub(this, vardict, cache)) for k, v in self.fields.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
#res.is_protocol = self.is_protocol
return res
......@@ -466,10 +500,10 @@ class Promise(TypeOperator, ABC):
if value == PromiseKind.GENERATOR:
f_iter = FunctionType([], self)
f_iter.is_method = True
self.methods["__iter__"] = f_iter
self.fields["__iter__"] = MemberDef(f_iter, ())
f_next = FunctionType([], self.return_type)
f_next.is_method = True
self.methods["__next__"] = f_next
self.fields["__next__"] = MemberDef(f_next, ())
self.args[1].val = value
def __str__(self):
......@@ -506,7 +540,7 @@ class UserType(TypeOperator):
def __init__(self, name: str):
super().__init__([], name=name, is_reference=True)
def unify_internal(self, other: "BaseType"):
def unify_internal(self, other: "BaseType", mode: UnifyMode):
if type(self) != type(other):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment