diff --git a/drivers/infiniband/core/core_priv.h b/drivers/infiniband/core/core_priv.h index 24f2aa2e1b7c..fe5607ebca80 100644 --- a/drivers/infiniband/core/core_priv.h +++ b/drivers/infiniband/core/core_priv.h @@ -279,7 +279,8 @@ static inline void ib_mad_agent_security_change(void) } #endif -struct ib_device *ib_device_get_by_index(u32 ifindex); +struct ib_device *ib_device_get_by_index(const struct net *net, u32 index); + /* RDMA device netlink */ void nldev_init(void); void nldev_exit(void); diff --git a/drivers/infiniband/core/device.c b/drivers/infiniband/core/device.c index 74736ea9b007..e6f82f4d4108 100644 --- a/drivers/infiniband/core/device.c +++ b/drivers/infiniband/core/device.c @@ -250,16 +250,22 @@ static int ib_device_check_mandatory(struct ib_device *device) * Caller must perform ib_device_put() to return the device reference count * when ib_device_get_by_index() returns valid device pointer. */ -struct ib_device *ib_device_get_by_index(u32 index) +struct ib_device *ib_device_get_by_index(const struct net *net, u32 index) { struct ib_device *device; down_read(&devices_rwsem); device = xa_load(&devices, index); if (device) { + if (!rdma_dev_access_netns(device, net)) { + device = NULL; + goto out; + } + if (!ib_device_try_get(device)) device = NULL; } +out: up_read(&devices_rwsem); return device; } @@ -1815,6 +1821,9 @@ int ib_enum_all_devs(nldev_callback nldev_cb, struct sk_buff *skb, down_read(&devices_rwsem); xa_for_each_marked (&devices, index, dev, DEVICE_REGISTERED) { + if (!rdma_dev_access_netns(dev, sock_net(skb->sk))) + continue; + ret = nldev_cb(dev, skb, cb, idx); if (ret) break; diff --git a/drivers/infiniband/core/nldev.c b/drivers/infiniband/core/nldev.c index 11ed58d3fce5..284e5f103fc9 100644 --- a/drivers/infiniband/core/nldev.c +++ b/drivers/infiniband/core/nldev.c @@ -614,7 +614,7 @@ static int nldev_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh, index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL; @@ -658,7 +658,7 @@ static int nldev_set_doit(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL; @@ -706,7 +706,7 @@ static int nldev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb) { /* * There is no need to take lock, because - * we are relying on ib_core's lists_rwsem + * we are relying on ib_core's locking. */ return ib_enum_all_devs(_nldev_get_dumpit, skb, cb); } @@ -729,7 +729,7 @@ static int nldev_port_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL; @@ -783,7 +783,7 @@ static int nldev_port_get_dumpit(struct sk_buff *skb, return -EINVAL; ifindex = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(ifindex); + device = ib_device_get_by_index(sock_net(skb->sk), ifindex); if (!device) return -EINVAL; @@ -838,7 +838,7 @@ static int nldev_res_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL; @@ -987,7 +987,7 @@ static int res_get_common_doit(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL; @@ -1084,7 +1084,7 @@ static int res_get_common_dumpit(struct sk_buff *skb, return -EINVAL; index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL; @@ -1299,7 +1299,7 @@ static int nldev_dellink(struct sk_buff *skb, struct nlmsghdr *nlh, return -EINVAL; index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); - device = ib_device_get_by_index(index); + device = ib_device_get_by_index(sock_net(skb->sk), index); if (!device) return -EINVAL;