net: devlink: make sure that devlink_try_get() works with valid pointer during xarray iteration

Remove dependency on devlink_mutex during devlinks xarray iteration.

The reason is that devlink_register/unregister() functions taking
devlink_mutex would deadlock during devlink reload operation of devlink
instance which registers/unregisters nested devlink instances.

The devlinks xarray consistency is ensured internally by xarray.
There is a reference taken when working with devlink using
devlink_try_get(). But there is no guarantee that devlink pointer
picked during xarray iteration is not freed before devlink_try_get()
is called.

Make sure that devlink_try_get() works with valid pointer.
Achieve it by:
1) Splitting devlink_put() so the completion is sent only
   after grace period. Completion unblocks the devlink_unregister()
   routine, which is followed-up by devlink_free()
2) During devlinks xa_array iteration, get devlink pointer from xa_array
   holding RCU read lock and taking reference using devlink_try_get()
   before unlock.

Signed-off-by: Jiri Pirko <jiri@nvidia.com>
Reviewed-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Jiri Pirko 2022-07-25 10:29:14 +02:00 committed by Jakub Kicinski
parent 35d099da41
commit 30bab7cdb5
1 changed files with 80 additions and 91 deletions

View File

@ -70,6 +70,7 @@ struct devlink {
u8 reload_failed:1;
refcount_t refcount;
struct completion comp;
struct rcu_head rcu;
char priv[] __aligned(NETDEV_ALIGN);
};
@ -221,8 +222,6 @@ static DEFINE_XARRAY_FLAGS(devlinks, XA_FLAGS_ALLOC);
/* devlink_mutex
*
* An overall lock guarding every operation coming from userspace.
* It also guards devlink devices list and it is taken when
* driver registers/unregisters it.
*/
static DEFINE_MUTEX(devlink_mutex);
@ -232,10 +231,21 @@ struct net *devlink_net(const struct devlink *devlink)
}
EXPORT_SYMBOL_GPL(devlink_net);
static void __devlink_put_rcu(struct rcu_head *head)
{
struct devlink *devlink = container_of(head, struct devlink, rcu);
complete(&devlink->comp);
}
void devlink_put(struct devlink *devlink)
{
if (refcount_dec_and_test(&devlink->refcount))
complete(&devlink->comp);
/* Make sure unregister operation that may await the completion
* is unblocked only after all users are after the end of
* RCU grace period.
*/
call_rcu(&devlink->rcu, __devlink_put_rcu);
}
struct devlink *__must_check devlink_try_get(struct devlink *devlink)
@ -278,12 +288,55 @@ void devl_unlock(struct devlink *devlink)
}
EXPORT_SYMBOL_GPL(devl_unlock);
static struct devlink *
devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter,
void * (*xa_find_fn)(struct xarray *, unsigned long *,
unsigned long, xa_mark_t))
{
struct devlink *devlink;
rcu_read_lock();
retry:
devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED);
if (!devlink)
goto unlock;
/* For a possible retry, the xa_find_after() should be always used */
xa_find_fn = xa_find_after;
if (!devlink_try_get(devlink))
goto retry;
unlock:
rcu_read_unlock();
return devlink;
}
static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp,
xa_mark_t filter)
{
return devlinks_xa_find_get(indexp, filter, xa_find);
}
static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp,
xa_mark_t filter)
{
return devlinks_xa_find_get(indexp, filter, xa_find_after);
}
/* Iterate over devlink pointers which were possible to get reference to.
* devlink_put() needs to be called for each iterated devlink pointer
* in loop body in order to release the reference.
*/
#define devlinks_xa_for_each_get(index, devlink, filter) \
for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter); \
devlink; devlink = devlinks_xa_find_get_next(&index, filter))
#define devlinks_xa_for_each_registered_get(index, devlink) \
devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED)
static struct devlink *devlink_get_from_attrs(struct net *net,
struct nlattr **attrs)
{
struct devlink *devlink;
unsigned long index;
bool found = false;
char *busname;
char *devname;
@ -293,21 +346,15 @@ static struct devlink *devlink_get_from_attrs(struct net *net,
busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
lockdep_assert_held(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
devlinks_xa_for_each_registered_get(index, devlink) {
if (strcmp(devlink->dev->bus->name, busname) == 0 &&
strcmp(dev_name(devlink->dev), devname) == 0 &&
net_eq(devlink_net(devlink), net)) {
found = true;
break;
}
net_eq(devlink_net(devlink), net))
return devlink;
devlink_put(devlink);
}
if (!found || !devlink_try_get(devlink))
devlink = ERR_PTR(-ENODEV);
return devlink;
return ERR_PTR(-ENODEV);
}
static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink,
@ -1329,10 +1376,7 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -1432,10 +1476,7 @@ static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
devlink_put(devlink);
continue;
@ -1495,10 +1536,7 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -2177,10 +2215,7 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -2449,10 +2484,7 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -2601,10 +2633,7 @@ static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
!devlink->ops->sb_pool_get)
goto retry;
@ -2822,10 +2851,7 @@ static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
!devlink->ops->sb_port_pool_get)
goto retry;
@ -3071,10 +3097,7 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
!devlink->ops->sb_tc_pool_bind_get)
goto retry;
@ -5158,10 +5181,7 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -5393,10 +5413,7 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -5977,10 +5994,7 @@ static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -6511,10 +6525,7 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg,
int err = 0;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -7691,10 +7702,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry_rep;
@ -7721,10 +7729,7 @@ retry_rep:
devlink_put(devlink);
}
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry_port;
@ -8291,10 +8296,7 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -8518,10 +8520,7 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -8832,10 +8831,7 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg,
int err;
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
goto retry;
@ -9589,10 +9585,8 @@ void devlink_register(struct devlink *devlink)
ASSERT_DEVLINK_NOT_REGISTERED(devlink);
/* Make sure that we are in .probe() routine */
mutex_lock(&devlink_mutex);
xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
devlink_notify_register(devlink);
mutex_unlock(&devlink_mutex);
}
EXPORT_SYMBOL_GPL(devlink_register);
@ -9609,10 +9603,8 @@ void devlink_unregister(struct devlink *devlink)
devlink_put(devlink);
wait_for_completion(&devlink->comp);
mutex_lock(&devlink_mutex);
devlink_notify_unregister(devlink);
xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
mutex_unlock(&devlink_mutex);
}
EXPORT_SYMBOL_GPL(devlink_unregister);
@ -12281,10 +12273,7 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net)
* all devlink instances from this namespace into init_net.
*/
mutex_lock(&devlink_mutex);
xa_for_each_marked(&devlinks, index, devlink, DEVLINK_REGISTERED) {
if (!devlink_try_get(devlink))
continue;
devlinks_xa_for_each_registered_get(index, devlink) {
if (!net_eq(devlink_net(devlink), net))
goto retry;