Commit cd79036c authored by Tom Niget's avatar Tom Niget

Add preliminary support for list comprehensions

parent ff8c2374
......@@ -294,4 +294,28 @@ template <class T> auto begin(std::shared_ptr<T> &obj) { return dotp(obj, begin)
template <class T> auto end(std::shared_ptr<T> &obj) { return dotp(obj, end)(); }
}
template <typename T>
struct AlwaysTrue { // (1)
constexpr bool operator()(const T&) const {
return true;
}
};
template <typename Seq>
struct ValueTypeEx {
using type = decltype(*std::begin(std::declval<Seq&>()));
};
// (2)
template <typename Map, typename Seq, typename Filt = AlwaysTrue<typename ValueTypeEx<Seq>::type>>
auto mapFilter(Map map, Seq seq, Filt filt = Filt()) {
//typedef typename Seq::value_type value_type;
using value_type = typename ValueTypeEx<Seq>::type;
using return_type = decltype(map(std::declval<value_type>()));
std::vector<return_type> result{};
for (auto i : seq | std::views::filter(filt)
| std::views::transform(map)) result.push_back(i);
return typon::PyList(std::move(result));
}
#endif // TYPON_BUILTINS_HPP
......@@ -15,6 +15,7 @@ namespace typon {
template <typename T> class PyList {
public:
using value_type = T;
PyList(std::shared_ptr<std::vector<T>> &&v) : _v(std::move(v)) {}
PyList(std::vector<T> &&v)
: _v(std::move(std::make_shared<std::vector<T>>(std::move(v)))) {}
......
......@@ -17,6 +17,7 @@ class int:
def __init__(self, x: str) -> None: ...
def __lt__(self, other: Self) -> bool: ...
def __gt__(self, other: Self) -> bool: ...
def __mod__(self, other: Self) -> Self: ...
assert int.__add__
......@@ -71,6 +72,7 @@ class list(Generic[U]):
def __len__(self) -> int: ...
def append(self, value: U) -> None: ...
def __contains__(self, item: U) -> bool: ...
def __init__(self, it: Iterator[U]) -> None: ...
assert [1, 2].__iter__()
assert list[int].__iter__
......
import sys
import math
def gàé():
return 1,2,3
if __name__ == "__main__":
if True:
a, b, c = gàé() # abc
\ No newline at end of file
a = [n for n in range(10)]
b = [x for x in a if x % 2 == 0]
c = [y * y for y in b]
print(a, b, c)
\ No newline at end of file
from dataclasses import dataclass
from typing import Any, Callable
from enum import Enum
from itertools import groupby
import operator
import string
@dataclass
class BinOperator:
symbol: str
priority: int
perform: Callable[[float, float], float]
OPERATORS = [
BinOperator("+", 0, operator.add),
BinOperator("-", 0, operator.sub),
BinOperator("*", 1, operator.mul),
BinOperator("/", 1, operator.truediv)
]
ops_by_priority = [list(it) for _, it in groupby(OPERATORS, lambda op: op.priority)]
MAX_PRIORITY = len(ops_by_priority)
ops_syms = [op.symbol for op in OPERATORS]
class TokenType(Enum):
NUMBER = 1
PARENTHESIS = 2
OPERATION = 3
@dataclass
class Token:
type: TokenType
val: Any
def tokenize(inp: str):
tokens = []
index = 0
def skip_spaces():
nonlocal index
while inp[index].isspace():
index += 1
def has():
return index < len(inp)
def peek():
return inp[index]
def read():
nonlocal index
index += 1
return inp[index - 1]
def read_number():
res = ""
while True:
res += read()
if not has() or peek() not in "0123456789.":
break
return Token(TokenType.NUMBER, float(res) if "." in res else int(res))
while has():
skip_spaces()
next = peek()
if next in ops_syms:
tok = Token(TokenType.OPERATION, read())
elif next in "()":
tok = Token(TokenType.PARENTHESIS, read())
elif next in "0123456789.":
tok = read_number()
else:
raise Exception(f"invalid character '{next}'", index)
tokens.append(tok)
return tokens
def parse(tokens):
index = 0
def has():
return index < len(tokens)
def current():
if not has():
raise Exception("expected token, got EOL")
return tokens[index]
def match(type: TokenType, val: Any = None):
return has() and tokens[index].type == type and (val is None or tokens[index].val == val)
def accept(type: TokenType, val: Any = None):
nonlocal index
if match(type, val):
index += 1
return True
return False
def expect(type: TokenType, val: Any = None):
nonlocal index
if match(type, val):
index += 1
return tokens[index - 1]
if not has():
raise Exception(f"expected {type}, got EOL")
else:
raise Exception(f"expected {type}, got {current().type}")
def parse_bin(priority=0):
if priority >= MAX_PRIORITY:
return parse_term()
left = parse_bin(priority + 1)
ops = ops_by_priority[priority]
while has() and current().type == TokenType.OPERATION:
for op in ops:
if accept(TokenType.OPERATION, op.symbol):
right = parse_bin(priority + 1)
left = op.perform(left, right)
break
else:
break
return left
def parse_term():
token = current()
if token.type == TokenType.NUMBER:
return expect(TokenType.NUMBER).val
elif accept(TokenType.PARENTHESIS, "("):
val = parse_expr()
expect(TokenType.PARENTHESIS, ")")
return val
else:
raise Exception(f"expected term, got {token.type}")
def parse_expr():
return parse_bin()
return parse_expr()
if __name__ == "__main__":
while True:
inp = input("> ")
try:
tok = tokenize(inp)
res = parse(tok)
print(res)
except Exception as e:
print(e)
print()
......@@ -110,6 +110,9 @@ class ExpressionVisitor(NodeVisitor):
# yield from self.visit_binary_operation(op, left, right, make_lnd(left, right))
def visit_BoolOp(self, node: ast.BoolOp) -> Iterable[str]:
if len(node.values) == 1:
yield from self.visit(node.values[0])
return
cpp_op = {
ast.And: "&&",
ast.Or: "||"
......@@ -297,3 +300,34 @@ class ExpressionVisitor(NodeVisitor):
# raise NotImplementedError(node)
yield "co_yield"
yield from self.prec("co_yield").visit(node.value)
def visit_ListComp(self, node: ast.ListComp) -> Iterable[str]:
if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0]
yield "mapFilter([]("
yield from self.visit(node.input_item_type)
yield from self.visit(gen.target)
yield ") { return "
yield from self.visit(node.elt)
yield "; }, "
yield from self.visit(gen.iter)
if gen.ifs:
yield ", "
yield "[]("
yield from self.visit(node.input_item_type)
yield from self.visit(gen.target)
yield ") { return "
yield from self.visit(gen.ifs_node)
yield "; }"
yield ")"
# iter_type = get_iter(self.visit(gen.iter))
# next_type = get_next(iter_type)
# virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
# from transpiler import ScoperBlockVisitor
# visitor = ScoperBlockVisitor(virt_scope)
# visitor.visit_assign_target(gen.target, next_type)
# res_item_type = visitor.expr().visit(node.elt)
# for if_ in gen.ifs:
# visitor.expr().visit(if_)
# return PyList(res_item_type)
\ No newline at end of file
......@@ -63,3 +63,18 @@ class ScoperVisitor(NodeVisitorSeq):
if not node.inner_scope.has_return:
rtype.unify(TY_NONE) # todo: properly indicate missing return
def get_iter(seq_type):
try:
iter_type = seq_type.methods["__iter__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIterableError
raise NotIterableError(seq_type)
return iter_type
def get_next(iter_type):
try:
next_type = iter_type.methods["__next__"].return_type
except:
from transpiler.phases.typing.exceptions import NotIteratorError
raise NotIteratorError(iter_type)
return next_type
\ No newline at end of file
......@@ -4,7 +4,7 @@ import inspect
from typing import List
from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE
......@@ -248,4 +248,20 @@ class ScoperExprVisitor(ScoperVisitor):
return self.visit_function_call(
self.visit_getattr(TypeType(args[0]), f"__{name}__"),
args
)
\ No newline at end of file
)
def visit_ListComp(self, node: ast.ListComp) -> BaseType:
if len(node.generators) != 1:
raise NotImplementedError("Multiple generators not handled yet")
gen: ast.comprehension = node.generators[0]
iter_type = get_iter(self.visit(gen.iter))
node.input_item_type = get_next(iter_type)
virt_scope = self.scope.child(ScopeKind.FUNCTION_INNER)
from transpiler import ScoperBlockVisitor
visitor = ScoperBlockVisitor(virt_scope)
visitor.visit_assign_target(gen.target, node.input_item_type)
node.item_type = visitor.expr().visit(node.elt)
for if_ in gen.ifs:
visitor.expr().visit(if_)
gen.ifs_node = ast.BoolOp(ast.And(), gen.ifs, **linenodata(node))
return PyList(node.item_type)
\ No newline at end of file
......@@ -218,7 +218,7 @@ class TypeOperator(BaseType, ABC):
if other.is_protocol and not self.is_protocol:
return other.unify_internal(self)
if self.is_protocol and not other.is_protocol:
return other.matches_protocol(self)
return other.matches_protocol(self) # TODO: doesn't print the correct type in the error message
if len(self.args) < len(other.args):
return other.unify_internal(self)
assert self.is_protocol == other.is_protocol
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment