@@ -183,38 +183,64 @@ static int net_failover_get_stats(struct net_device *dev,
struct rtnl_link_stats64 *stats)
{
struct net_failover_info *nfo_info = netdev_priv(dev);
- struct rtnl_link_stats64 temp;
- struct net_device *slave_dev;
+ struct rtnl_link_stats64 primary_stats;
+ struct rtnl_link_stats64 standby_stats;
+ struct net_device *primary_dev;
+ struct net_device *standby_dev;
int err = 0;
- spin_lock(&nfo_info->stats_lock);
- memcpy(stats, &nfo_info->failover_stats, sizeof(*stats));
-
rcu_read_lock();
- slave_dev = rcu_dereference(nfo_info->primary_dev);
- if (slave_dev) {
- err = dev_get_stats(slave_dev, &temp);
+ primary_dev = rcu_dereference(nfo_info->primary_dev);
+ if (primary_dev)
+ dev_hold(primary_dev);
+
+ standby_dev = rcu_dereference(nfo_info->standby_dev);
+ if (standby_dev)
+ dev_hold(standby_dev);
+
+ rcu_read_unlock();
+
+ /* Don't hold rcu_read_lock while calling dev_get_stats, just a
+ * reference to ensure they won't get unregistered.
+ */
+ if (primary_dev) {
+ err = dev_get_stats(primary_dev, &primary_stats);
if (err)
goto out;
- net_failover_fold_stats(stats, &temp, &nfo_info->primary_stats);
- memcpy(&nfo_info->primary_stats, &temp, sizeof(temp));
}
- slave_dev = rcu_dereference(nfo_info->standby_dev);
- if (slave_dev) {
- err = dev_get_stats(slave_dev, &temp);
+ if (standby_dev) {
+ err = dev_get_stats(standby_dev, &standby_stats);
if (err)
goto out;
- net_failover_fold_stats(stats, &temp, &nfo_info->standby_stats);
- memcpy(&nfo_info->standby_stats, &temp, sizeof(temp));
}
-out:
- rcu_read_unlock();
+ spin_lock(&nfo_info->stats_lock);
+
+ memcpy(stats, &nfo_info->failover_stats, sizeof(*stats));
+
+ if (primary_dev) {
+ net_failover_fold_stats(stats, &primary_stats,
+ &nfo_info->primary_stats);
+ memcpy(&nfo_info->primary_stats, &primary_stats,
+ sizeof(primary_stats));
+ }
+ if (standby_dev) {
+ net_failover_fold_stats(stats, &standby_stats,
+ &nfo_info->standby_stats);
+ memcpy(&nfo_info->standby_stats, &standby_stats,
+ sizeof(standby_stats));
+ }
memcpy(&nfo_info->failover_stats, stats, sizeof(*stats));
+
spin_unlock(&nfo_info->stats_lock);
+out:
+ if (primary_dev)
+ dev_put(primary_dev);
+ if (standby_dev)
+ dev_put(standby_dev);
return err;
}
@@ -728,6 +754,7 @@ static struct failover_ops net_failover_ops = {
struct failover *net_failover_create(struct net_device *standby_dev)
{
struct device *dev = standby_dev->dev.parent;
+ struct net_failover_info *nfo_info;
struct net_device *failover_dev;
struct failover *failover;
int err;
@@ -772,6 +799,9 @@ struct failover *net_failover_create(struct net_device *standby_dev)
failover_dev->min_mtu = standby_dev->min_mtu;
failover_dev->max_mtu = standby_dev->max_mtu;
+ nfo_info = netdev_priv(failover_dev);
+ spin_lock_init(&nfo_info->stats_lock);
+
err = register_netdev(failover_dev);
if (err) {
dev_err(dev, "Unable to register failover_dev!\n");