Commit 5d1d527c authored by David S. Miller's avatar David S. Miller

Merge branch 'raw-RCU-conversion'

Eric Dumazet says:

====================
raw: RCU conversion

Using rwlock in networking code is extremely risky.
writers can starve if enough readers are constantly
grabing the rwlock.

I thought rwlock were at fault and sent this patch:

https://lkml.org/lkml/2022/6/17/272

But Peter and Linus essentially told me rwlock had to be unfair.

We need to get rid of rwlock in networking stacks.

Without this conversion, following script triggers soft lockups:

for i in {1..48}
do
 ping -f -n -q 127.0.0.1 &
 sleep 0.1
done

Next step will be to convert ping sockets to RCU as well.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 8670dc33 0daf07e5
...@@ -20,9 +20,8 @@ ...@@ -20,9 +20,8 @@
extern struct proto raw_prot; extern struct proto raw_prot;
extern struct raw_hashinfo raw_v4_hashinfo; extern struct raw_hashinfo raw_v4_hashinfo;
struct sock *__raw_v4_lookup(struct net *net, struct sock *sk, bool raw_v4_match(struct net *net, struct sock *sk, unsigned short num,
unsigned short num, __be32 raddr, __be32 raddr, __be32 laddr, int dif, int sdif);
__be32 laddr, int dif, int sdif);
int raw_abort(struct sock *sk, int err); int raw_abort(struct sock *sk, int err);
void raw_icmp_error(struct sk_buff *, int, u32); void raw_icmp_error(struct sk_buff *, int, u32);
...@@ -34,9 +33,18 @@ int raw_rcv(struct sock *, struct sk_buff *); ...@@ -34,9 +33,18 @@ int raw_rcv(struct sock *, struct sk_buff *);
struct raw_hashinfo { struct raw_hashinfo {
rwlock_t lock; rwlock_t lock;
struct hlist_head ht[RAW_HTABLE_SIZE]; struct hlist_nulls_head ht[RAW_HTABLE_SIZE];
}; };
static inline void raw_hashinfo_init(struct raw_hashinfo *hashinfo)
{
int i;
rwlock_init(&hashinfo->lock);
for (i = 0; i < RAW_HTABLE_SIZE; i++)
INIT_HLIST_NULLS_HEAD(&hashinfo->ht[i], i);
}
#ifdef CONFIG_PROC_FS #ifdef CONFIG_PROC_FS
int raw_proc_init(void); int raw_proc_init(void);
void raw_proc_exit(void); void raw_proc_exit(void);
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
#define _NET_RAWV6_H #define _NET_RAWV6_H
#include <net/protocol.h> #include <net/protocol.h>
#include <net/raw.h>
extern struct raw_hashinfo raw_v6_hashinfo; extern struct raw_hashinfo raw_v6_hashinfo;
struct sock *__raw_v6_lookup(struct net *net, struct sock *sk, bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num,
unsigned short num, const struct in6_addr *loc_addr, const struct in6_addr *loc_addr,
const struct in6_addr *rmt_addr, int dif, int sdif); const struct in6_addr *rmt_addr, int dif, int sdif);
int raw_abort(struct sock *sk, int err); int raw_abort(struct sock *sk, int err);
......
...@@ -1929,6 +1929,8 @@ static int __init inet_init(void) ...@@ -1929,6 +1929,8 @@ static int __init inet_init(void)
sock_skb_cb_check_size(sizeof(struct inet_skb_parm)); sock_skb_cb_check_size(sizeof(struct inet_skb_parm));
raw_hashinfo_init(&raw_v4_hashinfo);
rc = proto_register(&tcp_prot, 1); rc = proto_register(&tcp_prot, 1);
if (rc) if (rc)
goto out; goto out;
......
...@@ -85,20 +85,19 @@ struct raw_frag_vec { ...@@ -85,20 +85,19 @@ struct raw_frag_vec {
int hlen; int hlen;
}; };
struct raw_hashinfo raw_v4_hashinfo = { struct raw_hashinfo raw_v4_hashinfo;
.lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock),
};
EXPORT_SYMBOL_GPL(raw_v4_hashinfo); EXPORT_SYMBOL_GPL(raw_v4_hashinfo);
int raw_hash_sk(struct sock *sk) int raw_hash_sk(struct sock *sk)
{ {
struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; struct raw_hashinfo *h = sk->sk_prot->h.raw_hash;
struct hlist_head *head; struct hlist_nulls_head *hlist;
head = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)]; hlist = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)];
write_lock_bh(&h->lock); write_lock_bh(&h->lock);
sk_add_node(sk, head); hlist_nulls_add_head_rcu(&sk->sk_nulls_node, hlist);
sock_set_flag(sk, SOCK_RCU_FREE);
write_unlock_bh(&h->lock); write_unlock_bh(&h->lock);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
...@@ -111,30 +110,25 @@ void raw_unhash_sk(struct sock *sk) ...@@ -111,30 +110,25 @@ void raw_unhash_sk(struct sock *sk)
struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; struct raw_hashinfo *h = sk->sk_prot->h.raw_hash;
write_lock_bh(&h->lock); write_lock_bh(&h->lock);
if (sk_del_node_init(sk)) if (__sk_nulls_del_node_init_rcu(sk))
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
write_unlock_bh(&h->lock); write_unlock_bh(&h->lock);
} }
EXPORT_SYMBOL_GPL(raw_unhash_sk); EXPORT_SYMBOL_GPL(raw_unhash_sk);
struct sock *__raw_v4_lookup(struct net *net, struct sock *sk, bool raw_v4_match(struct net *net, struct sock *sk, unsigned short num,
unsigned short num, __be32 raddr, __be32 laddr, __be32 raddr, __be32 laddr, int dif, int sdif)
int dif, int sdif)
{ {
sk_for_each_from(sk) { struct inet_sock *inet = inet_sk(sk);
struct inet_sock *inet = inet_sk(sk);
if (net_eq(sock_net(sk), net) && inet->inet_num == num &&
if (net_eq(sock_net(sk), net) && inet->inet_num == num && !(inet->inet_daddr && inet->inet_daddr != raddr) &&
!(inet->inet_daddr && inet->inet_daddr != raddr) && !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) &&
!(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) && raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) return true;
goto found; /* gotcha */ return false;
}
sk = NULL;
found:
return sk;
} }
EXPORT_SYMBOL_GPL(__raw_v4_lookup); EXPORT_SYMBOL_GPL(raw_v4_match);
/* /*
* 0 - deliver * 0 - deliver
...@@ -168,23 +162,20 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb) ...@@ -168,23 +162,20 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb)
*/ */
static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
{ {
struct net *net = dev_net(skb->dev);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
int sdif = inet_sdif(skb); int sdif = inet_sdif(skb);
int dif = inet_iif(skb); int dif = inet_iif(skb);
struct sock *sk;
struct hlist_head *head;
int delivered = 0; int delivered = 0;
struct net *net; struct sock *sk;
read_lock(&raw_v4_hashinfo.lock);
head = &raw_v4_hashinfo.ht[hash];
if (hlist_empty(head))
goto out;
net = dev_net(skb->dev);
sk = __raw_v4_lookup(net, __sk_head(head), iph->protocol,
iph->saddr, iph->daddr, dif, sdif);
while (sk) { hlist = &raw_v4_hashinfo.ht[hash];
rcu_read_lock();
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (!raw_v4_match(net, sk, iph->protocol,
iph->saddr, iph->daddr, dif, sdif))
continue;
delivered = 1; delivered = 1;
if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) && if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) &&
ip_mc_sf_allow(sk, iph->daddr, iph->saddr, ip_mc_sf_allow(sk, iph->daddr, iph->saddr,
...@@ -195,31 +186,16 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash) ...@@ -195,31 +186,16 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
if (clone) if (clone)
raw_rcv(sk, clone); raw_rcv(sk, clone);
} }
sk = __raw_v4_lookup(net, sk_next(sk), iph->protocol,
iph->saddr, iph->daddr,
dif, sdif);
} }
out: rcu_read_unlock();
read_unlock(&raw_v4_hashinfo.lock);
return delivered; return delivered;
} }
int raw_local_deliver(struct sk_buff *skb, int protocol) int raw_local_deliver(struct sk_buff *skb, int protocol)
{ {
int hash; int hash = protocol & (RAW_HTABLE_SIZE - 1);
struct sock *raw_sk;
hash = protocol & (RAW_HTABLE_SIZE - 1);
raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]);
/* If there maybe a raw socket we must check - if not we
* don't care less
*/
if (raw_sk && !raw_v4_input(skb, ip_hdr(skb), hash))
raw_sk = NULL;
return raw_sk != NULL;
return raw_v4_input(skb, ip_hdr(skb), hash);
} }
static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
...@@ -286,31 +262,27 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) ...@@ -286,31 +262,27 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info) void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info)
{ {
int hash; struct net *net = dev_net(skb->dev);;
struct sock *raw_sk; struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
const struct iphdr *iph; const struct iphdr *iph;
struct net *net; struct sock *sk;
int hash;
hash = protocol & (RAW_HTABLE_SIZE - 1); hash = protocol & (RAW_HTABLE_SIZE - 1);
hlist = &raw_v4_hashinfo.ht[hash];
read_lock(&raw_v4_hashinfo.lock); rcu_read_lock();
raw_sk = sk_head(&raw_v4_hashinfo.ht[hash]); hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (raw_sk) {
int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
iph = (const struct iphdr *)skb->data; iph = (const struct iphdr *)skb->data;
net = dev_net(skb->dev); if (!raw_v4_match(net, sk, iph->protocol,
iph->saddr, iph->daddr, dif, sdif))
while ((raw_sk = __raw_v4_lookup(net, raw_sk, protocol, continue;
iph->daddr, iph->saddr, raw_err(sk, skb, info);
dif, sdif)) != NULL) {
raw_err(raw_sk, skb, info);
raw_sk = sk_next(raw_sk);
iph = (const struct iphdr *)skb->data;
}
} }
read_unlock(&raw_v4_hashinfo.lock); rcu_read_unlock();
} }
static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb) static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb)
...@@ -971,44 +943,41 @@ struct proto raw_prot = { ...@@ -971,44 +943,41 @@ struct proto raw_prot = {
}; };
#ifdef CONFIG_PROC_FS #ifdef CONFIG_PROC_FS
static struct sock *raw_get_first(struct seq_file *seq) static struct sock *raw_get_first(struct seq_file *seq, int bucket)
{ {
struct sock *sk;
struct raw_hashinfo *h = pde_data(file_inode(seq->file)); struct raw_hashinfo *h = pde_data(file_inode(seq->file));
struct raw_iter_state *state = raw_seq_private(seq); struct raw_iter_state *state = raw_seq_private(seq);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
struct sock *sk;
for (state->bucket = 0; state->bucket < RAW_HTABLE_SIZE; for (state->bucket = bucket; state->bucket < RAW_HTABLE_SIZE;
++state->bucket) { ++state->bucket) {
sk_for_each(sk, &h->ht[state->bucket]) hlist = &h->ht[state->bucket];
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (sock_net(sk) == seq_file_net(seq)) if (sock_net(sk) == seq_file_net(seq))
goto found; return sk;
}
} }
sk = NULL; return NULL;
found:
return sk;
} }
static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk) static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk)
{ {
struct raw_hashinfo *h = pde_data(file_inode(seq->file));
struct raw_iter_state *state = raw_seq_private(seq); struct raw_iter_state *state = raw_seq_private(seq);
do { do {
sk = sk_next(sk); sk = sk_nulls_next(sk);
try_again:
;
} while (sk && sock_net(sk) != seq_file_net(seq)); } while (sk && sock_net(sk) != seq_file_net(seq));
if (!sk && ++state->bucket < RAW_HTABLE_SIZE) { if (!sk)
sk = sk_head(&h->ht[state->bucket]); return raw_get_first(seq, state->bucket + 1);
goto try_again;
}
return sk; return sk;
} }
static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos) static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos)
{ {
struct sock *sk = raw_get_first(seq); struct sock *sk = raw_get_first(seq, 0);
if (sk) if (sk)
while (pos && (sk = raw_get_next(seq, sk)) != NULL) while (pos && (sk = raw_get_next(seq, sk)) != NULL)
...@@ -1017,11 +986,9 @@ static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos) ...@@ -1017,11 +986,9 @@ static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos)
} }
void *raw_seq_start(struct seq_file *seq, loff_t *pos) void *raw_seq_start(struct seq_file *seq, loff_t *pos)
__acquires(&h->lock) __acquires(RCU)
{ {
struct raw_hashinfo *h = pde_data(file_inode(seq->file)); rcu_read_lock();
read_lock(&h->lock);
return *pos ? raw_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; return *pos ? raw_get_idx(seq, *pos - 1) : SEQ_START_TOKEN;
} }
EXPORT_SYMBOL_GPL(raw_seq_start); EXPORT_SYMBOL_GPL(raw_seq_start);
...@@ -1031,7 +998,7 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -1031,7 +998,7 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos)
struct sock *sk; struct sock *sk;
if (v == SEQ_START_TOKEN) if (v == SEQ_START_TOKEN)
sk = raw_get_first(seq); sk = raw_get_first(seq, 0);
else else
sk = raw_get_next(seq, v); sk = raw_get_next(seq, v);
++*pos; ++*pos;
...@@ -1040,11 +1007,9 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -1040,11 +1007,9 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos)
EXPORT_SYMBOL_GPL(raw_seq_next); EXPORT_SYMBOL_GPL(raw_seq_next);
void raw_seq_stop(struct seq_file *seq, void *v) void raw_seq_stop(struct seq_file *seq, void *v)
__releases(&h->lock) __releases(RCU)
{ {
struct raw_hashinfo *h = pde_data(file_inode(seq->file)); rcu_read_unlock();
read_unlock(&h->lock);
} }
EXPORT_SYMBOL_GPL(raw_seq_stop); EXPORT_SYMBOL_GPL(raw_seq_stop);
...@@ -1106,6 +1071,7 @@ static __net_initdata struct pernet_operations raw_net_ops = { ...@@ -1106,6 +1071,7 @@ static __net_initdata struct pernet_operations raw_net_ops = {
int __init raw_proc_init(void) int __init raw_proc_init(void)
{ {
return register_pernet_subsys(&raw_net_ops); return register_pernet_subsys(&raw_net_ops);
} }
......
...@@ -34,57 +34,57 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r) ...@@ -34,57 +34,57 @@ raw_get_hashinfo(const struct inet_diag_req_v2 *r)
* use helper to figure it out. * use helper to figure it out.
*/ */
static struct sock *raw_lookup(struct net *net, struct sock *from, static bool raw_lookup(struct net *net, struct sock *sk,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
struct inet_diag_req_raw *r = (void *)req; struct inet_diag_req_raw *r = (void *)req;
struct sock *sk = NULL;
if (r->sdiag_family == AF_INET) if (r->sdiag_family == AF_INET)
sk = __raw_v4_lookup(net, from, r->sdiag_raw_protocol, return raw_v4_match(net, sk, r->sdiag_raw_protocol,
r->id.idiag_dst[0], r->id.idiag_dst[0],
r->id.idiag_src[0], r->id.idiag_src[0],
r->id.idiag_if, 0); r->id.idiag_if, 0);
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
else else
sk = __raw_v6_lookup(net, from, r->sdiag_raw_protocol, return raw_v6_match(net, sk, r->sdiag_raw_protocol,
(const struct in6_addr *)r->id.idiag_src, (const struct in6_addr *)r->id.idiag_src,
(const struct in6_addr *)r->id.idiag_dst, (const struct in6_addr *)r->id.idiag_dst,
r->id.idiag_if, 0); r->id.idiag_if, 0);
#endif #endif
return sk; return false;
} }
static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r) static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r)
{ {
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
struct sock *sk = NULL, *s; struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
struct sock *sk;
int slot; int slot;
if (IS_ERR(hashinfo)) if (IS_ERR(hashinfo))
return ERR_CAST(hashinfo); return ERR_CAST(hashinfo);
read_lock(&hashinfo->lock); rcu_read_lock();
for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) { for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) {
sk_for_each(s, &hashinfo->ht[slot]) { hlist = &hashinfo->ht[slot];
sk = raw_lookup(net, s, r); hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (sk) { if (raw_lookup(net, sk, r)) {
/* /*
* Grab it and keep until we fill * Grab it and keep until we fill
* diag meaage to be reported, so * diag message to be reported, so
* caller should call sock_put then. * caller should call sock_put then.
* We can do that because we're keeping
* hashinfo->lock here.
*/ */
sock_hold(sk); if (refcount_inc_not_zero(&sk->sk_refcnt))
goto out_unlock; goto out_unlock;
} }
} }
} }
sk = ERR_PTR(-ENOENT);
out_unlock: out_unlock:
read_unlock(&hashinfo->lock); rcu_read_unlock();
return sk ? sk : ERR_PTR(-ENOENT); return sk;
} }
static int raw_diag_dump_one(struct netlink_callback *cb, static int raw_diag_dump_one(struct netlink_callback *cb,
...@@ -142,6 +142,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, ...@@ -142,6 +142,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct inet_diag_dump_data *cb_data; struct inet_diag_dump_data *cb_data;
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
int num, s_num, slot, s_slot; int num, s_num, slot, s_slot;
struct sock *sk = NULL; struct sock *sk = NULL;
struct nlattr *bc; struct nlattr *bc;
...@@ -158,7 +160,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, ...@@ -158,7 +160,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) { for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) {
num = 0; num = 0;
sk_for_each(sk, &hashinfo->ht[slot]) { hlist = &hashinfo->ht[slot];
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
if (!net_eq(sock_net(sk), net)) if (!net_eq(sock_net(sk), net))
......
...@@ -63,6 +63,7 @@ ...@@ -63,6 +63,7 @@
#include <net/compat.h> #include <net/compat.h>
#include <net/xfrm.h> #include <net/xfrm.h>
#include <net/ioam6.h> #include <net/ioam6.h>
#include <net/rawv6.h>
#include <linux/uaccess.h> #include <linux/uaccess.h>
#include <linux/mroute6.h> #include <linux/mroute6.h>
...@@ -1073,6 +1074,8 @@ static int __init inet6_init(void) ...@@ -1073,6 +1074,8 @@ static int __init inet6_init(void)
goto out; goto out;
} }
raw_hashinfo_init(&raw_v6_hashinfo);
err = proto_register(&tcpv6_prot, 1); err = proto_register(&tcpv6_prot, 1);
if (err) if (err)
goto out; goto out;
......
...@@ -61,46 +61,30 @@ ...@@ -61,46 +61,30 @@
#define ICMPV6_HDRLEN 4 /* ICMPv6 header, RFC 4443 Section 2.1 */ #define ICMPV6_HDRLEN 4 /* ICMPv6 header, RFC 4443 Section 2.1 */
struct raw_hashinfo raw_v6_hashinfo = { struct raw_hashinfo raw_v6_hashinfo;
.lock = __RW_LOCK_UNLOCKED(raw_v6_hashinfo.lock),
};
EXPORT_SYMBOL_GPL(raw_v6_hashinfo); EXPORT_SYMBOL_GPL(raw_v6_hashinfo);
struct sock *__raw_v6_lookup(struct net *net, struct sock *sk, bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num,
unsigned short num, const struct in6_addr *loc_addr, const struct in6_addr *loc_addr,
const struct in6_addr *rmt_addr, int dif, int sdif) const struct in6_addr *rmt_addr, int dif, int sdif)
{ {
bool is_multicast = ipv6_addr_is_multicast(loc_addr); if (inet_sk(sk)->inet_num != num ||
!net_eq(sock_net(sk), net) ||
sk_for_each_from(sk) (!ipv6_addr_any(&sk->sk_v6_daddr) &&
if (inet_sk(sk)->inet_num == num) { !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ||
!raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if,
if (!net_eq(sock_net(sk), net)) dif, sdif))
continue; return false;
if (!ipv6_addr_any(&sk->sk_v6_daddr) && if (ipv6_addr_any(&sk->sk_v6_rcv_saddr) ||
!ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr) ||
continue; (ipv6_addr_is_multicast(loc_addr) &&
inet6_mc_check(sk, loc_addr, rmt_addr)))
if (!raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, return true;
dif, sdif))
continue; return false;
if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) {
if (ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr))
goto found;
if (is_multicast &&
inet6_mc_check(sk, loc_addr, rmt_addr))
goto found;
continue;
}
goto found;
}
sk = NULL;
found:
return sk;
} }
EXPORT_SYMBOL_GPL(__raw_v6_lookup); EXPORT_SYMBOL_GPL(raw_v6_match);
/* /*
* 0 - deliver * 0 - deliver
...@@ -156,31 +140,27 @@ EXPORT_SYMBOL(rawv6_mh_filter_unregister); ...@@ -156,31 +140,27 @@ EXPORT_SYMBOL(rawv6_mh_filter_unregister);
*/ */
static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
{ {
struct net *net = dev_net(skb->dev);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
const struct in6_addr *saddr; const struct in6_addr *saddr;
const struct in6_addr *daddr; const struct in6_addr *daddr;
struct sock *sk; struct sock *sk;
bool delivered = false; bool delivered = false;
__u8 hash; __u8 hash;
struct net *net;
saddr = &ipv6_hdr(skb)->saddr; saddr = &ipv6_hdr(skb)->saddr;
daddr = saddr + 1; daddr = saddr + 1;
hash = nexthdr & (RAW_HTABLE_SIZE - 1); hash = nexthdr & (RAW_HTABLE_SIZE - 1);
hlist = &raw_v6_hashinfo.ht[hash];
read_lock(&raw_v6_hashinfo.lock); rcu_read_lock();
sk = sk_head(&raw_v6_hashinfo.ht[hash]); hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (!sk)
goto out;
net = dev_net(skb->dev);
sk = __raw_v6_lookup(net, sk, nexthdr, daddr, saddr,
inet6_iif(skb), inet6_sdif(skb));
while (sk) {
int filtered; int filtered;
if (!raw_v6_match(net, sk, nexthdr, daddr, saddr,
inet6_iif(skb), inet6_sdif(skb)))
continue;
delivered = true; delivered = true;
switch (nexthdr) { switch (nexthdr) {
case IPPROTO_ICMPV6: case IPPROTO_ICMPV6:
...@@ -219,23 +199,14 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) ...@@ -219,23 +199,14 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
rawv6_rcv(sk, clone); rawv6_rcv(sk, clone);
} }
} }
sk = __raw_v6_lookup(net, sk_next(sk), nexthdr, daddr, saddr,
inet6_iif(skb), inet6_sdif(skb));
} }
out: rcu_read_unlock();
read_unlock(&raw_v6_hashinfo.lock);
return delivered; return delivered;
} }
bool raw6_local_deliver(struct sk_buff *skb, int nexthdr) bool raw6_local_deliver(struct sk_buff *skb, int nexthdr)
{ {
struct sock *raw_sk; return ipv6_raw_deliver(skb, nexthdr);
raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (RAW_HTABLE_SIZE - 1)]);
if (raw_sk && !ipv6_raw_deliver(skb, nexthdr))
raw_sk = NULL;
return raw_sk != NULL;
} }
/* This cleans up af_inet6 a bit. -DaveM */ /* This cleans up af_inet6 a bit. -DaveM */
...@@ -361,30 +332,28 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb, ...@@ -361,30 +332,28 @@ static void rawv6_err(struct sock *sk, struct sk_buff *skb,
void raw6_icmp_error(struct sk_buff *skb, int nexthdr, void raw6_icmp_error(struct sk_buff *skb, int nexthdr,
u8 type, u8 code, int inner_offset, __be32 info) u8 type, u8 code, int inner_offset, __be32 info)
{ {
const struct in6_addr *saddr, *daddr;
struct net *net = dev_net(skb->dev);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
struct sock *sk; struct sock *sk;
int hash; int hash;
const struct in6_addr *saddr, *daddr;
struct net *net;
hash = nexthdr & (RAW_HTABLE_SIZE - 1); hash = nexthdr & (RAW_HTABLE_SIZE - 1);
hlist = &raw_v6_hashinfo.ht[hash];
read_lock(&raw_v6_hashinfo.lock); rcu_read_lock();
sk = sk_head(&raw_v6_hashinfo.ht[hash]); hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (sk) {
/* Note: ipv6_hdr(skb) != skb->data */ /* Note: ipv6_hdr(skb) != skb->data */
const struct ipv6hdr *ip6h = (const struct ipv6hdr *)skb->data; const struct ipv6hdr *ip6h = (const struct ipv6hdr *)skb->data;
saddr = &ip6h->saddr; saddr = &ip6h->saddr;
daddr = &ip6h->daddr; daddr = &ip6h->daddr;
net = dev_net(skb->dev);
while ((sk = __raw_v6_lookup(net, sk, nexthdr, saddr, daddr, if (!raw_v6_match(net, sk, nexthdr, &ip6h->saddr, &ip6h->daddr,
inet6_iif(skb), inet6_iif(skb)))) { inet6_iif(skb), inet6_iif(skb)))
rawv6_err(sk, skb, NULL, type, code, continue;
inner_offset, info); rawv6_err(sk, skb, NULL, type, code, inner_offset, info);
sk = sk_next(sk);
}
} }
read_unlock(&raw_v6_hashinfo.lock); rcu_read_unlock();
} }
static inline int rawv6_rcv_skb(struct sock *sk, struct sk_buff *skb) static inline int rawv6_rcv_skb(struct sock *sk, struct sk_buff *skb)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment