diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c index fc011e13213b..c9d756b7ee9e 100644 --- a/drivers/vfio/pci/vfio_pci.c +++ b/drivers/vfio/pci/vfio_pci.c @@ -37,6 +37,8 @@ module_param_named(nointxmask, nointxmask, bool, S_IRUGO | S_IWUSR); MODULE_PARM_DESC(nointxmask, "Disable support for PCI 2.3 style INTx masking. If this resolves problems for specific devices, report lspci -vvvxxx to linux-pci@vger.kernel.org so the device can be fixed automatically via the broken_intx_masking flag."); +static DEFINE_MUTEX(driver_lock); + static int vfio_pci_enable(struct vfio_pci_device *vdev) { struct pci_dev *pdev = vdev->pdev; @@ -163,23 +165,29 @@ static void vfio_pci_release(void *device_data) { struct vfio_pci_device *vdev = device_data; - if (atomic_dec_and_test(&vdev->refcnt)) { + mutex_lock(&driver_lock); + + if (!(--vdev->refcnt)) { vfio_spapr_pci_eeh_release(vdev->pdev); vfio_pci_disable(vdev); } + mutex_unlock(&driver_lock); + module_put(THIS_MODULE); } static int vfio_pci_open(void *device_data) { struct vfio_pci_device *vdev = device_data; - int ret; + int ret = 0; if (!try_module_get(THIS_MODULE)) return -ENODEV; - if (atomic_inc_return(&vdev->refcnt) == 1) { + mutex_lock(&driver_lock); + + if (!vdev->refcnt) { ret = vfio_pci_enable(vdev); if (ret) goto error; @@ -190,10 +198,11 @@ static int vfio_pci_open(void *device_data) goto error; } } - - return 0; + vdev->refcnt++; error: - module_put(THIS_MODULE); + mutex_unlock(&driver_lock); + if (ret) + module_put(THIS_MODULE); return ret; } @@ -849,7 +858,6 @@ static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id) vdev->irq_type = VFIO_PCI_NUM_IRQS; mutex_init(&vdev->igate); spin_lock_init(&vdev->irqlock); - atomic_set(&vdev->refcnt, 0); ret = vfio_add_group_dev(&pdev->dev, &vfio_pci_ops, vdev); if (ret) { @@ -864,12 +872,15 @@ static void vfio_pci_remove(struct pci_dev *pdev) { struct vfio_pci_device *vdev; - vdev = vfio_del_group_dev(&pdev->dev); - if (!vdev) - return; + mutex_lock(&driver_lock); - iommu_group_put(pdev->dev.iommu_group); - kfree(vdev); + vdev = vfio_del_group_dev(&pdev->dev); + if (vdev) { + iommu_group_put(pdev->dev.iommu_group); + kfree(vdev); + } + + mutex_unlock(&driver_lock); } static pci_ers_result_t vfio_pci_aer_err_detected(struct pci_dev *pdev, diff --git a/drivers/vfio/pci/vfio_pci_private.h b/drivers/vfio/pci/vfio_pci_private.h index 9c6d5d0f3b02..31e7a30196ab 100644 --- a/drivers/vfio/pci/vfio_pci_private.h +++ b/drivers/vfio/pci/vfio_pci_private.h @@ -55,7 +55,7 @@ struct vfio_pci_device { bool bardirty; bool has_vga; struct pci_saved_state *pci_saved_state; - atomic_t refcnt; + int refcnt; struct eventfd_ctx *err_trigger; };