diff --git a/include/net/sock.h b/include/net/sock.h index 5db02546941c..e0517ecc6531 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -323,7 +323,7 @@ struct sk_filter; * @sk_tskey: counter to disambiguate concurrent tstamp requests * @sk_zckey: counter to order MSG_ZEROCOPY notifications * @sk_socket: Identd and reporting IO signals - * @sk_user_data: RPC layer private data + * @sk_user_data: RPC layer private data. Write-protected by @sk_callback_lock. * @sk_frag: cached page frag * @sk_peek_off: current peek_offset value * @sk_send_head: front of stuff to transmit diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 7499c51b1850..754fdda8a5f5 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -1150,8 +1150,10 @@ static void l2tp_tunnel_destruct(struct sock *sk) } /* Remove hooks into tunnel socket */ + write_lock_bh(&sk->sk_callback_lock); sk->sk_destruct = tunnel->old_sk_destruct; sk->sk_user_data = NULL; + write_unlock_bh(&sk->sk_callback_lock); /* Call the original destructor */ if (sk->sk_destruct) @@ -1469,16 +1471,18 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, sock = sockfd_lookup(tunnel->fd, &ret); if (!sock) goto err; - - ret = l2tp_validate_socket(sock->sk, net, tunnel->encap); - if (ret < 0) - goto err_sock; } + sk = sock->sk; + write_lock(&sk->sk_callback_lock); + + ret = l2tp_validate_socket(sk, net, tunnel->encap); + if (ret < 0) + goto err_sock; + tunnel->l2tp_net = net; pn = l2tp_pernet(net); - sk = sock->sk; sock_hold(sk); tunnel->sock = sk; @@ -1504,7 +1508,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, setup_udp_tunnel_sock(net, sock, &udp_cfg); } else { - sk->sk_user_data = tunnel; + rcu_assign_sk_user_data(sk, tunnel); } tunnel->old_sk_destruct = sk->sk_destruct; @@ -1518,6 +1522,7 @@ int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, if (tunnel->fd >= 0) sockfd_put(sock); + write_unlock(&sk->sk_callback_lock); return 0; err_sock: @@ -1525,6 +1530,8 @@ err_sock: sock_release(sock); else sockfd_put(sock); + + write_unlock(&sk->sk_callback_lock); err: return ret; }