Commit 3faf1b4e authored by Tom Niget's avatar Tom Niget

Add a few helper functions, rework concepts

parent 35960de1
...@@ -13,20 +13,41 @@ ...@@ -13,20 +13,41 @@
using namespace std::literals; using namespace std::literals;
template<typename T> template<typename T>
concept Streamable = requires(const T &x, std::ostream &s) {
{ s << x } -> std::same_as<std::ostream &>;
};
template<Streamable T>
void print_to(const T &x, std::ostream &s) { void print_to(const T &x, std::ostream &s) {
s << x; s << x;
} }
template<typename T> template<typename T>
concept Printable = requires(const T &x, std::ostream &s) { concept FunctionPointer = std::is_function_v<T>
or std::is_member_function_pointer_v<T>
or std::is_function_v<std::remove_pointer_t<T>>;
template<Streamable T>
requires (FunctionPointer<T>)
void print_to(const T &x, std::ostream &s) {
s << "<function at 0x" << std::hex << (size_t) x << ">";
}
template<typename T>
concept PyPrint = requires(const T &x, std::ostream &s) {
{ x.py_print(s) } -> std::same_as<void>; { x.py_print(s) } -> std::same_as<void>;
}; };
template<Printable T> template<PyPrint T>
void print_to(const T &x, std::ostream &s) { void print_to(const T &x, std::ostream &s) {
x.py_print(s); x.py_print(s);
} }
template<typename T>
concept Printable = requires(const T &x, std::ostream &s) {
{ print_to(x, s) } -> std::same_as<void>;
};
template<typename T> template<typename T>
concept PyIterator = requires(T t) { concept PyIterator = requires(T t) {
{ t.py_next() } -> std::same_as<std::optional<T>>; { t.py_next() } -> std::same_as<std::optional<T>>;
...@@ -56,13 +77,6 @@ void print() { ...@@ -56,13 +77,6 @@ void print() {
std::cout << '\n'; std::cout << '\n';
} }
template<typename T, typename ... Args>
void print(T const &head, Args const &... args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
bool is_cpp() { bool is_cpp() {
return true; return true;
} }
...@@ -74,4 +88,11 @@ bool is_cpp() { ...@@ -74,4 +88,11 @@ bool is_cpp() {
#include "builtins/set.hpp" #include "builtins/set.hpp"
#include "builtins/str.hpp" #include "builtins/str.hpp"
template<Printable T, Printable ... Args>
void print(T const &head, Args const &... args) {
print_to(head, std::cout);
(((std::cout << ' '), print_to(args, std::cout)), ...);
std::cout << '\n';
}
#endif //TYPON_BUILTINS_HPP #endif //TYPON_BUILTINS_HPP
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <ostream> #include <ostream>
template<> template<>
void print_to(const bool &x, std::ostream &s) { void print_to<bool>(const bool &x, std::ostream &s) {
s << (x ? "True" : "False"); s << (x ? "True" : "False");
} }
......
...@@ -57,4 +57,9 @@ public: ...@@ -57,4 +57,9 @@ public:
} }
}; };
template<typename T>
PyList<T> list(std::initializer_list<T> &&v) {
return PyList<T>(std::move(v));
}
#endif //TYPON_LIST_HPP #endif //TYPON_LIST_HPP
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#ifndef TYPON_SET_HPP #ifndef TYPON_SET_HPP
#define TYPON_SET_HPP #define TYPON_SET_HPP
#include <unordered_set> #include <unordered_set>
template<typename T> template<typename T>
...@@ -83,4 +84,9 @@ public: ...@@ -83,4 +84,9 @@ public:
} }
}; };
template<typename T>
PySet<T> set(std::initializer_list<T> &&s) {
return PySet<T>(std::move(s));
}
#endif //TYPON_SET_HPP #endif //TYPON_SET_HPP
//
// Created by Tom on 09/03/2023.
//
#ifndef TYPON_SYS_HPP
#define TYPON_SYS_HPP
#include <iostream>
struct sys_t {
static constexpr auto& stdout = std::cout;
} sys;
#endif //TYPON_SYS_HPP
# coding: utf-8 # coding: utf-8
from typon import is_cpp from typon import is_cpp
import sys
test = (2 + 3) * 4
glob = 5 glob = 5
def g():
if True:
if True:
if True:
x = 5
print(x)
def f(x): def f(x):
return x + 1 return x + 1
...@@ -14,13 +22,15 @@ def fct(param): ...@@ -14,13 +22,15 @@ def fct(param):
global glob global glob
loc = 789 loc = 789
glob = 123 glob = 123
a = 5
b = 6 def fct2():
z = f(a + b) * 2 global glob
glob += 5
if __name__ == "__main__": if __name__ == "__main__":
# todo: 0x55 & 7 == 5 # todo: 0x55 & 7 == 5
print(is_cpp)
print("C++ " if is_cpp() else "Python", print("C++ " if is_cpp() else "Python",
"res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5, "res=", 5, ".", True, [4, 5, 6], {7, 8, 9}, [1, 2] + [3, 4], [5, 6] * 3, {1: 7, 9: 3}, 0x55 & 7 == 5,
2 + 3j) 2 + 3j)
......
...@@ -73,6 +73,8 @@ SYMBOLS = { ...@@ -73,6 +73,8 @@ SYMBOLS = {
"""Mapping of Python AST nodes to C++ symbols.""" """Mapping of Python AST nodes to C++ symbols."""
PRECEDENCE = [ PRECEDENCE = [
("()", "[]", ".",),
("unary",),
("*", "/", "%",), ("*", "/", "%",),
("+", "-"), ("+", "-"),
("<<", ">>"), ("<<", ">>"),
...@@ -83,6 +85,8 @@ PRECEDENCE = [ ...@@ -83,6 +85,8 @@ PRECEDENCE = [
("|",), ("|",),
("&&",), ("&&",),
("||",), ("||",),
("?:",),
(",",)
] ]
"""Precedence of C++ operators.""" """Precedence of C++ operators."""
...@@ -136,12 +140,46 @@ class NodeVisitor: ...@@ -136,12 +140,46 @@ class NodeVisitor:
return MAPPINGS.get(name, name) return MAPPINGS.get(name, name)
class PrecedenceContext:
def __init__(self, visitor: "ExpressionVisitor", op: str):
self.visitor = visitor
self.op = op
def __enter__(self):
if self.visitor.precedence[-1:] != [self.op]:
self.visitor.precedence.append(self.op)
def __exit__(self, exc_type, exc_val, exc_tb):
self.visitor.precedence.pop()
# noinspection PyPep8Naming
class ExpressionVisitor(NodeVisitor): class ExpressionVisitor(NodeVisitor):
def __init__(self, precedence: Optional[int] = None): def __init__(self, precedence=None):
self._precedence = precedence self.precedence = precedence or []
def prec_ctx(self, op: str) -> PrecedenceContext:
"""
Creates a context manager that sets the precedence of the next expression.
"""
return PrecedenceContext(self, op)
def prec(self, op: str) -> "ExpressionVisitor":
"""
Sets the precedence of the next expression.
"""
return ExpressionVisitor([op])
def reset(self) -> "ExpressionVisitor":
"""
Resets the precedence stack.
"""
return ExpressionVisitor()
def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]: def visit_Tuple(self, node: ast.Tuple) -> Iterable[str]:
yield f"std::make_tuple({', '.join(flatmap(self.visit, node.elts))})" yield "std::make_tuple("
yield from join(", ", map(self.visit, node.elts))
yield ")"
def visit_Constant(self, node: ast.Constant) -> Iterable[str]: def visit_Constant(self, node: ast.Constant) -> Iterable[str]:
if isinstance(node.value, str): if isinstance(node.value, str):
...@@ -161,8 +199,8 @@ class ExpressionVisitor(NodeVisitor): ...@@ -161,8 +199,8 @@ class ExpressionVisitor(NodeVisitor):
yield self.fix_name(node.id) yield self.fix_name(node.id)
def visit_Compare(self, node: ast.Compare) -> Iterable[str]: def visit_Compare(self, node: ast.Compare) -> Iterable[str]:
# TODO: operator precedence
operands = [node.left, *node.comparators] operands = [node.left, *node.comparators]
with self.prec_ctx("&&"):
yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1]) yield from self.visit_binary_operation(node.ops[0], operands[0], operands[1])
for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]): for (left, right), op in zip(zip(operands[1:], operands[2:]), node.ops[1:]):
# TODO: cleaner code # TODO: cleaner code
...@@ -176,9 +214,9 @@ class ExpressionVisitor(NodeVisitor): ...@@ -176,9 +214,9 @@ class ExpressionVisitor(NodeVisitor):
raise NotImplementedError(node, "varargs") raise NotImplementedError(node, "varargs")
if getattr(node, "kwargs", None): if getattr(node, "kwargs", None):
raise NotImplementedError(node, "kwargs") raise NotImplementedError(node, "kwargs")
yield from self.visit(node.func) yield from self.prec("()").visit(node.func)
yield "(" yield "("
yield from join(", ", map(self.visit, node.args)) yield from join(", ", map(self.reset().visit, node.args))
yield ")" yield ")"
def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]: def visit_Lambda(self, node: ast.Lambda) -> Iterable[str]:
...@@ -188,7 +226,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -188,7 +226,7 @@ class ExpressionVisitor(NodeVisitor):
yield args yield args
yield "{" yield "{"
yield "return" yield "return"
yield from self.visit(node.body) yield from self.reset().visit(node.body)
yield ";" yield ";"
yield "}" yield "}"
...@@ -196,39 +234,38 @@ class ExpressionVisitor(NodeVisitor): ...@@ -196,39 +234,38 @@ class ExpressionVisitor(NodeVisitor):
yield from self.visit_binary_operation(node.op, node.left, node.right) yield from self.visit_binary_operation(node.op, node.left, node.right)
def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]: def visit_binary_operation(self, op, left: ast.AST, right: ast.AST) -> Iterable[str]:
# TODO: precedence
op = SYMBOLS[type(op)] op = SYMBOLS[type(op)]
inner = ExpressionVisitor(PRECEDENCE_LEVELS[op]) prio = self.precedence and PRECEDENCE_LEVELS[self.precedence[-1]] < PRECEDENCE_LEVELS[op]
prio = self._precedence is not None and inner._precedence > self._precedence
if prio: if prio:
yield "(" yield "("
yield from inner.visit(left) with self.prec_ctx(op):
yield from self.visit(left)
yield op yield op
yield from inner.visit(right) yield from self.visit(right)
if prio: if prio:
yield ")" yield ")"
def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]: def visit_Attribute(self, node: ast.Attribute) -> Iterable[str]:
yield from self.visit(node.value) yield from self.prec(".").visit(node.value)
yield "." yield "."
yield node.attr yield node.attr
def visit_List(self, node: ast.List) -> Iterable[str]: def visit_List(self, node: ast.List) -> Iterable[str]:
yield "PyList{" yield "PyList{"
yield from join(", ", map(self.visit, node.elts)) yield from join(", ", map(self.reset().visit, node.elts))
yield "}" yield "}"
def visit_Set(self, node: ast.Set) -> Iterable[str]: def visit_Set(self, node: ast.Set) -> Iterable[str]:
yield "PySet{" yield "PySet{"
yield from join(", ", map(self.visit, node.elts)) yield from join(", ", map(self.reset().visit, node.elts))
yield "}" yield "}"
def visit_Dict(self, node: ast.Dict) -> Iterable[str]: def visit_Dict(self, node: ast.Dict) -> Iterable[str]:
def visit_item(key, value): def visit_item(key, value):
yield "std::pair {" yield "std::pair {"
yield from self.visit(key) yield from self.reset().visit(key)
yield ", " yield ", "
yield from self.visit(value) yield from self.reset().visit(value)
yield "}" yield "}"
yield "PyDict{" yield "PyDict{"
...@@ -236,16 +273,17 @@ class ExpressionVisitor(NodeVisitor): ...@@ -236,16 +273,17 @@ class ExpressionVisitor(NodeVisitor):
yield "}" yield "}"
def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]: def visit_Subscript(self, node: ast.Subscript) -> Iterable[str]:
yield from self.visit(node.value) yield from self.prec("[]").visit(node.value)
yield "[" yield "["
yield from self.visit(node.slice) yield from self.reset().visit(node.slice)
yield "]" yield "]"
def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]: def visit_UnaryOp(self, node: ast.UnaryOp) -> Iterable[str]:
yield from self.visit(node.op) yield from self.visit(node.op)
yield from self.visit(node.operand) yield from self.prec("unary").visit(node.operand)
def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]: def visit_IfExp(self, node: ast.IfExp) -> Iterable[str]:
with self.prec_ctx("?:"):
yield from self.visit(node.test) yield from self.visit(node.test)
yield " ? " yield " ? "
yield from self.visit(node.body) yield from self.visit(node.body)
...@@ -265,6 +303,7 @@ class Scope: ...@@ -265,6 +303,7 @@ class Scope:
return name in self.vars or (self.parent is not None and self.parent.exists(name)) return name in self.vars or (self.parent is not None and self.parent.exists(name))
# noinspection PyPep8Naming
class BlockVisitor(NodeVisitor): class BlockVisitor(NodeVisitor):
def __init__(self, scope: Scope): def __init__(self, scope: Scope):
self._scope = scope self._scope = scope
...@@ -280,11 +319,12 @@ class BlockVisitor(NodeVisitor): ...@@ -280,11 +319,12 @@ class BlockVisitor(NodeVisitor):
yield ";" yield ";"
def visit_Import(self, node: ast.Import) -> Iterable[str]: def visit_Import(self, node: ast.Import) -> Iterable[str]:
for name in node.names: for alias in node.names:
if name == "typon": if alias.name == "typon":
yield "" yield ""
else: else:
raise NotImplementedError(node) yield f'#include "python/{alias.name}.hpp"'
#raise NotImplementedError(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]: def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[str]:
if node.module == "typon": if node.module == "typon":
...@@ -377,6 +417,20 @@ class BlockVisitor(NodeVisitor): ...@@ -377,6 +417,20 @@ class BlockVisitor(NodeVisitor):
yield from ExpressionVisitor().visit(node.value) yield from ExpressionVisitor().visit(node.value)
yield ";" yield ";"
def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[str]:
if node.value is None:
raise NotImplementedError(node, "empty value")
yield from self.visit_lvalue(node.target)
yield " = "
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[str]:
yield from self.visit_lvalue(node.target)
yield SYMBOLS[type(node.op)] + "="
yield from ExpressionVisitor().visit(node.value)
yield ";"
def visit_For(self, node: ast.For) -> Iterable[str]: def visit_For(self, node: ast.For) -> Iterable[str]:
if not isinstance(node.target, ast.Name): if not isinstance(node.target, ast.Name):
raise NotImplementedError(node) raise NotImplementedError(node)
......
...@@ -7,4 +7,8 @@ clang = clang_format._get_executable("clang-format") # noqa ...@@ -7,4 +7,8 @@ clang = clang_format._get_executable("clang-format") # noqa
def format_code(code: str) -> str: def format_code(code: str) -> str:
return subprocess.check_output([clang, "-style=LLVM"], input=code.encode("utf-8")).decode("utf-8") return subprocess.check_output([
clang,
"--style=LLVM",
"--assume-filename=main.cpp"
], input=code.encode("utf-8")).decode("utf-8")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment