Commit 86234dee authored by Mark Florisson's avatar Mark Florisson

Fix num_thread for prange() without parallel() + more error checks

parent c54e21df
......@@ -5915,7 +5915,11 @@ class ParallelStatNode(StatNode, ParallelNode):
self.body.analyse_declarations(env)
if self.kwargs:
self.kwargs = self.kwargs.compile_time_value(env)
try:
self.kwargs = self.kwargs.compile_time_value(env)
except Exception, e:
error(self.kwargs.pos, "Only compile-time values may be "
"supplied as keyword arguments")
else:
self.kwargs = {}
......@@ -5929,6 +5933,17 @@ class ParallelStatNode(StatNode, ParallelNode):
self.body.analyse_expressions(env)
self.analyse_sharing_attributes(env)
if self.num_threads is not None:
if self.parent and self.parent.num_threads is not None:
error(self.pos,
"num_threads already declared in outer section")
elif not isinstance(self.num_threads, (int, long)):
error(self.pos,
"Invalid value for num_threads argument, expected an int")
elif self.num_threads <= 0:
error(self.pos,
"argument to num_threads must be greater than 0")
def analyse_sharing_attributes(self, env):
"""
Analyse the privates for this block and set them in self.privates.
......@@ -6068,11 +6083,8 @@ class ParallelStatNode(StatNode, ParallelNode):
Write self.num_threads if set as the num_threads OpenMP directive
"""
if self.num_threads is not None:
if isinstance(self.num_threads, (int, long)):
code.put(" num_threads(%d)" % (self.num_threads,))
else:
error(self.pos, "Invalid value for num_threads argument, "
"expected an int")
code.put(" num_threads(%d)" % (self.num_threads,))
def declare_closure_privates(self, code):
"""
......@@ -6727,11 +6739,11 @@ class ParallelRangeNode(ParallelStatNode):
if not self.is_parallel:
code.put("#pragma omp for")
self.privatization_insertion_point = code.insertion_point()
# reduction_codepoint = self.parent.privatization_insertion_point
reduction_codepoint = self.parent.privatization_insertion_point
else:
code.put("#pragma omp parallel")
self.privatization_insertion_point = code.insertion_point()
# reduction_codepoint = self.privatization_insertion_point
reduction_codepoint = self.privatization_insertion_point
code.putln("")
code.putln("#endif /* _OPENMP */")
......@@ -6743,11 +6755,6 @@ class ParallelRangeNode(ParallelStatNode):
code.putln("#ifdef _OPENMP")
code.put("#pragma omp for")
# Nested parallelism is not supported, so we can put reductions on the
# for and not on the parallel (but would be valid, but gcc45 bugs on
# the former)
reduction_codepoint = code
for entry, (op, lastprivate) in self.privates.iteritems():
# Don't declare the index variable as a reduction
if op and op in "+*-&^|" and entry != self.target.entry:
......
# mode: error
from cython.parallel cimport parallel, prange
cdef int i
# valid
with nogil, parallel(num_threads=None):
pass
# invalid
with nogil, parallel(num_threads=0):
pass
with nogil, parallel(num_threads=i):
pass
with nogil, parallel(num_threads=2):
for i in prange(10, num_threads=2):
pass
_ERRORS = u"""
e_invalid_num_threads.pyx:12:20: argument to num_threads must be greater than 0
e_invalid_num_threads.pyx:15:20: Invalid value for num_threads argument, expected an int
e_invalid_num_threads.pyx:19:19: num_threads already declared in outer section
"""
......@@ -720,3 +720,15 @@ def test_nogil_cdef_except_clause():
for i in prange(10, nogil=True):
nogil_cdef_except_clause()
nogil_cdef_except_star()
def test_num_threads_compile():
cdef int i
for i in prange(10, nogil=True, num_threads=2):
pass
with nogil, cython.parallel.parallel(num_threads=2):
pass
with nogil, cython.parallel.parallel():
for i in prange(10, num_threads=2):
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment