userfaultfd: use vma iterator

Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-17-Liam.Howlett@oracle.com
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
This commit is contained in:
Liam R. Howlett 2023-01-20 11:26:17 -05:00 committed by Andrew Morton
parent 27b2670112
commit 11a9b90274
1 changed files with 34 additions and 55 deletions

View File

@ -883,7 +883,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
/* len == 0 means wake all */
struct userfaultfd_wake_range range = { .len = 0, };
unsigned long new_flags;
MA_STATE(mas, &mm->mm_mt, 0, 0);
VMA_ITERATOR(vmi, mm, 0);
WRITE_ONCE(ctx->released, true);
@ -900,7 +900,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
*/
mmap_write_lock(mm);
prev = NULL;
mas_for_each(&mas, vma, ULONG_MAX) {
for_each_vma(vmi, vma) {
cond_resched();
BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
!!(vma->vm_flags & __VM_UFFD_FLAGS));
@ -909,13 +909,12 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
continue;
}
new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
prev = vmi_vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
new_flags, vma->anon_vma,
vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
NULL_VM_UFFD_CTX, anon_vma_name(vma));
if (prev) {
mas_pause(&mas);
vma = prev;
} else {
prev = vma;
@ -1302,7 +1301,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
bool found;
bool basic_ioctls;
unsigned long start, end, vma_end;
MA_STATE(mas, &mm->mm_mt, 0, 0);
struct vma_iterator vmi;
user_uffdio_register = (struct uffdio_register __user *) arg;
@ -1344,15 +1343,11 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
if (!mmget_not_zero(mm))
goto out;
mmap_write_lock(mm);
mas_set(&mas, start);
vma = mas_find(&mas, ULONG_MAX);
if (!vma)
goto out_unlock;
/* check that there's at least one vma in the range */
ret = -EINVAL;
if (vma->vm_start >= end)
mmap_write_lock(mm);
vma_iter_init(&vmi, mm, start);
vma = vma_find(&vmi, end);
if (!vma)
goto out_unlock;
/*
@ -1371,7 +1366,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
*/
found = false;
basic_ioctls = false;
for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
cur = vma;
do {
cond_resched();
BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@ -1428,16 +1424,14 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
basic_ioctls = true;
found = true;
}
} for_each_vma_range(vmi, cur, end);
BUG_ON(!found);
mas_set(&mas, start);
prev = mas_prev(&mas, 0);
if (prev != vma)
mas_next(&mas, ULONG_MAX);
vma_iter_set(&vmi, start);
prev = vma_prev(&vmi);
ret = 0;
do {
for_each_vma_range(vmi, vma, end) {
cond_resched();
BUG_ON(!vma_can_userfault(vma, vm_flags));
@ -1458,30 +1452,25 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
vma_end = min(end, vma->vm_end);
new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
prev = vma_merge(mm, prev, start, vma_end, new_flags,
prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
vma->anon_vma, vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
((struct vm_userfaultfd_ctx){ ctx }),
anon_vma_name(vma));
if (prev) {
/* vma_merge() invalidated the mas */
mas_pause(&mas);
vma = prev;
goto next;
}
if (vma->vm_start < start) {
ret = split_vma(mm, vma, start, 1);
ret = vmi_split_vma(&vmi, mm, vma, start, 1);
if (ret)
break;
/* split_vma() invalidated the mas */
mas_pause(&mas);
}
if (vma->vm_end > end) {
ret = split_vma(mm, vma, end, 0);
ret = vmi_split_vma(&vmi, mm, vma, end, 0);
if (ret)
break;
/* split_vma() invalidated the mas */
mas_pause(&mas);
}
next:
/*
@ -1498,8 +1487,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
skip:
prev = vma;
start = vma->vm_end;
vma = mas_next(&mas, end - 1);
} while (vma);
}
out_unlock:
mmap_write_unlock(mm);
mmput(mm);
@ -1543,7 +1532,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
bool found;
unsigned long start, end, vma_end;
const void __user *buf = (void __user *)arg;
MA_STATE(mas, &mm->mm_mt, 0, 0);
struct vma_iterator vmi;
ret = -EFAULT;
if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@ -1562,14 +1551,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
goto out;
mmap_write_lock(mm);
mas_set(&mas, start);
vma = mas_find(&mas, ULONG_MAX);
if (!vma)
goto out_unlock;
/* check that there's at least one vma in the range */
ret = -EINVAL;
if (vma->vm_start >= end)
vma_iter_init(&vmi, mm, start);
vma = vma_find(&vmi, end);
if (!vma)
goto out_unlock;
/*
@ -1587,8 +1572,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
* Search for not compatible vmas.
*/
found = false;
ret = -EINVAL;
for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
cur = vma;
do {
cond_resched();
BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@ -1605,16 +1590,13 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
goto out_unlock;
found = true;
}
} for_each_vma_range(vmi, cur, end);
BUG_ON(!found);
mas_set(&mas, start);
prev = mas_prev(&mas, 0);
if (prev != vma)
mas_next(&mas, ULONG_MAX);
vma_iter_set(&vmi, start);
prev = vma_prev(&vmi);
ret = 0;
do {
for_each_vma_range(vmi, vma, end) {
cond_resched();
BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
@ -1650,26 +1632,23 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
uffd_wp_range(mm, vma, start, vma_end - start, false);
new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
prev = vma_merge(mm, prev, start, vma_end, new_flags,
prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
vma->anon_vma, vma->vm_file, vma->vm_pgoff,
vma_policy(vma),
NULL_VM_UFFD_CTX, anon_vma_name(vma));
if (prev) {
vma = prev;
mas_pause(&mas);
goto next;
}
if (vma->vm_start < start) {
ret = split_vma(mm, vma, start, 1);
ret = vmi_split_vma(&vmi, mm, vma, start, 1);
if (ret)
break;
mas_pause(&mas);
}
if (vma->vm_end > end) {
ret = split_vma(mm, vma, end, 0);
ret = vmi_split_vma(&vmi, mm, vma, end, 0);
if (ret)
break;
mas_pause(&mas);
}
next:
/*
@ -1683,8 +1662,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
skip:
prev = vma;
start = vma->vm_end;
vma = mas_next(&mas, end - 1);
} while (vma);
}
out_unlock:
mmap_write_unlock(mm);
mmput(mm);