Commit 5aa3bd9b authored by David S. Miller's avatar David S. Miller

Merge branch 'virtio-vsock-seqpacket'

Arseny Krasnov says:

====================
virtio/vsock: introduce SOCK_SEQPACKET support

This patchset implements support of SOCK_SEQPACKET for virtio
transport.
	As SOCK_SEQPACKET guarantees to save record boundaries, so to
do it, new bit for field 'flags' was added: SEQ_EOR. This bit is
set to 1 in last RW packet of message.
	Now as  packets of one socket are not reordered neither on vsock
nor on vhost transport layers, such bit allows to restore original
message on receiver's side. If user's buffer is smaller than message
length, when all out of size data is dropped.
	Maximum length of datagram is limited by 'peer_buf_alloc' value.
	Implementation also supports 'MSG_TRUNC' flags.
	Tests also implemented.

	Thanks to stsp2@yandex.ru for encouragements and initial design
recommendations.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 57806b28 184039ee
......@@ -31,7 +31,8 @@
enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES |
(1ULL << VIRTIO_F_ACCESS_PLATFORM)
(1ULL << VIRTIO_F_ACCESS_PLATFORM) |
(1ULL << VIRTIO_VSOCK_F_SEQPACKET)
};
enum {
......@@ -56,6 +57,7 @@ struct vhost_vsock {
atomic_t queued_replies;
u32 guest_cid;
bool seqpacket_allow;
};
static u32 vhost_transport_get_local_cid(void)
......@@ -112,6 +114,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
size_t nbytes;
size_t iov_len, payload_len;
int head;
bool restore_flag = false;
spin_lock_bh(&vsock->send_pkt_list_lock);
if (list_empty(&vsock->send_pkt_list)) {
......@@ -168,9 +171,26 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
/* If the packet is greater than the space available in the
* buffer, we split it using multiple buffers.
*/
if (payload_len > iov_len - sizeof(pkt->hdr))
if (payload_len > iov_len - sizeof(pkt->hdr)) {
payload_len = iov_len - sizeof(pkt->hdr);
/* As we are copying pieces of large packet's buffer to
* small rx buffers, headers of packets in rx queue are
* created dynamically and are initialized with header
* of current packet(except length). But in case of
* SOCK_SEQPACKET, we also must clear record delimeter
* bit(VIRTIO_VSOCK_SEQ_EOR). Otherwise, instead of one
* packet with delimeter(which marks end of record),
* there will be sequence of packets with delimeter
* bit set. After initialized header will be copied to
* rx buffer, this bit will be restored.
*/
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
restore_flag = true;
}
}
/* Set the correct length in the header */
pkt->hdr.len = cpu_to_le32(payload_len);
......@@ -204,6 +224,9 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
* to send it with the next available buffer.
*/
if (pkt->off < pkt->len) {
if (restore_flag)
pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
/* We are queueing the same virtio_vsock_pkt to handle
* the remaining bytes, and we want to deliver it
* to monitoring devices in the next iteration.
......@@ -354,7 +377,6 @@ vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
return NULL;
}
if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
pkt->len = le32_to_cpu(pkt->hdr.len);
/* No payload */
......@@ -398,6 +420,8 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
return val < vq->num;
}
static bool vhost_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport vhost_transport = {
.transport = {
.module = THIS_MODULE,
......@@ -424,6 +448,11 @@ static struct virtio_transport vhost_transport = {
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = vhost_transport_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
......@@ -441,6 +470,22 @@ static struct virtio_transport vhost_transport = {
.send_pkt = vhost_transport_send_pkt,
};
static bool vhost_transport_seqpacket_allow(u32 remote_cid)
{
struct vhost_vsock *vsock;
bool seqpacket_allow = false;
rcu_read_lock();
vsock = vhost_vsock_get(remote_cid);
if (vsock)
seqpacket_allow = vsock->seqpacket_allow;
rcu_read_unlock();
return seqpacket_allow;
}
static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
{
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
......@@ -785,6 +830,9 @@ static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
goto err;
}
if (features & (1ULL << VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
vq = &vsock->vqs[i];
mutex_lock(&vq->mutex);
......
......@@ -36,6 +36,7 @@ struct virtio_vsock_sock {
u32 rx_bytes;
u32 buf_alloc;
struct list_head rx_queue;
u32 msg_count;
};
struct virtio_vsock_pkt {
......@@ -80,8 +81,17 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags);
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk);
int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk);
......
......@@ -135,6 +135,14 @@ struct vsock_transport {
bool (*stream_is_active)(struct vsock_sock *);
bool (*stream_allow)(u32 cid, u32 port);
/* SEQ_PACKET. */
ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
int flags);
int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
size_t len);
bool (*seqpacket_allow)(u32 remote_cid);
u32 (*seqpacket_has_data)(struct vsock_sock *vsk);
/* Notification. */
int (*notify_poll_in)(struct vsock_sock *, size_t, bool *);
int (*notify_poll_out)(struct vsock_sock *, size_t, bool *);
......
......@@ -9,9 +9,12 @@
#include <linux/tracepoint.h>
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_SEQPACKET);
#define show_type(val) \
__print_symbolic(val, { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" })
__print_symbolic(val, \
{ VIRTIO_VSOCK_TYPE_STREAM, "STREAM" }, \
{ VIRTIO_VSOCK_TYPE_SEQPACKET, "SEQPACKET" })
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_INVALID);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_OP_REQUEST);
......
......@@ -38,6 +38,9 @@
#include <linux/virtio_ids.h>
#include <linux/virtio_config.h>
/* The feature bitmap for virtio vsock */
#define VIRTIO_VSOCK_F_SEQPACKET 1 /* SOCK_SEQPACKET supported */
struct virtio_vsock_config {
__le64 guest_cid;
} __attribute__((packed));
......@@ -65,6 +68,7 @@ struct virtio_vsock_hdr {
enum virtio_vsock_type {
VIRTIO_VSOCK_TYPE_STREAM = 1,
VIRTIO_VSOCK_TYPE_SEQPACKET = 2,
};
enum virtio_vsock_op {
......@@ -91,4 +95,9 @@ enum virtio_vsock_shutdown {
VIRTIO_VSOCK_SHUTDOWN_SEND = 2,
};
/* VIRTIO_VSOCK_OP_RW flags values */
enum virtio_vsock_rw {
VIRTIO_VSOCK_SEQ_EOR = 1,
};
#endif /* _UAPI_LINUX_VIRTIO_VSOCK_H */
......@@ -415,8 +415,8 @@ static void vsock_deassign_transport(struct vsock_sock *vsk)
/* Assign a transport to a socket and call the .init transport callback.
*
* Note: for stream socket this must be called when vsk->remote_addr is set
* (e.g. during the connect() or when a connection request on a listener
* Note: for connection oriented socket this must be called when vsk->remote_addr
* is set (e.g. during the connect() or when a connection request on a listener
* socket is received).
* The vsk->remote_addr is used to decide which transport to use:
* - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
......@@ -452,6 +452,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
new_transport = transport_dgram;
break;
case SOCK_STREAM:
case SOCK_SEQPACKET:
if (vsock_use_local_transport(remote_cid))
new_transport = transport_local;
else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
......@@ -469,10 +470,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
return 0;
/* transport->release() must be called with sock lock acquired.
* This path can only be taken during vsock_stream_connect(),
* where we have already held the sock lock.
* In the other cases, this function is called on a new socket
* which is not assigned to any transport.
* This path can only be taken during vsock_connect(), where we
* have already held the sock lock. In the other cases, this
* function is called on a new socket which is not assigned to
* any transport.
*/
vsk->transport->release(vsk);
vsock_deassign_transport(vsk);
......@@ -484,6 +485,14 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
if (!new_transport || !try_module_get(new_transport->module))
return -ENODEV;
if (sk->sk_type == SOCK_SEQPACKET) {
if (!new_transport->seqpacket_allow ||
!new_transport->seqpacket_allow(remote_cid)) {
module_put(new_transport->module);
return -ESOCKTNOSUPPORT;
}
}
ret = new_transport->init(vsk, psk);
if (ret) {
module_put(new_transport->module);
......@@ -604,7 +613,7 @@ static void vsock_pending_work(struct work_struct *work)
/**** SOCKET OPERATIONS ****/
static int __vsock_bind_stream(struct vsock_sock *vsk,
static int __vsock_bind_connectible(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
static u32 port;
......@@ -649,9 +658,10 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
/* Remove stream sockets from the unbound list and add them to the hash
* table for easy lookup by its address. The unbound list is simply an
* extra entry at the end of the hash table, a trick used by AF_UNIX.
/* Remove connection oriented sockets from the unbound list and add them
* to the hash table for easy lookup by its address. The unbound list
* is simply an extra entry at the end of the hash table, a trick used
* by AF_UNIX.
*/
__vsock_remove_bound(vsk);
__vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
......@@ -684,8 +694,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
switch (sk->sk_socket->type) {
case SOCK_STREAM:
case SOCK_SEQPACKET:
spin_lock_bh(&vsock_table_lock);
retval = __vsock_bind_stream(vsk, addr);
retval = __vsock_bind_connectible(vsk, addr);
spin_unlock_bh(&vsock_table_lock);
break;
......@@ -768,6 +779,11 @@ static struct sock *__vsock_create(struct net *net,
return sk;
}
static bool sock_type_connectible(u16 type)
{
return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET);
}
static void __vsock_release(struct sock *sk, int level)
{
if (sk) {
......@@ -786,7 +802,7 @@ static void __vsock_release(struct sock *sk, int level)
if (vsk->transport)
vsk->transport->release(vsk);
else if (sk->sk_type == SOCK_STREAM)
else if (sock_type_connectible(sk->sk_type))
vsock_remove_sock(vsk);
sock_orphan(sk);
......@@ -844,6 +860,16 @@ s64 vsock_stream_has_data(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(vsock_stream_has_data);
static s64 vsock_has_data(struct vsock_sock *vsk)
{
struct sock *sk = sk_vsock(vsk);
if (sk->sk_type == SOCK_SEQPACKET)
return vsk->transport->seqpacket_has_data(vsk);
else
return vsock_stream_has_data(vsk);
}
s64 vsock_stream_has_space(struct vsock_sock *vsk)
{
return vsk->transport->stream_has_space(vsk);
......@@ -937,10 +963,10 @@ static int vsock_shutdown(struct socket *sock, int mode)
if ((mode & ~SHUTDOWN_MASK) || !mode)
return -EINVAL;
/* If this is a STREAM socket and it is not connected then bail out
* immediately. If it is a DGRAM socket then we must first kick the
* socket so that it wakes up from any sleeping calls, for example
* recv(), and then afterwards return the error.
/* If this is a connection oriented socket and it is not connected then
* bail out immediately. If it is a DGRAM socket then we must first
* kick the socket so that it wakes up from any sleeping calls, for
* example recv(), and then afterwards return the error.
*/
sk = sock->sk;
......@@ -948,7 +974,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
lock_sock(sk);
if (sock->state == SS_UNCONNECTED) {
err = -ENOTCONN;
if (sk->sk_type == SOCK_STREAM)
if (sock_type_connectible(sk->sk_type))
goto out;
} else {
sock->state = SS_DISCONNECTING;
......@@ -961,7 +987,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
sk->sk_shutdown |= mode;
sk->sk_state_change(sk);
if (sk->sk_type == SOCK_STREAM) {
if (sock_type_connectible(sk->sk_type)) {
sock_reset_flag(sk, SOCK_DONE);
vsock_send_shutdown(sk, mode);
}
......@@ -1016,7 +1042,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
if (!(sk->sk_shutdown & SEND_SHUTDOWN))
mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
} else if (sock->type == SOCK_STREAM) {
} else if (sock_type_connectible(sk->sk_type)) {
const struct vsock_transport *transport;
lock_sock(sk);
......@@ -1263,7 +1289,7 @@ static void vsock_connect_timeout(struct work_struct *work)
sock_put(sk);
}
static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
static int vsock_connect(struct socket *sock, struct sockaddr *addr,
int addr_len, int flags)
{
int err;
......@@ -1414,7 +1440,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
lock_sock(listener);
if (sock->type != SOCK_STREAM) {
if (!sock_type_connectible(sock->type)) {
err = -EOPNOTSUPP;
goto out;
}
......@@ -1491,7 +1517,7 @@ static int vsock_listen(struct socket *sock, int backlog)
lock_sock(sk);
if (sock->type != SOCK_STREAM) {
if (!sock_type_connectible(sk->sk_type)) {
err = -EOPNOTSUPP;
goto out;
}
......@@ -1535,7 +1561,7 @@ static void vsock_update_buffer_size(struct vsock_sock *vsk,
vsk->buffer_size = val;
}
static int vsock_stream_setsockopt(struct socket *sock,
static int vsock_connectible_setsockopt(struct socket *sock,
int level,
int optname,
sockptr_t optval,
......@@ -1617,7 +1643,7 @@ static int vsock_stream_setsockopt(struct socket *sock,
return err;
}
static int vsock_stream_getsockopt(struct socket *sock,
static int vsock_connectible_getsockopt(struct socket *sock,
int level, int optname,
char __user *optval,
int __user *optlen)
......@@ -1688,7 +1714,7 @@ static int vsock_stream_getsockopt(struct socket *sock,
return 0;
}
static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
size_t len)
{
struct sock *sk;
......@@ -1712,7 +1738,9 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
transport = vsk->transport;
/* Callers should not provide a destination with stream sockets. */
/* Callers should not provide a destination with connection oriented
* sockets.
*/
if (msg->msg_namelen) {
err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
goto out;
......@@ -1803,9 +1831,13 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
* responsibility to check how many bytes we were able to send.
*/
written = transport->stream_enqueue(
vsk, msg,
len - total_written);
if (sk->sk_type == SOCK_SEQPACKET) {
written = transport->seqpacket_enqueue(vsk,
msg, len - total_written);
} else {
written = transport->stream_enqueue(vsk,
msg, len - total_written);
}
if (written < 0) {
err = -ENOMEM;
goto out_err;
......@@ -1821,154 +1853,130 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
}
out_err:
if (total_written > 0)
if (total_written > 0) {
/* Return number of written bytes only if:
* 1) SOCK_STREAM socket.
* 2) SOCK_SEQPACKET socket when whole buffer is sent.
*/
if (sk->sk_type == SOCK_STREAM || total_written == len)
err = total_written;
}
out:
release_sock(sk);
return err;
}
static int
vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
int flags)
static int vsock_wait_data(struct sock *sk, struct wait_queue_entry *wait,
long timeout,
struct vsock_transport_recv_notify_data *recv_data,
size_t target)
{
struct sock *sk;
struct vsock_sock *vsk;
const struct vsock_transport *transport;
struct vsock_sock *vsk;
s64 data;
int err;
size_t target;
ssize_t copied;
long timeout;
struct vsock_transport_recv_notify_data recv_data;
DEFINE_WAIT(wait);
sk = sock->sk;
vsk = vsock_sk(sk);
err = 0;
lock_sock(sk);
transport = vsk->transport;
if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
* peer has not connected or a local shutdown occurred with the
* SOCK_DONE flag.
*/
if (sock_flag(sk, SOCK_DONE))
err = 0;
else
err = -ENOTCONN;
goto out;
}
if (flags & MSG_OOB) {
err = -EOPNOTSUPP;
goto out;
}
/* We don't check peer_shutdown flag here since peer may actually shut
* down, but there can be data in the queue that a local socket can
* receive.
*/
if (sk->sk_shutdown & RCV_SHUTDOWN) {
err = 0;
goto out;
}
/* It is valid on Linux to pass in a zero-length receive buffer. This
* is not an error. We may as well bail out now.
*/
if (!len) {
err = 0;
goto out;
}
/* We must not copy less than target bytes into the user's buffer
* before returning successfully, so we wait for the consume queue to
* have that much data to consume before dequeueing. Note that this
* makes it impossible to handle cases where target is greater than the
* queue size.
*/
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
if (target >= transport->stream_rcvhiwat(vsk)) {
err = -ENOMEM;
goto out;
}
timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
copied = 0;
err = transport->notify_recv_init(vsk, target, &recv_data);
if (err < 0)
goto out;
while ((data = vsock_has_data(vsk)) == 0) {
prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
while (1) {
s64 ready;
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
ready = vsock_stream_has_data(vsk);
if (ready == 0) {
if (sk->sk_err != 0 ||
(sk->sk_shutdown & RCV_SHUTDOWN) ||
(vsk->peer_shutdown & SEND_SHUTDOWN)) {
finish_wait(sk_sleep(sk), &wait);
break;
}
/* Don't wait for non-blocking sockets. */
if (timeout == 0) {
err = -EAGAIN;
finish_wait(sk_sleep(sk), &wait);
break;
}
err = transport->notify_recv_pre_block(
vsk, target, &recv_data);
if (err < 0) {
finish_wait(sk_sleep(sk), &wait);
if (recv_data) {
err = transport->notify_recv_pre_block(vsk, target, recv_data);
if (err < 0)
break;
}
release_sock(sk);
timeout = schedule_timeout(timeout);
lock_sock(sk);
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
finish_wait(sk_sleep(sk), &wait);
break;
} else if (timeout == 0) {
err = -EAGAIN;
finish_wait(sk_sleep(sk), &wait);
break;
}
} else {
ssize_t read;
}
finish_wait(sk_sleep(sk), &wait);
finish_wait(sk_sleep(sk), wait);
if (err)
return err;
if (ready < 0) {
/* Invalid queue pair content. XXX This should
* be changed to a connection reset in a later
* change.
/* Internal transport error when checking for available
* data. XXX This should be changed to a connection
* reset in a later change.
*/
if (data < 0)
return -ENOMEM;
return data;
}
static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
size_t len, int flags)
{
struct vsock_transport_recv_notify_data recv_data;
const struct vsock_transport *transport;
struct vsock_sock *vsk;
ssize_t copied;
size_t target;
long timeout;
int err;
DEFINE_WAIT(wait);
vsk = vsock_sk(sk);
transport = vsk->transport;
/* We must not copy less than target bytes into the user's buffer
* before returning successfully, so we wait for the consume queue to
* have that much data to consume before dequeueing. Note that this
* makes it impossible to handle cases where target is greater than the
* queue size.
*/
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
if (target >= transport->stream_rcvhiwat(vsk)) {
err = -ENOMEM;
goto out;
}
timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
copied = 0;
err = transport->notify_recv_pre_dequeue(
vsk, target, &recv_data);
err = transport->notify_recv_init(vsk, target, &recv_data);
if (err < 0)
goto out;
while (1) {
ssize_t read;
err = vsock_wait_data(sk, &wait, timeout, &recv_data, target);
if (err <= 0)
break;
read = transport->stream_dequeue(
vsk, msg,
len - copied, flags);
err = transport->notify_recv_pre_dequeue(vsk, target,
&recv_data);
if (err < 0)
break;
read = transport->stream_dequeue(vsk, msg, len - copied, flags);
if (read < 0) {
err = -ENOMEM;
break;
......@@ -1976,8 +1984,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
copied += read;
err = transport->notify_recv_post_dequeue(
vsk, target, read,
err = transport->notify_recv_post_dequeue(vsk, target, read,
!(flags & MSG_PEEK), &recv_data);
if (err < 0)
goto out;
......@@ -1987,7 +1994,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
target -= read;
}
}
if (sk->sk_err)
err = -sk->sk_err;
......@@ -1997,6 +2003,120 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
if (copied > 0)
err = copied;
out:
return err;
}
static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
size_t len, int flags)
{
const struct vsock_transport *transport;
struct vsock_sock *vsk;
ssize_t record_len;
long timeout;
int err = 0;
DEFINE_WAIT(wait);
vsk = vsock_sk(sk);
transport = vsk->transport;
timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
err = vsock_wait_data(sk, &wait, timeout, NULL, 0);
if (err <= 0)
goto out;
record_len = transport->seqpacket_dequeue(vsk, msg, flags);
if (record_len < 0) {
err = -ENOMEM;
goto out;
}
if (sk->sk_err) {
err = -sk->sk_err;
} else if (sk->sk_shutdown & RCV_SHUTDOWN) {
err = 0;
} else {
/* User sets MSG_TRUNC, so return real length of
* packet.
*/
if (flags & MSG_TRUNC)
err = record_len;
else
err = len - msg_data_left(msg);
/* Always set MSG_TRUNC if real length of packet is
* bigger than user's buffer.
*/
if (record_len > len)
msg->msg_flags |= MSG_TRUNC;
}
out:
return err;
}
static int
vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
int flags)
{
struct sock *sk;
struct vsock_sock *vsk;
const struct vsock_transport *transport;
int err;
DEFINE_WAIT(wait);
sk = sock->sk;
vsk = vsock_sk(sk);
err = 0;
lock_sock(sk);
transport = vsk->transport;
if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
* peer has not connected or a local shutdown occurred with the
* SOCK_DONE flag.
*/
if (sock_flag(sk, SOCK_DONE))
err = 0;
else
err = -ENOTCONN;
goto out;
}
if (flags & MSG_OOB) {
err = -EOPNOTSUPP;
goto out;
}
/* We don't check peer_shutdown flag here since peer may actually shut
* down, but there can be data in the queue that a local socket can
* receive.
*/
if (sk->sk_shutdown & RCV_SHUTDOWN) {
err = 0;
goto out;
}
/* It is valid on Linux to pass in a zero-length receive buffer. This
* is not an error. We may as well bail out now.
*/
if (!len) {
err = 0;
goto out;
}
if (sk->sk_type == SOCK_STREAM)
err = __vsock_stream_recvmsg(sk, msg, len, flags);
else
err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
out:
release_sock(sk);
return err;
......@@ -2007,7 +2127,7 @@ static const struct proto_ops vsock_stream_ops = {
.owner = THIS_MODULE,
.release = vsock_release,
.bind = vsock_bind,
.connect = vsock_stream_connect,
.connect = vsock_connect,
.socketpair = sock_no_socketpair,
.accept = vsock_accept,
.getname = vsock_getname,
......@@ -2015,10 +2135,31 @@ static const struct proto_ops vsock_stream_ops = {
.ioctl = sock_no_ioctl,
.listen = vsock_listen,
.shutdown = vsock_shutdown,
.setsockopt = vsock_stream_setsockopt,
.getsockopt = vsock_stream_getsockopt,
.sendmsg = vsock_stream_sendmsg,
.recvmsg = vsock_stream_recvmsg,
.setsockopt = vsock_connectible_setsockopt,
.getsockopt = vsock_connectible_getsockopt,
.sendmsg = vsock_connectible_sendmsg,
.recvmsg = vsock_connectible_recvmsg,
.mmap = sock_no_mmap,
.sendpage = sock_no_sendpage,
};
static const struct proto_ops vsock_seqpacket_ops = {
.family = PF_VSOCK,
.owner = THIS_MODULE,
.release = vsock_release,
.bind = vsock_bind,
.connect = vsock_connect,
.socketpair = sock_no_socketpair,
.accept = vsock_accept,
.getname = vsock_getname,
.poll = vsock_poll,
.ioctl = sock_no_ioctl,
.listen = vsock_listen,
.shutdown = vsock_shutdown,
.setsockopt = vsock_connectible_setsockopt,
.getsockopt = vsock_connectible_getsockopt,
.sendmsg = vsock_connectible_sendmsg,
.recvmsg = vsock_connectible_recvmsg,
.mmap = sock_no_mmap,
.sendpage = sock_no_sendpage,
};
......@@ -2043,6 +2184,9 @@ static int vsock_create(struct net *net, struct socket *sock,
case SOCK_STREAM:
sock->ops = &vsock_stream_ops;
break;
case SOCK_SEQPACKET:
sock->ops = &vsock_seqpacket_ops;
break;
default:
return -ESOCKTNOSUPPORT;
}
......
......@@ -62,6 +62,7 @@ struct virtio_vsock {
struct virtio_vsock_event event_list[8];
u32 guest_cid;
bool seqpacket_allow;
};
static u32 virtio_transport_get_local_cid(void)
......@@ -443,6 +444,8 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
queue_work(virtio_vsock_workqueue, &vsock->rx_work);
}
static bool virtio_transport_seqpacket_allow(u32 remote_cid);
static struct virtio_transport virtio_transport = {
.transport = {
.module = THIS_MODULE,
......@@ -469,6 +472,11 @@ static struct virtio_transport virtio_transport = {
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = virtio_transport_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
......@@ -485,6 +493,19 @@ static struct virtio_transport virtio_transport = {
.send_pkt = virtio_transport_send_pkt,
};
static bool virtio_transport_seqpacket_allow(u32 remote_cid)
{
struct virtio_vsock *vsock;
bool seqpacket_allow;
rcu_read_lock();
vsock = rcu_dereference(the_virtio_vsock);
seqpacket_allow = vsock->seqpacket_allow;
rcu_read_unlock();
return seqpacket_allow;
}
static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
......@@ -608,10 +629,14 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
vsock->event_run = true;
mutex_unlock(&vsock->event_lock);
if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET))
vsock->seqpacket_allow = true;
vdev->priv = vsock;
rcu_assign_pointer(the_virtio_vsock, vsock);
mutex_unlock(&the_virtio_vsock_mutex);
return 0;
out:
......@@ -695,6 +720,7 @@ static struct virtio_device_id id_table[] = {
};
static unsigned int features[] = {
VIRTIO_VSOCK_F_SEQPACKET
};
static struct virtio_driver virtio_vsock_driver = {
......
......@@ -74,6 +74,10 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
err = memcpy_from_msg(pkt->buf, info->msg, len);
if (err)
goto out;
if (msg_data_left(info->msg) == 0 &&
info->type == VIRTIO_VSOCK_TYPE_SEQPACKET)
pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
}
trace_virtio_transport_alloc_pkt(src_cid, src_port,
......@@ -165,6 +169,14 @@ void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
}
EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
static u16 virtio_transport_get_type(struct sock *sk)
{
if (sk->sk_type == SOCK_STREAM)
return VIRTIO_VSOCK_TYPE_STREAM;
else
return VIRTIO_VSOCK_TYPE_SEQPACKET;
}
/* This function can only be used on connecting/connected sockets,
* since a socket assigned to a transport is required.
*
......@@ -179,6 +191,8 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt *pkt;
u32 pkt_len = info->pkt_len;
info->type = virtio_transport_get_type(sk_vsock(vsk));
t_ops = virtio_transport_get_ops(vsk);
if (unlikely(!t_ops))
return -EFAULT;
......@@ -269,13 +283,10 @@ void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
}
EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
int type,
struct virtio_vsock_hdr *hdr)
static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
.type = type,
.vsk = vsk,
};
......@@ -383,11 +394,8 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
* messages, we set the limit to a high value. TODO: experiment
* with different values.
*/
if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
virtio_transport_send_credit_update(vsk,
VIRTIO_VSOCK_TYPE_STREAM,
NULL);
}
if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
virtio_transport_send_credit_update(vsk);
return total;
......@@ -397,6 +405,78 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
return err;
}
static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
struct virtio_vsock_sock *vvs = vsk->trans;
struct virtio_vsock_pkt *pkt;
int dequeued_len = 0;
size_t user_buf_len = msg_data_left(msg);
bool copy_failed = false;
bool msg_ready = false;
spin_lock_bh(&vvs->rx_lock);
if (vvs->msg_count == 0) {
spin_unlock_bh(&vvs->rx_lock);
return 0;
}
while (!msg_ready) {
pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
if (!copy_failed) {
size_t pkt_len;
size_t bytes_to_copy;
pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
bytes_to_copy = min(user_buf_len, pkt_len);
if (bytes_to_copy) {
int err;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
if (err) {
/* Copy of message failed, set flag to skip
* copy path for rest of fragments. Rest of
* fragments will be freed without copy.
*/
copy_failed = true;
dequeued_len = err;
} else {
user_buf_len -= bytes_to_copy;
}
spin_lock_bh(&vvs->rx_lock);
}
if (dequeued_len >= 0)
dequeued_len += pkt_len;
}
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
msg_ready = true;
vvs->msg_count--;
}
virtio_transport_dec_rx_pkt(vvs, pkt);
list_del(&pkt->list);
virtio_transport_free_pkt(pkt);
}
spin_unlock_bh(&vvs->rx_lock);
virtio_transport_send_credit_update(vsk);
return dequeued_len;
}
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
......@@ -409,6 +489,38 @@ virtio_transport_stream_dequeue(struct vsock_sock *vsk,
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
if (flags & MSG_PEEK)
return -EOPNOTSUPP;
return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
spin_lock_bh(&vvs->tx_lock);
if (len > vvs->peer_buf_alloc) {
spin_unlock_bh(&vvs->tx_lock);
return -EMSGSIZE;
}
spin_unlock_bh(&vvs->tx_lock);
return virtio_transport_stream_enqueue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
......@@ -431,6 +543,19 @@ s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
u32 msg_count;
spin_lock_bh(&vvs->rx_lock);
msg_count = vvs->msg_count;
spin_unlock_bh(&vvs->rx_lock);
return msg_count;
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
static s64 virtio_transport_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
......@@ -496,8 +621,7 @@ void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
vvs->buf_alloc = *val;
virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
NULL);
virtio_transport_send_credit_update(vsk);
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
......@@ -624,7 +748,6 @@ int virtio_transport_connect(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.vsk = vsk,
};
......@@ -636,7 +759,6 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_SHUTDOWN,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.flags = (mode & RCV_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
(mode & SEND_SHUTDOWN ?
......@@ -665,7 +787,6 @@ virtio_transport_stream_enqueue(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.msg = msg,
.pkt_len = len,
.vsk = vsk,
......@@ -688,7 +809,6 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.reply = !!pkt,
.vsk = vsk,
};
......@@ -848,7 +968,7 @@ void virtio_transport_release(struct vsock_sock *vsk)
struct sock *sk = &vsk->sk;
bool remove_sock = true;
if (sk->sk_type == SOCK_STREAM)
if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
remove_sock = virtio_transport_close(vsk);
if (remove_sock) {
......@@ -912,6 +1032,9 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
goto out;
}
if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)
vvs->msg_count++;
/* Try to copy small packets into the buffer of last packet queued,
* to avoid wasting memory queueing the entire buffer with a small
* payload.
......@@ -923,13 +1046,18 @@ virtio_transport_recv_enqueue(struct vsock_sock *vsk,
struct virtio_vsock_pkt, list);
/* If there is space in the last packet queued, we copy the
* new packet in its buffer.
* new packet in its buffer. We avoid this if the last packet
* queued has VIRTIO_VSOCK_SEQ_EOR set, because this is
* delimiter of SEQPACKET record, so 'pkt' is the first packet
* of a new record.
*/
if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
if ((pkt->len <= last_pkt->buf_len - last_pkt->len) &&
!(le32_to_cpu(last_pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)) {
memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
pkt->len);
last_pkt->len += pkt->len;
free_pkt = true;
last_pkt->hdr.flags |= pkt->hdr.flags;
goto out;
}
}
......@@ -1000,7 +1128,6 @@ virtio_transport_send_response(struct vsock_sock *vsk,
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RESPONSE,
.type = VIRTIO_VSOCK_TYPE_STREAM,
.remote_cid = le64_to_cpu(pkt->hdr.src_cid),
.remote_port = le32_to_cpu(pkt->hdr.src_port),
.reply = true,
......@@ -1096,6 +1223,12 @@ virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
return 0;
}
static bool virtio_transport_valid_type(u16 type)
{
return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
(type == VIRTIO_VSOCK_TYPE_SEQPACKET);
}
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
......@@ -1121,7 +1254,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
le32_to_cpu(pkt->hdr.buf_alloc),
le32_to_cpu(pkt->hdr.fwd_cnt));
if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
if (!virtio_transport_valid_type(le16_to_cpu(pkt->hdr.type))) {
(void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}
......@@ -1138,6 +1271,12 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
}
}
if (virtio_transport_get_type(sk) != le16_to_cpu(pkt->hdr.type)) {
(void)virtio_transport_reset_no_sock(t, pkt);
sock_put(sk);
goto free_pkt;
}
vsk = vsock_sk(sk);
lock_sock(sk);
......
......@@ -63,6 +63,8 @@ static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk)
return 0;
}
static bool vsock_loopback_seqpacket_allow(u32 remote_cid);
static struct virtio_transport loopback_transport = {
.transport = {
.module = THIS_MODULE,
......@@ -89,6 +91,11 @@ static struct virtio_transport loopback_transport = {
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,
.seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
.seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
.seqpacket_allow = vsock_loopback_seqpacket_allow,
.seqpacket_has_data = virtio_transport_seqpacket_has_data,
.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
......@@ -105,6 +112,11 @@ static struct virtio_transport loopback_transport = {
.send_pkt = vsock_loopback_send_pkt,
};
static bool vsock_loopback_seqpacket_allow(u32 remote_cid)
{
return true;
}
static void vsock_loopback_work(struct work_struct *work)
{
struct vsock_loopback *vsock =
......
......@@ -84,7 +84,7 @@ void vsock_wait_remote_close(int fd)
}
/* Connect to <cid, port> and return the file descriptor. */
int vsock_stream_connect(unsigned int cid, unsigned int port)
static int vsock_connect(unsigned int cid, unsigned int port, int type)
{
union {
struct sockaddr sa;
......@@ -101,7 +101,7 @@ int vsock_stream_connect(unsigned int cid, unsigned int port)
control_expectln("LISTENING");
fd = socket(AF_VSOCK, SOCK_STREAM, 0);
fd = socket(AF_VSOCK, type, 0);
timeout_begin(TIMEOUT);
do {
......@@ -120,11 +120,21 @@ int vsock_stream_connect(unsigned int cid, unsigned int port)
return fd;
}
int vsock_stream_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_STREAM);
}
int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
{
return vsock_connect(cid, port, SOCK_SEQPACKET);
}
/* Listen on <cid, port> and return the first incoming connection. The remote
* address is stored to clientaddrp. clientaddrp may be NULL.
*/
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp, int type)
{
union {
struct sockaddr sa;
......@@ -145,7 +155,7 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
int client_fd;
int old_errno;
fd = socket(AF_VSOCK, SOCK_STREAM, 0);
fd = socket(AF_VSOCK, type, 0);
if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
perror("bind");
......@@ -189,6 +199,18 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
return client_fd;
}
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
}
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp)
{
return vsock_accept(cid, port, clientaddrp, SOCK_SEQPACKET);
}
/* Transmit one byte and check the return value.
*
* expected_ret:
......
......@@ -36,8 +36,11 @@ struct test_case {
void init_signals(void);
unsigned int parse_cid(const char *str);
int vsock_stream_connect(unsigned int cid, unsigned int port);
int vsock_seqpacket_connect(unsigned int cid, unsigned int port);
int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp);
void vsock_wait_remote_close(int fd);
void send_byte(int fd, int expected_ret, int flags);
void recv_byte(int fd, int expected_ret, int flags);
......
......@@ -14,6 +14,8 @@
#include <errno.h>
#include <unistd.h>
#include <linux/kernel.h>
#include <sys/types.h>
#include <sys/socket.h>
#include "timeout.h"
#include "control.h"
......@@ -279,6 +281,110 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
close(fd);
}
#define MESSAGES_CNT 7
static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
{
int fd;
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
if (fd < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
/* Send several messages, one with MSG_EOR flag */
for (int i = 0; i < MESSAGES_CNT; i++)
send_byte(fd, 1, 0);
control_writeln("SENDDONE");
close(fd);
}
static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
{
int fd;
char buf[16];
struct msghdr msg = {0};
struct iovec iov = {0};
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
if (fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
control_expectln("SENDDONE");
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
for (int i = 0; i < MESSAGES_CNT; i++) {
if (recvmsg(fd, &msg, 0) != 1) {
perror("message bound violated");
exit(EXIT_FAILURE);
}
}
close(fd);
}
#define MESSAGE_TRUNC_SZ 32
static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
{
int fd;
char buf[MESSAGE_TRUNC_SZ];
fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
if (fd < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
perror("send failed");
exit(EXIT_FAILURE);
}
control_writeln("SENDDONE");
close(fd);
}
static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
{
int fd;
char buf[MESSAGE_TRUNC_SZ / 2];
struct msghdr msg = {0};
struct iovec iov = {0};
fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
if (fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
control_expectln("SENDDONE");
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
if (ret != MESSAGE_TRUNC_SZ) {
printf("%zi\n", ret);
perror("MSG_TRUNC doesn't work");
exit(EXIT_FAILURE);
}
if (!(msg.msg_flags & MSG_TRUNC)) {
fprintf(stderr, "MSG_TRUNC expected\n");
exit(EXIT_FAILURE);
}
close(fd);
}
static struct test_case test_cases[] = {
{
.name = "SOCK_STREAM connection reset",
......@@ -309,6 +415,16 @@ static struct test_case test_cases[] = {
.run_client = test_stream_msg_peek_client,
.run_server = test_stream_msg_peek_server,
},
{
.name = "SOCK_SEQPACKET msg bounds",
.run_client = test_seqpacket_msg_bounds_client,
.run_server = test_seqpacket_msg_bounds_server,
},
{
.name = "SOCK_SEQPACKET MSG_TRUNC flag",
.run_client = test_seqpacket_msg_trunc_client,
.run_server = test_seqpacket_msg_trunc_server,
},
{},
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment