Commit 5061256f authored by Mark Florisson's avatar Mark Florisson

Constant fold fused type check expressions with branch pruning

parent e3d28878
...@@ -29,6 +29,7 @@ from StringEncoding import EncodedString, escape_byte_string, split_string_liter ...@@ -29,6 +29,7 @@ from StringEncoding import EncodedString, escape_byte_string, split_string_liter
import Options import Options
import ControlFlow import ControlFlow
import DebugFlags import DebugFlags
from Cython.Compiler import Errors
absolute_path_length = 0 absolute_path_length = 0
...@@ -942,14 +943,23 @@ class FusedTypeNode(CBaseTypeNode): ...@@ -942,14 +943,23 @@ class FusedTypeNode(CBaseTypeNode):
return self.types[0] return self.types[0]
types = [] types = []
seen = cython.set()
for type in self.types: for type in self.types:
self.add_type(type, types, seen)
return PyrexTypes.FusedType(types)
def add_type(self, type, types, seen):
if type not in seen:
seen.add(type)
if type.is_fused: if type.is_fused:
types.extend(type.types) for specific_type in PyrexTypes.get_specific_types(type):
self.add_type(specific_type, types, seen)
else: else:
types.append(type) types.append(type)
return PyrexTypes.FusedType(types)
class CVarDefNode(StatNode): class CVarDefNode(StatNode):
# C variable definition or forward/extern function declaration. # C variable definition or forward/extern function declaration.
...@@ -2055,8 +2065,13 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2055,8 +2065,13 @@ class FusedCFuncDefNode(StatListNode):
cname = self.node.type.get_specific_cname(cname) cname = self.node.type.get_specific_cname(cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname copied_node.entry.func_cname = copied_node.entry.cname = cname
# TransformBuiltinMethods(copied_node) num_errors = Errors.num_errors
ParseTreeTransforms.ReplaceFusedTypeChecks(copied_node.local_scope)(copied_node) transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
transform(copied_node)
if Errors.num_errors > num_errors:
break
class PyArgDeclNode(Node): class PyArgDeclNode(Node):
......
...@@ -2974,8 +2974,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -2974,8 +2974,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
literal nodes at each step. Non-literal nodes are never merged literal nodes at each step. Non-literal nodes are never merged
into a single node. into a single node.
""" """
check_constant_value_not_set = True
def _calculate_const(self, node): def _calculate_const(self, node):
if node.constant_result is not ExprNodes.constant_value_not_set: if (self.check_constant_value_not_set and
node.constant_result is not ExprNodes.constant_value_not_set):
return return
# make sure we always set the value # make sure we always set the value
......
...@@ -1899,30 +1899,64 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1899,30 +1899,64 @@ class ReplaceFusedTypeChecks(VisitorTransform):
self.local_scope = local_scope self.local_scope = local_scope
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
if_clauses = node.if_clauses[:] """
Filters out any if clauses with false compile time type check
expression.
"""
from Cython.Compiler import Optimize
self.visitchildren(node) self.visitchildren(node)
if if_clauses != node.if_clauses: if_clauses = []
if node.if_clauses: seen_true = False
return node.if_clauses[0] seen_true_anywhere = False
return node.else_clause
for if_clause in node.if_clauses:
transform = Optimize.ConstantFolding()
transform.check_constant_value_not_set = False
transform(if_clause.condition)
is_const = if_clause.condition.has_constant_result()
const = if_clause.condition.constant_result
if is_const:
if const and not seen_true:
seen_true = True
seen_true_anywhere = True
if_clauses.append(if_clause)
else:
seen_true = False
if_clauses.append(if_clause)
if if_clauses:
node.if_clauses = if_clauses
if seen_true_anywhere:
node.else_clause = None
else:
node = node.else_clause
return node return node
def visit_IfClauseNode(self, node): def visit_PrimaryCmpNode(self, node):
cond = node.condition type1 = node.operand1.analyse_as_type(self.local_scope)
if isinstance(cond, ExprNodes.PrimaryCmpNode): type2 = node.operand2.analyse_as_type(self.local_scope)
type1 = cond.operand1.analyse_as_type(self.local_scope)
type2 = cond.operand2.analyse_as_type(self.local_scope)
if type1 and type2: if type1 and type2:
type1 = self.specialize_type(type1, cond.operand1.pos) false = ExprNodes.BoolNode(node.pos, value=False)
op = cond.operator true = ExprNodes.BoolNode(node.pos, value=True)
type1 = self.specialize_type(type1, node.operand1.pos)
op = node.operator
if op in ('is', 'is not', '==', '!='):
type2 = self.specialize_type(type2, node.operand2.pos)
is_same = type1.same_as(type2)
eq = op in ('is', '==')
if (is_same and eq) or (not is_same and not eq):
return true
if op == 'is':
type2 = self.specialize_type(type2, cond.operand1.pos)
if type1.same_as(type2):
return node.body
elif op in ('in', 'not_in'): elif op in ('in', 'not_in'):
# We have to do an instance check directly, as operand2 # We have to do an instance check directly, as operand2
# needs to be a fused type and not a type with a subtype # needs to be a fused type and not a type with a subtype
...@@ -1930,17 +1964,29 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -1930,17 +1964,29 @@ class ReplaceFusedTypeChecks(VisitorTransform):
if isinstance(type2, PyrexTypes.CTypedefType): if isinstance(type2, PyrexTypes.CTypedefType):
type2 = type2.typedef_base_type type2 = type2.typedef_base_type
if type1.is_fused or not isinstance(type2, PyrexTypes.FusedType): if type1.is_fused:
error(cond.pos, "Can use 'in' or 'not in' only on a " error(node.operand1.pos, "Type is fused")
"specific and a fused type") elif not type2.is_fused:
elif op == 'in': error(node.operand2.pos,
if type1 in type2.types: "Can only use 'in' or 'not in' on a fused type")
return node.body else:
if not isinstance(type2, PyrexTypes.FusedType):
# Composed fused type, get all specific versions
types = PyrexTypes.get_specific_types(type2)
else: else:
if type1 not in type2.types: types = type2.types
return node.body
return None for specific_type in types:
if type1.same_as(specific_type):
if op == 'in':
return true
else:
return false
if op == 'not_in':
return true
return false
return node return node
......
...@@ -2023,18 +2023,7 @@ class CFuncType(CType): ...@@ -2023,18 +2023,7 @@ class CFuncType(CType):
if fused_types is None: if fused_types is None:
fused_types = self.get_fused_types() fused_types = self.get_fused_types()
fused_type = fused_types[0] return get_all_specific_permutations(fused_types)
for specific_type in fused_type.types:
cname = str(specific_type)
result_fused_to_specific = { fused_type: specific_type }
if len(fused_types) > 1:
it = self.get_all_specific_permutations(fused_types[1:])
for next_cname, fused_to_specific in it:
d = dict(fused_to_specific, **result_fused_to_specific)
yield '%s_%s' % (cname, next_cname), d
else:
yield cname, result_fused_to_specific
def get_all_specific_function_types(self): def get_all_specific_function_types(self):
""" """
...@@ -2042,23 +2031,25 @@ class CFuncType(CType): ...@@ -2042,23 +2031,25 @@ class CFuncType(CType):
""" """
assert self.is_fused assert self.is_fused
result = []
permutations = self.get_all_specific_permutations() permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific) new_func_type = self.entry.type.specialize(fused_to_specific)
new_entry = copy.deepcopy(self.entry) new_entry = copy.deepcopy(self.entry)
new_entry.cname = self.get_specific_cname(cname) new_entry.cname = self.get_specific_cname(cname)
new_entry.type = new_func_type
new_entry.type = new_func_type
new_func_type.entry = new_entry new_func_type.entry = new_entry
yield new_func_type
result.append(new_func_type)
return result
def get_specific_cname(self, fused_cname): def get_specific_cname(self, fused_cname):
""" """
Given the cname for a permutation of fused types, return the cname Given the cname for a permutation of fused types, return the cname
for the corresponding function with specific types. for the corresponding function with specific types.
The fused_cname is usually '_'.join(str(t) for t in specific_types)
""" """
assert self.is_fused assert self.is_fused
return '%s%s%s' % (Naming.fused_func_prefix, return '%s%s%s' % (Naming.fused_func_prefix,
...@@ -2086,6 +2077,31 @@ class CFuncType(CType): ...@@ -2086,6 +2077,31 @@ class CFuncType(CType):
# a normal cdef # a normal cdef
return func(entry, *args, **kwargs) return func(entry, *args, **kwargs)
def get_all_specific_permutations(fused_types, id="0", f2s=()):
fused_type = fused_types[0]
result = []
for newid, specific_type in enumerate(fused_type.types):
f2s = dict(f2s, **{ fused_type: specific_type })
cname = '%s_%s' % (id, newid)
if len(fused_types) > 1:
result.extend(get_all_specific_permutations(
fused_types[1:], cname, f2s))
else:
result.append((cname, f2s))
return result
def get_specific_types(type):
assert type.is_fused
result = []
for cname, f2s in get_all_specific_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
return result
class CFuncTypeArg(BaseType): class CFuncTypeArg(BaseType):
# name string # name string
......
...@@ -2,8 +2,16 @@ cimport cython ...@@ -2,8 +2,16 @@ cimport cython
cimport check_fused_types_pxd cimport check_fused_types_pxd
ctypedef char *string_t ctypedef char *string_t
ctypedef cython.fused_type(int, long, float, string_t) fused_t ctypedef cython.fused_type(int, long, float, string_t) fused_t
ctypedef cython.fused_type(int, long) other_t ctypedef cython.fused_type(int, long) other_t
ctypedef cython.fused_type(short, short int, short, int) base_t
ctypedef cython.fused_type(float complex, double complex,
int complex, long complex) complex_t
ctypedef base_t **base_t_p_p
ctypedef cython.fused_type(char, base_t_p_p, fused_t, complex_t) composed_t
cdef func(fused_t a, other_t b): cdef func(fused_t a, other_t b):
cdef int int_a cdef int int_a
...@@ -121,3 +129,59 @@ def test_if_then_else_float_int(): ...@@ -121,3 +129,59 @@ def test_if_then_else_float_int():
cdef int y = 1 cdef int y = 1
if_then_else(x, y) if_then_else(x, y)
cdef composed_t composed(composed_t x, composed_t y):
if composed_t in base_t_p_p or composed_t is string_t:
if string_t == composed_t:
print x, y
else:
print x[0][0], y[0][0]
return x
elif composed_t == string_t:
print 'this is never executed'
elif list():
print 'neither is this one'
else:
if composed_t not in complex_t:
print 'not a complex number'
print <int> x, <int> y
else:
print 'it is a complex number'
print x.real, x.imag
return x + y
def test_composed_types():
"""
>>> test_composed_types()
it is a complex number
0.5 0.6
(0.9+0.4j)
<BLANKLINE>
not a complex number
9 10
19
<BLANKLINE>
7 8
<BLANKLINE>
spam eggs
spam
"""
cdef double complex a = 0.5 + 0.6j, b = 0.4 -0.2j, result
cdef int c = 7, d = 8
cdef int *cp = &c, *dp = &d
cdef string_t e = "spam", f = "eggs"
result = composed(a, b)
print result
print
print composed(c + 2, d + 2)
print
composed(&cp, &dp)
print
print composed(e, f)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment