diff mbox series

[v2,net-next] mld: change lockdep annotation for ip6_sf_socklist and ipv6_mc_socklist

Message ID 20210404133823.15509-1-ap420073@gmail.com
State New
Headers show
Series [v2,net-next] mld: change lockdep annotation for ip6_sf_socklist and ipv6_mc_socklist | expand

Commit Message

Taehee Yoo April 4, 2021, 1:38 p.m. UTC
struct ip6_sf_socklist and ipv6_mc_socklist are per-socket MLD data.
These data are protected by rtnl lock, socket lock, and RCU.
So, when these are used, it verifies whether rtnl lock is acquired or not.

ip6_mc_msfget() is called by do_ipv6_getsockopt().
But caller doesn't acquire rtnl lock.
So, when these data are used in the ip6_mc_msfget() lockdep warns about it.
But accessing these is actually safe because socket lock was acquired by
do_ipv6_getsockopt().

So, it changes lockdep annotation from rtnl lock to socket lock.
(rtnl_dereference -> sock_dereference)

Locking graph for mld data is like below:

When writing mld data:
do_ipv6_setsockopt()
    rtnl_lock
    lock_sock
    (mld functions)
        idev->mc_lock(if per-interface mld data is modified)

When reading mld data:
do_ipv6_getsockopt()
    lock_sock
    ip6_mc_msfget()

Splat looks like:
diff mbox series

Patch

=============================
WARNING: suspicious RCU usage
5.12.0-rc4+ #503 Not tainted
-----------------------------
net/ipv6/mcast.c:610 suspicious rcu_dereference_protected() usage!

other info that might help us debug this:

rcu_scheduler_active = 2, debug_locks = 1
1 lock held by mcast-listener-/923:
 #0: ffff888007958a70 (sk_lock-AF_INET6){+.+.}-{0:0}, at:
ipv6_get_msfilter+0xaf/0x190

stack backtrace:
CPU: 1 PID: 923 Comm: mcast-listener- Not tainted 5.12.0-rc4+ #503
Call Trace:
 dump_stack+0xa4/0xe5
 ip6_mc_msfget+0x553/0x6c0
 ? ipv6_sock_mc_join_ssm+0x10/0x10
 ? lockdep_hardirqs_on_prepare+0x3e0/0x3e0
 ? mark_held_locks+0xb7/0x120
 ? lockdep_hardirqs_on_prepare+0x27c/0x3e0
 ? __local_bh_enable_ip+0xa5/0xf0
 ? lock_sock_nested+0x82/0xf0
 ipv6_get_msfilter+0xc3/0x190
 ? compat_ipv6_get_msfilter+0x300/0x300
 ? lock_downgrade+0x690/0x690
 do_ipv6_getsockopt.isra.6.constprop.13+0x1809/0x29e0
 ? do_ipv6_mcast_group_source+0x150/0x150
 ? register_lock_class+0x1750/0x1750
 ? kvm_sched_clock_read+0x14/0x30
 ? sched_clock+0x5/0x10
 ? sched_clock_cpu+0x18/0x170
 ? find_held_lock+0x3a/0x1c0
 ? lock_downgrade+0x690/0x690
 ? ipv6_getsockopt+0xdb/0x1b0
 ipv6_getsockopt+0xdb/0x1b0
[ ... ]

Fixes: 88e2ca308094 ("mld: convert ifmcaddr6 to RCU")
Reported-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Taehee Yoo <ap420073@gmail.com>
---

v1 -> v2:
 - Do not add new rtnl lock(by Eric Dumazet)
 - Add missing lock_sock().
 - Change headline.

 net/ipv6/mcast.c     | 48 +++++++++++++++++++++-----------------------
 net/tipc/udp_media.c |  2 ++
 2 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c
index 49b0cebfdcdc..ff536a158b85 100644
--- a/net/ipv6/mcast.c
+++ b/net/ipv6/mcast.c
@@ -114,10 +114,13 @@  int sysctl_mld_qrv __read_mostly = MLD_QRV_DEFAULT;
 #define mc_dereference(e, idev) \
 	rcu_dereference_protected(e, lockdep_is_held(&(idev)->mc_lock))
 
-#define for_each_pmc_rtnl(np, pmc)				\
-	for (pmc = rtnl_dereference((np)->ipv6_mc_list);	\
+#define sock_dereference(e, sk) \
+	rcu_dereference_protected(e, lockdep_sock_is_held(sk))
+
+#define for_each_pmc_socklock(np, sk, pmc)			\
+	for (pmc = sock_dereference((np)->ipv6_mc_list, sk);	\
 	     pmc;						\
-	     pmc = rtnl_dereference(pmc->next))
+	     pmc = sock_dereference(pmc->next, sk))
 
 #define for_each_pmc_rcu(np, pmc)				\
 	for (pmc = rcu_dereference((np)->ipv6_mc_list);		\
@@ -180,7 +183,7 @@  static int __ipv6_sock_mc_join(struct sock *sk, int ifindex,
 	if (!ipv6_addr_is_multicast(addr))
 		return -EINVAL;
 
-	for_each_pmc_rtnl(np, mc_lst) {
+	for_each_pmc_socklock(np, sk, mc_lst) {
 		if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
 		    ipv6_addr_equal(&mc_lst->addr, addr))
 			return -EADDRINUSE;
@@ -258,7 +261,7 @@  int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
 		return -EINVAL;
 
 	for (lnk = &np->ipv6_mc_list;
-	     (mc_lst = rtnl_dereference(*lnk)) != NULL;
+	     (mc_lst = sock_dereference(*lnk, sk)) != NULL;
 	      lnk = &mc_lst->next) {
 		if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
 		    ipv6_addr_equal(&mc_lst->addr, addr)) {
@@ -323,7 +326,7 @@  void __ipv6_sock_mc_close(struct sock *sk)
 
 	ASSERT_RTNL();
 
-	while ((mc_lst = rtnl_dereference(np->ipv6_mc_list)) != NULL) {
+	while ((mc_lst = sock_dereference(np->ipv6_mc_list, sk)) != NULL) {
 		struct net_device *dev;
 
 		np->ipv6_mc_list = mc_lst->next;
@@ -350,8 +353,11 @@  void ipv6_sock_mc_close(struct sock *sk)
 
 	if (!rcu_access_pointer(np->ipv6_mc_list))
 		return;
+
 	rtnl_lock();
+	lock_sock(sk);
 	__ipv6_sock_mc_close(sk);
+	release_sock(sk);
 	rtnl_unlock();
 }
 
@@ -381,7 +387,7 @@  int ip6_mc_source(int add, int omode, struct sock *sk,
 	err = -EADDRNOTAVAIL;
 
 	mutex_lock(&idev->mc_lock);
-	for_each_pmc_rtnl(inet6, pmc) {
+	for_each_pmc_socklock(inet6, sk, pmc) {
 		if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
 			continue;
 		if (ipv6_addr_equal(&pmc->addr, group))
@@ -404,7 +410,7 @@  int ip6_mc_source(int add, int omode, struct sock *sk,
 		pmc->sfmode = omode;
 	}
 
-	psl = rtnl_dereference(pmc->sflist);
+	psl = sock_dereference(pmc->sflist, sk);
 	if (!add) {
 		if (!psl)
 			goto done;	/* err = -EADDRNOTAVAIL */
@@ -511,7 +517,7 @@  int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
 		goto done;
 	}
 
-	for_each_pmc_rtnl(inet6, pmc) {
+	for_each_pmc_socklock(inet6, sk, pmc) {
 		if (pmc->ifindex != gsf->gf_interface)
 			continue;
 		if (ipv6_addr_equal(&pmc->addr, group))
@@ -552,7 +558,7 @@  int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
 	}
 
 	mutex_lock(&idev->mc_lock);
-	psl = rtnl_dereference(pmc->sflist);
+	psl = sock_dereference(pmc->sflist, sk);
 	if (psl) {
 		ip6_mc_del_src(idev, group, pmc->sfmode,
 			       psl->sl_count, psl->sl_addr, 0);
@@ -574,40 +580,32 @@  int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
 int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
 		  struct sockaddr_storage __user *p)
 {
-	int err, i, count, copycount;
+	struct ipv6_pinfo *inet6 = inet6_sk(sk);
 	const struct in6_addr *group;
 	struct ipv6_mc_socklist *pmc;
-	struct inet6_dev *idev;
-	struct ipv6_pinfo *inet6 = inet6_sk(sk);
 	struct ip6_sf_socklist *psl;
-	struct net *net = sock_net(sk);
+	int i, count, copycount;
 
 	group = &((struct sockaddr_in6 *)&gsf->gf_group)->sin6_addr;
 
 	if (!ipv6_addr_is_multicast(group))
 		return -EINVAL;
 
-	idev = ip6_mc_find_dev_rtnl(net, group, gsf->gf_interface);
-	if (!idev)
-		return -ENODEV;
-
-	err = -EADDRNOTAVAIL;
 	/* changes to the ipv6_mc_list require the socket lock and
-	 * rtnl lock. We have the socket lock and rcu read lock,
-	 * so reading the list is safe.
+	 * rtnl lock. We have the socket lock, so reading the list is safe.
 	 */
 
-	for_each_pmc_rtnl(inet6, pmc) {
+	for_each_pmc_socklock(inet6, sk, pmc) {
 		if (pmc->ifindex != gsf->gf_interface)
 			continue;
 		if (ipv6_addr_equal(group, &pmc->addr))
 			break;
 	}
 	if (!pmc)		/* must have a prior join */
-		return err;
+		return -EADDRNOTAVAIL;
 
 	gsf->gf_fmode = pmc->sfmode;
-	psl = rtnl_dereference(pmc->sflist);
+	psl = sock_dereference(pmc->sflist, sk);
 	count = psl ? psl->sl_count : 0;
 
 	copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc;
@@ -2600,7 +2598,7 @@  static int ip6_mc_leave_src(struct sock *sk, struct ipv6_mc_socklist *iml,
 	struct ip6_sf_socklist *psl;
 	int err;
 
-	psl = rtnl_dereference(iml->sflist);
+	psl = sock_dereference(iml->sflist, sk);
 
 	if (idev)
 		mutex_lock(&idev->mc_lock);
diff --git a/net/tipc/udp_media.c b/net/tipc/udp_media.c
index 21e75e28e86a..e556d2cdc064 100644
--- a/net/tipc/udp_media.c
+++ b/net/tipc/udp_media.c
@@ -414,8 +414,10 @@  static int enable_mcast(struct udp_bearer *ub, struct udp_media_addr *remote)
 		err = ip_mc_join_group(sk, &mreqn);
 #if IS_ENABLED(CONFIG_IPV6)
 	} else {
+		lock_sock(sk);
 		err = ipv6_stub->ipv6_sock_mc_join(sk, ub->ifindex,
 						   &remote->ipv6);
+		release_sock(sk);
 #endif
 	}
 	return err;