Commit eaa43a06 authored by farah kassabri's avatar farah kassabri Committed by Oded Gabbay

accel/habanalabs: Allow single timestamp registration request at a time

Protect against concurrency of user requesting to register a timestamp
offset (where the driver fills the timestamp when the command submission
has finished executing) to a specific user interrupt ID. The
protection is basically to allow only one timestamp registration
request to be handled at a time.

This is needed because the user can decide to re-use a timestamp
offset (register an already registered offset, to a different
interrupt ID). This means the request will cause the timestamp node to
move from one interrupt list to another interrupt list. In such
scenario, without proper protection, we could end up adding the same
node twice to the interrupts wait lists.
Signed-off-by: default avatarfarah kassabri <fkassabri@habana.ai>
Reviewed-by: default avatarOded Gabbay <ogabbay@kernel.org>
Signed-off-by: default avatarOded Gabbay <ogabbay@kernel.org>
parent 964b1f67
......@@ -31,6 +31,25 @@ enum hl_cs_wait_status {
CS_WAIT_STATUS_GONE
};
/*
* Data used while handling wait/timestamp nodes.
* The purpose of this struct is to store the needed data for both operations
* in one variable instead of passing large number of arguments to functions.
*/
struct wait_interrupt_data {
struct hl_user_interrupt *interrupt;
struct hl_mmap_mem_buf *buf;
struct hl_mem_mgr *mmg;
struct hl_cb *cq_cb;
u64 ts_handle;
u64 ts_offset;
u64 cq_handle;
u64 cq_offset;
u64 target_value;
u64 intr_timeout_us;
unsigned long flags;
};
static void job_wq_completion(struct work_struct *work);
static int _hl_cs_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx, u64 timeout_us, u64 seq,
enum hl_cs_wait_status *status, s64 *timestamp);
......@@ -3197,133 +3216,181 @@ static int hl_cs_wait_ioctl(struct hl_fpriv *hpriv, void *data)
return 0;
}
static int ts_buff_get_kernel_ts_record(struct hl_mmap_mem_buf *buf,
struct hl_cb *cq_cb,
u64 ts_offset, u64 cq_offset, u64 target_value,
spinlock_t *wait_list_lock,
struct hl_user_pending_interrupt **pend)
static inline void set_record_cq_info(struct hl_user_pending_interrupt *record,
struct hl_cb *cq_cb, u32 cq_offset, u32 target_value)
{
struct hl_ts_buff *ts_buff = buf->private;
struct hl_user_pending_interrupt *requested_offset_record =
(struct hl_user_pending_interrupt *)ts_buff->kernel_buff_address +
record->ts_reg_info.cq_cb = cq_cb;
record->cq_kernel_addr = (u64 *) cq_cb->kernel_address + cq_offset;
record->cq_target_value = target_value;
}
static int validate_and_get_ts_record(struct device *dev,
struct hl_ts_buff *ts_buff, u64 ts_offset,
struct hl_user_pending_interrupt **req_event_record)
{
struct hl_user_pending_interrupt *ts_cb_last;
*req_event_record = (struct hl_user_pending_interrupt *)ts_buff->kernel_buff_address +
ts_offset;
struct hl_user_pending_interrupt *cb_last =
(struct hl_user_pending_interrupt *)ts_buff->kernel_buff_address +
ts_cb_last = (struct hl_user_pending_interrupt *)ts_buff->kernel_buff_address +
(ts_buff->kernel_buff_size / sizeof(struct hl_user_pending_interrupt));
unsigned long iter_counter = 0;
u64 current_cq_counter;
ktime_t timestamp;
/* Validate ts_offset not exceeding last max */
if (requested_offset_record >= cb_last) {
dev_err(buf->mmg->dev, "Ts offset exceeds max CB offset(0x%llx)\n",
(u64)(uintptr_t)cb_last);
if (*req_event_record >= ts_cb_last) {
dev_err(dev, "Ts offset(%llu) exceeds max CB offset(0x%llx)\n",
ts_offset, (u64)(uintptr_t)ts_cb_last);
return -EINVAL;
}
timestamp = ktime_get();
return 0;
}
start_over:
spin_lock(wait_list_lock);
static int unregister_timestamp_node(struct hl_device *hdev, struct hl_ctx *ctx,
struct hl_mem_mgr *mmg, u64 ts_handle, u64 ts_offset,
struct hl_user_interrupt *interrupt)
{
struct hl_user_pending_interrupt *req_event_record, *pend, *temp_pend;
struct hl_mmap_mem_buf *buff;
struct hl_ts_buff *ts_buff;
bool ts_rec_found = false;
int rc;
/* Unregister only if we didn't reach the target value
* since in this case there will be no handling in irq context
* and then it's safe to delete the node out of the interrupt list
* then re-use it on other interrupt
*/
if (requested_offset_record->ts_reg_info.in_use) {
current_cq_counter = *requested_offset_record->cq_kernel_addr;
if (current_cq_counter < requested_offset_record->cq_target_value) {
list_del(&requested_offset_record->wait_list_node);
spin_unlock(wait_list_lock);
buff = hl_mmap_mem_buf_get(mmg, ts_handle);
if (!buff) {
dev_err(hdev->dev, "invalid TS buff handle!\n");
return -EINVAL;
}
hl_mmap_mem_buf_put(requested_offset_record->ts_reg_info.buf);
hl_cb_put(requested_offset_record->ts_reg_info.cq_cb);
ts_buff = buff->private;
dev_dbg(buf->mmg->dev,
"ts node removed from interrupt list now can re-use\n");
} else {
dev_dbg(buf->mmg->dev,
"ts node in middle of irq handling\n");
rc = validate_and_get_ts_record(hdev->dev, ts_buff, ts_offset, &req_event_record);
if (rc)
goto put_buf;
/*
* Note: we don't use the ts in_use field here, but we rather scan the list
* because we cannot rely on the user to keep the order of register/unregister calls
* and since we might have races here all the time between the irq and register/unregister
* calls so it safer to lock the list and scan it to find the node.
* If the node found on the list we mark it as not in use and delete it from the list,
* if it's not here then the node was handled already in the irq before we get into
* this ioctl.
*/
spin_lock(&interrupt->wait_list_lock);
/* irq thread handling in the middle give it time to finish */
spin_unlock(wait_list_lock);
usleep_range(100, 1000);
if (++iter_counter == MAX_TS_ITER_NUM) {
dev_err(buf->mmg->dev,
"Timestamp offset processing reached timeout of %lld ms\n",
ktime_ms_delta(ktime_get(), timestamp));
return -EAGAIN;
list_for_each_entry_safe(pend, temp_pend, &interrupt->wait_list_head, wait_list_node) {
if (pend == req_event_record) {
pend->ts_reg_info.in_use = false;
list_del(&pend->wait_list_node);
ts_rec_found = true;
break;
}
}
goto start_over;
spin_unlock(&interrupt->wait_list_lock);
/* Put refcounts that were taken when we registered the event */
if (ts_rec_found) {
hl_mmap_mem_buf_put(pend->ts_reg_info.buf);
hl_cb_put(pend->ts_reg_info.cq_cb);
}
} else {
/* Fill up the new registration node info */
requested_offset_record->ts_reg_info.buf = buf;
requested_offset_record->ts_reg_info.cq_cb = cq_cb;
requested_offset_record->ts_reg_info.timestamp_kernel_addr =
(u64 *) ts_buff->user_buff_address + ts_offset;
requested_offset_record->cq_kernel_addr =
(u64 *) cq_cb->kernel_address + cq_offset;
requested_offset_record->cq_target_value = target_value;
spin_unlock(wait_list_lock);
put_buf:
hl_mmap_mem_buf_put(buff);
return rc;
}
static int ts_get_and_handle_kernel_record(struct hl_device *hdev, struct hl_ctx *ctx,
struct wait_interrupt_data *data,
struct hl_user_pending_interrupt **pend)
{
struct hl_user_pending_interrupt *req_offset_record;
struct hl_ts_buff *ts_buff = data->buf->private;
int rc;
rc = validate_and_get_ts_record(data->buf->mmg->dev, ts_buff, data->ts_offset,
&req_offset_record);
if (rc)
return rc;
/* In case the node already registered, need to unregister first then re-use*/
if (req_offset_record->ts_reg_info.in_use) {
dev_dbg(data->buf->mmg->dev,
"Requested ts offset(%llx) is in use, unregister first\n",
data->ts_offset);
/*
* Since interrupt here can be different than the one the node currently registered
* on, and we don't wan't to lock two lists while we're doing unregister, so
* unlock the new interrupt wait list here and acquire the lock again after you done
*/
spin_unlock_irqrestore(&data->interrupt->wait_list_lock, data->flags);
unregister_timestamp_node(hdev, ctx, data->mmg, data->ts_handle,
data->ts_offset, req_offset_record->ts_reg_info.interrupt);
spin_lock_irqsave(&data->interrupt->wait_list_lock, data->flags);
}
*pend = requested_offset_record;
/* Fill up the new registration node info and add it to the list */
req_offset_record->ts_reg_info.in_use = true;
req_offset_record->ts_reg_info.buf = data->buf;
req_offset_record->ts_reg_info.timestamp_kernel_addr =
(u64 *) ts_buff->user_buff_address + data->ts_offset;
req_offset_record->ts_reg_info.interrupt = data->interrupt;
set_record_cq_info(req_offset_record, data->cq_cb, data->cq_offset,
data->target_value);
dev_dbg(buf->mmg->dev, "Found available node in TS kernel CB %p\n",
requested_offset_record);
return 0;
*pend = req_offset_record;
return rc;
}
static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
struct hl_mem_mgr *cb_mmg, struct hl_mem_mgr *mmg,
u64 timeout_us, u64 cq_counters_handle, u64 cq_counters_offset,
u64 target_value, struct hl_user_interrupt *interrupt,
bool register_ts_record, u64 ts_handle, u64 ts_offset,
struct wait_interrupt_data *data,
bool register_ts_record,
u32 *status, u64 *timestamp)
{
struct hl_user_pending_interrupt *pend;
struct hl_mmap_mem_buf *buf;
struct hl_cb *cq_cb;
unsigned long timeout;
long completion_rc;
int rc = 0;
timeout = hl_usecs64_to_jiffies(timeout_us);
timeout = hl_usecs64_to_jiffies(data->intr_timeout_us);
hl_ctx_get(ctx);
cq_cb = hl_cb_get(cb_mmg, cq_counters_handle);
if (!cq_cb) {
data->cq_cb = hl_cb_get(data->mmg, data->cq_handle);
if (!data->cq_cb) {
rc = -EINVAL;
goto put_ctx;
}
/* Validate the cq offset */
if (((u64 *) cq_cb->kernel_address + cq_counters_offset) >=
((u64 *) cq_cb->kernel_address + (cq_cb->size / sizeof(u64)))) {
if (((u64 *) data->cq_cb->kernel_address + data->cq_offset) >=
((u64 *) data->cq_cb->kernel_address + (data->cq_cb->size / sizeof(u64)))) {
rc = -EINVAL;
goto put_cq_cb;
}
if (register_ts_record) {
dev_dbg(hdev->dev, "Timestamp registration: interrupt id: %u, ts offset: %llu, cq_offset: %llu\n",
interrupt->interrupt_id, ts_offset, cq_counters_offset);
buf = hl_mmap_mem_buf_get(mmg, ts_handle);
if (!buf) {
dev_dbg(hdev->dev, "Timestamp registration: interrupt id: %u, handle: 0x%llx, ts offset: %llu, cq_offset: %llu\n",
data->interrupt->interrupt_id, data->ts_handle,
data->ts_offset, data->cq_offset);
data->buf = hl_mmap_mem_buf_get(data->mmg, data->ts_handle);
if (!data->buf) {
rc = -EINVAL;
goto put_cq_cb;
}
spin_lock_irqsave(&data->interrupt->wait_list_lock, data->flags);
/* get ts buffer record */
rc = ts_buff_get_kernel_ts_record(buf, cq_cb, ts_offset,
cq_counters_offset, target_value,
&interrupt->wait_list_lock, &pend);
if (rc)
rc = ts_get_and_handle_kernel_record(hdev, ctx, data, &pend);
if (rc) {
spin_unlock_irqrestore(&data->interrupt->wait_list_lock, data->flags);
goto put_ts_buff;
}
} else {
pend = kzalloc(sizeof(*pend), GFP_KERNEL);
if (!pend) {
......@@ -3331,19 +3398,22 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
goto put_cq_cb;
}
hl_fence_init(&pend->fence, ULONG_MAX);
pend->cq_kernel_addr = (u64 *) cq_cb->kernel_address + cq_counters_offset;
pend->cq_target_value = target_value;
pend->cq_kernel_addr = (u64 *) data->cq_cb->kernel_address + data->cq_offset;
pend->cq_target_value = data->target_value;
spin_lock_irqsave(&data->interrupt->wait_list_lock, data->flags);
}
spin_lock(&interrupt->wait_list_lock);
/* We check for completion value as interrupt could have been received
* before we added the node to the wait list
* before we add the wait/timestamp node to the wait list.
*/
if (*pend->cq_kernel_addr >= target_value) {
if (register_ts_record)
pend->ts_reg_info.in_use = 0;
spin_unlock(&interrupt->wait_list_lock);
if (*pend->cq_kernel_addr >= data->target_value) {
spin_unlock_irqrestore(&data->interrupt->wait_list_lock, data->flags);
if (register_ts_record) {
dev_dbg(hdev->dev, "Target value already reached release ts record: pend: %p, offset: %llu, interrupt: %u\n",
pend, data->ts_offset, data->interrupt->interrupt_id);
pend->ts_reg_info.in_use = false;
}
*status = HL_WAIT_CS_STATUS_COMPLETED;
......@@ -3354,8 +3424,8 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
pend->fence.timestamp = ktime_get();
goto set_timestamp;
}
} else if (!timeout_us) {
spin_unlock(&interrupt->wait_list_lock);
} else if (!data->intr_timeout_us) {
spin_unlock_irqrestore(&data->interrupt->wait_list_lock, data->flags);
*status = HL_WAIT_CS_STATUS_BUSY;
pend->fence.timestamp = ktime_get();
goto set_timestamp;
......@@ -3366,21 +3436,9 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
* Note that we cannot have sorted list by target value,
* in order to shorten the list pass loop, since
* same list could have nodes for different cq counter handle.
* Note:
* Mark ts buff offset as in use here in the spinlock protection area
* to avoid getting in the re-use section in ts_buff_get_kernel_ts_record
* before adding the node to the list. this scenario might happen when
* multiple threads are racing on same offset and one thread could
* set the ts buff in ts_buff_get_kernel_ts_record then the other thread
* takes over and get to ts_buff_get_kernel_ts_record and then we will try
* to re-use the same ts buff offset, and will try to delete a non existing
* node from the list.
*/
if (register_ts_record)
pend->ts_reg_info.in_use = 1;
list_add_tail(&pend->wait_list_node, &interrupt->wait_list_head);
spin_unlock(&interrupt->wait_list_lock);
*/
list_add_tail(&pend->wait_list_node, &data->interrupt->wait_list_head);
spin_unlock_irqrestore(&data->interrupt->wait_list_lock, data->flags);
if (register_ts_record) {
rc = *status = HL_WAIT_CS_STATUS_COMPLETED;
......@@ -3396,7 +3454,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
if (completion_rc == -ERESTARTSYS) {
dev_err_ratelimited(hdev->dev,
"user process got signal while waiting for interrupt ID %d\n",
interrupt->interrupt_id);
data->interrupt->interrupt_id);
rc = -EINTR;
*status = HL_WAIT_CS_STATUS_ABORTED;
} else {
......@@ -3424,23 +3482,23 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
* for ts record, the node will be deleted in the irq handler after
* we reach the target value.
*/
spin_lock(&interrupt->wait_list_lock);
spin_lock_irqsave(&data->interrupt->wait_list_lock, data->flags);
list_del(&pend->wait_list_node);
spin_unlock(&interrupt->wait_list_lock);
spin_unlock_irqrestore(&data->interrupt->wait_list_lock, data->flags);
set_timestamp:
*timestamp = ktime_to_ns(pend->fence.timestamp);
kfree(pend);
hl_cb_put(cq_cb);
hl_cb_put(data->cq_cb);
ts_registration_exit:
hl_ctx_put(ctx);
return rc;
put_ts_buff:
hl_mmap_mem_buf_put(buf);
hl_mmap_mem_buf_put(data->buf);
put_cq_cb:
hl_cb_put(cq_cb);
hl_cb_put(data->cq_cb);
put_ctx:
hl_ctx_put(ctx);
......@@ -3611,19 +3669,41 @@ static int hl_interrupt_wait_ioctl(struct hl_fpriv *hpriv, void *data)
return -EINVAL;
}
if (args->in.flags & HL_WAIT_CS_FLAGS_INTERRUPT_KERNEL_CQ)
rc = _hl_interrupt_wait_ioctl(hdev, hpriv->ctx, &hpriv->mem_mgr, &hpriv->mem_mgr,
args->in.interrupt_timeout_us, args->in.cq_counters_handle,
args->in.cq_counters_offset,
args->in.target, interrupt,
/*
* Allow only one registration at a time. this is needed in order to prevent issues
* while handling the flow of re-use of the same offset.
* Since the registration flow is protected only by the interrupt lock, re-use flow
* might request to move ts node to another interrupt list, and in such case we're
* not protected.
*/
if (args->in.flags & HL_WAIT_CS_FLAGS_REGISTER_INTERRUPT)
mutex_lock(&hpriv->ctx->ts_reg_lock);
if (args->in.flags & HL_WAIT_CS_FLAGS_INTERRUPT_KERNEL_CQ) {
struct wait_interrupt_data wait_intr_data = {0};
wait_intr_data.interrupt = interrupt;
wait_intr_data.mmg = &hpriv->mem_mgr;
wait_intr_data.cq_handle = args->in.cq_counters_handle;
wait_intr_data.cq_offset = args->in.cq_counters_offset;
wait_intr_data.ts_handle = args->in.timestamp_handle;
wait_intr_data.ts_offset = args->in.timestamp_offset;
wait_intr_data.target_value = args->in.target;
wait_intr_data.intr_timeout_us = args->in.interrupt_timeout_us;
rc = _hl_interrupt_wait_ioctl(hdev, hpriv->ctx, &wait_intr_data,
!!(args->in.flags & HL_WAIT_CS_FLAGS_REGISTER_INTERRUPT),
args->in.timestamp_handle, args->in.timestamp_offset,
&status, &timestamp);
else
} else {
rc = _hl_interrupt_wait_ioctl_user_addr(hdev, hpriv->ctx,
args->in.interrupt_timeout_us, args->in.addr,
args->in.target, interrupt, &status,
&timestamp);
}
if (args->in.flags & HL_WAIT_CS_FLAGS_REGISTER_INTERRUPT)
mutex_unlock(&hpriv->ctx->ts_reg_lock);
if (rc)
return rc;
......
......@@ -119,6 +119,7 @@ static void hl_ctx_fini(struct hl_ctx *ctx)
hl_vm_ctx_fini(ctx);
hl_asid_free(hdev, ctx->asid);
hl_encaps_sig_mgr_fini(hdev, &ctx->sig_mgr);
mutex_destroy(&ctx->ts_reg_lock);
} else {
dev_dbg(hdev->dev, "closing kernel context\n");
hdev->asic_funcs->ctx_fini(ctx);
......@@ -268,6 +269,8 @@ int hl_ctx_init(struct hl_device *hdev, struct hl_ctx *ctx, bool is_kernel_ctx)
hl_encaps_sig_mgr_init(&ctx->sig_mgr);
mutex_init(&ctx->ts_reg_lock);
dev_dbg(hdev->dev, "create user context, comm=\"%s\", asid=%u\n",
get_task_comm(task_comm, current), ctx->asid);
}
......
......@@ -1144,6 +1144,7 @@ struct timestamp_reg_work_obj {
* @buf: pointer to the timestamp buffer which include both user/kernel buffers.
* relevant only when doing timestamps records registration.
* @cq_cb: pointer to CQ counter CB.
* @interrupt: interrupt that the node hanged on it's wait list.
* @timestamp_kernel_addr: timestamp handle address, where to set timestamp
* relevant only when doing timestamps records
* registration.
......@@ -1155,8 +1156,9 @@ struct timestamp_reg_work_obj {
struct timestamp_reg_info {
struct hl_mmap_mem_buf *buf;
struct hl_cb *cq_cb;
struct hl_user_interrupt *interrupt;
u64 *timestamp_kernel_addr;
u8 in_use;
bool in_use;
};
/**
......@@ -1835,6 +1837,7 @@ struct hl_cs_outcome_store {
* @va_range: holds available virtual addresses for host and dram mappings.
* @mem_hash_lock: protects the mem_hash.
* @hw_block_list_lock: protects the HW block memory list.
* @ts_reg_lock: timestamp registration ioctls lock.
* @debugfs_list: node in debugfs list of contexts.
* @hw_block_mem_list: list of HW block virtual mapped addresses.
* @cs_counters: context command submission counters.
......@@ -1871,6 +1874,7 @@ struct hl_ctx {
struct hl_va_range *va_range[HL_VA_RANGE_TYPE_MAX];
struct mutex mem_hash_lock;
struct mutex hw_block_list_lock;
struct mutex ts_reg_lock;
struct list_head debugfs_list;
struct list_head hw_block_mem_list;
struct hl_cs_counters_atomic cs_counters;
......
......@@ -233,7 +233,8 @@ static void hl_ts_free_objects(struct work_struct *work)
* list to a dedicated workqueue to do the actual put.
*/
static int handle_registration_node(struct hl_device *hdev, struct hl_user_pending_interrupt *pend,
struct list_head **free_list, ktime_t now)
struct list_head **free_list, ktime_t now,
u32 interrupt_id)
{
struct timestamp_reg_free_node *free_node;
u64 timestamp;
......@@ -255,14 +256,12 @@ static int handle_registration_node(struct hl_device *hdev, struct hl_user_pendi
*pend->ts_reg_info.timestamp_kernel_addr = timestamp;
dev_dbg(hdev->dev, "Timestamp is set to ts cb address (%p), ts: 0x%llx\n",
pend->ts_reg_info.timestamp_kernel_addr,
*(u64 *)pend->ts_reg_info.timestamp_kernel_addr);
list_del(&pend->wait_list_node);
dev_dbg(hdev->dev, "Irq handle: Timestamp record (%p) ts cb address (%p), interrupt_id: %u\n",
pend, pend->ts_reg_info.timestamp_kernel_addr, interrupt_id);
/* Mark kernel CB node as free */
pend->ts_reg_info.in_use = 0;
pend->ts_reg_info.in_use = false;
list_del(&pend->wait_list_node);
/* Putting the refcount for ts_buff and cq_cb objects will be handled
* in workqueue context, just add job to free_list.
......@@ -296,13 +295,15 @@ static void handle_user_interrupt(struct hl_device *hdev, struct hl_user_interru
return;
spin_lock(&intr->wait_list_lock);
list_for_each_entry_safe(pend, temp_pend, &intr->wait_list_head, wait_list_node) {
if ((pend->cq_kernel_addr && *(pend->cq_kernel_addr) >= pend->cq_target_value) ||
!pend->cq_kernel_addr) {
if (pend->ts_reg_info.buf) {
if (!reg_node_handle_fail) {
rc = handle_registration_node(hdev, pend,
&ts_reg_free_list_head, intr->timestamp);
&ts_reg_free_list_head, intr->timestamp,
intr->interrupt_id);
if (rc)
reg_node_handle_fail = true;
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment