Commit 545a8e79 authored by stephen hemminger's avatar stephen hemminger Committed by David S. Miller

netvsc: use RCU to protect inner device structure

The netvsc driver has an internal structure (netvsc_device) which
is created when device is opened and released when device is closed.
And also opened/released when MTU or number of channels change.

Since this is referenced in the receive and transmit path, it is
safer to use RCU to protect/prevent use after free problems.
Signed-off-by: default avatarStephen Hemminger <sthemmin@microsoft.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 3071ada4
...@@ -686,7 +686,7 @@ struct net_device_context { ...@@ -686,7 +686,7 @@ struct net_device_context {
/* point back to our device context */ /* point back to our device context */
struct hv_device *device_ctx; struct hv_device *device_ctx;
/* netvsc_device */ /* netvsc_device */
struct netvsc_device *nvdev; struct netvsc_device __rcu *nvdev;
/* reconfigure work */ /* reconfigure work */
struct delayed_work dwork; struct delayed_work dwork;
/* last reconfig time */ /* last reconfig time */
...@@ -780,6 +780,8 @@ struct netvsc_device { ...@@ -780,6 +780,8 @@ struct netvsc_device {
atomic_t open_cnt; atomic_t open_cnt;
struct netvsc_channel chan_table[VRSS_CHANNEL_MAX]; struct netvsc_channel chan_table[VRSS_CHANNEL_MAX];
struct rcu_head rcu;
}; };
static inline struct netvsc_device * static inline struct netvsc_device *
......
...@@ -80,8 +80,10 @@ static struct netvsc_device *alloc_net_device(void) ...@@ -80,8 +80,10 @@ static struct netvsc_device *alloc_net_device(void)
return net_device; return net_device;
} }
static void free_netvsc_device(struct netvsc_device *nvdev) static void free_netvsc_device(struct rcu_head *head)
{ {
struct netvsc_device *nvdev
= container_of(head, struct netvsc_device, rcu);
int i; int i;
for (i = 0; i < VRSS_CHANNEL_MAX; i++) for (i = 0; i < VRSS_CHANNEL_MAX; i++)
...@@ -90,6 +92,10 @@ static void free_netvsc_device(struct netvsc_device *nvdev) ...@@ -90,6 +92,10 @@ static void free_netvsc_device(struct netvsc_device *nvdev)
kfree(nvdev); kfree(nvdev);
} }
static void free_netvsc_device_rcu(struct netvsc_device *nvdev)
{
call_rcu(&nvdev->rcu, free_netvsc_device);
}
static struct netvsc_device *get_outbound_net_device(struct hv_device *device) static struct netvsc_device *get_outbound_net_device(struct hv_device *device)
{ {
...@@ -551,7 +557,7 @@ void netvsc_device_remove(struct hv_device *device) ...@@ -551,7 +557,7 @@ void netvsc_device_remove(struct hv_device *device)
netvsc_disconnect_vsp(device); netvsc_disconnect_vsp(device);
net_device_ctx->nvdev = NULL; RCU_INIT_POINTER(net_device_ctx->nvdev, NULL);
/* /*
* At this point, no one should be accessing net_device * At this point, no one should be accessing net_device
...@@ -566,7 +572,7 @@ void netvsc_device_remove(struct hv_device *device) ...@@ -566,7 +572,7 @@ void netvsc_device_remove(struct hv_device *device)
napi_disable(&net_device->chan_table[i].napi); napi_disable(&net_device->chan_table[i].napi);
/* Release all resources */ /* Release all resources */
free_netvsc_device(net_device); free_netvsc_device_rcu(net_device);
} }
#define RING_AVAIL_PERCENT_HIWATER 20 #define RING_AVAIL_PERCENT_HIWATER 20
...@@ -1322,7 +1328,7 @@ int netvsc_device_add(struct hv_device *device, ...@@ -1322,7 +1328,7 @@ int netvsc_device_add(struct hv_device *device,
*/ */
wmb(); wmb();
net_device_ctx->nvdev = net_device; rcu_assign_pointer(net_device_ctx->nvdev, net_device);
/* Connect with the NetVsp */ /* Connect with the NetVsp */
ret = netvsc_connect_vsp(device); ret = netvsc_connect_vsp(device);
...@@ -1341,7 +1347,7 @@ int netvsc_device_add(struct hv_device *device, ...@@ -1341,7 +1347,7 @@ int netvsc_device_add(struct hv_device *device,
vmbus_close(device->channel); vmbus_close(device->channel);
cleanup: cleanup:
free_netvsc_device(net_device); free_netvsc_device(&net_device->rcu);
return ret; return ret;
} }
...@@ -62,7 +62,7 @@ static void do_set_multicast(struct work_struct *w) ...@@ -62,7 +62,7 @@ static void do_set_multicast(struct work_struct *w)
container_of(w, struct net_device_context, work); container_of(w, struct net_device_context, work);
struct hv_device *device_obj = ndevctx->device_ctx; struct hv_device *device_obj = ndevctx->device_ctx;
struct net_device *ndev = hv_get_drvdata(device_obj); struct net_device *ndev = hv_get_drvdata(device_obj);
struct netvsc_device *nvdev = ndevctx->nvdev; struct netvsc_device *nvdev = rcu_dereference(ndevctx->nvdev);
struct rndis_device *rdev; struct rndis_device *rdev;
if (!nvdev) if (!nvdev)
...@@ -116,7 +116,7 @@ static int netvsc_open(struct net_device *net) ...@@ -116,7 +116,7 @@ static int netvsc_open(struct net_device *net)
static int netvsc_close(struct net_device *net) static int netvsc_close(struct net_device *net)
{ {
struct net_device_context *net_device_ctx = netdev_priv(net); struct net_device_context *net_device_ctx = netdev_priv(net);
struct netvsc_device *nvdev = net_device_ctx->nvdev; struct netvsc_device *nvdev = rtnl_dereference(net_device_ctx->nvdev);
int ret; int ret;
u32 aread, awrite, i, msec = 10, retry = 0, retry_max = 20; u32 aread, awrite, i, msec = 10, retry = 0, retry_max = 20;
struct vmbus_channel *chn; struct vmbus_channel *chn;
...@@ -637,9 +637,9 @@ int netvsc_recv_callback(struct net_device *net, ...@@ -637,9 +637,9 @@ int netvsc_recv_callback(struct net_device *net,
const struct ndis_pkt_8021q_info *vlan) const struct ndis_pkt_8021q_info *vlan)
{ {
struct net_device_context *net_device_ctx = netdev_priv(net); struct net_device_context *net_device_ctx = netdev_priv(net);
struct netvsc_device *net_device = net_device_ctx->nvdev; struct netvsc_device *net_device;
u16 q_idx = channel->offermsg.offer.sub_channel_index; u16 q_idx = channel->offermsg.offer.sub_channel_index;
struct netvsc_channel *nvchan = &net_device->chan_table[q_idx]; struct netvsc_channel *nvchan;
struct net_device *vf_netdev; struct net_device *vf_netdev;
struct sk_buff *skb; struct sk_buff *skb;
struct netvsc_stats *rx_stats; struct netvsc_stats *rx_stats;
...@@ -655,6 +655,11 @@ int netvsc_recv_callback(struct net_device *net, ...@@ -655,6 +655,11 @@ int netvsc_recv_callback(struct net_device *net,
* interface in the guest. * interface in the guest.
*/ */
rcu_read_lock(); rcu_read_lock();
net_device = rcu_dereference(net_device_ctx->nvdev);
if (unlikely(!net_device))
goto drop;
nvchan = &net_device->chan_table[q_idx];
vf_netdev = rcu_dereference(net_device_ctx->vf_netdev); vf_netdev = rcu_dereference(net_device_ctx->vf_netdev);
if (vf_netdev && (vf_netdev->flags & IFF_UP)) if (vf_netdev && (vf_netdev->flags & IFF_UP))
net = vf_netdev; net = vf_netdev;
...@@ -663,6 +668,7 @@ int netvsc_recv_callback(struct net_device *net, ...@@ -663,6 +668,7 @@ int netvsc_recv_callback(struct net_device *net,
skb = netvsc_alloc_recv_skb(net, &nvchan->napi, skb = netvsc_alloc_recv_skb(net, &nvchan->napi,
csum_info, vlan, data, len); csum_info, vlan, data, len);
if (unlikely(!skb)) { if (unlikely(!skb)) {
drop:
++net->stats.rx_dropped; ++net->stats.rx_dropped;
rcu_read_unlock(); rcu_read_unlock();
return NVSP_STAT_FAIL; return NVSP_STAT_FAIL;
...@@ -704,7 +710,7 @@ static void netvsc_get_channels(struct net_device *net, ...@@ -704,7 +710,7 @@ static void netvsc_get_channels(struct net_device *net,
struct ethtool_channels *channel) struct ethtool_channels *channel)
{ {
struct net_device_context *net_device_ctx = netdev_priv(net); struct net_device_context *net_device_ctx = netdev_priv(net);
struct netvsc_device *nvdev = net_device_ctx->nvdev; struct netvsc_device *nvdev = rtnl_dereference(net_device_ctx->nvdev);
if (nvdev) { if (nvdev) {
channel->max_combined = nvdev->max_chn; channel->max_combined = nvdev->max_chn;
...@@ -741,7 +747,7 @@ static int netvsc_set_channels(struct net_device *net, ...@@ -741,7 +747,7 @@ static int netvsc_set_channels(struct net_device *net,
{ {
struct net_device_context *net_device_ctx = netdev_priv(net); struct net_device_context *net_device_ctx = netdev_priv(net);
struct hv_device *dev = net_device_ctx->device_ctx; struct hv_device *dev = net_device_ctx->device_ctx;
struct netvsc_device *nvdev = net_device_ctx->nvdev; struct netvsc_device *nvdev = rtnl_dereference(net_device_ctx->nvdev);
unsigned int count = channels->combined_count; unsigned int count = channels->combined_count;
bool was_running; bool was_running;
int ret; int ret;
...@@ -848,7 +854,7 @@ static int netvsc_set_link_ksettings(struct net_device *dev, ...@@ -848,7 +854,7 @@ static int netvsc_set_link_ksettings(struct net_device *dev,
static int netvsc_change_mtu(struct net_device *ndev, int mtu) static int netvsc_change_mtu(struct net_device *ndev, int mtu)
{ {
struct net_device_context *ndevctx = netdev_priv(ndev); struct net_device_context *ndevctx = netdev_priv(ndev);
struct netvsc_device *nvdev = ndevctx->nvdev; struct netvsc_device *nvdev = rtnl_dereference(ndevctx->nvdev);
struct hv_device *hdev = ndevctx->device_ctx; struct hv_device *hdev = ndevctx->device_ctx;
struct netvsc_device_info device_info; struct netvsc_device_info device_info;
bool was_running; bool was_running;
...@@ -897,7 +903,7 @@ static void netvsc_get_stats64(struct net_device *net, ...@@ -897,7 +903,7 @@ static void netvsc_get_stats64(struct net_device *net,
struct rtnl_link_stats64 *t) struct rtnl_link_stats64 *t)
{ {
struct net_device_context *ndev_ctx = netdev_priv(net); struct net_device_context *ndev_ctx = netdev_priv(net);
struct netvsc_device *nvdev = ndev_ctx->nvdev; struct netvsc_device *nvdev = rcu_dereference(ndev_ctx->nvdev);
int i; int i;
if (!nvdev) if (!nvdev)
...@@ -982,7 +988,10 @@ static const struct { ...@@ -982,7 +988,10 @@ static const struct {
static int netvsc_get_sset_count(struct net_device *dev, int string_set) static int netvsc_get_sset_count(struct net_device *dev, int string_set)
{ {
struct net_device_context *ndc = netdev_priv(dev); struct net_device_context *ndc = netdev_priv(dev);
struct netvsc_device *nvdev = ndc->nvdev; struct netvsc_device *nvdev = rcu_dereference(ndc->nvdev);
if (!nvdev)
return -ENODEV;
switch (string_set) { switch (string_set) {
case ETH_SS_STATS: case ETH_SS_STATS:
...@@ -996,13 +1005,16 @@ static void netvsc_get_ethtool_stats(struct net_device *dev, ...@@ -996,13 +1005,16 @@ static void netvsc_get_ethtool_stats(struct net_device *dev,
struct ethtool_stats *stats, u64 *data) struct ethtool_stats *stats, u64 *data)
{ {
struct net_device_context *ndc = netdev_priv(dev); struct net_device_context *ndc = netdev_priv(dev);
struct netvsc_device *nvdev = ndc->nvdev; struct netvsc_device *nvdev = rcu_dereference(ndc->nvdev);
const void *nds = &ndc->eth_stats; const void *nds = &ndc->eth_stats;
const struct netvsc_stats *qstats; const struct netvsc_stats *qstats;
unsigned int start; unsigned int start;
u64 packets, bytes; u64 packets, bytes;
int i, j; int i, j;
if (!nvdev)
return;
for (i = 0; i < NETVSC_GLOBAL_STATS_LEN; i++) for (i = 0; i < NETVSC_GLOBAL_STATS_LEN; i++)
data[i] = *(unsigned long *)(nds + netvsc_stats[i].offset); data[i] = *(unsigned long *)(nds + netvsc_stats[i].offset);
...@@ -1031,10 +1043,13 @@ static void netvsc_get_ethtool_stats(struct net_device *dev, ...@@ -1031,10 +1043,13 @@ static void netvsc_get_ethtool_stats(struct net_device *dev,
static void netvsc_get_strings(struct net_device *dev, u32 stringset, u8 *data) static void netvsc_get_strings(struct net_device *dev, u32 stringset, u8 *data)
{ {
struct net_device_context *ndc = netdev_priv(dev); struct net_device_context *ndc = netdev_priv(dev);
struct netvsc_device *nvdev = ndc->nvdev; struct netvsc_device *nvdev = rcu_dereference(ndc->nvdev);
u8 *p = data; u8 *p = data;
int i; int i;
if (!nvdev)
return;
switch (stringset) { switch (stringset) {
case ETH_SS_STATS: case ETH_SS_STATS:
for (i = 0; i < ARRAY_SIZE(netvsc_stats); i++) for (i = 0; i < ARRAY_SIZE(netvsc_stats); i++)
...@@ -1086,7 +1101,10 @@ netvsc_get_rxnfc(struct net_device *dev, struct ethtool_rxnfc *info, ...@@ -1086,7 +1101,10 @@ netvsc_get_rxnfc(struct net_device *dev, struct ethtool_rxnfc *info,
u32 *rules) u32 *rules)
{ {
struct net_device_context *ndc = netdev_priv(dev); struct net_device_context *ndc = netdev_priv(dev);
struct netvsc_device *nvdev = ndc->nvdev; struct netvsc_device *nvdev = rcu_dereference(ndc->nvdev);
if (!nvdev)
return -ENODEV;
switch (info->cmd) { switch (info->cmd) {
case ETHTOOL_GRXRINGS: case ETHTOOL_GRXRINGS:
...@@ -1122,10 +1140,13 @@ static int netvsc_get_rxfh(struct net_device *dev, u32 *indir, u8 *key, ...@@ -1122,10 +1140,13 @@ static int netvsc_get_rxfh(struct net_device *dev, u32 *indir, u8 *key,
u8 *hfunc) u8 *hfunc)
{ {
struct net_device_context *ndc = netdev_priv(dev); struct net_device_context *ndc = netdev_priv(dev);
struct netvsc_device *ndev = ndc->nvdev; struct netvsc_device *ndev = rcu_dereference(ndc->nvdev);
struct rndis_device *rndis_dev = ndev->extension; struct rndis_device *rndis_dev = ndev->extension;
int i; int i;
if (!ndev)
return -ENODEV;
if (hfunc) if (hfunc)
*hfunc = ETH_RSS_HASH_TOP; /* Toeplitz */ *hfunc = ETH_RSS_HASH_TOP; /* Toeplitz */
...@@ -1144,10 +1165,13 @@ static int netvsc_set_rxfh(struct net_device *dev, const u32 *indir, ...@@ -1144,10 +1165,13 @@ static int netvsc_set_rxfh(struct net_device *dev, const u32 *indir,
const u8 *key, const u8 hfunc) const u8 *key, const u8 hfunc)
{ {
struct net_device_context *ndc = netdev_priv(dev); struct net_device_context *ndc = netdev_priv(dev);
struct netvsc_device *ndev = ndc->nvdev; struct netvsc_device *ndev = rtnl_dereference(ndc->nvdev);
struct rndis_device *rndis_dev = ndev->extension; struct rndis_device *rndis_dev = ndev->extension;
int i; int i;
if (!ndev)
return -ENODEV;
if (hfunc != ETH_RSS_HASH_NO_CHANGE && hfunc != ETH_RSS_HASH_TOP) if (hfunc != ETH_RSS_HASH_NO_CHANGE && hfunc != ETH_RSS_HASH_TOP)
return -EOPNOTSUPP; return -EOPNOTSUPP;
...@@ -1224,7 +1248,7 @@ static void netvsc_link_change(struct work_struct *w) ...@@ -1224,7 +1248,7 @@ static void netvsc_link_change(struct work_struct *w)
if (ndev_ctx->start_remove) if (ndev_ctx->start_remove)
goto out_unlock; goto out_unlock;
net_device = ndev_ctx->nvdev; net_device = rtnl_dereference(ndev_ctx->nvdev);
rdev = net_device->extension; rdev = net_device->extension;
next_reconfig = ndev_ctx->last_reconfig + LINKCHANGE_INT; next_reconfig = ndev_ctx->last_reconfig + LINKCHANGE_INT;
...@@ -1365,7 +1389,7 @@ static int netvsc_register_vf(struct net_device *vf_netdev) ...@@ -1365,7 +1389,7 @@ static int netvsc_register_vf(struct net_device *vf_netdev)
return NOTIFY_DONE; return NOTIFY_DONE;
net_device_ctx = netdev_priv(ndev); net_device_ctx = netdev_priv(ndev);
netvsc_dev = net_device_ctx->nvdev; netvsc_dev = rtnl_dereference(net_device_ctx->nvdev);
if (!netvsc_dev || rtnl_dereference(net_device_ctx->vf_netdev)) if (!netvsc_dev || rtnl_dereference(net_device_ctx->vf_netdev))
return NOTIFY_DONE; return NOTIFY_DONE;
...@@ -1391,7 +1415,7 @@ static int netvsc_vf_up(struct net_device *vf_netdev) ...@@ -1391,7 +1415,7 @@ static int netvsc_vf_up(struct net_device *vf_netdev)
return NOTIFY_DONE; return NOTIFY_DONE;
net_device_ctx = netdev_priv(ndev); net_device_ctx = netdev_priv(ndev);
netvsc_dev = net_device_ctx->nvdev; netvsc_dev = rtnl_dereference(net_device_ctx->nvdev);
netdev_info(ndev, "VF up: %s\n", vf_netdev->name); netdev_info(ndev, "VF up: %s\n", vf_netdev->name);
...@@ -1425,7 +1449,7 @@ static int netvsc_vf_down(struct net_device *vf_netdev) ...@@ -1425,7 +1449,7 @@ static int netvsc_vf_down(struct net_device *vf_netdev)
return NOTIFY_DONE; return NOTIFY_DONE;
net_device_ctx = netdev_priv(ndev); net_device_ctx = netdev_priv(ndev);
netvsc_dev = net_device_ctx->nvdev; netvsc_dev = rtnl_dereference(net_device_ctx->nvdev);
netdev_info(ndev, "VF down: %s\n", vf_netdev->name); netdev_info(ndev, "VF down: %s\n", vf_netdev->name);
netvsc_switch_datapath(ndev, false); netvsc_switch_datapath(ndev, false);
...@@ -1519,6 +1543,7 @@ static int netvsc_probe(struct hv_device *dev, ...@@ -1519,6 +1543,7 @@ static int netvsc_probe(struct hv_device *dev,
NETIF_F_HW_VLAN_CTAG_TX | NETIF_F_HW_VLAN_CTAG_RX; NETIF_F_HW_VLAN_CTAG_TX | NETIF_F_HW_VLAN_CTAG_RX;
net->vlan_features = net->features; net->vlan_features = net->features;
/* RCU not necessary here, device not registered */
nvdev = net_device_ctx->nvdev; nvdev = net_device_ctx->nvdev;
netif_set_real_num_tx_queues(net, nvdev->num_chn); netif_set_real_num_tx_queues(net, nvdev->num_chn);
netif_set_real_num_rx_queues(net, nvdev->num_chn); netif_set_real_num_rx_queues(net, nvdev->num_chn);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment