diff --git a/drivers/iommu/amd_iommu.c b/drivers/iommu/amd_iommu.c index c7e28a8d25d1..a227e7a9b8b7 100644 --- a/drivers/iommu/amd_iommu.c +++ b/drivers/iommu/amd_iommu.c @@ -501,6 +501,29 @@ static void iommu_uninit_device(struct device *dev) */ } +/* + * Helper function to get the first pte of a large mapping + */ +static u64 *first_pte_l7(u64 *pte, unsigned long *page_size, + unsigned long *count) +{ + unsigned long pte_mask, pg_size, cnt; + u64 *fpte; + + pg_size = PTE_PAGE_SIZE(*pte); + cnt = PAGE_SIZE_PTE_COUNT(pg_size); + pte_mask = ~((cnt << 3) - 1); + fpte = (u64 *)(((unsigned long)pte) & pte_mask); + + if (page_size) + *page_size = pg_size; + + if (count) + *count = cnt; + + return fpte; +} + /**************************************************************************** * * Interrupt handling functions @@ -1567,17 +1590,12 @@ static u64 *fetch_pte(struct protection_domain *domain, *page_size = PTE_LEVEL_PAGE_SIZE(level); } - if (PM_PTE_LEVEL(*pte) == 0x07) { - unsigned long pte_mask; - - /* - * If we have a series of large PTEs, make - * sure to return a pointer to the first one. - */ - *page_size = pte_mask = PTE_PAGE_SIZE(*pte); - pte_mask = ~((PAGE_SIZE_PTE_COUNT(pte_mask) << 3) - 1); - pte = (u64 *)(((unsigned long)pte) & pte_mask); - } + /* + * If we have a series of large PTEs, make + * sure to return a pointer to the first one. + */ + if (PM_PTE_LEVEL(*pte) == PAGE_MODE_7_LEVEL) + pte = first_pte_l7(pte, page_size, NULL); return pte; }