Commit 79aab7df authored by Kirill Smelkov's avatar Kirill Smelkov

sync.WorkGroup: Propagate all exception types, not only those derived from Exception

A problem was hit with pytest.fail with raises Failed exception not
being propagated to .wait. As it turned out it was not propagated
because pytest's Failed derives from BaseException, not Exception, and
we were catching only Exception and its children.

Rework the code to propagate all exception types from workers.

Performance change is with noise (it is either a bit faster for one set
of runs, or a bit slower for another set of runs).
parent 94c6160b
...@@ -121,15 +121,18 @@ class WorkGroup(object): ...@@ -121,15 +121,18 @@ class WorkGroup(object):
try: try:
f(g._ctx, *argv, **kw) f(g._ctx, *argv, **kw)
except Exception as exc: except:
_, exc, tb = sys.exc_info()
with g._mu: with g._mu:
if g._err is None: if g._err is None:
# this goroutine is the first failed task # this goroutine is the first failed task
g._err = exc g._err = exc
if six.PY2: if six.PY2:
# py3 has __traceback__ automatically # py3 has __traceback__ automatically
exc.__traceback__ = sys.exc_info()[2] exc.__traceback__ = tb
g._cancel() g._cancel()
exc = None
tb = None
def wait(g): def wait(g):
......
...@@ -25,6 +25,7 @@ from golang import sync, context ...@@ -25,6 +25,7 @@ from golang import sync, context
import time, threading import time, threading
from pytest import raises from pytest import raises
from six.moves import range as xrange from six.moves import range as xrange
import six
def test_once(): def test_once():
once = sync.Once() once = sync.Once()
...@@ -100,6 +101,17 @@ def test_workgroup(): ...@@ -100,6 +101,17 @@ def test_workgroup():
wg.wait() wg.wait()
assert l == [1, 2] assert l == [1, 2]
# WorkGroup must catch/propagate all exception classes.
# Python2 allows to raise old-style classes not derived from BaseException.
# Python3 allows to raise only BaseException derivatives.
if six.PY2:
class MyError:
def __init__(self, *args):
self.args = args
else:
class MyError(BaseException):
pass
# t1=fail, t2=ok, does not look at ctx # t1=fail, t2=ok, does not look at ctx
wg = sync.WorkGroup(ctx) wg = sync.WorkGroup(ctx)
l = [0, 0] l = [0, 0]
...@@ -109,15 +121,15 @@ def test_workgroup(): ...@@ -109,15 +121,15 @@ def test_workgroup():
with mu: with mu:
l[i] = i+1 l[i] = i+1
if i == 0: if i == 0:
raise RuntimeError('aaa') raise MyError('aaa')
def f(ctx, i): def f(ctx, i):
Iam_f = 0 Iam_f = 0
_(ctx, i) _(ctx, i)
wg.go(f, i) wg.go(f, i)
with raises(RuntimeError) as exc: with raises(MyError) as exc:
wg.wait() wg.wait()
assert exc.type is RuntimeError assert exc.type is MyError
assert exc.value.args == ('aaa',) assert exc.value.args == ('aaa',)
assert 'Iam__' in exc.traceback[-1].locals assert 'Iam__' in exc.traceback[-1].locals
assert 'Iam_f' in exc.traceback[-2].locals assert 'Iam_f' in exc.traceback[-2].locals
...@@ -132,18 +144,18 @@ def test_workgroup(): ...@@ -132,18 +144,18 @@ def test_workgroup():
with mu: with mu:
l[i] = i+1 l[i] = i+1
if i == 0: if i == 0:
raise RuntimeError('bbb') raise MyError('bbb')
if i == 1: if i == 1:
ctx.done().recv() ctx.done().recv()
raise ValueError('ccc') # != RuntimeError raise ValueError('ccc') # != MyError
def f(ctx, i): def f(ctx, i):
Iam_f = 0 Iam_f = 0
_(ctx, i) _(ctx, i)
wg.go(f, i) wg.go(f, i)
with raises(RuntimeError) as exc: with raises(MyError) as exc:
wg.wait() wg.wait()
assert exc.type is RuntimeError assert exc.type is MyError
assert exc.value.args == ('bbb',) assert exc.value.args == ('bbb',)
assert 'Iam__' in exc.traceback[-1].locals assert 'Iam__' in exc.traceback[-1].locals
assert 'Iam_f' in exc.traceback[-2].locals assert 'Iam_f' in exc.traceback[-2].locals
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment