Commit 859f8b26 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

ipv6: lockless IPV6_FLOWINFO_SEND implementation

np->sndflow reads are racy.

Use one bit ftom atomic inet->inet_flags instead,
IPV6_FLOWINFO_SEND setsockopt() can be lockless.
Signed-off-by: default avatarEric Dumazet <edumazet@google.com>
Reviewed-by: default avatarDavid Ahern <dsahern@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 6b724bc4
...@@ -243,8 +243,7 @@ struct ipv6_pinfo { ...@@ -243,8 +243,7 @@ struct ipv6_pinfo {
} rxopt; } rxopt;
/* sockopt flags */ /* sockopt flags */
__u8 sndflow:1, __u8 srcprefs:3; /* 001: prefer temporary address
srcprefs:3; /* 001: prefer temporary address
* 010: prefer public address * 010: prefer public address
* 100: prefer care-of address * 100: prefer care-of address
*/ */
......
...@@ -277,6 +277,7 @@ enum { ...@@ -277,6 +277,7 @@ enum {
INET_FLAGS_RECVERR6 = 26, INET_FLAGS_RECVERR6 = 26,
INET_FLAGS_REPFLOW = 27, INET_FLAGS_REPFLOW = 27,
INET_FLAGS_RTALERT_ISOLATE = 28, INET_FLAGS_RTALERT_ISOLATE = 28,
INET_FLAGS_SNDFLOW = 29,
}; };
/* cmsg flags for inet */ /* cmsg flags for inet */
......
...@@ -844,7 +844,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr, ...@@ -844,7 +844,7 @@ static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
memset(&fl6, 0, sizeof(fl6)); memset(&fl6, 0, sizeof(fl6));
if (np->sndflow) { if (inet6_test_bit(SNDFLOW, sk)) {
fl6.flowlabel = usin->sin6_flowinfo & IPV6_FLOWINFO_MASK; fl6.flowlabel = usin->sin6_flowinfo & IPV6_FLOWINFO_MASK;
IP6_ECN_flow_init(fl6.flowlabel); IP6_ECN_flow_init(fl6.flowlabel);
if (fl6.flowlabel & IPV6_FLOWLABEL_MASK) { if (fl6.flowlabel & IPV6_FLOWLABEL_MASK) {
......
...@@ -899,7 +899,6 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, ...@@ -899,7 +899,6 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
#if IS_ENABLED(CONFIG_IPV6) #if IS_ENABLED(CONFIG_IPV6)
} else if (family == AF_INET6) { } else if (family == AF_INET6) {
struct ipv6_pinfo *np = inet6_sk(sk);
struct ipv6hdr *ip6 = ipv6_hdr(skb); struct ipv6hdr *ip6 = ipv6_hdr(skb);
DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name);
...@@ -908,7 +907,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, ...@@ -908,7 +907,7 @@ int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags,
sin6->sin6_port = 0; sin6->sin6_port = 0;
sin6->sin6_addr = ip6->saddr; sin6->sin6_addr = ip6->saddr;
sin6->sin6_flowinfo = 0; sin6->sin6_flowinfo = 0;
if (np->sndflow) if (inet6_test_bit(SNDFLOW, sk))
sin6->sin6_flowinfo = ip6_flowinfo(ip6); sin6->sin6_flowinfo = ip6_flowinfo(ip6);
sin6->sin6_scope_id = sin6->sin6_scope_id =
ipv6_iface_scope_id(&sin6->sin6_addr, ipv6_iface_scope_id(&sin6->sin6_addr,
......
...@@ -537,7 +537,7 @@ int inet6_getname(struct socket *sock, struct sockaddr *uaddr, ...@@ -537,7 +537,7 @@ int inet6_getname(struct socket *sock, struct sockaddr *uaddr,
} }
sin->sin6_port = inet->inet_dport; sin->sin6_port = inet->inet_dport;
sin->sin6_addr = sk->sk_v6_daddr; sin->sin6_addr = sk->sk_v6_daddr;
if (np->sndflow) if (inet6_test_bit(SNDFLOW, sk))
sin->sin6_flowinfo = np->flow_label; sin->sin6_flowinfo = np->flow_label;
BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin,
CGROUP_INET6_GETPEERNAME); CGROUP_INET6_GETPEERNAME);
......
...@@ -80,7 +80,8 @@ int ip6_datagram_dst_update(struct sock *sk, bool fix_sk_saddr) ...@@ -80,7 +80,8 @@ int ip6_datagram_dst_update(struct sock *sk, bool fix_sk_saddr)
struct flowi6 fl6; struct flowi6 fl6;
int err = 0; int err = 0;
if (np->sndflow && (np->flow_label & IPV6_FLOWLABEL_MASK)) { if (inet6_test_bit(SNDFLOW, sk) &&
(np->flow_label & IPV6_FLOWLABEL_MASK)) {
flowlabel = fl6_sock_lookup(sk, np->flow_label); flowlabel = fl6_sock_lookup(sk, np->flow_label);
if (IS_ERR(flowlabel)) if (IS_ERR(flowlabel))
return -EINVAL; return -EINVAL;
...@@ -163,7 +164,7 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, ...@@ -163,7 +164,7 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr,
if (usin->sin6_family != AF_INET6) if (usin->sin6_family != AF_INET6)
return -EAFNOSUPPORT; return -EAFNOSUPPORT;
if (np->sndflow) if (inet6_test_bit(SNDFLOW, sk))
fl6_flowlabel = usin->sin6_flowinfo & IPV6_FLOWINFO_MASK; fl6_flowlabel = usin->sin6_flowinfo & IPV6_FLOWINFO_MASK;
if (ipv6_addr_any(&usin->sin6_addr)) { if (ipv6_addr_any(&usin->sin6_addr)) {
...@@ -491,7 +492,7 @@ int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) ...@@ -491,7 +492,7 @@ int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
const struct ipv6hdr *ip6h = container_of((struct in6_addr *)(nh + serr->addr_offset), const struct ipv6hdr *ip6h = container_of((struct in6_addr *)(nh + serr->addr_offset),
struct ipv6hdr, daddr); struct ipv6hdr, daddr);
sin->sin6_addr = ip6h->daddr; sin->sin6_addr = ip6h->daddr;
if (np->sndflow) if (inet6_test_bit(SNDFLOW, sk))
sin->sin6_flowinfo = ip6_flowinfo(ip6h); sin->sin6_flowinfo = ip6_flowinfo(ip6h);
sin->sin6_scope_id = sin->sin6_scope_id =
ipv6_iface_scope_id(&sin->sin6_addr, ipv6_iface_scope_id(&sin->sin6_addr,
......
...@@ -500,6 +500,11 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -500,6 +500,11 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
return -EINVAL; return -EINVAL;
WRITE_ONCE(np->pmtudisc, val); WRITE_ONCE(np->pmtudisc, val);
return 0; return 0;
case IPV6_FLOWINFO_SEND:
if (optlen < sizeof(int))
return -EINVAL;
inet6_assign_bit(SNDFLOW, sk, valbool);
return 0;
} }
if (needs_rtnl) if (needs_rtnl)
rtnl_lock(); rtnl_lock();
...@@ -948,12 +953,6 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname, ...@@ -948,12 +953,6 @@ int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
goto e_inval; goto e_inval;
retv = ip6_ra_control(sk, val); retv = ip6_ra_control(sk, val);
break; break;
case IPV6_FLOWINFO_SEND:
if (optlen < sizeof(int))
goto e_inval;
np->sndflow = valbool;
retv = 0;
break;
case IPV6_FLOWLABEL_MGR: case IPV6_FLOWLABEL_MGR:
retv = ipv6_flowlabel_opt(sk, optval, optlen); retv = ipv6_flowlabel_opt(sk, optval, optlen);
break; break;
...@@ -1381,7 +1380,7 @@ int do_ipv6_getsockopt(struct sock *sk, int level, int optname, ...@@ -1381,7 +1380,7 @@ int do_ipv6_getsockopt(struct sock *sk, int level, int optname,
break; break;
case IPV6_FLOWINFO_SEND: case IPV6_FLOWINFO_SEND:
val = np->sndflow; val = inet6_test_bit(SNDFLOW, sk);
break; break;
case IPV6_FLOWLABEL_MGR: case IPV6_FLOWLABEL_MGR:
......
...@@ -89,7 +89,7 @@ static int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -89,7 +89,7 @@ static int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
return -EAFNOSUPPORT; return -EAFNOSUPPORT;
} }
daddr = &(u->sin6_addr); daddr = &(u->sin6_addr);
if (np->sndflow) if (inet6_test_bit(SNDFLOW, sk))
fl6.flowlabel = u->sin6_flowinfo & IPV6_FLOWINFO_MASK; fl6.flowlabel = u->sin6_flowinfo & IPV6_FLOWINFO_MASK;
if (__ipv6_addr_needs_scope_id(ipv6_addr_type(daddr))) if (__ipv6_addr_needs_scope_id(ipv6_addr_type(daddr)))
oif = u->sin6_scope_id; oif = u->sin6_scope_id;
......
...@@ -795,7 +795,7 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -795,7 +795,7 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
return -EINVAL; return -EINVAL;
daddr = &sin6->sin6_addr; daddr = &sin6->sin6_addr;
if (np->sndflow) { if (inet6_test_bit(SNDFLOW, sk)) {
fl6.flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK; fl6.flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK;
if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) { if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) {
flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); flowlabel = fl6_sock_lookup(sk, fl6.flowlabel);
......
...@@ -163,7 +163,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, ...@@ -163,7 +163,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
memset(&fl6, 0, sizeof(fl6)); memset(&fl6, 0, sizeof(fl6));
if (np->sndflow) { if (inet6_test_bit(SNDFLOW, sk)) {
fl6.flowlabel = usin->sin6_flowinfo&IPV6_FLOWINFO_MASK; fl6.flowlabel = usin->sin6_flowinfo&IPV6_FLOWINFO_MASK;
IP6_ECN_flow_init(fl6.flowlabel); IP6_ECN_flow_init(fl6.flowlabel);
if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) { if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) {
......
...@@ -1429,7 +1429,7 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -1429,7 +1429,7 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
fl6->fl6_dport = sin6->sin6_port; fl6->fl6_dport = sin6->sin6_port;
daddr = &sin6->sin6_addr; daddr = &sin6->sin6_addr;
if (np->sndflow) { if (inet6_test_bit(SNDFLOW, sk)) {
fl6->flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK; fl6->flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK;
if (fl6->flowlabel & IPV6_FLOWLABEL_MASK) { if (fl6->flowlabel & IPV6_FLOWLABEL_MASK) {
flowlabel = fl6_sock_lookup(sk, fl6->flowlabel); flowlabel = fl6_sock_lookup(sk, fl6->flowlabel);
......
...@@ -431,7 +431,7 @@ static int l2tp_ip6_getname(struct socket *sock, struct sockaddr *uaddr, ...@@ -431,7 +431,7 @@ static int l2tp_ip6_getname(struct socket *sock, struct sockaddr *uaddr,
return -ENOTCONN; return -ENOTCONN;
lsa->l2tp_conn_id = lsk->peer_conn_id; lsa->l2tp_conn_id = lsk->peer_conn_id;
lsa->l2tp_addr = sk->sk_v6_daddr; lsa->l2tp_addr = sk->sk_v6_daddr;
if (np->sndflow) if (inet6_test_bit(SNDFLOW, sk))
lsa->l2tp_flowinfo = np->flow_label; lsa->l2tp_flowinfo = np->flow_label;
} else { } else {
if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) if (ipv6_addr_any(&sk->sk_v6_rcv_saddr))
...@@ -529,7 +529,7 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -529,7 +529,7 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
return -EAFNOSUPPORT; return -EAFNOSUPPORT;
daddr = &lsa->l2tp_addr; daddr = &lsa->l2tp_addr;
if (np->sndflow) { if (inet6_test_bit(SNDFLOW, sk)) {
fl6.flowlabel = lsa->l2tp_flowinfo & IPV6_FLOWINFO_MASK; fl6.flowlabel = lsa->l2tp_flowinfo & IPV6_FLOWINFO_MASK;
if (fl6.flowlabel & IPV6_FLOWLABEL_MASK) { if (fl6.flowlabel & IPV6_FLOWLABEL_MASK) {
flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); flowlabel = fl6_sock_lookup(sk, fl6.flowlabel);
......
...@@ -296,7 +296,8 @@ static void sctp_v6_get_dst(struct sctp_transport *t, union sctp_addr *saddr, ...@@ -296,7 +296,8 @@ static void sctp_v6_get_dst(struct sctp_transport *t, union sctp_addr *saddr,
if (t->flowlabel & SCTP_FLOWLABEL_SET_MASK) if (t->flowlabel & SCTP_FLOWLABEL_SET_MASK)
fl6->flowlabel = htonl(t->flowlabel & SCTP_FLOWLABEL_VAL_MASK); fl6->flowlabel = htonl(t->flowlabel & SCTP_FLOWLABEL_VAL_MASK);
if (np->sndflow && (fl6->flowlabel & IPV6_FLOWLABEL_MASK)) { if (inet6_test_bit(SNDFLOW, sk) &&
(fl6->flowlabel & IPV6_FLOWLABEL_MASK)) {
struct ip6_flowlabel *flowlabel; struct ip6_flowlabel *flowlabel;
flowlabel = fl6_sock_lookup(sk, fl6->flowlabel); flowlabel = fl6_sock_lookup(sk, fl6->flowlabel);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment