Commit dad76bf7 authored by K. Y. Srinivasan's avatar K. Y. Srinivasan Committed by Greg Kroah-Hartman

Staging: hv: vmbus: Properly deal with de-registering channel callback

Ensure that we correctly handle racing invocations of the channel callback
when the channel is being closed. We do this using the channel's inbound_lock.
A side-effect of this strategy is that we avoid repeatedly picking up this lock
as we drain the inbound ring-buffer.
Signed-off-by: default avatarK. Y. Srinivasan <kys@microsoft.com>
Signed-off-by: default avatarHaiyang Zhang <haiyangz@microsoft.com>
Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@suse.de>
parent 76c39d42
...@@ -513,9 +513,12 @@ void vmbus_close(struct vmbus_channel *channel) ...@@ -513,9 +513,12 @@ void vmbus_close(struct vmbus_channel *channel)
{ {
struct vmbus_channel_close_channel *msg; struct vmbus_channel_close_channel *msg;
int ret; int ret;
unsigned long flags;
/* Stop callback and cancel the timer asap */ /* Stop callback and cancel the timer asap */
spin_lock_irqsave(&channel->inbound_lock, flags);
channel->onchannel_callback = NULL; channel->onchannel_callback = NULL;
spin_unlock_irqrestore(&channel->inbound_lock, flags);
/* Send a closing message */ /* Send a closing message */
...@@ -735,19 +738,15 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer, ...@@ -735,19 +738,15 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
u32 packetlen; u32 packetlen;
u32 userlen; u32 userlen;
int ret; int ret;
unsigned long flags;
*buffer_actual_len = 0; *buffer_actual_len = 0;
*requestid = 0; *requestid = 0;
spin_lock_irqsave(&channel->inbound_lock, flags);
ret = hv_ringbuffer_peek(&channel->inbound, &desc, ret = hv_ringbuffer_peek(&channel->inbound, &desc,
sizeof(struct vmpacket_descriptor)); sizeof(struct vmpacket_descriptor));
if (ret != 0) { if (ret != 0)
spin_unlock_irqrestore(&channel->inbound_lock, flags);
return 0; return 0;
}
packetlen = desc.len8 << 3; packetlen = desc.len8 << 3;
userlen = packetlen - (desc.offset8 << 3); userlen = packetlen - (desc.offset8 << 3);
...@@ -755,7 +754,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer, ...@@ -755,7 +754,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
*buffer_actual_len = userlen; *buffer_actual_len = userlen;
if (userlen > bufferlen) { if (userlen > bufferlen) {
spin_unlock_irqrestore(&channel->inbound_lock, flags);
pr_err("Buffer too small - got %d needs %d\n", pr_err("Buffer too small - got %d needs %d\n",
bufferlen, userlen); bufferlen, userlen);
...@@ -768,7 +766,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer, ...@@ -768,7 +766,6 @@ int vmbus_recvpacket(struct vmbus_channel *channel, void *buffer,
ret = hv_ringbuffer_read(&channel->inbound, buffer, userlen, ret = hv_ringbuffer_read(&channel->inbound, buffer, userlen,
(desc.offset8 << 3)); (desc.offset8 << 3));
spin_unlock_irqrestore(&channel->inbound_lock, flags);
return 0; return 0;
} }
...@@ -785,19 +782,15 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer, ...@@ -785,19 +782,15 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
u32 packetlen; u32 packetlen;
u32 userlen; u32 userlen;
int ret; int ret;
unsigned long flags;
*buffer_actual_len = 0; *buffer_actual_len = 0;
*requestid = 0; *requestid = 0;
spin_lock_irqsave(&channel->inbound_lock, flags);
ret = hv_ringbuffer_peek(&channel->inbound, &desc, ret = hv_ringbuffer_peek(&channel->inbound, &desc,
sizeof(struct vmpacket_descriptor)); sizeof(struct vmpacket_descriptor));
if (ret != 0) { if (ret != 0)
spin_unlock_irqrestore(&channel->inbound_lock, flags);
return 0; return 0;
}
packetlen = desc.len8 << 3; packetlen = desc.len8 << 3;
...@@ -806,8 +799,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer, ...@@ -806,8 +799,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
*buffer_actual_len = packetlen; *buffer_actual_len = packetlen;
if (packetlen > bufferlen) { if (packetlen > bufferlen) {
spin_unlock_irqrestore(&channel->inbound_lock, flags);
pr_err("Buffer too small - needed %d bytes but " pr_err("Buffer too small - needed %d bytes but "
"got space for only %d bytes\n", "got space for only %d bytes\n",
packetlen, bufferlen); packetlen, bufferlen);
...@@ -819,7 +810,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer, ...@@ -819,7 +810,6 @@ int vmbus_recvpacket_raw(struct vmbus_channel *channel, void *buffer,
/* Copy over the entire packet to the user buffer */ /* Copy over the entire packet to the user buffer */
ret = hv_ringbuffer_read(&channel->inbound, buffer, packetlen, 0); ret = hv_ringbuffer_read(&channel->inbound, buffer, packetlen, 0);
spin_unlock_irqrestore(&channel->inbound_lock, flags);
return 0; return 0;
} }
EXPORT_SYMBOL_GPL(vmbus_recvpacket_raw); EXPORT_SYMBOL_GPL(vmbus_recvpacket_raw);
...@@ -215,6 +215,7 @@ struct vmbus_channel *relid2channel(u32 relid) ...@@ -215,6 +215,7 @@ struct vmbus_channel *relid2channel(u32 relid)
static void process_chn_event(u32 relid) static void process_chn_event(u32 relid)
{ {
struct vmbus_channel *channel; struct vmbus_channel *channel;
unsigned long flags;
/* /*
* Find the channel based on this relid and invokes the * Find the channel based on this relid and invokes the
...@@ -222,11 +223,13 @@ static void process_chn_event(u32 relid) ...@@ -222,11 +223,13 @@ static void process_chn_event(u32 relid)
*/ */
channel = relid2channel(relid); channel = relid2channel(relid);
spin_lock_irqsave(&channel->inbound_lock, flags);
if (channel && (channel->onchannel_callback != NULL)) { if (channel && (channel->onchannel_callback != NULL)) {
channel->onchannel_callback(channel->channel_callback_context); channel->onchannel_callback(channel->channel_callback_context);
} else { } else {
pr_err("channel not found for relid - %u\n", relid); pr_err("channel not found for relid - %u\n", relid);
} }
spin_unlock_irqrestore(&channel->inbound_lock, flags);
} }
/* /*
......
...@@ -62,9 +62,7 @@ static struct netvsc_device *get_outbound_net_device(struct hv_device *device) ...@@ -62,9 +62,7 @@ static struct netvsc_device *get_outbound_net_device(struct hv_device *device)
static struct netvsc_device *get_inbound_net_device(struct hv_device *device) static struct netvsc_device *get_inbound_net_device(struct hv_device *device)
{ {
struct netvsc_device *net_device; struct netvsc_device *net_device;
unsigned long flags;
spin_lock_irqsave(&device->channel->inbound_lock, flags);
net_device = device->ext; net_device = device->ext;
if (!net_device) if (!net_device)
...@@ -75,7 +73,6 @@ static struct netvsc_device *get_inbound_net_device(struct hv_device *device) ...@@ -75,7 +73,6 @@ static struct netvsc_device *get_inbound_net_device(struct hv_device *device)
net_device = NULL; net_device = NULL;
get_in_err: get_in_err:
spin_unlock_irqrestore(&device->channel->inbound_lock, flags);
return net_device; return net_device;
} }
......
...@@ -352,9 +352,7 @@ static inline struct storvsc_device *get_in_stor_device( ...@@ -352,9 +352,7 @@ static inline struct storvsc_device *get_in_stor_device(
struct hv_device *device) struct hv_device *device)
{ {
struct storvsc_device *stor_device; struct storvsc_device *stor_device;
unsigned long flags;
spin_lock_irqsave(&device->channel->inbound_lock, flags);
stor_device = (struct storvsc_device *)device->ext; stor_device = (struct storvsc_device *)device->ext;
if (!stor_device) if (!stor_device)
...@@ -370,7 +368,6 @@ static inline struct storvsc_device *get_in_stor_device( ...@@ -370,7 +368,6 @@ static inline struct storvsc_device *get_in_stor_device(
stor_device = NULL; stor_device = NULL;
get_in_err: get_in_err:
spin_unlock_irqrestore(&device->channel->inbound_lock, flags);
return stor_device; return stor_device;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment