Commit ba47aacb authored by Tom Niget's avatar Tom Niget

Fix optional parameter unification

parent 5f91e8b6
...@@ -92,4 +92,13 @@ class file: ...@@ -92,4 +92,13 @@ class file:
def read(self) -> Task[str]: ... def read(self) -> Task[str]: ...
def close(self) -> Task[None]: ... def close(self) -> Task[None]: ...
def open(filename: str, mode: str) -> Task[file]: ... def open(filename: str, mode: str) -> Task[file]: ...
\ No newline at end of file
def __test_opt(x: int, y: int = 5) -> int:
...
assert __test_opt
assert __test_opt(5)
assert __test_opt(5, 6)
assert not __test_opt(5, 6, 7)
assert not __test_opt()
\ No newline at end of file
...@@ -9,13 +9,8 @@ class IfMainVisitor(ast.NodeVisitor): ...@@ -9,13 +9,8 @@ class IfMainVisitor(ast.NodeVisitor):
for i, stmt in enumerate(node.body): for i, stmt in enumerate(node.body):
if isinstance(stmt, ast.If): if isinstance(stmt, ast.If):
if not stmt.orelse and compare_ast(stmt.test, NAME_MAIN): if not stmt.orelse and compare_ast(stmt.test, NAME_MAIN):
new_node = ast.FunctionDef( new_node = ast.parse("def main(): pass").body[0]
name="main", new_node.body = stmt.body
args=ast.arguments(args=[]),
body=stmt.body,
decorator_list=[],
returns=None
)
new_node.is_main = True new_node.is_main = True
node.body[i] = new_node node.body[i] = new_node
return return
\ No newline at end of file
...@@ -55,7 +55,10 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -55,7 +55,10 @@ class ScoperBlockVisitor(ScoperVisitor):
raise NotImplementedError(node) raise NotImplementedError(node)
target = node.targets[0] target = node.targets[0]
ty = self.get_type(node.value) ty = self.get_type(node.value)
node.is_declare = self.visit_assign_target(target, ty) try:
node.is_declare = self.visit_assign_target(target, ty)
except IncompatibleTypesError as e:
raise IncompatibleTypesError(f"`{ast.unparse(node)}: {e}")
def visit_assign_target(self, target, decl_val: BaseType) -> bool: def visit_assign_target(self, target, decl_val: BaseType) -> bool:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
...@@ -92,6 +95,7 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -92,6 +95,7 @@ class ScoperBlockVisitor(ScoperVisitor):
scope.function = scope scope.function = scope
node.inner_scope = scope node.inner_scope = scope
node.type = ftype node.type = ftype
ftype.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
for arg, ty in zip(node.args.args, argtypes): for arg, ty in zip(node.args.args, argtypes):
scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty) scope.vars[arg.arg] = VarDecl(VarKind.LOCAL, ty)
for b in node.body: for b in node.body:
......
...@@ -77,6 +77,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -77,6 +77,7 @@ class StdlibVisitor(NodeVisitorSeq):
ty.typevars = arg_visitor.typevars ty.typevars = arg_visitor.typevars
if node.args.vararg: if node.args.vararg:
ty.variadic = True ty.variadic = True
ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
if self.cur_class: if self.cur_class:
assert isinstance(self.cur_class, TypeType) assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta): if isinstance(self.cur_class.type_object, ABCMeta):
...@@ -86,7 +87,16 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -86,7 +87,16 @@ class StdlibVisitor(NodeVisitorSeq):
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert): def visit_Assert(self, node: ast.Assert):
print("Type of", ast.unparse(node.test), ":=", self.expr().visit(node.test)) if isinstance(node.test, ast.UnaryOp) and isinstance(node.test.op, ast.Not):
oper = node.test.operand
try:
res = self.expr().visit(oper)
except:
print("Type of", ast.unparse(oper), ":=", "INVALID")
else:
raise AssertionError(f"Assertion should fail, got {res} for {ast.unparse(oper)}")
else:
print("Type of", ast.unparse(node.test), ":=", self.expr().visit(node.test))
def visit_Call(self, node: ast.Call) -> BaseType: def visit_Call(self, node: ast.Call) -> BaseType:
ty_op = self.visit(node.func) ty_op = self.visit(node.func)
......
...@@ -144,6 +144,8 @@ class TypeOperator(BaseType, ABC): ...@@ -144,6 +144,8 @@ class TypeOperator(BaseType, ABC):
def unify_internal(self, other: BaseType): def unify_internal(self, other: BaseType):
if not isinstance(other, TypeOperator): if not isinstance(other, TypeOperator):
raise IncompatibleTypesError() raise IncompatibleTypesError()
if len(self.args) < len(other.args):
return other.unify_internal(self)
if type(self) != type(other): if type(self) != type(other):
for parent in other.get_parents(): for parent in other.get_parents():
try: try:
...@@ -161,28 +163,34 @@ class TypeOperator(BaseType, ABC): ...@@ -161,28 +163,34 @@ class TypeOperator(BaseType, ABC):
return return
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different type and no common parents") raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different type and no common parents")
if len(self.args) != len(other.args): if len(self.args) != len(other.args):
a, b = self, other
a_opt = a.optional_at is not None
b_opt = b.optional_at is not None
if a_opt and b_opt:
raise IncompatibleTypesError(f"This really should never happen")
if b_opt:
a, b = b, a
if a_opt:
# a = f(A, B; C=?, D=?)
# b = g(A, B, ... ?)
# either
# |a| < |b| => b has more args => invalid
# |a| ≥ |b| => b has less args => valid, up to |b|, so normal course of events
x = True
# c'est pété => utiliser le truc de la boucle en bas
# TODO: pas implémenté
if not (self.variadic or other.variadic): if not (self.variadic or other.variadic):
raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments") pass
# # a, b = self, other
# # if a.optio
# # a_opt = a.optional_at is not None
# # b_opt = b.optional_at is not None
# # if a_opt and b_opt:
# # raise IncompatibleTypesError(f"This really should never happen")
# # if b_opt:
# # other.unify_internal(self)
# # return
# if a_opt:
# if len(a.args) < len(b.args):
# raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
# # a = f(A, B; C=?, D=?)
# # b = g(A, B, ... ?)
# # either
# # |a| < |b| => b has more args => invalid
# # |a| ≥ |b| => b has less args => valid, up to |b|, so normal course of events
#
# x = True
#
# # c'est pété => utiliser le truc de la boucle en bas
#
# # TODO: pas implémenté
#
# # if not (self.variadic or other.variadic):
# # raise IncompatibleTypesError(f"Cannot unify {self} and {other} with different number of arguments")
if len(self.args) == 0: if len(self.args) == 0:
if self.name != other.name: if self.name != other.name:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}") raise IncompatibleTypesError(f"Cannot unify {self} and {other}")
...@@ -190,6 +198,12 @@ class TypeOperator(BaseType, ABC): ...@@ -190,6 +198,12 @@ class TypeOperator(BaseType, ABC):
if a is None and self.variadic or b is None and other.variadic: if a is None and self.variadic or b is None and other.variadic:
continue continue
if a is not None and b is None:
if i >= self.optional_at:
continue
else:
raise IncompatibleTypesError(f"Cannot unify {self} and {other}, not enough arguments")
if isinstance(a, BaseType) and isinstance(b, BaseType): if isinstance(a, BaseType) and isinstance(b, BaseType):
a.unify(b) a.unify(b)
else: else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment