Commit c9990ab3 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Doug Ledford

RDMA/umem: Move all the ODP related stuff out of ucontext and into per_mm

This is the first step to make ODP use the owning_mm that is now part of
struct ib_umem.

Each ODP umem is linked to a single per_mm structure, which in turn, is
linked to a single mm, via the embedded mmu_notifier. This first patch
introduces the structure and reworks eveything to use it.

This also needs to introduce tgid into the ib_ucontext_per_mm, as
get_user_pages_remote() requires the originating task for statistics
tracking.
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarDoug Ledford <dledford@redhat.com>
parent 597ecc5a
...@@ -115,34 +115,35 @@ static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp) ...@@ -115,34 +115,35 @@ static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
} }
/* Account for a new mmu notifier in an ib_ucontext. */ /* Account for a new mmu notifier in an ib_ucontext. */
static void ib_ucontext_notifier_start_account(struct ib_ucontext *context) static void
ib_ucontext_notifier_start_account(struct ib_ucontext_per_mm *per_mm)
{ {
atomic_inc(&context->notifier_count); atomic_inc(&per_mm->notifier_count);
} }
/* Account for a terminating mmu notifier in an ib_ucontext. /* Account for a terminating mmu notifier in an ib_ucontext.
* *
* Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since * Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since
* the function takes the semaphore itself. */ * the function takes the semaphore itself. */
static void ib_ucontext_notifier_end_account(struct ib_ucontext *context) static void ib_ucontext_notifier_end_account(struct ib_ucontext_per_mm *per_mm)
{ {
int zero_notifiers = atomic_dec_and_test(&context->notifier_count); int zero_notifiers = atomic_dec_and_test(&per_mm->notifier_count);
if (zero_notifiers && if (zero_notifiers &&
!list_empty(&context->no_private_counters)) { !list_empty(&per_mm->no_private_counters)) {
/* No currently running mmu notifiers. Now is the chance to /* No currently running mmu notifiers. Now is the chance to
* add private accounting to all previously added umems. */ * add private accounting to all previously added umems. */
struct ib_umem_odp *odp_data, *next; struct ib_umem_odp *odp_data, *next;
/* Prevent concurrent mmu notifiers from working on the /* Prevent concurrent mmu notifiers from working on the
* no_private_counters list. */ * no_private_counters list. */
down_write(&context->umem_rwsem); down_write(&per_mm->umem_rwsem);
/* Read the notifier_count again, with the umem_rwsem /* Read the notifier_count again, with the umem_rwsem
* semaphore taken for write. */ * semaphore taken for write. */
if (!atomic_read(&context->notifier_count)) { if (!atomic_read(&per_mm->notifier_count)) {
list_for_each_entry_safe(odp_data, next, list_for_each_entry_safe(odp_data, next,
&context->no_private_counters, &per_mm->no_private_counters,
no_private_counters) { no_private_counters) {
mutex_lock(&odp_data->umem_mutex); mutex_lock(&odp_data->umem_mutex);
odp_data->mn_counters_active = true; odp_data->mn_counters_active = true;
...@@ -152,7 +153,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context) ...@@ -152,7 +153,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
} }
} }
up_write(&context->umem_rwsem); up_write(&per_mm->umem_rwsem);
} }
} }
...@@ -179,19 +180,20 @@ static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp, ...@@ -179,19 +180,20 @@ static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
static void ib_umem_notifier_release(struct mmu_notifier *mn, static void ib_umem_notifier_release(struct mmu_notifier *mn,
struct mm_struct *mm) struct mm_struct *mm)
{ {
struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn); struct ib_ucontext_per_mm *per_mm =
container_of(mn, struct ib_ucontext_per_mm, mn);
if (!context->invalidate_range) if (!per_mm->context->invalidate_range)
return; return;
ib_ucontext_notifier_start_account(context); ib_ucontext_notifier_start_account(per_mm);
down_read(&context->umem_rwsem); down_read(&per_mm->umem_rwsem);
rbt_ib_umem_for_each_in_range(&context->umem_tree, 0, rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0,
ULLONG_MAX, ULLONG_MAX,
ib_umem_notifier_release_trampoline, ib_umem_notifier_release_trampoline,
true, true,
NULL); NULL);
up_read(&context->umem_rwsem); up_read(&per_mm->umem_rwsem);
} }
static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start, static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start,
...@@ -217,23 +219,24 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn, ...@@ -217,23 +219,24 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
unsigned long end, unsigned long end,
bool blockable) bool blockable)
{ {
struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn); struct ib_ucontext_per_mm *per_mm =
container_of(mn, struct ib_ucontext_per_mm, mn);
int ret; int ret;
if (!context->invalidate_range) if (!per_mm->context->invalidate_range)
return 0; return 0;
if (blockable) if (blockable)
down_read(&context->umem_rwsem); down_read(&per_mm->umem_rwsem);
else if (!down_read_trylock(&context->umem_rwsem)) else if (!down_read_trylock(&per_mm->umem_rwsem))
return -EAGAIN; return -EAGAIN;
ib_ucontext_notifier_start_account(context); ib_ucontext_notifier_start_account(per_mm);
ret = rbt_ib_umem_for_each_in_range(&context->umem_tree, start, ret = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
end, end,
invalidate_range_start_trampoline, invalidate_range_start_trampoline,
blockable, NULL); blockable, NULL);
up_read(&context->umem_rwsem); up_read(&per_mm->umem_rwsem);
return ret; return ret;
} }
...@@ -250,9 +253,10 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn, ...@@ -250,9 +253,10 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
unsigned long start, unsigned long start,
unsigned long end) unsigned long end)
{ {
struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn); struct ib_ucontext_per_mm *per_mm =
container_of(mn, struct ib_ucontext_per_mm, mn);
if (!context->invalidate_range) if (!per_mm->context->invalidate_range)
return; return;
/* /*
...@@ -260,12 +264,12 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn, ...@@ -260,12 +264,12 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
* in ib_umem_notifier_invalidate_range_start so we shouldn't really block * in ib_umem_notifier_invalidate_range_start so we shouldn't really block
* here. But this is ugly and fragile. * here. But this is ugly and fragile.
*/ */
down_read(&context->umem_rwsem); down_read(&per_mm->umem_rwsem);
rbt_ib_umem_for_each_in_range(&context->umem_tree, start, rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
end, end,
invalidate_range_end_trampoline, true, NULL); invalidate_range_end_trampoline, true, NULL);
up_read(&context->umem_rwsem); up_read(&per_mm->umem_rwsem);
ib_ucontext_notifier_end_account(context); ib_ucontext_notifier_end_account(per_mm);
} }
static const struct mmu_notifier_ops ib_umem_notifiers = { static const struct mmu_notifier_ops ib_umem_notifiers = {
...@@ -277,6 +281,7 @@ static const struct mmu_notifier_ops ib_umem_notifiers = { ...@@ -277,6 +281,7 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
unsigned long addr, size_t size) unsigned long addr, size_t size)
{ {
struct ib_ucontext_per_mm *per_mm;
struct ib_umem_odp *odp_data; struct ib_umem_odp *odp_data;
struct ib_umem *umem; struct ib_umem *umem;
int pages = size >> PAGE_SHIFT; int pages = size >> PAGE_SHIFT;
...@@ -292,6 +297,7 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, ...@@ -292,6 +297,7 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
umem->page_shift = PAGE_SHIFT; umem->page_shift = PAGE_SHIFT;
umem->writable = 1; umem->writable = 1;
umem->is_odp = 1; umem->is_odp = 1;
odp_data->per_mm = per_mm = &context->per_mm;
mutex_init(&odp_data->umem_mutex); mutex_init(&odp_data->umem_mutex);
init_completion(&odp_data->notifier_completion); init_completion(&odp_data->notifier_completion);
...@@ -310,15 +316,15 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, ...@@ -310,15 +316,15 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
goto out_page_list; goto out_page_list;
} }
down_write(&context->umem_rwsem); down_write(&per_mm->umem_rwsem);
context->odp_mrs_count++; per_mm->odp_mrs_count++;
rbt_ib_umem_insert(&odp_data->interval_tree, &context->umem_tree); rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree);
if (likely(!atomic_read(&context->notifier_count))) if (likely(!atomic_read(&per_mm->notifier_count)))
odp_data->mn_counters_active = true; odp_data->mn_counters_active = true;
else else
list_add(&odp_data->no_private_counters, list_add(&odp_data->no_private_counters,
&context->no_private_counters); &per_mm->no_private_counters);
up_write(&context->umem_rwsem); up_write(&per_mm->umem_rwsem);
return odp_data; return odp_data;
...@@ -334,6 +340,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -334,6 +340,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{ {
struct ib_ucontext *context = umem_odp->umem.context; struct ib_ucontext *context = umem_odp->umem.context;
struct ib_umem *umem = &umem_odp->umem; struct ib_umem *umem = &umem_odp->umem;
struct ib_ucontext_per_mm *per_mm;
int ret_val; int ret_val;
struct pid *our_pid; struct pid *our_pid;
struct mm_struct *mm = get_task_mm(current); struct mm_struct *mm = get_task_mm(current);
...@@ -396,28 +403,30 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -396,28 +403,30 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
* notification before the "current" task (and MM) is * notification before the "current" task (and MM) is
* destroyed. We use the umem_rwsem semaphore to synchronize. * destroyed. We use the umem_rwsem semaphore to synchronize.
*/ */
down_write(&context->umem_rwsem); umem_odp->per_mm = per_mm = &context->per_mm;
context->odp_mrs_count++;
down_write(&per_mm->umem_rwsem);
per_mm->odp_mrs_count++;
if (likely(ib_umem_start(umem) != ib_umem_end(umem))) if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
rbt_ib_umem_insert(&umem_odp->interval_tree, rbt_ib_umem_insert(&umem_odp->interval_tree,
&context->umem_tree); &per_mm->umem_tree);
if (likely(!atomic_read(&context->notifier_count)) || if (likely(!atomic_read(&per_mm->notifier_count)) ||
context->odp_mrs_count == 1) per_mm->odp_mrs_count == 1)
umem_odp->mn_counters_active = true; umem_odp->mn_counters_active = true;
else else
list_add(&umem_odp->no_private_counters, list_add(&umem_odp->no_private_counters,
&context->no_private_counters); &per_mm->no_private_counters);
downgrade_write(&context->umem_rwsem); downgrade_write(&per_mm->umem_rwsem);
if (context->odp_mrs_count == 1) { if (per_mm->odp_mrs_count == 1) {
/* /*
* Note that at this point, no MMU notifier is running * Note that at this point, no MMU notifier is running
* for this context! * for this per_mm!
*/ */
atomic_set(&context->notifier_count, 0); atomic_set(&per_mm->notifier_count, 0);
INIT_HLIST_NODE(&context->mn.hlist); INIT_HLIST_NODE(&per_mm->mn.hlist);
context->mn.ops = &ib_umem_notifiers; per_mm->mn.ops = &ib_umem_notifiers;
ret_val = mmu_notifier_register(&context->mn, mm); ret_val = mmu_notifier_register(&per_mm->mn, mm);
if (ret_val) { if (ret_val) {
pr_err("Failed to register mmu_notifier %d\n", ret_val); pr_err("Failed to register mmu_notifier %d\n", ret_val);
ret_val = -EBUSY; ret_val = -EBUSY;
...@@ -425,7 +434,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -425,7 +434,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
} }
} }
up_read(&context->umem_rwsem); up_read(&per_mm->umem_rwsem);
/* /*
* Note that doing an mmput can cause a notifier for the relevant mm. * Note that doing an mmput can cause a notifier for the relevant mm.
...@@ -437,7 +446,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -437,7 +446,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
return 0; return 0;
out_mutex: out_mutex:
up_read(&context->umem_rwsem); up_read(&per_mm->umem_rwsem);
vfree(umem_odp->dma_list); vfree(umem_odp->dma_list);
out_page_list: out_page_list:
vfree(umem_odp->page_list); vfree(umem_odp->page_list);
...@@ -449,7 +458,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -449,7 +458,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
void ib_umem_odp_release(struct ib_umem_odp *umem_odp) void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{ {
struct ib_umem *umem = &umem_odp->umem; struct ib_umem *umem = &umem_odp->umem;
struct ib_ucontext *context = umem->context; struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
/* /*
* Ensure that no more pages are mapped in the umem. * Ensure that no more pages are mapped in the umem.
...@@ -460,11 +469,11 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ...@@ -460,11 +469,11 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem), ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
ib_umem_end(umem)); ib_umem_end(umem));
down_write(&context->umem_rwsem); down_write(&per_mm->umem_rwsem);
if (likely(ib_umem_start(umem) != ib_umem_end(umem))) if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
rbt_ib_umem_remove(&umem_odp->interval_tree, rbt_ib_umem_remove(&umem_odp->interval_tree,
&context->umem_tree); &per_mm->umem_tree);
context->odp_mrs_count--; per_mm->odp_mrs_count--;
if (!umem_odp->mn_counters_active) { if (!umem_odp->mn_counters_active) {
list_del(&umem_odp->no_private_counters); list_del(&umem_odp->no_private_counters);
complete_all(&umem_odp->notifier_completion); complete_all(&umem_odp->notifier_completion);
...@@ -477,13 +486,13 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ...@@ -477,13 +486,13 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
* that since we are doing it atomically, no other user could register * that since we are doing it atomically, no other user could register
* and unregister while we do the check. * and unregister while we do the check.
*/ */
downgrade_write(&context->umem_rwsem); downgrade_write(&per_mm->umem_rwsem);
if (!context->odp_mrs_count) { if (!per_mm->odp_mrs_count) {
struct task_struct *owning_process = NULL; struct task_struct *owning_process = NULL;
struct mm_struct *owning_mm = NULL; struct mm_struct *owning_mm = NULL;
owning_process = get_pid_task(context->tgid, owning_process =
PIDTYPE_PID); get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID);
if (owning_process == NULL) if (owning_process == NULL)
/* /*
* The process is already dead, notifier were removed * The process is already dead, notifier were removed
...@@ -498,7 +507,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ...@@ -498,7 +507,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
* removed already. * removed already.
*/ */
goto out_put_task; goto out_put_task;
mmu_notifier_unregister(&context->mn, owning_mm); mmu_notifier_unregister(&per_mm->mn, owning_mm);
mmput(owning_mm); mmput(owning_mm);
...@@ -506,7 +515,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ...@@ -506,7 +515,7 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
put_task_struct(owning_process); put_task_struct(owning_process);
} }
out: out:
up_read(&context->umem_rwsem); up_read(&per_mm->umem_rwsem);
vfree(umem_odp->dma_list); vfree(umem_odp->dma_list);
vfree(umem_odp->page_list); vfree(umem_odp->page_list);
......
...@@ -124,10 +124,11 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file, ...@@ -124,10 +124,11 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
ucontext->cleanup_retryable = false; ucontext->cleanup_retryable = false;
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
ucontext->umem_tree = RB_ROOT_CACHED; ucontext->per_mm.umem_tree = RB_ROOT_CACHED;
init_rwsem(&ucontext->umem_rwsem); init_rwsem(&ucontext->per_mm.umem_rwsem);
ucontext->odp_mrs_count = 0; ucontext->per_mm.odp_mrs_count = 0;
INIT_LIST_HEAD(&ucontext->no_private_counters); INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters);
ucontext->per_mm.context = ucontext;
if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING)) if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
ucontext->invalidate_range = NULL; ucontext->invalidate_range = NULL;
......
...@@ -61,13 +61,21 @@ static int check_parent(struct ib_umem_odp *odp, ...@@ -61,13 +61,21 @@ static int check_parent(struct ib_umem_odp *odp,
return mr && mr->parent == parent && !odp->dying; return mr && mr->parent == parent && !odp->dying;
} }
struct ib_ucontext_per_mm *mr_to_per_mm(struct mlx5_ib_mr *mr)
{
if (WARN_ON(!mr || !mr->umem || !mr->umem->is_odp))
return NULL;
return to_ib_umem_odp(mr->umem)->per_mm;
}
static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp) static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
{ {
struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent; struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
struct ib_ucontext *ctx = odp->umem.context; struct ib_ucontext_per_mm *per_mm = odp->per_mm;
struct rb_node *rb; struct rb_node *rb;
down_read(&ctx->umem_rwsem); down_read(&per_mm->umem_rwsem);
while (1) { while (1) {
rb = rb_next(&odp->interval_tree.rb); rb = rb_next(&odp->interval_tree.rb);
if (!rb) if (!rb)
...@@ -79,19 +87,19 @@ static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp) ...@@ -79,19 +87,19 @@ static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
not_found: not_found:
odp = NULL; odp = NULL;
end: end:
up_read(&ctx->umem_rwsem); up_read(&per_mm->umem_rwsem);
return odp; return odp;
} }
static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx, static struct ib_umem_odp *odp_lookup(u64 start, u64 length,
u64 start, u64 length,
struct mlx5_ib_mr *parent) struct mlx5_ib_mr *parent)
{ {
struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(parent);
struct ib_umem_odp *odp; struct ib_umem_odp *odp;
struct rb_node *rb; struct rb_node *rb;
down_read(&ctx->umem_rwsem); down_read(&per_mm->umem_rwsem);
odp = rbt_ib_umem_lookup(&ctx->umem_tree, start, length); odp = rbt_ib_umem_lookup(&per_mm->umem_tree, start, length);
if (!odp) if (!odp)
goto end; goto end;
...@@ -108,7 +116,7 @@ static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx, ...@@ -108,7 +116,7 @@ static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx,
not_found: not_found:
odp = NULL; odp = NULL;
end: end:
up_read(&ctx->umem_rwsem); up_read(&per_mm->umem_rwsem);
return odp; return odp;
} }
...@@ -116,7 +124,6 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset, ...@@ -116,7 +124,6 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
size_t nentries, struct mlx5_ib_mr *mr, int flags) size_t nentries, struct mlx5_ib_mr *mr, int flags)
{ {
struct ib_pd *pd = mr->ibmr.pd; struct ib_pd *pd = mr->ibmr.pd;
struct ib_ucontext *ctx = pd->uobject->context;
struct mlx5_ib_dev *dev = to_mdev(pd->device); struct mlx5_ib_dev *dev = to_mdev(pd->device);
struct ib_umem_odp *odp; struct ib_umem_odp *odp;
unsigned long va; unsigned long va;
...@@ -131,7 +138,7 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset, ...@@ -131,7 +138,7 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
return; return;
} }
odp = odp_lookup(ctx, offset * MLX5_IMR_MTT_SIZE, odp = odp_lookup(offset * MLX5_IMR_MTT_SIZE,
nentries * MLX5_IMR_MTT_SIZE, mr); nentries * MLX5_IMR_MTT_SIZE, mr);
for (i = 0; i < nentries; i++, pklm++) { for (i = 0; i < nentries; i++, pklm++) {
...@@ -368,7 +375,6 @@ static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd, ...@@ -368,7 +375,6 @@ static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd,
static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
u64 io_virt, size_t bcnt) u64 io_virt, size_t bcnt)
{ {
struct ib_ucontext *ctx = mr->ibmr.pd->uobject->context;
struct mlx5_ib_dev *dev = to_mdev(mr->ibmr.pd->device); struct mlx5_ib_dev *dev = to_mdev(mr->ibmr.pd->device);
struct ib_umem_odp *odp, *result = NULL; struct ib_umem_odp *odp, *result = NULL;
struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem); struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem);
...@@ -377,7 +383,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, ...@@ -377,7 +383,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
struct mlx5_ib_mr *mtt; struct mlx5_ib_mr *mtt;
mutex_lock(&odp_mr->umem_mutex); mutex_lock(&odp_mr->umem_mutex);
odp = odp_lookup(ctx, addr, 1, mr); odp = odp_lookup(addr, 1, mr);
mlx5_ib_dbg(dev, "io_virt:%llx bcnt:%zx addr:%llx odp:%p\n", mlx5_ib_dbg(dev, "io_virt:%llx bcnt:%zx addr:%llx odp:%p\n",
io_virt, bcnt, addr, odp); io_virt, bcnt, addr, odp);
...@@ -387,7 +393,8 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, ...@@ -387,7 +393,8 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
if (nentries) if (nentries)
nentries++; nentries++;
} else { } else {
odp = ib_alloc_odp_umem(ctx, addr, MLX5_IMR_MTT_SIZE); odp = ib_alloc_odp_umem(odp_mr->umem.context, addr,
MLX5_IMR_MTT_SIZE);
if (IS_ERR(odp)) { if (IS_ERR(odp)) {
mutex_unlock(&odp_mr->umem_mutex); mutex_unlock(&odp_mr->umem_mutex);
return ERR_CAST(odp); return ERR_CAST(odp);
...@@ -486,12 +493,12 @@ static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end, ...@@ -486,12 +493,12 @@ static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end,
void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr) void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
{ {
struct ib_ucontext *ctx = imr->ibmr.pd->uobject->context; struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr);
down_read(&ctx->umem_rwsem); down_read(&per_mm->umem_rwsem);
rbt_ib_umem_for_each_in_range(&ctx->umem_tree, 0, ULLONG_MAX, rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0, ULLONG_MAX,
mr_leaf_free, true, imr); mr_leaf_free, true, imr);
up_read(&ctx->umem_rwsem); up_read(&per_mm->umem_rwsem);
wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free)); wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free));
} }
......
...@@ -44,6 +44,8 @@ struct umem_odp_node { ...@@ -44,6 +44,8 @@ struct umem_odp_node {
struct ib_umem_odp { struct ib_umem_odp {
struct ib_umem umem; struct ib_umem umem;
struct ib_ucontext_per_mm *per_mm;
/* /*
* An array of the pages included in the on-demand paging umem. * An array of the pages included in the on-demand paging umem.
* Indices of pages that are currently not mapped into the device will * Indices of pages that are currently not mapped into the device will
......
...@@ -1488,6 +1488,25 @@ struct ib_rdmacg_object { ...@@ -1488,6 +1488,25 @@ struct ib_rdmacg_object {
#endif #endif
}; };
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
struct ib_ucontext_per_mm {
struct ib_ucontext *context;
struct rb_root_cached umem_tree;
/*
* Protects .umem_rbroot and tree, as well as odp_mrs_count and
* mmu notifiers registration.
*/
struct rw_semaphore umem_rwsem;
struct mmu_notifier mn;
atomic_t notifier_count;
/* A list of umems that don't have private mmu notifier counters yet. */
struct list_head no_private_counters;
unsigned int odp_mrs_count;
};
#endif
struct ib_ucontext { struct ib_ucontext {
struct ib_device *device; struct ib_device *device;
struct ib_uverbs_file *ufile; struct ib_uverbs_file *ufile;
...@@ -1502,20 +1521,9 @@ struct ib_ucontext { ...@@ -1502,20 +1521,9 @@ struct ib_ucontext {
struct pid *tgid; struct pid *tgid;
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
struct rb_root_cached umem_tree;
/*
* Protects .umem_rbroot and tree, as well as odp_mrs_count and
* mmu notifiers registration.
*/
struct rw_semaphore umem_rwsem;
void (*invalidate_range)(struct ib_umem_odp *umem_odp, void (*invalidate_range)(struct ib_umem_odp *umem_odp,
unsigned long start, unsigned long end); unsigned long start, unsigned long end);
struct ib_ucontext_per_mm per_mm;
struct mmu_notifier mn;
atomic_t notifier_count;
/* A list of umems that don't have private mmu notifier counters yet. */
struct list_head no_private_counters;
int odp_mrs_count;
#endif #endif
struct ib_rdmacg_object cg_obj; struct ib_rdmacg_object cg_obj;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment