Commit 807abcb0 authored by Jens Axboe's avatar Jens Axboe

io_uring: ensure double poll additions work with both request types

The double poll additions were centered around doing POLL_ADD on file
descriptors that use more than one waitqueue (typically one for read,
one for write) when being polled. However, it can also end up being
triggered for when we use poll triggered retry. For that case, we cannot
safely use req->io, as that could be used by the request type itself.

Add a second io_poll_iocb pointer in the structure we allocate for poll
based retry, and ensure we use the right one from the two paths.

Fixes: 18bceab1 ("io_uring: allow POLL_ADD with double poll_wait() users")
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 681fda8d
...@@ -605,6 +605,7 @@ enum { ...@@ -605,6 +605,7 @@ enum {
struct async_poll { struct async_poll {
struct io_poll_iocb poll; struct io_poll_iocb poll;
struct io_poll_iocb *double_poll;
struct io_wq_work work; struct io_wq_work work;
}; };
...@@ -4159,9 +4160,9 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll) ...@@ -4159,9 +4160,9 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
return false; return false;
} }
static void io_poll_remove_double(struct io_kiocb *req) static void io_poll_remove_double(struct io_kiocb *req, void *data)
{ {
struct io_poll_iocb *poll = (struct io_poll_iocb *) req->io; struct io_poll_iocb *poll = data;
lockdep_assert_held(&req->ctx->completion_lock); lockdep_assert_held(&req->ctx->completion_lock);
...@@ -4181,7 +4182,7 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error) ...@@ -4181,7 +4182,7 @@ static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
io_poll_remove_double(req); io_poll_remove_double(req, req->io);
req->poll.done = true; req->poll.done = true;
io_cqring_fill_event(req, error ? error : mangle_poll(mask)); io_cqring_fill_event(req, error ? error : mangle_poll(mask));
io_commit_cqring(ctx); io_commit_cqring(ctx);
...@@ -4224,21 +4225,21 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode, ...@@ -4224,21 +4225,21 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
int sync, void *key) int sync, void *key)
{ {
struct io_kiocb *req = wait->private; struct io_kiocb *req = wait->private;
struct io_poll_iocb *poll = (struct io_poll_iocb *) req->io; struct io_poll_iocb *poll = req->apoll->double_poll;
__poll_t mask = key_to_poll(key); __poll_t mask = key_to_poll(key);
/* for instances that support it check for an event match first: */ /* for instances that support it check for an event match first: */
if (mask && !(mask & poll->events)) if (mask && !(mask & poll->events))
return 0; return 0;
if (req->poll.head) { if (poll && poll->head) {
bool done; bool done;
spin_lock(&req->poll.head->lock); spin_lock(&poll->head->lock);
done = list_empty(&req->poll.wait.entry); done = list_empty(&poll->wait.entry);
if (!done) if (!done)
list_del_init(&req->poll.wait.entry); list_del_init(&poll->wait.entry);
spin_unlock(&req->poll.head->lock); spin_unlock(&poll->head->lock);
if (!done) if (!done)
__io_async_wake(req, poll, mask, io_poll_task_func); __io_async_wake(req, poll, mask, io_poll_task_func);
} }
...@@ -4258,7 +4259,8 @@ static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events, ...@@ -4258,7 +4259,8 @@ static void io_init_poll_iocb(struct io_poll_iocb *poll, __poll_t events,
} }
static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt, static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
struct wait_queue_head *head) struct wait_queue_head *head,
struct io_poll_iocb **poll_ptr)
{ {
struct io_kiocb *req = pt->req; struct io_kiocb *req = pt->req;
...@@ -4269,7 +4271,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt, ...@@ -4269,7 +4271,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
*/ */
if (unlikely(poll->head)) { if (unlikely(poll->head)) {
/* already have a 2nd entry, fail a third attempt */ /* already have a 2nd entry, fail a third attempt */
if (req->io) { if (*poll_ptr) {
pt->error = -EINVAL; pt->error = -EINVAL;
return; return;
} }
...@@ -4281,7 +4283,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt, ...@@ -4281,7 +4283,7 @@ static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt,
io_init_poll_iocb(poll, req->poll.events, io_poll_double_wake); io_init_poll_iocb(poll, req->poll.events, io_poll_double_wake);
refcount_inc(&req->refs); refcount_inc(&req->refs);
poll->wait.private = req; poll->wait.private = req;
req->io = (void *) poll; *poll_ptr = poll;
} }
pt->error = 0; pt->error = 0;
...@@ -4293,8 +4295,9 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head, ...@@ -4293,8 +4295,9 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
struct poll_table_struct *p) struct poll_table_struct *p)
{ {
struct io_poll_table *pt = container_of(p, struct io_poll_table, pt); struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
struct async_poll *apoll = pt->req->apoll;
__io_queue_proc(&pt->req->apoll->poll, pt, head); __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
} }
static void io_sq_thread_drop_mm(struct io_ring_ctx *ctx) static void io_sq_thread_drop_mm(struct io_ring_ctx *ctx)
...@@ -4344,11 +4347,13 @@ static void io_async_task_func(struct callback_head *cb) ...@@ -4344,11 +4347,13 @@ static void io_async_task_func(struct callback_head *cb)
} }
} }
io_poll_remove_double(req, apoll->double_poll);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
/* restore ->work in case we need to retry again */ /* restore ->work in case we need to retry again */
if (req->flags & REQ_F_WORK_INITIALIZED) if (req->flags & REQ_F_WORK_INITIALIZED)
memcpy(&req->work, &apoll->work, sizeof(req->work)); memcpy(&req->work, &apoll->work, sizeof(req->work));
kfree(apoll->double_poll);
kfree(apoll); kfree(apoll);
if (!canceled) { if (!canceled) {
...@@ -4436,7 +4441,6 @@ static bool io_arm_poll_handler(struct io_kiocb *req) ...@@ -4436,7 +4441,6 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
struct async_poll *apoll; struct async_poll *apoll;
struct io_poll_table ipt; struct io_poll_table ipt;
__poll_t mask, ret; __poll_t mask, ret;
bool had_io;
if (!req->file || !file_can_poll(req->file)) if (!req->file || !file_can_poll(req->file))
return false; return false;
...@@ -4448,11 +4452,11 @@ static bool io_arm_poll_handler(struct io_kiocb *req) ...@@ -4448,11 +4452,11 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
apoll = kmalloc(sizeof(*apoll), GFP_ATOMIC); apoll = kmalloc(sizeof(*apoll), GFP_ATOMIC);
if (unlikely(!apoll)) if (unlikely(!apoll))
return false; return false;
apoll->double_poll = NULL;
req->flags |= REQ_F_POLLED; req->flags |= REQ_F_POLLED;
if (req->flags & REQ_F_WORK_INITIALIZED) if (req->flags & REQ_F_WORK_INITIALIZED)
memcpy(&apoll->work, &req->work, sizeof(req->work)); memcpy(&apoll->work, &req->work, sizeof(req->work));
had_io = req->io != NULL;
io_get_req_task(req); io_get_req_task(req);
req->apoll = apoll; req->apoll = apoll;
...@@ -4470,13 +4474,11 @@ static bool io_arm_poll_handler(struct io_kiocb *req) ...@@ -4470,13 +4474,11 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask, ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
io_async_wake); io_async_wake);
if (ret) { if (ret) {
ipt.error = 0; io_poll_remove_double(req, apoll->double_poll);
/* only remove double add if we did it here */
if (!had_io)
io_poll_remove_double(req);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
if (req->flags & REQ_F_WORK_INITIALIZED) if (req->flags & REQ_F_WORK_INITIALIZED)
memcpy(&req->work, &apoll->work, sizeof(req->work)); memcpy(&req->work, &apoll->work, sizeof(req->work));
kfree(apoll->double_poll);
kfree(apoll); kfree(apoll);
return false; return false;
} }
...@@ -4507,11 +4509,13 @@ static bool io_poll_remove_one(struct io_kiocb *req) ...@@ -4507,11 +4509,13 @@ static bool io_poll_remove_one(struct io_kiocb *req)
bool do_complete; bool do_complete;
if (req->opcode == IORING_OP_POLL_ADD) { if (req->opcode == IORING_OP_POLL_ADD) {
io_poll_remove_double(req); io_poll_remove_double(req, req->io);
do_complete = __io_poll_remove_one(req, &req->poll); do_complete = __io_poll_remove_one(req, &req->poll);
} else { } else {
struct async_poll *apoll = req->apoll; struct async_poll *apoll = req->apoll;
io_poll_remove_double(req, apoll->double_poll);
/* non-poll requests have submit ref still */ /* non-poll requests have submit ref still */
do_complete = __io_poll_remove_one(req, &apoll->poll); do_complete = __io_poll_remove_one(req, &apoll->poll);
if (do_complete) { if (do_complete) {
...@@ -4524,6 +4528,7 @@ static bool io_poll_remove_one(struct io_kiocb *req) ...@@ -4524,6 +4528,7 @@ static bool io_poll_remove_one(struct io_kiocb *req)
if (req->flags & REQ_F_WORK_INITIALIZED) if (req->flags & REQ_F_WORK_INITIALIZED)
memcpy(&req->work, &apoll->work, memcpy(&req->work, &apoll->work,
sizeof(req->work)); sizeof(req->work));
kfree(apoll->double_poll);
kfree(apoll); kfree(apoll);
} }
} }
...@@ -4624,7 +4629,7 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head, ...@@ -4624,7 +4629,7 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
{ {
struct io_poll_table *pt = container_of(p, struct io_poll_table, pt); struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
__io_queue_proc(&pt->req->poll, pt, head); __io_queue_proc(&pt->req->poll, pt, head, (struct io_poll_iocb **) &pt->req->io);
} }
static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment