Commit 49b06385 authored by Suren Baghdasaryan's avatar Suren Baghdasaryan Committed by Andrew Morton

mm: enable page walking API to lock vmas during the walk

walk_page_range() and friends often operate under write-locked mmap_lock. 
With introduction of vma locks, the vmas have to be locked as well during
such walks to prevent concurrent page faults in these areas.  Add an
additional member to mm_walk_ops to indicate locking requirements for the
walk.

The change ensures that page walks which prevent concurrent page faults
by write-locking mmap_lock, operate correctly after introduction of
per-vma locks.  With per-vma locks page faults can be handled under vma
lock without taking mmap_lock at all, so write locking mmap_lock would
not stop them.  The change ensures vmas are properly locked during such
walks.

A sample issue this solves is do_mbind() performing queue_pages_range()
to queue pages for migration.  Without this change a concurrent page
can be faulted into the area and be left out of migration.

Link: https://lkml.kernel.org/r/20230804152724.3090321-2-surenb@google.comSigned-off-by: default avatarSuren Baghdasaryan <surenb@google.com>
Suggested-by: default avatarLinus Torvalds <torvalds@linuxfoundation.org>
Suggested-by: default avatarJann Horn <jannh@google.com>
Cc: David Hildenbrand <david@redhat.com>
Cc: Davidlohr Bueso <dave@stgolabs.net>
Cc: Hugh Dickins <hughd@google.com>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Laurent Dufour <ldufour@linux.ibm.com>
Cc: Liam Howlett <liam.howlett@oracle.com>
Cc: Matthew Wilcox (Oracle) <willy@infradead.org>
Cc: Michal Hocko <mhocko@suse.com>
Cc: Michel Lespinasse <michel@lespinasse.org>
Cc: Peter Xu <peterx@redhat.com>
Cc: Vlastimil Babka <vbabka@suse.cz>
Cc: <stable@vger.kernel.org>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
parent 8b9c1cc0
...@@ -145,6 +145,7 @@ static int subpage_walk_pmd_entry(pmd_t *pmd, unsigned long addr, ...@@ -145,6 +145,7 @@ static int subpage_walk_pmd_entry(pmd_t *pmd, unsigned long addr,
static const struct mm_walk_ops subpage_walk_ops = { static const struct mm_walk_ops subpage_walk_ops = {
.pmd_entry = subpage_walk_pmd_entry, .pmd_entry = subpage_walk_pmd_entry,
.walk_lock = PGWALK_WRLOCK_VERIFY,
}; };
static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr, static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
......
...@@ -102,6 +102,7 @@ static const struct mm_walk_ops pageattr_ops = { ...@@ -102,6 +102,7 @@ static const struct mm_walk_ops pageattr_ops = {
.pmd_entry = pageattr_pmd_entry, .pmd_entry = pageattr_pmd_entry,
.pte_entry = pageattr_pte_entry, .pte_entry = pageattr_pte_entry,
.pte_hole = pageattr_pte_hole, .pte_hole = pageattr_pte_hole,
.walk_lock = PGWALK_RDLOCK,
}; };
static int __set_memory(unsigned long addr, int numpages, pgprot_t set_mask, static int __set_memory(unsigned long addr, int numpages, pgprot_t set_mask,
......
...@@ -2514,6 +2514,7 @@ static int thp_split_walk_pmd_entry(pmd_t *pmd, unsigned long addr, ...@@ -2514,6 +2514,7 @@ static int thp_split_walk_pmd_entry(pmd_t *pmd, unsigned long addr,
static const struct mm_walk_ops thp_split_walk_ops = { static const struct mm_walk_ops thp_split_walk_ops = {
.pmd_entry = thp_split_walk_pmd_entry, .pmd_entry = thp_split_walk_pmd_entry,
.walk_lock = PGWALK_WRLOCK_VERIFY,
}; };
static inline void thp_split_mm(struct mm_struct *mm) static inline void thp_split_mm(struct mm_struct *mm)
...@@ -2565,6 +2566,7 @@ static int __zap_zero_pages(pmd_t *pmd, unsigned long start, ...@@ -2565,6 +2566,7 @@ static int __zap_zero_pages(pmd_t *pmd, unsigned long start,
static const struct mm_walk_ops zap_zero_walk_ops = { static const struct mm_walk_ops zap_zero_walk_ops = {
.pmd_entry = __zap_zero_pages, .pmd_entry = __zap_zero_pages,
.walk_lock = PGWALK_WRLOCK,
}; };
/* /*
...@@ -2655,6 +2657,7 @@ static const struct mm_walk_ops enable_skey_walk_ops = { ...@@ -2655,6 +2657,7 @@ static const struct mm_walk_ops enable_skey_walk_ops = {
.hugetlb_entry = __s390_enable_skey_hugetlb, .hugetlb_entry = __s390_enable_skey_hugetlb,
.pte_entry = __s390_enable_skey_pte, .pte_entry = __s390_enable_skey_pte,
.pmd_entry = __s390_enable_skey_pmd, .pmd_entry = __s390_enable_skey_pmd,
.walk_lock = PGWALK_WRLOCK,
}; };
int s390_enable_skey(void) int s390_enable_skey(void)
...@@ -2692,6 +2695,7 @@ static int __s390_reset_cmma(pte_t *pte, unsigned long addr, ...@@ -2692,6 +2695,7 @@ static int __s390_reset_cmma(pte_t *pte, unsigned long addr,
static const struct mm_walk_ops reset_cmma_walk_ops = { static const struct mm_walk_ops reset_cmma_walk_ops = {
.pte_entry = __s390_reset_cmma, .pte_entry = __s390_reset_cmma,
.walk_lock = PGWALK_WRLOCK,
}; };
void s390_reset_cmma(struct mm_struct *mm) void s390_reset_cmma(struct mm_struct *mm)
...@@ -2728,6 +2732,7 @@ static int s390_gather_pages(pte_t *ptep, unsigned long addr, ...@@ -2728,6 +2732,7 @@ static int s390_gather_pages(pte_t *ptep, unsigned long addr,
static const struct mm_walk_ops gather_pages_ops = { static const struct mm_walk_ops gather_pages_ops = {
.pte_entry = s390_gather_pages, .pte_entry = s390_gather_pages,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
......
...@@ -757,12 +757,14 @@ static int smaps_hugetlb_range(pte_t *pte, unsigned long hmask, ...@@ -757,12 +757,14 @@ static int smaps_hugetlb_range(pte_t *pte, unsigned long hmask,
static const struct mm_walk_ops smaps_walk_ops = { static const struct mm_walk_ops smaps_walk_ops = {
.pmd_entry = smaps_pte_range, .pmd_entry = smaps_pte_range,
.hugetlb_entry = smaps_hugetlb_range, .hugetlb_entry = smaps_hugetlb_range,
.walk_lock = PGWALK_RDLOCK,
}; };
static const struct mm_walk_ops smaps_shmem_walk_ops = { static const struct mm_walk_ops smaps_shmem_walk_ops = {
.pmd_entry = smaps_pte_range, .pmd_entry = smaps_pte_range,
.hugetlb_entry = smaps_hugetlb_range, .hugetlb_entry = smaps_hugetlb_range,
.pte_hole = smaps_pte_hole, .pte_hole = smaps_pte_hole,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
...@@ -1244,6 +1246,7 @@ static int clear_refs_test_walk(unsigned long start, unsigned long end, ...@@ -1244,6 +1246,7 @@ static int clear_refs_test_walk(unsigned long start, unsigned long end,
static const struct mm_walk_ops clear_refs_walk_ops = { static const struct mm_walk_ops clear_refs_walk_ops = {
.pmd_entry = clear_refs_pte_range, .pmd_entry = clear_refs_pte_range,
.test_walk = clear_refs_test_walk, .test_walk = clear_refs_test_walk,
.walk_lock = PGWALK_WRLOCK,
}; };
static ssize_t clear_refs_write(struct file *file, const char __user *buf, static ssize_t clear_refs_write(struct file *file, const char __user *buf,
...@@ -1621,6 +1624,7 @@ static const struct mm_walk_ops pagemap_ops = { ...@@ -1621,6 +1624,7 @@ static const struct mm_walk_ops pagemap_ops = {
.pmd_entry = pagemap_pmd_range, .pmd_entry = pagemap_pmd_range,
.pte_hole = pagemap_pte_hole, .pte_hole = pagemap_pte_hole,
.hugetlb_entry = pagemap_hugetlb_range, .hugetlb_entry = pagemap_hugetlb_range,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
...@@ -1934,6 +1938,7 @@ static int gather_hugetlb_stats(pte_t *pte, unsigned long hmask, ...@@ -1934,6 +1938,7 @@ static int gather_hugetlb_stats(pte_t *pte, unsigned long hmask,
static const struct mm_walk_ops show_numa_ops = { static const struct mm_walk_ops show_numa_ops = {
.hugetlb_entry = gather_hugetlb_stats, .hugetlb_entry = gather_hugetlb_stats,
.pmd_entry = gather_pte_stats, .pmd_entry = gather_pte_stats,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
......
...@@ -6,6 +6,16 @@ ...@@ -6,6 +6,16 @@
struct mm_walk; struct mm_walk;
/* Locking requirement during a page walk. */
enum page_walk_lock {
/* mmap_lock should be locked for read to stabilize the vma tree */
PGWALK_RDLOCK = 0,
/* vma will be write-locked during the walk */
PGWALK_WRLOCK = 1,
/* vma is expected to be already write-locked during the walk */
PGWALK_WRLOCK_VERIFY = 2,
};
/** /**
* struct mm_walk_ops - callbacks for walk_page_range * struct mm_walk_ops - callbacks for walk_page_range
* @pgd_entry: if set, called for each non-empty PGD (top-level) entry * @pgd_entry: if set, called for each non-empty PGD (top-level) entry
...@@ -66,6 +76,7 @@ struct mm_walk_ops { ...@@ -66,6 +76,7 @@ struct mm_walk_ops {
int (*pre_vma)(unsigned long start, unsigned long end, int (*pre_vma)(unsigned long start, unsigned long end,
struct mm_walk *walk); struct mm_walk *walk);
void (*post_vma)(struct mm_walk *walk); void (*post_vma)(struct mm_walk *walk);
enum page_walk_lock walk_lock;
}; };
/* /*
......
...@@ -386,6 +386,7 @@ static int damon_mkold_hugetlb_entry(pte_t *pte, unsigned long hmask, ...@@ -386,6 +386,7 @@ static int damon_mkold_hugetlb_entry(pte_t *pte, unsigned long hmask,
static const struct mm_walk_ops damon_mkold_ops = { static const struct mm_walk_ops damon_mkold_ops = {
.pmd_entry = damon_mkold_pmd_entry, .pmd_entry = damon_mkold_pmd_entry,
.hugetlb_entry = damon_mkold_hugetlb_entry, .hugetlb_entry = damon_mkold_hugetlb_entry,
.walk_lock = PGWALK_RDLOCK,
}; };
static void damon_va_mkold(struct mm_struct *mm, unsigned long addr) static void damon_va_mkold(struct mm_struct *mm, unsigned long addr)
...@@ -525,6 +526,7 @@ static int damon_young_hugetlb_entry(pte_t *pte, unsigned long hmask, ...@@ -525,6 +526,7 @@ static int damon_young_hugetlb_entry(pte_t *pte, unsigned long hmask,
static const struct mm_walk_ops damon_young_ops = { static const struct mm_walk_ops damon_young_ops = {
.pmd_entry = damon_young_pmd_entry, .pmd_entry = damon_young_pmd_entry,
.hugetlb_entry = damon_young_hugetlb_entry, .hugetlb_entry = damon_young_hugetlb_entry,
.walk_lock = PGWALK_RDLOCK,
}; };
static bool damon_va_young(struct mm_struct *mm, unsigned long addr, static bool damon_va_young(struct mm_struct *mm, unsigned long addr,
......
...@@ -562,6 +562,7 @@ static const struct mm_walk_ops hmm_walk_ops = { ...@@ -562,6 +562,7 @@ static const struct mm_walk_ops hmm_walk_ops = {
.pte_hole = hmm_vma_walk_hole, .pte_hole = hmm_vma_walk_hole,
.hugetlb_entry = hmm_vma_walk_hugetlb_entry, .hugetlb_entry = hmm_vma_walk_hugetlb_entry,
.test_walk = hmm_vma_walk_test, .test_walk = hmm_vma_walk_test,
.walk_lock = PGWALK_RDLOCK,
}; };
/** /**
......
...@@ -455,6 +455,12 @@ static int break_ksm_pmd_entry(pmd_t *pmd, unsigned long addr, unsigned long nex ...@@ -455,6 +455,12 @@ static int break_ksm_pmd_entry(pmd_t *pmd, unsigned long addr, unsigned long nex
static const struct mm_walk_ops break_ksm_ops = { static const struct mm_walk_ops break_ksm_ops = {
.pmd_entry = break_ksm_pmd_entry, .pmd_entry = break_ksm_pmd_entry,
.walk_lock = PGWALK_RDLOCK,
};
static const struct mm_walk_ops break_ksm_lock_vma_ops = {
.pmd_entry = break_ksm_pmd_entry,
.walk_lock = PGWALK_WRLOCK,
}; };
/* /*
...@@ -470,16 +476,17 @@ static const struct mm_walk_ops break_ksm_ops = { ...@@ -470,16 +476,17 @@ static const struct mm_walk_ops break_ksm_ops = {
* of the process that owns 'vma'. We also do not want to enforce * of the process that owns 'vma'. We also do not want to enforce
* protection keys here anyway. * protection keys here anyway.
*/ */
static int break_ksm(struct vm_area_struct *vma, unsigned long addr) static int break_ksm(struct vm_area_struct *vma, unsigned long addr, bool lock_vma)
{ {
vm_fault_t ret = 0; vm_fault_t ret = 0;
const struct mm_walk_ops *ops = lock_vma ?
&break_ksm_lock_vma_ops : &break_ksm_ops;
do { do {
int ksm_page; int ksm_page;
cond_resched(); cond_resched();
ksm_page = walk_page_range_vma(vma, addr, addr + 1, ksm_page = walk_page_range_vma(vma, addr, addr + 1, ops, NULL);
&break_ksm_ops, NULL);
if (WARN_ON_ONCE(ksm_page < 0)) if (WARN_ON_ONCE(ksm_page < 0))
return ksm_page; return ksm_page;
if (!ksm_page) if (!ksm_page)
...@@ -565,7 +572,7 @@ static void break_cow(struct ksm_rmap_item *rmap_item) ...@@ -565,7 +572,7 @@ static void break_cow(struct ksm_rmap_item *rmap_item)
mmap_read_lock(mm); mmap_read_lock(mm);
vma = find_mergeable_vma(mm, addr); vma = find_mergeable_vma(mm, addr);
if (vma) if (vma)
break_ksm(vma, addr); break_ksm(vma, addr, false);
mmap_read_unlock(mm); mmap_read_unlock(mm);
} }
...@@ -871,7 +878,7 @@ static void remove_trailing_rmap_items(struct ksm_rmap_item **rmap_list) ...@@ -871,7 +878,7 @@ static void remove_trailing_rmap_items(struct ksm_rmap_item **rmap_list)
* in cmp_and_merge_page on one of the rmap_items we would be removing. * in cmp_and_merge_page on one of the rmap_items we would be removing.
*/ */
static int unmerge_ksm_pages(struct vm_area_struct *vma, static int unmerge_ksm_pages(struct vm_area_struct *vma,
unsigned long start, unsigned long end) unsigned long start, unsigned long end, bool lock_vma)
{ {
unsigned long addr; unsigned long addr;
int err = 0; int err = 0;
...@@ -882,7 +889,7 @@ static int unmerge_ksm_pages(struct vm_area_struct *vma, ...@@ -882,7 +889,7 @@ static int unmerge_ksm_pages(struct vm_area_struct *vma,
if (signal_pending(current)) if (signal_pending(current))
err = -ERESTARTSYS; err = -ERESTARTSYS;
else else
err = break_ksm(vma, addr); err = break_ksm(vma, addr, lock_vma);
} }
return err; return err;
} }
...@@ -1029,7 +1036,7 @@ static int unmerge_and_remove_all_rmap_items(void) ...@@ -1029,7 +1036,7 @@ static int unmerge_and_remove_all_rmap_items(void)
if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma) if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
continue; continue;
err = unmerge_ksm_pages(vma, err = unmerge_ksm_pages(vma,
vma->vm_start, vma->vm_end); vma->vm_start, vma->vm_end, false);
if (err) if (err)
goto error; goto error;
} }
...@@ -2530,7 +2537,7 @@ static int __ksm_del_vma(struct vm_area_struct *vma) ...@@ -2530,7 +2537,7 @@ static int __ksm_del_vma(struct vm_area_struct *vma)
return 0; return 0;
if (vma->anon_vma) { if (vma->anon_vma) {
err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end); err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end, true);
if (err) if (err)
return err; return err;
} }
...@@ -2668,7 +2675,7 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start, ...@@ -2668,7 +2675,7 @@ int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
return 0; /* just ignore the advice */ return 0; /* just ignore the advice */
if (vma->anon_vma) { if (vma->anon_vma) {
err = unmerge_ksm_pages(vma, start, end); err = unmerge_ksm_pages(vma, start, end, true);
if (err) if (err)
return err; return err;
} }
......
...@@ -233,6 +233,7 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start, ...@@ -233,6 +233,7 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start,
static const struct mm_walk_ops swapin_walk_ops = { static const struct mm_walk_ops swapin_walk_ops = {
.pmd_entry = swapin_walk_pmd_entry, .pmd_entry = swapin_walk_pmd_entry,
.walk_lock = PGWALK_RDLOCK,
}; };
static void shmem_swapin_range(struct vm_area_struct *vma, static void shmem_swapin_range(struct vm_area_struct *vma,
...@@ -534,6 +535,7 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd, ...@@ -534,6 +535,7 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
static const struct mm_walk_ops cold_walk_ops = { static const struct mm_walk_ops cold_walk_ops = {
.pmd_entry = madvise_cold_or_pageout_pte_range, .pmd_entry = madvise_cold_or_pageout_pte_range,
.walk_lock = PGWALK_RDLOCK,
}; };
static void madvise_cold_page_range(struct mmu_gather *tlb, static void madvise_cold_page_range(struct mmu_gather *tlb,
...@@ -757,6 +759,7 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr, ...@@ -757,6 +759,7 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
static const struct mm_walk_ops madvise_free_walk_ops = { static const struct mm_walk_ops madvise_free_walk_ops = {
.pmd_entry = madvise_free_pte_range, .pmd_entry = madvise_free_pte_range,
.walk_lock = PGWALK_RDLOCK,
}; };
static int madvise_free_single_vma(struct vm_area_struct *vma, static int madvise_free_single_vma(struct vm_area_struct *vma,
......
...@@ -6024,6 +6024,7 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd, ...@@ -6024,6 +6024,7 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
static const struct mm_walk_ops precharge_walk_ops = { static const struct mm_walk_ops precharge_walk_ops = {
.pmd_entry = mem_cgroup_count_precharge_pte_range, .pmd_entry = mem_cgroup_count_precharge_pte_range,
.walk_lock = PGWALK_RDLOCK,
}; };
static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm) static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
...@@ -6303,6 +6304,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd, ...@@ -6303,6 +6304,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
static const struct mm_walk_ops charge_walk_ops = { static const struct mm_walk_ops charge_walk_ops = {
.pmd_entry = mem_cgroup_move_charge_pte_range, .pmd_entry = mem_cgroup_move_charge_pte_range,
.walk_lock = PGWALK_RDLOCK,
}; };
static void mem_cgroup_move_charge(void) static void mem_cgroup_move_charge(void)
......
...@@ -831,6 +831,7 @@ static int hwpoison_hugetlb_range(pte_t *ptep, unsigned long hmask, ...@@ -831,6 +831,7 @@ static int hwpoison_hugetlb_range(pte_t *ptep, unsigned long hmask,
static const struct mm_walk_ops hwp_walk_ops = { static const struct mm_walk_ops hwp_walk_ops = {
.pmd_entry = hwpoison_pte_range, .pmd_entry = hwpoison_pte_range,
.hugetlb_entry = hwpoison_hugetlb_range, .hugetlb_entry = hwpoison_hugetlb_range,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
......
...@@ -718,6 +718,14 @@ static const struct mm_walk_ops queue_pages_walk_ops = { ...@@ -718,6 +718,14 @@ static const struct mm_walk_ops queue_pages_walk_ops = {
.hugetlb_entry = queue_folios_hugetlb, .hugetlb_entry = queue_folios_hugetlb,
.pmd_entry = queue_folios_pte_range, .pmd_entry = queue_folios_pte_range,
.test_walk = queue_pages_test_walk, .test_walk = queue_pages_test_walk,
.walk_lock = PGWALK_RDLOCK,
};
static const struct mm_walk_ops queue_pages_lock_vma_walk_ops = {
.hugetlb_entry = queue_folios_hugetlb,
.pmd_entry = queue_folios_pte_range,
.test_walk = queue_pages_test_walk,
.walk_lock = PGWALK_WRLOCK,
}; };
/* /*
...@@ -738,7 +746,7 @@ static const struct mm_walk_ops queue_pages_walk_ops = { ...@@ -738,7 +746,7 @@ static const struct mm_walk_ops queue_pages_walk_ops = {
static int static int
queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end, queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
nodemask_t *nodes, unsigned long flags, nodemask_t *nodes, unsigned long flags,
struct list_head *pagelist) struct list_head *pagelist, bool lock_vma)
{ {
int err; int err;
struct queue_pages qp = { struct queue_pages qp = {
...@@ -749,8 +757,10 @@ queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end, ...@@ -749,8 +757,10 @@ queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
.end = end, .end = end,
.first = NULL, .first = NULL,
}; };
const struct mm_walk_ops *ops = lock_vma ?
&queue_pages_lock_vma_walk_ops : &queue_pages_walk_ops;
err = walk_page_range(mm, start, end, &queue_pages_walk_ops, &qp); err = walk_page_range(mm, start, end, ops, &qp);
if (!qp.first) if (!qp.first)
/* whole range in hole */ /* whole range in hole */
...@@ -1078,7 +1088,7 @@ static int migrate_to_node(struct mm_struct *mm, int source, int dest, ...@@ -1078,7 +1088,7 @@ static int migrate_to_node(struct mm_struct *mm, int source, int dest,
vma = find_vma(mm, 0); vma = find_vma(mm, 0);
VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))); VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)));
queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask, queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask,
flags | MPOL_MF_DISCONTIG_OK, &pagelist); flags | MPOL_MF_DISCONTIG_OK, &pagelist, false);
if (!list_empty(&pagelist)) { if (!list_empty(&pagelist)) {
err = migrate_pages(&pagelist, alloc_migration_target, NULL, err = migrate_pages(&pagelist, alloc_migration_target, NULL,
...@@ -1321,12 +1331,8 @@ static long do_mbind(unsigned long start, unsigned long len, ...@@ -1321,12 +1331,8 @@ static long do_mbind(unsigned long start, unsigned long len,
* Lock the VMAs before scanning for pages to migrate, to ensure we don't * Lock the VMAs before scanning for pages to migrate, to ensure we don't
* miss a concurrently inserted page. * miss a concurrently inserted page.
*/ */
vma_iter_init(&vmi, mm, start);
for_each_vma_range(vmi, vma, end)
vma_start_write(vma);
ret = queue_pages_range(mm, start, end, nmask, ret = queue_pages_range(mm, start, end, nmask,
flags | MPOL_MF_INVERT, &pagelist); flags | MPOL_MF_INVERT, &pagelist, true);
if (ret < 0) { if (ret < 0) {
err = ret; err = ret;
......
...@@ -279,6 +279,7 @@ static int migrate_vma_collect_pmd(pmd_t *pmdp, ...@@ -279,6 +279,7 @@ static int migrate_vma_collect_pmd(pmd_t *pmdp,
static const struct mm_walk_ops migrate_vma_walk_ops = { static const struct mm_walk_ops migrate_vma_walk_ops = {
.pmd_entry = migrate_vma_collect_pmd, .pmd_entry = migrate_vma_collect_pmd,
.pte_hole = migrate_vma_collect_hole, .pte_hole = migrate_vma_collect_hole,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
......
...@@ -176,6 +176,7 @@ static const struct mm_walk_ops mincore_walk_ops = { ...@@ -176,6 +176,7 @@ static const struct mm_walk_ops mincore_walk_ops = {
.pmd_entry = mincore_pte_range, .pmd_entry = mincore_pte_range,
.pte_hole = mincore_unmapped_range, .pte_hole = mincore_unmapped_range,
.hugetlb_entry = mincore_hugetlb, .hugetlb_entry = mincore_hugetlb,
.walk_lock = PGWALK_RDLOCK,
}; };
/* /*
......
...@@ -371,6 +371,7 @@ static void mlock_vma_pages_range(struct vm_area_struct *vma, ...@@ -371,6 +371,7 @@ static void mlock_vma_pages_range(struct vm_area_struct *vma,
{ {
static const struct mm_walk_ops mlock_walk_ops = { static const struct mm_walk_ops mlock_walk_ops = {
.pmd_entry = mlock_pte_range, .pmd_entry = mlock_pte_range,
.walk_lock = PGWALK_WRLOCK_VERIFY,
}; };
/* /*
......
...@@ -568,6 +568,7 @@ static const struct mm_walk_ops prot_none_walk_ops = { ...@@ -568,6 +568,7 @@ static const struct mm_walk_ops prot_none_walk_ops = {
.pte_entry = prot_none_pte_entry, .pte_entry = prot_none_pte_entry,
.hugetlb_entry = prot_none_hugetlb_entry, .hugetlb_entry = prot_none_hugetlb_entry,
.test_walk = prot_none_test, .test_walk = prot_none_test,
.walk_lock = PGWALK_WRLOCK,
}; };
int int
......
...@@ -400,6 +400,33 @@ static int __walk_page_range(unsigned long start, unsigned long end, ...@@ -400,6 +400,33 @@ static int __walk_page_range(unsigned long start, unsigned long end,
return err; return err;
} }
static inline void process_mm_walk_lock(struct mm_struct *mm,
enum page_walk_lock walk_lock)
{
if (walk_lock == PGWALK_RDLOCK)
mmap_assert_locked(mm);
else
mmap_assert_write_locked(mm);
}
static inline void process_vma_walk_lock(struct vm_area_struct *vma,
enum page_walk_lock walk_lock)
{
#ifdef CONFIG_PER_VMA_LOCK
switch (walk_lock) {
case PGWALK_WRLOCK:
vma_start_write(vma);
break;
case PGWALK_WRLOCK_VERIFY:
vma_assert_write_locked(vma);
break;
case PGWALK_RDLOCK:
/* PGWALK_RDLOCK is handled by process_mm_walk_lock */
break;
}
#endif
}
/** /**
* walk_page_range - walk page table with caller specific callbacks * walk_page_range - walk page table with caller specific callbacks
* @mm: mm_struct representing the target process of page table walk * @mm: mm_struct representing the target process of page table walk
...@@ -459,7 +486,7 @@ int walk_page_range(struct mm_struct *mm, unsigned long start, ...@@ -459,7 +486,7 @@ int walk_page_range(struct mm_struct *mm, unsigned long start,
if (!walk.mm) if (!walk.mm)
return -EINVAL; return -EINVAL;
mmap_assert_locked(walk.mm); process_mm_walk_lock(walk.mm, ops->walk_lock);
vma = find_vma(walk.mm, start); vma = find_vma(walk.mm, start);
do { do {
...@@ -474,6 +501,7 @@ int walk_page_range(struct mm_struct *mm, unsigned long start, ...@@ -474,6 +501,7 @@ int walk_page_range(struct mm_struct *mm, unsigned long start,
if (ops->pte_hole) if (ops->pte_hole)
err = ops->pte_hole(start, next, -1, &walk); err = ops->pte_hole(start, next, -1, &walk);
} else { /* inside vma */ } else { /* inside vma */
process_vma_walk_lock(vma, ops->walk_lock);
walk.vma = vma; walk.vma = vma;
next = min(end, vma->vm_end); next = min(end, vma->vm_end);
vma = find_vma(mm, vma->vm_end); vma = find_vma(mm, vma->vm_end);
...@@ -549,7 +577,8 @@ int walk_page_range_vma(struct vm_area_struct *vma, unsigned long start, ...@@ -549,7 +577,8 @@ int walk_page_range_vma(struct vm_area_struct *vma, unsigned long start,
if (start < vma->vm_start || end > vma->vm_end) if (start < vma->vm_start || end > vma->vm_end)
return -EINVAL; return -EINVAL;
mmap_assert_locked(walk.mm); process_mm_walk_lock(walk.mm, ops->walk_lock);
process_vma_walk_lock(vma, ops->walk_lock);
return __walk_page_range(start, end, &walk); return __walk_page_range(start, end, &walk);
} }
...@@ -566,7 +595,8 @@ int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops, ...@@ -566,7 +595,8 @@ int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops,
if (!walk.mm) if (!walk.mm)
return -EINVAL; return -EINVAL;
mmap_assert_locked(walk.mm); process_mm_walk_lock(walk.mm, ops->walk_lock);
process_vma_walk_lock(vma, ops->walk_lock);
return __walk_page_range(vma->vm_start, vma->vm_end, &walk); return __walk_page_range(vma->vm_start, vma->vm_end, &walk);
} }
......
...@@ -4284,6 +4284,7 @@ static void walk_mm(struct lruvec *lruvec, struct mm_struct *mm, struct lru_gen_ ...@@ -4284,6 +4284,7 @@ static void walk_mm(struct lruvec *lruvec, struct mm_struct *mm, struct lru_gen_
static const struct mm_walk_ops mm_walk_ops = { static const struct mm_walk_ops mm_walk_ops = {
.test_walk = should_skip_vma, .test_walk = should_skip_vma,
.p4d_entry = walk_pud_range, .p4d_entry = walk_pud_range,
.walk_lock = PGWALK_RDLOCK,
}; };
int err; int err;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment