diff mbox series

[bpf-next] skmsg: pass psock pointer to ->psock_update_sk_prot()

Message ID 20210407032111.33398-1-xiyou.wangcong@gmail.com
State New
Headers show
Series [bpf-next] skmsg: pass psock pointer to ->psock_update_sk_prot() | expand

Commit Message

Cong Wang April 7, 2021, 3:21 a.m. UTC
From: Cong Wang <cong.wang@bytedance.com>

Using sk_psock() to retrieve psock pointer from sock requires
RCU read lock, but we already get psock pointer before calling
->psock_update_sk_prot() in both cases, so we can just pass it
without bothering sk_psock().

Reported-and-tested-by: syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com
Fixes: 8a59f9d1e3d4 ("sock: Introduce sk->sk_prot->psock_update_sk_prot()")
Cc: John Fastabend <john.fastabend@gmail.com>
Cc: Eric Dumazet <edumazet@google.com>
Cc: Daniel Borkmann <daniel@iogearbox.net>
Cc: Jakub Sitnicki <jakub@cloudflare.com>
Cc: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Cong Wang <cong.wang@bytedance.com>
---
 include/linux/skmsg.h | 5 +++--
 include/net/sock.h    | 5 ++++-
 include/net/tcp.h     | 2 +-
 include/net/udp.h     | 2 +-
 net/core/sock_map.c   | 2 +-
 net/ipv4/tcp_bpf.c    | 3 +--
 net/ipv4/udp_bpf.c    | 3 +--
 7 files changed, 12 insertions(+), 10 deletions(-)

Comments

John Fastabend April 8, 2021, 11:32 p.m. UTC | #1
Cong Wang wrote:
> From: Cong Wang <cong.wang@bytedance.com>

> 

> Using sk_psock() to retrieve psock pointer from sock requires

> RCU read lock, but we already get psock pointer before calling

> ->psock_update_sk_prot() in both cases, so we can just pass it

> without bothering sk_psock().

> 

> Reported-and-tested-by: syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com

> Fixes: 8a59f9d1e3d4 ("sock: Introduce sk->sk_prot->psock_update_sk_prot()")

> Cc: John Fastabend <john.fastabend@gmail.com>

> Cc: Eric Dumazet <edumazet@google.com>

> Cc: Daniel Borkmann <daniel@iogearbox.net>

> Cc: Jakub Sitnicki <jakub@cloudflare.com>

> Cc: Lorenz Bauer <lmb@cloudflare.com>

> Signed-off-by: Cong Wang <cong.wang@bytedance.com>

> ---


Thanks.

Acked-by: John Fastabend <john.fastabend@gmail.com>
Jakub Sitnicki April 12, 2021, 8:05 a.m. UTC | #2
On Wed, Apr 07, 2021 at 05:21 AM CEST, Cong Wang wrote:
> From: Cong Wang <cong.wang@bytedance.com>

>

> Using sk_psock() to retrieve psock pointer from sock requires

> RCU read lock, but we already get psock pointer before calling

> ->psock_update_sk_prot() in both cases, so we can just pass it

> without bothering sk_psock().

>

> Reported-and-tested-by: syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com

> Fixes: 8a59f9d1e3d4 ("sock: Introduce sk->sk_prot->psock_update_sk_prot()")

> Cc: John Fastabend <john.fastabend@gmail.com>

> Cc: Eric Dumazet <edumazet@google.com>

> Cc: Daniel Borkmann <daniel@iogearbox.net>

> Cc: Jakub Sitnicki <jakub@cloudflare.com>

> Cc: Lorenz Bauer <lmb@cloudflare.com>

> Signed-off-by: Cong Wang <cong.wang@bytedance.com>

> ---


We don't necessarily need to pass both sk and psock.  psock has a
backpointer to sk that owns it.

Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
patchwork-bot+netdevbpf@kernel.org April 12, 2021, 3:40 p.m. UTC | #3
Hello:

This patch was applied to bpf/bpf-next.git (refs/heads/master):

On Tue,  6 Apr 2021 20:21:11 -0700 you wrote:
> From: Cong Wang <cong.wang@bytedance.com>

> 

> Using sk_psock() to retrieve psock pointer from sock requires

> RCU read lock, but we already get psock pointer before calling

> ->psock_update_sk_prot() in both cases, so we can just pass it

> without bothering sk_psock().

> 

> [...]


Here is the summary with links:
  - [bpf-next] skmsg: pass psock pointer to ->psock_update_sk_prot()
    https://git.kernel.org/bpf/bpf-next/c/51e0158a5432

You are awesome, thank you!
--
Deet-doot-dot, I am a bot.
https://korg.docs.kernel.org/patchwork/pwbot.html
diff mbox series

Patch

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index f78e90a04a69..e2fb0a5a101e 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -99,7 +99,8 @@  struct sk_psock {
 	void (*saved_close)(struct sock *sk, long timeout);
 	void (*saved_write_space)(struct sock *sk);
 	void (*saved_data_ready)(struct sock *sk);
-	int  (*psock_update_sk_prot)(struct sock *sk, bool restore);
+	int  (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
+				     bool restore);
 	struct proto			*sk_proto;
 	struct mutex			work_mutex;
 	struct sk_psock_work_state	work_state;
@@ -405,7 +406,7 @@  static inline void sk_psock_restore_proto(struct sock *sk,
 {
 	sk->sk_prot->unhash = psock->saved_unhash;
 	if (psock->psock_update_sk_prot)
-		psock->psock_update_sk_prot(sk, true);
+		psock->psock_update_sk_prot(sk, psock, true);
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
diff --git a/include/net/sock.h b/include/net/sock.h
index 8b4155e756c2..c4bbdcd83f4d 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1114,6 +1114,7 @@  struct inet_hashinfo;
 struct raw_hashinfo;
 struct smc_hashinfo;
 struct module;
+struct sk_psock;
 
 /*
  * caches using SLAB_TYPESAFE_BY_RCU should let .next pointer from nulls nodes
@@ -1185,7 +1186,9 @@  struct proto {
 	void			(*rehash)(struct sock *sk);
 	int			(*get_port)(struct sock *sk, unsigned short snum);
 #ifdef CONFIG_BPF_SYSCALL
-	int			(*psock_update_sk_prot)(struct sock *sk, bool restore);
+	int			(*psock_update_sk_prot)(struct sock *sk,
+							struct sk_psock *psock,
+							bool restore);
 #endif
 
 	/* Keeping track of sockets in use */
diff --git a/include/net/tcp.h b/include/net/tcp.h
index eaea43afcc97..d05193cb0d99 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2215,7 +2215,7 @@  struct sk_psock;
 
 #ifdef CONFIG_BPF_SYSCALL
 struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
-int tcp_bpf_update_proto(struct sock *sk, bool restore);
+int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 #endif /* CONFIG_BPF_SYSCALL */
 
diff --git a/include/net/udp.h b/include/net/udp.h
index f55aaeef7e91..360df454356c 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -543,7 +543,7 @@  static inline void udp_post_segment_fix_csum(struct sk_buff *skb)
 #ifdef CONFIG_BPF_SYSCALL
 struct sk_psock;
 struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
-int udp_bpf_update_proto(struct sock *sk, bool restore);
+int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
 #endif
 
 #endif	/* _UDP_H */
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index 3d190d22b0d8..f473c51cbc4b 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -188,7 +188,7 @@  static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 	if (!sk->sk_prot->psock_update_sk_prot)
 		return -EINVAL;
 	psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
-	return sk->sk_prot->psock_update_sk_prot(sk, false);
+	return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 3d622a0d0753..4930bc8ab47e 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -499,9 +499,8 @@  static int tcp_bpf_assert_proto_ops(struct proto *ops)
 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-int tcp_bpf_update_proto(struct sock *sk, bool restore)
+int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 {
-	struct sk_psock *psock = sk_psock(sk);
 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
index 4a7e38c5d842..954c4591a6fd 100644
--- a/net/ipv4/udp_bpf.c
+++ b/net/ipv4/udp_bpf.c
@@ -103,10 +103,9 @@  static int __init udp_bpf_v4_build_proto(void)
 }
 core_initcall(udp_bpf_v4_build_proto);
 
-int udp_bpf_update_proto(struct sock *sk, bool restore)
+int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 {
 	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
-	struct sk_psock *psock = sk_psock(sk);
 
 	if (restore) {
 		sk->sk_write_space = psock->saved_write_space;