Commit 47ce63c9 authored by Mark Florisson's avatar Mark Florisson

Support for fused types in cdef functions

parent b386862c
...@@ -2949,6 +2949,7 @@ class SimpleCallNode(CallNode): ...@@ -2949,6 +2949,7 @@ class SimpleCallNode(CallNode):
else: else:
for arg in self.args: for arg in self.args:
arg.analyse_types(env) arg.analyse_types(env)
if self.self and func_type.args: if self.self and func_type.args:
# Coerce 'self' to the type expected by the method. # Coerce 'self' to the type expected by the method.
self_arg = func_type.args[0] self_arg = func_type.args[0]
...@@ -2965,10 +2966,13 @@ class SimpleCallNode(CallNode): ...@@ -2965,10 +2966,13 @@ class SimpleCallNode(CallNode):
def function_type(self): def function_type(self):
# Return the type of the function being called, coercing a function # Return the type of the function being called, coercing a function
# pointer to a function if necessary. # pointer to a function if necessary. If the function has fused
# arguments, return the specific type.
func_type = self.function.type func_type = self.function.type
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
return func_type return func_type
def is_simple(self): def is_simple(self):
...@@ -2982,6 +2986,7 @@ class SimpleCallNode(CallNode): ...@@ -2982,6 +2986,7 @@ class SimpleCallNode(CallNode):
if self.function.type is error_type: if self.function.type is error_type:
self.type = error_type self.type = error_type
return return
if self.function.type.is_cpp_class: if self.function.type.is_cpp_class:
overloaded_entry = self.function.type.scope.lookup("operator()") overloaded_entry = self.function.type.scope.lookup("operator()")
if overloaded_entry is None: if overloaded_entry is None:
...@@ -2992,8 +2997,16 @@ class SimpleCallNode(CallNode): ...@@ -2992,8 +2997,16 @@ class SimpleCallNode(CallNode):
overloaded_entry = self.function.entry overloaded_entry = self.function.entry
else: else:
overloaded_entry = None overloaded_entry = None
if overloaded_entry: if overloaded_entry:
entry = PyrexTypes.best_match(self.args, overloaded_entry.all_alternatives(), self.pos) if overloaded_entry.fused_cfunction:
specific_cdef_funcs = overloaded_entry.fused_cfunction.nodes
alternatives = [n.entry for n in specific_cdef_funcs]
else:
alternatives = overloaded_entry.all_alternatives()
entry = PyrexTypes.best_match(self.args, alternatives, self.pos, env)
if not entry: if not entry:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
...@@ -3130,8 +3143,8 @@ class SimpleCallNode(CallNode): ...@@ -3130,8 +3143,8 @@ class SimpleCallNode(CallNode):
for actual_arg in self.args[len(formal_args):]: for actual_arg in self.args[len(formal_args):]:
arg_list_code.append(actual_arg.result()) arg_list_code.append(actual_arg.result())
result = "%s(%s)" % (self.function.result(),
', '.join(arg_list_code)) result = "%s(%s)" % (self.function.result(), ', '.join(arg_list_code))
return result return result
def generate_result_code(self, code): def generate_result_code(self, code):
......
...@@ -156,13 +156,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -156,13 +156,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f.close() f.close()
def generate_public_declaration(self, entry, h_code, i_code): def generate_public_declaration(self, entry, h_code, i_code):
if entry.fused_cfunction:
for cfunction in entry.fused_cfunction.nodes:
self._generate_public_declaration(cfunction.entry,
cfunction.entry.cname, h_code, i_code)
else:
self._generate_public_declaration(entry, entry.cname,
h_code, i_code)
def _generate_public_declaration(self, entry, cname, h_code, i_code):
h_code.putln("%s %s;" % ( h_code.putln("%s %s;" % (
Naming.extern_c_macro, Naming.extern_c_macro,
entry.type.declaration_code( entry.type.declaration_code(
entry.cname, dll_linkage = "DL_IMPORT"))) cname, dll_linkage = "DL_IMPORT")))
if i_code: if i_code:
i_code.putln("cdef extern %s" % i_code.putln("cdef extern %s" %
entry.type.declaration_code(entry.cname, pyrex = 1)) entry.type.declaration_code(cname, pyrex = 1))
def api_name(self, env): def api_name(self, env):
return env.qualified_name.replace(".", "__") return env.qualified_name.replace(".", "__")
...@@ -987,34 +996,43 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -987,34 +996,43 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cfunction_predeclarations(self, env, code, definition): def generate_cfunction_predeclarations(self, env, code, definition):
for entry in env.cfunc_entries: for entry in env.cfunc_entries:
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition if entry.fused_cfunction:
or entry.defined_in_pxd or entry.visibility == 'extern')): for node in entry.fused_cfunction.nodes:
if entry.visibility == 'public': self._generate_cfunction_predeclaration(
storage_class = "%s " % Naming.extern_c_macro code, definition, node.entry)
dll_linkage = "DL_EXPORT" else:
elif entry.visibility == 'extern': self._generate_cfunction_predeclaration(code, definition, entry)
storage_class = "%s " % Naming.extern_c_macro
dll_linkage = "DL_IMPORT"
elif entry.visibility == 'private': def _generate_cfunction_predeclaration(self, code, definition, entry):
storage_class = "static " if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
dll_linkage = None or entry.defined_in_pxd or entry.visibility == 'extern')):
else: if entry.visibility == 'public':
storage_class = "static " storage_class = "%s " % Naming.extern_c_macro
dll_linkage = None dll_linkage = "DL_EXPORT"
type = entry.type elif entry.visibility == 'extern':
storage_class = "%s " % Naming.extern_c_macro
if not definition and entry.defined_in_pxd: dll_linkage = "DL_IMPORT"
type = CPtrType(type) elif entry.visibility == 'private':
header = type.declaration_code(entry.cname, storage_class = "static "
dll_linkage = dll_linkage) dll_linkage = None
if entry.func_modifiers: else:
modifiers = "%s " % ' '.join(entry.func_modifiers).upper() storage_class = "static "
else: dll_linkage = None
modifiers = '' type = entry.type
code.putln("%s%s%s; /*proto*/" % (
storage_class, if not definition and entry.defined_in_pxd:
modifiers, type = CPtrType(type)
header)) header = type.declaration_code(entry.cname,
dll_linkage = dll_linkage)
if entry.func_modifiers:
modifiers = "%s " % ' '.join(entry.func_modifiers).upper()
else:
modifiers = ''
code.putln("%s%s%s; /*proto*/" % (
storage_class,
modifiers,
header))
def generate_typeobj_definitions(self, env, code): def generate_typeobj_definitions(self, env, code):
full_module_name = env.qualified_name full_module_name = env.qualified_name
......
...@@ -93,6 +93,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope" ...@@ -93,6 +93,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope"
frame_cname = pyrex_prefix + "frame" frame_cname = pyrex_prefix + "frame"
frame_code_cname = pyrex_prefix + "frame_code" frame_code_cname = pyrex_prefix + "frame_code"
binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
fused_func_prefix = pyrex_prefix + 'fuse_'
genexpr_id_ref = 'genexpr' genexpr_id_ref = 'genexpr'
......
This diff is collapsed.
...@@ -610,8 +610,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -610,8 +610,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma' : ExprNodes.c_binop_constructor(','), 'operator.comma' : ExprNodes.c_binop_constructor(','),
} }
special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof', special_methods = cython.set(['declare', 'union', 'struct', 'typedef',
'cast', 'pointer', 'compiled', 'NULL']) 'sizeof', 'cast', 'pointer', 'compiled',
'NULL', 'fused_type'])
special_methods.update(unop_method_nodes.keys()) special_methods.update(unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
...@@ -896,6 +897,36 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -896,6 +897,36 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return self.visit_with_directives(node.body, directive_dict) return self.visit_with_directives(node.body, directive_dict)
return self.visit_Node(node) return self.visit_Node(node)
def visit_CTypeDefNode(self, node):
"Don't skip ctypedefs"
self.visitchildren(node)
return node
def visit_FusedTypeNode(self, node):
"""
See if a function call expression in a ctypedef is actually
cython.fused_type()
"""
def err():
error(node.pos, "Can only fuse types with cython.fused_type()")
if len(node.funcname) == 1:
fused_type, = node.funcname
else:
cython_module, fused_type = node.funcname
wrong_module = cython_module not in self.cython_module_names
if wrong_module or fused_type != u'fused_type':
err()
return node
if not self.directive_names.get(fused_type):
err()
return node
class WithTransform(CythonTransform, SkipDeclarations): class WithTransform(CythonTransform, SkipDeclarations):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
...@@ -1115,6 +1146,14 @@ if VALUE is not None: ...@@ -1115,6 +1146,14 @@ if VALUE is not None:
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
self.seen_vars_stack.append(cython.set()) self.seen_vars_stack.append(cython.set())
lenv = node.local_scope lenv = node.local_scope
node.body.analyse_control_flow(lenv) # this will be totally refactored node.body.analyse_control_flow(lenv) # this will be totally refactored
...@@ -1126,10 +1165,16 @@ if VALUE is not None: ...@@ -1126,10 +1165,16 @@ if VALUE is not None:
lenv.declare_var(var, type, type_node.pos) lenv.declare_var(var, type, type_node.pos)
else: else:
error(type_node.pos, "Not a type") error(type_node.pos, "Not a type")
node.body.analyse_declarations(lenv)
self.env_stack.append(lenv) if node.has_fused_arguments:
self.visitchildren(node) node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1])
self.env_stack.pop() self.visitchildren(node)
else:
node.body.analyse_declarations(lenv)
self.env_stack.append(lenv)
self.visitchildren(node)
self.env_stack.pop()
self.seen_vars_stack.pop() self.seen_vars_stack.pop()
return node return node
......
...@@ -2572,6 +2572,24 @@ def p_c_func_or_var_declaration(s, pos, ctx): ...@@ -2572,6 +2572,24 @@ def p_c_func_or_var_declaration(s, pos, ctx):
overridable = ctx.overridable) overridable = ctx.overridable)
return result return result
def p_typelist(s):
"""
parse a list of basic c types as part of a function call, like
cython.fused_type(int, long, double)
"""
types = []
pos = s.position()
while s.sy == 'IDENT':
types.append(p_c_base_type(s))
if s.sy != ',':
if s.sy != ')':
s.expect(',')
break
s.next()
return Nodes.FusedTypeNode(pos, types=types)
def p_ctypedef_statement(s, ctx): def p_ctypedef_statement(s, ctx):
# s.sy == 'ctypedef' # s.sy == 'ctypedef'
pos = s.position() pos = s.position()
...@@ -2588,17 +2606,37 @@ def p_ctypedef_statement(s, ctx): ...@@ -2588,17 +2606,37 @@ def p_ctypedef_statement(s, ctx):
return p_c_enum_definition(s, pos, ctx) return p_c_enum_definition(s, pos, ctx)
else: else:
return p_c_struct_or_union_definition(s, pos, ctx) return p_c_struct_or_union_definition(s, pos, ctx)
elif looking_at_expr(s):
# ctypedef cython.fused_types(int, long) integral
if s.sy == 'IDENT':
funcname = [s.systring]
s.next()
if s.systring == u'.':
s.next()
funcname.append(s.systring)
s.expect('IDENT')
s.expect('(')
base_type = p_typelist(s)
s.expect(')')
# Check if funcname equals cython.fused_types in
# InterpretCompilerDirectives
base_type.funcname = funcname
else:
s.error("Syntax error in ctypedef statement")
else: else:
base_type = p_c_base_type(s, nonempty = 1) base_type = p_c_base_type(s, nonempty = 1)
if base_type.name is None: if base_type.name is None:
s.error("Syntax error in ctypedef statement") s.error("Syntax error in ctypedef statement")
declarator = p_c_declarator(s, ctx, is_type = 1, nonempty = 1)
s.expect_newline("Syntax error in ctypedef statement") declarator = p_c_declarator(s, ctx, is_type = 1, nonempty = 1)
return Nodes.CTypeDefNode( s.expect_newline("Syntax error in ctypedef statement")
pos, base_type = base_type, return Nodes.CTypeDefNode(
declarator = declarator, pos, base_type = base_type,
visibility = visibility, api = api, declarator = declarator,
in_pxd = ctx.level == 'module_pxd') visibility = visibility, api = api,
in_pxd = ctx.level == 'module_pxd')
def p_decorators(s): def p_decorators(s):
decorators = [] decorators = []
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# Pyrex - Types # Pyrex - Types
# #
import cython
from Code import UtilityCode from Code import UtilityCode
import StringEncoding import StringEncoding
import Naming import Naming
...@@ -12,6 +14,9 @@ class BaseType(object): ...@@ -12,6 +14,9 @@ class BaseType(object):
# #
# Base class for all Pyrex types including pseudo-types. # Base class for all Pyrex types including pseudo-types.
# List of attribute names of any subtypes
subtypes = []
def can_coerce_to_pyobject(self, env): def can_coerce_to_pyobject(self, env):
return False return False
...@@ -27,6 +32,42 @@ class BaseType(object): ...@@ -27,6 +32,42 @@ class BaseType(object):
else: else:
return base_code return base_code
def __deepcopy__(self, memo):
"""
Types never need to be copied, if we do copy, Unfortunate Things
Will Happen!
"""
return self
def get_fused_types(self, result=None, seen=None):
if self.subtypes:
def add_fused_types(types):
for type in types or ():
if type not in seen:
seen.add(type)
result.append(type)
if result is None:
result = []
seen = cython.set()
for attr in self.subtypes:
list_or_subtype = getattr(self, attr)
if isinstance(list_or_subtype, BaseType):
list_or_subtype.get_fused_types(result, seen)
else:
for subtype in list_or_subtype:
subtype.get_fused_types(result, seen)
return result
return None
is_fused = property(get_fused_types, doc="Whether this type or any of its "
"subtypes is a fused type")
class PyrexType(BaseType): class PyrexType(BaseType):
# #
# Base class for all Pyrex types. # Base class for all Pyrex types.
...@@ -195,7 +236,8 @@ class CTypedefType(BaseType): ...@@ -195,7 +236,8 @@ class CTypedefType(BaseType):
to_py_utility_code = None to_py_utility_code = None
from_py_utility_code = None from_py_utility_code = None
subtypes = ['typedef_base_type']
def __init__(self, name, base_type, cname, is_external=0): def __init__(self, name, base_type, cname, is_external=0):
assert not base_type.is_complex assert not base_type.is_complex
...@@ -314,6 +356,9 @@ class BufferType(BaseType): ...@@ -314,6 +356,9 @@ class BufferType(BaseType):
is_buffer = 1 is_buffer = 1
writable = True writable = True
subtypes = ['dtype']
def __init__(self, base, dtype, ndim, mode, negative_indices, cast): def __init__(self, base, dtype, ndim, mode, negative_indices, cast):
self.base = base self.base = base
self.dtype = dtype self.dtype = dtype
...@@ -345,7 +390,7 @@ class PyObjectType(PyrexType): ...@@ -345,7 +390,7 @@ class PyObjectType(PyrexType):
buffer_defaults = None buffer_defaults = None
is_extern = False is_extern = False
is_subclassed = False is_subclassed = False
def __str__(self): def __str__(self):
return "Python object" return "Python object"
...@@ -618,6 +663,45 @@ class CType(PyrexType): ...@@ -618,6 +663,45 @@ class CType(PyrexType):
return 0 return 0
class FusedType(CType):
"""
Represents a Fused Type. All it needs to do is keep track of the types
it aggregates, as it will be replaced with its specific version wherever
needed.
See http://wiki.cython.org/enhancements/fusedtypes
types [CSimpleBaseTypeNode] is the list of types to be fused
name str the name of the ctypedef
"""
is_fused = 1
def __init__(self, types):
self.types = types
def declaration_code(self, entity_code, for_display = 0,
dll_linkage = None, pyrex = 0):
if pyrex or for_display:
return self.name
raise Exception("This may never happen, please report a bug")
def __repr__(self):
return 'FusedType(name=%r)' % self.name
def specialize(self, values):
return values[self]
def get_fused_types(self, result=None, seen=None):
if result is None:
return [self]
if self not in seen:
result.append(self)
seen.add(self)
class CVoidType(CType): class CVoidType(CType):
# #
# C "void" type # C "void" type
...@@ -1531,7 +1615,9 @@ class CArrayType(CType): ...@@ -1531,7 +1615,9 @@ class CArrayType(CType):
# size integer or None Number of elements # size integer or None Number of elements
is_array = 1 is_array = 1
subtypes = ['base_type']
def __init__(self, base_type, size): def __init__(self, base_type, size):
self.base_type = base_type self.base_type = base_type
self.size = size self.size = size
...@@ -1577,6 +1663,8 @@ class CPtrType(CType): ...@@ -1577,6 +1663,8 @@ class CPtrType(CType):
is_ptr = 1 is_ptr = 1
default_value = "0" default_value = "0"
subtypes = ['base_type']
def __init__(self, base_type): def __init__(self, base_type):
self.base_type = base_type self.base_type = base_type
...@@ -1675,7 +1763,9 @@ class CFuncType(CType): ...@@ -1675,7 +1763,9 @@ class CFuncType(CType):
is_cfunction = 1 is_cfunction = 1
original_sig = None original_sig = None
subtypes = ['return_type', 'args']
def __init__(self, return_type, args, has_varargs = 0, def __init__(self, return_type, args, has_varargs = 0,
exception_value = None, exception_check = 0, calling_convention = "", exception_value = None, exception_check = 0, calling_convention = "",
nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0, nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0,
...@@ -1691,7 +1781,7 @@ class CFuncType(CType): ...@@ -1691,7 +1781,7 @@ class CFuncType(CType):
self.with_gil = with_gil self.with_gil = with_gil
self.is_overridable = is_overridable self.is_overridable = is_overridable
self.templates = templates self.templates = templates
def __repr__(self): def __repr__(self):
arg_reprs = map(repr, self.args) arg_reprs = map(repr, self.args)
if self.has_varargs: if self.has_varargs:
...@@ -1915,7 +2005,7 @@ class CFuncType(CType): ...@@ -1915,7 +2005,7 @@ class CFuncType(CType):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
class CFuncTypeArg(object): class CFuncTypeArg(BaseType):
# name string # name string
# cname string # cname string
# type PyrexType # type PyrexType
...@@ -1926,6 +2016,8 @@ class CFuncTypeArg(object): ...@@ -1926,6 +2016,8 @@ class CFuncTypeArg(object):
or_none = False or_none = False
accept_none = True accept_none = True
subtypes = ['type']
def __init__(self, name, type, pos, cname=None): def __init__(self, name, type, pos, cname=None):
self.name = name self.name = name
if cname is not None: if cname is not None:
...@@ -2478,7 +2570,7 @@ def is_promotion(src_type, dst_type): ...@@ -2478,7 +2570,7 @@ def is_promotion(src_type, dst_type):
return src_type.is_float and src_type.rank <= dst_type.rank return src_type.is_float and src_type.rank <= dst_type.rank
return False return False
def best_match(args, functions, pos=None): def best_match(args, functions, pos=None, env=None):
""" """
Given a list args of arguments and a list of functions, choose one Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments. to call which seems to be the "best" fit for this list of arguments.
...@@ -2546,12 +2638,33 @@ def best_match(args, functions, pos=None): ...@@ -2546,12 +2638,33 @@ def best_match(args, functions, pos=None):
possibilities = [] possibilities = []
bad_types = [] bad_types = []
needed_coercions = {}
for func, func_type in candidates: for func, func_type in candidates:
score = [0,0,0] score = [0,0,0]
for i in range(min(len(args), len(func_type.args))): for i in range(min(len(args), len(func_type.args))):
src_type = args[i].type src_type = args[i].type
dst_type = func_type.args[i].type dst_type = func_type.args[i].type
if dst_type.assignable_from(src_type):
assignable = dst_type.assignable_from(src_type)
# Now take care of normal string literals. So when you call a cdef
# function that takes a char *, the coercion will mean that the
# type will simply become bytes. We need to do this coercion
# manually for overloaded and fused functions
if not assignable and src_type.is_pyobject:
if (src_type.is_builtin_type and src_type.name == 'str' and
dst_type.resolve() is c_char_ptr_type):
c_src_type = c_char_ptr_type
else:
c_src_type = src_type.default_coerced_ctype()
if c_src_type:
assignable = dst_type.assignable_from(c_src_type)
if assignable:
src_type = c_src_type
needed_coercions[func] = i, dst_type
if assignable:
if src_type == dst_type or dst_type.same_as(src_type): if src_type == dst_type or dst_type.same_as(src_type):
pass # score 0 pass # score 0
elif is_promotion(src_type, dst_type): elif is_promotion(src_type, dst_type):
...@@ -2567,18 +2680,28 @@ def best_match(args, functions, pos=None): ...@@ -2567,18 +2680,28 @@ def best_match(args, functions, pos=None):
break break
else: else:
possibilities.append((score, func)) # so we can sort it possibilities.append((score, func)) # so we can sort it
if possibilities: if possibilities:
possibilities.sort() possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
if pos is not None: if pos is not None:
error(pos, "ambiguous overloaded method") error(pos, "ambiguous overloaded method")
return None return None
return possibilities[0][1]
function = possibilities[0][1]
if function in needed_coercions and env:
arg_i, coerce_to_type = needed_coercions[function]
args[arg_i] = args[arg_i].coerce_to(coerce_to_type, env)
return function
if pos is not None: if pos is not None:
if len(bad_types) == 1: if len(bad_types) == 1:
error(pos, bad_types[0][1]) error(pos, bad_types[0][1])
else: else:
error(pos, "no suitable method found") error(pos, "no suitable method found")
return None return None
def widest_numeric_type(type1, type2): def widest_numeric_type(type1, type2):
......
...@@ -176,6 +176,7 @@ class Entry(object): ...@@ -176,6 +176,7 @@ class Entry(object):
buffer_aux = None buffer_aux = None
prev_entry = None prev_entry = None
might_overflow = 0 might_overflow = 0
fused_cfunction = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -241,6 +242,7 @@ class Scope(object): ...@@ -241,6 +242,7 @@ class Scope(object):
scope_prefix = "" scope_prefix = ""
in_cinclude = 0 in_cinclude = 0
nogil = 0 nogil = 0
fused_to_specific = None
def __init__(self, name, outer_scope, parent_scope): def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain. # The outer_scope is the next scope in the lookup chain.
...@@ -279,6 +281,9 @@ class Scope(object): ...@@ -279,6 +281,9 @@ class Scope(object):
self.return_type = None self.return_type = None
self.id_counters = {} self.id_counters = {}
def __deepcopy__(self, memo):
return self
def start_branching(self, pos): def start_branching(self, pos):
self.control_flow = self.control_flow.start_branch(pos) self.control_flow = self.control_flow.start_branch(pos)
...@@ -677,6 +682,8 @@ class Scope(object): ...@@ -677,6 +682,8 @@ class Scope(object):
def lookup_type(self, name): def lookup_type(self, name):
entry = self.lookup(name) entry = self.lookup(name)
if entry and entry.is_type: if entry and entry.is_type:
if entry.type.is_fused and self.fused_to_specific:
return entry.type.specialize(self.fused_to_specific)
return entry.type return entry.type
def lookup_operator(self, operator, operands): def lookup_operator(self, operator, operands):
......
...@@ -225,6 +225,30 @@ class typedef(CythonType): ...@@ -225,6 +225,30 @@ class typedef(CythonType):
value = cast(self._basetype, *arg) value = cast(self._basetype, *arg)
return value return value
class _FusedType(CythonType):
def __call__(self, type, value):
return value
def fused_type(*args):
if not args:
raise TypeError("Expected at least one type as argument")
rank = -1
for type in args:
if type not in (py_int, py_long, py_float, py_complex):
break
if type_ordering.index(type) > rank:
result_type = type
else:
return result_type
# Not a simple numeric type, return a fused type instance. The result
# isn't really meant to be used, as we can't keep track of the context in
# pure-mode. Casting won't do anything in this case.
return _FusedType()
py_int = int py_int = int
...@@ -277,3 +301,5 @@ for t in int_types + float_types + complex_types + other_types: ...@@ -277,3 +301,5 @@ for t in int_types + float_types + complex_types + other_types:
void = typedef(None) void = typedef(None)
NULL = p_void(0) NULL = p_void(0)
type_ordering = [py_int, py_long, py_float, py_complex]
\ No newline at end of file
# mode: error
cimport cython
from cython import fused_type
# This is all invalid
ctypedef foo(int) dtype1
ctypedef foo.bar(float) dtype2
ctypedef fused_type(foo) dtype3
dtype4 = cython.typedef(cython.fused_type(int, long, kw=None))
# This is all valid
ctypedef fused_type(int, long, float) dtype5
ctypedef cython.fused_type(int, long) dtype6
_ERRORS = u"""
fused_types.pyx:7:13: Can only fuse types with cython.fused_type()
fused_types.pyx:8:17: Can only fuse types with cython.fused_type()
fused_types.pyx:9:20: 'foo' is not a type identifier
fused_types.pyx:10:23: fused_type does not take keyword arguments
"""
# mode: run
cimport cython
from cpython cimport Py_INCREF
from Cython import Shadow as pure_cython
ctypedef char * string_t
ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1
ctypedef cython.fused_type(string_t) fused_type2
def test_pure():
"""
>>> test_pure()
(10+0j)
"""
mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex))
print mytype(10)
cdef cdef_func_with_fused_args(fused_type1 x, fused_type1 y, fused_type2 z):
print x, y, z
return x + y
def test_cdef_func_with_fused_args():
"""
>>> test_cdef_func_with_fused_args()
spam ham eggs
spamham
10 20 butter
30
4.2 8.6 bunny
12.8
"""
print cdef_func_with_fused_args('spam', 'ham', 'eggs')
print cdef_func_with_fused_args(10, 20, 'butter')
print cdef_func_with_fused_args(4.2, 8.6, 'bunny')
cdef fused_type1 fused_with_pointer(fused_type1 *array):
for i in range(5):
print array[i]
obj = array[0] + array[1] + array[2] + array[3] + array[4]
# if cython.typeof(fused_type1) is string_t:
Py_INCREF(obj)
return obj
def test_fused_with_pointer():
"""
>>> test_fused_with_pointer()
0
1
2
3
4
10
<BLANKLINE>
0
1
2
3
4
10
<BLANKLINE>
0.0
1.0
2.0
3.0
4.0
10.0
<BLANKLINE>
humpty
dumpty
fall
splatch
breakfast
humptydumptyfallsplatchbreakfast
"""
cdef int int_array[5]
cdef long long_array[5]
cdef float float_array[5]
cdef string_t string_array[5]
cdef char *s1 = "humpty", *s2 = "dumpty", *s3 = "fall", *s4 = "splatch", *s5 = "breakfast"
strings = ["humpty", "dumpty", "fall", "splatch", "breakfast"]
for i in range(5):
int_array[i] = i
long_array[i] = i
float_array[i] = i
s = strings[i]
string_array[i] = s
print fused_with_pointer(int_array)
print
print fused_with_pointer(long_array)
print
print fused_with_pointer(float_array)
print
print fused_with_pointer(string_array)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment