Commit fda9ef5d authored by Dmitry Mishin's avatar Dmitry Mishin Committed by David S. Miller

[NET]: Fix sk->sk_filter field access

Function sk_filter() is called from tcp_v{4,6}_rcv() functions with arg
needlock = 0, while socket is not locked at that moment. In order to avoid
this and similar issues in the future, use rcu for sk->sk_filter field read
protection.
Signed-off-by: default avatarDmitry Mishin <dim@openvz.org>
Signed-off-by: default avatarAlexey Kuznetsov <kuznet@ms2.inr.ac.ru>
Signed-off-by: default avatarKirill Korotaev <dev@openvz.org>
parent dc435e6d
...@@ -25,10 +25,10 @@ ...@@ -25,10 +25,10 @@
struct sock_filter /* Filter block */ struct sock_filter /* Filter block */
{ {
__u16 code; /* Actual filter code */ __u16 code; /* Actual filter code */
__u8 jt; /* Jump true */ __u8 jt; /* Jump true */
__u8 jf; /* Jump false */ __u8 jf; /* Jump false */
__u32 k; /* Generic multiuse field */ __u32 k; /* Generic multiuse field */
}; };
struct sock_fprog /* Required for SO_ATTACH_FILTER. */ struct sock_fprog /* Required for SO_ATTACH_FILTER. */
...@@ -41,8 +41,9 @@ struct sock_fprog /* Required for SO_ATTACH_FILTER. */ ...@@ -41,8 +41,9 @@ struct sock_fprog /* Required for SO_ATTACH_FILTER. */
struct sk_filter struct sk_filter
{ {
atomic_t refcnt; atomic_t refcnt;
unsigned int len; /* Number of filter blocks */ unsigned int len; /* Number of filter blocks */
struct sock_filter insns[0]; struct rcu_head rcu;
struct sock_filter insns[0];
}; };
static inline unsigned int sk_filter_len(struct sk_filter *fp) static inline unsigned int sk_filter_len(struct sk_filter *fp)
......
...@@ -862,30 +862,24 @@ extern void sock_init_data(struct socket *sock, struct sock *sk); ...@@ -862,30 +862,24 @@ extern void sock_init_data(struct socket *sock, struct sock *sk);
* *
*/ */
static inline int sk_filter(struct sock *sk, struct sk_buff *skb, int needlock) static inline int sk_filter(struct sock *sk, struct sk_buff *skb)
{ {
int err; int err;
struct sk_filter *filter;
err = security_sock_rcv_skb(sk, skb); err = security_sock_rcv_skb(sk, skb);
if (err) if (err)
return err; return err;
if (sk->sk_filter) { rcu_read_lock_bh();
struct sk_filter *filter; filter = sk->sk_filter;
if (filter) {
if (needlock) unsigned int pkt_len = sk_run_filter(skb, filter->insns,
bh_lock_sock(sk); filter->len);
err = pkt_len ? pskb_trim(skb, pkt_len) : -EPERM;
filter = sk->sk_filter;
if (filter) {
unsigned int pkt_len = sk_run_filter(skb, filter->insns,
filter->len);
err = pkt_len ? pskb_trim(skb, pkt_len) : -EPERM;
}
if (needlock)
bh_unlock_sock(sk);
} }
rcu_read_unlock_bh();
return err; return err;
} }
...@@ -897,6 +891,12 @@ static inline int sk_filter(struct sock *sk, struct sk_buff *skb, int needlock) ...@@ -897,6 +891,12 @@ static inline int sk_filter(struct sock *sk, struct sk_buff *skb, int needlock)
* Remove a filter from a socket and release its resources. * Remove a filter from a socket and release its resources.
*/ */
static inline void sk_filter_rcu_free(struct rcu_head *rcu)
{
struct sk_filter *fp = container_of(rcu, struct sk_filter, rcu);
kfree(fp);
}
static inline void sk_filter_release(struct sock *sk, struct sk_filter *fp) static inline void sk_filter_release(struct sock *sk, struct sk_filter *fp)
{ {
unsigned int size = sk_filter_len(fp); unsigned int size = sk_filter_len(fp);
...@@ -904,7 +904,7 @@ static inline void sk_filter_release(struct sock *sk, struct sk_filter *fp) ...@@ -904,7 +904,7 @@ static inline void sk_filter_release(struct sock *sk, struct sk_filter *fp)
atomic_sub(size, &sk->sk_omem_alloc); atomic_sub(size, &sk->sk_omem_alloc);
if (atomic_dec_and_test(&fp->refcnt)) if (atomic_dec_and_test(&fp->refcnt))
kfree(fp); call_rcu_bh(&fp->rcu, sk_filter_rcu_free);
} }
static inline void sk_filter_charge(struct sock *sk, struct sk_filter *fp) static inline void sk_filter_charge(struct sock *sk, struct sk_filter *fp)
......
...@@ -422,10 +422,10 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk) ...@@ -422,10 +422,10 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk)
if (!err) { if (!err) {
struct sk_filter *old_fp; struct sk_filter *old_fp;
spin_lock_bh(&sk->sk_lock.slock); rcu_read_lock_bh();
old_fp = sk->sk_filter; old_fp = rcu_dereference(sk->sk_filter);
sk->sk_filter = fp; rcu_assign_pointer(sk->sk_filter, fp);
spin_unlock_bh(&sk->sk_lock.slock); rcu_read_unlock_bh();
fp = old_fp; fp = old_fp;
} }
......
...@@ -247,11 +247,7 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) ...@@ -247,11 +247,7 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
goto out; goto out;
} }
/* It would be deadlock, if sock_queue_rcv_skb is used err = sk_filter(sk, skb);
with socket lock! We assume that users of this
function are lock free.
*/
err = sk_filter(sk, skb, 1);
if (err) if (err)
goto out; goto out;
...@@ -278,7 +274,7 @@ int sk_receive_skb(struct sock *sk, struct sk_buff *skb) ...@@ -278,7 +274,7 @@ int sk_receive_skb(struct sock *sk, struct sk_buff *skb)
{ {
int rc = NET_RX_SUCCESS; int rc = NET_RX_SUCCESS;
if (sk_filter(sk, skb, 0)) if (sk_filter(sk, skb))
goto discard_and_relse; goto discard_and_relse;
skb->dev = NULL; skb->dev = NULL;
...@@ -606,15 +602,15 @@ int sock_setsockopt(struct socket *sock, int level, int optname, ...@@ -606,15 +602,15 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
break; break;
case SO_DETACH_FILTER: case SO_DETACH_FILTER:
spin_lock_bh(&sk->sk_lock.slock); rcu_read_lock_bh();
filter = sk->sk_filter; filter = rcu_dereference(sk->sk_filter);
if (filter) { if (filter) {
sk->sk_filter = NULL; rcu_assign_pointer(sk->sk_filter, NULL);
spin_unlock_bh(&sk->sk_lock.slock);
sk_filter_release(sk, filter); sk_filter_release(sk, filter);
rcu_read_unlock_bh();
break; break;
} }
spin_unlock_bh(&sk->sk_lock.slock); rcu_read_unlock_bh();
ret = -ENONET; ret = -ENONET;
break; break;
...@@ -884,10 +880,10 @@ void sk_free(struct sock *sk) ...@@ -884,10 +880,10 @@ void sk_free(struct sock *sk)
if (sk->sk_destruct) if (sk->sk_destruct)
sk->sk_destruct(sk); sk->sk_destruct(sk);
filter = sk->sk_filter; filter = rcu_dereference(sk->sk_filter);
if (filter) { if (filter) {
sk_filter_release(sk, filter); sk_filter_release(sk, filter);
sk->sk_filter = NULL; rcu_assign_pointer(sk->sk_filter, NULL);
} }
sock_disable_timestamp(sk); sock_disable_timestamp(sk);
......
...@@ -970,7 +970,7 @@ static int dccp_v6_do_rcv(struct sock *sk, struct sk_buff *skb) ...@@ -970,7 +970,7 @@ static int dccp_v6_do_rcv(struct sock *sk, struct sk_buff *skb)
if (skb->protocol == htons(ETH_P_IP)) if (skb->protocol == htons(ETH_P_IP))
return dccp_v4_do_rcv(sk, skb); return dccp_v4_do_rcv(sk, skb);
if (sk_filter(sk, skb, 0)) if (sk_filter(sk, skb))
goto discard; goto discard;
/* /*
......
...@@ -586,7 +586,7 @@ static __inline__ int dn_queue_skb(struct sock *sk, struct sk_buff *skb, int sig ...@@ -586,7 +586,7 @@ static __inline__ int dn_queue_skb(struct sock *sk, struct sk_buff *skb, int sig
goto out; goto out;
} }
err = sk_filter(sk, skb, 0); err = sk_filter(sk, skb);
if (err) if (err)
goto out; goto out;
......
...@@ -1104,7 +1104,7 @@ int tcp_v4_rcv(struct sk_buff *skb) ...@@ -1104,7 +1104,7 @@ int tcp_v4_rcv(struct sk_buff *skb)
goto discard_and_relse; goto discard_and_relse;
nf_reset(skb); nf_reset(skb);
if (sk_filter(sk, skb, 0)) if (sk_filter(sk, skb))
goto discard_and_relse; goto discard_and_relse;
skb->dev = NULL; skb->dev = NULL;
......
...@@ -1075,7 +1075,7 @@ static int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb) ...@@ -1075,7 +1075,7 @@ static int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb)
if (skb->protocol == htons(ETH_P_IP)) if (skb->protocol == htons(ETH_P_IP))
return tcp_v4_do_rcv(sk, skb); return tcp_v4_do_rcv(sk, skb);
if (sk_filter(sk, skb, 0)) if (sk_filter(sk, skb))
goto discard; goto discard;
/* /*
...@@ -1232,7 +1232,7 @@ static int tcp_v6_rcv(struct sk_buff **pskb) ...@@ -1232,7 +1232,7 @@ static int tcp_v6_rcv(struct sk_buff **pskb)
if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb))
goto discard_and_relse; goto discard_and_relse;
if (sk_filter(sk, skb, 0)) if (sk_filter(sk, skb))
goto discard_and_relse; goto discard_and_relse;
skb->dev = NULL; skb->dev = NULL;
......
...@@ -427,21 +427,24 @@ static int packet_sendmsg_spkt(struct kiocb *iocb, struct socket *sock, ...@@ -427,21 +427,24 @@ static int packet_sendmsg_spkt(struct kiocb *iocb, struct socket *sock,
} }
#endif #endif
static inline unsigned run_filter(struct sk_buff *skb, struct sock *sk, unsigned res) static inline int run_filter(struct sk_buff *skb, struct sock *sk,
unsigned *snaplen)
{ {
struct sk_filter *filter; struct sk_filter *filter;
int err = 0;
bh_lock_sock(sk); rcu_read_lock_bh();
filter = sk->sk_filter; filter = rcu_dereference(sk->sk_filter);
/* if (filter != NULL) {
* Our caller already checked that filter != NULL but we need to err = sk_run_filter(skb, filter->insns, filter->len);
* verify that under bh_lock_sock() to be safe if (!err)
*/ err = -EPERM;
if (likely(filter != NULL)) else if (*snaplen > err)
res = sk_run_filter(skb, filter->insns, filter->len); *snaplen = err;
bh_unlock_sock(sk); }
rcu_read_unlock_bh();
return res; return err;
} }
/* /*
...@@ -491,13 +494,8 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev, struct packet ...@@ -491,13 +494,8 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev, struct packet
snaplen = skb->len; snaplen = skb->len;
if (sk->sk_filter) { if (run_filter(skb, sk, &snaplen) < 0)
unsigned res = run_filter(skb, sk, snaplen); goto drop_n_restore;
if (res == 0)
goto drop_n_restore;
if (snaplen > res)
snaplen = res;
}
if (atomic_read(&sk->sk_rmem_alloc) + skb->truesize >= if (atomic_read(&sk->sk_rmem_alloc) + skb->truesize >=
(unsigned)sk->sk_rcvbuf) (unsigned)sk->sk_rcvbuf)
...@@ -593,13 +591,8 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, struct packe ...@@ -593,13 +591,8 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, struct packe
snaplen = skb->len; snaplen = skb->len;
if (sk->sk_filter) { if (run_filter(skb, sk, &snaplen) < 0)
unsigned res = run_filter(skb, sk, snaplen); goto drop_n_restore;
if (res == 0)
goto drop_n_restore;
if (snaplen > res)
snaplen = res;
}
if (sk->sk_type == SOCK_DGRAM) { if (sk->sk_type == SOCK_DGRAM) {
macoff = netoff = TPACKET_ALIGN(TPACKET_HDRLEN) + 16; macoff = netoff = TPACKET_ALIGN(TPACKET_HDRLEN) + 16;
......
...@@ -228,7 +228,7 @@ int sctp_rcv(struct sk_buff *skb) ...@@ -228,7 +228,7 @@ int sctp_rcv(struct sk_buff *skb)
goto discard_release; goto discard_release;
nf_reset(skb); nf_reset(skb);
if (sk_filter(sk, skb, 1)) if (sk_filter(sk, skb))
goto discard_release; goto discard_release;
/* Create an SCTP packet structure. */ /* Create an SCTP packet structure. */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment