Commit 7344ba03 authored by David S. Miller's avatar David S. Miller

Merge branch 'vhost-skb-leaks'

Wei Xu says:

====================
vhost: fix a few skb leaks

Matthew found a roughly 40% tcp throughput regression with commit
c67df11f(vhost_net: try batch dequing from skb array) as discussed
in the following thread:
https://www.mail-archive.com/netdev@vger.kernel.org/msg187936.html

v4:
- fix zero iov iterator count in tap/tap_do_read()(Jason)
- don't put tun in case of EBADFD(Jason)
- Replace msg->msg_control with new 'skb' when calling tun/tap_do_read()

v3:
- move freeing skb from vhost to tun/tap recvmsg() to not
  confuse the callers.

v2:
- add Matthew as the reporter, thanks matthew.
- moving zero headcount check ahead instead of defer consuming skb
  due to jason and mst's comment.
- add freeing skb in favor of recvmsg() fails.
====================
Acked-by: default avatarMichael S. Tsirkin <mst@redhat.com>
Tested-by: default avatarMatthew Rosato <mjrosato@linux.vnet.ibm.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents fa935ca2 61d78537
...@@ -829,8 +829,11 @@ static ssize_t tap_do_read(struct tap_queue *q, ...@@ -829,8 +829,11 @@ static ssize_t tap_do_read(struct tap_queue *q,
DEFINE_WAIT(wait); DEFINE_WAIT(wait);
ssize_t ret = 0; ssize_t ret = 0;
if (!iov_iter_count(to)) if (!iov_iter_count(to)) {
if (skb)
kfree_skb(skb);
return 0; return 0;
}
if (skb) if (skb)
goto put; goto put;
...@@ -1154,11 +1157,14 @@ static int tap_recvmsg(struct socket *sock, struct msghdr *m, ...@@ -1154,11 +1157,14 @@ static int tap_recvmsg(struct socket *sock, struct msghdr *m,
size_t total_len, int flags) size_t total_len, int flags)
{ {
struct tap_queue *q = container_of(sock, struct tap_queue, sock); struct tap_queue *q = container_of(sock, struct tap_queue, sock);
struct sk_buff *skb = m->msg_control;
int ret; int ret;
if (flags & ~(MSG_DONTWAIT|MSG_TRUNC)) if (flags & ~(MSG_DONTWAIT|MSG_TRUNC)) {
if (skb)
kfree_skb(skb);
return -EINVAL; return -EINVAL;
ret = tap_do_read(q, &m->msg_iter, flags & MSG_DONTWAIT, }
m->msg_control); ret = tap_do_read(q, &m->msg_iter, flags & MSG_DONTWAIT, skb);
if (ret > total_len) { if (ret > total_len) {
m->msg_flags |= MSG_TRUNC; m->msg_flags |= MSG_TRUNC;
ret = flags & MSG_TRUNC ? ret : total_len; ret = flags & MSG_TRUNC ? ret : total_len;
......
...@@ -1952,8 +1952,11 @@ static ssize_t tun_do_read(struct tun_struct *tun, struct tun_file *tfile, ...@@ -1952,8 +1952,11 @@ static ssize_t tun_do_read(struct tun_struct *tun, struct tun_file *tfile,
tun_debug(KERN_INFO, tun, "tun_do_read\n"); tun_debug(KERN_INFO, tun, "tun_do_read\n");
if (!iov_iter_count(to)) if (!iov_iter_count(to)) {
if (skb)
kfree_skb(skb);
return 0; return 0;
}
if (!skb) { if (!skb) {
/* Read frames from ring */ /* Read frames from ring */
...@@ -2069,22 +2072,24 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len, ...@@ -2069,22 +2072,24 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len,
{ {
struct tun_file *tfile = container_of(sock, struct tun_file, socket); struct tun_file *tfile = container_of(sock, struct tun_file, socket);
struct tun_struct *tun = tun_get(tfile); struct tun_struct *tun = tun_get(tfile);
struct sk_buff *skb = m->msg_control;
int ret; int ret;
if (!tun) if (!tun) {
return -EBADFD; ret = -EBADFD;
goto out_free_skb;
}
if (flags & ~(MSG_DONTWAIT|MSG_TRUNC|MSG_ERRQUEUE)) { if (flags & ~(MSG_DONTWAIT|MSG_TRUNC|MSG_ERRQUEUE)) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_put_tun;
} }
if (flags & MSG_ERRQUEUE) { if (flags & MSG_ERRQUEUE) {
ret = sock_recv_errqueue(sock->sk, m, total_len, ret = sock_recv_errqueue(sock->sk, m, total_len,
SOL_PACKET, TUN_TX_TIMESTAMP); SOL_PACKET, TUN_TX_TIMESTAMP);
goto out; goto out;
} }
ret = tun_do_read(tun, tfile, &m->msg_iter, flags & MSG_DONTWAIT, ret = tun_do_read(tun, tfile, &m->msg_iter, flags & MSG_DONTWAIT, skb);
m->msg_control);
if (ret > (ssize_t)total_len) { if (ret > (ssize_t)total_len) {
m->msg_flags |= MSG_TRUNC; m->msg_flags |= MSG_TRUNC;
ret = flags & MSG_TRUNC ? ret : total_len; ret = flags & MSG_TRUNC ? ret : total_len;
...@@ -2092,6 +2097,13 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len, ...@@ -2092,6 +2097,13 @@ static int tun_recvmsg(struct socket *sock, struct msghdr *m, size_t total_len,
out: out:
tun_put(tun); tun_put(tun);
return ret; return ret;
out_put_tun:
tun_put(tun);
out_free_skb:
if (skb)
kfree_skb(skb);
return ret;
} }
static int tun_peek_len(struct socket *sock) static int tun_peek_len(struct socket *sock)
......
...@@ -778,16 +778,6 @@ static void handle_rx(struct vhost_net *net) ...@@ -778,16 +778,6 @@ static void handle_rx(struct vhost_net *net)
/* On error, stop handling until the next kick. */ /* On error, stop handling until the next kick. */
if (unlikely(headcount < 0)) if (unlikely(headcount < 0))
goto out; goto out;
if (nvq->rx_array)
msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
/* On overrun, truncate and discard */
if (unlikely(headcount > UIO_MAXIOV)) {
iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
err = sock->ops->recvmsg(sock, &msg,
1, MSG_DONTWAIT | MSG_TRUNC);
pr_debug("Discarded rx packet: len %zd\n", sock_len);
continue;
}
/* OK, now we need to know about added descriptors. */ /* OK, now we need to know about added descriptors. */
if (!headcount) { if (!headcount) {
if (unlikely(vhost_enable_notify(&net->dev, vq))) { if (unlikely(vhost_enable_notify(&net->dev, vq))) {
...@@ -800,6 +790,16 @@ static void handle_rx(struct vhost_net *net) ...@@ -800,6 +790,16 @@ static void handle_rx(struct vhost_net *net)
* they refilled. */ * they refilled. */
goto out; goto out;
} }
if (nvq->rx_array)
msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
/* On overrun, truncate and discard */
if (unlikely(headcount > UIO_MAXIOV)) {
iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
err = sock->ops->recvmsg(sock, &msg,
1, MSG_DONTWAIT | MSG_TRUNC);
pr_debug("Discarded rx packet: len %zd\n", sock_len);
continue;
}
/* We don't need to be notified again. */ /* We don't need to be notified again. */
iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len); iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
fixup = msg.msg_iter; fixup = msg.msg_iter;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment