Commit bd64b1c6 authored by Tom Niget's avatar Tom Niget

Emit forward declarations for free functions

parent 32b26d80
# coding: utf-8
# https://lab.nexedi.com/xavier_thompson/typon-snippets/blob/master/module/mymodule.cpp
def f():
return 1 + g()
def g():
return f()
class T:
def f(self):
return 1 + self.g()
def g(self):
return self.f()
if __name__ == "__main__":
pass
\ No newline at end of file
# coding: utf-8
import ast
import enum
from enum import Flag
from itertools import chain
from typing import Iterable
......@@ -94,6 +95,11 @@ class CoroutineMode(Flag):
TASK = 16 | ASYNC
JOIN = 32 | ASYNC
class FunctionEmissionKind(enum.Enum):
DECLARATION = enum.auto()
DEFINITION = enum.auto()
METHOD = enum.auto()
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items)
try:
......
......@@ -6,7 +6,7 @@ from typing import Iterable, Optional
from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise
from transpiler.utils import compare_ast
from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap
from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind
from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.emit_cpp.search import SearchVisitor
......@@ -26,7 +26,12 @@ class BlockVisitor(NodeVisitor):
yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from self.visit_free_func(node)
def visit_free_func(self, node: ast.FunctionDef, emission: FunctionEmissionKind) -> Iterable[str]:
if getattr(node, "is_main", False):
if emission == FunctionEmissionKind.DECLARATION:
return
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
......@@ -44,49 +49,16 @@ class BlockVisitor(NodeVisitor):
yield from FunctionVisitor(self.scope, CoroutineMode.TASK).emit_block(node.scope, block())
return
yield "struct {"
yield from self.visit_func_new(node)
if emission == FunctionEmissionKind.DECLARATION:
yield f"struct {node.name}_inner {{"
yield from self.visit_func_new(node, emission)
if emission == FunctionEmissionKind.DECLARATION:
yield f"}} {node.name};"
return
yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE)
class YieldVisitor(SearchVisitor):
def visit_Yield(self, node: ast.Yield) -> bool:
yield CoroutineMode.GENERATOR
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
def visit_Call(self, node: ast.Call):
func = node.func
if compare_ast(func, ast.parse("fork", mode="eval").body):
yield CoroutineMode.JOIN
yield from ()
func_type = YieldVisitor().match(node.body)
if func_type is False:
func_type = CoroutineMode.TASK
yield from self.visit_func(node, func_type)
if func_type == CoroutineMode.GENERATOR:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
yield f"auto operator()"
yield args
yield f"-> typon::Task<decltype(gen({', '.join(names)}))>"
yield "{"
yield f"co_return std::move(gen({', '.join(names)}));"
yield "}"
yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef, skip_first_arg: bool = False) -> Iterable[str]:
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()"
yield "("
args_iter = zip(node.args.args, node.type.parameters)
......@@ -101,6 +73,10 @@ class BlockVisitor(NodeVisitor):
inner_scope = node.inner_scope
if emission == FunctionEmissionKind.DECLARATION:
yield ";"
return
yield "{"
class ReturnVisitor(SearchVisitor):
......@@ -133,123 +109,6 @@ class BlockVisitor(NodeVisitor):
yield "}"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = ReturnVisitor().match(node.body)
if CoroutineMode.SYNC in generator:
if has_return:
yield "auto"
else:
yield "void"
yield "sync"
elif CoroutineMode.GENERATOR in generator:
yield "auto gen"
else:
yield "auto operator()"
yield args
if CoroutineMode.ASYNC in generator:
yield "-> typon::"
if CoroutineMode.TASK in generator:
yield "Task"
elif CoroutineMode.GENERATOR in generator:
yield "Generator"
elif CoroutineMode.JOIN in generator:
yield "Join"
yield f"<decltype(sync({', '.join(names)}))>"
yield "{"
inner_scope = node.inner_scope
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global
# scope if outside a function.
# This is different from C++, where scope is tied to any code block. To emulate this behavior, we need to
# declare all variables in the first inner scope of a function.
# For example,
# ```py
# def f():
# if True:
# x = 1
# print(x)
# ```
# is translated to
# ```cpp
# auto f() {
# decltype(1) x;
# if (true) {
# x = 1;
# }
# print(x);
# }
# ```
# `decltype` allows for proper typing (`auto` can't be used for variables with values later assigned, since
# this would require real type inference, akin to what Rust does).
# This is only done, though, for *nested* blocks of a function. Root-level variables are declared with
# `auto`:
# ```py
# x = 1
# def f():
# y = 2
# ```
# is translated to
# ```cpp
# auto x = 1;
# auto f() {
# auto y = 2;
# }
# ```
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator)
if True:
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
else:
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough.
# TODO(zdimension): break this in two visitors
[*child_code] = child_visitor.visit(child)
# Hoist inner variables to the root scope.
for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
if getattr(decl.val[1], "in_await", False):
# TODO(zdimension): really?
yield f"decltype({decl.val[0][9:]}.operator co_await().await_resume()) {var};"
else:
yield f"decltype({decl.val[0]}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code.
if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif CoroutineMode.ASYNC in generator and CoroutineMode.GENERATOR not in generator:
if not has_return:
yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, declare: bool = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
......
......@@ -4,7 +4,7 @@ from typing import Iterable
from dataclasses import dataclass
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import NodeVisitor
from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind
class ClassVisitor(NodeVisitor):
......@@ -47,7 +47,7 @@ class ClassInnerVisitor(NodeVisitor):
yield "struct {"
yield "type* self;"
from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, True)
yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD, True)
yield f"}} {node.name} {{ this }};"
@dataclass
......
......@@ -3,7 +3,7 @@ import ast
from typing import Iterable
from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.module import ModuleVisitor
from transpiler.phases.emit_cpp.module import ModuleVisitor, ModuleVisitor2
# noinspection PyPep8Naming
......@@ -17,6 +17,9 @@ class FileVisitor(BlockVisitor):
yield from visitor.includes
yield "namespace PROGRAMNS {"
yield from code
visitor = ModuleVisitor2(self.scope)
code = [line for stmt in node.body for line in visitor.visit(stmt)]
yield from code
yield "}"
yield "int main(int argc, char* argv[]) {"
yield "py_sys::all.argv = PyList<PyStr>(std::vector<PyStr>(argv, argv + argc));"
......
......@@ -3,8 +3,8 @@ import ast
from typing import Iterable
from dataclasses import dataclass, field
from transpiler.phases.emit_cpp import CoroutineMode
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind, NodeVisitor
from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.class_ import ClassVisitor
from transpiler.phases.emit_cpp.function import FunctionVisitor
......@@ -46,3 +46,17 @@ class ModuleVisitor(BlockVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield from ClassVisitor().visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from super().visit_free_func(node, FunctionEmissionKind.DECLARATION)
@dataclass
class ModuleVisitor2(NodeVisitor):
scope: Scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from BlockVisitor(self.scope).visit_free_func(node, FunctionEmissionKind.DEFINITION)
def visit_AST(self, node: ast.AST) -> Iterable[str]:
yield ""
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment