Commit 0803e040 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull virtio/vhost updates from Michael Tsirkin:

 - new vsock device support in host and guest

 - platform IOMMU support in host and guest, including compatibility
   quirks for legacy systems.

 - misc fixes and cleanups.

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost:
  VSOCK: Use kvfree()
  vhost: split out vringh Kconfig
  vhost: detect 32 bit integer wrap around
  vhost: new device IOTLB API
  vhost: drop vringh dependency
  vhost: convert pre sorted vhost memory array to interval tree
  vhost: introduce vhost memory accessors
  VSOCK: Add Makefile and Kconfig
  VSOCK: Introduce vhost_vsock.ko
  VSOCK: Introduce virtio_transport.ko
  VSOCK: Introduce virtio_vsock_common.ko
  VSOCK: defer sock removal to transports
  VSOCK: transport-specific vsock_transport functions
  vhost: drop vringh dependency
  vop: pull in vhost Kconfig
  virtio: new feature to detect IOMMU device quirk
  balloon: check the number of available pages in leak balloon
  vhost: lockless enqueuing
  vhost: simplify work flushing
parents 80fac0f5 b226acab
...@@ -12419,6 +12419,19 @@ S: Maintained ...@@ -12419,6 +12419,19 @@ S: Maintained
F: drivers/media/v4l2-core/videobuf2-* F: drivers/media/v4l2-core/videobuf2-*
F: include/media/videobuf2-* F: include/media/videobuf2-*
VIRTIO AND VHOST VSOCK DRIVER
M: Stefan Hajnoczi <stefanha@redhat.com>
L: kvm@vger.kernel.org
L: virtualization@lists.linux-foundation.org
L: netdev@vger.kernel.org
S: Maintained
F: include/linux/virtio_vsock.h
F: include/uapi/linux/virtio_vsock.h
F: net/vmw_vsock/virtio_transport_common.c
F: net/vmw_vsock/virtio_transport.c
F: drivers/vhost/vsock.c
F: drivers/vhost/vsock.h
VIRTUAL SERIO DEVICE DRIVER VIRTUAL SERIO DEVICE DRIVER
M: Stephen Chandler Paul <thatslyude@gmail.com> M: Stephen Chandler Paul <thatslyude@gmail.com>
S: Maintained S: Maintained
......
...@@ -138,6 +138,7 @@ obj-$(CONFIG_OF) += of/ ...@@ -138,6 +138,7 @@ obj-$(CONFIG_OF) += of/
obj-$(CONFIG_SSB) += ssb/ obj-$(CONFIG_SSB) += ssb/
obj-$(CONFIG_BCMA) += bcma/ obj-$(CONFIG_BCMA) += bcma/
obj-$(CONFIG_VHOST_RING) += vhost/ obj-$(CONFIG_VHOST_RING) += vhost/
obj-$(CONFIG_VHOST) += vhost/
obj-$(CONFIG_VLYNQ) += vlynq/ obj-$(CONFIG_VLYNQ) += vlynq/
obj-$(CONFIG_STAGING) += staging/ obj-$(CONFIG_STAGING) += staging/
obj-y += platform/ obj-y += platform/
......
...@@ -146,3 +146,7 @@ config VOP ...@@ -146,3 +146,7 @@ config VOP
More information about the Intel MIC family as well as the Linux More information about the Intel MIC family as well as the Linux
OS and tools for MIC to use with this driver are available from OS and tools for MIC to use with this driver are available from
<http://software.intel.com/en-us/mic-developer>. <http://software.intel.com/en-us/mic-developer>.
if VOP
source "drivers/vhost/Kconfig.vringh"
endif
...@@ -52,5 +52,5 @@ config CAIF_VIRTIO ...@@ -52,5 +52,5 @@ config CAIF_VIRTIO
The caif driver for CAIF over Virtio. The caif driver for CAIF over Virtio.
if CAIF_VIRTIO if CAIF_VIRTIO
source "drivers/vhost/Kconfig" source "drivers/vhost/Kconfig.vringh"
endif endif
...@@ -2,7 +2,6 @@ config VHOST_NET ...@@ -2,7 +2,6 @@ config VHOST_NET
tristate "Host kernel accelerator for virtio net" tristate "Host kernel accelerator for virtio net"
depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP) depends on NET && EVENTFD && (TUN || !TUN) && (MACVTAP || !MACVTAP)
select VHOST select VHOST
select VHOST_RING
---help--- ---help---
This kernel module can be loaded in host kernel to accelerate This kernel module can be loaded in host kernel to accelerate
guest networking with virtio_net. Not to be confused with virtio_net guest networking with virtio_net. Not to be confused with virtio_net
...@@ -15,17 +14,24 @@ config VHOST_SCSI ...@@ -15,17 +14,24 @@ config VHOST_SCSI
tristate "VHOST_SCSI TCM fabric driver" tristate "VHOST_SCSI TCM fabric driver"
depends on TARGET_CORE && EVENTFD && m depends on TARGET_CORE && EVENTFD && m
select VHOST select VHOST
select VHOST_RING
default n default n
---help--- ---help---
Say M here to enable the vhost_scsi TCM fabric module Say M here to enable the vhost_scsi TCM fabric module
for use with virtio-scsi guests for use with virtio-scsi guests
config VHOST_RING config VHOST_VSOCK
tristate tristate "vhost virtio-vsock driver"
depends on VSOCKETS && EVENTFD
select VIRTIO_VSOCKETS_COMMON
select VHOST
default n
---help--- ---help---
This option is selected by any driver which needs to access This kernel module can be loaded in the host kernel to provide AF_VSOCK
the host side of a virtio ring. sockets for communicating with guests. The guests must have the
virtio_transport.ko driver loaded to use the virtio-vsock device.
To compile this driver as a module, choose M here: the module will be called
vhost_vsock.
config VHOST config VHOST
tristate tristate
......
config VHOST_RING
tristate
---help---
This option is selected by any driver which needs to access
the host side of a virtio ring.
...@@ -4,5 +4,9 @@ vhost_net-y := net.o ...@@ -4,5 +4,9 @@ vhost_net-y := net.o
obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o obj-$(CONFIG_VHOST_SCSI) += vhost_scsi.o
vhost_scsi-y := scsi.o vhost_scsi-y := scsi.o
obj-$(CONFIG_VHOST_VSOCK) += vhost_vsock.o
vhost_vsock-y := vsock.o
obj-$(CONFIG_VHOST_RING) += vringh.o obj-$(CONFIG_VHOST_RING) += vringh.o
obj-$(CONFIG_VHOST) += vhost.o obj-$(CONFIG_VHOST) += vhost.o
...@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;" ...@@ -61,7 +61,8 @@ MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
enum { enum {
VHOST_NET_FEATURES = VHOST_FEATURES | VHOST_NET_FEATURES = VHOST_FEATURES |
(1ULL << VHOST_NET_F_VIRTIO_NET_HDR) | (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
(1ULL << VIRTIO_NET_F_MRG_RXBUF) (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
(1ULL << VIRTIO_F_IOMMU_PLATFORM)
}; };
enum { enum {
...@@ -377,6 +378,9 @@ static void handle_tx(struct vhost_net *net) ...@@ -377,6 +378,9 @@ static void handle_tx(struct vhost_net *net)
if (!sock) if (!sock)
goto out; goto out;
if (!vq_iotlb_prefetch(vq))
goto out;
vhost_disable_notify(&net->dev, vq); vhost_disable_notify(&net->dev, vq);
hdr_size = nvq->vhost_hlen; hdr_size = nvq->vhost_hlen;
...@@ -652,6 +656,10 @@ static void handle_rx(struct vhost_net *net) ...@@ -652,6 +656,10 @@ static void handle_rx(struct vhost_net *net)
sock = vq->private_data; sock = vq->private_data;
if (!sock) if (!sock)
goto out; goto out;
if (!vq_iotlb_prefetch(vq))
goto out;
vhost_disable_notify(&net->dev, vq); vhost_disable_notify(&net->dev, vq);
vhost_net_disable_vq(net, vq); vhost_net_disable_vq(net, vq);
...@@ -1052,20 +1060,20 @@ static long vhost_net_reset_owner(struct vhost_net *n) ...@@ -1052,20 +1060,20 @@ static long vhost_net_reset_owner(struct vhost_net *n)
struct socket *tx_sock = NULL; struct socket *tx_sock = NULL;
struct socket *rx_sock = NULL; struct socket *rx_sock = NULL;
long err; long err;
struct vhost_memory *memory; struct vhost_umem *umem;
mutex_lock(&n->dev.mutex); mutex_lock(&n->dev.mutex);
err = vhost_dev_check_owner(&n->dev); err = vhost_dev_check_owner(&n->dev);
if (err) if (err)
goto done; goto done;
memory = vhost_dev_reset_owner_prepare(); umem = vhost_dev_reset_owner_prepare();
if (!memory) { if (!umem) {
err = -ENOMEM; err = -ENOMEM;
goto done; goto done;
} }
vhost_net_stop(n, &tx_sock, &rx_sock); vhost_net_stop(n, &tx_sock, &rx_sock);
vhost_net_flush(n); vhost_net_flush(n);
vhost_dev_reset_owner(&n->dev, memory); vhost_dev_reset_owner(&n->dev, umem);
vhost_net_vq_reset(n); vhost_net_vq_reset(n);
done: done:
mutex_unlock(&n->dev.mutex); mutex_unlock(&n->dev.mutex);
...@@ -1096,10 +1104,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) ...@@ -1096,10 +1104,14 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
} }
mutex_lock(&n->dev.mutex); mutex_lock(&n->dev.mutex);
if ((features & (1 << VHOST_F_LOG_ALL)) && if ((features & (1 << VHOST_F_LOG_ALL)) &&
!vhost_log_access_ok(&n->dev)) { !vhost_log_access_ok(&n->dev))
mutex_unlock(&n->dev.mutex); goto out_unlock;
return -EFAULT;
if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) {
if (vhost_init_device_iotlb(&n->dev, true))
goto out_unlock;
} }
for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
mutex_lock(&n->vqs[i].vq.mutex); mutex_lock(&n->vqs[i].vq.mutex);
n->vqs[i].vq.acked_features = features; n->vqs[i].vq.acked_features = features;
...@@ -1109,6 +1121,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) ...@@ -1109,6 +1121,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
} }
mutex_unlock(&n->dev.mutex); mutex_unlock(&n->dev.mutex);
return 0; return 0;
out_unlock:
mutex_unlock(&n->dev.mutex);
return -EFAULT;
} }
static long vhost_net_set_owner(struct vhost_net *n) static long vhost_net_set_owner(struct vhost_net *n)
...@@ -1182,9 +1198,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl, ...@@ -1182,9 +1198,40 @@ static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
} }
#endif #endif
static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
{
struct file *file = iocb->ki_filp;
struct vhost_net *n = file->private_data;
struct vhost_dev *dev = &n->dev;
int noblock = file->f_flags & O_NONBLOCK;
return vhost_chr_read_iter(dev, to, noblock);
}
static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
struct iov_iter *from)
{
struct file *file = iocb->ki_filp;
struct vhost_net *n = file->private_data;
struct vhost_dev *dev = &n->dev;
return vhost_chr_write_iter(dev, from);
}
static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
{
struct vhost_net *n = file->private_data;
struct vhost_dev *dev = &n->dev;
return vhost_chr_poll(file, dev, wait);
}
static const struct file_operations vhost_net_fops = { static const struct file_operations vhost_net_fops = {
.owner = THIS_MODULE, .owner = THIS_MODULE,
.release = vhost_net_release, .release = vhost_net_release,
.read_iter = vhost_net_chr_read_iter,
.write_iter = vhost_net_chr_write_iter,
.poll = vhost_net_chr_poll,
.unlocked_ioctl = vhost_net_ioctl, .unlocked_ioctl = vhost_net_ioctl,
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
.compat_ioctl = vhost_net_compat_ioctl, .compat_ioctl = vhost_net_compat_ioctl,
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <linux/cgroup.h> #include <linux/cgroup.h>
#include <linux/module.h> #include <linux/module.h>
#include <linux/sort.h> #include <linux/sort.h>
#include <linux/interval_tree_generic.h>
#include "vhost.h" #include "vhost.h"
...@@ -34,6 +35,10 @@ static ushort max_mem_regions = 64; ...@@ -34,6 +35,10 @@ static ushort max_mem_regions = 64;
module_param(max_mem_regions, ushort, 0444); module_param(max_mem_regions, ushort, 0444);
MODULE_PARM_DESC(max_mem_regions, MODULE_PARM_DESC(max_mem_regions,
"Maximum number of memory regions in memory map. (default: 64)"); "Maximum number of memory regions in memory map. (default: 64)");
static int max_iotlb_entries = 2048;
module_param(max_iotlb_entries, int, 0444);
MODULE_PARM_DESC(max_iotlb_entries,
"Maximum number of iotlb entries. (default: 2048)");
enum { enum {
VHOST_MEMORY_F_LOG = 0x1, VHOST_MEMORY_F_LOG = 0x1,
...@@ -42,6 +47,10 @@ enum { ...@@ -42,6 +47,10 @@ enum {
#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
INTERVAL_TREE_DEFINE(struct vhost_umem_node,
rb, __u64, __subtree_last,
START, LAST, , vhost_umem_interval_tree);
#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
{ {
...@@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq) ...@@ -131,6 +140,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq)
vq->is_le = virtio_legacy_is_little_endian(); vq->is_le = virtio_legacy_is_little_endian();
} }
struct vhost_flush_struct {
struct vhost_work work;
struct completion wait_event;
};
static void vhost_flush_work(struct vhost_work *work)
{
struct vhost_flush_struct *s;
s = container_of(work, struct vhost_flush_struct, work);
complete(&s->wait_event);
}
static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
poll_table *pt) poll_table *pt)
{ {
...@@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, ...@@ -155,11 +177,9 @@ static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
{ {
INIT_LIST_HEAD(&work->node); clear_bit(VHOST_WORK_QUEUED, &work->flags);
work->fn = fn; work->fn = fn;
init_waitqueue_head(&work->done); init_waitqueue_head(&work->done);
work->flushing = 0;
work->queue_seq = work->done_seq = 0;
} }
EXPORT_SYMBOL_GPL(vhost_work_init); EXPORT_SYMBOL_GPL(vhost_work_init);
...@@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll) ...@@ -211,31 +231,17 @@ void vhost_poll_stop(struct vhost_poll *poll)
} }
EXPORT_SYMBOL_GPL(vhost_poll_stop); EXPORT_SYMBOL_GPL(vhost_poll_stop);
static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
unsigned seq)
{
int left;
spin_lock_irq(&dev->work_lock);
left = seq - work->done_seq;
spin_unlock_irq(&dev->work_lock);
return left <= 0;
}
void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
{ {
unsigned seq; struct vhost_flush_struct flush;
int flushing;
spin_lock_irq(&dev->work_lock); if (dev->worker) {
seq = work->queue_seq; init_completion(&flush.wait_event);
work->flushing++; vhost_work_init(&flush.work, vhost_flush_work);
spin_unlock_irq(&dev->work_lock);
wait_event(work->done, vhost_work_seq_done(dev, work, seq)); vhost_work_queue(dev, &flush.work);
spin_lock_irq(&dev->work_lock); wait_for_completion(&flush.wait_event);
flushing = --work->flushing; }
spin_unlock_irq(&dev->work_lock);
BUG_ON(flushing < 0);
} }
EXPORT_SYMBOL_GPL(vhost_work_flush); EXPORT_SYMBOL_GPL(vhost_work_flush);
...@@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush); ...@@ -249,16 +255,16 @@ EXPORT_SYMBOL_GPL(vhost_poll_flush);
void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
{ {
unsigned long flags; if (!dev->worker)
return;
spin_lock_irqsave(&dev->work_lock, flags); if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
if (list_empty(&work->node)) { /* We can only add the work to the list after we're
list_add_tail(&work->node, &dev->work_list); * sure it was not in the list.
work->queue_seq++; */
spin_unlock_irqrestore(&dev->work_lock, flags); smp_mb();
llist_add(&work->node, &dev->work_list);
wake_up_process(dev->worker); wake_up_process(dev->worker);
} else {
spin_unlock_irqrestore(&dev->work_lock, flags);
} }
} }
EXPORT_SYMBOL_GPL(vhost_work_queue); EXPORT_SYMBOL_GPL(vhost_work_queue);
...@@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue); ...@@ -266,7 +272,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue);
/* A lockless hint for busy polling code to exit the loop */ /* A lockless hint for busy polling code to exit the loop */
bool vhost_has_work(struct vhost_dev *dev) bool vhost_has_work(struct vhost_dev *dev)
{ {
return !list_empty(&dev->work_list); return !llist_empty(&dev->work_list);
} }
EXPORT_SYMBOL_GPL(vhost_has_work); EXPORT_SYMBOL_GPL(vhost_has_work);
...@@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev, ...@@ -300,17 +306,18 @@ static void vhost_vq_reset(struct vhost_dev *dev,
vq->call_ctx = NULL; vq->call_ctx = NULL;
vq->call = NULL; vq->call = NULL;
vq->log_ctx = NULL; vq->log_ctx = NULL;
vq->memory = NULL;
vhost_reset_is_le(vq); vhost_reset_is_le(vq);
vhost_disable_cross_endian(vq); vhost_disable_cross_endian(vq);
vq->busyloop_timeout = 0; vq->busyloop_timeout = 0;
vq->umem = NULL;
vq->iotlb = NULL;
} }
static int vhost_worker(void *data) static int vhost_worker(void *data)
{ {
struct vhost_dev *dev = data; struct vhost_dev *dev = data;
struct vhost_work *work = NULL; struct vhost_work *work, *work_next;
unsigned uninitialized_var(seq); struct llist_node *node;
mm_segment_t oldfs = get_fs(); mm_segment_t oldfs = get_fs();
set_fs(USER_DS); set_fs(USER_DS);
...@@ -320,35 +327,25 @@ static int vhost_worker(void *data) ...@@ -320,35 +327,25 @@ static int vhost_worker(void *data)
/* mb paired w/ kthread_stop */ /* mb paired w/ kthread_stop */
set_current_state(TASK_INTERRUPTIBLE); set_current_state(TASK_INTERRUPTIBLE);
spin_lock_irq(&dev->work_lock);
if (work) {
work->done_seq = seq;
if (work->flushing)
wake_up_all(&work->done);
}
if (kthread_should_stop()) { if (kthread_should_stop()) {
spin_unlock_irq(&dev->work_lock);
__set_current_state(TASK_RUNNING); __set_current_state(TASK_RUNNING);
break; break;
} }
if (!list_empty(&dev->work_list)) {
work = list_first_entry(&dev->work_list,
struct vhost_work, node);
list_del_init(&work->node);
seq = work->queue_seq;
} else
work = NULL;
spin_unlock_irq(&dev->work_lock);
if (work) { node = llist_del_all(&dev->work_list);
if (!node)
schedule();
node = llist_reverse_order(node);
/* make sure flag is seen after deletion */
smp_wmb();
llist_for_each_entry_safe(work, work_next, node, node) {
clear_bit(VHOST_WORK_QUEUED, &work->flags);
__set_current_state(TASK_RUNNING); __set_current_state(TASK_RUNNING);
work->fn(work); work->fn(work);
if (need_resched()) if (need_resched())
schedule(); schedule();
} else }
schedule();
} }
unuse_mm(dev->mm); unuse_mm(dev->mm);
set_fs(oldfs); set_fs(oldfs);
...@@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev, ...@@ -407,11 +404,16 @@ void vhost_dev_init(struct vhost_dev *dev,
mutex_init(&dev->mutex); mutex_init(&dev->mutex);
dev->log_ctx = NULL; dev->log_ctx = NULL;
dev->log_file = NULL; dev->log_file = NULL;
dev->memory = NULL; dev->umem = NULL;
dev->iotlb = NULL;
dev->mm = NULL; dev->mm = NULL;
spin_lock_init(&dev->work_lock);
INIT_LIST_HEAD(&dev->work_list);
dev->worker = NULL; dev->worker = NULL;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
INIT_LIST_HEAD(&dev->pending_list);
spin_lock_init(&dev->iotlb_lock);
for (i = 0; i < dev->nvqs; ++i) { for (i = 0; i < dev->nvqs; ++i) {
vq = dev->vqs[i]; vq = dev->vqs[i];
...@@ -512,27 +514,36 @@ long vhost_dev_set_owner(struct vhost_dev *dev) ...@@ -512,27 +514,36 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
} }
EXPORT_SYMBOL_GPL(vhost_dev_set_owner); EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
struct vhost_memory *vhost_dev_reset_owner_prepare(void) static void *vhost_kvzalloc(unsigned long size)
{
void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
if (!n)
n = vzalloc(size);
return n;
}
struct vhost_umem *vhost_dev_reset_owner_prepare(void)
{ {
return kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); return vhost_kvzalloc(sizeof(struct vhost_umem));
} }
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
/* Caller should have device mutex */ /* Caller should have device mutex */
void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_memory *memory) void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
{ {
int i; int i;
vhost_dev_cleanup(dev, true); vhost_dev_cleanup(dev, true);
/* Restore memory to default empty mapping. */ /* Restore memory to default empty mapping. */
memory->nregions = 0; INIT_LIST_HEAD(&umem->umem_list);
dev->memory = memory; dev->umem = umem;
/* We don't need VQ locks below since vhost_dev_cleanup makes sure /* We don't need VQ locks below since vhost_dev_cleanup makes sure
* VQs aren't running. * VQs aren't running.
*/ */
for (i = 0; i < dev->nvqs; ++i) for (i = 0; i < dev->nvqs; ++i)
dev->vqs[i]->memory = memory; dev->vqs[i]->umem = umem;
} }
EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
...@@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev) ...@@ -549,6 +560,47 @@ void vhost_dev_stop(struct vhost_dev *dev)
} }
EXPORT_SYMBOL_GPL(vhost_dev_stop); EXPORT_SYMBOL_GPL(vhost_dev_stop);
static void vhost_umem_free(struct vhost_umem *umem,
struct vhost_umem_node *node)
{
vhost_umem_interval_tree_remove(node, &umem->umem_tree);
list_del(&node->link);
kfree(node);
umem->numem--;
}
static void vhost_umem_clean(struct vhost_umem *umem)
{
struct vhost_umem_node *node, *tmp;
if (!umem)
return;
list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
vhost_umem_free(umem, node);
kvfree(umem);
}
static void vhost_clear_msg(struct vhost_dev *dev)
{
struct vhost_msg_node *node, *n;
spin_lock(&dev->iotlb_lock);
list_for_each_entry_safe(node, n, &dev->read_list, node) {
list_del(&node->node);
kfree(node);
}
list_for_each_entry_safe(node, n, &dev->pending_list, node) {
list_del(&node->node);
kfree(node);
}
spin_unlock(&dev->iotlb_lock);
}
/* Caller should have device mutex if and only if locked is set */ /* Caller should have device mutex if and only if locked is set */
void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
{ {
...@@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) ...@@ -575,9 +627,13 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
fput(dev->log_file); fput(dev->log_file);
dev->log_file = NULL; dev->log_file = NULL;
/* No one will access memory at this point */ /* No one will access memory at this point */
kvfree(dev->memory); vhost_umem_clean(dev->umem);
dev->memory = NULL; dev->umem = NULL;
WARN_ON(!list_empty(&dev->work_list)); vhost_umem_clean(dev->iotlb);
dev->iotlb = NULL;
vhost_clear_msg(dev);
wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
WARN_ON(!llist_empty(&dev->work_list));
if (dev->worker) { if (dev->worker) {
kthread_stop(dev->worker); kthread_stop(dev->worker);
dev->worker = NULL; dev->worker = NULL;
...@@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) ...@@ -601,26 +657,34 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
} }
static bool vhost_overflow(u64 uaddr, u64 size)
{
/* Make sure 64 bit math will not overflow. */
return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
}
/* Caller should have vq mutex and device mutex. */ /* Caller should have vq mutex and device mutex. */
static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
int log_all) int log_all)
{ {
int i; struct vhost_umem_node *node;
if (!mem) if (!umem)
return 0; return 0;
for (i = 0; i < mem->nregions; ++i) { list_for_each_entry(node, &umem->umem_list, link) {
struct vhost_memory_region *m = mem->regions + i; unsigned long a = node->userspace_addr;
unsigned long a = m->userspace_addr;
if (m->memory_size > ULONG_MAX) if (vhost_overflow(node->userspace_addr, node->size))
return 0; return 0;
else if (!access_ok(VERIFY_WRITE, (void __user *)a,
m->memory_size))
if (!access_ok(VERIFY_WRITE, (void __user *)a,
node->size))
return 0; return 0;
else if (log_all && !log_access_ok(log_base, else if (log_all && !log_access_ok(log_base,
m->guest_phys_addr, node->start,
m->memory_size)) node->size))
return 0; return 0;
} }
return 1; return 1;
...@@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, ...@@ -628,7 +692,7 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
/* Can we switch to this memory table? */ /* Can we switch to this memory table? */
/* Caller should have device mutex but not vq mutex */ /* Caller should have device mutex but not vq mutex */
static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
int log_all) int log_all)
{ {
int i; int i;
...@@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, ...@@ -641,7 +705,8 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
/* If ring is inactive, will check when it's enabled. */ /* If ring is inactive, will check when it's enabled. */
if (d->vqs[i]->private_data) if (d->vqs[i]->private_data)
ok = vq_memory_access_ok(d->vqs[i]->log_base, mem, log); ok = vq_memory_access_ok(d->vqs[i]->log_base,
umem, log);
else else
ok = 1; ok = 1;
mutex_unlock(&d->vqs[i]->mutex); mutex_unlock(&d->vqs[i]->mutex);
...@@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, ...@@ -651,12 +716,385 @@ static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem,
return 1; return 1;
} }
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size, int access);
static int vhost_copy_to_user(struct vhost_virtqueue *vq, void *to,
const void *from, unsigned size)
{
int ret;
if (!vq->iotlb)
return __copy_to_user(to, from, size);
else {
/* This function should be called after iotlb
* prefetch, which means we're sure that all vq
* could be access through iotlb. So -EAGAIN should
* not happen in this case.
*/
/* TODO: more fast path */
struct iov_iter t;
ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
ARRAY_SIZE(vq->iotlb_iov),
VHOST_ACCESS_WO);
if (ret < 0)
goto out;
iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
ret = copy_to_iter(from, size, &t);
if (ret == size)
ret = 0;
}
out:
return ret;
}
static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
void *from, unsigned size)
{
int ret;
if (!vq->iotlb)
return __copy_from_user(to, from, size);
else {
/* This function should be called after iotlb
* prefetch, which means we're sure that vq
* could be access through iotlb. So -EAGAIN should
* not happen in this case.
*/
/* TODO: more fast path */
struct iov_iter f;
ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
ARRAY_SIZE(vq->iotlb_iov),
VHOST_ACCESS_RO);
if (ret < 0) {
vq_err(vq, "IOTLB translation failure: uaddr "
"%p size 0x%llx\n", from,
(unsigned long long) size);
goto out;
}
iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
ret = copy_from_iter(to, size, &f);
if (ret == size)
ret = 0;
}
out:
return ret;
}
static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
void *addr, unsigned size)
{
int ret;
/* This function should be called after iotlb
* prefetch, which means we're sure that vq
* could be access through iotlb. So -EAGAIN should
* not happen in this case.
*/
/* TODO: more fast path */
ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
ARRAY_SIZE(vq->iotlb_iov),
VHOST_ACCESS_RO);
if (ret < 0) {
vq_err(vq, "IOTLB translation failure: uaddr "
"%p size 0x%llx\n", addr,
(unsigned long long) size);
return NULL;
}
if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
vq_err(vq, "Non atomic userspace memory access: uaddr "
"%p size 0x%llx\n", addr,
(unsigned long long) size);
return NULL;
}
return vq->iotlb_iov[0].iov_base;
}
#define vhost_put_user(vq, x, ptr) \
({ \
int ret = -EFAULT; \
if (!vq->iotlb) { \
ret = __put_user(x, ptr); \
} else { \
__typeof__(ptr) to = \
(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
if (to != NULL) \
ret = __put_user(x, to); \
else \
ret = -EFAULT; \
} \
ret; \
})
#define vhost_get_user(vq, x, ptr) \
({ \
int ret; \
if (!vq->iotlb) { \
ret = __get_user(x, ptr); \
} else { \
__typeof__(ptr) from = \
(__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
if (from != NULL) \
ret = __get_user(x, from); \
else \
ret = -EFAULT; \
} \
ret; \
})
static void vhost_dev_lock_vqs(struct vhost_dev *d)
{
int i = 0;
for (i = 0; i < d->nvqs; ++i)
mutex_lock(&d->vqs[i]->mutex);
}
static void vhost_dev_unlock_vqs(struct vhost_dev *d)
{
int i = 0;
for (i = 0; i < d->nvqs; ++i)
mutex_unlock(&d->vqs[i]->mutex);
}
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
{
struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
if (!node)
return -ENOMEM;
if (umem->numem == max_iotlb_entries) {
tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
vhost_umem_free(umem, tmp);
}
node->start = start;
node->size = size;
node->last = end;
node->userspace_addr = userspace_addr;
node->perm = perm;
INIT_LIST_HEAD(&node->link);
list_add_tail(&node->link, &umem->umem_list);
vhost_umem_interval_tree_insert(node, &umem->umem_tree);
umem->numem++;
return 0;
}
static void vhost_del_umem_range(struct vhost_umem *umem,
u64 start, u64 end)
{
struct vhost_umem_node *node;
while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
start, end)))
vhost_umem_free(umem, node);
}
static void vhost_iotlb_notify_vq(struct vhost_dev *d,
struct vhost_iotlb_msg *msg)
{
struct vhost_msg_node *node, *n;
spin_lock(&d->iotlb_lock);
list_for_each_entry_safe(node, n, &d->pending_list, node) {
struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
if (msg->iova <= vq_msg->iova &&
msg->iova + msg->size - 1 > vq_msg->iova &&
vq_msg->type == VHOST_IOTLB_MISS) {
vhost_poll_queue(&node->vq->poll);
list_del(&node->node);
kfree(node);
}
}
spin_unlock(&d->iotlb_lock);
}
static int umem_access_ok(u64 uaddr, u64 size, int access)
{
unsigned long a = uaddr;
/* Make sure 64 bit math will not overflow. */
if (vhost_overflow(uaddr, size))
return -EFAULT;
if ((access & VHOST_ACCESS_RO) &&
!access_ok(VERIFY_READ, (void __user *)a, size))
return -EFAULT;
if ((access & VHOST_ACCESS_WO) &&
!access_ok(VERIFY_WRITE, (void __user *)a, size))
return -EFAULT;
return 0;
}
int vhost_process_iotlb_msg(struct vhost_dev *dev,
struct vhost_iotlb_msg *msg)
{
int ret = 0;
vhost_dev_lock_vqs(dev);
switch (msg->type) {
case VHOST_IOTLB_UPDATE:
if (!dev->iotlb) {
ret = -EFAULT;
break;
}
if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
ret = -EFAULT;
break;
}
if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
msg->iova + msg->size - 1,
msg->uaddr, msg->perm)) {
ret = -ENOMEM;
break;
}
vhost_iotlb_notify_vq(dev, msg);
break;
case VHOST_IOTLB_INVALIDATE:
vhost_del_umem_range(dev->iotlb, msg->iova,
msg->iova + msg->size - 1);
break;
default:
ret = -EINVAL;
break;
}
vhost_dev_unlock_vqs(dev);
return ret;
}
ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct iov_iter *from)
{
struct vhost_msg_node node;
unsigned size = sizeof(struct vhost_msg);
size_t ret;
int err;
if (iov_iter_count(from) < size)
return 0;
ret = copy_from_iter(&node.msg, size, from);
if (ret != size)
goto done;
switch (node.msg.type) {
case VHOST_IOTLB_MSG:
err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
if (err)
ret = err;
break;
default:
ret = -EINVAL;
break;
}
done:
return ret;
}
EXPORT_SYMBOL(vhost_chr_write_iter);
unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
poll_table *wait)
{
unsigned int mask = 0;
poll_wait(file, &dev->wait, wait);
if (!list_empty(&dev->read_list))
mask |= POLLIN | POLLRDNORM;
return mask;
}
EXPORT_SYMBOL(vhost_chr_poll);
ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
int noblock)
{
DEFINE_WAIT(wait);
struct vhost_msg_node *node;
ssize_t ret = 0;
unsigned size = sizeof(struct vhost_msg);
if (iov_iter_count(to) < size)
return 0;
while (1) {
if (!noblock)
prepare_to_wait(&dev->wait, &wait,
TASK_INTERRUPTIBLE);
node = vhost_dequeue_msg(dev, &dev->read_list);
if (node)
break;
if (noblock) {
ret = -EAGAIN;
break;
}
if (signal_pending(current)) {
ret = -ERESTARTSYS;
break;
}
if (!dev->iotlb) {
ret = -EBADFD;
break;
}
schedule();
}
if (!noblock)
finish_wait(&dev->wait, &wait);
if (node) {
ret = copy_to_iter(&node->msg, size, to);
if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
kfree(node);
return ret;
}
vhost_enqueue_msg(dev, &dev->pending_list, node);
}
return ret;
}
EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
{
struct vhost_dev *dev = vq->dev;
struct vhost_msg_node *node;
struct vhost_iotlb_msg *msg;
node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
if (!node)
return -ENOMEM;
msg = &node->msg.iotlb;
msg->type = VHOST_IOTLB_MISS;
msg->iova = iova;
msg->perm = access;
vhost_enqueue_msg(dev, &dev->read_list, node);
return 0;
}
static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
struct vring_desc __user *desc, struct vring_desc __user *desc,
struct vring_avail __user *avail, struct vring_avail __user *avail,
struct vring_used __user *used) struct vring_used __user *used)
{ {
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
return access_ok(VERIFY_READ, desc, num * sizeof *desc) && return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
access_ok(VERIFY_READ, avail, access_ok(VERIFY_READ, avail,
sizeof *avail + num * sizeof *avail->ring + s) && sizeof *avail + num * sizeof *avail->ring + s) &&
...@@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, ...@@ -664,11 +1102,59 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
sizeof *used + num * sizeof *used->ring + s); sizeof *used + num * sizeof *used->ring + s);
} }
static int iotlb_access_ok(struct vhost_virtqueue *vq,
int access, u64 addr, u64 len)
{
const struct vhost_umem_node *node;
struct vhost_umem *umem = vq->iotlb;
u64 s = 0, size;
while (len > s) {
node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
addr,
addr + len - 1);
if (node == NULL || node->start > addr) {
vhost_iotlb_miss(vq, addr, access);
return false;
} else if (!(node->perm & access)) {
/* Report the possible access violation by
* request another translation from userspace.
*/
return false;
}
size = node->size - addr + node->start;
s += size;
addr += size;
}
return true;
}
int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
{
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
unsigned int num = vq->num;
if (!vq->iotlb)
return 1;
return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
num * sizeof *vq->desc) &&
iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
sizeof *vq->avail +
num * sizeof *vq->avail->ring + s) &&
iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
sizeof *vq->used +
num * sizeof *vq->used->ring + s);
}
EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
/* Can we log writes? */ /* Can we log writes? */
/* Caller should have device mutex but not vq mutex */ /* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev) int vhost_log_access_ok(struct vhost_dev *dev)
{ {
return memory_access_ok(dev, dev->memory, 1); return memory_access_ok(dev, dev->umem, 1);
} }
EXPORT_SYMBOL_GPL(vhost_log_access_ok); EXPORT_SYMBOL_GPL(vhost_log_access_ok);
...@@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, ...@@ -679,7 +1165,7 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
{ {
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
return vq_memory_access_ok(log_base, vq->memory, return vq_memory_access_ok(log_base, vq->umem,
vhost_has_feature(vq, VHOST_F_LOG_ALL)) && vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr, (!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used + sizeof *vq->used +
...@@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq, ...@@ -690,33 +1176,36 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
/* Caller should have vq mutex and device mutex */ /* Caller should have vq mutex and device mutex */
int vhost_vq_access_ok(struct vhost_virtqueue *vq) int vhost_vq_access_ok(struct vhost_virtqueue *vq)
{ {
if (vq->iotlb) {
/* When device IOTLB was used, the access validation
* will be validated during prefetching.
*/
return 1;
}
return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
vq_log_access_ok(vq, vq->log_base); vq_log_access_ok(vq, vq->log_base);
} }
EXPORT_SYMBOL_GPL(vhost_vq_access_ok); EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2) static struct vhost_umem *vhost_umem_alloc(void)
{ {
const struct vhost_memory_region *r1 = p1, *r2 = p2; struct vhost_umem *umem = vhost_kvzalloc(sizeof(*umem));
if (r1->guest_phys_addr < r2->guest_phys_addr)
return 1;
if (r1->guest_phys_addr > r2->guest_phys_addr)
return -1;
return 0;
}
static void *vhost_kvzalloc(unsigned long size) if (!umem)
{ return NULL;
void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
if (!n) umem->umem_tree = RB_ROOT;
n = vzalloc(size); umem->numem = 0;
return n; INIT_LIST_HEAD(&umem->umem_list);
return umem;
} }
static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
{ {
struct vhost_memory mem, *newmem, *oldmem; struct vhost_memory mem, *newmem;
struct vhost_memory_region *region;
struct vhost_umem *newumem, *oldumem;
unsigned long size = offsetof(struct vhost_memory, regions); unsigned long size = offsetof(struct vhost_memory, regions);
int i; int i;
...@@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) ...@@ -736,24 +1225,47 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
kvfree(newmem); kvfree(newmem);
return -EFAULT; return -EFAULT;
} }
sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
vhost_memory_reg_sort_cmp, NULL);
if (!memory_access_ok(d, newmem, 0)) { newumem = vhost_umem_alloc();
if (!newumem) {
kvfree(newmem); kvfree(newmem);
return -EFAULT; return -ENOMEM;
}
for (region = newmem->regions;
region < newmem->regions + mem.nregions;
region++) {
if (vhost_new_umem_range(newumem,
region->guest_phys_addr,
region->memory_size,
region->guest_phys_addr +
region->memory_size - 1,
region->userspace_addr,
VHOST_ACCESS_RW))
goto err;
} }
oldmem = d->memory;
d->memory = newmem; if (!memory_access_ok(d, newumem, 0))
goto err;
oldumem = d->umem;
d->umem = newumem;
/* All memory accesses are done under some VQ mutex. */ /* All memory accesses are done under some VQ mutex. */
for (i = 0; i < d->nvqs; ++i) { for (i = 0; i < d->nvqs; ++i) {
mutex_lock(&d->vqs[i]->mutex); mutex_lock(&d->vqs[i]->mutex);
d->vqs[i]->memory = newmem; d->vqs[i]->umem = newumem;
mutex_unlock(&d->vqs[i]->mutex); mutex_unlock(&d->vqs[i]->mutex);
} }
kvfree(oldmem);
kvfree(newmem);
vhost_umem_clean(oldumem);
return 0; return 0;
err:
vhost_umem_clean(newumem);
kvfree(newmem);
return -EFAULT;
} }
long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
...@@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) ...@@ -974,6 +1486,30 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
} }
EXPORT_SYMBOL_GPL(vhost_vring_ioctl); EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
{
struct vhost_umem *niotlb, *oiotlb;
int i;
niotlb = vhost_umem_alloc();
if (!niotlb)
return -ENOMEM;
oiotlb = d->iotlb;
d->iotlb = niotlb;
for (i = 0; i < d->nvqs; ++i) {
mutex_lock(&d->vqs[i]->mutex);
d->vqs[i]->iotlb = niotlb;
mutex_unlock(&d->vqs[i]->mutex);
}
vhost_umem_clean(oiotlb);
return 0;
}
EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
/* Caller must have device mutex */ /* Caller must have device mutex */
long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
{ {
...@@ -1056,28 +1592,6 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) ...@@ -1056,28 +1592,6 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
} }
EXPORT_SYMBOL_GPL(vhost_dev_ioctl); EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
__u64 addr, __u32 len)
{
const struct vhost_memory_region *reg;
int start = 0, end = mem->nregions;
while (start < end) {
int slot = start + (end - start) / 2;
reg = mem->regions + slot;
if (addr >= reg->guest_phys_addr)
end = slot;
else
start = slot + 1;
}
reg = mem->regions + start;
if (addr >= reg->guest_phys_addr &&
reg->guest_phys_addr + reg->memory_size > addr)
return reg;
return NULL;
}
/* TODO: This is really inefficient. We need something like get_user() /* TODO: This is really inefficient. We need something like get_user()
* (instruction directly accesses the data, with an exception table entry * (instruction directly accesses the data, with an exception table entry
* returning -EFAULT). See Documentation/x86/exception-tables.txt. * returning -EFAULT). See Documentation/x86/exception-tables.txt.
...@@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write); ...@@ -1156,7 +1670,8 @@ EXPORT_SYMBOL_GPL(vhost_log_write);
static int vhost_update_used_flags(struct vhost_virtqueue *vq) static int vhost_update_used_flags(struct vhost_virtqueue *vq)
{ {
void __user *used; void __user *used;
if (__put_user(cpu_to_vhost16(vq, vq->used_flags), &vq->used->flags) < 0) if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
&vq->used->flags) < 0)
return -EFAULT; return -EFAULT;
if (unlikely(vq->log_used)) { if (unlikely(vq->log_used)) {
/* Make sure the flag is seen before log. */ /* Make sure the flag is seen before log. */
...@@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq) ...@@ -1174,7 +1689,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
{ {
if (__put_user(cpu_to_vhost16(vq, vq->avail_idx), vhost_avail_event(vq))) if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
vhost_avail_event(vq)))
return -EFAULT; return -EFAULT;
if (unlikely(vq->log_used)) { if (unlikely(vq->log_used)) {
void __user *used; void __user *used;
...@@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) ...@@ -1208,15 +1724,20 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
if (r) if (r)
goto err; goto err;
vq->signalled_used_valid = false; vq->signalled_used_valid = false;
if (!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { if (!vq->iotlb &&
!access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
r = -EFAULT; r = -EFAULT;
goto err; goto err;
} }
r = __get_user(last_used_idx, &vq->used->idx); r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
if (r) if (r) {
vq_err(vq, "Can't access used idx at %p\n",
&vq->used->idx);
goto err; goto err;
}
vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
return 0; return 0;
err: err:
vq->is_le = is_le; vq->is_le = is_le;
return r; return r;
...@@ -1224,36 +1745,48 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) ...@@ -1224,36 +1745,48 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
EXPORT_SYMBOL_GPL(vhost_vq_init_access); EXPORT_SYMBOL_GPL(vhost_vq_init_access);
static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
struct iovec iov[], int iov_size) struct iovec iov[], int iov_size, int access)
{ {
const struct vhost_memory_region *reg; const struct vhost_umem_node *node;
struct vhost_memory *mem; struct vhost_dev *dev = vq->dev;
struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
struct iovec *_iov; struct iovec *_iov;
u64 s = 0; u64 s = 0;
int ret = 0; int ret = 0;
mem = vq->memory;
while ((u64)len > s) { while ((u64)len > s) {
u64 size; u64 size;
if (unlikely(ret >= iov_size)) { if (unlikely(ret >= iov_size)) {
ret = -ENOBUFS; ret = -ENOBUFS;
break; break;
} }
reg = find_region(mem, addr, len);
if (unlikely(!reg)) { node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
addr, addr + len - 1);
if (node == NULL || node->start > addr) {
if (umem != dev->iotlb) {
ret = -EFAULT; ret = -EFAULT;
break; break;
} }
ret = -EAGAIN;
break;
} else if (!(node->perm & access)) {
ret = -EPERM;
break;
}
_iov = iov + ret; _iov = iov + ret;
size = reg->memory_size - addr + reg->guest_phys_addr; size = node->size - addr + node->start;
_iov->iov_len = min((u64)len - s, size); _iov->iov_len = min((u64)len - s, size);
_iov->iov_base = (void __user *)(unsigned long) _iov->iov_base = (void __user *)(unsigned long)
(reg->userspace_addr + addr - reg->guest_phys_addr); (node->userspace_addr + addr - node->start);
s += size; s += size;
addr += size; addr += size;
++ret; ++ret;
} }
if (ret == -EAGAIN)
vhost_iotlb_miss(vq, addr, access);
return ret; return ret;
} }
...@@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq, ...@@ -1288,7 +1821,7 @@ static int get_indirect(struct vhost_virtqueue *vq,
unsigned int i = 0, count, found = 0; unsigned int i = 0, count, found = 0;
u32 len = vhost32_to_cpu(vq, indirect->len); u32 len = vhost32_to_cpu(vq, indirect->len);
struct iov_iter from; struct iov_iter from;
int ret; int ret, access;
/* Sanity check */ /* Sanity check */
if (unlikely(len % sizeof desc)) { if (unlikely(len % sizeof desc)) {
...@@ -1300,8 +1833,9 @@ static int get_indirect(struct vhost_virtqueue *vq, ...@@ -1300,8 +1833,9 @@ static int get_indirect(struct vhost_virtqueue *vq,
} }
ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
UIO_MAXIOV); UIO_MAXIOV, VHOST_ACCESS_RO);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
if (ret != -EAGAIN)
vq_err(vq, "Translation failure %d in indirect.\n", ret); vq_err(vq, "Translation failure %d in indirect.\n", ret);
return ret; return ret;
} }
...@@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq, ...@@ -1340,16 +1874,22 @@ static int get_indirect(struct vhost_virtqueue *vq,
return -EINVAL; return -EINVAL;
} }
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
access = VHOST_ACCESS_WO;
else
access = VHOST_ACCESS_RO;
ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
vhost32_to_cpu(vq, desc.len), iov + iov_count, vhost32_to_cpu(vq, desc.len), iov + iov_count,
iov_size - iov_count); iov_size - iov_count, access);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
if (ret != -EAGAIN)
vq_err(vq, "Translation failure %d indirect idx %d\n", vq_err(vq, "Translation failure %d indirect idx %d\n",
ret, i); ret, i);
return ret; return ret;
} }
/* If this is an input descriptor, increment that count. */ /* If this is an input descriptor, increment that count. */
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { if (access == VHOST_ACCESS_WO) {
*in_num += ret; *in_num += ret;
if (unlikely(log)) { if (unlikely(log)) {
log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
...@@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1388,11 +1928,11 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
u16 last_avail_idx; u16 last_avail_idx;
__virtio16 avail_idx; __virtio16 avail_idx;
__virtio16 ring_head; __virtio16 ring_head;
int ret; int ret, access;
/* Check it isn't doing very strange things with descriptor numbers. */ /* Check it isn't doing very strange things with descriptor numbers. */
last_avail_idx = vq->last_avail_idx; last_avail_idx = vq->last_avail_idx;
if (unlikely(__get_user(avail_idx, &vq->avail->idx))) { if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) {
vq_err(vq, "Failed to access avail idx at %p\n", vq_err(vq, "Failed to access avail idx at %p\n",
&vq->avail->idx); &vq->avail->idx);
return -EFAULT; return -EFAULT;
...@@ -1414,7 +1954,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1414,7 +1954,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
/* Grab the next descriptor number they're advertising, and increment /* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */ * the index we've seen. */
if (unlikely(__get_user(ring_head, if (unlikely(vhost_get_user(vq, ring_head,
&vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
vq_err(vq, "Failed to read head: idx %d address %p\n", vq_err(vq, "Failed to read head: idx %d address %p\n",
last_avail_idx, last_avail_idx,
...@@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1450,7 +1990,8 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
i, vq->num, head); i, vq->num, head);
return -EINVAL; return -EINVAL;
} }
ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
sizeof desc);
if (unlikely(ret)) { if (unlikely(ret)) {
vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
i, vq->desc + i); i, vq->desc + i);
...@@ -1461,6 +2002,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1461,6 +2002,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
out_num, in_num, out_num, in_num,
log, log_num, &desc); log, log_num, &desc);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
if (ret != -EAGAIN)
vq_err(vq, "Failure detected " vq_err(vq, "Failure detected "
"in indirect descriptor at idx %d\n", i); "in indirect descriptor at idx %d\n", i);
return ret; return ret;
...@@ -1468,15 +2010,20 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq, ...@@ -1468,15 +2010,20 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
continue; continue;
} }
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
access = VHOST_ACCESS_WO;
else
access = VHOST_ACCESS_RO;
ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
vhost32_to_cpu(vq, desc.len), iov + iov_count, vhost32_to_cpu(vq, desc.len), iov + iov_count,
iov_size - iov_count); iov_size - iov_count, access);
if (unlikely(ret < 0)) { if (unlikely(ret < 0)) {
if (ret != -EAGAIN)
vq_err(vq, "Translation failure %d descriptor idx %d\n", vq_err(vq, "Translation failure %d descriptor idx %d\n",
ret, i); ret, i);
return ret; return ret;
} }
if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) { if (access == VHOST_ACCESS_WO) {
/* If this is an input descriptor, /* If this is an input descriptor,
* increment that count. */ * increment that count. */
*in_num += ret; *in_num += ret;
...@@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, ...@@ -1538,15 +2085,15 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
start = vq->last_used_idx & (vq->num - 1); start = vq->last_used_idx & (vq->num - 1);
used = vq->used->ring + start; used = vq->used->ring + start;
if (count == 1) { if (count == 1) {
if (__put_user(heads[0].id, &used->id)) { if (vhost_put_user(vq, heads[0].id, &used->id)) {
vq_err(vq, "Failed to write used id"); vq_err(vq, "Failed to write used id");
return -EFAULT; return -EFAULT;
} }
if (__put_user(heads[0].len, &used->len)) { if (vhost_put_user(vq, heads[0].len, &used->len)) {
vq_err(vq, "Failed to write used len"); vq_err(vq, "Failed to write used len");
return -EFAULT; return -EFAULT;
} }
} else if (__copy_to_user(used, heads, count * sizeof *used)) { } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
vq_err(vq, "Failed to write used"); vq_err(vq, "Failed to write used");
return -EFAULT; return -EFAULT;
} }
...@@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, ...@@ -1590,7 +2137,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
/* Make sure buffer is written before we update index. */ /* Make sure buffer is written before we update index. */
smp_wmb(); smp_wmb();
if (__put_user(cpu_to_vhost16(vq, vq->last_used_idx), &vq->used->idx)) { if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
&vq->used->idx)) {
vq_err(vq, "Failed to increment used idx"); vq_err(vq, "Failed to increment used idx");
return -EFAULT; return -EFAULT;
} }
...@@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) ...@@ -1622,7 +2170,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
__virtio16 flags; __virtio16 flags;
if (__get_user(flags, &vq->avail->flags)) { if (vhost_get_user(vq, flags, &vq->avail->flags)) {
vq_err(vq, "Failed to get flags"); vq_err(vq, "Failed to get flags");
return true; return true;
} }
...@@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) ...@@ -1636,7 +2184,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
if (unlikely(!v)) if (unlikely(!v))
return true; return true;
if (__get_user(event, vhost_used_event(vq))) { if (vhost_get_user(vq, event, vhost_used_event(vq))) {
vq_err(vq, "Failed to get used event idx"); vq_err(vq, "Failed to get used event idx");
return true; return true;
} }
...@@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) ...@@ -1678,7 +2226,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
__virtio16 avail_idx; __virtio16 avail_idx;
int r; int r;
r = __get_user(avail_idx, &vq->avail->idx); r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
if (r) if (r)
return false; return false;
...@@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) ...@@ -1713,7 +2261,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
/* They could have slipped one in as we were doing that: make /* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */ * sure it's written, then check again. */
smp_mb(); smp_mb();
r = __get_user(avail_idx, &vq->avail->idx); r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
if (r) { if (r) {
vq_err(vq, "Failed to check avail idx at %p: %d\n", vq_err(vq, "Failed to check avail idx at %p: %d\n",
&vq->avail->idx, r); &vq->avail->idx, r);
...@@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) ...@@ -1741,6 +2289,47 @@ void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
} }
EXPORT_SYMBOL_GPL(vhost_disable_notify); EXPORT_SYMBOL_GPL(vhost_disable_notify);
/* Create a new message. */
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
{
struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
if (!node)
return NULL;
node->vq = vq;
node->msg.type = type;
return node;
}
EXPORT_SYMBOL_GPL(vhost_new_msg);
void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
struct vhost_msg_node *node)
{
spin_lock(&dev->iotlb_lock);
list_add_tail(&node->node, head);
spin_unlock(&dev->iotlb_lock);
wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
}
EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
struct list_head *head)
{
struct vhost_msg_node *node = NULL;
spin_lock(&dev->iotlb_lock);
if (!list_empty(head)) {
node = list_first_entry(head, struct vhost_msg_node,
node);
list_del(&node->node);
}
spin_unlock(&dev->iotlb_lock);
return node;
}
EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
static int __init vhost_init(void) static int __init vhost_init(void)
{ {
return 0; return 0;
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
struct vhost_work; struct vhost_work;
typedef void (*vhost_work_fn_t)(struct vhost_work *work); typedef void (*vhost_work_fn_t)(struct vhost_work *work);
#define VHOST_WORK_QUEUED 1
struct vhost_work { struct vhost_work {
struct list_head node; struct llist_node node;
vhost_work_fn_t fn; vhost_work_fn_t fn;
wait_queue_head_t done; wait_queue_head_t done;
int flushing; int flushing;
unsigned queue_seq; unsigned queue_seq;
unsigned done_seq; unsigned done_seq;
unsigned long flags;
}; };
/* Poll a file (eventfd or socket) */ /* Poll a file (eventfd or socket) */
...@@ -53,6 +55,27 @@ struct vhost_log { ...@@ -53,6 +55,27 @@ struct vhost_log {
u64 len; u64 len;
}; };
#define START(node) ((node)->start)
#define LAST(node) ((node)->last)
struct vhost_umem_node {
struct rb_node rb;
struct list_head link;
__u64 start;
__u64 last;
__u64 size;
__u64 userspace_addr;
__u32 perm;
__u32 flags_padding;
__u64 __subtree_last;
};
struct vhost_umem {
struct rb_root umem_tree;
struct list_head umem_list;
int numem;
};
/* The virtqueue structure describes a queue attached to a device. */ /* The virtqueue structure describes a queue attached to a device. */
struct vhost_virtqueue { struct vhost_virtqueue {
struct vhost_dev *dev; struct vhost_dev *dev;
...@@ -98,10 +121,12 @@ struct vhost_virtqueue { ...@@ -98,10 +121,12 @@ struct vhost_virtqueue {
u64 log_addr; u64 log_addr;
struct iovec iov[UIO_MAXIOV]; struct iovec iov[UIO_MAXIOV];
struct iovec iotlb_iov[64];
struct iovec *indirect; struct iovec *indirect;
struct vring_used_elem *heads; struct vring_used_elem *heads;
/* Protected by virtqueue mutex. */ /* Protected by virtqueue mutex. */
struct vhost_memory *memory; struct vhost_umem *umem;
struct vhost_umem *iotlb;
void *private_data; void *private_data;
u64 acked_features; u64 acked_features;
/* Log write descriptors */ /* Log write descriptors */
...@@ -118,25 +143,35 @@ struct vhost_virtqueue { ...@@ -118,25 +143,35 @@ struct vhost_virtqueue {
u32 busyloop_timeout; u32 busyloop_timeout;
}; };
struct vhost_msg_node {
struct vhost_msg msg;
struct vhost_virtqueue *vq;
struct list_head node;
};
struct vhost_dev { struct vhost_dev {
struct vhost_memory *memory;
struct mm_struct *mm; struct mm_struct *mm;
struct mutex mutex; struct mutex mutex;
struct vhost_virtqueue **vqs; struct vhost_virtqueue **vqs;
int nvqs; int nvqs;
struct file *log_file; struct file *log_file;
struct eventfd_ctx *log_ctx; struct eventfd_ctx *log_ctx;
spinlock_t work_lock; struct llist_head work_list;
struct list_head work_list;
struct task_struct *worker; struct task_struct *worker;
struct vhost_umem *umem;
struct vhost_umem *iotlb;
spinlock_t iotlb_lock;
struct list_head read_list;
struct list_head pending_list;
wait_queue_head_t wait;
}; };
void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs); void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
long vhost_dev_set_owner(struct vhost_dev *dev); long vhost_dev_set_owner(struct vhost_dev *dev);
bool vhost_dev_has_owner(struct vhost_dev *dev); bool vhost_dev_has_owner(struct vhost_dev *dev);
long vhost_dev_check_owner(struct vhost_dev *); long vhost_dev_check_owner(struct vhost_dev *);
struct vhost_memory *vhost_dev_reset_owner_prepare(void); struct vhost_umem *vhost_dev_reset_owner_prepare(void);
void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_memory *); void vhost_dev_reset_owner(struct vhost_dev *, struct vhost_umem *);
void vhost_dev_cleanup(struct vhost_dev *, bool locked); void vhost_dev_cleanup(struct vhost_dev *, bool locked);
void vhost_dev_stop(struct vhost_dev *); void vhost_dev_stop(struct vhost_dev *);
long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp); long vhost_dev_ioctl(struct vhost_dev *, unsigned int ioctl, void __user *argp);
...@@ -165,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *); ...@@ -165,6 +200,21 @@ bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
unsigned int log_num, u64 len); unsigned int log_num, u64 len);
int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
void vhost_enqueue_msg(struct vhost_dev *dev,
struct list_head *head,
struct vhost_msg_node *node);
struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
struct list_head *head);
unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
poll_table *wait);
ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
int noblock);
ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
struct iov_iter *from);
int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled);
#define vq_err(vq, fmt, ...) do { \ #define vq_err(vq, fmt, ...) do { \
pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \
......
/*
* vhost transport for vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
*
* This work is licensed under the terms of the GNU GPL, version 2.
*/
#include <linux/miscdevice.h>
#include <linux/atomic.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/vmalloc.h>
#include <net/sock.h>
#include <linux/virtio_vsock.h>
#include <linux/vhost.h>
#include <net/af_vsock.h>
#include "vhost.h"
#define VHOST_VSOCK_DEFAULT_HOST_CID 2
enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES,
};
/* Used to track all the vhost_vsock instances on the system. */
static DEFINE_SPINLOCK(vhost_vsock_lock);
static LIST_HEAD(vhost_vsock_list);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
/* Link to global vhost_vsock_list, protected by vhost_vsock_lock */
struct list_head list;
struct vhost_work send_pkt_work;
spinlock_t send_pkt_list_lock;
struct list_head send_pkt_list; /* host->guest pending packets */
atomic_t queued_replies;
u32 guest_cid;
};
static u32 vhost_transport_get_local_cid(void)
{
return VHOST_VSOCK_DEFAULT_HOST_CID;
}
static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
{
struct vhost_vsock *vsock;
spin_lock_bh(&vhost_vsock_lock);
list_for_each_entry(vsock, &vhost_vsock_list, list) {
u32 other_cid = vsock->guest_cid;
/* Skip instances that have no CID yet */
if (other_cid == 0)
continue;
if (other_cid == guest_cid) {
spin_unlock_bh(&vhost_vsock_lock);
return vsock;
}
}
spin_unlock_bh(&vhost_vsock_lock);
return NULL;
}
static void
vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
struct vhost_virtqueue *vq)
{
struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
bool added = false;
bool restart_tx = false;
mutex_lock(&vq->mutex);
if (!vq->private_data)
goto out;
/* Avoid further vmexits, we're already processing the virtqueue */
vhost_disable_notify(&vsock->dev, vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
struct iov_iter iov_iter;
unsigned out, in;
size_t nbytes;
size_t len;
int head;
spin_lock_bh(&vsock->send_pkt_list_lock);
if (list_empty(&vsock->send_pkt_list)) {
spin_unlock_bh(&vsock->send_pkt_list_lock);
vhost_enable_notify(&vsock->dev, vq);
break;
}
pkt = list_first_entry(&vsock->send_pkt_list,
struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
&out, &in, NULL, NULL);
if (head < 0) {
spin_lock_bh(&vsock->send_pkt_list_lock);
list_add(&pkt->list, &vsock->send_pkt_list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
break;
}
if (head == vq->num) {
spin_lock_bh(&vsock->send_pkt_list_lock);
list_add(&pkt->list, &vsock->send_pkt_list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
/* We cannot finish yet if more buffers snuck in while
* re-enabling notify.
*/
if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
vhost_disable_notify(&vsock->dev, vq);
continue;
}
break;
}
if (out) {
virtio_transport_free_pkt(pkt);
vq_err(vq, "Expected 0 output buffers, got %u\n", out);
break;
}
len = iov_length(&vq->iov[out], in);
iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
if (nbytes != sizeof(pkt->hdr)) {
virtio_transport_free_pkt(pkt);
vq_err(vq, "Faulted on copying pkt hdr\n");
break;
}
nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
if (nbytes != pkt->len) {
virtio_transport_free_pkt(pkt);
vq_err(vq, "Faulted on copying pkt buf\n");
break;
}
vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
added = true;
if (pkt->reply) {
int val;
val = atomic_dec_return(&vsock->queued_replies);
/* Do we have resources to resume tx processing? */
if (val + 1 == tx_vq->num)
restart_tx = true;
}
virtio_transport_free_pkt(pkt);
}
if (added)
vhost_signal(&vsock->dev, vq);
out:
mutex_unlock(&vq->mutex);
if (restart_tx)
vhost_poll_queue(&tx_vq->poll);
}
static void vhost_transport_send_pkt_work(struct vhost_work *work)
{
struct vhost_virtqueue *vq;
struct vhost_vsock *vsock;
vsock = container_of(work, struct vhost_vsock, send_pkt_work);
vq = &vsock->vqs[VSOCK_VQ_RX];
vhost_transport_do_send_pkt(vsock, vq);
}
static int
vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
{
struct vhost_vsock *vsock;
struct vhost_virtqueue *vq;
int len = pkt->len;
/* Find the vhost_vsock according to guest context id */
vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
if (!vsock) {
virtio_transport_free_pkt(pkt);
return -ENODEV;
}
vq = &vsock->vqs[VSOCK_VQ_RX];
if (pkt->reply)
atomic_inc(&vsock->queued_replies);
spin_lock_bh(&vsock->send_pkt_list_lock);
list_add_tail(&pkt->list, &vsock->send_pkt_list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
return len;
}
static struct virtio_vsock_pkt *
vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
unsigned int out, unsigned int in)
{
struct virtio_vsock_pkt *pkt;
struct iov_iter iov_iter;
size_t nbytes;
size_t len;
if (in != 0) {
vq_err(vq, "Expected 0 input buffers, got %u\n", in);
return NULL;
}
pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
if (!pkt)
return NULL;
len = iov_length(vq->iov, out);
iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
if (nbytes != sizeof(pkt->hdr)) {
vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
sizeof(pkt->hdr), nbytes);
kfree(pkt);
return NULL;
}
if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
pkt->len = le32_to_cpu(pkt->hdr.len);
/* No payload */
if (!pkt->len)
return pkt;
/* The pkt is too big */
if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
kfree(pkt);
return NULL;
}
pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
if (!pkt->buf) {
kfree(pkt);
return NULL;
}
nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
if (nbytes != pkt->len) {
vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
pkt->len, nbytes);
virtio_transport_free_pkt(pkt);
return NULL;
}
return pkt;
}
/* Is there space left for replies to rx packets? */
static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
{
struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
int val;
smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
val = atomic_read(&vsock->queued_replies);
return val < vq->num;
}
static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
{
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
poll.work);
struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
dev);
struct virtio_vsock_pkt *pkt;
int head;
unsigned int out, in;
bool added = false;
mutex_lock(&vq->mutex);
if (!vq->private_data)
goto out;
vhost_disable_notify(&vsock->dev, vq);
for (;;) {
if (!vhost_vsock_more_replies(vsock)) {
/* Stop tx until the device processes already
* pending replies. Leave tx virtqueue
* callbacks disabled.
*/
goto no_more_replies;
}
head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
&out, &in, NULL, NULL);
if (head < 0)
break;
if (head == vq->num) {
if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
vhost_disable_notify(&vsock->dev, vq);
continue;
}
break;
}
pkt = vhost_vsock_alloc_pkt(vq, out, in);
if (!pkt) {
vq_err(vq, "Faulted on pkt\n");
continue;
}
/* Only accept correctly addressed packets */
if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid)
virtio_transport_recv_pkt(pkt);
else
virtio_transport_free_pkt(pkt);
vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
added = true;
}
no_more_replies:
if (added)
vhost_signal(&vsock->dev, vq);
out:
mutex_unlock(&vq->mutex);
}
static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
{
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
poll.work);
struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
dev);
vhost_transport_do_send_pkt(vsock, vq);
}
static int vhost_vsock_start(struct vhost_vsock *vsock)
{
size_t i;
int ret;
mutex_lock(&vsock->dev.mutex);
ret = vhost_dev_check_owner(&vsock->dev);
if (ret)
goto err;
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
struct vhost_virtqueue *vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
if (!vhost_vq_access_ok(vq)) {
ret = -EFAULT;
mutex_unlock(&vq->mutex);
goto err_vq;
}
if (!vq->private_data) {
vq->private_data = vsock;
vhost_vq_init_access(vq);
}
mutex_unlock(&vq->mutex);
}
mutex_unlock(&vsock->dev.mutex);
return 0;
err_vq:
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
struct vhost_virtqueue *vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
vq->private_data = NULL;
mutex_unlock(&vq->mutex);
}
err:
mutex_unlock(&vsock->dev.mutex);
return ret;
}
static int vhost_vsock_stop(struct vhost_vsock *vsock)
{
size_t i;
int ret;
mutex_lock(&vsock->dev.mutex);
ret = vhost_dev_check_owner(&vsock->dev);
if (ret)
goto err;
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
struct vhost_virtqueue *vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
vq->private_data = NULL;
mutex_unlock(&vq->mutex);
}
err:
mutex_unlock(&vsock->dev.mutex);
return ret;
}
static void vhost_vsock_free(struct vhost_vsock *vsock)
{
kvfree(vsock);
}
static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
{
struct vhost_virtqueue **vqs;
struct vhost_vsock *vsock;
int ret;
/* This struct is large and allocation could fail, fall back to vmalloc
* if there is no other way.
*/
vsock = kzalloc(sizeof(*vsock), GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
if (!vsock) {
vsock = vmalloc(sizeof(*vsock));
if (!vsock)
return -ENOMEM;
}
vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
if (!vqs) {
ret = -ENOMEM;
goto out;
}
atomic_set(&vsock->queued_replies, 0);
vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs));
file->private_data = vsock;
spin_lock_init(&vsock->send_pkt_list_lock);
INIT_LIST_HEAD(&vsock->send_pkt_list);
vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
spin_lock_bh(&vhost_vsock_lock);
list_add_tail(&vsock->list, &vhost_vsock_list);
spin_unlock_bh(&vhost_vsock_lock);
return 0;
out:
vhost_vsock_free(vsock);
return ret;
}
static void vhost_vsock_flush(struct vhost_vsock *vsock)
{
int i;
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
if (vsock->vqs[i].handle_kick)
vhost_poll_flush(&vsock->vqs[i].poll);
vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
}
static void vhost_vsock_reset_orphans(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
/* vmci_transport.c doesn't take sk_lock here either. At least we're
* under vsock_table_lock so the sock cannot disappear while we're
* executing.
*/
if (!vhost_vsock_get(vsk->local_addr.svm_cid)) {
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
sk->sk_state = SS_UNCONNECTED;
sk->sk_err = ECONNRESET;
sk->sk_error_report(sk);
}
}
static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
{
struct vhost_vsock *vsock = file->private_data;
spin_lock_bh(&vhost_vsock_lock);
list_del(&vsock->list);
spin_unlock_bh(&vhost_vsock_lock);
/* Iterating over all connections for all CIDs to find orphans is
* inefficient. Room for improvement here. */
vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
vhost_vsock_stop(vsock);
vhost_vsock_flush(vsock);
vhost_dev_stop(&vsock->dev);
spin_lock_bh(&vsock->send_pkt_list_lock);
while (!list_empty(&vsock->send_pkt_list)) {
struct virtio_vsock_pkt *pkt;
pkt = list_first_entry(&vsock->send_pkt_list,
struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);
virtio_transport_free_pkt(pkt);
}
spin_unlock_bh(&vsock->send_pkt_list_lock);
vhost_dev_cleanup(&vsock->dev, false);
kfree(vsock->dev.vqs);
vhost_vsock_free(vsock);
return 0;
}
static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
{
struct vhost_vsock *other;
/* Refuse reserved CIDs */
if (guest_cid <= VMADDR_CID_HOST ||
guest_cid == U32_MAX)
return -EINVAL;
/* 64-bit CIDs are not yet supported */
if (guest_cid > U32_MAX)
return -EINVAL;
/* Refuse if CID is already in use */
other = vhost_vsock_get(guest_cid);
if (other && other != vsock)
return -EADDRINUSE;
spin_lock_bh(&vhost_vsock_lock);
vsock->guest_cid = guest_cid;
spin_unlock_bh(&vhost_vsock_lock);
return 0;
}
static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
{
struct vhost_virtqueue *vq;
int i;
if (features & ~VHOST_VSOCK_FEATURES)
return -EOPNOTSUPP;
mutex_lock(&vsock->dev.mutex);
if ((features & (1 << VHOST_F_LOG_ALL)) &&
!vhost_log_access_ok(&vsock->dev)) {
mutex_unlock(&vsock->dev.mutex);
return -EFAULT;
}
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
vq->acked_features = features;
mutex_unlock(&vq->mutex);
}
mutex_unlock(&vsock->dev.mutex);
return 0;
}
static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
unsigned long arg)
{
struct vhost_vsock *vsock = f->private_data;
void __user *argp = (void __user *)arg;
u64 guest_cid;
u64 features;
int start;
int r;
switch (ioctl) {
case VHOST_VSOCK_SET_GUEST_CID:
if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
return -EFAULT;
return vhost_vsock_set_cid(vsock, guest_cid);
case VHOST_VSOCK_SET_RUNNING:
if (copy_from_user(&start, argp, sizeof(start)))
return -EFAULT;
if (start)
return vhost_vsock_start(vsock);
else
return vhost_vsock_stop(vsock);
case VHOST_GET_FEATURES:
features = VHOST_VSOCK_FEATURES;
if (copy_to_user(argp, &features, sizeof(features)))
return -EFAULT;
return 0;
case VHOST_SET_FEATURES:
if (copy_from_user(&features, argp, sizeof(features)))
return -EFAULT;
return vhost_vsock_set_features(vsock, features);
default:
mutex_lock(&vsock->dev.mutex);
r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
if (r == -ENOIOCTLCMD)
r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
else
vhost_vsock_flush(vsock);
mutex_unlock(&vsock->dev.mutex);
return r;
}
}
static const struct file_operations vhost_vsock_fops = {
.owner = THIS_MODULE,
.open = vhost_vsock_dev_open,
.release = vhost_vsock_dev_release,
.llseek = noop_llseek,
.unlocked_ioctl = vhost_vsock_dev_ioctl,
};
static struct miscdevice vhost_vsock_misc = {
.minor = MISC_DYNAMIC_MINOR,
.name = "vhost-vsock",
.fops = &vhost_vsock_fops,
};
static struct virtio_transport vhost_transport = {
.transport = {
.get_local_cid = vhost_transport_get_local_cid,
.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
.release = virtio_transport_release,
.connect = virtio_transport_connect,
.shutdown = virtio_transport_shutdown,
.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_dequeue = virtio_transport_dgram_dequeue,
.dgram_bind = virtio_transport_dgram_bind,
.dgram_allow = virtio_transport_dgram_allow,
.stream_enqueue = virtio_transport_stream_enqueue,
.stream_dequeue = virtio_transport_stream_dequeue,
.stream_has_data = virtio_transport_stream_has_data,
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
.notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
.notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
.notify_send_init = virtio_transport_notify_send_init,
.notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
.set_buffer_size = virtio_transport_set_buffer_size,
.set_min_buffer_size = virtio_transport_set_min_buffer_size,
.set_max_buffer_size = virtio_transport_set_max_buffer_size,
.get_buffer_size = virtio_transport_get_buffer_size,
.get_min_buffer_size = virtio_transport_get_min_buffer_size,
.get_max_buffer_size = virtio_transport_get_max_buffer_size,
},
.send_pkt = vhost_transport_send_pkt,
};
static int __init vhost_vsock_init(void)
{
int ret;
ret = vsock_core_init(&vhost_transport.transport);
if (ret < 0)
return ret;
return misc_register(&vhost_vsock_misc);
};
static void __exit vhost_vsock_exit(void)
{
misc_deregister(&vhost_vsock_misc);
vsock_core_exit();
};
module_init(vhost_vsock_init);
module_exit(vhost_vsock_exit);
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Asias He");
MODULE_DESCRIPTION("vhost transport for vsock ");
...@@ -207,6 +207,8 @@ static unsigned leak_balloon(struct virtio_balloon *vb, size_t num) ...@@ -207,6 +207,8 @@ static unsigned leak_balloon(struct virtio_balloon *vb, size_t num)
num = min(num, ARRAY_SIZE(vb->pfns)); num = min(num, ARRAY_SIZE(vb->pfns));
mutex_lock(&vb->balloon_lock); mutex_lock(&vb->balloon_lock);
/* We can't release more pages than taken */
num = min(num, (size_t)vb->num_pages);
for (vb->num_pfns = 0; vb->num_pfns < num; for (vb->num_pfns = 0; vb->num_pfns < num;
vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) { vb->num_pfns += VIRTIO_BALLOON_PAGES_PER_PAGE) {
page = balloon_page_dequeue(vb_dev_info); page = balloon_page_dequeue(vb_dev_info);
......
...@@ -117,7 +117,10 @@ struct vring_virtqueue { ...@@ -117,7 +117,10 @@ struct vring_virtqueue {
#define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq) #define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq)
/* /*
* The interaction between virtio and a possible IOMMU is a mess. * Modern virtio devices have feature bits to specify whether they need a
* quirk and bypass the IOMMU. If not there, just use the DMA API.
*
* If there, the interaction between virtio and DMA API is messy.
* *
* On most systems with virtio, physical addresses match bus addresses, * On most systems with virtio, physical addresses match bus addresses,
* and it doesn't particularly matter whether we use the DMA API. * and it doesn't particularly matter whether we use the DMA API.
...@@ -133,10 +136,18 @@ struct vring_virtqueue { ...@@ -133,10 +136,18 @@ struct vring_virtqueue {
* *
* For the time being, we preserve historic behavior and bypass the DMA * For the time being, we preserve historic behavior and bypass the DMA
* API. * API.
*
* TODO: install a per-device DMA ops structure that does the right thing
* taking into account all the above quirks, and use the DMA API
* unconditionally on data path.
*/ */
static bool vring_use_dma_api(struct virtio_device *vdev) static bool vring_use_dma_api(struct virtio_device *vdev)
{ {
if (!virtio_has_iommu_quirk(vdev))
return true;
/* Otherwise, we are left to guess. */
/* /*
* In theory, it's possible to have a buggy QEMU-supposed * In theory, it's possible to have a buggy QEMU-supposed
* emulated Q35 IOMMU and Xen enabled at the same time. On * emulated Q35 IOMMU and Xen enabled at the same time. On
...@@ -1099,6 +1110,8 @@ void vring_transport_features(struct virtio_device *vdev) ...@@ -1099,6 +1110,8 @@ void vring_transport_features(struct virtio_device *vdev)
break; break;
case VIRTIO_F_VERSION_1: case VIRTIO_F_VERSION_1:
break; break;
case VIRTIO_F_IOMMU_PLATFORM:
break;
default: default:
/* We don't understand this bit. */ /* We don't understand this bit. */
__virtio_clear_bit(vdev, i); __virtio_clear_bit(vdev, i);
......
...@@ -149,6 +149,19 @@ static inline bool virtio_has_feature(const struct virtio_device *vdev, ...@@ -149,6 +149,19 @@ static inline bool virtio_has_feature(const struct virtio_device *vdev,
return __virtio_test_bit(vdev, fbit); return __virtio_test_bit(vdev, fbit);
} }
/**
* virtio_has_iommu_quirk - determine whether this device has the iommu quirk
* @vdev: the device
*/
static inline bool virtio_has_iommu_quirk(const struct virtio_device *vdev)
{
/*
* Note the reverse polarity of the quirk feature (compared to most
* other features), this is for compatibility with legacy systems.
*/
return !virtio_has_feature(vdev, VIRTIO_F_IOMMU_PLATFORM);
}
static inline static inline
struct virtqueue *virtio_find_single_vq(struct virtio_device *vdev, struct virtqueue *virtio_find_single_vq(struct virtio_device *vdev,
vq_callback_t *c, const char *n) vq_callback_t *c, const char *n)
......
#ifndef _LINUX_VIRTIO_VSOCK_H
#define _LINUX_VIRTIO_VSOCK_H
#include <uapi/linux/virtio_vsock.h>
#include <linux/socket.h>
#include <net/sock.h>
#include <net/af_vsock.h>
#define VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE 128
#define VIRTIO_VSOCK_DEFAULT_BUF_SIZE (1024 * 256)
#define VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE (1024 * 256)
#define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4)
#define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL
#define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64)
enum {
VSOCK_VQ_RX = 0, /* for host to guest data */
VSOCK_VQ_TX = 1, /* for guest to host data */
VSOCK_VQ_EVENT = 2,
VSOCK_VQ_MAX = 3,
};
/* Per-socket state (accessed via vsk->trans) */
struct virtio_vsock_sock {
struct vsock_sock *vsk;
/* Protected by lock_sock(sk_vsock(trans->vsk)) */
u32 buf_size;
u32 buf_size_min;
u32 buf_size_max;
spinlock_t tx_lock;
spinlock_t rx_lock;
/* Protected by tx_lock */
u32 tx_cnt;
u32 buf_alloc;
u32 peer_fwd_cnt;
u32 peer_buf_alloc;
/* Protected by rx_lock */
u32 fwd_cnt;
u32 rx_bytes;
struct list_head rx_queue;
};
struct virtio_vsock_pkt {
struct virtio_vsock_hdr hdr;
struct work_struct work;
struct list_head list;
void *buf;
u32 len;
u32 off;
bool reply;
};
struct virtio_vsock_pkt_info {
u32 remote_cid, remote_port;
struct msghdr *msg;
u32 pkt_len;
u16 type;
u16 op;
u32 flags;
bool reply;
};
struct virtio_transport {
/* This must be the first field */
struct vsock_transport transport;
/* Takes ownership of the packet */
int (*send_pkt)(struct virtio_vsock_pkt *pkt);
};
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len,
int type);
int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags);
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk);
u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk);
u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk);
u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk);
void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val);
void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val);
void virtio_transport_set_max_buffer_size(struct vsock_sock *vs, u64 val);
int
virtio_transport_notify_poll_in(struct vsock_sock *vsk,
size_t target,
bool *data_ready_now);
int
virtio_transport_notify_poll_out(struct vsock_sock *vsk,
size_t target,
bool *space_available_now);
int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data);
int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data);
int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data);
int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
size_t target, ssize_t copied, bool data_read,
struct vsock_transport_recv_notify_data *data);
int virtio_transport_notify_send_init(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data);
int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data);
int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data);
int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
ssize_t written, struct vsock_transport_send_notify_data *data);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
bool virtio_transport_stream_allow(u32 cid, u32 port);
int virtio_transport_dgram_bind(struct vsock_sock *vsk,
struct sockaddr_vm *addr);
bool virtio_transport_dgram_allow(u32 cid, u32 port);
int virtio_transport_connect(struct vsock_sock *vsk);
int virtio_transport_shutdown(struct vsock_sock *vsk, int mode);
void virtio_transport_release(struct vsock_sock *vsk);
ssize_t
virtio_transport_stream_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len);
int
virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
struct sockaddr_vm *remote_addr,
struct msghdr *msg,
size_t len);
void virtio_transport_destruct(struct vsock_sock *vsk);
void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt);
void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt);
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt);
u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit);
#endif /* _LINUX_VIRTIO_VSOCK_H */
...@@ -63,6 +63,8 @@ struct vsock_sock { ...@@ -63,6 +63,8 @@ struct vsock_sock {
struct list_head accept_queue; struct list_head accept_queue;
bool rejected; bool rejected;
struct delayed_work dwork; struct delayed_work dwork;
struct delayed_work close_work;
bool close_work_scheduled;
u32 peer_shutdown; u32 peer_shutdown;
bool sent_request; bool sent_request;
bool ignore_connecting_rst; bool ignore_connecting_rst;
...@@ -165,6 +167,9 @@ static inline int vsock_core_init(const struct vsock_transport *t) ...@@ -165,6 +167,9 @@ static inline int vsock_core_init(const struct vsock_transport *t)
} }
void vsock_core_exit(void); void vsock_core_exit(void);
/* The transport may downcast this to access transport-specific functions */
const struct vsock_transport *vsock_core_get_transport(void);
/**** UTILS ****/ /**** UTILS ****/
void vsock_release_pending(struct sock *pending); void vsock_release_pending(struct sock *pending);
...@@ -177,6 +182,7 @@ void vsock_remove_connected(struct vsock_sock *vsk); ...@@ -177,6 +182,7 @@ void vsock_remove_connected(struct vsock_sock *vsk);
struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr); struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
struct sockaddr_vm *dst); struct sockaddr_vm *dst);
void vsock_remove_sock(struct vsock_sock *vsk);
void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)); void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
#endif /* __AF_VSOCK_H__ */ #endif /* __AF_VSOCK_H__ */
#undef TRACE_SYSTEM
#define TRACE_SYSTEM vsock
#if !defined(_TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H) || \
defined(TRACE_HEADER_MULTI_READ)
#define _TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H
#include <linux/tracepoint.h>
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM);
#define show_type(val) \
__print_symbolic(val, { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" })
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RESPONSE);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RST);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_SHUTDOWN);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_RW);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_CREDIT_UPDATE);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_CREDIT_REQUEST);
#define show_op(val) \
__print_symbolic(val, \
{ VIRTIO_VSOCK_OP_INVALID, "INVALID" }, \
{ VIRTIO_VSOCK_OP_REQUEST, "REQUEST" }, \
{ VIRTIO_VSOCK_OP_RESPONSE, "RESPONSE" }, \
{ VIRTIO_VSOCK_OP_RST, "RST" }, \
{ VIRTIO_VSOCK_OP_SHUTDOWN, "SHUTDOWN" }, \
{ VIRTIO_VSOCK_OP_RW, "RW" }, \
{ VIRTIO_VSOCK_OP_CREDIT_UPDATE, "CREDIT_UPDATE" }, \
{ VIRTIO_VSOCK_OP_CREDIT_REQUEST, "CREDIT_REQUEST" })
TRACE_EVENT(virtio_transport_alloc_pkt,
TP_PROTO(
__u32 src_cid, __u32 src_port,
__u32 dst_cid, __u32 dst_port,
__u32 len,
__u16 type,
__u16 op,
__u32 flags
),
TP_ARGS(
src_cid, src_port,
dst_cid, dst_port,
len,
type,
op,
flags
),
TP_STRUCT__entry(
__field(__u32, src_cid)
__field(__u32, src_port)
__field(__u32, dst_cid)
__field(__u32, dst_port)
__field(__u32, len)
__field(__u16, type)
__field(__u16, op)
__field(__u32, flags)
),
TP_fast_assign(
__entry->src_cid = src_cid;
__entry->src_port = src_port;
__entry->dst_cid = dst_cid;
__entry->dst_port = dst_port;
__entry->len = len;
__entry->type = type;
__entry->op = op;
__entry->flags = flags;
),
TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x",
__entry->src_cid, __entry->src_port,
__entry->dst_cid, __entry->dst_port,
__entry->len,
show_type(__entry->type),
show_op(__entry->op),
__entry->flags)
);
TRACE_EVENT(virtio_transport_recv_pkt,
TP_PROTO(
__u32 src_cid, __u32 src_port,
__u32 dst_cid, __u32 dst_port,
__u32 len,
__u16 type,
__u16 op,
__u32 flags,
__u32 buf_alloc,
__u32 fwd_cnt
),
TP_ARGS(
src_cid, src_port,
dst_cid, dst_port,
len,
type,
op,
flags,
buf_alloc,
fwd_cnt
),
TP_STRUCT__entry(
__field(__u32, src_cid)
__field(__u32, src_port)
__field(__u32, dst_cid)
__field(__u32, dst_port)
__field(__u32, len)
__field(__u16, type)
__field(__u16, op)
__field(__u32, flags)
__field(__u32, buf_alloc)
__field(__u32, fwd_cnt)
),
TP_fast_assign(
__entry->src_cid = src_cid;
__entry->src_port = src_port;
__entry->dst_cid = dst_cid;
__entry->dst_port = dst_port;
__entry->len = len;
__entry->type = type;
__entry->op = op;
__entry->flags = flags;
__entry->buf_alloc = buf_alloc;
__entry->fwd_cnt = fwd_cnt;
),
TP_printk("%u:%u -> %u:%u len=%u type=%s op=%s flags=%#x "
"buf_alloc=%u fwd_cnt=%u",
__entry->src_cid, __entry->src_port,
__entry->dst_cid, __entry->dst_port,
__entry->len,
show_type(__entry->type),
show_op(__entry->op),
__entry->flags,
__entry->buf_alloc,
__entry->fwd_cnt)
);
#endif /* _TRACE_VSOCK_VIRTIO_TRANSPORT_COMMON_H */
#undef TRACE_INCLUDE_FILE
#define TRACE_INCLUDE_FILE vsock_virtio_transport_common
/* This part must be outside protection */
#include <trace/define_trace.h>
...@@ -454,6 +454,7 @@ header-y += virtio_ring.h ...@@ -454,6 +454,7 @@ header-y += virtio_ring.h
header-y += virtio_rng.h header-y += virtio_rng.h
header-y += virtio_scsi.h header-y += virtio_scsi.h
header-y += virtio_types.h header-y += virtio_types.h
header-y += virtio_vsock.h
header-y += vm_sockets.h header-y += vm_sockets.h
header-y += vt.h header-y += vt.h
header-y += vtpm_proxy.h header-y += vtpm_proxy.h
......
...@@ -47,6 +47,32 @@ struct vhost_vring_addr { ...@@ -47,6 +47,32 @@ struct vhost_vring_addr {
__u64 log_guest_addr; __u64 log_guest_addr;
}; };
/* no alignment requirement */
struct vhost_iotlb_msg {
__u64 iova;
__u64 size;
__u64 uaddr;
#define VHOST_ACCESS_RO 0x1
#define VHOST_ACCESS_WO 0x2
#define VHOST_ACCESS_RW 0x3
__u8 perm;
#define VHOST_IOTLB_MISS 1
#define VHOST_IOTLB_UPDATE 2
#define VHOST_IOTLB_INVALIDATE 3
#define VHOST_IOTLB_ACCESS_FAIL 4
__u8 type;
};
#define VHOST_IOTLB_MSG 0x1
struct vhost_msg {
int type;
union {
struct vhost_iotlb_msg iotlb;
__u8 padding[64];
};
};
struct vhost_memory_region { struct vhost_memory_region {
__u64 guest_phys_addr; __u64 guest_phys_addr;
__u64 memory_size; /* bytes */ __u64 memory_size; /* bytes */
...@@ -146,6 +172,8 @@ struct vhost_memory { ...@@ -146,6 +172,8 @@ struct vhost_memory {
#define VHOST_F_LOG_ALL 26 #define VHOST_F_LOG_ALL 26
/* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */ /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
#define VHOST_NET_F_VIRTIO_NET_HDR 27 #define VHOST_NET_F_VIRTIO_NET_HDR 27
/* Vhost have device IOTLB */
#define VHOST_F_DEVICE_IOTLB 63
/* VHOST_SCSI specific definitions */ /* VHOST_SCSI specific definitions */
...@@ -175,4 +203,9 @@ struct vhost_scsi_target { ...@@ -175,4 +203,9 @@ struct vhost_scsi_target {
#define VHOST_SCSI_SET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x43, __u32) #define VHOST_SCSI_SET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x43, __u32)
#define VHOST_SCSI_GET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x44, __u32) #define VHOST_SCSI_GET_EVENTS_MISSED _IOW(VHOST_VIRTIO, 0x44, __u32)
/* VHOST_VSOCK specific defines */
#define VHOST_VSOCK_SET_GUEST_CID _IOW(VHOST_VIRTIO, 0x60, __u64)
#define VHOST_VSOCK_SET_RUNNING _IOW(VHOST_VIRTIO, 0x61, int)
#endif #endif
...@@ -49,7 +49,7 @@ ...@@ -49,7 +49,7 @@
* transport being used (eg. virtio_ring), the rest are per-device feature * transport being used (eg. virtio_ring), the rest are per-device feature
* bits. */ * bits. */
#define VIRTIO_TRANSPORT_F_START 28 #define VIRTIO_TRANSPORT_F_START 28
#define VIRTIO_TRANSPORT_F_END 33 #define VIRTIO_TRANSPORT_F_END 34
#ifndef VIRTIO_CONFIG_NO_LEGACY #ifndef VIRTIO_CONFIG_NO_LEGACY
/* Do we get callbacks when the ring is completely used, even if we've /* Do we get callbacks when the ring is completely used, even if we've
...@@ -63,4 +63,12 @@ ...@@ -63,4 +63,12 @@
/* v1.0 compliant. */ /* v1.0 compliant. */
#define VIRTIO_F_VERSION_1 32 #define VIRTIO_F_VERSION_1 32
/*
* If clear - device has the IOMMU bypass quirk feature.
* If set - use platform tools to detect the IOMMU.
*
* Note the reverse polarity (compared to most other features),
* this is for compatibility with legacy systems.
*/
#define VIRTIO_F_IOMMU_PLATFORM 33
#endif /* _UAPI_LINUX_VIRTIO_CONFIG_H */ #endif /* _UAPI_LINUX_VIRTIO_CONFIG_H */
...@@ -41,5 +41,6 @@ ...@@ -41,5 +41,6 @@
#define VIRTIO_ID_CAIF 12 /* Virtio caif */ #define VIRTIO_ID_CAIF 12 /* Virtio caif */
#define VIRTIO_ID_GPU 16 /* virtio GPU */ #define VIRTIO_ID_GPU 16 /* virtio GPU */
#define VIRTIO_ID_INPUT 18 /* virtio input */ #define VIRTIO_ID_INPUT 18 /* virtio input */
#define VIRTIO_ID_VSOCK 19 /* virtio vsock transport */
#endif /* _LINUX_VIRTIO_IDS_H */ #endif /* _LINUX_VIRTIO_IDS_H */
/*
* This header, excluding the #ifdef __KERNEL__ part, is BSD licensed so
* anyone can use the definitions to implement compatible drivers/servers:
*
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* 3. Neither the name of IBM nor the names of its contributors
* may be used to endorse or promote products derived from this software
* without specific prior written permission.
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS''
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL IBM OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*
* Copyright (C) Red Hat, Inc., 2013-2015
* Copyright (C) Asias He <asias@redhat.com>, 2013
* Copyright (C) Stefan Hajnoczi <stefanha@redhat.com>, 2015
*/
#ifndef _UAPI_LINUX_VIRTIO_VSOCK_H
#define _UAPI_LINUX_VIRTIO_VOSCK_H
#include <linux/types.h>
#include <linux/virtio_ids.h>
#include <linux/virtio_config.h>
struct virtio_vsock_config {
__le64 guest_cid;
} __attribute__((packed));
enum virtio_vsock_event_id {
VIRTIO_VSOCK_EVENT_TRANSPORT_RESET = 0,
};
struct virtio_vsock_event {
__le32 id;
} __attribute__((packed));
struct virtio_vsock_hdr {
__le64 src_cid;
__le64 dst_cid;
__le32 src_port;
__le32 dst_port;
__le32 len;
__le16 type; /* enum virtio_vsock_type */
__le16 op; /* enum virtio_vsock_op */
__le32 flags;
__le32 buf_alloc;
__le32 fwd_cnt;
} __attribute__((packed));
enum virtio_vsock_type {
VIRTIO_VSOCK_TYPE_STREAM = 1,
};
enum virtio_vsock_op {
VIRTIO_VSOCK_OP_INVALID = 0,
/* Connect operations */
VIRTIO_VSOCK_OP_REQUEST = 1,
VIRTIO_VSOCK_OP_RESPONSE = 2,
VIRTIO_VSOCK_OP_RST = 3,
VIRTIO_VSOCK_OP_SHUTDOWN = 4,
/* To send payload */
VIRTIO_VSOCK_OP_RW = 5,
/* Tell the peer our credit info */
VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6,
/* Request the peer to send the credit info to us */
VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7,
};
/* VIRTIO_VSOCK_OP_SHUTDOWN flags values */
enum virtio_vsock_shutdown {
VIRTIO_VSOCK_SHUTDOWN_RCV = 1,
VIRTIO_VSOCK_SHUTDOWN_SEND = 2,
};
#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */
...@@ -26,3 +26,23 @@ config VMWARE_VMCI_VSOCKETS ...@@ -26,3 +26,23 @@ config VMWARE_VMCI_VSOCKETS
To compile this driver as a module, choose M here: the module To compile this driver as a module, choose M here: the module
will be called vmw_vsock_vmci_transport. If unsure, say N. will be called vmw_vsock_vmci_transport. If unsure, say N.
config VIRTIO_VSOCKETS
tristate "virtio transport for Virtual Sockets"
depends on VSOCKETS && VIRTIO
select VIRTIO_VSOCKETS_COMMON
help
This module implements a virtio transport for Virtual Sockets.
Enable this transport if your Virtual Machine host supports Virtual
Sockets over virtio.
To compile this driver as a module, choose M here: the module will be
called vmw_vsock_virtio_transport. If unsure, say N.
config VIRTIO_VSOCKETS_COMMON
tristate
help
This option is selected by any driver which needs to access
the virtio_vsock. The module will be called
vmw_vsock_virtio_transport_common.
obj-$(CONFIG_VSOCKETS) += vsock.o obj-$(CONFIG_VSOCKETS) += vsock.o
obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o
obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o
obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o
vsock-y += af_vsock.o vsock_addr.o vsock-y += af_vsock.o vsock_addr.o
vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \
vmci_transport_notify_qstate.o vmci_transport_notify_qstate.o
vmw_vsock_virtio_transport-y += virtio_transport.o
vmw_vsock_virtio_transport_common-y += virtio_transport_common.o
...@@ -344,6 +344,16 @@ static bool vsock_in_connected_table(struct vsock_sock *vsk) ...@@ -344,6 +344,16 @@ static bool vsock_in_connected_table(struct vsock_sock *vsk)
return ret; return ret;
} }
void vsock_remove_sock(struct vsock_sock *vsk)
{
if (vsock_in_bound_table(vsk))
vsock_remove_bound(vsk);
if (vsock_in_connected_table(vsk))
vsock_remove_connected(vsk);
}
EXPORT_SYMBOL_GPL(vsock_remove_sock);
void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
{ {
int i; int i;
...@@ -660,12 +670,6 @@ static void __vsock_release(struct sock *sk) ...@@ -660,12 +670,6 @@ static void __vsock_release(struct sock *sk)
vsk = vsock_sk(sk); vsk = vsock_sk(sk);
pending = NULL; /* Compiler warning. */ pending = NULL; /* Compiler warning. */
if (vsock_in_bound_table(vsk))
vsock_remove_bound(vsk);
if (vsock_in_connected_table(vsk))
vsock_remove_connected(vsk);
transport->release(vsk); transport->release(vsk);
lock_sock(sk); lock_sock(sk);
...@@ -1995,6 +1999,15 @@ void vsock_core_exit(void) ...@@ -1995,6 +1999,15 @@ void vsock_core_exit(void)
} }
EXPORT_SYMBOL_GPL(vsock_core_exit); EXPORT_SYMBOL_GPL(vsock_core_exit);
const struct vsock_transport *vsock_core_get_transport(void)
{
/* vsock_register_mutex not taken since only the transport uses this
* function and only while registered.
*/
return transport;
}
EXPORT_SYMBOL_GPL(vsock_core_get_transport);
MODULE_AUTHOR("VMware, Inc."); MODULE_AUTHOR("VMware, Inc.");
MODULE_DESCRIPTION("VMware Virtual Socket Family"); MODULE_DESCRIPTION("VMware Virtual Socket Family");
MODULE_VERSION("1.0.1.0-k"); MODULE_VERSION("1.0.1.0-k");
......
/*
* virtio transport for vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
*
* Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
* early virtio-vsock proof-of-concept bits.
*
* This work is licensed under the terms of the GNU GPL, version 2.
*/
#include <linux/spinlock.h>
#include <linux/module.h>
#include <linux/list.h>
#include <linux/atomic.h>
#include <linux/virtio.h>
#include <linux/virtio_ids.h>
#include <linux/virtio_config.h>
#include <linux/virtio_vsock.h>
#include <net/sock.h>
#include <linux/mutex.h>
#include <net/af_vsock.h>
static struct workqueue_struct *virtio_vsock_workqueue;
static struct virtio_vsock *the_virtio_vsock;
static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
struct virtio_vsock {
struct virtio_device *vdev;
struct virtqueue *vqs[VSOCK_VQ_MAX];
/* Virtqueue processing is deferred to a workqueue */
struct work_struct tx_work;
struct work_struct rx_work;
struct work_struct event_work;
/* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX]
* must be accessed with tx_lock held.
*/
struct mutex tx_lock;
struct work_struct send_pkt_work;
spinlock_t send_pkt_list_lock;
struct list_head send_pkt_list;
atomic_t queued_replies;
/* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX]
* must be accessed with rx_lock held.
*/
struct mutex rx_lock;
int rx_buf_nr;
int rx_buf_max_nr;
/* The following fields are protected by event_lock.
* vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
*/
struct mutex event_lock;
struct virtio_vsock_event event_list[8];
u32 guest_cid;
};
static struct virtio_vsock *virtio_vsock_get(void)
{
return the_virtio_vsock;
}
static u32 virtio_transport_get_local_cid(void)
{
struct virtio_vsock *vsock = virtio_vsock_get();
return vsock->guest_cid;
}
static void
virtio_transport_send_pkt_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, send_pkt_work);
struct virtqueue *vq;
bool added = false;
bool restart_rx = false;
mutex_lock(&vsock->tx_lock);
vq = vsock->vqs[VSOCK_VQ_TX];
/* Avoid unnecessary interrupts while we're processing the ring */
virtqueue_disable_cb(vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
struct scatterlist hdr, buf, *sgs[2];
int ret, in_sg = 0, out_sg = 0;
bool reply;
spin_lock_bh(&vsock->send_pkt_list_lock);
if (list_empty(&vsock->send_pkt_list)) {
spin_unlock_bh(&vsock->send_pkt_list_lock);
virtqueue_enable_cb(vq);
break;
}
pkt = list_first_entry(&vsock->send_pkt_list,
struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
reply = pkt->reply;
sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
sgs[out_sg++] = &hdr;
if (pkt->buf) {
sg_init_one(&buf, pkt->buf, pkt->len);
sgs[out_sg++] = &buf;
}
ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
if (ret < 0) {
spin_lock_bh(&vsock->send_pkt_list_lock);
list_add(&pkt->list, &vsock->send_pkt_list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
if (!virtqueue_enable_cb(vq) && ret == -ENOSPC)
continue; /* retry now that we have more space */
break;
}
if (reply) {
struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
int val;
val = atomic_dec_return(&vsock->queued_replies);
/* Do we now have resources to resume rx processing? */
if (val + 1 == virtqueue_get_vring_size(rx_vq))
restart_rx = true;
}
added = true;
}
if (added)
virtqueue_kick(vq);
mutex_unlock(&vsock->tx_lock);
if (restart_rx)
queue_work(virtio_vsock_workqueue, &vsock->rx_work);
}
static int
virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
{
struct virtio_vsock *vsock;
int len = pkt->len;
vsock = virtio_vsock_get();
if (!vsock) {
virtio_transport_free_pkt(pkt);
return -ENODEV;
}
if (pkt->reply)
atomic_inc(&vsock->queued_replies);
spin_lock_bh(&vsock->send_pkt_list_lock);
list_add_tail(&pkt->list, &vsock->send_pkt_list);
spin_unlock_bh(&vsock->send_pkt_list_lock);
queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
return len;
}
static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
{
int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
struct virtio_vsock_pkt *pkt;
struct scatterlist hdr, buf, *sgs[2];
struct virtqueue *vq;
int ret;
vq = vsock->vqs[VSOCK_VQ_RX];
do {
pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
if (!pkt)
break;
pkt->buf = kmalloc(buf_len, GFP_KERNEL);
if (!pkt->buf) {
virtio_transport_free_pkt(pkt);
break;
}
pkt->len = buf_len;
sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
sgs[0] = &hdr;
sg_init_one(&buf, pkt->buf, buf_len);
sgs[1] = &buf;
ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
if (ret) {
virtio_transport_free_pkt(pkt);
break;
}
vsock->rx_buf_nr++;
} while (vq->num_free);
if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
vsock->rx_buf_max_nr = vsock->rx_buf_nr;
virtqueue_kick(vq);
}
static void virtio_transport_tx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, tx_work);
struct virtqueue *vq;
bool added = false;
vq = vsock->vqs[VSOCK_VQ_TX];
mutex_lock(&vsock->tx_lock);
do {
struct virtio_vsock_pkt *pkt;
unsigned int len;
virtqueue_disable_cb(vq);
while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
virtio_transport_free_pkt(pkt);
added = true;
}
} while (!virtqueue_enable_cb(vq));
mutex_unlock(&vsock->tx_lock);
if (added)
queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
}
/* Is there space left for replies to rx packets? */
static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
{
struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
int val;
smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
val = atomic_read(&vsock->queued_replies);
return val < virtqueue_get_vring_size(vq);
}
static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, rx_work);
struct virtqueue *vq;
vq = vsock->vqs[VSOCK_VQ_RX];
mutex_lock(&vsock->rx_lock);
do {
virtqueue_disable_cb(vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
unsigned int len;
if (!virtio_transport_more_replies(vsock)) {
/* Stop rx until the device processes already
* pending replies. Leave rx virtqueue
* callbacks disabled.
*/
goto out;
}
pkt = virtqueue_get_buf(vq, &len);
if (!pkt) {
break;
}
vsock->rx_buf_nr--;
/* Drop short/long packets */
if (unlikely(len < sizeof(pkt->hdr) ||
len > sizeof(pkt->hdr) + pkt->len)) {
virtio_transport_free_pkt(pkt);
continue;
}
pkt->len = len - sizeof(pkt->hdr);
virtio_transport_recv_pkt(pkt);
}
} while (!virtqueue_enable_cb(vq));
out:
if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
virtio_vsock_rx_fill(vsock);
mutex_unlock(&vsock->rx_lock);
}
/* event_lock must be held */
static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
struct virtio_vsock_event *event)
{
struct scatterlist sg;
struct virtqueue *vq;
vq = vsock->vqs[VSOCK_VQ_EVENT];
sg_init_one(&sg, event, sizeof(*event));
return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
}
/* event_lock must be held */
static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
{
size_t i;
for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
struct virtio_vsock_event *event = &vsock->event_list[i];
virtio_vsock_event_fill_one(vsock, event);
}
virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
}
static void virtio_vsock_reset_sock(struct sock *sk)
{
lock_sock(sk);
sk->sk_state = SS_UNCONNECTED;
sk->sk_err = ECONNRESET;
sk->sk_error_report(sk);
release_sock(sk);
}
static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
{
struct virtio_device *vdev = vsock->vdev;
u64 guest_cid;
vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
&guest_cid, sizeof(guest_cid));
vsock->guest_cid = le64_to_cpu(guest_cid);
}
/* event_lock must be held */
static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
struct virtio_vsock_event *event)
{
switch (le32_to_cpu(event->id)) {
case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
virtio_vsock_update_guest_cid(vsock);
vsock_for_each_connected_socket(virtio_vsock_reset_sock);
break;
}
}
static void virtio_transport_event_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, event_work);
struct virtqueue *vq;
vq = vsock->vqs[VSOCK_VQ_EVENT];
mutex_lock(&vsock->event_lock);
do {
struct virtio_vsock_event *event;
unsigned int len;
virtqueue_disable_cb(vq);
while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
if (len == sizeof(*event))
virtio_vsock_event_handle(vsock, event);
virtio_vsock_event_fill_one(vsock, event);
}
} while (!virtqueue_enable_cb(vq));
virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
mutex_unlock(&vsock->event_lock);
}
static void virtio_vsock_event_done(struct virtqueue *vq)
{
struct virtio_vsock *vsock = vq->vdev->priv;
if (!vsock)
return;
queue_work(virtio_vsock_workqueue, &vsock->event_work);
}
static void virtio_vsock_tx_done(struct virtqueue *vq)
{
struct virtio_vsock *vsock = vq->vdev->priv;
if (!vsock)
return;
queue_work(virtio_vsock_workqueue, &vsock->tx_work);
}
static void virtio_vsock_rx_done(struct virtqueue *vq)
{
struct virtio_vsock *vsock = vq->vdev->priv;
if (!vsock)
return;
queue_work(virtio_vsock_workqueue, &vsock->rx_work);
}
static struct virtio_transport virtio_transport = {
.transport = {
.get_local_cid = virtio_transport_get_local_cid,
.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
.release = virtio_transport_release,
.connect = virtio_transport_connect,
.shutdown = virtio_transport_shutdown,
.dgram_bind = virtio_transport_dgram_bind,
.dgram_dequeue = virtio_transport_dgram_dequeue,
.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_allow = virtio_transport_dgram_allow,
.stream_dequeue = virtio_transport_stream_dequeue,
.stream_enqueue = virtio_transport_stream_enqueue,
.stream_has_data = virtio_transport_stream_has_data,
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
.notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
.notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
.notify_send_init = virtio_transport_notify_send_init,
.notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
.set_buffer_size = virtio_transport_set_buffer_size,
.set_min_buffer_size = virtio_transport_set_min_buffer_size,
.set_max_buffer_size = virtio_transport_set_max_buffer_size,
.get_buffer_size = virtio_transport_get_buffer_size,
.get_min_buffer_size = virtio_transport_get_min_buffer_size,
.get_max_buffer_size = virtio_transport_get_max_buffer_size,
},
.send_pkt = virtio_transport_send_pkt,
};
static int virtio_vsock_probe(struct virtio_device *vdev)
{
vq_callback_t *callbacks[] = {
virtio_vsock_rx_done,
virtio_vsock_tx_done,
virtio_vsock_event_done,
};
static const char * const names[] = {
"rx",
"tx",
"event",
};
struct virtio_vsock *vsock = NULL;
int ret;
ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
if (ret)
return ret;
/* Only one virtio-vsock device per guest is supported */
if (the_virtio_vsock) {
ret = -EBUSY;
goto out;
}
vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
if (!vsock) {
ret = -ENOMEM;
goto out;
}
vsock->vdev = vdev;
ret = vsock->vdev->config->find_vqs(vsock->vdev, VSOCK_VQ_MAX,
vsock->vqs, callbacks, names);
if (ret < 0)
goto out;
virtio_vsock_update_guest_cid(vsock);
ret = vsock_core_init(&virtio_transport.transport);
if (ret < 0)
goto out_vqs;
vsock->rx_buf_nr = 0;
vsock->rx_buf_max_nr = 0;
atomic_set(&vsock->queued_replies, 0);
vdev->priv = vsock;
the_virtio_vsock = vsock;
mutex_init(&vsock->tx_lock);
mutex_init(&vsock->rx_lock);
mutex_init(&vsock->event_lock);
spin_lock_init(&vsock->send_pkt_list_lock);
INIT_LIST_HEAD(&vsock->send_pkt_list);
INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
INIT_WORK(&vsock->event_work, virtio_transport_event_work);
INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
mutex_lock(&vsock->rx_lock);
virtio_vsock_rx_fill(vsock);
mutex_unlock(&vsock->rx_lock);
mutex_lock(&vsock->event_lock);
virtio_vsock_event_fill(vsock);
mutex_unlock(&vsock->event_lock);
mutex_unlock(&the_virtio_vsock_mutex);
return 0;
out_vqs:
vsock->vdev->config->del_vqs(vsock->vdev);
out:
kfree(vsock);
mutex_unlock(&the_virtio_vsock_mutex);
return ret;
}
static void virtio_vsock_remove(struct virtio_device *vdev)
{
struct virtio_vsock *vsock = vdev->priv;
struct virtio_vsock_pkt *pkt;
flush_work(&vsock->rx_work);
flush_work(&vsock->tx_work);
flush_work(&vsock->event_work);
flush_work(&vsock->send_pkt_work);
vdev->config->reset(vdev);
mutex_lock(&vsock->rx_lock);
while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
virtio_transport_free_pkt(pkt);
mutex_unlock(&vsock->rx_lock);
mutex_lock(&vsock->tx_lock);
while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
virtio_transport_free_pkt(pkt);
mutex_unlock(&vsock->tx_lock);
spin_lock_bh(&vsock->send_pkt_list_lock);
while (!list_empty(&vsock->send_pkt_list)) {
pkt = list_first_entry(&vsock->send_pkt_list,
struct virtio_vsock_pkt, list);
list_del(&pkt->list);
virtio_transport_free_pkt(pkt);
}
spin_unlock_bh(&vsock->send_pkt_list_lock);
mutex_lock(&the_virtio_vsock_mutex);
the_virtio_vsock = NULL;
vsock_core_exit();
mutex_unlock(&the_virtio_vsock_mutex);
vdev->config->del_vqs(vdev);
kfree(vsock);
}
static struct virtio_device_id id_table[] = {
{ VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
{ 0 },
};
static unsigned int features[] = {
};
static struct virtio_driver virtio_vsock_driver = {
.feature_table = features,
.feature_table_size = ARRAY_SIZE(features),
.driver.name = KBUILD_MODNAME,
.driver.owner = THIS_MODULE,
.id_table = id_table,
.probe = virtio_vsock_probe,
.remove = virtio_vsock_remove,
};
static int __init virtio_vsock_init(void)
{
int ret;
virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
if (!virtio_vsock_workqueue)
return -ENOMEM;
ret = register_virtio_driver(&virtio_vsock_driver);
if (ret)
destroy_workqueue(virtio_vsock_workqueue);
return ret;
}
static void __exit virtio_vsock_exit(void)
{
unregister_virtio_driver(&virtio_vsock_driver);
destroy_workqueue(virtio_vsock_workqueue);
}
module_init(virtio_vsock_init);
module_exit(virtio_vsock_exit);
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Asias He");
MODULE_DESCRIPTION("virtio transport for vsock");
MODULE_DEVICE_TABLE(virtio, id_table);
/*
* common code for virtio vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
*
* This work is licensed under the terms of the GNU GPL, version 2.
*/
#include <linux/spinlock.h>
#include <linux/module.h>
#include <linux/ctype.h>
#include <linux/list.h>
#include <linux/virtio.h>
#include <linux/virtio_ids.h>
#include <linux/virtio_config.h>
#include <linux/virtio_vsock.h>
#include <net/sock.h>
#include <net/af_vsock.h>
#define CREATE_TRACE_POINTS
#include <trace/events/vsock_virtio_transport_common.h>
/* How long to wait for graceful shutdown of a connection */
#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
static const struct virtio_transport *virtio_transport_get_ops(void)
{
const struct vsock_transport *t = vsock_core_get_transport();
return container_of(t, struct virtio_transport, transport);
}
struct virtio_vsock_pkt *
virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
size_t len,
u32 src_cid,
u32 src_port,
u32 dst_cid,
u32 dst_port)
{
struct virtio_vsock_pkt *pkt;
int err;
pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
if (!pkt)
return NULL;
pkt->hdr.type = cpu_to_le16(info->type);
pkt->hdr.op = cpu_to_le16(info->op);
pkt->hdr.src_cid = cpu_to_le64(src_cid);
pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
pkt->hdr.src_port = cpu_to_le32(src_port);
pkt->hdr.dst_port = cpu_to_le32(dst_port);
pkt->hdr.flags = cpu_to_le32(info->flags);
pkt->len = len;
pkt->hdr.len = cpu_to_le32(len);
pkt->reply = info->reply;
if (info->msg && len > 0) {
pkt->buf = kmalloc(len, GFP_KERNEL);
if (!pkt->buf)
goto out_pkt;
err = memcpy_from_msg(pkt->buf, info->msg, len);
if (err)
goto out;
}
trace_virtio_transport_alloc_pkt(src_cid, src_port,
dst_cid, dst_port,
len,
info->type,
info->op,
info->flags);
return pkt;
out:
kfree(pkt->buf);
out_pkt:
kfree(pkt);
return NULL;
}
EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt_info *info)
{
u32 src_cid, src_port, dst_cid, dst_port;
struct virtio_vsock_sock *vvs;
struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len;
src_cid = vm_sockets_get_local_cid();
src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) {
dst_cid = vsk->remote_addr.svm_cid;
dst_port = vsk->remote_addr.svm_port;
} else {
dst_cid = info->remote_cid;
dst_port = info->remote_port;
}
vvs = vsk->trans;
/* we can send less than pkt_len bytes */
if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
/* virtio_transport_get_credit might return less than pkt_len credit */
pkt_len = virtio_transport_get_credit(vvs, pkt_len);
/* Do not send zero length OP_RW pkt */
if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
return pkt_len;
pkt = virtio_transport_alloc_pkt(info, pkt_len,
src_cid, src_port,
dst_cid, dst_port);
if (!pkt) {
virtio_transport_put_credit(vvs, pkt_len);
return -ENOMEM;
}
virtio_transport_inc_tx_pkt(vvs, pkt);
return virtio_transport_get_ops()->send_pkt(pkt);
}
static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
struct virtio_vsock_pkt *pkt)
{
vvs->rx_bytes += pkt->len;
}
static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
struct virtio_vsock_pkt *pkt)
{
vvs->rx_bytes -= pkt->len;
vvs->fwd_cnt += pkt->len;
}
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
{
spin_lock_bh(&vvs->tx_lock);
pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
spin_unlock_bh(&vvs->tx_lock);
}
EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
{
u32 ret;
spin_lock_bh(&vvs->tx_lock);
ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
if (ret > credit)
ret = credit;
vvs->tx_cnt += ret;
spin_unlock_bh(&vvs->tx_lock);
return ret;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
{
spin_lock_bh(&vvs->tx_lock);
vvs->tx_cnt -= credit;
spin_unlock_bh(&vvs->tx_lock);
}
EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
int type,
struct virtio_vsock_hdr *hdr)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
.type = type,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
static ssize_t
virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
struct virtio_vsock_pkt *pkt;
size_t bytes, total = 0;
int err = -EFAULT;
spin_lock_bh(&vvs->rx_lock);
while (total < len && !list_empty(&vvs->rx_queue)) {
pkt = list_first_entry(&vvs->rx_queue,
struct virtio_vsock_pkt, list);
bytes = len - total;
if (bytes > pkt->len - pkt->off)
bytes = pkt->len - pkt->off;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
if (err)
goto out;
spin_lock_bh(&vvs->rx_lock);
total += bytes;
pkt->off += bytes;
if (pkt->off == pkt->len) {
virtio_transport_dec_rx_pkt(vvs, pkt);
list_del(&pkt->list);
virtio_transport_free_pkt(pkt);
}
}
spin_unlock_bh(&vvs->rx_lock);
/* Send a credit pkt to peer */
virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
NULL);
return total;
out:
if (total)
err = total;
return err;
}
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags)
{
if (flags & MSG_PEEK)
return -EOPNOTSUPP;
return virtio_transport_stream_do_dequeue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags)
{
return -EOPNOTSUPP;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
s64 bytes;
spin_lock_bh(&vvs->rx_lock);
bytes = vvs->rx_bytes;
spin_unlock_bh(&vvs->rx_lock);
return bytes;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
static s64 virtio_transport_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
s64 bytes;
bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
if (bytes < 0)
bytes = 0;
return bytes;
}
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
s64 bytes;
spin_lock_bh(&vvs->tx_lock);
bytes = virtio_transport_has_space(vsk);
spin_unlock_bh(&vvs->tx_lock);
return bytes;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk)
{
struct virtio_vsock_sock *vvs;
vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
if (!vvs)
return -ENOMEM;
vsk->trans = vvs;
vvs->vsk = vsk;
if (psk) {
struct virtio_vsock_sock *ptrans = psk->trans;
vvs->buf_size = ptrans->buf_size;
vvs->buf_size_min = ptrans->buf_size_min;
vvs->buf_size_max = ptrans->buf_size_max;
vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
} else {
vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
}
vvs->buf_alloc = vvs->buf_size;
spin_lock_init(&vvs->rx_lock);
spin_lock_init(&vvs->tx_lock);
INIT_LIST_HEAD(&vvs->rx_queue);
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size_min;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size_max;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
val = VIRTIO_VSOCK_MAX_BUF_SIZE;
if (val < vvs->buf_size_min)
vvs->buf_size_min = val;
if (val > vvs->buf_size_max)
vvs->buf_size_max = val;
vvs->buf_size = val;
vvs->buf_alloc = val;
}
EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
val = VIRTIO_VSOCK_MAX_BUF_SIZE;
if (val > vvs->buf_size)
vvs->buf_size = val;
vvs->buf_size_min = val;
}
EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
val = VIRTIO_VSOCK_MAX_BUF_SIZE;
if (val < vvs->buf_size)
vvs->buf_size = val;
vvs->buf_size_max = val;
}
EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
int
virtio_transport_notify_poll_in(struct vsock_sock *vsk,
size_t target,
bool *data_ready_now)
{
if (vsock_stream_has_data(vsk))
*data_ready_now = true;
else
*data_ready_now = false;
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
int
virtio_transport_notify_poll_out(struct vsock_sock *vsk,
size_t target,
bool *space_avail_now)
{
s64 free_space;
free_space = vsock_stream_has_space(vsk);
if (free_space > 0)
*space_avail_now = true;
else if (free_space == 0)
*space_avail_now = false;
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
size_t target, ssize_t copied, bool data_read,
struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
int virtio_transport_notify_send_init(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
ssize_t written, struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
{
return true;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
bool virtio_transport_stream_allow(u32 cid, u32 port)
{
return true;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
int virtio_transport_dgram_bind(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
return -EOPNOTSUPP;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
bool virtio_transport_dgram_allow(u32 cid, u32 port)
{
return false;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
int virtio_transport_connect(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_connect);
int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_SHUTDOWN,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.flags = (mode & RCV_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
(mode & SEND_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
};
return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
int
virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
struct sockaddr_vm *remote_addr,
struct msghdr *msg,
size_t dgram_len)
{
return -EOPNOTSUPP;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
ssize_t
virtio_transport_stream_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.msg = msg,
.pkt_len = len,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
void virtio_transport_destruct(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
kfree(vvs);
}
EXPORT_SYMBOL_GPL(virtio_transport_destruct);
static int virtio_transport_reset(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.reply = !!pkt,
};
/* Send RST only if the original pkt is not a RST pkt */
if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
return 0;
return virtio_transport_send_pkt_info(vsk, &info);
}
/* Normally packets are associated with a socket. There may be no socket if an
* attempt was made to connect to a socket that does not exist.
*/
static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.type = le16_to_cpu(pkt->hdr.type),
.reply = true,
};
/* Send RST only if the original pkt is not a RST pkt */
if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
return 0;
pkt = virtio_transport_alloc_pkt(&info, 0,
le32_to_cpu(pkt->hdr.dst_cid),
le32_to_cpu(pkt->hdr.dst_port),
le32_to_cpu(pkt->hdr.src_cid),
le32_to_cpu(pkt->hdr.src_port));
if (!pkt)
return -ENOMEM;
return virtio_transport_get_ops()->send_pkt(pkt);
}
static void virtio_transport_wait_close(struct sock *sk, long timeout)
{
if (timeout) {
DEFINE_WAIT(wait);
do {
prepare_to_wait(sk_sleep(sk), &wait,
TASK_INTERRUPTIBLE);
if (sk_wait_event(sk, &timeout,
sock_flag(sk, SOCK_DONE)))
break;
} while (!signal_pending(current) && timeout);
finish_wait(sk_sleep(sk), &wait);
}
}
static void virtio_transport_do_close(struct vsock_sock *vsk,
bool cancel_timeout)
{
struct sock *sk = sk_vsock(vsk);
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
if (vsock_stream_has_data(vsk) <= 0)
sk->sk_state = SS_DISCONNECTING;
sk->sk_state_change(sk);
if (vsk->close_work_scheduled &&
(!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
vsk->close_work_scheduled = false;
vsock_remove_sock(vsk);
/* Release refcnt obtained when we scheduled the timeout */
sock_put(sk);
}
}
static void virtio_transport_close_timeout(struct work_struct *work)
{
struct vsock_sock *vsk =
container_of(work, struct vsock_sock, close_work.work);
struct sock *sk = sk_vsock(vsk);
sock_hold(sk);
lock_sock(sk);
if (!sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset(vsk, NULL);
virtio_transport_do_close(vsk, false);
}
vsk->close_work_scheduled = false;
release_sock(sk);
sock_put(sk);
}
/* User context, vsk->sk is locked */
static bool virtio_transport_close(struct vsock_sock *vsk)
{
struct sock *sk = &vsk->sk;
if (!(sk->sk_state == SS_CONNECTED ||
sk->sk_state == SS_DISCONNECTING))
return true;
/* Already received SHUTDOWN from peer, reply with RST */
if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
(void)virtio_transport_reset(vsk, NULL);
return true;
}
if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
virtio_transport_wait_close(sk, sk->sk_lingertime);
if (sock_flag(sk, SOCK_DONE)) {
return true;
}
sock_hold(sk);
INIT_DELAYED_WORK(&vsk->close_work,
virtio_transport_close_timeout);
vsk->close_work_scheduled = true;
schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
return false;
}
void virtio_transport_release(struct vsock_sock *vsk)
{
struct sock *sk = &vsk->sk;
bool remove_sock = true;
lock_sock(sk);
if (sk->sk_type == SOCK_STREAM)
remove_sock = virtio_transport_close(vsk);
release_sock(sk);
if (remove_sock)
vsock_remove_sock(vsk);
}
EXPORT_SYMBOL_GPL(virtio_transport_release);
static int
virtio_transport_recv_connecting(struct sock *sk,
struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
int err;
int skerr;
switch (le16_to_cpu(pkt->hdr.op)) {
case VIRTIO_VSOCK_OP_RESPONSE:
sk->sk_state = SS_CONNECTED;
sk->sk_socket->state = SS_CONNECTED;
vsock_insert_connected(vsk);
sk->sk_state_change(sk);
break;
case VIRTIO_VSOCK_OP_INVALID:
break;
case VIRTIO_VSOCK_OP_RST:
skerr = ECONNRESET;
err = 0;
goto destroy;
default:
skerr = EPROTO;
err = -EINVAL;
goto destroy;
}
return 0;
destroy:
virtio_transport_reset(vsk, pkt);
sk->sk_state = SS_UNCONNECTED;
sk->sk_err = skerr;
sk->sk_error_report(sk);
return err;
}
static int
virtio_transport_recv_connected(struct sock *sk,
struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
struct virtio_vsock_sock *vvs = vsk->trans;
int err = 0;
switch (le16_to_cpu(pkt->hdr.op)) {
case VIRTIO_VSOCK_OP_RW:
pkt->len = le32_to_cpu(pkt->hdr.len);
pkt->off = 0;
spin_lock_bh(&vvs->rx_lock);
virtio_transport_inc_rx_pkt(vvs, pkt);
list_add_tail(&pkt->list, &vvs->rx_queue);
spin_unlock_bh(&vvs->rx_lock);
sk->sk_data_ready(sk);
return err;
case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
sk->sk_write_space(sk);
break;
case VIRTIO_VSOCK_OP_SHUTDOWN:
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
vsk->peer_shutdown |= RCV_SHUTDOWN;
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
vsk->peer_shutdown |= SEND_SHUTDOWN;
if (vsk->peer_shutdown == SHUTDOWN_MASK &&
vsock_stream_has_data(vsk) <= 0)
sk->sk_state = SS_DISCONNECTING;
if (le32_to_cpu(pkt->hdr.flags))
sk->sk_state_change(sk);
break;
case VIRTIO_VSOCK_OP_RST:
virtio_transport_do_close(vsk, true);
break;
default:
err = -EINVAL;
break;
}
virtio_transport_free_pkt(pkt);
return err;
}
static void
virtio_transport_recv_disconnecting(struct sock *sk,
struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
virtio_transport_do_close(vsk, true);
}
static int
virtio_transport_send_response(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RESPONSE,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.remote_cid = le32_to_cpu(pkt->hdr.src_cid),
.remote_port = le32_to_cpu(pkt->hdr.src_port),
.reply = true,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
/* Handle server socket */
static int
virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
struct vsock_sock *vchild;
struct sock *child;
if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
virtio_transport_reset(vsk, pkt);
return -EINVAL;
}
if (sk_acceptq_is_full(sk)) {
virtio_transport_reset(vsk, pkt);
return -ENOMEM;
}
child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
sk->sk_type, 0);
if (!child) {
virtio_transport_reset(vsk, pkt);
return -ENOMEM;
}
sk->sk_ack_backlog++;
lock_sock_nested(child, SINGLE_DEPTH_NESTING);
child->sk_state = SS_CONNECTED;
vchild = vsock_sk(child);
vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
le32_to_cpu(pkt->hdr.dst_port));
vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
le32_to_cpu(pkt->hdr.src_port));
vsock_insert_connected(vchild);
vsock_enqueue_accept(sk, child);
virtio_transport_send_response(vchild, pkt);
release_sock(child);
sk->sk_data_ready(sk);
return 0;
}
static bool virtio_transport_space_update(struct sock *sk,
struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
struct virtio_vsock_sock *vvs = vsk->trans;
bool space_available;
/* buf_alloc and fwd_cnt is always included in the hdr */
spin_lock_bh(&vvs->tx_lock);
vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
space_available = virtio_transport_has_space(vsk);
spin_unlock_bh(&vvs->tx_lock);
return space_available;
}
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
{
struct sockaddr_vm src, dst;
struct vsock_sock *vsk;
struct sock *sk;
bool space_available;
vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid),
le32_to_cpu(pkt->hdr.src_port));
vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid),
le32_to_cpu(pkt->hdr.dst_port));
trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
dst.svm_cid, dst.svm_port,
le32_to_cpu(pkt->hdr.len),
le16_to_cpu(pkt->hdr.type),
le16_to_cpu(pkt->hdr.op),
le32_to_cpu(pkt->hdr.flags),
le32_to_cpu(pkt->hdr.buf_alloc),
le32_to_cpu(pkt->hdr.fwd_cnt));
if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
(void)virtio_transport_reset_no_sock(pkt);
goto free_pkt;
}
/* The socket must be in connected or bound table
* otherwise send reset back
*/
sk = vsock_find_connected_socket(&src, &dst);
if (!sk) {
sk = vsock_find_bound_socket(&dst);
if (!sk) {
(void)virtio_transport_reset_no_sock(pkt);
goto free_pkt;
}
}
vsk = vsock_sk(sk);
space_available = virtio_transport_space_update(sk, pkt);
lock_sock(sk);
/* Update CID in case it has changed after a transport reset event */
vsk->local_addr.svm_cid = dst.svm_cid;
if (space_available)
sk->sk_write_space(sk);
switch (sk->sk_state) {
case VSOCK_SS_LISTEN:
virtio_transport_recv_listen(sk, pkt);
virtio_transport_free_pkt(pkt);
break;
case SS_CONNECTING:
virtio_transport_recv_connecting(sk, pkt);
virtio_transport_free_pkt(pkt);
break;
case SS_CONNECTED:
virtio_transport_recv_connected(sk, pkt);
break;
case SS_DISCONNECTING:
virtio_transport_recv_disconnecting(sk, pkt);
virtio_transport_free_pkt(pkt);
break;
default:
virtio_transport_free_pkt(pkt);
break;
}
release_sock(sk);
/* Release refcnt obtained when we fetched this socket out of the
* bound or connected list.
*/
sock_put(sk);
return;
free_pkt:
virtio_transport_free_pkt(pkt);
}
EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
{
kfree(pkt->buf);
kfree(pkt);
}
EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Asias He");
MODULE_DESCRIPTION("common code for virtio vsock");
...@@ -1644,6 +1644,8 @@ static void vmci_transport_destruct(struct vsock_sock *vsk) ...@@ -1644,6 +1644,8 @@ static void vmci_transport_destruct(struct vsock_sock *vsk)
static void vmci_transport_release(struct vsock_sock *vsk) static void vmci_transport_release(struct vsock_sock *vsk)
{ {
vsock_remove_sock(vsk);
if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) {
vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle);
vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment