Commit ff8c2374 authored by Tom Niget's avatar Tom Niget

Add preliminary support for dataclass __init__ generation

parent c94d3e59
# coding: utf-8
dataclass: BuiltinFeature["dataclass"]
\ No newline at end of file
...@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope ...@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \ from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \ TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature
PRELUDE.vars.update({ PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT), # "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
...@@ -36,6 +36,7 @@ PRELUDE.vars.update({ ...@@ -36,6 +36,7 @@ PRELUDE.vars.update({
"tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)), "tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
"slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)), "slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)),
"object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)), "object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
"BuiltinFeature": VarDecl(VarKind.LOCAL, TypeType(BuiltinFeature)),
}) })
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib" typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
......
...@@ -6,12 +6,12 @@ from dataclasses import dataclass ...@@ -6,12 +6,12 @@ from dataclasses import dataclass
from transpiler.exceptions import CompileError from transpiler.exceptions import CompileError
from transpiler.utils import highlight, linenodata from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \ from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature
from transpiler.phases.utils import PlainBlock, AnnotationName from transpiler.phases.utils import PlainBlock, AnnotationName
...@@ -178,6 +178,38 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -178,6 +178,38 @@ class ScoperBlockVisitor(ScoperVisitor):
node.type = ctype node.type = ctype
visitor = ScoperClassVisitor(scope, cur_class=cttype) visitor = ScoperClassVisitor(scope, cur_class=cttype)
visitor.visit_block(node.body) visitor.visit_block(node.body)
for deco in node.decorator_list:
deco = self.expr().visit(deco)
if isinstance(deco, BuiltinFeature) and deco.val == "dataclass":
# init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable())
# cttype.methods["__init__"] = init_type
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.members]],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n),
**lnd
) for n in ctype.members
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
else:
raise NotImplementedError(deco)
def visit_If(self, node: ast.If): def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER) scope = self.scope.child(ScopeKind.FUNCTION_INNER)
...@@ -217,16 +249,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -217,16 +249,8 @@ class ScoperBlockVisitor(ScoperVisitor):
var_var = TypeVariable() var_var = TypeVariable()
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var) scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var)
seq_type = self.expr().visit(node.iter) seq_type = self.expr().visit(node.iter)
try: iter_type = get_iter(seq_type)
iter_type = seq_type.methods["__iter__"].return_type next_type = get_next(iter_type)
except:
from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
try:
next_type = iter_type.methods["__next__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
var_var.unify(next_type) var_var.unify(next_type)
body_scope = scope.child(ScopeKind.FUNCTION_INNER) body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls) body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
......
...@@ -35,4 +35,6 @@ class ScoperClassVisitor(ScoperVisitor): ...@@ -35,4 +35,6 @@ class ScoperClassVisitor(ScoperVisitor):
node.type = ftype node.type = ftype
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, inner_rtype)) res = (node, inner_rtype)
self.fdecls.append(res)
return res
...@@ -95,7 +95,10 @@ class MagicType(BaseType, typing.Generic[T]): ...@@ -95,7 +95,10 @@ class MagicType(BaseType, typing.Generic[T]):
return str(self.val) return str(self.val)
def clone(self) -> "BaseType": def clone(self) -> "BaseType":
return MagicType(self.val) return type(self)(self.val)
class BuiltinFeature(MagicType):
pass
cur_var = 0 cur_var = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment