Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
T
typon
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Tom Niget
typon
Commits
32a6dcfe
Commit
32a6dcfe
authored
Aug 29, 2023
by
Tom Niget
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Merge .members and .methods; fix unification for hierarchy lookup
parent
e2134ee5
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
135 additions
and
75 deletions
+135
-75
trans/tests/a_a_enumtest.py
trans/tests/a_a_enumtest.py
+11
-0
trans/tests/a_calcbasic.py2
trans/tests/a_calcbasic.py2
+0
-0
trans/transpiler/phases/emit_cpp/block.py
trans/transpiler/phases/emit_cpp/block.py
+1
-5
trans/transpiler/phases/emit_cpp/class_.py
trans/transpiler/phases/emit_cpp/class_.py
+4
-4
trans/transpiler/phases/emit_cpp/module.py
trans/transpiler/phases/emit_cpp/module.py
+3
-3
trans/transpiler/phases/emit_cpp/search.py
trans/transpiler/phases/emit_cpp/search.py
+2
-0
trans/transpiler/phases/typing/__init__.py
trans/transpiler/phases/typing/__init__.py
+2
-2
trans/transpiler/phases/typing/annotations.py
trans/transpiler/phases/typing/annotations.py
+1
-1
trans/transpiler/phases/typing/block.py
trans/transpiler/phases/typing/block.py
+9
-6
trans/transpiler/phases/typing/class_.py
trans/transpiler/phases/typing/class_.py
+5
-5
trans/transpiler/phases/typing/common.py
trans/transpiler/phases/typing/common.py
+2
-2
trans/transpiler/phases/typing/expr.py
trans/transpiler/phases/typing/expr.py
+27
-10
trans/transpiler/phases/typing/scope.py
trans/transpiler/phases/typing/scope.py
+1
-5
trans/transpiler/phases/typing/stdlib.py
trans/transpiler/phases/typing/stdlib.py
+4
-3
trans/transpiler/phases/typing/types.py
trans/transpiler/phases/typing/types.py
+63
-29
No files found.
trans/tests/a_a_enumtest.py
0 → 100644
View file @
32a6dcfe
# coding: utf-8
from
enum
import
Enum
class
TokenType
(
Enum
):
NUMBER
=
1
PARENTHESIS
=
2
OPERATION
=
3
if
__name__
==
"__main__"
:
x
=
TokenType
.
NUMBER
\ No newline at end of file
trans/tests/
calcbasic.py
→
trans/tests/
a_calcbasic.py2
View file @
32a6dcfe
File moved
trans/transpiler/phases/emit_cpp/block.py
View file @
32a6dcfe
...
...
@@ -124,11 +124,7 @@ class BlockVisitor(NodeVisitor):
def
visit_ClassDef
(
self
,
node
:
ast
.
ClassDef
):
yield
from
()
def
check
(
self
,
f
):
for
b
in
node
.
body
:
yield
from
self
.
match
(
node
)
has_return
=
next
(
ReturnVisitor
().
check
(
node
),
False
)
has_return
=
ReturnVisitor
().
match
(
node
.
body
)
yield
from
self
.
visit_func_decls
(
node
.
body
,
inner_scope
)
...
...
trans/transpiler/phases/emit_cpp/class_.py
View file @
32a6dcfe
...
...
@@ -29,12 +29,12 @@ class ClassVisitor(NodeVisitor):
yield
"int value;"
yield
"operator int() const { return value; }"
yield
"void py_repr(std::ostream &s) const {"
yield
f's << "
{
node
.
name
}
."
<< value
;'
yield
f's << "
{
node
.
name
}
.";'
yield
"}"
else
:
yield
"void py_repr(std::ostream &s) const {"
yield
"s << '{';"
for
i
,
(
name
,
memb
)
in
enumerate
(
node
.
type
.
member
s
.
items
()):
for
i
,
(
name
,
memb
)
in
enumerate
(
node
.
type
.
field
s
.
items
()):
if
i
!=
0
:
yield
's << ", ";'
yield
f's << "
\
\
"
{
name
}\
\
": ";'
...
...
@@ -63,8 +63,8 @@ class ClassInnerVisitor(NodeVisitor):
scope
:
Scope
def
visit_AnnAssign
(
self
,
node
:
ast
.
AnnAssign
)
->
Iterable
[
str
]:
member
=
self
.
scope
.
obj_type
.
member
s
[
node
.
target
.
id
]
yield
from
self
.
visit
(
member
)
member
=
self
.
scope
.
obj_type
.
field
s
[
node
.
target
.
id
]
yield
from
self
.
visit
(
member
.
type
)
yield
node
.
target
.
id
yield
";"
...
...
trans/transpiler/phases/emit_cpp/module.py
View file @
32a6dcfe
...
...
@@ -26,9 +26,9 @@ class ModuleVisitor(BlockVisitor):
yield
f"namespace py_
{
concrete
}
{{"
yield
f"struct
{
concrete
}
_t {{"
for
name
,
obj
in
alias
.
module_obj
.
member
s
.
items
():
if
obj
.
python_func_used
:
yield
from
self
.
emit_python_func
(
alias
.
name
,
name
,
name
,
obj
)
for
name
,
obj
in
alias
.
module_obj
.
field
s
.
items
():
if
obj
.
type
.
python_func_used
:
yield
from
self
.
emit_python_func
(
alias
.
name
,
name
,
name
,
obj
.
type
)
yield
"} all;"
yield
f"auto& get_all() {{ return all; }}"
...
...
trans/transpiler/phases/emit_cpp/search.py
View file @
32a6dcfe
...
...
@@ -15,4 +15,6 @@ class SearchVisitor(ast.NodeVisitor):
yield
from
self
.
visit
(
value
)
def
match
(
self
,
node
)
->
bool
:
if
type
(
node
)
==
list
:
return
any
(
self
.
match
(
n
)
for
n
in
node
)
return
next
(
self
.
visit
(
node
),
False
)
trans/transpiler/phases/typing/__init__.py
View file @
32a6dcfe
...
...
@@ -5,7 +5,7 @@ from transpiler.phases.typing.scope import VarKind, VarDecl, ScopeKind, Scope
from
transpiler.phases.typing.stdlib
import
PRELUDE
,
StdlibVisitor
from
transpiler.phases.typing.types
import
TY_TYPE
,
TY_INT
,
TY_STR
,
TY_BOOL
,
TY_COMPLEX
,
TY_NONE
,
FunctionType
,
\
TypeVariable
,
CppType
,
PyList
,
TypeType
,
Forked
,
Task
,
Future
,
PyIterator
,
TupleType
,
TypeOperator
,
BaseType
,
\
ModuleType
,
TY_BYTES
,
TY_FLOAT
,
PyDict
,
TY_SLICE
,
TY_OBJECT
,
BuiltinFeature
,
UnionType
ModuleType
,
TY_BYTES
,
TY_FLOAT
,
PyDict
,
TY_SLICE
,
TY_OBJECT
,
BuiltinFeature
,
UnionType
,
MemberDef
PRELUDE
.
vars
.
update
({
# "int": VarDecl(VarKind.LOCAL, TY_TYPE, TY_INT),
...
...
@@ -46,7 +46,7 @@ typon_std = Path(__file__).parent.parent.parent.parent / "stdlib"
def
make_module
(
name
:
str
,
scope
:
Scope
)
->
BaseType
:
ty
=
ModuleType
([],
f"
{
name
}
"
)
for
n
,
v
in
scope
.
vars
.
items
():
ty
.
members
[
n
]
=
v
.
type
ty
.
fields
[
n
]
=
MemberDef
(
v
.
type
,
v
.
val
,
False
)
return
ty
...
...
trans/transpiler/phases/typing/annotations.py
View file @
32a6dcfe
...
...
@@ -57,7 +57,7 @@ class TypeAnnotationVisitor(NodeVisitorSeq):
def
visit_Attribute
(
self
,
node
:
ast
.
Attribute
)
->
BaseType
:
left
=
self
.
visit
(
node
.
value
)
res
=
left
.
members
[
node
.
attr
]
res
=
left
.
fields
[
node
.
attr
].
type
assert
isinstance
(
res
,
TypeType
)
return
res
.
type_object
...
...
trans/transpiler/phases/typing/block.py
View file @
32a6dcfe
...
...
@@ -11,7 +11,8 @@ from transpiler.phases.typing.expr import ScoperExprVisitor, DUNDER
from
transpiler.phases.typing.class_
import
ScoperClassVisitor
from
transpiler.phases.typing.scope
import
VarDecl
,
VarKind
,
ScopeKind
,
Scope
from
transpiler.phases.typing.types
import
BaseType
,
TypeVariable
,
FunctionType
,
\
Promise
,
TY_NONE
,
PromiseKind
,
TupleType
,
UserType
,
TypeType
,
ModuleType
,
BuiltinFeature
,
TY_INT
Promise
,
TY_NONE
,
PromiseKind
,
TupleType
,
UserType
,
TypeType
,
ModuleType
,
BuiltinFeature
,
TY_INT
,
MemberDef
,
\
RuntimeValue
from
transpiler.phases.utils
import
PlainBlock
,
AnnotationName
...
...
@@ -167,7 +168,7 @@ class ScoperBlockVisitor(ScoperVisitor):
init_method
=
ast
.
FunctionDef
(
name
=
"__init__"
,
args
=
ast
.
arguments
(
args
=
[
ast
.
arg
(
arg
=
"self"
),
*
[
ast
.
arg
(
arg
=
n
)
for
n
in
ctype
.
members
]],
args
=
[
ast
.
arg
(
arg
=
"self"
),
*
[
ast
.
arg
(
arg
=
n
)
for
n
in
ctype
.
get_members
()
]],
defaults
=
[],
kw_defaults
=
[],
kwarg
=
None
,
...
...
@@ -179,7 +180,7 @@ class ScoperBlockVisitor(ScoperVisitor):
targets
=
[
ast
.
Attribute
(
value
=
ast
.
Name
(
id
=
"self"
),
attr
=
n
)],
value
=
ast
.
Name
(
id
=
n
),
**
lnd
)
for
n
in
ctype
.
members
)
for
n
in
ctype
.
get_members
()
],
decorator_list
=
[],
returns
=
None
,
...
...
@@ -195,9 +196,11 @@ class ScoperBlockVisitor(ScoperVisitor):
base
=
self
.
expr
().
visit
(
base
)
if
is_builtin
(
base
,
"Enum"
):
ctype
.
parents
.
append
(
TY_INT
)
for
k
in
ctype
.
members
:
ctype
.
members
[
k
]
=
ctype
ctype
.
members
[
"value"
]
=
TY_INT
for
k
,
m
in
ctype
.
fields
.
items
():
m
.
type
=
ctype
m
.
val
=
ast
.
literal_eval
(
m
.
val
)
assert
type
(
m
.
val
)
==
int
ctype
.
fields
[
"value"
]
=
MemberDef
(
TY_INT
)
lnd
=
linenodata
(
node
)
init_method
=
ast
.
FunctionDef
(
name
=
"__init__"
,
...
...
trans/transpiler/phases/typing/class_.py
View file @
32a6dcfe
...
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from
transpiler.phases.typing
import
FunctionType
,
ScopeKind
,
VarDecl
,
VarKind
,
TY_NONE
from
transpiler.phases.typing.common
import
ScoperVisitor
from
transpiler.phases.typing.types
import
PromiseKind
,
Promise
,
BaseType
from
transpiler.phases.typing.types
import
PromiseKind
,
Promise
,
BaseType
,
MemberDef
@
dataclass
...
...
@@ -15,15 +15,15 @@ class ScoperClassVisitor(ScoperVisitor):
assert
node
.
value
is
None
,
"Class field should not have a value"
assert
node
.
simple
==
1
,
"Class field should be simple (identifier, not parenthesized)"
assert
isinstance
(
node
.
target
,
ast
.
Name
)
self
.
scope
.
obj_type
.
members
[
node
.
target
.
id
]
=
self
.
visit_annotation
(
node
.
annotation
)
self
.
scope
.
obj_type
.
fields
[
node
.
target
.
id
]
=
MemberDef
(
self
.
visit_annotation
(
node
.
annotation
)
)
def
visit_Assign
(
self
,
node
:
ast
.
Assign
):
assert
len
(
node
.
targets
)
==
1
,
"C
lass field should be assigned to only once
"
assert
len
(
node
.
targets
)
==
1
,
"C
an't use destructuring in class static member
"
assert
isinstance
(
node
.
targets
[
0
],
ast
.
Name
)
node
.
is_declare
=
True
valtype
=
self
.
expr
().
visit
(
node
.
value
)
node
.
targets
[
0
].
type
=
valtype
self
.
scope
.
obj_type
.
members
[
node
.
targets
[
0
].
id
]
=
valtype
self
.
scope
.
obj_type
.
fields
[
node
.
targets
[
0
].
id
]
=
MemberDef
(
valtype
,
node
.
value
)
def
visit_FunctionDef
(
self
,
node
:
ast
.
FunctionDef
):
ftype
=
self
.
parse_function
(
node
)
...
...
@@ -32,5 +32,5 @@ class ScoperClassVisitor(ScoperVisitor):
if
node
.
name
!=
"__init__"
:
ftype
.
return_type
=
Promise
(
ftype
.
return_type
,
PromiseKind
.
TASK
)
ftype
.
is_method
=
True
self
.
scope
.
obj_type
.
methods
[
node
.
name
]
=
ftype
self
.
scope
.
obj_type
.
fields
[
node
.
name
]
=
MemberDef
(
ftype
,
node
)
return
(
node
,
inner
)
trans/transpiler/phases/typing/common.py
View file @
32a6dcfe
...
...
@@ -108,7 +108,7 @@ class ScoperVisitor(NodeVisitorSeq):
def
get_iter
(
seq_type
):
try
:
iter_type
=
seq_type
.
methods
[
"__iter__"
]
.
return_type
iter_type
=
seq_type
.
fields
[
"__iter__"
].
type
.
return_type
except
:
from
transpiler.phases.typing.exceptions
import
NotIterableError
raise
NotIterableError
(
seq_type
)
...
...
@@ -116,7 +116,7 @@ def get_iter(seq_type):
def
get_next
(
iter_type
):
try
:
next_type
=
iter_type
.
methods
[
"__next__"
]
.
return_type
next_type
=
iter_type
.
fields
[
"__next__"
].
type
.
return_type
except
:
from
transpiler.phases.typing.exceptions
import
NotIteratorError
raise
NotIteratorError
(
iter_type
)
...
...
trans/transpiler/phases/typing/expr.py
View file @
32a6dcfe
...
...
@@ -174,6 +174,11 @@ class ScoperExprVisitor(ScoperVisitor):
def
visit_getattr
(
self
,
ltype
:
BaseType
,
name
:
str
)
->
BaseType
:
bound
=
True
if
isinstance
(
ltype
,
TypeType
):
# if mdecl := ltype.static_members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
ltype
=
ltype
.
type_object
bound
=
False
if
isinstance
(
ltype
,
abc
.
ABCMeta
):
...
...
@@ -182,16 +187,28 @@ class ScoperExprVisitor(ScoperVisitor):
if
not
all
(
arg
.
annotation
==
BaseType
for
arg
in
args
):
raise
NotImplementedError
(
"I don't know how to handle this type"
)
ltype
=
ltype
(
*
(
TypeVariable
()
for
_
in
args
))
if
attr
:
=
ltype
.
members
.
get
(
name
):
if
getattr
(
attr
,
"is_python_func"
,
False
):
attr
.
python_func_used
=
True
return
attr
if
meth
:
=
ltype
.
methods
.
get
(
name
):
meth
=
meth
.
gen_sub
(
ltype
,
{})
if
bound
:
return
meth
.
remove_self
()
else
:
return
meth
# if mdecl := ltype.members.get(name):
# attr = mdecl.type
# if getattr(attr, "is_python_func", False):
# attr.python_func_used = True
# return attr
# if meth := ltype.methods.get(name):
# meth = meth.gen_sub(ltype, {})
# if bound:
# return meth.remove_self()
# else:
# return meth
if
field
:
=
ltype
.
fields
.
get
(
name
):
ty
=
field
.
type
if
getattr
(
ty
,
"is_python_func"
,
False
):
ty
.
python_func_used
=
True
if
isinstance
(
ty
,
FunctionType
):
ty
=
ty
.
gen_sub
(
ltype
,
{})
if
bound
and
field
.
in_class_def
:
return
ty
.
remove_self
()
return
ty
from
transpiler.phases.typing.exceptions
import
MissingAttributeError
parents
=
ltype
.
iter_hierarchy_recursive
()
next
(
parents
)
...
...
trans/transpiler/phases/typing/scope.py
View file @
32a6dcfe
...
...
@@ -3,7 +3,7 @@ from dataclasses import field, dataclass
from
enum
import
Enum
from
typing
import
Optional
,
Dict
,
List
,
Any
from
transpiler.phases.typing.types
import
BaseType
from
transpiler.phases.typing.types
import
BaseType
,
RuntimeValue
class
VarKind
(
Enum
):
...
...
@@ -23,10 +23,6 @@ class VarType:
pass
class
RuntimeValue
:
pass
@
dataclass
class
VarDecl
:
kind
:
VarKind
...
...
trans/transpiler/phases/typing/stdlib.py
View file @
32a6dcfe
...
...
@@ -8,7 +8,8 @@ from transpiler.phases.typing.annotations import TypeAnnotationVisitor
from
transpiler.phases.typing.common
import
PRELUDE
from
transpiler.phases.typing.expr
import
ScoperExprVisitor
from
transpiler.phases.typing.scope
import
Scope
,
VarDecl
,
VarKind
,
ScopeKind
from
transpiler.phases.typing.types
import
BaseType
,
TypeOperator
,
FunctionType
,
TY_VARARG
,
TypeType
,
TypeVariable
from
transpiler.phases.typing.types
import
BaseType
,
TypeOperator
,
FunctionType
,
TY_VARARG
,
TypeType
,
TypeVariable
,
\
MemberDef
from
transpiler.phases.utils
import
NodeVisitorSeq
...
...
@@ -36,7 +37,7 @@ class StdlibVisitor(NodeVisitorSeq):
if
isinstance
(
self
.
cur_class
.
type_object
,
ABCMeta
):
raise
NotImplementedError
else
:
self
.
cur_class
.
type_object
.
members
[
node
.
target
.
id
]
=
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
self
.
cur_class
.
type_object
.
fields
[
node
.
target
.
id
]
=
MemberDef
(
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
)
self
.
scope
.
vars
[
node
.
target
.
id
]
=
VarDecl
(
VarKind
.
LOCAL
,
ty
)
def
visit_ImportFrom
(
self
,
node
:
ast
.
ImportFrom
):
...
...
@@ -110,7 +111,7 @@ class StdlibVisitor(NodeVisitorSeq):
if
isinstance
(
self
.
cur_class
.
type_object
,
ABCMeta
):
self
.
cur_class
.
type_object
.
gen_methods
[
node
.
name
]
=
lambda
t
:
ty
.
gen_sub
(
t
,
self
.
typevars
)
else
:
self
.
cur_class
.
type_object
.
methods
[
node
.
name
]
=
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
self
.
cur_class
.
type_object
.
fields
[
node
.
name
]
=
MemberDef
(
ty
.
gen_sub
(
self
.
cur_class
.
type_object
,
self
.
typevars
)
)
self
.
scope
.
vars
[
node
.
name
]
=
VarDecl
(
VarKind
.
LOCAL
,
ty
)
def
visit_Assert
(
self
,
node
:
ast
.
Assert
):
...
...
trans/transpiler/phases/typing/types.py
View file @
32a6dcfe
...
...
@@ -13,12 +13,36 @@ def get_default_parents():
return
[
obj
]
return
[]
class
RuntimeValue
:
pass
@
dataclass
class
MemberDef
:
type
:
"BaseType"
val
:
typing
.
Any
=
RuntimeValue
()
in_class_def
:
bool
=
True
@
dataclass
class
UnifyMode
:
search_hierarchy
:
bool
=
True
match_protocol
:
bool
=
True
UnifyMode
.
NORMAL
=
UnifyMode
()
UnifyMode
.
EXACT
=
UnifyMode
(
False
,
False
)
@
dataclass
(
eq
=
False
)
class
BaseType
(
ABC
):
members
:
Dict
[
str
,
"BaseType"
]
=
field
(
default_factory
=
dict
,
init
=
False
)
methods
:
Dict
[
str
,
"FunctionType"
]
=
field
(
default_factory
=
dict
,
init
=
False
)
#members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
#methods: Dict[str, "FunctionType"] = field(default_factory=dict, init=False)
fields
:
Dict
[
str
,
"MemberDef"
]
=
field
(
default_factory
=
dict
,
init
=
False
)
parents
:
List
[
"BaseType"
]
=
field
(
default_factory
=
get_default_parents
,
init
=
False
)
typevars
:
List
[
"TypeVariable"
]
=
field
(
default_factory
=
list
,
init
=
False
)
#static_members: Dict[str, "MemberDef"] = field(default_factory=dict, init=False)
def
get_members
(
self
):
return
{
n
:
m
for
n
,
m
in
self
.
fields
.
items
()
if
type
(
m
.
val
)
is
RuntimeValue
}
def
get_parents
(
self
)
->
List
[
"BaseType"
]:
...
...
@@ -41,21 +65,29 @@ class BaseType(ABC):
queue
.
put
(
p
)
def
inherits_from
(
self
,
other
:
"BaseType"
)
->
bool
:
return
other
in
self
.
iter_hierarchy_recursive
()
from
transpiler.exceptions
import
CompileError
for
parent
in
self
.
iter_hierarchy_recursive
():
try
:
parent
.
unify
(
other
,
UnifyMode
.
EXACT
)
except
CompileError
:
pass
else
:
return
True
return
False
def
resolve
(
self
)
->
"BaseType"
:
return
self
@
abstractmethod
def
unify_internal
(
self
,
other
:
"BaseType"
):
def
unify_internal
(
self
,
other
:
"BaseType"
,
mode
:
UnifyMode
):
pass
def
unify
(
self
,
other
:
"BaseType"
):
def
unify
(
self
,
other
:
"BaseType"
,
mode
=
UnifyMode
.
NORMAL
):
a
,
b
=
self
.
resolve
(),
other
.
resolve
()
TB
=
f"unifying
{
highlight
(
a
)
}
and
{
highlight
(
b
)
}
"
if
isinstance
(
b
,
TypeVariable
):
a
,
b
=
b
,
a
a
.
unify_internal
(
b
)
a
.
unify_internal
(
b
,
mode
)
def
contains
(
self
,
other
:
"BaseType"
)
->
bool
:
needle
,
haystack
=
other
.
resolve
(),
self
.
resolve
()
...
...
@@ -86,7 +118,7 @@ class MagicType(BaseType, typing.Generic[T]):
super
().
__init__
()
self
.
val
=
val
def
unify_internal
(
self
,
other
:
"BaseType"
):
def
unify_internal
(
self
,
other
:
"BaseType"
,
mode
:
UnifyMode
):
if
type
(
self
)
!=
type
(
other
)
or
self
.
val
!=
other
.
val
:
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
...
...
@@ -128,7 +160,7 @@ class TypeVariable(BaseType):
return
self
return
self
.
resolved
.
resolve
()
def
unify_internal
(
self
,
other
:
BaseType
):
def
unify_internal
(
self
,
other
:
BaseType
,
mode
:
UnifyMode
):
if
self
is
not
other
:
if
other
.
contains
(
self
):
from
transpiler.phases.typing.exceptions
import
RecursiveTypeUnificationError
...
...
@@ -178,19 +210,19 @@ class TypeOperator(BaseType, ABC):
if
self
.
name
is
None
:
self
.
name
=
self
.
__class__
.
__name__
for
name
,
factory
in
self
.
gen_methods
.
items
():
self
.
methods
[
name
]
=
factory
(
self
)
self
.
fields
[
name
]
=
MemberDef
(
factory
(
self
)
)
for
gp
in
self
.
gen_parents
:
if
not
isinstance
(
gp
,
BaseType
):
gp
=
gp
(
self
.
args
)
self
.
parents
.
append
(
gp
)
self
.
methods
=
{
**
gp
.
methods
,
**
self
.
metho
ds
}
self
.
fields
=
{
**
gp
.
fields
,
**
self
.
fiel
ds
}
self
.
is_protocol
=
self
.
is_protocol
or
self
.
is_protocol_gen
self
.
_add_default_eq
()
def
_add_default_eq
(
self
):
if
"__eq__"
not
in
self
.
metho
ds
:
if
"__eq__"
not
in
self
.
fiel
ds
:
if
"DEFAULT_EQ"
in
globals
():
self
.
methods
[
"__eq__"
]
=
DEFAULT_EQ
self
.
fields
[
"__eq__"
]
=
MemberDef
(
DEFAULT_EQ
)
def
matches_protocol
(
self
,
protocol
:
"TypeOperator"
):
if
hash
(
protocol
)
in
self
.
match_cache
:
...
...
@@ -199,33 +231,35 @@ class TypeOperator(BaseType, ABC):
try
:
dupl
=
protocol
.
gen_sub
(
self
,
{
v
.
name
:
(
TypeVariable
(
v
.
name
)
if
isinstance
(
v
.
resolve
(),
TypeVariable
)
else
v
)
for
v
in
protocol
.
args
})
self
.
match_cache
.
add
(
hash
(
protocol
))
for
name
,
ty
in
dupl
.
metho
ds
.
items
():
for
name
,
ty
in
dupl
.
fiel
ds
.
items
():
if
name
==
"__eq__"
:
continue
if
name
not
in
self
.
metho
ds
:
if
name
not
in
self
.
fiel
ds
:
raise
ProtocolMismatchError
(
self
,
protocol
,
f"missing method
{
name
}
"
)
corresp
=
self
.
methods
[
name
]
corresp
.
remove_self
().
unify
(
ty
.
remove_self
())
corresp
=
self
.
fields
[
name
].
type
corresp
.
remove_self
().
unify
(
ty
.
type
.
remove_self
())
except
TypeMismatchError
as
e
:
if
hash
(
protocol
)
in
self
.
match_cache
:
self
.
match_cache
.
remove
(
hash
(
protocol
))
raise
ProtocolMismatchError
(
self
,
protocol
,
e
)
def
unify_internal
(
self
,
other
:
BaseType
):
def
unify_internal
(
self
,
other
:
BaseType
,
mode
:
UnifyMode
):
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
# TODO(zdimension): this is really broken... but it would be nice
# if from_node := next(filter(None, (getattr(x, "from_node", None) for x in (other, self))), None):
# TB_NODE = from_node
if
not
isinstance
(
other
,
TypeOperator
):
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
if
other
.
is_protocol
and
not
self
.
is_protocol
:
return
other
.
unify_internal
(
self
)
if
self
.
is_protocol
and
not
other
.
is_protocol
:
return
other
.
matches_protocol
(
self
)
# TODO: doesn't print the correct type in the error message
if
mode
.
match_protocol
:
if
other
.
is_protocol
and
not
self
.
is_protocol
:
return
other
.
unify_internal
(
self
,
mode
)
if
self
.
is_protocol
and
not
other
.
is_protocol
:
return
other
.
matches_protocol
(
self
)
# TODO: doesn't print the correct type in the error message
assert
self
.
is_protocol
==
other
.
is_protocol
if
type
(
self
)
!=
type
(
other
):
# and ((TY_NONE not in {self, other}) or isinstance(({self, other} - {TY_NONE}).pop(), UnionType)):
if
self
.
inherits_from
(
other
)
or
other
.
inherits_from
(
self
):
return
if
mode
.
search_hierarchy
:
if
self
.
inherits_from
(
other
)
or
other
.
inherits_from
(
self
):
return
# for parent in other.get_parents():
# try:
# self.unify(parent)
...
...
@@ -242,8 +276,8 @@ class TypeOperator(BaseType, ABC):
# return
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
if
len
(
self
.
args
)
<
len
(
other
.
args
):
return
other
.
unify_internal
(
self
)
if
len
(
self
.
args
)
==
0
:
return
other
.
unify_internal
(
self
,
mode
)
if
True
or
len
(
self
.
args
)
==
0
:
# todo: why check len?
if
self
.
name
!=
other
.
name
:
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
for
i
,
(
a
,
b
)
in
enumerate
(
zip_longest
(
self
.
args
,
other
.
args
)):
...
...
@@ -292,7 +326,7 @@ class TypeOperator(BaseType, ABC):
for
k
,
v
in
self
.
__dict__
.
items
():
setattr
(
res
,
k
,
v
)
res
.
args
=
[
arg
.
resolve
().
gen_sub
(
this
,
vardict
,
cache
)
for
arg
in
self
.
args
]
res
.
methods
=
{
k
:
v
.
gen_sub
(
this
,
vardict
,
cache
)
for
k
,
v
in
self
.
metho
ds
.
items
()}
res
.
fields
=
{
k
:
dataclasses
.
replace
(
v
,
type
=
v
.
type
.
gen_sub
(
this
,
vardict
,
cache
))
for
k
,
v
in
self
.
fiel
ds
.
items
()}
res
.
parents
=
[
p
.
gen_sub
(
this
,
vardict
,
cache
)
for
p
in
self
.
parents
]
#res.is_protocol = self.is_protocol
return
res
...
...
@@ -466,10 +500,10 @@ class Promise(TypeOperator, ABC):
if
value
==
PromiseKind
.
GENERATOR
:
f_iter
=
FunctionType
([],
self
)
f_iter
.
is_method
=
True
self
.
methods
[
"__iter__"
]
=
f_iter
self
.
fields
[
"__iter__"
]
=
MemberDef
(
f_iter
,
())
f_next
=
FunctionType
([],
self
.
return_type
)
f_next
.
is_method
=
True
self
.
methods
[
"__next__"
]
=
f_next
self
.
fields
[
"__next__"
]
=
MemberDef
(
f_next
,
())
self
.
args
[
1
].
val
=
value
def
__str__
(
self
):
...
...
@@ -506,7 +540,7 @@ class UserType(TypeOperator):
def
__init__
(
self
,
name
:
str
):
super
().
__init__
([],
name
=
name
,
is_reference
=
True
)
def
unify_internal
(
self
,
other
:
"BaseType"
):
def
unify_internal
(
self
,
other
:
"BaseType"
,
mode
:
UnifyMode
):
if
type
(
self
)
!=
type
(
other
):
from
transpiler.phases.typing.exceptions
import
TypeMismatchError
,
TypeMismatchKind
raise
TypeMismatchError
(
self
,
other
,
TypeMismatchKind
.
DIFFERENT_TYPE
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment