Commit fc038410 authored by David S. Miller's avatar David S. Miller

[UDP]: Fix AF-specific references in AF-agnostic code.

__udp_lib_port_inuse() cannot make direct references to
inet_sk(sk)->rcv_saddr as that is ipv4 specific state and
this code is used by ipv6 too.

Use an operations vector to solve this, and this also paves
the way for ipv6 support for non-wild saddr hashing in UDP.
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent a2af421f
...@@ -119,9 +119,16 @@ static inline void udp_lib_close(struct sock *sk, long timeout) ...@@ -119,9 +119,16 @@ static inline void udp_lib_close(struct sock *sk, long timeout)
} }
struct udp_get_port_ops {
int (*saddr_cmp)(const struct sock *sk1, const struct sock *sk2);
int (*saddr_any)(const struct sock *sk);
unsigned int (*hash_port_and_rcv_saddr)(__u16 port,
const struct sock *sk);
};
/* net/ipv4/udp.c */ /* net/ipv4/udp.c */
extern int udp_get_port(struct sock *sk, unsigned short snum, extern int udp_get_port(struct sock *sk, unsigned short snum,
int (*saddr_cmp)(const struct sock *, const struct sock *)); const struct udp_get_port_ops *ops);
extern void udp_err(struct sk_buff *, u32); extern void udp_err(struct sk_buff *, u32);
extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk, extern int udp_sendmsg(struct kiocb *iocb, struct sock *sk,
......
...@@ -120,5 +120,5 @@ static inline __wsum udplite_csum_outgoing(struct sock *sk, struct sk_buff *skb) ...@@ -120,5 +120,5 @@ static inline __wsum udplite_csum_outgoing(struct sock *sk, struct sk_buff *skb)
extern void udplite4_register(void); extern void udplite4_register(void);
extern int udplite_get_port(struct sock *sk, unsigned short snum, extern int udplite_get_port(struct sock *sk, unsigned short snum,
int (*scmp)(const struct sock *, const struct sock *)); const struct udp_get_port_ops *ops);
#endif /* _UDPLITE_H */ #endif /* _UDPLITE_H */
...@@ -118,15 +118,15 @@ static int udp_port_rover; ...@@ -118,15 +118,15 @@ static int udp_port_rover;
* Note about this hash function : * Note about this hash function :
* Typical use is probably daddr = 0, only dport is going to vary hash * Typical use is probably daddr = 0, only dport is going to vary hash
*/ */
static inline unsigned int hash_port_and_addr(__u16 port, __be32 addr) static inline unsigned int udp_hash_port(__u16 port)
{ {
addr ^= addr >> 16; return port;
addr ^= addr >> 8;
return port ^ addr;
} }
static inline int __udp_lib_port_inuse(unsigned int hash, int port, static inline int __udp_lib_port_inuse(unsigned int hash, int port,
__be32 daddr, struct hlist_head udptable[]) const struct sock *this_sk,
struct hlist_head udptable[],
const struct udp_get_port_ops *ops)
{ {
struct sock *sk; struct sock *sk;
struct hlist_node *node; struct hlist_node *node;
...@@ -138,7 +138,10 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port, ...@@ -138,7 +138,10 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
inet = inet_sk(sk); inet = inet_sk(sk);
if (inet->num != port) if (inet->num != port)
continue; continue;
if (inet->rcv_saddr == daddr) if (this_sk) {
if (ops->saddr_cmp(sk, this_sk))
return 1;
} else if (ops->saddr_any(sk))
return 1; return 1;
} }
return 0; return 0;
...@@ -151,12 +154,11 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port, ...@@ -151,12 +154,11 @@ static inline int __udp_lib_port_inuse(unsigned int hash, int port,
* @snum: port number to look up * @snum: port number to look up
* @udptable: hash list table, must be of UDP_HTABLE_SIZE * @udptable: hash list table, must be of UDP_HTABLE_SIZE
* @port_rover: pointer to record of last unallocated port * @port_rover: pointer to record of last unallocated port
* @saddr_comp: AF-dependent comparison of bound local IP addresses * @ops: AF-dependent address operations
*/ */
int __udp_lib_get_port(struct sock *sk, unsigned short snum, int __udp_lib_get_port(struct sock *sk, unsigned short snum,
struct hlist_head udptable[], int *port_rover, struct hlist_head udptable[], int *port_rover,
int (*saddr_comp)(const struct sock *sk1, const struct udp_get_port_ops *ops)
const struct sock *sk2 ) )
{ {
struct hlist_node *node; struct hlist_node *node;
struct hlist_head *head; struct hlist_head *head;
...@@ -176,8 +178,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -176,8 +178,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) { for (i = 0; i < UDP_HTABLE_SIZE; i++, result++) {
int size; int size;
hash = hash_port_and_addr(result, hash = ops->hash_port_and_rcv_saddr(result, sk);
inet_sk(sk)->rcv_saddr);
head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
if (hlist_empty(head)) { if (hlist_empty(head)) {
if (result > sysctl_local_port_range[1]) if (result > sysctl_local_port_range[1])
...@@ -203,17 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -203,17 +204,16 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
result = sysctl_local_port_range[0] result = sysctl_local_port_range[0]
+ ((result - sysctl_local_port_range[0]) & + ((result - sysctl_local_port_range[0]) &
(UDP_HTABLE_SIZE - 1)); (UDP_HTABLE_SIZE - 1));
hash = hash_port_and_addr(result, 0); hash = udp_hash_port(result);
if (__udp_lib_port_inuse(hash, result, if (__udp_lib_port_inuse(hash, result,
0, udptable)) NULL, udptable, ops))
continue; continue;
if (!inet_sk(sk)->rcv_saddr) if (ops->saddr_any(sk))
break; break;
hash = hash_port_and_addr(result, hash = ops->hash_port_and_rcv_saddr(result, sk);
inet_sk(sk)->rcv_saddr);
if (! __udp_lib_port_inuse(hash, result, if (! __udp_lib_port_inuse(hash, result,
inet_sk(sk)->rcv_saddr, udptable)) sk, udptable, ops))
break; break;
} }
if (i >= (1 << 16) / UDP_HTABLE_SIZE) if (i >= (1 << 16) / UDP_HTABLE_SIZE)
...@@ -221,7 +221,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -221,7 +221,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
gotit: gotit:
*port_rover = snum = result; *port_rover = snum = result;
} else { } else {
hash = hash_port_and_addr(snum, 0); hash = udp_hash_port(snum);
head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
sk_for_each(sk2, node, head) sk_for_each(sk2, node, head)
...@@ -231,12 +231,11 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -231,12 +231,11 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
(!sk2->sk_reuse || !sk->sk_reuse) && (!sk2->sk_reuse || !sk->sk_reuse) &&
(!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
(*saddr_comp)(sk, sk2)) ops->saddr_cmp(sk, sk2))
goto fail; goto fail;
if (inet_sk(sk)->rcv_saddr) { if (!ops->saddr_any(sk)) {
hash = hash_port_and_addr(snum, hash = ops->hash_port_and_rcv_saddr(snum, sk);
inet_sk(sk)->rcv_saddr);
head = &udptable[hash & (UDP_HTABLE_SIZE - 1)]; head = &udptable[hash & (UDP_HTABLE_SIZE - 1)];
sk_for_each(sk2, node, head) sk_for_each(sk2, node, head)
...@@ -248,7 +247,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -248,7 +247,7 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
!sk->sk_bound_dev_if || !sk->sk_bound_dev_if ||
sk2->sk_bound_dev_if == sk2->sk_bound_dev_if ==
sk->sk_bound_dev_if) && sk->sk_bound_dev_if) &&
(*saddr_comp)(sk, sk2)) ops->saddr_cmp(sk, sk2))
goto fail; goto fail;
} }
} }
...@@ -266,12 +265,12 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum, ...@@ -266,12 +265,12 @@ int __udp_lib_get_port(struct sock *sk, unsigned short snum,
} }
int udp_get_port(struct sock *sk, unsigned short snum, int udp_get_port(struct sock *sk, unsigned short snum,
int (*scmp)(const struct sock *, const struct sock *)) const struct udp_get_port_ops *ops)
{ {
return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, scmp); return __udp_lib_get_port(sk, snum, udp_hash, &udp_port_rover, ops);
} }
int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
{ {
struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2); struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
...@@ -280,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2) ...@@ -280,9 +279,33 @@ int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
inet1->rcv_saddr == inet2->rcv_saddr )); inet1->rcv_saddr == inet2->rcv_saddr ));
} }
static int ipv4_rcv_saddr_any(const struct sock *sk)
{
return !inet_sk(sk)->rcv_saddr;
}
static inline unsigned int ipv4_hash_port_and_addr(__u16 port, __be32 addr)
{
addr ^= addr >> 16;
addr ^= addr >> 8;
return port ^ addr;
}
static unsigned int ipv4_hash_port_and_rcv_saddr(__u16 port,
const struct sock *sk)
{
return ipv4_hash_port_and_addr(port, inet_sk(sk)->rcv_saddr);
}
const struct udp_get_port_ops udp_ipv4_ops = {
.saddr_cmp = ipv4_rcv_saddr_equal,
.saddr_any = ipv4_rcv_saddr_any,
.hash_port_and_rcv_saddr = ipv4_hash_port_and_rcv_saddr,
};
static inline int udp_v4_get_port(struct sock *sk, unsigned short snum) static inline int udp_v4_get_port(struct sock *sk, unsigned short snum)
{ {
return udp_get_port(sk, snum, ipv4_rcv_saddr_equal); return udp_get_port(sk, snum, &udp_ipv4_ops);
} }
/* UDP is nearly always wildcards out the wazoo, it makes no sense to try /* UDP is nearly always wildcards out the wazoo, it makes no sense to try
...@@ -297,8 +320,8 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport, ...@@ -297,8 +320,8 @@ static struct sock *__udp4_lib_lookup(__be32 saddr, __be16 sport,
unsigned int hash, hashwild; unsigned int hash, hashwild;
int score, best = -1, hport = ntohs(dport); int score, best = -1, hport = ntohs(dport);
hash = hash_port_and_addr(hport, daddr); hash = ipv4_hash_port_and_addr(hport, daddr);
hashwild = hash_port_and_addr(hport, 0); hashwild = udp_hash_port(hport);
read_lock(&udp_hash_lock); read_lock(&udp_hash_lock);
...@@ -1198,8 +1221,8 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb, ...@@ -1198,8 +1221,8 @@ static int __udp4_lib_mcast_deliver(struct sk_buff *skb,
struct sock *sk, *skw, *sknext; struct sock *sk, *skw, *sknext;
int dif; int dif;
int hport = ntohs(uh->dest); int hport = ntohs(uh->dest);
unsigned int hash = hash_port_and_addr(hport, daddr); unsigned int hash = ipv4_hash_port_and_addr(hport, daddr);
unsigned int hashwild = hash_port_and_addr(hport, 0); unsigned int hashwild = udp_hash_port(hport);
dif = skb->dev->ifindex; dif = skb->dev->ifindex;
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
#include <net/protocol.h> #include <net/protocol.h>
#include <net/inet_common.h> #include <net/inet_common.h>
extern const struct udp_get_port_ops udp_ipv4_ops;
extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int ); extern int __udp4_lib_rcv(struct sk_buff *, struct hlist_head [], int );
extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []); extern void __udp4_lib_err(struct sk_buff *, u32, struct hlist_head []);
extern int __udp_lib_get_port(struct sock *sk, unsigned short snum, extern int __udp_lib_get_port(struct sock *sk, unsigned short snum,
struct hlist_head udptable[], int *port_rover, struct hlist_head udptable[], int *port_rover,
int (*)(const struct sock*,const struct sock*)); const struct udp_get_port_ops *ops);
extern int ipv4_rcv_saddr_equal(const struct sock *, const struct sock *);
extern int udp_setsockopt(struct sock *sk, int level, int optname, extern int udp_setsockopt(struct sock *sk, int level, int optname,
char __user *optval, int optlen); char __user *optval, int optlen);
......
...@@ -19,14 +19,15 @@ struct hlist_head udplite_hash[UDP_HTABLE_SIZE]; ...@@ -19,14 +19,15 @@ struct hlist_head udplite_hash[UDP_HTABLE_SIZE];
static int udplite_port_rover; static int udplite_port_rover;
int udplite_get_port(struct sock *sk, unsigned short p, int udplite_get_port(struct sock *sk, unsigned short p,
int (*c)(const struct sock *, const struct sock *)) const struct udp_get_port_ops *ops)
{ {
return __udp_lib_get_port(sk, p, udplite_hash, &udplite_port_rover, c); return __udp_lib_get_port(sk, p, udplite_hash,
&udplite_port_rover, ops);
} }
static int udplite_v4_get_port(struct sock *sk, unsigned short snum) static int udplite_v4_get_port(struct sock *sk, unsigned short snum)
{ {
return udplite_get_port(sk, snum, ipv4_rcv_saddr_equal); return udplite_get_port(sk, snum, &udp_ipv4_ops);
} }
static int udplite_rcv(struct sk_buff *skb) static int udplite_rcv(struct sk_buff *skb)
......
...@@ -52,9 +52,28 @@ ...@@ -52,9 +52,28 @@
DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly; DEFINE_SNMP_STAT(struct udp_mib, udp_stats_in6) __read_mostly;
static int ipv6_rcv_saddr_any(const struct sock *sk)
{
struct ipv6_pinfo *np = inet6_sk(sk);
return ipv6_addr_any(&np->rcv_saddr);
}
static unsigned int ipv6_hash_port_and_rcv_saddr(__u16 port,
const struct sock *sk)
{
return port;
}
const struct udp_get_port_ops udp_ipv6_ops = {
.saddr_cmp = ipv6_rcv_saddr_equal,
.saddr_any = ipv6_rcv_saddr_any,
.hash_port_and_rcv_saddr = ipv6_hash_port_and_rcv_saddr,
};
static inline int udp_v6_get_port(struct sock *sk, unsigned short snum) static inline int udp_v6_get_port(struct sock *sk, unsigned short snum)
{ {
return udp_get_port(sk, snum, ipv6_rcv_saddr_equal); return udp_get_port(sk, snum, &udp_ipv6_ops);
} }
static struct sock *__udp6_lib_lookup(struct in6_addr *saddr, __be16 sport, static struct sock *__udp6_lib_lookup(struct in6_addr *saddr, __be16 sport,
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <net/addrconf.h> #include <net/addrconf.h>
#include <net/inet_common.h> #include <net/inet_common.h>
extern const struct udp_get_port_ops udp_ipv6_ops;
extern int __udp6_lib_rcv(struct sk_buff **, struct hlist_head [], int ); extern int __udp6_lib_rcv(struct sk_buff **, struct hlist_head [], int );
extern void __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *, extern void __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *,
int , int , int , __be32 , struct hlist_head []); int , int , int , __be32 , struct hlist_head []);
......
...@@ -37,7 +37,7 @@ static struct inet6_protocol udplitev6_protocol = { ...@@ -37,7 +37,7 @@ static struct inet6_protocol udplitev6_protocol = {
static int udplite_v6_get_port(struct sock *sk, unsigned short snum) static int udplite_v6_get_port(struct sock *sk, unsigned short snum)
{ {
return udplite_get_port(sk, snum, ipv6_rcv_saddr_equal); return udplite_get_port(sk, snum, &udp_ipv6_ops);
} }
struct proto udplitev6_prot = { struct proto udplitev6_prot = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment