Commit 655b1d94 authored by Tom Niget's avatar Tom Niget

Fix various things, calcbasic works

parent 32a6dcfe
...@@ -66,32 +66,34 @@ class PyObj : public std::shared_ptr<typename RealType<T>::type> { ...@@ -66,32 +66,34 @@ class PyObj : public std::shared_ptr<typename RealType<T>::type> {
public: public:
using inner = typename RealType<T>::type; using inner = typename RealType<T>::type;
PyObj() : std::shared_ptr<inner>() {} template<typename... Args>
PyObj(std::nullptr_t) : std::shared_ptr<inner>(nullptr) {} PyObj(Args&&... args) : std::shared_ptr<inner>(std::make_shared<inner>(std::forward<Args>(args)...)) {}
PyObj(inner *ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(const std::shared_ptr<inner> &ptr) : std::shared_ptr<inner>(ptr) {} PyObj() : std::shared_ptr<inner>() {}
PyObj(std::shared_ptr<inner> &&ptr) : std::shared_ptr<inner>(ptr) {} PyObj(std::nullptr_t) : std::shared_ptr<inner>(nullptr) {}
PyObj(const PyObj &ptr) : std::shared_ptr<inner>(ptr) {} PyObj(inner *ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(PyObj &&ptr) : std::shared_ptr<inner>(ptr) {} PyObj(const std::shared_ptr<inner> &ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(std::shared_ptr<inner> &&ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(const PyObj &ptr) : std::shared_ptr<inner>(ptr) {}
PyObj( PyObj &ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(PyObj &&ptr) : std::shared_ptr<inner>(ptr) {}
PyObj &operator=(const PyObj &ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; } PyObj &operator=(const PyObj &ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; }
PyObj &operator=(PyObj &&ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; } PyObj &operator=(PyObj &&ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; }
PyObj &operator=(std::nullptr_t) { std::shared_ptr<inner>::operator=(nullptr); return *this; } PyObj &operator=(std::nullptr_t) { std::shared_ptr<inner>::operator=(nullptr); return *this; }
PyObj &operator=(inner *ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; } PyObj &operator=(inner *ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; }
PyObj &operator=(const std::shared_ptr<inner> &ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; } PyObj &operator=(const std::shared_ptr<inner> &ptr) { std::shared_ptr<inner>::operator=(ptr); return *this; }
template<typename U> template<typename U>
PyObj(const PyObj<U> &ptr) : std::shared_ptr<inner>(ptr) {} PyObj(const PyObj<U> &ptr) : std::shared_ptr<inner>(ptr) {}
template<typename U> //PyObj(PyObj<U> &&ptr) : std::shared_ptr<inner>(ptr) {}
PyObj(PyObj<U> &&ptr) : std::shared_ptr<inner>(ptr) {}
// using make_shared // using make_shared
template<class U> /*template<class U>
PyObj(U&& other) : std::shared_ptr<inner>(std::make_shared<inner>(other)) {} PyObj(U&& other) : std::shared_ptr<inner>(std::make_shared<inner>(other)) {}*/
/*template<typename... Args>
PyObj(Args&&... args) : std::shared_ptr<inner>(std::forward<Args>(args)...) {}*/
...@@ -124,7 +126,7 @@ public: ...@@ -124,7 +126,7 @@ public:
} }
}; };
template <typename T, typename... Args> auto pyobj(Args &&...args) -> PyObj<T> { template <typename T, typename... Args> auto pyobj(Args &&...args) -> PyObj<typename RealType<T>::type> {
return std::make_shared<typename RealType<T>::type>( return std::make_shared<typename RealType<T>::type>(
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
......
...@@ -91,7 +91,7 @@ class BlockVisitor(NodeVisitor): ...@@ -91,7 +91,7 @@ class BlockVisitor(NodeVisitor):
else: else:
yield from self.visit(argty) yield from self.visit(argty)
yield arg.arg yield arg.arg
if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA} and default: if emission in {FunctionEmissionKind.DECLARATION, FunctionEmissionKind.LAMBDA, FunctionEmissionKind.METHOD} and default:
yield " = " yield " = "
yield from self.expr().visit(default) yield from self.expr().visit(default)
yield ")" yield ")"
......
...@@ -34,7 +34,7 @@ class ClassVisitor(NodeVisitor): ...@@ -34,7 +34,7 @@ class ClassVisitor(NodeVisitor):
else: else:
yield "void py_repr(std::ostream &s) const {" yield "void py_repr(std::ostream &s) const {"
yield "s << '{';" yield "s << '{';"
for i, (name, memb) in enumerate(node.type.fields.items()): for i, (name, memb) in enumerate(node.type.get_members().items()):
if i != 0: if i != 0:
yield 's << ", ";' yield 's << ", ";'
yield f's << "\\"{name}\\": ";' yield f's << "\\"{name}\\": ";'
......
...@@ -27,8 +27,9 @@ class ModuleVisitor(BlockVisitor): ...@@ -27,8 +27,9 @@ class ModuleVisitor(BlockVisitor):
yield f"struct {concrete}_t {{" yield f"struct {concrete}_t {{"
for name, obj in alias.module_obj.fields.items(): for name, obj in alias.module_obj.fields.items():
if obj.type.python_func_used: ty = obj.type.resolve()
yield from self.emit_python_func(alias.name, name, name, obj.type) if getattr(ty, "python_func_used", False):
yield from self.emit_python_func(alias.name, name, name, ty)
yield "} all;" yield "} all;"
yield f"auto& get_all() {{ return all; }}" yield f"auto& get_all() {{ return all; }}"
......
...@@ -32,7 +32,9 @@ class ScoperBlockVisitor(ScoperVisitor): ...@@ -32,7 +32,9 @@ class ScoperBlockVisitor(ScoperVisitor):
# copy all functions to mod_scope # copy all functions to mod_scope
for fname, obj in py_mod.__dict__.items(): for fname, obj in py_mod.__dict__.items():
if callable(obj): if callable(obj):
fty = FunctionType([], TypeVariable()) # fty = FunctionType([], TypeVariable())
# fty.is_python_func = True
fty = TypeVariable()
fty.is_python_func = True fty.is_python_func = True
mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty) mod_scope.vars[fname] = VarDecl(VarKind.LOCAL, fty)
mod = make_mod_decl(name, mod_scope) mod = make_mod_decl(name, mod_scope)
......
...@@ -7,7 +7,7 @@ from transpiler.phases.typing import ScopeKind, VarDecl, VarKind ...@@ -7,7 +7,7 @@ from transpiler.phases.typing import ScopeKind, VarDecl, VarKind
from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next from transpiler.phases.typing.common import ScoperVisitor, get_iter, get_next
from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \ from transpiler.phases.typing.types import BaseType, TupleType, TY_STR, TY_BOOL, TY_INT, \
TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \ TY_COMPLEX, TY_NONE, FunctionType, PyList, TypeVariable, PySet, TypeType, PyDict, Promise, PromiseKind, UserType, \
TY_SLICE, TY_FLOAT TY_SLICE, TY_FLOAT, RuntimeValue
from transpiler.utils import linenodata from transpiler.utils import linenodata
DUNDER = { DUNDER = {
...@@ -92,11 +92,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -92,11 +92,12 @@ class ScoperExprVisitor(ScoperVisitor):
if not obj: if not obj:
from transpiler.phases.typing.exceptions import UnknownNameError from transpiler.phases.typing.exceptions import UnknownNameError
raise UnknownNameError(node.id) raise UnknownNameError(node.id)
if isinstance(obj.type, TypeType) and isinstance(obj.type.type_object, TypeVariable): ty = obj.type.resolve()
if isinstance(ty, TypeType) and isinstance(ty.type_object, TypeVariable):
raise NameError(f"Use of type variable") # todo: when does this happen exactly? raise NameError(f"Use of type variable") # todo: when does this happen exactly?
if getattr(obj, "is_python_func", False): if getattr(ty, "is_python_func", False):
obj.python_func_used = True ty.python_func_used = True
return obj.type return ty
def visit_BoolOp(self, node: ast.BoolOp) -> BaseType: def visit_BoolOp(self, node: ast.BoolOp) -> BaseType:
for value in node.values: for value in node.values:
...@@ -199,12 +200,12 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -199,12 +200,12 @@ class ScoperExprVisitor(ScoperVisitor):
# else: # else:
# return meth # return meth
if field := ltype.fields.get(name): if field := ltype.fields.get(name):
ty = field.type ty = field.type.resolve()
if getattr(ty, "is_python_func", False): if getattr(ty, "is_python_func", False):
ty.python_func_used = True ty.python_func_used = True
if isinstance(ty, FunctionType): if isinstance(ty, FunctionType):
ty = ty.gen_sub(ltype, {}) ty = ty.gen_sub(ltype, {})
if bound and field.in_class_def: if bound and field.in_class_def and type(field.val) != RuntimeValue:
return ty.remove_self() return ty.remove_self()
return ty return ty
......
...@@ -111,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -111,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq):
if isinstance(self.cur_class.type_object, ABCMeta): if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars) self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
else: else:
self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars)) self.cur_class.type_object.fields[node.name] = MemberDef(ty.gen_sub(self.cur_class.type_object, self.typevars), ())
self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty) self.scope.vars[node.name] = VarDecl(VarKind.LOCAL, ty)
def visit_Assert(self, node: ast.Assert): def visit_Assert(self, node: ast.Assert):
......
...@@ -148,6 +148,18 @@ def next_var_id(): ...@@ -148,6 +148,18 @@ def next_var_id():
class TypeVariable(BaseType): class TypeVariable(BaseType):
name: str = field(default_factory=lambda: next_var_id()) name: str = field(default_factory=lambda: next_var_id())
resolved: Optional[BaseType] = None resolved: Optional[BaseType] = None
patch_attrs: dict = field(default_factory=dict)
def __setattr__(self, key, value):
if "patch_attrs" in self.__dict__ and key not in self.__dict__:
self.patch_attrs[key] = value
else:
super().__setattr__(key, value)
def __getattr__(self, item):
if "patch_attrs" in self.__dict__ and item in self.patch_attrs:
return self.patch_attrs[item]
raise AttributeError(item)
def __str__(self): def __str__(self):
if self.resolved is None: if self.resolved is None:
...@@ -166,6 +178,8 @@ class TypeVariable(BaseType): ...@@ -166,6 +178,8 @@ class TypeVariable(BaseType):
from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError from transpiler.phases.typing.exceptions import RecursiveTypeUnificationError
raise RecursiveTypeUnificationError(self, other) raise RecursiveTypeUnificationError(self, other)
self.resolved = other self.resolved = other
for k, v in self.patch_attrs.items():
setattr(other, k, v)
def contains_internal(self, other: BaseType) -> bool: def contains_internal(self, other: BaseType) -> bool:
return self.resolve() is other.resolve() return self.resolve() is other.resolve()
...@@ -210,7 +224,7 @@ class TypeOperator(BaseType, ABC): ...@@ -210,7 +224,7 @@ class TypeOperator(BaseType, ABC):
if self.name is None: if self.name is None:
self.name = self.__class__.__name__ self.name = self.__class__.__name__
for name, factory in self.gen_methods.items(): for name, factory in self.gen_methods.items():
self.fields[name] = MemberDef(factory(self)) self.fields[name] = MemberDef(factory(self), ())
for gp in self.gen_parents: for gp in self.gen_parents:
if not isinstance(gp, BaseType): if not isinstance(gp, BaseType):
gp = gp(self.args) gp = gp(self.args)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment