Commit d4e7cd36 authored by Jens Axboe's avatar Jens Axboe

io_uring: sanitize double poll handling

There's a bit of confusion on the matching pairs of poll vs double poll,
depending on if the request is a pure poll (IORING_OP_POLL_ADD) or
poll driven retry.

Add io_poll_get_double() that returns the double poll waitqueue, if any,
and io_poll_get_single() that returns the original poll waitqueue. With
that, remove the argument to io_poll_remove_double().

Finally ensure that wait->private is cleared once the double poll handler
has run, so that remove knows it's already been seen.

Cc: stable@vger.kernel.org # v5.8
Reported-by: syzbot+7f617d4a9369028b8a2c@syzkaller.appspotmail.com
Fixes: 18bceab1 ("io_uring: allow POLL_ADD with double poll_wait() users")
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 227c0c96
...@@ -4649,9 +4649,24 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll) ...@@ -4649,9 +4649,24 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
return false; return false;
} }
static void io_poll_remove_double(struct io_kiocb *req, void *data) static struct io_poll_iocb *io_poll_get_double(struct io_kiocb *req)
{ {
struct io_poll_iocb *poll = data; /* pure poll stashes this in ->io, poll driven retry elsewhere */
if (req->opcode == IORING_OP_POLL_ADD)
return (struct io_poll_iocb *) req->io;
return req->apoll->double_poll;
}
static struct io_poll_iocb *io_poll_get_single(struct io_kiocb *req)
{
if (req->opcode == IORING_OP_POLL_ADD)
return &req->poll;
return &req->apoll->poll;
}
static void io_poll_remove_double(struct io_kiocb *req)
{
struct io_poll_iocb *poll = io_poll_get_double(req);
lockdep_assert_held(&req->ctx->completion_lock); lockdep_assert_held(&req->ctx->completion_lock);
...@@ -4671,7 +4686,7 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error) ...@@ -4671,7 +4686,7 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
io_poll_remove_double(req, req->io); io_poll_remove_double(req);
req->poll.done = true; req->poll.done = true;
io_cqring_fill_event(req, error ? error : mangle_poll(mask)); io_cqring_fill_event(req, error ? error : mangle_poll(mask));
io_commit_cqring(ctx); io_commit_cqring(ctx);
...@@ -4711,7 +4726,7 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode, ...@@ -4711,7 +4726,7 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
int sync, void *key) int sync, void *key)
{ {
struct io_kiocb *req = wait->private; struct io_kiocb *req = wait->private;
struct io_poll_iocb *poll = req->apoll->double_poll; struct io_poll_iocb *poll = io_poll_get_single(req);
__poll_t mask = key_to_poll(key); __poll_t mask = key_to_poll(key);
/* for instances that support it check for an event match first: */ /* for instances that support it check for an event match first: */
...@@ -4725,6 +4740,8 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode, ...@@ -4725,6 +4740,8 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
done = list_empty(&poll->wait.entry); done = list_empty(&poll->wait.entry);
if (!done) if (!done)
list_del_init(&poll->wait.entry); list_del_init(&poll->wait.entry);
/* make sure double remove sees this as being gone */
wait->private = NULL;
spin_unlock(&poll->head->lock); spin_unlock(&poll->head->lock);
if (!done) if (!done)
__io_async_wake(req, poll, mask, io_poll_task_func); __io_async_wake(req, poll, mask, io_poll_task_func);
...@@ -4808,7 +4825,7 @@ static void io_async_task_func(struct callback_head *cb) ...@@ -4808,7 +4825,7 @@ static void io_async_task_func(struct callback_head *cb)
if (hash_hashed(&req->hash_node)) if (hash_hashed(&req->hash_node))
hash_del(&req->hash_node); hash_del(&req->hash_node);
io_poll_remove_double(req, apoll->double_poll); io_poll_remove_double(req);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
if (!READ_ONCE(apoll->poll.canceled)) if (!READ_ONCE(apoll->poll.canceled))
...@@ -4919,7 +4936,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req) ...@@ -4919,7 +4936,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask, ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
io_async_wake); io_async_wake);
if (ret || ipt.error) { if (ret || ipt.error) {
io_poll_remove_double(req, apoll->double_poll); io_poll_remove_double(req);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
kfree(apoll->double_poll); kfree(apoll->double_poll);
kfree(apoll); kfree(apoll);
...@@ -4951,14 +4968,13 @@ static bool io_poll_remove_one(struct io_kiocb *req) ...@@ -4951,14 +4968,13 @@ static bool io_poll_remove_one(struct io_kiocb *req)
{ {
bool do_complete; bool do_complete;
io_poll_remove_double(req);
if (req->opcode == IORING_OP_POLL_ADD) { if (req->opcode == IORING_OP_POLL_ADD) {
io_poll_remove_double(req, req->io);
do_complete = __io_poll_remove_one(req, &req->poll); do_complete = __io_poll_remove_one(req, &req->poll);
} else { } else {
struct async_poll *apoll = req->apoll; struct async_poll *apoll = req->apoll;
io_poll_remove_double(req, apoll->double_poll);
/* non-poll requests have submit ref still */ /* non-poll requests have submit ref still */
do_complete = __io_poll_remove_one(req, &apoll->poll); do_complete = __io_poll_remove_one(req, &apoll->poll);
if (do_complete) { if (do_complete) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment