Commit edf04816 authored by Mark Florisson's avatar Mark Florisson

Support fused def functions + lambda + better runtime dispatch

parent d80c2c4c
......@@ -63,7 +63,7 @@ class AutoTestDictTransform(ScopeTrackingTransform):
return node
def visit_FuncDefNode(self, node):
if not node.doc:
if not node.doc or node.fused_py_func:
return node
if not self.cdef_docstrings:
if isinstance(node, CFuncDefNode) and not node.py_func:
......
......@@ -8969,12 +8969,17 @@ static PyObject *%(binding_cfunc)s_call(PyObject *func, PyObject *args, PyObject
static PyObject *%(binding_cfunc)s_get__name__(%(binding_cfunc)s_object *func, void *closure);
static int %(binding_cfunc)s_set__name__(%(binding_cfunc)s_object *func, PyObject *value, void *closure);
static PyObject *%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure);
static PyGetSetDef %(binding_cfunc)s_getsets[] = {
{(char *)"__name__",
(getter) %(binding_cfunc)s_get__name__,
(setter) %(binding_cfunc)s_set__name__,
NULL},
{(char *)"__doc__",
(getter) %(binding_cfunc)s_get__doc__,
NULL,
NULL},
{NULL},
};
......@@ -9139,6 +9144,12 @@ static int
return PyDict_SetItemString(func->__dict__, "__name__", value);
}
static PyObject *
%(binding_cfunc)s_get__doc__(%(binding_cfunc)s_object *func, void *closure)
{
return PyUnicode_FromString(func->func.m_ml->ml_doc);
}
static PyObject *
%(binding_cfunc)s_descr_get(PyObject *op, PyObject *obj, PyObject *type)
{
......@@ -9290,11 +9301,13 @@ static PyObject *
binaryfunc meth = (binaryfunc) binding_func->func.m_ml->ml_meth;
func = new_func = meth(binding_func->__signatures__, args);
*/
PyObject *tup = PyTuple_Pack(2, binding_func->__signatures__, args);
PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
kw == NULL ? Py_None : kw);
if (!tup)
goto __pyx_err;
func = new_func = PyCFunction_Call(func, tup, NULL);
Py_DECREF(tup);
if (!new_func)
goto __pyx_err;
......
......@@ -936,6 +936,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# Generate struct declaration for an extension type's vtable.
type = entry.type
scope = type.scope
self.specialize_fused_types(scope)
if type.vtabstruct_cname:
code.putln("")
code.putln(
......@@ -1942,7 +1945,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("/*--- Function import code ---*/")
for module in imported_modules:
self.specialize_fused_types(module, env)
self.specialize_fused_types(module)
self.generate_c_function_import_code_for_module(module, env, code)
code.putln("/*--- Execution code ---*/")
......@@ -2160,7 +2163,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if entry.defined_in_pxd:
self.generate_type_import_code(env, entry.type, entry.pos, code)
def specialize_fused_types(self, pxd_env, impl_env):
def specialize_fused_types(self, pxd_env):
"""
If fused c(p)def functions are defined in an imported pxd, but not
used in this implementation file, we still have fused entries and
......
......@@ -139,7 +139,6 @@ class Node(object):
cf_state = None
def __init__(self, pos, **kw):
self.pos = pos
self.__dict__.update(kw)
......@@ -2039,7 +2038,8 @@ class FusedCFuncDefNode(StatListNode):
node FuncDefNode the original function
nodes [FuncDefNode] list of copies of node with different specific types
py_func DefNode the original python function (in case of a cpdef)
py_func DefNode the fused python function subscriptable from
Python space
"""
def __init__(self, node, env):
......@@ -2048,42 +2048,78 @@ class FusedCFuncDefNode(StatListNode):
self.nodes = []
self.node = node
self.copy_cdefs(env)
is_def = isinstance(self.node, DefNode)
if is_def:
self.copy_def(env)
else:
self.copy_cdef(env)
# Perform some sanity checks. If anything fails, it's a bug
for n in self.nodes:
assert not n.type.is_fused
assert not n.entry.type.is_fused
assert not n.local_scope.return_type.is_fused
if node.return_type.is_fused:
assert not n.return_type.is_fused
if n.cfunc_declarator.optional_arg_count:
if not is_def and n.cfunc_declarator.optional_arg_count:
assert n.type.op_arg_struct
assert n.type.entry
assert node.type.is_fused
node.entry.fused_cfunction = self
if self.py_func:
self.py_func.entry.fused_cfunction = self
for node in self.nodes:
node.py_func.fused_py_func = self.py_func
node.entry.as_variable = self.py_func.entry
if is_def:
node.fused_py_func = self.py_func
else:
node.py_func.fused_py_func = self.py_func
node.entry.as_variable = self.py_func.entry
# Copy the nodes as AnalyseDeclarationsTransform will append
# Copy the nodes as AnalyseDeclarationsTransform will prepend
# self.py_func to self.stats, as we only want specialized
# CFuncDefNodes in self.nodes
self.stats = self.nodes[:]
def copy_cdefs(self, env):
def copy_def(self, env):
"""
Gives a list of fused types and the parent environment, make copies
of the original cdef function.
Create a copy of the original def or lambda function for specialized
versions.
"""
from Cython.Compiler import ParseTreeTransforms
fused_types = [arg.type for arg in self.node.args if arg.type.is_fused]
permutations = PyrexTypes.get_all_specific_permutations(fused_types)
if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry)
for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node)
for arg in copied_node.args:
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
copied_node.return_type = self.node.return_type.specialize(
fused_to_specific)
copied_node.analyse_declarations(env)
self.create_new_local_scope(copied_node, env, fused_to_specific)
self.specialize_copied_def(copied_node, cname, self.node.entry,
fused_to_specific, fused_types)
PyrexTypes.specialize_entry(copied_node.entry, cname)
copied_node.entry.used = True
env.entries[copied_node.entry.name] = copied_node.entry
if not self.replace_fused_typechecks(copied_node):
break
self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
def copy_cdef(self, env):
"""
Create a copy of the original c(p)def function for all specialized
versions.
"""
permutations = self.node.type.get_all_specific_permutations()
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
......@@ -2120,13 +2156,7 @@ class FusedCFuncDefNode(StatListNode):
type, env, fused_cname=cname)
copied_node.return_type = type.return_type
copied_node.create_local_scope(env)
copied_node.local_scope.fused_to_specific = fused_to_specific
# This is copied from the original function, set it to false to
# stop recursion
copied_node.has_fused_arguments = False
self.nodes.append(copied_node)
self.create_new_local_scope(copied_node, env, fused_to_specific)
# Make the argument types in the CFuncDeclarator specific
for arg in copied_node.cfunc_declarator.args:
......@@ -2135,45 +2165,83 @@ class FusedCFuncDefNode(StatListNode):
type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry)
# If a cpdef, declare all specialized cpdefs
# If a cpdef, declare all specialized cpdefs (this
# also calls analyse_declarations)
copied_node.declare_cpdef_wrapper(env)
if copied_node.py_func:
env.pyfunc_entries.remove(copied_node.py_func.entry)
# copied_node.py_func.self_in_stararg = True
type_strings = [
fused_to_specific[fused_type].typeof_name()
for fused_type in fused_types]
if len(type_strings) == 1:
sigstring = type_strings[0]
else:
sigstring = ', '.join(type_strings)
copied_node.py_func.specialized_signature_string = sigstring
self.specialize_copied_def(
copied_node.py_func, cname, self.node.entry.as_variable,
fused_to_specific, fused_types)
e = copied_node.py_func.entry
e.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, e.pymethdef_cname)
num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
transform(copied_node)
if Errors.num_errors > num_errors:
if not self.replace_fused_typechecks(copied_node):
break
if orig_py_func:
self.py_func = self.make_fused_cpdef(orig_py_func, env)
self.py_func = self.make_fused_cpdef(orig_py_func, env,
is_def=False)
else:
self.py_func = orig_py_func
def create_new_local_scope(self, node, env, f2s):
"""
Create a new local scope for the copied node and append it to
self.nodes. A new local scope is needed because the arguments with the
fused types are aready in the local scope, and we need the specialized
entries created after analyse_declarations on each specialized version
of the (CFunc)DefNode.
f2s is a dict mapping each fused type to its specialized version
"""
node.create_local_scope(env)
node.local_scope.fused_to_specific = f2s
def make_fused_cpdef(self, orig_py_func, env):
# This is copied from the original function, set it to false to
# stop recursion
node.has_fused_arguments = False
self.nodes.append(node)
def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types):
"""Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry"""
type_strings = [f2s[fused_type].typeof_name()
for fused_type in fused_types]
node.specialized_signature_string = ', '.join(type_strings)
node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, node.entry.pymethdef_cname)
node.entry.doc = py_entry.doc
node.entry.doc_cname = py_entry.doc_cname
def replace_fused_typechecks(self, copied_node):
"""
Branch-prune fused type checks like
if fused_t is int:
...
Returns whether an error was issued and whether we should stop in
in order to prevent a flood of errors.
"""
from Cython.Compiler import ParseTreeTransforms
num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
transform(copied_node)
if Errors.num_errors > num_errors:
return False
return True
def make_fused_cpdef(self, orig_py_func, env, is_def):
"""
This creates the function that is indexable from Python and does
runtime dispatch based on the argument types.
runtime dispatch based on the argument types. The function gets the
arg tuple and kwargs dict (or None) as arugments from the Binding
Fused Function's tp_call.
"""
from Cython.Compiler import TreeFragment
from Cython.Compiler import ParseTreeTransforms
......@@ -2184,17 +2252,28 @@ class FusedCFuncDefNode(StatListNode):
# list of statements that do the instance checks
body_stmts = []
for i, arg_type in enumerate(self.node.type.args):
arg_type = arg_type.type
args = self.node.args
for i, arg in enumerate(args):
arg_type = arg.type
if arg_type.is_fused and arg_type not in seen_fused_types:
seen_fused_types.add(arg_type)
specialized_types = PyrexTypes.get_specific_types(arg_type)
specialized_types = PyrexTypes.get_specialized_types(arg_type)
# Prefer long over int, etc
specialized_types.sort()
# specialized_types.sort()
seen_py_type_names = cython.set()
first_check = True
body_stmts.append(u"""
if nargs >= %(nextidx)d or '%(argname)s' in kwargs:
if nargs >= %(nextidx)d:
arg = args[%(idx)d]
else:
arg = kwargs['%(argname)s']
""" % {'idx': i, 'nextidx': i + 1, 'argname': arg.name})
all_numeric = True
for specialized_type in specialized_types:
py_type_name = specialized_type.py_type_name()
......@@ -2203,6 +2282,8 @@ class FusedCFuncDefNode(StatListNode):
seen_py_type_names.add(py_type_name)
all_numeric = all_numeric and specialized_type.is_numeric
if first_check:
if_ = 'if'
first_check = False
......@@ -2216,24 +2297,43 @@ class FusedCFuncDefNode(StatListNode):
if py_type_name in ('long', 'unicode', 'bytes'):
instance_check_py_type_name += '_'
tup = (if_, i, instance_check_py_type_name,
tup = (if_, instance_check_py_type_name,
len(seen_fused_types) - 1,
specialized_type.typeof_name())
body_stmts.append(
" %s isinstance(args[%d], %s): "
"dest_sig[%d] = '%s'" % tup)
" %s isinstance(arg, %s): "
"dest_sig[%d] = '%s'" % tup)
if arg.default and all_numeric:
arg.default.analyse_types(env)
ts = specialized_types
if arg.default.type.is_complex:
typelist = [t for t in ts if t.is_complex]
elif arg.default.type.is_float:
typelist = [t for t in ts if t.is_float]
else:
typelist = [t for t in ts if t.is_int]
if typelist:
body_stmts.append(u"""\
else:
dest_sig[%d] = '%s'
""" % (i, typelist[0].typeof_name()))
fmt_dict = {
'body': '\n'.join(body_stmts),
'nargs': len(self.node.type.args),
'nargs': len(args),
'name': orig_py_func.entry.name,
}
fragment = TreeFragment.TreeFragment(u"""
def __pyx_fused_cpdef(signatures, args):
if len(args) < %(nargs)d:
raise TypeError("Invalid number of arguments, expected %(nargs)d, "
"got %%d" %% len(args))
fragment_code = u"""
def __pyx_fused_cpdef(signatures, args, kwargs):
#if len(args) < %(nargs)d:
# raise TypeError("Invalid number of arguments, expected %(nargs)d, "
# "got %%d" %% len(args))
cdef int nargs
nargs = len(args)
import sys
if sys.version_info >= (3, 0):
......@@ -2245,7 +2345,10 @@ def __pyx_fused_cpdef(signatures, args):
unicode_ = unicode
bytes_ = str
dest_sig = [None] * len(args)
dest_sig = [None] * %(nargs)d
if kwargs is None:
kwargs = {}
# instance check body
%(body)s
......@@ -2266,8 +2369,11 @@ def __pyx_fused_cpdef(signatures, args):
raise TypeError("Function call with ambiguous argument types")
else:
return signatures[candidates[0]]
""" % fmt_dict, level='module')
""" % fmt_dict
# print fragment_code
fragment = TreeFragment.TreeFragment(fragment_code, level='module')
# analyse the declarations of our fragment ...
py_func, = fragment.substitute(pos=self.node.pos).stats
......@@ -2283,17 +2389,31 @@ def __pyx_fused_cpdef(signatures, args):
py_func.name = e.name = orig_e.name
e.cname, e.func_cname = orig_e.cname, orig_e.func_cname
e.pymethdef_cname = orig_e.pymethdef_cname
e.doc, e.doc_cname = orig_e.doc, orig_e.doc_cname
# e.signature = TypeSlots.binaryfunc
py_func.doc = orig_py_func.doc
# ... and the symbol table
del env.entries['__pyx_fused_cpdef']
env.entries[e.name].as_variable = e
if is_def:
env.entries[e.name] = e
else:
env.entries[e.name].as_variable = e
env.pyfunc_entries.append(e)
py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
if is_def:
py_func.specialized_cpdefs = self.nodes[:]
else:
py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
return py_func
def generate_function_definitions(self, env, code):
# Ensure the indexable fused function is generated first, so we can
# use its docstring
# self.stats.insert(0, self.stats.pop())
for stat in self.stats:
# print stat.entry, stat.entry.used
if stat.entry.used:
......@@ -2482,6 +2602,7 @@ class DefNode(FuncDefNode):
self.declare_lambda_function(env)
else:
self.declare_pyfunction(env)
self.analyse_signature(env)
self.return_type = self.entry.signature.return_type()
self.create_local_scope(env)
......@@ -2498,6 +2619,10 @@ class DefNode(FuncDefNode):
arg.declarator.analyse(base_type, env)
arg.name = name_declarator.name
arg.type = type
if type.is_fused:
self.has_fused_arguments = True
self.align_argument_type(env, arg)
if name_declarator and name_declarator.cname:
error(self.pos,
......@@ -2707,6 +2832,10 @@ class DefNode(FuncDefNode):
def synthesize_assignment_node(self, env):
import ExprNodes
if self.fused_py_func:
return
genv = env
while genv.is_py_class_scope or genv.is_c_class_scope:
genv = genv.outer_scope
......@@ -2766,22 +2895,29 @@ class DefNode(FuncDefNode):
# If we are the specialized version of the cpdef, we still
# want the prototype for the "fused cpdef", in case we're
# checking to see if our method was overridden in Python
self.fused_py_func.generate_function_header(code, with_pymethdef, proto_only=True)
self.fused_py_func.generate_function_header(
code, with_pymethdef, proto_only=True)
return
if (Options.docstrings and self.entry.doc and
not self.fused_py_func and
not self.entry.scope.is_property_scope and
(not self.entry.is_special or self.entry.wrapperbase_cname)):
# h_code = code.globalstate['h_code']
docstr = self.entry.doc
if docstr.is_unicode:
docstr = docstr.utf8encode()
code.putln(
'static char %s[] = "%s";' % (
self.entry.doc_cname,
split_string_literal(escape_byte_string(docstr))))
if self.entry.is_special:
code.putln(
"struct wrapperbase %s;" % self.entry.wrapperbase_cname)
if with_pymethdef or self.fused_py_func:
code.put(
"static PyMethodDef %s = " %
......@@ -2909,10 +3045,12 @@ class DefNode(FuncDefNode):
else:
func = arg.type.from_py_function
if func:
code.putln("%s = %s(%s); %s" % (
rhs = "%s(%s)" % (func, item)
if arg.type.is_enum:
rhs = arg.type.cast_code(rhs)
code.putln("%s = %s; %s" % (
arg.entry.cname,
func,
item,
rhs,
code.error_goto_if(arg.type.error_condition(arg.entry.cname), arg.pos)))
else:
error(arg.pos, "Cannot convert Python object argument to type '%s'" % arg.type)
......
......@@ -1348,10 +1348,13 @@ if VALUE is not None:
count += 1
""")
fused_function = None
def __call__(self, root):
self.env_stack = [root.scope]
# needed to determine if a cdef var is declared after it's used.
self.seen_vars_stack = []
self.fused_error_funcs = cython.set()
return super(AnalyseDeclarationsTransform, self).__call__(root)
def visit_NameNode(self, node):
......@@ -1399,9 +1402,12 @@ if VALUE is not None:
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
env = self.env_stack[-1]
self.seen_vars_stack.append(cython.set())
lenv = node.local_scope
node.declare_arguments(lenv)
for var, type_node in node.directive_locals.items():
if not lenv.lookup_here(var): # don't redeclare args
type = type_node.analyse_as_type(lenv)
......@@ -1411,10 +1417,27 @@ if VALUE is not None:
error(type_node.pos, "Not a type")
if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1])
if self.fused_function:
if self.fused_function not in self.fused_error_funcs:
error(node.pos, "Cannot nest fused functions")
self.fused_error_funcs.add(self.fused_function)
# env.declare_var(node.name, PyrexTypes.py_object_type, node.pos)
node = Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=node.name),
rhs=ExprNodes.NoneNode(node.pos))
node.analyse_declarations(env)
return node
node = Nodes.FusedCFuncDefNode(node, env)
self.fused_function = node
self.visitchildren(node)
self.fused_function = None
if node.py_func:
node.stats.append(node.py_func)
node.stats.insert(0, node.py_func)
else:
node.body.analyse_declarations(lenv)
......@@ -2082,6 +2105,10 @@ class CreateClosureClasses(CythonTransform):
target_module_scope.check_c_class(func_scope.scope_class)
def visit_LambdaNode(self, node):
if not isinstance(node.def_node, Nodes.DefNode):
# fused function, an error has been previously issued
return node
was_in_lambda = self.in_lambda
self.in_lambda = True
self.create_class_from_scope(node.def_node, self.module_scope, node)
......@@ -2408,7 +2435,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
error(node.operand2.pos,
"Can only use 'in' or 'not in' on a fused type")
else:
types = PyrexTypes.get_specific_types(type2)
types = PyrexTypes.get_specialized_types(type2)
for specific_type in types:
if type1.same_as(specific_type):
......
......@@ -235,9 +235,16 @@ def public_decl(base_code, dll_linkage):
return base_code
def create_typedef_type(name, base_type, cname, is_external=0):
if base_type.is_complex:
is_fused = base_type.is_fused
if base_type.is_complex or is_fused:
if is_external:
raise ValueError("Complex external typedefs not supported")
if is_fused:
msg = "Fused"
else:
msg = "Complex"
raise ValueError("%s external typedefs not supported" % msg)
return base_type
else:
return CTypedefType(name, base_type, cname, is_external)
......@@ -2123,6 +2130,7 @@ class CFuncType(CType):
result = []
permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific)
......@@ -2150,20 +2158,25 @@ class CFuncType(CType):
def specialize_entry(self, entry, cname):
assert not self.is_fused
specialize_entry(entry, cname)
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
def specialize_entry(entry, cname):
"""
Specialize an entry of a copied fused function or method
"""
entry.name = get_fused_cname(cname, entry.name)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
def get_fused_cname(fused_cname, orig_cname):
"""
......@@ -2177,7 +2190,7 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0]
result = []
for newid, specific_type in enumerate(fused_type.types):
for newid, specific_type in enumerate(sorted(fused_type.types)):
# f2s = dict(f2s, **{ fused_type: specific_type })
f2s = dict(f2s)
f2s.update({ fused_type: specific_type })
......@@ -2195,17 +2208,21 @@ def get_all_specific_permutations(fused_types, id="", f2s=()):
return result
def get_specific_types(type):
def get_specialized_types(type):
"""
Return a list of specialized types sorted in reverse order in accordance
with their preference in runtime fused-type dispatch
"""
assert type.is_fused
if isinstance(type, FusedType):
return type.types
result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
result = type.types
else:
result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
return result
return sorted(result)
class CFuncTypeArg(BaseType):
......
......@@ -1509,6 +1509,7 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
class StructOrUnionScope(Scope):
# Namespace of a C struct or union.
......
......@@ -272,7 +272,7 @@ try:
except NameError: # Py3
py_long = typedef(int, "long")
py_float = typedef(float, "float")
py_complex = typedef(complex, "complex")
py_complex = typedef(complex, "double complex")
try:
......
# mode: error
cimport cython
def closure(cython.integral i):
def inner(cython.floating f):
pass
def closure2(cython.integral i):
return lambda cython.integral i: i
def closure3(cython.integral i):
def inner():
return lambda cython.floating f: f
_ERRORS = u"""
e_fused_closure.pyx:6:4: Cannot nest fused functions
e_fused_closure.pyx:10:11: Cannot nest fused functions
e_fused_closure.pyx:14:15: Cannot nest fused functions
"""
# mode: run
"""
Test Python def functions without extern types
"""
cy = __import__("cython")
cimport cython
cdef class Base(object):
def __repr__(self):
return type(self).__name__
cdef class ExtClassA(Base):
pass
cdef class ExtClassB(Base):
pass
cdef enum MyEnum:
entry0
entry1
entry2
entry3
entry4
ctypedef fused fused_t:
str
int
long
complex
ExtClassA
ExtClassB
MyEnum
f = 5.6
i = 9
def opt_func(fused_t obj, cython.floating myf = 1.2, cython.integral myi = 7):
"""
Test runtime dispatch, indexing of various kinds and optional arguments
>>> opt_func("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func["str, double, long"]("spam", f, i)
str object double long
spam 5.60 9 5.60 9
>>> opt_func[str, float, cy.int]("spam", f, i)
str object float int
spam 5.60 9 5.60 9
>>> opt_func(ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func[ExtClassA, float, int](ExtClassA(), f, i)
ExtClassA float int
ExtClassA 5.60 9 5.60 9
>>> opt_func["ExtClassA, double, long"](ExtClassA(), f, i)
ExtClassA double long
ExtClassA 5.60 9 5.60 9
>>> opt_func(ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func[ExtClassB, cy.double, cy.long](ExtClassB(), f, i)
ExtClassB double long
ExtClassB 5.60 9 5.60 9
>>> opt_func(10, f)
long double long
10 5.60 7 5.60 9
>>> opt_func[int, float, int](10, f)
int float int
10 5.60 7 5.60 9
>>> opt_func(10 + 2j, myf = 2.6)
double complex double long
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.py_complex, float, int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func[cy.doublecomplex, cy.float, cy.int](10 + 2j, myf = 2.6)
double complex float int
(10+2j) 2.60 7 5.60 9
>>> opt_func(object(), f)
Traceback (most recent call last):
...
TypeError: Function call with ambiguous argument types
>>> opt_func[ExtClassA, cy.float, long](object(), f)
Traceback (most recent call last):
...
TypeError: Argument 'obj' has incorrect type (expected fused_def.ExtClassA, got object)
"""
print cython.typeof(obj), cython.typeof(myf), cython.typeof(myi)
print obj, "%.2f" % myf, myi, "%.2f" % f, i
def test_opt_func():
"""
>>> test_opt_func()
str object double long
ham 5.60 4 5.60 9
"""
cdef char *s = "ham"
opt_func(s, f, entry4)
def args_kwargs(fused_t obj, cython.floating myf = 1.2, *args, **kwargs):
"""
>>> args_kwargs("foo")
str object double
foo 1.20 5.60 () {}
>>> args_kwargs("eggs", f, 1, 2, [], d={})
str object double
eggs 5.60 5.60 (1, 2, []) {'d': {}}
>>> args_kwargs[str, float]("eggs", f, 1, 2, [], d={})
str object float
eggs 5.60 5.60 (1, 2, []) {'d': {}}
"""
print cython.typeof(obj), cython.typeof(myf)
print obj, "%.2f" % myf, "%.2f" % f, args, kwargs
......@@ -35,6 +35,21 @@ less_simple_t = cython.fused_type(int, float, string_t)
struct_t = cython.fused_type(mystruct_t, myunion_t, MyExt)
builtin_t = cython.fused_type(str, unicode, bytes)
ctypedef fused fusedbunch:
int
long
complex
string_t
ctypedef fused fused1:
short
string_t
cdef fused fused2:
float
double
string_t
cdef struct_t add_simple(struct_t obj, simple_t simple)
cdef less_simple_t add_to_simple(struct_t obj, less_simple_t simple)
cdef public_optional_args(struct_t obj, simple_t simple = *)
......@@ -79,7 +94,17 @@ cdef class TestFusedExtMethods(object):
cpdef cpdef_method(self, cython.integral x, cython.floating y):
return cython.typeof(x), cython.typeof(y)
def def_method(self, fused1 x, fused2 y):
if (fused1 is string_t and fused2 is not string_t or
not fused1 is string_t and fused2 is string_t):
return x, y
else:
return <fused1> x + y
cpdef public_cpdef(cython.integral x, cython.floating y, object_t z):
if cython.integral is int:
pass
return cython.typeof(x), cython.typeof(y), cython.typeof(z)
......@@ -131,7 +156,9 @@ cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method
assert func(obj, a, b) == 15.0
result = func(obj, a, b)
assert result == 15.0, result
func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0
......@@ -200,5 +227,11 @@ ae(myobj.cpdef_method[cy.int, cy.float](10, 10.0), (10, 10.0))
"""
d = {'obj': obj, 'myobj': myobj, 'ae': ae}
# FIXME: uncomment after subclassing CyFunction
#exec s in d
# Test def methods
# ae(obj.def_method(12, 14.9), 26)
# ae(obj.def_method(13, "spam"), (13, "spam"))
# ae(obj.def_method[cy.short, cy.float](13, 16.3), 29)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment