Commit ddb1a072 authored by Paolo Abeni's avatar Paolo Abeni Committed by David S. Miller

mptcp: move first subflow allocation at mpc access time

In the long run this will simplify the mptcp code and will
allow for more consistent behavior. Move the first subflow
allocation out of the sock->init ops into the __mptcp_nmpc_socket()
helper.

Since the first subflow creation can now happen after the first
setsockopt() we additionally need to invoke mptcp_sockopt_sync()
on it.
Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
Reviewed-by: default avatarMatthieu Baerts <matthieu.baerts@tessares.net>
Signed-off-by: default avatarMatthieu Baerts <matthieu.baerts@tessares.net>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent a2702a07
...@@ -1035,8 +1035,8 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, ...@@ -1035,8 +1035,8 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
lock_sock(newsk); lock_sock(newsk);
ssock = __mptcp_nmpc_socket(mptcp_sk(newsk)); ssock = __mptcp_nmpc_socket(mptcp_sk(newsk));
release_sock(newsk); release_sock(newsk);
if (!ssock) if (IS_ERR(ssock))
return -EINVAL; return PTR_ERR(ssock);
mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
......
...@@ -49,18 +49,6 @@ static void __mptcp_check_send_data_fin(struct sock *sk); ...@@ -49,18 +49,6 @@ static void __mptcp_check_send_data_fin(struct sock *sk);
DEFINE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions); DEFINE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions);
static struct net_device mptcp_napi_dev; static struct net_device mptcp_napi_dev;
/* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
* completed yet or has failed, return the subflow socket.
* Otherwise return NULL.
*/
struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
{
if (!msk->subflow || READ_ONCE(msk->can_ack))
return NULL;
return msk->subflow;
}
/* Returns end sequence number of the receiver's advertised window */ /* Returns end sequence number of the receiver's advertised window */
static u64 mptcp_wnd_end(const struct mptcp_sock *msk) static u64 mptcp_wnd_end(const struct mptcp_sock *msk)
{ {
...@@ -116,6 +104,31 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) ...@@ -116,6 +104,31 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)
return 0; return 0;
} }
/* If the MPC handshake is not started, returns the first subflow,
* eventually allocating it.
*/
struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk)
{
struct sock *sk = (struct sock *)msk;
int ret;
if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)))
return ERR_PTR(-EINVAL);
if (!msk->subflow) {
if (msk->first)
return ERR_PTR(-EINVAL);
ret = __mptcp_socket_create(msk);
if (ret)
return ERR_PTR(ret);
mptcp_sockopt_sync(msk, msk->first);
}
return msk->subflow;
}
static void mptcp_drop(struct sock *sk, struct sk_buff *skb) static void mptcp_drop(struct sock *sk, struct sk_buff *skb)
{ {
sk_drops_add(sk, skb); sk_drops_add(sk, skb);
...@@ -1667,6 +1680,7 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, ...@@ -1667,6 +1680,7 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
{ {
unsigned int saved_flags = msg->msg_flags; unsigned int saved_flags = msg->msg_flags;
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct socket *ssock;
struct sock *ssk; struct sock *ssk;
int ret; int ret;
...@@ -1676,8 +1690,11 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, ...@@ -1676,8 +1690,11 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
* Since the defer_connect flag is cleared after the first succsful * Since the defer_connect flag is cleared after the first succsful
* fastopen attempt, no need to check for additional subflow status. * fastopen attempt, no need to check for additional subflow status.
*/ */
if (msg->msg_flags & MSG_FASTOPEN && !__mptcp_nmpc_socket(msk)) if (msg->msg_flags & MSG_FASTOPEN) {
return -EINVAL; ssock = __mptcp_nmpc_socket(msk);
if (IS_ERR(ssock))
return PTR_ERR(ssock);
}
if (!msk->first) if (!msk->first)
return -EINVAL; return -EINVAL;
...@@ -2740,10 +2757,6 @@ static int mptcp_init_sock(struct sock *sk) ...@@ -2740,10 +2757,6 @@ static int mptcp_init_sock(struct sock *sk)
if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net)) if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net))
return -ENOMEM; return -ENOMEM;
ret = __mptcp_socket_create(mptcp_sk(sk));
if (ret)
return ret;
set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags); set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags);
/* fetch the ca name; do it outside __mptcp_init_sock(), so that clone will /* fetch the ca name; do it outside __mptcp_init_sock(), so that clone will
...@@ -3563,8 +3576,8 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) ...@@ -3563,8 +3576,8 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
int err = -EINVAL; int err = -EINVAL;
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) if (IS_ERR(ssock))
return -EINVAL; return PTR_ERR(ssock);
mptcp_token_destroy(msk); mptcp_token_destroy(msk);
inet_sk_state_store(sk, TCP_SYN_SENT); inet_sk_state_store(sk, TCP_SYN_SENT);
...@@ -3652,8 +3665,8 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -3652,8 +3665,8 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
lock_sock(sock->sk); lock_sock(sock->sk);
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) { if (IS_ERR(ssock)) {
err = -EINVAL; err = PTR_ERR(ssock);
goto unlock; goto unlock;
} }
...@@ -3689,8 +3702,8 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -3689,8 +3702,8 @@ static int mptcp_listen(struct socket *sock, int backlog)
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) { if (IS_ERR(ssock)) {
err = -EINVAL; err = PTR_ERR(ssock);
goto unlock; goto unlock;
} }
......
...@@ -627,7 +627,7 @@ void mptcp_close_ssk(struct sock *sk, struct sock *ssk, ...@@ -627,7 +627,7 @@ void mptcp_close_ssk(struct sock *sk, struct sock *ssk,
void __mptcp_subflow_send_ack(struct sock *ssk); void __mptcp_subflow_send_ack(struct sock *ssk);
void mptcp_subflow_reset(struct sock *ssk); void mptcp_subflow_reset(struct sock *ssk);
void mptcp_sock_graft(struct sock *sk, struct socket *parent); void mptcp_sock_graft(struct sock *sk, struct socket *parent);
struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk); struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk);
bool __mptcp_close(struct sock *sk, long timeout); bool __mptcp_close(struct sock *sk, long timeout);
void mptcp_cancel_work(struct sock *sk); void mptcp_cancel_work(struct sock *sk);
void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk); void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk);
......
...@@ -301,9 +301,9 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, ...@@ -301,9 +301,9 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
case SO_BINDTOIFINDEX: case SO_BINDTOIFINDEX:
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) { if (IS_ERR(ssock)) {
release_sock(sk); release_sock(sk);
return -EINVAL; return PTR_ERR(ssock);
} }
ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen); ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen);
...@@ -396,9 +396,9 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, ...@@ -396,9 +396,9 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
case IPV6_FREEBIND: case IPV6_FREEBIND:
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) { if (IS_ERR(ssock)) {
release_sock(sk); release_sock(sk);
return -EINVAL; return PTR_ERR(ssock);
} }
ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen); ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen);
...@@ -693,9 +693,9 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o ...@@ -693,9 +693,9 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) { if (IS_ERR(ssock)) {
release_sock(sk); release_sock(sk);
return -EINVAL; return PTR_ERR(ssock);
} }
issk = inet_sk(ssock->sk); issk = inet_sk(ssock->sk);
...@@ -762,13 +762,15 @@ static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int ...@@ -762,13 +762,15 @@ static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct socket *sock; struct socket *sock;
int ret = -EINVAL; int ret;
/* Limit to first subflow, before the connection establishment */ /* Limit to first subflow, before the connection establishment */
lock_sock(sk); lock_sock(sk);
sock = __mptcp_nmpc_socket(msk); sock = __mptcp_nmpc_socket(msk);
if (!sock) if (IS_ERR(sock)) {
ret = PTR_ERR(sock);
goto unlock; goto unlock;
}
ret = tcp_setsockopt(sock->sk, level, optname, optval, optlen); ret = tcp_setsockopt(sock->sk, level, optname, optval, optlen);
...@@ -861,7 +863,7 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int ...@@ -861,7 +863,7 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct socket *ssock; struct socket *ssock;
int ret = -EINVAL; int ret;
struct sock *ssk; struct sock *ssk;
lock_sock(sk); lock_sock(sk);
...@@ -872,8 +874,10 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int ...@@ -872,8 +874,10 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
} }
ssock = __mptcp_nmpc_socket(msk); ssock = __mptcp_nmpc_socket(msk);
if (!ssock) if (IS_ERR(ssock)) {
ret = PTR_ERR(ssock);
goto out; goto out;
}
ret = tcp_getsockopt(ssock->sk, level, optname, optval, optlen); ret = tcp_getsockopt(ssock->sk, level, optname, optval, optlen);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment