Commit ff8c2374 authored by Tom Niget's avatar Tom Niget

Add preliminary support for dataclass __init__ generation

parent c94d3e59
# coding: utf-8
dataclass: BuiltinFeature["dataclass"]
\ No newline at end of file
......@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from transpiler.phases.typing.stdlib import PRELUDE, StdlibVisitor
from transpiler.phases.typing.types import TY_TYPE, TY_INT, TY_STR, TY_BOOL, TY_COMPLEX, TY_NONE, FunctionType, \
TypeVariable, CppType, PyList, TypeType, Forked, Task, Future, PyIterator, TupleType, TypeOperator, BaseType, \
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT
ModuleType, TY_BYTES, TY_FLOAT, PyDict, TY_SLICE, TY_OBJECT, BuiltinFeature
PRELUDE.vars.update({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
......@@ -36,6 +36,7 @@ PRELUDE.vars.update({
"tuple": VarDecl(VarKind.LOCAL, TypeType(TupleType)),
"slice": VarDecl(VarKind.LOCAL, TypeType(TY_SLICE)),
"object": VarDecl(VarKind.LOCAL, TypeType(TY_OBJECT)),
"BuiltinFeature": VarDecl(VarKind.LOCAL, TypeType(BuiltinFeature)),
})
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
......
......@@ -6,12 +6,12 @@ from dataclasses import dataclass
from transpiler.exceptions import CompileError
from transpiler.utils import highlight, linenodata
from transpiler.phases.typing import make_mod_decl
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from transpiler.phases.typing.class_ import ScoperClassVisitor
from transpiler.phases.typing.scope import VarDecl, VarKind, ScopeKind, Scope
from transpiler.phases.typing.types import BaseType, TypeVariable, FunctionType, \
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType
Promise, TY_NONE, PromiseKind, TupleType, UserType, TypeType, ModuleType, BuiltinFeature
from transpiler.phases.utils import PlainBlock, AnnotationName
......@@ -178,6 +178,38 @@ class ScoperBlockVisitor(ScoperVisitor):
node.type = ctype
visitor = ScoperClassVisitor(scope, cur_class=cttype)
visitor.visit_block(node.body)
for deco in node.decorator_list:
deco = self.expr().visit(deco)
if isinstance(deco, BuiltinFeature) and deco.val == "dataclass":
# init_type = FunctionType([cttype, *cttype.members.values()], TypeVariable())
# cttype.methods["__init__"] = init_type
lnd = linenodata(node)
init_method = ast.FunctionDef(
name="__init__",
args=ast.arguments(
args=[ast.arg(arg="self"), * [ast.arg(arg=n) for n in ctype.members]],
defaults=[],
kw_defaults=[],
kwarg=None,
kwonlyargs=[],
posonlyargs=[],
),
body=[
ast.Assign(
targets=[ast.Attribute(value=ast.Name(id="self"), attr=n)],
value=ast.Name(id=n),
**lnd
) for n in ctype.members
],
decorator_list=[],
returns=None,
type_comment=None,
**lnd
)
_, rtype = visitor.visit_FunctionDef(init_method)
visitor.visit_function_definition(init_method, rtype)
else:
raise NotImplementedError(deco)
def visit_If(self, node: ast.If):
scope = self.scope.child(ScopeKind.FUNCTION_INNER)
......@@ -217,16 +249,8 @@ class ScoperBlockVisitor(ScoperVisitor):
var_var = TypeVariable()
scope.vars[node.target.id] = VarDecl(VarKind.LOCAL, var_var)
seq_type = self.expr().visit(node.iter)
try:
iter_type = seq_type.methods["__iter__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
try:
next_type = iter_type.methods["__next__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
iter_type = get_iter(seq_type)
next_type = get_next(iter_type)
var_var.unify(next_type)
body_scope = scope.child(ScopeKind.FUNCTION_INNER)
body_visitor = ScoperBlockVisitor(body_scope, self.root_decls)
......
......@@ -35,4 +35,6 @@ class ScoperClassVisitor(ScoperVisitor):
node.type = ftype
for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
self.fdecls.append((node, inner_rtype))
res = (node, inner_rtype)
self.fdecls.append(res)
return res
......@@ -95,7 +95,10 @@ class MagicType(BaseType, typing.Generic[T]):
return str(self.val)
def clone(self) -> "BaseType":
return MagicType(self.val)
return type(self)(self.val)
class BuiltinFeature(MagicType):
pass
cur_var = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment