Commit 2e315dc0 authored by Ming Lei's avatar Ming Lei Committed by Jens Axboe

blk-mq: grab rq->refcount before calling ->fn in blk_mq_tagset_busy_iter

Grab rq->refcount before calling ->fn in blk_mq_tagset_busy_iter(), and
this way will prevent the request from being re-used when ->fn is
running. The approach is same as what we do during handling timeout.

Fix request use-after-free(UAF) related with completion race or queue
releasing:

- If one rq is referred before rq->q is frozen, then queue won't be
frozen before the request is released during iteration.

- If one rq is referred after rq->q is frozen, refcount_inc_not_zero()
will return false, and we won't iterate over this request.

However, still one request UAF not covered: refcount_inc_not_zero() may
read one freed request, and it will be handled in next patch.
Tested-by: default avatarJohn Garry <john.garry@huawei.com>
Reviewed-by: default avatarChristoph Hellwig <hch@lst.de>
Reviewed-by: default avatarBart Van Assche <bvanassche@acm.org>
Signed-off-by: default avatarMing Lei <ming.lei@redhat.com>
Link: https://lore.kernel.org/r/20210511152236.763464-3-ming.lei@redhat.comSigned-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 84da7acc
...@@ -199,6 +199,16 @@ struct bt_iter_data { ...@@ -199,6 +199,16 @@ struct bt_iter_data {
bool reserved; bool reserved;
}; };
static struct request *blk_mq_find_and_get_req(struct blk_mq_tags *tags,
unsigned int bitnr)
{
struct request *rq = tags->rqs[bitnr];
if (!rq || !refcount_inc_not_zero(&rq->ref))
return NULL;
return rq;
}
static bool bt_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data) static bool bt_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data)
{ {
struct bt_iter_data *iter_data = data; struct bt_iter_data *iter_data = data;
...@@ -206,18 +216,22 @@ static bool bt_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data) ...@@ -206,18 +216,22 @@ static bool bt_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data)
struct blk_mq_tags *tags = hctx->tags; struct blk_mq_tags *tags = hctx->tags;
bool reserved = iter_data->reserved; bool reserved = iter_data->reserved;
struct request *rq; struct request *rq;
bool ret = true;
if (!reserved) if (!reserved)
bitnr += tags->nr_reserved_tags; bitnr += tags->nr_reserved_tags;
rq = tags->rqs[bitnr];
/* /*
* We can hit rq == NULL here, because the tagging functions * We can hit rq == NULL here, because the tagging functions
* test and set the bit before assigning ->rqs[]. * test and set the bit before assigning ->rqs[].
*/ */
if (rq && rq->q == hctx->queue && rq->mq_hctx == hctx) rq = blk_mq_find_and_get_req(tags, bitnr);
return iter_data->fn(hctx, rq, iter_data->data, reserved); if (!rq)
return true; return true;
if (rq->q == hctx->queue && rq->mq_hctx == hctx)
ret = iter_data->fn(hctx, rq, iter_data->data, reserved);
blk_mq_put_rq_ref(rq);
return ret;
} }
/** /**
...@@ -264,6 +278,8 @@ static bool bt_tags_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data) ...@@ -264,6 +278,8 @@ static bool bt_tags_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data)
struct blk_mq_tags *tags = iter_data->tags; struct blk_mq_tags *tags = iter_data->tags;
bool reserved = iter_data->flags & BT_TAG_ITER_RESERVED; bool reserved = iter_data->flags & BT_TAG_ITER_RESERVED;
struct request *rq; struct request *rq;
bool ret = true;
bool iter_static_rqs = !!(iter_data->flags & BT_TAG_ITER_STATIC_RQS);
if (!reserved) if (!reserved)
bitnr += tags->nr_reserved_tags; bitnr += tags->nr_reserved_tags;
...@@ -272,16 +288,19 @@ static bool bt_tags_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data) ...@@ -272,16 +288,19 @@ static bool bt_tags_iter(struct sbitmap *bitmap, unsigned int bitnr, void *data)
* We can hit rq == NULL here, because the tagging functions * We can hit rq == NULL here, because the tagging functions
* test and set the bit before assigning ->rqs[]. * test and set the bit before assigning ->rqs[].
*/ */
if (iter_data->flags & BT_TAG_ITER_STATIC_RQS) if (iter_static_rqs)
rq = tags->static_rqs[bitnr]; rq = tags->static_rqs[bitnr];
else else
rq = tags->rqs[bitnr]; rq = blk_mq_find_and_get_req(tags, bitnr);
if (!rq) if (!rq)
return true; return true;
if ((iter_data->flags & BT_TAG_ITER_STARTED) &&
!blk_mq_request_started(rq)) if (!(iter_data->flags & BT_TAG_ITER_STARTED) ||
return true; blk_mq_request_started(rq))
return iter_data->fn(rq, iter_data->data, reserved); ret = iter_data->fn(rq, iter_data->data, reserved);
if (!iter_static_rqs)
blk_mq_put_rq_ref(rq);
return ret;
} }
/** /**
...@@ -348,6 +367,9 @@ void blk_mq_all_tag_iter(struct blk_mq_tags *tags, busy_tag_iter_fn *fn, ...@@ -348,6 +367,9 @@ void blk_mq_all_tag_iter(struct blk_mq_tags *tags, busy_tag_iter_fn *fn,
* indicates whether or not @rq is a reserved request. Return * indicates whether or not @rq is a reserved request. Return
* true to continue iterating tags, false to stop. * true to continue iterating tags, false to stop.
* @priv: Will be passed as second argument to @fn. * @priv: Will be passed as second argument to @fn.
*
* We grab one request reference before calling @fn and release it after
* @fn returns.
*/ */
void blk_mq_tagset_busy_iter(struct blk_mq_tag_set *tagset, void blk_mq_tagset_busy_iter(struct blk_mq_tag_set *tagset,
busy_tag_iter_fn *fn, void *priv) busy_tag_iter_fn *fn, void *priv)
......
...@@ -909,6 +909,14 @@ static bool blk_mq_req_expired(struct request *rq, unsigned long *next) ...@@ -909,6 +909,14 @@ static bool blk_mq_req_expired(struct request *rq, unsigned long *next)
return false; return false;
} }
void blk_mq_put_rq_ref(struct request *rq)
{
if (is_flush_rq(rq, rq->mq_hctx))
rq->end_io(rq, 0);
else if (refcount_dec_and_test(&rq->ref))
__blk_mq_free_request(rq);
}
static bool blk_mq_check_expired(struct blk_mq_hw_ctx *hctx, static bool blk_mq_check_expired(struct blk_mq_hw_ctx *hctx,
struct request *rq, void *priv, bool reserved) struct request *rq, void *priv, bool reserved)
{ {
...@@ -942,11 +950,7 @@ static bool blk_mq_check_expired(struct blk_mq_hw_ctx *hctx, ...@@ -942,11 +950,7 @@ static bool blk_mq_check_expired(struct blk_mq_hw_ctx *hctx,
if (blk_mq_req_expired(rq, next)) if (blk_mq_req_expired(rq, next))
blk_mq_rq_timed_out(rq, reserved); blk_mq_rq_timed_out(rq, reserved);
if (is_flush_rq(rq, hctx)) blk_mq_put_rq_ref(rq);
rq->end_io(rq, 0);
else if (refcount_dec_and_test(&rq->ref))
__blk_mq_free_request(rq);
return true; return true;
} }
......
...@@ -47,6 +47,7 @@ void blk_mq_add_to_requeue_list(struct request *rq, bool at_head, ...@@ -47,6 +47,7 @@ void blk_mq_add_to_requeue_list(struct request *rq, bool at_head,
void blk_mq_flush_busy_ctxs(struct blk_mq_hw_ctx *hctx, struct list_head *list); void blk_mq_flush_busy_ctxs(struct blk_mq_hw_ctx *hctx, struct list_head *list);
struct request *blk_mq_dequeue_from_ctx(struct blk_mq_hw_ctx *hctx, struct request *blk_mq_dequeue_from_ctx(struct blk_mq_hw_ctx *hctx,
struct blk_mq_ctx *start); struct blk_mq_ctx *start);
void blk_mq_put_rq_ref(struct request *rq);
/* /*
* Internal helpers for allocating/freeing the request map * Internal helpers for allocating/freeing the request map
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment