diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c
index 01e0f6200631..4ee32964e1dd 100644
--- a/drivers/infiniband/hw/mlx5/odp.c
+++ b/drivers/infiniband/hw/mlx5/odp.c
@@ -1595,10 +1595,12 @@ static void mlx5_ib_prefetch_mr_work(struct work_struct *work)
 	struct prefetch_mr_work *w =
 		container_of(work, struct prefetch_mr_work, work);
 
-	if (w->dev->ib_dev.reg_state == IB_DEV_REGISTERED)
+	if (ib_device_try_get(&w->dev->ib_dev)) {
 		mlx5_ib_prefetch_sg_list(w->dev, w->pf_flags, w->sg_list,
 					 w->num_sge);
-
+		ib_device_put(&w->dev->ib_dev);
+	}
+	put_device(&w->dev->ib_dev.dev);
 	kfree(w);
 }
 
@@ -1617,15 +1619,13 @@ int mlx5_ib_advise_mr_prefetch(struct ib_pd *pd,
 		return mlx5_ib_prefetch_sg_list(dev, pf_flags, sg_list,
 						num_sge);
 
-	if (dev->ib_dev.reg_state != IB_DEV_REGISTERED)
-		return -ENODEV;
-
 	work = kvzalloc(struct_size(work, sg_list, num_sge), GFP_KERNEL);
 	if (!work)
 		return -ENOMEM;
 
 	memcpy(work->sg_list, sg_list, num_sge * sizeof(struct ib_sge));
 
+	get_device(&dev->ib_dev.dev);
 	work->dev = dev;
 	work->pf_flags = pf_flags;
 	work->num_sge = num_sge;