Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Labels
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Commits
Open sidebar
nexedi
cython
Commits
995d565e
Commit
995d565e
authored
Dec 07, 2016
by
Robert Bradshaw
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Enable type inference of template function results.
parent
1b61bc34
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
33 additions
and
8 deletions
+33
-8
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+14
-1
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+2
-3
Cython/Compiler/Symtab.py
Cython/Compiler/Symtab.py
+5
-2
Cython/Compiler/Tests/TestSignatureMatching.py
Cython/Compiler/Tests/TestSignatureMatching.py
+1
-2
Cython/Compiler/TypeInference.py
Cython/Compiler/TypeInference.py
+1
-0
tests/run/cpp_template_functions.pyx
tests/run/cpp_template_functions.pyx
+10
-0
No files found.
Cython/Compiler/ExprNodes.py
View file @
995d565e
...
@@ -4940,6 +4940,7 @@ class CallNode(ExprNode):
...
@@ -4940,6 +4940,7 @@ class CallNode(ExprNode):
may_return_none
=
None
may_return_none
=
None
def
infer_type
(
self
,
env
):
def
infer_type
(
self
,
env
):
# TODO(robertwb): Reduce redundancy with analyse_types.
function
=
self
.
function
function
=
self
.
function
func_type
=
function
.
infer_type
(
env
)
func_type
=
function
.
infer_type
(
env
)
if
isinstance
(
function
,
NewExprNode
):
if
isinstance
(
function
,
NewExprNode
):
...
@@ -4953,6 +4954,17 @@ class CallNode(ExprNode):
...
@@ -4953,6 +4954,17 @@ class CallNode(ExprNode):
if
func_type
.
is_ptr
:
if
func_type
.
is_ptr
:
func_type
=
func_type
.
base_type
func_type
=
func_type
.
base_type
if
func_type
.
is_cfunction
:
if
func_type
.
is_cfunction
:
if
hasattr
(
self
.
function
,
'entry'
):
alternatives
=
self
.
function
.
entry
.
all_alternatives
()
arg_types
=
[
arg
.
infer_type
(
env
)
for
arg
in
self
.
args
]
func_entry
=
PyrexTypes
.
best_match
(
arg_types
,
alternatives
,
self
.
pos
,
env
)
if
func_entry
:
func_type
=
func_entry
.
type
if
func_type
.
is_ptr
:
func_type
=
func_type
.
base_type
return
func_type
.
return_type
return
func_type
.
return_type
return
func_type
.
return_type
elif
func_type
is
type_type
:
elif
func_type
is
type_type
:
if
function
.
is_name
and
function
.
entry
and
function
.
entry
.
type
:
if
function
.
is_name
and
function
.
entry
and
function
.
entry
.
type
:
...
@@ -5173,7 +5185,8 @@ class SimpleCallNode(CallNode):
...
@@ -5173,7 +5185,8 @@ class SimpleCallNode(CallNode):
else
:
else
:
alternatives
=
overloaded_entry
.
all_alternatives
()
alternatives
=
overloaded_entry
.
all_alternatives
()
entry
=
PyrexTypes
.
best_match
(
args
,
alternatives
,
self
.
pos
,
env
)
entry
=
PyrexTypes
.
best_match
(
[
arg
.
type
for
arg
in
args
],
alternatives
,
self
.
pos
,
env
)
if
not
entry
:
if
not
entry
:
self
.
type
=
PyrexTypes
.
error_type
self
.
type
=
PyrexTypes
.
error_type
...
...
Cython/Compiler/PyrexTypes.py
View file @
995d565e
...
@@ -4021,7 +4021,7 @@ def is_promotion(src_type, dst_type):
...
@@ -4021,7 +4021,7 @@ def is_promotion(src_type, dst_type):
return
src_type
.
is_float
and
src_type
.
rank
<=
dst_type
.
rank
return
src_type
.
is_float
and
src_type
.
rank
<=
dst_type
.
rank
return
False
return
False
def
best_match
(
args
,
functions
,
pos
=
None
,
env
=
None
):
def
best_match
(
arg
_type
s
,
functions
,
pos
=
None
,
env
=
None
):
"""
"""
Given a list args of arguments and a list of functions, choose one
Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments.
to call which seems to be the "best" fit for this list of arguments.
...
@@ -4044,7 +4044,7 @@ def best_match(args, functions, pos=None, env=None):
...
@@ -4044,7 +4044,7 @@ def best_match(args, functions, pos=None, env=None):
is not None, we also generate an error.
is not None, we also generate an error.
"""
"""
# TODO: args should be a list of types, not a list of Nodes.
# TODO: args should be a list of types, not a list of Nodes.
actual_nargs
=
len
(
args
)
actual_nargs
=
len
(
arg
_type
s
)
candidates
=
[]
candidates
=
[]
errors
=
[]
errors
=
[]
...
@@ -4075,7 +4075,6 @@ def best_match(args, functions, pos=None, env=None):
...
@@ -4075,7 +4075,6 @@ def best_match(args, functions, pos=None, env=None):
errors
.
append
((
func
,
error_mesg
))
errors
.
append
((
func
,
error_mesg
))
continue
continue
if
func_type
.
templates
:
if
func_type
.
templates
:
arg_types
=
[
arg
.
type
for
arg
in
args
]
deductions
=
reduce
(
deductions
=
reduce
(
merge_template_deductions
,
merge_template_deductions
,
[
pattern
.
type
.
deduce_template_params
(
actual
)
for
(
pattern
,
actual
)
in
zip
(
func_type
.
args
,
arg_types
)],
[
pattern
.
type
.
deduce_template_params
(
actual
)
for
(
pattern
,
actual
)
in
zip
(
func_type
.
args
,
arg_types
)],
...
...
Cython/Compiler/Symtab.py
View file @
995d565e
...
@@ -846,13 +846,16 @@ class Scope(object):
...
@@ -846,13 +846,16 @@ class Scope(object):
obj_type
=
operands
[
0
].
type
obj_type
=
operands
[
0
].
type
method
=
obj_type
.
scope
.
lookup
(
"operator%s"
%
operator
)
method
=
obj_type
.
scope
.
lookup
(
"operator%s"
%
operator
)
if
method
is
not
None
:
if
method
is
not
None
:
res
=
PyrexTypes
.
best_match
(
operands
[
1
:],
method
.
all_alternatives
())
arg_types
=
[
arg
.
type
for
arg
in
operands
[
1
:]]
res
=
PyrexTypes
.
best_match
([
arg
.
type
for
arg
in
operands
[
1
:]],
method
.
all_alternatives
())
if
res
is
not
None
:
if
res
is
not
None
:
return
res
return
res
function
=
self
.
lookup
(
"operator%s"
%
operator
)
function
=
self
.
lookup
(
"operator%s"
%
operator
)
if
function
is
None
:
if
function
is
None
:
return
None
return
None
return
PyrexTypes
.
best_match
(
operands
,
function
.
all_alternatives
())
return
PyrexTypes
.
best_match
([
arg
.
type
for
arg
in
operands
],
function
.
all_alternatives
())
def
lookup_operator_for_types
(
self
,
pos
,
operator
,
types
):
def
lookup_operator_for_types
(
self
,
pos
,
operator
,
types
):
from
.Nodes
import
Node
from
.Nodes
import
Node
...
...
Cython/Compiler/Tests/TestSignatureMatching.py
View file @
995d565e
...
@@ -16,8 +16,7 @@ class SignatureMatcherTest(unittest.TestCase):
...
@@ -16,8 +16,7 @@ class SignatureMatcherTest(unittest.TestCase):
Test the signature matching algorithm for overloaded signatures.
Test the signature matching algorithm for overloaded signatures.
"""
"""
def
assertMatches
(
self
,
expected_type
,
arg_types
,
functions
):
def
assertMatches
(
self
,
expected_type
,
arg_types
,
functions
):
args
=
[
NameNode
(
None
,
type
=
arg_type
)
for
arg_type
in
arg_types
]
match
=
pt
.
best_match
(
arg_types
,
functions
)
match
=
pt
.
best_match
(
args
,
functions
)
if
expected_type
is
not
None
:
if
expected_type
is
not
None
:
self
.
assertNotEqual
(
None
,
match
)
self
.
assertNotEqual
(
None
,
match
)
self
.
assertEqual
(
expected_type
,
match
.
type
)
self
.
assertEqual
(
expected_type
,
match
.
type
)
...
...
Cython/Compiler/TypeInference.py
View file @
995d565e
...
@@ -475,6 +475,7 @@ class SimpleAssignmentTypeInferer(object):
...
@@ -475,6 +475,7 @@ class SimpleAssignmentTypeInferer(object):
for
assmt
in
entry
.
cf_assignments
]
for
assmt
in
entry
.
cf_assignments
]
new_type
=
spanning_type
(
types
,
entry
.
might_overflow
,
entry
.
pos
,
scope
)
new_type
=
spanning_type
(
types
,
entry
.
might_overflow
,
entry
.
pos
,
scope
)
if
new_type
!=
entry
.
type
:
if
new_type
!=
entry
.
type
:
print
"FOUND"
,
entry
,
entry
.
type
,
new_type
,
type
(
new_type
)
self
.
set_entry_type
(
entry
,
new_type
)
self
.
set_entry_type
(
entry
,
new_type
)
dirty
=
True
dirty
=
True
return
dirty
return
dirty
...
...
tests/run/cpp_template_functions.pyx
View file @
995d565e
# tag: cpp
# tag: cpp
cimport
cython
from
libcpp.pair
cimport
pair
from
libcpp.pair
cimport
pair
cdef
extern
from
"cpp_template_functions_helper.h"
:
cdef
extern
from
"cpp_template_functions_helper.h"
:
...
@@ -89,3 +90,12 @@ def test_deduce_through_pointers(int k):
...
@@ -89,3 +90,12 @@ def test_deduce_through_pointers(int k):
"""
"""
cdef
double
x
=
k
cdef
double
x
=
k
return
pointer_param
(
&
k
)[
0
],
pointer_param
(
&
x
)[
0
]
return
pointer_param
(
&
k
)[
0
],
pointer_param
(
&
x
)[
0
]
def
test_inference
(
int
k
):
"""
>>> test_inference(27)
27
"""
res
=
one_param
(
&
k
)
assert
cython
.
typeof
(
res
)
==
'int *'
,
cython
.
typeof
(
res
)
return
res
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment