Commit 3537d938 authored by Mark Florisson's avatar Mark Florisson
parents efb0837d 3487c4e0
...@@ -11,6 +11,7 @@ Cython/Runtime/refnanny.c ...@@ -11,6 +11,7 @@ Cython/Runtime/refnanny.c
BUILD/ BUILD/
build/ build/
dist/ dist/
.git/
.gitrev .gitrev
.coverage .coverage
*.orig *.orig
......
...@@ -1444,6 +1444,10 @@ class CCodeWriter(object): ...@@ -1444,6 +1444,10 @@ class CCodeWriter(object):
def put_trace_return(self, retvalue_cname): def put_trace_return(self, retvalue_cname):
self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname) self.putln("__Pyx_TraceReturn(%s);" % retvalue_cname)
def putln_openmp(self, string):
self.putln("#ifdef _OPENMP")
self.putln(string)
self.putln("#endif /* _OPENMP */")
class PyrexCodeWriter(object): class PyrexCodeWriter(object):
# f file output file # f file output file
......
...@@ -2059,6 +2059,64 @@ class RawCNameExprNode(ExprNode): ...@@ -2059,6 +2059,64 @@ class RawCNameExprNode(ExprNode):
pass pass
#-------------------------------------------------------------------
#
# Parallel nodes (cython.parallel.thread(savailable|id))
#
#-------------------------------------------------------------------
class ParallelThreadsAvailableNode(AtomicExprNode):
"""
Note: this is disabled and not a valid directive at this moment
Implements cython.parallel.threadsavailable(). If we are called from the
sequential part of the application, we need to call omp_get_max_threads(),
and in the parallel part we can just call omp_get_num_threads()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
# env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("if (omp_in_parallel()) %s = omp_get_max_threads();" %
self.temp_code)
code.putln("else %s = omp_get_num_threads();" % self.temp_code)
code.putln("#else")
code.putln("%s = 1;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
"""
Implements cython.parallel.threadid()
"""
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.is_temp = True
# env.add_include_file("omp.h")
return self.type
def generate_result_code(self, code):
code.putln("#ifdef _OPENMP")
code.putln("%s = omp_get_thread_num();" % self.temp_code)
code.putln("#else")
code.putln("%s = 0;" % self.temp_code)
code.putln("#endif")
def result(self):
return self.temp_code
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
# Trailer nodes # Trailer nodes
...@@ -3465,8 +3523,11 @@ class AttributeNode(ExprNode): ...@@ -3465,8 +3523,11 @@ class AttributeNode(ExprNode):
needs_none_check = True needs_none_check = True
def as_cython_attribute(self): def as_cython_attribute(self):
if isinstance(self.obj, NameNode) and self.obj.is_cython_module: if (isinstance(self.obj, NameNode) and
self.obj.is_cython_module and not
self.attribute == u"parallel"):
return self.attribute return self.attribute
cy = self.obj.as_cython_attribute() cy = self.obj.as_cython_attribute()
if cy: if cy:
return "%s.%s" % (cy, self.attribute) return "%s.%s" % (cy, self.attribute)
......
...@@ -106,7 +106,7 @@ class Context(object): ...@@ -106,7 +106,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
...@@ -135,6 +135,7 @@ class Context(object): ...@@ -135,6 +135,7 @@ class Context(object):
PostParse(self), PostParse(self),
_specific_post_parse, _specific_post_parse,
InterpretCompilerDirectives(self, self.compiler_directives), InterpretCompilerDirectives(self, self.compiler_directives),
ParallelRangeTransform(self),
MarkClosureVisitor(self), MarkClosureVisitor(self),
_align_function_definitions, _align_function_definitions,
ConstantFolding(), ConstantFolding(),
......
...@@ -756,6 +756,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -756,6 +756,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else: else:
code.putln('#include "%s"' % byte_decoded_filenname) code.putln('#include "%s"' % byte_decoded_filenname)
code.putln_openmp("#include <omp.h>")
def generate_filename_table(self, code): def generate_filename_table(self, code):
code.putln("") code.putln("")
code.putln("static const char *%s[] = {" % Naming.filetable_cname) code.putln("static const char *%s[] = {" % Naming.filetable_cname)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# Pyrex - Parse tree nodes # Pyrex - Parse tree nodes
# #
import cython import cython
from cython import set from cython import set
cython.declare(sys=object, os=object, time=object, copy=object, cython.declare(sys=object, os=object, time=object, copy=object,
...@@ -4591,14 +4590,12 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -4591,14 +4590,12 @@ class ForInStatNode(LoopNode, StatNode):
old_loop_labels = code.new_loop_labels() old_loop_labels = code.new_loop_labels()
self.iterator.allocate_counter_temp(code) self.iterator.allocate_counter_temp(code)
self.iterator.generate_evaluation_code(code) self.iterator.generate_evaluation_code(code)
code.putln( code.putln("for (;;) {")
"for (;;) {")
self.item.generate_evaluation_code(code) self.item.generate_evaluation_code(code)
self.target.generate_assignment_code(self.item, code) self.target.generate_assignment_code(self.item, code)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.put_label(code.continue_label) code.put_label(code.continue_label)
code.putln( code.putln("}")
"}")
break_label = code.break_label break_label = code.break_label
code.set_loop_labels(old_loop_labels) code.set_loop_labels(old_loop_labels)
...@@ -5729,6 +5726,487 @@ class FromImportStatNode(StatNode): ...@@ -5729,6 +5726,487 @@ class FromImportStatNode(StatNode):
self.module.free_temps(code) self.module.free_temps(code)
class ParallelNode(Node):
"""
Base class for cython.parallel constructs.
"""
nogil_check = None
class ParallelStatNode(StatNode, ParallelNode):
"""
Base class for 'with cython.parallel.parallel:' and 'for i in prange():'.
assignments { Entry(var) : (var.pos, inplace_operator_or_None) }
assignments to variables in this parallel section
parent parent ParallelStatNode or None
is_parallel indicates whether this is a parallel node
is_parallel is true for:
#pragma omp parallel
#pragma omp parallel for
sections, but NOT for
#pragma omp for
We need this to determine the sharing attributes.
privatization_insertion_point a code insertion point used to make temps
private (esp. the "nsteps" temp)
"""
child_attrs = ['body']
body = None
is_prange = False
def __init__(self, pos, **kwargs):
super(ParallelStatNode, self).__init__(pos, **kwargs)
# All assignments in this scope
self.assignments = kwargs.get('assignments') or {}
# All seen closure cnames and their temporary cnames
self.seen_closure_vars = set()
# Dict of variables that should be declared (first|last|)private or
# reduction { Entry: op }. If op is not None, it's a reduction.
self.privates = {}
def analyse_declarations(self, env):
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
def analyse_sharing_attributes(self, env):
"""
Analyse the privates for this block and set them in self.privates.
This should be called in a post-order fashion during the
analyse_expressions phase
"""
for entry, (pos, op) in self.assignments.iteritems():
if self.is_private(entry):
self.propagate_var_privatization(entry, op)
def is_private(self, entry):
"""
True if this scope should declare the variable private, lastprivate
or reduction.
"""
return (self.is_parallel or
(self.parent and entry not in self.parent.privates))
def propagate_var_privatization(self, entry, op):
"""
Propagate the sharing attributes of a variable. If the privatization is
determined by a parent scope, done propagate further.
If we are a prange, we propagate our sharing attributes outwards to
other pranges. If we are a prange in parallel block and the parallel
block does not determine the variable private, we propagate to the
parent of the parent. Recursion stops at parallel blocks, as they have
no concept of lastprivate or reduction.
So the following cases propagate:
sum is a reduction for all loops:
for i in prange(n):
for j in prange(n):
for k in prange(n):
sum += i * j * k
sum is a reduction for both loops, local_var is private to the
parallel with block:
for i in prange(n):
with parallel:
local_var = ... # private to the parallel
for j in prange(n):
sum += i * j
Nested with parallel blocks are disallowed, because they wouldn't
allow you to propagate lastprivates or reductions:
#pragma omp parallel for lastprivate(i)
for i in prange(n):
sum = 0
#pragma omp parallel private(j, sum)
with parallel:
#pragma omp parallel
with parallel:
#pragma omp for lastprivate(j) reduction(+:sum)
for j in prange(n):
sum += i
# sum and j are well-defined here
# sum and j are undefined here
# sum and j are undefined here
"""
self.privates[entry] = op
if self.is_prange:
if not self.is_parallel and entry not in self.parent.assignments:
# Parent is a parallel with block
parent = self.parent.parent
else:
parent = self.parent
if parent:
parent.propagate_var_privatization(entry, op)
def _allocate_closure_temp(self, code, entry):
"""
Helper function that allocate a temporary for a closure variable that
is assigned to.
"""
if self.parent:
return self.parent._allocate_closure_temp(code, entry)
if entry.cname in self.seen_closure_vars:
return entry.cname
cname = code.funcstate.allocate_temp(entry.type, False)
# Add both the actual cname and the temp cname, as the actual cname
# will be replaced with the temp cname on the entry
self.seen_closure_vars.add(entry.cname)
self.seen_closure_vars.add(cname)
self.modified_entries.append((entry, entry.cname))
code.putln("%s = %s;" % (cname, entry.cname))
entry.cname = cname
def declare_closure_privates(self, code):
"""
Set self.privates to a dict mapping C variable names that are to be
declared (first|last)private or reduction, to the reduction operator.
If the private is not a reduction, the operator is None.
This is used by subclasses.
If a variable is in a scope object, we need to allocate a temp and
assign the value from the temp to the variable in the scope object
after the parallel section. This kind of copying should be done only
in the outermost parallel section.
"""
self.modified_entries = []
for entry, (pos, op) in self.assignments.iteritems():
if entry.from_closure or entry.in_closure:
self._allocate_closure_temp(code, entry)
def release_closure_privates(self, code):
"""
Release any temps used for variables in scope objects. As this is the
outermost parallel block, we don't need to delete the cnames from
self.seen_closure_vars
"""
for entry, original_cname in self.modified_entries:
code.putln("%s = %s;" % (original_cname, entry.cname))
code.funcstate.release_temp(entry.cname)
entry.cname = original_cname
class ParallelWithBlockNode(ParallelStatNode):
"""
This node represents a 'with cython.parallel.parallel:' block
"""
nogil_check = None
def analyse_expressions(self, env):
super(ParallelWithBlockNode, self).analyse_expressions(env)
self.analyse_sharing_attributes(env)
def generate_execution_code(self, code):
self.declare_closure_privates(code)
code.putln("#ifdef _OPENMP")
code.put("#pragma omp parallel ")
if self.privates:
code.put(
'private(%s)' % ', '.join([e.cname for e in self.privates]))
self.privatization_insertion_point = code.insertion_point()
code.putln("")
code.putln("#endif /* _OPENMP */")
code.begin_block()
self.body.generate_execution_code(code)
code.end_block()
self.release_closure_privates(code)
class ParallelRangeNode(ParallelStatNode):
"""
This node represents a 'for i in cython.parallel.prange():' construct.
target NameNode the target iteration variable
else_clause Node or None the else clause of this loop
args tuple the arguments passed to prange()
kwargs DictNode the keyword arguments passed to prange()
(replaced by its compile time value)
is_nogil bool indicates whether this is a nogil prange() node
"""
child_attrs = ['body', 'target', 'else_clause', 'args']
body = target = else_clause = args = None
start = stop = step = None
is_prange = True
is_nogil = False
def analyse_declarations(self, env):
super(ParallelRangeNode, self).analyse_declarations(env)
self.target.analyse_target_declaration(env)
if self.else_clause is not None:
self.else_clause.analyse_declarations(env)
if not self.args or len(self.args) > 3:
error(self.pos, "Invalid number of positional arguments to prange")
return
if len(self.args) == 1:
self.stop, = self.args
elif len(self.args) == 2:
self.start, self.stop = self.args
else:
self.start, self.stop, self.step = self.args
if self.kwargs:
self.kwargs = self.kwargs.compile_time_value(env)
else:
self.kwargs = {}
self.is_nogil = self.kwargs.pop('nogil', False)
self.schedule = self.kwargs.pop('schedule', None)
if hasattr(self.schedule, 'decode'):
self.schedule = self.schedule.decode('ascii')
if self.schedule not in (None, 'static', 'dynamic', 'guided',
'runtime'):
error(self.pos, "Invalid schedule argument to prange: %s" %
(self.schedule,))
for kw in self.kwargs:
error(self.pos, "Invalid keyword argument to prange: %s" % kw)
def analyse_expressions(self, env):
if self.target is None:
error(self.pos, "prange() can only be used as part of a for loop")
return
self.target.analyse_target_types(env)
if not self.target.type.is_numeric:
# Not a valid type, assume one for now anyway
if not self.target.type.is_pyobject:
# nogil_check will catch the is_pyobject case
error(self.target.pos,
"Must be of numeric type, not %s" % self.target.type)
self.index_type = PyrexTypes.c_py_ssize_t_type
else:
self.index_type = self.target.type
# Setup start, stop and step, allocating temps if needed
self.names = 'start', 'stop', 'step'
start_stop_step = self.start, self.stop, self.step
for node, name in zip(start_stop_step, self.names):
if node is not None:
node.analyse_types(env)
if not node.type.is_numeric:
error(node.pos, "%s argument must be numeric or a pointer "
"(perhaps if a numeric literal is too "
"big, use 1000LL)" % name)
if not node.is_literal:
node = node.coerce_to_temp(env)
setattr(self, name, node)
# As we range from 0 to nsteps, computing the index along the
# way, we need a fitting type for 'i' and 'nsteps'
self.index_type = PyrexTypes.widest_numeric_type(
self.index_type, node.type)
super(ParallelRangeNode, self).analyse_expressions(env)
if self.else_clause is not None:
self.else_clause.analyse_expressions(env)
# Although not actually an assignment in this scope, it should be
# treated as such to ensure it is unpacked if a closure temp, and to
# ensure lastprivate behaviour and propagation. If the target index is
# not a NameNode, it won't have an entry, and an error was issued by
# ParallelRangeTransform
if hasattr(self.target, 'entry'):
self.assignments[self.target.entry] = self.target.pos, None
self.analyse_sharing_attributes(env)
def nogil_check(self, env):
names = 'start', 'stop', 'step', 'target'
nodes = self.start, self.stop, self.step, self.target
for name, node in zip(names, nodes):
if node is not None and node.type.is_pyobject:
error(node.pos, "%s may not be a Python object "
"as we don't have the GIL" % name)
def generate_execution_code(self, code):
"""
Generate code in the following steps
1) copy any closure variables determined thread-private
into temporaries
2) allocate temps for start, stop and step
3) generate a loop that calculates the total number of steps,
which then computes the target iteration variable for every step:
for i in prange(start, stop, step):
...
becomes
nsteps = (stop - start) / step;
i = start;
#pragma omp parallel for lastprivate(i)
for (temp = 0; temp < nsteps; temp++) {
i = start + step * temp;
...
}
Note that accumulation of 'i' would have a data dependency
between iterations.
Also, you can't do this
for (i = start; i < stop; i += step)
...
as the '<' operator should become '>' for descending loops.
'for i from x < i < y:' does not suffer from this problem
as the relational operator is known at compile time!
4) release our temps and write back any private closure variables
"""
self.declare_closure_privates(code)
# This can only be a NameNode
target_index_cname = self.target.entry.cname
# This will be used as the dict to format our code strings, holding
# the start, stop , step, temps and target cnames
fmt_dict = {
'target': target_index_cname,
}
# Setup start, stop and step, allocating temps if needed
start_stop_step = self.start, self.stop, self.step
defaults = '0', '0', '1'
for node, name, default in zip(start_stop_step, self.names, defaults):
if node is None:
result = default
elif node.is_literal:
result = node.get_constant_c_result_code()
else:
node.generate_evaluation_code(code)
result = node.result()
fmt_dict[name] = result
fmt_dict['i'] = code.funcstate.allocate_temp(self.index_type, False)
fmt_dict['nsteps'] = code.funcstate.allocate_temp(self.index_type, False)
# TODO: check if the step is 0 and if so, raise an exception in a
# 'with gil' block. For now, just abort
code.putln("if (%(step)s == 0) abort();" % fmt_dict)
# Note: nsteps is private in an outer scope if present
code.putln("%(nsteps)s = (%(stop)s - %(start)s) / %(step)s;" % fmt_dict)
# The target iteration variable might not be initialized, do it only if
# we are executing at least 1 iteration, otherwise we should leave the
# target unaffected. The target iteration variable is firstprivate to
# shut up compiler warnings caused by lastprivate, as the compiler
# erroneously believes that nsteps may be <= 0, leaving the private
# target index uninitialized
code.putln("if (%(nsteps)s > 0)" % fmt_dict)
code.begin_block()
code.putln("%(target)s = 0;" % fmt_dict)
self.generate_loop(code, fmt_dict)
code.end_block()
# And finally, release our privates and write back any closure
# variables
for temp in start_stop_step:
if temp is not None:
temp.generate_disposal_code(code)
temp.free_temps(code)
code.funcstate.release_temp(fmt_dict['i'])
code.funcstate.release_temp(fmt_dict['nsteps'])
self.release_closure_privates(code)
def generate_loop(self, code, fmt_dict):
code.putln("#ifdef _OPENMP")
if not self.is_parallel:
code.put("#pragma omp for")
else:
code.put("#pragma omp parallel for")
for entry, op in self.privates.iteritems():
# Don't declare the index variable as a reduction
if op and op in "+*-&^|" and entry != self.target.entry:
code.put(" reduction(%s:%s)" % (op, entry.cname))
else:
if entry == self.target.entry:
code.put(" firstprivate(%s)" % entry.cname)
code.put(" lastprivate(%s)" % entry.cname)
if self.schedule:
code.put(" schedule(%s)" % self.schedule)
if self.parent:
c = self.parent.privatization_insertion_point
c.put(" private(%(nsteps)s)" % fmt_dict)
self.privatization_insertion_point = code.insertion_point()
code.putln("")
code.putln("#endif /* _OPENMP */")
code.put("for (%(i)s = 0; %(i)s < %(nsteps)s; %(i)s++)" % fmt_dict)
code.begin_block()
code.putln("%(target)s = %(start)s + %(step)s * %(i)s;" % fmt_dict)
self.body.generate_execution_code(code)
code.end_block()
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
......
...@@ -612,9 +612,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -612,9 +612,16 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
} }
special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof', special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL']) 'cast', 'pointer', 'compiled', 'NULL', 'parallel'])
special_methods.update(unop_method_nodes.keys()) special_methods.update(unop_method_nodes.keys())
valid_parallel_directives = cython.set([
"parallel",
"prange",
"threadid",
# "threadsavailable",
])
def __init__(self, context, compilation_directive_defaults): def __init__(self, context, compilation_directive_defaults):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
self.compilation_directive_defaults = {} self.compilation_directive_defaults = {}
...@@ -622,6 +629,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -622,6 +629,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value) self.compilation_directive_defaults[unicode(key)] = copy.deepcopy(value)
self.cython_module_names = cython.set() self.cython_module_names = cython.set()
self.directive_names = {} self.directive_names = {}
self.parallel_directives = {}
def check_directive_scope(self, pos, directive, scope): def check_directive_scope(self, pos, directive, scope):
legal_scopes = Options.directive_scopes.get(directive, None) legal_scopes = Options.directive_scopes.get(directive, None)
...@@ -644,6 +652,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -644,6 +652,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
directives.update(node.directive_comments) directives.update(node.directive_comments)
self.directives = directives self.directives = directives
node.directives = directives node.directives = directives
node.parallel_directives = self.parallel_directives
self.visitchildren(node) self.visitchildren(node)
node.cython_module_names = self.cython_module_names node.cython_module_names = self.cython_module_names
return node return node
...@@ -655,11 +664,35 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -655,11 +664,35 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
name in self.special_methods or name in self.special_methods or
PyrexTypes.parse_basic_type(name)) PyrexTypes.parse_basic_type(name))
def is_parallel_directive(self, full_name, pos):
result = (full_name + ".").startswith("cython.parallel.")
if result:
directive = full_name.rsplit('.', 1)
if len(directive) == 2 and directive[1] == '*':
# star import
for name in self.valid_parallel_directives:
self.parallel_directives[name] = u"cython.parallel.%s" % name
elif (len(directive) != 2 or
directive[1] not in self.valid_parallel_directives):
error(pos, "No such directive: %s" % full_name)
return result
def visit_CImportStatNode(self, node): def visit_CImportStatNode(self, node):
if node.module_name == u"cython": if node.module_name == u"cython":
self.cython_module_names.add(node.as_name or u"cython") self.cython_module_names.add(node.as_name or u"cython")
elif node.module_name.startswith(u"cython."): elif node.module_name.startswith(u"cython."):
if node.as_name: if node.module_name.startswith(u"cython.parallel."):
error(node.pos, node.module_name + " is not a module")
if node.module_name == u"cython.parallel":
if node.as_name and node.as_name != u"cython":
self.parallel_directives[node.as_name] = node.module_name
else:
self.cython_module_names.add(u"cython")
self.parallel_directives[
u"cython.parallel"] = node.module_name
elif node.as_name:
self.directive_names[node.as_name] = node.module_name[7:] self.directive_names[node.as_name] = node.module_name[7:]
else: else:
self.cython_module_names.add(u"cython") self.cython_module_names.add(u"cython")
...@@ -673,19 +706,29 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -673,19 +706,29 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
node.module_name.startswith(u"cython."): node.module_name.startswith(u"cython."):
submodule = (node.module_name + u".")[7:] submodule = (node.module_name + u".")[7:]
newimp = [] newimp = []
for pos, name, as_name, kind in node.imported_names: for pos, name, as_name, kind in node.imported_names:
full_name = submodule + name full_name = submodule + name
if self.is_cython_directive(full_name): qualified_name = u"cython." + full_name
if self.is_parallel_directive(qualified_name, node.pos):
# from cython cimport parallel, or
# from cython.parallel cimport parallel, prange, ...
self.parallel_directives[as_name or name] = qualified_name
elif self.is_cython_directive(full_name):
if as_name is None: if as_name is None:
as_name = full_name as_name = full_name
self.directive_names[as_name] = full_name self.directive_names[as_name] = full_name
if kind is not None: if kind is not None:
self.context.nonfatal_error(PostParseError(pos, self.context.nonfatal_error(PostParseError(pos,
"Compiler directive imports must be plain imports")) "Compiler directive imports must be plain imports"))
else: else:
newimp.append((pos, name, as_name, kind)) newimp.append((pos, name, as_name, kind))
if not newimp: if not newimp:
return None return None
node.imported_names = newimp node.imported_names = newimp
return node return node
...@@ -696,7 +739,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -696,7 +739,10 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
newimp = [] newimp = []
for name, name_node in node.items: for name, name_node in node.items:
full_name = submodule + name full_name = submodule + name
if self.is_cython_directive(full_name): qualified_name = u"cython." + full_name
if self.is_parallel_directive(qualified_name, node.pos):
self.parallel_directives[name_node.name] = qualified_name
elif self.is_cython_directive(full_name):
self.directive_names[name_node.name] = full_name self.directive_names[name_node.name] = full_name
else: else:
newimp.append((name, name_node)) newimp.append((name, name_node))
...@@ -706,14 +752,23 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -706,14 +752,23 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return node return node
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
if (isinstance(node.rhs, ExprNodes.ImportNode) and if isinstance(node.rhs, ExprNodes.ImportNode):
node.rhs.module_name.value == u'cython'): module_name = node.rhs.module_name.value
is_parallel = (module_name + u".").startswith(u"cython.parallel.")
if module_name != u"cython" and not is_parallel:
return node
module_name = node.rhs.module_name.value
as_name = node.lhs.name
node = Nodes.CImportStatNode(node.pos, node = Nodes.CImportStatNode(node.pos,
module_name = u'cython', module_name = module_name,
as_name = node.lhs.name) as_name = as_name)
self.visit_CImportStatNode(node) node = self.visit_CImportStatNode(node)
else: else:
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_NameNode(self, node): def visit_NameNode(self, node):
...@@ -897,6 +952,202 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -897,6 +952,202 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return self.visit_with_directives(node.body, directive_dict) return self.visit_with_directives(node.body, directive_dict)
return self.visit_Node(node) return self.visit_Node(node)
class ParallelRangeTransform(CythonTransform, SkipDeclarations):
"""
Transform cython.parallel stuff. The parallel_directives come from the
module node, set there by InterpretCompilerDirectives.
x = cython.parallel.threadavailable() -> ParallelThreadAvailableNode
with nogil, cython.parallel.parallel: -> ParallelWithBlockNode
print cython.parallel.threadid() -> ParallelThreadIdNode
for i in cython.parallel.prange(...): -> ParallelRangeNode
...
"""
# a list of names, maps 'cython.parallel.prange' in the code to
# ['cython', 'parallel', 'prange']
parallel_directive = None
# Indicates whether a namenode in an expression is the cython module
namenode_is_cython_module = False
# Keep track of whether we are the context manager of a 'with' statement
in_context_manager_section = False
# Keep track of whether we are in a parallel range section
in_prange = False
# One of 'prange' or 'with parallel'. This is used to disallow closely
# nested 'with parallel:' blocks
state = None
directive_to_node = {
u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
# u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
u"cython.parallel.prange": Nodes.ParallelRangeNode,
}
def node_is_parallel_directive(self, node):
return node.name in self.parallel_directives or node.is_cython_module
def get_directive_class_node(self, node):
"""
Figure out which parallel directive was used and return the associated
Node class.
E.g. for a cython.parallel.prange() call we return ParallelRangeNode
Also disallow break, continue and return in a prange section
"""
if self.namenode_is_cython_module:
directive = '.'.join(self.parallel_directive)
else:
directive = self.parallel_directives[self.parallel_directive[0]]
directive = '%s.%s' % (directive,
'.'.join(self.parallel_directive[1:]))
directive = directive.rstrip('.')
cls = self.directive_to_node.get(directive)
if cls is None:
error(node.pos, "Invalid directive: %s" % directive)
self.namenode_is_cython_module = False
self.parallel_directive = None
return cls
def visit_ModuleNode(self, node):
"""
If any parallel directives were imported, copy them over and visit
the AST
"""
if node.parallel_directives:
self.parallel_directives = node.parallel_directives
self.assignment_stack = []
return self.visit_Node(node)
# No parallel directives were imported, so they can't be used :)
return node
def visit_NameNode(self, node):
if self.node_is_parallel_directive(node):
self.parallel_directive = [node.name]
self.namenode_is_cython_module = node.is_cython_module
return node
def visit_AttributeNode(self, node):
self.visitchildren(node)
if self.parallel_directive:
self.parallel_directive.append(node.attribute)
return node
def visit_CallNode(self, node):
self.visit(node.function)
if not self.parallel_directive:
return node
# We are a parallel directive, replace this node with the
# corresponding ParallelSomethingSomething node
if isinstance(node, ExprNodes.GeneralCallNode):
args = node.positional_args.args
kwargs = node.keyword_args
else:
args = node.args
kwargs = {}
parallel_directive_class = self.get_directive_class_node(node)
if parallel_directive_class:
node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)
return node
def visit_WithStatNode(self, node):
"Rewrite with cython.parallel() blocks"
self.visit(node.manager)
if self.parallel_directive:
parallel_directive_class = self.get_directive_class_node(node)
if not parallel_directive_class:
# There was an error, stop here and now
return None
if self.state == 'parallel with':
error(node.manager.pos,
"Closely nested 'with parallel:' blocks are disallowed")
self.state = 'parallel with'
self.visit(node.body)
self.state = None
newnode = Nodes.ParallelWithBlockNode(node.pos, body=node.body)
else:
newnode = node
self.visit(node.body)
return newnode
def visit_ForInStatNode(self, node):
"Rewrite 'for i in cython.parallel.prange(...):'"
self.visit(node.iterator)
self.visit(node.target)
was_in_prange = self.in_prange
self.in_prange = isinstance(node.iterator.sequence,
Nodes.ParallelRangeNode)
previous_state = self.state
if self.in_prange:
# This will replace the entire ForInStatNode, so copy the
# attributes
parallel_range_node = node.iterator.sequence
parallel_range_node.target = node.target
parallel_range_node.body = node.body
parallel_range_node.else_clause = node.else_clause
node = parallel_range_node
if not isinstance(node.target, ExprNodes.NameNode):
error(node.target.pos,
"Can only iterate over an iteration variable")
self.state = 'prange'
self.visit(node.body)
self.state = previous_state
self.in_prange = was_in_prange
self.visit(node.else_clause)
return node
def ensure_not_in_prange(name):
"Creates error checking functions for break, continue and return"
def visit_method(self, node):
if self.in_prange:
error(node.pos,
name + " not allowed in a parallel range section")
# Do a visit for 'return'
self.visitchildren(node)
return node
return visit_method
visit_BreakStatNode = ensure_not_in_prange("break")
visit_ContinueStatNode = ensure_not_in_prange("continue")
visit_ReturnStatNode = ensure_not_in_prange("return")
def visit(self, node):
"Visit a node that may be None"
if node is not None:
super(ParallelRangeTransform, self).visit(node)
class WithTransform(CythonTransform, SkipDeclarations): class WithTransform(CythonTransform, SkipDeclarations):
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
self.visitchildren(node, 'body') self.visitchildren(node, 'body')
...@@ -1678,22 +1929,54 @@ class GilCheck(VisitorTransform): ...@@ -1678,22 +1929,54 @@ class GilCheck(VisitorTransform):
self.env_stack.append(node.local_scope) self.env_stack.append(node.local_scope)
was_nogil = self.nogil was_nogil = self.nogil
self.nogil = node.local_scope.nogil self.nogil = node.local_scope.nogil
if self.nogil and node.nogil_check: if self.nogil and node.nogil_check:
node.nogil_check(node.local_scope) node.nogil_check(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
self.env_stack.pop() self.env_stack.pop()
self.nogil = was_nogil self.nogil = was_nogil
return node return node
def visit_GILStatNode(self, node): def visit_GILStatNode(self, node):
env = self.env_stack[-1] if self.nogil and node.nogil_check:
if self.nogil and node.nogil_check: node.nogil_check() node.nogil_check()
was_nogil = self.nogil was_nogil = self.nogil
self.nogil = (node.state == 'nogil') self.nogil = (node.state == 'nogil')
self.visitchildren(node) self.visitchildren(node)
self.nogil = was_nogil self.nogil = was_nogil
return node return node
def visit_ParallelRangeNode(self, node):
if node.is_nogil:
node.is_nogil = False
node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
return self.visit_GILStatNode(node)
if not self.nogil:
error(node.pos, "prange() can only be used without the GIL")
# Forget about any GIL-related errors that may occur in the body
return None
node.nogil_check(self.env_stack[-1])
self.visitchildren(node)
return node
def visit_ParallelWithBlockNode(self, node):
if not self.nogil:
error(node.pos, "The parallel section may only be used without "
"the GIL")
return None
if node.nogil_check:
# It does not currently implement this, but test for it anyway to
# avoid potential future surprises
node.nogil_check(self.env_stack[-1])
self.visitchildren(node)
return node
def visit_Node(self, node): def visit_Node(self, node):
if self.env_stack and self.nogil and node.nogil_check: if self.env_stack and self.nogil and node.nogil_check:
node.nogil_check(self.env_stack[-1]) node.nogil_check(self.env_stack[-1])
...@@ -1820,8 +2103,7 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1820,8 +2103,7 @@ class TransformBuiltinMethods(EnvTransform):
class DebugTransform(CythonTransform): class DebugTransform(CythonTransform):
""" """
Create debug information and all functions' visibility to extern in order Write debug information for this Cython module.
to enable debugging.
""" """
def __init__(self, context, options, result): def __init__(self, context, options, result):
......
...@@ -723,6 +723,9 @@ class Scope(object): ...@@ -723,6 +723,9 @@ class Scope(object):
else: else:
return outer.is_cpp() return outer.is_cpp()
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname namespace_cname = Naming.preimport_cname
...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope): ...@@ -1856,8 +1859,6 @@ class CppClassScope(Scope):
utility_code = e.utility_code) utility_code = e.utility_code)
return scope return scope
def add_include_file(self, filename):
self.outer_scope.add_include_file(filename)
class PropertyScope(Scope): class PropertyScope(Scope):
# Scope holding the __get__, __set__ and __del__ methods for # Scope holding the __get__, __set__ and __del__ methods for
......
...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine ...@@ -4,6 +4,7 @@ from Cython.Compiler import CmdLine
from Cython.TestUtils import TransformTest from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler import Main
class TestNormalizeTree(TransformTest): class TestNormalizeTree(TransformTest):
...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled! ...@@ -144,6 +145,62 @@ class TestWithTransform(object): # (TransformTest): # Disabled!
""", t) """, t)
class TestInterpretCompilerDirectives(TransformTest):
"""
This class tests the parallel directives AST-rewriting and importing.
"""
# Test the parallel directives (c)importing
import_code = u"""
cimport cython.parallel
cimport cython.parallel as par
from cython cimport parallel as par2
from cython cimport parallel
from cython.parallel cimport threadid as tid
from cython.parallel cimport threadavailable as tavail
from cython.parallel cimport prange
"""
expected_directives_dict = {
u'cython.parallel': u'cython.parallel',
u'par': u'cython.parallel',
u'par2': u'cython.parallel',
u'parallel': u'cython.parallel',
u"tid": u"cython.parallel.threadid",
u"tavail": u"cython.parallel.threadavailable",
u"prange": u"cython.parallel.prange",
}
def setUp(self):
super(TestInterpretCompilerDirectives, self).setUp()
compilation_options = Main.CompilationOptions(Main.default_options)
ctx = compilation_options.create_context()
self.pipeline = [
InterpretCompilerDirectives(ctx, ctx.compiler_directives),
]
self.debug_exception_on_error = DebugFlags.debug_exception_on_error
def tearDown(self):
DebugFlags.debug_exception_on_error = self.debug_exception_on_error
def test_parallel_directives_cimports(self):
self.run_pipeline(self.pipeline, self.import_code)
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
def test_parallel_directives_imports(self):
self.run_pipeline(self.pipeline,
self.import_code.replace(u'cimport', u'import'))
parallel_directives = self.pipeline[0].parallel_directives
self.assertEqual(parallel_directives, self.expected_directives_dict)
# TODO: Re-enable once they're more robust. # TODO: Re-enable once they're more robust.
if sys.version_info[:2] >= (2, 5) and False: if sys.version_info[:2] >= (2, 5) and False:
from Cython.Debugger import DebugWriter from Cython.Debugger import DebugWriter
......
...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type) ...@@ -23,12 +23,24 @@ object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform): class MarkAssignments(CythonTransform):
def mark_assignment(self, lhs, rhs): def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None: if lhs.entry is None:
# TODO: This shouldn't happen... # TODO: This shouldn't happen...
return return
lhs.entry.assignments.append(rhs) lhs.entry.assignments.append(rhs)
if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1]
parallel_node.assignments[lhs.entry] = (lhs.pos, inplace_op)
elif isinstance(lhs, ExprNodes.SequenceNode): elif isinstance(lhs, ExprNodes.SequenceNode):
for arg in lhs.args: for arg in lhs.args:
self.mark_assignment(arg, object_expr) self.mark_assignment(arg, object_expr)
...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform): ...@@ -48,7 +60,7 @@ class MarkAssignments(CythonTransform):
return node return node
def visit_InPlaceAssignmentNode(self, node): def visit_InPlaceAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.create_binop_node()) self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform): ...@@ -127,6 +139,27 @@ class MarkAssignments(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ParallelStatNode(self, node):
if self.parallel_block_stack:
node.parent = self.parallel_block_stack[-1]
else:
node.parent = None
if node.is_prange:
if not node.parent:
node.is_parallel = True
else:
node.is_parallel = (node.parent.is_prange or not
node.parent.is_parallel)
else:
node.is_parallel = True
self.parallel_block_stack.append(node)
self.visitchildren(node)
self.parallel_block_stack.pop()
return node
class MarkOverflowingArithmetic(CythonTransform): class MarkOverflowingArithmetic(CythonTransform):
# It may be possible to integrate this with the above for # It may be possible to integrate this with the above for
......
cdef extern from "omp.h":
ctypedef struct omp_lock_t
ctypedef struct omp_nest_lock_t
ctypedef enum omp_sched_t:
omp_sched_static = 1,
omp_sched_dynamic = 2,
omp_sched_guided = 3,
omp_sched_auto = 4
extern void omp_set_num_threads(int)
extern int omp_get_num_threads()
extern int omp_get_max_threads()
extern int omp_get_thread_num()
extern int omp_get_num_procs()
extern int omp_in_parallel()
extern void omp_set_dynamic(int)
extern int omp_get_dynamic()
extern void omp_set_nested(int)
extern int omp_get_nested()
extern void omp_init_lock(omp_lock_t *)
extern void omp_destroy_lock(omp_lock_t *)
extern void omp_set_lock(omp_lock_t *)
extern void omp_unset_lock(omp_lock_t *)
extern int omp_test_lock(omp_lock_t *)
extern void omp_init_nest_lock(omp_nest_lock_t *)
extern void omp_destroy_nest_lock(omp_nest_lock_t *)
extern void omp_set_nest_lock(omp_nest_lock_t *)
extern void omp_unset_nest_lock(omp_nest_lock_t *)
extern int omp_test_nest_lock(omp_nest_lock_t *)
extern double omp_get_wtime()
extern double omp_get_wtick()
void omp_set_schedule(omp_sched_t, int)
void omp_get_schedule(omp_sched_t *, int *)
int omp_get_thread_limit()
void omp_set_max_active_levels(int)
int omp_get_max_active_levels()
int omp_get_level()
int omp_get_ancestor_thread_num(int)
int omp_get_team_size(int)
int omp_get_active_level()
...@@ -277,3 +277,28 @@ for t in int_types + float_types + complex_types + other_types: ...@@ -277,3 +277,28 @@ for t in int_types + float_types + complex_types + other_types:
void = typedef(None) void = typedef(None)
NULL = p_void(0) NULL = p_void(0)
class CythonDotParallel(object):
"""
The cython.parallel module.
"""
__all__ = ['parallel', 'prange', 'threadid']
parallel = nogil
def prange(self, start=0, stop=None, step=1, schedule=None, nogil=False):
if stop is None:
stop = start
start = 0
return range(start, stop, step)
def threadid(self):
return 0
# def threadsavailable(self):
# return 1
import sys
sys.modules['cython.parallel'] = CythonDotParallel()
del sys
\ No newline at end of file
...@@ -18,6 +18,7 @@ Contents: ...@@ -18,6 +18,7 @@ Contents:
limitations limitations
pyrex_differences pyrex_differences
early_binding_for_speed early_binding_for_speed
parallelism
debugging debugging
Indices and tables Indices and tables
......
.. highlight:: cython
.. py:module:: cython.parallel
**********************************
Using Parallelism
**********************************
Cython supports native parallelism through the :py:mod:`cython.parallel`
module. To use this kind of parallelism, the GIL must be released. It
currently supports OpenMP, but later on more backends might be supported.
.. function:: prange([start,] stop[, step], nogil=False, schedule=None)
This function can be used for parallel loops. OpenMP automatically
starts a thread pool and distributes the work according to the schedule
used. ``step`` must not be 0. This function can only be used with the
GIL released. If ``nogil`` is true, the loop will be wrapped in a nogil
section.
Thread-locality and reductions are automatically inferred for variables.
If you assign to a variable, it becomes lastprivate, meaning that the
variable will contain the value from the last iteration. If you use an
inplace operator on a variable, it becomes a reduction, meaning that the
values from the thread-local copies of the variable will be reduced with
the operator and assigned to the original variable after the loop. The
index variable is always lastprivate.
The ``schedule`` is passed to OpenMP and can be one of the following:
+-----------------+------------------------------------------------------+
| Schedule | Description |
+=================+======================================================+
|static | The iteration space is divided into chunks that are |
| | approximately equal in size, and at most one chunk |
| | is distributed to each thread. |
+-----------------+------------------------------------------------------+
|dynamic | The iterations are distributed to threads in the team|
| | as the threads request them, with a chunk size of 1. |
+-----------------+------------------------------------------------------+
|guided | The iterations are distributed to threads in the team|
| | as the threads request them. The size of each chunk |
| | is proportional to the number of unassigned |
| | iterations divided by the number of threads in the |
| | team, decreasing to 1. |
+-----------------+------------------------------------------------------+
|auto | The decision regarding scheduling is delegated to the|
| | compiler and/or runtime system. The programmer gives |
| | the implementation the freedom to choose any possible|
| | mapping of iterations to threads in the team. |
+-----------------+------------------------------------------------------+
|runtime | The schedule and chunk size are taken from the |
| | runtime-scheduling-variable, which can be set through|
| | the ``omp_set_schedule`` function call, or the |
| | ``OMP_SCHEDULE`` environment variable. |
+-----------------+------------------------------------------------------+
The default schedule is implementation defined. For more information consult
the OpenMP specification: [#]_.
Example with a reduction::
from cython.parallel import prange, parallel, threadid
cdef int i
cdef int sum = 0
for i in prange(n, nogil=True):
sum += i
print sum
Example with a shared numpy array::
from cython.parallel import *
def func(np.ndarray[double] x, double alpha):
cdef Py_ssize_t i
for i in prange(x.shape[0]):
x[i] = alpha * x[i]
.. function:: parallel
This directive can be used as part of a ``with`` statement to execute code
sequences in parallel. This is currently useful to setup thread-local
buffers used by a prange. A contained prange will be a worksharing loop
that is not parallel, so any variable assigned to in the parallel section
is also private to the prange. Variables that are private in the parallel
construct are undefined after the parallel block.
Example with thread-local buffers::
from cython.parallel import *
from cython.stdlib cimport abort
cdef Py_ssize_t i, n = 100
cdef int * local_buf
cdef size_t size = 10
with nogil, parallel:
local_buf = malloc(sizeof(int) * size)
if local_buf == NULL:
abort()
# populate our local buffer in a sequential loop
for i in range(size):
local_buf[i] = i * 2
# share the work using the thread-local buffer(s)
for i in prange(n, schedule='guided'):
func(local_buf)
free(local_buf)
Later on sections might be supported in parallel blocks, to distribute
code sections of work among threads.
.. function:: threadid()
Returns the id of the thread. For n threads, the ids will range from 0 to
n.
Compiling
=========
To actually use the OpenMP support, you need to tell the C or C++ compiler to
enable OpenMP. For gcc this can be done as follows in a setup.py::
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_module = Extension(
"hello",
["hello.pyx"],
extra_compile_args=['-fopenmp'],
libraries=['gomp'],
)
setup(
name = 'Hello world app',
cmdclass = {'build_ext': build_ext},
ext_modules = [ext_module],
)
.. rubric:: References
.. [#] http://www.openmp.org/mp-documents/spec30.pdf
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import sys import sys
import re import re
import gc import gc
import locale
import codecs import codecs
import shutil import shutil
import time import time
...@@ -11,6 +12,7 @@ import unittest ...@@ -11,6 +12,7 @@ import unittest
import doctest import doctest
import operator import operator
import tempfile import tempfile
import warnings
import traceback import traceback
try: try:
from StringIO import StringIO from StringIO import StringIO
...@@ -54,6 +56,7 @@ CY3_DIR = None ...@@ -54,6 +56,7 @@ CY3_DIR = None
from distutils.dist import Distribution from distutils.dist import Distribution
from distutils.core import Extension from distutils.core import Extension
from distutils.command.build_ext import build_ext as _build_ext from distutils.command.build_ext import build_ext as _build_ext
from distutils import sysconfig
distutils_distro = Distribution() distutils_distro = Distribution()
if sys.platform == 'win32': if sys.platform == 'win32':
...@@ -78,8 +81,83 @@ def update_numpy_extension(ext): ...@@ -78,8 +81,83 @@ def update_numpy_extension(ext):
import numpy import numpy
ext.include_dirs.append(numpy.get_include()) ext.include_dirs.append(numpy.get_include())
def update_openmp_extension(ext):
language = ext.language
if language == 'cpp':
flags = OPENMP_CPP_COMPILER_FLAGS
else:
flags = OPENMP_C_COMPILER_FLAGS
if flags:
compile_flags, link_flags = flags
ext.extra_compile_args.extend(compile_flags.split())
ext.extra_link_args.extend(link_flags.split())
return ext
return EXCLUDE_EXT
def get_openmp_compiler_flags(language):
"""
As of gcc 4.2, it supports OpenMP 2.5. Gcc 4.4 implements 3.0. We don't
(currently) check for other compilers.
returns a two-tuple of (CFLAGS, LDFLAGS) to build the OpenMP extension
"""
if language == 'cpp':
cc = sysconfig.get_config_var('CXX')
else:
cc = sysconfig.get_config_var('CC')
# For some reason, cc can be e.g. 'gcc -pthread'
cc = cc.split()[0]
matcher = re.compile(r"gcc version (\d+\.\d+)").search
try:
import subprocess
except ImportError:
try:
in_, out, err = os.popen(cc + " -v")
except EnvironmentError:
# Be compatible with Python 3
_, e, _ = sys.exc_info()
warnings.warn("Unable to find the %s compiler: %s: %s" %
(language, os.strerror(e.errno), cc))
return None
output = out.read() or err.read()
else:
try:
p = subprocess.Popen([cc, "-v"], stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
except EnvironmentError:
# Be compatible with Python 3
_, e, _ = sys.exc_info()
warnings.warn("Unable to find the %s compiler: %s: %s" %
(language, os.strerror(e.errno), cc))
return None
output = p.stdout.read()
output = output.decode(locale.getpreferredencoding() or 'UTF-8')
compiler_version = matcher(output).group(1)
if compiler_version and compiler_version.split('.') >= ['4', '2']:
return '-fopenmp', '-fopenmp'
locale.setlocale(locale.LC_ALL, '')
OPENMP_C_COMPILER_FLAGS = get_openmp_compiler_flags('c')
OPENMP_CPP_COMPILER_FLAGS = get_openmp_compiler_flags('cpp')
# Return this from the EXT_EXTRAS matcher callback to exclude the extension
EXCLUDE_EXT = object()
EXT_EXTRAS = { EXT_EXTRAS = {
'tag:numpy' : update_numpy_extension, 'tag:numpy' : update_numpy_extension,
'tag:openmp': update_openmp_extension,
} }
# TODO: use tags # TODO: use tags
...@@ -519,13 +597,21 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -519,13 +597,21 @@ class CythonCompileTestCase(unittest.TestCase):
extra_compile_args = ext_compile_flags, extra_compile_args = ext_compile_flags,
**extra_extension_args **extra_extension_args
) )
if self.language == 'cpp':
# Set the language now as the fixer might need it
extension.language = 'c++'
for matcher, fixer in EXT_EXTRAS.items(): for matcher, fixer in EXT_EXTRAS.items():
if isinstance(matcher, str): if isinstance(matcher, str):
del EXT_EXTRAS[matcher] del EXT_EXTRAS[matcher]
matcher = string_selector(matcher) matcher = string_selector(matcher)
EXT_EXTRAS[matcher] = fixer EXT_EXTRAS[matcher] = fixer
if matcher(module, tags): if matcher(module, tags):
extension = fixer(extension) or extension newext = fixer(extension)
if newext is EXCLUDE_EXT:
return
extension = newext or extension
if self.language == 'cpp': if self.language == 'cpp':
extension.language = 'c++' extension.language = 'c++'
build_extension.extensions = [extension] build_extension.extensions = [extension]
...@@ -646,6 +732,7 @@ def run_forked_test(result, run_func, test_name, fork=True): ...@@ -646,6 +732,7 @@ def run_forked_test(result, run_func, test_name, fork=True):
try: try:
cid, result_code = os.waitpid(child_id, 0) cid, result_code = os.waitpid(child_id, 0)
module_name = test_name.split()[-1]
# os.waitpid returns the child's result code in the # os.waitpid returns the child's result code in the
# upper byte of result_code, and the signal it was # upper byte of result_code, and the signal it was
# killed by in the lower byte # killed by in the lower byte
......
# mode: error
cimport cython.parallel.parallel as p
from cython.parallel cimport something
import cython.parallel.parallel as p
from cython.parallel import something
from cython.parallel cimport prange
import cython.parallel
prange(1, 2, 3, schedule='dynamic')
cdef int i
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='invalid_schedule'):
pass
with cython.parallel.parallel:
print "hello world!"
cdef int *x = NULL
with nogil, cython.parallel.parallel:
for j in prange(10):
pass
for x[1] in prange(10):
pass
for x in prange(10):
pass
with cython.parallel.parallel:
pass
_ERRORS = u"""
e_cython_parallel.pyx:3:8: cython.parallel.parallel is not a module
e_cython_parallel.pyx:4:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:6:7: cython.parallel.parallel is not a module
e_cython_parallel.pyx:7:0: No such directive: cython.parallel.something
e_cython_parallel.pyx:13:6: prange() can only be used as part of a for loop
e_cython_parallel.pyx:13:6: prange() can only be used without the GIL
e_cython_parallel.pyx:18:19: Invalid schedule argument to prange: invalid_schedule
e_cython_parallel.pyx:21:5: The parallel section may only be used without the GIL
e_cython_parallel.pyx:27:10: target may not be a Python object as we don't have the GIL
e_cython_parallel.pyx:30:9: Can only iterate over an iteration variable
e_cython_parallel.pyx:33:10: Must be of numeric type, not int *
e_cython_parallel.pyx:36:24: Closely nested 'with parallel:' blocks are disallowed
"""
# tag: numpy
# tag: openmp
cimport cython
from cython.parallel import prange
cimport numpy as np
@cython.boundscheck(False)
def test_parallel_numpy_arrays():
"""
>>> test_parallel_numpy_arrays()
-5
-4
-3
-2
-1
0
1
2
3
4
"""
cdef Py_ssize_t i
cdef np.ndarray[np.int_t] x
try:
import numpy
except ImportError:
for i in range(-5, 5):
print i
return
x = numpy.zeros(10, dtype=numpy.int)
for i in prange(x.shape[0], nogil=True):
x[i] = i - 5
for i in x:
print i
# tag: run
# tag: openmp
cimport cython.parallel
from cython.parallel import prange, threadid
cimport openmp
from libc.stdlib cimport malloc, free
def test_parallel():
"""
>>> test_parallel()
"""
cdef int maxthreads = openmp.omp_get_max_threads()
cdef int *buf = <int *> malloc(sizeof(int) * maxthreads)
if buf == NULL:
raise MemoryError
with nogil, cython.parallel.parallel:
buf[threadid()] = threadid()
for i in range(maxthreads):
assert buf[i] == i
free(buf)
include "sequential_parallel.pyx"
# tag: run
cimport cython.parallel
from cython.parallel import prange, threadid
from libc.stdlib cimport malloc, free, abort
from libc.stdio cimport puts
import sys
try:
from builtins import next # Py3k
except ImportError:
def next(it):
return it.next()
#@cython.test_assert_path_exists(
# "//ParallelWithBlockNode//ParallelRangeNode[@schedule = 'dynamic']",
# "//GILStatNode[@state = 'nogil]//ParallelRangeNode")
def test_prange():
"""
>>> test_prange()
(9, 9, 45, 45)
"""
cdef Py_ssize_t i, j, sum1 = 0, sum2 = 0
with nogil, cython.parallel.parallel:
for i in prange(10, schedule='dynamic'):
sum1 += i
for j in prange(10, nogil=True):
sum2 += j
return i, j, sum1, sum2
def test_descending_prange():
"""
>>> test_descending_prange()
5
"""
cdef int i, start = 5, stop = -5, step = -2
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
sum += i
return sum
def test_propagation():
"""
>>> test_propagation()
(9, 9, 9, 9, 450, 450)
"""
cdef int i, j, x, y
cdef int sum1 = 0, sum2 = 0
for i in prange(10, nogil=True):
for j in prange(10):
sum1 += i
with nogil, cython.parallel.parallel:
for x in prange(10):
with cython.parallel.parallel:
for y in prange(10):
sum2 += y
return i, j, x, y, sum1, sum2
def test_unsigned_operands():
"""
>>> test_unsigned_operands()
10
"""
cdef int i
cdef int start = -5
cdef unsigned int stop = 5
cdef int step = 1
cdef int steps_taken = 0
for i in prange(start, stop, step, nogil=True):
steps_taken += 1
if steps_taken > 10:
abort()
return steps_taken
def test_reassign_start_stop_step():
"""
>>> test_reassign_start_stop_step()
20
"""
cdef int start = 0, stop = 10, step = 2
cdef int i
cdef int sum = 0
for i in prange(start, stop, step, nogil=True):
start = -2
stop = 2
step = 0
sum += i
return sum
def test_closure_parallel_privates():
"""
>>> test_closure_parallel_privates()
9 9
45 45
0 0 9 9
"""
cdef int x
def test_target():
nonlocal x
for x in prange(10, nogil=True):
pass
return x
print test_target(), x
def test_reduction():
nonlocal x
cdef int i
x = 0
for i in prange(10, nogil=True):
x += i
return x
print test_reduction(), x
def test_generator():
nonlocal x
cdef int i
x = 0
yield x
x = 2
for i in prange(10, nogil=True):
x = i
yield x
g = test_generator()
print next(g), x, next(g), x
def test_pure_mode():
"""
>>> test_pure_mode()
0
1
2
3
4
4
3
2
1
0
0
"""
import Cython.Shadow
pure_parallel = sys.modules['cython.parallel']
for i in pure_parallel.prange(5):
print i
for i in pure_parallel.prange(4, -1, -1, schedule='dynamic', nogil=True):
print i
with pure_parallel.parallel:
print pure_parallel.threadid()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment