Commit a3e0d41c authored by Jérôme Glisse's avatar Jérôme Glisse Committed by Linus Torvalds

mm/hmm: improve driver API to work and wait over a range

A common use case for HMM mirror is user trying to mirror a range and
before they could program the hardware it get invalidated by some core mm
event.  Instead of having user re-try right away to mirror the range
provide a completion mechanism for them to wait for any active
invalidation affecting the range.

This also changes how hmm_range_snapshot() and hmm_range_fault() works by
not relying on vma so that we can drop the mmap_sem when waiting and
lookup the vma again on retry.

Link: http://lkml.kernel.org/r/20190403193318.16478-7-jglisse@redhat.comSigned-off-by: default avatarJérôme Glisse <jglisse@redhat.com>
Reviewed-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Cc: John Hubbard <jhubbard@nvidia.com>
Cc: Dan Williams <dan.j.williams@intel.com>
Cc: Dan Carpenter <dan.carpenter@oracle.com>
Cc: Matthew Wilcox <willy@infradead.org>
Cc: Arnd Bergmann <arnd@arndb.de>
Cc: Balbir Singh <bsingharora@gmail.com>
Cc: Ira Weiny <ira.weiny@intel.com>
Cc: Souptick Joarder <jrdr.linux@gmail.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent 73231612
...@@ -217,17 +217,33 @@ respect in order to keep things properly synchronized. The usage pattern is:: ...@@ -217,17 +217,33 @@ respect in order to keep things properly synchronized. The usage pattern is::
range.flags = ...; range.flags = ...;
range.values = ...; range.values = ...;
range.pfn_shift = ...; range.pfn_shift = ...;
hmm_range_register(&range);
/*
* Just wait for range to be valid, safe to ignore return value as we
* will use the return value of hmm_range_snapshot() below under the
* mmap_sem to ascertain the validity of the range.
*/
hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);
again: again:
down_read(&mm->mmap_sem); down_read(&mm->mmap_sem);
range.vma = ...;
ret = hmm_range_snapshot(&range); ret = hmm_range_snapshot(&range);
if (ret) { if (ret) {
up_read(&mm->mmap_sem); up_read(&mm->mmap_sem);
if (ret == -EAGAIN) {
/*
* No need to check hmm_range_wait_until_valid() return value
* on retry we will get proper error with hmm_range_snapshot()
*/
hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);
goto again;
}
hmm_mirror_unregister(&range);
return ret; return ret;
} }
take_lock(driver->update); take_lock(driver->update);
if (!hmm_vma_range_done(vma, &range)) { if (!range.valid) {
release_lock(driver->update); release_lock(driver->update);
up_read(&mm->mmap_sem); up_read(&mm->mmap_sem);
goto again; goto again;
...@@ -235,14 +251,15 @@ respect in order to keep things properly synchronized. The usage pattern is:: ...@@ -235,14 +251,15 @@ respect in order to keep things properly synchronized. The usage pattern is::
// Use pfns array content to update device page table // Use pfns array content to update device page table
hmm_mirror_unregister(&range);
release_lock(driver->update); release_lock(driver->update);
up_read(&mm->mmap_sem); up_read(&mm->mmap_sem);
return 0; return 0;
} }
The driver->update lock is the same lock that the driver takes inside its The driver->update lock is the same lock that the driver takes inside its
update() callback. That lock must be held before hmm_vma_range_done() to avoid update() callback. That lock must be held before checking the range.valid
any race with a concurrent CPU page table update. field to avoid any race with a concurrent CPU page table update.
HMM implements all this on top of the mmu_notifier API because we wanted a HMM implements all this on top of the mmu_notifier API because we wanted a
simpler API and also to be able to perform optimizations latter on like doing simpler API and also to be able to perform optimizations latter on like doing
......
...@@ -77,8 +77,34 @@ ...@@ -77,8 +77,34 @@
#include <linux/migrate.h> #include <linux/migrate.h>
#include <linux/memremap.h> #include <linux/memremap.h>
#include <linux/completion.h> #include <linux/completion.h>
#include <linux/mmu_notifier.h>
struct hmm;
/*
* struct hmm - HMM per mm struct
*
* @mm: mm struct this HMM struct is bound to
* @lock: lock protecting ranges list
* @ranges: list of range being snapshotted
* @mirrors: list of mirrors for this mm
* @mmu_notifier: mmu notifier to track updates to CPU page table
* @mirrors_sem: read/write semaphore protecting the mirrors list
* @wq: wait queue for user waiting on a range invalidation
* @notifiers: count of active mmu notifiers
* @dead: is the mm dead ?
*/
struct hmm {
struct mm_struct *mm;
struct kref kref;
struct mutex lock;
struct list_head ranges;
struct list_head mirrors;
struct mmu_notifier mmu_notifier;
struct rw_semaphore mirrors_sem;
wait_queue_head_t wq;
long notifiers;
bool dead;
};
/* /*
* hmm_pfn_flag_e - HMM flag enums * hmm_pfn_flag_e - HMM flag enums
...@@ -155,6 +181,38 @@ struct hmm_range { ...@@ -155,6 +181,38 @@ struct hmm_range {
bool valid; bool valid;
}; };
/*
* hmm_range_wait_until_valid() - wait for range to be valid
* @range: range affected by invalidation to wait on
* @timeout: time out for wait in ms (ie abort wait after that period of time)
* Returns: true if the range is valid, false otherwise.
*/
static inline bool hmm_range_wait_until_valid(struct hmm_range *range,
unsigned long timeout)
{
/* Check if mm is dead ? */
if (range->hmm == NULL || range->hmm->dead || range->hmm->mm == NULL) {
range->valid = false;
return false;
}
if (range->valid)
return true;
wait_event_timeout(range->hmm->wq, range->valid || range->hmm->dead,
msecs_to_jiffies(timeout));
/* Return current valid status just in case we get lucky */
return range->valid;
}
/*
* hmm_range_valid() - test if a range is valid or not
* @range: range
* Returns: true if the range is valid, false otherwise.
*/
static inline bool hmm_range_valid(struct hmm_range *range)
{
return range->valid;
}
/* /*
* hmm_pfn_to_page() - return struct page pointed to by a valid HMM pfn * hmm_pfn_to_page() - return struct page pointed to by a valid HMM pfn
* @range: range use to decode HMM pfn value * @range: range use to decode HMM pfn value
...@@ -357,51 +415,66 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror); ...@@ -357,51 +415,66 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror);
/* /*
* To snapshot the CPU page table, call hmm_vma_get_pfns(), then take a device * Please see Documentation/vm/hmm.rst for how to use the range API.
* driver lock that serializes device page table updates, then call
* hmm_vma_range_done(), to check if the snapshot is still valid. The same
* device driver page table update lock must also be used in the
* hmm_mirror_ops.sync_cpu_device_pagetables() callback, so that CPU page
* table invalidation serializes on it.
*
* YOU MUST CALL hmm_vma_range_done() ONCE AND ONLY ONCE EACH TIME YOU CALL
* hmm_range_snapshot() WITHOUT ERROR !
*
* IF YOU DO NOT FOLLOW THE ABOVE RULE THE SNAPSHOT CONTENT MIGHT BE INVALID !
*/ */
int hmm_range_register(struct hmm_range *range,
struct mm_struct *mm,
unsigned long start,
unsigned long end);
void hmm_range_unregister(struct hmm_range *range);
long hmm_range_snapshot(struct hmm_range *range); long hmm_range_snapshot(struct hmm_range *range);
bool hmm_vma_range_done(struct hmm_range *range); long hmm_range_fault(struct hmm_range *range, bool block);
/* /*
* Fault memory on behalf of device driver. Unlike handle_mm_fault(), this will * HMM_RANGE_DEFAULT_TIMEOUT - default timeout (ms) when waiting for a range
* not migrate any device memory back to system memory. The HMM pfn array will
* be updated with the fault result and current snapshot of the CPU page table
* for the range.
*
* The mmap_sem must be taken in read mode before entering and it might be
* dropped by the function if the block argument is false. In that case, the
* function returns -EAGAIN.
*
* Return value does not reflect if the fault was successful for every single
* address or not. Therefore, the caller must to inspect the HMM pfn array to
* determine fault status for each address.
*
* Trying to fault inside an invalid vma will result in -EINVAL.
* *
* See the function description in mm/hmm.c for further documentation. * When waiting for mmu notifiers we need some kind of time out otherwise we
* could potentialy wait for ever, 1000ms ie 1s sounds like a long time to
* wait already.
*/ */
long hmm_range_fault(struct hmm_range *range, bool block); #define HMM_RANGE_DEFAULT_TIMEOUT 1000
/* This is a temporary helper to avoid merge conflict between trees. */
static inline bool hmm_vma_range_done(struct hmm_range *range)
{
bool ret = hmm_range_valid(range);
hmm_range_unregister(range);
return ret;
}
/* This is a temporary helper to avoid merge conflict between trees. */ /* This is a temporary helper to avoid merge conflict between trees. */
static inline int hmm_vma_fault(struct hmm_range *range, bool block) static inline int hmm_vma_fault(struct hmm_range *range, bool block)
{ {
long ret = hmm_range_fault(range, block); long ret;
if (ret == -EBUSY)
ret = -EAGAIN; ret = hmm_range_register(range, range->vma->vm_mm,
else if (ret == -EAGAIN) range->start, range->end);
ret = -EBUSY; if (ret)
return ret < 0 ? ret : 0; return (int)ret;
if (!hmm_range_wait_until_valid(range, HMM_RANGE_DEFAULT_TIMEOUT)) {
/*
* The mmap_sem was taken by driver we release it here and
* returns -EAGAIN which correspond to mmap_sem have been
* drop in the old API.
*/
up_read(&range->vma->vm_mm->mmap_sem);
return -EAGAIN;
}
ret = hmm_range_fault(range, block);
if (ret <= 0) {
if (ret == -EBUSY || !ret) {
/* Same as above drop mmap_sem to match old API. */
up_read(&range->vma->vm_mm->mmap_sem);
ret = -EBUSY;
} else if (ret == -EAGAIN)
ret = -EBUSY;
hmm_range_unregister(range);
return ret;
}
return 0;
} }
/* Below are for HMM internal use only! Not to be used by device driver! */ /* Below are for HMM internal use only! Not to be used by device driver! */
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment