diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index f2930699c6d3..b6dc6e260334 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -1673,6 +1673,37 @@ static void mptcp_set_nospace(struct sock *sk) set_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags); } +static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msghdr *msg, + size_t len, int *copied_syn) +{ + unsigned int saved_flags = msg->msg_flags; + struct mptcp_sock *msk = mptcp_sk(sk); + int ret; + + lock_sock(ssk); + msg->msg_flags |= MSG_DONTWAIT; + msk->connect_flags = O_NONBLOCK; + msk->is_sendmsg = 1; + ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL); + msk->is_sendmsg = 0; + msg->msg_flags = saved_flags; + release_sock(ssk); + + /* do the blocking bits of inet_stream_connect outside the ssk socket lock */ + if (ret == -EINPROGRESS && !(msg->msg_flags & MSG_DONTWAIT)) { + ret = __inet_stream_connect(sk->sk_socket, msg->msg_name, + msg->msg_namelen, msg->msg_flags, 1); + + /* Keep the same behaviour of plain TCP: zero the copied bytes in + * case of any error, except timeout or signal + */ + if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR) + *copied_syn = 0; + } + + return ret; +} + static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) { struct mptcp_sock *msk = mptcp_sk(sk); @@ -1693,26 +1724,14 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ssock = __mptcp_nmpc_socket(msk); if (unlikely(ssock && inet_sk(ssock->sk)->defer_connect)) { - struct sock *ssk = ssock->sk; int copied_syn = 0; - lock_sock(ssk); - - msk->connect_flags = (msg->msg_flags & MSG_DONTWAIT) ? O_NONBLOCK : 0; - msk->is_sendmsg = 1; - ret = tcp_sendmsg_fastopen(ssk, msg, &copied_syn, len, NULL); - msk->is_sendmsg = 0; + ret = mptcp_sendmsg_fastopen(sk, ssock->sk, msg, len, &copied_syn); copied += copied_syn; - if (ret == -EINPROGRESS && copied_syn > 0) { - /* reflect the new state on the MPTCP socket */ - inet_sk_state_store(sk, inet_sk_state_load(ssk)); - release_sock(ssk); + if (ret == -EINPROGRESS && copied_syn > 0) goto out; - } else if (ret) { - release_sock(ssk); + else if (ret) goto do_error; - } - release_sock(ssk); } timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);