Commit bd64b1c6 authored by Tom Niget's avatar Tom Niget

Emit forward declarations for free functions

parent 32b26d80
# coding: utf-8
# https://lab.nexedi.com/xavier_thompson/typon-snippets/blob/master/module/mymodule.cpp
def f():
return 1 + g()
def g():
return f()
class T:
def f(self):
return 1 + self.g()
def g(self):
return self.f()
if __name__ == "__main__":
pass
\ No newline at end of file
# coding: utf-8 # coding: utf-8
import ast import ast
import enum
from enum import Flag from enum import Flag
from itertools import chain from itertools import chain
from typing import Iterable from typing import Iterable
...@@ -94,6 +95,11 @@ class CoroutineMode(Flag): ...@@ -94,6 +95,11 @@ class CoroutineMode(Flag):
TASK = 16 | ASYNC TASK = 16 | ASYNC
JOIN = 32 | ASYNC JOIN = 32 | ASYNC
class FunctionEmissionKind(enum.Enum):
DECLARATION = enum.auto()
DEFINITION = enum.auto()
METHOD = enum.auto()
def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]: def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
items = iter(items) items = iter(items)
try: try:
......
...@@ -6,7 +6,7 @@ from typing import Iterable, Optional ...@@ -6,7 +6,7 @@ from typing import Iterable, Optional
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise from transpiler.phases.typing.types import BaseType, TY_INT, TY_BOOL, TypeVariable, Promise
from transpiler.utils import compare_ast from transpiler.utils import compare_ast
from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap from transpiler.phases.emit_cpp import NodeVisitor, CoroutineMode, flatmap, FunctionEmissionKind
from transpiler.phases.emit_cpp.expr import ExpressionVisitor from transpiler.phases.emit_cpp.expr import ExpressionVisitor
from transpiler.phases.emit_cpp.search import SearchVisitor from transpiler.phases.emit_cpp.search import SearchVisitor
...@@ -26,7 +26,12 @@ class BlockVisitor(NodeVisitor): ...@@ -26,7 +26,12 @@ class BlockVisitor(NodeVisitor):
yield ";" yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from self.visit_free_func(node)
def visit_free_func(self, node: ast.FunctionDef, emission: FunctionEmissionKind) -> Iterable[str]:
if getattr(node, "is_main", False): if getattr(node, "is_main", False):
if emission == FunctionEmissionKind.DECLARATION:
return
# Special case handling for Python's interesting way of defining an entry point. # Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting # I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are # language that, by essence, uses "the start of the file" as the implicit entry point, since files are
...@@ -44,49 +49,16 @@ class BlockVisitor(NodeVisitor): ...@@ -44,49 +49,16 @@ class BlockVisitor(NodeVisitor):
yield from FunctionVisitor(self.scope, CoroutineMode.TASK).emit_block(node.scope, block()) yield from FunctionVisitor(self.scope, CoroutineMode.TASK).emit_block(node.scope, block())
return return
yield "struct {" if emission == FunctionEmissionKind.DECLARATION:
yield from self.visit_func_new(node) yield f"struct {node.name}_inner {{"
yield f"}} {node.name};" yield from self.visit_func_new(node, emission)
return if emission == FunctionEmissionKind.DECLARATION:
yield f"}} {node.name};"
yield "struct {"
yield from self.visit_func(node, CoroutineMode.FAKE)
class YieldVisitor(SearchVisitor):
def visit_Yield(self, node: ast.Yield) -> bool:
yield CoroutineMode.GENERATOR
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
def visit_Call(self, node: ast.Call):
func = node.func
if compare_ast(func, ast.parse("fork", mode="eval").body):
yield CoroutineMode.JOIN
yield from ()
func_type = YieldVisitor().match(node.body)
if func_type is False:
func_type = CoroutineMode.TASK
yield from self.visit_func(node, func_type)
if func_type == CoroutineMode.GENERATOR:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
yield f"auto operator()"
yield args
yield f"-> typon::Task<decltype(gen({', '.join(names)}))>"
yield "{"
yield f"co_return std::move(gen({', '.join(names)}));"
yield "}"
yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef, skip_first_arg: bool = False) -> Iterable[str]: def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
yield from self.visit(node.type.return_type) yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::"
yield "operator()" yield "operator()"
yield "(" yield "("
args_iter = zip(node.args.args, node.type.parameters) args_iter = zip(node.args.args, node.type.parameters)
...@@ -101,6 +73,10 @@ class BlockVisitor(NodeVisitor): ...@@ -101,6 +73,10 @@ class BlockVisitor(NodeVisitor):
inner_scope = node.inner_scope inner_scope = node.inner_scope
if emission == FunctionEmissionKind.DECLARATION:
yield ";"
return
yield "{" yield "{"
class ReturnVisitor(SearchVisitor): class ReturnVisitor(SearchVisitor):
...@@ -133,123 +109,6 @@ class BlockVisitor(NodeVisitor): ...@@ -133,123 +109,6 @@ class BlockVisitor(NodeVisitor):
yield "}" yield "}"
def visit_func(self, node: ast.FunctionDef, generator: CoroutineMode) -> Iterable[str]:
templ, args, names = self.process_args(node.args)
if templ:
yield "template"
yield templ
class ReturnVisitor(SearchVisitor):
def visit_Return(self, node: ast.Return) -> bool:
yield True
def visit_Yield(self, node: ast.Yield) -> bool:
yield True
def visit_FunctionDef(self, node: ast.FunctionDef):
yield from ()
def visit_ClassDef(self, node: ast.ClassDef):
yield from ()
has_return = ReturnVisitor().match(node.body)
if CoroutineMode.SYNC in generator:
if has_return:
yield "auto"
else:
yield "void"
yield "sync"
elif CoroutineMode.GENERATOR in generator:
yield "auto gen"
else:
yield "auto operator()"
yield args
if CoroutineMode.ASYNC in generator:
yield "-> typon::"
if CoroutineMode.TASK in generator:
yield "Task"
elif CoroutineMode.GENERATOR in generator:
yield "Generator"
elif CoroutineMode.JOIN in generator:
yield "Join"
yield f"<decltype(sync({', '.join(names)}))>"
yield "{"
inner_scope = node.inner_scope
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global
# scope if outside a function.
# This is different from C++, where scope is tied to any code block. To emulate this behavior, we need to
# declare all variables in the first inner scope of a function.
# For example,
# ```py
# def f():
# if True:
# x = 1
# print(x)
# ```
# is translated to
# ```cpp
# auto f() {
# decltype(1) x;
# if (true) {
# x = 1;
# }
# print(x);
# }
# ```
# `decltype` allows for proper typing (`auto` can't be used for variables with values later assigned, since
# this would require real type inference, akin to what Rust does).
# This is only done, though, for *nested* blocks of a function. Root-level variables are declared with
# `auto`:
# ```py
# x = 1
# def f():
# y = 2
# ```
# is translated to
# ```cpp
# auto x = 1;
# auto f() {
# auto y = 2;
# }
# ```
from transpiler.phases.emit_cpp.function import FunctionVisitor
child_visitor = FunctionVisitor(inner_scope, generator)
if True:
for name, decl in getattr(child, "decls", {}).items():
#yield f"decltype({' '.join(self.expr().visit(decl.type))}) {name};"
yield from self.visit(decl.type)
yield f" {name};"
yield from child_visitor.visit(child)
else:
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough.
# TODO(zdimension): break this in two visitors
[*child_code] = child_visitor.visit(child)
# Hoist inner variables to the root scope.
for var, decl in child_visitor.scope.vars.items():
if decl.kind == VarKind.LOCAL: # Nested declarations become `decltype` declarations.
if getattr(decl.val[1], "in_await", False):
# TODO(zdimension): really?
yield f"decltype({decl.val[0][9:]}.operator co_await().await_resume()) {var};"
else:
yield f"decltype({decl.val[0]}) {var};"
elif decl.kind in (VarKind.GLOBAL, VarKind.NONLOCAL): # `global` and `nonlocal` just get hoisted as-is.
inner_scope.vars[var] = decl
yield from child_code # Yeet back the child node code.
if CoroutineMode.FAKE in generator:
yield "TYPON_UNREACHABLE();" # So the compiler doesn't complain about missing return statements.
elif CoroutineMode.ASYNC in generator and CoroutineMode.GENERATOR not in generator:
if not has_return:
yield "co_return;"
yield "}"
def visit_lvalue(self, lvalue: ast.expr, declare: bool = False) -> Iterable[str]: def visit_lvalue(self, lvalue: ast.expr, declare: bool = False) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple): if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})" yield f"std::tie({', '.join(flatmap(self.visit_lvalue, lvalue.elts))})"
......
...@@ -4,7 +4,7 @@ from typing import Iterable ...@@ -4,7 +4,7 @@ from typing import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import NodeVisitor from transpiler.phases.emit_cpp import NodeVisitor, FunctionEmissionKind
class ClassVisitor(NodeVisitor): class ClassVisitor(NodeVisitor):
...@@ -47,7 +47,7 @@ class ClassInnerVisitor(NodeVisitor): ...@@ -47,7 +47,7 @@ class ClassInnerVisitor(NodeVisitor):
yield "struct {" yield "struct {"
yield "type* self;" yield "type* self;"
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, True) yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD, True)
yield f"}} {node.name} {{ this }};" yield f"}} {node.name} {{ this }};"
@dataclass @dataclass
......
...@@ -3,7 +3,7 @@ import ast ...@@ -3,7 +3,7 @@ import ast
from typing import Iterable from typing import Iterable
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.module import ModuleVisitor from transpiler.phases.emit_cpp.module import ModuleVisitor, ModuleVisitor2
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -17,6 +17,9 @@ class FileVisitor(BlockVisitor): ...@@ -17,6 +17,9 @@ class FileVisitor(BlockVisitor):
yield from visitor.includes yield from visitor.includes
yield "namespace PROGRAMNS {" yield "namespace PROGRAMNS {"
yield from code yield from code
visitor = ModuleVisitor2(self.scope)
code = [line for stmt in node.body for line in visitor.visit(stmt)]
yield from code
yield "}" yield "}"
yield "int main(int argc, char* argv[]) {" yield "int main(int argc, char* argv[]) {"
yield "py_sys::all.argv = PyList<PyStr>(std::vector<PyStr>(argv, argv + argc));" yield "py_sys::all.argv = PyList<PyStr>(std::vector<PyStr>(argv, argv + argc));"
......
...@@ -3,8 +3,8 @@ import ast ...@@ -3,8 +3,8 @@ import ast
from typing import Iterable from typing import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from transpiler.phases.typing.scope import Scope
from transpiler.phases.emit_cpp import CoroutineMode from transpiler.phases.emit_cpp import CoroutineMode, FunctionEmissionKind, NodeVisitor
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
from transpiler.phases.emit_cpp.class_ import ClassVisitor from transpiler.phases.emit_cpp.class_ import ClassVisitor
from transpiler.phases.emit_cpp.function import FunctionVisitor from transpiler.phases.emit_cpp.function import FunctionVisitor
...@@ -46,3 +46,17 @@ class ModuleVisitor(BlockVisitor): ...@@ -46,3 +46,17 @@ class ModuleVisitor(BlockVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]: def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[str]:
yield from ClassVisitor().visit(node) yield from ClassVisitor().visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from super().visit_free_func(node, FunctionEmissionKind.DECLARATION)
@dataclass
class ModuleVisitor2(NodeVisitor):
scope: Scope
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield from BlockVisitor(self.scope).visit_free_func(node, FunctionEmissionKind.DEFINITION)
def visit_AST(self, node: ast.AST) -> Iterable[str]:
yield ""
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment