diff mbox series

net: Allow any address multicast join for IP sockets

Message ID 20210706011548.2201-2-callum.sinclair@alliedtelesis.co.nz
State New
Headers show
Series net: Allow any address multicast join for IP sockets | expand

Commit Message

Callum Sinclair July 6, 2021, 1:15 a.m. UTC
For an application to receive all multicast packets in a range such as
224.0.0.1 - 239.255.255.255 each multicast IP address has to be joined
explicitly one at a time.

Allow the any address to be passed to the IP_ADD_MEMBERSHIP and
IPV6_ADD_MEMBERSHIP socket option per interface. By joining the any
address the socket will receive all multicast packets that are received
on the interface. This allows any IP socket to be used for IGMP or MLD
snooping.

Signed-off-by: Callum Sinclair <callum.sinclair@alliedtelesis.co.nz>
---
 net/ipv4/igmp.c  | 40 ++++++++++++++++++++++++++++++++--------
 net/ipv6/mcast.c | 20 ++++++++++++++------
 2 files changed, 46 insertions(+), 14 deletions(-)
diff mbox series

Patch

diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c
index 6b3c558a4f23..3978c9f2d1c5 100644
--- a/net/ipv4/igmp.c
+++ b/net/ipv4/igmp.c
@@ -1413,6 +1413,25 @@  static void ip_mc_hash_remove(struct in_device *in_dev,
 	*mc_hash = im->next_hash;
 }
 
+static struct ip_mc_list *ip_mc_hash_lookup(struct ip_mc_list __rcu **mc_hash,
+					    __be32 mc_addr)
+{
+	struct ip_mc_list *im;
+	u32 hash;
+
+	if (mc_hash) {
+		hash = hash_32((__force u32)mc_addr, MC_HASH_SZ_LOG);
+		for (im = rcu_dereference(mc_hash[hash]);
+		     im != NULL;
+		     im = rcu_dereference(im->next_hash)) {
+			if (im->multiaddr == mc_addr)
+				break;
+			}
+	}
+
+	return im;
+}
+
 
 /*
  *	A socket has joined a multicast group on device dev.
@@ -2166,7 +2185,7 @@  static int __ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr,
 
 	ASSERT_RTNL();
 
-	if (!ipv4_is_multicast(addr))
+	if (!ipv4_is_multicast(addr) && addr != htonl(INADDR_ANY))
 		return -EINVAL;
 
 	in_dev = ip_mc_find_dev(net, imr);
@@ -2627,6 +2646,11 @@  int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr,
 
 	rcu_read_lock();
 	for_each_pmc_rcu(inet, pmc) {
+		if (pmc->multi.imr_multiaddr.s_addr == htonl(INADDR_ANY) &&
+		    pmc->multi.imr_ifindex == dif) {
+			ret = 1;
+			goto unlock;
+		}
 		if (pmc->multi.imr_multiaddr.s_addr == loc_addr &&
 		    (pmc->multi.imr_ifindex == dif ||
 		     (sdif && pmc->multi.imr_ifindex == sdif)))
@@ -2695,18 +2719,18 @@  int ip_check_mc_rcu(struct in_device *in_dev, __be32 mc_addr, __be32 src_addr, u
 
 	mc_hash = rcu_dereference(in_dev->mc_hash);
 	if (mc_hash) {
-		u32 hash = hash_32((__force u32)mc_addr, MC_HASH_SZ_LOG);
-
-		for (im = rcu_dereference(mc_hash[hash]);
-		     im != NULL;
-		     im = rcu_dereference(im->next_hash)) {
-			if (im->multiaddr == mc_addr)
-				break;
+		im = ip_mc_hash_lookup(mc_hash, mc_addr);
+		if (!im) {
+			if (ip_mc_hash_lookup(mc_hash, htonl(INADDR_ANY)))
+				return 1;
 		}
+
 	} else {
 		for_each_pmc_rcu(in_dev, im) {
 			if (im->multiaddr == mc_addr)
 				break;
+			if (im->multiaddr == htonl(INADDR_ANY))
+				return 1;
 		}
 	}
 	if (im && proto == IPPROTO_IGMP) {
diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c
index 54ec163fbafa..7acf5b3cb435 100644
--- a/net/ipv6/mcast.c
+++ b/net/ipv6/mcast.c
@@ -177,7 +177,7 @@  static int __ipv6_sock_mc_join(struct sock *sk, int ifindex,
 
 	ASSERT_RTNL();
 
-	if (!ipv6_addr_is_multicast(addr))
+	if (!ipv6_addr_is_multicast(addr) && !ipv6_addr_any(addr))
 		return -EINVAL;
 
 	for_each_pmc_socklock(np, sk, mc_lst) {
@@ -254,7 +254,7 @@  int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
 
 	ASSERT_RTNL();
 
-	if (!ipv6_addr_is_multicast(addr))
+	if (!ipv6_addr_is_multicast(addr) && !ipv6_addr_any(addr))
 		return -EINVAL;
 
 	for (lnk = &np->ipv6_mc_list;
@@ -374,7 +374,7 @@  int ip6_mc_source(int add, int omode, struct sock *sk,
 	source = &((struct sockaddr_in6 *)&pgsr->gsr_source)->sin6_addr;
 	group = &((struct sockaddr_in6 *)&pgsr->gsr_group)->sin6_addr;
 
-	if (!ipv6_addr_is_multicast(group))
+	if (!ipv6_addr_is_multicast(group) && !ipv6_addr_any(group))
 		return -EINVAL;
 
 	idev = ip6_mc_find_dev_rtnl(net, group, pgsr->gsr_interface);
@@ -497,7 +497,7 @@  int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
 
 	group = &((struct sockaddr_in6 *)&gsf->gf_group)->sin6_addr;
 
-	if (!ipv6_addr_is_multicast(group))
+	if (!ipv6_addr_is_multicast(group) && !ipv6_addr_any(group))
 		return -EINVAL;
 	if (gsf->gf_fmode != MCAST_INCLUDE &&
 	    gsf->gf_fmode != MCAST_EXCLUDE)
@@ -585,7 +585,7 @@  int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
 
 	group = &((struct sockaddr_in6 *)&gsf->gf_group)->sin6_addr;
 
-	if (!ipv6_addr_is_multicast(group))
+	if (!ipv6_addr_is_multicast(group) && !ipv6_addr_any(group))
 		return -EINVAL;
 
 	/* changes to the ipv6_mc_list require the socket lock and
@@ -634,6 +634,10 @@  bool inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
 	for_each_pmc_rcu(np, mc) {
 		if (ipv6_addr_equal(&mc->addr, mc_addr))
 			break;
+		if (ipv6_addr_any(&mc->addr)) {
+			rcu_read_unlock();
+			return rv;
+		}
 	}
 	if (!mc) {
 		rcu_read_unlock();
@@ -1019,8 +1023,12 @@  bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group,
 		for_each_mc_rcu(idev, mc) {
 			if (ipv6_addr_equal(&mc->mca_addr, group))
 				break;
+			if (ipv6_addr_any(&mc->mca_addr)) {
+				rv = true;
+				break;
+			}
 		}
-		if (mc) {
+		if (mc && !ipv6_addr_any(&mc->mca_addr)) {
 			if (src_addr && !ipv6_addr_any(src_addr)) {
 				struct ip6_sf_list *psf;