Commit 3fccb633 authored by Xavier Thompson's avatar Xavier Thompson

Rethink exception propagation from forks

Previously, an exception thrown from a fork was rethrown:
- a) when the call to fork() returns if the fork completes synchronously
- b) otherwise, at the next explicit call to Sync() if there is one
- c) otherwise, when the call to the enclosing Join coroutine returns

In the case of infinite fork loops, this meant exceptions might never be
propagated.

Now, when an exception is thrown from a fork it's always rethrown when
the call to the enclosing Join coroutine returns. The body of the Join
coroutine just stops executing as soon as possible once a fork signals
an exception. This will be at the call to fork() if the fork completes
synchronously, or at any ensuing call to Sync() or fork() otherwise.

Essentially once an exception is signaled from a parallel fork, the next
call to fork() behaves like Sync() instead of creating a fork, and once
all the parallel forks have completed, execution resumes directly at the
call to the enclosing Join coroutine, where the exception is rethrown.
parent fe0843f5
......@@ -4,9 +4,12 @@
#include <atomic>
#include <coroutine>
#include <cstdint>
#include <type_traits>
#include <typon/defer.hpp>
#include <typon/fork_refcount.hpp>
#include <typon/forked.hpp>
#include <typon/meta.hpp>
#include <typon/result.hpp>
#include <typon/scheduler.hpp>
#include <typon/span.hpp>
......@@ -15,109 +18,13 @@
namespace typon
{
namespace policy
{
struct Bundle
{
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
(void) coroutine;
}
struct OnAwaitable
{
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
(void) coroutine;
}
template <typename Promise>
auto on_await_resume(std::coroutine_handle<Promise> coroutine)
{
using T = typename Promise::value_type;
auto thefts = coroutine.promise()._span->_thefts;
auto rank = coroutine.promise()._rank;
bool ready = (thefts == rank);
if (!ready)
{
coroutine.promise()._span->_children.push_back(coroutine);
}
return Forked<T>(coroutine, ready, nullptr);
}
};
};
struct Refcnt
{
ForkNode _node;
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
(void) coroutine;
_node.decref();
}
struct OnAwaitable
{
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
coroutine.promise()._policy._node._coroutine = coroutine;
}
template <typename Promise>
auto on_await_resume(std::coroutine_handle<Promise> coroutine)
{
using T = typename Promise::value_type;
auto thefts = coroutine.promise()._span->_thefts;
auto rank = coroutine.promise()._rank;
bool ready = (thefts == rank);
auto node = &(coroutine.promise()._policy._node);
return Forked<T>(coroutine, ready, node);
}
};
};
struct Drop
{
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
coroutine.destroy();
}
struct OnAwaitable
{
Span * _span;
Span::u64 _rank;
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
_span = coroutine.promise()._span;
_rank = _span->_thefts;
}
template <typename Promise>
void on_await_resume(std::coroutine_handle<Promise> coroutine)
{
if (_span->_thefts == _rank)
{
Defer defer { [coroutine]() { coroutine.destroy(); } };
coroutine.promise().get();
}
}
};
};
}
template <typename T = void, typename P = policy::Refcnt>
template <typename T = void>
struct [[nodiscard]] Fork
{
struct promise_type;
using u64 = Span::u64;
using Policy = std::conditional_t<std::is_same_v<T, void>, policy::Drop, P>;
static constexpr bool is_void { std::is_same_v<T, void> };
std::coroutine_handle<promise_type> _coroutine;
......@@ -133,9 +40,10 @@ namespace typon
struct promise_type : Result<T>
{
using Refcount = std::conditional_t<is_void, meta::Empty, ForkRefcount>;
Span * _span;
u64 _rank;
[[no_unique_address]] Policy _policy;
[[no_unique_address]] Refcount _refcount;
Fork get_return_object() noexcept
{
......@@ -154,19 +62,35 @@ namespace typon
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept
{
auto span = coroutine.promise()._span;
auto exception = std::move(coroutine.promise()._exception);
if constexpr(is_void)
{
coroutine.destroy();
}
if (Scheduler::pop())
{
if (exception)
{
span->set_sequential_exception(exception);
if ((span->_thefts == 0) || span->notify_sync())
{
return span->_continuation;
}
return std::noop_coroutine();
}
return span->_coroutine;
}
if (auto & exception = coroutine.promise()._exception)
if (exception)
{
span->set_concurrent_exception(exception);
}
if constexpr(!is_void)
{
span->set_exception(exception, coroutine.promise()._rank);
coroutine.promise()._refcount.decref();
}
coroutine.promise()._policy.on_final_suspend(coroutine);
u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel);
if (n == 1)
if (span->notify_fork())
{
return span->continuation();
return span->fork_continuation();
}
return std::noop_coroutine();
}
......@@ -179,31 +103,51 @@ namespace typon
struct awaitable : std::suspend_always
{
std::coroutine_handle<promise_type> _coroutine;
[[no_unique_address]] Policy::OnAwaitable _policy;
u64 _thefts;
awaitable(std::coroutine_handle<promise_type> coroutine)
: _coroutine(coroutine)
{}
template <typename Promise>
auto await_suspend(std::coroutine_handle<Promise> continuation) noexcept
std::coroutine_handle<> await_suspend(std::coroutine_handle<Promise> continuation) noexcept
{
Span * span = &(continuation.promise()._span);
_coroutine.promise()._span = span;
_coroutine.promise()._rank = span->_thefts;
_policy.on_await_suspend(_coroutine);
_thefts = span->_thefts;
if constexpr(!is_void)
{
_coroutine.promise()._refcount.set(_coroutine);
}
if (_thefts && span->has_concurrent_exception())
{
// Destroy the fork because it will not be run.
_coroutine.destroy();
if (span->notify_sync())
{
return span->_continuation;
}
return std::noop_coroutine();
}
std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(span);
return on_stack_handle;
}
auto await_resume()
auto await_resume() noexcept
{
return _policy.on_await_resume(_coroutine);
if constexpr(!is_void)
{
auto span = _coroutine.promise()._span;
bool ready = (span->_thefts == _thefts);
return Forked<T>(_coroutine, ready);
}
}
};
auto operator co_await() &&
{
return awaitable { {}, _coroutine, {} };
return awaitable { _coroutine };
}
};
......
#ifndef TYPON_FORK_REFCOUNT_HPP_INCLUDED
#define TYPON_FORK_REFCOUNT_HPP_INCLUDED
#include <atomic>
#include <coroutine>
namespace typon
{
struct ForkRefcount
{
std::coroutine_handle<> _coroutine;
std::atomic<bool> _refcount {true};
void set(std::coroutine_handle<> coroutine) noexcept
{
_coroutine = coroutine;
}
void decref() noexcept
{
if (!_refcount.exchange(false, std::memory_order_acq_rel))
{
_coroutine.destroy();
}
}
};
}
#endif // TYPON_FORK_REFCOUNT_HPP_INCLUDED
......@@ -7,106 +7,123 @@
#include <utility>
#include <typon/defer.hpp>
#include <typon/fork_refcount.hpp>
#include <typon/result.hpp>
namespace typon
{
struct ForkNode
{
std::coroutine_handle<> _coroutine;
std::atomic<bool> _ref {true};
void decref() noexcept
{
if (!_ref.exchange(false, std::memory_order_acq_rel))
{
_coroutine.destroy();
}
}
};
template <typename T>
struct ForkResult
struct Forked
{
using value_type = T;
Result<T> * _result = nullptr;
union
{
T _value;
ForkNode * _node;
ForkRefcount * _refcount;
};
template <typename Promise>
void construct_value(std::coroutine_handle<Promise> coroutine)
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{
std::construct_at(std::addressof(_value), coroutine.promise().get());
}
void construct_value(ForkResult && other)
{
std::construct_at(std::addressof(_value), std::move(other._value));
if (ready)
{
std::construct_at(std::addressof(_value), coroutine.promise().value());
coroutine.destroy();
}
else
{
_refcount = &(coroutine.promise()._refcount);
_result = &(coroutine.promise());
}
}
T get_value() noexcept
Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
{
return _value;
_result = other._result;
if (_result)
{
_refcount = std::exchange(other._refcount, nullptr);
}
else
{
std::construct_at(std::addressof(_value), std::move(other._value));
}
}
void destroy_value() noexcept
Forked& operator=(Forked && other)
noexcept(std::is_nothrow_move_constructible_v<T>)
{
std::destroy_at(std::addressof(_value));
if (this != &other)
{
Forked old { std::move(*this) };
_result = other._result;
if (_result)
{
_refcount = std::exchange(other._refcount, nullptr);
}
else
{
std::construct_at(std::addressof(_value), std::move(other._value));
}
}
return *this;
}
};
template <typename T>
struct ForkResult<T&>
{
union
{
T * _value;
ForkNode * _node;
};
template <typename Promise>
void construct_value(std::coroutine_handle<Promise> coroutine)
~Forked()
{
_value = std::addressof(coroutine.promise().get());
if (_result)
{
_refcount->decref();
}
else
{
std::destroy_at(std::addressof(_value));
}
}
void construct_value(ForkResult && other)
T get() &
{
_value = other._value;
if (_result)
{
return _result->value();
}
return _value;
}
T& get_value() noexcept
T get() &&
{
return *_value;
if (_result)
{
return _result->value();
}
return std::move(_value);
}
void destroy_value() noexcept {}
};
template <typename T>
struct Forked : ForkResult<T>
struct Forked<T&>
{
using value_type = T;
Result<T> * _result = nullptr;
void * _data;
template <typename Coroutine>
Forked(Coroutine coroutine, bool ready, ForkNode * node)
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{
if (ready)
{
Defer defer { [&coroutine]() { coroutine.destroy(); } };
this->construct_value(coroutine);
_data = std::addressof(coroutine.promise().value());
coroutine.destroy();
}
else
{
this->_node = node;
_data = &(coroutine.promise()._refcount);
_result = &(coroutine.promise());
}
}
......@@ -114,32 +131,14 @@ namespace typon
Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
{
_result = other._result;
if (_result)
{
this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
}
_data = std::exchange(other._data, nullptr);
}
Forked& operator=(Forked && other)
Forked& operator=(Forked other)
noexcept(std::is_nothrow_move_constructible_v<T>)
{
if (this != &other)
{
Forked old { std::move(*this) };
_result = other._result;
if (_result)
{
this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
}
}
std::swap(_result, other._result);
std::swap(_data, other._data);
return *this;
}
......@@ -147,24 +146,17 @@ namespace typon
{
if (_result)
{
if (auto node = this->_node)
{
node->decref();
}
}
else
{
this->destroy_value();
reinterpret_cast<ForkRefcount *>(_data)->decref();
}
}
T get()
T& get()
{
if (_result)
{
return _result->get();
return _result->value();
}
return this->get_value();
return *(reinterpret_cast<T *>(_data));
}
};
......@@ -179,8 +171,8 @@ namespace typon
Result<T> * _result;
void * _coroutine;
template <typename Coroutine>
Forked(Coroutine coroutine, bool ready, ForkNode * node)
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{
_ready = ready;
_result = &(coroutine.promise());
......@@ -191,7 +183,7 @@ namespace typon
}
else
{
_coroutine = node;
_coroutine = &(coroutine.promise()._refcount);
}
}
......@@ -222,14 +214,14 @@ namespace typon
}
else
{
reinterpret_cast<ForkNode *>(_coroutine)->decref();
reinterpret_cast<ForkRefcount *>(_coroutine)->decref();
}
}
}
T get()
{
return _result->get();
return _result->value();
}
};
......
......@@ -4,6 +4,7 @@
#include <coroutine>
#include <utility>
#include <typon/meta.hpp>
#include <typon/result.hpp>
#include <typon/span.hpp>
......@@ -19,9 +20,11 @@ namespace typon
{
struct promise_type;
std::coroutine_handle<promise_type> _coroutine;
using coroutine_type = std::coroutine_handle<promise_type>;
Join(std::coroutine_handle<promise_type> coroutine) noexcept : _coroutine(coroutine) {}
coroutine_type _coroutine;
Join(coroutine_type coroutine) noexcept : _coroutine(coroutine) {}
Join(const Join &) = delete;
Join & operator=(const Join &) = delete;
......@@ -44,7 +47,7 @@ namespace typon
}
}
struct promise_type : Result<T>
struct promise_type : Result<T, meta::Empty>
{
using u64 = Span::u64;
static constexpr u64 UMAX = Span::UMAX;
......@@ -52,12 +55,12 @@ namespace typon
Span _span;
promise_type() noexcept
: _span(std::coroutine_handle<promise_type>::from_promise(*this))
: _span(coroutine_type::from_promise(*this))
{}
Join get_return_object() noexcept
{
return { std::coroutine_handle<promise_type>::from_promise(*this) };
return { coroutine_type::from_promise(*this) };
}
std::suspend_always initial_suspend() noexcept
......@@ -65,6 +68,11 @@ namespace typon
return {};
}
void unhandled_exception() noexcept
{
_span.set_sequential_exception(std::current_exception());
}
template <typename U>
decltype(auto) await_transform(U && expr) noexcept
{
......@@ -79,34 +87,22 @@ namespace typon
bool await_ready() noexcept
{
if (u64 thefts = _span._thefts)
{
u64 n = _span._n.load(std::memory_order_acquire);
if (n - (UMAX - thefts) == 0)
{
return true;
}
return false;
}
return true;
return (_span._thefts == 0);
}
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept
std::coroutine_handle<> await_suspend(coroutine_type coroutine) noexcept
{
u64 thefts = _span._thefts;
u64 n = _span._n.fetch_sub(UMAX - thefts, std::memory_order_acq_rel);
if (n - (UMAX - thefts) == 0)
(void) coroutine;
if (_span.notify_sync())
{
return coroutine;
return _span.sync_continuation();
}
return std::noop_coroutine();
}
void await_resume()
void await_resume() noexcept
{
_span._thefts = 0;
_span._n.store(UMAX, std::memory_order_release);
_span.check_exception();
_span.reset_sync();
}
};
......@@ -117,16 +113,10 @@ namespace typon
{
struct awaitable : std::suspend_always
{
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept
std::coroutine_handle<> await_suspend(coroutine_type coroutine) noexcept
{
Span & span = coroutine.promise()._span;
u64 thefts = span._thefts;
if (thefts == 0)
{
return span._continuation;
}
u64 n = span._n.fetch_sub(UMAX - thefts, std::memory_order_acq_rel);
if (n - (UMAX - thefts) == 0)
if ((span._thefts == 0) || span.notify_sync())
{
return span._continuation;
}
......@@ -142,14 +132,14 @@ namespace typon
{
struct awaitable
{
std::coroutine_handle<promise_type> _coroutine;
coroutine_type _coroutine;
bool await_ready() noexcept
{
return false;
}
std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) noexcept
auto await_suspend(std::coroutine_handle<> continuation) noexcept
{
_coroutine.promise()._span._continuation = continuation;
return _coroutine;
......@@ -157,8 +147,8 @@ namespace typon
decltype(auto) await_resume()
{
_coroutine.promise()._span.check_exception();
return _coroutine.promise().get();
_coroutine.promise()._span.propagate_exception();
return _coroutine.promise().value();
}
};
......
#ifndef TYPON_META_HPP_INCLUDED
#define TYPON_META_HPP_INCLUDED
namespace typon::meta
{
struct Empty {};
}
#endif // TYPON_META_HPP_INCLUDED
......@@ -11,13 +11,13 @@
namespace typon
{
template <typename T>
template <typename T, typename E = std::exception_ptr>
struct Result
{
using value_type = T;
bool _valid = false;
std::exception_ptr _exception;
[[no_unique_address]] E _exception;
union
{
T _value;
......@@ -40,37 +40,56 @@ namespace typon
void unhandled_exception() noexcept
{
_exception = std::current_exception();
if constexpr(std::is_assignable_v<E, std::exception_ptr>)
{
_exception = std::current_exception();
}
}
T& get() &
{
if (_exception)
if constexpr(std::is_same_v<E, std::exception_ptr>)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
if (_exception)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
}
}
return _value;
}
T&& get() &&
{
if (_exception)
if constexpr(std::is_same_v<E, std::exception_ptr>)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
if (_exception)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
}
}
return std::move(_value);
}
T& value() & noexcept
{
return _value;
}
T&& value() && noexcept
{
return std::move(_value);
}
};
template <typename T>
struct Result<T&>
template <typename T, typename E>
struct Result<T&, E>
{
using value_type = T&;
T* _value;
std::exception_ptr _exception;
[[no_unique_address]] E _exception;
void return_value(T& expr) noexcept
{
......@@ -79,41 +98,60 @@ namespace typon
void unhandled_exception() noexcept
{
_exception = std::current_exception();
if constexpr(std::is_assignable_v<E, std::exception_ptr>)
{
_exception = std::current_exception();
}
}
T& get() &
{
if (_exception)
if constexpr(std::is_same_v<E, std::exception_ptr>)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
if (_exception)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
}
}
return *_value;
}
T& value() & noexcept
{
return *_value;
}
};
template <>
struct Result<void>
template <typename E>
struct Result<void, E>
{
using value_type = void;
std::exception_ptr _exception;
[[no_unique_address]] E _exception;
void return_void() noexcept {}
void unhandled_exception() noexcept
{
_exception = std::current_exception();
if constexpr(std::is_assignable_v<E, std::exception_ptr>)
{
_exception = std::current_exception();
}
}
void get()
{
if (_exception)
if constexpr(std::is_same_v<E, std::exception_ptr>)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
if (_exception)
{
std::rethrow_exception(std::exchange(_exception, nullptr));
}
}
}
void value() noexcept {}
};
}
......
......@@ -20,19 +20,14 @@ namespace typon
{
using u64 = TheftPoint::u64;
struct Error
{
u64 _rank;
std::exception_ptr _exception;
};
static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _continuation;
std::atomic<Error *> _error { nullptr };
std::vector<std::coroutine_handle<>> _children;
std::atomic<bool> _concurrent_error_flag { false };
std::exception_ptr _concurrent_exception;
std::exception_ptr _sequential_exception;
std::atomic<u64> _n = UMAX;
......@@ -40,55 +35,52 @@ namespace typon
: TheftPoint(coroutine)
{}
~Span()
void propagate_exception()
{
if (auto error = _error.load(std::memory_order_relaxed))
if (_sequential_exception)
{
std::rethrow_exception(_sequential_exception);
}
if (_concurrent_exception)
{
delete error;
std::rethrow_exception(_concurrent_exception);
}
clear_children();
}
void clear_children() noexcept
bool has_concurrent_exception() noexcept
{
for (auto & child : _children)
{
child.destroy();
}
_children.clear();
return _concurrent_error_flag.load(std::memory_order_acquire);
}
void check_exception()
void set_concurrent_exception(std::exception_ptr & exception) noexcept
{
if (auto error = _error.load(std::memory_order_relaxed))
if (!_concurrent_error_flag.exchange(true, std::memory_order_acq_rel))
{
_error.store(nullptr, std::memory_order_relaxed);
Defer defer { [error]() { delete error; } };
std::rethrow_exception(error->_exception);
_concurrent_exception = exception;
}
}
void set_exception(std::exception_ptr & exception, u64 rank) noexcept
void set_sequential_exception(std::exception_ptr exception) noexcept
{
auto error = new Error(rank, exception);
Error * expected = nullptr;
while (!_error.compare_exchange_strong(expected, error))
{
if (expected->_rank < rank)
{
delete error;
return;
}
}
if (expected)
{
delete expected;
}
_sequential_exception = std::move(exception);
}
void resume()
bool notify_sync() noexcept
{
_coroutine.resume();
u64 n = _n.fetch_sub(UMAX - _thefts, std::memory_order_acq_rel);
return (n - (UMAX - _thefts) == 0);
}
bool notify_fork() noexcept
{
u64 n = _n.fetch_sub(1, std::memory_order_acq_rel);
return (n == 1);
}
void reset_sync() noexcept
{
_thefts = 0;
_n.store(UMAX, std::memory_order_release);
}
operator std::coroutine_handle<>() noexcept
......@@ -96,9 +88,22 @@ namespace typon
return _coroutine;
}
std::coroutine_handle<> continuation() noexcept
std::coroutine_handle<> fork_continuation() noexcept
{
// It's safe to access _concurrent_exception here
// because this is only called when all strands are done
if (_coroutine.done() || _sequential_exception || _concurrent_exception)
{
return _continuation;
}
return _coroutine;
}
std::coroutine_handle<> sync_continuation() noexcept
{
if (_coroutine.done())
// It's safe to access _concurrent_exception here
// because this is only called when all strands are done
if (_sequential_exception || _concurrent_exception)
{
return _continuation;
}
......
......@@ -26,16 +26,6 @@ namespace typon
}
template <typename Policy, typename Task>
Fork<typename Task::promise_type::value_type, Policy> fork(Task task)
{
// Put the task in a local variable to ensure its destructor will
// be called on co_return instead of only on coroutine destruction.
Task local_task = std::move(task);
co_return co_await std::move(local_task);
}
template <typename Task>
Future<typename Task::promise_type::value_type> future(Task task)
{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment