Commit 81ddb41f authored by Jason Gunthorpe's avatar Jason Gunthorpe

RDMA/cm: Allow ib_send_cm_rej() to be done under lock

The first thing ib_send_cm_rej() does is obtain the lock, so use the usual
unlocked wrapper, locked actor pattern here.

This avoids a sketchy lock/unlock sequence (which could allow state to
change) during cm_destroy_id().

While here simplify some of the logic in the implementation.

Link: https://lore.kernel.org/r/20200310092545.251365-14-leon@kernel.orgSigned-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 87cabf3e
...@@ -87,6 +87,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv, ...@@ -87,6 +87,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
const void *private_data, u8 private_data_len); const void *private_data, u8 private_data_len);
static int cm_send_drep_locked(struct cm_id_private *cm_id_priv, static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
void *private_data, u8 private_data_len); void *private_data, u8 private_data_len);
static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
enum ib_cm_rej_reason reason, void *ari,
u8 ari_length, const void *private_data,
u8 private_data_len);
static struct ib_client cm_client = { static struct ib_client cm_client = {
.name = "cm", .name = "cm",
...@@ -1060,11 +1064,11 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err) ...@@ -1060,11 +1064,11 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err)
case IB_CM_REQ_SENT: case IB_CM_REQ_SENT:
case IB_CM_MRA_REQ_RCVD: case IB_CM_MRA_REQ_RCVD:
ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg); ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
spin_unlock_irq(&cm_id_priv->lock); cm_send_rej_locked(cm_id_priv, IB_CM_REJ_TIMEOUT,
ib_send_cm_rej(cm_id, IB_CM_REJ_TIMEOUT,
&cm_id_priv->id.device->node_guid, &cm_id_priv->id.device->node_guid,
sizeof cm_id_priv->id.device->node_guid, sizeof(cm_id_priv->id.device->node_guid),
NULL, 0); NULL, 0);
spin_unlock_irq(&cm_id_priv->lock);
break; break;
case IB_CM_REQ_RCVD: case IB_CM_REQ_RCVD:
if (err == -ENOMEM) { if (err == -ENOMEM) {
...@@ -1072,9 +1076,10 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err) ...@@ -1072,9 +1076,10 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err)
cm_reset_to_idle(cm_id_priv); cm_reset_to_idle(cm_id_priv);
spin_unlock_irq(&cm_id_priv->lock); spin_unlock_irq(&cm_id_priv->lock);
} else { } else {
cm_send_rej_locked(cm_id_priv,
IB_CM_REJ_CONSUMER_DEFINED, NULL, 0,
NULL, 0);
spin_unlock_irq(&cm_id_priv->lock); spin_unlock_irq(&cm_id_priv->lock);
ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
NULL, 0, NULL, 0);
} }
break; break;
case IB_CM_REP_SENT: case IB_CM_REP_SENT:
...@@ -1084,9 +1089,9 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err) ...@@ -1084,9 +1089,9 @@ static void cm_destroy_id(struct ib_cm_id *cm_id, int err)
case IB_CM_MRA_REQ_SENT: case IB_CM_MRA_REQ_SENT:
case IB_CM_REP_RCVD: case IB_CM_REP_RCVD:
case IB_CM_MRA_REP_SENT: case IB_CM_MRA_REP_SENT:
cm_send_rej_locked(cm_id_priv, IB_CM_REJ_CONSUMER_DEFINED, NULL,
0, NULL, 0);
spin_unlock_irq(&cm_id_priv->lock); spin_unlock_irq(&cm_id_priv->lock);
ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
NULL, 0, NULL, 0);
break; break;
case IB_CM_ESTABLISHED: case IB_CM_ESTABLISHED:
if (cm_id_priv->qp_type == IB_QPT_XRC_TGT) { if (cm_id_priv->qp_type == IB_QPT_XRC_TGT) {
...@@ -2899,65 +2904,72 @@ static int cm_drep_handler(struct cm_work *work) ...@@ -2899,65 +2904,72 @@ static int cm_drep_handler(struct cm_work *work)
return -EINVAL; return -EINVAL;
} }
int ib_send_cm_rej(struct ib_cm_id *cm_id, static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
enum ib_cm_rej_reason reason, enum ib_cm_rej_reason reason, void *ari,
void *ari, u8 ari_length, const void *private_data,
u8 ari_length,
const void *private_data,
u8 private_data_len) u8 private_data_len)
{ {
struct cm_id_private *cm_id_priv;
struct ib_mad_send_buf *msg; struct ib_mad_send_buf *msg;
unsigned long flags;
int ret; int ret;
lockdep_assert_held(&cm_id_priv->lock);
if ((private_data && private_data_len > IB_CM_REJ_PRIVATE_DATA_SIZE) || if ((private_data && private_data_len > IB_CM_REJ_PRIVATE_DATA_SIZE) ||
(ari && ari_length > IB_CM_REJ_ARI_LENGTH)) (ari && ari_length > IB_CM_REJ_ARI_LENGTH))
return -EINVAL; return -EINVAL;
cm_id_priv = container_of(cm_id, struct cm_id_private, id); switch (cm_id_priv->id.state) {
spin_lock_irqsave(&cm_id_priv->lock, flags);
switch (cm_id->state) {
case IB_CM_REQ_SENT: case IB_CM_REQ_SENT:
case IB_CM_MRA_REQ_RCVD: case IB_CM_MRA_REQ_RCVD:
case IB_CM_REQ_RCVD: case IB_CM_REQ_RCVD:
case IB_CM_MRA_REQ_SENT: case IB_CM_MRA_REQ_SENT:
case IB_CM_REP_RCVD: case IB_CM_REP_RCVD:
case IB_CM_MRA_REP_SENT: case IB_CM_MRA_REP_SENT:
ret = cm_alloc_msg(cm_id_priv, &msg);
if (!ret)
cm_format_rej((struct cm_rej_msg *) msg->mad,
cm_id_priv, reason, ari, ari_length,
private_data, private_data_len);
cm_reset_to_idle(cm_id_priv); cm_reset_to_idle(cm_id_priv);
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
ari, ari_length, private_data, private_data_len);
break; break;
case IB_CM_REP_SENT: case IB_CM_REP_SENT:
case IB_CM_MRA_REP_RCVD: case IB_CM_MRA_REP_RCVD:
ret = cm_alloc_msg(cm_id_priv, &msg);
if (!ret)
cm_format_rej((struct cm_rej_msg *) msg->mad,
cm_id_priv, reason, ari, ari_length,
private_data, private_data_len);
cm_enter_timewait(cm_id_priv); cm_enter_timewait(cm_id_priv);
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
ari, ari_length, private_data, private_data_len);
break; break;
default: default:
pr_debug("%s: local_id %d, cm_id->state: %d\n", __func__, pr_debug("%s: local_id %d, cm_id->state: %d\n", __func__,
be32_to_cpu(cm_id_priv->id.local_id), cm_id->state); be32_to_cpu(cm_id_priv->id.local_id),
ret = -EINVAL; cm_id_priv->id.state);
goto out; return -EINVAL;
} }
if (ret)
goto out;
ret = ib_post_send_mad(msg, NULL); ret = ib_post_send_mad(msg, NULL);
if (ret) if (ret) {
cm_free_msg(msg); cm_free_msg(msg);
return ret;
}
out: spin_unlock_irqrestore(&cm_id_priv->lock, flags); return 0;
}
int ib_send_cm_rej(struct ib_cm_id *cm_id, enum ib_cm_rej_reason reason,
void *ari, u8 ari_length, const void *private_data,
u8 private_data_len)
{
struct cm_id_private *cm_id_priv =
container_of(cm_id, struct cm_id_private, id);
unsigned long flags;
int ret;
spin_lock_irqsave(&cm_id_priv->lock, flags);
ret = cm_send_rej_locked(cm_id_priv, reason, ari, ari_length,
private_data, private_data_len);
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return ret; return ret;
} }
EXPORT_SYMBOL(ib_send_cm_rej); EXPORT_SYMBOL(ib_send_cm_rej);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment