@@ -1926,10 +1926,9 @@ svm_range_update_notifier_and_interval_tree(struct mm_struct *mm,
}
static void
-svm_range_handle_list_op(struct svm_range_list *svms, struct svm_range *prange)
+svm_range_handle_list_op(struct svm_range_list *svms, struct svm_range *prange,
+ struct mm_struct *mm)
{
- struct mm_struct *mm = prange->work_item.mm;
-
switch (prange->work_item.op) {
case SVM_OP_NULL:
pr_debug("NULL OP 0x%p prange 0x%p [0x%lx 0x%lx]\n",
@@ -2013,34 +2012,41 @@ static void svm_range_deferred_list_work(struct work_struct *work)
pr_debug("enter svms 0x%p\n", svms);
p = container_of(svms, struct kfd_process, svms);
- /* Avoid mm is gone when inserting mmu notifier */
- mm = get_task_mm(p->lead_thread);
- if (!mm) {
- pr_debug("svms 0x%p process mm gone\n", svms);
- return;
- }
-retry:
- mmap_write_lock(mm);
-
- /* Checking for the need to drain retry faults must be inside
- * mmap write lock to serialize with munmap notifiers.
- */
- if (unlikely(atomic_read(&svms->drain_pagefaults))) {
- mmap_write_unlock(mm);
- svm_range_drain_retry_fault(svms);
- goto retry;
- }
spin_lock(&svms->deferred_list_lock);
while (!list_empty(&svms->deferred_range_list)) {
prange = list_first_entry(&svms->deferred_range_list,
struct svm_range, deferred_list);
- list_del_init(&prange->deferred_list);
spin_unlock(&svms->deferred_list_lock);
pr_debug("prange 0x%p [0x%lx 0x%lx] op %d\n", prange,
prange->start, prange->last, prange->work_item.op);
+ mm = prange->work_item.mm;
+retry:
+ mmap_write_lock(mm);
+
+ /* Checking for the need to drain retry faults must be inside
+ * mmap write lock to serialize with munmap notifiers.
+ */
+ if (unlikely(atomic_read(&svms->drain_pagefaults))) {
+ mmap_write_unlock(mm);
+ svm_range_drain_retry_fault(svms);
+ goto retry;
+ }
+
+ /* Remove from deferred_list must be inside mmap write lock, for
+ * two race cases:
+ * 1. unmap_from_cpu may change work_item.op and add the range
+ * to deferred_list again, cause use after free bug.
+ * 2. svm_range_list_lock_and_flush_work may hold mmap write
+ * lock and continue because deferred_list is empty, but
+ * deferred_list work is actually waiting for mmap lock.
+ */
+ spin_lock(&svms->deferred_list_lock);
+ list_del_init(&prange->deferred_list);
+ spin_unlock(&svms->deferred_list_lock);
+
mutex_lock(&svms->lock);
mutex_lock(&prange->migrate_mutex);
while (!list_empty(&prange->child_list)) {
@@ -2051,19 +2057,20 @@ static void svm_range_deferred_list_work(struct work_struct *work)
pr_debug("child prange 0x%p op %d\n", pchild,
pchild->work_item.op);
list_del_init(&pchild->child_list);
- svm_range_handle_list_op(svms, pchild);
+ svm_range_handle_list_op(svms, pchild, mm);
}
mutex_unlock(&prange->migrate_mutex);
- svm_range_handle_list_op(svms, prange);
+ svm_range_handle_list_op(svms, prange, mm);
mutex_unlock(&svms->lock);
+ mmap_write_unlock(mm);
+
+ /* Pairs with mmget in svm_range_add_list_work */
+ mmput(mm);
spin_lock(&svms->deferred_list_lock);
}
spin_unlock(&svms->deferred_list_lock);
-
- mmap_write_unlock(mm);
- mmput(mm);
pr_debug("exit svms 0x%p\n", svms);
}
@@ -2081,6 +2088,9 @@ svm_range_add_list_work(struct svm_range_list *svms, struct svm_range *prange,
prange->work_item.op = op;
} else {
prange->work_item.op = op;
+
+ /* Pairs with mmput in deferred_list_work */
+ mmget(mm);
prange->work_item.mm = mm;
list_add_tail(&prange->deferred_list,
&prange->svms->deferred_range_list);