diff --git a/include/uapi/linux/mptcp.h b/include/uapi/linux/mptcp.h index f106a3941cdf..9690efedb5fa 100644 --- a/include/uapi/linux/mptcp.h +++ b/include/uapi/linux/mptcp.h @@ -81,6 +81,7 @@ enum { #define MPTCP_PM_ADDR_FLAG_SUBFLOW (1 << 1) #define MPTCP_PM_ADDR_FLAG_BACKUP (1 << 2) #define MPTCP_PM_ADDR_FLAG_FULLMESH (1 << 3) +#define MPTCP_PM_ADDR_FLAG_IMPLICIT (1 << 4) enum { MPTCP_PM_CMD_UNSPEC, diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 91b77d1162cf..10368a4f1c4a 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -877,10 +877,18 @@ static bool address_use_port(struct mptcp_pm_addr_entry *entry) MPTCP_PM_ADDR_FLAG_SIGNAL; } +/* caller must ensure the RCU grace period is already elapsed */ +static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry) +{ + if (entry->lsk) + sock_release(entry->lsk); + kfree(entry); +} + static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, struct mptcp_pm_addr_entry *entry) { - struct mptcp_pm_addr_entry *cur; + struct mptcp_pm_addr_entry *cur, *del_entry = NULL; unsigned int addr_max; int ret = -EINVAL; @@ -901,8 +909,22 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, list_for_each_entry(cur, &pernet->local_addr_list, list) { if (addresses_equal(&cur->addr, &entry->addr, address_use_port(entry) && - address_use_port(cur))) - goto out; + address_use_port(cur))) { + /* allow replacing the exiting endpoint only if such + * endpoint is an implicit one and the user-space + * did not provide an endpoint id + */ + if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)) + goto out; + if (entry->addr.id) + goto out; + + pernet->addrs--; + entry->addr.id = cur->addr.id; + list_del_rcu(&cur->list); + del_entry = cur; + break; + } } if (!entry->addr.id) { @@ -938,6 +960,12 @@ find_next: out: spin_unlock_bh(&pernet->lock); + + /* just replaced an existing entry, free it */ + if (del_entry) { + synchronize_rcu(); + __mptcp_pm_release_addr_entry(del_entry); + } return ret; } @@ -1036,7 +1064,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) entry->addr.id = 0; entry->addr.port = 0; entry->ifindex = 0; - entry->flags = 0; + entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT; entry->lsk = NULL; ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); if (ret < 0) @@ -1249,6 +1277,11 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) return -EINVAL; } + if (addr.flags & MPTCP_PM_ADDR_FLAG_IMPLICIT) { + GENL_SET_ERR_MSG(info, "can't create IMPLICIT endpoint"); + return -EINVAL; + } + entry = kmalloc(sizeof(*entry), GFP_KERNEL); if (!entry) { GENL_SET_ERR_MSG(info, "can't allocate addr"); @@ -1333,11 +1366,12 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, } static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, - struct mptcp_addr_info *addr) + const struct mptcp_pm_addr_entry *entry) { - struct mptcp_sock *msk; - long s_slot = 0, s_num = 0; + const struct mptcp_addr_info *addr = &entry->addr; struct mptcp_rm_list list = { .nr = 0 }; + long s_slot = 0, s_num = 0; + struct mptcp_sock *msk; pr_debug("remove_id=%d", addr->id); @@ -1354,7 +1388,8 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, lock_sock(sk); remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); - mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); + mptcp_pm_remove_anno_addr(msk, addr, remove_subflow && + !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)); if (remove_subflow) mptcp_pm_remove_subflow(msk, &list); release_sock(sk); @@ -1367,14 +1402,6 @@ next: return 0; } -/* caller must ensure the RCU grace period is already elapsed */ -static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry) -{ - if (entry->lsk) - sock_release(entry->lsk); - kfree(entry); -} - static int mptcp_nl_remove_id_zero_address(struct net *net, struct mptcp_addr_info *addr) { @@ -1451,7 +1478,7 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) __clear_bit(entry->addr.id, pernet->id_bitmap); spin_unlock_bh(&pernet->lock); - mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); + mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), entry); synchronize_rcu(); __mptcp_pm_release_addr_entry(entry); diff --git a/tools/testing/selftests/net/mptcp/mptcp_join.sh b/tools/testing/selftests/net/mptcp/mptcp_join.sh index 02bab8a2d5a5..1e2e8dd9f0d6 100755 --- a/tools/testing/selftests/net/mptcp/mptcp_join.sh +++ b/tools/testing/selftests/net/mptcp/mptcp_join.sh @@ -1938,7 +1938,7 @@ backup_tests() run_tests $ns1 $ns2 10.0.1.1 0 0 0 slow backup chk_join_nr "single address, backup" 1 1 1 chk_add_nr 1 1 - chk_prio_nr 1 0 + chk_prio_nr 1 1 # single address with port, backup reset @@ -1948,7 +1948,7 @@ backup_tests() run_tests $ns1 $ns2 10.0.1.1 0 0 0 slow backup chk_join_nr "single address with port, backup" 1 1 1 chk_add_nr 1 1 - chk_prio_nr 1 0 + chk_prio_nr 1 1 } add_addr_ports_tests()