Commit 85f54768 authored by Tom Niget's avatar Tom Niget

Make first things work

parent 54759ed9
Subproject commit 285703a35af4ca1476ce30e2404b73af9d880ec5
Subproject commit 79677d125f915f7c61492d8d1d8cde9fc6a11875
......@@ -79,71 +79,79 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
if len(node.values) == 1:
yield from self.visit(node.values[0])
return
cpp_op = {
ast.And: "&&",
ast.Or: "||"
}[type(node.op)]
with self.prec_ctx(cpp_op):
yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
for left, right in zip(node.values[1:], node.values[2:]):
yield f" {cpp_op} "
yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
raise NotImplementedError()
# if len(node.values) == 1:
# yield from self.visit(node.values[0])
# return
# cpp_op = {
# ast.And: "&&",
# ast.Or: "||"
# }[type(node.op)]
# with self.prec_ctx(cpp_op):
# yield from self.visit_binary_operation(cpp_op, node.values[0], node.values[1], make_lnd(node.values[0], node.values[1]))
# for left, right in zip(node.values[1:], node.values[2:]):
# yield f" {cpp_op} "
# yield from self.visit_binary_operation(cpp_op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]:
yield "("
yield from self.visit(node.func)
yield ")("
yield from join(", ", map(self.visit, node.args))
yield ")"
#raise NotImplementedError()
# TODO
# if getattr(node, "keywords", None):
# raise NotImplementedError(node, "keywords")
if getattr(node, "starargs", None):
raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs")
func = node.func
if isinstance(func, ast.Attribute):
if sym := DUNDER_SYMBOLS.get(func.attr, None):
if len(node.args) == 1:
yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
else:
yield from self.visit_unary_operation(sym, func.value)
return
for name in ("fork", "future"):
if compare_ast(func, ast.parse(name, mode="eval").body):
assert len(node.args) == 1
arg = node.args[0]
assert isinstance(arg, ast.Lambda)
node.is_future = name
vis = self.reset()
vis.generator = CoroutineMode.SYNC
# todo: bad code
if CoroutineMode.ASYNC in self.generator:
yield f"co_await typon::{name}("
yield from vis.visit(arg.body)
yield ")"
return
elif CoroutineMode.FAKE in self.generator:
yield from self.visit(arg.body)
return
if compare_ast(func, ast.parse('sync', mode="eval").body):
if CoroutineMode.ASYNC in self.generator:
yield "co_await typon::Sync()"
elif CoroutineMode.FAKE in self.generator:
yield from ()
return
# TODO: precedence needed?
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield "(" # TODO: temporary
yield "co_await "
node.in_await = True
elif CoroutineMode.FAKE in self.generator:
func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
yield from self.prec("()").visit(func)
yield "("
yield from join(", ", map(self.reset().visit, node.args))
yield ")"
if CoroutineMode.ASYNC in self.generator and node.is_await:
yield ")"
# if getattr(node, "starargs", None):
# raise NotImplementedError(node, "varargs")
# if getattr(node, "kwargs", None):
# raise NotImplementedError(node, "kwargs")
# func = node.func
# if isinstance(func, ast.Attribute):
# if sym := DUNDER_SYMBOLS.get(func.attr, None):
# if len(node.args) == 1:
# yield from self.visit_binary_operation(sym, func.value, node.args[0], linenodata(node))
# else:
# yield from self.visit_unary_operation(sym, func.value)
# return
# for name in ("fork", "future"):
# if compare_ast(func, ast.parse(name, mode="eval").body):
# assert len(node.args) == 1
# arg = node.args[0]
# assert isinstance(arg, ast.Lambda)
# node.is_future = name
# vis = self.reset()
# vis.generator = CoroutineMode.SYNC
# # todo: bad code
# if CoroutineMode.ASYNC in self.generator:
# yield f"co_await typon::{name}("
# yield from vis.visit(arg.body)
# yield ")"
# return
# elif CoroutineMode.FAKE in self.generator:
# yield from self.visit(arg.body)
# return
# if compare_ast(func, ast.parse('sync', mode="eval").body):
# if CoroutineMode.ASYNC in self.generator:
# yield "co_await typon::Sync()"
# elif CoroutineMode.FAKE in self.generator:
# yield from ()
# return
# # TODO: precedence needed?
# if CoroutineMode.ASYNC in self.generator and node.is_await:
# yield "(" # TODO: temporary
# yield "co_await "
# node.in_await = True
# elif CoroutineMode.FAKE in self.generator:
# func = ast.Attribute(value=func, attr="sync", ctx=ast.Load())
# yield from self.prec("()").visit(func)
# yield "("
# yield from join(", ", map(self.reset().visit, node.args))
# yield ")"
# if CoroutineMode.ASYNC in self.generator and node.is_await:
# yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
yield "[]"
......@@ -157,12 +165,15 @@ class ExpressionVisitor(NodeVisitor):
yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
raise NotImplementedError()
yield from self.visit_binary_operation(node.ops[0], node.left, node.comparators[0], linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
raise NotImplementedError()
# if type(op) == ast.In:
# call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
# call.is_await = False
......
......@@ -9,7 +9,7 @@ from transpiler.phases.typing.types import CallableInstanceType, BaseType
def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield f"struct : function {{"
yield f"struct : referencemodel::function {{"
yield "typon::Task<void> operator()("
for arg, ty in zip(func.block_data.node.args.args, func.parameters):
......@@ -17,6 +17,8 @@ def emit_function(name: str, func: CallableInstanceType) -> Iterable[str]:
yield arg
yield ") const {"
yield from BlockVisitor(func.block_data.scope, generator=CoroutineMode.TASK).visit(func.block_data.node.body)
yield "co_return {};"
yield "}"
yield f"}} static constexpr {name} {{}};"
yield f"static_assert(sizeof {name} == 1);"
......@@ -34,6 +36,10 @@ class BlockVisitor(NodeVisitor):
def visit_Pass(self, node: ast.Pass) -> Iterable[str]:
yield ";"
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from self.expr().visit(node.value)
yield ";"
# def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
# yield from self.visit_free_func(node)
......@@ -66,111 +72,111 @@ class BlockVisitor(NodeVisitor):
# if emission == FunctionEmissionKind.DECLARATION:
# yield f"}} {node.name};"
def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]:
for child in body:
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator=mode)
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
def visit_func_params(self, args: Iterable[tuple[str, BaseType, Optional[ast.expr]]], emission: FunctionEmissionKind) -> Iterable[str]:
for i, (arg, argty, default) in enumerate(args):
if i != 0:
yield ", "
if emission == FunctionEmissionKind.METHOD and i == 0:
yield "Self"
else:
yield from self.visit(argty)
yield arg
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
yield " = "
yield from self.expr().visit(default)
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.LAMBDA:
yield "[&]"
else:
if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>"
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()"
yield "("
padded_defaults = [None] * (len(node.args.args) if node.type.optional_at is None else node.type.optional_at) + node.args.defaults
args_iter = zip(node.args.args, node.type.parameters, padded_defaults)
if skip_first_arg:
next(args_iter)
yield from self.visit_func_params(((arg.arg, argty, default) for arg, argty, default in args_iter), emission)
yield ")"
if emission == FunctionEmissionKind.METHOD:
yield "const"
inner_scope = node.inner_scope
if emission == FunctionEmissionKind.DECLARATION:
yield ";"
return
if emission == FunctionEmissionKind.LAMBDA:
yield "->"
yield from self.visit(node.type.return_type)
yield "{"
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = ReturnVisitor().match(node.body)
yield from self.visit_func_decls(node.body, inner_scope)
# if not has_return and isinstance(node.type.return_type, Promise):
# yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
if decl:
yield from self.visit_lvalue(name, True)
yield ";"
yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
elif isinstance(lvalue, ast.Name):
if lvalue.id == "_":
if not declare:
yield "std::ignore"
return
name = self.fix_name(lvalue.id)
# if name not in self._scope.vars:
# if not self.scope.exists_local(name):
# yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# getattr(val, "is_future", False))
if declare:
yield from self.visit(lvalue.type)
yield name
elif isinstance(lvalue, ast.Subscript):
yield from self.expr().visit(lvalue)
elif isinstance(lvalue, ast.Attribute):
yield from self.expr().visit(lvalue)
else:
raise NotImplementedError(lvalue)
# def visit_func_decls(self, body: list[ast.stmt], inner_scope: Scope, mode = CoroutineMode.ASYNC) -> Iterable[str]:
# for child in body:
# from transpiler.phases.emit_cpp.function import FunctionVisitor
# child_visitor = FunctionVisitor(inner_scope, generator=mode)
#
# for name, decl in getattr(child, "decls", {}).items():
# #yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
# yield from self.visit(decl.type)
# yield f" {name};"
# yield from child_visitor.visit(child)
#
# def visit_func_params(self, args: Iterable[tuple[str, BaseType, Optional[ast.expr]]], emission: FunctionEmissionKind) -> Iterable[str]:
# for i, (arg, argty, default) in enumerate(args):
# if i != 0:
# yield ", "
# if emission == FunctionEmissionKind.METHOD and i == 0:
# yield "Self"
# else:
# yield from self.visit(argty)
# yield arg
# if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
# yield " = "
# yield from self.expr().visit(default)
#
# def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
# if emission == FunctionEmissionKind.LAMBDA:
# yield "[&]"
# else:
# if emission == FunctionEmissionKind.METHOD:
# yield "template <typename Self>"
# yield from self.visit(node.type.return_type)
# if emission == FunctionEmissionKind.DEFINITION:
# yield f"{node.name}_inner::"
# yield "operator()"
# yield "("
# padded_defaults = [None] * (len(node.args.args) if node.type.optional_at is None else node.type.optional_at) + node.args.defaults
# args_iter = zip(node.args.args, node.type.parameters, padded_defaults)
# if skip_first_arg:
# next(args_iter)
# yield from self.visit_func_params(((arg.arg, argty, default) for arg, argty, default in args_iter), emission)
# yield ")"
#
# if emission == FunctionEmissionKind.METHOD:
# yield "const"
#
# inner_scope = node.inner_scope
#
# if emission == FunctionEmissionKind.DECLARATION:
# yield ";"
# return
#
# if emission == FunctionEmissionKind.LAMBDA:
# yield "->"
# yield from self.visit(node.type.return_type)
#
# yield "{"
#
# class ReturnVisitor(SearchVisitor):
# def visit_Return(self, node: ast.Return) -> bool:
# yield True
#
# def visit_Yield(self, node: ast.Yield) -> bool:
# yield True
#
# def visit_FunctionDef(self, node: ast.FunctionDef):
# yield from ()
#
# def visit_ClassDef(self, node: ast.ClassDef):
# yield from ()
#
# has_return = ReturnVisitor().match(node.body)
#
# yield from self.visit_func_decls(node.body, inner_scope)
#
# # if not has_return and isinstance(node.type.return_type, Promise):
# # yield "co_return;"
#
# yield "}"
#
# def visit_lvalue(self, lvalue: ast.expr, declare: bool | list[bool] = False) -> Iterable[str]:
# if isinstance(lvalue, ast.Tuple):
# for name, decl, ty in zip(lvalue.elts, declare, lvalue.type.args):
# if decl:
# yield from self.visit_lvalue(name, True)
# yield ";"
# yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
# elif isinstance(lvalue, ast.Name):
# if lvalue.id == "_":
# if not declare:
# yield "std::ignore"
# return
# name = self.fix_name(lvalue.id)
# # if name not in self._scope.vars:
# # if not self.scope.exists_local(name):
# # yield self.scope.declare(name, (" ".join(self.expr().visit(val)), val) if val else None,
# # getattr(val, "is_future", False))
# if declare:
# yield from self.visit(lvalue.type)
# yield name
# elif isinstance(lvalue, ast.Subscript):
# yield from self.expr().visit(lvalue)
# elif isinstance(lvalue, ast.Attribute):
# yield from self.expr().visit(lvalue)
# else:
# raise NotImplementedError(lvalue)
def visit_Assign(self, node: ast.Assign) -> Iterable[str]:
if len(node.targets) != 1:
......
......@@ -6,7 +6,7 @@ from typing import Iterable
import transpiler.phases.typing.types as types
from transpiler.phases.typing.exceptions import UnresolvedTypeVariableError
from transpiler.phases.typing.types import BaseType
from transpiler.utils import UnsupportedNodeError
from transpiler.utils import UnsupportedNodeError, highlight
class UniversalVisitor:
......@@ -48,7 +48,8 @@ class NodeVisitor(UniversalVisitor):
def fix_name(self, name: str) -> str:
if name.startswith("__") and name.endswith("__"):
return f"py_{name[2:-2]}"
return MAPPINGS.get(name, name)
return name
#return MAPPINGS.get(name, name)
def visit_BaseType(self, node: BaseType) -> Iterable[str]:
node = node.resolve()
......
......@@ -54,7 +54,7 @@ def transpile(source, name: str, path: Path):
# yield from code
# yield "}"
yield "#else"
yield "typon::Root root() const {"
yield "typon::Root root() {"
yield f"co_await dot(PROGRAMNS::{module.name()}, main)();"
yield "}"
yield "int main(int argc, char* argv[]) {"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment