Commit 47dae27f authored by Tom Niget's avatar Tom Niget

Separate visitors into File, Module and Function

parent 9eb76c50
......@@ -42,7 +42,7 @@ def join(sep: str, items: Iterable[Iterable[str]]) -> Iterable[str]:
def transpile(source):
tree = ast.parse(source)
# print(ast.unparse(tree))
return "\n".join(filter(None, map(str, BlockVisitor(Scope()).visit(tree))))
return "\n".join(filter(None, map(str, FileVisitor(Scope()).visit(tree))))
SYMBOLS = {
......@@ -289,11 +289,13 @@ class ExpressionVisitor(NodeVisitor):
yield " : "
yield from self.visit(node.orelse)
@dataclass
class VarDecl:
kind: VarKind
val: Optional[str]
@dataclass
class Scope:
parent: Optional["Scope"] = None
......@@ -318,7 +320,8 @@ class Scope:
The check does not cross function boundaries; i.e. global variables are not taken into account from inside
functions.
"""
return name in self.vars or (not self.is_function and self.parent is not None and self.parent.exists_local(name))
return name in self.vars or (
not self.is_function and self.parent is not None and self.parent.exists_local(name))
def child(self) -> "Scope":
"""
......@@ -371,35 +374,6 @@ class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope):
self._scope = scope
def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST
yield "#include <python/builtins.hpp>"
for stmt in node.body:
yield from self.visit(stmt)
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_Import(self, node: ast.Import) -> Iterable[str]:
for alias in node.names:
if alias.name == "typon":
yield ""
else:
yield from self.import_module(alias.name)
yield f'auto& {alias.asname or alias.name} = py_{alias.name}::all;'
def import_module(self, name: str) -> Iterable[str]:
yield f'#include "python/{name}.hpp"'
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon":
yield ""
else:
yield from self.import_module(node.module)
for alias in node.names:
yield f"auto& {alias.asname or alias.name} = py_{node.module}::all.{alias.name};"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
templ, args = self.process_args(node.args)
if templ:
......@@ -408,7 +382,7 @@ class BlockVisitor(NodeVisitor):
yield f"auto {node.name}"
yield args
yield "{"
inner = BlockVisitor(self._scope.function())
inner = FunctionVisitor(self._scope.function())
for child in node.body:
# Python uses module- and function- level scoping. Blocks, like conditionals and loops, do not form scopes
# on their own. Variables are still accessible in the remainder of the parent function or in the global
......@@ -448,7 +422,7 @@ class BlockVisitor(NodeVisitor):
# auto y = 2;
# }
# ```
child_visitor = BlockVisitor(inner._scope.child())
child_visitor = FunctionVisitor(inner._scope.child())
# We need to do this in two-passes. This unfortunately breaks our nice generator state-machine architecture.
# Fair enough.
......@@ -463,60 +437,12 @@ class BlockVisitor(NodeVisitor):
yield from child_code # Yeet back the child node code.
yield "}"
def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
yield ""
def visit_If(self, node: ast.If) -> Iterable[str]:
if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body):
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "int main()"
yield from self.emit_block(node.body)
return
yield "if ("
yield from ExpressionVisitor().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
yield "else "
if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse)
else:
yield from self.emit_block(node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]:
yield "return "
if node.value:
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_While(self, node: ast.While) -> Iterable[str]:
yield "while ("
yield from ExpressionVisitor().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
def visit_lvalue(self, lvalue: ast.expr, val: Optional[ast.AST] = None) -> Iterable[str]:
if isinstance(lvalue, ast.Tuple):
yield f"std::tie({', '.join(flatmap(ExpressionVisitor().visit, lvalue.elts))})"
elif isinstance(lvalue, ast.Name):
name = self.fix_name(lvalue.id)
#if name not in self._scope.vars:
# if name not in self._scope.vars:
if not self._scope.exists_local(name):
yield self._scope.declare(name, " ".join(ExpressionVisitor().visit(val)) if val else None)
yield name
......@@ -541,6 +467,60 @@ class BlockVisitor(NodeVisitor):
yield from ExpressionVisitor().visit(node.value)
yield ";"
# noinspection PyPep8Naming
class FileVisitor(BlockVisitor):
def visit_Module(self, node: ast.Module) -> Iterable[str]:
stmt: ast.AST
yield "#include <python/builtins.hpp>"
visitor = ModuleVisitor(self._scope)
for stmt in node.body:
yield from visitor.visit(stmt)
# noinspection PyPep8Naming
class ModuleVisitor(BlockVisitor):
def visit_Import(self, node: ast.Import) -> Iterable[str]:
for alias in node.names:
if alias.name == "typon":
yield ""
else:
yield from self.import_module(alias.name)
yield f'auto& {alias.asname or alias.name} = py_{alias.name}::all;'
def import_module(self, name: str) -> Iterable[str]:
yield f'#include "python/{name}.hpp"'
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon":
yield ""
else:
yield from self.import_module(node.module)
for alias in node.names:
yield f"auto& {alias.asname or alias.name} = py_{node.module}::all.{alias.name};"
def visit_If(self, node: ast.If) -> Iterable[str]:
if not node.orelse and compare_ast(node.test, ast.parse('__name__ == "__main__"', mode="eval").body):
# Special case handling for Python's interesting way of defining an entry point.
# I mean, it's not *that* bad, it's just an attempt at retrofitting an "entry point" logic in a scripting
# language that, by essence, uses "the start of the file" as the implicit entry point, since files are
# read and executed line-by-line, contrary to usual structured languages that mark a distinction between
# declarations (functions, classes, modules, ...) and code.
# Also, for nitpickers, the C++ standard explicitly allows for omitting a `return` statement in the `main`.
# 0 is returned by default.
yield "int main()"
yield from FunctionVisitor(self._scope).emit_block(node.body)
return
raise NotImplementedError(node, "global scope if")
# noinspection PyPep8Naming
class FunctionVisitor(BlockVisitor):
def visit_Expr(self, node: ast.Expr) -> Iterable[str]:
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target)
yield SYMBOLS[type(node.op)] + "="
......@@ -557,11 +537,47 @@ class BlockVisitor(NodeVisitor):
if node.orelse:
raise NotImplementedError(node, "orelse")
def block(self) -> "BlockVisitor":
def visit_If(self, node: ast.If) -> Iterable[str]:
yield "if ("
yield from ExpressionVisitor().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
yield "else "
if isinstance(node.orelse, ast.If):
yield from self.visit(node.orelse)
else:
yield from self.emit_block(node.orelse)
def visit_Return(self, node: ast.Return) -> Iterable[str]:
yield "return "
if node.value:
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_While(self, node: ast.While) -> Iterable[str]:
yield "while ("
yield from ExpressionVisitor().visit(node.test)
yield ")"
yield from self.emit_block(node.body)
if node.orelse:
raise NotImplementedError(node, "orelse")
def visit_Global(self, node: ast.Global) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.GLOBAL, None)
yield ""
def visit_Nonlocal(self, node: ast.Nonlocal) -> Iterable[str]:
for name in map(self.fix_name, node.names):
self._scope.vars[name] = VarDecl(VarKind.NONLOCAL, None)
yield ""
def block(self) -> "FunctionVisitor":
# See the comments in visit_FunctionDef.
# A Python code block does not introduce a new scope, so we create a new `Scope` object that shares the same
# variables as the parent scope.
return BlockVisitor(self._scope.child_share())
return FunctionVisitor(self._scope.child_share())
def emit_block(self, items: List[ast.stmt]) -> Iterable[str]:
yield "{"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment