Commit f9d31c4c authored by Xin Long's avatar Xin Long Committed by David S. Miller

sctp: hold endpoint before calling cb in sctp_transport_lookup_process

The same fix in commit 5ec7d18d ("sctp: use call_rcu to free endpoint")
is also needed for dumping one asoc and sock after the lookup.

Fixes: 86fdb344 ("sctp: ensure ep is not destroyed before doing the dump")
Signed-off-by: default avatarXin Long <lucien.xin@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 5b40d10b
...@@ -112,8 +112,7 @@ struct sctp_transport *sctp_transport_get_next(struct net *net, ...@@ -112,8 +112,7 @@ struct sctp_transport *sctp_transport_get_next(struct net *net,
struct rhashtable_iter *iter); struct rhashtable_iter *iter);
struct sctp_transport *sctp_transport_get_idx(struct net *net, struct sctp_transport *sctp_transport_get_idx(struct net *net,
struct rhashtable_iter *iter, int pos); struct rhashtable_iter *iter, int pos);
int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *), int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net,
struct net *net,
const union sctp_addr *laddr, const union sctp_addr *laddr,
const union sctp_addr *paddr, void *p); const union sctp_addr *paddr, void *p);
int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done, int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done,
......
...@@ -245,48 +245,44 @@ static size_t inet_assoc_attr_size(struct sctp_association *asoc) ...@@ -245,48 +245,44 @@ static size_t inet_assoc_attr_size(struct sctp_association *asoc)
+ 64; + 64;
} }
static int sctp_tsp_dump_one(struct sctp_transport *tsp, void *p) static int sctp_sock_dump_one(struct sctp_endpoint *ep, struct sctp_transport *tsp, void *p)
{ {
struct sctp_association *assoc = tsp->asoc; struct sctp_association *assoc = tsp->asoc;
struct sock *sk = tsp->asoc->base.sk;
struct sctp_comm_param *commp = p; struct sctp_comm_param *commp = p;
struct sk_buff *in_skb = commp->skb; struct sock *sk = ep->base.sk;
const struct inet_diag_req_v2 *req = commp->r; const struct inet_diag_req_v2 *req = commp->r;
const struct nlmsghdr *nlh = commp->nlh; struct sk_buff *skb = commp->skb;
struct net *net = sock_net(in_skb->sk);
struct sk_buff *rep; struct sk_buff *rep;
int err; int err;
err = sock_diag_check_cookie(sk, req->id.idiag_cookie); err = sock_diag_check_cookie(sk, req->id.idiag_cookie);
if (err) if (err)
goto out; return err;
err = -ENOMEM;
rep = nlmsg_new(inet_assoc_attr_size(assoc), GFP_KERNEL); rep = nlmsg_new(inet_assoc_attr_size(assoc), GFP_KERNEL);
if (!rep) if (!rep)
goto out; return -ENOMEM;
lock_sock(sk); lock_sock(sk);
if (sk != assoc->base.sk) { if (ep != assoc->ep) {
release_sock(sk); err = -EAGAIN;
sk = assoc->base.sk; goto out;
lock_sock(sk);
} }
err = inet_sctp_diag_fill(sk, assoc, rep, req,
sk_user_ns(NETLINK_CB(in_skb).sk), err = inet_sctp_diag_fill(sk, assoc, rep, req, sk_user_ns(NETLINK_CB(skb).sk),
NETLINK_CB(in_skb).portid, NETLINK_CB(skb).portid, commp->nlh->nlmsg_seq, 0,
nlh->nlmsg_seq, 0, nlh, commp->nlh, commp->net_admin);
commp->net_admin);
release_sock(sk);
if (err < 0) { if (err < 0) {
WARN_ON(err == -EMSGSIZE); WARN_ON(err == -EMSGSIZE);
kfree_skb(rep);
goto out; goto out;
} }
release_sock(sk);
err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); return nlmsg_unicast(sock_net(skb->sk)->diag_nlsk, rep, NETLINK_CB(skb).portid);
out: out:
release_sock(sk);
kfree_skb(rep);
return err; return err;
} }
...@@ -429,15 +425,15 @@ static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, ...@@ -429,15 +425,15 @@ static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
static int sctp_diag_dump_one(struct netlink_callback *cb, static int sctp_diag_dump_one(struct netlink_callback *cb,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
struct sk_buff *in_skb = cb->skb; struct sk_buff *skb = cb->skb;
struct net *net = sock_net(in_skb->sk); struct net *net = sock_net(skb->sk);
const struct nlmsghdr *nlh = cb->nlh; const struct nlmsghdr *nlh = cb->nlh;
union sctp_addr laddr, paddr; union sctp_addr laddr, paddr;
struct sctp_comm_param commp = { struct sctp_comm_param commp = {
.skb = in_skb, .skb = skb,
.r = req, .r = req,
.nlh = nlh, .nlh = nlh,
.net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN), .net_admin = netlink_net_capable(skb, CAP_NET_ADMIN),
}; };
if (req->sdiag_family == AF_INET) { if (req->sdiag_family == AF_INET) {
...@@ -460,7 +456,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb, ...@@ -460,7 +456,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb,
paddr.v6.sin6_family = AF_INET6; paddr.v6.sin6_family = AF_INET6;
} }
return sctp_transport_lookup_process(sctp_tsp_dump_one, return sctp_transport_lookup_process(sctp_sock_dump_one,
net, &laddr, &paddr, &commp); net, &laddr, &paddr, &commp);
} }
......
...@@ -5317,23 +5317,31 @@ int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), ...@@ -5317,23 +5317,31 @@ int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *),
} }
EXPORT_SYMBOL_GPL(sctp_for_each_endpoint); EXPORT_SYMBOL_GPL(sctp_for_each_endpoint);
int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *), int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net,
struct net *net,
const union sctp_addr *laddr, const union sctp_addr *laddr,
const union sctp_addr *paddr, void *p) const union sctp_addr *paddr, void *p)
{ {
struct sctp_transport *transport; struct sctp_transport *transport;
int err; struct sctp_endpoint *ep;
int err = -ENOENT;
rcu_read_lock(); rcu_read_lock();
transport = sctp_addrs_lookup_transport(net, laddr, paddr); transport = sctp_addrs_lookup_transport(net, laddr, paddr);
if (!transport) {
rcu_read_unlock(); rcu_read_unlock();
if (!transport) return err;
return -ENOENT; }
ep = transport->asoc->ep;
err = cb(transport, p); if (!sctp_endpoint_hold(ep)) { /* asoc can be peeled off */
sctp_transport_put(transport); sctp_transport_put(transport);
rcu_read_unlock();
return err;
}
rcu_read_unlock();
err = cb(ep, transport, p);
sctp_endpoint_put(ep);
sctp_transport_put(transport);
return err; return err;
} }
EXPORT_SYMBOL_GPL(sctp_transport_lookup_process); EXPORT_SYMBOL_GPL(sctp_transport_lookup_process);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment