Commit 24df31f8 authored by David S. Miller's avatar David S. Miller

Merge branch 'vsock-add-multi-transports-support'

Stefano Garzarella says:

====================
vsock: add multi-transports support

Most of the patches are reviewed by Dexuan, Stefan, and Jorgen.
The following patches need reviews:
- [11/15] vsock: add multi-transports support
- [12/15] vsock/vmci: register vmci_transport only when VMCI guest/host
          are active
- [15/15] vhost/vsock: refuse CID assigned to the guest->host transport

RFC: https://patchwork.ozlabs.org/cover/1168442/
v1: https://patchwork.ozlabs.org/cover/1181986/

v1 -> v2:
- Patch 11:
    + vmci_transport: sent reset when vsock_assign_transport() fails
      [Jorgen]
    + fixed loopback in the guests, checking if the remote_addr is the
      same of transport_g2h->get_local_cid()
    + virtio_transport_common: updated space available while creating
      the new child socket during a connection request
- Patch 12:
    + removed 'features' variable in vmci_transport_init() [Stefan]
    + added a flag to register only once the host [Jorgen]
- Added patch 15 to refuse CID assigned to the guest->host transport in
  the vhost_transport

This series adds the multi-transports support to vsock, following
this proposal: https://www.spinics.net/lists/netdev/msg575792.html

With the multi-transports support, we can use VSOCK with nested VMs
(using also different hypervisors) loading both guest->host and
host->guest transports at the same time.
Before this series, vmci_transport supported this behavior but only
using VMware hypervisor on L0, L1, etc.

The first 9 patches are cleanups and preparations, maybe some of
these can go regardless of this series.

Patch 10 changes the hvs_remote_addr_init(). setting the
VMADDR_CID_HOST as remote CID instead of VMADDR_CID_ANY to make
the choice of transport to be used work properly.

Patch 11 adds multi-transports support.

Patch 12 changes a little bit the vmci_transport and the vmci driver
to register the vmci_transport only when there are active host/guest.

Patch 13 prevents the transport modules unloading while sockets are
assigned to them.

Patch 14 fixes an issue in the bind() logic discoverable only with
the new multi-transport support.

Patch 15 refuses CID assigned to the guest->host transport in the
vhost_transport.

I've tested this series with nested KVM (vsock-transport [L0,L1],
virtio-transport[L1,L2]) and with VMware (L0) + KVM (L1)
(vmci-transport [L0,L1], vhost-transport [L1], virtio-transport[L2]).

Dexuan successfully tested the RFC series on HyperV with a Linux guest.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 798a496b ed8640a9
...@@ -28,6 +28,10 @@ MODULE_PARM_DESC(disable_guest, ...@@ -28,6 +28,10 @@ MODULE_PARM_DESC(disable_guest,
static bool vmci_guest_personality_initialized; static bool vmci_guest_personality_initialized;
static bool vmci_host_personality_initialized; static bool vmci_host_personality_initialized;
static DEFINE_MUTEX(vmci_vsock_mutex); /* protects vmci_vsock_transport_cb */
static vmci_vsock_cb vmci_vsock_transport_cb;
bool vmci_vsock_cb_host_called;
/* /*
* vmci_get_context_id() - Gets the current context ID. * vmci_get_context_id() - Gets the current context ID.
* *
...@@ -45,6 +49,69 @@ u32 vmci_get_context_id(void) ...@@ -45,6 +49,69 @@ u32 vmci_get_context_id(void)
} }
EXPORT_SYMBOL_GPL(vmci_get_context_id); EXPORT_SYMBOL_GPL(vmci_get_context_id);
/*
* vmci_register_vsock_callback() - Register the VSOCK vmci_transport callback.
*
* The callback will be called when the first host or guest becomes active,
* or if they are already active when this function is called.
* To unregister the callback, call this function with NULL parameter.
*
* Returns 0 on success. -EBUSY if a callback is already registered.
*/
int vmci_register_vsock_callback(vmci_vsock_cb callback)
{
int err = 0;
mutex_lock(&vmci_vsock_mutex);
if (vmci_vsock_transport_cb && callback) {
err = -EBUSY;
goto out;
}
vmci_vsock_transport_cb = callback;
if (!vmci_vsock_transport_cb) {
vmci_vsock_cb_host_called = false;
goto out;
}
if (vmci_guest_code_active())
vmci_vsock_transport_cb(false);
if (vmci_host_users() > 0) {
vmci_vsock_cb_host_called = true;
vmci_vsock_transport_cb(true);
}
out:
mutex_unlock(&vmci_vsock_mutex);
return err;
}
EXPORT_SYMBOL_GPL(vmci_register_vsock_callback);
void vmci_call_vsock_callback(bool is_host)
{
mutex_lock(&vmci_vsock_mutex);
if (!vmci_vsock_transport_cb)
goto out;
/* In the host, this function could be called multiple times,
* but we want to register it only once.
*/
if (is_host) {
if (vmci_vsock_cb_host_called)
goto out;
vmci_vsock_cb_host_called = true;
}
vmci_vsock_transport_cb(is_host);
out:
mutex_unlock(&vmci_vsock_mutex);
}
static int __init vmci_drv_init(void) static int __init vmci_drv_init(void)
{ {
int vmci_err; int vmci_err;
......
...@@ -36,10 +36,12 @@ extern struct pci_dev *vmci_pdev; ...@@ -36,10 +36,12 @@ extern struct pci_dev *vmci_pdev;
u32 vmci_get_context_id(void); u32 vmci_get_context_id(void);
int vmci_send_datagram(struct vmci_datagram *dg); int vmci_send_datagram(struct vmci_datagram *dg);
void vmci_call_vsock_callback(bool is_host);
int vmci_host_init(void); int vmci_host_init(void);
void vmci_host_exit(void); void vmci_host_exit(void);
bool vmci_host_code_active(void); bool vmci_host_code_active(void);
int vmci_host_users(void);
int vmci_guest_init(void); int vmci_guest_init(void);
void vmci_guest_exit(void); void vmci_guest_exit(void);
......
...@@ -637,6 +637,8 @@ static int vmci_guest_probe_device(struct pci_dev *pdev, ...@@ -637,6 +637,8 @@ static int vmci_guest_probe_device(struct pci_dev *pdev,
vmci_dev->iobase + VMCI_CONTROL_ADDR); vmci_dev->iobase + VMCI_CONTROL_ADDR);
pci_set_drvdata(pdev, vmci_dev); pci_set_drvdata(pdev, vmci_dev);
vmci_call_vsock_callback(false);
return 0; return 0;
err_free_irq: err_free_irq:
......
...@@ -108,6 +108,11 @@ bool vmci_host_code_active(void) ...@@ -108,6 +108,11 @@ bool vmci_host_code_active(void)
atomic_read(&vmci_host_active_users) > 0); atomic_read(&vmci_host_active_users) > 0);
} }
int vmci_host_users(void)
{
return atomic_read(&vmci_host_active_users);
}
/* /*
* Called on open of /dev/vmci. * Called on open of /dev/vmci.
*/ */
...@@ -338,6 +343,8 @@ static int vmci_host_do_init_context(struct vmci_host_dev *vmci_host_dev, ...@@ -338,6 +343,8 @@ static int vmci_host_do_init_context(struct vmci_host_dev *vmci_host_dev,
vmci_host_dev->ct_type = VMCIOBJ_CONTEXT; vmci_host_dev->ct_type = VMCIOBJ_CONTEXT;
atomic_inc(&vmci_host_active_users); atomic_inc(&vmci_host_active_users);
vmci_call_vsock_callback(true);
retval = 0; retval = 0;
out: out:
......
...@@ -384,6 +384,49 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) ...@@ -384,6 +384,49 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
return val < vq->num; return val < vq->num;
} }
static struct virtio_transport vhost_transport = {
.transport = {
.module = THIS_MODULE,
.get_local_cid = vhost_transport_get_local_cid,
.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
.release = virtio_transport_release,
.connect = virtio_transport_connect,
.shutdown = virtio_transport_shutdown,
.cancel_pkt = vhost_transport_cancel_pkt,
.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_dequeue = virtio_transport_dgram_dequeue,
.dgram_bind = virtio_transport_dgram_bind,
.dgram_allow = virtio_transport_dgram_allow,
.stream_enqueue = virtio_transport_stream_enqueue,
.stream_dequeue = virtio_transport_stream_dequeue,
.stream_has_data = virtio_transport_stream_has_data,
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
.notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
.notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
.notify_send_init = virtio_transport_notify_send_init,
.notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
.notify_buffer_size = virtio_transport_notify_buffer_size,
},
.send_pkt = vhost_transport_send_pkt,
};
static void vhost_vsock_handle_tx_kick(struct vhost_work *work) static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
{ {
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
...@@ -438,7 +481,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work) ...@@ -438,7 +481,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
/* Only accept correctly addressed packets */ /* Only accept correctly addressed packets */
if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid) if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid)
virtio_transport_recv_pkt(pkt); virtio_transport_recv_pkt(&vhost_transport, pkt);
else else
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
...@@ -675,6 +718,12 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) ...@@ -675,6 +718,12 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
if (guest_cid > U32_MAX) if (guest_cid > U32_MAX)
return -EINVAL; return -EINVAL;
/* Refuse if CID is assigned to the guest->host transport (i.e. nested
* VM), to make the loopback work.
*/
if (vsock_find_cid(guest_cid))
return -EADDRINUSE;
/* Refuse if CID is already in use */ /* Refuse if CID is already in use */
mutex_lock(&vhost_vsock_mutex); mutex_lock(&vhost_vsock_mutex);
other = vhost_vsock_get(guest_cid); other = vhost_vsock_get(guest_cid);
...@@ -786,57 +835,12 @@ static struct miscdevice vhost_vsock_misc = { ...@@ -786,57 +835,12 @@ static struct miscdevice vhost_vsock_misc = {
.fops = &vhost_vsock_fops, .fops = &vhost_vsock_fops,
}; };
static struct virtio_transport vhost_transport = {
.transport = {
.get_local_cid = vhost_transport_get_local_cid,
.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
.release = virtio_transport_release,
.connect = virtio_transport_connect,
.shutdown = virtio_transport_shutdown,
.cancel_pkt = vhost_transport_cancel_pkt,
.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_dequeue = virtio_transport_dgram_dequeue,
.dgram_bind = virtio_transport_dgram_bind,
.dgram_allow = virtio_transport_dgram_allow,
.stream_enqueue = virtio_transport_stream_enqueue,
.stream_dequeue = virtio_transport_stream_dequeue,
.stream_has_data = virtio_transport_stream_has_data,
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
.notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
.notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
.notify_send_init = virtio_transport_notify_send_init,
.notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
.set_buffer_size = virtio_transport_set_buffer_size,
.set_min_buffer_size = virtio_transport_set_min_buffer_size,
.set_max_buffer_size = virtio_transport_set_max_buffer_size,
.get_buffer_size = virtio_transport_get_buffer_size,
.get_min_buffer_size = virtio_transport_get_min_buffer_size,
.get_max_buffer_size = virtio_transport_get_max_buffer_size,
},
.send_pkt = vhost_transport_send_pkt,
};
static int __init vhost_vsock_init(void) static int __init vhost_vsock_init(void)
{ {
int ret; int ret;
ret = vsock_core_init(&vhost_transport.transport); ret = vsock_core_register(&vhost_transport.transport,
VSOCK_TRANSPORT_F_H2G);
if (ret < 0) if (ret < 0)
return ret; return ret;
return misc_register(&vhost_vsock_misc); return misc_register(&vhost_vsock_misc);
...@@ -845,7 +849,7 @@ static int __init vhost_vsock_init(void) ...@@ -845,7 +849,7 @@ static int __init vhost_vsock_init(void)
static void __exit vhost_vsock_exit(void) static void __exit vhost_vsock_exit(void)
{ {
misc_deregister(&vhost_vsock_misc); misc_deregister(&vhost_vsock_misc);
vsock_core_exit(); vsock_core_unregister(&vhost_transport.transport);
}; };
module_init(vhost_vsock_init); module_init(vhost_vsock_init);
......
...@@ -7,9 +7,6 @@ ...@@ -7,9 +7,6 @@
#include <net/sock.h> #include <net/sock.h>
#include <net/af_vsock.h> #include <net/af_vsock.h>
#define VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE 128
#define VIRTIO_VSOCK_DEFAULT_BUF_SIZE (1024 * 256)
#define VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE (1024 * 256)
#define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4) #define VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE (1024 * 4)
#define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL #define VIRTIO_VSOCK_MAX_BUF_SIZE 0xFFFFFFFFUL
#define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64) #define VIRTIO_VSOCK_MAX_PKT_BUF_SIZE (1024 * 64)
...@@ -25,11 +22,6 @@ enum { ...@@ -25,11 +22,6 @@ enum {
struct virtio_vsock_sock { struct virtio_vsock_sock {
struct vsock_sock *vsk; struct vsock_sock *vsk;
/* Protected by lock_sock(sk_vsock(trans->vsk)) */
u32 buf_size;
u32 buf_size_min;
u32 buf_size_max;
spinlock_t tx_lock; spinlock_t tx_lock;
spinlock_t rx_lock; spinlock_t rx_lock;
...@@ -92,12 +84,6 @@ s64 virtio_transport_stream_has_space(struct vsock_sock *vsk); ...@@ -92,12 +84,6 @@ s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
int virtio_transport_do_socket_init(struct vsock_sock *vsk, int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk); struct vsock_sock *psk);
u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk);
u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk);
u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk);
void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val);
void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val);
void virtio_transport_set_max_buffer_size(struct vsock_sock *vs, u64 val);
int int
virtio_transport_notify_poll_in(struct vsock_sock *vsk, virtio_transport_notify_poll_in(struct vsock_sock *vsk,
size_t target, size_t target,
...@@ -124,6 +110,7 @@ int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, ...@@ -124,6 +110,7 @@ int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data); struct vsock_transport_send_notify_data *data);
int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
ssize_t written, struct vsock_transport_send_notify_data *data); ssize_t written, struct vsock_transport_send_notify_data *data);
void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk); u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk); bool virtio_transport_stream_is_active(struct vsock_sock *vsk);
...@@ -150,7 +137,8 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk, ...@@ -150,7 +137,8 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
void virtio_transport_destruct(struct vsock_sock *vsk); void virtio_transport_destruct(struct vsock_sock *vsk);
void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt); void virtio_transport_recv_pkt(struct virtio_transport *t,
struct virtio_vsock_pkt *pkt);
void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt); void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt);
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt); void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt);
u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted); u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
......
/* SPDX-License-Identifier: GPL-2.0-only */
/*
* VMware vSockets Driver
*
* Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
*/
#ifndef _VM_SOCKETS_H
#define _VM_SOCKETS_H
#include <uapi/linux/vm_sockets.h>
int vm_sockets_get_local_cid(void);
#endif /* _VM_SOCKETS_H */
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
struct msghdr; struct msghdr;
typedef void (vmci_device_shutdown_fn) (void *device_registration, typedef void (vmci_device_shutdown_fn) (void *device_registration,
void *user_data); void *user_data);
typedef void (*vmci_vsock_cb) (bool is_host);
int vmci_datagram_create_handle(u32 resource_id, u32 flags, int vmci_datagram_create_handle(u32 resource_id, u32 flags,
vmci_datagram_recv_cb recv_cb, vmci_datagram_recv_cb recv_cb,
...@@ -37,6 +38,7 @@ int vmci_doorbell_destroy(struct vmci_handle handle); ...@@ -37,6 +38,7 @@ int vmci_doorbell_destroy(struct vmci_handle handle);
int vmci_doorbell_notify(struct vmci_handle handle, u32 priv_flags); int vmci_doorbell_notify(struct vmci_handle handle, u32 priv_flags);
u32 vmci_get_context_id(void); u32 vmci_get_context_id(void);
bool vmci_is_context_owner(u32 context_id, kuid_t uid); bool vmci_is_context_owner(u32 context_id, kuid_t uid);
int vmci_register_vsock_callback(vmci_vsock_cb callback);
int vmci_event_subscribe(u32 event, int vmci_event_subscribe(u32 event,
vmci_event_cb callback, void *callback_data, vmci_event_cb callback, void *callback_data,
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <linux/kernel.h> #include <linux/kernel.h>
#include <linux/workqueue.h> #include <linux/workqueue.h>
#include <linux/vm_sockets.h> #include <uapi/linux/vm_sockets.h>
#include "vsock_addr.h" #include "vsock_addr.h"
...@@ -27,6 +27,7 @@ extern spinlock_t vsock_table_lock; ...@@ -27,6 +27,7 @@ extern spinlock_t vsock_table_lock;
struct vsock_sock { struct vsock_sock {
/* sk must be the first member. */ /* sk must be the first member. */
struct sock sk; struct sock sk;
const struct vsock_transport *transport;
struct sockaddr_vm local_addr; struct sockaddr_vm local_addr;
struct sockaddr_vm remote_addr; struct sockaddr_vm remote_addr;
/* Links for the global tables of bound and connected sockets. */ /* Links for the global tables of bound and connected sockets. */
...@@ -64,16 +65,18 @@ struct vsock_sock { ...@@ -64,16 +65,18 @@ struct vsock_sock {
bool sent_request; bool sent_request;
bool ignore_connecting_rst; bool ignore_connecting_rst;
/* Protected by lock_sock(sk) */
u64 buffer_size;
u64 buffer_min_size;
u64 buffer_max_size;
/* Private to transport. */ /* Private to transport. */
void *trans; void *trans;
}; };
s64 vsock_stream_has_data(struct vsock_sock *vsk); s64 vsock_stream_has_data(struct vsock_sock *vsk);
s64 vsock_stream_has_space(struct vsock_sock *vsk); s64 vsock_stream_has_space(struct vsock_sock *vsk);
struct sock *__vsock_create(struct net *net, struct sock *vsock_create_connected(struct sock *parent);
struct socket *sock,
struct sock *parent,
gfp_t priority, unsigned short type, int kern);
/**** TRANSPORT ****/ /**** TRANSPORT ****/
...@@ -88,7 +91,17 @@ struct vsock_transport_send_notify_data { ...@@ -88,7 +91,17 @@ struct vsock_transport_send_notify_data {
u64 data2; /* Transport-defined. */ u64 data2; /* Transport-defined. */
}; };
/* Transport features flags */
/* Transport provides host->guest communication */
#define VSOCK_TRANSPORT_F_H2G 0x00000001
/* Transport provides guest->host communication */
#define VSOCK_TRANSPORT_F_G2H 0x00000002
/* Transport provides DGRAM communication */
#define VSOCK_TRANSPORT_F_DGRAM 0x00000004
struct vsock_transport { struct vsock_transport {
struct module *module;
/* Initialize/tear-down socket. */ /* Initialize/tear-down socket. */
int (*init)(struct vsock_sock *, struct vsock_sock *); int (*init)(struct vsock_sock *, struct vsock_sock *);
void (*destruct)(struct vsock_sock *); void (*destruct)(struct vsock_sock *);
...@@ -139,33 +152,23 @@ struct vsock_transport { ...@@ -139,33 +152,23 @@ struct vsock_transport {
struct vsock_transport_send_notify_data *); struct vsock_transport_send_notify_data *);
int (*notify_send_post_enqueue)(struct vsock_sock *, ssize_t, int (*notify_send_post_enqueue)(struct vsock_sock *, ssize_t,
struct vsock_transport_send_notify_data *); struct vsock_transport_send_notify_data *);
/* sk_lock held by the caller */
void (*notify_buffer_size)(struct vsock_sock *, u64 *);
/* Shutdown. */ /* Shutdown. */
int (*shutdown)(struct vsock_sock *, int); int (*shutdown)(struct vsock_sock *, int);
/* Buffer sizes. */
void (*set_buffer_size)(struct vsock_sock *, u64);
void (*set_min_buffer_size)(struct vsock_sock *, u64);
void (*set_max_buffer_size)(struct vsock_sock *, u64);
u64 (*get_buffer_size)(struct vsock_sock *);
u64 (*get_min_buffer_size)(struct vsock_sock *);
u64 (*get_max_buffer_size)(struct vsock_sock *);
/* Addressing. */ /* Addressing. */
u32 (*get_local_cid)(void); u32 (*get_local_cid)(void);
}; };
/**** CORE ****/ /**** CORE ****/
int __vsock_core_init(const struct vsock_transport *t, struct module *owner); int vsock_core_register(const struct vsock_transport *t, int features);
static inline int vsock_core_init(const struct vsock_transport *t) void vsock_core_unregister(const struct vsock_transport *t);
{
return __vsock_core_init(t, THIS_MODULE);
}
void vsock_core_exit(void);
/* The transport may downcast this to access transport-specific functions */ /* The transport may downcast this to access transport-specific functions */
const struct vsock_transport *vsock_core_get_transport(void); const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk);
/**** UTILS ****/ /**** UTILS ****/
...@@ -193,6 +196,8 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, ...@@ -193,6 +196,8 @@ struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
struct sockaddr_vm *dst); struct sockaddr_vm *dst);
void vsock_remove_sock(struct vsock_sock *vsk); void vsock_remove_sock(struct vsock_sock *vsk);
void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)); void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
bool vsock_find_cid(unsigned int cid);
/**** TAP ****/ /**** TAP ****/
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#ifndef _VSOCK_ADDR_H_ #ifndef _VSOCK_ADDR_H_
#define _VSOCK_ADDR_H_ #define _VSOCK_ADDR_H_
#include <linux/vm_sockets.h> #include <uapi/linux/vm_sockets.h>
void vsock_addr_init(struct sockaddr_vm *addr, u32 cid, u32 port); void vsock_addr_init(struct sockaddr_vm *addr, u32 cid, u32 port);
int vsock_addr_validate(const struct sockaddr_vm *addr); int vsock_addr_validate(const struct sockaddr_vm *addr);
......
This diff is collapsed.
...@@ -165,6 +165,8 @@ static const guid_t srv_id_template = ...@@ -165,6 +165,8 @@ static const guid_t srv_id_template =
GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58, GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3); 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
static bool hvs_check_transport(struct vsock_sock *vsk);
static bool is_valid_srv_id(const guid_t *id) static bool is_valid_srv_id(const guid_t *id)
{ {
return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4); return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
...@@ -188,7 +190,8 @@ static void hvs_remote_addr_init(struct sockaddr_vm *remote, ...@@ -188,7 +190,8 @@ static void hvs_remote_addr_init(struct sockaddr_vm *remote,
static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT; static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
struct sock *sk; struct sock *sk;
vsock_addr_init(remote, VMADDR_CID_ANY, VMADDR_PORT_ANY); /* Remote peer is always the host */
vsock_addr_init(remote, VMADDR_CID_HOST, VMADDR_PORT_ANY);
while (1) { while (1) {
/* Wrap around ? */ /* Wrap around ? */
...@@ -360,13 +363,24 @@ static void hvs_open_connection(struct vmbus_channel *chan) ...@@ -360,13 +363,24 @@ static void hvs_open_connection(struct vmbus_channel *chan)
if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
goto out; goto out;
new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, new = vsock_create_connected(sk);
sk->sk_type, 0);
if (!new) if (!new)
goto out; goto out;
new->sk_state = TCP_SYN_SENT; new->sk_state = TCP_SYN_SENT;
vnew = vsock_sk(new); vnew = vsock_sk(new);
hvs_addr_init(&vnew->local_addr, if_type);
hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
ret = vsock_assign_transport(vnew, vsock_sk(sk));
/* Transport assigned (looking at remote_addr) must be the
* same where we received the request.
*/
if (ret || !hvs_check_transport(vnew)) {
sock_put(new);
goto out;
}
hvs_new = vnew->trans; hvs_new = vnew->trans;
hvs_new->chan = chan; hvs_new->chan = chan;
} else { } else {
...@@ -430,9 +444,6 @@ static void hvs_open_connection(struct vmbus_channel *chan) ...@@ -430,9 +444,6 @@ static void hvs_open_connection(struct vmbus_channel *chan)
new->sk_state = TCP_ESTABLISHED; new->sk_state = TCP_ESTABLISHED;
sk_acceptq_added(sk); sk_acceptq_added(sk);
hvs_addr_init(&vnew->local_addr, if_type);
hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
hvs_new->vm_srv_id = *if_type; hvs_new->vm_srv_id = *if_type;
hvs_new->host_srv_id = *if_instance; hvs_new->host_srv_id = *if_instance;
...@@ -845,37 +856,9 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, ...@@ -845,37 +856,9 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
return 0; return 0;
} }
static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val)
{
/* Ignored. */
}
static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
{
/* Ignored. */
}
static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
{
/* Ignored. */
}
static u64 hvs_get_buffer_size(struct vsock_sock *vsk)
{
return -ENOPROTOOPT;
}
static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk)
{
return -ENOPROTOOPT;
}
static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk)
{
return -ENOPROTOOPT;
}
static struct vsock_transport hvs_transport = { static struct vsock_transport hvs_transport = {
.module = THIS_MODULE,
.get_local_cid = hvs_get_local_cid, .get_local_cid = hvs_get_local_cid,
.init = hvs_sock_init, .init = hvs_sock_init,
...@@ -908,14 +891,13 @@ static struct vsock_transport hvs_transport = { ...@@ -908,14 +891,13 @@ static struct vsock_transport hvs_transport = {
.notify_send_pre_enqueue = hvs_notify_send_pre_enqueue, .notify_send_pre_enqueue = hvs_notify_send_pre_enqueue,
.notify_send_post_enqueue = hvs_notify_send_post_enqueue, .notify_send_post_enqueue = hvs_notify_send_post_enqueue,
.set_buffer_size = hvs_set_buffer_size,
.set_min_buffer_size = hvs_set_min_buffer_size,
.set_max_buffer_size = hvs_set_max_buffer_size,
.get_buffer_size = hvs_get_buffer_size,
.get_min_buffer_size = hvs_get_min_buffer_size,
.get_max_buffer_size = hvs_get_max_buffer_size,
}; };
static bool hvs_check_transport(struct vsock_sock *vsk)
{
return vsk->transport == &hvs_transport;
}
static int hvs_probe(struct hv_device *hdev, static int hvs_probe(struct hv_device *hdev,
const struct hv_vmbus_device_id *dev_id) const struct hv_vmbus_device_id *dev_id)
{ {
...@@ -964,7 +946,7 @@ static int __init hvs_init(void) ...@@ -964,7 +946,7 @@ static int __init hvs_init(void)
if (ret != 0) if (ret != 0)
return ret; return ret;
ret = vsock_core_init(&hvs_transport); ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H);
if (ret) { if (ret) {
vmbus_driver_unregister(&hvs_drv); vmbus_driver_unregister(&hvs_drv);
return ret; return ret;
...@@ -975,7 +957,7 @@ static int __init hvs_init(void) ...@@ -975,7 +957,7 @@ static int __init hvs_init(void)
static void __exit hvs_exit(void) static void __exit hvs_exit(void)
{ {
vsock_core_exit(); vsock_core_unregister(&hvs_transport);
vmbus_driver_unregister(&hvs_drv); vmbus_driver_unregister(&hvs_drv);
} }
......
...@@ -86,33 +86,6 @@ static u32 virtio_transport_get_local_cid(void) ...@@ -86,33 +86,6 @@ static u32 virtio_transport_get_local_cid(void)
return ret; return ret;
} }
static void virtio_transport_loopback_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, loopback_work);
LIST_HEAD(pkts);
spin_lock_bh(&vsock->loopback_list_lock);
list_splice_init(&vsock->loopback_list, &pkts);
spin_unlock_bh(&vsock->loopback_list_lock);
mutex_lock(&vsock->rx_lock);
if (!vsock->rx_run)
goto out;
while (!list_empty(&pkts)) {
struct virtio_vsock_pkt *pkt;
pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);
virtio_transport_recv_pkt(pkt);
}
out:
mutex_unlock(&vsock->rx_lock);
}
static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock, static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock,
struct virtio_vsock_pkt *pkt) struct virtio_vsock_pkt *pkt)
{ {
...@@ -370,59 +343,6 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock) ...@@ -370,59 +343,6 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
return val < virtqueue_get_vring_size(vq); return val < virtqueue_get_vring_size(vq);
} }
static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, rx_work);
struct virtqueue *vq;
vq = vsock->vqs[VSOCK_VQ_RX];
mutex_lock(&vsock->rx_lock);
if (!vsock->rx_run)
goto out;
do {
virtqueue_disable_cb(vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
unsigned int len;
if (!virtio_transport_more_replies(vsock)) {
/* Stop rx until the device processes already
* pending replies. Leave rx virtqueue
* callbacks disabled.
*/
goto out;
}
pkt = virtqueue_get_buf(vq, &len);
if (!pkt) {
break;
}
vsock->rx_buf_nr--;
/* Drop short/long packets */
if (unlikely(len < sizeof(pkt->hdr) ||
len > sizeof(pkt->hdr) + pkt->len)) {
virtio_transport_free_pkt(pkt);
continue;
}
pkt->len = len - sizeof(pkt->hdr);
virtio_transport_deliver_tap_pkt(pkt);
virtio_transport_recv_pkt(pkt);
}
} while (!virtqueue_enable_cb(vq));
out:
if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
virtio_vsock_rx_fill(vsock);
mutex_unlock(&vsock->rx_lock);
}
/* event_lock must be held */ /* event_lock must be held */
static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
struct virtio_vsock_event *event) struct virtio_vsock_event *event)
...@@ -542,6 +462,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq) ...@@ -542,6 +462,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
static struct virtio_transport virtio_transport = { static struct virtio_transport virtio_transport = {
.transport = { .transport = {
.module = THIS_MODULE,
.get_local_cid = virtio_transport_get_local_cid, .get_local_cid = virtio_transport_get_local_cid,
.init = virtio_transport_do_socket_init, .init = virtio_transport_do_socket_init,
...@@ -574,18 +496,92 @@ static struct virtio_transport virtio_transport = { ...@@ -574,18 +496,92 @@ static struct virtio_transport virtio_transport = {
.notify_send_pre_block = virtio_transport_notify_send_pre_block, .notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
.notify_buffer_size = virtio_transport_notify_buffer_size,
.set_buffer_size = virtio_transport_set_buffer_size,
.set_min_buffer_size = virtio_transport_set_min_buffer_size,
.set_max_buffer_size = virtio_transport_set_max_buffer_size,
.get_buffer_size = virtio_transport_get_buffer_size,
.get_min_buffer_size = virtio_transport_get_min_buffer_size,
.get_max_buffer_size = virtio_transport_get_max_buffer_size,
}, },
.send_pkt = virtio_transport_send_pkt, .send_pkt = virtio_transport_send_pkt,
}; };
static void virtio_transport_loopback_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, loopback_work);
LIST_HEAD(pkts);
spin_lock_bh(&vsock->loopback_list_lock);
list_splice_init(&vsock->loopback_list, &pkts);
spin_unlock_bh(&vsock->loopback_list_lock);
mutex_lock(&vsock->rx_lock);
if (!vsock->rx_run)
goto out;
while (!list_empty(&pkts)) {
struct virtio_vsock_pkt *pkt;
pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);
virtio_transport_recv_pkt(&virtio_transport, pkt);
}
out:
mutex_unlock(&vsock->rx_lock);
}
static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, rx_work);
struct virtqueue *vq;
vq = vsock->vqs[VSOCK_VQ_RX];
mutex_lock(&vsock->rx_lock);
if (!vsock->rx_run)
goto out;
do {
virtqueue_disable_cb(vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
unsigned int len;
if (!virtio_transport_more_replies(vsock)) {
/* Stop rx until the device processes already
* pending replies. Leave rx virtqueue
* callbacks disabled.
*/
goto out;
}
pkt = virtqueue_get_buf(vq, &len);
if (!pkt) {
break;
}
vsock->rx_buf_nr--;
/* Drop short/long packets */
if (unlikely(len < sizeof(pkt->hdr) ||
len > sizeof(pkt->hdr) + pkt->len)) {
virtio_transport_free_pkt(pkt);
continue;
}
pkt->len = len - sizeof(pkt->hdr);
virtio_transport_deliver_tap_pkt(pkt);
virtio_transport_recv_pkt(&virtio_transport, pkt);
}
} while (!virtqueue_enable_cb(vq));
out:
if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
virtio_vsock_rx_fill(vsock);
mutex_unlock(&vsock->rx_lock);
}
static int virtio_vsock_probe(struct virtio_device *vdev) static int virtio_vsock_probe(struct virtio_device *vdev)
{ {
vq_callback_t *callbacks[] = { vq_callback_t *callbacks[] = {
...@@ -776,7 +772,8 @@ static int __init virtio_vsock_init(void) ...@@ -776,7 +772,8 @@ static int __init virtio_vsock_init(void)
if (!virtio_vsock_workqueue) if (!virtio_vsock_workqueue)
return -ENOMEM; return -ENOMEM;
ret = vsock_core_init(&virtio_transport.transport); ret = vsock_core_register(&virtio_transport.transport,
VSOCK_TRANSPORT_F_G2H);
if (ret) if (ret)
goto out_wq; goto out_wq;
...@@ -787,7 +784,7 @@ static int __init virtio_vsock_init(void) ...@@ -787,7 +784,7 @@ static int __init virtio_vsock_init(void)
return 0; return 0;
out_vci: out_vci:
vsock_core_exit(); vsock_core_unregister(&virtio_transport.transport);
out_wq: out_wq:
destroy_workqueue(virtio_vsock_workqueue); destroy_workqueue(virtio_vsock_workqueue);
return ret; return ret;
...@@ -796,7 +793,7 @@ static int __init virtio_vsock_init(void) ...@@ -796,7 +793,7 @@ static int __init virtio_vsock_init(void)
static void __exit virtio_vsock_exit(void) static void __exit virtio_vsock_exit(void)
{ {
unregister_virtio_driver(&virtio_vsock_driver); unregister_virtio_driver(&virtio_vsock_driver);
vsock_core_exit(); vsock_core_unregister(&virtio_transport.transport);
destroy_workqueue(virtio_vsock_workqueue); destroy_workqueue(virtio_vsock_workqueue);
} }
......
...@@ -29,9 +29,10 @@ ...@@ -29,9 +29,10 @@
/* Threshold for detecting small packets to copy */ /* Threshold for detecting small packets to copy */
#define GOOD_COPY_LEN 128 #define GOOD_COPY_LEN 128
static const struct virtio_transport *virtio_transport_get_ops(void) static const struct virtio_transport *
virtio_transport_get_ops(struct vsock_sock *vsk)
{ {
const struct vsock_transport *t = vsock_core_get_transport(); const struct vsock_transport *t = vsock_core_get_transport(vsk);
return container_of(t, struct virtio_transport, transport); return container_of(t, struct virtio_transport, transport);
} }
...@@ -168,7 +169,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, ...@@ -168,7 +169,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt; struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len; u32 pkt_len = info->pkt_len;
src_cid = vm_sockets_get_local_cid(); src_cid = virtio_transport_get_ops(vsk)->transport.get_local_cid();
src_port = vsk->local_addr.svm_port; src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) { if (!info->remote_cid) {
dst_cid = vsk->remote_addr.svm_cid; dst_cid = vsk->remote_addr.svm_cid;
...@@ -201,7 +202,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, ...@@ -201,7 +202,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
virtio_transport_inc_tx_pkt(vvs, pkt); virtio_transport_inc_tx_pkt(vvs, pkt);
return virtio_transport_get_ops()->send_pkt(pkt); return virtio_transport_get_ops(vsk)->send_pkt(pkt);
} }
static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
...@@ -452,20 +453,16 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk, ...@@ -452,20 +453,16 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk,
vsk->trans = vvs; vsk->trans = vvs;
vvs->vsk = vsk; vvs->vsk = vsk;
if (psk) { if (psk && psk->trans) {
struct virtio_vsock_sock *ptrans = psk->trans; struct virtio_vsock_sock *ptrans = psk->trans;
vvs->buf_size = ptrans->buf_size;
vvs->buf_size_min = ptrans->buf_size_min;
vvs->buf_size_max = ptrans->buf_size_max;
vvs->peer_buf_alloc = ptrans->peer_buf_alloc; vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
} else {
vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
} }
vvs->buf_alloc = vvs->buf_size; if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
vvs->buf_alloc = vsk->buffer_size;
spin_lock_init(&vvs->rx_lock); spin_lock_init(&vvs->rx_lock);
spin_lock_init(&vvs->tx_lock); spin_lock_init(&vvs->tx_lock);
...@@ -475,71 +472,20 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk, ...@@ -475,71 +472,20 @@ int virtio_transport_do_socket_init(struct vsock_sock *vsk,
} }
EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) /* sk_lock held by the caller */
{ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
{ {
struct virtio_vsock_sock *vvs = vsk->trans; struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size_min; if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
} *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) vvs->buf_alloc = *val;
{
struct virtio_vsock_sock *vvs = vsk->trans;
return vvs->buf_size_max;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
val = VIRTIO_VSOCK_MAX_BUF_SIZE;
if (val < vvs->buf_size_min)
vvs->buf_size_min = val;
if (val > vvs->buf_size_max)
vvs->buf_size_max = val;
vvs->buf_size = val;
vvs->buf_alloc = val;
virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
NULL); NULL);
} }
EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
val = VIRTIO_VSOCK_MAX_BUF_SIZE;
if (val > vvs->buf_size)
vvs->buf_size = val;
vvs->buf_size_min = val;
}
EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
val = VIRTIO_VSOCK_MAX_BUF_SIZE;
if (val < vvs->buf_size)
vvs->buf_size = val;
vvs->buf_size_max = val;
}
EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
int int
virtio_transport_notify_poll_in(struct vsock_sock *vsk, virtio_transport_notify_poll_in(struct vsock_sock *vsk,
...@@ -631,9 +577,7 @@ EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); ...@@ -631,9 +577,7 @@ EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
{ {
struct virtio_vsock_sock *vvs = vsk->trans; return vsk->buffer_size;
return vvs->buf_size;
} }
EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
...@@ -745,9 +689,9 @@ static int virtio_transport_reset(struct vsock_sock *vsk, ...@@ -745,9 +689,9 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
/* Normally packets are associated with a socket. There may be no socket if an /* Normally packets are associated with a socket. There may be no socket if an
* attempt was made to connect to a socket that does not exist. * attempt was made to connect to a socket that does not exist.
*/ */
static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
struct virtio_vsock_pkt *pkt)
{ {
const struct virtio_transport *t;
struct virtio_vsock_pkt *reply; struct virtio_vsock_pkt *reply;
struct virtio_vsock_pkt_info info = { struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST, .op = VIRTIO_VSOCK_OP_RST,
...@@ -767,7 +711,6 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) ...@@ -767,7 +711,6 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
if (!reply) if (!reply)
return -ENOMEM; return -ENOMEM;
t = virtio_transport_get_ops();
if (!t) { if (!t) {
virtio_transport_free_pkt(reply); virtio_transport_free_pkt(reply);
return -ENOTCONN; return -ENOTCONN;
...@@ -1043,13 +986,39 @@ virtio_transport_send_response(struct vsock_sock *vsk, ...@@ -1043,13 +986,39 @@ virtio_transport_send_response(struct vsock_sock *vsk,
return virtio_transport_send_pkt_info(vsk, &info); return virtio_transport_send_pkt_info(vsk, &info);
} }
static bool virtio_transport_space_update(struct sock *sk,
struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
struct virtio_vsock_sock *vvs = vsk->trans;
bool space_available;
/* Listener sockets are not associated with any transport, so we are
* not able to take the state to see if there is space available in the
* remote peer, but since they are only used to receive requests, we
* can assume that there is always space available in the other peer.
*/
if (!vvs)
return true;
/* buf_alloc and fwd_cnt is always included in the hdr */
spin_lock_bh(&vvs->tx_lock);
vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
space_available = virtio_transport_has_space(vsk);
spin_unlock_bh(&vvs->tx_lock);
return space_available;
}
/* Handle server socket */ /* Handle server socket */
static int static int
virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
struct virtio_transport *t)
{ {
struct vsock_sock *vsk = vsock_sk(sk); struct vsock_sock *vsk = vsock_sk(sk);
struct vsock_sock *vchild; struct vsock_sock *vchild;
struct sock *child; struct sock *child;
int ret;
if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
virtio_transport_reset(vsk, pkt); virtio_transport_reset(vsk, pkt);
...@@ -1061,8 +1030,7 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) ...@@ -1061,8 +1030,7 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
return -ENOMEM; return -ENOMEM;
} }
child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, child = vsock_create_connected(sk);
sk->sk_type, 0);
if (!child) { if (!child) {
virtio_transport_reset(vsk, pkt); virtio_transport_reset(vsk, pkt);
return -ENOMEM; return -ENOMEM;
...@@ -1080,6 +1048,20 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) ...@@ -1080,6 +1048,20 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
le32_to_cpu(pkt->hdr.src_port)); le32_to_cpu(pkt->hdr.src_port));
ret = vsock_assign_transport(vchild, vsk);
/* Transport assigned (looking at remote_addr) must be the same
* where we received the request.
*/
if (ret || vchild->transport != &t->transport) {
release_sock(child);
virtio_transport_reset(vsk, pkt);
sock_put(child);
return ret;
}
if (virtio_transport_space_update(child, pkt))
child->sk_write_space(child);
vsock_insert_connected(vchild); vsock_insert_connected(vchild);
vsock_enqueue_accept(sk, child); vsock_enqueue_accept(sk, child);
virtio_transport_send_response(vchild, pkt); virtio_transport_send_response(vchild, pkt);
...@@ -1090,26 +1072,11 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) ...@@ -1090,26 +1072,11 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
return 0; return 0;
} }
static bool virtio_transport_space_update(struct sock *sk,
struct virtio_vsock_pkt *pkt)
{
struct vsock_sock *vsk = vsock_sk(sk);
struct virtio_vsock_sock *vvs = vsk->trans;
bool space_available;
/* buf_alloc and fwd_cnt is always included in the hdr */
spin_lock_bh(&vvs->tx_lock);
vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
space_available = virtio_transport_has_space(vsk);
spin_unlock_bh(&vvs->tx_lock);
return space_available;
}
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock. * lock.
*/ */
void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) void virtio_transport_recv_pkt(struct virtio_transport *t,
struct virtio_vsock_pkt *pkt)
{ {
struct sockaddr_vm src, dst; struct sockaddr_vm src, dst;
struct vsock_sock *vsk; struct vsock_sock *vsk;
...@@ -1131,7 +1098,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) ...@@ -1131,7 +1098,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
le32_to_cpu(pkt->hdr.fwd_cnt)); le32_to_cpu(pkt->hdr.fwd_cnt));
if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
(void)virtio_transport_reset_no_sock(pkt); (void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt; goto free_pkt;
} }
...@@ -1142,7 +1109,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) ...@@ -1142,7 +1109,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
if (!sk) { if (!sk) {
sk = vsock_find_bound_socket(&dst); sk = vsock_find_bound_socket(&dst);
if (!sk) { if (!sk) {
(void)virtio_transport_reset_no_sock(pkt); (void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt; goto free_pkt;
} }
} }
...@@ -1161,7 +1128,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) ...@@ -1161,7 +1128,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
switch (sk->sk_state) { switch (sk->sk_state) {
case TCP_LISTEN: case TCP_LISTEN:
virtio_transport_recv_listen(sk, pkt); virtio_transport_recv_listen(sk, pkt, t);
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
break; break;
case TCP_SYN_SENT: case TCP_SYN_SENT:
...@@ -1179,6 +1146,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) ...@@ -1179,6 +1146,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
virtio_transport_free_pkt(pkt); virtio_transport_free_pkt(pkt);
break; break;
} }
release_sock(sk); release_sock(sk);
/* Release refcnt obtained when we fetched this socket out of the /* Release refcnt obtained when we fetched this socket out of the
......
...@@ -57,6 +57,7 @@ static bool vmci_transport_old_proto_override(bool *old_pkt_proto); ...@@ -57,6 +57,7 @@ static bool vmci_transport_old_proto_override(bool *old_pkt_proto);
static u16 vmci_transport_new_proto_supported_versions(void); static u16 vmci_transport_new_proto_supported_versions(void);
static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto, static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
bool old_pkt_proto); bool old_pkt_proto);
static bool vmci_check_transport(struct vsock_sock *vsk);
struct vmci_transport_recv_pkt_info { struct vmci_transport_recv_pkt_info {
struct work_struct work; struct work_struct work;
...@@ -74,15 +75,6 @@ static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; ...@@ -74,15 +75,6 @@ static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
static int PROTOCOL_OVERRIDE = -1; static int PROTOCOL_OVERRIDE = -1;
#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN 128
#define VMCI_TRANSPORT_DEFAULT_QP_SIZE 262144
#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX 262144
/* The default peer timeout indicates how long we will wait for a peer response
* to a control message.
*/
#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
/* Helper function to convert from a VMCI error code to a VSock error code. */ /* Helper function to convert from a VMCI error code to a VSock error code. */
static s32 vmci_transport_error_to_vsock_error(s32 vmci_error) static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
...@@ -1013,8 +1005,7 @@ static int vmci_transport_recv_listen(struct sock *sk, ...@@ -1013,8 +1005,7 @@ static int vmci_transport_recv_listen(struct sock *sk,
return -ECONNREFUSED; return -ECONNREFUSED;
} }
pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, pending = vsock_create_connected(sk);
sk->sk_type, 0);
if (!pending) { if (!pending) {
vmci_transport_send_reset(sk, pkt); vmci_transport_send_reset(sk, pkt);
return -ENOMEM; return -ENOMEM;
...@@ -1027,14 +1018,24 @@ static int vmci_transport_recv_listen(struct sock *sk, ...@@ -1027,14 +1018,24 @@ static int vmci_transport_recv_listen(struct sock *sk,
vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
pkt->src_port); pkt->src_port);
err = vsock_assign_transport(vpending, vsock_sk(sk));
/* Transport assigned (looking at remote_addr) must be the same
* where we received the request.
*/
if (err || !vmci_check_transport(vpending)) {
vmci_transport_send_reset(sk, pkt);
sock_put(pending);
return err;
}
/* If the proposed size fits within our min/max, accept it. Otherwise /* If the proposed size fits within our min/max, accept it. Otherwise
* propose our own size. * propose our own size.
*/ */
if (pkt->u.size >= vmci_trans(vpending)->queue_pair_min_size && if (pkt->u.size >= vpending->buffer_min_size &&
pkt->u.size <= vmci_trans(vpending)->queue_pair_max_size) { pkt->u.size <= vpending->buffer_max_size) {
qp_size = pkt->u.size; qp_size = pkt->u.size;
} else { } else {
qp_size = vmci_trans(vpending)->queue_pair_size; qp_size = vpending->buffer_size;
} }
/* Figure out if we are using old or new requests based on the /* Figure out if we are using old or new requests based on the
...@@ -1103,7 +1104,7 @@ static int vmci_transport_recv_listen(struct sock *sk, ...@@ -1103,7 +1104,7 @@ static int vmci_transport_recv_listen(struct sock *sk,
pending->sk_state = TCP_SYN_SENT; pending->sk_state = TCP_SYN_SENT;
vmci_trans(vpending)->produce_size = vmci_trans(vpending)->produce_size =
vmci_trans(vpending)->consume_size = qp_size; vmci_trans(vpending)->consume_size = qp_size;
vmci_trans(vpending)->queue_pair_size = qp_size; vpending->buffer_size = qp_size;
vmci_trans(vpending)->notify_ops->process_request(pending); vmci_trans(vpending)->notify_ops->process_request(pending);
...@@ -1397,8 +1398,8 @@ static int vmci_transport_recv_connecting_client_negotiate( ...@@ -1397,8 +1398,8 @@ static int vmci_transport_recv_connecting_client_negotiate(
vsk->ignore_connecting_rst = false; vsk->ignore_connecting_rst = false;
/* Verify that we're OK with the proposed queue pair size */ /* Verify that we're OK with the proposed queue pair size */
if (pkt->u.size < vmci_trans(vsk)->queue_pair_min_size || if (pkt->u.size < vsk->buffer_min_size ||
pkt->u.size > vmci_trans(vsk)->queue_pair_max_size) { pkt->u.size > vsk->buffer_max_size) {
err = -EINVAL; err = -EINVAL;
goto destroy; goto destroy;
} }
...@@ -1503,8 +1504,7 @@ vmci_transport_recv_connecting_client_invalid(struct sock *sk, ...@@ -1503,8 +1504,7 @@ vmci_transport_recv_connecting_client_invalid(struct sock *sk,
vsk->sent_request = false; vsk->sent_request = false;
vsk->ignore_connecting_rst = true; vsk->ignore_connecting_rst = true;
err = vmci_transport_send_conn_request( err = vmci_transport_send_conn_request(sk, vsk->buffer_size);
sk, vmci_trans(vsk)->queue_pair_size);
if (err < 0) if (err < 0)
err = vmci_transport_error_to_vsock_error(err); err = vmci_transport_error_to_vsock_error(err);
else else
...@@ -1588,21 +1588,6 @@ static int vmci_transport_socket_init(struct vsock_sock *vsk, ...@@ -1588,21 +1588,6 @@ static int vmci_transport_socket_init(struct vsock_sock *vsk,
INIT_LIST_HEAD(&vmci_trans(vsk)->elem); INIT_LIST_HEAD(&vmci_trans(vsk)->elem);
vmci_trans(vsk)->sk = &vsk->sk; vmci_trans(vsk)->sk = &vsk->sk;
spin_lock_init(&vmci_trans(vsk)->lock); spin_lock_init(&vmci_trans(vsk)->lock);
if (psk) {
vmci_trans(vsk)->queue_pair_size =
vmci_trans(psk)->queue_pair_size;
vmci_trans(vsk)->queue_pair_min_size =
vmci_trans(psk)->queue_pair_min_size;
vmci_trans(vsk)->queue_pair_max_size =
vmci_trans(psk)->queue_pair_max_size;
} else {
vmci_trans(vsk)->queue_pair_size =
VMCI_TRANSPORT_DEFAULT_QP_SIZE;
vmci_trans(vsk)->queue_pair_min_size =
VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN;
vmci_trans(vsk)->queue_pair_max_size =
VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX;
}
return 0; return 0;
} }
...@@ -1818,8 +1803,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) ...@@ -1818,8 +1803,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk)
if (vmci_transport_old_proto_override(&old_pkt_proto) && if (vmci_transport_old_proto_override(&old_pkt_proto) &&
old_pkt_proto) { old_pkt_proto) {
err = vmci_transport_send_conn_request( err = vmci_transport_send_conn_request(sk, vsk->buffer_size);
sk, vmci_trans(vsk)->queue_pair_size);
if (err < 0) { if (err < 0) {
sk->sk_state = TCP_CLOSE; sk->sk_state = TCP_CLOSE;
return err; return err;
...@@ -1827,8 +1811,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk) ...@@ -1827,8 +1811,7 @@ static int vmci_transport_connect(struct vsock_sock *vsk)
} else { } else {
int supported_proto_versions = int supported_proto_versions =
vmci_transport_new_proto_supported_versions(); vmci_transport_new_proto_supported_versions();
err = vmci_transport_send_conn_request2( err = vmci_transport_send_conn_request2(sk, vsk->buffer_size,
sk, vmci_trans(vsk)->queue_pair_size,
supported_proto_versions); supported_proto_versions);
if (err < 0) { if (err < 0) {
sk->sk_state = TCP_CLOSE; sk->sk_state = TCP_CLOSE;
...@@ -1881,46 +1864,6 @@ static bool vmci_transport_stream_is_active(struct vsock_sock *vsk) ...@@ -1881,46 +1864,6 @@ static bool vmci_transport_stream_is_active(struct vsock_sock *vsk)
return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle); return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle);
} }
static u64 vmci_transport_get_buffer_size(struct vsock_sock *vsk)
{
return vmci_trans(vsk)->queue_pair_size;
}
static u64 vmci_transport_get_min_buffer_size(struct vsock_sock *vsk)
{
return vmci_trans(vsk)->queue_pair_min_size;
}
static u64 vmci_transport_get_max_buffer_size(struct vsock_sock *vsk)
{
return vmci_trans(vsk)->queue_pair_max_size;
}
static void vmci_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
{
if (val < vmci_trans(vsk)->queue_pair_min_size)
vmci_trans(vsk)->queue_pair_min_size = val;
if (val > vmci_trans(vsk)->queue_pair_max_size)
vmci_trans(vsk)->queue_pair_max_size = val;
vmci_trans(vsk)->queue_pair_size = val;
}
static void vmci_transport_set_min_buffer_size(struct vsock_sock *vsk,
u64 val)
{
if (val > vmci_trans(vsk)->queue_pair_size)
vmci_trans(vsk)->queue_pair_size = val;
vmci_trans(vsk)->queue_pair_min_size = val;
}
static void vmci_transport_set_max_buffer_size(struct vsock_sock *vsk,
u64 val)
{
if (val < vmci_trans(vsk)->queue_pair_size)
vmci_trans(vsk)->queue_pair_size = val;
vmci_trans(vsk)->queue_pair_max_size = val;
}
static int vmci_transport_notify_poll_in( static int vmci_transport_notify_poll_in(
struct vsock_sock *vsk, struct vsock_sock *vsk,
size_t target, size_t target,
...@@ -2076,7 +2019,8 @@ static u32 vmci_transport_get_local_cid(void) ...@@ -2076,7 +2019,8 @@ static u32 vmci_transport_get_local_cid(void)
return vmci_get_context_id(); return vmci_get_context_id();
} }
static const struct vsock_transport vmci_transport = { static struct vsock_transport vmci_transport = {
.module = THIS_MODULE,
.init = vmci_transport_socket_init, .init = vmci_transport_socket_init,
.destruct = vmci_transport_destruct, .destruct = vmci_transport_destruct,
.release = vmci_transport_release, .release = vmci_transport_release,
...@@ -2103,15 +2047,26 @@ static const struct vsock_transport vmci_transport = { ...@@ -2103,15 +2047,26 @@ static const struct vsock_transport vmci_transport = {
.notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue, .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue, .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue,
.shutdown = vmci_transport_shutdown, .shutdown = vmci_transport_shutdown,
.set_buffer_size = vmci_transport_set_buffer_size,
.set_min_buffer_size = vmci_transport_set_min_buffer_size,
.set_max_buffer_size = vmci_transport_set_max_buffer_size,
.get_buffer_size = vmci_transport_get_buffer_size,
.get_min_buffer_size = vmci_transport_get_min_buffer_size,
.get_max_buffer_size = vmci_transport_get_max_buffer_size,
.get_local_cid = vmci_transport_get_local_cid, .get_local_cid = vmci_transport_get_local_cid,
}; };
static bool vmci_check_transport(struct vsock_sock *vsk)
{
return vsk->transport == &vmci_transport;
}
void vmci_vsock_transport_cb(bool is_host)
{
int features;
if (is_host)
features = VSOCK_TRANSPORT_F_H2G;
else
features = VSOCK_TRANSPORT_F_G2H;
vsock_core_register(&vmci_transport, features);
}
static int __init vmci_transport_init(void) static int __init vmci_transport_init(void)
{ {
int err; int err;
...@@ -2128,7 +2083,6 @@ static int __init vmci_transport_init(void) ...@@ -2128,7 +2083,6 @@ static int __init vmci_transport_init(void)
pr_err("Unable to create datagram handle. (%d)\n", err); pr_err("Unable to create datagram handle. (%d)\n", err);
return vmci_transport_error_to_vsock_error(err); return vmci_transport_error_to_vsock_error(err);
} }
err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED, err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED,
vmci_transport_qp_resumed_cb, vmci_transport_qp_resumed_cb,
NULL, &vmci_transport_qp_resumed_sub_id); NULL, &vmci_transport_qp_resumed_sub_id);
...@@ -2139,12 +2093,21 @@ static int __init vmci_transport_init(void) ...@@ -2139,12 +2093,21 @@ static int __init vmci_transport_init(void)
goto err_destroy_stream_handle; goto err_destroy_stream_handle;
} }
err = vsock_core_init(&vmci_transport); /* Register only with dgram feature, other features (H2G, G2H) will be
* registered when the first host or guest becomes active.
*/
err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM);
if (err < 0) if (err < 0)
goto err_unsubscribe; goto err_unsubscribe;
err = vmci_register_vsock_callback(vmci_vsock_transport_cb);
if (err < 0)
goto err_unregister;
return 0; return 0;
err_unregister:
vsock_core_unregister(&vmci_transport);
err_unsubscribe: err_unsubscribe:
vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
err_destroy_stream_handle: err_destroy_stream_handle:
...@@ -2170,7 +2133,8 @@ static void __exit vmci_transport_exit(void) ...@@ -2170,7 +2133,8 @@ static void __exit vmci_transport_exit(void)
vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
} }
vsock_core_exit(); vmci_register_vsock_callback(NULL);
vsock_core_unregister(&vmci_transport);
} }
module_exit(vmci_transport_exit); module_exit(vmci_transport_exit);
......
...@@ -108,9 +108,6 @@ struct vmci_transport { ...@@ -108,9 +108,6 @@ struct vmci_transport {
struct vmci_qp *qpair; struct vmci_qp *qpair;
u64 produce_size; u64 produce_size;
u64 consume_size; u64 consume_size;
u64 queue_pair_size;
u64 queue_pair_min_size;
u64 queue_pair_max_size;
u32 detach_sub_id; u32 detach_sub_id;
union vmci_transport_notify notify; union vmci_transport_notify notify;
const struct vmci_transport_notify_ops *notify_ops; const struct vmci_transport_notify_ops *notify_ops;
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <linux/types.h> #include <linux/types.h>
#include <linux/vmw_vmci_defs.h> #include <linux/vmw_vmci_defs.h>
#include <linux/vmw_vmci_api.h> #include <linux/vmw_vmci_api.h>
#include <linux/vm_sockets.h>
#include "vmci_transport.h" #include "vmci_transport.h"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment