Commit 76039ac9 authored by Mark Zhang's avatar Mark Zhang Committed by Jason Gunthorpe

IB/cm: Protect cm_dev, cm_ports and mad_agent with kref and lock

During cm_dev deregistration in cm_remove_one(), the cm_device and
cm_ports will be freed, after that they should not be accessed. The
mad_agent needs to be protected as well.

This patch adds a cm_device kref to protect cm_dev and cm_ports, and a
mad_agent_lock spinlock to protect mad_agent.

Link: https://lore.kernel.org/r/501ba7a2ff203dccd0e6755d3f93329772adce52.1622629024.git.leonro@nvidia.comSigned-off-by: default avatarMark Zhang <markzhang@nvidia.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@nvidia.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@nvidia.com>
parent 7345201c
...@@ -205,7 +205,9 @@ struct cm_port { ...@@ -205,7 +205,9 @@ struct cm_port {
}; };
struct cm_device { struct cm_device {
struct kref kref;
struct list_head list; struct list_head list;
spinlock_t mad_agent_lock;
struct ib_device *ib_device; struct ib_device *ib_device;
u8 ack_delay; u8 ack_delay;
int going_down; int going_down;
...@@ -287,6 +289,22 @@ struct cm_id_private { ...@@ -287,6 +289,22 @@ struct cm_id_private {
struct rdma_ucm_ece ece; struct rdma_ucm_ece ece;
}; };
static void cm_dev_release(struct kref *kref)
{
struct cm_device *cm_dev = container_of(kref, struct cm_device, kref);
u32 i;
rdma_for_each_port(cm_dev->ib_device, i)
kfree(cm_dev->port[i - 1]);
kfree(cm_dev);
}
static void cm_device_put(struct cm_device *cm_dev)
{
kref_put(&cm_dev->kref, cm_dev_release);
}
static void cm_work_handler(struct work_struct *work); static void cm_work_handler(struct work_struct *work);
static inline void cm_deref_id(struct cm_id_private *cm_id_priv) static inline void cm_deref_id(struct cm_id_private *cm_id_priv)
...@@ -301,10 +319,23 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) ...@@ -301,10 +319,23 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
struct ib_mad_send_buf *m; struct ib_mad_send_buf *m;
struct ib_ah *ah; struct ib_ah *ah;
lockdep_assert_held(&cm_id_priv->lock);
if (!cm_id_priv->av.port)
return ERR_PTR(-EINVAL);
spin_lock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
mad_agent = cm_id_priv->av.port->mad_agent; mad_agent = cm_id_priv->av.port->mad_agent;
if (!mad_agent) {
m = ERR_PTR(-EINVAL);
goto out;
}
ah = rdma_create_ah(mad_agent->qp->pd, &cm_id_priv->av.ah_attr, 0); ah = rdma_create_ah(mad_agent->qp->pd, &cm_id_priv->av.ah_attr, 0);
if (IS_ERR(ah)) if (IS_ERR(ah)) {
return (void *)ah; m = ERR_CAST(ah);
goto out;
}
m = ib_create_send_mad(mad_agent, cm_id_priv->id.remote_cm_qpn, m = ib_create_send_mad(mad_agent, cm_id_priv->id.remote_cm_qpn,
cm_id_priv->av.pkey_index, cm_id_priv->av.pkey_index,
...@@ -313,7 +344,7 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) ...@@ -313,7 +344,7 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
IB_MGMT_BASE_VERSION); IB_MGMT_BASE_VERSION);
if (IS_ERR(m)) { if (IS_ERR(m)) {
rdma_destroy_ah(ah, 0); rdma_destroy_ah(ah, 0);
return m; goto out;
} }
/* Timeout set by caller if response is expected. */ /* Timeout set by caller if response is expected. */
...@@ -322,6 +353,9 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv) ...@@ -322,6 +353,9 @@ static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
refcount_inc(&cm_id_priv->refcount); refcount_inc(&cm_id_priv->refcount);
m->context[0] = cm_id_priv; m->context[0] = cm_id_priv;
out:
spin_unlock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
return m; return m;
} }
...@@ -440,10 +474,24 @@ static void cm_set_private_data(struct cm_id_private *cm_id_priv, ...@@ -440,10 +474,24 @@ static void cm_set_private_data(struct cm_id_private *cm_id_priv,
cm_id_priv->private_data_len = private_data_len; cm_id_priv->private_data_len = private_data_len;
} }
static void cm_set_av_port(struct cm_av *av, struct cm_port *port)
{
struct cm_port *old_port = av->port;
if (old_port == port)
return;
av->port = port;
if (old_port)
cm_device_put(old_port->cm_dev);
if (port)
kref_get(&port->cm_dev->kref);
}
static void cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc, static void cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc,
struct rdma_ah_attr *ah_attr, struct cm_av *av) struct rdma_ah_attr *ah_attr, struct cm_av *av)
{ {
av->port = port; cm_set_av_port(av, port);
av->pkey_index = wc->pkey_index; av->pkey_index = wc->pkey_index;
rdma_move_ah_attr(&av->ah_attr, ah_attr); rdma_move_ah_attr(&av->ah_attr, ah_attr);
} }
...@@ -451,7 +499,7 @@ static void cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc, ...@@ -451,7 +499,7 @@ static void cm_init_av_for_lap(struct cm_port *port, struct ib_wc *wc,
static int cm_init_av_for_response(struct cm_port *port, struct ib_wc *wc, static int cm_init_av_for_response(struct cm_port *port, struct ib_wc *wc,
struct ib_grh *grh, struct cm_av *av) struct ib_grh *grh, struct cm_av *av)
{ {
av->port = port; cm_set_av_port(av, port);
av->pkey_index = wc->pkey_index; av->pkey_index = wc->pkey_index;
return ib_init_ah_attr_from_wc(port->cm_dev->ib_device, return ib_init_ah_attr_from_wc(port->cm_dev->ib_device,
port->port_num, wc, port->port_num, wc,
...@@ -518,7 +566,7 @@ static int cm_init_av_by_path(struct sa_path_rec *path, ...@@ -518,7 +566,7 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
if (ret) if (ret)
return ret; return ret;
av->port = port; cm_set_av_port(av, port);
/* /*
* av->ah_attr might be initialized based on wc or during * av->ah_attr might be initialized based on wc or during
...@@ -542,7 +590,8 @@ static int cm_init_av_by_path(struct sa_path_rec *path, ...@@ -542,7 +590,8 @@ static int cm_init_av_by_path(struct sa_path_rec *path,
/* Move av created by cm_init_av_by_path(), so av.dgid is not moved */ /* Move av created by cm_init_av_by_path(), so av.dgid is not moved */
static void cm_move_av_from_path(struct cm_av *dest, struct cm_av *src) static void cm_move_av_from_path(struct cm_av *dest, struct cm_av *src)
{ {
dest->port = src->port; cm_set_av_port(dest, src->port);
cm_set_av_port(src, NULL);
dest->pkey_index = src->pkey_index; dest->pkey_index = src->pkey_index;
rdma_move_ah_attr(&dest->ah_attr, &src->ah_attr); rdma_move_ah_attr(&dest->ah_attr, &src->ah_attr);
dest->timeout = src->timeout; dest->timeout = src->timeout;
...@@ -551,6 +600,7 @@ static void cm_move_av_from_path(struct cm_av *dest, struct cm_av *src) ...@@ -551,6 +600,7 @@ static void cm_move_av_from_path(struct cm_av *dest, struct cm_av *src)
static void cm_destroy_av(struct cm_av *av) static void cm_destroy_av(struct cm_av *av)
{ {
rdma_destroy_ah_attr(&av->ah_attr); rdma_destroy_ah_attr(&av->ah_attr);
cm_set_av_port(av, NULL);
} }
static u32 cm_local_id(__be32 local_id) static u32 cm_local_id(__be32 local_id)
...@@ -1275,10 +1325,18 @@ EXPORT_SYMBOL(ib_cm_insert_listen); ...@@ -1275,10 +1325,18 @@ EXPORT_SYMBOL(ib_cm_insert_listen);
static __be64 cm_form_tid(struct cm_id_private *cm_id_priv) static __be64 cm_form_tid(struct cm_id_private *cm_id_priv)
{ {
u64 hi_tid, low_tid; u64 hi_tid = 0, low_tid;
lockdep_assert_held(&cm_id_priv->lock);
low_tid = (u64)cm_id_priv->id.local_id;
if (!cm_id_priv->av.port)
return cpu_to_be64(low_tid);
hi_tid = ((u64) cm_id_priv->av.port->mad_agent->hi_tid) << 32; spin_lock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
low_tid = (u64)cm_id_priv->id.local_id; if (cm_id_priv->av.port->mad_agent)
hi_tid = ((u64)cm_id_priv->av.port->mad_agent->hi_tid) << 32;
spin_unlock(&cm_id_priv->av.port->cm_dev->mad_agent_lock);
return cpu_to_be64(hi_tid | low_tid); return cpu_to_be64(hi_tid | low_tid);
} }
...@@ -2139,6 +2197,9 @@ static int cm_req_handler(struct cm_work *work) ...@@ -2139,6 +2197,9 @@ static int cm_req_handler(struct cm_work *work)
sa_path_set_dmac(&work->path[0], sa_path_set_dmac(&work->path[0],
cm_id_priv->av.ah_attr.roce.dmac); cm_id_priv->av.ah_attr.roce.dmac);
work->path[0].hop_limit = grh->hop_limit; work->path[0].hop_limit = grh->hop_limit;
/* This destroy call is needed to pair with cm_init_av_for_response */
cm_destroy_av(&cm_id_priv->av);
ret = cm_init_av_by_path(&work->path[0], gid_attr, &cm_id_priv->av); ret = cm_init_av_by_path(&work->path[0], gid_attr, &cm_id_priv->av);
if (ret) { if (ret) {
int err; int err;
...@@ -4090,7 +4151,8 @@ static int cm_init_qp_init_attr(struct cm_id_private *cm_id_priv, ...@@ -4090,7 +4151,8 @@ static int cm_init_qp_init_attr(struct cm_id_private *cm_id_priv,
qp_attr->qp_access_flags |= IB_ACCESS_REMOTE_READ | qp_attr->qp_access_flags |= IB_ACCESS_REMOTE_READ |
IB_ACCESS_REMOTE_ATOMIC; IB_ACCESS_REMOTE_ATOMIC;
qp_attr->pkey_index = cm_id_priv->av.pkey_index; qp_attr->pkey_index = cm_id_priv->av.pkey_index;
qp_attr->port_num = cm_id_priv->av.port->port_num; if (cm_id_priv->av.port)
qp_attr->port_num = cm_id_priv->av.port->port_num;
ret = 0; ret = 0;
break; break;
default: default:
...@@ -4132,7 +4194,8 @@ static int cm_init_qp_rtr_attr(struct cm_id_private *cm_id_priv, ...@@ -4132,7 +4194,8 @@ static int cm_init_qp_rtr_attr(struct cm_id_private *cm_id_priv,
cm_id_priv->responder_resources; cm_id_priv->responder_resources;
qp_attr->min_rnr_timer = 0; qp_attr->min_rnr_timer = 0;
} }
if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr)) { if (rdma_ah_get_dlid(&cm_id_priv->alt_av.ah_attr) &&
cm_id_priv->alt_av.port) {
*qp_attr_mask |= IB_QP_ALT_PATH; *qp_attr_mask |= IB_QP_ALT_PATH;
qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num; qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num;
qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index; qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index;
...@@ -4193,7 +4256,9 @@ static int cm_init_qp_rts_attr(struct cm_id_private *cm_id_priv, ...@@ -4193,7 +4256,9 @@ static int cm_init_qp_rts_attr(struct cm_id_private *cm_id_priv,
} }
} else { } else {
*qp_attr_mask = IB_QP_ALT_PATH | IB_QP_PATH_MIG_STATE; *qp_attr_mask = IB_QP_ALT_PATH | IB_QP_PATH_MIG_STATE;
qp_attr->alt_port_num = cm_id_priv->alt_av.port->port_num; if (cm_id_priv->alt_av.port)
qp_attr->alt_port_num =
cm_id_priv->alt_av.port->port_num;
qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index; qp_attr->alt_pkey_index = cm_id_priv->alt_av.pkey_index;
qp_attr->alt_timeout = cm_id_priv->alt_av.timeout; qp_attr->alt_timeout = cm_id_priv->alt_av.timeout;
qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr; qp_attr->alt_ah_attr = cm_id_priv->alt_av.ah_attr;
...@@ -4311,6 +4376,8 @@ static int cm_add_one(struct ib_device *ib_device) ...@@ -4311,6 +4376,8 @@ static int cm_add_one(struct ib_device *ib_device)
if (!cm_dev) if (!cm_dev)
return -ENOMEM; return -ENOMEM;
kref_init(&cm_dev->kref);
spin_lock_init(&cm_dev->mad_agent_lock);
cm_dev->ib_device = ib_device; cm_dev->ib_device = ib_device;
cm_dev->ack_delay = ib_device->attrs.local_ca_ack_delay; cm_dev->ack_delay = ib_device->attrs.local_ca_ack_delay;
cm_dev->going_down = 0; cm_dev->going_down = 0;
...@@ -4373,7 +4440,6 @@ static int cm_add_one(struct ib_device *ib_device) ...@@ -4373,7 +4440,6 @@ static int cm_add_one(struct ib_device *ib_device)
error1: error1:
port_modify.set_port_cap_mask = 0; port_modify.set_port_cap_mask = 0;
port_modify.clr_port_cap_mask = IB_PORT_CM_SUP; port_modify.clr_port_cap_mask = IB_PORT_CM_SUP;
kfree(port);
while (--i) { while (--i) {
if (!rdma_cap_ib_cm(ib_device, i)) if (!rdma_cap_ib_cm(ib_device, i))
continue; continue;
...@@ -4382,10 +4448,9 @@ static int cm_add_one(struct ib_device *ib_device) ...@@ -4382,10 +4448,9 @@ static int cm_add_one(struct ib_device *ib_device)
ib_modify_port(ib_device, port->port_num, 0, &port_modify); ib_modify_port(ib_device, port->port_num, 0, &port_modify);
ib_unregister_mad_agent(port->mad_agent); ib_unregister_mad_agent(port->mad_agent);
cm_remove_port_fs(port); cm_remove_port_fs(port);
kfree(port);
} }
free: free:
kfree(cm_dev); cm_device_put(cm_dev);
return ret; return ret;
} }
...@@ -4408,10 +4473,13 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data) ...@@ -4408,10 +4473,13 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data)
spin_unlock_irq(&cm.lock); spin_unlock_irq(&cm.lock);
rdma_for_each_port (ib_device, i) { rdma_for_each_port (ib_device, i) {
struct ib_mad_agent *mad_agent;
if (!rdma_cap_ib_cm(ib_device, i)) if (!rdma_cap_ib_cm(ib_device, i))
continue; continue;
port = cm_dev->port[i-1]; port = cm_dev->port[i-1];
mad_agent = port->mad_agent;
ib_modify_port(ib_device, port->port_num, 0, &port_modify); ib_modify_port(ib_device, port->port_num, 0, &port_modify);
/* /*
* We flush the queue here after the going_down set, this * We flush the queue here after the going_down set, this
...@@ -4419,12 +4487,18 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data) ...@@ -4419,12 +4487,18 @@ static void cm_remove_one(struct ib_device *ib_device, void *client_data)
* after that we can call the unregister_mad_agent * after that we can call the unregister_mad_agent
*/ */
flush_workqueue(cm.wq); flush_workqueue(cm.wq);
ib_unregister_mad_agent(port->mad_agent); /*
* The above ensures no call paths from the work are running,
* the remaining paths all take the mad_agent_lock.
*/
spin_lock(&cm_dev->mad_agent_lock);
port->mad_agent = NULL;
spin_unlock(&cm_dev->mad_agent_lock);
ib_unregister_mad_agent(mad_agent);
cm_remove_port_fs(port); cm_remove_port_fs(port);
kfree(port);
} }
kfree(cm_dev); cm_device_put(cm_dev);
} }
static int __init ib_cm_init(void) static int __init ib_cm_init(void)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment