Commit 47283bef authored by Michael S. Tsirkin's avatar Michael S. Tsirkin

vhost: move memory pointer to VQs

commit 2ae76693b8bcabf370b981cd00c36cd41d33fabc
    vhost: replace rcu with mutex
replaced rcu sync for memory accesses with VQ mutex locl/unlock.
This is correct since all accesses are under VQ mutex, but incomplete:
we still do useless rcu lock/unlock operations, someone might copy this
code into some other context where this won't be right.
This use of RCU is also non standard and hard to understand.
Let's copy the pointer to each VQ structure, this way
the access rules become straight-forward, and there's
no need for RCU anymore.
Reported-by: default avatarEric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarMichael S. Tsirkin <mst@redhat.com>
parent ea16c514
...@@ -374,7 +374,7 @@ static void handle_tx(struct vhost_net *net) ...@@ -374,7 +374,7 @@ static void handle_tx(struct vhost_net *net)
% UIO_MAXIOV == nvq->done_idx)) % UIO_MAXIOV == nvq->done_idx))
break; break;
head = vhost_get_vq_desc(&net->dev, vq, vq->iov, head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov), ARRAY_SIZE(vq->iov),
&out, &in, &out, &in,
NULL, NULL); NULL, NULL);
...@@ -506,7 +506,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, ...@@ -506,7 +506,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
r = -ENOBUFS; r = -ENOBUFS;
goto err; goto err;
} }
r = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, r = vhost_get_vq_desc(vq, vq->iov + seg,
ARRAY_SIZE(vq->iov) - seg, &out, ARRAY_SIZE(vq->iov) - seg, &out,
&in, log, log_num); &in, log, log_num);
if (unlikely(r < 0)) if (unlikely(r < 0))
......
...@@ -606,7 +606,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt) ...@@ -606,7 +606,7 @@ tcm_vhost_do_evt_work(struct vhost_scsi *vs, struct tcm_vhost_evt *evt)
again: again:
vhost_disable_notify(&vs->dev, vq); vhost_disable_notify(&vs->dev, vq);
head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov), &out, &in, ARRAY_SIZE(vq->iov), &out, &in,
NULL, NULL); NULL, NULL);
if (head < 0) { if (head < 0) {
...@@ -945,7 +945,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq) ...@@ -945,7 +945,7 @@ vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
vhost_disable_notify(&vs->dev, vq); vhost_disable_notify(&vs->dev, vq);
for (;;) { for (;;) {
head = vhost_get_vq_desc(&vs->dev, vq, vq->iov, head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov), &out, &in, ARRAY_SIZE(vq->iov), &out, &in,
NULL, NULL); NULL, NULL);
pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n", pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
......
...@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n) ...@@ -53,7 +53,7 @@ static void handle_vq(struct vhost_test *n)
vhost_disable_notify(&n->dev, vq); vhost_disable_notify(&n->dev, vq);
for (;;) { for (;;) {
head = vhost_get_vq_desc(&n->dev, vq, vq->iov, head = vhost_get_vq_desc(vq, vq->iov,
ARRAY_SIZE(vq->iov), ARRAY_SIZE(vq->iov),
&out, &in, &out, &in,
NULL, NULL); NULL, NULL);
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <linux/mmu_context.h> #include <linux/mmu_context.h>
#include <linux/miscdevice.h> #include <linux/miscdevice.h>
#include <linux/mutex.h> #include <linux/mutex.h>
#include <linux/rcupdate.h>
#include <linux/poll.h> #include <linux/poll.h>
#include <linux/file.h> #include <linux/file.h>
#include <linux/highmem.h> #include <linux/highmem.h>
...@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev, ...@@ -199,6 +198,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->call_ctx = NULL; vq->call_ctx = NULL;
vq->call = NULL; vq->call = NULL;
vq->log_ctx = NULL; vq->log_ctx = NULL;
vq->memory = NULL;
} }
static int vhost_worker(void *data) static int vhost_worker(void *data)
...@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); ...@@ -416,11 +416,18 @@ EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
/* Caller should have device mutex */ /* Caller should have device mutex */
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory)
{ {
int i;
vhost_dev_cleanup(dev, true); vhost_dev_cleanup(dev, true);
/* Restore memory to default empty mapping. */ /* Restore memory to default empty mapping. */
memory->nregions = 0; memory->nregions = 0;
RCU_INIT_POINTER(dev->memory, memory); dev->memory = memory;
/* We don't need VQ locks below since vhost_dev_cleanup makes sure
* VQs aren't running.
*/
for (i = 0; i < dev->nvqs; ++i)
dev->vqs[i]->memory = memory;
} }
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
...@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) ...@@ -463,10 +470,8 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
fput(dev->log_file); fput(dev->log_file);
dev->log_file = NULL; dev->log_file = NULL;
/* No one will access memory at this point */ /* No one will access memory at this point */
kfree(rcu_dereference_protected(dev->memory, kfree(dev->memory);
locked == dev->memory = NULL;
lockdep_is_held(&dev->mutex)));
RCU_INIT_POINTER(dev->memory, NULL);
WARN_ON(!list_empty(&dev->work_list)); WARN_ON(!list_empty(&dev->work_list));
if (dev->worker) { if (dev->worker) {
kthread_stop(dev->worker); kthread_stop(dev->worker);
...@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, ...@@ -558,11 +563,7 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
/* Caller should have device mutex but not vq mutex */ /* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev) int vhost_log_access_ok(struct vhost_dev *dev)
{ {
struct vhost_memory *mp; return memory_access_ok(dev, dev->memory, 1);
mp = rcu_dereference_protected(dev->memory,
lockdep_is_held(&dev->mutex));
return memory_access_ok(dev, mp, 1);
} }
EXPORT_SYMBOL_GPL(vhost_log_access_ok); EXPORT_SYMBOL_GPL(vhost_log_access_ok);
...@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok); ...@@ -571,12 +572,9 @@ EXPORT_SYMBOL_GPL(vhost_log_access_ok);
static int vq_log_access_ok(struct vhost_virtqueue *vq, static int vq_log_access_ok(struct vhost_virtqueue *vq,
void __user *log_base) void __user *log_base)
{ {
struct vhost_memory *mp;
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
mp = rcu_dereference_protected(vq->dev->memory, return vq_memory_access_ok(log_base, vq->memory,
lockdep_is_held(&vq->mutex));
return vq_memory_access_ok(log_base, mp,
vhost_has_feature(vq, VHOST_F_LOG_ALL)) && vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr, (!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used + sizeof *vq->used +
...@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) ...@@ -619,15 +617,13 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
kfree(newmem); kfree(newmem);
return -EFAULT; return -EFAULT;
} }
oldmem = rcu_dereference_protected(d->memory, oldmem = d->memory;
lockdep_is_held(&d->mutex)); d->memory = newmem;
rcu_assign_pointer(d->memory, newmem);
/* All memory accesses are done under some VQ mutex. /* All memory accesses are done under some VQ mutex. */
* So below is a faster equivalent of synchronize_rcu()
*/
for (i = 0; i < d->nvqs; ++i) { for (i = 0; i < d->nvqs; ++i) {
mutex_lock(&d->vqs[i]->mutex); mutex_lock(&d->vqs[i]->mutex);
d->vqs[i]->memory = newmem;
mutex_unlock(&d->vqs[i]->mutex); mutex_unlock(&d->vqs[i]->mutex);
} }
kfree(oldmem); kfree(oldmem);
...@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq) ...@@ -1054,7 +1050,7 @@ int vhost_init_used(struct vhost_virtqueue *vq)
} }
EXPORT_SYMBOL_GPL(vhost_init_used); EXPORT_SYMBOL_GPL(vhost_init_used);
static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size) struct iovec iov[], int iov_size)
{ {
const struct vhost_memory_region *reg; const struct vhost_memory_region *reg;
...@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, ...@@ -1063,9 +1059,7 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
u64 s = 0; u64 s = 0;
int ret = 0; int ret = 0;
rcu_read_lock(); mem = vq->memory;
mem = rcu_dereference(dev->memory);
while ((u64)len > s) { while ((u64)len > s) {
u64 size; u64 size;
if (unlikely(ret >= iov_size)) { if (unlikely(ret >= iov_size)) {
...@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, ...@@ -1087,7 +1081,6 @@ static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
++ret; ++ret;
} }
rcu_read_unlock();
return ret; return ret;
} }
...@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc) ...@@ -1112,7 +1105,7 @@ static unsigned next_desc(struct vring_desc *desc)
return next; return next;
} }
static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, static int get_indirect(struct vhost_virtqueue *vq,
struct iovec iov[], unsigned int iov_size, struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num, unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num, struct vhost_log *log, unsigned int *log_num,
...@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, ...@@ -1131,7 +1124,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EINVAL; return -EINVAL;
} }
ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, ret = translate_desc(vq, indirect->addr, indirect->len, vq->indirect,
UIO_MAXIOV); UIO_MAXIOV);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret); vq_err(vq, "Translation failure %d in indirect.\n", ret);
...@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, ...@@ -1171,7 +1164,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EINVAL; return -EINVAL;
} }
ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
iov_size - iov_count); iov_size - iov_count);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d indirect idx %d\n", vq_err(vq, "Translation failure %d indirect idx %d\n",
...@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, ...@@ -1208,7 +1201,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
* This function returns the descriptor number found, or vq->num (which is * This function returns the descriptor number found, or vq->num (which is
* never a valid descriptor number) if none was found. A negative code is * never a valid descriptor number) if none was found. A negative code is
* returned on error. */ * returned on error. */
int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, int vhost_get_vq_desc(struct vhost_virtqueue *vq,
struct iovec iov[], unsigned int iov_size, struct iovec iov[], unsigned int iov_size,
unsigned int *out_num, unsigned int *in_num, unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num) struct vhost_log *log, unsigned int *log_num)
...@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, ...@@ -1282,7 +1275,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
return -EFAULT; return -EFAULT;
} }
if (desc.flags & VRING_DESC_F_INDIRECT) { if (desc.flags & VRING_DESC_F_INDIRECT) {
ret = get_indirect(dev, vq, iov, iov_size, ret = get_indirect(vq, iov, iov_size,
out_num, in_num, out_num, in_num,
log, log_num, &desc); log, log_num, &desc);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
...@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, ...@@ -1293,7 +1286,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
continue; continue;
} }
ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, ret = translate_desc(vq, desc.addr, desc.len, iov + iov_count,
iov_size - iov_count); iov_size - iov_count);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d descriptor idx %d\n", vq_err(vq, "Translation failure %d descriptor idx %d\n",
......
...@@ -104,6 +104,7 @@ struct vhost_virtqueue { ...@@ -104,6 +104,7 @@ struct vhost_virtqueue {
struct iovec *indirect; struct iovec *indirect;
struct vring_used_elem *heads; struct vring_used_elem *heads;
/* Protected by virtqueue mutex. */ /* Protected by virtqueue mutex. */
struct vhost_memory *memory;
void *private_data; void *private_data;
unsigned acked_features; unsigned acked_features;
/* Log write descriptors */ /* Log write descriptors */
...@@ -112,10 +113,7 @@ struct vhost_virtqueue { ...@@ -112,10 +113,7 @@ struct vhost_virtqueue {
}; };
struct vhost_dev { struct vhost_dev {
/* Readers use RCU to access memory table pointer struct vhost_memory *memory;
* log base pointer and features.
* Writers use mutex below.*/
struct vhost_memory __rcu *memory;
struct mm_struct *mm; struct mm_struct *mm;
struct mutex mutex; struct mutex mutex;
struct vhost_virtqueue **vqs; struct vhost_virtqueue **vqs;
...@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp); ...@@ -140,7 +138,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp);
int vhost_vq_access_ok(struct vhost_virtqueue *vq); int vhost_vq_access_ok(struct vhost_virtqueue *vq);
int vhost_log_access_ok(struct vhost_dev *); int vhost_log_access_ok(struct vhost_dev *);
int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, int vhost_get_vq_desc(struct vhost_virtqueue *,
struct iovec iov[], unsigned int iov_count, struct iovec iov[], unsigned int iov_count,
unsigned int *out_num, unsigned int *in_num, unsigned int *out_num, unsigned int *in_num,
struct vhost_log *log, unsigned int *log_num); struct vhost_log *log, unsigned int *log_num);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment