Commit 6a5d39aa authored by David S. Miller's avatar David S. Miller

Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf

Daniel Borkmann says:

====================
pull-request: bpf 2018-08-29

The following pull-request contains BPF updates for your *net* tree.

The main changes are:

1) Fix a build error in sk_reuseport_convert_ctx_access() when
   compiling with clang which cannot resolve hweight_long() at
   build time inside the BUILD_BUG_ON() assertion, from Stefan.

2) Several fixes for BPF sockmap, four of them in getting the
   bpf_msg_pull_data() helper to work, one use after free case
   in bpf_tcp_close() and one refcount leak in bpf_tcp_recvmsg(),
   from Daniel.

3) Another fix for BPF sockmap where we misaccount sk_mem_uncharge()
   in the socket redirect error case from unwinding scatterlist
   twice, from John.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 53ae914d d65e6c80
...@@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk) ...@@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk)
} }
static void smap_release_sock(struct smap_psock *psock, struct sock *sock); static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
static int free_start_sg(struct sock *sk, struct sk_msg_buff *md); static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
static void bpf_tcp_release(struct sock *sk) static void bpf_tcp_release(struct sock *sk)
{ {
...@@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk) ...@@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk)
goto out; goto out;
if (psock->cork) { if (psock->cork) {
free_start_sg(psock->sock, psock->cork); free_start_sg(psock->sock, psock->cork, true);
kfree(psock->cork); kfree(psock->cork);
psock->cork = NULL; psock->cork = NULL;
} }
...@@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
close_fun = psock->save_close; close_fun = psock->save_close;
if (psock->cork) { if (psock->cork) {
free_start_sg(psock->sock, psock->cork); free_start_sg(psock->sock, psock->cork, true);
kfree(psock->cork); kfree(psock->cork);
psock->cork = NULL; psock->cork = NULL;
} }
list_for_each_entry_safe(md, mtmp, &psock->ingress, list) { list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
list_del(&md->list); list_del(&md->list);
free_start_sg(psock->sock, md); free_start_sg(psock->sock, md, true);
kfree(md); kfree(md);
} }
...@@ -369,7 +369,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -369,7 +369,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
/* If another thread deleted this object skip deletion. /* If another thread deleted this object skip deletion.
* The refcnt on psock may or may not be zero. * The refcnt on psock may or may not be zero.
*/ */
if (l) { if (l && l == link) {
hlist_del_rcu(&link->hash_node); hlist_del_rcu(&link->hash_node);
smap_release_sock(psock, link->sk); smap_release_sock(psock, link->sk);
free_htab_elem(htab, link); free_htab_elem(htab, link);
...@@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes, ...@@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes,
md->sg_start = i; md->sg_start = i;
} }
static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md) static int free_sg(struct sock *sk, int start,
struct sk_msg_buff *md, bool charge)
{ {
struct scatterlist *sg = md->sg_data; struct scatterlist *sg = md->sg_data;
int i = start, free = 0; int i = start, free = 0;
while (sg[i].length) { while (sg[i].length) {
free += sg[i].length; free += sg[i].length;
sk_mem_uncharge(sk, sg[i].length); if (charge)
sk_mem_uncharge(sk, sg[i].length);
if (!md->skb) if (!md->skb)
put_page(sg_page(&sg[i])); put_page(sg_page(&sg[i]));
sg[i].length = 0; sg[i].length = 0;
...@@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md) ...@@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
return free; return free;
} }
static int free_start_sg(struct sock *sk, struct sk_msg_buff *md) static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
{ {
int free = free_sg(sk, md->sg_start, md); int free = free_sg(sk, md->sg_start, md, charge);
md->sg_start = md->sg_end; md->sg_start = md->sg_end;
return free; return free;
...@@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md) ...@@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md) static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
{ {
return free_sg(sk, md->sg_curr, md); return free_sg(sk, md->sg_curr, md, true);
} }
static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md) static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
...@@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes, ...@@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
list_add_tail(&r->list, &psock->ingress); list_add_tail(&r->list, &psock->ingress);
sk->sk_data_ready(sk); sk->sk_data_ready(sk);
} else { } else {
free_start_sg(sk, r); free_start_sg(sk, r, true);
kfree(r); kfree(r);
} }
...@@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send, ...@@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
release_sock(sk); release_sock(sk);
} }
smap_release_sock(psock, sk); smap_release_sock(psock, sk);
if (unlikely(err)) return err;
goto out;
return 0;
out_rcu: out_rcu:
rcu_read_unlock(); rcu_read_unlock();
out: return 0;
free_bytes_sg(NULL, send, md, false);
return err;
} }
static inline void bpf_md_init(struct smap_psock *psock) static inline void bpf_md_init(struct smap_psock *psock)
...@@ -822,7 +820,7 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock, ...@@ -822,7 +820,7 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock,
case __SK_PASS: case __SK_PASS:
err = bpf_tcp_push(sk, send, m, flags, true); err = bpf_tcp_push(sk, send, m, flags, true);
if (unlikely(err)) { if (unlikely(err)) {
*copied -= free_start_sg(sk, m); *copied -= free_start_sg(sk, m, true);
break; break;
} }
...@@ -845,16 +843,17 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock, ...@@ -845,16 +843,17 @@ static int bpf_exec_tx_verdict(struct smap_psock *psock,
lock_sock(sk); lock_sock(sk);
if (unlikely(err < 0)) { if (unlikely(err < 0)) {
free_start_sg(sk, m); int free = free_start_sg(sk, m, false);
psock->sg_size = 0; psock->sg_size = 0;
if (!cork) if (!cork)
*copied -= send; *copied -= free;
} else { } else {
psock->sg_size -= send; psock->sg_size -= send;
} }
if (cork) { if (cork) {
free_start_sg(sk, m); free_start_sg(sk, m, true);
psock->sg_size = 0; psock->sg_size = 0;
kfree(m); kfree(m);
m = NULL; m = NULL;
...@@ -912,6 +911,8 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -912,6 +911,8 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (unlikely(flags & MSG_ERRQUEUE)) if (unlikely(flags & MSG_ERRQUEUE))
return inet_recv_error(sk, msg, len, addr_len); return inet_recv_error(sk, msg, len, addr_len);
if (!skb_queue_empty(&sk->sk_receive_queue))
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
rcu_read_lock(); rcu_read_lock();
psock = smap_psock_sk(sk); psock = smap_psock_sk(sk);
...@@ -922,9 +923,6 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -922,9 +923,6 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
goto out; goto out;
rcu_read_unlock(); rcu_read_unlock();
if (!skb_queue_empty(&sk->sk_receive_queue))
return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
lock_sock(sk); lock_sock(sk);
bytes_ready: bytes_ready:
while (copied != len) { while (copied != len) {
...@@ -1122,7 +1120,7 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -1122,7 +1120,7 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
err = sk_stream_wait_memory(sk, &timeo); err = sk_stream_wait_memory(sk, &timeo);
if (err) { if (err) {
if (m && m != psock->cork) if (m && m != psock->cork)
free_start_sg(sk, m); free_start_sg(sk, m, true);
goto out_err; goto out_err;
} }
} }
...@@ -1581,13 +1579,13 @@ static void smap_gc_work(struct work_struct *w) ...@@ -1581,13 +1579,13 @@ static void smap_gc_work(struct work_struct *w)
bpf_prog_put(psock->bpf_tx_msg); bpf_prog_put(psock->bpf_tx_msg);
if (psock->cork) { if (psock->cork) {
free_start_sg(psock->sock, psock->cork); free_start_sg(psock->sock, psock->cork, true);
kfree(psock->cork); kfree(psock->cork);
} }
list_for_each_entry_safe(md, mtmp, &psock->ingress, list) { list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
list_del(&md->list); list_del(&md->list);
free_start_sg(psock->sock, md); free_start_sg(psock->sock, md, true);
kfree(md); kfree(md);
} }
......
...@@ -2282,14 +2282,21 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { ...@@ -2282,14 +2282,21 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
.arg2_type = ARG_ANYTHING, .arg2_type = ARG_ANYTHING,
}; };
#define sk_msg_iter_var(var) \
do { \
var++; \
if (var == MAX_SKB_FRAGS) \
var = 0; \
} while (0)
BPF_CALL_4(bpf_msg_pull_data, BPF_CALL_4(bpf_msg_pull_data,
struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags) struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
{ {
unsigned int len = 0, offset = 0, copy = 0; unsigned int len = 0, offset = 0, copy = 0;
int bytes = end - start, bytes_sg_total;
struct scatterlist *sg = msg->sg_data; struct scatterlist *sg = msg->sg_data;
int first_sg, last_sg, i, shift; int first_sg, last_sg, i, shift;
unsigned char *p, *to, *from; unsigned char *p, *to, *from;
int bytes = end - start;
struct page *page; struct page *page;
if (unlikely(flags || end <= start)) if (unlikely(flags || end <= start))
...@@ -2299,21 +2306,22 @@ BPF_CALL_4(bpf_msg_pull_data, ...@@ -2299,21 +2306,22 @@ BPF_CALL_4(bpf_msg_pull_data,
i = msg->sg_start; i = msg->sg_start;
do { do {
len = sg[i].length; len = sg[i].length;
offset += len;
if (start < offset + len) if (start < offset + len)
break; break;
i++; offset += len;
if (i == MAX_SKB_FRAGS) sk_msg_iter_var(i);
i = 0;
} while (i != msg->sg_end); } while (i != msg->sg_end);
if (unlikely(start >= offset + len)) if (unlikely(start >= offset + len))
return -EINVAL; return -EINVAL;
if (!msg->sg_copy[i] && bytes <= len)
goto out;
first_sg = i; first_sg = i;
/* The start may point into the sg element so we need to also
* account for the headroom.
*/
bytes_sg_total = start - offset + bytes;
if (!msg->sg_copy[i] && bytes_sg_total <= len)
goto out;
/* At this point we need to linearize multiple scatterlist /* At this point we need to linearize multiple scatterlist
* elements or a single shared page. Either way we need to * elements or a single shared page. Either way we need to
...@@ -2327,15 +2335,13 @@ BPF_CALL_4(bpf_msg_pull_data, ...@@ -2327,15 +2335,13 @@ BPF_CALL_4(bpf_msg_pull_data,
*/ */
do { do {
copy += sg[i].length; copy += sg[i].length;
i++; sk_msg_iter_var(i);
if (i == MAX_SKB_FRAGS) if (bytes_sg_total <= copy)
i = 0;
if (bytes < copy)
break; break;
} while (i != msg->sg_end); } while (i != msg->sg_end);
last_sg = i; last_sg = i;
if (unlikely(copy < end - start)) if (unlikely(bytes_sg_total > copy))
return -EINVAL; return -EINVAL;
page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC, get_order(copy)); page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC, get_order(copy));
...@@ -2355,9 +2361,7 @@ BPF_CALL_4(bpf_msg_pull_data, ...@@ -2355,9 +2361,7 @@ BPF_CALL_4(bpf_msg_pull_data,
sg[i].length = 0; sg[i].length = 0;
put_page(sg_page(&sg[i])); put_page(sg_page(&sg[i]));
i++; sk_msg_iter_var(i);
if (i == MAX_SKB_FRAGS)
i = 0;
} while (i != last_sg); } while (i != last_sg);
sg[first_sg].length = copy; sg[first_sg].length = copy;
...@@ -2367,11 +2371,15 @@ BPF_CALL_4(bpf_msg_pull_data, ...@@ -2367,11 +2371,15 @@ BPF_CALL_4(bpf_msg_pull_data,
* had a single entry though we can just replace it and * had a single entry though we can just replace it and
* be done. Otherwise walk the ring and shift the entries. * be done. Otherwise walk the ring and shift the entries.
*/ */
shift = last_sg - first_sg - 1; WARN_ON_ONCE(last_sg == first_sg);
shift = last_sg > first_sg ?
last_sg - first_sg - 1 :
MAX_SKB_FRAGS - first_sg + last_sg - 1;
if (!shift) if (!shift)
goto out; goto out;
i = first_sg + 1; i = first_sg;
sk_msg_iter_var(i);
do { do {
int move_from; int move_from;
...@@ -2388,15 +2396,13 @@ BPF_CALL_4(bpf_msg_pull_data, ...@@ -2388,15 +2396,13 @@ BPF_CALL_4(bpf_msg_pull_data,
sg[move_from].page_link = 0; sg[move_from].page_link = 0;
sg[move_from].offset = 0; sg[move_from].offset = 0;
i++; sk_msg_iter_var(i);
if (i == MAX_SKB_FRAGS)
i = 0;
} while (1); } while (1);
msg->sg_end -= shift; msg->sg_end -= shift;
if (msg->sg_end < 0) if (msg->sg_end < 0)
msg->sg_end += MAX_SKB_FRAGS; msg->sg_end += MAX_SKB_FRAGS;
out: out:
msg->data = sg_virt(&sg[i]) + start - offset; msg->data = sg_virt(&sg[first_sg]) + start - offset;
msg->data_end = msg->data + bytes; msg->data_end = msg->data + bytes;
return 0; return 0;
...@@ -7281,7 +7287,7 @@ static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type, ...@@ -7281,7 +7287,7 @@ static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type,
break; break;
case offsetof(struct sk_reuseport_md, ip_protocol): case offsetof(struct sk_reuseport_md, ip_protocol):
BUILD_BUG_ON(hweight_long(SK_FL_PROTO_MASK) != BITS_PER_BYTE); BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE);
SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset, SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset,
BPF_W, 0); BPF_W, 0);
*insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK); *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment