diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h index 44a419b9e3d5..520dd894b73d 100644 --- a/include/net/inet_hashtables.h +++ b/include/net/inet_hashtables.h @@ -170,6 +170,16 @@ struct inet_hashinfo { struct inet_listen_hashbucket *lhash2; }; +static inline struct inet_hashinfo *tcp_or_dccp_get_hashinfo(const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IP_DCCP) + return sk->sk_prot->h.hashinfo ? : + sock_net(sk)->ipv4.tcp_death_row.hashinfo; +#else + return sock_net(sk)->ipv4.tcp_death_row.hashinfo; +#endif +} + static inline struct inet_listen_hashbucket * inet_lhash2_bucket(struct inet_hashinfo *h, u32 hash) { diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index d3ab1ae32ef5..e2c219382345 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -1250,7 +1250,7 @@ static int inet_sk_reselect_saddr(struct sock *sk) } prev_addr_hashbucket = - inet_bhashfn_portaddr(sk->sk_prot->h.hashinfo, sk, + inet_bhashfn_portaddr(tcp_or_dccp_get_hashinfo(sk), sk, sock_net(sk), inet->inet_num); inet->inet_saddr = inet->inet_rcv_saddr = new_saddr; diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 8e71d65cfad4..ebca860e113f 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -285,7 +285,7 @@ inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret, struct inet_bind2_bucket **tb2_ret, struct inet_bind_hashbucket **head2_ret, int *port_ret) { - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); int i, low, high, attempt_half, port, l3mdev; struct inet_bind_hashbucket *head, *head2; struct net *net = sock_net(sk); @@ -467,8 +467,8 @@ void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, */ int inet_csk_get_port(struct sock *sk, unsigned short snum) { + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; bool found_port = false, check_bind_conflict = true; bool bhash_created = false, bhash2_created = false; struct inet_bind_hashbucket *head, *head2; @@ -910,10 +910,9 @@ static bool reqsk_queue_unlink(struct request_sock *req) bool found = false; if (sk_hashed(sk)) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; - spinlock_t *lock; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + spinlock_t *lock = inet_ehash_lockp(hashinfo, req->rsk_hash); - lock = inet_ehash_lockp(hashinfo, req->rsk_hash); spin_lock(lock); found = __sk_nulls_del_node_init_rcu(sk); spin_unlock(lock); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c index 29dce78de179..bdb5427a7a3d 100644 --- a/net/ipv4/inet_hashtables.c +++ b/net/ipv4/inet_hashtables.c @@ -168,7 +168,7 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, */ static void __inet_put_port(struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_bind_hashbucket *head, *head2; struct net *net = sock_net(sk); struct inet_bind_bucket *tb; @@ -208,7 +208,7 @@ EXPORT_SYMBOL(inet_put_port); int __inet_inherit_port(const struct sock *sk, struct sock *child) { - struct inet_hashinfo *table = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk); unsigned short port = inet_sk(child)->inet_num; struct inet_bind_hashbucket *head, *head2; bool created_inet_bind_bucket = false; @@ -629,7 +629,7 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk, */ bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_ehash_bucket *head; struct hlist_nulls_head *list; spinlock_t *lock; @@ -701,7 +701,7 @@ static int inet_reuseport_add_sock(struct sock *sk, int __inet_hash(struct sock *sk, struct sock *osk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_listen_hashbucket *ilb2; int err = 0; @@ -747,7 +747,7 @@ EXPORT_SYMBOL_GPL(inet_hash); void inet_unhash(struct sock *sk) { - struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); if (sk_unhashed(sk)) return; @@ -834,7 +834,7 @@ inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net struct inet_bind_hashbucket * inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) { - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); u32 hash; #if IS_ENABLED(CONFIG_IPV6) struct in6_addr addr_any = {}; @@ -850,7 +850,7 @@ inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, in int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk) { - struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo; + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); struct inet_bind2_bucket *tb2, *new_tb2; int l3mdev = inet_sk_bound_l3mdev(sk); struct inet_bind_hashbucket *head2; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 3930b6a1e0d6..3bb7da95b757 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -3083,7 +3083,7 @@ struct proto tcp_prot = { .slab_flags = SLAB_TYPESAFE_BY_RCU, .twsk_prot = &tcp_timewait_sock_ops, .rsk_prot = &tcp_request_sock_ops, - .h.hashinfo = &tcp_hashinfo, + .h.hashinfo = NULL, .no_autobind = true, .diag_destroy = tcp_abort, }; diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index eb1da7a63fbb..e0b5f5b4d868 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -2194,7 +2194,7 @@ struct proto tcpv6_prot = { .slab_flags = SLAB_TYPESAFE_BY_RCU, .twsk_prot = &tcp6_timewait_sock_ops, .rsk_prot = &tcp6_request_sock_ops, - .h.hashinfo = &tcp_hashinfo, + .h.hashinfo = NULL, .no_autobind = true, .diag_destroy = tcp_abort, };