Commit 8ce64ec2 authored by Tom Niget's avatar Tom Niget

Use dot macro

parent 5234daf1
...@@ -21,4 +21,5 @@ if __name__ == "__main__": ...@@ -21,4 +21,5 @@ if __name__ == "__main__":
print(x.name) print(x.name)
print(x.age) print(x.age)
x.afficher() x.afficher()
y.afficher(x)
...@@ -56,6 +56,8 @@ class BlockVisitor(NodeVisitor): ...@@ -56,6 +56,8 @@ class BlockVisitor(NodeVisitor):
yield f"}} {node.name};" yield f"}} {node.name};"
def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]: def visit_func_new(self, node: ast.FunctionDef, emission: FunctionEmissionKind, skip_first_arg: bool = False) -> Iterable[str]:
if emission == FunctionEmissionKind.METHOD:
yield "template <typename Self>"
yield from self.visit(node.type.return_type) yield from self.visit(node.type.return_type)
if emission == FunctionEmissionKind.DEFINITION: if emission == FunctionEmissionKind.DEFINITION:
yield f"{node.name}_inner::" yield f"{node.name}_inner::"
...@@ -68,13 +70,19 @@ class BlockVisitor(NodeVisitor): ...@@ -68,13 +70,19 @@ class BlockVisitor(NodeVisitor):
for i, (arg, argty, default) in enumerate(args_iter): for i, (arg, argty, default) in enumerate(args_iter):
if i != 0: if i != 0:
yield ", " yield ", "
yield from self.visit(argty) if emission == FunctionEmissionKind.METHOD and i == 0:
yield "Self"
else:
yield from self.visit(argty)
yield arg.arg yield arg.arg
if emission == FunctionEmissionKind.DECLARATION and default: if emission == FunctionEmissionKind.DECLARATION and default:
yield " = " yield " = "
yield from self.expr().visit(default) yield from self.expr().visit(default)
yield ")" yield ")"
if emission == FunctionEmissionKind.METHOD:
yield "const"
inner_scope = node.inner_scope inner_scope = node.inner_scope
if emission == FunctionEmissionKind.DECLARATION: if emission == FunctionEmissionKind.DECLARATION:
......
...@@ -19,7 +19,7 @@ class ClassVisitor(NodeVisitor): ...@@ -19,7 +19,7 @@ class ClassVisitor(NodeVisitor):
yield from inner.visit(stmt) yield from inner.visit(stmt)
yield "template<typename... T> type(T&&... args) {" yield "template<typename... T> type(T&&... args) {"
yield "__init__(std::forward<T>(args)...);" yield "__init__(this, std::forward<T>(args)...);"
yield "}" yield "}"
yield "type() {}" yield "type() {}"
yield "type(const type&) = delete;" yield "type(const type&) = delete;"
...@@ -62,11 +62,16 @@ class ClassInnerVisitor(NodeVisitor): ...@@ -62,11 +62,16 @@ class ClassInnerVisitor(NodeVisitor):
yield ";" yield ";"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {" # yield "struct {"
yield "type* self;" # yield "type* self;"
# from transpiler.phases.emit_cpp.block import BlockVisitor
# yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD, True)
# yield f"}} {node.name} {{ this }};"
yield f"struct {node.name}_m_s : function {{"
from transpiler.phases.emit_cpp.block import BlockVisitor from transpiler.phases.emit_cpp.block import BlockVisitor
yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD, True) yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
yield f"}} {node.name} {{ this }};" yield f"}} static constexpr {node.name} {{}};"
yield ""
@dataclass @dataclass
class ClassOuterVisitor(NodeVisitor): class ClassOuterVisitor(NodeVisitor):
...@@ -77,8 +82,13 @@ class ClassOuterVisitor(NodeVisitor): ...@@ -77,8 +82,13 @@ class ClassOuterVisitor(NodeVisitor):
def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]: def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[str]:
yield "struct {" yield "struct {"
yield "template<typename... T>" yield "template<typename Self, typename... T>"
yield "auto operator()(type& self, T&&... args) {" yield "auto operator()(Self self, T&&... args) {"
yield f"return self.{node.name}(std::forward<T>(args)...);" yield f"return dotp(self, {node.name})(std::forward<T>(args)...);"
yield "}" yield "}"
yield f"}} {node.name};" yield f"}} {node.name};"
yield ""
# yield "struct : function {"
# from transpiler.phases.emit_cpp.block import BlockVisitor
# yield from BlockVisitor(self.scope).visit_func_new(node, FunctionEmissionKind.METHOD)
# yield f"}} static constexpr {node.name} {{}};"
...@@ -3,8 +3,8 @@ import ast ...@@ -3,8 +3,8 @@ import ast
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Iterable from typing import List, Iterable
from transpiler.phases.typing.types import UserType from transpiler.phases.typing.types import UserType, FunctionType
from transpiler.utils import compare_ast from transpiler.utils import compare_ast, linenodata
from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS from transpiler.consts import SYMBOLS, PRECEDENCE_LEVELS
from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor from transpiler.phases.emit_cpp import CoroutineMode, join, NodeVisitor
from transpiler.phases.typing.scope import Scope, VarKind from transpiler.phases.typing.scope import Scope, VarKind
...@@ -92,13 +92,21 @@ class ExpressionVisitor(NodeVisitor): ...@@ -92,13 +92,21 @@ class ExpressionVisitor(NodeVisitor):
yield res yield res
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
def make_lnd(op1, op2):
return {
"lineno": op1.lineno,
"col_offset": op1.col_offset,
"end_lineno": op2.end_lineno,
"end_col_offset": op2.end_col_offset
}
operands = [node.left, *node.comparators] operands = [node.left, *node.comparators]
with self.prec_ctx("&&"): with self.prec_ctx("&&"):
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1]) yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1], make_lnd(operands[0], operands[1]))
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]): for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# TODO: cleaner code # TODO: cleaner code
yield " && " yield " && "
yield from self.visit_binary_operation(op, left, right) yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_Call(self, node: ast.Call) -> Iterable[str]: def visit_Call(self, node: ast.Call) -> Iterable[str]:
# TODO # TODO
...@@ -158,11 +166,11 @@ class ExpressionVisitor(NodeVisitor): ...@@ -158,11 +166,11 @@ class ExpressionVisitor(NodeVisitor):
yield "}" yield "}"
def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]: def visit_BinOp(self, node: ast.BinOp) -> Iterable[str]:
yield from self.visit_binary_operation(node.op, node.left, node.right) yield from self.visit_binary_operation(node.op, node.left, node.right, linenodata(node))
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST, lnd: dict) -> Iterable[str]:
if type(op) == ast.In: if type(op) == ast.In:
call = ast.Call(ast.Attribute(right, "__contains__"), [left], []) call = ast.Call(ast.Attribute(right, "__contains__", **lnd), [left], [], **lnd)
call.is_await = False call.is_await = False
yield from self.visit_Call(call) yield from self.visit_Call(call)
return return
...@@ -180,12 +188,23 @@ class ExpressionVisitor(NodeVisitor): ...@@ -180,12 +188,23 @@ class ExpressionVisitor(NodeVisitor):
yield ")" yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]: def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield from self.prec(".").visit(node.value) if isinstance(node.type, FunctionType):
if node.value.type.resolve().is_reference: if node.value.type.resolve().is_reference:
yield "->" yield "dotp"
else:
yield "dot"
yield "("
yield from self.visit(node.value)
yield ", "
yield self.fix_name(node.attr)
yield ")"
else: else:
yield "." yield from self.prec(".").visit(node.value)
yield self.fix_name(node.attr) if node.value.type.resolve().is_reference:
yield "->"
else:
yield "."
yield self.fix_name(node.attr)
def visit_List(self, node: ast.List) -> Iterable[str]: def visit_List(self, node: ast.List) -> Iterable[str]:
if node.elts: if node.elts:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment