Commit ce06b03e authored by David S. Miller's avatar David S. Miller

packet: Add helpers to register/unregister ->prot_hook

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent f8bae99e
...@@ -222,6 +222,55 @@ struct packet_skb_cb { ...@@ -222,6 +222,55 @@ struct packet_skb_cb {
#define PACKET_SKB_CB(__skb) ((struct packet_skb_cb *)((__skb)->cb)) #define PACKET_SKB_CB(__skb) ((struct packet_skb_cb *)((__skb)->cb))
static inline struct packet_sock *pkt_sk(struct sock *sk)
{
return (struct packet_sock *)sk;
}
/* register_prot_hook must be invoked with the po->bind_lock held,
* or from a context in which asynchronous accesses to the packet
* socket is not possible (packet_create()).
*/
static void register_prot_hook(struct sock *sk)
{
struct packet_sock *po = pkt_sk(sk);
if (!po->running) {
dev_add_pack(&po->prot_hook);
sock_hold(sk);
po->running = 1;
}
}
/* {,__}unregister_prot_hook() must be invoked with the po->bind_lock
* held. If the sync parameter is true, we will temporarily drop
* the po->bind_lock and do a synchronize_net to make sure no
* asynchronous packet processing paths still refer to the elements
* of po->prot_hook. If the sync parameter is false, it is the
* callers responsibility to take care of this.
*/
static void __unregister_prot_hook(struct sock *sk, bool sync)
{
struct packet_sock *po = pkt_sk(sk);
po->running = 0;
__dev_remove_pack(&po->prot_hook);
__sock_put(sk);
if (sync) {
spin_unlock(&po->bind_lock);
synchronize_net();
spin_lock(&po->bind_lock);
}
}
static void unregister_prot_hook(struct sock *sk, bool sync)
{
struct packet_sock *po = pkt_sk(sk);
if (po->running)
__unregister_prot_hook(sk, sync);
}
static inline __pure struct page *pgv_to_page(void *addr) static inline __pure struct page *pgv_to_page(void *addr)
{ {
if (is_vmalloc_addr(addr)) if (is_vmalloc_addr(addr))
...@@ -324,11 +373,6 @@ static inline void packet_increment_head(struct packet_ring_buffer *buff) ...@@ -324,11 +373,6 @@ static inline void packet_increment_head(struct packet_ring_buffer *buff)
buff->head = buff->head != buff->frame_max ? buff->head+1 : 0; buff->head = buff->head != buff->frame_max ? buff->head+1 : 0;
} }
static inline struct packet_sock *pkt_sk(struct sock *sk)
{
return (struct packet_sock *)sk;
}
static void packet_sock_destruct(struct sock *sk) static void packet_sock_destruct(struct sock *sk)
{ {
skb_queue_purge(&sk->sk_error_queue); skb_queue_purge(&sk->sk_error_queue);
...@@ -1337,15 +1381,7 @@ static int packet_release(struct socket *sock) ...@@ -1337,15 +1381,7 @@ static int packet_release(struct socket *sock)
spin_unlock_bh(&net->packet.sklist_lock); spin_unlock_bh(&net->packet.sklist_lock);
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (po->running) { unregister_prot_hook(sk, false);
/*
* Remove from protocol table
*/
po->running = 0;
po->num = 0;
__dev_remove_pack(&po->prot_hook);
__sock_put(sk);
}
if (po->prot_hook.dev) { if (po->prot_hook.dev) {
dev_put(po->prot_hook.dev); dev_put(po->prot_hook.dev);
po->prot_hook.dev = NULL; po->prot_hook.dev = NULL;
...@@ -1392,15 +1428,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc ...@@ -1392,15 +1428,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
lock_sock(sk); lock_sock(sk);
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (po->running) { unregister_prot_hook(sk, true);
__sock_put(sk);
po->running = 0;
po->num = 0;
spin_unlock(&po->bind_lock);
dev_remove_pack(&po->prot_hook);
spin_lock(&po->bind_lock);
}
po->num = protocol; po->num = protocol;
po->prot_hook.type = protocol; po->prot_hook.type = protocol;
if (po->prot_hook.dev) if (po->prot_hook.dev)
...@@ -1413,9 +1441,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc ...@@ -1413,9 +1441,7 @@ static int packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protoc
goto out_unlock; goto out_unlock;
if (!dev || (dev->flags & IFF_UP)) { if (!dev || (dev->flags & IFF_UP)) {
dev_add_pack(&po->prot_hook); register_prot_hook(sk);
sock_hold(sk);
po->running = 1;
} else { } else {
sk->sk_err = ENETDOWN; sk->sk_err = ENETDOWN;
if (!sock_flag(sk, SOCK_DEAD)) if (!sock_flag(sk, SOCK_DEAD))
...@@ -1542,9 +1568,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol, ...@@ -1542,9 +1568,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
if (proto) { if (proto) {
po->prot_hook.type = proto; po->prot_hook.type = proto;
dev_add_pack(&po->prot_hook); register_prot_hook(sk);
sock_hold(sk);
po->running = 1;
} }
spin_lock_bh(&net->packet.sklist_lock); spin_lock_bh(&net->packet.sklist_lock);
...@@ -2240,9 +2264,7 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void ...@@ -2240,9 +2264,7 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
if (dev->ifindex == po->ifindex) { if (dev->ifindex == po->ifindex) {
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (po->running) { if (po->running) {
__dev_remove_pack(&po->prot_hook); __unregister_prot_hook(sk, false);
__sock_put(sk);
po->running = 0;
sk->sk_err = ENETDOWN; sk->sk_err = ENETDOWN;
if (!sock_flag(sk, SOCK_DEAD)) if (!sock_flag(sk, SOCK_DEAD))
sk->sk_error_report(sk); sk->sk_error_report(sk);
...@@ -2259,11 +2281,8 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void ...@@ -2259,11 +2281,8 @@ static int packet_notifier(struct notifier_block *this, unsigned long msg, void
case NETDEV_UP: case NETDEV_UP:
if (dev->ifindex == po->ifindex) { if (dev->ifindex == po->ifindex) {
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (po->num && !po->running) { if (po->num)
dev_add_pack(&po->prot_hook); register_prot_hook(sk);
sock_hold(sk);
po->running = 1;
}
spin_unlock(&po->bind_lock); spin_unlock(&po->bind_lock);
} }
break; break;
...@@ -2530,10 +2549,8 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req, ...@@ -2530,10 +2549,8 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,
was_running = po->running; was_running = po->running;
num = po->num; num = po->num;
if (was_running) { if (was_running) {
__dev_remove_pack(&po->prot_hook);
po->num = 0; po->num = 0;
po->running = 0; __unregister_prot_hook(sk, false);
__sock_put(sk);
} }
spin_unlock(&po->bind_lock); spin_unlock(&po->bind_lock);
...@@ -2564,11 +2581,9 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req, ...@@ -2564,11 +2581,9 @@ static int packet_set_ring(struct sock *sk, struct tpacket_req *req,
mutex_unlock(&po->pg_vec_lock); mutex_unlock(&po->pg_vec_lock);
spin_lock(&po->bind_lock); spin_lock(&po->bind_lock);
if (was_running && !po->running) { if (was_running) {
sock_hold(sk);
po->running = 1;
po->num = num; po->num = num;
dev_add_pack(&po->prot_hook); register_prot_hook(sk);
} }
spin_unlock(&po->bind_lock); spin_unlock(&po->bind_lock);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment