Commit d36c7e8f authored by Tom Niget's avatar Tom Niget

Continue work on pretty errors

parent b5c58745
...@@ -4,4 +4,4 @@ dataclasses~=0.6 ...@@ -4,4 +4,4 @@ dataclasses~=0.6
python-dotenv~=1.0.0 python-dotenv~=1.0.0
colorama~=0.4.6 colorama~=0.4.6
numpy~=1.25.1 numpy~=1.25.1
pygments~=2.15.1 colorful~=0.5.5
\ No newline at end of file \ No newline at end of file
def f(x): import sys
import math; x = (math.
abcd)
def f(x: int):
return x return x
if __name__ == "__main__": if __name__ == "__main__":
y = f(f) y = (6).x
\ No newline at end of file \ No newline at end of file
# coding: utf-8 # coding: utf-8
import ast import ast
import builtins import builtins
import importlib
import inspect import inspect
import os import os
...@@ -17,7 +18,7 @@ from transpiler.phases.if_main import IfMainVisitor ...@@ -17,7 +18,7 @@ from transpiler.phases.if_main import IfMainVisitor
from transpiler.phases.typing.block import ScoperBlockVisitor from transpiler.phases.typing.block import ScoperBlockVisitor
from transpiler.phases.typing.scope import Scope from transpiler.phases.typing.scope import Scope
from itertools import islice
import sys import sys
import colorful as cf import colorful as cf
...@@ -48,17 +49,40 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -48,17 +49,40 @@ def exception_hook(exc_type, exc_value, tb):
filename = tb.tb_frame.f_code.co_filename filename = tb.tb_frame.f_code.co_filename
line_no = tb.tb_lineno line_no = tb.tb_lineno
print(cf.red(f"File \"{filename}\", line {line_no}, in {name}"), end="") print(cf.red(f"File \"{filename}\", line {line_no}, in {cf.green(name)}"), end="")
if info := local_vars.get("TB", None): if info := local_vars.get("TB", None):
print(f", while {cf.magenta(info)}") print(f": {cf.magenta(info)}\x1b[24m")
else: else:
print() print()
tb = tb.tb_next tb = tb.tb_next
if last_node is not None: if last_node is not None:
print()
print(f"In file \"{cf.white(last_file)}\", line {last_node.lineno}") print(f"In file \"{cf.white(last_file)}\", line {last_node.lineno}")
print("\t" + highlight(ast.unparse(last_node))) with open(last_file, "r", encoding="utf-8") as f:
code = f.read()
hg = str(highlight(code, True)).replace("\x1b[04m", "").replace("\x1b[24m", "").splitlines()
if last_node.lineno == last_node.end_lineno:
old = hg[last_node.lineno - 1]
start, end = find_indices(old, [last_node.col_offset, last_node.end_col_offset])
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:end] + "\x1b[24m" + old[end:]
else:
old = hg[last_node.lineno - 1]
[start] = find_indices(old, [last_node.col_offset])
hg[last_node.lineno - 1] = old[:start] + "\x1b[4m" + old[start:]
old = hg[last_node.end_lineno - 1]
first_nonspace = len(old) - len(old.lstrip())
[end] = find_indices(old, [last_node.end_col_offset])
hg[last_node.end_lineno - 1] = old[:first_nonspace] + "\x1b[4m" + old[first_nonspace:end] + "\x1b[24m" + old[end:]
CONTEXT_SIZE = 2
start = max(0, last_node.lineno - CONTEXT_SIZE - 1)
offset = start + 1
for i, line in enumerate(hg[start:last_node.end_lineno + CONTEXT_SIZE]):
erroneous = last_node.lineno <= offset + i <= last_node.end_lineno
indicator = cf.white("**>") if erroneous else " "
print(
f"\x1b[24m{indicator}{cf.white}{(offset + i):>4}{cf.red if erroneous else cf.reset} | {cf.reset}{line}\x1b[24m")
print() print()
print(cf.red("Error:"), exc_value) print(cf.red("Error:"), exc_value)
if isinstance(exc_value, CompileError): if isinstance(exc_value, CompileError):
...@@ -66,9 +90,50 @@ def exception_hook(exc_type, exc_value, tb): ...@@ -66,9 +90,50 @@ def exception_hook(exc_type, exc_value, tb):
print(inspect.cleandoc(exc_value.detail(last_node))) print(inspect.cleandoc(exc_value.detail(last_node)))
print() print()
def find_indices(s, indices: list[int]) -> list[int]:
"""
Matches indices to an ANSI-colored string
"""
results = set()
i = 0
j = 0
it = iter(set(indices))
current = next(it)
while i <= len(s):
if i != len(s) and s[i] == "\x1b":
i += 1
while s[i] != "m":
i += 1
i += 1
continue
if j == current:
results.add(i)
try:
current = next(it)
except StopIteration:
break
i += 1
j += 1
assert len(results) == len(indices), (results, indices, s)
return sorted(list(results))
assert find_indices("\x1b[48;5;237mmath.abcd\x1b[37m\x1b[39m\x1b[49m", [0, 9]) == [11, 35], find_indices("\x1b[48;5;237mmath.abcd\x1b[37m\x1b[39m\x1b[49m", [0, 9])
assert find_indices("abcdef", [2, 5]) == [2, 5]
assert find_indices("abc\x1b[32mdef", [2, 5]) == [2, 10], find_indices("abc\x1b[32mdef", [2, 5])
assert find_indices("math.abcd\x1b[37m\x1b[39m", [0, 9]) == [0, 19], find_indices("math.abcd\x1b[37m\x1b[39m", [0, 9])
sys.excepthook = exception_hook sys.excepthook = exception_hook
try:
pydevd = importlib.import_module("_pydevd_bundle.pydevd_breakpoints")
except ImportError:
pass
else:
pydevd._fallback_excepthook = sys.excepthook
pydevd.original_excepthook = sys.excepthook
def transpile(source, name="<module>", path=None): def transpile(source, name="<module>", path=None):
TB = f"transpiling module {cf.white(name)}" TB = f"transpiling module {cf.white(name)}"
......
...@@ -39,7 +39,7 @@ PRELUDE.vars.update({ ...@@ -39,7 +39,7 @@ PRELUDE.vars.update({
typon_std = Path(__file__).parent.parent.parent.parent / "stdlib" typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def make_module(name: str, scope: Scope) -> BaseType: def make_module(name: str, scope: Scope) -> BaseType:
ty = ModuleType([], f"module${name}") ty = ModuleType([], f"{name}")
for n, v in scope.vars.items(): for n, v in scope.vars.items():
ty.members[n] = v.type ty.members[n] = v.type
return ty return ty
......
...@@ -32,7 +32,8 @@ class TypeAnnotationVisitor(NodeVisitorSeq): ...@@ -32,7 +32,8 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
return ty.type_object return ty.type_object
return ty return ty
raise NameError(node) from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node)
def visit_Name(self, node: ast.Name) -> BaseType: def visit_Name(self, node: ast.Name) -> BaseType:
return self.visit_str(node.id) return self.visit_str(node.id)
...@@ -59,4 +60,3 @@ class TypeAnnotationVisitor(NodeVisitorSeq): ...@@ -59,4 +60,3 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
res = left.members[node.attr] res = left.members[node.attr]
assert isinstance(res, TypeType) assert isinstance(res, TypeType)
return res.type_object return res.type_object
raise NotImplementedError(ast.unparse(node))
...@@ -58,7 +58,8 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -58,7 +58,8 @@ class ScoperBlockVisitor(ScoperVisitor):
for alias in node.names: for alias in node.names:
thing = module.val.get(alias.name) thing = module.val.get(alias.name)
if not thing: if not thing:
raise NameError(alias.name) from transpiler.phases.typing.exceptions import UnknownModuleMemberError
raise UnknownModuleMemberError(node.module, alias.name)
alias.item_obj = thing alias.item_obj = thing
self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing) self.scope.vars[alias.asname or alias.name] = VarDecl(VarKind.LOCAL, thing)
......
...@@ -35,5 +35,5 @@ class ScoperVisitor(NodeVisitorSeq): ...@@ -35,5 +35,5 @@ class ScoperVisitor(NodeVisitorSeq):
visitor.visit(b) visitor.visit(b)
b.decls = decls b.decls = decls
if not node.inner_scope.has_return: if not node.inner_scope.has_return:
rtype.unify(TY_NONE) rtype.unify(TY_NONE) # todo: properly indicate missing return
import ast import ast
import enum
from dataclasses import dataclass from dataclasses import dataclass
from transpiler.utils import highlight from transpiler.utils import highlight
from transpiler.exceptions import CompileError from transpiler.exceptions import CompileError
from transpiler.phases.typing.types import TypeVariable, BaseType from transpiler.phases.typing.types import TypeVariable, BaseType, TypeOperator
@dataclass @dataclass
...@@ -34,6 +35,7 @@ class UnresolvedTypeVariableError(CompileError): ...@@ -34,6 +35,7 @@ class UnresolvedTypeVariableError(CompileError):
{highlight('def f(x: int):')} {highlight('def f(x: int):')}
""" """
@dataclass @dataclass
class RecursiveTypeUnificationError(CompileError): class RecursiveTypeUnificationError(CompileError):
needle: BaseType needle: BaseType
...@@ -51,4 +53,167 @@ class RecursiveTypeUnificationError(CompileError): ...@@ -51,4 +53,167 @@ class RecursiveTypeUnificationError(CompileError):
In the current case, {highlight(self.haystack)} contains type {highlight(self.needle)}, but an attempt was made to In the current case, {highlight(self.haystack)} contains type {highlight(self.needle)}, but an attempt was made to
unify them. unify them.
"""
@dataclass
class InvalidCallError(CompileError):
callee: BaseType
args: list[BaseType]
def __str__(self) -> str:
return f"Invalid call: {highlight(self.callee)} with arguments {highlight(self.args)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This generally indicates a type error in a function call.
For example:
{highlight('def f(x: int): pass')}
{highlight('f("hello")')}
In the current case, {highlight(self.callee)} was called with arguments {highlight(self.args)}, but the function
expects arguments of type {highlight(self.callee.args)}.
"""
class TypeMismatchKind(enum.Enum):
NO_COMMON_PARENT = enum.auto()
DIFFERENT_TYPE = enum.auto()
@dataclass
class TypeMismatchError(CompileError):
expected: BaseType
got: BaseType
reason: TypeMismatchKind
def __str__(self) -> str:
return f"Type mismatch: expected {highlight(self.expected)}, got {highlight(self.got)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This generally indicates a type error.
For example:
{highlight('def f(x: int): pass')}
{highlight('f("hello")')}
In the current case, the compiler expected an expression of type {highlight(self.expected)}, but instead got
an expression of type {highlight(self.got)}.
"""
@dataclass
class ArgumentCountMismatchError(CompileError):
func: TypeOperator
arguments: TypeOperator
def __setattr__(self, key, value):
print(key, value)
super().__setattr__(key, value)
def __str__(self) -> str:
fcount = str(len(self.func.args))
if self.func.variadic:
fcount = f"at least {fcount}"
return f"Argument count mismatch: expected {fcount}, got {len(self.arguments.args)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates missing or extraneous arguments in a function call or type instantiation.
The called or instantiated signature was {highlight(self.func)}.
Other examples:
{highlight('def f(x: int): pass')}
{highlight('f(1, 2)')}
Here, the function {highlight('f')} expects one argument, but was called with two.
{highlight('x: list[int, str]')}
Here, the type {highlight('list')} expects one argument, but was instantiated with two.
"""
@dataclass
class ProtocolMismatchError(CompileError):
value: BaseType
protocol: BaseType
reason: Exception
def __str__(self) -> str:
return f"Protocol mismatch: {highlight(self.value)} does not implement {highlight(self.protocol)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This generally indicates a type error.
For example:
{highlight('def f(x: Iterable[int]): pass')}
{highlight('f("hello")')}
In the current case, the compiler expected an expression whose type implements {highlight(self.protocol)}, but
instead got an expression of type {highlight(self.value)}.
"""
@dataclass
class NotCallableError(CompileError):
value: BaseType
def __str__(self) -> str:
return f"Trying to call a non-function type: {highlight(self.value)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to call an object that is not a function.
For example:
{highlight('x = 1')}
{highlight('x()')}
"""
@dataclass
class MissingAttributeError(CompileError):
value: BaseType
attribute: str
def __str__(self) -> str:
return f"Missing attribute: {highlight(self.value)} has no attribute {highlight(self.attribute)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to access an attribute that does not exist.
For example:
{highlight('x = 1')}
{highlight('print(x.y)')}
"""
@dataclass
class UnknownNameError(CompileError):
name: str
def __str__(self) -> str:
return f"Unknown name: {highlight(self.name)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to access a name that does not exist.
For example:
{highlight('print(abcd)')}
"""
@dataclass
class UnknownModuleMemberError(CompileError):
module: str
name: str
def __str__(self) -> str:
return f"Unknown module member: Module {highlight(self.module)} does not contain {highlight(self.name)}"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that an attempt was made to import
For example:
{highlight('from math import abcd')}
""" """
\ No newline at end of file
...@@ -74,9 +74,10 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -74,9 +74,10 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_Name(self, node: ast.Name) -> BaseType: def visit_Name(self, node: ast.Name) -> BaseType:
obj = self.scope.get(node.id) obj = self.scope.get(node.id)
if not obj: if not obj:
raise NameError(f"Name {node.id} is not defined") from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node.id)
if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable): if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable):
raise NameError(f"Use of type variable") raise NameError(f"Use of type variable") # todo: when does this happen exactly?
if getattr(obj, "is_python_func", False): if getattr(obj, "is_python_func", False):
obj.python_func_used = True obj.python_func_used = True
return obj.type return obj.type
...@@ -93,10 +94,7 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -93,10 +94,7 @@ class ScoperExprVisitor(ScoperVisitor):
ftype = self.visit(node.func) ftype = self.visit(node.func)
if ftype.typevars: if ftype.typevars:
ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars}) ftype = ftype.gen_sub(None, {v.name: TypeVariable(v.name) for v in ftype.typevars})
try: rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
rtype = self.visit_function_call(ftype, [self.visit(arg) for arg in node.args])
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"`{ast.unparse(node)}`: {e}")
actual = rtype actual = rtype
node.is_await = False node.is_await = False
if isinstance(actual, Promise) and actual.kind != PromiseKind.GENERATOR: if isinstance(actual, Promise) and actual.kind != PromiseKind.GENERATOR:
...@@ -115,13 +113,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -115,13 +113,12 @@ class ScoperExprVisitor(ScoperVisitor):
init.return_type = ftype.type_object init.return_type = ftype.type_object
return self.visit_function_call(init, arguments) return self.visit_function_call(init, arguments)
if not isinstance(ftype, FunctionType): if not isinstance(ftype, FunctionType):
raise IncompatibleTypesError(f"Cannot call {ftype}") from transpiler.phases.typing.exceptions import NotCallableError
raise NotCallableError(ftype)
#is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list()) #is_generic = any(isinstance(arg, TypeVariable) for arg in ftype.to_list())
equivalent = FunctionType(arguments, ftype.return_type) equivalent = FunctionType(arguments, ftype.return_type)
try: equivalent.is_intermediary = True
ftype.unify(equivalent) ftype.unify(equivalent)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"Cannot call {ftype} with ({(', '.join(map(str, arguments)))}): {e}")
return ftype.return_type return ftype.return_type
def visit_Lambda(self, node: ast.Lambda) -> BaseType: def visit_Lambda(self, node: ast.Lambda) -> BaseType:
...@@ -143,17 +140,11 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -143,17 +140,11 @@ class ScoperExprVisitor(ScoperVisitor):
def visit_BinOp(self, node: ast.BinOp) -> BaseType: def visit_BinOp(self, node: ast.BinOp) -> BaseType:
left, right = map(self.visit, (node.left, node.right)) left, right = map(self.visit, (node.left, node.right))
try: return self.make_dunder([left, right], DUNDER[type(node.op)])
return self.make_dunder([left, right], DUNDER[type(node.op)])
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"{e} in `{ast.unparse(node)}`")
def visit_Attribute(self, node: ast.Attribute) -> BaseType: def visit_Attribute(self, node: ast.Attribute) -> BaseType:
try: ltype = self.visit(node.value)
ltype = self.visit(node.value) return self.visit_getattr(ltype, node.attr)
return self.visit_getattr(ltype, node.attr)
except Exception as e:
raise IncompatibleTypesError(f"{e} in `{ast.unparse(node)}`")
def visit_getattr(self, ltype: BaseType, name: str): def visit_getattr(self, ltype: BaseType, name: str):
bound = True bound = True
...@@ -175,7 +166,8 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -175,7 +166,8 @@ class ScoperExprVisitor(ScoperVisitor):
return meth.remove_self() return meth.remove_self()
else: else:
return meth return meth
raise IncompatibleTypesError(f"Type {ltype} has no attribute {name}") from transpiler.phases.typing.exceptions import MissingAttributeError
raise MissingAttributeError(ltype, name)
def visit_List(self, node: ast.List) -> BaseType: def visit_List(self, node: ast.List) -> BaseType:
if not node.elts: if not node.elts:
...@@ -216,10 +208,7 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -216,10 +208,7 @@ class ScoperExprVisitor(ScoperVisitor):
val = self.visit(node.operand) val = self.visit(node.operand)
if isinstance(node.op, ast.Not): if isinstance(node.op, ast.Not):
return TY_BOOL return TY_BOOL
try: return self.make_dunder([val], DUNDER[type(node.op)])
return self.make_dunder([val], DUNDER[type(node.op)])
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"{e} in `{ast.unparse(node)}`")
def visit_IfExp(self, node: ast.IfExp) -> BaseType: def visit_IfExp(self, node: ast.IfExp) -> BaseType:
self.visit(node.test) self.visit(node.test)
......
...@@ -5,7 +5,6 @@ from dataclasses import dataclass, field ...@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from itertools import zip_longest from itertools import zip_longest
from typing import Dict, Optional, List, ClassVar, Callable from typing import Dict, Optional, List, ClassVar, Callable
from transpiler.utils import highlight from transpiler.utils import highlight
...@@ -131,6 +130,7 @@ class TypeOperator(BaseType, ABC): ...@@ -131,6 +130,7 @@ class TypeOperator(BaseType, ABC):
is_protocol_gen: ClassVar[bool] = False is_protocol_gen: ClassVar[bool] = False
match_cache: set["TypeOperator"] = field(default_factory=set, init=False) match_cache: set["TypeOperator"] = field(default_factory=set, init=False)
is_reference: bool = False is_reference: bool = False
is_intermediary: bool = False
@staticmethod @staticmethod
def make_type(name: str): def make_type(name: str):
...@@ -167,9 +167,11 @@ class TypeOperator(BaseType, ABC): ...@@ -167,9 +167,11 @@ class TypeOperator(BaseType, ABC):
corresp.remove_self().unify(ty.remove_self()) corresp.remove_self().unify(ty.remove_self())
except Exception as e: except Exception as e:
self.match_cache.remove(hash(protocol)) self.match_cache.remove(hash(protocol))
raise IncompatibleTypesError(f"Type {self} doesn't implement protocol {protocol}: {e}") from transpiler.phases.typing.exceptions import ProtocolMismatchError
raise ProtocolMismatchError(self, protocol, e)
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
from transpiler.phases.typing.exceptions import TypeMismatchError, TypeMismatchKind
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
raise IncompatibleTypesError() raise IncompatibleTypesError()
if other.is_protocol and not self.is_protocol: if other.is_protocol and not self.is_protocol:
...@@ -194,10 +196,10 @@ class TypeOperator(BaseType, ABC): ...@@ -194,10 +196,10 @@ class TypeOperator(BaseType, ABC):
pass pass
else: else:
return return
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different type and no common parents") raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
if len(self.args) == 0: if len(self.args) == 0:
if self.name != other.name: if self.name != other.name:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}") raise TypeMismatchError(self, other, TypeMismatchKind.DIFFERENT_TYPE)
for i, (a, b) in enumerate(zip_longest(self.args, other.args)): for i, (a, b) in enumerate(zip_longest(self.args, other.args)):
if a is None and self.variadic or b is None and other.variadic: if a is None and self.variadic or b is None and other.variadic:
continue continue
...@@ -210,13 +212,14 @@ class TypeOperator(BaseType, ABC): ...@@ -210,13 +212,14 @@ class TypeOperator(BaseType, ABC):
other.args.append(a) other.args.append(a)
continue continue
else: else:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}, not enough arguments") from transpiler.phases.typing.exceptions import ArgumentCountMismatchError
raise ArgumentCountMismatchError(*sorted((self, other), key=lambda x: x.is_intermediary))
if isinstance(a, BaseType) and isinstance(b, BaseType): if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b) a.unify(b)
else: else:
if a != b: if a != b:
raise IncompatibleTypesError(f"Cannot unify {a} and {b}") raise TypeMismatchError(a, b, TypeMismatchKind.DIFFERENT_TYPE)
def contains_internal(self, other: "BaseType") -> bool: def contains_internal(self, other: "BaseType") -> bool:
return any(arg.contains(other) for arg in self.args) return any(arg.contains(other) for arg in self.args)
...@@ -259,6 +262,11 @@ class FunctionType(TypeOperator): ...@@ -259,6 +262,11 @@ class FunctionType(TypeOperator):
is_python_func: bool = False is_python_func: bool = False
python_func_used: bool = False python_func_used: bool = False
def __iter__(self):
x = 5
pass
return iter([str(self)])
def __init__(self, args: List[BaseType], ret: BaseType): def __init__(self, args: List[BaseType], ret: BaseType):
super().__init__([ret, *args]) super().__init__([ret, *args])
......
...@@ -2,11 +2,12 @@ import ast ...@@ -2,11 +2,12 @@ import ast
from dataclasses import dataclass from dataclasses import dataclass
from transpiler.utils import UnsupportedNodeError from transpiler.utils import UnsupportedNodeError, highlight
class NodeVisitorSeq: class NodeVisitorSeq:
def visit(self, node): def visit(self, node):
TB = f"running type analysis on {highlight(node)}"
"""Visit a node.""" """Visit a node."""
if type(node) == list: if type(node) == list:
for n in node: for n in node:
......
...@@ -29,21 +29,22 @@ def highlight(code, full=False): ...@@ -29,21 +29,22 @@ def highlight(code, full=False):
""" """
from transpiler.phases.typing import BaseType from transpiler.phases.typing import BaseType
if isinstance(code, ast.AST): if isinstance(code, ast.AST):
return cf.italic_darkGrey(f"[{type(code).__name__}] ") + highlight(ast.unparse(code)) return cf.italic_grey60(f"[{type(code).__name__}] ") + highlight(ast.unparse(code))
elif isinstance(code, BaseType): elif isinstance(code, BaseType):
return cf.italic_grey50(f"[{type(code).__name__}] ") + highlight(str(code)) return cf.italic_grey60(f"[{type(code).__name__}] ") + highlight(str(code))
from pygments import highlight as pyg_highlight from pygments import highlight as pyg_highlight
from pygments.lexers import PythonLexer from pygments.lexers import get_lexer_by_name
from pygments.formatters import TerminalFormatter from pygments.formatters import TerminalFormatter
items = pyg_highlight(code, PythonLexer(), TerminalFormatter()).replace("\x1b[39;49;00m", "\x1b[39m").splitlines() lexer = get_lexer_by_name("python", stripnl=False)
items = pyg_highlight(code, lexer, TerminalFormatter()).replace("\x1b[39;49;00m", "\x1b[39;24m")
if full: if full:
return "\n".join(items) return items
items = items.splitlines()
res = items[0] res = items[0]
if len(items) > 1: if len(items) > 1:
res += cf.white(" [...]") res += cf.white(" [...]")
#return Back.LIGHTBLACK_EX + Fore.RESET + res + Back.RESET return cf.on_gray25(res)
return cf.on_gray30(res)
def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool: def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment